diff --git a/bench-tmp_test.go b/bench-tmp_test.go new file mode 100644 index 00000000..a8e3f7db --- /dev/null +++ b/bench-tmp_test.go @@ -0,0 +1,55 @@ +package pgx_test + +import ( + "testing" +) + +func BenchmarkPgtypeInt4ParseBinary(b *testing.B) { + conn := mustConnect(b, *defaultConnConfig) + defer closeConn(b, conn) + + _, err := conn.Prepare("selectBinary", "select n::int4 from generate_series(1, 100) n") + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + var n int32 + + rows, err := conn.Query("selectBinary") + if err != nil { + b.Fatal(err) + } + + for rows.Next() { + err := rows.Scan(&n) + if err != nil { + b.Fatal(err) + } + } + + if rows.Err() != nil { + b.Fatal(rows.Err()) + } + } +} + +func BenchmarkPgtypeInt4EncodeBinary(b *testing.B) { + conn := mustConnect(b, *defaultConnConfig) + defer closeConn(b, conn) + + _, err := conn.Prepare("encodeBinary", "select $1::int4, $2::int4, $3::int4, $4::int4, $5::int4, $6::int4, $7::int4") + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + rows, err := conn.Query("encodeBinary", int32(i), int32(i), int32(i), int32(i), int32(i), int32(i), int32(i)) + if err != nil { + b.Fatal(err) + } + rows.Close() + } +} diff --git a/conn.go b/conn.go index 9303fb74..09dada10 100644 --- a/conn.go +++ b/conn.go @@ -7,7 +7,6 @@ import ( "encoding/hex" "errors" "fmt" - "golang.org/x/net/context" "io" "net" "net/url" @@ -20,7 +19,10 @@ import ( "sync/atomic" "time" + "golang.org/x/net/context" + "github.com/jackc/pgx/chunkreader" + "github.com/jackc/pgx/pgtype" ) const ( @@ -102,6 +104,8 @@ type Conn struct { ctxInProgress bool doneChan chan struct{} closedChan chan error + + oidPgtypeValues map[OID]pgtype.Value } // PreparedStatement is a description of a prepared statement @@ -275,6 +279,16 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.doneChan = make(chan struct{}) c.closedChan = make(chan error) + c.oidPgtypeValues = map[OID]pgtype.Value{ + BoolOID: &pgtype.Bool{}, + DateOID: &pgtype.Date{}, + Int2OID: &pgtype.Int2{}, + Int2ArrayOID: &pgtype.Int2Array{}, + Int4OID: &pgtype.Int4{}, + Int8OID: &pgtype.Int8{}, + TimestampTzOID: &pgtype.Timestamptz{}, + } + if tlsConfig != nil { if c.shouldLog(LogLevelDebug) { c.log(LogLevelDebug, "Starting TLS handshake") @@ -961,6 +975,10 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} switch arg := arguments[i].(type) { case Encoder: wbuf.WriteInt16(arg.FormatCode()) + case pgtype.BinaryEncoder: + wbuf.WriteInt16(BinaryFormatCode) + case pgtype.TextEncoder: + wbuf.WriteInt16(TextFormatCode) case string, *string: wbuf.WriteInt16(TextFormatCode) default: diff --git a/copy_to_test.go b/copy_to_test.go index 43cb5acc..7d5f2509 100644 --- a/copy_to_test.go +++ b/copy_to_test.go @@ -26,7 +26,7 @@ func TestConnCopyToSmall(t *testing.T) { )`) inputRows := [][]interface{}{ - {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)}, + {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)}, {nil, nil, nil, nil, nil, nil, nil}, } @@ -83,7 +83,7 @@ func TestConnCopyToLarge(t *testing.T) { inputRows := [][]interface{}{} for i := 0; i < 10000; i++ { - inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}}) + inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}}) } copyCount, err := conn.CopyTo("foo", []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyToRows(inputRows)) diff --git a/messages.go b/messages.go index c2964b82..f6be9ff9 100644 --- a/messages.go +++ b/messages.go @@ -101,6 +101,7 @@ func newWriteBuf(c *Conn, t byte) *WriteBuf { // by the Encoder interface when implementing custom encoders. type WriteBuf struct { buf []byte + convBuf [8]byte sizeIdx int conn *Conn } @@ -125,35 +126,40 @@ func (wb *WriteBuf) WriteCString(s string) { } func (wb *WriteBuf) WriteInt16(n int16) { - b := make([]byte, 2) - binary.BigEndian.PutUint16(b, uint16(n)) - wb.buf = append(wb.buf, b...) + wb.WriteUint16(uint16(n)) } -func (wb *WriteBuf) WriteUint16(n uint16) { - b := make([]byte, 2) - binary.BigEndian.PutUint16(b, n) - wb.buf = append(wb.buf, b...) +func (wb *WriteBuf) WriteUint16(n uint16) (int, error) { + binary.BigEndian.PutUint16(wb.convBuf[:2], n) + wb.buf = append(wb.buf, wb.convBuf[:2]...) + return 2, nil } func (wb *WriteBuf) WriteInt32(n int32) { - b := make([]byte, 4) - binary.BigEndian.PutUint32(b, uint32(n)) - wb.buf = append(wb.buf, b...) + wb.WriteUint32(uint32(n)) } -func (wb *WriteBuf) WriteUint32(n uint32) { - b := make([]byte, 4) - binary.BigEndian.PutUint32(b, n) - wb.buf = append(wb.buf, b...) +func (wb *WriteBuf) WriteUint32(n uint32) (int, error) { + binary.BigEndian.PutUint32(wb.convBuf[:4], n) + wb.buf = append(wb.buf, wb.convBuf[:4]...) + return 4, nil } func (wb *WriteBuf) WriteInt64(n int64) { - b := make([]byte, 8) - binary.BigEndian.PutUint64(b, uint64(n)) - wb.buf = append(wb.buf, b...) + wb.WriteUint64(uint64(n)) +} + +func (wb *WriteBuf) WriteUint64(n uint64) (int, error) { + binary.BigEndian.PutUint64(wb.convBuf[:8], n) + wb.buf = append(wb.buf, wb.convBuf[:8]...) + return 8, nil } func (wb *WriteBuf) WriteBytes(b []byte) { wb.buf = append(wb.buf, b...) } + +func (wb *WriteBuf) Write(b []byte) (int, error) { + wb.buf = append(wb.buf, b...) + return len(b), nil +} diff --git a/pgio/doc.go b/pgio/doc.go new file mode 100644 index 00000000..36233a47 --- /dev/null +++ b/pgio/doc.go @@ -0,0 +1,8 @@ +// Package pgio a extremely low-level IO toolkit for the PostgreSQL wire protocol. +/* +pgio provides functions for reading and writing integers from io.Reader and +io.Writer while doing byte order conversion. It publishes interfaces which +readers and writers may implement to decode and encode messages with the minimum +of memory allocations. +*/ +package pgio diff --git a/pgio/read.go b/pgio/read.go new file mode 100644 index 00000000..7c39162c --- /dev/null +++ b/pgio/read.go @@ -0,0 +1,104 @@ +package pgio + +import ( + "encoding/binary" + "io" +) + +type Uint16Reader interface { + ReadUint16() (n uint16, err error) +} + +type Uint32Reader interface { + ReadUint32() (n uint32, err error) +} + +type Uint64Reader interface { + ReadUint64() (n uint64, err error) +} + +// ReadByte reads a byte from r. +func ReadByte(r io.Reader) (byte, error) { + if r, ok := r.(io.ByteReader); ok { + return r.ReadByte() + } + + buf := make([]byte, 1) + _, err := r.Read(buf) + return buf[0], err +} + +// ReadUint16 reads an uint16 from r in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Read if r provides a ReadUint16 +// method. +func ReadUint16(r io.Reader) (uint16, error) { + if r, ok := r.(Uint16Reader); ok { + return r.ReadUint16() + } + + buf := make([]byte, 2) + _, err := io.ReadFull(r, buf) + if err != nil { + return 0, err + } + + return binary.BigEndian.Uint16(buf), nil +} + +// ReadInt16 reads an int16 r in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Read if r provides a ReadUint16 +// method. +func ReadInt16(r io.Reader) (int16, error) { + n, err := ReadUint16(r) + return int16(n), err +} + +// ReadUint32 reads an uint32 r in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Read if r provides a ReadUint32 +// method. +func ReadUint32(r io.Reader) (uint32, error) { + if r, ok := r.(Uint32Reader); ok { + return r.ReadUint32() + } + + buf := make([]byte, 4) + _, err := io.ReadFull(r, buf) + if err != nil { + return 0, err + } + + return binary.BigEndian.Uint32(buf), nil +} + +// ReadInt32 reads an int32 r in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Read if r provides a ReadUint32 +// method. +func ReadInt32(r io.Reader) (int32, error) { + n, err := ReadUint32(r) + return int32(n), err +} + +// ReadUint64 reads an uint64 r in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Read if r provides a ReadUint64 +// method. +func ReadUint64(r io.Reader) (uint64, error) { + if r, ok := r.(Uint64Reader); ok { + return r.ReadUint64() + } + + buf := make([]byte, 8) + _, err := io.ReadFull(r, buf) + if err != nil { + return 0, err + } + + return binary.BigEndian.Uint64(buf), nil +} + +// ReadInt64 reads an int64 r in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Read if r provides a ReadUint64 +// method. +func ReadInt64(r io.Reader) (int64, error) { + n, err := ReadUint64(r) + return int64(n), err +} diff --git a/pgio/write.go b/pgio/write.go new file mode 100644 index 00000000..823fbd00 --- /dev/null +++ b/pgio/write.go @@ -0,0 +1,97 @@ +package pgio + +import ( + "encoding/binary" + "io" +) + +type Uint16Writer interface { + WriteUint16(uint16) (n int, err error) +} + +type Uint32Writer interface { + WriteUint32(uint32) (n int, err error) +} + +type Uint64Writer interface { + WriteUint64(uint64) (n int, err error) +} + +// WriteByte writes b to w. +func WriteByte(w io.Writer, b byte) error { + if w, ok := w.(io.ByteWriter); ok { + return w.WriteByte(b) + } + _, err := w.Write([]byte{b}) + return err +} + +// WriteUint16 writes n to w in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Write if w provides a WriteUint16 +// method. +func WriteUint16(w io.Writer, n uint16) (int, error) { + if w, ok := w.(Uint16Writer); ok { + return w.WriteUint16(n) + } + b := make([]byte, 2) + binary.BigEndian.PutUint16(b, n) + return w.Write(b) +} + +// WriteInt16 writes n to w in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Write if w provides a WriteUint16 +// method. +func WriteInt16(w io.Writer, n int16) (int, error) { + return WriteUint16(w, uint16(n)) +} + +// WriteUint32 writes n to w in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Write if w provides a WriteUint32 +// method. +func WriteUint32(w io.Writer, n uint32) (int, error) { + if w, ok := w.(Uint32Writer); ok { + return w.WriteUint32(n) + } + b := make([]byte, 4) + binary.BigEndian.PutUint32(b, n) + return w.Write(b) +} + +// WriteInt32 writes n to w in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Write if w provides a WriteUint32 +// method. +func WriteInt32(w io.Writer, n int32) (int, error) { + return WriteUint32(w, uint32(n)) +} + +// WriteUint64 writes n to w in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Write if w provides a WriteUint64 +// method. +func WriteUint64(w io.Writer, n uint64) (int, error) { + if w, ok := w.(Uint64Writer); ok { + return w.WriteUint64(n) + } + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, n) + return w.Write(b) +} + +// WriteInt64 writes n to w in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Write if w provides a WriteUint64 +// method. +func WriteInt64(w io.Writer, n int64) (int, error) { + return WriteUint64(w, uint64(n)) +} + +// WriteCString writes s to w followed by a null byte. +func WriteCString(w io.Writer, s string) (int, error) { + n, err := io.WriteString(w, s) + if err != nil { + return n, err + } + err = WriteByte(w, 0) + if err != nil { + return n, err + } + return n + 1, nil +} diff --git a/pgtype/array.go b/pgtype/array.go new file mode 100644 index 00000000..75d2e440 --- /dev/null +++ b/pgtype/array.go @@ -0,0 +1,375 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + "strconv" + "unicode" + + "github.com/jackc/pgx/pgio" +) + +// Information on the internals of PostgreSQL arrays can be found in +// src/include/utils/array.h and src/backend/utils/adt/arrayfuncs.c. Of +// particular interest is the array_send function. + +type ArrayHeader struct { + ContainsNull bool + ElementOID int32 + Dimensions []ArrayDimension +} + +type ArrayDimension struct { + Length int32 + LowerBound int32 +} + +func (ah *ArrayHeader) DecodeBinary(r io.Reader) error { + numDims, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if numDims > 0 { + ah.Dimensions = make([]ArrayDimension, numDims) + } + + containsNull, err := pgio.ReadInt32(r) + if err != nil { + return err + } + ah.ContainsNull = containsNull == 1 + + ah.ElementOID, err = pgio.ReadInt32(r) + if err != nil { + return err + } + + for i := range ah.Dimensions { + ah.Dimensions[i].Length, err = pgio.ReadInt32(r) + if err != nil { + return err + } + + ah.Dimensions[i].LowerBound, err = pgio.ReadInt32(r) + if err != nil { + return err + } + } + + return nil +} + +func (ah *ArrayHeader) EncodeBinary(w io.Writer) error { + _, err := pgio.WriteInt32(w, int32(len(ah.Dimensions))) + if err != nil { + return err + } + + var containsNull int32 + if ah.ContainsNull { + containsNull = 1 + } + _, err = pgio.WriteInt32(w, containsNull) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, ah.ElementOID) + if err != nil { + return err + } + + for i := range ah.Dimensions { + _, err = pgio.WriteInt32(w, ah.Dimensions[i].Length) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, ah.Dimensions[i].LowerBound) + if err != nil { + return err + } + } + + return nil +} + +type UntypedTextArray struct { + Elements []string + Dimensions []ArrayDimension +} + +func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { + uta := &UntypedTextArray{} + + buf := bytes.NewBufferString(src) + + skipWhitespace(buf) + + r, _, err := buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + var explicitDimensions []ArrayDimension + + // Array has explicit dimensions + if r == '[' { + buf.UnreadRune() + + for { + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r == '=' { + break + } else if r != '[' { + return nil, fmt.Errorf("invalid array, expected '[' or '=' got %v", r) + } + + lower, err := arrayParseInteger(buf) + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r != ':' { + return nil, fmt.Errorf("invalid array, expected ':' got %v", r) + } + + upper, err := arrayParseInteger(buf) + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r != ']' { + return nil, fmt.Errorf("invalid array, expected ']' got %v", r) + } + + explicitDimensions = append(explicitDimensions, ArrayDimension{LowerBound: lower, Length: upper - lower + 1}) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + } + + if r != '{' { + return nil, fmt.Errorf("invalid array, expected '{': %v", err) + } + + implicitDimensions := []ArrayDimension{{LowerBound: 1, Length: 0}} + + // Consume all initial opening brackets. This provides number of dimensions. + for { + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r == '{' { + implicitDimensions[len(implicitDimensions)-1].Length = 1 + implicitDimensions = append(implicitDimensions, ArrayDimension{LowerBound: 1}) + } else { + buf.UnreadRune() + break + } + } + currentDim := len(implicitDimensions) - 1 + counterDim := currentDim + + for { + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + switch r { + case '{': + if currentDim == counterDim { + implicitDimensions[currentDim].Length++ + } + currentDim++ + case ',': + case '}': + currentDim-- + if currentDim < counterDim { + counterDim = currentDim + } + default: + buf.UnreadRune() + value, err := arrayParseValue(buf) + if err != nil { + return nil, fmt.Errorf("invalid array value: %v", err) + } + if currentDim == counterDim { + implicitDimensions[currentDim].Length++ + } + uta.Elements = append(uta.Elements, value) + } + + if currentDim < 0 { + break + } + } + + skipWhitespace(buf) + + if buf.Len() > 0 { + return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) + } + + if len(uta.Elements) == 0 { + uta.Dimensions = nil + } else if len(explicitDimensions) > 0 { + uta.Dimensions = explicitDimensions + } else { + uta.Dimensions = implicitDimensions + } + + return uta, nil +} + +func skipWhitespace(buf *bytes.Buffer) { + var r rune + var err error + for r, _, _ = buf.ReadRune(); unicode.IsSpace(r); r, _, _ = buf.ReadRune() { + } + + if err != io.EOF { + buf.UnreadRune() + } +} + +func arrayParseValue(buf *bytes.Buffer) (string, error) { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + if r == '"' { + return arrayParseQuotedValue(buf) + } + buf.UnreadRune() + + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case ',', '}': + buf.UnreadRune() + return s.String(), nil + } + + s.WriteRune(r) + } +} + +func arrayParseQuotedValue(buf *bytes.Buffer) (string, error) { + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case '\\': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + case '"': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + buf.UnreadRune() + return s.String(), nil + } + s.WriteRune(r) + } +} + +func arrayParseInteger(buf *bytes.Buffer) (int32, error) { + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return 0, err + } + + if '0' <= r && r <= '9' { + s.WriteRune(r) + } else { + buf.UnreadRune() + n, err := strconv.ParseInt(s.String(), 10, 32) + if err != nil { + return 0, err + } + return int32(n), nil + } + } +} + +func EncodeTextArrayDimensions(w io.Writer, dimensions []ArrayDimension) error { + var customDimensions bool + for _, dim := range dimensions { + if dim.LowerBound != 1 { + customDimensions = true + } + } + + if !customDimensions { + return nil + } + + for _, dim := range dimensions { + err := pgio.WriteByte(w, '[') + if err != nil { + return err + } + + _, err = io.WriteString(w, strconv.FormatInt(int64(dim.LowerBound), 10)) + if err != nil { + return err + } + + err = pgio.WriteByte(w, ':') + if err != nil { + return err + } + + _, err = io.WriteString(w, strconv.FormatInt(int64(dim.LowerBound+dim.Length-1), 10)) + if err != nil { + return err + } + + err = pgio.WriteByte(w, ']') + if err != nil { + return err + } + } + + return pgio.WriteByte(w, '=') +} diff --git a/pgtype/array_test.go b/pgtype/array_test.go new file mode 100644 index 00000000..5e5f00e7 --- /dev/null +++ b/pgtype/array_test.go @@ -0,0 +1,98 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestParseUntypedTextArray(t *testing.T) { + tests := []struct { + source string + result pgtype.UntypedTextArray + }{ + { + source: "{}", + result: pgtype.UntypedTextArray{ + Elements: nil, + Dimensions: nil, + }, + }, + { + source: "{1}", + result: pgtype.UntypedTextArray{ + Elements: []string{"1"}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: "{a,b}", + result: pgtype.UntypedTextArray{ + Elements: []string{"a", "b"}, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + }, + }, + { + source: `{"NULL"}`, + result: pgtype.UntypedTextArray{ + Elements: []string{"NULL"}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: `{"He said, \"Hello.\""}`, + result: pgtype.UntypedTextArray{ + Elements: []string{`He said, "Hello."`}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: "{{a,b},{c,d},{e,f}}", + result: pgtype.UntypedTextArray{ + Elements: []string{"a", "b", "c", "d", "e", "f"}, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + }, + }, + { + source: "{{{a,b},{c,d},{e,f}},{{a,b},{c,d},{e,f}}}", + result: pgtype.UntypedTextArray{ + Elements: []string{"a", "b", "c", "d", "e", "f", "a", "b", "c", "d", "e", "f"}, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 1}, + {Length: 3, LowerBound: 1}, + {Length: 2, LowerBound: 1}, + }, + }, + }, + { + source: "[4:4]={1}", + result: pgtype.UntypedTextArray{ + Elements: []string{"1"}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 4}}, + }, + }, + { + source: "[4:5][2:3]={{a,b},{c,d}}", + result: pgtype.UntypedTextArray{ + Elements: []string{"a", "b", "c", "d"}, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + }, + }, + } + + for i, tt := range tests { + r, err := pgtype.ParseUntypedTextArray(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + continue + } + + if !reflect.DeepEqual(*r, tt.result) { + t.Errorf("%d: expected %+v to be parsed to %+v, but it was %+v", i, tt.source, tt.result, *r) + } + } +} diff --git a/pgtype/bool.go b/pgtype/bool.go new file mode 100644 index 00000000..81c72472 --- /dev/null +++ b/pgtype/bool.go @@ -0,0 +1,166 @@ +package pgtype + +import ( + "fmt" + "io" + "reflect" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +type Bool struct { + Bool bool + Status Status +} + +func (b *Bool) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Bool: + *b = value + case bool: + *b = Bool{Bool: value, Status: Present} + case string: + bb, err := strconv.ParseBool(value) + if err != nil { + return err + } + *b = Bool{Bool: bb, Status: Present} + default: + if originalSrc, ok := underlyingBoolType(src); ok { + return b.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Bool", value) + } + + return nil +} + +func (b *Bool) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *bool: + if b.Status != Present { + return fmt.Errorf("cannot assign %v to %T", b, dst) + } + *v = b.Bool + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if b.Status == Null { + if !el.IsNil() { + // if the destination pointer is not nil, nil it out + el.Set(reflect.Zero(el.Type())) + } + return nil + } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return b.AssignTo(el.Interface()) + case reflect.Bool: + if b.Status != Present { + return fmt.Errorf("cannot assign %v to %T", b, dst) + } + el.SetBool(b.Bool) + return nil + } + } + return fmt.Errorf("cannot put decode %v into %T", b, dst) + } + + return nil +} + +func (b *Bool) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *b = Bool{Status: Null} + return nil + } + + if size != 1 { + return fmt.Errorf("invalid length for bool: %v", size) + } + + byt, err := pgio.ReadByte(r) + if err != nil { + return err + } + + *b = Bool{Bool: byt == 't', Status: Present} + return nil +} + +func (b *Bool) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *b = Bool{Status: Null} + return nil + } + + if size != 1 { + return fmt.Errorf("invalid length for bool: %v", size) + } + + byt, err := pgio.ReadByte(r) + if err != nil { + return err + } + + *b = Bool{Bool: byt == 1, Status: Present} + return nil +} + +func (b Bool) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, b.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 1) + if err != nil { + return nil + } + + var buf []byte + if b.Bool { + buf = []byte{'t'} + } else { + buf = []byte{'f'} + } + + _, err = w.Write(buf) + return err +} + +func (b Bool) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, b.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 1) + if err != nil { + return nil + } + + var buf []byte + if b.Bool { + buf = []byte{1} + } else { + buf = []byte{0} + } + + _, err = w.Write(buf) + return err +} diff --git a/pgtype/bool_test.go b/pgtype/bool_test.go new file mode 100644 index 00000000..53df1747 --- /dev/null +++ b/pgtype/bool_test.go @@ -0,0 +1,43 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestBoolTranscode(t *testing.T) { + testSuccessfulTranscode(t, "bool", []interface{}{ + pgtype.Bool{Bool: false, Status: pgtype.Present}, + pgtype.Bool{Bool: true, Status: pgtype.Present}, + pgtype.Bool{Bool: false, Status: pgtype.Null}, + }) +} + +func TestBoolConvertFrom(t *testing.T) { + type _int8 int8 + + successfulTests := []struct { + source interface{} + result pgtype.Bool + }{ + {source: true, result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: false, result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, + {source: "true", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: "false", result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, + {source: "t", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: "f", result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Bool + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/pgtype/convert.go b/pgtype/convert.go new file mode 100644 index 00000000..3f3d9e5f --- /dev/null +++ b/pgtype/convert.go @@ -0,0 +1,239 @@ +package pgtype + +import ( + "fmt" + "math" + "reflect" + "time" +) + +const maxUint = ^uint(0) +const maxInt = int(maxUint >> 1) +const minInt = -maxInt - 1 + +// underlyingIntType gets the underlying type that can be converted to Int2, Int4, or Int8 +func underlyingIntType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Int: + convVal := int(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int8: + convVal := int8(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int16: + convVal := int16(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int32: + convVal := int32(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int64: + convVal := int64(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint: + convVal := uint(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint8: + convVal := uint8(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint16: + convVal := uint16(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint32: + convVal := uint32(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint64: + convVal := uint64(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.String: + convVal := refVal.String() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + + return nil, false +} + +// underlyingBoolType gets the underlying type that can be converted to Bool +func underlyingBoolType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Bool: + convVal := refVal.Bool() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + + return nil, false +} + +// underlyingTimeType gets the underlying type that can be converted to time.Time +func underlyingTimeType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return time.Time{}, false + } + convVal := refVal.Elem().Interface() + return convVal, true + } + + timeType := reflect.TypeOf(time.Time{}) + if refVal.Type().ConvertibleTo(timeType) { + return refVal.Convert(timeType).Interface(), true + } + + return time.Time{}, false +} + +// underlyingSliceType gets the underlying slice type +func underlyingSliceType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Slice: + baseSliceType := reflect.SliceOf(refVal.Type().Elem()) + if refVal.Type().ConvertibleTo(baseSliceType) { + convVal := refVal.Convert(baseSliceType) + return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type() + } + } + + return nil, false +} + +func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { + if srcStatus == Present { + switch v := dst.(type) { + case *int: + if srcVal < int64(minInt) { + return fmt.Errorf("%d is less than minimum value for int", srcVal) + } else if srcVal > int64(maxInt) { + return fmt.Errorf("%d is greater than maximum value for int", srcVal) + } + *v = int(srcVal) + case *int8: + if srcVal < math.MinInt8 { + return fmt.Errorf("%d is less than minimum value for int8", srcVal) + } else if srcVal > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for int8", srcVal) + } + *v = int8(srcVal) + case *int16: + if srcVal < math.MinInt16 { + return fmt.Errorf("%d is less than minimum value for int16", srcVal) + } else if srcVal > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for int16", srcVal) + } + *v = int16(srcVal) + case *int32: + if srcVal < math.MinInt32 { + return fmt.Errorf("%d is less than minimum value for int32", srcVal) + } else if srcVal > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for int32", srcVal) + } + *v = int32(srcVal) + case *int64: + if srcVal < math.MinInt64 { + return fmt.Errorf("%d is less than minimum value for int64", srcVal) + } else if srcVal > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for int64", srcVal) + } + *v = int64(srcVal) + case *uint: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint", srcVal) + } else if uint64(srcVal) > uint64(maxUint) { + return fmt.Errorf("%d is greater than maximum value for uint", srcVal) + } + *v = uint(srcVal) + case *uint8: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint8", srcVal) + } else if srcVal > math.MaxUint8 { + return fmt.Errorf("%d is greater than maximum value for uint8", srcVal) + } + *v = uint8(srcVal) + case *uint16: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint32", srcVal) + } else if srcVal > math.MaxUint16 { + return fmt.Errorf("%d is greater than maximum value for uint16", srcVal) + } + *v = uint16(srcVal) + case *uint32: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint32", srcVal) + } else if srcVal > math.MaxUint32 { + return fmt.Errorf("%d is greater than maximum value for uint32", srcVal) + } + *v = uint32(srcVal) + case *uint64: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint64", srcVal) + } + *v = uint64(srcVal) + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return int64AssignTo(srcVal, srcStatus, el.Interface()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if el.OverflowInt(int64(srcVal)) { + return fmt.Errorf("cannot put %d into %T", srcVal, dst) + } + el.SetInt(int64(srcVal)) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for %T", srcVal, dst) + } + if el.OverflowUint(uint64(srcVal)) { + return fmt.Errorf("cannot put %d into %T", srcVal, dst) + } + el.SetUint(uint64(srcVal)) + return nil + } + } + return fmt.Errorf("cannot assign %v into %T", srcVal, dst) + } + return nil + } + + // if dst is a pointer to pointer and srcStatus is not Present, nil it out + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + if el.Kind() == reflect.Ptr { + el.Set(reflect.Zero(el.Type())) + return nil + } + } + + return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) +} diff --git a/pgtype/date.go b/pgtype/date.go new file mode 100644 index 00000000..f3e3e4c6 --- /dev/null +++ b/pgtype/date.go @@ -0,0 +1,191 @@ +package pgtype + +import ( + "fmt" + "io" + "reflect" + "time" + + "github.com/jackc/pgx/pgio" +) + +type Date struct { + Time time.Time + Status Status + InfinityModifier +} + +const ( + negativeInfinityDayOffset = -2147483648 + infinityDayOffset = 2147483647 +) + +func (d *Date) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Date: + *d = value + case time.Time: + *d = Date{Time: value, Status: Present} + default: + if originalSrc, ok := underlyingTimeType(src); ok { + return d.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Date", value) + } + + return nil +} + +func (d *Date) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *time.Time: + if d.Status != Present || d.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", d, dst) + } + *v = d.Time + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if d.Status == Null { + if !el.IsNil() { + // if the destination pointer is not nil, nil it out + el.Set(reflect.Zero(el.Type())) + } + return nil + } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return d.AssignTo(el.Interface()) + } + } + return fmt.Errorf("cannot decode %v into %T", d, dst) + } + + return nil +} + +func (d *Date) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *d = Date{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + sbuf := string(buf) + switch sbuf { + case "infinity": + *d = Date{Status: Present, InfinityModifier: Infinity} + case "-infinity": + *d = Date{Status: Present, InfinityModifier: -Infinity} + default: + t, err := time.ParseInLocation("2006-01-02", sbuf, time.UTC) + if err != nil { + return err + } + + *d = Date{Time: t, Status: Present} + } + + return nil +} + +func (d *Date) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *d = Date{Status: Null} + return nil + } + + if size != 4 { + return fmt.Errorf("invalid length for date: %v", size) + } + + dayOffset, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + switch dayOffset { + case infinityDayOffset: + *d = Date{Status: Present, InfinityModifier: Infinity} + case negativeInfinityDayOffset: + *d = Date{Status: Present, InfinityModifier: -Infinity} + default: + t := time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.UTC) + *d = Date{Time: t, Status: Present} + } + + return nil +} + +func (d Date) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, d.Status); done { + return err + } + + var s string + + switch d.InfinityModifier { + case None: + s = d.Time.Format("2006-01-02") + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + + _, err = w.Write([]byte(s)) + return err +} + +func (d Date) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, d.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 4) + if err != nil { + return err + } + + var daysSinceDateEpoch int32 + switch d.InfinityModifier { + case None: + tUnix := time.Date(d.Time.Year(), d.Time.Month(), d.Time.Day(), 0, 0, 0, 0, time.UTC).Unix() + dateEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).Unix() + + secSinceDateEpoch := tUnix - dateEpoch + daysSinceDateEpoch = int32(secSinceDateEpoch / 86400) + case Infinity: + daysSinceDateEpoch = infinityDayOffset + case NegativeInfinity: + daysSinceDateEpoch = negativeInfinityDayOffset + } + + _, err = pgio.WriteInt32(w, daysSinceDateEpoch) + return err +} diff --git a/pgtype/date_test.go b/pgtype/date_test.go new file mode 100644 index 00000000..c3e971d0 --- /dev/null +++ b/pgtype/date_test.go @@ -0,0 +1,51 @@ +package pgtype_test + +import ( + "testing" + "time" + + "github.com/jackc/pgx/pgtype" +) + +func TestDateTranscode(t *testing.T) { + testSuccessfulTranscode(t, "date", []interface{}{ + pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Status: pgtype.Null}, + pgtype.Date{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, + pgtype.Date{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, + }) +} + +func TestDateConvertFrom(t *testing.T) { + type _time time.Time + + successfulTests := []struct { + source interface{} + result pgtype.Date + }{ + {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)), result: pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var d pgtype.Date + err := d.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if d != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} diff --git a/pgtype/extra-interface.txt b/pgtype/extra-interface.txt new file mode 100644 index 00000000..16453823 --- /dev/null +++ b/pgtype/extra-interface.txt @@ -0,0 +1,3 @@ +Can pass function to get inet data and function to get oid/name mapping as optional interface with io.Reader or io.Writer + +Could be useful for arrays of types without defined OIDs like hstore. diff --git a/pgtype/int2.go b/pgtype/int2.go new file mode 100644 index 00000000..2da8a96d --- /dev/null +++ b/pgtype/int2.go @@ -0,0 +1,167 @@ +package pgtype + +import ( + "fmt" + "io" + "math" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +type Int2 struct { + Int int16 + Status Status +} + +func (i *Int2) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Int2: + *i = value + case int8: + *i = Int2{Int: int16(value), Status: Present} + case uint8: + *i = Int2{Int: int16(value), Status: Present} + case int16: + *i = Int2{Int: int16(value), Status: Present} + case uint16: + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *i = Int2{Int: int16(value), Status: Present} + case int32: + if value < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *i = Int2{Int: int16(value), Status: Present} + case uint32: + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *i = Int2{Int: int16(value), Status: Present} + case int64: + if value < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *i = Int2{Int: int16(value), Status: Present} + case uint64: + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *i = Int2{Int: int16(value), Status: Present} + case int: + if value < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *i = Int2{Int: int16(value), Status: Present} + case uint: + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *i = Int2{Int: int16(value), Status: Present} + case string: + num, err := strconv.ParseInt(value, 10, 16) + if err != nil { + return err + } + *i = Int2{Int: int16(num), Status: Present} + default: + if originalSrc, ok := underlyingIntType(src); ok { + return i.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int2", value) + } + + return nil +} + +func (i *Int2) AssignTo(dst interface{}) error { + return int64AssignTo(int64(i.Int), i.Status, dst) +} + +func (i *Int2) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *i = Int2{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + n, err := strconv.ParseInt(string(buf), 10, 16) + if err != nil { + return err + } + + *i = Int2{Int: int16(n), Status: Present} + return nil +} + +func (i *Int2) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *i = Int2{Status: Null} + return nil + } + + if size != 2 { + return fmt.Errorf("invalid length for int2: %v", size) + } + + n, err := pgio.ReadInt16(r) + if err != nil { + return err + } + + *i = Int2{Int: int16(n), Status: Present} + return nil +} + +func (i Int2) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, i.Status); done { + return err + } + + s := strconv.FormatInt(int64(i.Int), 10) + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + _, err = w.Write([]byte(s)) + return err +} + +func (i Int2) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, i.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = pgio.WriteInt16(w, i.Int) + return err +} diff --git a/pgtype/int2_test.go b/pgtype/int2_test.go new file mode 100644 index 00000000..a8493a16 --- /dev/null +++ b/pgtype/int2_test.go @@ -0,0 +1,55 @@ +package pgtype_test + +import ( + "math" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInt2Transcode(t *testing.T) { + testSuccessfulTranscode(t, "int2", []interface{}{ + pgtype.Int2{Int: math.MinInt16, Status: pgtype.Present}, + pgtype.Int2{Int: -1, Status: pgtype.Present}, + pgtype.Int2{Int: 0, Status: pgtype.Present}, + pgtype.Int2{Int: 1, Status: pgtype.Present}, + pgtype.Int2{Int: math.MaxInt16, Status: pgtype.Present}, + pgtype.Int2{Int: 0, Status: pgtype.Null}, + }) +} + +func TestInt2ConvertFrom(t *testing.T) { + type _int8 int8 + + successfulTests := []struct { + source interface{} + result pgtype.Int2 + }{ + {source: int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Int2 + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/pgtype/int2array.go b/pgtype/int2array.go new file mode 100644 index 00000000..86375516 --- /dev/null +++ b/pgtype/int2array.go @@ -0,0 +1,308 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type Int2Array struct { + Elements []Int2 + Dimensions []ArrayDimension + Status Status +} + +func (a *Int2Array) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Int2Array: + *a = value + case []int16: + if value == nil { + *a = Int2Array{Status: Null} + } else if len(value) == 0 { + *a = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *a = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []uint16: + if value == nil { + *a = Int2Array{Status: Null} + } else if len(value) == 0 { + *a = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *a = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return a.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int2", value) + } + + return nil +} + +func (a *Int2Array) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *[]int16: + if a.Status == Present { + *v = make([]int16, len(a.Elements)) + for i := range a.Elements { + if err := a.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + case *[]uint16: + if a.Status == Present { + *v = make([]uint16, len(a.Elements)) + for i := range a.Elements { + if err := a.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + default: + return fmt.Errorf("cannot put decode %v into %T", a, dst) + } + + return nil +} + +func (a *Int2Array) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *a = Int2Array{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + uta, err := ParseUntypedTextArray(string(buf)) + if err != nil { + return err + } + + textElementReader := NewTextElementReader(r) + var elements []Int2 + + if len(uta.Elements) > 0 { + elements = make([]Int2, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Int2 + textElementReader.Reset(s) + err = elem.DecodeText(textElementReader) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *a = Int2Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (a *Int2Array) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *a = Int2Array{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + err = arrayHeader.DecodeBinary(r) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *a = Int2Array{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Int2, elementCount) + + for i := range elements { + err = elements[i].DecodeBinary(r) + if err != nil { + return err + } + } + + *a = Int2Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (a *Int2Array) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, a.Status); done { + return err + } + + if len(a.Dimensions) == 0 { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = w.Write([]byte("{}")) + return err + } + + buf := &bytes.Buffer{} + + err := EncodeTextArrayDimensions(buf, a.Dimensions) + if err != nil { + return err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(a.Dimensions)) + dimElemCounts[len(a.Dimensions)-1] = int(a.Dimensions[len(a.Dimensions)-1].Length) + for i := len(a.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(a.Dimensions[i].Length) * dimElemCounts[i+1] + } + + textElementWriter := NewTextElementWriter(buf) + + for i, elem := range a.Elements { + if i > 0 { + err = pgio.WriteByte(buf, ',') + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(buf, '{') + if err != nil { + return err + } + } + } + + textElementWriter.Reset() + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(buf, '}') + if err != nil { + return err + } + } + } + } + + _, err = pgio.WriteInt32(w, int32(buf.Len())) + if err != nil { + return err + } + + _, err = buf.WriteTo(w) + return err +} + +func (a *Int2Array) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, a.Status); done { + return err + } + + var arrayHeader ArrayHeader + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + elemBuf := &bytes.Buffer{} + + for i := range a.Elements { + err := a.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return err + } + if a.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + } + } + + arrayHeader.ElementOID = Int2OID + arrayHeader.Dimensions = a.Dimensions + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + headerBuf := &bytes.Buffer{} + err := arrayHeader.EncodeBinary(headerBuf) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) + if err != nil { + return err + } + + _, err = headerBuf.WriteTo(w) + if err != nil { + return err + } + + _, err = elemBuf.WriteTo(w) + if err != nil { + return err + } + + return err +} diff --git a/pgtype/int2array_test.go b/pgtype/int2array_test.go new file mode 100644 index 00000000..5ea81990 --- /dev/null +++ b/pgtype/int2array_test.go @@ -0,0 +1,87 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInt2ArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "int2[]", []interface{}{ + &pgtype.Int2Array{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.Int2Array{ + Elements: []pgtype.Int2{ + pgtype.Int2{Int: 1, Status: pgtype.Present}, + pgtype.Int2{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int2Array{Status: pgtype.Null}, + &pgtype.Int2Array{ + Elements: []pgtype.Int2{ + pgtype.Int2{Int: 1, Status: pgtype.Present}, + pgtype.Int2{Int: 2, Status: pgtype.Present}, + pgtype.Int2{Int: 3, Status: pgtype.Present}, + pgtype.Int2{Int: 4, Status: pgtype.Present}, + pgtype.Int2{Status: pgtype.Null}, + pgtype.Int2{Int: 6, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int2Array{ + Elements: []pgtype.Int2{ + pgtype.Int2{Int: 1, Status: pgtype.Present}, + pgtype.Int2{Int: 2, Status: pgtype.Present}, + pgtype.Int2{Int: 3, Status: pgtype.Present}, + pgtype.Int2{Int: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +// func TestInt2ConvertFrom(t *testing.T) { +// type _int8 int8 + +// successfulTests := []struct { +// source interface{} +// result pgtype.Int2 +// }{ +// {source: int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: int16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: int32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: int64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: int8(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, +// {source: int16(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, +// {source: int32(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, +// {source: int64(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, +// {source: uint8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: uint16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: uint32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: uint64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: "1", result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: _int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// } + +// for i, tt := range successfulTests { +// var r pgtype.Int2 +// err := r.ConvertFrom(tt.source) +// if err != nil { +// t.Errorf("%d: %v", i, err) +// } + +// if r != tt.result { +// t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) +// } +// } +// } diff --git a/pgtype/int4.go b/pgtype/int4.go new file mode 100644 index 00000000..84c45522 --- /dev/null +++ b/pgtype/int4.go @@ -0,0 +1,158 @@ +package pgtype + +import ( + "fmt" + "io" + "math" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +type Int4 struct { + Int int32 + Status Status +} + +func (i *Int4) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Int4: + *i = value + case int8: + *i = Int4{Int: int32(value), Status: Present} + case uint8: + *i = Int4{Int: int32(value), Status: Present} + case int16: + *i = Int4{Int: int32(value), Status: Present} + case uint16: + *i = Int4{Int: int32(value), Status: Present} + case int32: + *i = Int4{Int: int32(value), Status: Present} + case uint32: + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *i = Int4{Int: int32(value), Status: Present} + case int64: + if value < math.MinInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *i = Int4{Int: int32(value), Status: Present} + case uint64: + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *i = Int4{Int: int32(value), Status: Present} + case int: + if value < math.MinInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *i = Int4{Int: int32(value), Status: Present} + case uint: + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *i = Int4{Int: int32(value), Status: Present} + case string: + num, err := strconv.ParseInt(value, 10, 32) + if err != nil { + return err + } + *i = Int4{Int: int32(num), Status: Present} + default: + if originalSrc, ok := underlyingIntType(src); ok { + return i.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int8", value) + } + + return nil +} + +func (i *Int4) AssignTo(dst interface{}) error { + return int64AssignTo(int64(i.Int), i.Status, dst) +} + +func (i *Int4) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *i = Int4{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + n, err := strconv.ParseInt(string(buf), 10, 32) + if err != nil { + return err + } + + *i = Int4{Int: int32(n), Status: Present} + return nil +} + +func (i *Int4) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *i = Int4{Status: Null} + return nil + } + + if size != 4 { + return fmt.Errorf("invalid length for int4: %v", size) + } + + n, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + *i = Int4{Int: n, Status: Present} + return nil +} + +func (i Int4) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, i.Status); done { + return err + } + + s := strconv.FormatInt(int64(i.Int), 10) + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + _, err = w.Write([]byte(s)) + return err +} + +func (i Int4) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, i.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 4) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, i.Int) + return err +} diff --git a/pgtype/int4_test.go b/pgtype/int4_test.go new file mode 100644 index 00000000..04411849 --- /dev/null +++ b/pgtype/int4_test.go @@ -0,0 +1,55 @@ +package pgtype_test + +import ( + "math" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInt4Transcode(t *testing.T) { + testSuccessfulTranscode(t, "int4", []interface{}{ + pgtype.Int4{Int: math.MinInt32, Status: pgtype.Present}, + pgtype.Int4{Int: -1, Status: pgtype.Present}, + pgtype.Int4{Int: 0, Status: pgtype.Present}, + pgtype.Int4{Int: 1, Status: pgtype.Present}, + pgtype.Int4{Int: math.MaxInt32, Status: pgtype.Present}, + pgtype.Int4{Int: 0, Status: pgtype.Null}, + }) +} + +func TestInt4ConvertFrom(t *testing.T) { + type _int8 int8 + + successfulTests := []struct { + source interface{} + result pgtype.Int4 + }{ + {source: int8(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Int4 + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/pgtype/int8.go b/pgtype/int8.go new file mode 100644 index 00000000..c0e14e44 --- /dev/null +++ b/pgtype/int8.go @@ -0,0 +1,149 @@ +package pgtype + +import ( + "fmt" + "io" + "math" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +type Int8 struct { + Int int64 + Status Status +} + +func (i *Int8) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Int8: + *i = value + case int8: + *i = Int8{Int: int64(value), Status: Present} + case uint8: + *i = Int8{Int: int64(value), Status: Present} + case int16: + *i = Int8{Int: int64(value), Status: Present} + case uint16: + *i = Int8{Int: int64(value), Status: Present} + case int32: + *i = Int8{Int: int64(value), Status: Present} + case uint32: + *i = Int8{Int: int64(value), Status: Present} + case int64: + *i = Int8{Int: int64(value), Status: Present} + case uint64: + if value > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", value) + } + *i = Int8{Int: int64(value), Status: Present} + case int: + if int64(value) < math.MinInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", value) + } + if int64(value) > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", value) + } + *i = Int8{Int: int64(value), Status: Present} + case uint: + if uint64(value) > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", value) + } + *i = Int8{Int: int64(value), Status: Present} + case string: + num, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return err + } + *i = Int8{Int: num, Status: Present} + default: + if originalSrc, ok := underlyingIntType(src); ok { + return i.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int8", value) + } + + return nil +} + +func (i *Int8) AssignTo(dst interface{}) error { + return int64AssignTo(int64(i.Int), i.Status, dst) +} + +func (i *Int8) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *i = Int8{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + n, err := strconv.ParseInt(string(buf), 10, 64) + if err != nil { + return err + } + + *i = Int8{Int: n, Status: Present} + return nil +} + +func (i *Int8) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *i = Int8{Status: Null} + return nil + } + + if size != 8 { + return fmt.Errorf("invalid length for int8: %v", size) + } + + n, err := pgio.ReadInt64(r) + if err != nil { + return err + } + + *i = Int8{Int: n, Status: Present} + return nil +} + +func (i Int8) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, i.Status); done { + return err + } + + s := strconv.FormatInt(i.Int, 10) + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + _, err = w.Write([]byte(s)) + return err +} + +func (i Int8) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, i.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 8) + if err != nil { + return err + } + + _, err = pgio.WriteInt64(w, i.Int) + return err +} diff --git a/pgtype/int8_test.go b/pgtype/int8_test.go new file mode 100644 index 00000000..ba246224 --- /dev/null +++ b/pgtype/int8_test.go @@ -0,0 +1,55 @@ +package pgtype_test + +import ( + "math" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInt8Transcode(t *testing.T) { + testSuccessfulTranscode(t, "int8", []interface{}{ + pgtype.Int8{Int: math.MinInt64, Status: pgtype.Present}, + pgtype.Int8{Int: -1, Status: pgtype.Present}, + pgtype.Int8{Int: 0, Status: pgtype.Present}, + pgtype.Int8{Int: 1, Status: pgtype.Present}, + pgtype.Int8{Int: math.MaxInt64, Status: pgtype.Present}, + pgtype.Int8{Int: 0, Status: pgtype.Null}, + }) +} + +func TestInt8ConvertFrom(t *testing.T) { + type _int8 int8 + + successfulTests := []struct { + source interface{} + result pgtype.Int8 + }{ + {source: int8(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Int8 + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go new file mode 100644 index 00000000..f9833363 --- /dev/null +++ b/pgtype/pgtype.go @@ -0,0 +1,102 @@ +package pgtype + +import ( + "errors" + "io" + + "github.com/jackc/pgx/pgio" +) + +// PostgreSQL oids for common types +const ( + BoolOID = 16 + ByteaOID = 17 + CharOID = 18 + NameOID = 19 + Int8OID = 20 + Int2OID = 21 + Int4OID = 23 + TextOID = 25 + OIDOID = 26 + TidOID = 27 + XidOID = 28 + CidOID = 29 + JSONOID = 114 + CidrOID = 650 + CidrArrayOID = 651 + Float4OID = 700 + Float8OID = 701 + UnknownOID = 705 + InetOID = 869 + BoolArrayOID = 1000 + Int2ArrayOID = 1005 + Int4ArrayOID = 1007 + TextArrayOID = 1009 + ByteaArrayOID = 1001 + VarcharArrayOID = 1015 + Int8ArrayOID = 1016 + Float4ArrayOID = 1021 + Float8ArrayOID = 1022 + AclItemOID = 1033 + AclItemArrayOID = 1034 + InetArrayOID = 1041 + VarcharOID = 1043 + DateOID = 1082 + TimestampOID = 1114 + TimestampArrayOID = 1115 + TimestampTzOID = 1184 + TimestampTzArrayOID = 1185 + RecordOID = 2249 + UUIDOID = 2950 + JSONBOID = 3802 +) + +type Status byte + +const ( + Undefined Status = iota + Null + Present +) + +type InfinityModifier int8 + +const ( + Infinity InfinityModifier = 1 + None InfinityModifier = 0 + NegativeInfinity InfinityModifier = -Infinity +) + +type Value interface { + ConvertFrom(src interface{}) error + AssignTo(dst interface{}) error +} + +type BinaryDecoder interface { + DecodeBinary(r io.Reader) error +} + +type TextDecoder interface { + DecodeText(r io.Reader) error +} + +type BinaryEncoder interface { + EncodeBinary(w io.Writer) error +} + +type TextEncoder interface { + EncodeText(w io.Writer) error +} + +var errUndefined = errors.New("cannot encode status undefined") + +func encodeNotPresent(w io.Writer, status Status) (done bool, err error) { + switch status { + case Undefined: + return true, errUndefined + case Null: + _, err = pgio.WriteInt32(w, -1) + return true, err + } + return false, nil +} diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go new file mode 100644 index 00000000..a1a575f7 --- /dev/null +++ b/pgtype/pgtype_test.go @@ -0,0 +1,108 @@ +package pgtype_test + +import ( + "fmt" + "io" + "os" + "reflect" + "testing" + + "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" +) + +func mustConnectPgx(t testing.TB) *pgx.Conn { + config, err := pgx.ParseURI(os.Getenv("DATABASE_URL")) + if err != nil { + t.Fatal(err) + } + + conn, err := pgx.Connect(config) + if err != nil { + t.Fatal(err) + } + + return conn +} + +func mustClose(t testing.TB, conn interface { + Close() error +}) { + err := conn.Close() + if err != nil { + t.Fatal(err) + } +} + +type forceTextEncoder struct { + e pgtype.TextEncoder +} + +func (f forceTextEncoder) EncodeText(w io.Writer) error { + return f.e.EncodeText(w) +} + +type forceBinaryEncoder struct { + e pgtype.BinaryEncoder +} + +func (f forceBinaryEncoder) EncodeBinary(w io.Writer) error { + return f.e.EncodeBinary(w) +} + +func forceEncoder(e interface{}, formatCode int16) interface{} { + switch formatCode { + case pgx.TextFormatCode: + return forceTextEncoder{e: e.(pgtype.TextEncoder)} + case pgx.BinaryFormatCode: + return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)} + default: + panic("bad encoder") + } +} + +func testSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface{}) { + testSuccessfulTranscodeEqFunc(t, pgTypeName, values, func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + }) +} + +func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + conn := mustConnectPgx(t) + defer mustClose(t, conn) + + ps, err := conn.Prepare("test", fmt.Sprintf("select $1::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for _, fc := range formats { + ps.FieldDescriptions[0].FormatCode = fc.formatCode + for i, v := range values { + // Derefence value if it is a pointer + derefV := v + refVal := reflect.ValueOf(v) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err := conn.QueryRow("test", forceEncoder(v, fc.formatCode)).Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", fc.name, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface()) + } + } + } +} diff --git a/pgtype/text_element.go b/pgtype/text_element.go new file mode 100644 index 00000000..1a585d08 --- /dev/null +++ b/pgtype/text_element.go @@ -0,0 +1,112 @@ +package pgtype + +import ( + "bytes" + "errors" + "io" + + "github.com/jackc/pgx/pgio" +) + +// TextElementWriter is a wrapper that makes TextEncoders composable into other +// TextEncoders. TextEncoder first writes the length of the subsequent value. +// This is not necessary when the value is part of another value such as an +// array. TextElementWriter requires one int32 to be written first which it +// ignores. No other integer writes are valid. +type TextElementWriter struct { + w io.Writer + lengthHeaderIgnored bool +} + +func NewTextElementWriter(w io.Writer) *TextElementWriter { + return &TextElementWriter{w: w} +} + +func (w *TextElementWriter) WriteUint16(n uint16) (int, error) { + return 0, errors.New("WriteUint16 should never be called on TextElementWriter") +} + +func (w *TextElementWriter) WriteUint32(n uint32) (int, error) { + if !w.lengthHeaderIgnored { + w.lengthHeaderIgnored = true + + if int32(n) == -1 { + return io.WriteString(w.w, "NULL") + } + + return 4, nil + } + + return 0, errors.New("WriteUint32 should only be called once on TextElementWriter") +} + +func (w *TextElementWriter) WriteUint64(n uint64) (int, error) { + if w.lengthHeaderIgnored { + return pgio.WriteUint64(w.w, n) + } + + return 0, errors.New("WriteUint64 should never be called on TextElementWriter") +} + +func (w *TextElementWriter) Write(buf []byte) (int, error) { + if w.lengthHeaderIgnored { + return w.w.Write(buf) + } + + return 0, errors.New("int32 must be written first") +} + +func (w *TextElementWriter) Reset() { + w.lengthHeaderIgnored = false +} + +// TextElementReader is a wrapper that makes TextDecoders composable into other +// TextDecoders. TextEncoders first read the length of the subsequent value. +// This length value is not present when the value is part of another value such +// as an array. TextElementReader provides a substitute length value from the +// length of the string. No other integer reads are valid. Each time DecodeText +// is called with a TextElementReader as the source the TextElementReader must +// first have Reset called with the new element string data. +type TextElementReader struct { + buf *bytes.Buffer + lengthHeaderIgnored bool +} + +func NewTextElementReader(r io.Reader) *TextElementReader { + return &TextElementReader{buf: &bytes.Buffer{}} +} + +func (r *TextElementReader) ReadUint16() (uint16, error) { + return 0, errors.New("ReadUint16 should never be called on TextElementReader") +} + +func (r *TextElementReader) ReadUint32() (uint32, error) { + if !r.lengthHeaderIgnored { + r.lengthHeaderIgnored = true + if r.buf.String() == "NULL" { + n32 := int32(-1) + return uint32(n32), nil + } + return uint32(r.buf.Len()), nil + } + + return 0, errors.New("ReadUint32 should only be called once on TextElementReader") +} + +func (r *TextElementReader) WriteUint64(n uint64) (int, error) { + return 0, errors.New("ReadUint64 should never be called on TextElementReader") +} + +func (r *TextElementReader) Read(buf []byte) (int, error) { + if r.lengthHeaderIgnored { + return r.buf.Read(buf) + } + + return 0, errors.New("int32 must be read first") +} + +func (r *TextElementReader) Reset(s string) { + r.lengthHeaderIgnored = false + r.buf.Reset() + r.buf.WriteString(s) +} diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go new file mode 100644 index 00000000..cc33b296 --- /dev/null +++ b/pgtype/timestamptz.go @@ -0,0 +1,203 @@ +package pgtype + +import ( + "fmt" + "io" + "reflect" + "time" + + "github.com/jackc/pgx/pgio" +) + +const pgTimestamptzHourFormat = "2006-01-02 15:04:05.999999999Z07" +const pgTimestamptzMinuteFormat = "2006-01-02 15:04:05.999999999Z07:00" +const pgTimestamptzSecondFormat = "2006-01-02 15:04:05.999999999Z07:00:00" +const microsecFromUnixEpochToY2K = 946684800 * 1000000 + +const ( + negativeInfinityMicrosecondOffset = -9223372036854775808 + infinityMicrosecondOffset = 9223372036854775807 +) + +type Timestamptz struct { + Time time.Time + Status Status + InfinityModifier +} + +func (t *Timestamptz) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Timestamptz: + *t = value + case time.Time: + *t = Timestamptz{Time: value, Status: Present} + default: + if originalSrc, ok := underlyingTimeType(src); ok { + return t.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Timestamptz", value) + } + + return nil +} + +func (t *Timestamptz) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *time.Time: + if t.Status != Present || t.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", t, dst) + } + *v = t.Time + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if t.Status == Null { + if !el.IsNil() { + // if the destination pointer is not nil, nil it out + el.Set(reflect.Zero(el.Type())) + } + return nil + } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return t.AssignTo(el.Interface()) + } + } + return fmt.Errorf("cannot assign %v into %T", t, dst) + } + + return nil +} + +func (t *Timestamptz) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *t = Timestamptz{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + sbuf := string(buf) + switch sbuf { + case "infinity": + *t = Timestamptz{Status: Present, InfinityModifier: Infinity} + case "-infinity": + *t = Timestamptz{Status: Present, InfinityModifier: -Infinity} + default: + var format string + if sbuf[len(sbuf)-9] == '-' || sbuf[len(sbuf)-9] == '+' { + format = pgTimestamptzSecondFormat + } else if sbuf[len(sbuf)-6] == '-' || sbuf[len(sbuf)-6] == '+' { + format = pgTimestamptzMinuteFormat + } else { + format = pgTimestamptzHourFormat + } + + tim, err := time.Parse(format, sbuf) + if err != nil { + return err + } + + *t = Timestamptz{Time: tim, Status: Present} + } + + return nil +} + +func (t *Timestamptz) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *t = Timestamptz{Status: Null} + return nil + } + + if size != 8 { + return fmt.Errorf("invalid length for timestamptz: %v", size) + } + + microsecSinceY2K, err := pgio.ReadInt64(r) + if err != nil { + return err + } + + switch microsecSinceY2K { + case infinityMicrosecondOffset: + *t = Timestamptz{Status: Present, InfinityModifier: Infinity} + case negativeInfinityMicrosecondOffset: + *t = Timestamptz{Status: Present, InfinityModifier: -Infinity} + default: + microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K + tim := time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) + *t = Timestamptz{Time: tim, Status: Present} + } + + return nil +} + +func (t Timestamptz) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, t.Status); done { + return err + } + + var s string + + switch t.InfinityModifier { + case None: + s = t.Time.UTC().Format(pgTimestamptzSecondFormat) + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + + _, err = w.Write([]byte(s)) + return err +} + +func (t Timestamptz) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, t.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 8) + if err != nil { + return err + } + + var microsecSinceY2K int64 + switch t.InfinityModifier { + case None: + microsecSinceUnixEpoch := t.Time.Unix()*1000000 + int64(t.Time.Nanosecond())/1000 + microsecSinceY2K = microsecSinceUnixEpoch - microsecFromUnixEpochToY2K + case Infinity: + microsecSinceY2K = infinityMicrosecondOffset + case NegativeInfinity: + microsecSinceY2K = negativeInfinityMicrosecondOffset + } + + _, err = pgio.WriteInt64(w, microsecSinceY2K) + return err +} diff --git a/pgtype/timestamptz_test.go b/pgtype/timestamptz_test.go new file mode 100644 index 00000000..795195f8 --- /dev/null +++ b/pgtype/timestamptz_test.go @@ -0,0 +1,60 @@ +package pgtype_test + +import ( + "testing" + "time" + + "github.com/jackc/pgx/pgtype" +) + +func TestTimestamptzTranscode(t *testing.T) { + testSuccessfulTranscodeEqFunc(t, "timestamptz", []interface{}{ + pgtype.Timestamptz{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Status: pgtype.Null}, + pgtype.Timestamptz{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, + pgtype.Timestamptz{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, + }, func(a, b interface{}) bool { + at := a.(pgtype.Timestamptz) + bt := b.(pgtype.Timestamptz) + + return at.Time.Equal(bt.Time) && at.Status == bt.Status && at.InfinityModifier == bt.InfinityModifier + }) +} + +func TestTimestamptzConvertFrom(t *testing.T) { + type _time time.Time + + successfulTests := []struct { + source interface{} + result pgtype.Timestamptz + }{ + {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 1, 0, 0, 1, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 1, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local)), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Timestamptz + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/query.go b/query.go index 99b383e0..4af1de10 100644 --- a/query.go +++ b/query.go @@ -4,8 +4,11 @@ import ( "database/sql" "errors" "fmt" - "golang.org/x/net/context" "time" + + "golang.org/x/net/context" + + "github.com/jackc/pgx/pgtype" ) // Row is a convenience wrapper over Rows that is returned by QueryRow. @@ -219,6 +222,27 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } + } else if s, ok := d.(ScannerV3); ok { + val, err := decodeByOID(vr) + if err != nil { + rows.Fatal(scanArgError{col: i, err: err}) + } + err = s.ScanPgxV3(nil, val) + if err != nil { + rows.Fatal(scanArgError{col: i, err: err}) + } + } else if s, ok := d.(pgtype.BinaryDecoder); ok && vr.Type().FormatCode == BinaryFormatCode { + vr.err = errRewoundLen + err = s.DecodeBinary(&valueReader2{vr}) + if err != nil { + rows.Fatal(scanArgError{col: i, err: err}) + } + } else if s, ok := d.(pgtype.TextDecoder); ok && vr.Type().FormatCode == TextFormatCode { + vr.err = errRewoundLen + err = s.DecodeText(&valueReader2{vr}) + if err != nil { + rows.Fatal(scanArgError{col: i, err: err}) + } } else if s, ok := d.(sql.Scanner); ok { var val interface{} if 0 <= vr.Len() { @@ -265,8 +289,39 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { d2 := d decodeJSONB(vr, &d2) } else { - if err := Decode(vr, d); err != nil { - rows.Fatal(scanArgError{col: i, err: err}) + if pgVal, present := rows.conn.oidPgtypeValues[vr.Type().DataType]; present { + switch vr.Type().FormatCode { + case TextFormatCode: + if textDecoder, ok := pgVal.(pgtype.TextDecoder); ok { + vr.err = errRewoundLen + err = textDecoder.DecodeText(&valueReader2{vr}) + if err != nil { + vr.Fatal(err) + } + } else { + vr.Fatal(fmt.Errorf("%T is not a pgtype.TextDecoder", pgVal)) + } + case BinaryFormatCode: + if binaryDecoder, ok := pgVal.(pgtype.BinaryDecoder); ok { + vr.err = errRewoundLen + err = binaryDecoder.DecodeBinary(&valueReader2{vr}) + if err != nil { + vr.Fatal(err) + } + } else { + vr.Fatal(fmt.Errorf("%T is not a pgtype.BinaryDecoder", pgVal)) + } + default: + vr.Fatal(fmt.Errorf("unknown format code: %v", vr.Type().FormatCode)) + } + + if err := pgVal.AssignTo(d); err != nil { + vr.Fatal(err) + } + } else { + if err := Decode(vr, d); err != nil { + rows.Fatal(scanArgError{col: i, err: err}) + } } } if vr.Err() != nil { @@ -296,7 +351,7 @@ func (rows *Rows) Values() ([]interface{}, error) { values = append(values, nil) continue } - + // TODO - consider what are the implications of returning complex types since database/sql uses this method switch vr.Type().FormatCode { // All intrinsic types (except string) are encoded with binary // encoding so anything else should be treated as a string diff --git a/query_test.go b/query_test.go index a78914b6..fd5d2e5b 100644 --- a/query_test.go +++ b/query_test.go @@ -4,11 +4,12 @@ import ( "bytes" "database/sql" "fmt" - "golang.org/x/net/context" "strings" "testing" "time" + "golang.org/x/net/context" + "github.com/jackc/pgx" "github.com/shopspring/decimal" @@ -110,7 +111,7 @@ func TestRowsScanDoesNotAllowScanningBinaryFormatValuesIntoString(t *testing.T) var s string err := conn.QueryRow("select 1").Scan(&s) - if err == nil || !strings.Contains(err.Error(), "cannot decode binary value into string") { + if err == nil || !(strings.Contains(err.Error(), "cannot decode binary value into string") || strings.Contains(err.Error(), "cannot assign")) { t.Fatalf("Expected Scan to fail to encode binary value into string but: %v", err) } @@ -199,7 +200,7 @@ func TestConnQueryReadWrongTypeError(t *testing.T) { t.Fatal("Expected Rows to have an error after an improper read but it didn't") } - if rows.Err().Error() != "can't scan into dest[0]: Can't convert OID 23 to time.Time" { + if rows.Err().Error() != "can't scan into dest[0]: Can't convert OID 23 to time.Time" && !strings.Contains(rows.Err().Error(), "cannot assign") { t.Fatalf("Expected different Rows.Err(): %v", rows.Err()) } @@ -518,7 +519,7 @@ func TestQueryRowCoreTypes(t *testing.T) { {"select $1::bool", []interface{}{true}, []interface{}{&actual.b}, allTypes{b: true}}, {"select $1::timestamptz", []interface{}{time.Unix(123, 5000)}, []interface{}{&actual.t}, allTypes{t: time.Unix(123, 5000)}}, {"select $1::timestamp", []interface{}{time.Date(2010, 1, 2, 3, 4, 5, 0, time.Local)}, []interface{}{&actual.t}, allTypes{t: time.Date(2010, 1, 2, 3, 4, 5, 0, time.Local)}}, - {"select $1::date", []interface{}{time.Date(1987, 1, 2, 0, 0, 0, 0, time.Local)}, []interface{}{&actual.t}, allTypes{t: time.Date(1987, 1, 2, 0, 0, 0, 0, time.Local)}}, + {"select $1::date", []interface{}{time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}, []interface{}{&actual.t}, allTypes{t: time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}}, {"select $1::oid", []interface{}{pgx.OID(42)}, []interface{}{&actual.oid}, allTypes{oid: 42}}, } @@ -541,7 +542,7 @@ func TestQueryRowCoreTypes(t *testing.T) { if err == nil { t.Errorf("%d. Expected null to cause error, but it didn't (sql -> %v)", i, tt.sql) } - if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { + if err != nil && !strings.Contains(err.Error(), "Cannot decode null") && !strings.Contains(err.Error(), "cannot assign") { t.Errorf(`%d. Expected null to cause error "Cannot decode null..." but it was %v (sql -> %v)`, i, err, tt.sql) } @@ -944,7 +945,7 @@ func TestQueryRowErrors(t *testing.T) { {"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`}, {"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"}, {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "Cannot decode oid 25 into any integer type"}, - {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "cannot encode int8 into oid 600"}, + {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "Cannot encode int into oid 600"}, } for i, tt := range tests { @@ -1017,7 +1018,7 @@ func TestQueryRowCoreInt16Slice(t *testing.T) { if err == nil { t.Error("Expected null to cause error when scanned into slice, but it didn't") } - if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { + if err != nil && !(strings.Contains(err.Error(), "Cannot decode null") || strings.Contains(err.Error(), "cannot assign")) { t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err) } diff --git a/value_reader.go b/value_reader.go index 249b8ba3..c91a21af 100644 --- a/value_reader.go +++ b/value_reader.go @@ -4,6 +4,8 @@ import ( "errors" ) +var errRewoundLen = errors.New("len was rewound") + // ValueReader is used by the Scanner interface to decode values. type ValueReader struct { mr *msgReader @@ -154,3 +156,28 @@ func (r *ValueReader) ReadBytes(count int32) []byte { return r.mr.readBytes(count) } + +type valueReader2 struct { + *ValueReader +} + +func (r *valueReader2) Read(dst []byte) (int, error) { + if r.err != nil { + return 0, r.err + } + + src := r.ReadBytes(int32(len(dst))) + + copy(dst, src) + + return len(dst), nil +} + +func (r *valueReader2) ReadUint32() (uint32, error) { + if r.err == errRewoundLen { + r.err = nil + return uint32(r.Len()), nil + } + + return r.ValueReader.ReadUint32(), nil +} diff --git a/values.go b/values.go index 45ed914c..a9c4c209 100644 --- a/values.go +++ b/values.go @@ -13,6 +13,8 @@ import ( "strconv" "strings" "time" + + "github.com/jackc/pgx/pgtype" ) // PostgreSQL oids for common types @@ -200,6 +202,10 @@ type Encoder interface { FormatCode() int16 } +type ScannerV3 interface { + ScanPgxV3(fieldDescription interface{}, src interface{}) error +} + // NullFloat32 represents an float4 that may be null. NullFloat32 implements the // Scanner and Encoder interfaces so it may be used both as an argument to // Query[Row] and a destination for Scan. @@ -497,7 +503,7 @@ func (n NullInt16) Encode(w *WriteBuf, oid OID) error { return nil } - return encodeInt16(w, oid, n.Int16) + return pgtype.Int2{Int: n.Int16, Status: pgtype.Present}.EncodeBinary(w) } // NullInt32 represents an integer that may be null. NullInt32 implements the @@ -536,7 +542,7 @@ func (n NullInt32) Encode(w *WriteBuf, oid OID) error { return nil } - return encodeInt32(w, oid, n.Int32) + return pgtype.Int4{Int: n.Int32, Status: pgtype.Present}.EncodeBinary(w) } // OID (Object Identifier Type) is, according to https://www.postgresql.org/docs/current/static/datatype-oid.html, @@ -782,7 +788,7 @@ func (n NullInt64) Encode(w *WriteBuf, oid OID) error { return nil } - return encodeInt64(w, oid, n.Int64) + return pgtype.Int8{Int: n.Int64, Status: pgtype.Present}.EncodeBinary(w) } // NullBool represents an bool that may be null. NullBool implements the Scanner @@ -1020,6 +1026,10 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { switch arg := arg.(type) { case Encoder: return arg.Encode(wbuf, oid) + case pgtype.BinaryEncoder: + return arg.EncodeBinary(wbuf) + case pgtype.TextEncoder: + return arg.EncodeText(wbuf) case driver.Valuer: v, err := arg.Value() if err != nil { @@ -1054,17 +1064,19 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { return encodeJSONB(wbuf, oid, arg) } + if value, ok := wbuf.conn.oidPgtypeValues[oid]; ok { + err := value.ConvertFrom(arg) + if err != nil { + return err + } + return value.(pgtype.BinaryEncoder).EncodeBinary(wbuf) + } + switch arg := arg.(type) { case []string: return encodeStringSlice(wbuf, oid, arg) - case bool: - return encodeBool(wbuf, oid, arg) case []bool: return encodeBoolSlice(wbuf, oid, arg) - case int: - return encodeInt(wbuf, oid, arg) - case uint: - return encodeUInt(wbuf, oid, arg) case Char: return encodeChar(wbuf, oid, arg) case AclItem: @@ -1075,32 +1087,12 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { // The name data type goes over the wire using the same format as string, // so just cast to string and use encodeString return encodeString(wbuf, oid, string(arg)) - case int8: - return encodeInt8(wbuf, oid, arg) - case uint8: - return encodeUInt8(wbuf, oid, arg) - case int16: - return encodeInt16(wbuf, oid, arg) - case []int16: - return encodeInt16Slice(wbuf, oid, arg) - case uint16: - return encodeUInt16(wbuf, oid, arg) - case []uint16: - return encodeUInt16Slice(wbuf, oid, arg) - case int32: - return encodeInt32(wbuf, oid, arg) case []int32: return encodeInt32Slice(wbuf, oid, arg) - case uint32: - return encodeUInt32(wbuf, oid, arg) case []uint32: return encodeUInt32Slice(wbuf, oid, arg) - case int64: - return encodeInt64(wbuf, oid, arg) case []int64: return encodeInt64Slice(wbuf, oid, arg) - case uint64: - return encodeUInt64(wbuf, oid, arg) case []uint64: return encodeUInt64Slice(wbuf, oid, arg) case float32: @@ -1140,32 +1132,57 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { func stripNamedType(val *reflect.Value) (interface{}, bool) { switch val.Kind() { case reflect.Int: - return int(val.Int()), true + convVal := int(val.Int()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Int8: - return int8(val.Int()), true + convVal := int8(val.Int()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Int16: - return int16(val.Int()), true + convVal := int16(val.Int()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Int32: - return int32(val.Int()), true + convVal := int32(val.Int()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Int64: - return int64(val.Int()), true + convVal := int64(val.Int()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Uint: - return uint(val.Uint()), true + convVal := uint(val.Uint()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Uint8: - return uint8(val.Uint()), true + convVal := uint8(val.Uint()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Uint16: - return uint16(val.Uint()), true + convVal := uint16(val.Uint()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Uint32: - return uint32(val.Uint()), true + convVal := uint32(val.Uint()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Uint64: - return uint64(val.Uint()), true + convVal := uint64(val.Uint()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.String: - return val.String(), true + convVal := val.String() + return convVal, reflect.TypeOf(convVal) != val.Type() } return nil, false } +func decodeByOID(vr *ValueReader) (interface{}, error) { + switch vr.Type().DataType { + case Int2OID, Int4OID, Int8OID: + n := decodeInt(vr) + return n, vr.Err() + case BoolOID: + b := decodeBool(vr) + return b, vr.Err() + default: + buf := vr.ReadBytes(vr.Len()) + return buf, vr.Err() + } +} + // Decode decodes from vr into d. d must be a pointer. This allows // implementations of the Decoder interface to delegate the actual work of // decoding to the built-in functionality. @@ -1381,28 +1398,36 @@ func Decode(vr *ValueReader, d interface{}) error { } func decodeBool(vr *ValueReader) bool { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into bool")) - return false - } - if vr.Type().DataType != BoolOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into bool", vr.Type().DataType))) return false } - if vr.Type().FormatCode != BinaryFormatCode { + vr.err = errRewoundLen + + var b pgtype.Bool + var err error + switch vr.Type().FormatCode { + case TextFormatCode: + err = b.DecodeText(&valueReader2{vr}) + case BinaryFormatCode: + err = b.DecodeBinary(&valueReader2{vr}) + default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return false } - if vr.Len() != 1 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool: %d", vr.Len()))) + if err != nil { + vr.Fatal(err) return false } - b := vr.ReadByte() - return b != 0 + if b.Status != pgtype.Present { + vr.Fatal(fmt.Errorf("Cannot decode null into bool")) + return false + } + + return b.Bool } func encodeBool(w *WriteBuf, oid OID, value bool) error { @@ -1410,16 +1435,8 @@ func encodeBool(w *WriteBuf, oid OID, value bool) error { return fmt.Errorf("cannot encode Go %s into oid %d", "bool", oid) } - w.WriteInt32(1) - - var n byte - if value { - n = 1 - } - - w.WriteByte(n) - - return nil + b := pgtype.Bool{Bool: value, Status: pgtype.Present} + return b.EncodeBinary(w) } func decodeInt(vr *ValueReader) int64 { @@ -1447,17 +1464,31 @@ func decodeInt8(vr *ValueReader) int64 { return 0 } - if vr.Type().FormatCode != BinaryFormatCode { + vr.err = errRewoundLen + + var n pgtype.Int8 + var err error + switch vr.Type().FormatCode { + case TextFormatCode: + err = n.DecodeText(&valueReader2{vr}) + case BinaryFormatCode: + err = n.DecodeBinary(&valueReader2{vr}) + default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0 } - if vr.Len() != 8 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8: %d", vr.Len()))) + if err != nil { + vr.Fatal(err) return 0 } - return vr.ReadInt64() + if n.Status == pgtype.Null { + vr.Fatal(ProtocolError("Cannot decode null into int16")) + return 0 + } + + return n.Int } func decodeChar(vr *ValueReader) Char { @@ -1485,88 +1516,37 @@ func decodeChar(vr *ValueReader) Char { } func decodeInt2(vr *ValueReader) int16 { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into int16")) - return 0 - } if vr.Type().DataType != Int2OID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int16", vr.Type().DataType))) return 0 } - if vr.Type().FormatCode != BinaryFormatCode { + vr.err = errRewoundLen + + var n pgtype.Int2 + var err error + switch vr.Type().FormatCode { + case TextFormatCode: + err = n.DecodeText(&valueReader2{vr}) + case BinaryFormatCode: + err = n.DecodeBinary(&valueReader2{vr}) + default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0 } - if vr.Len() != 2 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2: %d", vr.Len()))) + if err != nil { + vr.Fatal(err) return 0 } - return vr.ReadInt16() -} - -func encodeInt(w *WriteBuf, oid OID, value int) error { - switch oid { - case Int2OID: - if value < math.MinInt16 { - return fmt.Errorf("%d is less than min pg:int2", value) - } else if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than max pg:int2", value) - } - w.WriteInt32(2) - w.WriteInt16(int16(value)) - case Int4OID: - if value < math.MinInt32 { - return fmt.Errorf("%d is less than min pg:int4", value) - } else if value > math.MaxInt32 { - return fmt.Errorf("%d is greater than max pg:int4", value) - } - w.WriteInt32(4) - w.WriteInt32(int32(value)) - case Int8OID: - if int64(value) <= int64(math.MaxInt64) { - w.WriteInt32(8) - w.WriteInt64(int64(value)) - } else { - return fmt.Errorf("%d is larger than max int64 %d", value, int64(math.MaxInt64)) - } - default: - return fmt.Errorf("cannot encode %s into oid %v", "int8", oid) + if n.Status == pgtype.Null { + vr.Fatal(ProtocolError("Cannot decode null into int16")) + return 0 } - return nil -} - -func encodeUInt(w *WriteBuf, oid OID, value uint) error { - switch oid { - case Int2OID: - if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than max pg:int2", value) - } - w.WriteInt32(2) - w.WriteInt16(int16(value)) - case Int4OID: - if value > math.MaxInt32 { - return fmt.Errorf("%d is greater than max pg:int4", value) - } - w.WriteInt32(4) - w.WriteInt32(int32(value)) - case Int8OID: - //****** Changed value to int64(value) and math.MaxInt64 to int64(math.MaxInt64) - if int64(value) > int64(math.MaxInt64) { - return fmt.Errorf("%d is greater than max pg:int8", value) - } - w.WriteInt32(8) - w.WriteInt64(int64(value)) - - default: - return fmt.Errorf("cannot encode %s into oid %v", "uint8", oid) - } - - return nil + return n.Int } func encodeChar(w *WriteBuf, oid OID, value Char) error { @@ -1575,187 +1555,6 @@ func encodeChar(w *WriteBuf, oid OID, value Char) error { return nil } -func encodeInt8(w *WriteBuf, oid OID, value int8) error { - switch oid { - case Int2OID: - w.WriteInt32(2) - w.WriteInt16(int16(value)) - case Int4OID: - w.WriteInt32(4) - w.WriteInt32(int32(value)) - case Int8OID: - w.WriteInt32(8) - w.WriteInt64(int64(value)) - default: - return fmt.Errorf("cannot encode %s into oid %v", "int8", oid) - } - - return nil -} - -func encodeUInt8(w *WriteBuf, oid OID, value uint8) error { - switch oid { - case Int2OID: - w.WriteInt32(2) - w.WriteInt16(int16(value)) - case Int4OID: - w.WriteInt32(4) - w.WriteInt32(int32(value)) - case Int8OID: - w.WriteInt32(8) - w.WriteInt64(int64(value)) - default: - return fmt.Errorf("cannot encode %s into oid %v", "uint8", oid) - } - - return nil -} - -func encodeInt16(w *WriteBuf, oid OID, value int16) error { - switch oid { - case Int2OID: - w.WriteInt32(2) - w.WriteInt16(value) - case Int4OID: - w.WriteInt32(4) - w.WriteInt32(int32(value)) - case Int8OID: - w.WriteInt32(8) - w.WriteInt64(int64(value)) - default: - return fmt.Errorf("cannot encode %s into oid %v", "int16", oid) - } - - return nil -} - -func encodeUInt16(w *WriteBuf, oid OID, value uint16) error { - switch oid { - case Int2OID: - if value <= math.MaxInt16 { - w.WriteInt32(2) - w.WriteInt16(int16(value)) - } else { - return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16) - } - case Int4OID: - w.WriteInt32(4) - w.WriteInt32(int32(value)) - case Int8OID: - w.WriteInt32(8) - w.WriteInt64(int64(value)) - default: - return fmt.Errorf("cannot encode %s into oid %v", "int16", oid) - } - - return nil -} - -func encodeInt32(w *WriteBuf, oid OID, value int32) error { - switch oid { - case Int2OID: - if value <= math.MaxInt16 { - w.WriteInt32(2) - w.WriteInt16(int16(value)) - } else { - return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16) - } - case Int4OID: - w.WriteInt32(4) - w.WriteInt32(value) - case Int8OID: - w.WriteInt32(8) - w.WriteInt64(int64(value)) - default: - return fmt.Errorf("cannot encode %s into oid %v", "int32", oid) - } - - return nil -} - -func encodeUInt32(w *WriteBuf, oid OID, value uint32) error { - switch oid { - case Int2OID: - if value <= math.MaxInt16 { - w.WriteInt32(2) - w.WriteInt16(int16(value)) - } else { - return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16) - } - case Int4OID: - if value <= math.MaxInt32 { - w.WriteInt32(4) - w.WriteInt32(int32(value)) - } else { - return fmt.Errorf("%d is greater than max int32 %d", value, math.MaxInt32) - } - case Int8OID: - w.WriteInt32(8) - w.WriteInt64(int64(value)) - default: - return fmt.Errorf("cannot encode %s into oid %v", "uint32", oid) - } - - return nil -} - -func encodeInt64(w *WriteBuf, oid OID, value int64) error { - switch oid { - case Int2OID: - if value <= math.MaxInt16 { - w.WriteInt32(2) - w.WriteInt16(int16(value)) - } else { - return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16) - } - case Int4OID: - if value <= math.MaxInt32 { - w.WriteInt32(4) - w.WriteInt32(int32(value)) - } else { - return fmt.Errorf("%d is greater than max int32 %d", value, math.MaxInt32) - } - case Int8OID: - w.WriteInt32(8) - w.WriteInt64(value) - default: - return fmt.Errorf("cannot encode %s into oid %v", "int64", oid) - } - - return nil -} - -func encodeUInt64(w *WriteBuf, oid OID, value uint64) error { - switch oid { - case Int2OID: - if value <= math.MaxInt16 { - w.WriteInt32(2) - w.WriteInt16(int16(value)) - } else { - return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16) - } - case Int4OID: - if value <= math.MaxInt32 { - w.WriteInt32(4) - w.WriteInt32(int32(value)) - } else { - return fmt.Errorf("%d is greater than max int32 %d", value, math.MaxInt32) - } - case Int8OID: - - if value <= math.MaxInt64 { - w.WriteInt32(8) - w.WriteInt64(int64(value)) - } else { - return fmt.Errorf("%d is greater than max int64 %d", value, int64(math.MaxInt64)) - } - default: - return fmt.Errorf("cannot encode %s into oid %v", "uint64", oid) - } - - return nil -} - func decodeInt4(vr *ValueReader) int32 { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into int32")) @@ -1767,17 +1566,31 @@ func decodeInt4(vr *ValueReader) int32 { return 0 } - if vr.Type().FormatCode != BinaryFormatCode { + vr.err = errRewoundLen + + var n pgtype.Int4 + var err error + switch vr.Type().FormatCode { + case TextFormatCode: + err = n.DecodeText(&valueReader2{vr}) + case BinaryFormatCode: + err = n.DecodeBinary(&valueReader2{vr}) + default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0 } - if vr.Len() != 4 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4: %d", vr.Len()))) + if err != nil { + vr.Fatal(err) return 0 } - return vr.ReadInt32() + if n.Status == pgtype.Null { + vr.Fatal(ProtocolError("Cannot decode null into int16")) + return 0 + } + + return n.Int } func decodeOID(vr *ValueReader) OID { @@ -2179,51 +1992,54 @@ func encodeJSONB(w *WriteBuf, oid OID, value interface{}) error { } func decodeDate(vr *ValueReader) time.Time { - var zeroTime time.Time - - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into time.Time")) - return zeroTime - } - if vr.Type().DataType != DateOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) - return zeroTime + return time.Time{} } - if vr.Type().FormatCode != BinaryFormatCode { + vr.err = errRewoundLen + + var d pgtype.Date + var err error + switch vr.Type().FormatCode { + case TextFormatCode: + err = d.DecodeText(&valueReader2{vr}) + case BinaryFormatCode: + err = d.DecodeBinary(&valueReader2{vr}) + default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return zeroTime + return time.Time{} } - if vr.Len() != 4 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an date: %d", vr.Len()))) + if err != nil { + vr.Fatal(err) + return time.Time{} } - dayOffset := vr.ReadInt32() - return time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.Local) + + if d.Status == pgtype.Null { + vr.Fatal(ProtocolError("Cannot decode null into int16")) + return time.Time{} + } + + return d.Time } func encodeTime(w *WriteBuf, oid OID, value time.Time) error { switch oid { case DateOID: - tUnix := time.Date(value.Year(), value.Month(), value.Day(), 0, 0, 0, 0, time.UTC).Unix() - dateEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).Unix() - - secSinceDateEpoch := tUnix - dateEpoch - daysSinceDateEpoch := secSinceDateEpoch / 86400 - - w.WriteInt32(4) - w.WriteInt32(int32(daysSinceDateEpoch)) - - return nil + var d pgtype.Date + err := d.ConvertFrom(value) + if err != nil { + return err + } + return d.EncodeBinary(w) case TimestampTzOID, TimestampOID: - microsecSinceUnixEpoch := value.Unix()*1000000 + int64(value.Nanosecond())/1000 - microsecSinceY2K := microsecSinceUnixEpoch - microsecFromUnixEpochToY2K - - w.WriteInt32(8) - w.WriteInt64(microsecSinceY2K) - - return nil + var t pgtype.Timestamptz + err := t.ConvertFrom(value) + if err != nil { + return err + } + return t.EncodeBinary(w) default: return fmt.Errorf("cannot encode %s into oid %v", "time.Time", oid) } @@ -2244,19 +2060,31 @@ func decodeTimestampTz(vr *ValueReader) time.Time { return zeroTime } - if vr.Type().FormatCode != BinaryFormatCode { + vr.err = errRewoundLen + + var t pgtype.Timestamptz + var err error + switch vr.Type().FormatCode { + case TextFormatCode: + err = t.DecodeText(&valueReader2{vr}) + case BinaryFormatCode: + err = t.DecodeBinary(&valueReader2{vr}) + default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return zeroTime + return time.Time{} } - if vr.Len() != 8 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an timestamptz: %d", vr.Len()))) - return zeroTime + if err != nil { + vr.Fatal(err) + return time.Time{} } - microsecSinceY2K := vr.ReadInt64() - microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K - return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) + if t.Status == pgtype.Null { + vr.Fatal(ProtocolError("Cannot decode null into time.Time")) + return time.Time{} + } + + return t.Time } func decodeTimestamp(vr *ValueReader) time.Time { @@ -2578,42 +2406,45 @@ func encodeByteSliceSlice(w *WriteBuf, oid OID, value [][]byte) error { } func decodeInt2Array(vr *ValueReader) []int16 { - if vr.Len() == -1 { - return nil - } - if vr.Type().DataType != Int2ArrayOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int16", vr.Type().DataType))) return nil } - if vr.Type().FormatCode != BinaryFormatCode { + vr.err = errRewoundLen + + var a pgtype.Int2Array + var err error + switch vr.Type().FormatCode { + case TextFormatCode: + err = a.DecodeText(&valueReader2{vr}) + case BinaryFormatCode: + err = a.DecodeBinary(&valueReader2{vr}) + default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return nil } - numElems, err := decode1dArrayHeader(vr) if err != nil { vr.Fatal(err) return nil } - a := make([]int16, int(numElems)) - for i := 0; i < len(a); i++ { - elSize := vr.ReadInt32() - switch elSize { - case 2: - a[i] = vr.ReadInt16() - case -1: + if a.Status == pgtype.Null { + return nil + } + + rawArray := make([]int16, len(a.Elements)) + for i := range a.Elements { + if a.Elements[i].Status == pgtype.Present { + rawArray[i] = a.Elements[i].Int + } else { vr.Fatal(ProtocolError("Cannot decode null element")) return nil - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2 element: %d", elSize))) - return nil } } - return a + return rawArray } func decodeInt2ArrayToUInt(vr *ValueReader) []uint16 { @@ -2660,38 +2491,6 @@ func decodeInt2ArrayToUInt(vr *ValueReader) []uint16 { return a } -func encodeInt16Slice(w *WriteBuf, oid OID, slice []int16) error { - if oid != Int2ArrayOID { - return fmt.Errorf("cannot encode Go %s into oid %d", "[]int16", oid) - } - - encodeArrayHeader(w, Int2OID, len(slice), 6) - for _, v := range slice { - w.WriteInt32(2) - w.WriteInt16(v) - } - - return nil -} - -func encodeUInt16Slice(w *WriteBuf, oid OID, slice []uint16) error { - if oid != Int2ArrayOID { - return fmt.Errorf("cannot encode Go %s into oid %d", "[]uint16", oid) - } - - encodeArrayHeader(w, Int2OID, len(slice), 6) - for _, v := range slice { - if v <= math.MaxInt16 { - w.WriteInt32(2) - w.WriteInt16(int16(v)) - } else { - return fmt.Errorf("%d is greater than max smallint %d", v, math.MaxInt16) - } - } - - return nil -} - func decodeInt4Array(vr *ValueReader) []int32 { if vr.Len() == -1 { return nil diff --git a/values_test.go b/values_test.go index 6ab221f7..ef13ccdf 100644 --- a/values_test.go +++ b/values_test.go @@ -18,24 +18,24 @@ func TestDateTranscode(t *testing.T) { defer closeConn(t, conn) dates := []time.Time{ - time.Date(1, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(1000, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(1600, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(1700, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(1800, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(1990, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local), - time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(2001, 1, 2, 0, 0, 0, 0, time.Local), - time.Date(2004, 2, 29, 0, 0, 0, 0, time.Local), - time.Date(2013, 7, 4, 0, 0, 0, 0, time.Local), - time.Date(2013, 12, 25, 0, 0, 0, 0, time.Local), - time.Date(2029, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(2081, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(2096, 2, 29, 0, 0, 0, 0, time.Local), - time.Date(2550, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(9999, 12, 31, 0, 0, 0, 0, time.Local), + time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1000, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1600, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1700, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1800, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), + time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(2001, 1, 2, 0, 0, 0, 0, time.UTC), + time.Date(2004, 2, 29, 0, 0, 0, 0, time.UTC), + time.Date(2013, 7, 4, 0, 0, 0, 0, time.UTC), + time.Date(2013, 12, 25, 0, 0, 0, 0, time.UTC), + time.Date(2029, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(2081, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(2096, 2, 29, 0, 0, 0, 0, time.UTC), + time.Date(2550, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(9999, 12, 31, 0, 0, 0, 0, time.UTC), } for _, actualDate := range dates { @@ -629,8 +629,8 @@ func TestNullX(t *testing.T) { {"select $1::timestamptz", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}}, {"select $1::timestamp", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}}, {"select $1::timestamp", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}}, - {"select $1::date", []interface{}{pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}}, - {"select $1::date", []interface{}{pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.Local), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}}, + {"select $1::date", []interface{}{pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}}, + {"select $1::date", []interface{}{pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}}, {"select 42::int4, $1::float8", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: true}}, []interface{}{&actual.i32, &actual.f64}, allTypes{i32: pgx.NullInt32{Int32: 42, Valid: true}, f64: pgx.NullFloat64{Float64: 1.23, Valid: true}}}, } @@ -1048,11 +1048,11 @@ func TestEncodeTypeRename(t *testing.T) { defer closeConn(t, conn) type _int int - inInt := _int(3) + inInt := _int(1) var outInt _int type _int8 int8 - inInt8 := _int8(3) + inInt8 := _int8(2) var outInt8 _int8 type _int16 int16