From 2bf3fac5940d659e94e3cbfa2cd7902b3875b6b9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 2 Aug 2016 13:22:09 -0500 Subject: [PATCH 001/264] Add note to README noting the experimental status of v3 --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index c90bf966..b2795fca 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,9 @@ # Pgx +## Experimental Branch + +This is the experimental v3 branch. v2 is the stable branch. + Pgx is a pure Go database connection library designed specifically for PostgreSQL. Pgx is different from other drivers such as [pq](http://godoc.org/github.com/lib/pq) because, while it can operate as a From 214443deb726478c0cc251abe94278e97e9a98fb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 2 Aug 2016 13:31:55 -0500 Subject: [PATCH 002/264] Rename Oid to OID --- conn.go | 50 ++-- conn_pool.go | 2 +- conn_test.go | 2 +- example_custom_type_test.go | 2 +- example_json_test.go | 2 +- fastpath.go | 12 +- large_objects.go | 36 +-- messages.go | 6 +- msg_reader.go | 4 +- query.go | 76 +++--- query_test.go | 8 +- stdlib/sql.go | 28 +- v3.md | 3 + value_reader.go | 4 +- values.go | 492 ++++++++++++++++++------------------ values_test.go | 2 +- 16 files changed, 366 insertions(+), 363 deletions(-) create mode 100644 v3.md diff --git a/conn.go b/conn.go index c2519003..e5c2a401 100644 --- a/conn.go +++ b/conn.go @@ -51,7 +51,7 @@ type Conn struct { Pid int32 // backend pid SecretKey int32 // key to use to send a cancel query message to the server RuntimeParams map[string]string // parameters that have been reported by the server - PgTypes map[Oid]PgType // oids to PgTypes + PgTypes map[OID]PgType // oids to PgTypes config ConnConfig // config used when establishing this connection TxStatus byte preparedStatements map[string]*PreparedStatement @@ -75,12 +75,12 @@ type PreparedStatement struct { Name string SQL string FieldDescriptions []FieldDescription - ParameterOids []Oid + ParameterOIDs []OID } // PrepareExOptions is an option struct that can be passed to PrepareEx type PrepareExOptions struct { - ParameterOids []Oid + ParameterOIDs []OID } // Notification is a message received from the PostgreSQL LISTEN/NOTIFY system @@ -145,13 +145,13 @@ func Connect(config ConnConfig) (c *Conn, err error) { return connect(config, nil, nil, nil) } -func connect(config ConnConfig, pgTypes map[Oid]PgType, pgsql_af_inet *byte, pgsql_af_inet6 *byte) (c *Conn, err error) { +func connect(config ConnConfig, pgTypes map[OID]PgType, pgsql_af_inet *byte, pgsql_af_inet6 *byte) (c *Conn, err error) { c = new(Conn) c.config = config if pgTypes != nil { - c.PgTypes = make(map[Oid]PgType, len(pgTypes)) + c.PgTypes = make(map[OID]PgType, len(pgTypes)) for k, v := range pgTypes { c.PgTypes[k] = v } @@ -344,10 +344,10 @@ where ( return err } - c.PgTypes = make(map[Oid]PgType, 128) + c.PgTypes = make(map[OID]PgType, 128) for rows.Next() { - var oid Oid + var oid OID var t PgType rows.Scan(&oid, &t.Name) @@ -626,11 +626,11 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared wbuf.WriteCString(sql) if opts != nil { - if len(opts.ParameterOids) > 65535 { - return nil, errors.New(fmt.Sprintf("Number of PrepareExOptions ParameterOids must be between 0 and 65535, received %d", len(opts.ParameterOids))) + if len(opts.ParameterOIDs) > 65535 { + return nil, errors.New(fmt.Sprintf("Number of PrepareExOptions ParameterOIDs must be between 0 and 65535, received %d", len(opts.ParameterOIDs))) } - wbuf.WriteInt16(int16(len(opts.ParameterOids))) - for _, oid := range opts.ParameterOids { + wbuf.WriteInt16(int16(len(opts.ParameterOIDs))) + for _, oid := range opts.ParameterOIDs { wbuf.WriteInt32(int32(oid)) } } else { @@ -667,10 +667,10 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared switch t { case parseComplete: case parameterDescription: - ps.ParameterOids = c.rxParameterDescription(r) + ps.ParameterOIDs = c.rxParameterDescription(r) - if len(ps.ParameterOids) > 65535 && softErr == nil { - softErr = fmt.Errorf("PostgreSQL supports maximum of 65535 parameters, received %d", len(ps.ParameterOids)) + if len(ps.ParameterOIDs) > 65535 && softErr == nil { + softErr = fmt.Errorf("PostgreSQL supports maximum of 65535 parameters, received %d", len(ps.ParameterOIDs)) } case rowDescription: ps.FieldDescriptions = c.rxRowDescription(r) @@ -899,8 +899,8 @@ func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error { } func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}) (err error) { - if len(ps.ParameterOids) != len(arguments) { - return fmt.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOids), len(arguments)) + if len(ps.ParameterOIDs) != len(arguments) { + return fmt.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOIDs), len(arguments)) } // bind @@ -908,8 +908,8 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteByte(0) wbuf.WriteCString(ps.Name) - wbuf.WriteInt16(int16(len(ps.ParameterOids))) - for i, oid := range ps.ParameterOids { + wbuf.WriteInt16(int16(len(ps.ParameterOIDs))) + for i, oid := range ps.ParameterOIDs { switch arg := arguments[i].(type) { case Encoder: wbuf.WriteInt16(arg.FormatCode()) @@ -917,7 +917,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteInt16(TextFormatCode) default: switch oid { - case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, DateOid, BoolArrayOid, ByteaArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid, InetArrayOid, CidrArrayOid, RecordOid: + case BoolOID, ByteaOID, Int2OID, Int4OID, Int8OID, Float4OID, Float8OID, TimestampTzOID, TimestampTzArrayOID, TimestampOID, TimestampArrayOID, DateOID, BoolArrayOID, ByteaArrayOID, Int2ArrayOID, Int4ArrayOID, Int8ArrayOID, Float4ArrayOID, Float8ArrayOID, TextArrayOID, VarcharArrayOID, OIDOID, InetOID, CidrOID, InetArrayOID, CidrArrayOID, RecordOID: wbuf.WriteInt16(BinaryFormatCode) default: wbuf.WriteInt16(TextFormatCode) @@ -926,7 +926,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} } wbuf.WriteInt16(int16(len(arguments))) - for i, oid := range ps.ParameterOids { + for i, oid := range ps.ParameterOIDs { if err := Encode(wbuf, oid, arguments[i]); err != nil { return err } @@ -1151,9 +1151,9 @@ func (c *Conn) rxRowDescription(r *msgReader) (fields []FieldDescription) { for i := int16(0); i < fieldCount; i++ { f := &fields[i] f.Name = r.readCString() - f.Table = r.readOid() + f.Table = r.readOID() f.AttributeNumber = r.readInt16() - f.DataType = r.readOid() + f.DataType = r.readOID() f.DataTypeSize = r.readInt16() f.Modifier = r.readInt32() f.FormatCode = r.readInt16() @@ -1161,7 +1161,7 @@ func (c *Conn) rxRowDescription(r *msgReader) (fields []FieldDescription) { return } -func (c *Conn) rxParameterDescription(r *msgReader) (parameters []Oid) { +func (c *Conn) rxParameterDescription(r *msgReader) (parameters []OID) { // Internally, PostgreSQL supports greater than 64k parameters to a prepared // statement. But the parameter description uses a 16-bit integer for the // count of parameters. If there are more than 64K parameters, this count is @@ -1170,10 +1170,10 @@ func (c *Conn) rxParameterDescription(r *msgReader) (parameters []Oid) { r.readInt16() parameterCount := r.msgBytesRemaining / 4 - parameters = make([]Oid, 0, parameterCount) + parameters = make([]OID, 0, parameterCount) for i := int32(0); i < parameterCount; i++ { - parameters = append(parameters, r.readOid()) + parameters = append(parameters, r.readOID()) } return } diff --git a/conn_pool.go b/conn_pool.go index 775fb091..a72d5daf 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -27,7 +27,7 @@ type ConnPool struct { closed bool preparedStatements map[string]*PreparedStatement acquireTimeout time.Duration - pgTypes map[Oid]PgType + pgTypes map[OID]PgType pgsql_af_inet *byte pgsql_af_inet6 *byte } diff --git a/conn_test.go b/conn_test.go index 181a3ed2..4067118c 100644 --- a/conn_test.go +++ b/conn_test.go @@ -987,7 +987,7 @@ func TestPrepareEx(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - _, err := conn.PrepareEx("test", "select $1", &pgx.PrepareExOptions{ParameterOids: []pgx.Oid{pgx.TextOid}}) + _, err := conn.PrepareEx("test", "select $1", &pgx.PrepareExOptions{ParameterOIDs: []pgx.OID{pgx.TextOID}}) if err != nil { t.Errorf("Unable to prepare statement: %v", err) return diff --git a/example_custom_type_test.go b/example_custom_type_test.go index c8d8e220..ddf4732d 100644 --- a/example_custom_type_test.go +++ b/example_custom_type_test.go @@ -57,7 +57,7 @@ func (p *NullPoint) Scan(vr *pgx.ValueReader) error { func (p NullPoint) FormatCode() int16 { return pgx.BinaryFormatCode } -func (p NullPoint) Encode(w *pgx.WriteBuf, oid pgx.Oid) error { +func (p NullPoint) Encode(w *pgx.WriteBuf, oid pgx.OID) error { if !p.Valid { w.WriteInt32(-1) return nil diff --git a/example_json_test.go b/example_json_test.go index c1534158..513cc90b 100644 --- a/example_json_test.go +++ b/example_json_test.go @@ -12,7 +12,7 @@ func Example_JSON() { return } - if _, ok := conn.PgTypes[pgx.JsonOid]; !ok { + if _, ok := conn.PgTypes[pgx.JsonOID]; !ok { // No JSON type -- must be running against very old PostgreSQL // Pretend it works fmt.Println("John", 42) diff --git a/fastpath.go b/fastpath.go index 8814e559..a1212e1b 100644 --- a/fastpath.go +++ b/fastpath.go @@ -7,26 +7,26 @@ import ( type fastpathArg []byte func newFastpath(cn *Conn) *fastpath { - return &fastpath{cn: cn, fns: make(map[string]Oid)} + return &fastpath{cn: cn, fns: make(map[string]OID)} } type fastpath struct { cn *Conn - fns map[string]Oid + fns map[string]OID } -func (f *fastpath) functionOID(name string) Oid { +func (f *fastpath) functionOID(name string) OID { return f.fns[name] } -func (f *fastpath) addFunction(name string, oid Oid) { +func (f *fastpath) addFunction(name string, oid OID) { f.fns[name] = oid } func (f *fastpath) addFunctions(rows *Rows) error { for rows.Next() { var name string - var oid Oid + var oid OID if err := rows.Scan(&name, &oid); err != nil { return err } @@ -49,7 +49,7 @@ func fpInt64Arg(n int64) fpArg { return res } -func (f *fastpath) Call(oid Oid, args []fpArg) (res []byte, err error) { +func (f *fastpath) Call(oid OID, args []fpArg) (res []byte, err error) { wbuf := newWriteBuf(f.cn, 'F') // function call wbuf.WriteInt32(int32(oid)) // function object id wbuf.WriteInt16(1) // # of argument format codes diff --git a/large_objects.go b/large_objects.go index a4922ef1..5b3e3a33 100644 --- a/large_objects.go +++ b/large_objects.go @@ -14,20 +14,20 @@ type LargeObjects struct { fp *fastpath } -const largeObjectFns = `select proname, oid from pg_catalog.pg_proc +const largeObjectFns = `select proname, oid from pg_catalog.pg_proc where proname in ( -'lo_open', -'lo_close', -'lo_create', -'lo_unlink', -'lo_lseek', -'lo_lseek64', -'lo_tell', -'lo_tell64', -'lo_truncate', -'lo_truncate64', -'loread', -'lowrite') +'lo_open', +'lo_close', +'lo_create', +'lo_unlink', +'lo_lseek', +'lo_lseek64', +'lo_tell', +'lo_tell64', +'lo_truncate', +'lo_truncate64', +'loread', +'lowrite') and pronamespace = (select oid from pg_catalog.pg_namespace where nspname = 'pg_catalog')` // LargeObjects returns a LargeObjects instance for the transaction. @@ -60,19 +60,19 @@ const ( // Create creates a new large object. If id is zero, the server assigns an // unused OID. -func (o *LargeObjects) Create(id Oid) (Oid, error) { - newOid, err := fpInt32(o.fp.CallFn("lo_create", []fpArg{fpIntArg(int32(id))})) - return Oid(newOid), err +func (o *LargeObjects) Create(id OID) (OID, error) { + newOID, err := fpInt32(o.fp.CallFn("lo_create", []fpArg{fpIntArg(int32(id))})) + return OID(newOID), err } // Open opens an existing large object with the given mode. -func (o *LargeObjects) Open(oid Oid, mode LargeObjectMode) (*LargeObject, error) { +func (o *LargeObjects) Open(oid OID, mode LargeObjectMode) (*LargeObject, error) { fd, err := fpInt32(o.fp.CallFn("lo_open", []fpArg{fpIntArg(int32(oid)), fpIntArg(int32(mode))})) return &LargeObject{fd: fd, lo: o}, err } // Unlink removes a large object from the database. -func (o *LargeObjects) Unlink(oid Oid) error { +func (o *LargeObjects) Unlink(oid OID) error { _, err := o.fp.CallFn("lo_unlink", []fpArg{fpIntArg(int32(oid))}) return err } diff --git a/messages.go b/messages.go index 1fbd9cbc..e4bdfb2c 100644 --- a/messages.go +++ b/messages.go @@ -49,13 +49,13 @@ func (self *startupMessage) Bytes() (buf []byte) { return buf } -type Oid int32 +type OID int32 type FieldDescription struct { Name string - Table Oid + Table OID AttributeNumber int16 - DataType Oid + DataType OID DataTypeSize int16 DataTypeName string Modifier int32 diff --git a/msg_reader.go b/msg_reader.go index fd74a63b..b5848946 100644 --- a/msg_reader.go +++ b/msg_reader.go @@ -158,8 +158,8 @@ func (r *msgReader) readInt64() int64 { return n } -func (r *msgReader) readOid() Oid { - return Oid(r.readInt32()) +func (r *msgReader) readOID() OID { + return OID(r.readInt32()) } // readCString reads a null terminated string diff --git a/query.go b/query.go index 50c8e290..abe9860e 100644 --- a/query.go +++ b/query.go @@ -250,7 +250,7 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { if b, ok := d.(*[]byte); ok { // If it actually is a bytea then pass it through decodeBytea (so it can be decoded if it is in text format) // Otherwise read the bytes directly regardless of what the actual type is. - if vr.Type().DataType == ByteaOid { + if vr.Type().DataType == ByteaOID { *b = decodeBytea(vr) } else { if vr.Len() != -1 { @@ -268,27 +268,27 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { var val interface{} if 0 <= vr.Len() { switch vr.Type().DataType { - case BoolOid: + case BoolOID: val = decodeBool(vr) - case Int8Oid: + case Int8OID: val = int64(decodeInt8(vr)) - case Int2Oid: + case Int2OID: val = int64(decodeInt2(vr)) - case Int4Oid: + case Int4OID: val = int64(decodeInt4(vr)) - case TextOid, VarcharOid: + case TextOID, VarcharOID: val = decodeText(vr) - case OidOid: - val = int64(decodeOid(vr)) - case Float4Oid: + case OIDOID: + val = int64(decodeOID(vr)) + case Float4OID: val = float64(decodeFloat4(vr)) - case Float8Oid: + case Float8OID: val = decodeFloat8(vr) - case DateOid: + case DateOID: val = decodeDate(vr) - case TimestampOid: + case TimestampOID: val = decodeTimestamp(vr) - case TimestampTzOid: + case TimestampTzOID: val = decodeTimestampTz(vr) default: val = vr.ReadBytes(vr.Len()) @@ -298,7 +298,7 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } - } else if vr.Type().DataType == JsonOid || vr.Type().DataType == JsonbOid { + } else if vr.Type().DataType == JsonOID || vr.Type().DataType == JsonbOID { // Because the argument passed to decodeJSON will escape the heap. // This allows d to be stack allocated and only copied to the heap when // we actually are decoding JSON. This saves one memory allocation per @@ -345,53 +345,53 @@ func (rows *Rows) Values() ([]interface{}, error) { values = append(values, vr.ReadString(vr.Len())) case BinaryFormatCode: switch vr.Type().DataType { - case TextOid, VarcharOid: + case TextOID, VarcharOID: values = append(values, decodeText(vr)) - case BoolOid: + case BoolOID: values = append(values, decodeBool(vr)) - case ByteaOid: + case ByteaOID: values = append(values, decodeBytea(vr)) - case Int8Oid: + case Int8OID: values = append(values, decodeInt8(vr)) - case Int2Oid: + case Int2OID: values = append(values, decodeInt2(vr)) - case Int4Oid: + case Int4OID: values = append(values, decodeInt4(vr)) - case OidOid: - values = append(values, decodeOid(vr)) - case Float4Oid: + case OIDOID: + values = append(values, decodeOID(vr)) + case Float4OID: values = append(values, decodeFloat4(vr)) - case Float8Oid: + case Float8OID: values = append(values, decodeFloat8(vr)) - case BoolArrayOid: + case BoolArrayOID: values = append(values, decodeBoolArray(vr)) - case Int2ArrayOid: + case Int2ArrayOID: values = append(values, decodeInt2Array(vr)) - case Int4ArrayOid: + case Int4ArrayOID: values = append(values, decodeInt4Array(vr)) - case Int8ArrayOid: + case Int8ArrayOID: values = append(values, decodeInt8Array(vr)) - case Float4ArrayOid: + case Float4ArrayOID: values = append(values, decodeFloat4Array(vr)) - case Float8ArrayOid: + case Float8ArrayOID: values = append(values, decodeFloat8Array(vr)) - case TextArrayOid, VarcharArrayOid: + case TextArrayOID, VarcharArrayOID: values = append(values, decodeTextArray(vr)) - case TimestampArrayOid, TimestampTzArrayOid: + case TimestampArrayOID, TimestampTzArrayOID: values = append(values, decodeTimestampArray(vr)) - case DateOid: + case DateOID: values = append(values, decodeDate(vr)) - case TimestampTzOid: + case TimestampTzOID: values = append(values, decodeTimestampTz(vr)) - case TimestampOid: + case TimestampOID: values = append(values, decodeTimestamp(vr)) - case InetOid, CidrOid: + case InetOID, CidrOID: values = append(values, decodeInet(vr)) - case JsonOid: + case JsonOID: var d interface{} decodeJSON(vr, &d) values = append(values, d) - case JsonbOid: + case JsonbOID: var d interface{} decodeJSON(vr, &d) values = append(values, d) diff --git a/query_test.go b/query_test.go index 2cf8b3cd..21496c19 100644 --- a/query_test.go +++ b/query_test.go @@ -83,7 +83,7 @@ func TestConnQueryValues(t *testing.T) { t.Errorf(`Expected values[3] to be %v, but it was %d`, nil, values[3]) } - if values[4] != pgx.Oid(rowCount) { + if values[4] != pgx.OID(rowCount) { t.Errorf(`Expected values[4] to be %d, but it was %d`, rowCount, values[4]) } } @@ -385,7 +385,7 @@ type coreEncoder struct{} func (n coreEncoder) FormatCode() int16 { return pgx.TextFormatCode } -func (n *coreEncoder) Encode(w *pgx.WriteBuf, oid pgx.Oid) error { +func (n *coreEncoder) Encode(w *pgx.WriteBuf, oid pgx.OID) error { w.WriteInt32(int32(2)) w.WriteBytes([]byte("42")) return nil @@ -420,7 +420,7 @@ func TestQueryRowCoreTypes(t *testing.T) { f64 float64 b bool t time.Time - oid pgx.Oid + oid pgx.OID } var actual, zero allTypes @@ -438,7 +438,7 @@ func TestQueryRowCoreTypes(t *testing.T) { {"select $1::timestamptz", []interface{}{time.Unix(123, 5000)}, []interface{}{&actual.t}, allTypes{t: time.Unix(123, 5000)}}, {"select $1::timestamp", []interface{}{time.Date(2010, 1, 2, 3, 4, 5, 0, time.Local)}, []interface{}{&actual.t}, allTypes{t: time.Date(2010, 1, 2, 3, 4, 5, 0, time.Local)}}, {"select $1::date", []interface{}{time.Date(1987, 1, 2, 0, 0, 0, 0, time.Local)}, []interface{}{&actual.t}, allTypes{t: time.Date(1987, 1, 2, 0, 0, 0, 0, time.Local)}}, - {"select $1::oid", []interface{}{pgx.Oid(42)}, []interface{}{&actual.oid}, allTypes{oid: 42}}, + {"select $1::oid", []interface{}{pgx.OID(42)}, []interface{}{&actual.oid}, allTypes{oid: 42}}, } for i, tt := range tests { diff --git a/stdlib/sql.go b/stdlib/sql.go index 5bf2c113..6e55996f 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -57,23 +57,23 @@ var openFromConnPoolCount int // oids that map to intrinsic database/sql types. These will be allowed to be // binary, anything else will be forced to text format -var databaseSqlOids map[pgx.Oid]bool +var databaseSqlOIDs map[pgx.OID]bool func init() { d := &Driver{} sql.Register("pgx", d) - databaseSqlOids = make(map[pgx.Oid]bool) - databaseSqlOids[pgx.BoolOid] = true - databaseSqlOids[pgx.ByteaOid] = true - databaseSqlOids[pgx.Int2Oid] = true - databaseSqlOids[pgx.Int4Oid] = true - databaseSqlOids[pgx.Int8Oid] = true - databaseSqlOids[pgx.Float4Oid] = true - databaseSqlOids[pgx.Float8Oid] = true - databaseSqlOids[pgx.DateOid] = true - databaseSqlOids[pgx.TimestampTzOid] = true - databaseSqlOids[pgx.TimestampOid] = true + databaseSqlOIDs = make(map[pgx.OID]bool) + databaseSqlOIDs[pgx.BoolOID] = true + databaseSqlOIDs[pgx.ByteaOID] = true + databaseSqlOIDs[pgx.Int2OID] = true + databaseSqlOIDs[pgx.Int4OID] = true + databaseSqlOIDs[pgx.Int8OID] = true + databaseSqlOIDs[pgx.Float4OID] = true + databaseSqlOIDs[pgx.Float8OID] = true + databaseSqlOIDs[pgx.DateOID] = true + databaseSqlOIDs[pgx.TimestampTzOID] = true + databaseSqlOIDs[pgx.TimestampOID] = true } type Driver struct { @@ -231,7 +231,7 @@ func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, er // (e.g. []int32) func restrictBinaryToDatabaseSqlTypes(ps *pgx.PreparedStatement) { for i, _ := range ps.FieldDescriptions { - intrinsic, _ := databaseSqlOids[ps.FieldDescriptions[i].DataType] + intrinsic, _ := databaseSqlOIDs[ps.FieldDescriptions[i].DataType] if !intrinsic { ps.FieldDescriptions[i].FormatCode = pgx.TextFormatCode } @@ -248,7 +248,7 @@ func (s *Stmt) Close() error { } func (s *Stmt) NumInput() int { - return len(s.ps.ParameterOids) + return len(s.ps.ParameterOIDs) } func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) { diff --git a/v3.md b/v3.md new file mode 100644 index 00000000..fdf8dcac --- /dev/null +++ b/v3.md @@ -0,0 +1,3 @@ +# V3 Changes + +Rename Oid to OID in accordance with Go conventions. diff --git a/value_reader.go b/value_reader.go index 4936b887..a47a1d17 100644 --- a/value_reader.go +++ b/value_reader.go @@ -88,8 +88,8 @@ func (r *ValueReader) ReadInt64() int64 { return r.mr.readInt64() } -func (r *ValueReader) ReadOid() Oid { - return Oid(r.ReadInt32()) +func (r *ValueReader) ReadOID() OID { + return OID(r.ReadInt32()) } // ReadString reads count bytes and returns as string diff --git a/values.go b/values.go index b6e0a84b..e49721ff 100644 --- a/values.go +++ b/values.go @@ -15,39 +15,39 @@ import ( // PostgreSQL oids for common types const ( - BoolOid = 16 - ByteaOid = 17 - Int8Oid = 20 - Int2Oid = 21 - Int4Oid = 23 - TextOid = 25 - OidOid = 26 - JsonOid = 114 - CidrOid = 650 - CidrArrayOid = 651 - Float4Oid = 700 - Float8Oid = 701 - UnknownOid = 705 - InetOid = 869 - BoolArrayOid = 1000 - Int2ArrayOid = 1005 - Int4ArrayOid = 1007 - TextArrayOid = 1009 - ByteaArrayOid = 1001 - VarcharArrayOid = 1015 - Int8ArrayOid = 1016 - Float4ArrayOid = 1021 - Float8ArrayOid = 1022 - InetArrayOid = 1041 - VarcharOid = 1043 - DateOid = 1082 - TimestampOid = 1114 - TimestampArrayOid = 1115 - TimestampTzOid = 1184 - TimestampTzArrayOid = 1185 - RecordOid = 2249 - UuidOid = 2950 - JsonbOid = 3802 + BoolOID = 16 + ByteaOID = 17 + Int8OID = 20 + Int2OID = 21 + Int4OID = 23 + TextOID = 25 + OIDOID = 26 + JsonOID = 114 + CidrOID = 650 + CidrArrayOID = 651 + Float4OID = 700 + Float8OID = 701 + UnknownOID = 705 + InetOID = 869 + BoolArrayOID = 1000 + Int2ArrayOID = 1005 + Int4ArrayOID = 1007 + TextArrayOID = 1009 + ByteaArrayOID = 1001 + VarcharArrayOID = 1015 + Int8ArrayOID = 1016 + Float4ArrayOID = 1021 + Float8ArrayOID = 1022 + InetArrayOID = 1041 + VarcharOID = 1043 + DateOID = 1082 + TimestampOID = 1114 + TimestampArrayOID = 1115 + TimestampTzOID = 1184 + TimestampTzArrayOID = 1185 + RecordOID = 2249 + UuidOID = 2950 + JsonbOID = 3802 ) // PostgreSQL format codes @@ -130,7 +130,7 @@ type Encoder interface { // expected data size or format of the encoded data does not match. But if // the encoded data is a valid representation of the data type PostgreSQL // expects such as date and int4, incorrect data may be stored. - Encode(w *WriteBuf, oid Oid) error + Encode(w *WriteBuf, oid OID) error // FormatCode returns the format that the encoder writes the value. It must be // either pgx.TextFormatCode or pgx.BinaryFormatCode. @@ -148,7 +148,7 @@ type NullFloat32 struct { } func (n *NullFloat32) Scan(vr *ValueReader) error { - if vr.Type().DataType != Float4Oid { + if vr.Type().DataType != Float4OID { return SerializationError(fmt.Sprintf("NullFloat32.Scan cannot decode OID %d", vr.Type().DataType)) } @@ -163,8 +163,8 @@ func (n *NullFloat32) Scan(vr *ValueReader) error { func (n NullFloat32) FormatCode() int16 { return BinaryFormatCode } -func (n NullFloat32) Encode(w *WriteBuf, oid Oid) error { - if oid != Float4Oid { +func (n NullFloat32) Encode(w *WriteBuf, oid OID) error { + if oid != Float4OID { return SerializationError(fmt.Sprintf("NullFloat32.Encode cannot encode into OID %d", oid)) } @@ -187,7 +187,7 @@ type NullFloat64 struct { } func (n *NullFloat64) Scan(vr *ValueReader) error { - if vr.Type().DataType != Float8Oid { + if vr.Type().DataType != Float8OID { return SerializationError(fmt.Sprintf("NullFloat64.Scan cannot decode OID %d", vr.Type().DataType)) } @@ -202,8 +202,8 @@ func (n *NullFloat64) Scan(vr *ValueReader) error { func (n NullFloat64) FormatCode() int16 { return BinaryFormatCode } -func (n NullFloat64) Encode(w *WriteBuf, oid Oid) error { - if oid != Float8Oid { +func (n NullFloat64) Encode(w *WriteBuf, oid OID) error { + if oid != Float8OID { return SerializationError(fmt.Sprintf("NullFloat64.Encode cannot encode into OID %d", oid)) } @@ -240,7 +240,7 @@ func (s *NullString) Scan(vr *ValueReader) error { func (n NullString) FormatCode() int16 { return TextFormatCode } -func (s NullString) Encode(w *WriteBuf, oid Oid) error { +func (s NullString) Encode(w *WriteBuf, oid OID) error { if !s.Valid { w.WriteInt32(-1) return nil @@ -260,7 +260,7 @@ type NullInt16 struct { } func (n *NullInt16) Scan(vr *ValueReader) error { - if vr.Type().DataType != Int2Oid { + if vr.Type().DataType != Int2OID { return SerializationError(fmt.Sprintf("NullInt16.Scan cannot decode OID %d", vr.Type().DataType)) } @@ -275,8 +275,8 @@ func (n *NullInt16) Scan(vr *ValueReader) error { func (n NullInt16) FormatCode() int16 { return BinaryFormatCode } -func (n NullInt16) Encode(w *WriteBuf, oid Oid) error { - if oid != Int2Oid { +func (n NullInt16) Encode(w *WriteBuf, oid OID) error { + if oid != Int2OID { return SerializationError(fmt.Sprintf("NullInt16.Encode cannot encode into OID %d", oid)) } @@ -299,7 +299,7 @@ type NullInt32 struct { } func (n *NullInt32) Scan(vr *ValueReader) error { - if vr.Type().DataType != Int4Oid { + if vr.Type().DataType != Int4OID { return SerializationError(fmt.Sprintf("NullInt32.Scan cannot decode OID %d", vr.Type().DataType)) } @@ -314,8 +314,8 @@ func (n *NullInt32) Scan(vr *ValueReader) error { func (n NullInt32) FormatCode() int16 { return BinaryFormatCode } -func (n NullInt32) Encode(w *WriteBuf, oid Oid) error { - if oid != Int4Oid { +func (n NullInt32) Encode(w *WriteBuf, oid OID) error { + if oid != Int4OID { return SerializationError(fmt.Sprintf("NullInt32.Encode cannot encode into OID %d", oid)) } @@ -338,7 +338,7 @@ type NullInt64 struct { } func (n *NullInt64) Scan(vr *ValueReader) error { - if vr.Type().DataType != Int8Oid { + if vr.Type().DataType != Int8OID { return SerializationError(fmt.Sprintf("NullInt64.Scan cannot decode OID %d", vr.Type().DataType)) } @@ -353,8 +353,8 @@ func (n *NullInt64) Scan(vr *ValueReader) error { func (n NullInt64) FormatCode() int16 { return BinaryFormatCode } -func (n NullInt64) Encode(w *WriteBuf, oid Oid) error { - if oid != Int8Oid { +func (n NullInt64) Encode(w *WriteBuf, oid OID) error { + if oid != Int8OID { return SerializationError(fmt.Sprintf("NullInt64.Encode cannot encode into OID %d", oid)) } @@ -377,7 +377,7 @@ type NullBool struct { } func (n *NullBool) Scan(vr *ValueReader) error { - if vr.Type().DataType != BoolOid { + if vr.Type().DataType != BoolOID { return SerializationError(fmt.Sprintf("NullBool.Scan cannot decode OID %d", vr.Type().DataType)) } @@ -392,8 +392,8 @@ func (n *NullBool) Scan(vr *ValueReader) error { func (n NullBool) FormatCode() int16 { return BinaryFormatCode } -func (n NullBool) Encode(w *WriteBuf, oid Oid) error { - if oid != BoolOid { +func (n NullBool) Encode(w *WriteBuf, oid OID) error { + if oid != BoolOID { return SerializationError(fmt.Sprintf("NullBool.Encode cannot encode into OID %d", oid)) } @@ -418,7 +418,7 @@ type NullTime struct { func (n *NullTime) Scan(vr *ValueReader) error { oid := vr.Type().DataType - if oid != TimestampTzOid && oid != TimestampOid && oid != DateOid { + if oid != TimestampTzOID && oid != TimestampOID && oid != DateOID { return SerializationError(fmt.Sprintf("NullTime.Scan cannot decode OID %d", vr.Type().DataType)) } @@ -429,11 +429,11 @@ func (n *NullTime) Scan(vr *ValueReader) error { n.Valid = true switch oid { - case TimestampTzOid: + case TimestampTzOID: n.Time = decodeTimestampTz(vr) - case TimestampOid: + case TimestampOID: n.Time = decodeTimestamp(vr) - case DateOid: + case DateOID: n.Time = decodeDate(vr) } @@ -442,8 +442,8 @@ func (n *NullTime) Scan(vr *ValueReader) error { func (n NullTime) FormatCode() int16 { return BinaryFormatCode } -func (n NullTime) Encode(w *WriteBuf, oid Oid) error { - if oid != TimestampTzOid && oid != TimestampOid && oid != DateOid { +func (n NullTime) Encode(w *WriteBuf, oid OID) error { + if oid != TimestampTzOID && oid != TimestampOID && oid != DateOID { return SerializationError(fmt.Sprintf("NullTime.Encode cannot encode into OID %d", oid)) } @@ -494,7 +494,7 @@ func (h *Hstore) Scan(vr *ValueReader) error { func (h Hstore) FormatCode() int16 { return TextFormatCode } -func (h Hstore) Encode(w *WriteBuf, oid Oid) error { +func (h Hstore) Encode(w *WriteBuf, oid OID) error { var buf bytes.Buffer i := 0 @@ -560,7 +560,7 @@ func (h *NullHstore) Scan(vr *ValueReader) error { func (h NullHstore) FormatCode() int16 { return TextFormatCode } -func (h NullHstore) Encode(w *WriteBuf, oid Oid) error { +func (h NullHstore) Encode(w *WriteBuf, oid OID) error { var buf bytes.Buffer if !h.Valid { @@ -592,7 +592,7 @@ func (h NullHstore) Encode(w *WriteBuf, oid Oid) error { // Encode encodes arg into wbuf as the type oid. This allows implementations // of the Encoder interface to delegate the actual work of encoding to the // built-in functionality. -func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error { +func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { if arg == nil { wbuf.WriteInt32(-1) return nil @@ -627,7 +627,7 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error { } } - if oid == JsonOid || oid == JsonbOid { + if oid == JsonOID || oid == JsonbOID { return encodeJSON(wbuf, oid, arg) } @@ -690,8 +690,8 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error { return encodeIPNet(wbuf, oid, arg) case []net.IPNet: return encodeIPNetSlice(wbuf, oid, arg) - case Oid: - return encodeOid(wbuf, oid, arg) + case OID: + return encodeOID(wbuf, oid, arg) default: if strippedArg, ok := stripNamedType(&refVal); ok { return Encode(wbuf, oid, strippedArg) @@ -814,8 +814,8 @@ func Decode(vr *ValueReader, d interface{}) error { return fmt.Errorf("%d is less than zero for uint64", n) } *v = uint64(n) - case *Oid: - *v = decodeOid(vr) + case *OID: + *v = decodeOID(vr) case *string: *v = decodeText(vr) case *float32: @@ -850,11 +850,11 @@ func Decode(vr *ValueReader, d interface{}) error { *v = decodeRecord(vr) case *time.Time: switch vr.Type().DataType { - case DateOid: + case DateOID: *v = decodeDate(vr) - case TimestampTzOid: + case TimestampTzOID: *v = decodeTimestampTz(vr) - case TimestampOid: + case TimestampOID: *v = decodeTimestamp(vr) default: return fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType) @@ -934,7 +934,7 @@ func decodeBool(vr *ValueReader) bool { return false } - if vr.Type().DataType != BoolOid { + if vr.Type().DataType != BoolOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into bool", vr.Type().DataType))) return false } @@ -953,8 +953,8 @@ func decodeBool(vr *ValueReader) bool { return b != 0 } -func encodeBool(w *WriteBuf, oid Oid, value bool) error { - if oid != BoolOid { +func encodeBool(w *WriteBuf, oid OID, value bool) error { + if oid != BoolOID { return fmt.Errorf("cannot encode Go %s into oid %d", "bool", oid) } @@ -972,11 +972,11 @@ func encodeBool(w *WriteBuf, oid Oid, value bool) error { func decodeInt(vr *ValueReader) int64 { switch vr.Type().DataType { - case Int2Oid: + case Int2OID: return int64(decodeInt2(vr)) - case Int4Oid: + case Int4OID: return int64(decodeInt4(vr)) - case Int8Oid: + case Int8OID: return int64(decodeInt8(vr)) } @@ -990,7 +990,7 @@ func decodeInt8(vr *ValueReader) int64 { return 0 } - if vr.Type().DataType != Int8Oid { + if vr.Type().DataType != Int8OID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int8", vr.Type().DataType))) return 0 } @@ -1014,7 +1014,7 @@ func decodeInt2(vr *ValueReader) int16 { return 0 } - if vr.Type().DataType != Int2Oid { + if vr.Type().DataType != Int2OID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int16", vr.Type().DataType))) return 0 } @@ -1032,9 +1032,9 @@ func decodeInt2(vr *ValueReader) int16 { return vr.ReadInt16() } -func encodeInt(w *WriteBuf, oid Oid, value int) error { +func encodeInt(w *WriteBuf, oid OID, value int) error { switch oid { - case Int2Oid: + case Int2OID: if value < math.MinInt16 { return fmt.Errorf("%d is less than min pg:int2", value) } else if value > math.MaxInt16 { @@ -1042,7 +1042,7 @@ func encodeInt(w *WriteBuf, oid Oid, value int) error { } w.WriteInt32(2) w.WriteInt16(int16(value)) - case Int4Oid: + case Int4OID: if value < math.MinInt32 { return fmt.Errorf("%d is less than min pg:int4", value) } else if value > math.MaxInt32 { @@ -1050,7 +1050,7 @@ func encodeInt(w *WriteBuf, oid Oid, value int) error { } w.WriteInt32(4) w.WriteInt32(int32(value)) - case Int8Oid: + case Int8OID: if int64(value) <= int64(math.MaxInt64) { w.WriteInt32(8) w.WriteInt64(int64(value)) @@ -1064,21 +1064,21 @@ func encodeInt(w *WriteBuf, oid Oid, value int) error { return nil } -func encodeUInt(w *WriteBuf, oid Oid, value uint) error { +func encodeUInt(w *WriteBuf, oid OID, value uint) error { switch oid { - case Int2Oid: + case Int2OID: if value > math.MaxInt16 { return fmt.Errorf("%d is greater than max pg:int2", value) } w.WriteInt32(2) w.WriteInt16(int16(value)) - case Int4Oid: + case Int4OID: if value > math.MaxInt32 { return fmt.Errorf("%d is greater than max pg:int4", value) } w.WriteInt32(4) w.WriteInt32(int32(value)) - case Int8Oid: + case Int8OID: //****** Changed value to int64(value) and math.MaxInt64 to int64(math.MaxInt64) if int64(value) > int64(math.MaxInt64) { return fmt.Errorf("%d is greater than max pg:int8", value) @@ -1093,15 +1093,15 @@ func encodeUInt(w *WriteBuf, oid Oid, value uint) error { return nil } -func encodeInt8(w *WriteBuf, oid Oid, value int8) error { +func encodeInt8(w *WriteBuf, oid OID, value int8) error { switch oid { - case Int2Oid: + case Int2OID: w.WriteInt32(2) w.WriteInt16(int16(value)) - case Int4Oid: + case Int4OID: w.WriteInt32(4) w.WriteInt32(int32(value)) - case Int8Oid: + case Int8OID: w.WriteInt32(8) w.WriteInt64(int64(value)) default: @@ -1111,15 +1111,15 @@ func encodeInt8(w *WriteBuf, oid Oid, value int8) error { return nil } -func encodeUInt8(w *WriteBuf, oid Oid, value uint8) error { +func encodeUInt8(w *WriteBuf, oid OID, value uint8) error { switch oid { - case Int2Oid: + case Int2OID: w.WriteInt32(2) w.WriteInt16(int16(value)) - case Int4Oid: + case Int4OID: w.WriteInt32(4) w.WriteInt32(int32(value)) - case Int8Oid: + case Int8OID: w.WriteInt32(8) w.WriteInt64(int64(value)) default: @@ -1129,15 +1129,15 @@ func encodeUInt8(w *WriteBuf, oid Oid, value uint8) error { return nil } -func encodeInt16(w *WriteBuf, oid Oid, value int16) error { +func encodeInt16(w *WriteBuf, oid OID, value int16) error { switch oid { - case Int2Oid: + case Int2OID: w.WriteInt32(2) w.WriteInt16(value) - case Int4Oid: + case Int4OID: w.WriteInt32(4) w.WriteInt32(int32(value)) - case Int8Oid: + case Int8OID: w.WriteInt32(8) w.WriteInt64(int64(value)) default: @@ -1147,19 +1147,19 @@ func encodeInt16(w *WriteBuf, oid Oid, value int16) error { return nil } -func encodeUInt16(w *WriteBuf, oid Oid, value uint16) error { +func encodeUInt16(w *WriteBuf, oid OID, value uint16) error { switch oid { - case Int2Oid: + case Int2OID: if value <= math.MaxInt16 { w.WriteInt32(2) w.WriteInt16(int16(value)) } else { return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16) } - case Int4Oid: + case Int4OID: w.WriteInt32(4) w.WriteInt32(int32(value)) - case Int8Oid: + case Int8OID: w.WriteInt32(8) w.WriteInt64(int64(value)) default: @@ -1169,19 +1169,19 @@ func encodeUInt16(w *WriteBuf, oid Oid, value uint16) error { return nil } -func encodeInt32(w *WriteBuf, oid Oid, value int32) error { +func encodeInt32(w *WriteBuf, oid OID, value int32) error { switch oid { - case Int2Oid: + case Int2OID: if value <= math.MaxInt16 { w.WriteInt32(2) w.WriteInt16(int16(value)) } else { return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16) } - case Int4Oid: + case Int4OID: w.WriteInt32(4) w.WriteInt32(value) - case Int8Oid: + case Int8OID: w.WriteInt32(8) w.WriteInt64(int64(value)) default: @@ -1191,23 +1191,23 @@ func encodeInt32(w *WriteBuf, oid Oid, value int32) error { return nil } -func encodeUInt32(w *WriteBuf, oid Oid, value uint32) error { +func encodeUInt32(w *WriteBuf, oid OID, value uint32) error { switch oid { - case Int2Oid: + case Int2OID: if value <= math.MaxInt16 { w.WriteInt32(2) w.WriteInt16(int16(value)) } else { return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16) } - case Int4Oid: + case Int4OID: if value <= math.MaxInt32 { w.WriteInt32(4) w.WriteInt32(int32(value)) } else { return fmt.Errorf("%d is greater than max int32 %d", value, math.MaxInt32) } - case Int8Oid: + case Int8OID: w.WriteInt32(8) w.WriteInt64(int64(value)) default: @@ -1217,23 +1217,23 @@ func encodeUInt32(w *WriteBuf, oid Oid, value uint32) error { return nil } -func encodeInt64(w *WriteBuf, oid Oid, value int64) error { +func encodeInt64(w *WriteBuf, oid OID, value int64) error { switch oid { - case Int2Oid: + case Int2OID: if value <= math.MaxInt16 { w.WriteInt32(2) w.WriteInt16(int16(value)) } else { return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16) } - case Int4Oid: + case Int4OID: if value <= math.MaxInt32 { w.WriteInt32(4) w.WriteInt32(int32(value)) } else { return fmt.Errorf("%d is greater than max int32 %d", value, math.MaxInt32) } - case Int8Oid: + case Int8OID: w.WriteInt32(8) w.WriteInt64(value) default: @@ -1243,23 +1243,23 @@ func encodeInt64(w *WriteBuf, oid Oid, value int64) error { return nil } -func encodeUInt64(w *WriteBuf, oid Oid, value uint64) error { +func encodeUInt64(w *WriteBuf, oid OID, value uint64) error { switch oid { - case Int2Oid: + case Int2OID: if value <= math.MaxInt16 { w.WriteInt32(2) w.WriteInt16(int16(value)) } else { return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16) } - case Int4Oid: + case Int4OID: if value <= math.MaxInt32 { w.WriteInt32(4) w.WriteInt32(int32(value)) } else { return fmt.Errorf("%d is greater than max int32 %d", value, math.MaxInt32) } - case Int8Oid: + case Int8OID: if value <= math.MaxInt64 { w.WriteInt32(8) @@ -1280,7 +1280,7 @@ func decodeInt4(vr *ValueReader) int32 { return 0 } - if vr.Type().DataType != Int4Oid { + if vr.Type().DataType != Int4OID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int32", vr.Type().DataType))) return 0 } @@ -1298,41 +1298,41 @@ func decodeInt4(vr *ValueReader) int32 { return vr.ReadInt32() } -func decodeOid(vr *ValueReader) Oid { +func decodeOID(vr *ValueReader) OID { if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into Oid")) + vr.Fatal(ProtocolError("Cannot decode null into OID")) return 0 } - if vr.Type().DataType != OidOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Oid", vr.Type().DataType))) + if vr.Type().DataType != OIDOID { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.OID", vr.Type().DataType))) return 0 } - // Oid needs to decode text format because it is used in loadPgTypes + // OID needs to decode text format because it is used in loadPgTypes switch vr.Type().FormatCode { case TextFormatCode: s := vr.ReadString(vr.Len()) n, err := strconv.ParseInt(s, 10, 32) if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid Oid: %v", s))) + vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid OID: %v", s))) } - return Oid(n) + return OID(n) case BinaryFormatCode: if vr.Len() != 4 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an Oid: %d", vr.Len()))) + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an OID: %d", vr.Len()))) return 0 } - return Oid(vr.ReadInt32()) + return OID(vr.ReadInt32()) default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return Oid(0) + return OID(0) } } -func encodeOid(w *WriteBuf, oid Oid, value Oid) error { - if oid != OidOid { - return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Oid", oid) +func encodeOID(w *WriteBuf, oid OID, value OID) error { + if oid != OIDOID { + return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.OID", oid) } w.WriteInt32(4) @@ -1347,7 +1347,7 @@ func decodeFloat4(vr *ValueReader) float32 { return 0 } - if vr.Type().DataType != Float4Oid { + if vr.Type().DataType != Float4OID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into float32", vr.Type().DataType))) return 0 } @@ -1366,12 +1366,12 @@ func decodeFloat4(vr *ValueReader) float32 { return math.Float32frombits(uint32(i)) } -func encodeFloat32(w *WriteBuf, oid Oid, value float32) error { +func encodeFloat32(w *WriteBuf, oid OID, value float32) error { switch oid { - case Float4Oid: + case Float4OID: w.WriteInt32(4) w.WriteInt32(int32(math.Float32bits(value))) - case Float8Oid: + case Float8OID: w.WriteInt32(8) w.WriteInt64(int64(math.Float64bits(float64(value)))) default: @@ -1387,7 +1387,7 @@ func decodeFloat8(vr *ValueReader) float64 { return 0 } - if vr.Type().DataType != Float8Oid { + if vr.Type().DataType != Float8OID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into float64", vr.Type().DataType))) return 0 } @@ -1406,9 +1406,9 @@ func decodeFloat8(vr *ValueReader) float64 { return math.Float64frombits(uint64(i)) } -func encodeFloat64(w *WriteBuf, oid Oid, value float64) error { +func encodeFloat64(w *WriteBuf, oid OID, value float64) error { switch oid { - case Float8Oid: + case Float8OID: w.WriteInt32(8) w.WriteInt64(int64(math.Float64bits(value))) default: @@ -1427,7 +1427,7 @@ func decodeText(vr *ValueReader) string { return vr.ReadString(vr.Len()) } -func encodeString(w *WriteBuf, oid Oid, value string) error { +func encodeString(w *WriteBuf, oid OID, value string) error { w.WriteInt32(int32(len(value))) w.WriteBytes([]byte(value)) return nil @@ -1438,7 +1438,7 @@ func decodeBytea(vr *ValueReader) []byte { return nil } - if vr.Type().DataType != ByteaOid { + if vr.Type().DataType != ByteaOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []byte", vr.Type().DataType))) return nil } @@ -1451,7 +1451,7 @@ func decodeBytea(vr *ValueReader) []byte { return vr.ReadBytes(vr.Len()) } -func encodeByteSlice(w *WriteBuf, oid Oid, value []byte) error { +func encodeByteSlice(w *WriteBuf, oid OID, value []byte) error { w.WriteInt32(int32(len(value))) w.WriteBytes(value) @@ -1463,7 +1463,7 @@ func decodeJSON(vr *ValueReader, d interface{}) error { return nil } - if vr.Type().DataType != JsonOid && vr.Type().DataType != JsonbOid { + if vr.Type().DataType != JsonOID && vr.Type().DataType != JsonbOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into json", vr.Type().DataType))) } @@ -1475,8 +1475,8 @@ func decodeJSON(vr *ValueReader, d interface{}) error { return err } -func encodeJSON(w *WriteBuf, oid Oid, value interface{}) error { - if oid != JsonOid && oid != JsonbOid { +func encodeJSON(w *WriteBuf, oid OID, value interface{}) error { + if oid != JsonOID && oid != JsonbOID { return fmt.Errorf("cannot encode JSON into oid %v", oid) } @@ -1499,7 +1499,7 @@ func decodeDate(vr *ValueReader) time.Time { return zeroTime } - if vr.Type().DataType != DateOid { + if vr.Type().DataType != DateOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) return zeroTime } @@ -1516,9 +1516,9 @@ func decodeDate(vr *ValueReader) time.Time { return time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.Local) } -func encodeTime(w *WriteBuf, oid Oid, value time.Time) error { +func encodeTime(w *WriteBuf, oid OID, value time.Time) error { switch oid { - case DateOid: + case DateOID: tUnix := time.Date(value.Year(), value.Month(), value.Day(), 0, 0, 0, 0, time.UTC).Unix() dateEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).Unix() @@ -1529,7 +1529,7 @@ func encodeTime(w *WriteBuf, oid Oid, value time.Time) error { w.WriteInt32(int32(daysSinceDateEpoch)) return nil - case TimestampTzOid, TimestampOid: + case TimestampTzOID, TimestampOID: microsecSinceUnixEpoch := value.Unix()*1000000 + int64(value.Nanosecond())/1000 microsecSinceY2K := microsecSinceUnixEpoch - microsecFromUnixEpochToY2K @@ -1552,7 +1552,7 @@ func decodeTimestampTz(vr *ValueReader) time.Time { return zeroTime } - if vr.Type().DataType != TimestampTzOid { + if vr.Type().DataType != TimestampTzOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) return zeroTime } @@ -1580,7 +1580,7 @@ func decodeTimestamp(vr *ValueReader) time.Time { return zeroTime } - if vr.Type().DataType != TimestampOid { + if vr.Type().DataType != TimestampOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) return zeroTime } @@ -1614,7 +1614,7 @@ func decodeInet(vr *ValueReader) net.IPNet { } pgType := vr.Type() - if pgType.DataType != InetOid && pgType.DataType != CidrOid { + if pgType.DataType != InetOID && pgType.DataType != CidrOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into %s", pgType.DataType, pgType.Name))) return zero } @@ -1635,8 +1635,8 @@ func decodeInet(vr *ValueReader) net.IPNet { return ipnet } -func encodeIPNet(w *WriteBuf, oid Oid, value net.IPNet) error { - if oid != InetOid && oid != CidrOid { +func encodeIPNet(w *WriteBuf, oid OID, value net.IPNet) error { + if oid != InetOID && oid != CidrOID { return fmt.Errorf("cannot encode %s into oid %v", "net.IPNet", oid) } @@ -1664,8 +1664,8 @@ func encodeIPNet(w *WriteBuf, oid Oid, value net.IPNet) error { return nil } -func encodeIP(w *WriteBuf, oid Oid, value net.IP) error { - if oid != InetOid && oid != CidrOid { +func encodeIP(w *WriteBuf, oid OID, value net.IP) error { + if oid != InetOID && oid != CidrOID { return fmt.Errorf("cannot encode %s into oid %v", "net.IP", oid) } @@ -1686,7 +1686,7 @@ func decodeRecord(vr *ValueReader) []interface{} { return nil } - if vr.Type().DataType != RecordOid { + if vr.Type().DataType != RecordOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []interface{}", vr.Type().DataType))) return nil } @@ -1697,36 +1697,36 @@ func decodeRecord(vr *ValueReader) []interface{} { for i := int32(0); i < valueCount; i++ { fd := FieldDescription{FormatCode: BinaryFormatCode} fieldVR := ValueReader{mr: vr.mr, fd: &fd} - fd.DataType = vr.ReadOid() + fd.DataType = vr.ReadOID() fieldVR.valueBytesRemaining = vr.ReadInt32() vr.valueBytesRemaining -= fieldVR.valueBytesRemaining switch fd.DataType { - case BoolOid: + case BoolOID: record = append(record, decodeBool(&fieldVR)) - case ByteaOid: + case ByteaOID: record = append(record, decodeBytea(&fieldVR)) - case Int8Oid: + case Int8OID: record = append(record, decodeInt8(&fieldVR)) - case Int2Oid: + case Int2OID: record = append(record, decodeInt2(&fieldVR)) - case Int4Oid: + case Int4OID: record = append(record, decodeInt4(&fieldVR)) - case OidOid: - record = append(record, decodeOid(&fieldVR)) - case Float4Oid: + case OIDOID: + record = append(record, decodeOID(&fieldVR)) + case Float4OID: record = append(record, decodeFloat4(&fieldVR)) - case Float8Oid: + case Float8OID: record = append(record, decodeFloat8(&fieldVR)) - case DateOid: + case DateOID: record = append(record, decodeDate(&fieldVR)) - case TimestampTzOid: + case TimestampTzOID: record = append(record, decodeTimestampTz(&fieldVR)) - case TimestampOid: + case TimestampOID: record = append(record, decodeTimestamp(&fieldVR)) - case InetOid, CidrOid: + case InetOID, CidrOID: record = append(record, decodeInet(&fieldVR)) - case TextOid, VarcharOid, UnknownOid: + case TextOID, VarcharOID, UnknownOID: record = append(record, decodeText(&fieldVR)) default: vr.Fatal(fmt.Errorf("decodeRecord cannot decode oid %d", fd.DataType)) @@ -1775,7 +1775,7 @@ func decodeBoolArray(vr *ValueReader) []bool { return nil } - if vr.Type().DataType != BoolArrayOid { + if vr.Type().DataType != BoolArrayOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []bool", vr.Type().DataType))) return nil } @@ -1811,12 +1811,12 @@ func decodeBoolArray(vr *ValueReader) []bool { return a } -func encodeBoolSlice(w *WriteBuf, oid Oid, slice []bool) error { - if oid != BoolArrayOid { +func encodeBoolSlice(w *WriteBuf, oid OID, slice []bool) error { + if oid != BoolArrayOID { return fmt.Errorf("cannot encode Go %s into oid %d", "[]bool", oid) } - encodeArrayHeader(w, BoolOid, len(slice), 5) + encodeArrayHeader(w, BoolOID, len(slice), 5) for _, v := range slice { w.WriteInt32(1) var b byte @@ -1834,7 +1834,7 @@ func decodeByteaArray(vr *ValueReader) [][]byte { return nil } - if vr.Type().DataType != ByteaArrayOid { + if vr.Type().DataType != ByteaArrayOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into [][]byte", vr.Type().DataType))) return nil } @@ -1865,8 +1865,8 @@ func decodeByteaArray(vr *ValueReader) [][]byte { return a } -func encodeByteSliceSlice(w *WriteBuf, oid Oid, value [][]byte) error { - if oid != ByteaArrayOid { +func encodeByteSliceSlice(w *WriteBuf, oid OID, value [][]byte) error { + if oid != ByteaArrayOID { return fmt.Errorf("cannot encode Go %s into oid %d", "[][]byte", oid) } @@ -1879,12 +1879,12 @@ func encodeByteSliceSlice(w *WriteBuf, oid Oid, value [][]byte) error { w.WriteInt32(1) // number of dimensions w.WriteInt32(0) // no nulls - w.WriteInt32(int32(ByteaOid)) // type of elements + w.WriteInt32(int32(ByteaOID)) // type of elements w.WriteInt32(int32(len(value))) // number of elements w.WriteInt32(1) // index of first element for _, el := range value { - encodeByteSlice(w, ByteaOid, el) + encodeByteSlice(w, ByteaOID, el) } return nil @@ -1895,7 +1895,7 @@ func decodeInt2Array(vr *ValueReader) []int16 { return nil } - if vr.Type().DataType != Int2ArrayOid { + if vr.Type().DataType != Int2ArrayOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int16", vr.Type().DataType))) return nil } @@ -1934,7 +1934,7 @@ func decodeInt2ArrayToUInt(vr *ValueReader) []uint16 { return nil } - if vr.Type().DataType != Int2ArrayOid { + if vr.Type().DataType != Int2ArrayOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []uint16", vr.Type().DataType))) return nil } @@ -1973,12 +1973,12 @@ func decodeInt2ArrayToUInt(vr *ValueReader) []uint16 { return a } -func encodeInt16Slice(w *WriteBuf, oid Oid, slice []int16) error { - if oid != Int2ArrayOid { +func encodeInt16Slice(w *WriteBuf, oid OID, slice []int16) error { + if oid != Int2ArrayOID { return fmt.Errorf("cannot encode Go %s into oid %d", "[]int16", oid) } - encodeArrayHeader(w, Int2Oid, len(slice), 6) + encodeArrayHeader(w, Int2OID, len(slice), 6) for _, v := range slice { w.WriteInt32(2) w.WriteInt16(v) @@ -1987,12 +1987,12 @@ func encodeInt16Slice(w *WriteBuf, oid Oid, slice []int16) error { return nil } -func encodeUInt16Slice(w *WriteBuf, oid Oid, slice []uint16) error { - if oid != Int2ArrayOid { +func encodeUInt16Slice(w *WriteBuf, oid OID, slice []uint16) error { + if oid != Int2ArrayOID { return fmt.Errorf("cannot encode Go %s into oid %d", "[]uint16", oid) } - encodeArrayHeader(w, Int2Oid, len(slice), 6) + encodeArrayHeader(w, Int2OID, len(slice), 6) for _, v := range slice { if v <= math.MaxInt16 { w.WriteInt32(2) @@ -2010,7 +2010,7 @@ func decodeInt4Array(vr *ValueReader) []int32 { return nil } - if vr.Type().DataType != Int4ArrayOid { + if vr.Type().DataType != Int4ArrayOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int32", vr.Type().DataType))) return nil } @@ -2049,7 +2049,7 @@ func decodeInt4ArrayToUInt(vr *ValueReader) []uint32 { return nil } - if vr.Type().DataType != Int4ArrayOid { + if vr.Type().DataType != Int4ArrayOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []uint32", vr.Type().DataType))) return nil } @@ -2088,12 +2088,12 @@ func decodeInt4ArrayToUInt(vr *ValueReader) []uint32 { return a } -func encodeInt32Slice(w *WriteBuf, oid Oid, slice []int32) error { - if oid != Int4ArrayOid { +func encodeInt32Slice(w *WriteBuf, oid OID, slice []int32) error { + if oid != Int4ArrayOID { return fmt.Errorf("cannot encode Go %s into oid %d", "[]int32", oid) } - encodeArrayHeader(w, Int4Oid, len(slice), 8) + encodeArrayHeader(w, Int4OID, len(slice), 8) for _, v := range slice { w.WriteInt32(4) w.WriteInt32(v) @@ -2102,12 +2102,12 @@ func encodeInt32Slice(w *WriteBuf, oid Oid, slice []int32) error { return nil } -func encodeUInt32Slice(w *WriteBuf, oid Oid, slice []uint32) error { - if oid != Int4ArrayOid { +func encodeUInt32Slice(w *WriteBuf, oid OID, slice []uint32) error { + if oid != Int4ArrayOID { return fmt.Errorf("cannot encode Go %s into oid %d", "[]uint32", oid) } - encodeArrayHeader(w, Int4Oid, len(slice), 8) + encodeArrayHeader(w, Int4OID, len(slice), 8) for _, v := range slice { if v <= math.MaxInt32 { w.WriteInt32(4) @@ -2125,7 +2125,7 @@ func decodeInt8Array(vr *ValueReader) []int64 { return nil } - if vr.Type().DataType != Int8ArrayOid { + if vr.Type().DataType != Int8ArrayOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int64", vr.Type().DataType))) return nil } @@ -2164,7 +2164,7 @@ func decodeInt8ArrayToUInt(vr *ValueReader) []uint64 { return nil } - if vr.Type().DataType != Int8ArrayOid { + if vr.Type().DataType != Int8ArrayOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []uint64", vr.Type().DataType))) return nil } @@ -2203,12 +2203,12 @@ func decodeInt8ArrayToUInt(vr *ValueReader) []uint64 { return a } -func encodeInt64Slice(w *WriteBuf, oid Oid, slice []int64) error { - if oid != Int8ArrayOid { +func encodeInt64Slice(w *WriteBuf, oid OID, slice []int64) error { + if oid != Int8ArrayOID { return fmt.Errorf("cannot encode Go %s into oid %d", "[]int64", oid) } - encodeArrayHeader(w, Int8Oid, len(slice), 12) + encodeArrayHeader(w, Int8OID, len(slice), 12) for _, v := range slice { w.WriteInt32(8) w.WriteInt64(v) @@ -2217,12 +2217,12 @@ func encodeInt64Slice(w *WriteBuf, oid Oid, slice []int64) error { return nil } -func encodeUInt64Slice(w *WriteBuf, oid Oid, slice []uint64) error { - if oid != Int8ArrayOid { +func encodeUInt64Slice(w *WriteBuf, oid OID, slice []uint64) error { + if oid != Int8ArrayOID { return fmt.Errorf("cannot encode Go %s into oid %d", "[]uint64", oid) } - encodeArrayHeader(w, Int8Oid, len(slice), 12) + encodeArrayHeader(w, Int8OID, len(slice), 12) for _, v := range slice { if v <= math.MaxInt64 { w.WriteInt32(8) @@ -2240,7 +2240,7 @@ func decodeFloat4Array(vr *ValueReader) []float32 { return nil } - if vr.Type().DataType != Float4ArrayOid { + if vr.Type().DataType != Float4ArrayOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []float32", vr.Type().DataType))) return nil } @@ -2275,12 +2275,12 @@ func decodeFloat4Array(vr *ValueReader) []float32 { return a } -func encodeFloat32Slice(w *WriteBuf, oid Oid, slice []float32) error { - if oid != Float4ArrayOid { +func encodeFloat32Slice(w *WriteBuf, oid OID, slice []float32) error { + if oid != Float4ArrayOID { return fmt.Errorf("cannot encode Go %s into oid %d", "[]float32", oid) } - encodeArrayHeader(w, Float4Oid, len(slice), 8) + encodeArrayHeader(w, Float4OID, len(slice), 8) for _, v := range slice { w.WriteInt32(4) w.WriteInt32(int32(math.Float32bits(v))) @@ -2294,7 +2294,7 @@ func decodeFloat8Array(vr *ValueReader) []float64 { return nil } - if vr.Type().DataType != Float8ArrayOid { + if vr.Type().DataType != Float8ArrayOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []float64", vr.Type().DataType))) return nil } @@ -2329,12 +2329,12 @@ func decodeFloat8Array(vr *ValueReader) []float64 { return a } -func encodeFloat64Slice(w *WriteBuf, oid Oid, slice []float64) error { - if oid != Float8ArrayOid { +func encodeFloat64Slice(w *WriteBuf, oid OID, slice []float64) error { + if oid != Float8ArrayOID { return fmt.Errorf("cannot encode Go %s into oid %d", "[]float64", oid) } - encodeArrayHeader(w, Float8Oid, len(slice), 12) + encodeArrayHeader(w, Float8OID, len(slice), 12) for _, v := range slice { w.WriteInt32(8) w.WriteInt64(int64(math.Float64bits(v))) @@ -2348,7 +2348,7 @@ func decodeTextArray(vr *ValueReader) []string { return nil } - if vr.Type().DataType != TextArrayOid && vr.Type().DataType != VarcharArrayOid { + if vr.Type().DataType != TextArrayOID && vr.Type().DataType != VarcharArrayOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []string", vr.Type().DataType))) return nil } @@ -2378,13 +2378,13 @@ func decodeTextArray(vr *ValueReader) []string { return a } -func encodeStringSlice(w *WriteBuf, oid Oid, slice []string) error { - var elOid Oid +func encodeStringSlice(w *WriteBuf, oid OID, slice []string) error { + var elOID OID switch oid { - case VarcharArrayOid: - elOid = VarcharOid - case TextArrayOid: - elOid = TextOid + case VarcharArrayOID: + elOID = VarcharOID + case TextArrayOID: + elOID = TextOID default: return fmt.Errorf("cannot encode Go %s into oid %d", "[]string", oid) } @@ -2399,7 +2399,7 @@ func encodeStringSlice(w *WriteBuf, oid Oid, slice []string) error { w.WriteInt32(1) // number of dimensions w.WriteInt32(0) // no nulls - w.WriteInt32(int32(elOid)) // type of elements + w.WriteInt32(int32(elOID)) // type of elements w.WriteInt32(int32(len(slice))) // number of elements w.WriteInt32(1) // index of first element @@ -2416,7 +2416,7 @@ func decodeTimestampArray(vr *ValueReader) []time.Time { return nil } - if vr.Type().DataType != TimestampArrayOid && vr.Type().DataType != TimestampTzArrayOid { + if vr.Type().DataType != TimestampArrayOID && vr.Type().DataType != TimestampTzArrayOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []time.Time", vr.Type().DataType))) return nil } @@ -2452,18 +2452,18 @@ func decodeTimestampArray(vr *ValueReader) []time.Time { return a } -func encodeTimeSlice(w *WriteBuf, oid Oid, slice []time.Time) error { - var elOid Oid +func encodeTimeSlice(w *WriteBuf, oid OID, slice []time.Time) error { + var elOID OID switch oid { - case TimestampArrayOid: - elOid = TimestampOid - case TimestampTzArrayOid: - elOid = TimestampTzOid + case TimestampArrayOID: + elOID = TimestampOID + case TimestampTzArrayOID: + elOID = TimestampTzOID default: return fmt.Errorf("cannot encode Go %s into oid %d", "[]time.Time", oid) } - encodeArrayHeader(w, int(elOid), len(slice), 12) + encodeArrayHeader(w, int(elOID), len(slice), 12) for _, t := range slice { w.WriteInt32(8) microsecSinceUnixEpoch := t.Unix()*1000000 + int64(t.Nanosecond())/1000 @@ -2479,7 +2479,7 @@ func decodeInetArray(vr *ValueReader) []net.IPNet { return nil } - if vr.Type().DataType != InetArrayOid && vr.Type().DataType != CidrArrayOid { + if vr.Type().DataType != InetArrayOID && vr.Type().DataType != CidrArrayOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []net.IP", vr.Type().DataType))) return nil } @@ -2518,13 +2518,13 @@ func decodeInetArray(vr *ValueReader) []net.IPNet { return a } -func encodeIPNetSlice(w *WriteBuf, oid Oid, slice []net.IPNet) error { - var elOid Oid +func encodeIPNetSlice(w *WriteBuf, oid OID, slice []net.IPNet) error { + var elOID OID switch oid { - case InetArrayOid: - elOid = InetOid - case CidrArrayOid: - elOid = CidrOid + case InetArrayOID: + elOID = InetOID + case CidrArrayOID: + elOID = CidrOID default: return fmt.Errorf("cannot encode Go %s into oid %d", "[]net.IPNet", oid) } @@ -2537,24 +2537,24 @@ func encodeIPNetSlice(w *WriteBuf, oid Oid, slice []net.IPNet) error { w.WriteInt32(1) // number of dimensions w.WriteInt32(0) // no nulls - w.WriteInt32(int32(elOid)) // type of elements + w.WriteInt32(int32(elOID)) // type of elements w.WriteInt32(int32(len(slice))) // number of elements w.WriteInt32(1) // index of first element for _, ipnet := range slice { - encodeIPNet(w, elOid, ipnet) + encodeIPNet(w, elOID, ipnet) } return nil } -func encodeIPSlice(w *WriteBuf, oid Oid, slice []net.IP) error { - var elOid Oid +func encodeIPSlice(w *WriteBuf, oid OID, slice []net.IP) error { + var elOID OID switch oid { - case InetArrayOid: - elOid = InetOid - case CidrArrayOid: - elOid = CidrOid + case InetArrayOID: + elOID = InetOID + case CidrArrayOID: + elOID = CidrOID default: return fmt.Errorf("cannot encode Go %s into oid %d", "[]net.IPNet", oid) } @@ -2567,12 +2567,12 @@ func encodeIPSlice(w *WriteBuf, oid Oid, slice []net.IP) error { w.WriteInt32(1) // number of dimensions w.WriteInt32(0) // no nulls - w.WriteInt32(int32(elOid)) // type of elements + w.WriteInt32(int32(elOID)) // type of elements w.WriteInt32(int32(len(slice))) // number of elements w.WriteInt32(1) // index of first element for _, ip := range slice { - encodeIP(w, elOid, ip) + encodeIP(w, elOID, ip) } return nil diff --git a/values_test.go b/values_test.go index 063598d9..9d7d1700 100644 --- a/values_test.go +++ b/values_test.go @@ -84,7 +84,7 @@ func TestJsonAndJsonbTranscode(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - for _, oid := range []pgx.Oid{pgx.JsonOid, pgx.JsonbOid} { + for _, oid := range []pgx.OID{pgx.JsonOID, pgx.JsonbOID} { if _, ok := conn.PgTypes[oid]; !ok { return // No JSON/JSONB type -- must be running against old PostgreSQL } From 04c02cf3d3e6d824e13c1497d6650d0660d9bc5a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 2 Aug 2016 13:35:52 -0500 Subject: [PATCH 003/264] Rename Json(b) to JSON(B) --- example_json_test.go | 2 +- query.go | 6 +++--- v3.md | 2 ++ values.go | 10 +++++----- values_test.go | 36 ++++++++++++++++++------------------ 5 files changed, 29 insertions(+), 27 deletions(-) diff --git a/example_json_test.go b/example_json_test.go index 513cc90b..631430b8 100644 --- a/example_json_test.go +++ b/example_json_test.go @@ -12,7 +12,7 @@ func Example_JSON() { return } - if _, ok := conn.PgTypes[pgx.JsonOID]; !ok { + if _, ok := conn.PgTypes[pgx.JSONOID]; !ok { // No JSON type -- must be running against very old PostgreSQL // Pretend it works fmt.Println("John", 42) diff --git a/query.go b/query.go index abe9860e..34035794 100644 --- a/query.go +++ b/query.go @@ -298,7 +298,7 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } - } else if vr.Type().DataType == JsonOID || vr.Type().DataType == JsonbOID { + } else if vr.Type().DataType == JSONOID || vr.Type().DataType == JSONBOID { // Because the argument passed to decodeJSON will escape the heap. // This allows d to be stack allocated and only copied to the heap when // we actually are decoding JSON. This saves one memory allocation per @@ -387,11 +387,11 @@ func (rows *Rows) Values() ([]interface{}, error) { values = append(values, decodeTimestamp(vr)) case InetOID, CidrOID: values = append(values, decodeInet(vr)) - case JsonOID: + case JSONOID: var d interface{} decodeJSON(vr, &d) values = append(values, d) - case JsonbOID: + case JSONBOID: var d interface{} decodeJSON(vr, &d) values = append(values, d) diff --git a/v3.md b/v3.md index fdf8dcac..5bb7162a 100644 --- a/v3.md +++ b/v3.md @@ -1,3 +1,5 @@ # V3 Changes Rename Oid to OID in accordance with Go conventions. + +Rename Json(b) to JSON(B) in accordance with Go conventions. diff --git a/values.go b/values.go index e49721ff..4fecbf74 100644 --- a/values.go +++ b/values.go @@ -22,7 +22,7 @@ const ( Int4OID = 23 TextOID = 25 OIDOID = 26 - JsonOID = 114 + JSONOID = 114 CidrOID = 650 CidrArrayOID = 651 Float4OID = 700 @@ -47,7 +47,7 @@ const ( TimestampTzArrayOID = 1185 RecordOID = 2249 UuidOID = 2950 - JsonbOID = 3802 + JSONBOID = 3802 ) // PostgreSQL format codes @@ -627,7 +627,7 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { } } - if oid == JsonOID || oid == JsonbOID { + if oid == JSONOID || oid == JSONBOID { return encodeJSON(wbuf, oid, arg) } @@ -1463,7 +1463,7 @@ func decodeJSON(vr *ValueReader, d interface{}) error { return nil } - if vr.Type().DataType != JsonOID && vr.Type().DataType != JsonbOID { + if vr.Type().DataType != JSONOID && vr.Type().DataType != JSONBOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into json", vr.Type().DataType))) } @@ -1476,7 +1476,7 @@ func decodeJSON(vr *ValueReader, d interface{}) error { } func encodeJSON(w *WriteBuf, oid OID, value interface{}) error { - if oid != JsonOID && oid != JsonbOID { + if oid != JSONOID && oid != JSONBOID { return fmt.Errorf("cannot encode JSON into oid %v", oid) } diff --git a/values_test.go b/values_test.go index 9d7d1700..2ef9c774 100644 --- a/values_test.go +++ b/values_test.go @@ -78,30 +78,30 @@ func TestTimestampTzTranscode(t *testing.T) { } } -func TestJsonAndJsonbTranscode(t *testing.T) { +func TestJSONAndJSONBTranscode(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - for _, oid := range []pgx.OID{pgx.JsonOID, pgx.JsonbOID} { + for _, oid := range []pgx.OID{pgx.JSONOID, pgx.JSONBOID} { if _, ok := conn.PgTypes[oid]; !ok { return // No JSON/JSONB type -- must be running against old PostgreSQL } typename := conn.PgTypes[oid].Name - testJsonString(t, conn, typename) - testJsonStringPointer(t, conn, typename) - testJsonSingleLevelStringMap(t, conn, typename) - testJsonNestedMap(t, conn, typename) - testJsonStringArray(t, conn, typename) - testJsonInt64Array(t, conn, typename) - testJsonInt16ArrayFailureDueToOverflow(t, conn, typename) - testJsonStruct(t, conn, typename) + testJSONString(t, conn, typename) + testJSONStringPointer(t, conn, typename) + testJSONSingleLevelStringMap(t, conn, typename) + testJSONNestedMap(t, conn, typename) + testJSONStringArray(t, conn, typename) + testJSONInt64Array(t, conn, typename) + testJSONInt16ArrayFailureDueToOverflow(t, conn, typename) + testJSONStruct(t, conn, typename) } } -func testJsonString(t *testing.T, conn *pgx.Conn, typename string) { +func testJSONString(t *testing.T, conn *pgx.Conn, typename string) { input := `{"key": "value"}` expectedOutput := map[string]string{"key": "value"} var output map[string]string @@ -117,7 +117,7 @@ func testJsonString(t *testing.T, conn *pgx.Conn, typename string) { } } -func testJsonStringPointer(t *testing.T, conn *pgx.Conn, typename string) { +func testJSONStringPointer(t *testing.T, conn *pgx.Conn, typename string) { input := `{"key": "value"}` expectedOutput := map[string]string{"key": "value"} var output map[string]string @@ -133,7 +133,7 @@ func testJsonStringPointer(t *testing.T, conn *pgx.Conn, typename string) { } } -func testJsonSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string) { +func testJSONSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string) { input := map[string]string{"key": "value"} var output map[string]string err := conn.QueryRow("select $1::"+typename, input).Scan(&output) @@ -148,7 +148,7 @@ func testJsonSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string) } } -func testJsonNestedMap(t *testing.T, conn *pgx.Conn, typename string) { +func testJSONNestedMap(t *testing.T, conn *pgx.Conn, typename string) { input := map[string]interface{}{ "name": "Uncanny", "stats": map[string]interface{}{"hp": float64(107), "maxhp": float64(150)}, @@ -167,7 +167,7 @@ func testJsonNestedMap(t *testing.T, conn *pgx.Conn, typename string) { } } -func testJsonStringArray(t *testing.T, conn *pgx.Conn, typename string) { +func testJSONStringArray(t *testing.T, conn *pgx.Conn, typename string) { input := []string{"foo", "bar", "baz"} var output []string err := conn.QueryRow("select $1::"+typename, input).Scan(&output) @@ -180,7 +180,7 @@ func testJsonStringArray(t *testing.T, conn *pgx.Conn, typename string) { } } -func testJsonInt64Array(t *testing.T, conn *pgx.Conn, typename string) { +func testJSONInt64Array(t *testing.T, conn *pgx.Conn, typename string) { input := []int64{1, 2, 234432} var output []int64 err := conn.QueryRow("select $1::"+typename, input).Scan(&output) @@ -193,7 +193,7 @@ func testJsonInt64Array(t *testing.T, conn *pgx.Conn, typename string) { } } -func testJsonInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string) { +func testJSONInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string) { input := []int{1, 2, 234432} var output []int16 err := conn.QueryRow("select $1::"+typename, input).Scan(&output) @@ -202,7 +202,7 @@ func testJsonInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typena } } -func testJsonStruct(t *testing.T, conn *pgx.Conn, typename string) { +func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string) { type person struct { Name string `json:"name"` Age int `json:"age"` From 390f75c0e15d06cc88c5675d24fe613596764262 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 2 Aug 2016 14:42:31 -0500 Subject: [PATCH 004/264] Reduce Logger interface to Log method --- bench_test.go | 89 ++++++++++++++++----------------------------------- conn.go | 13 +------- conn_pool.go | 2 +- conn_test.go | 13 ++------ doc.go | 4 +-- logger.go | 11 ++----- v3.md | 2 ++ 7 files changed, 36 insertions(+), 98 deletions(-) diff --git a/bench_test.go b/bench_test.go index eb9c0595..99f65c2b 100644 --- a/bench_test.go +++ b/bench_test.go @@ -5,7 +5,6 @@ import ( "time" "github.com/jackc/pgx" - log "gopkg.in/inconshreveable/log15.v2" ) func BenchmarkConnPool(b *testing.B) { @@ -294,71 +293,51 @@ func BenchmarkSelectWithoutLogging(b *testing.B) { benchmarkSelectWithLog(b, conn) } -func BenchmarkSelectWithLoggingTraceWithLog15(b *testing.B) { - connConfig := *defaultConnConfig +type discardLogger struct{} - logger := log.New() - lvl, err := log.LvlFromString("debug") - if err != nil { - b.Fatal(err) - } - logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler())) - connConfig.Logger = logger - connConfig.LogLevel = pgx.LogLevelTrace - conn := mustConnect(b, connConfig) +func (dl discardLogger) Log(level int, msg string, ctx ...interface{}) {} + +func BenchmarkSelectWithLoggingTraceDiscard(b *testing.B) { + conn := mustConnect(b, *defaultConnConfig) defer closeConn(b, conn) + var logger discardLogger + conn.SetLogger(logger) + conn.SetLogLevel(pgx.LogLevelTrace) + benchmarkSelectWithLog(b, conn) } -func BenchmarkSelectWithLoggingDebugWithLog15(b *testing.B) { - connConfig := *defaultConnConfig - - logger := log.New() - lvl, err := log.LvlFromString("debug") - if err != nil { - b.Fatal(err) - } - logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler())) - connConfig.Logger = logger - connConfig.LogLevel = pgx.LogLevelDebug - conn := mustConnect(b, connConfig) +func BenchmarkSelectWithLoggingDebugWithDiscard(b *testing.B) { + conn := mustConnect(b, *defaultConnConfig) defer closeConn(b, conn) + var logger discardLogger + conn.SetLogger(logger) + conn.SetLogLevel(pgx.LogLevelDebug) + benchmarkSelectWithLog(b, conn) } -func BenchmarkSelectWithLoggingInfoWithLog15(b *testing.B) { - connConfig := *defaultConnConfig - - logger := log.New() - lvl, err := log.LvlFromString("info") - if err != nil { - b.Fatal(err) - } - logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler())) - connConfig.Logger = logger - connConfig.LogLevel = pgx.LogLevelInfo - conn := mustConnect(b, connConfig) +func BenchmarkSelectWithLoggingInfoWithDiscard(b *testing.B) { + conn := mustConnect(b, *defaultConnConfig) defer closeConn(b, conn) + var logger discardLogger + conn.SetLogger(logger) + conn.SetLogLevel(pgx.LogLevelInfo) + benchmarkSelectWithLog(b, conn) } -func BenchmarkSelectWithLoggingErrorWithLog15(b *testing.B) { - connConfig := *defaultConnConfig - - logger := log.New() - lvl, err := log.LvlFromString("error") - if err != nil { - b.Fatal(err) - } - logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler())) - connConfig.Logger = logger - connConfig.LogLevel = pgx.LogLevelError - conn := mustConnect(b, connConfig) +func BenchmarkSelectWithLoggingErrorWithDiscard(b *testing.B) { + conn := mustConnect(b, *defaultConnConfig) defer closeConn(b, conn) + var logger discardLogger + conn.SetLogger(logger) + conn.SetLogLevel(pgx.LogLevelError) + benchmarkSelectWithLog(b, conn) } @@ -418,17 +397,3 @@ func benchmarkSelectWithLog(b *testing.B, conn *pgx.Conn) { } } } - -func BenchmarkLog15Discard(b *testing.B) { - logger := log.New() - lvl, err := log.LvlFromString("error") - if err != nil { - b.Fatal(err) - } - logger.SetHandler(log.LvlFilterHandler(lvl, log.DiscardHandler())) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - logger.Debug("benchmark", "i", i, "b.N", b.N) - } -} diff --git a/conn.go b/conn.go index e5c2a401..6dfd5d5a 100644 --- a/conn.go +++ b/conn.go @@ -1252,18 +1252,7 @@ func (c *Conn) log(lvl int, msg string, ctx ...interface{}) { ctx = append(ctx, "pid", c.Pid) } - switch lvl { - case LogLevelTrace: - c.logger.Debug(msg, ctx...) - case LogLevelDebug: - c.logger.Debug(msg, ctx...) - case LogLevelInfo: - c.logger.Info(msg, ctx...) - case LogLevelWarn: - c.logger.Warn(msg, ctx...) - case LogLevelError: - c.logger.Error(msg, ctx...) - } + c.logger.Log(lvl, msg, ctx...) } // SetLogger replaces the current logger and returns the previous logger. diff --git a/conn_pool.go b/conn_pool.go index a72d5daf..1627af10 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -148,7 +148,7 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) { } else { // All connections are in use and we cannot create more if p.logLevel >= LogLevelWarn { - p.logger.Warn("All connections in pool are busy - waiting...") + p.logger.Log(LogLevelWarn, "All connections in pool are busy - waiting...") } // Wait until there is an available connection OR room to create a new connection diff --git a/conn_test.go b/conn_test.go index 4067118c..62e8b7d5 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1401,17 +1401,8 @@ type testLogger struct { logs []testLog } -func (l *testLogger) Debug(msg string, ctx ...interface{}) { - l.logs = append(l.logs, testLog{lvl: pgx.LogLevelDebug, msg: msg, ctx: ctx}) -} -func (l *testLogger) Info(msg string, ctx ...interface{}) { - l.logs = append(l.logs, testLog{lvl: pgx.LogLevelInfo, msg: msg, ctx: ctx}) -} -func (l *testLogger) Warn(msg string, ctx ...interface{}) { - l.logs = append(l.logs, testLog{lvl: pgx.LogLevelWarn, msg: msg, ctx: ctx}) -} -func (l *testLogger) Error(msg string, ctx ...interface{}) { - l.logs = append(l.logs, testLog{lvl: pgx.LogLevelError, msg: msg, ctx: ctx}) +func (l *testLogger) Log(level int, msg string, ctx ...interface{}) { + l.logs = append(l.logs, testLog{lvl: level, msg: msg, ctx: ctx}) } func TestSetLogger(t *testing.T) { diff --git a/doc.go b/doc.go index 0fd3d2f6..7964aa82 100644 --- a/doc.go +++ b/doc.go @@ -202,9 +202,7 @@ connection. Logging pgx defines a simple logger interface. Connections optionally accept a logger -that satisfies this interface. The log15 package -(http://gopkg.in/inconshreveable/log15.v2) satisfies this interface and it is -simple to define adapters for other loggers. Set LogLevel to control logging +that satisfies this interface. Set LogLevel to control logging verbosity. */ package pgx diff --git a/logger.go b/logger.go index f85d4bd0..8cadee4e 100644 --- a/logger.go +++ b/logger.go @@ -7,8 +7,7 @@ import ( ) // The values for log levels are chosen such that the zero value means that no -// log level was specified and we can default to LogLevelDebug to preserve -// the behavior that existed prior to log level introduction. +// log level was specified. const ( LogLevelTrace = 6 LogLevelDebug = 5 @@ -19,15 +18,9 @@ const ( ) // Logger is the interface used to get logging from pgx internals. -// https://github.com/inconshreveable/log15 is the recommended logging package. -// This logging interface was extracted from there. However, it should be simple -// to adapt any logger to this interface. type Logger interface { // Log a message at the given level with context key/value pairs - Debug(msg string, ctx ...interface{}) - Info(msg string, ctx ...interface{}) - Warn(msg string, ctx ...interface{}) - Error(msg string, ctx ...interface{}) + Log(level int, msg string, ctx ...interface{}) } // LogLevelFromString converts log level string to constant diff --git a/v3.md b/v3.md index 5bb7162a..75c9385c 100644 --- a/v3.md +++ b/v3.md @@ -3,3 +3,5 @@ Rename Oid to OID in accordance with Go conventions. Rename Json(b) to JSON(B) in accordance with Go conventions. + +Logger interface reduced to single Log method. From 73124171e22d33f2ae4f42b2c792e390007ea94a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 2 Aug 2016 15:10:00 -0500 Subject: [PATCH 005/264] Rename Pid to PID --- conn.go | 12 ++++++------ conn_pool_test.go | 12 ++++++------ conn_test.go | 6 +++--- examples/chat/main.go | 2 +- v3.md | 6 ++++-- 5 files changed, 20 insertions(+), 18 deletions(-) diff --git a/conn.go b/conn.go index 6dfd5d5a..752c3ddd 100644 --- a/conn.go +++ b/conn.go @@ -48,7 +48,7 @@ type Conn struct { reader *bufio.Reader // buffered reader to improve read performance wbuf [1024]byte writeBuf WriteBuf - Pid int32 // backend pid + PID int32 // backend pid SecretKey int32 // key to use to send a cancel query message to the server RuntimeParams map[string]string // parameters that have been reported by the server PgTypes map[OID]PgType // oids to PgTypes @@ -85,7 +85,7 @@ type PrepareExOptions struct { // Notification is a message received from the PostgreSQL LISTEN/NOTIFY system type Notification struct { - Pid int32 // backend pid that sent the notification + PID int32 // backend pid that sent the notification Channel string // channel from which notification was received Payload string } @@ -1137,7 +1137,7 @@ func (c *Conn) rxErrorResponse(r *msgReader) (err PgError) { } func (c *Conn) rxBackendKeyData(r *msgReader) { - c.Pid = r.readInt32() + c.PID = r.readInt32() c.SecretKey = r.readInt32() } @@ -1180,7 +1180,7 @@ func (c *Conn) rxParameterDescription(r *msgReader) (parameters []OID) { func (c *Conn) rxNotificationResponse(r *msgReader) { n := new(Notification) - n.Pid = r.readInt32() + n.PID = r.readInt32() n.Channel = r.readCString() n.Payload = r.readCString() c.notifications = append(c.notifications, n) @@ -1248,8 +1248,8 @@ func (c *Conn) shouldLog(lvl int) bool { } func (c *Conn) log(lvl int, msg string, ctx ...interface{}) { - if c.Pid != 0 { - ctx = append(ctx, "pid", c.Pid) + if c.PID != 0 { + ctx = append(ctx, "pid", c.PID) } c.logger.Log(lvl, msg, ctx...) diff --git a/conn_pool_test.go b/conn_pool_test.go index 9aa31758..773a0272 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -428,7 +428,7 @@ func TestPoolReleaseDiscardsDeadConnections(t *testing.T) { } }() - if _, err = c2.Exec("select pg_terminate_backend($1)", c1.Pid); err != nil { + if _, err = c2.Exec("select pg_terminate_backend($1)", c1.PID); err != nil { t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) } @@ -599,7 +599,7 @@ func TestConnPoolBeginRetry(t *testing.T) { pool.Release(victimConn) // Terminate connection that was released to pool - if _, err = killerConn.Exec("select pg_terminate_backend($1)", victimConn.Pid); err != nil { + if _, err = killerConn.Exec("select pg_terminate_backend($1)", victimConn.PID); err != nil { t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) } @@ -611,13 +611,13 @@ func TestConnPoolBeginRetry(t *testing.T) { } defer tx.Rollback() - var txPid int32 - err = tx.QueryRow("select pg_backend_pid()").Scan(&txPid) + var txPID int32 + err = tx.QueryRow("select pg_backend_pid()").Scan(&txPID) if err != nil { t.Fatalf("tx.QueryRow Scan failed: %v", err) } - if txPid == victimConn.Pid { - t.Error("Expected txPid to defer from killed conn pid, but it didn't") + if txPID == victimConn.PID { + t.Error("Expected txPID to defer from killed conn pid, but it didn't") } }() } diff --git a/conn_test.go b/conn_test.go index 62e8b7d5..60f762ed 100644 --- a/conn_test.go +++ b/conn_test.go @@ -27,7 +27,7 @@ func TestConnect(t *testing.T) { t.Error("Runtime parameters not stored") } - if conn.Pid == 0 { + if conn.PID == 0 { t.Error("Backend PID not stored") } @@ -1255,7 +1255,7 @@ func TestFatalRxError(t *testing.T) { } defer otherConn.Close() - if _, err := otherConn.Exec("select pg_terminate_backend($1)", conn.Pid); err != nil { + if _, err := otherConn.Exec("select pg_terminate_backend($1)", conn.PID); err != nil { t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) } @@ -1281,7 +1281,7 @@ func TestFatalTxError(t *testing.T) { } defer otherConn.Close() - _, err = otherConn.Exec("select pg_terminate_backend($1)", conn.Pid) + _, err = otherConn.Exec("select pg_terminate_backend($1)", conn.PID) if err != nil { t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) } diff --git a/examples/chat/main.go b/examples/chat/main.go index 517508cc..ad8d56db 100644 --- a/examples/chat/main.go +++ b/examples/chat/main.go @@ -67,7 +67,7 @@ func listen() { os.Exit(1) } - fmt.Println("PID:", notification.Pid, "Channel:", notification.Channel, "Payload:", notification.Payload) + fmt.Println("PID:", notification.PID, "Channel:", notification.Channel, "Payload:", notification.Payload) } } diff --git a/v3.md b/v3.md index 75c9385c..38c41146 100644 --- a/v3.md +++ b/v3.md @@ -1,7 +1,9 @@ # V3 Changes -Rename Oid to OID in accordance with Go conventions. +Rename Oid to OID in accordance with Go naming conventions. -Rename Json(b) to JSON(B) in accordance with Go conventions. +Rename Json(b) to JSON(B) in accordance with Go naming conventions. + +Rename Pid to PID in accordance with Go naming conventions. Logger interface reduced to single Log method. From e47373227b79362952be8c0c0d952fa5991345b7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 2 Aug 2016 16:21:32 -0500 Subject: [PATCH 006/264] Rename Uuid to UUID --- v3.md | 2 ++ values.go | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/v3.md b/v3.md index 38c41146..1cb31615 100644 --- a/v3.md +++ b/v3.md @@ -6,4 +6,6 @@ Rename Json(b) to JSON(B) in accordance with Go naming conventions. Rename Pid to PID in accordance with Go naming conventions. +Rename Uuid to UUID in accordance with Go naming conventions. + Logger interface reduced to single Log method. diff --git a/values.go b/values.go index 4fecbf74..e8e5a6d5 100644 --- a/values.go +++ b/values.go @@ -46,7 +46,7 @@ const ( TimestampTzOID = 1184 TimestampTzArrayOID = 1185 RecordOID = 2249 - UuidOID = 2950 + UUIDOID = 2950 JSONBOID = 3802 ) From 2578cba28829f5e0ab6b5bc1685cfa657a11761d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 2 Aug 2016 16:59:07 -0500 Subject: [PATCH 007/264] Add v3 todos --- v3.md | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/v3.md b/v3.md index 1cb31615..4a80ab11 100644 --- a/v3.md +++ b/v3.md @@ -1,4 +1,6 @@ -# V3 Changes +# V3 Experimental + +## Changes Rename Oid to OID in accordance with Go naming conventions. @@ -9,3 +11,22 @@ Rename Pid to PID in accordance with Go naming conventions. Rename Uuid to UUID in accordance with Go naming conventions. Logger interface reduced to single Log method. + +## TODO / Possible / Investigate + +Organize errors better + +Optionally use Go 1.7 context + +Remove circular dependency between Conn and ConnPool such that ConnPool depends on Conn, but Conn doesn't know anything about ConnPool + +Extract types Null* and Hstore to separate package + +Remove names from prepared statements - use database/sql style objects + +Better way of handling text/binary protocol choice than pgx.DefaultTypeFormats or manually editing a PreparedStatement. Possibly an optional part of preparing a statement is specifying the format and/or a decoder. Or maybe it is part of a QueryEx call... Could be very interesting to make encoding and decoding possible without being a method of the type. This could drastically clean up those huge type switches. + +Copy protocol support (this potentially ties in with text/binary protocol) +ValueReader / msgReader cleanup + +Make easier / possible to mock Conn or ConnPool (https://github.com/jackc/pgx/pull/162) From 430d4943c75109529161d5ca07da8db30af87023 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 31 Dec 2016 11:48:45 -0600 Subject: [PATCH 008/264] Replace BeginIso with BeginEx Adds support for read/write mode and deferrable modes. --- conn_pool.go | 12 ++++---- conn_pool_test.go | 4 +-- tx.go | 71 ++++++++++++++++++++++++++++++++--------------- tx_test.go | 36 ++++++++++++++++++------ v3.md | 4 +++ 5 files changed, 88 insertions(+), 39 deletions(-) diff --git a/conn_pool.go b/conn_pool.go index 67868769..126d5b14 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -383,7 +383,7 @@ func (p *ConnPool) QueryRow(sql string, args ...interface{}) *Row { // Begin acquires a connection and begins a transaction on it. When the // transaction is closed the connection will be automatically released. func (p *ConnPool) Begin() (*Tx, error) { - return p.BeginIso("") + return p.BeginEx(nil) } // Prepare creates a prepared statement on a connection in the pool to test the @@ -469,17 +469,17 @@ func (p *ConnPool) Deallocate(name string) (err error) { return nil } -// BeginIso acquires a connection and begins a transaction in isolation mode iso -// on it. When the transaction is closed the connection will be automatically -// released. -func (p *ConnPool) BeginIso(iso string) (*Tx, error) { +// BeginEx acquires a connection and starts a transaction with txOptions +// determining the transaction mode. When the transaction is closed the +// connection will be automatically released. +func (p *ConnPool) BeginEx(txOptions *TxOptions) (*Tx, error) { for { c, err := p.Acquire() if err != nil { return nil, err } - tx, err := c.BeginIso(iso) + tx, err := c.BeginEx(txOptions) if err != nil { alive := c.IsAlive() p.Release(c) diff --git a/conn_pool_test.go b/conn_pool_test.go index db8702fb..212a79d9 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -560,9 +560,9 @@ func TestConnPoolTransactionIso(t *testing.T) { pool := createConnPool(t, 2) defer pool.Close() - tx, err := pool.BeginIso(pgx.Serializable) + tx, err := pool.BeginEx(&pgx.TxOptions{IsoLevel: pgx.Serializable}) if err != nil { - t.Fatalf("pool.Begin failed: %v", err) + t.Fatalf("pool.BeginEx failed: %v", err) } defer tx.Rollback() diff --git a/tx.go b/tx.go index 36f99c28..a636b364 100644 --- a/tx.go +++ b/tx.go @@ -1,16 +1,35 @@ package pgx import ( + "bytes" "errors" "fmt" ) +type TxIsoLevel string + // Transaction isolation levels const ( - Serializable = "serializable" - RepeatableRead = "repeatable read" - ReadCommitted = "read committed" - ReadUncommitted = "read uncommitted" + Serializable = TxIsoLevel("serializable") + RepeatableRead = TxIsoLevel("repeatable read") + ReadCommitted = TxIsoLevel("read committed") + ReadUncommitted = TxIsoLevel("read uncommitted") +) + +type TxAccessMode string + +// Transaction access modes +const ( + ReadWrite = TxAccessMode("read write") + ReadOnly = TxAccessMode("read only") +) + +type TxDeferrableMode string + +// Transaction deferrable modes +const ( + Deferrable = TxDeferrableMode("deferrable") + NotDeferrable = TxDeferrableMode("not deferrable") ) const ( @@ -21,6 +40,12 @@ const ( TxStatusRollbackSuccess = 2 ) +type TxOptions struct { + IsoLevel TxIsoLevel + AccessMode TxAccessMode + DeferrableMode TxDeferrableMode +} + var ErrTxClosed = errors.New("tx is closed") // ErrTxCommitRollback occurs when an error has occurred in a transaction and @@ -28,30 +53,32 @@ var ErrTxClosed = errors.New("tx is closed") // it is treated as ROLLBACK. var ErrTxCommitRollback = errors.New("commit unexpectedly resulted in rollback") -// Begin starts a transaction with the default isolation level for the current -// connection. To use a specific isolation level see BeginIso. +// Begin starts a transaction with the default transaction mode for the +// current connection. To use a specific transaction mode see BeginEx. func (c *Conn) Begin() (*Tx, error) { - return c.begin("") + return c.BeginEx(nil) } -// BeginIso starts a transaction with isoLevel as the transaction isolation -// level. -// -// Valid isolation levels (and their constants) are: -// serializable (pgx.Serializable) -// repeatable read (pgx.RepeatableRead) -// read committed (pgx.ReadCommitted) -// read uncommitted (pgx.ReadUncommitted) -func (c *Conn) BeginIso(isoLevel string) (*Tx, error) { - return c.begin(isoLevel) -} - -func (c *Conn) begin(isoLevel string) (*Tx, error) { +// BeginEx starts a transaction with txOptions determining the transaction +// mode. +func (c *Conn) BeginEx(txOptions *TxOptions) (*Tx, error) { var beginSQL string - if isoLevel == "" { + if txOptions == nil { beginSQL = "begin" } else { - beginSQL = fmt.Sprintf("begin isolation level %s", isoLevel) + buf := &bytes.Buffer{} + buf.WriteString("begin") + if txOptions.IsoLevel != "" { + fmt.Fprintf(buf, " isolation level %s", txOptions.IsoLevel) + } + if txOptions.AccessMode != "" { + fmt.Fprintf(buf, " %s", txOptions.AccessMode) + } + if txOptions.DeferrableMode != "" { + fmt.Fprintf(buf, " %s", txOptions.DeferrableMode) + } + + beginSQL = buf.String() } _, err := c.Exec(beginSQL) diff --git a/tx_test.go b/tx_test.go index 435521a3..0ba5904b 100644 --- a/tx_test.go +++ b/tx_test.go @@ -107,15 +107,15 @@ func TestTxCommitSerializationFailure(t *testing.T) { } defer pool.Exec(`drop table tx_serializable_sums`) - tx1, err := pool.BeginIso(pgx.Serializable) + tx1, err := pool.BeginEx(&pgx.TxOptions{IsoLevel: pgx.Serializable}) if err != nil { - t.Fatalf("BeginIso failed: %v", err) + t.Fatalf("BeginEx failed: %v", err) } defer tx1.Rollback() - tx2, err := pool.BeginIso(pgx.Serializable) + tx2, err := pool.BeginEx(&pgx.TxOptions{IsoLevel: pgx.Serializable}) if err != nil { - t.Fatalf("BeginIso failed: %v", err) + t.Fatalf("BeginEx failed: %v", err) } defer tx2.Rollback() @@ -182,20 +182,20 @@ func TestTransactionSuccessfulRollback(t *testing.T) { } } -func TestBeginIso(t *testing.T) { +func TestBeginExIsoLevels(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - isoLevels := []string{pgx.Serializable, pgx.RepeatableRead, pgx.ReadCommitted, pgx.ReadUncommitted} + isoLevels := []pgx.TxIsoLevel{pgx.Serializable, pgx.RepeatableRead, pgx.ReadCommitted, pgx.ReadUncommitted} for _, iso := range isoLevels { - tx, err := conn.BeginIso(iso) + tx, err := conn.BeginEx(&pgx.TxOptions{IsoLevel: iso}) if err != nil { - t.Fatalf("conn.BeginIso failed: %v", err) + t.Fatalf("conn.BeginEx failed: %v", err) } - var level string + var level pgx.TxIsoLevel conn.QueryRow("select current_setting('transaction_isolation')").Scan(&level) if level != iso { t.Errorf("Expected to be in isolation level %v but was %v", iso, level) @@ -208,6 +208,24 @@ func TestBeginIso(t *testing.T) { } } +func TestBeginExReadOnly(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + tx, err := conn.BeginEx(&pgx.TxOptions{AccessMode: pgx.ReadOnly}) + if err != nil { + t.Fatalf("conn.BeginEx failed: %v", err) + } + defer tx.Rollback() + + _, err = conn.Exec("create table foo(id serial primary key)") + if pgErr, ok := err.(pgx.PgError); !ok || pgErr.Code != "25006" { + t.Errorf("Expected error SQLSTATE 25006, but got %#v", err) + } +} + func TestTxAfterClose(t *testing.T) { t.Parallel() diff --git a/v3.md b/v3.md index 4a80ab11..c4f04d7d 100644 --- a/v3.md +++ b/v3.md @@ -12,6 +12,10 @@ Rename Uuid to UUID in accordance with Go naming conventions. Logger interface reduced to single Log method. +Replace BeginIso with BeginEx. BeginEx adds support for read/write mode and deferrable mode. + +Transaction isolation level constants are now typed strings instead of bare strings. + ## TODO / Possible / Investigate Organize errors better From a35621d28523af3072e85ccb8e3ce9e1a11e1f00 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 6 Jan 2017 18:45:02 -0600 Subject: [PATCH 009/264] Add note about cancel and timeouts --- v3.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/v3.md b/v3.md index c4f04d7d..d89c4875 100644 --- a/v3.md +++ b/v3.md @@ -20,7 +20,7 @@ Transaction isolation level constants are now typed strings instead of bare stri Organize errors better -Optionally use Go 1.7 context +Optionally use Go 1.7 context / cancel and timeouts could be implemented this way Remove circular dependency between Conn and ConnPool such that ConnPool depends on Conn, but Conn doesn't know anything about ConnPool From 0131efd6c9700826177ae5759af03edef81b54da Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 7 Jan 2017 10:11:27 -0600 Subject: [PATCH 010/264] Fix url_shortener example logging --- examples/url_shortener/main.go | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/examples/url_shortener/main.go b/examples/url_shortener/main.go index f6a22c37..25e4cb90 100644 --- a/examples/url_shortener/main.go +++ b/examples/url_shortener/main.go @@ -98,7 +98,28 @@ func urlHandler(w http.ResponseWriter, req *http.Request) { } } +type log15Adapter struct { + logger log.Logger +} + +func (a *log15Adapter) Log(level int, msg string, ctx ...interface{}) { + switch level { + case pgx.LogLevelTrace, pgx.LogLevelDebug: + a.logger.Debug(msg, ctx...) + case pgx.LogLevelInfo: + a.logger.Info(msg, ctx...) + case pgx.LogLevelWarn: + a.logger.Warn(msg, ctx...) + case pgx.LogLevelError: + a.logger.Error(msg, ctx...) + default: + panic("invalid log level") + } +} + func main() { + logger := &log15Adapter{logger: log.New("module", "pgx")} + var err error connPoolConfig := pgx.ConnPoolConfig{ ConnConfig: pgx.ConnConfig{ @@ -106,7 +127,7 @@ func main() { User: "jack", Password: "jack", Database: "url_shortener", - Logger: log.New("module", "pgx"), + Logger: logger, }, MaxConnections: 5, AfterConnect: afterConnect, From ecedf3d94a48b5551e9e65d2b157d3d4d2d6da63 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 7 Jan 2017 10:16:12 -0600 Subject: [PATCH 011/264] Fix stdlib test logger --- stdlib/sql_test.go | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 5a5f7049..546ec4fb 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -335,17 +335,8 @@ type testLogger struct { logs []testLog } -func (l *testLogger) Debug(msg string, ctx ...interface{}) { - l.logs = append(l.logs, testLog{lvl: pgx.LogLevelDebug, msg: msg, ctx: ctx}) -} -func (l *testLogger) Info(msg string, ctx ...interface{}) { - l.logs = append(l.logs, testLog{lvl: pgx.LogLevelInfo, msg: msg, ctx: ctx}) -} -func (l *testLogger) Warn(msg string, ctx ...interface{}) { - l.logs = append(l.logs, testLog{lvl: pgx.LogLevelWarn, msg: msg, ctx: ctx}) -} -func (l *testLogger) Error(msg string, ctx ...interface{}) { - l.logs = append(l.logs, testLog{lvl: pgx.LogLevelError, msg: msg, ctx: ctx}) +func (l *testLogger) Log(lvl int, msg string, ctx ...interface{}) { + l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, ctx: ctx}) } func TestConnQueryLog(t *testing.T) { From 356fcd4b0e121a287a76875112bba5c5096b53d0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 7 Jan 2017 12:58:37 -0600 Subject: [PATCH 012/264] Add note for privatizing struct vars --- v3.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/v3.md b/v3.md index d89c4875..58887aed 100644 --- a/v3.md +++ b/v3.md @@ -34,3 +34,5 @@ Copy protocol support (this potentially ties in with text/binary protocol) ValueReader / msgReader cleanup Make easier / possible to mock Conn or ConnPool (https://github.com/jackc/pgx/pull/162) + +Every field that should not be set by user should be replaced by accessor method (e.g. Conn.PID, Conn.SecretKey) From e871ccfca20cb066afb25a057847df51b2199362 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 7 Jan 2017 13:12:09 -0600 Subject: [PATCH 013/264] Update docs for BeginEx --- doc.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc.go b/doc.go index 248c7e26..14843c28 100644 --- a/doc.go +++ b/doc.go @@ -190,8 +190,8 @@ type conversion from string. Transactions -Transactions are started by calling Begin or BeginIso. The BeginIso variant -creates a transaction with a specified isolation level. +Transactions are started by calling Begin or BeginEx. The BeginEx variant +can create a transaction with a specified isolation level. tx, err := conn.Begin() if err != nil { From ec513248acbfc5e83980b53634d0e70047f7ce66 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 7 Jan 2017 13:37:36 -0600 Subject: [PATCH 014/264] Conn.PID accessed through method --- conn.go | 11 ++++++++--- conn_pool_test.go | 6 +++--- conn_test.go | 6 +++--- v3.md | 2 ++ 4 files changed, 16 insertions(+), 9 deletions(-) diff --git a/conn.go b/conn.go index 380691cb..72e9652c 100644 --- a/conn.go +++ b/conn.go @@ -48,7 +48,7 @@ type Conn struct { reader *bufio.Reader // buffered reader to improve read performance wbuf [1024]byte writeBuf WriteBuf - PID int32 // backend pid + pid int32 // backend pid SecretKey int32 // key to use to send a cancel query message to the server RuntimeParams map[string]string // parameters that have been reported by the server PgTypes map[OID]PgType // oids to PgTypes @@ -381,6 +381,11 @@ func (c *Conn) loadInetConstants() error { return nil } +// PID returns the backend PID for this connection. +func (c *Conn) PID() int32 { + return c.pid +} + // Close closes a connection. It is safe to call Close on a already closed // connection. func (c *Conn) Close() (err error) { @@ -1140,7 +1145,7 @@ func (c *Conn) rxErrorResponse(r *msgReader) (err PgError) { } func (c *Conn) rxBackendKeyData(r *msgReader) { - c.PID = r.readInt32() + c.pid = r.readInt32() c.SecretKey = r.readInt32() } @@ -1251,7 +1256,7 @@ func (c *Conn) shouldLog(lvl int) bool { } func (c *Conn) log(lvl int, msg string, ctx ...interface{}) { - if c.PID != 0 { + if c.pid != 0 { ctx = append(ctx, "pid", c.PID) } diff --git a/conn_pool_test.go b/conn_pool_test.go index 212a79d9..f6f166d8 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -428,7 +428,7 @@ func TestPoolReleaseDiscardsDeadConnections(t *testing.T) { } }() - if _, err = c2.Exec("select pg_terminate_backend($1)", c1.PID); err != nil { + if _, err = c2.Exec("select pg_terminate_backend($1)", c1.PID()); err != nil { t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) } @@ -599,7 +599,7 @@ func TestConnPoolBeginRetry(t *testing.T) { pool.Release(victimConn) // Terminate connection that was released to pool - if _, err = killerConn.Exec("select pg_terminate_backend($1)", victimConn.PID); err != nil { + if _, err = killerConn.Exec("select pg_terminate_backend($1)", victimConn.PID()); err != nil { t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) } @@ -616,7 +616,7 @@ func TestConnPoolBeginRetry(t *testing.T) { if err != nil { t.Fatalf("tx.QueryRow Scan failed: %v", err) } - if txPID == victimConn.PID { + if txPID == victimConn.PID() { t.Error("Expected txPID to defer from killed conn pid, but it didn't") } }() diff --git a/conn_test.go b/conn_test.go index 10f40552..ecd7e88d 100644 --- a/conn_test.go +++ b/conn_test.go @@ -27,7 +27,7 @@ func TestConnect(t *testing.T) { t.Error("Runtime parameters not stored") } - if conn.PID == 0 { + if conn.PID() == 0 { t.Error("Backend PID not stored") } @@ -1255,7 +1255,7 @@ func TestFatalRxError(t *testing.T) { } defer otherConn.Close() - if _, err := otherConn.Exec("select pg_terminate_backend($1)", conn.PID); err != nil { + if _, err := otherConn.Exec("select pg_terminate_backend($1)", conn.PID()); err != nil { t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) } @@ -1281,7 +1281,7 @@ func TestFatalTxError(t *testing.T) { } defer otherConn.Close() - _, err = otherConn.Exec("select pg_terminate_backend($1)", conn.PID) + _, err = otherConn.Exec("select pg_terminate_backend($1)", conn.PID()) if err != nil { t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) } diff --git a/v3.md b/v3.md index 58887aed..424635fc 100644 --- a/v3.md +++ b/v3.md @@ -16,6 +16,8 @@ Replace BeginIso with BeginEx. BeginEx adds support for read/write mode and defe Transaction isolation level constants are now typed strings instead of bare strings. +Conn.Pid changed to accessor method Conn.PID() + ## TODO / Possible / Investigate Organize errors better From 69434056c65891277f70fe8b30aca07fe95b56ad Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 7 Jan 2017 14:10:11 -0600 Subject: [PATCH 015/264] Remove Conn.TxStatus --- conn.go | 4 ++-- conn_pool.go | 2 +- conn_pool_test.go | 16 ++++++++-------- private_test.go | 7 +++++++ v3.md | 2 ++ 5 files changed, 20 insertions(+), 11 deletions(-) create mode 100644 private_test.go diff --git a/conn.go b/conn.go index 72e9652c..75792408 100644 --- a/conn.go +++ b/conn.go @@ -53,7 +53,7 @@ type Conn struct { RuntimeParams map[string]string // parameters that have been reported by the server PgTypes map[OID]PgType // oids to PgTypes config ConnConfig // config used when establishing this connection - TxStatus byte + txStatus byte preparedStatements map[string]*PreparedStatement channels map[string]struct{} notifications []*Notification @@ -1150,7 +1150,7 @@ func (c *Conn) rxBackendKeyData(r *msgReader) { } func (c *Conn) rxReadyForQuery(r *msgReader) { - c.TxStatus = r.readByte() + c.txStatus = r.readByte() } func (c *Conn) rxRowDescription(r *msgReader) (fields []FieldDescription) { diff --git a/conn_pool.go b/conn_pool.go index 126d5b14..4bb64a24 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -181,7 +181,7 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) { // Release gives up use of a connection. func (p *ConnPool) Release(conn *Conn) { - if conn.TxStatus != 'I' { + if conn.txStatus != 'I' { conn.Exec("rollback") } diff --git a/conn_pool_test.go b/conn_pool_test.go index f6f166d8..0bbda0bc 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -328,14 +328,14 @@ func TestPoolReleaseWithTransactions(t *testing.T) { t.Fatal("Did not receive expected error") } - if conn.TxStatus != 'E' { - t.Fatalf("Expected TxStatus to be 'E', instead it was '%c'", conn.TxStatus) + if conn.TxStatus() != 'E' { + t.Fatalf("Expected TxStatus to be 'E', instead it was '%c'", conn.TxStatus()) } pool.Release(conn) - if conn.TxStatus != 'I' { - t.Fatalf("Expected release to rollback errored transaction, but it did not: '%c'", conn.TxStatus) + if conn.TxStatus() != 'I' { + t.Fatalf("Expected release to rollback errored transaction, but it did not: '%c'", conn.TxStatus()) } conn, err = pool.Acquire() @@ -343,14 +343,14 @@ func TestPoolReleaseWithTransactions(t *testing.T) { t.Fatalf("Unable to acquire connection: %v", err) } mustExec(t, conn, "begin") - if conn.TxStatus != 'T' { - t.Fatalf("Expected txStatus to be 'T', instead it was '%c'", conn.TxStatus) + if conn.TxStatus() != 'T' { + t.Fatalf("Expected txStatus to be 'T', instead it was '%c'", conn.TxStatus()) } pool.Release(conn) - if conn.TxStatus != 'I' { - t.Fatalf("Expected release to rollback uncommitted transaction, but it did not: '%c'", conn.TxStatus) + if conn.TxStatus() != 'I' { + t.Fatalf("Expected release to rollback uncommitted transaction, but it did not: '%c'", conn.TxStatus()) } } diff --git a/private_test.go b/private_test.go new file mode 100644 index 00000000..df732a72 --- /dev/null +++ b/private_test.go @@ -0,0 +1,7 @@ +package pgx + +// This file contains methods that expose internal pgx state to tests. + +func (c *Conn) TxStatus() byte { + return c.txStatus +} diff --git a/v3.md b/v3.md index 424635fc..ca13055f 100644 --- a/v3.md +++ b/v3.md @@ -18,6 +18,8 @@ Transaction isolation level constants are now typed strings instead of bare stri Conn.Pid changed to accessor method Conn.PID() +Remove Conn.TxStatus + ## TODO / Possible / Investigate Organize errors better From 15c4307c483663e0f592365399010b51f47cdb29 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 7 Jan 2017 14:18:52 -0600 Subject: [PATCH 016/264] Add note for strongly typed queries --- v3.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/v3.md b/v3.md index ca13055f..b64c734a 100644 --- a/v3.md +++ b/v3.md @@ -40,3 +40,5 @@ ValueReader / msgReader cleanup Make easier / possible to mock Conn or ConnPool (https://github.com/jackc/pgx/pull/162) Every field that should not be set by user should be replaced by accessor method (e.g. Conn.PID, Conn.SecretKey) + +Investigate strongly typed queries. i.e. Some sort of interface where varargs of Query, Exec, and Scan wouldn't happen. Need to be some low-level interface where (probably generated) functions could (more or less) directly read and write to the connection. From 79b28d24e22eb0a4ff41fc00b03a3e5e71a5b91d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 21 Jan 2017 16:01:21 -0600 Subject: [PATCH 017/264] Add todo to v3 --- v3.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/v3.md b/v3.md index b64c734a..f6bccb29 100644 --- a/v3.md +++ b/v3.md @@ -42,3 +42,5 @@ Make easier / possible to mock Conn or ConnPool (https://github.com/jackc/pgx/pu Every field that should not be set by user should be replaced by accessor method (e.g. Conn.PID, Conn.SecretKey) Investigate strongly typed queries. i.e. Some sort of interface where varargs of Query, Exec, and Scan wouldn't happen. Need to be some low-level interface where (probably generated) functions could (more or less) directly read and write to the connection. + +Reject scanning non-string like things into a string (e.g. binary encoded timestamptz to string). See https://github.com/jackc/pgx/issues/223 From 3cc6264dfdfea2cc01b05b8ca18ad733b26c7ec9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Feb 2017 11:31:11 -0600 Subject: [PATCH 018/264] Fix renamed constant --- query_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/query_test.go b/query_test.go index f5250589..15f57e49 100644 --- a/query_test.go +++ b/query_test.go @@ -298,7 +298,7 @@ type pgxNullInt64 struct { } func (n *pgxNullInt64) ScanPgx(vr *pgx.ValueReader) error { - if vr.Type().DataType != pgx.Int8Oid { + if vr.Type().DataType != pgx.Int8OID { return pgx.SerializationError(fmt.Sprintf("pgxNullInt64.Scan cannot decode OID %d", vr.Type().DataType)) } From 78adfb13d796427deafa89fc45aa5c7e47f8d51b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Feb 2017 14:20:00 -0600 Subject: [PATCH 019/264] Add Ping, PingContext, and ExecContext --- conn.go | 96 ++++++++++++++++++++++++++++++++++++++++++++++------ conn_test.go | 68 +++++++++++++++++++++++++++++++++++++ 2 files changed, 153 insertions(+), 11 deletions(-) diff --git a/conn.go b/conn.go index 602ecbff..645b9c5d 100644 --- a/conn.go +++ b/conn.go @@ -8,6 +8,7 @@ import ( "encoding/hex" "errors" "fmt" + "golang.org/x/net/context" "io" "net" "net/url" @@ -39,6 +40,22 @@ type ConnConfig struct { RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) } +func (cc *ConnConfig) networkAddress() (network, address string) { + network = "tcp" + address = fmt.Sprintf("%s:%d", cc.Host, cc.Port) + // See if host is a valid path, if yes connect with a socket + if _, err := os.Stat(cc.Host); err == nil { + // For backward compatibility accept socket file paths -- but directories are now preferred + network = "unix" + address = cc.Host + if !strings.Contains(address, "/.s.PGSQL.") { + address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10) + } + } + + return network, address +} + // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. // Use ConnPool to manage access to multiple database connections from multiple // goroutines. @@ -194,17 +211,7 @@ func connect(config ConnConfig, pgTypes map[Oid]PgType, pgsqlAfInet *byte, pgsql } } - network := "tcp" - address := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port) - // See if host is a valid path, if yes connect with a socket - if _, err := os.Stat(c.config.Host); err == nil { - // For backward compatibility accept socket file paths -- but directories are now preferred - network = "unix" - address = c.config.Host - if !strings.Contains(address, "/.s.PGSQL.") { - address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(c.config.Port), 10) - } - } + network, address := c.config.networkAddress() if c.config.Dial == nil { c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial } @@ -1292,3 +1299,70 @@ func (c *Conn) SetLogLevel(lvl int) (int, error) { func quoteIdentifier(s string) string { return `"` + strings.Replace(s, `"`, `""`, -1) + `"` } + +// cancelQuery sends a cancel request to the PostgreSQL server. It returns an +// error if unable to deliver the cancel request, but lack of an error does not +// ensure that the query was canceled. As specified in the documentation, there +// is no way to be sure a query was canceled. See +// https://www.postgresql.org/docs/current/static/protocol-flow.html#AEN112861 +func (c *Conn) cancelQuery() error { + network, address := c.config.networkAddress() + cancelConn, err := c.config.Dial(network, address) + if err != nil { + return err + } + defer cancelConn.Close() + + buf := make([]byte, 16) + binary.BigEndian.PutUint32(buf[0:4], 16) + binary.BigEndian.PutUint32(buf[4:8], 80877102) + binary.BigEndian.PutUint32(buf[8:12], uint32(c.Pid)) + binary.BigEndian.PutUint32(buf[12:16], uint32(c.SecretKey)) + _, err = cancelConn.Write(buf) + return err +} + +func (c *Conn) Ping() error { + _, err := c.Exec(";") + return err +} + +func (c *Conn) PingContext(ctx context.Context) error { + _, err := c.ExecContext(ctx, ";") + return err +} + +func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) { + select { + case <-ctx.Done(): + return "", ctx.Err() + default: + } + + doneChan := make(chan struct{}) + closedChan := make(chan bool) + + go func() { + select { + case <-ctx.Done(): + c.cancelQuery() + c.Close() + <-doneChan + closedChan <- true + case <-doneChan: + closedChan <- false + } + }() + + commandTag, err = c.Exec(sql, arguments...) + + // Signal cancelation goroutine that operation is done + doneChan <- struct{}{} + + // If c was closed due to context cancelation then return context err + if <-closedChan { + return "", ctx.Err() + } + + return commandTag, err +} diff --git a/conn_test.go b/conn_test.go index 9ed073ce..a9cf02c9 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3,6 +3,7 @@ package pgx_test import ( "crypto/tls" "fmt" + "golang.org/x/net/context" "net" "os" "reflect" @@ -816,6 +817,73 @@ func TestExecFailure(t *testing.T) { } } +func TestExecContextWithoutCancelation(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + commandTag, err := conn.ExecContext(ctx, "create temporary table foo(id integer primary key);") + if err != nil { + t.Fatal(err) + } + if commandTag != "CREATE TABLE" { + t.Fatalf("Unexpected results from ExecContext: %v", commandTag) + } +} + +func TestExecContextFailureWithoutCancelation(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + if _, err := conn.ExecContext(ctx, "selct;"); err == nil { + t.Fatal("Expected SQL syntax error") + } + + rows, _ := conn.Query("select 1") + rows.Close() + if rows.Err() != nil { + t.Fatalf("ExecContext failure appears to have broken connection: %v", rows.Err()) + } +} + +func TestExecContextCancelationCancelsQuery(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + time.Sleep(500 * time.Millisecond) + cancelFunc() + }() + + _, err := conn.ExecContext(ctx, "select pg_sleep(60)") + if err != context.Canceled { + t.Fatal("Expected context.Canceled err, got %v", err) + } + + time.Sleep(500 * time.Millisecond) + + checkConn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, checkConn) + + var found bool + err = checkConn.QueryRow("select true from pg_stat_activity where pid=$1", conn.Pid).Scan(&found) + if err != pgx.ErrNoRows { + t.Fatal("Expected context canceled connection to be disconnected from server, but it wasn't") + } +} + func TestPrepare(t *testing.T) { t.Parallel() From 3e13b333d9d3e2fa14f8e7e43ae041dcd6602433 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Feb 2017 15:40:58 -0600 Subject: [PATCH 020/264] Add QueryContext --- query.go | 48 ++++++++++++++++++++++ query_test.go | 111 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 159 insertions(+) diff --git a/query.go b/query.go index 19b867e2..121dcfe3 100644 --- a/query.go +++ b/query.go @@ -4,6 +4,7 @@ import ( "database/sql" "errors" "fmt" + "golang.org/x/net/context" "time" ) @@ -49,6 +50,9 @@ type Rows struct { afterClose func(*Rows) unlockConn bool closed bool + + ctx context.Context + doneChan chan struct{} } func (rows *Rows) FieldDescriptions() []FieldDescription { @@ -120,6 +124,15 @@ func (rows *Rows) Close() { return } rows.readUntilReadyForQuery() + + if rows.ctx != nil { + select { + case <-rows.ctx.Done(): + rows.err = rows.ctx.Err() + case rows.doneChan <- struct{}{}: + } + } + rows.close() } @@ -492,3 +505,38 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { rows, _ := c.Query(sql, args...) return (*Row)(rows) } + +func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + doneChan := make(chan struct{}) + + go func() { + select { + case <-ctx.Done(): + c.cancelQuery() + c.Close() + case <-doneChan: + } + }() + + rows, err := c.Query(sql, args...) + + if err != nil { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case doneChan <- struct{}{}: + return nil, err + } + } + + rows.ctx = ctx + rows.doneChan = doneChan + + return rows, nil +} diff --git a/query_test.go b/query_test.go index f08887b5..ca05fb42 100644 --- a/query_test.go +++ b/query_test.go @@ -4,6 +4,7 @@ import ( "bytes" "database/sql" "fmt" + "golang.org/x/net/context" "strings" "testing" "time" @@ -1412,3 +1413,113 @@ func TestConnQueryDatabaseSQLNullX(t *testing.T) { ensureConnValid(t, conn) } + +func TestQueryContextSuccess(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + rows, err := conn.QueryContext(ctx, "select 42::integer") + if err != nil { + t.Fatal(err) + } + + var result, rowCount int + for rows.Next() { + err = rows.Scan(&result) + if err != nil { + t.Fatal(err) + } + rowCount++ + } + + if rows.Err() != nil { + t.Fatal(rows.Err()) + } + + if rowCount != 1 { + t.Fatalf("Expected 1 row, got %d", rowCount) + } + if result != 42 { + t.Fatalf("Expected result 42, got %d", result) + } + + ensureConnValid(t, conn) +} + +func TestQueryContextErrorWhileReceivingRows(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + rows, err := conn.QueryContext(ctx, "select 10/(10-n) from generate_series(1, 100) n") + if err != nil { + t.Fatal(err) + } + + var result, rowCount int + for rows.Next() { + err = rows.Scan(&result) + if err != nil { + t.Fatal(err) + } + rowCount++ + } + + if rows.Err() == nil || rows.Err().Error() != "ERROR: division by zero (SQLSTATE 22012)" { + t.Fatalf("Expected division by zero error, but got %v", rows.Err()) + } + + if rowCount != 9 { + t.Fatalf("Expected 9 rows, got %d", rowCount) + } + if result != 10 { + t.Fatalf("Expected result 10, got %d", result) + } + + ensureConnValid(t, conn) +} + +func TestQueryContextCancelationCancelsQuery(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + time.Sleep(500 * time.Millisecond) + cancelFunc() + }() + + rows, err := conn.QueryContext(ctx, "select pg_sleep(5)") + if err != nil { + t.Fatal(err) + } + + for rows.Next() { + t.Fatal("No rows should ever be ready -- context cancel apparently did not happen") + } + + if rows.Err() != context.Canceled { + t.Fatal("Expected context.Canceled error, got %v", rows.Err()) + } + + checkConn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, checkConn) + + var found bool + err = checkConn.QueryRow("select true from pg_stat_activity where pid=$1", conn.Pid).Scan(&found) + if err != pgx.ErrNoRows { + t.Fatal("Expected context canceled connection to be disconnected from server, but it wasn't") + } + +} From 24193ee3223581d6593d5de2364f72839c73b5ba Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Feb 2017 15:57:06 -0600 Subject: [PATCH 021/264] Add QueryRowContext --- query.go | 15 ++++++------ query_test.go | 66 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 8 deletions(-) diff --git a/query.go b/query.go index 121dcfe3..fc3f405b 100644 --- a/query.go +++ b/query.go @@ -507,12 +507,6 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { } func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - doneChan := make(chan struct{}) go func() { @@ -529,9 +523,9 @@ func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{} if err != nil { select { case <-ctx.Done(): - return nil, ctx.Err() + return rows, ctx.Err() case doneChan <- struct{}{}: - return nil, err + return rows, err } } @@ -540,3 +534,8 @@ func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{} return rows, nil } + +func (c *Conn) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *Row { + rows, _ := c.QueryContext(ctx, sql, args...) + return (*Row)(rows) +} diff --git a/query_test.go b/query_test.go index ca05fb42..6909ba1e 100644 --- a/query_test.go +++ b/query_test.go @@ -1521,5 +1521,71 @@ func TestQueryContextCancelationCancelsQuery(t *testing.T) { if err != pgx.ErrNoRows { t.Fatal("Expected context canceled connection to be disconnected from server, but it wasn't") } +} +func TestQueryRowContextSuccess(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + var result int + err := conn.QueryRowContext(ctx, "select 42::integer").Scan(&result) + if err != nil { + t.Fatal(err) + } + if result != 42 { + t.Fatalf("Expected result 42, got %d", result) + } + + ensureConnValid(t, conn) +} + +func TestQueryRowContextErrorWhileReceivingRow(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + var result int + err := conn.QueryRowContext(ctx, "select 10/0").Scan(&result) + if err == nil || err.Error() != "ERROR: division by zero (SQLSTATE 22012)" { + t.Fatalf("Expected division by zero error, but got %v", err) + } + + ensureConnValid(t, conn) +} + +func TestQueryRowContextCancelationCancelsQuery(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + time.Sleep(500 * time.Millisecond) + cancelFunc() + }() + + var result []byte + err := conn.QueryRowContext(ctx, "select pg_sleep(5)").Scan(&result) + if err != context.Canceled { + t.Fatal("Expected context.Canceled error, got %v", err) + } + + checkConn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, checkConn) + + var found bool + err = checkConn.QueryRow("select true from pg_stat_activity where pid=$1", conn.Pid).Scan(&found) + if err != pgx.ErrNoRows { + t.Fatal("Expected context canceled connection to be disconnected from server, but it wasn't") + } } From a9e7e3acbc04145211116a11959c4db176a5df9a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Feb 2017 16:03:20 -0600 Subject: [PATCH 022/264] Extract connection dead on server test --- conn_test.go | 11 +---------- helper_test.go | 17 ++++++++++++++++- query_test.go | 18 ++---------------- 3 files changed, 19 insertions(+), 27 deletions(-) diff --git a/conn_test.go b/conn_test.go index a9cf02c9..e92c7ca3 100644 --- a/conn_test.go +++ b/conn_test.go @@ -872,16 +872,7 @@ func TestExecContextCancelationCancelsQuery(t *testing.T) { t.Fatal("Expected context.Canceled err, got %v", err) } - time.Sleep(500 * time.Millisecond) - - checkConn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, checkConn) - - var found bool - err = checkConn.QueryRow("select true from pg_stat_activity where pid=$1", conn.Pid).Scan(&found) - if err != pgx.ErrNoRows { - t.Fatal("Expected context canceled connection to be disconnected from server, but it wasn't") - } + ensureConnDeadOnServer(t, conn, *defaultConnConfig) } func TestPrepare(t *testing.T) { diff --git a/helper_test.go b/helper_test.go index eff731e8..997ae26f 100644 --- a/helper_test.go +++ b/helper_test.go @@ -21,7 +21,6 @@ func mustReplicationConnect(t testing.TB, config pgx.ConnConfig) *pgx.Replicatio return conn } - func closeConn(t testing.TB, conn *pgx.Conn) { err := conn.Close() if err != nil { @@ -72,3 +71,19 @@ func ensureConnValid(t *testing.T, conn *pgx.Conn) { t.Error("Wrong values returned") } } + +func ensureConnDeadOnServer(t *testing.T, conn *pgx.Conn, config pgx.ConnConfig) { + checkConn := mustConnect(t, config) + defer closeConn(t, checkConn) + + for i := 0; i < 10; i++ { + var found bool + err := checkConn.QueryRow("select true from pg_stat_activity where pid=$1", conn.Pid).Scan(&found) + if err == pgx.ErrNoRows { + return + } else if err != nil { + t.Fatalf("Unable to check if conn is dead on server: %v", err) + } + } + t.Fatal("Expected conn to be disconnected from server, but it wasn't") +} diff --git a/query_test.go b/query_test.go index 6909ba1e..40886f2e 100644 --- a/query_test.go +++ b/query_test.go @@ -1513,14 +1513,7 @@ func TestQueryContextCancelationCancelsQuery(t *testing.T) { t.Fatal("Expected context.Canceled error, got %v", rows.Err()) } - checkConn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, checkConn) - - var found bool - err = checkConn.QueryRow("select true from pg_stat_activity where pid=$1", conn.Pid).Scan(&found) - if err != pgx.ErrNoRows { - t.Fatal("Expected context canceled connection to be disconnected from server, but it wasn't") - } + ensureConnDeadOnServer(t, conn, *defaultConnConfig) } func TestQueryRowContextSuccess(t *testing.T) { @@ -1580,12 +1573,5 @@ func TestQueryRowContextCancelationCancelsQuery(t *testing.T) { t.Fatal("Expected context.Canceled error, got %v", err) } - checkConn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, checkConn) - - var found bool - err = checkConn.QueryRow("select true from pg_stat_activity where pid=$1", conn.Pid).Scan(&found) - if err != pgx.ErrNoRows { - t.Fatal("Expected context canceled connection to be disconnected from server, but it wasn't") - } + ensureConnDeadOnServer(t, conn, *defaultConnConfig) } From 94eea5128e3eb1f37f9b70771b0d9a68545839b5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Feb 2017 18:09:25 -0600 Subject: [PATCH 023/264] Add context dependency to travis --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index d9ea43b0..4a3b91e2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -51,6 +51,7 @@ install: - go get -u github.com/shopspring/decimal - go get -u gopkg.in/inconshreveable/log15.v2 - go get -u github.com/jackc/fake + - go get -u golang.org/x/net/context script: - go test -v -race -short ./... From 37b86083e4361243246805ceae845e36c9692e9b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Feb 2017 18:44:55 -0600 Subject: [PATCH 024/264] Fix race condition with canceled contexts --- conn.go | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/conn.go b/conn.go index 645b9c5d..45bb9441 100644 --- a/conn.go +++ b/conn.go @@ -18,6 +18,7 @@ import ( "regexp" "strconv" "strings" + "sync" "time" ) @@ -74,8 +75,6 @@ type Conn struct { preparedStatements map[string]*PreparedStatement channels map[string]struct{} notifications []*Notification - alive bool - causeOfDeath error logger Logger logLevel int mr msgReader @@ -85,6 +84,10 @@ type Conn struct { busy bool poolResetCount int preallocatedRows []Rows + + closingLock sync.Mutex + alive bool + causeOfDeath error } // PreparedStatement is a description of a prepared statement @@ -391,14 +394,14 @@ func (c *Conn) loadInetConstants() error { // Close closes a connection. It is safe to call Close on a already closed // connection. func (c *Conn) Close() (err error) { - if !c.IsAlive() { + c.closingLock.Lock() + defer c.closingLock.Unlock() + + if !c.alive { return nil } - wbuf := newWriteBuf(c, 'X') - wbuf.closeMsg() - - _, err = c.conn.Write(wbuf.buf) + _, err = c.conn.Write([]byte{'X', 0, 0, 0, 4}) c.die(errors.New("Closed")) if c.shouldLog(LogLevelInfo) { @@ -870,7 +873,10 @@ func (c *Conn) waitForNotification(deadline time.Time) (*Notification, error) { } func (c *Conn) IsAlive() bool { - return c.alive + c.closingLock.Lock() + alive := c.alive + c.closingLock.Unlock() + return alive } func (c *Conn) CauseOfDeath() error { From 14eedb4fcaa7eec18725aeb692346a1d2e883b30 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Feb 2017 21:10:13 -0600 Subject: [PATCH 025/264] Add ConnPool context methods --- conn.go | 3 +++ conn_pool.go | 34 ++++++++++++++++++++++++++++++++++ context-todo.txt | 12 ++++++++++++ query.go | 12 ++++++++---- stress_test.go | 43 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 100 insertions(+), 4 deletions(-) create mode 100644 context-todo.txt diff --git a/conn.go b/conn.go index 45bb9441..f7c06014 100644 --- a/conn.go +++ b/conn.go @@ -1051,9 +1051,12 @@ func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) { } func (c *Conn) rxMsg() (t byte, r *msgReader, err error) { + c.closingLock.Lock() if !c.alive { + c.closingLock.Unlock() return 0, nil, ErrDeadConn } + c.closingLock.Unlock() t, err = c.mr.rxMsg() if err != nil { diff --git a/conn_pool.go b/conn_pool.go index 6d04565d..50b9d588 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -2,6 +2,7 @@ package pgx import ( "errors" + "golang.org/x/net/context" "sync" "time" ) @@ -357,6 +358,16 @@ func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag Comman return c.Exec(sql, arguments...) } +func (p *ConnPool) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) { + var c *Conn + if c, err = p.Acquire(); err != nil { + return + } + defer p.Release(c) + + return c.ExecContext(ctx, sql, arguments...) +} + // Query acquires a connection and delegates the call to that connection. When // *Rows are closed, the connection is released automatically. func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) { @@ -377,6 +388,24 @@ func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) { return rows, nil } +func (p *ConnPool) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) { + c, err := p.Acquire() + if err != nil { + // Because checking for errors can be deferred to the *Rows, build one with the error + return &Rows{closed: true, err: err}, err + } + + rows, err := c.QueryContext(ctx, sql, args...) + if err != nil { + p.Release(c) + return rows, err + } + + rows.AfterClose(p.rowsAfterClose) + + return rows, nil +} + // QueryRow acquires a connection and delegates the call to that connection. The // connection is released automatically after Scan is called on the returned // *Row. @@ -385,6 +414,11 @@ func (p *ConnPool) QueryRow(sql string, args ...interface{}) *Row { return (*Row)(rows) } +func (p *ConnPool) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *Row { + rows, _ := p.QueryContext(ctx, sql, args...) + return (*Row)(rows) +} + // Begin acquires a connection and begins a transaction on it. When the // transaction is closed the connection will be automatically released. func (p *ConnPool) Begin() (*Tx, error) { diff --git a/context-todo.txt b/context-todo.txt new file mode 100644 index 00000000..b5a20d0a --- /dev/null +++ b/context-todo.txt @@ -0,0 +1,12 @@ +Add more testing +- stress test style +- pgmock + +Add documentation + +Add PrepareContext +Add context methods to ConnPool +Add context methods to Tx +Add context support database/sql + +Benchmark - possibly cache done channel on Conn diff --git a/query.go b/query.go index fc3f405b..3ded881d 100644 --- a/query.go +++ b/query.go @@ -51,8 +51,9 @@ type Rows struct { unlockConn bool closed bool - ctx context.Context - doneChan chan struct{} + ctx context.Context + doneChan chan struct{} + closedChan chan bool } func (rows *Rows) FieldDescriptions() []FieldDescription { @@ -127,7 +128,7 @@ func (rows *Rows) Close() { if rows.ctx != nil { select { - case <-rows.ctx.Done(): + case <-rows.closedChan: rows.err = rows.ctx.Err() case rows.doneChan <- struct{}{}: } @@ -508,12 +509,14 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) { doneChan := make(chan struct{}) + closedChan := make(chan bool) go func() { select { case <-ctx.Done(): c.cancelQuery() c.Close() + closedChan <- true case <-doneChan: } }() @@ -522,7 +525,7 @@ func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{} if err != nil { select { - case <-ctx.Done(): + case <-closedChan: return rows, ctx.Err() case doneChan <- struct{}{}: return rows, err @@ -531,6 +534,7 @@ func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{} rows.ctx = ctx rows.doneChan = doneChan + rows.closedChan = closedChan return rows, nil } diff --git a/stress_test.go b/stress_test.go index 150d13c8..d22d9d6b 100644 --- a/stress_test.go +++ b/stress_test.go @@ -3,6 +3,7 @@ package pgx_test import ( "errors" "fmt" + "golang.org/x/net/context" "math/rand" "testing" "time" @@ -44,6 +45,8 @@ func TestStressConnPool(t *testing.T) { {"listenAndPoolUnlistens", listenAndPoolUnlistens}, {"reset", func(p *pgx.ConnPool, n int) error { p.Reset(); return nil }}, {"poolPrepareUseAndDeallocate", poolPrepareUseAndDeallocate}, + {"canceledQueryContext", canceledQueryContext}, + {"canceledExecContext", canceledExecContext}, } var timer *time.Timer @@ -344,3 +347,43 @@ func txMultipleQueries(pool *pgx.ConnPool, actionNum int) error { return tx.Commit() } + +func canceledQueryContext(pool *pgx.ConnPool, actionNum int) error { + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond) + cancelFunc() + }() + + rows, err := pool.QueryContext(ctx, "select pg_sleep(5)") + if err == context.Canceled { + return nil + } else if err != nil { + return fmt.Errorf("canceledQueryContext: Only allowed error is context.Canceled, got %v", err) + } + + for rows.Next() { + return errors.New("canceledQueryContext: should never receive row") + } + + if rows.Err() != context.Canceled { + return fmt.Errorf("canceledQueryContext: Expected context.Canceled error, got %v", rows.Err()) + } + + return nil +} + +func canceledExecContext(pool *pgx.ConnPool, actionNum int) error { + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond) + cancelFunc() + }() + + _, err := pool.ExecContext(ctx, "select pg_sleep(5)") + if err != context.Canceled { + return fmt.Errorf("canceledExecContext: Expected context.Canceled error, got %v", err) + } + + return nil +} From 351eb8ba679c66de3a67db7da9e0cd06f6fecda8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 6 Feb 2017 19:39:34 -0600 Subject: [PATCH 026/264] Initial proof-of-concept database/sql context support --- conn.go | 52 ++++++++++++++++++++++++++++++++++++++++----------- stdlib/sql.go | 46 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+), 11 deletions(-) diff --git a/conn.go b/conn.go index f7c06014..b8131716 100644 --- a/conn.go +++ b/conn.go @@ -619,6 +619,41 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { // name and sql arguments. This allows a code path to PrepareEx and Query/Exec without // concern for if the statement has already been prepared. func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { + return c.PrepareExContext(context.Background(), name, sql, opts) + +} + +func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + doneChan := make(chan struct{}) + closedChan := make(chan struct{}) + + go func() { + select { + case <-ctx.Done(): + c.cancelQuery() + c.Close() + closedChan <- struct{}{} + case <-doneChan: + } + }() + + ps, err = c.prepareEx(name, sql, opts) + + select { + case <-closedChan: + return nil, ctx.Err() + case doneChan <- struct{}{}: + return ps, err + } +} + +func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { if name != "" { if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql { return ps, nil @@ -1349,29 +1384,24 @@ func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interfa } doneChan := make(chan struct{}) - closedChan := make(chan bool) + closedChan := make(chan struct{}) go func() { select { case <-ctx.Done(): c.cancelQuery() c.Close() - <-doneChan - closedChan <- true + closedChan <- struct{}{} case <-doneChan: - closedChan <- false } }() commandTag, err = c.Exec(sql, arguments...) - // Signal cancelation goroutine that operation is done - doneChan <- struct{}{} - - // If c was closed due to context cancelation then return context err - if <-closedChan { + select { + case <-closedChan: return "", ctx.Err() + case doneChan <- struct{}{}: + return commandTag, err } - - return commandTag, err } diff --git a/stdlib/sql.go b/stdlib/sql.go index 610aefd4..74218a7b 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -44,6 +44,7 @@ package stdlib import ( + "context" "database/sql" "database/sql/driver" "errors" @@ -211,6 +212,21 @@ func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) { return c.queryPrepared("", argsV) } +func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) { + if !c.conn.IsAlive() { + return nil, driver.ErrBadConn + } + + ps, err := c.conn.PrepareExContext(ctx, "", query, nil) + if err != nil { + return nil, err + } + + restrictBinaryToDatabaseSqlTypes(ps) + + return c.queryPreparedContext(ctx, "", argsV) +} + func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) { if !c.conn.IsAlive() { return nil, driver.ErrBadConn @@ -226,6 +242,24 @@ func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, er return &Rows{rows: rows}, nil } +func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []driver.NamedValue) (driver.Rows, error) { + if !c.conn.IsAlive() { + return nil, driver.ErrBadConn + } + + args := namedValueToInterface(argsV) + + rows, err := c.conn.QueryContext(ctx, name, args...) + if err != nil { + fmt.Println(err) + return nil, err + } + + fmt.Println("ere") + + return &Rows{rows: rows}, nil +} + // Anything that isn't a database/sql compatible type needs to be forced to // text format so that pgx.Rows.Values doesn't decode it into a native type // (e.g. []int32) @@ -318,6 +352,18 @@ func valueToInterface(argsV []driver.Value) []interface{} { return args } +func namedValueToInterface(argsV []driver.NamedValue) []interface{} { + args := make([]interface{}, 0, len(argsV)) + for _, v := range argsV { + if v.Value != nil { + args = append(args, v.Value.(interface{})) + } else { + args = append(args, nil) + } + } + return args +} + type Tx struct { conn *pgx.Conn } From 004c18e5a21c7837cb6dc578f22471115b29fdc8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 7 Feb 2017 20:35:37 -0600 Subject: [PATCH 027/264] Begin extracting context handling --- conn.go | 53 +++++++++++++++++++++++------------------------------ query.go | 27 ++++++--------------------- 2 files changed, 29 insertions(+), 51 deletions(-) diff --git a/conn.go b/conn.go index b8131716..453f1a51 100644 --- a/conn.go +++ b/conn.go @@ -88,6 +88,10 @@ type Conn struct { closingLock sync.Mutex alive bool causeOfDeath error + + // context support + doneChan chan struct{} + closedChan chan struct{} } // PreparedStatement is a description of a prepared statement @@ -257,6 +261,8 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.channels = make(map[string]struct{}) c.alive = true c.lastActivityTime = time.Now() + c.doneChan = make(chan struct{}) + c.closedChan = make(chan struct{}) if tlsConfig != nil { if c.shouldLog(LogLevelDebug) { @@ -619,8 +625,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { // name and sql arguments. This allows a code path to PrepareEx and Query/Exec without // concern for if the statement has already been prepared. func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { - return c.PrepareExContext(context.Background(), name, sql, opts) - + return c.prepareEx(name, sql, opts) } func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { @@ -630,25 +635,14 @@ func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *Pre default: } - doneChan := make(chan struct{}) - closedChan := make(chan struct{}) - - go func() { - select { - case <-ctx.Done(): - c.cancelQuery() - c.Close() - closedChan <- struct{}{} - case <-doneChan: - } - }() + go c.contextHandler(ctx) ps, err = c.prepareEx(name, sql, opts) select { - case <-closedChan: + case <-c.closedChan: return nil, ctx.Err() - case doneChan <- struct{}{}: + case c.doneChan <- struct{}{}: return ps, err } } @@ -1383,25 +1377,24 @@ func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interfa default: } - doneChan := make(chan struct{}) - closedChan := make(chan struct{}) - - go func() { - select { - case <-ctx.Done(): - c.cancelQuery() - c.Close() - closedChan <- struct{}{} - case <-doneChan: - } - }() + go c.contextHandler(ctx) commandTag, err = c.Exec(sql, arguments...) select { - case <-closedChan: + case <-c.closedChan: return "", ctx.Err() - case doneChan <- struct{}{}: + case c.doneChan <- struct{}{}: return commandTag, err } } + +func (c *Conn) contextHandler(ctx context.Context) { + select { + case <-ctx.Done(): + c.cancelQuery() + c.Close() + c.closedChan <- struct{}{} + case <-c.doneChan: + } +} diff --git a/query.go b/query.go index 3ded881d..daf1b354 100644 --- a/query.go +++ b/query.go @@ -51,9 +51,7 @@ type Rows struct { unlockConn bool closed bool - ctx context.Context - doneChan chan struct{} - closedChan chan bool + ctx context.Context } func (rows *Rows) FieldDescriptions() []FieldDescription { @@ -128,9 +126,9 @@ func (rows *Rows) Close() { if rows.ctx != nil { select { - case <-rows.closedChan: + case <-rows.conn.closedChan: rows.err = rows.ctx.Err() - case rows.doneChan <- struct{}{}: + case rows.conn.doneChan <- struct{}{}: } } @@ -508,33 +506,20 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { } func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) { - doneChan := make(chan struct{}) - closedChan := make(chan bool) - - go func() { - select { - case <-ctx.Done(): - c.cancelQuery() - c.Close() - closedChan <- true - case <-doneChan: - } - }() + go c.contextHandler(ctx) rows, err := c.Query(sql, args...) if err != nil { select { - case <-closedChan: + case <-c.closedChan: return rows, ctx.Err() - case doneChan <- struct{}{}: + case c.doneChan <- struct{}{}: return rows, err } } rows.ctx = ctx - rows.doneChan = doneChan - rows.closedChan = closedChan return rows, nil } From 72b6d32e2f841e6be96c5602c248b2875d345c3c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 7 Feb 2017 21:49:58 -0600 Subject: [PATCH 028/264] Extracted more context handling --- conn.go | 71 ++++++++++++++++++++++++++++++++++++---------------- conn_pool.go | 4 +++ query.go | 33 ++++++++---------------- 3 files changed, 64 insertions(+), 44 deletions(-) diff --git a/conn.go b/conn.go index 453f1a51..b662ba4c 100644 --- a/conn.go +++ b/conn.go @@ -90,8 +90,9 @@ type Conn struct { causeOfDeath error // context support - doneChan chan struct{} - closedChan chan struct{} + ctxInProgress bool + doneChan chan struct{} + closedChan chan error } // PreparedStatement is a description of a prepared statement @@ -262,7 +263,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.alive = true c.lastActivityTime = time.Now() c.doneChan = make(chan struct{}) - c.closedChan = make(chan struct{}) + c.closedChan = make(chan error) if tlsConfig != nil { if c.shouldLog(LogLevelDebug) { @@ -629,22 +630,14 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared } func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: + err = c.initContext(ctx) + if err != nil { + return nil, err } - go c.contextHandler(ctx) - ps, err = c.prepareEx(name, sql, opts) - - select { - case <-c.closedChan: - return nil, ctx.Err() - case c.doneChan <- struct{}{}: - return ps, err - } + err = c.termContext(err) + return ps, err } func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { @@ -1371,22 +1364,56 @@ func (c *Conn) PingContext(ctx context.Context) error { } func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) { + err = c.initContext(ctx) + if err != nil { + return "", err + } + + commandTag, err = c.Exec(sql, arguments...) + err = c.termContext(err) + return commandTag, err +} + +func (c *Conn) initContext(ctx context.Context) error { + if c.ctxInProgress { + return errors.New("ctx already in progress") + } + + if ctx.Done() == nil { + return nil + } + select { case <-ctx.Done(): - return "", ctx.Err() + return ctx.Err() default: } + c.ctxInProgress = true + go c.contextHandler(ctx) - commandTag, err = c.Exec(sql, arguments...) + return nil +} + +func (c *Conn) termContext(opErr error) error { + if !c.ctxInProgress { + return opErr + } + + var err error select { - case <-c.closedChan: - return "", ctx.Err() + case err = <-c.closedChan: + if opErr == nil { + err = nil + } case c.doneChan <- struct{}{}: - return commandTag, err + err = opErr } + + c.ctxInProgress = false + return err } func (c *Conn) contextHandler(ctx context.Context) { @@ -1394,7 +1421,7 @@ func (c *Conn) contextHandler(ctx context.Context) { case <-ctx.Done(): c.cancelQuery() c.Close() - c.closedChan <- struct{}{} + c.closedChan <- ctx.Err() case <-c.doneChan: } } diff --git a/conn_pool.go b/conn_pool.go index 50b9d588..2a243a76 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -182,6 +182,10 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) { // Release gives up use of a connection. func (p *ConnPool) Release(conn *Conn) { + if conn.ctxInProgress { + panic("should never release when context is in progress") + } + if conn.TxStatus != 'I' { conn.Exec("rollback") } diff --git a/query.go b/query.go index daf1b354..61136092 100644 --- a/query.go +++ b/query.go @@ -50,8 +50,6 @@ type Rows struct { afterClose func(*Rows) unlockConn bool closed bool - - ctx context.Context } func (rows *Rows) FieldDescriptions() []FieldDescription { @@ -84,6 +82,9 @@ func (rows *Rows) close() { } } +// TODO - consider inlining in Close(). This method calling rows.close is a +// foot-gun waiting to happen if anyone puts anything between the call to this +// and rows.close. func (rows *Rows) readUntilReadyForQuery() { for { t, r, err := rows.conn.rxMsg() @@ -122,16 +123,8 @@ func (rows *Rows) Close() { if rows.closed { return } + rows.err = rows.conn.termContext(rows.err) rows.readUntilReadyForQuery() - - if rows.ctx != nil { - select { - case <-rows.conn.closedChan: - rows.err = rows.ctx.Err() - case rows.conn.doneChan <- struct{}{}: - } - } - rows.close() } @@ -506,20 +499,16 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { } func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) { - go c.contextHandler(ctx) - - rows, err := c.Query(sql, args...) - + err := c.initContext(ctx) if err != nil { - select { - case <-c.closedChan: - return rows, ctx.Err() - case c.doneChan <- struct{}{}: - return rows, err - } + return nil, err } - rows.ctx = ctx + rows, err := c.Query(sql, args...) + if err != nil { + err = c.termContext(err) + return nil, err + } return rows, nil } From b8fdc38fa861830ab82c6325a019af83e9270913 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 9 Feb 2017 19:37:23 -0600 Subject: [PATCH 029/264] Only store Conn's *bufio.Reader in msgReader Confusing and redundant to have the same *bufio.Reader in msgReader and Conn. --- conn.go | 10 ++++------ replication.go | 2 +- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/conn.go b/conn.go index b662ba4c..7ecd18b2 100644 --- a/conn.go +++ b/conn.go @@ -61,9 +61,8 @@ func (cc *ConnConfig) networkAddress() (network, address string) { // Use ConnPool to manage access to multiple database connections from multiple // goroutines. type Conn struct { - conn net.Conn // the underlying TCP or unix domain socket connection - lastActivityTime time.Time // the last time the connection was used - reader *bufio.Reader // buffered reader to improve read performance + conn net.Conn // the underlying TCP or unix domain socket connection + lastActivityTime time.Time // the last time the connection was used wbuf [1024]byte writeBuf WriteBuf Pid int32 // backend pid @@ -274,8 +273,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl } } - c.reader = bufio.NewReader(c.conn) - c.mr.reader = c.reader + c.mr.reader = bufio.NewReader(c.conn) msg := newStartupMessage() @@ -862,7 +860,7 @@ func (c *Conn) waitForNotification(deadline time.Time) (*Notification, error) { } // Wait until there is a byte available before continuing onto the normal msg reading path - _, err = c.reader.Peek(1) + _, err = c.mr.reader.Peek(1) if err != nil { c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline if err, ok := err.(*net.OpError); ok && err.Timeout() { diff --git a/replication.go b/replication.go index 7b28d6b6..12a5c914 100644 --- a/replication.go +++ b/replication.go @@ -289,7 +289,7 @@ func (rc *ReplicationConn) WaitForReplicationMessage(timeout time.Duration) (r * } // Wait until there is a byte available before continuing onto the normal msg reading path - _, err = rc.c.reader.Peek(1) + _, err = rc.c.mr.reader.Peek(1) if err != nil { rc.c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline if err, ok := err.(*net.OpError); ok && err.Timeout() { From 855276e2cf09ce6e53ee0c8876422b7975bf0667 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 9 Feb 2017 19:40:01 -0600 Subject: [PATCH 030/264] Remove unused msgReader.Err() --- msg_reader.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/msg_reader.go b/msg_reader.go index 21db5d26..f7b497f7 100644 --- a/msg_reader.go +++ b/msg_reader.go @@ -16,11 +16,6 @@ type msgReader struct { shouldLog func(lvl int) bool } -// Err returns any error that the msgReader has experienced -func (r *msgReader) Err() error { - return r.err -} - // fatal tells rc that a Fatal error has occurred func (r *msgReader) fatal(err error) { if r.shouldLog(LogLevelTrace) { From 50b0bea9e57b9c6181b4318bf3f7a89b03cb6ea9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 9 Feb 2017 21:04:16 -0600 Subject: [PATCH 031/264] msgReader pre-buffers messages when possible --- msg_reader.go | 26 ++++++- msg_reader_test.go | 189 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 212 insertions(+), 3 deletions(-) create mode 100644 msg_reader_test.go diff --git a/msg_reader.go b/msg_reader.go index f7b497f7..1f4e67e9 100644 --- a/msg_reader.go +++ b/msg_reader.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "errors" "io" + "net" ) // msgReader is a helper that reads values from a PostgreSQL message. @@ -35,20 +36,39 @@ func (r *msgReader) rxMsg() (byte, error) { r.log(LogLevelTrace, "msgReader.rxMsg discarding unread previous message", "msgBytesRemaining", r.msgBytesRemaining) } - _, err := r.reader.Discard(int(r.msgBytesRemaining)) + n, err := r.reader.Discard(int(r.msgBytesRemaining)) + r.msgBytesRemaining -= int32(n) if err != nil { + if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { + r.fatal(err) + } return 0, err } } b, err := r.reader.Peek(5) if err != nil { - r.fatal(err) + if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { + r.fatal(err) + } return 0, err } + msgType := b[0] - r.msgBytesRemaining = int32(binary.BigEndian.Uint32(b[1:])) - 4 + payloadSize := int32(binary.BigEndian.Uint32(b[1:])) - 4 + + // Try to preload bufio.Reader with entire message + b, err = r.reader.Peek(5 + int(payloadSize)) + if err != nil && err != bufio.ErrBufferFull { + if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { + r.fatal(err) + } + return 0, err + } + + r.msgBytesRemaining = payloadSize r.reader.Discard(5) + return msgType, nil } diff --git a/msg_reader_test.go b/msg_reader_test.go new file mode 100644 index 00000000..2bbd53c9 --- /dev/null +++ b/msg_reader_test.go @@ -0,0 +1,189 @@ +package pgx + +import ( + "bufio" + "net" + "testing" + "time" + + "github.com/jackc/pgmock/pgmsg" +) + +func TestMsgReaderPrebuffersWhenPossible(t *testing.T) { + t.Parallel() + + tests := []struct { + msgType byte + payloadSize int32 + buffered bool + }{ + {1, 50, true}, + {2, 0, true}, + {3, 500, true}, + {4, 1050, true}, + {5, 1500, true}, + {6, 1500, true}, + {7, 4000, true}, + {8, 24000, false}, + {9, 4000, true}, + {1, 1500, true}, + {2, 0, true}, + {3, 500, true}, + {4, 1050, true}, + {5, 1500, true}, + {6, 1500, true}, + {7, 4000, true}, + {8, 14000, false}, + {9, 0, true}, + {1, 500, true}, + } + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + go func() { + var bigEndian pgmsg.BigEndianBuf + + conn, err := ln.Accept() + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + for _, tt := range tests { + _, err = conn.Write([]byte{tt.msgType}) + if err != nil { + t.Fatal(err) + } + + _, err = conn.Write(bigEndian.Int32(tt.payloadSize + 4)) + if err != nil { + t.Fatal(err) + } + + payload := make([]byte, int(tt.payloadSize)) + _, err = conn.Write(payload) + if err != nil { + t.Fatal(err) + } + } + }() + + conn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + mr := &msgReader{ + reader: bufio.NewReader(conn), + shouldLog: func(int) bool { return false }, + } + + for i, tt := range tests { + msgType, err := mr.rxMsg() + if err != nil { + t.Fatalf("%d. Unexpected error: %v", i, err) + } + + if msgType != tt.msgType { + t.Fatalf("%d. Expected %v, got %v", 1, i, tt.msgType, msgType) + } + + if mr.reader.Buffered() < int(tt.payloadSize) && tt.buffered { + t.Fatalf("%d. Expected message to be buffered with at least %d bytes, but only %v bytes buffered", i, tt.payloadSize, mr.reader.Buffered()) + } + } +} + +func TestMsgReaderDeadlineNeverInterruptsNormalSizedMessages(t *testing.T) { + t.Parallel() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + testCount := 10000 + + go func() { + var bigEndian pgmsg.BigEndianBuf + + conn, err := ln.Accept() + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + for i := 0; i < testCount; i++ { + msgType := byte(i) + + _, err = conn.Write([]byte{msgType}) + if err != nil { + t.Fatal(err) + } + + msgSize := i % 4000 + + _, err = conn.Write(bigEndian.Int32(int32(msgSize + 4))) + if err != nil { + t.Fatal(err) + } + + payload := make([]byte, msgSize) + _, err = conn.Write(payload) + if err != nil { + t.Fatal(err) + } + } + }() + + conn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + mr := &msgReader{ + reader: bufio.NewReader(conn), + shouldLog: func(int) bool { return false }, + } + + conn.SetReadDeadline(time.Now().Add(time.Millisecond)) + + i := 0 + for { + msgType, err := mr.rxMsg() + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + conn.SetReadDeadline(time.Now().Add(time.Millisecond)) + continue + } else { + t.Fatalf("%d. Unexpected error: %v", i, err) + } + } + + expectedMsgType := byte(i) + if msgType != expectedMsgType { + t.Fatalf("%d. Expected %v, got %v", i, expectedMsgType, msgType) + } + + expectedMsgSize := i % 4000 + payload := mr.readBytes(mr.msgBytesRemaining) + if mr.err != nil { + t.Fatalf("%d. readBytes killed msgReader: %v", i, mr.err) + } + if len(payload) != expectedMsgSize { + t.Fatalf("%d. Expected %v, got %v", i, expectedMsgSize, len(payload)) + } + + i++ + if i == testCount { + break + } + } +} From 09d37880bafc78b43a429610d8825b095e9f24df Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 9 Feb 2017 21:42:58 -0600 Subject: [PATCH 032/264] wip --- conn-lock-todo.txt | 11 +++++++++++ conn.go | 7 ++++++- 2 files changed, 17 insertions(+), 1 deletion(-) create mode 100644 conn-lock-todo.txt diff --git a/conn-lock-todo.txt b/conn-lock-todo.txt new file mode 100644 index 00000000..ab5eac95 --- /dev/null +++ b/conn-lock-todo.txt @@ -0,0 +1,11 @@ +Extract all locking state into a separate struct that will encapsulate locking and state change behavior. + +This struct should add or subsume at least the following: +* alive +* closingLock +* ctxInProgress (though this may be restructured because it's possible a Tx may have a ctx and a query run in that Tx could have one) +* busy +* lock/unlock +* Tx in-progress +* Rows in-progress +* ConnPool checked-out or checked-in - maybe include reference to conn pool diff --git a/conn.go b/conn.go index 7ecd18b2..78bdcedc 100644 --- a/conn.go +++ b/conn.go @@ -1403,6 +1403,9 @@ func (c *Conn) termContext(opErr error) error { select { case err = <-c.closedChan: + if dlErr := c.conn.SetDeadline(time.Time{}); dlErr != nil { + c.Close() // Close connection if unable to disable deadline + } if opErr == nil { err = nil } @@ -1418,7 +1421,9 @@ func (c *Conn) contextHandler(ctx context.Context) { select { case <-ctx.Done(): c.cancelQuery() - c.Close() + if err := c.conn.SetDeadline(time.Now()); err != nil { + c.Close() // Close connection if unable to set deadline + } c.closedChan <- ctx.Err() case <-c.doneChan: } From f0dfe4fe8926487e5772dade1decef121a7279ea Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Feb 2017 13:01:51 -0600 Subject: [PATCH 033/264] Merge alive and busy states into atomic status --- conn.go | 56 ++++++++++++++++++++++++++++---------------------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/conn.go b/conn.go index 78bdcedc..7243a4d1 100644 --- a/conn.go +++ b/conn.go @@ -18,10 +18,17 @@ import ( "regexp" "strconv" "strings" - "sync" + "sync/atomic" "time" ) +const ( + connStatusUninitialized = iota + connStatusClosed + connStatusIdle + connStatusBusy +) + // DialFunc is a function that can be used to connect to a PostgreSQL server type DialFunc func(network, addr string) (net.Conn, error) @@ -80,12 +87,10 @@ type Conn struct { fp *fastpath pgsqlAfInet *byte pgsqlAfInet6 *byte - busy bool poolResetCount int preallocatedRows []Rows - closingLock sync.Mutex - alive bool + status int32 // One of connStatus* constants causeOfDeath error // context support @@ -252,14 +257,14 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl defer func() { if c != nil && err != nil { c.conn.Close() - c.alive = false + atomic.StoreInt32(&c.status, connStatusClosed) } }() c.RuntimeParams = make(map[string]string) c.preparedStatements = make(map[string]*PreparedStatement) c.channels = make(map[string]struct{}) - c.alive = true + atomic.StoreInt32(&c.status, connStatusIdle) c.lastActivityTime = time.Now() c.doneChan = make(chan struct{}) c.closedChan = make(chan error) @@ -399,11 +404,14 @@ func (c *Conn) loadInetConstants() error { // Close closes a connection. It is safe to call Close on a already closed // connection. func (c *Conn) Close() (err error) { - c.closingLock.Lock() - defer c.closingLock.Unlock() - - if !c.alive { - return nil + for { + status := atomic.LoadInt32(&c.status) + if status < connStatusIdle { + return nil + } + if atomic.CompareAndSwapInt32(&c.status, status, connStatusClosed) { + break + } } _, err = c.conn.Write([]byte{'X', 0, 0, 0, 4}) @@ -893,10 +901,7 @@ func (c *Conn) waitForNotification(deadline time.Time) (*Notification, error) { } func (c *Conn) IsAlive() bool { - c.closingLock.Lock() - alive := c.alive - c.closingLock.Unlock() - return alive + return atomic.LoadInt32(&c.status) >= connStatusIdle } func (c *Conn) CauseOfDeath() error { @@ -1071,12 +1076,9 @@ func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) { } func (c *Conn) rxMsg() (t byte, r *msgReader, err error) { - c.closingLock.Lock() - if !c.alive { - c.closingLock.Unlock() + if atomic.LoadInt32(&c.status) < connStatusIdle { return 0, nil, ErrDeadConn } - c.closingLock.Unlock() t, err = c.mr.rxMsg() if err != nil { @@ -1261,25 +1263,23 @@ func (c *Conn) txPasswordMessage(password string) (err error) { } func (c *Conn) die(err error) { - c.alive = false + atomic.StoreInt32(&c.status, connStatusClosed) c.causeOfDeath = err c.conn.Close() } func (c *Conn) lock() error { - if c.busy { - return ErrConnBusy + if atomic.CompareAndSwapInt32(&c.status, connStatusIdle, connStatusBusy) { + return nil } - c.busy = true - return nil + return ErrConnBusy } func (c *Conn) unlock() error { - if !c.busy { - return errors.New("unlock conn that is not busy") + if atomic.CompareAndSwapInt32(&c.status, connStatusBusy, connStatusIdle) { + return nil } - c.busy = false - return nil + return errors.New("unlock conn that is not busy") } func (c *Conn) shouldLog(lvl int) bool { From e4f9108e8251f3a6e35c3bd698ad39273b172e9d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Feb 2017 14:59:16 -0600 Subject: [PATCH 034/264] wip --- conn.go | 87 +++++++++++++++++++++++++++++++++++++++++--------- conn_test.go | 2 +- copy_to.go | 1 - fastpath.go | 4 +++ helper_test.go | 16 ---------- query.go | 44 ++----------------------- query_test.go | 4 +-- 7 files changed, 82 insertions(+), 76 deletions(-) diff --git a/conn.go b/conn.go index 7243a4d1..f7443719 100644 --- a/conn.go +++ b/conn.go @@ -93,6 +93,8 @@ type Conn struct { status int32 // One of connStatus* constants causeOfDeath error + readyForQuery bool // can the connection be used to send a query + // context support ctxInProgress bool doneChan chan struct{} @@ -653,6 +655,10 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared } } + if err := c.ensureConnectionReadyForQuery(); err != nil { + return nil, err + } + if c.shouldLog(LogLevelError) { defer func() { if err != nil { @@ -692,6 +698,7 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared c.die(err) return nil, err } + c.readyForQuery = false ps = &PreparedStatement{Name: name, SQL: sql} @@ -706,7 +713,6 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared } switch t { - case parseComplete: case parameterDescription: ps.ParameterOids = c.rxParameterDescription(r) @@ -720,7 +726,6 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared ps.FieldDescriptions[i].DataTypeName = t.Name ps.FieldDescriptions[i].FormatCode = t.DefaultFormat } - case noData: case readyForQuery: c.rxReadyForQuery(r) @@ -739,6 +744,10 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared // Deallocate released a prepared statement func (c *Conn) Deallocate(name string) (err error) { + if err := c.ensureConnectionReadyForQuery(); err != nil { + return err + } + delete(c.preparedStatements, name) // close @@ -809,6 +818,10 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error) return notification, nil } + if err := c.ensureConnectionReadyForQuery(); err != nil { + return nil, err + } + stopTime := time.Now().Add(timeout) for { @@ -916,6 +929,9 @@ func (c *Conn) sendQuery(sql string, arguments ...interface{}) (err error) { } func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error { + if err := c.ensureConnectionReadyForQuery(); err != nil { + return err + } if len(args) == 0 { wbuf := newWriteBuf(c, 'Q') @@ -927,6 +943,7 @@ func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error { c.die(err) return err } + c.readyForQuery = false return nil } @@ -944,6 +961,10 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} return fmt.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOids), len(arguments)) } + if err := c.ensureConnectionReadyForQuery(); err != nil { + return err + } + // bind wbuf := newWriteBuf(c, 'B') wbuf.WriteByte(0) @@ -991,6 +1012,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} if err != nil { c.die(err) } + c.readyForQuery = false return err } @@ -1040,9 +1062,6 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag case readyForQuery: c.rxReadyForQuery(r) return commandTag, softErr - case rowDescription: - case dataRow: - case bindComplete: case commandComplete: commandTag = CommandTag(r.readCString()) default: @@ -1054,25 +1073,36 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag } // Processes messages that are not exclusive to one context such as -// authentication or query response. The response to these messages -// is the same regardless of when they occur. +// authentication or query response. The response to these messages is the same +// regardless of when they occur. It also ignores messages that are only +// meaningful in a given context. These messages can occur do to a context +// deadline interrupting message processing. For example, an interrupted query +// may have left DataRow messages on the wire. func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) { switch t { - case 'S': - c.rxParameterStatus(r) - return nil + case bindComplete: + case commandComplete: + case dataRow: + case emptyQueryResponse: case errorResponse: return c.rxErrorResponse(r) + case noData: case noticeResponse: - return nil - case emptyQueryResponse: - return nil case notificationResponse: c.rxNotificationResponse(r) - return nil + case parameterDescription: + case parseComplete: + case readyForQuery: + c.rxReadyForQuery(r) + case rowDescription: + case 'S': + c.rxParameterStatus(r) + default: return fmt.Errorf("Received unknown message type: %c", t) } + + return nil } func (c *Conn) rxMsg() (t byte, r *msgReader, err error) { @@ -1082,7 +1112,9 @@ func (c *Conn) rxMsg() (t byte, r *msgReader, err error) { t, err = c.mr.rxMsg() if err != nil { - c.die(err) + if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { + c.die(err) + } } c.lastActivityTime = time.Now() @@ -1183,6 +1215,7 @@ func (c *Conn) rxBackendKeyData(r *msgReader) { } func (c *Conn) rxReadyForQuery(r *msgReader) { + c.readyForQuery = true c.TxStatus = r.readByte() } @@ -1428,3 +1461,27 @@ func (c *Conn) contextHandler(ctx context.Context) { case <-c.doneChan: } } + +func (c *Conn) ensureConnectionReadyForQuery() error { + for !c.readyForQuery { + t, r, err := c.rxMsg() + if err != nil { + return err + } + + switch t { + case errorResponse: + pgErr := c.rxErrorResponse(r) + if pgErr.Severity == "FATAL" { + return pgErr + } + default: + err = c.processContextFreeMsg(t, r) + if err != nil { + return err + } + } + } + + return nil +} diff --git a/conn_test.go b/conn_test.go index e92c7ca3..ca39b4b4 100644 --- a/conn_test.go +++ b/conn_test.go @@ -872,7 +872,7 @@ func TestExecContextCancelationCancelsQuery(t *testing.T) { t.Fatal("Expected context.Canceled err, got %v", err) } - ensureConnDeadOnServer(t, conn, *defaultConnConfig) + ensureConnValid(t, conn) } func TestPrepare(t *testing.T) { diff --git a/copy_to.go b/copy_to.go index 91292bb0..dd70ada3 100644 --- a/copy_to.go +++ b/copy_to.go @@ -66,7 +66,6 @@ func (ct *copyTo) readUntilReadyForQuery() { ct.conn.rxReadyForQuery(r) close(ct.readerErrChan) return - case commandComplete: case errorResponse: ct.readerErrChan <- ct.conn.rxErrorResponse(r) default: diff --git a/fastpath.go b/fastpath.go index 19b98784..30a9f102 100644 --- a/fastpath.go +++ b/fastpath.go @@ -48,6 +48,10 @@ func fpInt64Arg(n int64) fpArg { } func (f *fastpath) Call(oid Oid, args []fpArg) (res []byte, err error) { + if err := f.cn.ensureConnectionReadyForQuery(); err != nil { + return nil, err + } + wbuf := newWriteBuf(f.cn, 'F') // function call wbuf.WriteInt32(int32(oid)) // function object id wbuf.WriteInt16(1) // # of argument format codes diff --git a/helper_test.go b/helper_test.go index 997ae26f..21f86de5 100644 --- a/helper_test.go +++ b/helper_test.go @@ -71,19 +71,3 @@ func ensureConnValid(t *testing.T, conn *pgx.Conn) { t.Error("Wrong values returned") } } - -func ensureConnDeadOnServer(t *testing.T, conn *pgx.Conn, config pgx.ConnConfig) { - checkConn := mustConnect(t, config) - defer closeConn(t, checkConn) - - for i := 0; i < 10; i++ { - var found bool - err := checkConn.QueryRow("select true from pg_stat_activity where pid=$1", conn.Pid).Scan(&found) - if err == pgx.ErrNoRows { - return - } else if err != nil { - t.Fatalf("Unable to check if conn is dead on server: %v", err) - } - } - t.Fatal("Expected conn to be disconnected from server, but it wasn't") -} diff --git a/query.go b/query.go index 61136092..b6470688 100644 --- a/query.go +++ b/query.go @@ -82,41 +82,6 @@ func (rows *Rows) close() { } } -// TODO - consider inlining in Close(). This method calling rows.close is a -// foot-gun waiting to happen if anyone puts anything between the call to this -// and rows.close. -func (rows *Rows) readUntilReadyForQuery() { - for { - t, r, err := rows.conn.rxMsg() - if err != nil { - rows.close() - return - } - - switch t { - case readyForQuery: - rows.conn.rxReadyForQuery(r) - rows.close() - return - case rowDescription: - case dataRow: - case commandComplete: - case bindComplete: - case errorResponse: - err = rows.conn.rxErrorResponse(r) - if rows.err == nil { - rows.err = err - } - default: - err = rows.conn.processContextFreeMsg(t, r) - if err != nil { - rows.close() - return - } - } - } -} - // Close closes the rows, making the connection ready for use again. It is safe // to call Close after rows is already closed. func (rows *Rows) Close() { @@ -124,7 +89,6 @@ func (rows *Rows) Close() { return } rows.err = rows.conn.termContext(rows.err) - rows.readUntilReadyForQuery() rows.close() } @@ -174,10 +138,6 @@ func (rows *Rows) Next() bool { } switch t { - case readyForQuery: - rows.conn.rxReadyForQuery(r) - rows.close() - return false case dataRow: fieldCount := r.readInt16() if int(fieldCount) != len(rows.fields) { @@ -188,7 +148,9 @@ func (rows *Rows) Next() bool { rows.mr = r return true case commandComplete: - case bindComplete: + rows.close() + return false + default: err = rows.conn.processContextFreeMsg(t, r) if err != nil { diff --git a/query_test.go b/query_test.go index 40886f2e..24310ab3 100644 --- a/query_test.go +++ b/query_test.go @@ -1513,7 +1513,7 @@ func TestQueryContextCancelationCancelsQuery(t *testing.T) { t.Fatal("Expected context.Canceled error, got %v", rows.Err()) } - ensureConnDeadOnServer(t, conn, *defaultConnConfig) + ensureConnValid(t, conn) } func TestQueryRowContextSuccess(t *testing.T) { @@ -1573,5 +1573,5 @@ func TestQueryRowContextCancelationCancelsQuery(t *testing.T) { t.Fatal("Expected context.Canceled error, got %v", err) } - ensureConnDeadOnServer(t, conn, *defaultConnConfig) + ensureConnValid(t, conn) } From 8cc480fc485a73281cdbcc41bc937a970133c0bf Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Feb 2017 18:44:27 -0600 Subject: [PATCH 035/264] Fix grammar --- conn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conn.go b/conn.go index f7443719..3ee0fe6b 100644 --- a/conn.go +++ b/conn.go @@ -1075,7 +1075,7 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag // Processes messages that are not exclusive to one context such as // authentication or query response. The response to these messages is the same // regardless of when they occur. It also ignores messages that are only -// meaningful in a given context. These messages can occur do to a context +// meaningful in a given context. These messages can occur due to a context // deadline interrupting message processing. For example, an interrupted query // may have left DataRow messages on the wire. func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) { From 9c74626d226753b61b8bdf0103749511975b6f70 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Feb 2017 18:44:39 -0600 Subject: [PATCH 036/264] Ping implemented in terms of PingContext --- conn.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/conn.go b/conn.go index 3ee0fe6b..51fab0e5 100644 --- a/conn.go +++ b/conn.go @@ -1385,8 +1385,7 @@ func (c *Conn) cancelQuery() error { } func (c *Conn) Ping() error { - _, err := c.Exec(";") - return err + return c.PingContext(context.Background()) } func (c *Conn) PingContext(ctx context.Context) error { From 6cdb58fc71181d84efb08496242dcab3ab4247fc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Feb 2017 18:46:46 -0600 Subject: [PATCH 037/264] Exec implemented in terms of ExecContext --- conn.go | 107 +++++++++++++++++++++++++++++--------------------------- 1 file changed, 55 insertions(+), 52 deletions(-) diff --git a/conn.go b/conn.go index 51fab0e5..5ede5944 100644 --- a/conn.go +++ b/conn.go @@ -1020,56 +1020,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} // Exec executes sql. sql can be either a prepared statement name or an SQL string. // arguments should be referenced positionally from the sql string as $1, $2, etc. func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) { - if err = c.lock(); err != nil { - return commandTag, err - } - - startTime := time.Now() - c.lastActivityTime = startTime - - defer func() { - if err == nil { - if c.shouldLog(LogLevelInfo) { - endTime := time.Now() - c.log(LogLevelInfo, "Exec", "sql", sql, "args", logQueryArgs(arguments), "time", endTime.Sub(startTime), "commandTag", commandTag) - } - } else { - if c.shouldLog(LogLevelError) { - c.log(LogLevelError, "Exec", "sql", sql, "args", logQueryArgs(arguments), "error", err) - } - } - - if unlockErr := c.unlock(); unlockErr != nil && err == nil { - err = unlockErr - } - }() - - if err = c.sendQuery(sql, arguments...); err != nil { - return - } - - var softErr error - - for { - var t byte - var r *msgReader - t, r, err = c.rxMsg() - if err != nil { - return commandTag, err - } - - switch t { - case readyForQuery: - c.rxReadyForQuery(r) - return commandTag, softErr - case commandComplete: - commandTag = CommandTag(r.readCString()) - default: - if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil { - softErr = e - } - } - } + return c.ExecContext(context.Background(), sql, arguments...) } // Processes messages that are not exclusive to one context such as @@ -1398,9 +1349,61 @@ func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interfa if err != nil { return "", err } + defer func() { + err = c.termContext(err) + }() + + if err = c.lock(); err != nil { + return commandTag, err + } + + startTime := time.Now() + c.lastActivityTime = startTime + + defer func() { + if err == nil { + if c.shouldLog(LogLevelInfo) { + endTime := time.Now() + c.log(LogLevelInfo, "Exec", "sql", sql, "args", logQueryArgs(arguments), "time", endTime.Sub(startTime), "commandTag", commandTag) + } + } else { + if c.shouldLog(LogLevelError) { + c.log(LogLevelError, "Exec", "sql", sql, "args", logQueryArgs(arguments), "error", err) + } + } + + if unlockErr := c.unlock(); unlockErr != nil && err == nil { + err = unlockErr + } + }() + + if err = c.sendQuery(sql, arguments...); err != nil { + return + } + + var softErr error + + for { + var t byte + var r *msgReader + t, r, err = c.rxMsg() + if err != nil { + return commandTag, err + } + + switch t { + case readyForQuery: + c.rxReadyForQuery(r) + return commandTag, softErr + case commandComplete: + commandTag = CommandTag(r.readCString()) + default: + if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil { + softErr = e + } + } + } - commandTag, err = c.Exec(sql, arguments...) - err = c.termContext(err) return commandTag, err } From deac6564eeb81e6ad3996b9e29f03854a8017f2d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Feb 2017 19:16:13 -0600 Subject: [PATCH 038/264] Implement Query in terms of QueryContext - Merge Rows.close into Rows.Close - Merge Rows.abort into Rows.Fatal --- query.go | 91 ++++++++++++++++++++------------------------------ replication.go | 6 ++-- 2 files changed, 39 insertions(+), 58 deletions(-) diff --git a/query.go b/query.go index b6470688..aa664649 100644 --- a/query.go +++ b/query.go @@ -56,7 +56,9 @@ func (rows *Rows) FieldDescriptions() []FieldDescription { return rows.fields } -func (rows *Rows) close() { +// Close closes the rows, making the connection ready for use again. It is safe +// to call Close after rows is already closed. +func (rows *Rows) Close() { if rows.closed { return } @@ -68,6 +70,8 @@ func (rows *Rows) close() { rows.closed = true + rows.err = rows.conn.termContext(rows.err) + if rows.err == nil { if rows.conn.shouldLog(LogLevelInfo) { endTime := time.Now() @@ -82,31 +86,10 @@ func (rows *Rows) close() { } } -// Close closes the rows, making the connection ready for use again. It is safe -// to call Close after rows is already closed. -func (rows *Rows) Close() { - if rows.closed { - return - } - rows.err = rows.conn.termContext(rows.err) - rows.close() -} - func (rows *Rows) Err() error { return rows.err } -// abort signals that the query was not successfully sent to the server. -// This differs from Fatal in that it is not necessary to readUntilReadyForQuery -func (rows *Rows) abort(err error) { - if rows.err != nil { - return - } - - rows.err = err - rows.close() -} - // Fatal signals an error occurred after the query was sent to the server. It // closes the rows automatically. func (rows *Rows) Fatal(err error) { @@ -148,7 +131,7 @@ func (rows *Rows) Next() bool { rows.mr = r return true case commandComplete: - rows.close() + rows.Close() return false default: @@ -408,32 +391,7 @@ func (rows *Rows) AfterClose(f func(*Rows)) { // be returned in an error state. So it is allowed to ignore the error returned // from Query and handle it in *Rows. func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) { - c.lastActivityTime = time.Now() - - rows := c.getRows(sql, args) - - if err := c.lock(); err != nil { - rows.abort(err) - return rows, err - } - rows.unlockConn = true - - ps, ok := c.preparedStatements[sql] - if !ok { - var err error - ps, err = c.Prepare("", sql) - if err != nil { - rows.abort(err) - return rows, rows.err - } - } - rows.sql = ps.SQL - rows.fields = ps.FieldDescriptions - err := c.sendPreparedQuery(ps, args...) - if err != nil { - rows.abort(err) - } - return rows, rows.err + return c.QueryContext(context.Background(), sql, args...) } func (c *Conn) getRows(sql string, args []interface{}) *Rows { @@ -460,19 +418,42 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { return (*Row)(rows) } -func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) { - err := c.initContext(ctx) +func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (rows *Rows, err error) { + c.lastActivityTime = time.Now() + + rows = c.getRows(sql, args) + + if err := c.lock(); err != nil { + rows.Fatal(err) + return rows, err + } + rows.unlockConn = true + + ps, ok := c.preparedStatements[sql] + if !ok { + var err error + ps, err = c.PrepareExContext(ctx, "", sql, nil) + if err != nil { + rows.Fatal(err) + return rows, rows.err + } + } + rows.sql = ps.SQL + rows.fields = ps.FieldDescriptions + + err = c.initContext(ctx) if err != nil { - return nil, err + rows.Fatal(err) + return rows, err } - rows, err := c.Query(sql, args...) + err = c.sendPreparedQuery(ps, args...) if err != nil { + rows.Fatal(err) err = c.termContext(err) - return nil, err } - return rows, nil + return rows, err } func (c *Conn) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *Row { diff --git a/replication.go b/replication.go index 12a5c914..0acc9df9 100644 --- a/replication.go +++ b/replication.go @@ -312,14 +312,14 @@ func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) { rows := rc.c.getRows(sql, nil) if err := rc.c.lock(); err != nil { - rows.abort(err) + rows.Fatal(err) return rows, err } rows.unlockConn = true err := rc.c.sendSimpleQuery(sql) if err != nil { - rows.abort(err) + rows.Fatal(err) } var t byte @@ -337,7 +337,7 @@ func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) { // only Oids. Not much we can do about this. default: if e := rc.c.processContextFreeMsg(t, r); e != nil { - rows.abort(e) + rows.Fatal(e) return rows, e } } From 048a75406f1139b19f1be31f3ec2f590c901fc8e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Feb 2017 19:53:18 -0600 Subject: [PATCH 039/264] Fix context query cancellation Previous commits had a race condition due to not waiting for the PostgreSQL server to close the cancel query connection. This made it possible for the cancel request to impact a subsequent query on the same connection. This commit sets a flag that a cancel request was made and blocks until the PostgreSQL server closes the cancel connection. --- conn.go | 128 ++++++++++++++++++++++++++++++++++++++++--------- query.go | 5 ++ stress_test.go | 14 +++--- 3 files changed, 118 insertions(+), 29 deletions(-) diff --git a/conn.go b/conn.go index 5ede5944..f91929c5 100644 --- a/conn.go +++ b/conn.go @@ -93,7 +93,9 @@ type Conn struct { status int32 // One of connStatus* constants causeOfDeath error - readyForQuery bool // can the connection be used to send a query + readyForQuery bool // connection has received ReadyForQuery message since last query was sent + cancelQueryInProgress int32 + cancelQueryCompleted chan struct{} // context support ctxInProgress bool @@ -268,6 +270,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.channels = make(map[string]struct{}) atomic.StoreInt32(&c.status, connStatusIdle) c.lastActivityTime = time.Now() + c.cancelQueryCompleted = make(chan struct{}, 1) c.doneChan = make(chan struct{}) c.closedChan = make(chan error) @@ -634,10 +637,15 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { // name and sql arguments. This allows a code path to PrepareEx and Query/Exec without // concern for if the statement has already been prepared. func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { - return c.prepareEx(name, sql, opts) + return c.PrepareExContext(context.Background(), name, sql, opts) } func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { + err = c.waitForPreviousCancelQuery(ctx) + if err != nil { + return nil, err + } + err = c.initContext(ctx) if err != nil { return nil, err @@ -743,7 +751,25 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared } // Deallocate released a prepared statement -func (c *Conn) Deallocate(name string) (err error) { +func (c *Conn) Deallocate(name string) error { + return c.deallocateContext(context.Background(), name) +} + +// TODO - consider making this public +func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) { + err = c.waitForPreviousCancelQuery(ctx) + if err != nil { + return err + } + + err = c.initContext(ctx) + if err != nil { + return err + } + defer func() { + err = c.termContext(err) + }() + if err := c.ensureConnectionReadyForQuery(); err != nil { return err } @@ -818,6 +844,13 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error) return notification, nil } + ctx, cancelFn := context.WithTimeout(context.Background(), timeout) + if err := c.waitForPreviousCancelQuery(ctx); err != nil { + cancelFn() + return nil, err + } + cancelFn() + if err := c.ensureConnectionReadyForQuery(); err != nil { return nil, err } @@ -1318,21 +1351,55 @@ func quoteIdentifier(s string) string { // ensure that the query was canceled. As specified in the documentation, there // is no way to be sure a query was canceled. See // https://www.postgresql.org/docs/current/static/protocol-flow.html#AEN112861 -func (c *Conn) cancelQuery() error { - network, address := c.config.networkAddress() - cancelConn, err := c.config.Dial(network, address) - if err != nil { - return err +func (c *Conn) cancelQuery() { + if !atomic.CompareAndSwapInt32(&c.cancelQueryInProgress, 0, 1) { + panic("cancelQuery when cancelQueryInProgress") } - defer cancelConn.Close() - buf := make([]byte, 16) - binary.BigEndian.PutUint32(buf[0:4], 16) - binary.BigEndian.PutUint32(buf[4:8], 80877102) - binary.BigEndian.PutUint32(buf[8:12], uint32(c.Pid)) - binary.BigEndian.PutUint32(buf[12:16], uint32(c.SecretKey)) - _, err = cancelConn.Write(buf) - return err + if err := c.conn.SetDeadline(time.Now()); err != nil { + c.Close() // Close connection if unable to set deadline + return + } + + doCancel := func() error { + network, address := c.config.networkAddress() + cancelConn, err := c.config.Dial(network, address) + if err != nil { + return err + } + defer cancelConn.Close() + + // If server doesn't process cancellation request in bounded time then abort. + err = cancelConn.SetDeadline(time.Now().Add(15 * time.Second)) + if err != nil { + return err + } + + buf := make([]byte, 16) + binary.BigEndian.PutUint32(buf[0:4], 16) + binary.BigEndian.PutUint32(buf[4:8], 80877102) + binary.BigEndian.PutUint32(buf[8:12], uint32(c.Pid)) + binary.BigEndian.PutUint32(buf[12:16], uint32(c.SecretKey)) + _, err = cancelConn.Write(buf) + if err != nil { + return err + } + + _, err = cancelConn.Read(buf) + if err != io.EOF { + return fmt.Errorf("Server failed to close connection after cancel query request: %v %v", err, buf) + } + + return nil + } + + go func() { + err := doCancel() + if err != nil { + c.Close() // Something is very wrong. Terminate the connection. + } + c.cancelQueryCompleted <- struct{}{} + }() } func (c *Conn) Ping() error { @@ -1345,6 +1412,11 @@ func (c *Conn) PingContext(ctx context.Context) error { } func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) { + err = c.waitForPreviousCancelQuery(ctx) + if err != nil { + return "", err + } + err = c.initContext(ctx) if err != nil { return "", err @@ -1438,9 +1510,6 @@ func (c *Conn) termContext(opErr error) error { select { case err = <-c.closedChan: - if dlErr := c.conn.SetDeadline(time.Time{}); dlErr != nil { - c.Close() // Close connection if unable to disable deadline - } if opErr == nil { err = nil } @@ -1456,14 +1525,29 @@ func (c *Conn) contextHandler(ctx context.Context) { select { case <-ctx.Done(): c.cancelQuery() - if err := c.conn.SetDeadline(time.Now()); err != nil { - c.Close() // Close connection if unable to set deadline - } c.closedChan <- ctx.Err() case <-c.doneChan: } } +func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error { + if atomic.LoadInt32(&c.cancelQueryInProgress) == 0 { + return nil + } + + select { + case <-c.cancelQueryCompleted: + atomic.StoreInt32(&c.cancelQueryInProgress, 0) + if err := c.conn.SetDeadline(time.Time{}); err != nil { + c.Close() // Close connection if unable to disable deadline + return err + } + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + func (c *Conn) ensureConnectionReadyForQuery() error { for !c.readyForQuery { t, r, err := c.rxMsg() diff --git a/query.go b/query.go index aa664649..dd7aafb0 100644 --- a/query.go +++ b/query.go @@ -419,6 +419,11 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { } func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (rows *Rows, err error) { + err = c.waitForPreviousCancelQuery(ctx) + if err != nil { + return nil, err + } + c.lastActivityTime = time.Now() rows = c.getRows(sql, args) diff --git a/stress_test.go b/stress_test.go index d22d9d6b..72d48a5c 100644 --- a/stress_test.go +++ b/stress_test.go @@ -66,7 +66,7 @@ func TestStressConnPool(t *testing.T) { action := actions[rand.Intn(len(actions))] err := action.fn(pool, n) if err != nil { - errChan <- err + errChan <- fmt.Errorf("%s: %v", action.name, err) break } } @@ -355,19 +355,19 @@ func canceledQueryContext(pool *pgx.ConnPool, actionNum int) error { cancelFunc() }() - rows, err := pool.QueryContext(ctx, "select pg_sleep(5)") + rows, err := pool.QueryContext(ctx, "select pg_sleep(2)") if err == context.Canceled { return nil } else if err != nil { - return fmt.Errorf("canceledQueryContext: Only allowed error is context.Canceled, got %v", err) + return fmt.Errorf("Only allowed error is context.Canceled, got %v", err) } for rows.Next() { - return errors.New("canceledQueryContext: should never receive row") + return errors.New("should never receive row") } if rows.Err() != context.Canceled { - return fmt.Errorf("canceledQueryContext: Expected context.Canceled error, got %v", rows.Err()) + return fmt.Errorf("Expected context.Canceled error, got %v", rows.Err()) } return nil @@ -380,9 +380,9 @@ func canceledExecContext(pool *pgx.ConnPool, actionNum int) error { cancelFunc() }() - _, err := pool.ExecContext(ctx, "select pg_sleep(5)") + _, err := pool.ExecContext(ctx, "select pg_sleep(2)") if err != context.Canceled { - return fmt.Errorf("canceledExecContext: Expected context.Canceled error, got %v", err) + return fmt.Errorf("Expected context.Canceled error, got %v", err) } return nil From d0a6921d124dfab48c89004e1a683bce180b795f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Feb 2017 20:40:28 -0600 Subject: [PATCH 040/264] Add dependency to travis.yml --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index 4a3b91e2..9ae8d963 100644 --- a/.travis.yml +++ b/.travis.yml @@ -52,6 +52,7 @@ install: - go get -u gopkg.in/inconshreveable/log15.v2 - go get -u github.com/jackc/fake - go get -u golang.org/x/net/context + - go get -u github.com/jackc/pgmock/pgmsg script: - go test -v -race -short ./... From cc414269c1bbca67c779c9798c13bb78c0a1843f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 12 Feb 2017 08:12:36 -0600 Subject: [PATCH 041/264] Remove debugging Println --- stdlib/sql.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/stdlib/sql.go b/stdlib/sql.go index 74218a7b..41c9d4dd 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -255,8 +255,6 @@ func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []dr return nil, err } - fmt.Println("ere") - return &Rows{rows: rows}, nil } From f597c16a7b7aca02710ec176ce9648aecc0a1734 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 12 Feb 2017 21:46:15 -0600 Subject: [PATCH 042/264] Add ChunkReader --- chunkreader/chunkreader.go | 106 ++++++++++++++++++++++++++ chunkreader/chunkreader_test.go | 128 ++++++++++++++++++++++++++++++++ 2 files changed, 234 insertions(+) create mode 100644 chunkreader/chunkreader.go create mode 100644 chunkreader/chunkreader_test.go diff --git a/chunkreader/chunkreader.go b/chunkreader/chunkreader.go new file mode 100644 index 00000000..f9d6555c --- /dev/null +++ b/chunkreader/chunkreader.go @@ -0,0 +1,106 @@ +package chunkreader + +import ( + "io" +) + +type ChunkReader struct { + r io.Reader + + buf []byte + rp, wp int // buf read position and write position + taken bool + + options Options +} + +type Options struct { + MinBufLen int // Minimum buffer length + BlockLen int // Increments to expand buffer (e.g. a 8000 byte request with a BlockLen of 1024 would yield a buffer len of 8192) +} + +func NewChunkReader(r io.Reader) *ChunkReader { + cr, err := NewChunkReaderEx(r, Options{}) + if err != nil { + panic("default options can't be bad") + } + + return cr +} + +func NewChunkReaderEx(r io.Reader, options Options) (*ChunkReader, error) { + if options.MinBufLen == 0 { + options.MinBufLen = 4096 + } + if options.BlockLen == 0 { + options.BlockLen = 512 + } + + return &ChunkReader{ + r: r, + buf: make([]byte, options.MinBufLen), + options: options, + }, nil +} + +// Next returns buf filled with the next n bytes. buf is only valid until the +// next call to Next. If an error occurs, buf will be nil. +func (r *ChunkReader) Next(n int) (buf []byte, err error) { + // n bytes already in buf + if (r.wp - r.rp) >= n { + buf = r.buf[r.rp : r.rp+n] + r.rp += n + return buf, err + } + + // available space in buf is less than n + if len(r.buf) < n { + r.copyBufContents(r.newBuf(n)) + r.taken = false + } + + // buf is large enough, but need to shift filled area to start to make enough contiguous space + minReadCount := n - (r.wp - r.rp) + if (len(r.buf) - r.wp) < minReadCount { + newBuf := r.buf + if r.taken { + newBuf = r.newBuf(n) + r.taken = false + } + r.copyBufContents(newBuf) + } + + if err := r.appendAtLeast(minReadCount); err != nil { + return nil, err + } + + buf = r.buf[r.rp : r.rp+n] + r.rp += n + return buf, nil +} + +// KeepLast prevents the last data retrieved by Next from being reused by the +// ChunkReader. +func (r *ChunkReader) KeepLast() { + r.taken = true +} + +func (r *ChunkReader) appendAtLeast(fillLen int) error { + n, err := io.ReadAtLeast(r.r, r.buf[r.wp:], fillLen) + r.wp += n + return err +} + +func (r *ChunkReader) newBuf(min int) []byte { + size := ((min / r.options.BlockLen) + 1) * r.options.BlockLen + if size < r.options.MinBufLen { + size = r.options.MinBufLen + } + return make([]byte, size) +} + +func (r *ChunkReader) copyBufContents(dest []byte) { + r.wp = copy(dest, r.buf[r.rp:r.wp]) + r.rp = 0 + r.buf = dest +} diff --git a/chunkreader/chunkreader_test.go b/chunkreader/chunkreader_test.go new file mode 100644 index 00000000..9c19ff4a --- /dev/null +++ b/chunkreader/chunkreader_test.go @@ -0,0 +1,128 @@ +package chunkreader + +import ( + "bytes" + "testing" +) + +func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) { + server := &bytes.Buffer{} + r, err := NewChunkReaderEx(server, Options{MinBufLen: 4, BlockLen: 2}) + if err != nil { + t.Fatal(err) + } + + src := []byte{1, 2, 3, 4} + server.Write(src) + + n1, err := r.Next(2) + if err != nil { + t.Fatal(err) + } + if bytes.Compare(n1, src[0:2]) != 0 { + t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:2], n1) + } + + n2, err := r.Next(2) + if err != nil { + t.Fatal(err) + } + if bytes.Compare(n2, src[2:4]) != 0 { + t.Fatalf("Expected read bytes to be %v, but they were %v", src[2:4], n2) + } + + if bytes.Compare(r.buf, src) != 0 { + t.Fatalf("Expected r.buf to be %v, but it was %v", src, r.buf) + } + if r.rp != 4 { + t.Fatalf("Expected r.rp to be %v, but it was %v", 4, r.rp) + } + if r.wp != 4 { + t.Fatalf("Expected r.wp to be %v, but it was %v", 4, r.wp) + } +} + +func TestChunkReaderNextExpandsBufAsNeeded(t *testing.T) { + server := &bytes.Buffer{} + r, err := NewChunkReaderEx(server, Options{MinBufLen: 4, BlockLen: 2}) + if err != nil { + t.Fatal(err) + } + + src := []byte{1, 2, 3, 4, 5, 6, 7, 8} + server.Write(src) + + n1, err := r.Next(5) + if err != nil { + t.Fatal(err) + } + if bytes.Compare(n1, src[0:5]) != 0 { + t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:5], n1) + } + if len(r.buf) != 6 { + t.Fatalf("Expected len(r.buf) to be %v, but it was %v", 6, len(r.buf)) + } +} + +func TestChunkReaderNextReusesBuf(t *testing.T) { + server := &bytes.Buffer{} + r, err := NewChunkReaderEx(server, Options{MinBufLen: 4, BlockLen: 1}) + if err != nil { + t.Fatal(err) + } + + src := []byte{1, 2, 3, 4, 5, 6, 7, 8} + server.Write(src) + + n1, err := r.Next(4) + if err != nil { + t.Fatal(err) + } + if bytes.Compare(n1, src[0:4]) != 0 { + t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:4], n1) + } + + n2, err := r.Next(4) + if err != nil { + t.Fatal(err) + } + if bytes.Compare(n2, src[4:8]) != 0 { + t.Fatalf("Expected read bytes to be %v, but they were %v", src[4:8], n2) + } + + if bytes.Compare(n1, src[4:8]) != 0 { + t.Fatalf("Expected Next to have reused buf, %v found instead of %v", src[4:8], n1) + } +} + +func TestChunkReaderKeepLastPreventsBufReuse(t *testing.T) { + server := &bytes.Buffer{} + r, err := NewChunkReaderEx(server, Options{MinBufLen: 4, BlockLen: 1}) + if err != nil { + t.Fatal(err) + } + + src := []byte{1, 2, 3, 4, 5, 6, 7, 8} + server.Write(src) + + n1, err := r.Next(4) + if err != nil { + t.Fatal(err) + } + if bytes.Compare(n1, src[0:4]) != 0 { + t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:4], n1) + } + r.KeepLast() + + n2, err := r.Next(4) + if err != nil { + t.Fatal(err) + } + if bytes.Compare(n2, src[4:8]) != 0 { + t.Fatalf("Expected read bytes to be %v, but they were %v", src[4:8], n2) + } + + if bytes.Compare(n1, src[0:4]) != 0 { + t.Fatalf("Expected KeepLast to prevent Next from overwriting buf, expected %v but it was %v", src[0:4], n1) + } +} From 84802ece05532943bb810dd6ad1f4bcc2f3fb0bf Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 13 Feb 2017 20:40:04 -0600 Subject: [PATCH 043/264] conn.Close closes underlying conn Previously, it merely sent the termination message. --- conn.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/conn.go b/conn.go index d1205636..07422a32 100644 --- a/conn.go +++ b/conn.go @@ -425,6 +425,11 @@ func (c *Conn) Close() (err error) { } _, err = c.conn.Write([]byte{'X', 0, 0, 0, 4}) + if err != nil && c.shouldLog(LogLevelWarn) { + c.log(LogLevelWarn, "Failed to send terminate message", "err", err) + } + + err = c.conn.Close() c.die(errors.New("Closed")) if c.shouldLog(LogLevelInfo) { From 11b82b3ca4bda887a7aca04e0ef1cc513798b744 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 13 Feb 2017 20:41:58 -0600 Subject: [PATCH 044/264] msgReader implemented in terms of ChunkReader This should substantially reduce memory allocations and memory copies. It also means that PostgreSQL messages are always entirely buffered in memory before processing begins. This simplifies the message processing code. In particular, Conn.WaitForNotification is dramatically simplified by this change. --- conn.go | 115 +++++++------------------ conn_test.go | 28 +++--- msg_reader.go | 202 ++++++++++++++------------------------------ msg_reader_test.go | 189 ----------------------------------------- replication.go | 48 +++-------- replication_test.go | 10 +-- stress_test.go | 5 +- 7 files changed, 130 insertions(+), 467 deletions(-) delete mode 100644 msg_reader_test.go diff --git a/conn.go b/conn.go index 07422a32..a8b0b22c 100644 --- a/conn.go +++ b/conn.go @@ -1,7 +1,6 @@ package pgx import ( - "bufio" "crypto/md5" "crypto/tls" "encoding/binary" @@ -20,6 +19,8 @@ import ( "strings" "sync/atomic" "time" + + "github.com/jackc/pgx/chunkreader" ) const ( @@ -283,7 +284,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl } } - c.mr.reader = bufio.NewReader(c.conn) + c.mr.cr = chunkreader.NewChunkReader(c.conn) msg := newStartupMessage() @@ -844,9 +845,8 @@ func (c *Conn) Unlisten(channel string) error { return nil } -// WaitForNotification waits for a PostgreSQL notification for up to timeout. -// If the timeout occurs it returns pgx.ErrNotificationTimeout -func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error) { +// WaitForNotification waits for a PostgreSQL notification. +func (c *Conn) WaitForNotification(ctx context.Context) (notification *Notification, err error) { // Return already received notification immediately if len(c.notifications) > 0 { notification := c.notifications[0] @@ -854,97 +854,40 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error) return notification, nil } - ctx, cancelFn := context.WithTimeout(context.Background(), timeout) - if err := c.waitForPreviousCancelQuery(ctx); err != nil { - cancelFn() + err = c.waitForPreviousCancelQuery(ctx) + if err != nil { return nil, err } - cancelFn() + + err = c.initContext(ctx) + if err != nil { + return nil, err + } + defer func() { + err = c.termContext(err) + }() + + if err = c.lock(); err != nil { + return nil, err + } + defer func() { + if unlockErr := c.unlock(); unlockErr != nil && err == nil { + err = unlockErr + } + }() if err := c.ensureConnectionReadyForQuery(); err != nil { return nil, err } - stopTime := time.Now().Add(timeout) - for { - now := time.Now() - - if now.After(stopTime) { - return nil, ErrNotificationTimeout - } - - // If there has been no activity on this connection for a while send a nop message just to ensure - // the connection is alive - nextEnsureAliveTime := c.lastActivityTime.Add(15 * time.Second) - if nextEnsureAliveTime.Before(now) { - // If the server can't respond to a nop in 15 seconds, assume it's dead - err := c.conn.SetReadDeadline(now.Add(15 * time.Second)) - if err != nil { - return nil, err - } - - _, err = c.Exec("--;") - if err != nil { - return nil, err - } - - c.lastActivityTime = now - } - - var deadline time.Time - if stopTime.Before(nextEnsureAliveTime) { - deadline = stopTime - } else { - deadline = nextEnsureAliveTime - } - - notification, err := c.waitForNotification(deadline) - if err != ErrNotificationTimeout { - return notification, err - } - } -} - -func (c *Conn) waitForNotification(deadline time.Time) (*Notification, error) { - var zeroTime time.Time - - for { - // Use SetReadDeadline to implement the timeout. SetReadDeadline will - // cause operations to fail with a *net.OpError that has a Timeout() - // of true. Because the normal pgx rxMsg path considers any error to - // have potentially corrupted the state of the connection, it dies - // on any errors. So to avoid timeout errors in rxMsg we set the - // deadline and peek into the reader. If a timeout error occurs there - // we don't break the pgx connection. If the Peek returns that data - // is available then we turn off the read deadline before the rxMsg. - err := c.conn.SetReadDeadline(deadline) + t, r, err := c.rxMsg() if err != nil { return nil, err } - // Wait until there is a byte available before continuing onto the normal msg reading path - _, err = c.mr.reader.Peek(1) + err = c.processContextFreeMsg(t, r) if err != nil { - c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline - if err, ok := err.(*net.OpError); ok && err.Timeout() { - return nil, ErrNotificationTimeout - } - return nil, err - } - - err = c.conn.SetReadDeadline(zeroTime) - if err != nil { - return nil, err - } - - var t byte - var r *msgReader - if t, r, err = c.rxMsg(); err == nil { - if err = c.processContextFreeMsg(t, r); err != nil { - return nil, err - } - } else { return nil, err } @@ -1114,7 +1057,7 @@ func (c *Conn) rxMsg() (t byte, r *msgReader, err error) { c.lastActivityTime = time.Now() if c.shouldLog(LogLevelTrace) { - c.log(LogLevelTrace, "rxMsg", "type", string(t), "msgBytesRemaining", c.mr.msgBytesRemaining) + c.log(LogLevelTrace, "rxMsg", "type", string(t), "msgBodyLen", len(c.mr.msgBody)) } return t, &c.mr, err @@ -1236,11 +1179,11 @@ func (c *Conn) rxParameterDescription(r *msgReader) (parameters []OID) { // wrong. So read the count, ignore it, and compute the proper value from // the size of the message. r.readInt16() - parameterCount := r.msgBytesRemaining / 4 + parameterCount := len(r.msgBody[r.rp:]) / 4 parameters = make([]OID, 0, parameterCount) - for i := int32(0); i < parameterCount; i++ { + for i := 0; i < parameterCount; i++ { parameters = append(parameters, r.readOID()) } return diff --git a/conn_test.go b/conn_test.go index a8398507..63b486a6 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1084,7 +1084,7 @@ func TestListenNotify(t *testing.T) { mustExec(t, notifier, "notify chat") // when notification is waiting on the socket to be read - notification, err := listener.WaitForNotification(time.Second) + notification, err := listener.WaitForNotification(context.Background()) if err != nil { t.Fatalf("Unexpected error on WaitForNotification: %v", err) } @@ -1099,7 +1099,10 @@ func TestListenNotify(t *testing.T) { if rows.Err() != nil { t.Fatalf("Unexpected error on Query: %v", rows.Err()) } - notification, err = listener.WaitForNotification(0) + + ctx, cancelFn := context.WithCancel(context.Background()) + cancelFn() + notification, err = listener.WaitForNotification(ctx) if err != nil { t.Fatalf("Unexpected error on WaitForNotification: %v", err) } @@ -1108,8 +1111,9 @@ func TestListenNotify(t *testing.T) { } // when timeout occurs - notification, err = listener.WaitForNotification(time.Millisecond) - if err != pgx.ErrNotificationTimeout { + ctx, _ = context.WithTimeout(context.Background(), time.Millisecond) + notification, err = listener.WaitForNotification(ctx) + if err != context.DeadlineExceeded { t.Errorf("WaitForNotification returned the wrong kind of error: %v", err) } if notification != nil { @@ -1118,7 +1122,7 @@ func TestListenNotify(t *testing.T) { // listener can listen again after a timeout mustExec(t, notifier, "notify chat") - notification, err = listener.WaitForNotification(time.Second) + notification, err = listener.WaitForNotification(context.Background()) if err != nil { t.Fatalf("Unexpected error on WaitForNotification: %v", err) } @@ -1143,7 +1147,7 @@ func TestUnlistenSpecificChannel(t *testing.T) { mustExec(t, notifier, "notify unlisten_test") // when notification is waiting on the socket to be read - notification, err := listener.WaitForNotification(time.Second) + notification, err := listener.WaitForNotification(context.Background()) if err != nil { t.Fatalf("Unexpected error on WaitForNotification: %v", err) } @@ -1163,8 +1167,10 @@ func TestUnlistenSpecificChannel(t *testing.T) { if rows.Err() != nil { t.Fatalf("Unexpected error on Query: %v", rows.Err()) } - notification, err = listener.WaitForNotification(100 * time.Millisecond) - if err != pgx.ErrNotificationTimeout { + + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + notification, err = listener.WaitForNotification(ctx) + if err != context.DeadlineExceeded { t.Errorf("WaitForNotification returned the wrong kind of error: %v", err) } } @@ -1246,7 +1252,8 @@ func TestListenNotifySelfNotification(t *testing.T) { // Notify self and WaitForNotification immediately mustExec(t, conn, "notify self") - notification, err := conn.WaitForNotification(time.Second) + ctx, _ := context.WithTimeout(context.Background(), time.Second) + notification, err := conn.WaitForNotification(ctx) if err != nil { t.Fatalf("Unexpected error on WaitForNotification: %v", err) } @@ -1263,7 +1270,8 @@ func TestListenNotifySelfNotification(t *testing.T) { t.Fatalf("Unexpected error on Query: %v", rows.Err()) } - notification, err = conn.WaitForNotification(time.Second) + ctx, _ = context.WithTimeout(context.Background(), time.Second) + notification, err = conn.WaitForNotification(ctx) if err != nil { t.Fatalf("Unexpected error on WaitForNotification: %v", err) } diff --git a/msg_reader.go b/msg_reader.go index f507c198..53e944bb 100644 --- a/msg_reader.go +++ b/msg_reader.go @@ -1,26 +1,29 @@ package pgx import ( - "bufio" + "bytes" "encoding/binary" "errors" - "io" "net" + + "github.com/jackc/pgx/chunkreader" ) // msgReader is a helper that reads values from a PostgreSQL message. type msgReader struct { - reader *bufio.Reader - msgBytesRemaining int32 - err error - log func(lvl int, msg string, ctx ...interface{}) - shouldLog func(lvl int) bool + cr *chunkreader.ChunkReader + msgType byte + msgBody []byte + rp int // read position + err error + log func(lvl int, msg string, ctx ...interface{}) + shouldLog func(lvl int) bool } // fatal tells rc that a Fatal error has occurred func (r *msgReader) fatal(err error) { if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgType", r.msgType, "msgBody", r.msgBody, "rp", r.rp) } r.err = err } @@ -31,22 +34,7 @@ func (r *msgReader) rxMsg() (byte, error) { return 0, r.err } - if r.msgBytesRemaining > 0 { - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.rxMsg discarding unread previous message", "msgBytesRemaining", r.msgBytesRemaining) - } - - n, err := r.reader.Discard(int(r.msgBytesRemaining)) - r.msgBytesRemaining -= int32(n) - if err != nil { - if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { - r.fatal(err) - } - return 0, err - } - } - - b, err := r.reader.Peek(5) + header, err := r.cr.Next(5) if err != nil { if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { r.fatal(err) @@ -54,22 +42,20 @@ func (r *msgReader) rxMsg() (byte, error) { return 0, err } - msgType := b[0] - payloadSize := int32(binary.BigEndian.Uint32(b[1:])) - 4 + r.msgType = header[0] + bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4 - // Try to preload bufio.Reader with entire message - b, err = r.reader.Peek(5 + int(payloadSize)) - if err != nil && err != bufio.ErrBufferFull { + r.msgBody, err = r.cr.Next(bodyLen) + if err != nil { if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { r.fatal(err) } return 0, err } - r.msgBytesRemaining = payloadSize - r.reader.Discard(5) + r.rp = 0 - return msgType, nil + return r.msgType, nil } func (r *msgReader) readByte() byte { @@ -77,20 +63,16 @@ func (r *msgReader) readByte() byte { return 0 } - r.msgBytesRemaining-- - if r.msgBytesRemaining < 0 { + if len(r.msgBody)-r.rp < 1 { r.fatal(errors.New("read past end of message")) return 0 } - b, err := r.reader.ReadByte() - if err != nil { - r.fatal(err) - return 0 - } + b := r.msgBody[r.rp] + r.rp++ if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readByte", "value", b, "byteAsString", string(b), "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.readByte", "value", b, "byteAsString", string(b), "msgType", r.msgType, "rp", r.rp) } return b @@ -101,24 +83,16 @@ func (r *msgReader) readInt16() int16 { return 0 } - r.msgBytesRemaining -= 2 - if r.msgBytesRemaining < 0 { + if len(r.msgBody)-r.rp < 2 { r.fatal(errors.New("read past end of message")) return 0 } - b, err := r.reader.Peek(2) - if err != nil { - r.fatal(err) - return 0 - } - - n := int16(binary.BigEndian.Uint16(b)) - - r.reader.Discard(2) + n := int16(binary.BigEndian.Uint16(r.msgBody[r.rp:])) + r.rp += 2 if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgType", r.msgType, "rp", r.rp) } return n @@ -129,24 +103,16 @@ func (r *msgReader) readInt32() int32 { return 0 } - r.msgBytesRemaining -= 4 - if r.msgBytesRemaining < 0 { + if len(r.msgBody)-r.rp < 4 { r.fatal(errors.New("read past end of message")) return 0 } - b, err := r.reader.Peek(4) - if err != nil { - r.fatal(err) - return 0 - } - - n := int32(binary.BigEndian.Uint32(b)) - - r.reader.Discard(4) + n := int32(binary.BigEndian.Uint32(r.msgBody[r.rp:])) + r.rp += 4 if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgType", r.msgType, "rp", r.rp) } return n @@ -157,24 +123,16 @@ func (r *msgReader) readUint16() uint16 { return 0 } - r.msgBytesRemaining -= 2 - if r.msgBytesRemaining < 0 { + if len(r.msgBody)-r.rp < 2 { r.fatal(errors.New("read past end of message")) return 0 } - b, err := r.reader.Peek(2) - if err != nil { - r.fatal(err) - return 0 - } - - n := uint16(binary.BigEndian.Uint16(b)) - - r.reader.Discard(2) + n := binary.BigEndian.Uint16(r.msgBody[r.rp:]) + r.rp += 2 if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgType", r.msgType, "rp", r.rp) } return n @@ -185,24 +143,16 @@ func (r *msgReader) readUint32() uint32 { return 0 } - r.msgBytesRemaining -= 4 - if r.msgBytesRemaining < 0 { + if len(r.msgBody)-r.rp < 4 { r.fatal(errors.New("read past end of message")) return 0 } - b, err := r.reader.Peek(4) - if err != nil { - r.fatal(err) - return 0 - } - - n := uint32(binary.BigEndian.Uint32(b)) - - r.reader.Discard(4) + n := binary.BigEndian.Uint32(r.msgBody[r.rp:]) + r.rp += 4 if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readUint32", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.readUint32", "value", n, "msgType", r.msgType, "rp", r.rp) } return n @@ -213,24 +163,16 @@ func (r *msgReader) readInt64() int64 { return 0 } - r.msgBytesRemaining -= 8 - if r.msgBytesRemaining < 0 { + if len(r.msgBody)-r.rp < 8 { r.fatal(errors.New("read past end of message")) return 0 } - b, err := r.reader.Peek(8) - if err != nil { - r.fatal(err) - return 0 - } - - n := int64(binary.BigEndian.Uint64(b)) - - r.reader.Discard(8) + n := int64(binary.BigEndian.Uint64(r.msgBody[r.rp:])) + r.rp += 8 if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgType", r.msgType, "rp", r.rp) } return n @@ -246,22 +188,17 @@ func (r *msgReader) readCString() string { return "" } - b, err := r.reader.ReadBytes(0) - if err != nil { - r.fatal(err) + nullIdx := bytes.IndexByte(r.msgBody[r.rp:], 0) + if nullIdx == -1 { + r.fatal(errors.New("null terminated string not found")) return "" } - r.msgBytesRemaining -= int32(len(b)) - if r.msgBytesRemaining < 0 { - r.fatal(errors.New("read past end of message")) - return "" - } - - s := string(b[0 : len(b)-1]) + s := string(r.msgBody[r.rp : r.rp+nullIdx]) + r.rp += nullIdx + 1 if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readCString", "value", s, "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.readCString", "value", s, "msgType", r.msgType, "rp", r.rp) } return s @@ -273,58 +210,43 @@ func (r *msgReader) readString(countI32 int32) string { return "" } - r.msgBytesRemaining -= countI32 - if r.msgBytesRemaining < 0 { + count := int(countI32) + + if len(r.msgBody)-r.rp < count { r.fatal(errors.New("read past end of message")) return "" } - count := int(countI32) - var s string - - if r.reader.Buffered() >= count { - buf, _ := r.reader.Peek(count) - s = string(buf) - r.reader.Discard(count) - } else { - buf := make([]byte, count) - _, err := io.ReadFull(r.reader, buf) - if err != nil { - r.fatal(err) - return "" - } - s = string(buf) - } + s := string(r.msgBody[r.rp : r.rp+count]) + r.rp += count if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgType", r.msgType, "rp", r.rp) } return s } // readBytes reads count bytes and returns as []byte -func (r *msgReader) readBytes(count int32) []byte { +func (r *msgReader) readBytes(countI32 int32) []byte { if r.err != nil { return nil } - r.msgBytesRemaining -= count - if r.msgBytesRemaining < 0 { + count := int(countI32) + + if len(r.msgBody)-r.rp < count { r.fatal(errors.New("read past end of message")) return nil } - b := make([]byte, int(count)) + b := r.msgBody[r.rp : r.rp+count] + r.rp += count - _, err := io.ReadFull(r.reader, b) - if err != nil { - r.fatal(err) - return nil - } + r.cr.KeepLast() if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readBytes", "value", b, "msgBytesRemaining", r.msgBytesRemaining) + r.log(LogLevelTrace, "msgReader.readBytes", "value", b, r.msgType, "rp", r.rp) } return b diff --git a/msg_reader_test.go b/msg_reader_test.go deleted file mode 100644 index 2bbd53c9..00000000 --- a/msg_reader_test.go +++ /dev/null @@ -1,189 +0,0 @@ -package pgx - -import ( - "bufio" - "net" - "testing" - "time" - - "github.com/jackc/pgmock/pgmsg" -) - -func TestMsgReaderPrebuffersWhenPossible(t *testing.T) { - t.Parallel() - - tests := []struct { - msgType byte - payloadSize int32 - buffered bool - }{ - {1, 50, true}, - {2, 0, true}, - {3, 500, true}, - {4, 1050, true}, - {5, 1500, true}, - {6, 1500, true}, - {7, 4000, true}, - {8, 24000, false}, - {9, 4000, true}, - {1, 1500, true}, - {2, 0, true}, - {3, 500, true}, - {4, 1050, true}, - {5, 1500, true}, - {6, 1500, true}, - {7, 4000, true}, - {8, 14000, false}, - {9, 0, true}, - {1, 500, true}, - } - - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatal(err) - } - defer ln.Close() - - go func() { - var bigEndian pgmsg.BigEndianBuf - - conn, err := ln.Accept() - if err != nil { - t.Fatal(err) - } - defer conn.Close() - - for _, tt := range tests { - _, err = conn.Write([]byte{tt.msgType}) - if err != nil { - t.Fatal(err) - } - - _, err = conn.Write(bigEndian.Int32(tt.payloadSize + 4)) - if err != nil { - t.Fatal(err) - } - - payload := make([]byte, int(tt.payloadSize)) - _, err = conn.Write(payload) - if err != nil { - t.Fatal(err) - } - } - }() - - conn, err := net.Dial("tcp", ln.Addr().String()) - if err != nil { - t.Fatal(err) - } - defer conn.Close() - - mr := &msgReader{ - reader: bufio.NewReader(conn), - shouldLog: func(int) bool { return false }, - } - - for i, tt := range tests { - msgType, err := mr.rxMsg() - if err != nil { - t.Fatalf("%d. Unexpected error: %v", i, err) - } - - if msgType != tt.msgType { - t.Fatalf("%d. Expected %v, got %v", 1, i, tt.msgType, msgType) - } - - if mr.reader.Buffered() < int(tt.payloadSize) && tt.buffered { - t.Fatalf("%d. Expected message to be buffered with at least %d bytes, but only %v bytes buffered", i, tt.payloadSize, mr.reader.Buffered()) - } - } -} - -func TestMsgReaderDeadlineNeverInterruptsNormalSizedMessages(t *testing.T) { - t.Parallel() - - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatal(err) - } - defer ln.Close() - - testCount := 10000 - - go func() { - var bigEndian pgmsg.BigEndianBuf - - conn, err := ln.Accept() - if err != nil { - t.Fatal(err) - } - defer conn.Close() - - for i := 0; i < testCount; i++ { - msgType := byte(i) - - _, err = conn.Write([]byte{msgType}) - if err != nil { - t.Fatal(err) - } - - msgSize := i % 4000 - - _, err = conn.Write(bigEndian.Int32(int32(msgSize + 4))) - if err != nil { - t.Fatal(err) - } - - payload := make([]byte, msgSize) - _, err = conn.Write(payload) - if err != nil { - t.Fatal(err) - } - } - }() - - conn, err := net.Dial("tcp", ln.Addr().String()) - if err != nil { - t.Fatal(err) - } - defer conn.Close() - - mr := &msgReader{ - reader: bufio.NewReader(conn), - shouldLog: func(int) bool { return false }, - } - - conn.SetReadDeadline(time.Now().Add(time.Millisecond)) - - i := 0 - for { - msgType, err := mr.rxMsg() - if err != nil { - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - conn.SetReadDeadline(time.Now().Add(time.Millisecond)) - continue - } else { - t.Fatalf("%d. Unexpected error: %v", i, err) - } - } - - expectedMsgType := byte(i) - if msgType != expectedMsgType { - t.Fatalf("%d. Expected %v, got %v", i, expectedMsgType, msgType) - } - - expectedMsgSize := i % 4000 - payload := mr.readBytes(mr.msgBytesRemaining) - if mr.err != nil { - t.Fatalf("%d. readBytes killed msgReader: %v", i, mr.err) - } - if len(payload) != expectedMsgSize { - t.Fatalf("%d. Expected %v, got %v", i, expectedMsgSize, len(payload)) - } - - i++ - if i == testCount { - break - } - } -} diff --git a/replication.go b/replication.go index 0acc9df9..a3e58fa3 100644 --- a/replication.go +++ b/replication.go @@ -1,9 +1,9 @@ package pgx import ( + "context" "errors" "fmt" - "net" "time" ) @@ -234,7 +234,7 @@ func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err walStart := reader.readInt64() serverWalEnd := reader.readInt64() serverTime := reader.readInt64() - walData := reader.readBytes(reader.msgBytesRemaining) + walData := reader.readBytes(int32(len(reader.msgBody) - reader.rp)) walMessage := WalMessage{WalStart: uint64(walStart), ServerWalEnd: uint64(serverWalEnd), ServerTime: uint64(serverTime), @@ -261,47 +261,23 @@ func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err return } -// Wait for a single replication message up to timeout time. +// Wait for a single replication message. // // Properly using this requires some knowledge of the postgres replication mechanisms, // as the client can receive both WAL data (the ultimate payload) and server heartbeat // updates. The caller also must send standby status updates in order to keep the connection // alive and working. // -// This returns pgx.ErrNotificationTimeout when there is no replication message by the specified -// duration. -func (rc *ReplicationConn) WaitForReplicationMessage(timeout time.Duration) (r *ReplicationMessage, err error) { - var zeroTime time.Time - - deadline := time.Now().Add(timeout) - - // Use SetReadDeadline to implement the timeout. SetReadDeadline will - // cause operations to fail with a *net.OpError that has a Timeout() - // of true. Because the normal pgx rxMsg path considers any error to - // have potentially corrupted the state of the connection, it dies - // on any errors. So to avoid timeout errors in rxMsg we set the - // deadline and peek into the reader. If a timeout error occurs there - // we don't break the pgx connection. If the Peek returns that data - // is available then we turn off the read deadline before the rxMsg. - err = rc.c.conn.SetReadDeadline(deadline) - if err != nil { - return nil, err - } - - // Wait until there is a byte available before continuing onto the normal msg reading path - _, err = rc.c.mr.reader.Peek(1) - if err != nil { - rc.c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline - if err, ok := err.(*net.OpError); ok && err.Timeout() { - return nil, ErrNotificationTimeout - } - return nil, err - } - - err = rc.c.conn.SetReadDeadline(zeroTime) +// This returns the context error when there is no replication message before +// the context is canceled. +func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (r *ReplicationMessage, err error) { + err = rc.c.initContext(ctx) if err != nil { return nil, err } + defer func() { + err = rc.c.termContext(err) + }() return rc.readReplicationMessage() } @@ -401,12 +377,14 @@ func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, ti return } + ctx, _ := context.WithTimeout(context.Background(), initialReplicationResponseTimeout) + // The first replication message that comes back here will be (in a success case) // a empty CopyBoth that is (apparently) sent as the confirmation that the replication has // started. This call will either return nil, nil or if it returns an error // that indicates the start replication command failed var r *ReplicationMessage - r, err = rc.WaitForReplicationMessage(initialReplicationResponseTimeout) + r, err = rc.WaitForReplicationMessage(ctx) if err != nil && r != nil { if rc.c.shouldLog(LogLevelError) { rc.c.log(LogLevelError, "Unxpected replication message %v", r) diff --git a/replication_test.go b/replication_test.go index 4f810c78..2c2d0af5 100644 --- a/replication_test.go +++ b/replication_test.go @@ -1,6 +1,7 @@ package pgx_test import ( + "context" "fmt" "github.com/jackc/pgx" "reflect" @@ -88,11 +89,10 @@ func TestSimpleReplicationConnection(t *testing.T) { for { var message *pgx.ReplicationMessage - message, err = replicationConn.WaitForReplicationMessage(time.Duration(1 * time.Second)) - if err != nil { - if err != pgx.ErrNotificationTimeout { - t.Fatalf("Replication failed: %v %s", err, reflect.TypeOf(err)) - } + ctx, _ := context.WithTimeout(context.Background(), time.Second) + message, err = replicationConn.WaitForReplicationMessage(ctx) + if err != nil && err != context.DeadlineExceeded { + t.Fatalf("Replication failed: %v %s", err, reflect.TypeOf(err)) } if message != nil { if message.WalMessage != nil { diff --git a/stress_test.go b/stress_test.go index 72d48a5c..82979fd6 100644 --- a/stress_test.go +++ b/stress_test.go @@ -244,8 +244,9 @@ func listenAndPoolUnlistens(pool *pgx.ConnPool, actionNum int) error { return err } - _, err = conn.WaitForNotification(100 * time.Millisecond) - if err == pgx.ErrNotificationTimeout { + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + _, err = conn.WaitForNotification(ctx) + if err == context.DeadlineExceeded { return nil } return err From c8be89a16b11e1f341dc1818ec2f53508b2e3894 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 13 Feb 2017 20:48:55 -0600 Subject: [PATCH 045/264] v3 notes updated --- v3.md | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/v3.md b/v3.md index f6bccb29..cd6345d0 100644 --- a/v3.md +++ b/v3.md @@ -20,12 +20,18 @@ Conn.Pid changed to accessor method Conn.PID() Remove Conn.TxStatus +Added Context methods + +Conn.WaitForNotification now takes context.Context instead of time.Duration for cancellation support. + +Conn.WaitForNotification no longer automatically pings internally every 15 seconds. (Reconsider this later...) + +ReplicationConn.WaitForReplicationMessage now takes context.Context instead of time.Duration for cancellation support. + ## TODO / Possible / Investigate Organize errors better -Optionally use Go 1.7 context / cancel and timeouts could be implemented this way - Remove circular dependency between Conn and ConnPool such that ConnPool depends on Conn, but Conn doesn't know anything about ConnPool Extract types Null* and Hstore to separate package @@ -34,6 +40,8 @@ Remove names from prepared statements - use database/sql style objects Better way of handling text/binary protocol choice than pgx.DefaultTypeFormats or manually editing a PreparedStatement. Possibly an optional part of preparing a statement is specifying the format and/or a decoder. Or maybe it is part of a QueryEx call... Could be very interesting to make encoding and decoding possible without being a method of the type. This could drastically clean up those huge type switches. +Also maybe support binary and text for everything possible + Copy protocol support (this potentially ties in with text/binary protocol) ValueReader / msgReader cleanup @@ -44,3 +52,8 @@ Every field that should not be set by user should be replaced by accessor method Investigate strongly typed queries. i.e. Some sort of interface where varargs of Query, Exec, and Scan wouldn't happen. Need to be some low-level interface where (probably generated) functions could (more or less) directly read and write to the connection. Reject scanning non-string like things into a string (e.g. binary encoded timestamptz to string). See https://github.com/jackc/pgx/issues/223 + +Further clean up logging interface -- still some pre-loglevel code in place +Possibly integrate internal logging support with context. Possibly add method that adds arbitrary pgx log data to context. Or add ability to configure what key(s) pgx looks at for additional log context. +Consider whether to switch to logrus style or stick with log15 style logs +Keep ability to change logging while running From f947f0971fcbd12262cf66094ec6324d0c23638a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 14 Feb 2017 21:57:48 -0600 Subject: [PATCH 046/264] more v3 notes --- v3.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/v3.md b/v3.md index cd6345d0..89a080b2 100644 --- a/v3.md +++ b/v3.md @@ -34,6 +34,10 @@ Organize errors better Remove circular dependency between Conn and ConnPool such that ConnPool depends on Conn, but Conn doesn't know anything about ConnPool +Or maybe double-down on conn/pool coupling and improve connpool + +Add auto-idle pinging to conns in pool + Extract types Null* and Hstore to separate package Remove names from prepared statements - use database/sql style objects @@ -42,14 +46,13 @@ Better way of handling text/binary protocol choice than pgx.DefaultTypeFormats o Also maybe support binary and text for everything possible -Copy protocol support (this potentially ties in with text/binary protocol) -ValueReader / msgReader cleanup +dValueReader / msgReader cleanup Make easier / possible to mock Conn or ConnPool (https://github.com/jackc/pgx/pull/162) Every field that should not be set by user should be replaced by accessor method (e.g. Conn.PID, Conn.SecretKey) -Investigate strongly typed queries. i.e. Some sort of interface where varargs of Query, Exec, and Scan wouldn't happen. Need to be some low-level interface where (probably generated) functions could (more or less) directly read and write to the connection. +Investigate strongly typed queries. i.e. Some sort of interface where varargs of Query, Exec, and Scan wouldn't happen. Need to be some low-level interface where (probably generated) functions could (more or less) directly read and write to the connection. Clean code and type-safety / control would be the benefits. Row scanning performance is already so fast there is little to improve (go_db_bench shows under 1 microsecond per row). Reject scanning non-string like things into a string (e.g. binary encoded timestamptz to string). See https://github.com/jackc/pgx/issues/223 From efcc172c8b6366d5bedcb4c34f8277cc5e5f4017 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 16 Feb 2017 18:08:43 -0600 Subject: [PATCH 047/264] Remove unreachable code --- conn.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/conn.go b/conn.go index a8b0b22c..88c344f4 100644 --- a/conn.go +++ b/conn.go @@ -1417,8 +1417,6 @@ func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interfa } } } - - return commandTag, err } func (c *Conn) initContext(ctx context.Context) error { From c540b65edf09e60a622b8851cd4b325aaed642b2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 16 Feb 2017 18:11:30 -0600 Subject: [PATCH 048/264] Fix leaked contexts --- replication.go | 3 ++- replication_test.go | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/replication.go b/replication.go index a3e58fa3..9bc4a1a4 100644 --- a/replication.go +++ b/replication.go @@ -377,7 +377,8 @@ func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, ti return } - ctx, _ := context.WithTimeout(context.Background(), initialReplicationResponseTimeout) + ctx, cancelFn := context.WithTimeout(context.Background(), initialReplicationResponseTimeout) + defer cancelFn() // The first replication message that comes back here will be (in a success case) // a empty CopyBoth that is (apparently) sent as the confirmation that the replication has diff --git a/replication_test.go b/replication_test.go index 2c2d0af5..43793f3c 100644 --- a/replication_test.go +++ b/replication_test.go @@ -89,7 +89,8 @@ func TestSimpleReplicationConnection(t *testing.T) { for { var message *pgx.ReplicationMessage - ctx, _ := context.WithTimeout(context.Background(), time.Second) + ctx, cancelFn := context.WithTimeout(context.Background(), time.Second) + defer cancelFn() message, err = replicationConn.WaitForReplicationMessage(ctx) if err != nil && err != context.DeadlineExceeded { t.Fatalf("Replication failed: %v %s", err, reflect.TypeOf(err)) From e390ac33f58a26a6f2105f435dab111c08674a0b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 16 Feb 2017 18:12:42 -0600 Subject: [PATCH 049/264] Fix Fatal -> Fatalf --- conn_test.go | 2 +- query_test.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/conn_test.go b/conn_test.go index 63b486a6..9a703bbd 100644 --- a/conn_test.go +++ b/conn_test.go @@ -869,7 +869,7 @@ func TestExecContextCancelationCancelsQuery(t *testing.T) { _, err := conn.ExecContext(ctx, "select pg_sleep(60)") if err != context.Canceled { - t.Fatal("Expected context.Canceled err, got %v", err) + t.Fatalf("Expected context.Canceled err, got %v", err) } ensureConnValid(t, conn) diff --git a/query_test.go b/query_test.go index 83c2f9c1..f2942951 100644 --- a/query_test.go +++ b/query_test.go @@ -1510,7 +1510,7 @@ func TestQueryContextCancelationCancelsQuery(t *testing.T) { } if rows.Err() != context.Canceled { - t.Fatal("Expected context.Canceled error, got %v", rows.Err()) + t.Fatalf("Expected context.Canceled error, got %v", rows.Err()) } ensureConnValid(t, conn) @@ -1570,7 +1570,7 @@ func TestQueryRowContextCancelationCancelsQuery(t *testing.T) { var result []byte err := conn.QueryRowContext(ctx, "select pg_sleep(5)").Scan(&result) if err != context.Canceled { - t.Fatal("Expected context.Canceled error, got %v", err) + t.Fatalf("Expected context.Canceled error, got %v", err) } ensureConnValid(t, conn) From ccc65c361aa0f5fe060cbfacdf9e1f5a5fe2564f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 16 Feb 2017 18:26:24 -0600 Subject: [PATCH 050/264] Privatize Conn.SecretKey --- conn.go | 6 +++--- conn_test.go | 4 ---- v3.md | 4 +++- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/conn.go b/conn.go index 88c344f4..4a0b6fda 100644 --- a/conn.go +++ b/conn.go @@ -74,7 +74,7 @@ type Conn struct { wbuf [1024]byte writeBuf WriteBuf pid int32 // backend pid - SecretKey int32 // key to use to send a cancel query message to the server + secretKey int32 // key to use to send a cancel query message to the server RuntimeParams map[string]string // parameters that have been reported by the server PgTypes map[OID]PgType // oids to PgTypes config ConnConfig // config used when establishing this connection @@ -1148,7 +1148,7 @@ func (c *Conn) rxErrorResponse(r *msgReader) (err PgError) { func (c *Conn) rxBackendKeyData(r *msgReader) { c.pid = r.readInt32() - c.SecretKey = r.readInt32() + c.secretKey = r.readInt32() } func (c *Conn) rxReadyForQuery(r *msgReader) { @@ -1321,7 +1321,7 @@ func (c *Conn) cancelQuery() { binary.BigEndian.PutUint32(buf[0:4], 16) binary.BigEndian.PutUint32(buf[4:8], 80877102) binary.BigEndian.PutUint32(buf[8:12], uint32(c.pid)) - binary.BigEndian.PutUint32(buf[12:16], uint32(c.SecretKey)) + binary.BigEndian.PutUint32(buf[12:16], uint32(c.secretKey)) _, err = cancelConn.Write(buf) if err != nil { return err diff --git a/conn_test.go b/conn_test.go index 9a703bbd..cc87efa8 100644 --- a/conn_test.go +++ b/conn_test.go @@ -32,10 +32,6 @@ func TestConnect(t *testing.T) { t.Error("Backend PID not stored") } - if conn.SecretKey == 0 { - t.Error("Backend secret key not stored") - } - var currentDB string err = conn.QueryRow("select current_database()").Scan(¤tDB) if err != nil { diff --git a/v3.md b/v3.md index 89a080b2..4663cfd9 100644 --- a/v3.md +++ b/v3.md @@ -18,6 +18,8 @@ Transaction isolation level constants are now typed strings instead of bare stri Conn.Pid changed to accessor method Conn.PID() +Conn.SecretKey removed + Remove Conn.TxStatus Added Context methods @@ -50,7 +52,7 @@ dValueReader / msgReader cleanup Make easier / possible to mock Conn or ConnPool (https://github.com/jackc/pgx/pull/162) -Every field that should not be set by user should be replaced by accessor method (e.g. Conn.PID, Conn.SecretKey) +Every field that should not be set by user should be replaced by accessor method (only ones left are Conn.RuntimeParams and Conn.PgTypes) Investigate strongly typed queries. i.e. Some sort of interface where varargs of Query, Exec, and Scan wouldn't happen. Need to be some low-level interface where (probably generated) functions could (more or less) directly read and write to the connection. Clean code and type-safety / control would be the benefits. Row scanning performance is already so fast there is little to improve (go_db_bench shows under 1 microsecond per row). From 4d5622186822a1c14592e2183b030448f063af67 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 16 Feb 2017 19:19:45 -0600 Subject: [PATCH 051/264] Do not scan binary values into strings refs #219 and #228 --- query.go | 14 +++++++++++++- query_test.go | 17 +++++++++++++++++ v3.md | 4 ++-- values.go | 21 +++++++++++++++------ 4 files changed, 47 insertions(+), 9 deletions(-) diff --git a/query.go b/query.go index efb039d5..99b383e0 100644 --- a/query.go +++ b/query.go @@ -301,7 +301,19 @@ func (rows *Rows) Values() ([]interface{}, error) { // All intrinsic types (except string) are encoded with binary // encoding so anything else should be treated as a string case TextFormatCode: - values = append(values, vr.ReadString(vr.Len())) + switch vr.Type().DataType { + case JSONOID: + var d interface{} + decodeJSON(vr, &d) + values = append(values, d) + case JSONBOID: + var d interface{} + decodeJSONB(vr, &d) + values = append(values, d) + default: + values = append(values, vr.ReadString(vr.Len())) + } + case BinaryFormatCode: switch vr.Type().DataType { case TextOID, VarcharOID: diff --git a/query_test.go b/query_test.go index f2942951..a78914b6 100644 --- a/query_test.go +++ b/query_test.go @@ -100,6 +100,23 @@ func TestConnQueryValues(t *testing.T) { } } +// https://github.com/jackc/pgx/issues/228 +func TestRowsScanDoesNotAllowScanningBinaryFormatValuesIntoString(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var s string + + err := conn.QueryRow("select 1").Scan(&s) + if err == nil || !strings.Contains(err.Error(), "cannot decode binary value into string") { + t.Fatalf("Expected Scan to fail to encode binary value into string but: %v", err) + } + + ensureConnValid(t, conn) +} + // Test that a connection stays valid when query results are closed early func TestConnQueryCloseEarly(t *testing.T) { t.Parallel() diff --git a/v3.md b/v3.md index 4663cfd9..68619d4d 100644 --- a/v3.md +++ b/v3.md @@ -30,6 +30,8 @@ Conn.WaitForNotification no longer automatically pings internally every 15 secon ReplicationConn.WaitForReplicationMessage now takes context.Context instead of time.Duration for cancellation support. +Reject scanning binary format values into a string (e.g. binary encoded timestamptz to string). See https://github.com/jackc/pgx/issues/219 and https://github.com/jackc/pgx/issues/228 + ## TODO / Possible / Investigate Organize errors better @@ -56,8 +58,6 @@ Every field that should not be set by user should be replaced by accessor method Investigate strongly typed queries. i.e. Some sort of interface where varargs of Query, Exec, and Scan wouldn't happen. Need to be some low-level interface where (probably generated) functions could (more or less) directly read and write to the connection. Clean code and type-safety / control would be the benefits. Row scanning performance is already so fast there is little to improve (go_db_bench shows under 1 microsecond per row). -Reject scanning non-string like things into a string (e.g. binary encoded timestamptz to string). See https://github.com/jackc/pgx/issues/223 - Further clean up logging interface -- still some pre-loglevel code in place Possibly integrate internal logging support with context. Possibly add method that adds arbitrary pgx log data to context. Or add ability to configure what key(s) pgx looks at for additional log context. Consider whether to switch to logrus style or stick with log15 style logs diff --git a/values.go b/values.go index a59ca0c3..4255f5ea 100644 --- a/values.go +++ b/values.go @@ -102,20 +102,15 @@ func init() { "date": BinaryFormatCode, "float4": BinaryFormatCode, "float8": BinaryFormatCode, - "json": BinaryFormatCode, - "jsonb": BinaryFormatCode, "inet": BinaryFormatCode, "int2": BinaryFormatCode, "int4": BinaryFormatCode, "int8": BinaryFormatCode, - "name": BinaryFormatCode, "oid": BinaryFormatCode, "record": BinaryFormatCode, - "text": BinaryFormatCode, "tid": BinaryFormatCode, "timestamp": BinaryFormatCode, "timestamptz": BinaryFormatCode, - "varchar": BinaryFormatCode, "xid": BinaryFormatCode, } } @@ -2022,6 +2017,20 @@ func decodeText(vr *ValueReader) string { return "" } + if vr.Type().FormatCode == BinaryFormatCode { + vr.Fatal(ProtocolError("cannot decode binary value into string")) + return "" + } + + return vr.ReadString(vr.Len()) +} + +func decodeTextAllowBinary(vr *ValueReader) string { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into string")) + return "" + } + return vr.ReadString(vr.Len()) } @@ -2370,7 +2379,7 @@ func decodeRecord(vr *ValueReader) []interface{} { case InetOID, CidrOID: record = append(record, decodeInet(&fieldVR)) case TextOID, VarcharOID, UnknownOID: - record = append(record, decodeText(&fieldVR)) + record = append(record, decodeTextAllowBinary(&fieldVR)) default: vr.Fatal(fmt.Errorf("decodeRecord cannot decode oid %d", fd.DataType)) return nil From 47eda78ea1cbd9f5fd9cfaef26bf3ee40a3045d9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 16 Feb 2017 19:41:28 -0600 Subject: [PATCH 052/264] Refactor huge switch statement --- conn.go | 7 +------ values.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/conn.go b/conn.go index 4a0b6fda..7d70b9f8 100644 --- a/conn.go +++ b/conn.go @@ -964,12 +964,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} case string, *string: wbuf.WriteInt16(TextFormatCode) default: - switch oid { - case BoolOID, ByteaOID, Int2OID, Int4OID, Int8OID, Float4OID, Float8OID, TimestampTzOID, TimestampTzArrayOID, TimestampOID, TimestampArrayOID, DateOID, BoolArrayOID, ByteaArrayOID, Int2ArrayOID, Int4ArrayOID, Int8ArrayOID, Float4ArrayOID, Float8ArrayOID, TextArrayOID, VarcharArrayOID, OIDOID, InetOID, CidrOID, InetArrayOID, CidrArrayOID, RecordOID, JSONOID, JSONBOID: - wbuf.WriteInt16(BinaryFormatCode) - default: - wbuf.WriteInt16(TextFormatCode) - } + wbuf.WriteInt16(internalNativeGoTypeFormats[oid]) } } diff --git a/values.go b/values.go index 4255f5ea..45ed914c 100644 --- a/values.go +++ b/values.go @@ -77,6 +77,9 @@ const minInt = -maxInt - 1 // set here. var DefaultTypeFormats map[string]int16 +// internalNativeGoTypeFormats lists the encoding type for native Go types (not handled with Encoder interface) +var internalNativeGoTypeFormats map[OID]int16 + func init() { DefaultTypeFormats = map[string]int16{ "_aclitem": TextFormatCode, // Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin) @@ -113,6 +116,38 @@ func init() { "timestamptz": BinaryFormatCode, "xid": BinaryFormatCode, } + + internalNativeGoTypeFormats = map[OID]int16{ + BoolArrayOID: BinaryFormatCode, + BoolOID: BinaryFormatCode, + ByteaArrayOID: BinaryFormatCode, + ByteaOID: BinaryFormatCode, + CidrArrayOID: BinaryFormatCode, + CidrOID: BinaryFormatCode, + DateOID: BinaryFormatCode, + Float4ArrayOID: BinaryFormatCode, + Float4OID: BinaryFormatCode, + Float8ArrayOID: BinaryFormatCode, + Float8OID: BinaryFormatCode, + InetArrayOID: BinaryFormatCode, + InetOID: BinaryFormatCode, + Int2ArrayOID: BinaryFormatCode, + Int2OID: BinaryFormatCode, + Int4ArrayOID: BinaryFormatCode, + Int4OID: BinaryFormatCode, + Int8ArrayOID: BinaryFormatCode, + Int8OID: BinaryFormatCode, + JSONBOID: BinaryFormatCode, + JSONOID: BinaryFormatCode, + OIDOID: BinaryFormatCode, + RecordOID: BinaryFormatCode, + TextArrayOID: BinaryFormatCode, + TimestampArrayOID: BinaryFormatCode, + TimestampOID: BinaryFormatCode, + TimestampTzArrayOID: BinaryFormatCode, + TimestampTzOID: BinaryFormatCode, + VarcharArrayOID: BinaryFormatCode, + } } // SerializationError occurs on failure to encode or decode a value From 366440d40d078c11ad3010fa6c6752731561186d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 17 Feb 2017 07:40:06 -0600 Subject: [PATCH 053/264] Remove *msgReader.readOID --- conn.go | 6 +++--- msg_reader.go | 4 ---- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/conn.go b/conn.go index 7d70b9f8..9303fb74 100644 --- a/conn.go +++ b/conn.go @@ -1157,9 +1157,9 @@ func (c *Conn) rxRowDescription(r *msgReader) (fields []FieldDescription) { for i := int16(0); i < fieldCount; i++ { f := &fields[i] f.Name = r.readCString() - f.Table = r.readOID() + f.Table = OID(r.readUint32()) f.AttributeNumber = r.readInt16() - f.DataType = r.readOID() + f.DataType = OID(r.readUint32()) f.DataTypeSize = r.readInt16() f.Modifier = r.readInt32() f.FormatCode = r.readInt16() @@ -1179,7 +1179,7 @@ func (c *Conn) rxParameterDescription(r *msgReader) (parameters []OID) { parameters = make([]OID, 0, parameterCount) for i := 0; i < parameterCount; i++ { - parameters = append(parameters, r.readOID()) + parameters = append(parameters, OID(r.readUint32())) } return } diff --git a/msg_reader.go b/msg_reader.go index 53e944bb..1858037a 100644 --- a/msg_reader.go +++ b/msg_reader.go @@ -178,10 +178,6 @@ func (r *msgReader) readInt64() int64 { return n } -func (r *msgReader) readOID() OID { - return OID(r.readInt32()) -} - // readCString reads a null terminated string func (r *msgReader) readCString() string { if r.err != nil { From dd0ee5bc6f7f453977ce2a59fa2cbd6e24441005 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 2 Mar 2017 20:24:44 -0600 Subject: [PATCH 054/264] Remove reference to gopkg.in It doesn't work with sub-packages and now that Go has vendoring in the standard build system it is less necessary. fixes #164 --- README.md | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/README.md b/README.md index 5f5a3f55..ea2038a8 100644 --- a/README.md +++ b/README.md @@ -124,8 +124,4 @@ Change the following settings in your postgresql.conf: ## Version Policy -pgx follows semantic versioning for the documented public API on stable releases. Branch ```v2``` is the latest stable release. ```master``` can contain new features or behavior that will change or be removed before being merged to the stable ```v2``` branch (in practice, this occurs very rarely). - -Consider using a vendoring -tool such as [godep](https://github.com/tools/godep) or importing pgx via ```import -"gopkg.in/jackc/pgx.v2"``` to lock to the ```v2``` branch. +pgx follows semantic versioning for the documented public API on stable releases. Branch `v2` is the latest stable release. `master` can contain new features or behavior that will change or be removed before being merged to the stable `v2` branch (in practice, this occurs very rarely). From 5e997e82f4255a4458edff7f6398ac0e3fbb4dd4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 Mar 2017 15:33:34 -0600 Subject: [PATCH 055/264] Initial proof-of-concept for pgtype Squashed commit of the following: commit c19454582b335ce5bdda6320f7e4e8c76cfeaf44 Author: Jack Christensen Date: Fri Mar 3 15:24:47 2017 -0600 Add AssignTo to pgtype.Timestamptz Also handle infinity for pgtype.Date commit 7329933610b38f4bc15731b1f7c55c520b49e300 Author: Jack Christensen Date: Fri Mar 3 15:12:18 2017 -0600 Implement AssignTo for most pgtypes commit cc3d1e4af896d34ec98c3bf2e982d0367451f21c Author: Jack Christensen Date: Thu Mar 2 21:19:07 2017 -0600 Use pgtype.Int2Array in pgx commit 36da5cc2178d1a31a56dc6e6f128843bd80dea0b Author: Jack Christensen Date: Tue Feb 28 21:45:33 2017 -0600 Add text array transcoding commit 1b0f18d99f38b69f8c2db26388815e67b2b03d59 Author: Jack Christensen Date: Mon Feb 27 19:28:55 2017 -0600 Add ParseUntypedTextArray commit 0f50ce3e833fc38495d333228daf04f5142be676 Author: Jack Christensen Date: Mon Feb 27 18:54:20 2017 -0600 wip commit d934f273627d79997035c282416db922f2fbe87a Author: Jack Christensen Date: Sun Feb 26 17:14:32 2017 -0600 WIP - beginning text format array parsing commit 7276ad33ce7fa9c250745a3ed909998f3dae4a32 Author: Jack Christensen Date: Sat Feb 25 22:50:11 2017 -0600 Beginning binary arrays commit 917faa5a3175d376222423c10aca297a20f96448 Author: Jack Christensen Date: Sat Feb 25 19:36:35 2017 -0600 Fix incomplete tests commit de8c140cfb98b7b047d53c5718ccbf12eaf813a1 Author: Jack Christensen Date: Sat Feb 25 19:32:22 2017 -0600 Add timestamptz null and infinity commit 7d9f954de4e071a1eccac762248079b90dbeb53f Author: Jack Christensen Date: Sat Feb 25 18:19:38 2017 -0600 Add infinity to pgtype.Date commit 7bf783ae20ba05571c2fb9f661183233c95eab41 Author: Jack Christensen Date: Sat Feb 25 17:19:55 2017 -0600 Add Status to pgtype.Date commit 984500455c9b9a4b6221758540d248e6410d93a4 Author: Jack Christensen Date: Sat Feb 25 16:54:01 2017 -0600 Add status to Int4 and Int8 commit 6fe76fcfc2de31552790db3b093480a9d5b2a742 Author: Jack Christensen Date: Sat Feb 25 16:40:27 2017 -0600 Extract testSuccessfulTranscode commit 001647c1da03f796014cf21f41c9a7fd2cfadfde Author: Jack Christensen Date: Sat Feb 25 16:15:51 2017 -0600 Add Status to pgtype.Int2 commit 720451f06d13d9c9fa2a0482e010f24bf4627c2a Author: Jack Christensen Date: Sat Feb 25 15:56:44 2017 -0600 Add status to pgtype.Bool commit 325f700b6edff215a692b10bc5b94cdfe1100769 Author: Jack Christensen Date: Fri Feb 24 17:28:15 2017 -0600 Add date to conversion system commit 4a9343e45d3897f59eab98a0009d2ddbe07e02d7 Author: Jack Christensen Date: Fri Feb 24 16:28:35 2017 -0600 Add bool to oid based encoding commit d984fcafab1476cf84852485b6711f4b2069eb6d Author: Jack Christensen Date: Fri Feb 24 16:15:38 2017 -0600 Add pgtype interfaces commit 0f93bfc2de4023b069b966c0988bf7f0469d1809 Author: Jack Christensen Date: Fri Feb 24 14:48:34 2017 -0600 Begin introduction of Convert commit e5707023cac7c07342b8c910e480d09a1caaf6ee Author: Jack Christensen Date: Fri Feb 24 14:10:56 2017 -0600 Move bool to pgtype commit bb764d2129efe7fb21e841dbb35e6d0dc7586d37 Author: Jack Christensen Date: Fri Feb 24 13:45:05 2017 -0600 Add Int2 test commit 08c49437f455a32f7c3f0a524cd21a895d440301 Author: Jack Christensen Date: Fri Feb 24 13:44:09 2017 -0600 Add Int4 test commit 16722952222fd15c53c8fa84974645504a6d0dc0 Author: Jack Christensen Date: Fri Feb 24 08:56:59 2017 -0600 Add int8 tests commit 83a5447cd2c46b58d0880023cc4e9af0c84988a2 Author: Jack Christensen Date: Wed Feb 22 18:08:05 2017 -0600 wip commit 0ca0ee72068a72b016729b01fccef22474595285 Author: Jack Christensen Date: Mon Feb 20 18:56:52 2017 -0600 wip commit d2c2baf4ea2cd0793d68c7094c425217df952bec Author: Jack Christensen Date: Mon Feb 20 18:46:10 2017 -0600 wip commit f78371da0098356527b193fd496a338da5fe414b Author: Jack Christensen Date: Mon Feb 20 17:43:39 2017 -0600 wip commit 3366699bea62ec0110db05f3cb2998d58ac9ce5d Author: Jack Christensen Date: Mon Feb 20 14:07:47 2017 -0600 wip commit 66b79e940870fd0133ebb10ac1547e1d4d7d0b51 Author: Jack Christensen Date: Mon Feb 20 13:35:37 2017 -0600 Extract pgio commit 8b07d97d1305ed98fd76db6e306a289c0af92d56 Author: Jack Christensen Date: Mon Feb 20 13:20:00 2017 -0600 wip commit 62f1adb3427f4317b708da075dce50c4d4daff7b Author: Jack Christensen Date: Mon Feb 20 12:08:46 2017 -0600 wip commit a712d2546933a5a8433c65eef0ff2ee135077c87 Author: Jack Christensen Date: Mon Feb 20 09:30:52 2017 -0600 wip commit 4faf97cc588126dda160fc360680719572a23105 Author: Jack Christensen Date: Fri Feb 17 22:20:18 2017 -0600 wip --- bench-tmp_test.go | 55 ++++ conn.go | 20 +- copy_to_test.go | 4 +- messages.go | 40 +-- pgio/doc.go | 8 + pgio/read.go | 104 +++++++ pgio/write.go | 97 ++++++ pgtype/array.go | 375 +++++++++++++++++++++++ pgtype/array_test.go | 98 ++++++ pgtype/bool.go | 166 ++++++++++ pgtype/bool_test.go | 43 +++ pgtype/convert.go | 239 +++++++++++++++ pgtype/date.go | 191 ++++++++++++ pgtype/date_test.go | 51 ++++ pgtype/extra-interface.txt | 3 + pgtype/int2.go | 167 ++++++++++ pgtype/int2_test.go | 55 ++++ pgtype/int2array.go | 308 +++++++++++++++++++ pgtype/int2array_test.go | 87 ++++++ pgtype/int4.go | 158 ++++++++++ pgtype/int4_test.go | 55 ++++ pgtype/int8.go | 149 +++++++++ pgtype/int8_test.go | 55 ++++ pgtype/pgtype.go | 102 +++++++ pgtype/pgtype_test.go | 108 +++++++ pgtype/text_element.go | 112 +++++++ pgtype/timestamptz.go | 203 +++++++++++++ pgtype/timestamptz_test.go | 60 ++++ query.go | 63 +++- query_test.go | 15 +- value_reader.go | 27 ++ values.go | 607 +++++++++++++------------------------ values_test.go | 44 +-- 33 files changed, 3412 insertions(+), 457 deletions(-) create mode 100644 bench-tmp_test.go create mode 100644 pgio/doc.go create mode 100644 pgio/read.go create mode 100644 pgio/write.go create mode 100644 pgtype/array.go create mode 100644 pgtype/array_test.go create mode 100644 pgtype/bool.go create mode 100644 pgtype/bool_test.go create mode 100644 pgtype/convert.go create mode 100644 pgtype/date.go create mode 100644 pgtype/date_test.go create mode 100644 pgtype/extra-interface.txt create mode 100644 pgtype/int2.go create mode 100644 pgtype/int2_test.go create mode 100644 pgtype/int2array.go create mode 100644 pgtype/int2array_test.go create mode 100644 pgtype/int4.go create mode 100644 pgtype/int4_test.go create mode 100644 pgtype/int8.go create mode 100644 pgtype/int8_test.go create mode 100644 pgtype/pgtype.go create mode 100644 pgtype/pgtype_test.go create mode 100644 pgtype/text_element.go create mode 100644 pgtype/timestamptz.go create mode 100644 pgtype/timestamptz_test.go diff --git a/bench-tmp_test.go b/bench-tmp_test.go new file mode 100644 index 00000000..a8e3f7db --- /dev/null +++ b/bench-tmp_test.go @@ -0,0 +1,55 @@ +package pgx_test + +import ( + "testing" +) + +func BenchmarkPgtypeInt4ParseBinary(b *testing.B) { + conn := mustConnect(b, *defaultConnConfig) + defer closeConn(b, conn) + + _, err := conn.Prepare("selectBinary", "select n::int4 from generate_series(1, 100) n") + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + var n int32 + + rows, err := conn.Query("selectBinary") + if err != nil { + b.Fatal(err) + } + + for rows.Next() { + err := rows.Scan(&n) + if err != nil { + b.Fatal(err) + } + } + + if rows.Err() != nil { + b.Fatal(rows.Err()) + } + } +} + +func BenchmarkPgtypeInt4EncodeBinary(b *testing.B) { + conn := mustConnect(b, *defaultConnConfig) + defer closeConn(b, conn) + + _, err := conn.Prepare("encodeBinary", "select $1::int4, $2::int4, $3::int4, $4::int4, $5::int4, $6::int4, $7::int4") + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + rows, err := conn.Query("encodeBinary", int32(i), int32(i), int32(i), int32(i), int32(i), int32(i), int32(i)) + if err != nil { + b.Fatal(err) + } + rows.Close() + } +} diff --git a/conn.go b/conn.go index 9303fb74..09dada10 100644 --- a/conn.go +++ b/conn.go @@ -7,7 +7,6 @@ import ( "encoding/hex" "errors" "fmt" - "golang.org/x/net/context" "io" "net" "net/url" @@ -20,7 +19,10 @@ import ( "sync/atomic" "time" + "golang.org/x/net/context" + "github.com/jackc/pgx/chunkreader" + "github.com/jackc/pgx/pgtype" ) const ( @@ -102,6 +104,8 @@ type Conn struct { ctxInProgress bool doneChan chan struct{} closedChan chan error + + oidPgtypeValues map[OID]pgtype.Value } // PreparedStatement is a description of a prepared statement @@ -275,6 +279,16 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.doneChan = make(chan struct{}) c.closedChan = make(chan error) + c.oidPgtypeValues = map[OID]pgtype.Value{ + BoolOID: &pgtype.Bool{}, + DateOID: &pgtype.Date{}, + Int2OID: &pgtype.Int2{}, + Int2ArrayOID: &pgtype.Int2Array{}, + Int4OID: &pgtype.Int4{}, + Int8OID: &pgtype.Int8{}, + TimestampTzOID: &pgtype.Timestamptz{}, + } + if tlsConfig != nil { if c.shouldLog(LogLevelDebug) { c.log(LogLevelDebug, "Starting TLS handshake") @@ -961,6 +975,10 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} switch arg := arguments[i].(type) { case Encoder: wbuf.WriteInt16(arg.FormatCode()) + case pgtype.BinaryEncoder: + wbuf.WriteInt16(BinaryFormatCode) + case pgtype.TextEncoder: + wbuf.WriteInt16(TextFormatCode) case string, *string: wbuf.WriteInt16(TextFormatCode) default: diff --git a/copy_to_test.go b/copy_to_test.go index 43cb5acc..7d5f2509 100644 --- a/copy_to_test.go +++ b/copy_to_test.go @@ -26,7 +26,7 @@ func TestConnCopyToSmall(t *testing.T) { )`) inputRows := [][]interface{}{ - {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)}, + {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)}, {nil, nil, nil, nil, nil, nil, nil}, } @@ -83,7 +83,7 @@ func TestConnCopyToLarge(t *testing.T) { inputRows := [][]interface{}{} for i := 0; i < 10000; i++ { - inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}}) + inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}}) } copyCount, err := conn.CopyTo("foo", []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyToRows(inputRows)) diff --git a/messages.go b/messages.go index c2964b82..f6be9ff9 100644 --- a/messages.go +++ b/messages.go @@ -101,6 +101,7 @@ func newWriteBuf(c *Conn, t byte) *WriteBuf { // by the Encoder interface when implementing custom encoders. type WriteBuf struct { buf []byte + convBuf [8]byte sizeIdx int conn *Conn } @@ -125,35 +126,40 @@ func (wb *WriteBuf) WriteCString(s string) { } func (wb *WriteBuf) WriteInt16(n int16) { - b := make([]byte, 2) - binary.BigEndian.PutUint16(b, uint16(n)) - wb.buf = append(wb.buf, b...) + wb.WriteUint16(uint16(n)) } -func (wb *WriteBuf) WriteUint16(n uint16) { - b := make([]byte, 2) - binary.BigEndian.PutUint16(b, n) - wb.buf = append(wb.buf, b...) +func (wb *WriteBuf) WriteUint16(n uint16) (int, error) { + binary.BigEndian.PutUint16(wb.convBuf[:2], n) + wb.buf = append(wb.buf, wb.convBuf[:2]...) + return 2, nil } func (wb *WriteBuf) WriteInt32(n int32) { - b := make([]byte, 4) - binary.BigEndian.PutUint32(b, uint32(n)) - wb.buf = append(wb.buf, b...) + wb.WriteUint32(uint32(n)) } -func (wb *WriteBuf) WriteUint32(n uint32) { - b := make([]byte, 4) - binary.BigEndian.PutUint32(b, n) - wb.buf = append(wb.buf, b...) +func (wb *WriteBuf) WriteUint32(n uint32) (int, error) { + binary.BigEndian.PutUint32(wb.convBuf[:4], n) + wb.buf = append(wb.buf, wb.convBuf[:4]...) + return 4, nil } func (wb *WriteBuf) WriteInt64(n int64) { - b := make([]byte, 8) - binary.BigEndian.PutUint64(b, uint64(n)) - wb.buf = append(wb.buf, b...) + wb.WriteUint64(uint64(n)) +} + +func (wb *WriteBuf) WriteUint64(n uint64) (int, error) { + binary.BigEndian.PutUint64(wb.convBuf[:8], n) + wb.buf = append(wb.buf, wb.convBuf[:8]...) + return 8, nil } func (wb *WriteBuf) WriteBytes(b []byte) { wb.buf = append(wb.buf, b...) } + +func (wb *WriteBuf) Write(b []byte) (int, error) { + wb.buf = append(wb.buf, b...) + return len(b), nil +} diff --git a/pgio/doc.go b/pgio/doc.go new file mode 100644 index 00000000..36233a47 --- /dev/null +++ b/pgio/doc.go @@ -0,0 +1,8 @@ +// Package pgio a extremely low-level IO toolkit for the PostgreSQL wire protocol. +/* +pgio provides functions for reading and writing integers from io.Reader and +io.Writer while doing byte order conversion. It publishes interfaces which +readers and writers may implement to decode and encode messages with the minimum +of memory allocations. +*/ +package pgio diff --git a/pgio/read.go b/pgio/read.go new file mode 100644 index 00000000..7c39162c --- /dev/null +++ b/pgio/read.go @@ -0,0 +1,104 @@ +package pgio + +import ( + "encoding/binary" + "io" +) + +type Uint16Reader interface { + ReadUint16() (n uint16, err error) +} + +type Uint32Reader interface { + ReadUint32() (n uint32, err error) +} + +type Uint64Reader interface { + ReadUint64() (n uint64, err error) +} + +// ReadByte reads a byte from r. +func ReadByte(r io.Reader) (byte, error) { + if r, ok := r.(io.ByteReader); ok { + return r.ReadByte() + } + + buf := make([]byte, 1) + _, err := r.Read(buf) + return buf[0], err +} + +// ReadUint16 reads an uint16 from r in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Read if r provides a ReadUint16 +// method. +func ReadUint16(r io.Reader) (uint16, error) { + if r, ok := r.(Uint16Reader); ok { + return r.ReadUint16() + } + + buf := make([]byte, 2) + _, err := io.ReadFull(r, buf) + if err != nil { + return 0, err + } + + return binary.BigEndian.Uint16(buf), nil +} + +// ReadInt16 reads an int16 r in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Read if r provides a ReadUint16 +// method. +func ReadInt16(r io.Reader) (int16, error) { + n, err := ReadUint16(r) + return int16(n), err +} + +// ReadUint32 reads an uint32 r in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Read if r provides a ReadUint32 +// method. +func ReadUint32(r io.Reader) (uint32, error) { + if r, ok := r.(Uint32Reader); ok { + return r.ReadUint32() + } + + buf := make([]byte, 4) + _, err := io.ReadFull(r, buf) + if err != nil { + return 0, err + } + + return binary.BigEndian.Uint32(buf), nil +} + +// ReadInt32 reads an int32 r in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Read if r provides a ReadUint32 +// method. +func ReadInt32(r io.Reader) (int32, error) { + n, err := ReadUint32(r) + return int32(n), err +} + +// ReadUint64 reads an uint64 r in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Read if r provides a ReadUint64 +// method. +func ReadUint64(r io.Reader) (uint64, error) { + if r, ok := r.(Uint64Reader); ok { + return r.ReadUint64() + } + + buf := make([]byte, 8) + _, err := io.ReadFull(r, buf) + if err != nil { + return 0, err + } + + return binary.BigEndian.Uint64(buf), nil +} + +// ReadInt64 reads an int64 r in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Read if r provides a ReadUint64 +// method. +func ReadInt64(r io.Reader) (int64, error) { + n, err := ReadUint64(r) + return int64(n), err +} diff --git a/pgio/write.go b/pgio/write.go new file mode 100644 index 00000000..823fbd00 --- /dev/null +++ b/pgio/write.go @@ -0,0 +1,97 @@ +package pgio + +import ( + "encoding/binary" + "io" +) + +type Uint16Writer interface { + WriteUint16(uint16) (n int, err error) +} + +type Uint32Writer interface { + WriteUint32(uint32) (n int, err error) +} + +type Uint64Writer interface { + WriteUint64(uint64) (n int, err error) +} + +// WriteByte writes b to w. +func WriteByte(w io.Writer, b byte) error { + if w, ok := w.(io.ByteWriter); ok { + return w.WriteByte(b) + } + _, err := w.Write([]byte{b}) + return err +} + +// WriteUint16 writes n to w in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Write if w provides a WriteUint16 +// method. +func WriteUint16(w io.Writer, n uint16) (int, error) { + if w, ok := w.(Uint16Writer); ok { + return w.WriteUint16(n) + } + b := make([]byte, 2) + binary.BigEndian.PutUint16(b, n) + return w.Write(b) +} + +// WriteInt16 writes n to w in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Write if w provides a WriteUint16 +// method. +func WriteInt16(w io.Writer, n int16) (int, error) { + return WriteUint16(w, uint16(n)) +} + +// WriteUint32 writes n to w in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Write if w provides a WriteUint32 +// method. +func WriteUint32(w io.Writer, n uint32) (int, error) { + if w, ok := w.(Uint32Writer); ok { + return w.WriteUint32(n) + } + b := make([]byte, 4) + binary.BigEndian.PutUint32(b, n) + return w.Write(b) +} + +// WriteInt32 writes n to w in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Write if w provides a WriteUint32 +// method. +func WriteInt32(w io.Writer, n int32) (int, error) { + return WriteUint32(w, uint32(n)) +} + +// WriteUint64 writes n to w in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Write if w provides a WriteUint64 +// method. +func WriteUint64(w io.Writer, n uint64) (int, error) { + if w, ok := w.(Uint64Writer); ok { + return w.WriteUint64(n) + } + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, n) + return w.Write(b) +} + +// WriteInt64 writes n to w in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Write if w provides a WriteUint64 +// method. +func WriteInt64(w io.Writer, n int64) (int, error) { + return WriteUint64(w, uint64(n)) +} + +// WriteCString writes s to w followed by a null byte. +func WriteCString(w io.Writer, s string) (int, error) { + n, err := io.WriteString(w, s) + if err != nil { + return n, err + } + err = WriteByte(w, 0) + if err != nil { + return n, err + } + return n + 1, nil +} diff --git a/pgtype/array.go b/pgtype/array.go new file mode 100644 index 00000000..75d2e440 --- /dev/null +++ b/pgtype/array.go @@ -0,0 +1,375 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + "strconv" + "unicode" + + "github.com/jackc/pgx/pgio" +) + +// Information on the internals of PostgreSQL arrays can be found in +// src/include/utils/array.h and src/backend/utils/adt/arrayfuncs.c. Of +// particular interest is the array_send function. + +type ArrayHeader struct { + ContainsNull bool + ElementOID int32 + Dimensions []ArrayDimension +} + +type ArrayDimension struct { + Length int32 + LowerBound int32 +} + +func (ah *ArrayHeader) DecodeBinary(r io.Reader) error { + numDims, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if numDims > 0 { + ah.Dimensions = make([]ArrayDimension, numDims) + } + + containsNull, err := pgio.ReadInt32(r) + if err != nil { + return err + } + ah.ContainsNull = containsNull == 1 + + ah.ElementOID, err = pgio.ReadInt32(r) + if err != nil { + return err + } + + for i := range ah.Dimensions { + ah.Dimensions[i].Length, err = pgio.ReadInt32(r) + if err != nil { + return err + } + + ah.Dimensions[i].LowerBound, err = pgio.ReadInt32(r) + if err != nil { + return err + } + } + + return nil +} + +func (ah *ArrayHeader) EncodeBinary(w io.Writer) error { + _, err := pgio.WriteInt32(w, int32(len(ah.Dimensions))) + if err != nil { + return err + } + + var containsNull int32 + if ah.ContainsNull { + containsNull = 1 + } + _, err = pgio.WriteInt32(w, containsNull) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, ah.ElementOID) + if err != nil { + return err + } + + for i := range ah.Dimensions { + _, err = pgio.WriteInt32(w, ah.Dimensions[i].Length) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, ah.Dimensions[i].LowerBound) + if err != nil { + return err + } + } + + return nil +} + +type UntypedTextArray struct { + Elements []string + Dimensions []ArrayDimension +} + +func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { + uta := &UntypedTextArray{} + + buf := bytes.NewBufferString(src) + + skipWhitespace(buf) + + r, _, err := buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + var explicitDimensions []ArrayDimension + + // Array has explicit dimensions + if r == '[' { + buf.UnreadRune() + + for { + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r == '=' { + break + } else if r != '[' { + return nil, fmt.Errorf("invalid array, expected '[' or '=' got %v", r) + } + + lower, err := arrayParseInteger(buf) + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r != ':' { + return nil, fmt.Errorf("invalid array, expected ':' got %v", r) + } + + upper, err := arrayParseInteger(buf) + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r != ']' { + return nil, fmt.Errorf("invalid array, expected ']' got %v", r) + } + + explicitDimensions = append(explicitDimensions, ArrayDimension{LowerBound: lower, Length: upper - lower + 1}) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + } + + if r != '{' { + return nil, fmt.Errorf("invalid array, expected '{': %v", err) + } + + implicitDimensions := []ArrayDimension{{LowerBound: 1, Length: 0}} + + // Consume all initial opening brackets. This provides number of dimensions. + for { + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + if r == '{' { + implicitDimensions[len(implicitDimensions)-1].Length = 1 + implicitDimensions = append(implicitDimensions, ArrayDimension{LowerBound: 1}) + } else { + buf.UnreadRune() + break + } + } + currentDim := len(implicitDimensions) - 1 + counterDim := currentDim + + for { + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid array: %v", err) + } + + switch r { + case '{': + if currentDim == counterDim { + implicitDimensions[currentDim].Length++ + } + currentDim++ + case ',': + case '}': + currentDim-- + if currentDim < counterDim { + counterDim = currentDim + } + default: + buf.UnreadRune() + value, err := arrayParseValue(buf) + if err != nil { + return nil, fmt.Errorf("invalid array value: %v", err) + } + if currentDim == counterDim { + implicitDimensions[currentDim].Length++ + } + uta.Elements = append(uta.Elements, value) + } + + if currentDim < 0 { + break + } + } + + skipWhitespace(buf) + + if buf.Len() > 0 { + return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) + } + + if len(uta.Elements) == 0 { + uta.Dimensions = nil + } else if len(explicitDimensions) > 0 { + uta.Dimensions = explicitDimensions + } else { + uta.Dimensions = implicitDimensions + } + + return uta, nil +} + +func skipWhitespace(buf *bytes.Buffer) { + var r rune + var err error + for r, _, _ = buf.ReadRune(); unicode.IsSpace(r); r, _, _ = buf.ReadRune() { + } + + if err != io.EOF { + buf.UnreadRune() + } +} + +func arrayParseValue(buf *bytes.Buffer) (string, error) { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + if r == '"' { + return arrayParseQuotedValue(buf) + } + buf.UnreadRune() + + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case ',', '}': + buf.UnreadRune() + return s.String(), nil + } + + s.WriteRune(r) + } +} + +func arrayParseQuotedValue(buf *bytes.Buffer) (string, error) { + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case '\\': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + case '"': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + buf.UnreadRune() + return s.String(), nil + } + s.WriteRune(r) + } +} + +func arrayParseInteger(buf *bytes.Buffer) (int32, error) { + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return 0, err + } + + if '0' <= r && r <= '9' { + s.WriteRune(r) + } else { + buf.UnreadRune() + n, err := strconv.ParseInt(s.String(), 10, 32) + if err != nil { + return 0, err + } + return int32(n), nil + } + } +} + +func EncodeTextArrayDimensions(w io.Writer, dimensions []ArrayDimension) error { + var customDimensions bool + for _, dim := range dimensions { + if dim.LowerBound != 1 { + customDimensions = true + } + } + + if !customDimensions { + return nil + } + + for _, dim := range dimensions { + err := pgio.WriteByte(w, '[') + if err != nil { + return err + } + + _, err = io.WriteString(w, strconv.FormatInt(int64(dim.LowerBound), 10)) + if err != nil { + return err + } + + err = pgio.WriteByte(w, ':') + if err != nil { + return err + } + + _, err = io.WriteString(w, strconv.FormatInt(int64(dim.LowerBound+dim.Length-1), 10)) + if err != nil { + return err + } + + err = pgio.WriteByte(w, ']') + if err != nil { + return err + } + } + + return pgio.WriteByte(w, '=') +} diff --git a/pgtype/array_test.go b/pgtype/array_test.go new file mode 100644 index 00000000..5e5f00e7 --- /dev/null +++ b/pgtype/array_test.go @@ -0,0 +1,98 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestParseUntypedTextArray(t *testing.T) { + tests := []struct { + source string + result pgtype.UntypedTextArray + }{ + { + source: "{}", + result: pgtype.UntypedTextArray{ + Elements: nil, + Dimensions: nil, + }, + }, + { + source: "{1}", + result: pgtype.UntypedTextArray{ + Elements: []string{"1"}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: "{a,b}", + result: pgtype.UntypedTextArray{ + Elements: []string{"a", "b"}, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + }, + }, + { + source: `{"NULL"}`, + result: pgtype.UntypedTextArray{ + Elements: []string{"NULL"}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: `{"He said, \"Hello.\""}`, + result: pgtype.UntypedTextArray{ + Elements: []string{`He said, "Hello."`}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, + { + source: "{{a,b},{c,d},{e,f}}", + result: pgtype.UntypedTextArray{ + Elements: []string{"a", "b", "c", "d", "e", "f"}, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + }, + }, + { + source: "{{{a,b},{c,d},{e,f}},{{a,b},{c,d},{e,f}}}", + result: pgtype.UntypedTextArray{ + Elements: []string{"a", "b", "c", "d", "e", "f", "a", "b", "c", "d", "e", "f"}, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 1}, + {Length: 3, LowerBound: 1}, + {Length: 2, LowerBound: 1}, + }, + }, + }, + { + source: "[4:4]={1}", + result: pgtype.UntypedTextArray{ + Elements: []string{"1"}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 4}}, + }, + }, + { + source: "[4:5][2:3]={{a,b},{c,d}}", + result: pgtype.UntypedTextArray{ + Elements: []string{"a", "b", "c", "d"}, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + }, + }, + } + + for i, tt := range tests { + r, err := pgtype.ParseUntypedTextArray(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + continue + } + + if !reflect.DeepEqual(*r, tt.result) { + t.Errorf("%d: expected %+v to be parsed to %+v, but it was %+v", i, tt.source, tt.result, *r) + } + } +} diff --git a/pgtype/bool.go b/pgtype/bool.go new file mode 100644 index 00000000..81c72472 --- /dev/null +++ b/pgtype/bool.go @@ -0,0 +1,166 @@ +package pgtype + +import ( + "fmt" + "io" + "reflect" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +type Bool struct { + Bool bool + Status Status +} + +func (b *Bool) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Bool: + *b = value + case bool: + *b = Bool{Bool: value, Status: Present} + case string: + bb, err := strconv.ParseBool(value) + if err != nil { + return err + } + *b = Bool{Bool: bb, Status: Present} + default: + if originalSrc, ok := underlyingBoolType(src); ok { + return b.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Bool", value) + } + + return nil +} + +func (b *Bool) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *bool: + if b.Status != Present { + return fmt.Errorf("cannot assign %v to %T", b, dst) + } + *v = b.Bool + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if b.Status == Null { + if !el.IsNil() { + // if the destination pointer is not nil, nil it out + el.Set(reflect.Zero(el.Type())) + } + return nil + } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return b.AssignTo(el.Interface()) + case reflect.Bool: + if b.Status != Present { + return fmt.Errorf("cannot assign %v to %T", b, dst) + } + el.SetBool(b.Bool) + return nil + } + } + return fmt.Errorf("cannot put decode %v into %T", b, dst) + } + + return nil +} + +func (b *Bool) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *b = Bool{Status: Null} + return nil + } + + if size != 1 { + return fmt.Errorf("invalid length for bool: %v", size) + } + + byt, err := pgio.ReadByte(r) + if err != nil { + return err + } + + *b = Bool{Bool: byt == 't', Status: Present} + return nil +} + +func (b *Bool) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *b = Bool{Status: Null} + return nil + } + + if size != 1 { + return fmt.Errorf("invalid length for bool: %v", size) + } + + byt, err := pgio.ReadByte(r) + if err != nil { + return err + } + + *b = Bool{Bool: byt == 1, Status: Present} + return nil +} + +func (b Bool) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, b.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 1) + if err != nil { + return nil + } + + var buf []byte + if b.Bool { + buf = []byte{'t'} + } else { + buf = []byte{'f'} + } + + _, err = w.Write(buf) + return err +} + +func (b Bool) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, b.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 1) + if err != nil { + return nil + } + + var buf []byte + if b.Bool { + buf = []byte{1} + } else { + buf = []byte{0} + } + + _, err = w.Write(buf) + return err +} diff --git a/pgtype/bool_test.go b/pgtype/bool_test.go new file mode 100644 index 00000000..53df1747 --- /dev/null +++ b/pgtype/bool_test.go @@ -0,0 +1,43 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestBoolTranscode(t *testing.T) { + testSuccessfulTranscode(t, "bool", []interface{}{ + pgtype.Bool{Bool: false, Status: pgtype.Present}, + pgtype.Bool{Bool: true, Status: pgtype.Present}, + pgtype.Bool{Bool: false, Status: pgtype.Null}, + }) +} + +func TestBoolConvertFrom(t *testing.T) { + type _int8 int8 + + successfulTests := []struct { + source interface{} + result pgtype.Bool + }{ + {source: true, result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: false, result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, + {source: "true", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: "false", result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, + {source: "t", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: "f", result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Bool + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/pgtype/convert.go b/pgtype/convert.go new file mode 100644 index 00000000..3f3d9e5f --- /dev/null +++ b/pgtype/convert.go @@ -0,0 +1,239 @@ +package pgtype + +import ( + "fmt" + "math" + "reflect" + "time" +) + +const maxUint = ^uint(0) +const maxInt = int(maxUint >> 1) +const minInt = -maxInt - 1 + +// underlyingIntType gets the underlying type that can be converted to Int2, Int4, or Int8 +func underlyingIntType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Int: + convVal := int(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int8: + convVal := int8(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int16: + convVal := int16(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int32: + convVal := int32(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Int64: + convVal := int64(refVal.Int()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint: + convVal := uint(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint8: + convVal := uint8(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint16: + convVal := uint16(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint32: + convVal := uint32(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Uint64: + convVal := uint64(refVal.Uint()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.String: + convVal := refVal.String() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + + return nil, false +} + +// underlyingBoolType gets the underlying type that can be converted to Bool +func underlyingBoolType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Bool: + convVal := refVal.Bool() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + + return nil, false +} + +// underlyingTimeType gets the underlying type that can be converted to time.Time +func underlyingTimeType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return time.Time{}, false + } + convVal := refVal.Elem().Interface() + return convVal, true + } + + timeType := reflect.TypeOf(time.Time{}) + if refVal.Type().ConvertibleTo(timeType) { + return refVal.Convert(timeType).Interface(), true + } + + return time.Time{}, false +} + +// underlyingSliceType gets the underlying slice type +func underlyingSliceType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Slice: + baseSliceType := reflect.SliceOf(refVal.Type().Elem()) + if refVal.Type().ConvertibleTo(baseSliceType) { + convVal := refVal.Convert(baseSliceType) + return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type() + } + } + + return nil, false +} + +func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { + if srcStatus == Present { + switch v := dst.(type) { + case *int: + if srcVal < int64(minInt) { + return fmt.Errorf("%d is less than minimum value for int", srcVal) + } else if srcVal > int64(maxInt) { + return fmt.Errorf("%d is greater than maximum value for int", srcVal) + } + *v = int(srcVal) + case *int8: + if srcVal < math.MinInt8 { + return fmt.Errorf("%d is less than minimum value for int8", srcVal) + } else if srcVal > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for int8", srcVal) + } + *v = int8(srcVal) + case *int16: + if srcVal < math.MinInt16 { + return fmt.Errorf("%d is less than minimum value for int16", srcVal) + } else if srcVal > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for int16", srcVal) + } + *v = int16(srcVal) + case *int32: + if srcVal < math.MinInt32 { + return fmt.Errorf("%d is less than minimum value for int32", srcVal) + } else if srcVal > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for int32", srcVal) + } + *v = int32(srcVal) + case *int64: + if srcVal < math.MinInt64 { + return fmt.Errorf("%d is less than minimum value for int64", srcVal) + } else if srcVal > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for int64", srcVal) + } + *v = int64(srcVal) + case *uint: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint", srcVal) + } else if uint64(srcVal) > uint64(maxUint) { + return fmt.Errorf("%d is greater than maximum value for uint", srcVal) + } + *v = uint(srcVal) + case *uint8: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint8", srcVal) + } else if srcVal > math.MaxUint8 { + return fmt.Errorf("%d is greater than maximum value for uint8", srcVal) + } + *v = uint8(srcVal) + case *uint16: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint32", srcVal) + } else if srcVal > math.MaxUint16 { + return fmt.Errorf("%d is greater than maximum value for uint16", srcVal) + } + *v = uint16(srcVal) + case *uint32: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint32", srcVal) + } else if srcVal > math.MaxUint32 { + return fmt.Errorf("%d is greater than maximum value for uint32", srcVal) + } + *v = uint32(srcVal) + case *uint64: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for uint64", srcVal) + } + *v = uint64(srcVal) + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return int64AssignTo(srcVal, srcStatus, el.Interface()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if el.OverflowInt(int64(srcVal)) { + return fmt.Errorf("cannot put %d into %T", srcVal, dst) + } + el.SetInt(int64(srcVal)) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if srcVal < 0 { + return fmt.Errorf("%d is less than zero for %T", srcVal, dst) + } + if el.OverflowUint(uint64(srcVal)) { + return fmt.Errorf("cannot put %d into %T", srcVal, dst) + } + el.SetUint(uint64(srcVal)) + return nil + } + } + return fmt.Errorf("cannot assign %v into %T", srcVal, dst) + } + return nil + } + + // if dst is a pointer to pointer and srcStatus is not Present, nil it out + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + if el.Kind() == reflect.Ptr { + el.Set(reflect.Zero(el.Type())) + return nil + } + } + + return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) +} diff --git a/pgtype/date.go b/pgtype/date.go new file mode 100644 index 00000000..f3e3e4c6 --- /dev/null +++ b/pgtype/date.go @@ -0,0 +1,191 @@ +package pgtype + +import ( + "fmt" + "io" + "reflect" + "time" + + "github.com/jackc/pgx/pgio" +) + +type Date struct { + Time time.Time + Status Status + InfinityModifier +} + +const ( + negativeInfinityDayOffset = -2147483648 + infinityDayOffset = 2147483647 +) + +func (d *Date) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Date: + *d = value + case time.Time: + *d = Date{Time: value, Status: Present} + default: + if originalSrc, ok := underlyingTimeType(src); ok { + return d.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Date", value) + } + + return nil +} + +func (d *Date) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *time.Time: + if d.Status != Present || d.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", d, dst) + } + *v = d.Time + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if d.Status == Null { + if !el.IsNil() { + // if the destination pointer is not nil, nil it out + el.Set(reflect.Zero(el.Type())) + } + return nil + } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return d.AssignTo(el.Interface()) + } + } + return fmt.Errorf("cannot decode %v into %T", d, dst) + } + + return nil +} + +func (d *Date) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *d = Date{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + sbuf := string(buf) + switch sbuf { + case "infinity": + *d = Date{Status: Present, InfinityModifier: Infinity} + case "-infinity": + *d = Date{Status: Present, InfinityModifier: -Infinity} + default: + t, err := time.ParseInLocation("2006-01-02", sbuf, time.UTC) + if err != nil { + return err + } + + *d = Date{Time: t, Status: Present} + } + + return nil +} + +func (d *Date) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *d = Date{Status: Null} + return nil + } + + if size != 4 { + return fmt.Errorf("invalid length for date: %v", size) + } + + dayOffset, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + switch dayOffset { + case infinityDayOffset: + *d = Date{Status: Present, InfinityModifier: Infinity} + case negativeInfinityDayOffset: + *d = Date{Status: Present, InfinityModifier: -Infinity} + default: + t := time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.UTC) + *d = Date{Time: t, Status: Present} + } + + return nil +} + +func (d Date) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, d.Status); done { + return err + } + + var s string + + switch d.InfinityModifier { + case None: + s = d.Time.Format("2006-01-02") + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + + _, err = w.Write([]byte(s)) + return err +} + +func (d Date) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, d.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 4) + if err != nil { + return err + } + + var daysSinceDateEpoch int32 + switch d.InfinityModifier { + case None: + tUnix := time.Date(d.Time.Year(), d.Time.Month(), d.Time.Day(), 0, 0, 0, 0, time.UTC).Unix() + dateEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).Unix() + + secSinceDateEpoch := tUnix - dateEpoch + daysSinceDateEpoch = int32(secSinceDateEpoch / 86400) + case Infinity: + daysSinceDateEpoch = infinityDayOffset + case NegativeInfinity: + daysSinceDateEpoch = negativeInfinityDayOffset + } + + _, err = pgio.WriteInt32(w, daysSinceDateEpoch) + return err +} diff --git a/pgtype/date_test.go b/pgtype/date_test.go new file mode 100644 index 00000000..c3e971d0 --- /dev/null +++ b/pgtype/date_test.go @@ -0,0 +1,51 @@ +package pgtype_test + +import ( + "testing" + "time" + + "github.com/jackc/pgx/pgtype" +) + +func TestDateTranscode(t *testing.T) { + testSuccessfulTranscode(t, "date", []interface{}{ + pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Status: pgtype.Null}, + pgtype.Date{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, + pgtype.Date{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, + }) +} + +func TestDateConvertFrom(t *testing.T) { + type _time time.Time + + successfulTests := []struct { + source interface{} + result pgtype.Date + }{ + {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)), result: pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var d pgtype.Date + err := d.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if d != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} diff --git a/pgtype/extra-interface.txt b/pgtype/extra-interface.txt new file mode 100644 index 00000000..16453823 --- /dev/null +++ b/pgtype/extra-interface.txt @@ -0,0 +1,3 @@ +Can pass function to get inet data and function to get oid/name mapping as optional interface with io.Reader or io.Writer + +Could be useful for arrays of types without defined OIDs like hstore. diff --git a/pgtype/int2.go b/pgtype/int2.go new file mode 100644 index 00000000..2da8a96d --- /dev/null +++ b/pgtype/int2.go @@ -0,0 +1,167 @@ +package pgtype + +import ( + "fmt" + "io" + "math" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +type Int2 struct { + Int int16 + Status Status +} + +func (i *Int2) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Int2: + *i = value + case int8: + *i = Int2{Int: int16(value), Status: Present} + case uint8: + *i = Int2{Int: int16(value), Status: Present} + case int16: + *i = Int2{Int: int16(value), Status: Present} + case uint16: + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *i = Int2{Int: int16(value), Status: Present} + case int32: + if value < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *i = Int2{Int: int16(value), Status: Present} + case uint32: + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *i = Int2{Int: int16(value), Status: Present} + case int64: + if value < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *i = Int2{Int: int16(value), Status: Present} + case uint64: + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *i = Int2{Int: int16(value), Status: Present} + case int: + if value < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *i = Int2{Int: int16(value), Status: Present} + case uint: + if value > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", value) + } + *i = Int2{Int: int16(value), Status: Present} + case string: + num, err := strconv.ParseInt(value, 10, 16) + if err != nil { + return err + } + *i = Int2{Int: int16(num), Status: Present} + default: + if originalSrc, ok := underlyingIntType(src); ok { + return i.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int2", value) + } + + return nil +} + +func (i *Int2) AssignTo(dst interface{}) error { + return int64AssignTo(int64(i.Int), i.Status, dst) +} + +func (i *Int2) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *i = Int2{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + n, err := strconv.ParseInt(string(buf), 10, 16) + if err != nil { + return err + } + + *i = Int2{Int: int16(n), Status: Present} + return nil +} + +func (i *Int2) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *i = Int2{Status: Null} + return nil + } + + if size != 2 { + return fmt.Errorf("invalid length for int2: %v", size) + } + + n, err := pgio.ReadInt16(r) + if err != nil { + return err + } + + *i = Int2{Int: int16(n), Status: Present} + return nil +} + +func (i Int2) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, i.Status); done { + return err + } + + s := strconv.FormatInt(int64(i.Int), 10) + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + _, err = w.Write([]byte(s)) + return err +} + +func (i Int2) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, i.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = pgio.WriteInt16(w, i.Int) + return err +} diff --git a/pgtype/int2_test.go b/pgtype/int2_test.go new file mode 100644 index 00000000..a8493a16 --- /dev/null +++ b/pgtype/int2_test.go @@ -0,0 +1,55 @@ +package pgtype_test + +import ( + "math" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInt2Transcode(t *testing.T) { + testSuccessfulTranscode(t, "int2", []interface{}{ + pgtype.Int2{Int: math.MinInt16, Status: pgtype.Present}, + pgtype.Int2{Int: -1, Status: pgtype.Present}, + pgtype.Int2{Int: 0, Status: pgtype.Present}, + pgtype.Int2{Int: 1, Status: pgtype.Present}, + pgtype.Int2{Int: math.MaxInt16, Status: pgtype.Present}, + pgtype.Int2{Int: 0, Status: pgtype.Null}, + }) +} + +func TestInt2ConvertFrom(t *testing.T) { + type _int8 int8 + + successfulTests := []struct { + source interface{} + result pgtype.Int2 + }{ + {source: int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Int2 + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/pgtype/int2array.go b/pgtype/int2array.go new file mode 100644 index 00000000..86375516 --- /dev/null +++ b/pgtype/int2array.go @@ -0,0 +1,308 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type Int2Array struct { + Elements []Int2 + Dimensions []ArrayDimension + Status Status +} + +func (a *Int2Array) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Int2Array: + *a = value + case []int16: + if value == nil { + *a = Int2Array{Status: Null} + } else if len(value) == 0 { + *a = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *a = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []uint16: + if value == nil { + *a = Int2Array{Status: Null} + } else if len(value) == 0 { + *a = Int2Array{Status: Present} + } else { + elements := make([]Int2, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *a = Int2Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return a.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int2", value) + } + + return nil +} + +func (a *Int2Array) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *[]int16: + if a.Status == Present { + *v = make([]int16, len(a.Elements)) + for i := range a.Elements { + if err := a.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + case *[]uint16: + if a.Status == Present { + *v = make([]uint16, len(a.Elements)) + for i := range a.Elements { + if err := a.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + default: + return fmt.Errorf("cannot put decode %v into %T", a, dst) + } + + return nil +} + +func (a *Int2Array) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *a = Int2Array{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + uta, err := ParseUntypedTextArray(string(buf)) + if err != nil { + return err + } + + textElementReader := NewTextElementReader(r) + var elements []Int2 + + if len(uta.Elements) > 0 { + elements = make([]Int2, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Int2 + textElementReader.Reset(s) + err = elem.DecodeText(textElementReader) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *a = Int2Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (a *Int2Array) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *a = Int2Array{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + err = arrayHeader.DecodeBinary(r) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *a = Int2Array{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Int2, elementCount) + + for i := range elements { + err = elements[i].DecodeBinary(r) + if err != nil { + return err + } + } + + *a = Int2Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (a *Int2Array) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, a.Status); done { + return err + } + + if len(a.Dimensions) == 0 { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = w.Write([]byte("{}")) + return err + } + + buf := &bytes.Buffer{} + + err := EncodeTextArrayDimensions(buf, a.Dimensions) + if err != nil { + return err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(a.Dimensions)) + dimElemCounts[len(a.Dimensions)-1] = int(a.Dimensions[len(a.Dimensions)-1].Length) + for i := len(a.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(a.Dimensions[i].Length) * dimElemCounts[i+1] + } + + textElementWriter := NewTextElementWriter(buf) + + for i, elem := range a.Elements { + if i > 0 { + err = pgio.WriteByte(buf, ',') + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(buf, '{') + if err != nil { + return err + } + } + } + + textElementWriter.Reset() + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(buf, '}') + if err != nil { + return err + } + } + } + } + + _, err = pgio.WriteInt32(w, int32(buf.Len())) + if err != nil { + return err + } + + _, err = buf.WriteTo(w) + return err +} + +func (a *Int2Array) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, a.Status); done { + return err + } + + var arrayHeader ArrayHeader + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + elemBuf := &bytes.Buffer{} + + for i := range a.Elements { + err := a.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return err + } + if a.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + } + } + + arrayHeader.ElementOID = Int2OID + arrayHeader.Dimensions = a.Dimensions + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + headerBuf := &bytes.Buffer{} + err := arrayHeader.EncodeBinary(headerBuf) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) + if err != nil { + return err + } + + _, err = headerBuf.WriteTo(w) + if err != nil { + return err + } + + _, err = elemBuf.WriteTo(w) + if err != nil { + return err + } + + return err +} diff --git a/pgtype/int2array_test.go b/pgtype/int2array_test.go new file mode 100644 index 00000000..5ea81990 --- /dev/null +++ b/pgtype/int2array_test.go @@ -0,0 +1,87 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInt2ArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "int2[]", []interface{}{ + &pgtype.Int2Array{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.Int2Array{ + Elements: []pgtype.Int2{ + pgtype.Int2{Int: 1, Status: pgtype.Present}, + pgtype.Int2{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int2Array{Status: pgtype.Null}, + &pgtype.Int2Array{ + Elements: []pgtype.Int2{ + pgtype.Int2{Int: 1, Status: pgtype.Present}, + pgtype.Int2{Int: 2, Status: pgtype.Present}, + pgtype.Int2{Int: 3, Status: pgtype.Present}, + pgtype.Int2{Int: 4, Status: pgtype.Present}, + pgtype.Int2{Status: pgtype.Null}, + pgtype.Int2{Int: 6, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int2Array{ + Elements: []pgtype.Int2{ + pgtype.Int2{Int: 1, Status: pgtype.Present}, + pgtype.Int2{Int: 2, Status: pgtype.Present}, + pgtype.Int2{Int: 3, Status: pgtype.Present}, + pgtype.Int2{Int: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +// func TestInt2ConvertFrom(t *testing.T) { +// type _int8 int8 + +// successfulTests := []struct { +// source interface{} +// result pgtype.Int2 +// }{ +// {source: int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: int16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: int32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: int64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: int8(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, +// {source: int16(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, +// {source: int32(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, +// {source: int64(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, +// {source: uint8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: uint16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: uint32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: uint64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: "1", result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// {source: _int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, +// } + +// for i, tt := range successfulTests { +// var r pgtype.Int2 +// err := r.ConvertFrom(tt.source) +// if err != nil { +// t.Errorf("%d: %v", i, err) +// } + +// if r != tt.result { +// t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) +// } +// } +// } diff --git a/pgtype/int4.go b/pgtype/int4.go new file mode 100644 index 00000000..84c45522 --- /dev/null +++ b/pgtype/int4.go @@ -0,0 +1,158 @@ +package pgtype + +import ( + "fmt" + "io" + "math" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +type Int4 struct { + Int int32 + Status Status +} + +func (i *Int4) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Int4: + *i = value + case int8: + *i = Int4{Int: int32(value), Status: Present} + case uint8: + *i = Int4{Int: int32(value), Status: Present} + case int16: + *i = Int4{Int: int32(value), Status: Present} + case uint16: + *i = Int4{Int: int32(value), Status: Present} + case int32: + *i = Int4{Int: int32(value), Status: Present} + case uint32: + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *i = Int4{Int: int32(value), Status: Present} + case int64: + if value < math.MinInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *i = Int4{Int: int32(value), Status: Present} + case uint64: + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *i = Int4{Int: int32(value), Status: Present} + case int: + if value < math.MinInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *i = Int4{Int: int32(value), Status: Present} + case uint: + if value > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", value) + } + *i = Int4{Int: int32(value), Status: Present} + case string: + num, err := strconv.ParseInt(value, 10, 32) + if err != nil { + return err + } + *i = Int4{Int: int32(num), Status: Present} + default: + if originalSrc, ok := underlyingIntType(src); ok { + return i.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int8", value) + } + + return nil +} + +func (i *Int4) AssignTo(dst interface{}) error { + return int64AssignTo(int64(i.Int), i.Status, dst) +} + +func (i *Int4) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *i = Int4{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + n, err := strconv.ParseInt(string(buf), 10, 32) + if err != nil { + return err + } + + *i = Int4{Int: int32(n), Status: Present} + return nil +} + +func (i *Int4) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *i = Int4{Status: Null} + return nil + } + + if size != 4 { + return fmt.Errorf("invalid length for int4: %v", size) + } + + n, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + *i = Int4{Int: n, Status: Present} + return nil +} + +func (i Int4) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, i.Status); done { + return err + } + + s := strconv.FormatInt(int64(i.Int), 10) + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + _, err = w.Write([]byte(s)) + return err +} + +func (i Int4) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, i.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 4) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, i.Int) + return err +} diff --git a/pgtype/int4_test.go b/pgtype/int4_test.go new file mode 100644 index 00000000..04411849 --- /dev/null +++ b/pgtype/int4_test.go @@ -0,0 +1,55 @@ +package pgtype_test + +import ( + "math" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInt4Transcode(t *testing.T) { + testSuccessfulTranscode(t, "int4", []interface{}{ + pgtype.Int4{Int: math.MinInt32, Status: pgtype.Present}, + pgtype.Int4{Int: -1, Status: pgtype.Present}, + pgtype.Int4{Int: 0, Status: pgtype.Present}, + pgtype.Int4{Int: 1, Status: pgtype.Present}, + pgtype.Int4{Int: math.MaxInt32, Status: pgtype.Present}, + pgtype.Int4{Int: 0, Status: pgtype.Null}, + }) +} + +func TestInt4ConvertFrom(t *testing.T) { + type _int8 int8 + + successfulTests := []struct { + source interface{} + result pgtype.Int4 + }{ + {source: int8(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.Int4{Int: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.Int4{Int: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Int4 + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/pgtype/int8.go b/pgtype/int8.go new file mode 100644 index 00000000..c0e14e44 --- /dev/null +++ b/pgtype/int8.go @@ -0,0 +1,149 @@ +package pgtype + +import ( + "fmt" + "io" + "math" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +type Int8 struct { + Int int64 + Status Status +} + +func (i *Int8) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Int8: + *i = value + case int8: + *i = Int8{Int: int64(value), Status: Present} + case uint8: + *i = Int8{Int: int64(value), Status: Present} + case int16: + *i = Int8{Int: int64(value), Status: Present} + case uint16: + *i = Int8{Int: int64(value), Status: Present} + case int32: + *i = Int8{Int: int64(value), Status: Present} + case uint32: + *i = Int8{Int: int64(value), Status: Present} + case int64: + *i = Int8{Int: int64(value), Status: Present} + case uint64: + if value > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", value) + } + *i = Int8{Int: int64(value), Status: Present} + case int: + if int64(value) < math.MinInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", value) + } + if int64(value) > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", value) + } + *i = Int8{Int: int64(value), Status: Present} + case uint: + if uint64(value) > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", value) + } + *i = Int8{Int: int64(value), Status: Present} + case string: + num, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return err + } + *i = Int8{Int: num, Status: Present} + default: + if originalSrc, ok := underlyingIntType(src); ok { + return i.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int8", value) + } + + return nil +} + +func (i *Int8) AssignTo(dst interface{}) error { + return int64AssignTo(int64(i.Int), i.Status, dst) +} + +func (i *Int8) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *i = Int8{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + n, err := strconv.ParseInt(string(buf), 10, 64) + if err != nil { + return err + } + + *i = Int8{Int: n, Status: Present} + return nil +} + +func (i *Int8) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *i = Int8{Status: Null} + return nil + } + + if size != 8 { + return fmt.Errorf("invalid length for int8: %v", size) + } + + n, err := pgio.ReadInt64(r) + if err != nil { + return err + } + + *i = Int8{Int: n, Status: Present} + return nil +} + +func (i Int8) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, i.Status); done { + return err + } + + s := strconv.FormatInt(i.Int, 10) + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + _, err = w.Write([]byte(s)) + return err +} + +func (i Int8) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, i.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 8) + if err != nil { + return err + } + + _, err = pgio.WriteInt64(w, i.Int) + return err +} diff --git a/pgtype/int8_test.go b/pgtype/int8_test.go new file mode 100644 index 00000000..ba246224 --- /dev/null +++ b/pgtype/int8_test.go @@ -0,0 +1,55 @@ +package pgtype_test + +import ( + "math" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInt8Transcode(t *testing.T) { + testSuccessfulTranscode(t, "int8", []interface{}{ + pgtype.Int8{Int: math.MinInt64, Status: pgtype.Present}, + pgtype.Int8{Int: -1, Status: pgtype.Present}, + pgtype.Int8{Int: 0, Status: pgtype.Present}, + pgtype.Int8{Int: 1, Status: pgtype.Present}, + pgtype.Int8{Int: math.MaxInt64, Status: pgtype.Present}, + pgtype.Int8{Int: 0, Status: pgtype.Null}, + }) +} + +func TestInt8ConvertFrom(t *testing.T) { + type _int8 int8 + + successfulTests := []struct { + source interface{} + result pgtype.Int8 + }{ + {source: int8(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.Int8{Int: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.Int8{Int: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Int8 + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go new file mode 100644 index 00000000..f9833363 --- /dev/null +++ b/pgtype/pgtype.go @@ -0,0 +1,102 @@ +package pgtype + +import ( + "errors" + "io" + + "github.com/jackc/pgx/pgio" +) + +// PostgreSQL oids for common types +const ( + BoolOID = 16 + ByteaOID = 17 + CharOID = 18 + NameOID = 19 + Int8OID = 20 + Int2OID = 21 + Int4OID = 23 + TextOID = 25 + OIDOID = 26 + TidOID = 27 + XidOID = 28 + CidOID = 29 + JSONOID = 114 + CidrOID = 650 + CidrArrayOID = 651 + Float4OID = 700 + Float8OID = 701 + UnknownOID = 705 + InetOID = 869 + BoolArrayOID = 1000 + Int2ArrayOID = 1005 + Int4ArrayOID = 1007 + TextArrayOID = 1009 + ByteaArrayOID = 1001 + VarcharArrayOID = 1015 + Int8ArrayOID = 1016 + Float4ArrayOID = 1021 + Float8ArrayOID = 1022 + AclItemOID = 1033 + AclItemArrayOID = 1034 + InetArrayOID = 1041 + VarcharOID = 1043 + DateOID = 1082 + TimestampOID = 1114 + TimestampArrayOID = 1115 + TimestampTzOID = 1184 + TimestampTzArrayOID = 1185 + RecordOID = 2249 + UUIDOID = 2950 + JSONBOID = 3802 +) + +type Status byte + +const ( + Undefined Status = iota + Null + Present +) + +type InfinityModifier int8 + +const ( + Infinity InfinityModifier = 1 + None InfinityModifier = 0 + NegativeInfinity InfinityModifier = -Infinity +) + +type Value interface { + ConvertFrom(src interface{}) error + AssignTo(dst interface{}) error +} + +type BinaryDecoder interface { + DecodeBinary(r io.Reader) error +} + +type TextDecoder interface { + DecodeText(r io.Reader) error +} + +type BinaryEncoder interface { + EncodeBinary(w io.Writer) error +} + +type TextEncoder interface { + EncodeText(w io.Writer) error +} + +var errUndefined = errors.New("cannot encode status undefined") + +func encodeNotPresent(w io.Writer, status Status) (done bool, err error) { + switch status { + case Undefined: + return true, errUndefined + case Null: + _, err = pgio.WriteInt32(w, -1) + return true, err + } + return false, nil +} diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go new file mode 100644 index 00000000..a1a575f7 --- /dev/null +++ b/pgtype/pgtype_test.go @@ -0,0 +1,108 @@ +package pgtype_test + +import ( + "fmt" + "io" + "os" + "reflect" + "testing" + + "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" +) + +func mustConnectPgx(t testing.TB) *pgx.Conn { + config, err := pgx.ParseURI(os.Getenv("DATABASE_URL")) + if err != nil { + t.Fatal(err) + } + + conn, err := pgx.Connect(config) + if err != nil { + t.Fatal(err) + } + + return conn +} + +func mustClose(t testing.TB, conn interface { + Close() error +}) { + err := conn.Close() + if err != nil { + t.Fatal(err) + } +} + +type forceTextEncoder struct { + e pgtype.TextEncoder +} + +func (f forceTextEncoder) EncodeText(w io.Writer) error { + return f.e.EncodeText(w) +} + +type forceBinaryEncoder struct { + e pgtype.BinaryEncoder +} + +func (f forceBinaryEncoder) EncodeBinary(w io.Writer) error { + return f.e.EncodeBinary(w) +} + +func forceEncoder(e interface{}, formatCode int16) interface{} { + switch formatCode { + case pgx.TextFormatCode: + return forceTextEncoder{e: e.(pgtype.TextEncoder)} + case pgx.BinaryFormatCode: + return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)} + default: + panic("bad encoder") + } +} + +func testSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface{}) { + testSuccessfulTranscodeEqFunc(t, pgTypeName, values, func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + }) +} + +func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + conn := mustConnectPgx(t) + defer mustClose(t, conn) + + ps, err := conn.Prepare("test", fmt.Sprintf("select $1::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for _, fc := range formats { + ps.FieldDescriptions[0].FormatCode = fc.formatCode + for i, v := range values { + // Derefence value if it is a pointer + derefV := v + refVal := reflect.ValueOf(v) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err := conn.QueryRow("test", forceEncoder(v, fc.formatCode)).Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", fc.name, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface()) + } + } + } +} diff --git a/pgtype/text_element.go b/pgtype/text_element.go new file mode 100644 index 00000000..1a585d08 --- /dev/null +++ b/pgtype/text_element.go @@ -0,0 +1,112 @@ +package pgtype + +import ( + "bytes" + "errors" + "io" + + "github.com/jackc/pgx/pgio" +) + +// TextElementWriter is a wrapper that makes TextEncoders composable into other +// TextEncoders. TextEncoder first writes the length of the subsequent value. +// This is not necessary when the value is part of another value such as an +// array. TextElementWriter requires one int32 to be written first which it +// ignores. No other integer writes are valid. +type TextElementWriter struct { + w io.Writer + lengthHeaderIgnored bool +} + +func NewTextElementWriter(w io.Writer) *TextElementWriter { + return &TextElementWriter{w: w} +} + +func (w *TextElementWriter) WriteUint16(n uint16) (int, error) { + return 0, errors.New("WriteUint16 should never be called on TextElementWriter") +} + +func (w *TextElementWriter) WriteUint32(n uint32) (int, error) { + if !w.lengthHeaderIgnored { + w.lengthHeaderIgnored = true + + if int32(n) == -1 { + return io.WriteString(w.w, "NULL") + } + + return 4, nil + } + + return 0, errors.New("WriteUint32 should only be called once on TextElementWriter") +} + +func (w *TextElementWriter) WriteUint64(n uint64) (int, error) { + if w.lengthHeaderIgnored { + return pgio.WriteUint64(w.w, n) + } + + return 0, errors.New("WriteUint64 should never be called on TextElementWriter") +} + +func (w *TextElementWriter) Write(buf []byte) (int, error) { + if w.lengthHeaderIgnored { + return w.w.Write(buf) + } + + return 0, errors.New("int32 must be written first") +} + +func (w *TextElementWriter) Reset() { + w.lengthHeaderIgnored = false +} + +// TextElementReader is a wrapper that makes TextDecoders composable into other +// TextDecoders. TextEncoders first read the length of the subsequent value. +// This length value is not present when the value is part of another value such +// as an array. TextElementReader provides a substitute length value from the +// length of the string. No other integer reads are valid. Each time DecodeText +// is called with a TextElementReader as the source the TextElementReader must +// first have Reset called with the new element string data. +type TextElementReader struct { + buf *bytes.Buffer + lengthHeaderIgnored bool +} + +func NewTextElementReader(r io.Reader) *TextElementReader { + return &TextElementReader{buf: &bytes.Buffer{}} +} + +func (r *TextElementReader) ReadUint16() (uint16, error) { + return 0, errors.New("ReadUint16 should never be called on TextElementReader") +} + +func (r *TextElementReader) ReadUint32() (uint32, error) { + if !r.lengthHeaderIgnored { + r.lengthHeaderIgnored = true + if r.buf.String() == "NULL" { + n32 := int32(-1) + return uint32(n32), nil + } + return uint32(r.buf.Len()), nil + } + + return 0, errors.New("ReadUint32 should only be called once on TextElementReader") +} + +func (r *TextElementReader) WriteUint64(n uint64) (int, error) { + return 0, errors.New("ReadUint64 should never be called on TextElementReader") +} + +func (r *TextElementReader) Read(buf []byte) (int, error) { + if r.lengthHeaderIgnored { + return r.buf.Read(buf) + } + + return 0, errors.New("int32 must be read first") +} + +func (r *TextElementReader) Reset(s string) { + r.lengthHeaderIgnored = false + r.buf.Reset() + r.buf.WriteString(s) +} diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go new file mode 100644 index 00000000..cc33b296 --- /dev/null +++ b/pgtype/timestamptz.go @@ -0,0 +1,203 @@ +package pgtype + +import ( + "fmt" + "io" + "reflect" + "time" + + "github.com/jackc/pgx/pgio" +) + +const pgTimestamptzHourFormat = "2006-01-02 15:04:05.999999999Z07" +const pgTimestamptzMinuteFormat = "2006-01-02 15:04:05.999999999Z07:00" +const pgTimestamptzSecondFormat = "2006-01-02 15:04:05.999999999Z07:00:00" +const microsecFromUnixEpochToY2K = 946684800 * 1000000 + +const ( + negativeInfinityMicrosecondOffset = -9223372036854775808 + infinityMicrosecondOffset = 9223372036854775807 +) + +type Timestamptz struct { + Time time.Time + Status Status + InfinityModifier +} + +func (t *Timestamptz) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Timestamptz: + *t = value + case time.Time: + *t = Timestamptz{Time: value, Status: Present} + default: + if originalSrc, ok := underlyingTimeType(src); ok { + return t.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Timestamptz", value) + } + + return nil +} + +func (t *Timestamptz) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *time.Time: + if t.Status != Present || t.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", t, dst) + } + *v = t.Time + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if t.Status == Null { + if !el.IsNil() { + // if the destination pointer is not nil, nil it out + el.Set(reflect.Zero(el.Type())) + } + return nil + } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return t.AssignTo(el.Interface()) + } + } + return fmt.Errorf("cannot assign %v into %T", t, dst) + } + + return nil +} + +func (t *Timestamptz) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *t = Timestamptz{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + sbuf := string(buf) + switch sbuf { + case "infinity": + *t = Timestamptz{Status: Present, InfinityModifier: Infinity} + case "-infinity": + *t = Timestamptz{Status: Present, InfinityModifier: -Infinity} + default: + var format string + if sbuf[len(sbuf)-9] == '-' || sbuf[len(sbuf)-9] == '+' { + format = pgTimestamptzSecondFormat + } else if sbuf[len(sbuf)-6] == '-' || sbuf[len(sbuf)-6] == '+' { + format = pgTimestamptzMinuteFormat + } else { + format = pgTimestamptzHourFormat + } + + tim, err := time.Parse(format, sbuf) + if err != nil { + return err + } + + *t = Timestamptz{Time: tim, Status: Present} + } + + return nil +} + +func (t *Timestamptz) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *t = Timestamptz{Status: Null} + return nil + } + + if size != 8 { + return fmt.Errorf("invalid length for timestamptz: %v", size) + } + + microsecSinceY2K, err := pgio.ReadInt64(r) + if err != nil { + return err + } + + switch microsecSinceY2K { + case infinityMicrosecondOffset: + *t = Timestamptz{Status: Present, InfinityModifier: Infinity} + case negativeInfinityMicrosecondOffset: + *t = Timestamptz{Status: Present, InfinityModifier: -Infinity} + default: + microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K + tim := time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) + *t = Timestamptz{Time: tim, Status: Present} + } + + return nil +} + +func (t Timestamptz) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, t.Status); done { + return err + } + + var s string + + switch t.InfinityModifier { + case None: + s = t.Time.UTC().Format(pgTimestamptzSecondFormat) + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + + _, err = w.Write([]byte(s)) + return err +} + +func (t Timestamptz) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, t.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 8) + if err != nil { + return err + } + + var microsecSinceY2K int64 + switch t.InfinityModifier { + case None: + microsecSinceUnixEpoch := t.Time.Unix()*1000000 + int64(t.Time.Nanosecond())/1000 + microsecSinceY2K = microsecSinceUnixEpoch - microsecFromUnixEpochToY2K + case Infinity: + microsecSinceY2K = infinityMicrosecondOffset + case NegativeInfinity: + microsecSinceY2K = negativeInfinityMicrosecondOffset + } + + _, err = pgio.WriteInt64(w, microsecSinceY2K) + return err +} diff --git a/pgtype/timestamptz_test.go b/pgtype/timestamptz_test.go new file mode 100644 index 00000000..795195f8 --- /dev/null +++ b/pgtype/timestamptz_test.go @@ -0,0 +1,60 @@ +package pgtype_test + +import ( + "testing" + "time" + + "github.com/jackc/pgx/pgtype" +) + +func TestTimestamptzTranscode(t *testing.T) { + testSuccessfulTranscodeEqFunc(t, "timestamptz", []interface{}{ + pgtype.Timestamptz{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + pgtype.Timestamptz{Status: pgtype.Null}, + pgtype.Timestamptz{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, + pgtype.Timestamptz{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, + }, func(a, b interface{}) bool { + at := a.(pgtype.Timestamptz) + bt := b.(pgtype.Timestamptz) + + return at.Time.Equal(bt.Time) && at.Status == bt.Status && at.InfinityModifier == bt.InfinityModifier + }) +} + +func TestTimestamptzConvertFrom(t *testing.T) { + type _time time.Time + + successfulTests := []struct { + source interface{} + result pgtype.Timestamptz + }{ + {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 1, 0, 0, 1, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 1, 0, time.Local), Status: pgtype.Present}}, + {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local)), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Timestamptz + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/query.go b/query.go index 99b383e0..4af1de10 100644 --- a/query.go +++ b/query.go @@ -4,8 +4,11 @@ import ( "database/sql" "errors" "fmt" - "golang.org/x/net/context" "time" + + "golang.org/x/net/context" + + "github.com/jackc/pgx/pgtype" ) // Row is a convenience wrapper over Rows that is returned by QueryRow. @@ -219,6 +222,27 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } + } else if s, ok := d.(ScannerV3); ok { + val, err := decodeByOID(vr) + if err != nil { + rows.Fatal(scanArgError{col: i, err: err}) + } + err = s.ScanPgxV3(nil, val) + if err != nil { + rows.Fatal(scanArgError{col: i, err: err}) + } + } else if s, ok := d.(pgtype.BinaryDecoder); ok && vr.Type().FormatCode == BinaryFormatCode { + vr.err = errRewoundLen + err = s.DecodeBinary(&valueReader2{vr}) + if err != nil { + rows.Fatal(scanArgError{col: i, err: err}) + } + } else if s, ok := d.(pgtype.TextDecoder); ok && vr.Type().FormatCode == TextFormatCode { + vr.err = errRewoundLen + err = s.DecodeText(&valueReader2{vr}) + if err != nil { + rows.Fatal(scanArgError{col: i, err: err}) + } } else if s, ok := d.(sql.Scanner); ok { var val interface{} if 0 <= vr.Len() { @@ -265,8 +289,39 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { d2 := d decodeJSONB(vr, &d2) } else { - if err := Decode(vr, d); err != nil { - rows.Fatal(scanArgError{col: i, err: err}) + if pgVal, present := rows.conn.oidPgtypeValues[vr.Type().DataType]; present { + switch vr.Type().FormatCode { + case TextFormatCode: + if textDecoder, ok := pgVal.(pgtype.TextDecoder); ok { + vr.err = errRewoundLen + err = textDecoder.DecodeText(&valueReader2{vr}) + if err != nil { + vr.Fatal(err) + } + } else { + vr.Fatal(fmt.Errorf("%T is not a pgtype.TextDecoder", pgVal)) + } + case BinaryFormatCode: + if binaryDecoder, ok := pgVal.(pgtype.BinaryDecoder); ok { + vr.err = errRewoundLen + err = binaryDecoder.DecodeBinary(&valueReader2{vr}) + if err != nil { + vr.Fatal(err) + } + } else { + vr.Fatal(fmt.Errorf("%T is not a pgtype.BinaryDecoder", pgVal)) + } + default: + vr.Fatal(fmt.Errorf("unknown format code: %v", vr.Type().FormatCode)) + } + + if err := pgVal.AssignTo(d); err != nil { + vr.Fatal(err) + } + } else { + if err := Decode(vr, d); err != nil { + rows.Fatal(scanArgError{col: i, err: err}) + } } } if vr.Err() != nil { @@ -296,7 +351,7 @@ func (rows *Rows) Values() ([]interface{}, error) { values = append(values, nil) continue } - + // TODO - consider what are the implications of returning complex types since database/sql uses this method switch vr.Type().FormatCode { // All intrinsic types (except string) are encoded with binary // encoding so anything else should be treated as a string diff --git a/query_test.go b/query_test.go index a78914b6..fd5d2e5b 100644 --- a/query_test.go +++ b/query_test.go @@ -4,11 +4,12 @@ import ( "bytes" "database/sql" "fmt" - "golang.org/x/net/context" "strings" "testing" "time" + "golang.org/x/net/context" + "github.com/jackc/pgx" "github.com/shopspring/decimal" @@ -110,7 +111,7 @@ func TestRowsScanDoesNotAllowScanningBinaryFormatValuesIntoString(t *testing.T) var s string err := conn.QueryRow("select 1").Scan(&s) - if err == nil || !strings.Contains(err.Error(), "cannot decode binary value into string") { + if err == nil || !(strings.Contains(err.Error(), "cannot decode binary value into string") || strings.Contains(err.Error(), "cannot assign")) { t.Fatalf("Expected Scan to fail to encode binary value into string but: %v", err) } @@ -199,7 +200,7 @@ func TestConnQueryReadWrongTypeError(t *testing.T) { t.Fatal("Expected Rows to have an error after an improper read but it didn't") } - if rows.Err().Error() != "can't scan into dest[0]: Can't convert OID 23 to time.Time" { + if rows.Err().Error() != "can't scan into dest[0]: Can't convert OID 23 to time.Time" && !strings.Contains(rows.Err().Error(), "cannot assign") { t.Fatalf("Expected different Rows.Err(): %v", rows.Err()) } @@ -518,7 +519,7 @@ func TestQueryRowCoreTypes(t *testing.T) { {"select $1::bool", []interface{}{true}, []interface{}{&actual.b}, allTypes{b: true}}, {"select $1::timestamptz", []interface{}{time.Unix(123, 5000)}, []interface{}{&actual.t}, allTypes{t: time.Unix(123, 5000)}}, {"select $1::timestamp", []interface{}{time.Date(2010, 1, 2, 3, 4, 5, 0, time.Local)}, []interface{}{&actual.t}, allTypes{t: time.Date(2010, 1, 2, 3, 4, 5, 0, time.Local)}}, - {"select $1::date", []interface{}{time.Date(1987, 1, 2, 0, 0, 0, 0, time.Local)}, []interface{}{&actual.t}, allTypes{t: time.Date(1987, 1, 2, 0, 0, 0, 0, time.Local)}}, + {"select $1::date", []interface{}{time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}, []interface{}{&actual.t}, allTypes{t: time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}}, {"select $1::oid", []interface{}{pgx.OID(42)}, []interface{}{&actual.oid}, allTypes{oid: 42}}, } @@ -541,7 +542,7 @@ func TestQueryRowCoreTypes(t *testing.T) { if err == nil { t.Errorf("%d. Expected null to cause error, but it didn't (sql -> %v)", i, tt.sql) } - if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { + if err != nil && !strings.Contains(err.Error(), "Cannot decode null") && !strings.Contains(err.Error(), "cannot assign") { t.Errorf(`%d. Expected null to cause error "Cannot decode null..." but it was %v (sql -> %v)`, i, err, tt.sql) } @@ -944,7 +945,7 @@ func TestQueryRowErrors(t *testing.T) { {"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`}, {"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"}, {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "Cannot decode oid 25 into any integer type"}, - {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "cannot encode int8 into oid 600"}, + {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "Cannot encode int into oid 600"}, } for i, tt := range tests { @@ -1017,7 +1018,7 @@ func TestQueryRowCoreInt16Slice(t *testing.T) { if err == nil { t.Error("Expected null to cause error when scanned into slice, but it didn't") } - if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { + if err != nil && !(strings.Contains(err.Error(), "Cannot decode null") || strings.Contains(err.Error(), "cannot assign")) { t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err) } diff --git a/value_reader.go b/value_reader.go index 249b8ba3..c91a21af 100644 --- a/value_reader.go +++ b/value_reader.go @@ -4,6 +4,8 @@ import ( "errors" ) +var errRewoundLen = errors.New("len was rewound") + // ValueReader is used by the Scanner interface to decode values. type ValueReader struct { mr *msgReader @@ -154,3 +156,28 @@ func (r *ValueReader) ReadBytes(count int32) []byte { return r.mr.readBytes(count) } + +type valueReader2 struct { + *ValueReader +} + +func (r *valueReader2) Read(dst []byte) (int, error) { + if r.err != nil { + return 0, r.err + } + + src := r.ReadBytes(int32(len(dst))) + + copy(dst, src) + + return len(dst), nil +} + +func (r *valueReader2) ReadUint32() (uint32, error) { + if r.err == errRewoundLen { + r.err = nil + return uint32(r.Len()), nil + } + + return r.ValueReader.ReadUint32(), nil +} diff --git a/values.go b/values.go index 45ed914c..a9c4c209 100644 --- a/values.go +++ b/values.go @@ -13,6 +13,8 @@ import ( "strconv" "strings" "time" + + "github.com/jackc/pgx/pgtype" ) // PostgreSQL oids for common types @@ -200,6 +202,10 @@ type Encoder interface { FormatCode() int16 } +type ScannerV3 interface { + ScanPgxV3(fieldDescription interface{}, src interface{}) error +} + // NullFloat32 represents an float4 that may be null. NullFloat32 implements the // Scanner and Encoder interfaces so it may be used both as an argument to // Query[Row] and a destination for Scan. @@ -497,7 +503,7 @@ func (n NullInt16) Encode(w *WriteBuf, oid OID) error { return nil } - return encodeInt16(w, oid, n.Int16) + return pgtype.Int2{Int: n.Int16, Status: pgtype.Present}.EncodeBinary(w) } // NullInt32 represents an integer that may be null. NullInt32 implements the @@ -536,7 +542,7 @@ func (n NullInt32) Encode(w *WriteBuf, oid OID) error { return nil } - return encodeInt32(w, oid, n.Int32) + return pgtype.Int4{Int: n.Int32, Status: pgtype.Present}.EncodeBinary(w) } // OID (Object Identifier Type) is, according to https://www.postgresql.org/docs/current/static/datatype-oid.html, @@ -782,7 +788,7 @@ func (n NullInt64) Encode(w *WriteBuf, oid OID) error { return nil } - return encodeInt64(w, oid, n.Int64) + return pgtype.Int8{Int: n.Int64, Status: pgtype.Present}.EncodeBinary(w) } // NullBool represents an bool that may be null. NullBool implements the Scanner @@ -1020,6 +1026,10 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { switch arg := arg.(type) { case Encoder: return arg.Encode(wbuf, oid) + case pgtype.BinaryEncoder: + return arg.EncodeBinary(wbuf) + case pgtype.TextEncoder: + return arg.EncodeText(wbuf) case driver.Valuer: v, err := arg.Value() if err != nil { @@ -1054,17 +1064,19 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { return encodeJSONB(wbuf, oid, arg) } + if value, ok := wbuf.conn.oidPgtypeValues[oid]; ok { + err := value.ConvertFrom(arg) + if err != nil { + return err + } + return value.(pgtype.BinaryEncoder).EncodeBinary(wbuf) + } + switch arg := arg.(type) { case []string: return encodeStringSlice(wbuf, oid, arg) - case bool: - return encodeBool(wbuf, oid, arg) case []bool: return encodeBoolSlice(wbuf, oid, arg) - case int: - return encodeInt(wbuf, oid, arg) - case uint: - return encodeUInt(wbuf, oid, arg) case Char: return encodeChar(wbuf, oid, arg) case AclItem: @@ -1075,32 +1087,12 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { // The name data type goes over the wire using the same format as string, // so just cast to string and use encodeString return encodeString(wbuf, oid, string(arg)) - case int8: - return encodeInt8(wbuf, oid, arg) - case uint8: - return encodeUInt8(wbuf, oid, arg) - case int16: - return encodeInt16(wbuf, oid, arg) - case []int16: - return encodeInt16Slice(wbuf, oid, arg) - case uint16: - return encodeUInt16(wbuf, oid, arg) - case []uint16: - return encodeUInt16Slice(wbuf, oid, arg) - case int32: - return encodeInt32(wbuf, oid, arg) case []int32: return encodeInt32Slice(wbuf, oid, arg) - case uint32: - return encodeUInt32(wbuf, oid, arg) case []uint32: return encodeUInt32Slice(wbuf, oid, arg) - case int64: - return encodeInt64(wbuf, oid, arg) case []int64: return encodeInt64Slice(wbuf, oid, arg) - case uint64: - return encodeUInt64(wbuf, oid, arg) case []uint64: return encodeUInt64Slice(wbuf, oid, arg) case float32: @@ -1140,32 +1132,57 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { func stripNamedType(val *reflect.Value) (interface{}, bool) { switch val.Kind() { case reflect.Int: - return int(val.Int()), true + convVal := int(val.Int()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Int8: - return int8(val.Int()), true + convVal := int8(val.Int()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Int16: - return int16(val.Int()), true + convVal := int16(val.Int()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Int32: - return int32(val.Int()), true + convVal := int32(val.Int()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Int64: - return int64(val.Int()), true + convVal := int64(val.Int()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Uint: - return uint(val.Uint()), true + convVal := uint(val.Uint()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Uint8: - return uint8(val.Uint()), true + convVal := uint8(val.Uint()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Uint16: - return uint16(val.Uint()), true + convVal := uint16(val.Uint()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Uint32: - return uint32(val.Uint()), true + convVal := uint32(val.Uint()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Uint64: - return uint64(val.Uint()), true + convVal := uint64(val.Uint()) + return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.String: - return val.String(), true + convVal := val.String() + return convVal, reflect.TypeOf(convVal) != val.Type() } return nil, false } +func decodeByOID(vr *ValueReader) (interface{}, error) { + switch vr.Type().DataType { + case Int2OID, Int4OID, Int8OID: + n := decodeInt(vr) + return n, vr.Err() + case BoolOID: + b := decodeBool(vr) + return b, vr.Err() + default: + buf := vr.ReadBytes(vr.Len()) + return buf, vr.Err() + } +} + // Decode decodes from vr into d. d must be a pointer. This allows // implementations of the Decoder interface to delegate the actual work of // decoding to the built-in functionality. @@ -1381,28 +1398,36 @@ func Decode(vr *ValueReader, d interface{}) error { } func decodeBool(vr *ValueReader) bool { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into bool")) - return false - } - if vr.Type().DataType != BoolOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into bool", vr.Type().DataType))) return false } - if vr.Type().FormatCode != BinaryFormatCode { + vr.err = errRewoundLen + + var b pgtype.Bool + var err error + switch vr.Type().FormatCode { + case TextFormatCode: + err = b.DecodeText(&valueReader2{vr}) + case BinaryFormatCode: + err = b.DecodeBinary(&valueReader2{vr}) + default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return false } - if vr.Len() != 1 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool: %d", vr.Len()))) + if err != nil { + vr.Fatal(err) return false } - b := vr.ReadByte() - return b != 0 + if b.Status != pgtype.Present { + vr.Fatal(fmt.Errorf("Cannot decode null into bool")) + return false + } + + return b.Bool } func encodeBool(w *WriteBuf, oid OID, value bool) error { @@ -1410,16 +1435,8 @@ func encodeBool(w *WriteBuf, oid OID, value bool) error { return fmt.Errorf("cannot encode Go %s into oid %d", "bool", oid) } - w.WriteInt32(1) - - var n byte - if value { - n = 1 - } - - w.WriteByte(n) - - return nil + b := pgtype.Bool{Bool: value, Status: pgtype.Present} + return b.EncodeBinary(w) } func decodeInt(vr *ValueReader) int64 { @@ -1447,17 +1464,31 @@ func decodeInt8(vr *ValueReader) int64 { return 0 } - if vr.Type().FormatCode != BinaryFormatCode { + vr.err = errRewoundLen + + var n pgtype.Int8 + var err error + switch vr.Type().FormatCode { + case TextFormatCode: + err = n.DecodeText(&valueReader2{vr}) + case BinaryFormatCode: + err = n.DecodeBinary(&valueReader2{vr}) + default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0 } - if vr.Len() != 8 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8: %d", vr.Len()))) + if err != nil { + vr.Fatal(err) return 0 } - return vr.ReadInt64() + if n.Status == pgtype.Null { + vr.Fatal(ProtocolError("Cannot decode null into int16")) + return 0 + } + + return n.Int } func decodeChar(vr *ValueReader) Char { @@ -1485,88 +1516,37 @@ func decodeChar(vr *ValueReader) Char { } func decodeInt2(vr *ValueReader) int16 { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into int16")) - return 0 - } if vr.Type().DataType != Int2OID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int16", vr.Type().DataType))) return 0 } - if vr.Type().FormatCode != BinaryFormatCode { + vr.err = errRewoundLen + + var n pgtype.Int2 + var err error + switch vr.Type().FormatCode { + case TextFormatCode: + err = n.DecodeText(&valueReader2{vr}) + case BinaryFormatCode: + err = n.DecodeBinary(&valueReader2{vr}) + default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0 } - if vr.Len() != 2 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2: %d", vr.Len()))) + if err != nil { + vr.Fatal(err) return 0 } - return vr.ReadInt16() -} - -func encodeInt(w *WriteBuf, oid OID, value int) error { - switch oid { - case Int2OID: - if value < math.MinInt16 { - return fmt.Errorf("%d is less than min pg:int2", value) - } else if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than max pg:int2", value) - } - w.WriteInt32(2) - w.WriteInt16(int16(value)) - case Int4OID: - if value < math.MinInt32 { - return fmt.Errorf("%d is less than min pg:int4", value) - } else if value > math.MaxInt32 { - return fmt.Errorf("%d is greater than max pg:int4", value) - } - w.WriteInt32(4) - w.WriteInt32(int32(value)) - case Int8OID: - if int64(value) <= int64(math.MaxInt64) { - w.WriteInt32(8) - w.WriteInt64(int64(value)) - } else { - return fmt.Errorf("%d is larger than max int64 %d", value, int64(math.MaxInt64)) - } - default: - return fmt.Errorf("cannot encode %s into oid %v", "int8", oid) + if n.Status == pgtype.Null { + vr.Fatal(ProtocolError("Cannot decode null into int16")) + return 0 } - return nil -} - -func encodeUInt(w *WriteBuf, oid OID, value uint) error { - switch oid { - case Int2OID: - if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than max pg:int2", value) - } - w.WriteInt32(2) - w.WriteInt16(int16(value)) - case Int4OID: - if value > math.MaxInt32 { - return fmt.Errorf("%d is greater than max pg:int4", value) - } - w.WriteInt32(4) - w.WriteInt32(int32(value)) - case Int8OID: - //****** Changed value to int64(value) and math.MaxInt64 to int64(math.MaxInt64) - if int64(value) > int64(math.MaxInt64) { - return fmt.Errorf("%d is greater than max pg:int8", value) - } - w.WriteInt32(8) - w.WriteInt64(int64(value)) - - default: - return fmt.Errorf("cannot encode %s into oid %v", "uint8", oid) - } - - return nil + return n.Int } func encodeChar(w *WriteBuf, oid OID, value Char) error { @@ -1575,187 +1555,6 @@ func encodeChar(w *WriteBuf, oid OID, value Char) error { return nil } -func encodeInt8(w *WriteBuf, oid OID, value int8) error { - switch oid { - case Int2OID: - w.WriteInt32(2) - w.WriteInt16(int16(value)) - case Int4OID: - w.WriteInt32(4) - w.WriteInt32(int32(value)) - case Int8OID: - w.WriteInt32(8) - w.WriteInt64(int64(value)) - default: - return fmt.Errorf("cannot encode %s into oid %v", "int8", oid) - } - - return nil -} - -func encodeUInt8(w *WriteBuf, oid OID, value uint8) error { - switch oid { - case Int2OID: - w.WriteInt32(2) - w.WriteInt16(int16(value)) - case Int4OID: - w.WriteInt32(4) - w.WriteInt32(int32(value)) - case Int8OID: - w.WriteInt32(8) - w.WriteInt64(int64(value)) - default: - return fmt.Errorf("cannot encode %s into oid %v", "uint8", oid) - } - - return nil -} - -func encodeInt16(w *WriteBuf, oid OID, value int16) error { - switch oid { - case Int2OID: - w.WriteInt32(2) - w.WriteInt16(value) - case Int4OID: - w.WriteInt32(4) - w.WriteInt32(int32(value)) - case Int8OID: - w.WriteInt32(8) - w.WriteInt64(int64(value)) - default: - return fmt.Errorf("cannot encode %s into oid %v", "int16", oid) - } - - return nil -} - -func encodeUInt16(w *WriteBuf, oid OID, value uint16) error { - switch oid { - case Int2OID: - if value <= math.MaxInt16 { - w.WriteInt32(2) - w.WriteInt16(int16(value)) - } else { - return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16) - } - case Int4OID: - w.WriteInt32(4) - w.WriteInt32(int32(value)) - case Int8OID: - w.WriteInt32(8) - w.WriteInt64(int64(value)) - default: - return fmt.Errorf("cannot encode %s into oid %v", "int16", oid) - } - - return nil -} - -func encodeInt32(w *WriteBuf, oid OID, value int32) error { - switch oid { - case Int2OID: - if value <= math.MaxInt16 { - w.WriteInt32(2) - w.WriteInt16(int16(value)) - } else { - return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16) - } - case Int4OID: - w.WriteInt32(4) - w.WriteInt32(value) - case Int8OID: - w.WriteInt32(8) - w.WriteInt64(int64(value)) - default: - return fmt.Errorf("cannot encode %s into oid %v", "int32", oid) - } - - return nil -} - -func encodeUInt32(w *WriteBuf, oid OID, value uint32) error { - switch oid { - case Int2OID: - if value <= math.MaxInt16 { - w.WriteInt32(2) - w.WriteInt16(int16(value)) - } else { - return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16) - } - case Int4OID: - if value <= math.MaxInt32 { - w.WriteInt32(4) - w.WriteInt32(int32(value)) - } else { - return fmt.Errorf("%d is greater than max int32 %d", value, math.MaxInt32) - } - case Int8OID: - w.WriteInt32(8) - w.WriteInt64(int64(value)) - default: - return fmt.Errorf("cannot encode %s into oid %v", "uint32", oid) - } - - return nil -} - -func encodeInt64(w *WriteBuf, oid OID, value int64) error { - switch oid { - case Int2OID: - if value <= math.MaxInt16 { - w.WriteInt32(2) - w.WriteInt16(int16(value)) - } else { - return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16) - } - case Int4OID: - if value <= math.MaxInt32 { - w.WriteInt32(4) - w.WriteInt32(int32(value)) - } else { - return fmt.Errorf("%d is greater than max int32 %d", value, math.MaxInt32) - } - case Int8OID: - w.WriteInt32(8) - w.WriteInt64(value) - default: - return fmt.Errorf("cannot encode %s into oid %v", "int64", oid) - } - - return nil -} - -func encodeUInt64(w *WriteBuf, oid OID, value uint64) error { - switch oid { - case Int2OID: - if value <= math.MaxInt16 { - w.WriteInt32(2) - w.WriteInt16(int16(value)) - } else { - return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16) - } - case Int4OID: - if value <= math.MaxInt32 { - w.WriteInt32(4) - w.WriteInt32(int32(value)) - } else { - return fmt.Errorf("%d is greater than max int32 %d", value, math.MaxInt32) - } - case Int8OID: - - if value <= math.MaxInt64 { - w.WriteInt32(8) - w.WriteInt64(int64(value)) - } else { - return fmt.Errorf("%d is greater than max int64 %d", value, int64(math.MaxInt64)) - } - default: - return fmt.Errorf("cannot encode %s into oid %v", "uint64", oid) - } - - return nil -} - func decodeInt4(vr *ValueReader) int32 { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into int32")) @@ -1767,17 +1566,31 @@ func decodeInt4(vr *ValueReader) int32 { return 0 } - if vr.Type().FormatCode != BinaryFormatCode { + vr.err = errRewoundLen + + var n pgtype.Int4 + var err error + switch vr.Type().FormatCode { + case TextFormatCode: + err = n.DecodeText(&valueReader2{vr}) + case BinaryFormatCode: + err = n.DecodeBinary(&valueReader2{vr}) + default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0 } - if vr.Len() != 4 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4: %d", vr.Len()))) + if err != nil { + vr.Fatal(err) return 0 } - return vr.ReadInt32() + if n.Status == pgtype.Null { + vr.Fatal(ProtocolError("Cannot decode null into int16")) + return 0 + } + + return n.Int } func decodeOID(vr *ValueReader) OID { @@ -2179,51 +1992,54 @@ func encodeJSONB(w *WriteBuf, oid OID, value interface{}) error { } func decodeDate(vr *ValueReader) time.Time { - var zeroTime time.Time - - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into time.Time")) - return zeroTime - } - if vr.Type().DataType != DateOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) - return zeroTime + return time.Time{} } - if vr.Type().FormatCode != BinaryFormatCode { + vr.err = errRewoundLen + + var d pgtype.Date + var err error + switch vr.Type().FormatCode { + case TextFormatCode: + err = d.DecodeText(&valueReader2{vr}) + case BinaryFormatCode: + err = d.DecodeBinary(&valueReader2{vr}) + default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return zeroTime + return time.Time{} } - if vr.Len() != 4 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an date: %d", vr.Len()))) + if err != nil { + vr.Fatal(err) + return time.Time{} } - dayOffset := vr.ReadInt32() - return time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.Local) + + if d.Status == pgtype.Null { + vr.Fatal(ProtocolError("Cannot decode null into int16")) + return time.Time{} + } + + return d.Time } func encodeTime(w *WriteBuf, oid OID, value time.Time) error { switch oid { case DateOID: - tUnix := time.Date(value.Year(), value.Month(), value.Day(), 0, 0, 0, 0, time.UTC).Unix() - dateEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).Unix() - - secSinceDateEpoch := tUnix - dateEpoch - daysSinceDateEpoch := secSinceDateEpoch / 86400 - - w.WriteInt32(4) - w.WriteInt32(int32(daysSinceDateEpoch)) - - return nil + var d pgtype.Date + err := d.ConvertFrom(value) + if err != nil { + return err + } + return d.EncodeBinary(w) case TimestampTzOID, TimestampOID: - microsecSinceUnixEpoch := value.Unix()*1000000 + int64(value.Nanosecond())/1000 - microsecSinceY2K := microsecSinceUnixEpoch - microsecFromUnixEpochToY2K - - w.WriteInt32(8) - w.WriteInt64(microsecSinceY2K) - - return nil + var t pgtype.Timestamptz + err := t.ConvertFrom(value) + if err != nil { + return err + } + return t.EncodeBinary(w) default: return fmt.Errorf("cannot encode %s into oid %v", "time.Time", oid) } @@ -2244,19 +2060,31 @@ func decodeTimestampTz(vr *ValueReader) time.Time { return zeroTime } - if vr.Type().FormatCode != BinaryFormatCode { + vr.err = errRewoundLen + + var t pgtype.Timestamptz + var err error + switch vr.Type().FormatCode { + case TextFormatCode: + err = t.DecodeText(&valueReader2{vr}) + case BinaryFormatCode: + err = t.DecodeBinary(&valueReader2{vr}) + default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return zeroTime + return time.Time{} } - if vr.Len() != 8 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an timestamptz: %d", vr.Len()))) - return zeroTime + if err != nil { + vr.Fatal(err) + return time.Time{} } - microsecSinceY2K := vr.ReadInt64() - microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K - return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) + if t.Status == pgtype.Null { + vr.Fatal(ProtocolError("Cannot decode null into time.Time")) + return time.Time{} + } + + return t.Time } func decodeTimestamp(vr *ValueReader) time.Time { @@ -2578,42 +2406,45 @@ func encodeByteSliceSlice(w *WriteBuf, oid OID, value [][]byte) error { } func decodeInt2Array(vr *ValueReader) []int16 { - if vr.Len() == -1 { - return nil - } - if vr.Type().DataType != Int2ArrayOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int16", vr.Type().DataType))) return nil } - if vr.Type().FormatCode != BinaryFormatCode { + vr.err = errRewoundLen + + var a pgtype.Int2Array + var err error + switch vr.Type().FormatCode { + case TextFormatCode: + err = a.DecodeText(&valueReader2{vr}) + case BinaryFormatCode: + err = a.DecodeBinary(&valueReader2{vr}) + default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return nil } - numElems, err := decode1dArrayHeader(vr) if err != nil { vr.Fatal(err) return nil } - a := make([]int16, int(numElems)) - for i := 0; i < len(a); i++ { - elSize := vr.ReadInt32() - switch elSize { - case 2: - a[i] = vr.ReadInt16() - case -1: + if a.Status == pgtype.Null { + return nil + } + + rawArray := make([]int16, len(a.Elements)) + for i := range a.Elements { + if a.Elements[i].Status == pgtype.Present { + rawArray[i] = a.Elements[i].Int + } else { vr.Fatal(ProtocolError("Cannot decode null element")) return nil - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2 element: %d", elSize))) - return nil } } - return a + return rawArray } func decodeInt2ArrayToUInt(vr *ValueReader) []uint16 { @@ -2660,38 +2491,6 @@ func decodeInt2ArrayToUInt(vr *ValueReader) []uint16 { return a } -func encodeInt16Slice(w *WriteBuf, oid OID, slice []int16) error { - if oid != Int2ArrayOID { - return fmt.Errorf("cannot encode Go %s into oid %d", "[]int16", oid) - } - - encodeArrayHeader(w, Int2OID, len(slice), 6) - for _, v := range slice { - w.WriteInt32(2) - w.WriteInt16(v) - } - - return nil -} - -func encodeUInt16Slice(w *WriteBuf, oid OID, slice []uint16) error { - if oid != Int2ArrayOID { - return fmt.Errorf("cannot encode Go %s into oid %d", "[]uint16", oid) - } - - encodeArrayHeader(w, Int2OID, len(slice), 6) - for _, v := range slice { - if v <= math.MaxInt16 { - w.WriteInt32(2) - w.WriteInt16(int16(v)) - } else { - return fmt.Errorf("%d is greater than max smallint %d", v, math.MaxInt16) - } - } - - return nil -} - func decodeInt4Array(vr *ValueReader) []int32 { if vr.Len() == -1 { return nil diff --git a/values_test.go b/values_test.go index 6ab221f7..ef13ccdf 100644 --- a/values_test.go +++ b/values_test.go @@ -18,24 +18,24 @@ func TestDateTranscode(t *testing.T) { defer closeConn(t, conn) dates := []time.Time{ - time.Date(1, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(1000, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(1600, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(1700, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(1800, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(1990, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local), - time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(2001, 1, 2, 0, 0, 0, 0, time.Local), - time.Date(2004, 2, 29, 0, 0, 0, 0, time.Local), - time.Date(2013, 7, 4, 0, 0, 0, 0, time.Local), - time.Date(2013, 12, 25, 0, 0, 0, 0, time.Local), - time.Date(2029, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(2081, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(2096, 2, 29, 0, 0, 0, 0, time.Local), - time.Date(2550, 1, 1, 0, 0, 0, 0, time.Local), - time.Date(9999, 12, 31, 0, 0, 0, 0, time.Local), + time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1000, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1600, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1700, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1800, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), + time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(2001, 1, 2, 0, 0, 0, 0, time.UTC), + time.Date(2004, 2, 29, 0, 0, 0, 0, time.UTC), + time.Date(2013, 7, 4, 0, 0, 0, 0, time.UTC), + time.Date(2013, 12, 25, 0, 0, 0, 0, time.UTC), + time.Date(2029, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(2081, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(2096, 2, 29, 0, 0, 0, 0, time.UTC), + time.Date(2550, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(9999, 12, 31, 0, 0, 0, 0, time.UTC), } for _, actualDate := range dates { @@ -629,8 +629,8 @@ func TestNullX(t *testing.T) { {"select $1::timestamptz", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}}, {"select $1::timestamp", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}}, {"select $1::timestamp", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}}, - {"select $1::date", []interface{}{pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}}, - {"select $1::date", []interface{}{pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.Local), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}}, + {"select $1::date", []interface{}{pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}}, + {"select $1::date", []interface{}{pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}}, {"select 42::int4, $1::float8", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: true}}, []interface{}{&actual.i32, &actual.f64}, allTypes{i32: pgx.NullInt32{Int32: 42, Valid: true}, f64: pgx.NullFloat64{Float64: 1.23, Valid: true}}}, } @@ -1048,11 +1048,11 @@ func TestEncodeTypeRename(t *testing.T) { defer closeConn(t, conn) type _int int - inInt := _int(3) + inInt := _int(1) var outInt _int type _int8 int8 - inInt8 := _int8(3) + inInt8 := _int8(2) var outInt8 _int8 type _int16 int16 From 91dea95b686bce4124508e3f76e6ca5fdf3e8015 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 Mar 2017 15:35:58 -0600 Subject: [PATCH 056/264] Only test on Go 1.8 on Travis --- .travis.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 9ae8d963..324d800b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,7 @@ language: go go: - - 1.7.4 - - 1.6.4 + - 1.8.0 - tip # Derived from https://github.com/lib/pq/blob/master/.travis.yml From eb484e1368d255ab037787b6906df888586f4fd3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 Mar 2017 15:44:17 -0600 Subject: [PATCH 057/264] TestStressConnPool now runs for X iterations ...instead of T time. Also run in parallel. --- stress_test.go | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/stress_test.go b/stress_test.go index 82979fd6..bb0a8287 100644 --- a/stress_test.go +++ b/stress_test.go @@ -3,11 +3,12 @@ package pgx_test import ( "errors" "fmt" - "golang.org/x/net/context" "math/rand" "testing" "time" + "golang.org/x/net/context" + "github.com/jackc/fake" "github.com/jackc/pgx" ) @@ -23,6 +24,8 @@ type queryRower interface { } func TestStressConnPool(t *testing.T) { + t.Parallel() + maxConnections := 8 pool := createConnPool(t, maxConnections) defer pool.Close() @@ -49,11 +52,12 @@ func TestStressConnPool(t *testing.T) { {"canceledExecContext", canceledExecContext}, } - var timer *time.Timer + var actionCount int + if testing.Short() { - timer = time.NewTimer(5 * time.Second) + actionCount = 1000 } else { - timer = time.NewTimer(60 * time.Second) + actionCount = 10000 } workerCount := 16 @@ -77,11 +81,8 @@ func TestStressConnPool(t *testing.T) { go work() } - var stop bool - for i := 0; !stop; i++ { + for i := 0; i < actionCount; i++ { select { - case <-timer.C: - stop = true case workChan <- i: case err := <-errChan: close(workChan) From 70f04f227e51c166cbd0e1850cbd8120c5eef19b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 Mar 2017 15:45:42 -0600 Subject: [PATCH 058/264] Remove long TLS stress test This was used to check that over 512 MB could be read over a TLS connection. This previously could fail due to SSL renegotiation. But now pgx explicitly disables renegotiation when connecting to the PostgreSQL server. Also, the Go TLS library now supports limited renegotiation. And Amazon Redshift was the only target that this mattered on, and it now supports disabling renegotiation. So removing this long running and no longer needed test. --- stress_test.go | 36 ------------------------------------ 1 file changed, 36 deletions(-) diff --git a/stress_test.go b/stress_test.go index bb0a8287..afb5d860 100644 --- a/stress_test.go +++ b/stress_test.go @@ -96,42 +96,6 @@ func TestStressConnPool(t *testing.T) { } } -func TestStressTLSConnection(t *testing.T) { - t.Parallel() - - if tlsConnConfig == nil { - t.Skip("Skipping due to undefined tlsConnConfig") - } - - if testing.Short() { - t.Skip("Skipping due to testing -short") - } - - conn, err := pgx.Connect(*tlsConnConfig) - if err != nil { - t.Fatalf("Unable to establish connection: %v", err) - } - defer conn.Close() - - for i := 0; i < 50; i++ { - sql := `select * from generate_series(1, $1)` - - rows, err := conn.Query(sql, 2000000) - if err != nil { - t.Fatal(err) - } - - var n int32 - for rows.Next() { - rows.Scan(&n) - } - - if rows.Err() != nil { - t.Fatalf("queryCount: %d, Row number: %d. %v", i, n, rows.Err()) - } - } -} - func setupStressDB(t *testing.T, pool *pgx.ConnPool) { _, err := pool.Exec(` drop table if exists widgets; From e53f739cbd764fed982e1110ce416ca5c487b38a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 Mar 2017 15:54:06 -0600 Subject: [PATCH 059/264] Add STRESS_FACTOR to stress tests --- stress_test.go | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/stress_test.go b/stress_test.go index afb5d860..814c8023 100644 --- a/stress_test.go +++ b/stress_test.go @@ -4,6 +4,8 @@ import ( "errors" "fmt" "math/rand" + "os" + "strconv" "testing" "time" @@ -52,13 +54,15 @@ func TestStressConnPool(t *testing.T) { {"canceledExecContext", canceledExecContext}, } - var actionCount int - - if testing.Short() { - actionCount = 1000 - } else { - actionCount = 10000 + actionCount := 1000 + if s := os.Getenv("STRESS_FACTOR"); s != "" { + stressFactor, err := strconv.ParseInt(s, 10, 64) + if err != nil { + t.Fatalf("failed to parse STRESS_FACTOR: %v", s) + } + actionCount *= int(stressFactor) } + workerCount := 16 workChan := make(chan int) From cea412f2ba721c123775fbe9bc1355c62e9b7b2b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 Mar 2017 15:57:24 -0600 Subject: [PATCH 060/264] Fix chat example --- examples/chat/main.go | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/examples/chat/main.go b/examples/chat/main.go index ad8d56db..69ef456b 100644 --- a/examples/chat/main.go +++ b/examples/chat/main.go @@ -2,10 +2,11 @@ package main import ( "bufio" + "context" "fmt" - "github.com/jackc/pgx" "os" - "time" + + "github.com/jackc/pgx" ) var pool *pgx.ConnPool @@ -58,10 +59,7 @@ func listen() { conn.Listen("chat") for { - notification, err := conn.WaitForNotification(time.Second) - if err == pgx.ErrNotificationTimeout { - continue - } + notification, err := conn.WaitForNotification(context.Background()) if err != nil { fmt.Fprintln(os.Stderr, "Error waiting for notification:", err) os.Exit(1) From 15b44f409684bdad0dce8f0dc23cc23a0c62756d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 Mar 2017 16:00:08 -0600 Subject: [PATCH 061/264] Remove -short from travis --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 324d800b..dc3263f4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -54,7 +54,7 @@ install: - go get -u github.com/jackc/pgmock/pgmsg script: - - go test -v -race -short ./... + - go test -v -race ./... matrix: allow_failures: From 908c439317d9c8831562fcd34e41a7f6b1fd2d85 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 Mar 2017 16:01:06 -0600 Subject: [PATCH 062/264] Use stdlib context --- .travis.yml | 1 - conn.go | 3 +-- conn_pool.go | 2 +- conn_test.go | 2 +- query.go | 3 +-- query_test.go | 4 +--- stress_test.go | 3 +-- 7 files changed, 6 insertions(+), 12 deletions(-) diff --git a/.travis.yml b/.travis.yml index dc3263f4..e3d94acd 100644 --- a/.travis.yml +++ b/.travis.yml @@ -50,7 +50,6 @@ install: - go get -u github.com/shopspring/decimal - go get -u gopkg.in/inconshreveable/log15.v2 - go get -u github.com/jackc/fake - - go get -u golang.org/x/net/context - go get -u github.com/jackc/pgmock/pgmsg script: diff --git a/conn.go b/conn.go index 09dada10..1c0b4e22 100644 --- a/conn.go +++ b/conn.go @@ -1,6 +1,7 @@ package pgx import ( + "context" "crypto/md5" "crypto/tls" "encoding/binary" @@ -19,8 +20,6 @@ import ( "sync/atomic" "time" - "golang.org/x/net/context" - "github.com/jackc/pgx/chunkreader" "github.com/jackc/pgx/pgtype" ) diff --git a/conn_pool.go b/conn_pool.go index 9701f170..9dfbf734 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -1,8 +1,8 @@ package pgx import ( + "context" "errors" - "golang.org/x/net/context" "sync" "time" ) diff --git a/conn_test.go b/conn_test.go index cc87efa8..b44fd6db 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1,9 +1,9 @@ package pgx_test import ( + "context" "crypto/tls" "fmt" - "golang.org/x/net/context" "net" "os" "reflect" diff --git a/query.go b/query.go index 4af1de10..80e7c47c 100644 --- a/query.go +++ b/query.go @@ -1,13 +1,12 @@ package pgx import ( + "context" "database/sql" "errors" "fmt" "time" - "golang.org/x/net/context" - "github.com/jackc/pgx/pgtype" ) diff --git a/query_test.go b/query_test.go index fd5d2e5b..c30ab2ef 100644 --- a/query_test.go +++ b/query_test.go @@ -2,16 +2,14 @@ package pgx_test import ( "bytes" + "context" "database/sql" "fmt" "strings" "testing" "time" - "golang.org/x/net/context" - "github.com/jackc/pgx" - "github.com/shopspring/decimal" ) diff --git a/stress_test.go b/stress_test.go index 814c8023..47a3f4d6 100644 --- a/stress_test.go +++ b/stress_test.go @@ -1,6 +1,7 @@ package pgx_test import ( + "context" "errors" "fmt" "math/rand" @@ -9,8 +10,6 @@ import ( "testing" "time" - "golang.org/x/net/context" - "github.com/jackc/fake" "github.com/jackc/pgx" ) From ed9e8bb1688118251bdb38b3919e148d889359f9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 Mar 2017 16:08:00 -0600 Subject: [PATCH 063/264] Remove skip test for missing json type All supported versions of PostgreSQL now have json type. --- example_json_test.go | 8 +------- stdlib/sql_test.go | 20 ++------------------ 2 files changed, 3 insertions(+), 25 deletions(-) diff --git a/example_json_test.go b/example_json_test.go index 631430b8..09e27cff 100644 --- a/example_json_test.go +++ b/example_json_test.go @@ -2,6 +2,7 @@ package pgx_test import ( "fmt" + "github.com/jackc/pgx" ) @@ -12,13 +13,6 @@ func Example_JSON() { return } - if _, ok := conn.PgTypes[pgx.JSONOID]; !ok { - // No JSON type -- must be running against very old PostgreSQL - // Pretend it works - fmt.Println("John", 42) - return - } - type person struct { Name string `json:"name"` Age int `json:"age"` diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 546ec4fb..c8062c61 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -3,9 +3,10 @@ package stdlib_test import ( "bytes" "database/sql" + "testing" + "github.com/jackc/pgx" "github.com/jackc/pgx/stdlib" - "testing" ) func openDB(t *testing.T) *sql.DB { @@ -497,10 +498,6 @@ func TestConnQueryJSONIntoByteSlice(t *testing.T) { db := openDB(t) defer closeDB(t, db) - if !serverHasJSON(t, db) { - t.Skip("Skipping due to server's lack of JSON type") - } - _, err := db.Exec(` create temporary table docs( body json not null @@ -537,10 +534,6 @@ func TestConnExecInsertByteSliceIntoJSON(t *testing.T) { db := openDB(t) defer closeDB(t, db) - if !serverHasJSON(t, db) { - t.Skip("Skipping due to server's lack of JSON type") - } - _, err := db.Exec(` create temporary table docs( body json not null @@ -575,15 +568,6 @@ func TestConnExecInsertByteSliceIntoJSON(t *testing.T) { ensureConnValid(t, db) } -func serverHasJSON(t *testing.T, db *sql.DB) bool { - var hasJSON bool - err := db.QueryRow(`select exists(select 1 from pg_type where typname='json')`).Scan(&hasJSON) - if err != nil { - t.Fatalf("db.QueryRow unexpectedly failed: %v", err) - } - return hasJSON -} - func TestTransactionLifeCycle(t *testing.T) { db := openDB(t) defer closeDB(t, db) From 0a0c086edd29955837b4e112a9a3e1db772a7db6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 Mar 2017 16:46:50 -0600 Subject: [PATCH 064/264] Fix broken stdlib tests --- query.go | 93 +++++++++++++++++++++++++++++++++++++++++++++++++++ stdlib/sql.go | 2 +- 2 files changed, 94 insertions(+), 1 deletion(-) diff --git a/query.go b/query.go index 80e7c47c..52643b8d 100644 --- a/query.go +++ b/query.go @@ -439,6 +439,99 @@ func (rows *Rows) Values() ([]interface{}, error) { return values, rows.Err() } +// ValuesForStdlib is a temporary function to keep all systems operational +// while refactoring. Do not use. +func (rows *Rows) ValuesForStdlib() ([]interface{}, error) { + if rows.closed { + return nil, errors.New("rows is closed") + } + + values := make([]interface{}, 0, len(rows.fields)) + + for range rows.fields { + vr, _ := rows.nextColumn() + + if vr.Len() == -1 { + values = append(values, nil) + continue + } + // TODO - consider what are the implications of returning complex types since database/sql uses this method + switch vr.Type().FormatCode { + // All intrinsic types (except string) are encoded with binary + // encoding so anything else should be treated as a string + case TextFormatCode: + values = append(values, vr.ReadString(vr.Len())) + case BinaryFormatCode: + switch vr.Type().DataType { + case TextOID, VarcharOID: + values = append(values, decodeText(vr)) + case BoolOID: + values = append(values, decodeBool(vr)) + case ByteaOID: + values = append(values, decodeBytea(vr)) + case Int8OID: + values = append(values, decodeInt8(vr)) + case Int2OID: + values = append(values, decodeInt2(vr)) + case Int4OID: + values = append(values, decodeInt4(vr)) + case OIDOID: + values = append(values, decodeOID(vr)) + case Float4OID: + values = append(values, decodeFloat4(vr)) + case Float8OID: + values = append(values, decodeFloat8(vr)) + case BoolArrayOID: + values = append(values, decodeBoolArray(vr)) + case Int2ArrayOID: + values = append(values, decodeInt2Array(vr)) + case Int4ArrayOID: + values = append(values, decodeInt4Array(vr)) + case Int8ArrayOID: + values = append(values, decodeInt8Array(vr)) + case Float4ArrayOID: + values = append(values, decodeFloat4Array(vr)) + case Float8ArrayOID: + values = append(values, decodeFloat8Array(vr)) + case TextArrayOID, VarcharArrayOID: + values = append(values, decodeTextArray(vr)) + case TimestampArrayOID, TimestampTzArrayOID: + values = append(values, decodeTimestampArray(vr)) + case DateOID: + values = append(values, decodeDate(vr)) + case TimestampTzOID: + values = append(values, decodeTimestampTz(vr)) + case TimestampOID: + values = append(values, decodeTimestamp(vr)) + case InetOID, CidrOID: + values = append(values, decodeInet(vr)) + case JSONOID: + var d interface{} + decodeJSON(vr, &d) + values = append(values, d) + case JSONBOID: + var d interface{} + decodeJSONB(vr, &d) + values = append(values, d) + default: + rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types")) + } + default: + rows.Fatal(errors.New("Unknown format code")) + } + + if vr.Err() != nil { + rows.Fatal(vr.Err()) + } + + if rows.Err() != nil { + return nil, rows.Err() + } + } + + return values, rows.Err() +} + // AfterClose adds f to a LILO queue of functions that will be called when // rows is closed. func (rows *Rows) AfterClose(f func(*Rows)) { diff --git a/stdlib/sql.go b/stdlib/sql.go index 420b521e..d138ae1e 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -320,7 +320,7 @@ func (r *Rows) Next(dest []driver.Value) error { } } - values, err := r.rows.Values() + values, err := r.rows.ValuesForStdlib() if err != nil { return err } From 2e2b11be347901074a0d61e4c7d4c37163e9f2f9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 Mar 2017 17:15:05 -0600 Subject: [PATCH 065/264] Add more tests for pgtype.Bool --- pgtype/bool.go | 5 +--- pgtype/bool_test.go | 59 +++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 58 insertions(+), 6 deletions(-) diff --git a/pgtype/bool.go b/pgtype/bool.go index 81c72472..14bc2d6e 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -50,10 +50,7 @@ func (b *Bool) AssignTo(dst interface{}) error { // if dst is a pointer to pointer, strip the pointer and try again case reflect.Ptr: if b.Status == Null { - if !el.IsNil() { - // if the destination pointer is not nil, nil it out - el.Set(reflect.Zero(el.Type())) - } + el.Set(reflect.Zero(el.Type())) return nil } if el.IsNil() { diff --git a/pgtype/bool_test.go b/pgtype/bool_test.go index 53df1747..74140b5e 100644 --- a/pgtype/bool_test.go +++ b/pgtype/bool_test.go @@ -1,11 +1,14 @@ package pgtype_test import ( + "reflect" "testing" "github.com/jackc/pgx/pgtype" ) +type _bool bool + func TestBoolTranscode(t *testing.T) { testSuccessfulTranscode(t, "bool", []interface{}{ pgtype.Bool{Bool: false, Status: pgtype.Present}, @@ -15,18 +18,19 @@ func TestBoolTranscode(t *testing.T) { } func TestBoolConvertFrom(t *testing.T) { - type _int8 int8 - successfulTests := []struct { source interface{} result pgtype.Bool }{ + {source: pgtype.Bool{Bool: false, Status: pgtype.Null}, result: pgtype.Bool{Bool: false, Status: pgtype.Null}}, {source: true, result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, {source: false, result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, {source: "true", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, {source: "false", result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, {source: "t", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, {source: "f", result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, + {source: _bool(true), result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: _bool(false), result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, } for i, tt := range successfulTests { @@ -41,3 +45,54 @@ func TestBoolConvertFrom(t *testing.T) { } } } + +func TestBoolAssignTo(t *testing.T) { + var b bool + var _b _bool + var pb *bool + var _pb *_bool + + simpleTests := []struct { + src pgtype.Bool + dst interface{} + expected interface{} + }{ + {src: pgtype.Bool{Bool: false, Status: pgtype.Present}, dst: &b, expected: false}, + {src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &b, expected: true}, + {src: pgtype.Bool{Bool: false, Status: pgtype.Present}, dst: &_b, expected: _bool(false)}, + {src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &_b, expected: _bool(true)}, + {src: pgtype.Bool{Bool: false, Status: pgtype.Null}, dst: &pb, expected: ((*bool)(nil))}, + {src: pgtype.Bool{Bool: false, Status: pgtype.Null}, dst: &_pb, expected: ((*_bool)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Bool + dst interface{} + expected interface{} + }{ + {src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &pb, expected: true}, + {src: pgtype.Bool{Bool: true, Status: pgtype.Present}, dst: &_pb, expected: _bool(true)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} From 66712f8259574032fac82fd52f807e88149afa06 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 Mar 2017 17:16:07 -0600 Subject: [PATCH 066/264] travis needs go 1.8 not 1.8.0 --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index e3d94acd..60e1670f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,7 @@ language: go go: - - 1.8.0 + - 1.8 - tip # Derived from https://github.com/lib/pq/blob/master/.travis.yml From 272f095a447a5b4b7ea4b395a29c70c57948a3c7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 Mar 2017 17:35:02 -0600 Subject: [PATCH 067/264] Standardize receiver variable name for pgtype Conversion functions now use standardized src and dst depending on their role. --- pgtype/array.go | 42 +++++++++++----------- pgtype/bool.go | 54 ++++++++++++++-------------- pgtype/date.go | 58 +++++++++++++++--------------- pgtype/int2.go | 56 ++++++++++++++--------------- pgtype/int2array.go | 82 +++++++++++++++++++++---------------------- pgtype/int4.go | 56 ++++++++++++++--------------- pgtype/int8.go | 56 ++++++++++++++--------------- pgtype/timestamptz.go | 58 +++++++++++++++--------------- 8 files changed, 231 insertions(+), 231 deletions(-) diff --git a/pgtype/array.go b/pgtype/array.go index 75d2e440..76492c61 100644 --- a/pgtype/array.go +++ b/pgtype/array.go @@ -25,34 +25,34 @@ type ArrayDimension struct { LowerBound int32 } -func (ah *ArrayHeader) DecodeBinary(r io.Reader) error { +func (dst *ArrayHeader) DecodeBinary(r io.Reader) error { numDims, err := pgio.ReadInt32(r) if err != nil { return err } if numDims > 0 { - ah.Dimensions = make([]ArrayDimension, numDims) + dst.Dimensions = make([]ArrayDimension, numDims) } containsNull, err := pgio.ReadInt32(r) if err != nil { return err } - ah.ContainsNull = containsNull == 1 + dst.ContainsNull = containsNull == 1 - ah.ElementOID, err = pgio.ReadInt32(r) + dst.ElementOID, err = pgio.ReadInt32(r) if err != nil { return err } - for i := range ah.Dimensions { - ah.Dimensions[i].Length, err = pgio.ReadInt32(r) + for i := range dst.Dimensions { + dst.Dimensions[i].Length, err = pgio.ReadInt32(r) if err != nil { return err } - ah.Dimensions[i].LowerBound, err = pgio.ReadInt32(r) + dst.Dimensions[i].LowerBound, err = pgio.ReadInt32(r) if err != nil { return err } @@ -61,14 +61,14 @@ func (ah *ArrayHeader) DecodeBinary(r io.Reader) error { return nil } -func (ah *ArrayHeader) EncodeBinary(w io.Writer) error { - _, err := pgio.WriteInt32(w, int32(len(ah.Dimensions))) +func (src *ArrayHeader) EncodeBinary(w io.Writer) error { + _, err := pgio.WriteInt32(w, int32(len(src.Dimensions))) if err != nil { return err } var containsNull int32 - if ah.ContainsNull { + if src.ContainsNull { containsNull = 1 } _, err = pgio.WriteInt32(w, containsNull) @@ -76,18 +76,18 @@ func (ah *ArrayHeader) EncodeBinary(w io.Writer) error { return err } - _, err = pgio.WriteInt32(w, ah.ElementOID) + _, err = pgio.WriteInt32(w, src.ElementOID) if err != nil { return err } - for i := range ah.Dimensions { - _, err = pgio.WriteInt32(w, ah.Dimensions[i].Length) + for i := range src.Dimensions { + _, err = pgio.WriteInt32(w, src.Dimensions[i].Length) if err != nil { return err } - _, err = pgio.WriteInt32(w, ah.Dimensions[i].LowerBound) + _, err = pgio.WriteInt32(w, src.Dimensions[i].LowerBound) if err != nil { return err } @@ -102,7 +102,7 @@ type UntypedTextArray struct { } func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { - uta := &UntypedTextArray{} + dst := &UntypedTextArray{} buf := bytes.NewBufferString(src) @@ -219,7 +219,7 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { if currentDim == counterDim { implicitDimensions[currentDim].Length++ } - uta.Elements = append(uta.Elements, value) + dst.Elements = append(dst.Elements, value) } if currentDim < 0 { @@ -233,15 +233,15 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) } - if len(uta.Elements) == 0 { - uta.Dimensions = nil + if len(dst.Elements) == 0 { + dst.Dimensions = nil } else if len(explicitDimensions) > 0 { - uta.Dimensions = explicitDimensions + dst.Dimensions = explicitDimensions } else { - uta.Dimensions = implicitDimensions + dst.Dimensions = implicitDimensions } - return uta, nil + return dst, nil } func skipWhitespace(buf *bytes.Buffer) { diff --git a/pgtype/bool.go b/pgtype/bool.go index 14bc2d6e..2889b787 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -14,21 +14,21 @@ type Bool struct { Status Status } -func (b *Bool) ConvertFrom(src interface{}) error { +func (dst *Bool) ConvertFrom(src interface{}) error { switch value := src.(type) { case Bool: - *b = value + *dst = value case bool: - *b = Bool{Bool: value, Status: Present} + *dst = Bool{Bool: value, Status: Present} case string: bb, err := strconv.ParseBool(value) if err != nil { return err } - *b = Bool{Bool: bb, Status: Present} + *dst = Bool{Bool: bb, Status: Present} default: if originalSrc, ok := underlyingBoolType(src); ok { - return b.ConvertFrom(originalSrc) + return dst.ConvertFrom(originalSrc) } return fmt.Errorf("cannot convert %v to Bool", value) } @@ -36,20 +36,20 @@ func (b *Bool) ConvertFrom(src interface{}) error { return nil } -func (b *Bool) AssignTo(dst interface{}) error { +func (src *Bool) AssignTo(dst interface{}) error { switch v := dst.(type) { case *bool: - if b.Status != Present { - return fmt.Errorf("cannot assign %v to %T", b, dst) + if src.Status != Present { + return fmt.Errorf("cannot assign %v to %T", src, dst) } - *v = b.Bool + *v = src.Bool default: if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { el := v.Elem() switch el.Kind() { // if dst is a pointer to pointer, strip the pointer and try again case reflect.Ptr: - if b.Status == Null { + if src.Status == Null { el.Set(reflect.Zero(el.Type())) return nil } @@ -57,29 +57,29 @@ func (b *Bool) AssignTo(dst interface{}) error { // allocate destination el.Set(reflect.New(el.Type().Elem())) } - return b.AssignTo(el.Interface()) + return src.AssignTo(el.Interface()) case reflect.Bool: - if b.Status != Present { - return fmt.Errorf("cannot assign %v to %T", b, dst) + if src.Status != Present { + return fmt.Errorf("cannot assign %v to %T", src, dst) } - el.SetBool(b.Bool) + el.SetBool(src.Bool) return nil } } - return fmt.Errorf("cannot put decode %v into %T", b, dst) + return fmt.Errorf("cannot put decode %v into %T", src, dst) } return nil } -func (b *Bool) DecodeText(r io.Reader) error { +func (dst *Bool) DecodeText(r io.Reader) error { size, err := pgio.ReadInt32(r) if err != nil { return err } if size == -1 { - *b = Bool{Status: Null} + *dst = Bool{Status: Null} return nil } @@ -92,18 +92,18 @@ func (b *Bool) DecodeText(r io.Reader) error { return err } - *b = Bool{Bool: byt == 't', Status: Present} + *dst = Bool{Bool: byt == 't', Status: Present} return nil } -func (b *Bool) DecodeBinary(r io.Reader) error { +func (dst *Bool) DecodeBinary(r io.Reader) error { size, err := pgio.ReadInt32(r) if err != nil { return err } if size == -1 { - *b = Bool{Status: Null} + *dst = Bool{Status: Null} return nil } @@ -116,12 +116,12 @@ func (b *Bool) DecodeBinary(r io.Reader) error { return err } - *b = Bool{Bool: byt == 1, Status: Present} + *dst = Bool{Bool: byt == 1, Status: Present} return nil } -func (b Bool) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, b.Status); done { +func (src Bool) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -131,7 +131,7 @@ func (b Bool) EncodeText(w io.Writer) error { } var buf []byte - if b.Bool { + if src.Bool { buf = []byte{'t'} } else { buf = []byte{'f'} @@ -141,8 +141,8 @@ func (b Bool) EncodeText(w io.Writer) error { return err } -func (b Bool) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, b.Status); done { +func (src Bool) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -152,7 +152,7 @@ func (b Bool) EncodeBinary(w io.Writer) error { } var buf []byte - if b.Bool { + if src.Bool { buf = []byte{1} } else { buf = []byte{0} diff --git a/pgtype/date.go b/pgtype/date.go index f3e3e4c6..6cd8e499 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -20,15 +20,15 @@ const ( infinityDayOffset = 2147483647 ) -func (d *Date) ConvertFrom(src interface{}) error { +func (dst *Date) ConvertFrom(src interface{}) error { switch value := src.(type) { case Date: - *d = value + *dst = value case time.Time: - *d = Date{Time: value, Status: Present} + *dst = Date{Time: value, Status: Present} default: if originalSrc, ok := underlyingTimeType(src); ok { - return d.ConvertFrom(originalSrc) + return dst.ConvertFrom(originalSrc) } return fmt.Errorf("cannot convert %v to Date", value) } @@ -36,20 +36,20 @@ func (d *Date) ConvertFrom(src interface{}) error { return nil } -func (d *Date) AssignTo(dst interface{}) error { +func (src *Date) AssignTo(dst interface{}) error { switch v := dst.(type) { case *time.Time: - if d.Status != Present || d.InfinityModifier != None { - return fmt.Errorf("cannot assign %v to %T", d, dst) + if src.Status != Present || src.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", src, dst) } - *v = d.Time + *v = src.Time default: if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { el := v.Elem() switch el.Kind() { // if dst is a pointer to pointer, strip the pointer and try again case reflect.Ptr: - if d.Status == Null { + if src.Status == Null { if !el.IsNil() { // if the destination pointer is not nil, nil it out el.Set(reflect.Zero(el.Type())) @@ -60,23 +60,23 @@ func (d *Date) AssignTo(dst interface{}) error { // allocate destination el.Set(reflect.New(el.Type().Elem())) } - return d.AssignTo(el.Interface()) + return src.AssignTo(el.Interface()) } } - return fmt.Errorf("cannot decode %v into %T", d, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil } -func (d *Date) DecodeText(r io.Reader) error { +func (dst *Date) DecodeText(r io.Reader) error { size, err := pgio.ReadInt32(r) if err != nil { return err } if size == -1 { - *d = Date{Status: Null} + *dst = Date{Status: Null} return nil } @@ -89,29 +89,29 @@ func (d *Date) DecodeText(r io.Reader) error { sbuf := string(buf) switch sbuf { case "infinity": - *d = Date{Status: Present, InfinityModifier: Infinity} + *dst = Date{Status: Present, InfinityModifier: Infinity} case "-infinity": - *d = Date{Status: Present, InfinityModifier: -Infinity} + *dst = Date{Status: Present, InfinityModifier: -Infinity} default: t, err := time.ParseInLocation("2006-01-02", sbuf, time.UTC) if err != nil { return err } - *d = Date{Time: t, Status: Present} + *dst = Date{Time: t, Status: Present} } return nil } -func (d *Date) DecodeBinary(r io.Reader) error { +func (dst *Date) DecodeBinary(r io.Reader) error { size, err := pgio.ReadInt32(r) if err != nil { return err } if size == -1 { - *d = Date{Status: Null} + *dst = Date{Status: Null} return nil } @@ -126,27 +126,27 @@ func (d *Date) DecodeBinary(r io.Reader) error { switch dayOffset { case infinityDayOffset: - *d = Date{Status: Present, InfinityModifier: Infinity} + *dst = Date{Status: Present, InfinityModifier: Infinity} case negativeInfinityDayOffset: - *d = Date{Status: Present, InfinityModifier: -Infinity} + *dst = Date{Status: Present, InfinityModifier: -Infinity} default: t := time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.UTC) - *d = Date{Time: t, Status: Present} + *dst = Date{Time: t, Status: Present} } return nil } -func (d Date) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, d.Status); done { +func (src Date) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { return err } var s string - switch d.InfinityModifier { + switch src.InfinityModifier { case None: - s = d.Time.Format("2006-01-02") + s = src.Time.Format("2006-01-02") case Infinity: s = "infinity" case NegativeInfinity: @@ -162,8 +162,8 @@ func (d Date) EncodeText(w io.Writer) error { return err } -func (d Date) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, d.Status); done { +func (src Date) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -173,9 +173,9 @@ func (d Date) EncodeBinary(w io.Writer) error { } var daysSinceDateEpoch int32 - switch d.InfinityModifier { + switch src.InfinityModifier { case None: - tUnix := time.Date(d.Time.Year(), d.Time.Month(), d.Time.Day(), 0, 0, 0, 0, time.UTC).Unix() + tUnix := time.Date(src.Time.Year(), src.Time.Month(), src.Time.Day(), 0, 0, 0, 0, time.UTC).Unix() dateEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).Unix() secSinceDateEpoch := tUnix - dateEpoch diff --git a/pgtype/int2.go b/pgtype/int2.go index 2da8a96d..fb6a8ccc 100644 --- a/pgtype/int2.go +++ b/pgtype/int2.go @@ -14,21 +14,21 @@ type Int2 struct { Status Status } -func (i *Int2) ConvertFrom(src interface{}) error { +func (dst *Int2) ConvertFrom(src interface{}) error { switch value := src.(type) { case Int2: - *i = value + *dst = value case int8: - *i = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Status: Present} case uint8: - *i = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Status: Present} case int16: - *i = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Status: Present} case uint16: if value > math.MaxInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) } - *i = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Status: Present} case int32: if value < math.MinInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) @@ -36,12 +36,12 @@ func (i *Int2) ConvertFrom(src interface{}) error { if value > math.MaxInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) } - *i = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Status: Present} case uint32: if value > math.MaxInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) } - *i = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Status: Present} case int64: if value < math.MinInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) @@ -49,12 +49,12 @@ func (i *Int2) ConvertFrom(src interface{}) error { if value > math.MaxInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) } - *i = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Status: Present} case uint64: if value > math.MaxInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) } - *i = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Status: Present} case int: if value < math.MinInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) @@ -62,21 +62,21 @@ func (i *Int2) ConvertFrom(src interface{}) error { if value > math.MaxInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) } - *i = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Status: Present} case uint: if value > math.MaxInt16 { return fmt.Errorf("%d is greater than maximum value for Int2", value) } - *i = Int2{Int: int16(value), Status: Present} + *dst = Int2{Int: int16(value), Status: Present} case string: num, err := strconv.ParseInt(value, 10, 16) if err != nil { return err } - *i = Int2{Int: int16(num), Status: Present} + *dst = Int2{Int: int16(num), Status: Present} default: if originalSrc, ok := underlyingIntType(src); ok { - return i.ConvertFrom(originalSrc) + return dst.ConvertFrom(originalSrc) } return fmt.Errorf("cannot convert %v to Int2", value) } @@ -84,18 +84,18 @@ func (i *Int2) ConvertFrom(src interface{}) error { return nil } -func (i *Int2) AssignTo(dst interface{}) error { - return int64AssignTo(int64(i.Int), i.Status, dst) +func (src *Int2) AssignTo(dst interface{}) error { + return int64AssignTo(int64(src.Int), src.Status, dst) } -func (i *Int2) DecodeText(r io.Reader) error { +func (dst *Int2) DecodeText(r io.Reader) error { size, err := pgio.ReadInt32(r) if err != nil { return err } if size == -1 { - *i = Int2{Status: Null} + *dst = Int2{Status: Null} return nil } @@ -110,18 +110,18 @@ func (i *Int2) DecodeText(r io.Reader) error { return err } - *i = Int2{Int: int16(n), Status: Present} + *dst = Int2{Int: int16(n), Status: Present} return nil } -func (i *Int2) DecodeBinary(r io.Reader) error { +func (dst *Int2) DecodeBinary(r io.Reader) error { size, err := pgio.ReadInt32(r) if err != nil { return err } if size == -1 { - *i = Int2{Status: Null} + *dst = Int2{Status: Null} return nil } @@ -134,16 +134,16 @@ func (i *Int2) DecodeBinary(r io.Reader) error { return err } - *i = Int2{Int: int16(n), Status: Present} + *dst = Int2{Int: int16(n), Status: Present} return nil } -func (i Int2) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, i.Status); done { +func (src Int2) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { return err } - s := strconv.FormatInt(int64(i.Int), 10) + s := strconv.FormatInt(int64(src.Int), 10) _, err := pgio.WriteInt32(w, int32(len(s))) if err != nil { return nil @@ -152,8 +152,8 @@ func (i Int2) EncodeText(w io.Writer) error { return err } -func (i Int2) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, i.Status); done { +func (src Int2) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -162,6 +162,6 @@ func (i Int2) EncodeBinary(w io.Writer) error { return err } - _, err = pgio.WriteInt16(w, i.Int) + _, err = pgio.WriteInt16(w, src.Int) return err } diff --git a/pgtype/int2array.go b/pgtype/int2array.go index 86375516..4ac0c409 100644 --- a/pgtype/int2array.go +++ b/pgtype/int2array.go @@ -14,15 +14,15 @@ type Int2Array struct { Status Status } -func (a *Int2Array) ConvertFrom(src interface{}) error { +func (dst *Int2Array) ConvertFrom(src interface{}) error { switch value := src.(type) { case Int2Array: - *a = value + *dst = value case []int16: if value == nil { - *a = Int2Array{Status: Null} + *dst = Int2Array{Status: Null} } else if len(value) == 0 { - *a = Int2Array{Status: Present} + *dst = Int2Array{Status: Present} } else { elements := make([]Int2, len(value)) for i := range value { @@ -30,7 +30,7 @@ func (a *Int2Array) ConvertFrom(src interface{}) error { return err } } - *a = Int2Array{ + *dst = Int2Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, Status: Present, @@ -38,9 +38,9 @@ func (a *Int2Array) ConvertFrom(src interface{}) error { } case []uint16: if value == nil { - *a = Int2Array{Status: Null} + *dst = Int2Array{Status: Null} } else if len(value) == 0 { - *a = Int2Array{Status: Present} + *dst = Int2Array{Status: Present} } else { elements := make([]Int2, len(value)) for i := range value { @@ -48,7 +48,7 @@ func (a *Int2Array) ConvertFrom(src interface{}) error { return err } } - *a = Int2Array{ + *dst = Int2Array{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, Status: Present, @@ -56,7 +56,7 @@ func (a *Int2Array) ConvertFrom(src interface{}) error { } default: if originalSrc, ok := underlyingSliceType(src); ok { - return a.ConvertFrom(originalSrc) + return dst.ConvertFrom(originalSrc) } return fmt.Errorf("cannot convert %v to Int2", value) } @@ -64,13 +64,13 @@ func (a *Int2Array) ConvertFrom(src interface{}) error { return nil } -func (a *Int2Array) AssignTo(dst interface{}) error { +func (src *Int2Array) AssignTo(dst interface{}) error { switch v := dst.(type) { case *[]int16: - if a.Status == Present { - *v = make([]int16, len(a.Elements)) - for i := range a.Elements { - if err := a.Elements[i].AssignTo(&((*v)[i])); err != nil { + if src.Status == Present { + *v = make([]int16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } @@ -78,10 +78,10 @@ func (a *Int2Array) AssignTo(dst interface{}) error { *v = nil } case *[]uint16: - if a.Status == Present { - *v = make([]uint16, len(a.Elements)) - for i := range a.Elements { - if err := a.Elements[i].AssignTo(&((*v)[i])); err != nil { + if src.Status == Present { + *v = make([]uint16, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } @@ -89,20 +89,20 @@ func (a *Int2Array) AssignTo(dst interface{}) error { *v = nil } default: - return fmt.Errorf("cannot put decode %v into %T", a, dst) + return fmt.Errorf("cannot put decode %v into %T", src, dst) } return nil } -func (a *Int2Array) DecodeText(r io.Reader) error { +func (dst *Int2Array) DecodeText(r io.Reader) error { size, err := pgio.ReadInt32(r) if err != nil { return err } if size == -1 { - *a = Int2Array{Status: Null} + *dst = Int2Array{Status: Null} return nil } @@ -135,19 +135,19 @@ func (a *Int2Array) DecodeText(r io.Reader) error { } } - *a = Int2Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = Int2Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} return nil } -func (a *Int2Array) DecodeBinary(r io.Reader) error { +func (dst *Int2Array) DecodeBinary(r io.Reader) error { size, err := pgio.ReadInt32(r) if err != nil { return err } if size == -1 { - *a = Int2Array{Status: Null} + *dst = Int2Array{Status: Null} return nil } @@ -158,7 +158,7 @@ func (a *Int2Array) DecodeBinary(r io.Reader) error { } if len(arrayHeader.Dimensions) == 0 { - *a = Int2Array{Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = Int2Array{Dimensions: arrayHeader.Dimensions, Status: Present} return nil } @@ -176,16 +176,16 @@ func (a *Int2Array) DecodeBinary(r io.Reader) error { } } - *a = Int2Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = Int2Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} return nil } -func (a *Int2Array) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, a.Status); done { +func (src *Int2Array) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { return err } - if len(a.Dimensions) == 0 { + if len(src.Dimensions) == 0 { _, err := pgio.WriteInt32(w, 2) if err != nil { return err @@ -197,7 +197,7 @@ func (a *Int2Array) EncodeText(w io.Writer) error { buf := &bytes.Buffer{} - err := EncodeTextArrayDimensions(buf, a.Dimensions) + err := EncodeTextArrayDimensions(buf, src.Dimensions) if err != nil { return err } @@ -207,15 +207,15 @@ func (a *Int2Array) EncodeText(w io.Writer) error { // [4]. A multi-dimensional array of lengths [3,5,2] would have a // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' // or '}'. - dimElemCounts := make([]int, len(a.Dimensions)) - dimElemCounts[len(a.Dimensions)-1] = int(a.Dimensions[len(a.Dimensions)-1].Length) - for i := len(a.Dimensions) - 2; i > -1; i-- { - dimElemCounts[i] = int(a.Dimensions[i].Length) * dimElemCounts[i+1] + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } textElementWriter := NewTextElementWriter(buf) - for i, elem := range a.Elements { + for i, elem := range src.Elements { if i > 0 { err = pgio.WriteByte(buf, ',') if err != nil { @@ -257,8 +257,8 @@ func (a *Int2Array) EncodeText(w io.Writer) error { return err } -func (a *Int2Array) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, a.Status); done { +func (src *Int2Array) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -268,18 +268,18 @@ func (a *Int2Array) EncodeBinary(w io.Writer) error { // or how not pay allocations for the byte order conversions. elemBuf := &bytes.Buffer{} - for i := range a.Elements { - err := a.Elements[i].EncodeBinary(elemBuf) + for i := range src.Elements { + err := src.Elements[i].EncodeBinary(elemBuf) if err != nil { return err } - if a.Elements[i].Status == Null { + if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true } } arrayHeader.ElementOID = Int2OID - arrayHeader.Dimensions = a.Dimensions + arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - // or how not pay allocations for the byte order conversions. diff --git a/pgtype/int4.go b/pgtype/int4.go index 84c45522..1a4733b0 100644 --- a/pgtype/int4.go +++ b/pgtype/int4.go @@ -14,25 +14,25 @@ type Int4 struct { Status Status } -func (i *Int4) ConvertFrom(src interface{}) error { +func (dst *Int4) ConvertFrom(src interface{}) error { switch value := src.(type) { case Int4: - *i = value + *dst = value case int8: - *i = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Status: Present} case uint8: - *i = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Status: Present} case int16: - *i = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Status: Present} case uint16: - *i = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Status: Present} case int32: - *i = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Status: Present} case uint32: if value > math.MaxInt32 { return fmt.Errorf("%d is greater than maximum value for Int4", value) } - *i = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Status: Present} case int64: if value < math.MinInt32 { return fmt.Errorf("%d is greater than maximum value for Int4", value) @@ -40,12 +40,12 @@ func (i *Int4) ConvertFrom(src interface{}) error { if value > math.MaxInt32 { return fmt.Errorf("%d is greater than maximum value for Int4", value) } - *i = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Status: Present} case uint64: if value > math.MaxInt32 { return fmt.Errorf("%d is greater than maximum value for Int4", value) } - *i = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Status: Present} case int: if value < math.MinInt32 { return fmt.Errorf("%d is greater than maximum value for Int4", value) @@ -53,21 +53,21 @@ func (i *Int4) ConvertFrom(src interface{}) error { if value > math.MaxInt32 { return fmt.Errorf("%d is greater than maximum value for Int4", value) } - *i = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Status: Present} case uint: if value > math.MaxInt32 { return fmt.Errorf("%d is greater than maximum value for Int4", value) } - *i = Int4{Int: int32(value), Status: Present} + *dst = Int4{Int: int32(value), Status: Present} case string: num, err := strconv.ParseInt(value, 10, 32) if err != nil { return err } - *i = Int4{Int: int32(num), Status: Present} + *dst = Int4{Int: int32(num), Status: Present} default: if originalSrc, ok := underlyingIntType(src); ok { - return i.ConvertFrom(originalSrc) + return dst.ConvertFrom(originalSrc) } return fmt.Errorf("cannot convert %v to Int8", value) } @@ -75,18 +75,18 @@ func (i *Int4) ConvertFrom(src interface{}) error { return nil } -func (i *Int4) AssignTo(dst interface{}) error { - return int64AssignTo(int64(i.Int), i.Status, dst) +func (src *Int4) AssignTo(dst interface{}) error { + return int64AssignTo(int64(src.Int), src.Status, dst) } -func (i *Int4) DecodeText(r io.Reader) error { +func (dst *Int4) DecodeText(r io.Reader) error { size, err := pgio.ReadInt32(r) if err != nil { return err } if size == -1 { - *i = Int4{Status: Null} + *dst = Int4{Status: Null} return nil } @@ -101,18 +101,18 @@ func (i *Int4) DecodeText(r io.Reader) error { return err } - *i = Int4{Int: int32(n), Status: Present} + *dst = Int4{Int: int32(n), Status: Present} return nil } -func (i *Int4) DecodeBinary(r io.Reader) error { +func (dst *Int4) DecodeBinary(r io.Reader) error { size, err := pgio.ReadInt32(r) if err != nil { return err } if size == -1 { - *i = Int4{Status: Null} + *dst = Int4{Status: Null} return nil } @@ -125,16 +125,16 @@ func (i *Int4) DecodeBinary(r io.Reader) error { return err } - *i = Int4{Int: n, Status: Present} + *dst = Int4{Int: n, Status: Present} return nil } -func (i Int4) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, i.Status); done { +func (src Int4) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { return err } - s := strconv.FormatInt(int64(i.Int), 10) + s := strconv.FormatInt(int64(src.Int), 10) _, err := pgio.WriteInt32(w, int32(len(s))) if err != nil { return nil @@ -143,8 +143,8 @@ func (i Int4) EncodeText(w io.Writer) error { return err } -func (i Int4) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, i.Status); done { +func (src Int4) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -153,6 +153,6 @@ func (i Int4) EncodeBinary(w io.Writer) error { return err } - _, err = pgio.WriteInt32(w, i.Int) + _, err = pgio.WriteInt32(w, src.Int) return err } diff --git a/pgtype/int8.go b/pgtype/int8.go index c0e14e44..7f307f18 100644 --- a/pgtype/int8.go +++ b/pgtype/int8.go @@ -14,29 +14,29 @@ type Int8 struct { Status Status } -func (i *Int8) ConvertFrom(src interface{}) error { +func (dst *Int8) ConvertFrom(src interface{}) error { switch value := src.(type) { case Int8: - *i = value + *dst = value case int8: - *i = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Status: Present} case uint8: - *i = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Status: Present} case int16: - *i = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Status: Present} case uint16: - *i = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Status: Present} case int32: - *i = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Status: Present} case uint32: - *i = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Status: Present} case int64: - *i = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Status: Present} case uint64: if value > math.MaxInt64 { return fmt.Errorf("%d is greater than maximum value for Int8", value) } - *i = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Status: Present} case int: if int64(value) < math.MinInt64 { return fmt.Errorf("%d is greater than maximum value for Int8", value) @@ -44,21 +44,21 @@ func (i *Int8) ConvertFrom(src interface{}) error { if int64(value) > math.MaxInt64 { return fmt.Errorf("%d is greater than maximum value for Int8", value) } - *i = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Status: Present} case uint: if uint64(value) > math.MaxInt64 { return fmt.Errorf("%d is greater than maximum value for Int8", value) } - *i = Int8{Int: int64(value), Status: Present} + *dst = Int8{Int: int64(value), Status: Present} case string: num, err := strconv.ParseInt(value, 10, 64) if err != nil { return err } - *i = Int8{Int: num, Status: Present} + *dst = Int8{Int: num, Status: Present} default: if originalSrc, ok := underlyingIntType(src); ok { - return i.ConvertFrom(originalSrc) + return dst.ConvertFrom(originalSrc) } return fmt.Errorf("cannot convert %v to Int8", value) } @@ -66,18 +66,18 @@ func (i *Int8) ConvertFrom(src interface{}) error { return nil } -func (i *Int8) AssignTo(dst interface{}) error { - return int64AssignTo(int64(i.Int), i.Status, dst) +func (src *Int8) AssignTo(dst interface{}) error { + return int64AssignTo(int64(src.Int), src.Status, dst) } -func (i *Int8) DecodeText(r io.Reader) error { +func (dst *Int8) DecodeText(r io.Reader) error { size, err := pgio.ReadInt32(r) if err != nil { return err } if size == -1 { - *i = Int8{Status: Null} + *dst = Int8{Status: Null} return nil } @@ -92,18 +92,18 @@ func (i *Int8) DecodeText(r io.Reader) error { return err } - *i = Int8{Int: n, Status: Present} + *dst = Int8{Int: n, Status: Present} return nil } -func (i *Int8) DecodeBinary(r io.Reader) error { +func (dst *Int8) DecodeBinary(r io.Reader) error { size, err := pgio.ReadInt32(r) if err != nil { return err } if size == -1 { - *i = Int8{Status: Null} + *dst = Int8{Status: Null} return nil } @@ -116,16 +116,16 @@ func (i *Int8) DecodeBinary(r io.Reader) error { return err } - *i = Int8{Int: n, Status: Present} + *dst = Int8{Int: n, Status: Present} return nil } -func (i Int8) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, i.Status); done { +func (src Int8) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { return err } - s := strconv.FormatInt(i.Int, 10) + s := strconv.FormatInt(src.Int, 10) _, err := pgio.WriteInt32(w, int32(len(s))) if err != nil { return nil @@ -134,8 +134,8 @@ func (i Int8) EncodeText(w io.Writer) error { return err } -func (i Int8) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, i.Status); done { +func (src Int8) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -144,6 +144,6 @@ func (i Int8) EncodeBinary(w io.Writer) error { return err } - _, err = pgio.WriteInt64(w, i.Int) + _, err = pgio.WriteInt64(w, src.Int) return err } diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index cc33b296..4f08cd2a 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -25,15 +25,15 @@ type Timestamptz struct { InfinityModifier } -func (t *Timestamptz) ConvertFrom(src interface{}) error { +func (dst *Timestamptz) ConvertFrom(src interface{}) error { switch value := src.(type) { case Timestamptz: - *t = value + *dst = value case time.Time: - *t = Timestamptz{Time: value, Status: Present} + *dst = Timestamptz{Time: value, Status: Present} default: if originalSrc, ok := underlyingTimeType(src); ok { - return t.ConvertFrom(originalSrc) + return dst.ConvertFrom(originalSrc) } return fmt.Errorf("cannot convert %v to Timestamptz", value) } @@ -41,20 +41,20 @@ func (t *Timestamptz) ConvertFrom(src interface{}) error { return nil } -func (t *Timestamptz) AssignTo(dst interface{}) error { +func (src *Timestamptz) AssignTo(dst interface{}) error { switch v := dst.(type) { case *time.Time: - if t.Status != Present || t.InfinityModifier != None { - return fmt.Errorf("cannot assign %v to %T", t, dst) + if src.Status != Present || src.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", src, dst) } - *v = t.Time + *v = src.Time default: if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { el := v.Elem() switch el.Kind() { // if dst is a pointer to pointer, strip the pointer and try again case reflect.Ptr: - if t.Status == Null { + if src.Status == Null { if !el.IsNil() { // if the destination pointer is not nil, nil it out el.Set(reflect.Zero(el.Type())) @@ -65,23 +65,23 @@ func (t *Timestamptz) AssignTo(dst interface{}) error { // allocate destination el.Set(reflect.New(el.Type().Elem())) } - return t.AssignTo(el.Interface()) + return src.AssignTo(el.Interface()) } } - return fmt.Errorf("cannot assign %v into %T", t, dst) + return fmt.Errorf("cannot assign %v into %T", src, dst) } return nil } -func (t *Timestamptz) DecodeText(r io.Reader) error { +func (dst *Timestamptz) DecodeText(r io.Reader) error { size, err := pgio.ReadInt32(r) if err != nil { return err } if size == -1 { - *t = Timestamptz{Status: Null} + *dst = Timestamptz{Status: Null} return nil } @@ -94,9 +94,9 @@ func (t *Timestamptz) DecodeText(r io.Reader) error { sbuf := string(buf) switch sbuf { case "infinity": - *t = Timestamptz{Status: Present, InfinityModifier: Infinity} + *dst = Timestamptz{Status: Present, InfinityModifier: Infinity} case "-infinity": - *t = Timestamptz{Status: Present, InfinityModifier: -Infinity} + *dst = Timestamptz{Status: Present, InfinityModifier: -Infinity} default: var format string if sbuf[len(sbuf)-9] == '-' || sbuf[len(sbuf)-9] == '+' { @@ -112,20 +112,20 @@ func (t *Timestamptz) DecodeText(r io.Reader) error { return err } - *t = Timestamptz{Time: tim, Status: Present} + *dst = Timestamptz{Time: tim, Status: Present} } return nil } -func (t *Timestamptz) DecodeBinary(r io.Reader) error { +func (dst *Timestamptz) DecodeBinary(r io.Reader) error { size, err := pgio.ReadInt32(r) if err != nil { return err } if size == -1 { - *t = Timestamptz{Status: Null} + *dst = Timestamptz{Status: Null} return nil } @@ -140,28 +140,28 @@ func (t *Timestamptz) DecodeBinary(r io.Reader) error { switch microsecSinceY2K { case infinityMicrosecondOffset: - *t = Timestamptz{Status: Present, InfinityModifier: Infinity} + *dst = Timestamptz{Status: Present, InfinityModifier: Infinity} case negativeInfinityMicrosecondOffset: - *t = Timestamptz{Status: Present, InfinityModifier: -Infinity} + *dst = Timestamptz{Status: Present, InfinityModifier: -Infinity} default: microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K tim := time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) - *t = Timestamptz{Time: tim, Status: Present} + *dst = Timestamptz{Time: tim, Status: Present} } return nil } -func (t Timestamptz) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, t.Status); done { +func (src Timestamptz) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { return err } var s string - switch t.InfinityModifier { + switch src.InfinityModifier { case None: - s = t.Time.UTC().Format(pgTimestamptzSecondFormat) + s = src.Time.UTC().Format(pgTimestamptzSecondFormat) case Infinity: s = "infinity" case NegativeInfinity: @@ -177,8 +177,8 @@ func (t Timestamptz) EncodeText(w io.Writer) error { return err } -func (t Timestamptz) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, t.Status); done { +func (src Timestamptz) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -188,9 +188,9 @@ func (t Timestamptz) EncodeBinary(w io.Writer) error { } var microsecSinceY2K int64 - switch t.InfinityModifier { + switch src.InfinityModifier { case None: - microsecSinceUnixEpoch := t.Time.Unix()*1000000 + int64(t.Time.Nanosecond())/1000 + microsecSinceUnixEpoch := src.Time.Unix()*1000000 + int64(src.Time.Nanosecond())/1000 microsecSinceY2K = microsecSinceUnixEpoch - microsecFromUnixEpochToY2K case Infinity: microsecSinceY2K = infinityMicrosecondOffset From 7fd09c4cd2a86d327028c145c8235c3bd3de96a9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 Mar 2017 17:39:55 -0600 Subject: [PATCH 068/264] Supply DATABASE_URL for travis --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index 60e1670f..40b9c399 100644 --- a/.travis.yml +++ b/.travis.yml @@ -27,6 +27,7 @@ before_install: - sudo /etc/init.d/postgresql restart env: + - DATABASE_URL=postgres://pgx_md5:secret@127.0.0.1/pgx_test matrix: - PGVERSION=9.6 - PGVERSION=9.5 From 2fb46fb16fcd43dcd863a9bd3c4c0f2961a1021e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 Mar 2017 17:42:25 -0600 Subject: [PATCH 069/264] Fix travis.yml --- .travis.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 40b9c399..cd9ab572 100644 --- a/.travis.yml +++ b/.travis.yml @@ -27,7 +27,8 @@ before_install: - sudo /etc/init.d/postgresql restart env: - - DATABASE_URL=postgres://pgx_md5:secret@127.0.0.1/pgx_test + global: + - DATABASE_URL=postgres://pgx_md5:secret@127.0.0.1/pgx_test matrix: - PGVERSION=9.6 - PGVERSION=9.5 From 9e5d81d8f5de69cc65868fc6bdfe5e1f7bb59eae Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 Mar 2017 17:59:26 -0600 Subject: [PATCH 070/264] Add test for pgtype.Int2.AssignTo --- pgtype/bool_test.go | 2 -- pgtype/int2_test.go | 70 +++++++++++++++++++++++++++++++++++++++++-- pgtype/pgtype_test.go | 5 ++++ 3 files changed, 73 insertions(+), 4 deletions(-) diff --git a/pgtype/bool_test.go b/pgtype/bool_test.go index 74140b5e..374f07da 100644 --- a/pgtype/bool_test.go +++ b/pgtype/bool_test.go @@ -7,8 +7,6 @@ import ( "github.com/jackc/pgx/pgtype" ) -type _bool bool - func TestBoolTranscode(t *testing.T) { testSuccessfulTranscode(t, "bool", []interface{}{ pgtype.Bool{Bool: false, Status: pgtype.Present}, diff --git a/pgtype/int2_test.go b/pgtype/int2_test.go index a8493a16..1074c9b5 100644 --- a/pgtype/int2_test.go +++ b/pgtype/int2_test.go @@ -2,6 +2,7 @@ package pgtype_test import ( "math" + "reflect" "testing" "github.com/jackc/pgx/pgtype" @@ -19,8 +20,6 @@ func TestInt2Transcode(t *testing.T) { } func TestInt2ConvertFrom(t *testing.T) { - type _int8 int8 - successfulTests := []struct { source interface{} result pgtype.Int2 @@ -53,3 +52,70 @@ func TestInt2ConvertFrom(t *testing.T) { } } } + +func TestInt2AssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + + simpleTests := []struct { + src pgtype.Int2 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i8, expected: int8(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Int2{Int: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Int2{Int: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Int2 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &pi8, expected: int8(42)}, + {src: pgtype.Int2{Int: 42, Status: pgtype.Present}, dst: &_pi8, expected: _int8(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index a1a575f7..32ebebfe 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -11,6 +11,11 @@ import ( "github.com/jackc/pgx/pgtype" ) +// Test for renamed types +type _bool bool +type _int8 int8 +type _int16 int16 + func mustConnectPgx(t testing.TB) *pgx.Conn { config, err := pgx.ParseURI(os.Getenv("DATABASE_URL")) if err != nil { From 5b861d0a5f91b506fd858a73b6ac2949bf54de2c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 Mar 2017 18:23:26 -0600 Subject: [PATCH 071/264] Add tests to more pgtypes Int4, Int8, Date, Timestamptz --- pgtype/date.go | 5 +-- pgtype/date_test.go | 46 +++++++++++++++++++++++++ pgtype/int4_test.go | 70 ++++++++++++++++++++++++++++++++++++-- pgtype/int8_test.go | 70 ++++++++++++++++++++++++++++++++++++-- pgtype/timestamptz_test.go | 46 +++++++++++++++++++++++++ 5 files changed, 229 insertions(+), 8 deletions(-) diff --git a/pgtype/date.go b/pgtype/date.go index 6cd8e499..307f1e59 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -50,10 +50,7 @@ func (src *Date) AssignTo(dst interface{}) error { // if dst is a pointer to pointer, strip the pointer and try again case reflect.Ptr: if src.Status == Null { - if !el.IsNil() { - // if the destination pointer is not nil, nil it out - el.Set(reflect.Zero(el.Type())) - } + el.Set(reflect.Zero(el.Type())) return nil } if el.IsNil() { diff --git a/pgtype/date_test.go b/pgtype/date_test.go index c3e971d0..65d743e9 100644 --- a/pgtype/date_test.go +++ b/pgtype/date_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "reflect" "testing" "time" @@ -28,6 +29,7 @@ func TestDateConvertFrom(t *testing.T) { source interface{} result pgtype.Date }{ + {source: pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, result: pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, {source: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, @@ -49,3 +51,47 @@ func TestDateConvertFrom(t *testing.T) { } } } + +func TestDateAssignTo(t *testing.T) { + var tim time.Time + var ptim *time.Time + + simpleTests := []struct { + src pgtype.Date + dst interface{} + expected interface{} + }{ + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + {src: pgtype.Date{Time: time.Time{}, Status: pgtype.Null}, dst: &ptim, expected: ((*time.Time)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Date + dst interface{} + expected interface{} + }{ + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/int4_test.go b/pgtype/int4_test.go index 04411849..cd57e2c9 100644 --- a/pgtype/int4_test.go +++ b/pgtype/int4_test.go @@ -2,6 +2,7 @@ package pgtype_test import ( "math" + "reflect" "testing" "github.com/jackc/pgx/pgtype" @@ -19,8 +20,6 @@ func TestInt4Transcode(t *testing.T) { } func TestInt4ConvertFrom(t *testing.T) { - type _int8 int8 - successfulTests := []struct { source interface{} result pgtype.Int4 @@ -53,3 +52,70 @@ func TestInt4ConvertFrom(t *testing.T) { } } } + +func TestInt4AssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + + simpleTests := []struct { + src pgtype.Int4 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i8, expected: int8(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Int4{Int: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Int4{Int: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Int4 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &pi8, expected: int8(42)}, + {src: pgtype.Int4{Int: 42, Status: pgtype.Present}, dst: &_pi8, expected: _int8(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/int8_test.go b/pgtype/int8_test.go index ba246224..f9d8646f 100644 --- a/pgtype/int8_test.go +++ b/pgtype/int8_test.go @@ -2,6 +2,7 @@ package pgtype_test import ( "math" + "reflect" "testing" "github.com/jackc/pgx/pgtype" @@ -19,8 +20,6 @@ func TestInt8Transcode(t *testing.T) { } func TestInt8ConvertFrom(t *testing.T) { - type _int8 int8 - successfulTests := []struct { source interface{} result pgtype.Int8 @@ -53,3 +52,70 @@ func TestInt8ConvertFrom(t *testing.T) { } } } + +func TestInt8AssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + + simpleTests := []struct { + src pgtype.Int8 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i8, expected: int8(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Int8{Int: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Int8{Int: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Int8 + dst interface{} + expected interface{} + }{ + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &pi8, expected: int8(42)}, + {src: pgtype.Int8{Int: 42, Status: pgtype.Present}, dst: &_pi8, expected: _int8(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/timestamptz_test.go b/pgtype/timestamptz_test.go index 795195f8..adb72620 100644 --- a/pgtype/timestamptz_test.go +++ b/pgtype/timestamptz_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "reflect" "testing" "time" @@ -37,6 +38,7 @@ func TestTimestamptzConvertFrom(t *testing.T) { source interface{} result pgtype.Timestamptz }{ + {source: pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, result: pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, {source: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), Status: pgtype.Present}}, @@ -58,3 +60,47 @@ func TestTimestamptzConvertFrom(t *testing.T) { } } } + +func TestTimestamptzAssignTo(t *testing.T) { + var tim time.Time + var ptim *time.Time + + simpleTests := []struct { + src pgtype.Timestamptz + dst interface{} + expected interface{} + }{ + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + {src: pgtype.Timestamptz{Time: time.Time{}, Status: pgtype.Null}, dst: &ptim, expected: ((*time.Time)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Timestamptz + dst interface{} + expected interface{} + }{ + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} From 6a3b22cee8b71b078434372ba027d0b2a7ad0a2a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 Mar 2017 18:39:52 -0600 Subject: [PATCH 072/264] Add pgtype error cases --- pgtype/date_test.go | 16 ++++++++++++++++ pgtype/int2_test.go | 20 ++++++++++++++++++++ pgtype/int4_test.go | 21 +++++++++++++++++++++ pgtype/int8_test.go | 22 ++++++++++++++++++++++ pgtype/timestamptz.go | 5 +---- pgtype/timestamptz_test.go | 16 ++++++++++++++++ 6 files changed, 96 insertions(+), 4 deletions(-) diff --git a/pgtype/date_test.go b/pgtype/date_test.go index 65d743e9..3a473b6a 100644 --- a/pgtype/date_test.go +++ b/pgtype/date_test.go @@ -94,4 +94,20 @@ func TestDateAssignTo(t *testing.T) { t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) } } + + errorTests := []struct { + src pgtype.Date + dst interface{} + }{ + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, dst: &tim}, + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, dst: &tim}, + {src: pgtype.Date{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Null}, dst: &tim}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } } diff --git a/pgtype/int2_test.go b/pgtype/int2_test.go index 1074c9b5..8601309d 100644 --- a/pgtype/int2_test.go +++ b/pgtype/int2_test.go @@ -118,4 +118,24 @@ func TestInt2AssignTo(t *testing.T) { t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) } } + + errorTests := []struct { + src pgtype.Int2 + dst interface{} + }{ + {src: pgtype.Int2{Int: 150, Status: pgtype.Present}, dst: &i8}, + {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui8}, + {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui16}, + {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui32}, + {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui64}, + {src: pgtype.Int2{Int: -1, Status: pgtype.Present}, dst: &ui}, + {src: pgtype.Int2{Int: 0, Status: pgtype.Null}, dst: &i16}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } } diff --git a/pgtype/int4_test.go b/pgtype/int4_test.go index cd57e2c9..0ac2e5b5 100644 --- a/pgtype/int4_test.go +++ b/pgtype/int4_test.go @@ -118,4 +118,25 @@ func TestInt4AssignTo(t *testing.T) { t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) } } + + errorTests := []struct { + src pgtype.Int4 + dst interface{} + }{ + {src: pgtype.Int4{Int: 150, Status: pgtype.Present}, dst: &i8}, + {src: pgtype.Int4{Int: 40000, Status: pgtype.Present}, dst: &i16}, + {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui8}, + {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui16}, + {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui32}, + {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui64}, + {src: pgtype.Int4{Int: -1, Status: pgtype.Present}, dst: &ui}, + {src: pgtype.Int4{Int: 0, Status: pgtype.Null}, dst: &i32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } } diff --git a/pgtype/int8_test.go b/pgtype/int8_test.go index f9d8646f..15762a50 100644 --- a/pgtype/int8_test.go +++ b/pgtype/int8_test.go @@ -118,4 +118,26 @@ func TestInt8AssignTo(t *testing.T) { t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) } } + + errorTests := []struct { + src pgtype.Int8 + dst interface{} + }{ + {src: pgtype.Int8{Int: 150, Status: pgtype.Present}, dst: &i8}, + {src: pgtype.Int8{Int: 40000, Status: pgtype.Present}, dst: &i16}, + {src: pgtype.Int8{Int: 5000000000, Status: pgtype.Present}, dst: &i32}, + {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui8}, + {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui16}, + {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui32}, + {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui64}, + {src: pgtype.Int8{Int: -1, Status: pgtype.Present}, dst: &ui}, + {src: pgtype.Int8{Int: 0, Status: pgtype.Null}, dst: &i64}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } } diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index 4f08cd2a..721c8084 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -55,10 +55,7 @@ func (src *Timestamptz) AssignTo(dst interface{}) error { // if dst is a pointer to pointer, strip the pointer and try again case reflect.Ptr: if src.Status == Null { - if !el.IsNil() { - // if the destination pointer is not nil, nil it out - el.Set(reflect.Zero(el.Type())) - } + el.Set(reflect.Zero(el.Type())) return nil } if el.IsNil() { diff --git a/pgtype/timestamptz_test.go b/pgtype/timestamptz_test.go index adb72620..8f80ca81 100644 --- a/pgtype/timestamptz_test.go +++ b/pgtype/timestamptz_test.go @@ -103,4 +103,20 @@ func TestTimestamptzAssignTo(t *testing.T) { t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) } } + + errorTests := []struct { + src pgtype.Timestamptz + dst interface{} + }{ + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, dst: &tim}, + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, dst: &tim}, + {src: pgtype.Timestamptz{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Null}, dst: &tim}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } } From 0e8dd862b1fd3adf1c3f1f2798f3fd55694edabd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 3 Mar 2017 19:19:31 -0600 Subject: [PATCH 073/264] Add tests for pgtype.Int2Array --- pgtype/convert.go | 22 ++++++ pgtype/int2array.go | 3 + pgtype/int2array_test.go | 153 +++++++++++++++++++++++++++++++-------- pgtype/pgtype_test.go | 1 + 4 files changed, 147 insertions(+), 32 deletions(-) diff --git a/pgtype/convert.go b/pgtype/convert.go index 3f3d9e5f..e35e2310 100644 --- a/pgtype/convert.go +++ b/pgtype/convert.go @@ -122,6 +122,28 @@ func underlyingSliceType(val interface{}) (interface{}, bool) { return nil, false } +func underlyingPtrSliceType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + if refVal.Kind() != reflect.Ptr { + return nil, false + } + if refVal.IsNil() { + return nil, false + } + + sliceVal := refVal.Elem().Interface() + baseSliceType := reflect.SliceOf(reflect.TypeOf(sliceVal).Elem()) + ptrBaseSliceType := reflect.PtrTo(baseSliceType) + + if refVal.Type().ConvertibleTo(ptrBaseSliceType) { + convVal := refVal.Convert(ptrBaseSliceType) + return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type() + } + + return nil, false +} + func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { if srcStatus == Present { switch v := dst.(type) { diff --git a/pgtype/int2array.go b/pgtype/int2array.go index 4ac0c409..e6809c1e 100644 --- a/pgtype/int2array.go +++ b/pgtype/int2array.go @@ -89,6 +89,9 @@ func (src *Int2Array) AssignTo(dst interface{}) error { *v = nil } default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } return fmt.Errorf("cannot put decode %v into %T", src, dst) } diff --git a/pgtype/int2array_test.go b/pgtype/int2array_test.go index 5ea81990..ced0eab4 100644 --- a/pgtype/int2array_test.go +++ b/pgtype/int2array_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "reflect" "testing" "github.com/jackc/pgx/pgtype" @@ -50,38 +51,126 @@ func TestInt2ArrayTranscode(t *testing.T) { }) } -// func TestInt2ConvertFrom(t *testing.T) { -// type _int8 int8 +func TestInt2ArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Int2Array + }{ + { + source: []int16{1}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []uint16{1}, + result: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]int16)(nil)), + result: pgtype.Int2Array{Status: pgtype.Null}, + }, + } -// successfulTests := []struct { -// source interface{} -// result pgtype.Int2 -// }{ -// {source: int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// {source: int16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// {source: int32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// {source: int64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// {source: int8(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, -// {source: int16(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, -// {source: int32(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, -// {source: int64(-1), result: pgtype.Int2{Int: -1, Status: pgtype.Present}}, -// {source: uint8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// {source: uint16(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// {source: uint32(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// {source: uint64(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// {source: "1", result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// {source: _int8(1), result: pgtype.Int2{Int: 1, Status: pgtype.Present}}, -// } + for i, tt := range successfulTests { + var r pgtype.Int2Array + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } -// for i, tt := range successfulTests { -// var r pgtype.Int2 -// err := r.ConvertFrom(tt.source) -// if err != nil { -// t.Errorf("%d: %v", i, err) -// } + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} -// if r != tt.result { -// t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) -// } -// } -// } +func TestInt2ArrayAssignTo(t *testing.T) { + var int16Slice []int16 + var uint16Slice []uint16 + var namedInt16Slice _int16Slice + + simpleTests := []struct { + src pgtype.Int2Array + dst interface{} + expected interface{} + }{ + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &int16Slice, + expected: []int16{1}, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &uint16Slice, + expected: []uint16{1}, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedInt16Slice, + expected: _int16Slice{1}, + }, + { + src: pgtype.Int2Array{Status: pgtype.Null}, + dst: &int16Slice, + expected: (([]int16)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Int2Array + dst interface{} + }{ + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &int16Slice, + }, + { + src: pgtype.Int2Array{ + Elements: []pgtype.Int2{{Int: -1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &uint16Slice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 32ebebfe..a727e2e5 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -15,6 +15,7 @@ import ( type _bool bool type _int8 int8 type _int16 int16 +type _int16Slice []int16 func mustConnectPgx(t testing.TB) *pgx.Conn { config, err := pgx.ParseURI(os.Getenv("DATABASE_URL")) From aabf43a725d5cdca6e3e79f871a509d2f54584df Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Mar 2017 09:44:10 -0600 Subject: [PATCH 074/264] Remove types from Decode handled by pgtypes --- values.go | 88 ------------------------------------------------------- 1 file changed, 88 deletions(-) diff --git a/values.go b/values.go index a9c4c209..ccc2eeb7 100644 --- a/values.go +++ b/values.go @@ -1188,86 +1188,6 @@ func decodeByOID(vr *ValueReader) (interface{}, error) { // decoding to the built-in functionality. func Decode(vr *ValueReader, d interface{}) error { switch v := d.(type) { - case *bool: - *v = decodeBool(vr) - case *int: - n := decodeInt(vr) - if n < int64(minInt) { - return fmt.Errorf("%d is less than minimum value for int", n) - } else if n > int64(maxInt) { - return fmt.Errorf("%d is greater than maximum value for int", n) - } - *v = int(n) - case *int8: - n := decodeInt(vr) - if n < math.MinInt8 { - return fmt.Errorf("%d is less than minimum value for int8", n) - } else if n > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for int8", n) - } - *v = int8(n) - case *int16: - n := decodeInt(vr) - if n < math.MinInt16 { - return fmt.Errorf("%d is less than minimum value for int16", n) - } else if n > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for int16", n) - } - *v = int16(n) - case *int32: - n := decodeInt(vr) - if n < math.MinInt32 { - return fmt.Errorf("%d is less than minimum value for int32", n) - } else if n > math.MaxInt32 { - return fmt.Errorf("%d is greater than maximum value for int32", n) - } - *v = int32(n) - case *int64: - n := decodeInt(vr) - if n < math.MinInt64 { - return fmt.Errorf("%d is less than minimum value for int64", n) - } else if n > math.MaxInt64 { - return fmt.Errorf("%d is greater than maximum value for int64", n) - } - *v = int64(n) - case *uint: - n := decodeInt(vr) - if n < 0 { - return fmt.Errorf("%d is less than zero for uint8", n) - } else if maxInt == math.MaxInt32 && n > math.MaxUint32 { - return fmt.Errorf("%d is greater than maximum value for uint", n) - } - *v = uint(n) - case *uint8: - n := decodeInt(vr) - if n < 0 { - return fmt.Errorf("%d is less than zero for uint8", n) - } else if n > math.MaxUint8 { - return fmt.Errorf("%d is greater than maximum value for uint8", n) - } - *v = uint8(n) - case *uint16: - n := decodeInt(vr) - if n < 0 { - return fmt.Errorf("%d is less than zero for uint16", n) - } else if n > math.MaxUint16 { - return fmt.Errorf("%d is greater than maximum value for uint16", n) - } - *v = uint16(n) - case *uint32: - n := decodeInt(vr) - if n < 0 { - return fmt.Errorf("%d is less than zero for uint32", n) - } else if n > math.MaxUint32 { - return fmt.Errorf("%d is greater than maximum value for uint32", n) - } - *v = uint32(n) - case *uint64: - n := decodeInt(vr) - if n < 0 { - return fmt.Errorf("%d is less than zero for uint64", n) - } - *v = uint64(n) case *Char: *v = decodeChar(vr) case *AclItem: @@ -1294,10 +1214,6 @@ func Decode(vr *ValueReader, d interface{}) error { *v = decodeAclItemArray(vr) case *[]bool: *v = decodeBoolArray(vr) - case *[]int16: - *v = decodeInt2Array(vr) - case *[]uint16: - *v = decodeInt2ArrayToUInt(vr) case *[]int32: *v = decodeInt4Array(vr) case *[]uint32: @@ -1320,10 +1236,6 @@ func Decode(vr *ValueReader, d interface{}) error { *v = decodeRecord(vr) case *time.Time: switch vr.Type().DataType { - case DateOID: - *v = decodeDate(vr) - case TimestampTzOID: - *v = decodeTimestampTz(vr) case TimestampOID: *v = decodeTimestamp(vr) default: From ffb949054d5e1d3114ddd39c74c3fded84f1d8d9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Mar 2017 11:48:53 -0600 Subject: [PATCH 075/264] Add arrays to all other pgtypes --- conn.go | 19 +- pgtype/boolarray.go | 286 ++++++++++++++++++++++++++++ pgtype/boolarray_test.go | 152 +++++++++++++++ pgtype/datearray.go | 287 +++++++++++++++++++++++++++++ pgtype/datearray_test.go | 142 ++++++++++++++ pgtype/int2array.go | 6 + pgtype/int4array.go | 317 ++++++++++++++++++++++++++++++++ pgtype/int4array_test.go | 176 ++++++++++++++++++ pgtype/int8array.go | 317 ++++++++++++++++++++++++++++++++ pgtype/int8array_test.go | 176 ++++++++++++++++++ pgtype/pgtype.go | 5 +- pgtype/pgtype_test.go | 2 + pgtype/timestamptzarray.go | 287 +++++++++++++++++++++++++++++ pgtype/timestamptzarray_test.go | 158 ++++++++++++++++ pgtype/typed_array.go.erb | 286 ++++++++++++++++++++++++++++ pgtype/typed_array_gen.sh | 6 + query_test.go | 6 - values.go | 19 +- 18 files changed, 2614 insertions(+), 33 deletions(-) create mode 100644 pgtype/boolarray.go create mode 100644 pgtype/boolarray_test.go create mode 100644 pgtype/datearray.go create mode 100644 pgtype/datearray_test.go create mode 100644 pgtype/int4array.go create mode 100644 pgtype/int4array_test.go create mode 100644 pgtype/int8array.go create mode 100644 pgtype/int8array_test.go create mode 100644 pgtype/timestamptzarray.go create mode 100644 pgtype/timestamptzarray_test.go create mode 100644 pgtype/typed_array.go.erb create mode 100644 pgtype/typed_array_gen.sh diff --git a/conn.go b/conn.go index 1c0b4e22..b8d92b0b 100644 --- a/conn.go +++ b/conn.go @@ -279,13 +279,18 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.closedChan = make(chan error) c.oidPgtypeValues = map[OID]pgtype.Value{ - BoolOID: &pgtype.Bool{}, - DateOID: &pgtype.Date{}, - Int2OID: &pgtype.Int2{}, - Int2ArrayOID: &pgtype.Int2Array{}, - Int4OID: &pgtype.Int4{}, - Int8OID: &pgtype.Int8{}, - TimestampTzOID: &pgtype.Timestamptz{}, + BoolOID: &pgtype.Bool{}, + BoolArrayOID: &pgtype.BoolArray{}, + DateOID: &pgtype.Date{}, + DateArrayOID: &pgtype.DateArray{}, + Int2OID: &pgtype.Int2{}, + Int2ArrayOID: &pgtype.Int2Array{}, + Int4OID: &pgtype.Int4{}, + Int4ArrayOID: &pgtype.Int4Array{}, + Int8OID: &pgtype.Int8{}, + Int8ArrayOID: &pgtype.Int8Array{}, + TimestampTzOID: &pgtype.Timestamptz{}, + TimestampTzArrayOID: &pgtype.TimestamptzArray{}, } if tlsConfig != nil { diff --git a/pgtype/boolarray.go b/pgtype/boolarray.go new file mode 100644 index 00000000..8dd68dc2 --- /dev/null +++ b/pgtype/boolarray.go @@ -0,0 +1,286 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type BoolArray struct { + Elements []Bool + Dimensions []ArrayDimension + Status Status +} + +func (dst *BoolArray) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case BoolArray: + *dst = value + + case []bool: + if value == nil { + *dst = BoolArray{Status: Null} + } else if len(value) == 0 { + *dst = BoolArray{Status: Present} + } else { + elements := make([]Bool, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = BoolArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Bool", value) + } + + return nil +} + +func (src *BoolArray) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[]bool: + if src.Status == Present { + *v = make([]bool, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot put decode %v into %T", src, dst) + } + + return nil +} + +func (dst *BoolArray) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = BoolArray{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + uta, err := ParseUntypedTextArray(string(buf)) + if err != nil { + return err + } + + textElementReader := NewTextElementReader(r) + var elements []Bool + + if len(uta.Elements) > 0 { + elements = make([]Bool, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Bool + textElementReader.Reset(s) + err = elem.DecodeText(textElementReader) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = BoolArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *BoolArray) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = BoolArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + err = arrayHeader.DecodeBinary(r) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = BoolArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Bool, elementCount) + + for i := range elements { + err = elements[i].DecodeBinary(r) + if err != nil { + return err + } + } + + *dst = BoolArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *BoolArray) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + if len(src.Dimensions) == 0 { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = w.Write([]byte("{}")) + return err + } + + buf := &bytes.Buffer{} + + err := EncodeTextArrayDimensions(buf, src.Dimensions) + if err != nil { + return err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + textElementWriter := NewTextElementWriter(buf) + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(buf, ',') + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(buf, '{') + if err != nil { + return err + } + } + } + + textElementWriter.Reset() + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(buf, '}') + if err != nil { + return err + } + } + } + } + + _, err = pgio.WriteInt32(w, int32(buf.Len())) + if err != nil { + return err + } + + _, err = buf.WriteTo(w) + return err +} + +func (src *BoolArray) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + var arrayHeader ArrayHeader + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + err := src.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return err + } + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + } + } + + arrayHeader.ElementOID = BoolOID + arrayHeader.Dimensions = src.Dimensions + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + headerBuf := &bytes.Buffer{} + err := arrayHeader.EncodeBinary(headerBuf) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) + if err != nil { + return err + } + + _, err = headerBuf.WriteTo(w) + if err != nil { + return err + } + + _, err = elemBuf.WriteTo(w) + if err != nil { + return err + } + + return err +} diff --git a/pgtype/boolarray_test.go b/pgtype/boolarray_test.go new file mode 100644 index 00000000..c5f15f97 --- /dev/null +++ b/pgtype/boolarray_test.go @@ -0,0 +1,152 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestBoolArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "bool[]", []interface{}{ + &pgtype.BoolArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.BoolArray{ + Elements: []pgtype.Bool{ + pgtype.Bool{Bool: true, Status: pgtype.Present}, + pgtype.Bool{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.BoolArray{Status: pgtype.Null}, + &pgtype.BoolArray{ + Elements: []pgtype.Bool{ + pgtype.Bool{Bool: true, Status: pgtype.Present}, + pgtype.Bool{Bool: true, Status: pgtype.Present}, + pgtype.Bool{Bool: false, Status: pgtype.Present}, + pgtype.Bool{Bool: true, Status: pgtype.Present}, + pgtype.Bool{Status: pgtype.Null}, + pgtype.Bool{Bool: false, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.BoolArray{ + Elements: []pgtype.Bool{ + pgtype.Bool{Bool: true, Status: pgtype.Present}, + pgtype.Bool{Bool: false, Status: pgtype.Present}, + pgtype.Bool{Bool: true, Status: pgtype.Present}, + pgtype.Bool{Bool: false, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestBoolArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.BoolArray + }{ + { + source: []bool{true}, + result: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]bool)(nil)), + result: pgtype.BoolArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.BoolArray + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestBoolArrayAssignTo(t *testing.T) { + var boolSlice []bool + type _boolSlice []bool + var namedBoolSlice _boolSlice + + simpleTests := []struct { + src pgtype.BoolArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &boolSlice, + expected: []bool{true}, + }, + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Bool: true, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedBoolSlice, + expected: _boolSlice{true}, + }, + { + src: pgtype.BoolArray{Status: pgtype.Null}, + dst: &boolSlice, + expected: (([]bool)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.BoolArray + dst interface{} + }{ + { + src: pgtype.BoolArray{ + Elements: []pgtype.Bool{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &boolSlice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/datearray.go b/pgtype/datearray.go new file mode 100644 index 00000000..877f328e --- /dev/null +++ b/pgtype/datearray.go @@ -0,0 +1,287 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + "time" + + "github.com/jackc/pgx/pgio" +) + +type DateArray struct { + Elements []Date + Dimensions []ArrayDimension + Status Status +} + +func (dst *DateArray) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case DateArray: + *dst = value + + case []time.Time: + if value == nil { + *dst = DateArray{Status: Null} + } else if len(value) == 0 { + *dst = DateArray{Status: Present} + } else { + elements := make([]Date, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = DateArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Date", value) + } + + return nil +} + +func (src *DateArray) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[]time.Time: + if src.Status == Present { + *v = make([]time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot put decode %v into %T", src, dst) + } + + return nil +} + +func (dst *DateArray) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = DateArray{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + uta, err := ParseUntypedTextArray(string(buf)) + if err != nil { + return err + } + + textElementReader := NewTextElementReader(r) + var elements []Date + + if len(uta.Elements) > 0 { + elements = make([]Date, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Date + textElementReader.Reset(s) + err = elem.DecodeText(textElementReader) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = DateArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *DateArray) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = DateArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + err = arrayHeader.DecodeBinary(r) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = DateArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Date, elementCount) + + for i := range elements { + err = elements[i].DecodeBinary(r) + if err != nil { + return err + } + } + + *dst = DateArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *DateArray) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + if len(src.Dimensions) == 0 { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = w.Write([]byte("{}")) + return err + } + + buf := &bytes.Buffer{} + + err := EncodeTextArrayDimensions(buf, src.Dimensions) + if err != nil { + return err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + textElementWriter := NewTextElementWriter(buf) + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(buf, ',') + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(buf, '{') + if err != nil { + return err + } + } + } + + textElementWriter.Reset() + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(buf, '}') + if err != nil { + return err + } + } + } + } + + _, err = pgio.WriteInt32(w, int32(buf.Len())) + if err != nil { + return err + } + + _, err = buf.WriteTo(w) + return err +} + +func (src *DateArray) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + var arrayHeader ArrayHeader + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + err := src.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return err + } + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + } + } + + arrayHeader.ElementOID = DateOID + arrayHeader.Dimensions = src.Dimensions + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + headerBuf := &bytes.Buffer{} + err := arrayHeader.EncodeBinary(headerBuf) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) + if err != nil { + return err + } + + _, err = headerBuf.WriteTo(w) + if err != nil { + return err + } + + _, err = elemBuf.WriteTo(w) + if err != nil { + return err + } + + return err +} diff --git a/pgtype/datearray_test.go b/pgtype/datearray_test.go new file mode 100644 index 00000000..60f15983 --- /dev/null +++ b/pgtype/datearray_test.go @@ -0,0 +1,142 @@ +package pgtype_test + +import ( + "reflect" + "testing" + "time" + + "github.com/jackc/pgx/pgtype" +) + +func TestDateArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "date[]", []interface{}{ + &pgtype.DateArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.DateArray{ + Elements: []pgtype.Date{ + pgtype.Date{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.DateArray{Status: pgtype.Null}, + &pgtype.DateArray{ + Elements: []pgtype.Date{ + pgtype.Date{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Status: pgtype.Null}, + pgtype.Date{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.DateArray{ + Elements: []pgtype.Date{ + pgtype.Date{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Date{Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestDateArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.DateArray + }{ + { + source: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + result: pgtype.DateArray{ + Elements: []pgtype.Date{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]time.Time)(nil)), + result: pgtype.DateArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.DateArray + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestDateArrayAssignTo(t *testing.T) { + var timeSlice []time.Time + + simpleTests := []struct { + src pgtype.DateArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &timeSlice, + expected: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + }, + { + src: pgtype.DateArray{Status: pgtype.Null}, + dst: &timeSlice, + expected: (([]time.Time)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.DateArray + dst interface{} + }{ + { + src: pgtype.DateArray{ + Elements: []pgtype.Date{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &timeSlice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/int2array.go b/pgtype/int2array.go index e6809c1e..4fc6d882 100644 --- a/pgtype/int2array.go +++ b/pgtype/int2array.go @@ -18,6 +18,7 @@ func (dst *Int2Array) ConvertFrom(src interface{}) error { switch value := src.(type) { case Int2Array: *dst = value + case []int16: if value == nil { *dst = Int2Array{Status: Null} @@ -36,6 +37,7 @@ func (dst *Int2Array) ConvertFrom(src interface{}) error { Status: Present, } } + case []uint16: if value == nil { *dst = Int2Array{Status: Null} @@ -54,6 +56,7 @@ func (dst *Int2Array) ConvertFrom(src interface{}) error { Status: Present, } } + default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.ConvertFrom(originalSrc) @@ -66,6 +69,7 @@ func (dst *Int2Array) ConvertFrom(src interface{}) error { func (src *Int2Array) AssignTo(dst interface{}) error { switch v := dst.(type) { + case *[]int16: if src.Status == Present { *v = make([]int16, len(src.Elements)) @@ -77,6 +81,7 @@ func (src *Int2Array) AssignTo(dst interface{}) error { } else { *v = nil } + case *[]uint16: if src.Status == Present { *v = make([]uint16, len(src.Elements)) @@ -88,6 +93,7 @@ func (src *Int2Array) AssignTo(dst interface{}) error { } else { *v = nil } + default: if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) diff --git a/pgtype/int4array.go b/pgtype/int4array.go new file mode 100644 index 00000000..40e1490d --- /dev/null +++ b/pgtype/int4array.go @@ -0,0 +1,317 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type Int4Array struct { + Elements []Int4 + Dimensions []ArrayDimension + Status Status +} + +func (dst *Int4Array) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Int4Array: + *dst = value + + case []int32: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint32: + if value == nil { + *dst = Int4Array{Status: Null} + } else if len(value) == 0 { + *dst = Int4Array{Status: Present} + } else { + elements := make([]Int4, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = Int4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int4", value) + } + + return nil +} + +func (src *Int4Array) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[]int32: + if src.Status == Present { + *v = make([]int32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + case *[]uint32: + if src.Status == Present { + *v = make([]uint32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot put decode %v into %T", src, dst) + } + + return nil +} + +func (dst *Int4Array) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Int4Array{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + uta, err := ParseUntypedTextArray(string(buf)) + if err != nil { + return err + } + + textElementReader := NewTextElementReader(r) + var elements []Int4 + + if len(uta.Elements) > 0 { + elements = make([]Int4, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Int4 + textElementReader.Reset(s) + err = elem.DecodeText(textElementReader) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = Int4Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *Int4Array) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Int4Array{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + err = arrayHeader.DecodeBinary(r) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = Int4Array{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Int4, elementCount) + + for i := range elements { + err = elements[i].DecodeBinary(r) + if err != nil { + return err + } + } + + *dst = Int4Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *Int4Array) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + if len(src.Dimensions) == 0 { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = w.Write([]byte("{}")) + return err + } + + buf := &bytes.Buffer{} + + err := EncodeTextArrayDimensions(buf, src.Dimensions) + if err != nil { + return err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + textElementWriter := NewTextElementWriter(buf) + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(buf, ',') + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(buf, '{') + if err != nil { + return err + } + } + } + + textElementWriter.Reset() + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(buf, '}') + if err != nil { + return err + } + } + } + } + + _, err = pgio.WriteInt32(w, int32(buf.Len())) + if err != nil { + return err + } + + _, err = buf.WriteTo(w) + return err +} + +func (src *Int4Array) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + var arrayHeader ArrayHeader + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + err := src.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return err + } + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + } + } + + arrayHeader.ElementOID = Int4OID + arrayHeader.Dimensions = src.Dimensions + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + headerBuf := &bytes.Buffer{} + err := arrayHeader.EncodeBinary(headerBuf) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) + if err != nil { + return err + } + + _, err = headerBuf.WriteTo(w) + if err != nil { + return err + } + + _, err = elemBuf.WriteTo(w) + if err != nil { + return err + } + + return err +} diff --git a/pgtype/int4array_test.go b/pgtype/int4array_test.go new file mode 100644 index 00000000..38ba27cb --- /dev/null +++ b/pgtype/int4array_test.go @@ -0,0 +1,176 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInt4ArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "int4[]", []interface{}{ + &pgtype.Int4Array{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.Int4Array{ + Elements: []pgtype.Int4{ + pgtype.Int4{Int: 1, Status: pgtype.Present}, + pgtype.Int4{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int4Array{Status: pgtype.Null}, + &pgtype.Int4Array{ + Elements: []pgtype.Int4{ + pgtype.Int4{Int: 1, Status: pgtype.Present}, + pgtype.Int4{Int: 2, Status: pgtype.Present}, + pgtype.Int4{Int: 3, Status: pgtype.Present}, + pgtype.Int4{Int: 4, Status: pgtype.Present}, + pgtype.Int4{Status: pgtype.Null}, + pgtype.Int4{Int: 6, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int4Array{ + Elements: []pgtype.Int4{ + pgtype.Int4{Int: 1, Status: pgtype.Present}, + pgtype.Int4{Int: 2, Status: pgtype.Present}, + pgtype.Int4{Int: 3, Status: pgtype.Present}, + pgtype.Int4{Int: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestInt4ArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Int4Array + }{ + { + source: []int32{1}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []uint32{1}, + result: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]int32)(nil)), + result: pgtype.Int4Array{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Int4Array + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInt4ArrayAssignTo(t *testing.T) { + var int32Slice []int32 + var uint32Slice []uint32 + var namedInt32Slice _int32Slice + + simpleTests := []struct { + src pgtype.Int4Array + dst interface{} + expected interface{} + }{ + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &int32Slice, + expected: []int32{1}, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &uint32Slice, + expected: []uint32{1}, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedInt32Slice, + expected: _int32Slice{1}, + }, + { + src: pgtype.Int4Array{Status: pgtype.Null}, + dst: &int32Slice, + expected: (([]int32)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Int4Array + dst interface{} + }{ + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &int32Slice, + }, + { + src: pgtype.Int4Array{ + Elements: []pgtype.Int4{{Int: -1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &uint32Slice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/int8array.go b/pgtype/int8array.go new file mode 100644 index 00000000..35ecf946 --- /dev/null +++ b/pgtype/int8array.go @@ -0,0 +1,317 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type Int8Array struct { + Elements []Int8 + Dimensions []ArrayDimension + Status Status +} + +func (dst *Int8Array) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Int8Array: + *dst = value + + case []int64: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []uint64: + if value == nil { + *dst = Int8Array{Status: Null} + } else if len(value) == 0 { + *dst = Int8Array{Status: Present} + } else { + elements := make([]Int8, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = Int8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Int8", value) + } + + return nil +} + +func (src *Int8Array) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[]int64: + if src.Status == Present { + *v = make([]int64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + case *[]uint64: + if src.Status == Present { + *v = make([]uint64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot put decode %v into %T", src, dst) + } + + return nil +} + +func (dst *Int8Array) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Int8Array{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + uta, err := ParseUntypedTextArray(string(buf)) + if err != nil { + return err + } + + textElementReader := NewTextElementReader(r) + var elements []Int8 + + if len(uta.Elements) > 0 { + elements = make([]Int8, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Int8 + textElementReader.Reset(s) + err = elem.DecodeText(textElementReader) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = Int8Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *Int8Array) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Int8Array{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + err = arrayHeader.DecodeBinary(r) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = Int8Array{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Int8, elementCount) + + for i := range elements { + err = elements[i].DecodeBinary(r) + if err != nil { + return err + } + } + + *dst = Int8Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *Int8Array) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + if len(src.Dimensions) == 0 { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = w.Write([]byte("{}")) + return err + } + + buf := &bytes.Buffer{} + + err := EncodeTextArrayDimensions(buf, src.Dimensions) + if err != nil { + return err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + textElementWriter := NewTextElementWriter(buf) + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(buf, ',') + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(buf, '{') + if err != nil { + return err + } + } + } + + textElementWriter.Reset() + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(buf, '}') + if err != nil { + return err + } + } + } + } + + _, err = pgio.WriteInt32(w, int32(buf.Len())) + if err != nil { + return err + } + + _, err = buf.WriteTo(w) + return err +} + +func (src *Int8Array) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + var arrayHeader ArrayHeader + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + err := src.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return err + } + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + } + } + + arrayHeader.ElementOID = Int8OID + arrayHeader.Dimensions = src.Dimensions + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + headerBuf := &bytes.Buffer{} + err := arrayHeader.EncodeBinary(headerBuf) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) + if err != nil { + return err + } + + _, err = headerBuf.WriteTo(w) + if err != nil { + return err + } + + _, err = elemBuf.WriteTo(w) + if err != nil { + return err + } + + return err +} diff --git a/pgtype/int8array_test.go b/pgtype/int8array_test.go new file mode 100644 index 00000000..137768c6 --- /dev/null +++ b/pgtype/int8array_test.go @@ -0,0 +1,176 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInt8ArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "int8[]", []interface{}{ + &pgtype.Int8Array{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.Int8Array{ + Elements: []pgtype.Int8{ + pgtype.Int8{Int: 1, Status: pgtype.Present}, + pgtype.Int8{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int8Array{Status: pgtype.Null}, + &pgtype.Int8Array{ + Elements: []pgtype.Int8{ + pgtype.Int8{Int: 1, Status: pgtype.Present}, + pgtype.Int8{Int: 2, Status: pgtype.Present}, + pgtype.Int8{Int: 3, Status: pgtype.Present}, + pgtype.Int8{Int: 4, Status: pgtype.Present}, + pgtype.Int8{Status: pgtype.Null}, + pgtype.Int8{Int: 6, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int8Array{ + Elements: []pgtype.Int8{ + pgtype.Int8{Int: 1, Status: pgtype.Present}, + pgtype.Int8{Int: 2, Status: pgtype.Present}, + pgtype.Int8{Int: 3, Status: pgtype.Present}, + pgtype.Int8{Int: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestInt8ArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Int8Array + }{ + { + source: []int64{1}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []uint64{1}, + result: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]int64)(nil)), + result: pgtype.Int8Array{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Int8Array + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInt8ArrayAssignTo(t *testing.T) { + var int64Slice []int64 + var uint64Slice []uint64 + var namedInt64Slice _int64Slice + + simpleTests := []struct { + src pgtype.Int8Array + dst interface{} + expected interface{} + }{ + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &int64Slice, + expected: []int64{1}, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &uint64Slice, + expected: []uint64{1}, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedInt64Slice, + expected: _int64Slice{1}, + }, + { + src: pgtype.Int8Array{Status: pgtype.Null}, + dst: &int64Slice, + expected: (([]int64)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Int8Array + dst interface{} + }{ + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &int64Slice, + }, + { + src: pgtype.Int8Array{ + Elements: []pgtype.Int8{{Int: -1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &uint64Slice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index f9833363..5722c8ab 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -44,8 +44,9 @@ const ( DateOID = 1082 TimestampOID = 1114 TimestampArrayOID = 1115 - TimestampTzOID = 1184 - TimestampTzArrayOID = 1185 + DateArrayOID = 1182 + TimestamptzOID = 1184 + TimestamptzArrayOID = 1185 RecordOID = 2249 UUIDOID = 2950 JSONBOID = 3802 diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index a727e2e5..97afc249 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -16,6 +16,8 @@ type _bool bool type _int8 int8 type _int16 int16 type _int16Slice []int16 +type _int32Slice []int32 +type _int64Slice []int64 func mustConnectPgx(t testing.TB) *pgx.Conn { config, err := pgx.ParseURI(os.Getenv("DATABASE_URL")) diff --git a/pgtype/timestamptzarray.go b/pgtype/timestamptzarray.go new file mode 100644 index 00000000..72b28e43 --- /dev/null +++ b/pgtype/timestamptzarray.go @@ -0,0 +1,287 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + "time" + + "github.com/jackc/pgx/pgio" +) + +type TimestamptzArray struct { + Elements []Timestamptz + Dimensions []ArrayDimension + Status Status +} + +func (dst *TimestamptzArray) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case TimestamptzArray: + *dst = value + + case []time.Time: + if value == nil { + *dst = TimestamptzArray{Status: Null} + } else if len(value) == 0 { + *dst = TimestamptzArray{Status: Present} + } else { + elements := make([]Timestamptz, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = TimestamptzArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Timestamptz", value) + } + + return nil +} + +func (src *TimestamptzArray) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[]time.Time: + if src.Status == Present { + *v = make([]time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot put decode %v into %T", src, dst) + } + + return nil +} + +func (dst *TimestamptzArray) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = TimestamptzArray{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + uta, err := ParseUntypedTextArray(string(buf)) + if err != nil { + return err + } + + textElementReader := NewTextElementReader(r) + var elements []Timestamptz + + if len(uta.Elements) > 0 { + elements = make([]Timestamptz, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Timestamptz + textElementReader.Reset(s) + err = elem.DecodeText(textElementReader) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = TimestamptzArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *TimestamptzArray) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = TimestamptzArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + err = arrayHeader.DecodeBinary(r) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = TimestamptzArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Timestamptz, elementCount) + + for i := range elements { + err = elements[i].DecodeBinary(r) + if err != nil { + return err + } + } + + *dst = TimestamptzArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *TimestamptzArray) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + if len(src.Dimensions) == 0 { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = w.Write([]byte("{}")) + return err + } + + buf := &bytes.Buffer{} + + err := EncodeTextArrayDimensions(buf, src.Dimensions) + if err != nil { + return err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + textElementWriter := NewTextElementWriter(buf) + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(buf, ',') + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(buf, '{') + if err != nil { + return err + } + } + } + + textElementWriter.Reset() + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(buf, '}') + if err != nil { + return err + } + } + } + } + + _, err = pgio.WriteInt32(w, int32(buf.Len())) + if err != nil { + return err + } + + _, err = buf.WriteTo(w) + return err +} + +func (src *TimestamptzArray) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + var arrayHeader ArrayHeader + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + err := src.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return err + } + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + } + } + + arrayHeader.ElementOID = TimestamptzOID + arrayHeader.Dimensions = src.Dimensions + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + headerBuf := &bytes.Buffer{} + err := arrayHeader.EncodeBinary(headerBuf) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) + if err != nil { + return err + } + + _, err = headerBuf.WriteTo(w) + if err != nil { + return err + } + + _, err = elemBuf.WriteTo(w) + if err != nil { + return err + } + + return err +} diff --git a/pgtype/timestamptzarray_test.go b/pgtype/timestamptzarray_test.go new file mode 100644 index 00000000..af2c004b --- /dev/null +++ b/pgtype/timestamptzarray_test.go @@ -0,0 +1,158 @@ +package pgtype_test + +import ( + "reflect" + "testing" + "time" + + "github.com/jackc/pgx/pgtype" +) + +func TestTimestamptzArrayTranscode(t *testing.T) { + testSuccessfulTranscodeEqFunc(t, "timestamptz[]", []interface{}{ + &pgtype.TimestamptzArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + pgtype.Timestamptz{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamptz{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.TimestamptzArray{Status: pgtype.Null}, + &pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + pgtype.Timestamptz{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamptz{Status: pgtype.Null}, + pgtype.Timestamptz{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{ + pgtype.Timestamptz{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamptz{Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }, func(a, b interface{}) bool { + ata := a.(pgtype.TimestamptzArray) + bta := b.(pgtype.TimestamptzArray) + + if len(ata.Elements) != len(bta.Elements) || ata.Status != bta.Status { + return false + } + + for i := range ata.Elements { + ae, be := ata.Elements[i], bta.Elements[i] + if !(ae.Time.Equal(be.Time) && ae.Status == be.Status && ae.InfinityModifier == be.InfinityModifier) { + return false + } + } + + return true + }) +} + +func TestTimestamptzArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.TimestamptzArray + }{ + { + source: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + result: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]time.Time)(nil)), + result: pgtype.TimestamptzArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.TimestamptzArray + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestTimestamptzArrayAssignTo(t *testing.T) { + var timeSlice []time.Time + + simpleTests := []struct { + src pgtype.TimestamptzArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &timeSlice, + expected: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + }, + { + src: pgtype.TimestamptzArray{Status: pgtype.Null}, + dst: &timeSlice, + expected: (([]time.Time)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.TimestamptzArray + dst interface{} + }{ + { + src: pgtype.TimestamptzArray{ + Elements: []pgtype.Timestamptz{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &timeSlice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb new file mode 100644 index 00000000..e6e480b0 --- /dev/null +++ b/pgtype/typed_array.go.erb @@ -0,0 +1,286 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type <%= pgtype_array_type %> struct { + Elements []<%= pgtype_element_type %> + Dimensions []ArrayDimension + Status Status +} + +func (dst *<%= pgtype_array_type %>) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case <%= pgtype_array_type %>: + *dst = value + <% go_array_types.split(",").each do |t| %> + case <%= t %>: + if value == nil { + *dst = <%= pgtype_array_type %>{Status: Null} + } else if len(value) == 0 { + *dst = <%= pgtype_array_type %>{Status: Present} + } else { + elements := make([]<%= pgtype_element_type %>, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = <%= pgtype_array_type %>{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + <% end %> + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to <%= pgtype_element_type %>", value) + } + + return nil +} + +func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { + switch v := dst.(type) { + <% go_array_types.split(",").each do |t| %> + case *<%= t %>: + if src.Status == Present { + *v = make(<%= t %>, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + <% end %> + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot put decode %v into %T", src, dst) + } + + return nil +} + +func (dst *<%= pgtype_array_type %>) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = <%= pgtype_array_type %>{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + uta, err := ParseUntypedTextArray(string(buf)) + if err != nil { + return err + } + + textElementReader := NewTextElementReader(r) + var elements []<%= pgtype_element_type %> + + if len(uta.Elements) > 0 { + elements = make([]<%= pgtype_element_type %>, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem <%= pgtype_element_type %> + textElementReader.Reset(s) + err = elem.DecodeText(textElementReader) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = <%= pgtype_array_type %>{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *<%= pgtype_array_type %>) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = <%= pgtype_array_type %>{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + err = arrayHeader.DecodeBinary(r) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = <%= pgtype_array_type %>{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]<%= pgtype_element_type %>, elementCount) + + for i := range elements { + err = elements[i].DecodeBinary(r) + if err != nil { + return err + } + } + + *dst = <%= pgtype_array_type %>{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + if len(src.Dimensions) == 0 { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = w.Write([]byte("{}")) + return err + } + + buf := &bytes.Buffer{} + + err := EncodeTextArrayDimensions(buf, src.Dimensions) + if err != nil { + return err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + textElementWriter := NewTextElementWriter(buf) + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(buf, ',') + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(buf, '{') + if err != nil { + return err + } + } + } + + textElementWriter.Reset() + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(buf, '}') + if err != nil { + return err + } + } + } + } + + _, err = pgio.WriteInt32(w, int32(buf.Len())) + if err != nil { + return err + } + + _, err = buf.WriteTo(w) + return err +} + +func (src *<%= pgtype_array_type %>) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + var arrayHeader ArrayHeader + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + err := src.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return err + } + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + } + } + + arrayHeader.ElementOID = <%= element_oid %> + arrayHeader.Dimensions = src.Dimensions + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + headerBuf := &bytes.Buffer{} + err := arrayHeader.EncodeBinary(headerBuf) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) + if err != nil { + return err + } + + _, err = headerBuf.WriteTo(w) + if err != nil { + return err + } + + _, err = elemBuf.WriteTo(w) + if err != nil { + return err + } + + return err +} diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh new file mode 100644 index 00000000..9fec58e8 --- /dev/null +++ b/pgtype/typed_array_gen.sh @@ -0,0 +1,6 @@ +erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_oid=Int2OID typed_array.go.erb > int2array.go +erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32 element_oid=Int4OID typed_array.go.erb > int4array.go +erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_oid=Int8OID typed_array.go.erb > int2array.go +erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_oid=BoolOID typed_array.go.erb > boolarray.go +erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_oid=DateOID typed_array.go.erb > datearray.go +erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_oid=TimestamptzOID typed_array.go.erb > timestamptzarray.go diff --git a/query_test.go b/query_test.go index c30ab2ef..981df3ee 100644 --- a/query_test.go +++ b/query_test.go @@ -1063,9 +1063,6 @@ func TestQueryRowCoreInt32Slice(t *testing.T) { if err == nil { t.Error("Expected null to cause error when scanned into slice, but it didn't") } - if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { - t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err) - } ensureConnValid(t, conn) } @@ -1110,9 +1107,6 @@ func TestQueryRowCoreInt64Slice(t *testing.T) { if err == nil { t.Error("Expected null to cause error when scanned into slice, but it didn't") } - if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { - t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err) - } ensureConnValid(t, conn) } diff --git a/values.go b/values.go index ccc2eeb7..e83af308 100644 --- a/values.go +++ b/values.go @@ -54,6 +54,7 @@ const ( DateOID = 1082 TimestampOID = 1114 TimestampArrayOID = 1115 + DateArrayOID = 1182 TimestampTzOID = 1184 TimestampTzArrayOID = 1185 RecordOID = 2249 @@ -1087,14 +1088,6 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { // The name data type goes over the wire using the same format as string, // so just cast to string and use encodeString return encodeString(wbuf, oid, string(arg)) - case []int32: - return encodeInt32Slice(wbuf, oid, arg) - case []uint32: - return encodeUInt32Slice(wbuf, oid, arg) - case []int64: - return encodeInt64Slice(wbuf, oid, arg) - case []uint64: - return encodeUInt64Slice(wbuf, oid, arg) case float32: return encodeFloat32(wbuf, oid, arg) case []float32: @@ -1212,16 +1205,6 @@ func Decode(vr *ValueReader, d interface{}) error { *v = decodeFloat8(vr) case *[]AclItem: *v = decodeAclItemArray(vr) - case *[]bool: - *v = decodeBoolArray(vr) - case *[]int32: - *v = decodeInt4Array(vr) - case *[]uint32: - *v = decodeInt4ArrayToUInt(vr) - case *[]int64: - *v = decodeInt8Array(vr) - case *[]uint64: - *v = decodeInt8ArrayToUInt(vr) case *[]float32: *v = decodeFloat4Array(vr) case *[]float64: From 3179e2debc4354879efc3aece8d1779e93e5745d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Mar 2017 12:36:24 -0600 Subject: [PATCH 076/264] Add timestamp to pgtype --- conn.go | 14 +- pgtype/timestamp.go | 204 ++++++++++++++++++++++++ pgtype/timestamp_test.go | 123 +++++++++++++++ pgtype/timestamparray.go | 287 ++++++++++++++++++++++++++++++++++ pgtype/timestamparray_test.go | 158 +++++++++++++++++++ pgtype/typed_array_gen.sh | 1 + query_test.go | 2 +- values.go | 13 -- values_test.go | 10 -- 9 files changed, 782 insertions(+), 30 deletions(-) create mode 100644 pgtype/timestamp.go create mode 100644 pgtype/timestamp_test.go create mode 100644 pgtype/timestamparray.go create mode 100644 pgtype/timestamparray_test.go diff --git a/conn.go b/conn.go index b8d92b0b..5fd82ea0 100644 --- a/conn.go +++ b/conn.go @@ -279,18 +279,20 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.closedChan = make(chan error) c.oidPgtypeValues = map[OID]pgtype.Value{ - BoolOID: &pgtype.Bool{}, BoolArrayOID: &pgtype.BoolArray{}, - DateOID: &pgtype.Date{}, + BoolOID: &pgtype.Bool{}, DateArrayOID: &pgtype.DateArray{}, - Int2OID: &pgtype.Int2{}, + DateOID: &pgtype.Date{}, Int2ArrayOID: &pgtype.Int2Array{}, - Int4OID: &pgtype.Int4{}, + Int2OID: &pgtype.Int2{}, Int4ArrayOID: &pgtype.Int4Array{}, - Int8OID: &pgtype.Int8{}, + Int4OID: &pgtype.Int4{}, Int8ArrayOID: &pgtype.Int8Array{}, - TimestampTzOID: &pgtype.Timestamptz{}, + Int8OID: &pgtype.Int8{}, + TimestampArrayOID: &pgtype.TimestampArray{}, + TimestampOID: &pgtype.Timestamp{}, TimestampTzArrayOID: &pgtype.TimestamptzArray{}, + TimestampTzOID: &pgtype.Timestamptz{}, } if tlsConfig != nil { diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go new file mode 100644 index 00000000..c6933988 --- /dev/null +++ b/pgtype/timestamp.go @@ -0,0 +1,204 @@ +package pgtype + +import ( + "fmt" + "io" + "reflect" + "time" + + "github.com/jackc/pgx/pgio" +) + +const pgTimestampFormat = "2006-01-02 15:04:05.999999999" + +// Timestamp represents the PostgreSQL timestamp type. The PostgreSQL +// timestamp does not have a time zone. This presents a problem when +// translating to and from time.Time which requires a time zone. It is highly +// recommended to use timestamptz whenever possible. Timestamp methods either +// convert to UTC or return an error on non-UTC times. +type Timestamp struct { + Time time.Time // Time must always be in UTC. + Status Status + InfinityModifier +} + +// ConvertFrom converts src into a Timestamp and stores in dst. If src is a +// time.Time in a non-UTC time zone, the time zone is discarded. +func (dst *Timestamp) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Timestamp: + *dst = value + case time.Time: + *dst = Timestamp{Time: time.Date(value.Year(), value.Month(), value.Day(), value.Hour(), value.Minute(), value.Second(), value.Nanosecond(), time.UTC), Status: Present} + default: + if originalSrc, ok := underlyingTimeType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Timestamp", value) + } + + return nil +} + +func (src *Timestamp) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *time.Time: + if src.Status != Present || src.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.Time + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if src.Status == Null { + el.Set(reflect.Zero(el.Type())) + return nil + } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return src.AssignTo(el.Interface()) + } + } + return fmt.Errorf("cannot assign %v into %T", src, dst) + } + + return nil +} + +// DecodeText decodes from src into dst. The decoded time is considered to +// be in UTC. +func (dst *Timestamp) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Timestamp{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + sbuf := string(buf) + switch sbuf { + case "infinity": + *dst = Timestamp{Status: Present, InfinityModifier: Infinity} + case "-infinity": + *dst = Timestamp{Status: Present, InfinityModifier: -Infinity} + default: + tim, err := time.Parse(pgTimestampFormat, sbuf) + if err != nil { + return err + } + + *dst = Timestamp{Time: tim, Status: Present} + } + + return nil +} + +// DecodeBinary decodes from src into dst. The decoded time is considered to +// be in UTC. +func (dst *Timestamp) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Timestamp{Status: Null} + return nil + } + + if size != 8 { + return fmt.Errorf("invalid length for timestamp: %v", size) + } + + microsecSinceY2K, err := pgio.ReadInt64(r) + if err != nil { + return err + } + + switch microsecSinceY2K { + case infinityMicrosecondOffset: + *dst = Timestamp{Status: Present, InfinityModifier: Infinity} + case negativeInfinityMicrosecondOffset: + *dst = Timestamp{Status: Present, InfinityModifier: -Infinity} + default: + microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K + tim := time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000).UTC() + *dst = Timestamp{Time: tim, Status: Present} + } + + return nil +} + +// EncodeText writes the text encoding of src into w. If src.Time is not in +// the UTC time zone it returns an error. +func (src Timestamp) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + if src.Time.Location() != time.UTC { + return fmt.Errorf("cannot encode non-UTC time into timestamp") + } + + var s string + + switch src.InfinityModifier { + case None: + s = src.Time.Format(pgTimestampFormat) + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + + _, err = w.Write([]byte(s)) + return err +} + +// EncodeBinary writes the binary encoding of src into w. If src.Time is not in +// the UTC time zone it returns an error. +func (src Timestamp) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + if src.Time.Location() != time.UTC { + return fmt.Errorf("cannot encode non-UTC time into timestamp") + } + + _, err := pgio.WriteInt32(w, 8) + if err != nil { + return err + } + + var microsecSinceY2K int64 + switch src.InfinityModifier { + case None: + microsecSinceUnixEpoch := src.Time.Unix()*1000000 + int64(src.Time.Nanosecond())/1000 + microsecSinceY2K = microsecSinceUnixEpoch - microsecFromUnixEpochToY2K + case Infinity: + microsecSinceY2K = infinityMicrosecondOffset + case NegativeInfinity: + microsecSinceY2K = negativeInfinityMicrosecondOffset + } + + _, err = pgio.WriteInt64(w, microsecSinceY2K) + return err +} diff --git a/pgtype/timestamp_test.go b/pgtype/timestamp_test.go new file mode 100644 index 00000000..6d6e738c --- /dev/null +++ b/pgtype/timestamp_test.go @@ -0,0 +1,123 @@ +package pgtype_test + +import ( + "reflect" + "testing" + "time" + + "github.com/jackc/pgx/pgtype" +) + +func TestTimestampTranscode(t *testing.T) { + testSuccessfulTranscodeEqFunc(t, "timestamp", []interface{}{ + pgtype.Timestamp{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Status: pgtype.Null}, + pgtype.Timestamp{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, + pgtype.Timestamp{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, + }, func(a, b interface{}) bool { + at := a.(pgtype.Timestamp) + bt := b.(pgtype.Timestamp) + + return at.Time.Equal(bt.Time) && at.Status == bt.Status && at.InfinityModifier == bt.InfinityModifier + }) +} + +func TestTimestampConvertFrom(t *testing.T) { + type _time time.Time + + successfulTests := []struct { + source interface{} + result pgtype.Timestamp + }{ + {source: pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, result: pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(1999, 12, 31, 12, 59, 59, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1999, 12, 31, 12, 59, 59, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)), result: pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Timestamp + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestTimestampAssignTo(t *testing.T) { + var tim time.Time + var ptim *time.Time + + simpleTests := []struct { + src pgtype.Timestamp + dst interface{} + expected interface{} + }{ + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, dst: &tim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.UTC)}, + {src: pgtype.Timestamp{Time: time.Time{}, Status: pgtype.Null}, dst: &ptim, expected: ((*time.Time)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Timestamp + dst interface{} + expected interface{} + }{ + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, dst: &ptim, expected: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Timestamp + dst interface{} + }{ + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, dst: &tim}, + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, dst: &tim}, + {src: pgtype.Timestamp{Time: time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Null}, dst: &tim}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/timestamparray.go b/pgtype/timestamparray.go new file mode 100644 index 00000000..f1b1d003 --- /dev/null +++ b/pgtype/timestamparray.go @@ -0,0 +1,287 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + "time" + + "github.com/jackc/pgx/pgio" +) + +type TimestampArray struct { + Elements []Timestamp + Dimensions []ArrayDimension + Status Status +} + +func (dst *TimestampArray) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case TimestampArray: + *dst = value + + case []time.Time: + if value == nil { + *dst = TimestampArray{Status: Null} + } else if len(value) == 0 { + *dst = TimestampArray{Status: Present} + } else { + elements := make([]Timestamp, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = TimestampArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Timestamp", value) + } + + return nil +} + +func (src *TimestampArray) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[]time.Time: + if src.Status == Present { + *v = make([]time.Time, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot put decode %v into %T", src, dst) + } + + return nil +} + +func (dst *TimestampArray) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = TimestampArray{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + uta, err := ParseUntypedTextArray(string(buf)) + if err != nil { + return err + } + + textElementReader := NewTextElementReader(r) + var elements []Timestamp + + if len(uta.Elements) > 0 { + elements = make([]Timestamp, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Timestamp + textElementReader.Reset(s) + err = elem.DecodeText(textElementReader) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = TimestampArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *TimestampArray) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = TimestampArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + err = arrayHeader.DecodeBinary(r) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = TimestampArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Timestamp, elementCount) + + for i := range elements { + err = elements[i].DecodeBinary(r) + if err != nil { + return err + } + } + + *dst = TimestampArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *TimestampArray) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + if len(src.Dimensions) == 0 { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = w.Write([]byte("{}")) + return err + } + + buf := &bytes.Buffer{} + + err := EncodeTextArrayDimensions(buf, src.Dimensions) + if err != nil { + return err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + textElementWriter := NewTextElementWriter(buf) + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(buf, ',') + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(buf, '{') + if err != nil { + return err + } + } + } + + textElementWriter.Reset() + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(buf, '}') + if err != nil { + return err + } + } + } + } + + _, err = pgio.WriteInt32(w, int32(buf.Len())) + if err != nil { + return err + } + + _, err = buf.WriteTo(w) + return err +} + +func (src *TimestampArray) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + var arrayHeader ArrayHeader + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + err := src.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return err + } + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + } + } + + arrayHeader.ElementOID = TimestampOID + arrayHeader.Dimensions = src.Dimensions + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + headerBuf := &bytes.Buffer{} + err := arrayHeader.EncodeBinary(headerBuf) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) + if err != nil { + return err + } + + _, err = headerBuf.WriteTo(w) + if err != nil { + return err + } + + _, err = elemBuf.WriteTo(w) + if err != nil { + return err + } + + return err +} diff --git a/pgtype/timestamparray_test.go b/pgtype/timestamparray_test.go new file mode 100644 index 00000000..68189cc7 --- /dev/null +++ b/pgtype/timestamparray_test.go @@ -0,0 +1,158 @@ +package pgtype_test + +import ( + "reflect" + "testing" + "time" + + "github.com/jackc/pgx/pgtype" +) + +func TestTimestampArrayTranscode(t *testing.T) { + testSuccessfulTranscodeEqFunc(t, "timestamp[]", []interface{}{ + &pgtype.TimestampArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + pgtype.Timestamp{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.TimestampArray{Status: pgtype.Null}, + &pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + pgtype.Timestamp{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(2016, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(2017, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(2012, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Status: pgtype.Null}, + pgtype.Timestamp{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{ + pgtype.Timestamp{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(2015, 2, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(2015, 2, 3, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + pgtype.Timestamp{Time: time.Date(2015, 2, 4, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }, func(a, b interface{}) bool { + ata := a.(pgtype.TimestampArray) + bta := b.(pgtype.TimestampArray) + + if len(ata.Elements) != len(bta.Elements) || ata.Status != bta.Status { + return false + } + + for i := range ata.Elements { + ae, be := ata.Elements[i], bta.Elements[i] + if !(ae.Time.Equal(be.Time) && ae.Status == be.Status && ae.InfinityModifier == be.InfinityModifier) { + return false + } + } + + return true + }) +} + +func TestTimestampArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.TimestampArray + }{ + { + source: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + result: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]time.Time)(nil)), + result: pgtype.TimestampArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.TimestampArray + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestTimestampArrayAssignTo(t *testing.T) { + var timeSlice []time.Time + + simpleTests := []struct { + src pgtype.TimestampArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{{Time: time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &timeSlice, + expected: []time.Time{time.Date(2015, 2, 1, 0, 0, 0, 0, time.UTC)}, + }, + { + src: pgtype.TimestampArray{Status: pgtype.Null}, + dst: &timeSlice, + expected: (([]time.Time)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.TimestampArray + dst interface{} + }{ + { + src: pgtype.TimestampArray{ + Elements: []pgtype.Timestamp{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &timeSlice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh index 9fec58e8..9f4e1ce0 100644 --- a/pgtype/typed_array_gen.sh +++ b/pgtype/typed_array_gen.sh @@ -4,3 +4,4 @@ erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64, erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_oid=BoolOID typed_array.go.erb > boolarray.go erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_oid=DateOID typed_array.go.erb > datearray.go erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_oid=TimestamptzOID typed_array.go.erb > timestamptzarray.go +erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_oid=TimestampOID typed_array.go.erb > timestamparray.go diff --git a/query_test.go b/query_test.go index 981df3ee..84b90d4b 100644 --- a/query_test.go +++ b/query_test.go @@ -516,7 +516,7 @@ func TestQueryRowCoreTypes(t *testing.T) { {"select $1::float8", []interface{}{float64(1.23)}, []interface{}{&actual.f64}, allTypes{f64: 1.23}}, {"select $1::bool", []interface{}{true}, []interface{}{&actual.b}, allTypes{b: true}}, {"select $1::timestamptz", []interface{}{time.Unix(123, 5000)}, []interface{}{&actual.t}, allTypes{t: time.Unix(123, 5000)}}, - {"select $1::timestamp", []interface{}{time.Date(2010, 1, 2, 3, 4, 5, 0, time.Local)}, []interface{}{&actual.t}, allTypes{t: time.Date(2010, 1, 2, 3, 4, 5, 0, time.Local)}}, + {"select $1::timestamp", []interface{}{time.Date(2010, 1, 2, 3, 4, 5, 0, time.UTC)}, []interface{}{&actual.t}, allTypes{t: time.Date(2010, 1, 2, 3, 4, 5, 0, time.UTC)}}, {"select $1::date", []interface{}{time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}, []interface{}{&actual.t}, allTypes{t: time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}}, {"select $1::oid", []interface{}{pgx.OID(42)}, []interface{}{&actual.oid}, allTypes{oid: 42}}, } diff --git a/values.go b/values.go index e83af308..e4347a8b 100644 --- a/values.go +++ b/values.go @@ -1096,10 +1096,6 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { return encodeFloat64(wbuf, oid, arg) case []float64: return encodeFloat64Slice(wbuf, oid, arg) - case time.Time: - return encodeTime(wbuf, oid, arg) - case []time.Time: - return encodeTimeSlice(wbuf, oid, arg) case net.IP: return encodeIP(wbuf, oid, arg) case []net.IP: @@ -1211,19 +1207,10 @@ func Decode(vr *ValueReader, d interface{}) error { *v = decodeFloat8Array(vr) case *[]string: *v = decodeTextArray(vr) - case *[]time.Time: - *v = decodeTimestampArray(vr) case *[][]byte: *v = decodeByteaArray(vr) case *[]interface{}: *v = decodeRecord(vr) - case *time.Time: - switch vr.Type().DataType { - case TimestampOID: - *v = decodeTimestamp(vr) - default: - return fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType) - } case *net.IP: ipnet := decodeInet(vr) if oneCount, bitCount := ipnet.Mask.Size(); oneCount != bitCount { diff --git a/values_test.go b/values_test.go index ef13ccdf..28f7371f 100644 --- a/values_test.go +++ b/values_test.go @@ -772,14 +772,6 @@ func TestArrayDecoding(t *testing.T) { } }, }, - { - "select $1::timestamp[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{}, - func(t *testing.T, query, scan interface{}) { - if !reflect.DeepEqual(query, *(scan.(*[]time.Time))) { - t.Errorf("failed to encode time.Time[] to timestamp[]") - } - }, - }, { "select $1::timestamptz[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{}, func(t *testing.T, query, scan interface{}) { @@ -1003,8 +995,6 @@ func TestPointerPointer(t *testing.T) { {"select $1::bool", []interface{}{zero.b}, []interface{}{&actual.b}, allTypes{}}, {"select $1::timestamptz", []interface{}{expected.t}, []interface{}{&actual.t}, allTypes{t: expected.t}}, {"select $1::timestamptz", []interface{}{zero.t}, []interface{}{&actual.t}, allTypes{}}, - {"select $1::timestamp", []interface{}{expected.t}, []interface{}{&actual.t}, allTypes{t: expected.t}}, - {"select $1::timestamp", []interface{}{zero.t}, []interface{}{&actual.t}, allTypes{}}, } for i, tt := range tests { From 2010bea5557fed90eb20d3a8baa767bf83fe869e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Mar 2017 13:29:04 -0600 Subject: [PATCH 077/264] Add float4, float8 and arrays --- conn.go | 4 + pgtype/convert.go | 52 ++++++- pgtype/float4.go | 171 ++++++++++++++++++++++ pgtype/float4_test.go | 148 +++++++++++++++++++ pgtype/float4array.go | 286 +++++++++++++++++++++++++++++++++++++ pgtype/float4array_test.go | 151 ++++++++++++++++++++ pgtype/float8.go | 161 +++++++++++++++++++++ pgtype/float8_test.go | 148 +++++++++++++++++++ pgtype/float8array.go | 286 +++++++++++++++++++++++++++++++++++++ pgtype/float8array_test.go | 151 ++++++++++++++++++++ pgtype/int2.go | 2 +- pgtype/int4.go | 2 +- pgtype/int8.go | 2 +- pgtype/pgtype_test.go | 2 + pgtype/typed_array_gen.sh | 2 + query_test.go | 6 - values.go | 16 --- 17 files changed, 1563 insertions(+), 27 deletions(-) create mode 100644 pgtype/float4.go create mode 100644 pgtype/float4_test.go create mode 100644 pgtype/float4array.go create mode 100644 pgtype/float4array_test.go create mode 100644 pgtype/float8.go create mode 100644 pgtype/float8_test.go create mode 100644 pgtype/float8array.go create mode 100644 pgtype/float8array_test.go diff --git a/conn.go b/conn.go index 5fd82ea0..1e277a0e 100644 --- a/conn.go +++ b/conn.go @@ -283,6 +283,10 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl BoolOID: &pgtype.Bool{}, DateArrayOID: &pgtype.DateArray{}, DateOID: &pgtype.Date{}, + Float4ArrayOID: &pgtype.Float4Array{}, + Float4OID: &pgtype.Float4{}, + Float8ArrayOID: &pgtype.Float8Array{}, + Float8OID: &pgtype.Float8{}, Int2ArrayOID: &pgtype.Int2Array{}, Int2OID: &pgtype.Int2{}, Int4ArrayOID: &pgtype.Int4Array{}, diff --git a/pgtype/convert.go b/pgtype/convert.go index e35e2310..c4b52322 100644 --- a/pgtype/convert.go +++ b/pgtype/convert.go @@ -11,8 +11,8 @@ const maxUint = ^uint(0) const maxInt = int(maxUint >> 1) const minInt = -maxInt - 1 -// underlyingIntType gets the underlying type that can be converted to Int2, Int4, or Int8 -func underlyingIntType(val interface{}) (interface{}, bool) { +// underlyingNumberType gets the underlying type that can be converted to Int2, Int4, Int8, Float4, or Float8 +func underlyingNumberType(val interface{}) (interface{}, bool) { refVal := reflect.ValueOf(val) switch refVal.Kind() { @@ -52,6 +52,12 @@ func underlyingIntType(val interface{}) (interface{}, bool) { case reflect.Uint64: convVal := uint64(refVal.Uint()) return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Float32: + convVal := float32(refVal.Float()) + return convVal, reflect.TypeOf(convVal) != refVal.Type() + case reflect.Float64: + convVal := refVal.Float() + return convVal, reflect.TypeOf(convVal) != refVal.Type() case reflect.String: convVal := refVal.String() return convVal, reflect.TypeOf(convVal) != refVal.Type() @@ -259,3 +265,45 @@ func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) } + +func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error { + if srcStatus == Present { + switch v := dst.(type) { + case *float32: + *v = float32(srcVal) + case *float64: + *v = srcVal + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return float64AssignTo(srcVal, srcStatus, el.Interface()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + i64 := int64(srcVal) + if float64(i64) == srcVal { + return int64AssignTo(i64, srcStatus, dst) + } + } + } + return fmt.Errorf("cannot assign %v into %T", srcVal, dst) + } + return nil + } + + // if dst is a pointer to pointer and srcStatus is not Present, nil it out + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + if el.Kind() == reflect.Ptr { + el.Set(reflect.Zero(el.Type())) + return nil + } + } + + return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) +} diff --git a/pgtype/float4.go b/pgtype/float4.go new file mode 100644 index 00000000..a1e5aa18 --- /dev/null +++ b/pgtype/float4.go @@ -0,0 +1,171 @@ +package pgtype + +import ( + "fmt" + "io" + "math" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +type Float4 struct { + Float float32 + Status Status +} + +func (dst *Float4) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Float4: + *dst = value + case float32: + *dst = Float4{Float: value, Status: Present} + case float64: + *dst = Float4{Float: float32(value), Status: Present} + case int8: + *dst = Float4{Float: float32(value), Status: Present} + case uint8: + *dst = Float4{Float: float32(value), Status: Present} + case int16: + *dst = Float4{Float: float32(value), Status: Present} + case uint16: + *dst = Float4{Float: float32(value), Status: Present} + case int32: + f32 := float32(value) + if int32(f32) == value { + *dst = Float4{Float: f32, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float32", value) + } + case uint32: + f32 := float32(value) + if uint32(f32) == value { + *dst = Float4{Float: f32, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float32", value) + } + case int64: + f32 := float32(value) + if int64(f32) == value { + *dst = Float4{Float: f32, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float32", value) + } + case uint64: + f32 := float32(value) + if uint64(f32) == value { + *dst = Float4{Float: f32, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float32", value) + } + case int: + f32 := float32(value) + if int(f32) == value { + *dst = Float4{Float: f32, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float32", value) + } + case uint: + f32 := float32(value) + if uint(f32) == value { + *dst = Float4{Float: f32, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float32", value) + } + case string: + num, err := strconv.ParseFloat(value, 32) + if err != nil { + return err + } + *dst = Float4{Float: float32(num), Status: Present} + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Float8", value) + } + + return nil +} + +func (src *Float4) AssignTo(dst interface{}) error { + return float64AssignTo(float64(src.Float), src.Status, dst) +} + +func (dst *Float4) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Float4{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + n, err := strconv.ParseFloat(string(buf), 32) + if err != nil { + return err + } + + *dst = Float4{Float: float32(n), Status: Present} + return nil +} + +func (dst *Float4) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Float4{Status: Null} + return nil + } + + if size != 4 { + return fmt.Errorf("invalid length for float4: %v", size) + } + + n, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + *dst = Float4{Float: math.Float32frombits(uint32(n)), Status: Present} + return nil +} + +func (src Float4) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + s := strconv.FormatFloat(float64(src.Float), 'f', -1, 32) + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + _, err = w.Write([]byte(s)) + return err +} + +func (src Float4) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 4) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(math.Float32bits(src.Float))) + return err +} diff --git a/pgtype/float4_test.go b/pgtype/float4_test.go new file mode 100644 index 00000000..62420b8d --- /dev/null +++ b/pgtype/float4_test.go @@ -0,0 +1,148 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestFloat4Transcode(t *testing.T) { + testSuccessfulTranscode(t, "float4", []interface{}{ + pgtype.Float4{Float: -1, Status: pgtype.Present}, + pgtype.Float4{Float: 0, Status: pgtype.Present}, + pgtype.Float4{Float: 0.00001, Status: pgtype.Present}, + pgtype.Float4{Float: 1, Status: pgtype.Present}, + pgtype.Float4{Float: 9999.99, Status: pgtype.Present}, + pgtype.Float4{Float: 0, Status: pgtype.Null}, + }) +} + +func TestFloat4ConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Float4 + }{ + {source: float32(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: float64(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: int8(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.Float4{Float: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.Float4{Float: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.Float4{Float: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.Float4{Float: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.Float4{Float: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Float4 + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestFloat4AssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + var f32 float32 + var f64 float64 + var pf32 *float32 + var pf64 *float64 + + simpleTests := []struct { + src pgtype.Float4 + dst interface{} + expected interface{} + }{ + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &f32, expected: float32(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &f64, expected: float64(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Float4{Float: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Float4{Float: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Float4 + dst interface{} + expected interface{} + }{ + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, + {src: pgtype.Float4{Float: 42, Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Float4 + dst interface{} + }{ + {src: pgtype.Float4{Float: 150, Status: pgtype.Present}, dst: &i8}, + {src: pgtype.Float4{Float: 40000, Status: pgtype.Present}, dst: &i16}, + {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui8}, + {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui16}, + {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui32}, + {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui64}, + {src: pgtype.Float4{Float: -1, Status: pgtype.Present}, dst: &ui}, + {src: pgtype.Float4{Float: 0, Status: pgtype.Null}, dst: &i32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/float4array.go b/pgtype/float4array.go new file mode 100644 index 00000000..c06490cf --- /dev/null +++ b/pgtype/float4array.go @@ -0,0 +1,286 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type Float4Array struct { + Elements []Float4 + Dimensions []ArrayDimension + Status Status +} + +func (dst *Float4Array) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Float4Array: + *dst = value + + case []float32: + if value == nil { + *dst = Float4Array{Status: Null} + } else if len(value) == 0 { + *dst = Float4Array{Status: Present} + } else { + elements := make([]Float4, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = Float4Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Float4", value) + } + + return nil +} + +func (src *Float4Array) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[]float32: + if src.Status == Present { + *v = make([]float32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot put decode %v into %T", src, dst) + } + + return nil +} + +func (dst *Float4Array) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Float4Array{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + uta, err := ParseUntypedTextArray(string(buf)) + if err != nil { + return err + } + + textElementReader := NewTextElementReader(r) + var elements []Float4 + + if len(uta.Elements) > 0 { + elements = make([]Float4, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Float4 + textElementReader.Reset(s) + err = elem.DecodeText(textElementReader) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = Float4Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *Float4Array) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Float4Array{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + err = arrayHeader.DecodeBinary(r) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = Float4Array{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Float4, elementCount) + + for i := range elements { + err = elements[i].DecodeBinary(r) + if err != nil { + return err + } + } + + *dst = Float4Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *Float4Array) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + if len(src.Dimensions) == 0 { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = w.Write([]byte("{}")) + return err + } + + buf := &bytes.Buffer{} + + err := EncodeTextArrayDimensions(buf, src.Dimensions) + if err != nil { + return err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + textElementWriter := NewTextElementWriter(buf) + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(buf, ',') + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(buf, '{') + if err != nil { + return err + } + } + } + + textElementWriter.Reset() + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(buf, '}') + if err != nil { + return err + } + } + } + } + + _, err = pgio.WriteInt32(w, int32(buf.Len())) + if err != nil { + return err + } + + _, err = buf.WriteTo(w) + return err +} + +func (src *Float4Array) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + var arrayHeader ArrayHeader + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + err := src.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return err + } + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + } + } + + arrayHeader.ElementOID = Float4OID + arrayHeader.Dimensions = src.Dimensions + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + headerBuf := &bytes.Buffer{} + err := arrayHeader.EncodeBinary(headerBuf) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) + if err != nil { + return err + } + + _, err = headerBuf.WriteTo(w) + if err != nil { + return err + } + + _, err = elemBuf.WriteTo(w) + if err != nil { + return err + } + + return err +} diff --git a/pgtype/float4array_test.go b/pgtype/float4array_test.go new file mode 100644 index 00000000..b22f4fbc --- /dev/null +++ b/pgtype/float4array_test.go @@ -0,0 +1,151 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestFloat4ArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "float4[]", []interface{}{ + &pgtype.Float4Array{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.Float4Array{ + Elements: []pgtype.Float4{ + pgtype.Float4{Float: 1, Status: pgtype.Present}, + pgtype.Float4{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Float4Array{Status: pgtype.Null}, + &pgtype.Float4Array{ + Elements: []pgtype.Float4{ + pgtype.Float4{Float: 1, Status: pgtype.Present}, + pgtype.Float4{Float: 2, Status: pgtype.Present}, + pgtype.Float4{Float: 3, Status: pgtype.Present}, + pgtype.Float4{Float: 4, Status: pgtype.Present}, + pgtype.Float4{Status: pgtype.Null}, + pgtype.Float4{Float: 6, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Float4Array{ + Elements: []pgtype.Float4{ + pgtype.Float4{Float: 1, Status: pgtype.Present}, + pgtype.Float4{Float: 2, Status: pgtype.Present}, + pgtype.Float4{Float: 3, Status: pgtype.Present}, + pgtype.Float4{Float: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestFloat4ArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Float4Array + }{ + { + source: []float32{1}, + result: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]float32)(nil)), + result: pgtype.Float4Array{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Float4Array + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestFloat4ArrayAssignTo(t *testing.T) { + var float32Slice []float32 + var namedFloat32Slice _float32Slice + + simpleTests := []struct { + src pgtype.Float4Array + dst interface{} + expected interface{} + }{ + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1.23, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float32Slice, + expected: []float32{1.23}, + }, + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Float: 1.23, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedFloat32Slice, + expected: _float32Slice{1.23}, + }, + { + src: pgtype.Float4Array{Status: pgtype.Null}, + dst: &float32Slice, + expected: (([]float32)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Float4Array + dst interface{} + }{ + { + src: pgtype.Float4Array{ + Elements: []pgtype.Float4{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float32Slice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/float8.go b/pgtype/float8.go new file mode 100644 index 00000000..c1347cb2 --- /dev/null +++ b/pgtype/float8.go @@ -0,0 +1,161 @@ +package pgtype + +import ( + "fmt" + "io" + "math" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +type Float8 struct { + Float float64 + Status Status +} + +func (dst *Float8) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Float8: + *dst = value + case float32: + *dst = Float8{Float: float64(value), Status: Present} + case float64: + *dst = Float8{Float: value, Status: Present} + case int8: + *dst = Float8{Float: float64(value), Status: Present} + case uint8: + *dst = Float8{Float: float64(value), Status: Present} + case int16: + *dst = Float8{Float: float64(value), Status: Present} + case uint16: + *dst = Float8{Float: float64(value), Status: Present} + case int32: + *dst = Float8{Float: float64(value), Status: Present} + case uint32: + *dst = Float8{Float: float64(value), Status: Present} + case int64: + f64 := float64(value) + if int64(f64) == value { + *dst = Float8{Float: f64, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float64", value) + } + case uint64: + f64 := float64(value) + if uint64(f64) == value { + *dst = Float8{Float: f64, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float64", value) + } + case int: + f64 := float64(value) + if int(f64) == value { + *dst = Float8{Float: f64, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float64", value) + } + case uint: + f64 := float64(value) + if uint(f64) == value { + *dst = Float8{Float: f64, Status: Present} + } else { + return fmt.Errorf("%v cannot be exactly represented as float64", value) + } + case string: + num, err := strconv.ParseFloat(value, 64) + if err != nil { + return err + } + *dst = Float8{Float: float64(num), Status: Present} + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Float8", value) + } + + return nil +} + +func (src *Float8) AssignTo(dst interface{}) error { + return float64AssignTo(src.Float, src.Status, dst) +} + +func (dst *Float8) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Float8{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + n, err := strconv.ParseFloat(string(buf), 64) + if err != nil { + return err + } + + *dst = Float8{Float: n, Status: Present} + return nil +} + +func (dst *Float8) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Float8{Status: Null} + return nil + } + + if size != 8 { + return fmt.Errorf("invalid length for float4: %v", size) + } + + n, err := pgio.ReadInt64(r) + if err != nil { + return err + } + + *dst = Float8{Float: math.Float64frombits(uint64(n)), Status: Present} + return nil +} + +func (src Float8) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + s := strconv.FormatFloat(float64(src.Float), 'f', -1, 64) + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + _, err = w.Write([]byte(s)) + return err +} + +func (src Float8) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 8) + if err != nil { + return err + } + + _, err = pgio.WriteInt64(w, int64(math.Float64bits(src.Float))) + return err +} diff --git a/pgtype/float8_test.go b/pgtype/float8_test.go new file mode 100644 index 00000000..748ffd25 --- /dev/null +++ b/pgtype/float8_test.go @@ -0,0 +1,148 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestFloat8Transcode(t *testing.T) { + testSuccessfulTranscode(t, "float8", []interface{}{ + pgtype.Float8{Float: -1, Status: pgtype.Present}, + pgtype.Float8{Float: 0, Status: pgtype.Present}, + pgtype.Float8{Float: 0.00001, Status: pgtype.Present}, + pgtype.Float8{Float: 1, Status: pgtype.Present}, + pgtype.Float8{Float: 9999.99, Status: pgtype.Present}, + pgtype.Float8{Float: 0, Status: pgtype.Null}, + }) +} + +func TestFloat8ConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Float8 + }{ + {source: float32(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: float64(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: int8(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.Float8{Float: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.Float8{Float: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.Float8{Float: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.Float8{Float: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.Float8{Float: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Float8 + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestFloat8AssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + var f32 float32 + var f64 float64 + var pf32 *float32 + var pf64 *float64 + + simpleTests := []struct { + src pgtype.Float8 + dst interface{} + expected interface{} + }{ + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &f32, expected: float32(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &f64, expected: float64(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.Float8{Float: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.Float8{Float: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Float8 + dst interface{} + expected interface{} + }{ + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, + {src: pgtype.Float8{Float: 42, Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Float8 + dst interface{} + }{ + {src: pgtype.Float8{Float: 150, Status: pgtype.Present}, dst: &i8}, + {src: pgtype.Float8{Float: 40000, Status: pgtype.Present}, dst: &i16}, + {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui8}, + {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui16}, + {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui32}, + {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui64}, + {src: pgtype.Float8{Float: -1, Status: pgtype.Present}, dst: &ui}, + {src: pgtype.Float8{Float: 0, Status: pgtype.Null}, dst: &i32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/float8array.go b/pgtype/float8array.go new file mode 100644 index 00000000..776fc1e6 --- /dev/null +++ b/pgtype/float8array.go @@ -0,0 +1,286 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type Float8Array struct { + Elements []Float8 + Dimensions []ArrayDimension + Status Status +} + +func (dst *Float8Array) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Float8Array: + *dst = value + + case []float64: + if value == nil { + *dst = Float8Array{Status: Null} + } else if len(value) == 0 { + *dst = Float8Array{Status: Present} + } else { + elements := make([]Float8, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = Float8Array{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Float8", value) + } + + return nil +} + +func (src *Float8Array) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[]float64: + if src.Status == Present { + *v = make([]float64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot put decode %v into %T", src, dst) + } + + return nil +} + +func (dst *Float8Array) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Float8Array{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + uta, err := ParseUntypedTextArray(string(buf)) + if err != nil { + return err + } + + textElementReader := NewTextElementReader(r) + var elements []Float8 + + if len(uta.Elements) > 0 { + elements = make([]Float8, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Float8 + textElementReader.Reset(s) + err = elem.DecodeText(textElementReader) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = Float8Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *Float8Array) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Float8Array{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + err = arrayHeader.DecodeBinary(r) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = Float8Array{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Float8, elementCount) + + for i := range elements { + err = elements[i].DecodeBinary(r) + if err != nil { + return err + } + } + + *dst = Float8Array{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *Float8Array) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + if len(src.Dimensions) == 0 { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = w.Write([]byte("{}")) + return err + } + + buf := &bytes.Buffer{} + + err := EncodeTextArrayDimensions(buf, src.Dimensions) + if err != nil { + return err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + textElementWriter := NewTextElementWriter(buf) + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(buf, ',') + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(buf, '{') + if err != nil { + return err + } + } + } + + textElementWriter.Reset() + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(buf, '}') + if err != nil { + return err + } + } + } + } + + _, err = pgio.WriteInt32(w, int32(buf.Len())) + if err != nil { + return err + } + + _, err = buf.WriteTo(w) + return err +} + +func (src *Float8Array) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + var arrayHeader ArrayHeader + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + err := src.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return err + } + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + } + } + + arrayHeader.ElementOID = Float8OID + arrayHeader.Dimensions = src.Dimensions + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + headerBuf := &bytes.Buffer{} + err := arrayHeader.EncodeBinary(headerBuf) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) + if err != nil { + return err + } + + _, err = headerBuf.WriteTo(w) + if err != nil { + return err + } + + _, err = elemBuf.WriteTo(w) + if err != nil { + return err + } + + return err +} diff --git a/pgtype/float8array_test.go b/pgtype/float8array_test.go new file mode 100644 index 00000000..d4402281 --- /dev/null +++ b/pgtype/float8array_test.go @@ -0,0 +1,151 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestFloat8ArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "float8[]", []interface{}{ + &pgtype.Float8Array{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.Float8Array{ + Elements: []pgtype.Float8{ + pgtype.Float8{Float: 1, Status: pgtype.Present}, + pgtype.Float8{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Float8Array{Status: pgtype.Null}, + &pgtype.Float8Array{ + Elements: []pgtype.Float8{ + pgtype.Float8{Float: 1, Status: pgtype.Present}, + pgtype.Float8{Float: 2, Status: pgtype.Present}, + pgtype.Float8{Float: 3, Status: pgtype.Present}, + pgtype.Float8{Float: 4, Status: pgtype.Present}, + pgtype.Float8{Status: pgtype.Null}, + pgtype.Float8{Float: 6, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Float8Array{ + Elements: []pgtype.Float8{ + pgtype.Float8{Float: 1, Status: pgtype.Present}, + pgtype.Float8{Float: 2, Status: pgtype.Present}, + pgtype.Float8{Float: 3, Status: pgtype.Present}, + pgtype.Float8{Float: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestFloat8ArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Float8Array + }{ + { + source: []float64{1}, + result: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]float64)(nil)), + result: pgtype.Float8Array{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Float8Array + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestFloat8ArrayAssignTo(t *testing.T) { + var float64Slice []float64 + var namedFloat64Slice _float64Slice + + simpleTests := []struct { + src pgtype.Float8Array + dst interface{} + expected interface{} + }{ + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1.23, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float64Slice, + expected: []float64{1.23}, + }, + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Float: 1.23, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedFloat64Slice, + expected: _float64Slice{1.23}, + }, + { + src: pgtype.Float8Array{Status: pgtype.Null}, + dst: &float64Slice, + expected: (([]float64)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Float8Array + dst interface{} + }{ + { + src: pgtype.Float8Array{ + Elements: []pgtype.Float8{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float64Slice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/int2.go b/pgtype/int2.go index fb6a8ccc..8057550b 100644 --- a/pgtype/int2.go +++ b/pgtype/int2.go @@ -75,7 +75,7 @@ func (dst *Int2) ConvertFrom(src interface{}) error { } *dst = Int2{Int: int16(num), Status: Present} default: - if originalSrc, ok := underlyingIntType(src); ok { + if originalSrc, ok := underlyingNumberType(src); ok { return dst.ConvertFrom(originalSrc) } return fmt.Errorf("cannot convert %v to Int2", value) diff --git a/pgtype/int4.go b/pgtype/int4.go index 1a4733b0..43691bb6 100644 --- a/pgtype/int4.go +++ b/pgtype/int4.go @@ -66,7 +66,7 @@ func (dst *Int4) ConvertFrom(src interface{}) error { } *dst = Int4{Int: int32(num), Status: Present} default: - if originalSrc, ok := underlyingIntType(src); ok { + if originalSrc, ok := underlyingNumberType(src); ok { return dst.ConvertFrom(originalSrc) } return fmt.Errorf("cannot convert %v to Int8", value) diff --git a/pgtype/int8.go b/pgtype/int8.go index 7f307f18..b87bb85a 100644 --- a/pgtype/int8.go +++ b/pgtype/int8.go @@ -57,7 +57,7 @@ func (dst *Int8) ConvertFrom(src interface{}) error { } *dst = Int8{Int: num, Status: Present} default: - if originalSrc, ok := underlyingIntType(src); ok { + if originalSrc, ok := underlyingNumberType(src); ok { return dst.ConvertFrom(originalSrc) } return fmt.Errorf("cannot convert %v to Int8", value) diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 97afc249..a1dcd11b 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -18,6 +18,8 @@ type _int16 int16 type _int16Slice []int16 type _int32Slice []int32 type _int64Slice []int64 +type _float32Slice []float32 +type _float64Slice []float64 func mustConnectPgx(t testing.TB) *pgx.Conn { config, err := pgx.ParseURI(os.Getenv("DATABASE_URL")) diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh index 9f4e1ce0..4ce6c3b5 100644 --- a/pgtype/typed_array_gen.sh +++ b/pgtype/typed_array_gen.sh @@ -5,3 +5,5 @@ erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool e erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_oid=DateOID typed_array.go.erb > datearray.go erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_oid=TimestamptzOID typed_array.go.erb > timestamptzarray.go erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_oid=TimestampOID typed_array.go.erb > timestamparray.go +erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_oid=Float4OID typed_array.go.erb > float4array.go +erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8OID typed_array.go.erb > float8array.go diff --git a/query_test.go b/query_test.go index 84b90d4b..9b6f9a06 100644 --- a/query_test.go +++ b/query_test.go @@ -1151,9 +1151,6 @@ func TestQueryRowCoreFloat32Slice(t *testing.T) { if err == nil { t.Error("Expected null to cause error when scanned into slice, but it didn't") } - if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { - t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err) - } ensureConnValid(t, conn) } @@ -1198,9 +1195,6 @@ func TestQueryRowCoreFloat64Slice(t *testing.T) { if err == nil { t.Error("Expected null to cause error when scanned into slice, but it didn't") } - if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { - t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err) - } ensureConnValid(t, conn) } diff --git a/values.go b/values.go index e4347a8b..d2ec9fc2 100644 --- a/values.go +++ b/values.go @@ -1088,14 +1088,6 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { // The name data type goes over the wire using the same format as string, // so just cast to string and use encodeString return encodeString(wbuf, oid, string(arg)) - case float32: - return encodeFloat32(wbuf, oid, arg) - case []float32: - return encodeFloat32Slice(wbuf, oid, arg) - case float64: - return encodeFloat64(wbuf, oid, arg) - case []float64: - return encodeFloat64Slice(wbuf, oid, arg) case net.IP: return encodeIP(wbuf, oid, arg) case []net.IP: @@ -1195,16 +1187,8 @@ func Decode(vr *ValueReader, d interface{}) error { *v = decodeCid(vr) case *string: *v = decodeText(vr) - case *float32: - *v = decodeFloat4(vr) - case *float64: - *v = decodeFloat8(vr) case *[]AclItem: *v = decodeAclItemArray(vr) - case *[]float32: - *v = decodeFloat4Array(vr) - case *[]float64: - *v = decodeFloat8Array(vr) case *[]string: *v = decodeTextArray(vr) case *[][]byte: From 4cdea13f0f6a1e0239b2db00b68fde85eddd265b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Mar 2017 17:33:41 -0600 Subject: [PATCH 078/264] Add inet and cidr to pgtype --- conn.go | 4 + pgtype/cidrarray.go | 31 ++++ pgtype/convert.go | 16 ++ pgtype/inet.go | 240 ++++++++++++++++++++++++++++ pgtype/inet_test.go | 115 ++++++++++++++ pgtype/inetarray.go | 320 ++++++++++++++++++++++++++++++++++++++ pgtype/inetarray_test.go | 164 +++++++++++++++++++ pgtype/pgtype_test.go | 10 ++ pgtype/typed_array_gen.sh | 1 + values.go | 28 ---- values_test.go | 30 ++-- 11 files changed, 916 insertions(+), 43 deletions(-) create mode 100644 pgtype/cidrarray.go create mode 100644 pgtype/inet.go create mode 100644 pgtype/inet_test.go create mode 100644 pgtype/inetarray.go create mode 100644 pgtype/inetarray_test.go diff --git a/conn.go b/conn.go index 1e277a0e..b6670735 100644 --- a/conn.go +++ b/conn.go @@ -281,12 +281,16 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.oidPgtypeValues = map[OID]pgtype.Value{ BoolArrayOID: &pgtype.BoolArray{}, BoolOID: &pgtype.Bool{}, + CidrArrayOID: &pgtype.CidrArray{}, + CidrOID: &pgtype.Inet{}, DateArrayOID: &pgtype.DateArray{}, DateOID: &pgtype.Date{}, Float4ArrayOID: &pgtype.Float4Array{}, Float4OID: &pgtype.Float4{}, Float8ArrayOID: &pgtype.Float8Array{}, Float8OID: &pgtype.Float8{}, + InetArrayOID: &pgtype.InetArray{}, + InetOID: &pgtype.Inet{}, Int2ArrayOID: &pgtype.Int2Array{}, Int2OID: &pgtype.Int2{}, Int4ArrayOID: &pgtype.Int4Array{}, diff --git a/pgtype/cidrarray.go b/pgtype/cidrarray.go new file mode 100644 index 00000000..66dd20d0 --- /dev/null +++ b/pgtype/cidrarray.go @@ -0,0 +1,31 @@ +package pgtype + +import ( + "io" +) + +type CidrArray InetArray + +func (dst *CidrArray) ConvertFrom(src interface{}) error { + return (*InetArray)(dst).ConvertFrom(src) +} + +func (src *CidrArray) AssignTo(dst interface{}) error { + return (*InetArray)(src).AssignTo(dst) +} + +func (dst *CidrArray) DecodeText(r io.Reader) error { + return (*InetArray)(dst).DecodeText(r) +} + +func (dst *CidrArray) DecodeBinary(r io.Reader) error { + return (*InetArray)(dst).DecodeBinary(r) +} + +func (src *CidrArray) EncodeText(w io.Writer) error { + return (*InetArray)(src).EncodeText(w) +} + +func (src *CidrArray) EncodeBinary(w io.Writer) error { + return (*InetArray)(src).encodeBinary(w, CidrOID) +} diff --git a/pgtype/convert.go b/pgtype/convert.go index c4b52322..7111f8bc 100644 --- a/pgtype/convert.go +++ b/pgtype/convert.go @@ -85,6 +85,22 @@ func underlyingBoolType(val interface{}) (interface{}, bool) { return nil, false } +// underlyingPtrType dereferences a pointer +func underlyingPtrType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + } + + return nil, false +} + // underlyingTimeType gets the underlying type that can be converted to time.Time func underlyingTimeType(val interface{}) (interface{}, bool) { refVal := reflect.ValueOf(val) diff --git a/pgtype/inet.go b/pgtype/inet.go new file mode 100644 index 00000000..e47c64b0 --- /dev/null +++ b/pgtype/inet.go @@ -0,0 +1,240 @@ +package pgtype + +import ( + "fmt" + "io" + "net" + "reflect" + + "github.com/jackc/pgx/pgio" +) + +// Network address family is dependent on server socket.h value for AF_INET. +// In practice, all platforms appear to have the same value. See +// src/include/utils/inet.h for more information. +const ( + defaultAFInet = 2 + defaultAFInet6 = 3 +) + +// Inet represents both inet and cidr PostgreSQL types. +type Inet struct { + IPNet *net.IPNet + Status Status +} + +func (dst *Inet) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Inet: + *dst = value + case net.IPNet: + *dst = Inet{IPNet: &value, Status: Present} + case *net.IPNet: + *dst = Inet{IPNet: value, Status: Present} + case net.IP: + bitCount := len(value) * 8 + mask := net.CIDRMask(bitCount, bitCount) + *dst = Inet{IPNet: &net.IPNet{Mask: mask, IP: value}, Status: Present} + case string: + _, ipnet, err := net.ParseCIDR(value) + if err != nil { + return err + } + *dst = Inet{IPNet: ipnet, Status: Present} + default: + if originalSrc, ok := underlyingPtrType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Inet", value) + } + + return nil +} + +func (src *Inet) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *net.IPNet: + if src.Status != Present { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = *src.IPNet + case *net.IP: + if src.Status == Present { + + if oneCount, bitCount := src.IPNet.Mask.Size(); oneCount != bitCount { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.IPNet.IP + } else { + *v = nil + } + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if src.Status == Null { + el.Set(reflect.Zero(el.Type())) + return nil + } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return src.AssignTo(el.Interface()) + } + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +func (dst *Inet) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Inet{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + var ipnet *net.IPNet + + if ip := net.ParseIP(string(buf)); ip != nil { + ipv4 := ip.To4() + if ipv4 != nil { + ip = ipv4 + } + bitCount := len(ip) * 8 + mask := net.CIDRMask(bitCount, bitCount) + ipnet = &net.IPNet{Mask: mask, IP: ip} + } else { + _, ipnet, err = net.ParseCIDR(string(buf)) + if err != nil { + return err + } + } + + *dst = Inet{IPNet: ipnet, Status: Present} + return nil +} + +func (dst *Inet) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Inet{Status: Null} + return nil + } + + if size != 8 && size != 20 { + return fmt.Errorf("Received an invalid size for a inet: %d", size) + } + + // ignore family + _, err = pgio.ReadByte(r) + if err != nil { + return err + } + + bits, err := pgio.ReadByte(r) + if err != nil { + return err + } + + // ignore is_cidr + _, err = pgio.ReadByte(r) + if err != nil { + return err + } + + addressLength, err := pgio.ReadByte(r) + if err != nil { + return err + } + + var ipnet net.IPNet + ipnet.IP = make(net.IP, int(addressLength)) + _, err = r.Read(ipnet.IP) + if err != nil { + return err + } + + ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8) + + *dst = Inet{IPNet: &ipnet, Status: Present} + + return nil +} + +func (src Inet) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + s := src.IPNet.String() + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + _, err = w.Write([]byte(s)) + return err +} + +// EncodeBinary encodes src into w. +func (src Inet) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + var size int32 + var family byte + switch len(src.IPNet.IP) { + case net.IPv4len: + size = 8 + family = defaultAFInet + case net.IPv6len: + size = 20 + family = defaultAFInet6 + default: + return fmt.Errorf("Unexpected IP length: %v", len(src.IPNet.IP)) + } + + if _, err := pgio.WriteInt32(w, size); err != nil { + return err + } + + if err := pgio.WriteByte(w, family); err != nil { + return err + } + + ones, _ := src.IPNet.Mask.Size() + if err := pgio.WriteByte(w, byte(ones)); err != nil { + return err + } + + // is_cidr is ignored on server + if err := pgio.WriteByte(w, 0); err != nil { + return err + } + + if err := pgio.WriteByte(w, byte(len(src.IPNet.IP))); err != nil { + return err + } + + _, err := w.Write(src.IPNet.IP) + return err +} diff --git a/pgtype/inet_test.go b/pgtype/inet_test.go new file mode 100644 index 00000000..5e86376b --- /dev/null +++ b/pgtype/inet_test.go @@ -0,0 +1,115 @@ +package pgtype_test + +import ( + "net" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInetTranscode(t *testing.T) { + for _, pgTypeName := range []string{"inet", "cidr"} { + testSuccessfulTranscode(t, pgTypeName, []interface{}{ + pgtype.Inet{IPNet: mustParseCIDR(t, "0.0.0.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.1.0/24"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "255.255.255.255/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "::/128"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "::/0"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "::1/128"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + pgtype.Inet{Status: pgtype.Null}, + }) + } +} + +func TestInetConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Inet + }{ + {source: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Null}, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Null}}, + {source: mustParseCIDR(t, "127.0.0.1/32"), result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + {source: mustParseCIDR(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + {source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.Inet + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInetAssignTo(t *testing.T) { + var ipnet net.IPNet + var pipnet *net.IPNet + var ip net.IP + var pip *net.IP + + simpleTests := []struct { + src pgtype.Inet + dst interface{} + expected interface{} + }{ + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, + {src: pgtype.Inet{Status: pgtype.Null}, dst: &pipnet, expected: ((*net.IPNet)(nil))}, + {src: pgtype.Inet{Status: pgtype.Null}, dst: &pip, expected: ((*net.IP)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %#v, but result was %#v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Inet + dst interface{} + expected interface{} + }{ + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Inet + dst interface{} + }{ + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.0/16"), Status: pgtype.Present}, dst: &ip}, + {src: pgtype.Inet{Status: pgtype.Null}, dst: &ipnet}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/inetarray.go b/pgtype/inetarray.go new file mode 100644 index 00000000..eb5a4c88 --- /dev/null +++ b/pgtype/inetarray.go @@ -0,0 +1,320 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + "net" + + "github.com/jackc/pgx/pgio" +) + +type InetArray struct { + Elements []Inet + Dimensions []ArrayDimension + Status Status +} + +func (dst *InetArray) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case InetArray: + *dst = value + case CidrArray: + *dst = InetArray(value) + case []*net.IPNet: + if value == nil { + *dst = InetArray{Status: Null} + } else if len(value) == 0 { + *dst = InetArray{Status: Present} + } else { + elements := make([]Inet, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = InetArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + case []net.IP: + if value == nil { + *dst = InetArray{Status: Null} + } else if len(value) == 0 { + *dst = InetArray{Status: Present} + } else { + elements := make([]Inet, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = InetArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Inet", value) + } + + return nil +} + +func (src *InetArray) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[]*net.IPNet: + if src.Status == Present { + *v = make([]*net.IPNet, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + case *[]net.IP: + if src.Status == Present { + *v = make([]net.IP, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot put decode %v into %T", src, dst) + } + + return nil +} + +func (dst *InetArray) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = InetArray{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + uta, err := ParseUntypedTextArray(string(buf)) + if err != nil { + return err + } + + textElementReader := NewTextElementReader(r) + var elements []Inet + + if len(uta.Elements) > 0 { + elements = make([]Inet, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Inet + textElementReader.Reset(s) + err = elem.DecodeText(textElementReader) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = InetArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *InetArray) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = InetArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + err = arrayHeader.DecodeBinary(r) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = InetArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Inet, elementCount) + + for i := range elements { + err = elements[i].DecodeBinary(r) + if err != nil { + return err + } + } + + *dst = InetArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *InetArray) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + if len(src.Dimensions) == 0 { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = w.Write([]byte("{}")) + return err + } + + buf := &bytes.Buffer{} + + err := EncodeTextArrayDimensions(buf, src.Dimensions) + if err != nil { + return err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + textElementWriter := NewTextElementWriter(buf) + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(buf, ',') + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(buf, '{') + if err != nil { + return err + } + } + } + + textElementWriter.Reset() + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(buf, '}') + if err != nil { + return err + } + } + } + } + + _, err = pgio.WriteInt32(w, int32(buf.Len())) + if err != nil { + return err + } + + _, err = buf.WriteTo(w) + return err +} + +func (src *InetArray) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, InetOID) +} + +func (src *InetArray) encodeBinary(w io.Writer, elementOID int32) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + var arrayHeader ArrayHeader + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + err := src.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return err + } + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + } + } + + arrayHeader.ElementOID = elementOID + arrayHeader.Dimensions = src.Dimensions + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + headerBuf := &bytes.Buffer{} + err := arrayHeader.EncodeBinary(headerBuf) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) + if err != nil { + return err + } + + _, err = headerBuf.WriteTo(w) + if err != nil { + return err + } + + _, err = elemBuf.WriteTo(w) + if err != nil { + return err + } + + return err +} diff --git a/pgtype/inetarray_test.go b/pgtype/inetarray_test.go new file mode 100644 index 00000000..8cab5355 --- /dev/null +++ b/pgtype/inetarray_test.go @@ -0,0 +1,164 @@ +package pgtype_test + +import ( + "net" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInetArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "inet[]", []interface{}{ + &pgtype.InetArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.InetArray{ + Elements: []pgtype.Inet{ + pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Inet{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.InetArray{Status: pgtype.Null}, + &pgtype.InetArray{ + Elements: []pgtype.Inet{ + pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + pgtype.Inet{Status: pgtype.Null}, + pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.InetArray{ + Elements: []pgtype.Inet{ + pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestInetArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.InetArray + }{ + { + source: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]*net.IPNet)(nil)), + result: pgtype.InetArray{Status: pgtype.Null}, + }, + { + source: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, + result: pgtype.InetArray{ + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]net.IP)(nil)), + result: pgtype.InetArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.InetArray + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestInetArrayAssignTo(t *testing.T) { + var ipnetSlice []*net.IPNet + var ipSlice []net.IP + + simpleTests := []struct { + src pgtype.InetArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipnetSlice, + expected: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipnetSlice, + expected: []*net.IPNet{nil}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipSlice, + expected: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, + }, + { + src: pgtype.InetArray{ + Elements: []pgtype.Inet{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipSlice, + expected: []net.IP{nil}, + }, + { + src: pgtype.InetArray{Status: pgtype.Null}, + dst: &ipnetSlice, + expected: (([]*net.IPNet)(nil)), + }, + { + src: pgtype.InetArray{Status: pgtype.Null}, + dst: &ipSlice, + expected: (([]net.IP)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index a1dcd11b..7d34ae34 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -3,6 +3,7 @@ package pgtype_test import ( "fmt" "io" + "net" "os" "reflect" "testing" @@ -44,6 +45,15 @@ func mustClose(t testing.TB, conn interface { } } +func mustParseCIDR(t testing.TB, s string) *net.IPNet { + _, ipnet, err := net.ParseCIDR(s) + if err != nil { + t.Fatal(err) + } + + return ipnet +} + type forceTextEncoder struct { e pgtype.TextEncoder } diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh index 4ce6c3b5..47afdf1d 100644 --- a/pgtype/typed_array_gen.sh +++ b/pgtype/typed_array_gen.sh @@ -7,3 +7,4 @@ erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_ erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_oid=TimestampOID typed_array.go.erb > timestamparray.go erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_oid=Float4OID typed_array.go.erb > float4array.go erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8OID typed_array.go.erb > float8array.go +erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOID typed_array.go.erb > inetarray.go diff --git a/values.go b/values.go index d2ec9fc2..8a2da367 100644 --- a/values.go +++ b/values.go @@ -1088,14 +1088,6 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { // The name data type goes over the wire using the same format as string, // so just cast to string and use encodeString return encodeString(wbuf, oid, string(arg)) - case net.IP: - return encodeIP(wbuf, oid, arg) - case []net.IP: - return encodeIPSlice(wbuf, oid, arg) - case net.IPNet: - return encodeIPNet(wbuf, oid, arg) - case []net.IPNet: - return encodeIPNetSlice(wbuf, oid, arg) case OID: return encodeOID(wbuf, oid, arg) case Xid: @@ -1195,26 +1187,6 @@ func Decode(vr *ValueReader, d interface{}) error { *v = decodeByteaArray(vr) case *[]interface{}: *v = decodeRecord(vr) - case *net.IP: - ipnet := decodeInet(vr) - if oneCount, bitCount := ipnet.Mask.Size(); oneCount != bitCount { - return fmt.Errorf("Cannot decode netmask into *net.IP") - } - *v = ipnet.IP - case *[]net.IP: - ipnets := decodeInetArray(vr) - ips := make([]net.IP, len(ipnets)) - for i, ipnet := range ipnets { - if oneCount, bitCount := ipnet.Mask.Size(); oneCount != bitCount { - return fmt.Errorf("Cannot decode netmask into *net.IP") - } - ips[i] = ipnet.IP - } - *v = ips - case *net.IPNet: - *v = decodeInet(vr) - case *[]net.IPNet: - *v = decodeInetArray(vr) default: if v := reflect.ValueOf(d); v.Kind() == reflect.Ptr { el := v.Elem() diff --git a/values_test.go b/values_test.go index 28f7371f..d6ce705a 100644 --- a/values_test.go +++ b/values_test.go @@ -232,13 +232,13 @@ func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string, format int16) } } -func mustParseCIDR(t *testing.T, s string) net.IPNet { +func mustParseCIDR(t *testing.T, s string) *net.IPNet { _, ipnet, err := net.ParseCIDR(s) if err != nil { t.Fatal(err) } - return *ipnet + return ipnet } func TestStringToNotTextTypeTranscode(t *testing.T) { @@ -275,7 +275,7 @@ func TestInetCidrTranscodeIPNet(t *testing.T) { tests := []struct { sql string - value net.IPNet + value *net.IPNet }{ {"select $1::inet", mustParseCIDR(t, "0.0.0.0/32")}, {"select $1::inet", mustParseCIDR(t, "127.0.0.1/32")}, @@ -358,7 +358,7 @@ func TestInetCidrTranscodeIP(t *testing.T) { failTests := []struct { sql string - value net.IPNet + value *net.IPNet }{ {"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")}, {"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")}, @@ -367,8 +367,8 @@ func TestInetCidrTranscodeIP(t *testing.T) { var actual net.IP err := conn.QueryRow(tt.sql, tt.value).Scan(&actual) - if !strings.Contains(err.Error(), "Cannot decode netmask") { - t.Errorf("%d. Expected failure cannot decode netmask, but got: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + if err == nil { + t.Errorf("%d. Expected failure but got none", i) continue } @@ -384,11 +384,11 @@ func TestInetCidrArrayTranscodeIPNet(t *testing.T) { tests := []struct { sql string - value []net.IPNet + value []*net.IPNet }{ { "select $1::inet[]", - []net.IPNet{ + []*net.IPNet{ mustParseCIDR(t, "0.0.0.0/32"), mustParseCIDR(t, "127.0.0.1/32"), mustParseCIDR(t, "12.34.56.0/32"), @@ -403,7 +403,7 @@ func TestInetCidrArrayTranscodeIPNet(t *testing.T) { }, { "select $1::cidr[]", - []net.IPNet{ + []*net.IPNet{ mustParseCIDR(t, "0.0.0.0/32"), mustParseCIDR(t, "127.0.0.1/32"), mustParseCIDR(t, "12.34.56.0/32"), @@ -419,7 +419,7 @@ func TestInetCidrArrayTranscodeIPNet(t *testing.T) { } for i, tt := range tests { - var actual []net.IPNet + var actual []*net.IPNet err := conn.QueryRow(tt.sql, tt.value).Scan(&actual) if err != nil { @@ -485,18 +485,18 @@ func TestInetCidrArrayTranscodeIP(t *testing.T) { failTests := []struct { sql string - value []net.IPNet + value []*net.IPNet }{ { "select $1::inet[]", - []net.IPNet{ + []*net.IPNet{ mustParseCIDR(t, "12.34.56.0/32"), mustParseCIDR(t, "192.168.1.0/24"), }, }, { "select $1::cidr[]", - []net.IPNet{ + []*net.IPNet{ mustParseCIDR(t, "12.34.56.0/32"), mustParseCIDR(t, "192.168.1.0/24"), }, @@ -507,8 +507,8 @@ func TestInetCidrArrayTranscodeIP(t *testing.T) { var actual []net.IP err := conn.QueryRow(tt.sql, tt.value).Scan(&actual) - if err == nil || !strings.Contains(err.Error(), "Cannot decode netmask") { - t.Errorf("%d. Expected failure cannot decode netmask, but got: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + if err == nil { + t.Errorf("%d. Expected failure but got none", i) continue } From 005916166a32d95af93db636f3b9e1a6097e0a2d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Mar 2017 17:43:12 -0600 Subject: [PATCH 079/264] Remove behavior migrated to pgtype --- query_test.go | 64 +-------------------------------------------------- values.go | 19 --------------- 2 files changed, 1 insertion(+), 82 deletions(-) diff --git a/query_test.go b/query_test.go index 9b6f9a06..364e6b57 100644 --- a/query_test.go +++ b/query_test.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "database/sql" - "fmt" "strings" "testing" "time" @@ -309,67 +308,6 @@ func TestConnQueryScanner(t *testing.T) { ensureConnValid(t, conn) } -type pgxNullInt64 struct { - Int64 int64 - Valid bool // Valid is true if Int64 is not NULL -} - -func (n *pgxNullInt64) ScanPgx(vr *pgx.ValueReader) error { - if vr.Type().DataType != pgx.Int8OID { - return pgx.SerializationError(fmt.Sprintf("pgxNullInt64.Scan cannot decode OID %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Int64, n.Valid = 0, false - return nil - } - n.Valid = true - - err := pgx.Decode(vr, &n.Int64) - if err != nil { - return err - } - return vr.Err() -} - -func TestConnQueryPgxScanner(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - rows, err := conn.Query("select null::int8, 1::int8") - if err != nil { - t.Fatalf("conn.Query failed: %v", err) - } - - ok := rows.Next() - if !ok { - t.Fatal("rows.Next terminated early") - } - - var n, m pgxNullInt64 - err = rows.Scan(&n, &m) - if err != nil { - t.Fatalf("rows.Scan failed: %v", err) - } - rows.Close() - - if n.Valid { - t.Error("Null should not be valid, but it was") - } - - if !m.Valid { - t.Error("1 should be valid, but it wasn't") - } - - if m.Int64 != 1 { - t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64) - } - - ensureConnValid(t, conn) -} - func TestConnQueryErrorWhileReturningRows(t *testing.T) { t.Parallel() @@ -942,7 +880,7 @@ func TestQueryRowErrors(t *testing.T) { {"select $1", []interface{}{"Jack"}, []interface{}{&actual.i16}, "could not determine data type of parameter $1 (SQLSTATE 42P18)"}, {"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`}, {"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"}, - {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "Cannot decode oid 25 into any integer type"}, + {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "cannot decode"}, {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "Cannot encode int into oid 600"}, } diff --git a/values.go b/values.go index 8a2da367..074f7432 100644 --- a/values.go +++ b/values.go @@ -1076,8 +1076,6 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { switch arg := arg.(type) { case []string: return encodeStringSlice(wbuf, oid, arg) - case []bool: - return encodeBoolSlice(wbuf, oid, arg) case Char: return encodeChar(wbuf, oid, arg) case AclItem: @@ -1207,23 +1205,6 @@ func Decode(vr *ValueReader, d interface{}) error { } d = el.Interface() return Decode(vr, d) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - n := decodeInt(vr) - if el.OverflowInt(n) { - return fmt.Errorf("Scan cannot decode %d into %T", n, d) - } - el.SetInt(n) - return nil - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - n := decodeInt(vr) - if n < 0 { - return fmt.Errorf("%d is less than zero for %T", n, d) - } - if el.OverflowUint(uint64(n)) { - return fmt.Errorf("Scan cannot decode %d into %T", n, d) - } - el.SetUint(uint64(n)) - return nil case reflect.String: el.SetString(decodeText(vr)) return nil From b1fc8109db7e4434a33e8ff623e03044189188cb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Mar 2017 18:00:51 -0600 Subject: [PATCH 080/264] Remove AF_INET fetching system Also remove old encode/decode inet/cidr code. This removed some functionality from Rows.Values, but that entire system will soon change anyway. --- conn.go | 39 +---------- conn_pool.go | 6 +- query.go | 4 -- values.go | 183 --------------------------------------------------- 4 files changed, 3 insertions(+), 229 deletions(-) diff --git a/conn.go b/conn.go index b6670735..19833dc0 100644 --- a/conn.go +++ b/conn.go @@ -87,8 +87,6 @@ type Conn struct { logLevel int mr msgReader fp *fastpath - pgsqlAfInet *byte - pgsqlAfInet6 *byte poolResetCount int preallocatedRows []Rows @@ -179,10 +177,10 @@ func (e ProtocolError) Error() string { // config.Host must be specified. config.User will default to the OS user name. // Other config fields are optional. func Connect(config ConnConfig) (c *Conn, err error) { - return connect(config, nil, nil, nil) + return connect(config, nil) } -func connect(config ConnConfig, pgTypes map[OID]PgType, pgsqlAfInet *byte, pgsqlAfInet6 *byte) (c *Conn, err error) { +func connect(config ConnConfig, pgTypes map[OID]PgType) (c *Conn, err error) { c = new(Conn) c.config = config @@ -194,15 +192,6 @@ func connect(config ConnConfig, pgTypes map[OID]PgType, pgsqlAfInet *byte, pgsql } } - if pgsqlAfInet != nil { - c.pgsqlAfInet = new(byte) - *c.pgsqlAfInet = *pgsqlAfInet - } - if pgsqlAfInet6 != nil { - c.pgsqlAfInet6 = new(byte) - *c.pgsqlAfInet6 = *pgsqlAfInet6 - } - if c.config.LogLevel != 0 { c.logLevel = c.config.LogLevel } else { @@ -372,13 +361,6 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl } } - if c.pgsqlAfInet == nil || c.pgsqlAfInet6 == nil { - err = c.loadInetConstants() - if err != nil { - return err - } - } - return nil default: if err = c.processContextFreeMsg(t, r); err != nil { @@ -418,23 +400,6 @@ where ( return rows.Err() } -// Family is needed for binary encoding of inet/cidr. The constant is based on -// the server's definition of AF_INET. In theory, this could differ between -// platforms, so request an IPv4 and an IPv6 inet and get the family from that. -func (c *Conn) loadInetConstants() error { - var ipv4, ipv6 []byte - - err := c.QueryRow("select '127.0.0.1'::inet, '1::'::inet").Scan(&ipv4, &ipv6) - if err != nil { - return err - } - - c.pgsqlAfInet = &ipv4[0] - c.pgsqlAfInet6 = &ipv6[0] - - return nil -} - // PID returns the backend PID for this connection. func (c *Conn) PID() int32 { return c.pid diff --git a/conn_pool.go b/conn_pool.go index 9dfbf734..fd632006 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -29,8 +29,6 @@ type ConnPool struct { preparedStatements map[string]*PreparedStatement acquireTimeout time.Duration pgTypes map[OID]PgType - pgsqlAfInet *byte - pgsqlAfInet6 *byte txAfterClose func(tx *Tx) rowsAfterClose func(rows *Rows) } @@ -294,7 +292,7 @@ func (p *ConnPool) Stat() (s ConnPoolStat) { } func (p *ConnPool) createConnection() (*Conn, error) { - c, err := connect(p.config, p.pgTypes, p.pgsqlAfInet, p.pgsqlAfInet6) + c, err := connect(p.config, p.pgTypes) if err != nil { return nil, err } @@ -330,8 +328,6 @@ func (p *ConnPool) createConnectionUnlocked() (*Conn, error) { // all the known statements for the new connection. func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) { p.pgTypes = c.PgTypes - p.pgsqlAfInet = c.pgsqlAfInet - p.pgsqlAfInet6 = c.pgsqlAfInet6 if p.afterConnect != nil { err := p.afterConnect(c) diff --git a/query.go b/query.go index 52643b8d..9019fca4 100644 --- a/query.go +++ b/query.go @@ -410,8 +410,6 @@ func (rows *Rows) Values() ([]interface{}, error) { values = append(values, decodeTimestampTz(vr)) case TimestampOID: values = append(values, decodeTimestamp(vr)) - case InetOID, CidrOID: - values = append(values, decodeInet(vr)) case JSONOID: var d interface{} decodeJSON(vr, &d) @@ -503,8 +501,6 @@ func (rows *Rows) ValuesForStdlib() ([]interface{}, error) { values = append(values, decodeTimestampTz(vr)) case TimestampOID: values = append(values, decodeTimestamp(vr)) - case InetOID, CidrOID: - values = append(values, decodeInet(vr)) case JSONOID: var d interface{} decodeJSON(vr, &d) diff --git a/values.go b/values.go index 074f7432..3d7d63a2 100644 --- a/values.go +++ b/values.go @@ -7,7 +7,6 @@ import ( "fmt" "io" "math" - "net" "reflect" "regexp" "strconv" @@ -1934,82 +1933,6 @@ func decodeTimestamp(vr *ValueReader) time.Time { return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) } -func decodeInet(vr *ValueReader) net.IPNet { - var zero net.IPNet - - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into net.IPNet")) - return zero - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return zero - } - - pgType := vr.Type() - if pgType.DataType != InetOID && pgType.DataType != CidrOID { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into %s", pgType.DataType, pgType.Name))) - return zero - } - if vr.Len() != 8 && vr.Len() != 20 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for a %s: %d", pgType.Name, vr.Len()))) - return zero - } - - vr.ReadByte() // ignore family - bits := vr.ReadByte() - vr.ReadByte() // ignore is_cidr - addressLength := vr.ReadByte() - - var ipnet net.IPNet - ipnet.IP = vr.ReadBytes(int32(addressLength)) - ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8) - - return ipnet -} - -func encodeIPNet(w *WriteBuf, oid OID, value net.IPNet) error { - if oid != InetOID && oid != CidrOID { - return fmt.Errorf("cannot encode %s into oid %v", "net.IPNet", oid) - } - - var size int32 - var family byte - switch len(value.IP) { - case net.IPv4len: - size = 8 - family = *w.conn.pgsqlAfInet - case net.IPv6len: - size = 20 - family = *w.conn.pgsqlAfInet6 - default: - return fmt.Errorf("Unexpected IP length: %v", len(value.IP)) - } - - w.WriteInt32(size) - w.WriteByte(family) - ones, _ := value.Mask.Size() - w.WriteByte(byte(ones)) - w.WriteByte(0) // is_cidr is ignored on server - w.WriteByte(byte(len(value.IP))) - w.WriteBytes(value.IP) - - return nil -} - -func encodeIP(w *WriteBuf, oid OID, value net.IP) error { - if oid != InetOID && oid != CidrOID { - return fmt.Errorf("cannot encode %s into oid %v", "net.IP", oid) - } - - var ipnet net.IPNet - ipnet.IP = value - bitCount := len(value) * 8 - ipnet.Mask = net.CIDRMask(bitCount, bitCount) - return encodeIPNet(w, oid, ipnet) -} - func decodeRecord(vr *ValueReader) []interface{} { if vr.Len() == -1 { return nil @@ -2058,8 +1981,6 @@ func decodeRecord(vr *ValueReader) []interface{} { record = append(record, decodeTimestampTz(&fieldVR)) case TimestampOID: record = append(record, decodeTimestamp(&fieldVR)) - case InetOID, CidrOID: - record = append(record, decodeInet(&fieldVR)) case TextOID, VarcharOID, UnknownOID: record = append(record, decodeTextAllowBinary(&fieldVR)) default: @@ -2983,110 +2904,6 @@ func encodeTimeSlice(w *WriteBuf, oid OID, slice []time.Time) error { return nil } -func decodeInetArray(vr *ValueReader) []net.IPNet { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != InetArrayOID && vr.Type().DataType != CidrArrayOID { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []net.IP", vr.Type().DataType))) - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - numElems, err := decode1dArrayHeader(vr) - if err != nil { - vr.Fatal(err) - return nil - } - - a := make([]net.IPNet, int(numElems)) - for i := 0; i < len(a); i++ { - elSize := vr.ReadInt32() - if elSize == -1 { - vr.Fatal(ProtocolError("Cannot decode null element")) - return nil - } - - vr.ReadByte() // ignore family - bits := vr.ReadByte() - vr.ReadByte() // ignore is_cidr - addressLength := vr.ReadByte() - - var ipnet net.IPNet - ipnet.IP = vr.ReadBytes(int32(addressLength)) - ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8) - - a[i] = ipnet - } - - return a -} - -func encodeIPNetSlice(w *WriteBuf, oid OID, slice []net.IPNet) error { - var elOID OID - switch oid { - case InetArrayOID: - elOID = InetOID - case CidrArrayOID: - elOID = CidrOID - default: - return fmt.Errorf("cannot encode Go %s into oid %d", "[]net.IPNet", oid) - } - - size := int32(20) // array header size - for _, ipnet := range slice { - size += 4 + 4 + int32(len(ipnet.IP)) // size of element + inet/cidr metadata + IP bytes - } - w.WriteInt32(int32(size)) - - w.WriteInt32(1) // number of dimensions - w.WriteInt32(0) // no nulls - w.WriteInt32(int32(elOID)) // type of elements - w.WriteInt32(int32(len(slice))) // number of elements - w.WriteInt32(1) // index of first element - - for _, ipnet := range slice { - encodeIPNet(w, elOID, ipnet) - } - - return nil -} - -func encodeIPSlice(w *WriteBuf, oid OID, slice []net.IP) error { - var elOID OID - switch oid { - case InetArrayOID: - elOID = InetOID - case CidrArrayOID: - elOID = CidrOID - default: - return fmt.Errorf("cannot encode Go %s into oid %d", "[]net.IPNet", oid) - } - - size := int32(20) // array header size - for _, ip := range slice { - size += 4 + 4 + int32(len(ip)) // size of element + inet/cidr metadata + IP bytes - } - w.WriteInt32(int32(size)) - - w.WriteInt32(1) // number of dimensions - w.WriteInt32(0) // no nulls - w.WriteInt32(int32(elOID)) // type of elements - w.WriteInt32(int32(len(slice))) // number of elements - w.WriteInt32(1) // index of first element - - for _, ip := range slice { - encodeIP(w, elOID, ip) - } - - return nil -} - func encodeArrayHeader(w *WriteBuf, oid, length, sizePerItem int) { w.WriteInt32(int32(20 + length*sizePerItem)) w.WriteInt32(1) // number of dimensions From fa57904d6b0b448716f1112fb2c13118f741cfab Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Mar 2017 21:20:56 -0600 Subject: [PATCH 081/264] Add text to pgtype --- conn.go | 4 + pgtype/array_test.go | 7 + pgtype/bool.go | 2 +- pgtype/boolarray.go | 10 +- pgtype/convert.go | 19 +++ pgtype/datearray.go | 2 +- pgtype/float4array.go | 10 +- pgtype/float8array.go | 10 +- pgtype/inetarray.go | 2 +- pgtype/int2array.go | 14 +- pgtype/int4array.go | 14 +- pgtype/int8array.go | 14 +- pgtype/pgtype_test.go | 1 + pgtype/text.go | 115 ++++++++++++++ pgtype/text_test.go | 100 +++++++++++++ pgtype/textarray.go | 297 +++++++++++++++++++++++++++++++++++++ pgtype/textarray_test.go | 151 +++++++++++++++++++ pgtype/timestamparray.go | 2 +- pgtype/timestamptzarray.go | 2 +- pgtype/typed_array.go.erb | 2 +- pgtype/typed_array_gen.sh | 1 + pgtype/varchararray.go | 31 ++++ query.go | 32 ---- query_test.go | 3 - values.go | 72 --------- 25 files changed, 768 insertions(+), 149 deletions(-) create mode 100644 pgtype/text.go create mode 100644 pgtype/text_test.go create mode 100644 pgtype/textarray.go create mode 100644 pgtype/textarray_test.go create mode 100644 pgtype/varchararray.go diff --git a/conn.go b/conn.go index 19833dc0..d97942aa 100644 --- a/conn.go +++ b/conn.go @@ -286,10 +286,14 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl Int4OID: &pgtype.Int4{}, Int8ArrayOID: &pgtype.Int8Array{}, Int8OID: &pgtype.Int8{}, + TextArrayOID: &pgtype.TextArray{}, + TextOID: &pgtype.Text{}, TimestampArrayOID: &pgtype.TimestampArray{}, TimestampOID: &pgtype.Timestamp{}, TimestampTzArrayOID: &pgtype.TimestamptzArray{}, TimestampTzOID: &pgtype.Timestamptz{}, + VarcharArrayOID: &pgtype.VarcharArray{}, + VarcharOID: &pgtype.Text{}, } if tlsConfig != nil { diff --git a/pgtype/array_test.go b/pgtype/array_test.go index 5e5f00e7..d1cdb4c5 100644 --- a/pgtype/array_test.go +++ b/pgtype/array_test.go @@ -40,6 +40,13 @@ func TestParseUntypedTextArray(t *testing.T) { Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, }, }, + { + source: `{""}`, + result: pgtype.UntypedTextArray{ + Elements: []string{""}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, { source: `{"He said, \"Hello.\""}`, result: pgtype.UntypedTextArray{ diff --git a/pgtype/bool.go b/pgtype/bool.go index 2889b787..076403f9 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -66,7 +66,7 @@ func (src *Bool) AssignTo(dst interface{}) error { return nil } } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/pgtype/boolarray.go b/pgtype/boolarray.go index 8dd68dc2..b6b5db02 100644 --- a/pgtype/boolarray.go +++ b/pgtype/boolarray.go @@ -18,7 +18,7 @@ func (dst *BoolArray) ConvertFrom(src interface{}) error { switch value := src.(type) { case BoolArray: *dst = value - + case []bool: if value == nil { *dst = BoolArray{Status: Null} @@ -37,7 +37,7 @@ func (dst *BoolArray) ConvertFrom(src interface{}) error { Status: Present, } } - + default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.ConvertFrom(originalSrc) @@ -50,7 +50,7 @@ func (dst *BoolArray) ConvertFrom(src interface{}) error { func (src *BoolArray) AssignTo(dst interface{}) error { switch v := dst.(type) { - + case *[]bool: if src.Status == Present { *v = make([]bool, len(src.Elements)) @@ -62,12 +62,12 @@ func (src *BoolArray) AssignTo(dst interface{}) error { } else { *v = nil } - + default: if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/pgtype/convert.go b/pgtype/convert.go index 7111f8bc..31bbf060 100644 --- a/pgtype/convert.go +++ b/pgtype/convert.go @@ -85,6 +85,25 @@ func underlyingBoolType(val interface{}) (interface{}, bool) { return nil, false } +// underlyingStringType gets the underlying type that can be converted to String +func underlyingStringType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.String: + convVal := refVal.String() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + + return nil, false +} + // underlyingPtrType dereferences a pointer func underlyingPtrType(val interface{}) (interface{}, bool) { refVal := reflect.ValueOf(val) diff --git a/pgtype/datearray.go b/pgtype/datearray.go index 877f328e..5e93501e 100644 --- a/pgtype/datearray.go +++ b/pgtype/datearray.go @@ -68,7 +68,7 @@ func (src *DateArray) AssignTo(dst interface{}) error { if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/pgtype/float4array.go b/pgtype/float4array.go index c06490cf..8834d213 100644 --- a/pgtype/float4array.go +++ b/pgtype/float4array.go @@ -18,7 +18,7 @@ func (dst *Float4Array) ConvertFrom(src interface{}) error { switch value := src.(type) { case Float4Array: *dst = value - + case []float32: if value == nil { *dst = Float4Array{Status: Null} @@ -37,7 +37,7 @@ func (dst *Float4Array) ConvertFrom(src interface{}) error { Status: Present, } } - + default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.ConvertFrom(originalSrc) @@ -50,7 +50,7 @@ func (dst *Float4Array) ConvertFrom(src interface{}) error { func (src *Float4Array) AssignTo(dst interface{}) error { switch v := dst.(type) { - + case *[]float32: if src.Status == Present { *v = make([]float32, len(src.Elements)) @@ -62,12 +62,12 @@ func (src *Float4Array) AssignTo(dst interface{}) error { } else { *v = nil } - + default: if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/pgtype/float8array.go b/pgtype/float8array.go index 776fc1e6..bad9ed9f 100644 --- a/pgtype/float8array.go +++ b/pgtype/float8array.go @@ -18,7 +18,7 @@ func (dst *Float8Array) ConvertFrom(src interface{}) error { switch value := src.(type) { case Float8Array: *dst = value - + case []float64: if value == nil { *dst = Float8Array{Status: Null} @@ -37,7 +37,7 @@ func (dst *Float8Array) ConvertFrom(src interface{}) error { Status: Present, } } - + default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.ConvertFrom(originalSrc) @@ -50,7 +50,7 @@ func (dst *Float8Array) ConvertFrom(src interface{}) error { func (src *Float8Array) AssignTo(dst interface{}) error { switch v := dst.(type) { - + case *[]float64: if src.Status == Present { *v = make([]float64, len(src.Elements)) @@ -62,12 +62,12 @@ func (src *Float8Array) AssignTo(dst interface{}) error { } else { *v = nil } - + default: if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/pgtype/inetarray.go b/pgtype/inetarray.go index eb5a4c88..cd12e917 100644 --- a/pgtype/inetarray.go +++ b/pgtype/inetarray.go @@ -97,7 +97,7 @@ func (src *InetArray) AssignTo(dst interface{}) error { if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/pgtype/int2array.go b/pgtype/int2array.go index 4fc6d882..a989347d 100644 --- a/pgtype/int2array.go +++ b/pgtype/int2array.go @@ -18,7 +18,7 @@ func (dst *Int2Array) ConvertFrom(src interface{}) error { switch value := src.(type) { case Int2Array: *dst = value - + case []int16: if value == nil { *dst = Int2Array{Status: Null} @@ -37,7 +37,7 @@ func (dst *Int2Array) ConvertFrom(src interface{}) error { Status: Present, } } - + case []uint16: if value == nil { *dst = Int2Array{Status: Null} @@ -56,7 +56,7 @@ func (dst *Int2Array) ConvertFrom(src interface{}) error { Status: Present, } } - + default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.ConvertFrom(originalSrc) @@ -69,7 +69,7 @@ func (dst *Int2Array) ConvertFrom(src interface{}) error { func (src *Int2Array) AssignTo(dst interface{}) error { switch v := dst.(type) { - + case *[]int16: if src.Status == Present { *v = make([]int16, len(src.Elements)) @@ -81,7 +81,7 @@ func (src *Int2Array) AssignTo(dst interface{}) error { } else { *v = nil } - + case *[]uint16: if src.Status == Present { *v = make([]uint16, len(src.Elements)) @@ -93,12 +93,12 @@ func (src *Int2Array) AssignTo(dst interface{}) error { } else { *v = nil } - + default: if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/pgtype/int4array.go b/pgtype/int4array.go index 40e1490d..89caf263 100644 --- a/pgtype/int4array.go +++ b/pgtype/int4array.go @@ -18,7 +18,7 @@ func (dst *Int4Array) ConvertFrom(src interface{}) error { switch value := src.(type) { case Int4Array: *dst = value - + case []int32: if value == nil { *dst = Int4Array{Status: Null} @@ -37,7 +37,7 @@ func (dst *Int4Array) ConvertFrom(src interface{}) error { Status: Present, } } - + case []uint32: if value == nil { *dst = Int4Array{Status: Null} @@ -56,7 +56,7 @@ func (dst *Int4Array) ConvertFrom(src interface{}) error { Status: Present, } } - + default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.ConvertFrom(originalSrc) @@ -69,7 +69,7 @@ func (dst *Int4Array) ConvertFrom(src interface{}) error { func (src *Int4Array) AssignTo(dst interface{}) error { switch v := dst.(type) { - + case *[]int32: if src.Status == Present { *v = make([]int32, len(src.Elements)) @@ -81,7 +81,7 @@ func (src *Int4Array) AssignTo(dst interface{}) error { } else { *v = nil } - + case *[]uint32: if src.Status == Present { *v = make([]uint32, len(src.Elements)) @@ -93,12 +93,12 @@ func (src *Int4Array) AssignTo(dst interface{}) error { } else { *v = nil } - + default: if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/pgtype/int8array.go b/pgtype/int8array.go index 35ecf946..003ed055 100644 --- a/pgtype/int8array.go +++ b/pgtype/int8array.go @@ -18,7 +18,7 @@ func (dst *Int8Array) ConvertFrom(src interface{}) error { switch value := src.(type) { case Int8Array: *dst = value - + case []int64: if value == nil { *dst = Int8Array{Status: Null} @@ -37,7 +37,7 @@ func (dst *Int8Array) ConvertFrom(src interface{}) error { Status: Present, } } - + case []uint64: if value == nil { *dst = Int8Array{Status: Null} @@ -56,7 +56,7 @@ func (dst *Int8Array) ConvertFrom(src interface{}) error { Status: Present, } } - + default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.ConvertFrom(originalSrc) @@ -69,7 +69,7 @@ func (dst *Int8Array) ConvertFrom(src interface{}) error { func (src *Int8Array) AssignTo(dst interface{}) error { switch v := dst.(type) { - + case *[]int64: if src.Status == Present { *v = make([]int64, len(src.Elements)) @@ -81,7 +81,7 @@ func (src *Int8Array) AssignTo(dst interface{}) error { } else { *v = nil } - + case *[]uint64: if src.Status == Present { *v = make([]uint64, len(src.Elements)) @@ -93,12 +93,12 @@ func (src *Int8Array) AssignTo(dst interface{}) error { } else { *v = nil } - + default: if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 7d34ae34..304fd0ea 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -13,6 +13,7 @@ import ( ) // Test for renamed types +type _string string type _bool bool type _int8 int8 type _int16 int16 diff --git a/pgtype/text.go b/pgtype/text.go new file mode 100644 index 00000000..c9054468 --- /dev/null +++ b/pgtype/text.go @@ -0,0 +1,115 @@ +package pgtype + +import ( + "fmt" + "io" + "reflect" + + "github.com/jackc/pgx/pgio" +) + +type Text struct { + String string + Status Status +} + +func (dst *Text) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Text: + *dst = value + case string: + *dst = Text{String: value, Status: Present} + case *string: + if value == nil { + *dst = Text{Status: Null} + } else { + *dst = Text{String: *value, Status: Present} + } + default: + if originalSrc, ok := underlyingStringType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Text", value) + } + + return nil +} + +func (src *Text) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *string: + if src.Status != Present { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.String + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if src.Status == Null { + el.Set(reflect.Zero(el.Type())) + return nil + } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return src.AssignTo(el.Interface()) + case reflect.String: + if src.Status != Present { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + el.SetString(src.String) + return nil + } + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +func (dst *Text) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Text{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + *dst = Text{String: string(buf), Status: Present} + return nil +} + +func (dst *Text) DecodeBinary(r io.Reader) error { + return dst.DecodeText(r) +} + +func (src Text) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, int32(len(src.String))) + if err != nil { + return nil + } + + _, err = io.WriteString(w, src.String) + return err +} + +func (src Text) EncodeBinary(w io.Writer) error { + return src.EncodeText(w) +} diff --git a/pgtype/text_test.go b/pgtype/text_test.go new file mode 100644 index 00000000..6e944857 --- /dev/null +++ b/pgtype/text_test.go @@ -0,0 +1,100 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestTextTranscode(t *testing.T) { + for _, pgTypeName := range []string{"text", "varchar"} { + testSuccessfulTranscode(t, pgTypeName, []interface{}{ + pgtype.Text{String: "", Status: pgtype.Present}, + pgtype.Text{String: "foo", Status: pgtype.Present}, + pgtype.Text{Status: pgtype.Null}, + }) + } +} + +func TestTextConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Text + }{ + {source: pgtype.Text{String: "foo", Status: pgtype.Present}, result: pgtype.Text{String: "foo", Status: pgtype.Present}}, + {source: "foo", result: pgtype.Text{String: "foo", Status: pgtype.Present}}, + {source: _string("bar"), result: pgtype.Text{String: "bar", Status: pgtype.Present}}, + {source: (*string)(nil), result: pgtype.Text{Status: pgtype.Null}}, + } + + for i, tt := range successfulTests { + var d pgtype.Text + err := d.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if d != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestTextAssignTo(t *testing.T) { + var s string + var ps *string + + simpleTests := []struct { + src pgtype.Text + dst interface{} + expected interface{} + }{ + {src: pgtype.Text{String: "foo", Status: pgtype.Present}, dst: &s, expected: "foo"}, + {src: pgtype.Text{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Text + dst interface{} + expected interface{} + }{ + {src: pgtype.Text{String: "foo", Status: pgtype.Present}, dst: &ps, expected: "foo"}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Text + dst interface{} + }{ + {src: pgtype.Text{Status: pgtype.Null}, dst: &s}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/textarray.go b/pgtype/textarray.go new file mode 100644 index 00000000..c420e5c9 --- /dev/null +++ b/pgtype/textarray.go @@ -0,0 +1,297 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type TextArray struct { + Elements []Text + Dimensions []ArrayDimension + Status Status +} + +func (dst *TextArray) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case TextArray: + *dst = value + + case []string: + if value == nil { + *dst = TextArray{Status: Null} + } else if len(value) == 0 { + *dst = TextArray{Status: Present} + } else { + elements := make([]Text, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = TextArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Text", value) + } + + return nil +} + +func (src *TextArray) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[]string: + if src.Status == Present { + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +func (dst *TextArray) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = TextArray{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + uta, err := ParseUntypedTextArray(string(buf)) + if err != nil { + return err + } + + textElementReader := NewTextElementReader(r) + var elements []Text + + if len(uta.Elements) > 0 { + elements = make([]Text, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Text + textElementReader.Reset(s) + err = elem.DecodeText(textElementReader) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = TextArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *TextArray) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = TextArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + err = arrayHeader.DecodeBinary(r) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = TextArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Text, elementCount) + + for i := range elements { + err = elements[i].DecodeBinary(r) + if err != nil { + return err + } + } + + *dst = TextArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *TextArray) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + if len(src.Dimensions) == 0 { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = w.Write([]byte("{}")) + return err + } + + buf := &bytes.Buffer{} + + err := EncodeTextArrayDimensions(buf, src.Dimensions) + if err != nil { + return err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + textElementWriter := NewTextElementWriter(buf) + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(buf, ',') + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(buf, '{') + if err != nil { + return err + } + } + } + + textElementWriter.Reset() + if elem.String == "" && elem.Status == Present { + _, err := io.WriteString(buf, `""`) + if err != nil { + return err + } + } else { + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(buf, '}') + if err != nil { + return err + } + } + } + } + + _, err = pgio.WriteInt32(w, int32(buf.Len())) + if err != nil { + return err + } + + _, err = buf.WriteTo(w) + return err +} + +func (src *TextArray) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, TextOID) +} + +func (src *TextArray) encodeBinary(w io.Writer, elementOID int32) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + var arrayHeader ArrayHeader + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + err := src.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return err + } + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + } + } + + arrayHeader.ElementOID = elementOID + arrayHeader.Dimensions = src.Dimensions + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + headerBuf := &bytes.Buffer{} + err := arrayHeader.EncodeBinary(headerBuf) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) + if err != nil { + return err + } + + _, err = headerBuf.WriteTo(w) + if err != nil { + return err + } + + _, err = elemBuf.WriteTo(w) + if err != nil { + return err + } + + return err +} diff --git a/pgtype/textarray_test.go b/pgtype/textarray_test.go new file mode 100644 index 00000000..29e3a6c7 --- /dev/null +++ b/pgtype/textarray_test.go @@ -0,0 +1,151 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestTextArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "text[]", []interface{}{ + &pgtype.TextArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.TextArray{ + Elements: []pgtype.Text{ + pgtype.Text{String: "foo", Status: pgtype.Present}, + pgtype.Text{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.TextArray{Status: pgtype.Null}, + &pgtype.TextArray{ + Elements: []pgtype.Text{ + pgtype.Text{String: "bar", Status: pgtype.Present}, + pgtype.Text{String: "baz", Status: pgtype.Present}, + pgtype.Text{String: "quz", Status: pgtype.Present}, + pgtype.Text{String: "", Status: pgtype.Present}, + pgtype.Text{Status: pgtype.Null}, + pgtype.Text{String: "foo", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.TextArray{ + Elements: []pgtype.Text{ + pgtype.Text{String: "bar", Status: pgtype.Present}, + pgtype.Text{String: "baz", Status: pgtype.Present}, + pgtype.Text{String: "quz", Status: pgtype.Present}, + pgtype.Text{String: "foo", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestTextArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.TextArray + }{ + { + source: []string{"foo"}, + result: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]string)(nil)), + result: pgtype.TextArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.TextArray + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestTextArrayAssignTo(t *testing.T) { + var stringSlice []string + type _stringSlice []string + var namedStringSlice _stringSlice + + simpleTests := []struct { + src pgtype.TextArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + expected: []string{"foo"}, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedStringSlice, + expected: _stringSlice{"bar"}, + }, + { + src: pgtype.TextArray{Status: pgtype.Null}, + dst: &stringSlice, + expected: (([]string)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.TextArray + dst interface{} + }{ + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/timestamparray.go b/pgtype/timestamparray.go index f1b1d003..3acbb35f 100644 --- a/pgtype/timestamparray.go +++ b/pgtype/timestamparray.go @@ -68,7 +68,7 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/pgtype/timestamptzarray.go b/pgtype/timestamptzarray.go index 72b28e43..9df746e6 100644 --- a/pgtype/timestamptzarray.go +++ b/pgtype/timestamptzarray.go @@ -68,7 +68,7 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb index e6e480b0..647ed7c0 100644 --- a/pgtype/typed_array.go.erb +++ b/pgtype/typed_array.go.erb @@ -67,7 +67,7 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh index 47afdf1d..f984e12e 100644 --- a/pgtype/typed_array_gen.sh +++ b/pgtype/typed_array_gen.sh @@ -8,3 +8,4 @@ erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_type erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_oid=Float4OID typed_array.go.erb > float4array.go erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8OID typed_array.go.erb > float8array.go erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOID typed_array.go.erb > inetarray.go +erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_oid=TextOID typed_array.go.erb > textarray.go diff --git a/pgtype/varchararray.go b/pgtype/varchararray.go new file mode 100644 index 00000000..13d94bc0 --- /dev/null +++ b/pgtype/varchararray.go @@ -0,0 +1,31 @@ +package pgtype + +import ( + "io" +) + +type VarcharArray TextArray + +func (dst *VarcharArray) ConvertFrom(src interface{}) error { + return (*TextArray)(dst).ConvertFrom(src) +} + +func (src *VarcharArray) AssignTo(dst interface{}) error { + return (*TextArray)(src).AssignTo(dst) +} + +func (dst *VarcharArray) DecodeText(r io.Reader) error { + return (*TextArray)(dst).DecodeText(r) +} + +func (dst *VarcharArray) DecodeBinary(r io.Reader) error { + return (*TextArray)(dst).DecodeBinary(r) +} + +func (src *VarcharArray) EncodeText(w io.Writer) error { + return (*TextArray)(src).EncodeText(w) +} + +func (src *VarcharArray) EncodeBinary(w io.Writer) error { + return (*TextArray)(src).encodeBinary(w, VarcharOID) +} diff --git a/query.go b/query.go index 9019fca4..ffe51ecc 100644 --- a/query.go +++ b/query.go @@ -388,22 +388,6 @@ func (rows *Rows) Values() ([]interface{}, error) { values = append(values, decodeFloat4(vr)) case Float8OID: values = append(values, decodeFloat8(vr)) - case BoolArrayOID: - values = append(values, decodeBoolArray(vr)) - case Int2ArrayOID: - values = append(values, decodeInt2Array(vr)) - case Int4ArrayOID: - values = append(values, decodeInt4Array(vr)) - case Int8ArrayOID: - values = append(values, decodeInt8Array(vr)) - case Float4ArrayOID: - values = append(values, decodeFloat4Array(vr)) - case Float8ArrayOID: - values = append(values, decodeFloat8Array(vr)) - case TextArrayOID, VarcharArrayOID: - values = append(values, decodeTextArray(vr)) - case TimestampArrayOID, TimestampTzArrayOID: - values = append(values, decodeTimestampArray(vr)) case DateOID: values = append(values, decodeDate(vr)) case TimestampTzOID: @@ -479,22 +463,6 @@ func (rows *Rows) ValuesForStdlib() ([]interface{}, error) { values = append(values, decodeFloat4(vr)) case Float8OID: values = append(values, decodeFloat8(vr)) - case BoolArrayOID: - values = append(values, decodeBoolArray(vr)) - case Int2ArrayOID: - values = append(values, decodeInt2Array(vr)) - case Int4ArrayOID: - values = append(values, decodeInt4Array(vr)) - case Int8ArrayOID: - values = append(values, decodeInt8Array(vr)) - case Float4ArrayOID: - values = append(values, decodeFloat4Array(vr)) - case Float8ArrayOID: - values = append(values, decodeFloat8Array(vr)) - case TextArrayOID, VarcharArrayOID: - values = append(values, decodeTextArray(vr)) - case TimestampArrayOID, TimestampTzArrayOID: - values = append(values, decodeTimestampArray(vr)) case DateOID: values = append(values, decodeDate(vr)) case TimestampTzOID: diff --git a/query_test.go b/query_test.go index 364e6b57..801ba851 100644 --- a/query_test.go +++ b/query_test.go @@ -1179,9 +1179,6 @@ func TestQueryRowCoreStringSlice(t *testing.T) { if err == nil { t.Error("Expected null to cause error when scanned into slice, but it didn't") } - if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { - t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err) - } ensureConnValid(t, conn) } diff --git a/values.go b/values.go index 3d7d63a2..c011c8cf 100644 --- a/values.go +++ b/values.go @@ -1073,8 +1073,6 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { } switch arg := arg.(type) { - case []string: - return encodeStringSlice(wbuf, oid, arg) case Char: return encodeChar(wbuf, oid, arg) case AclItem: @@ -1178,8 +1176,6 @@ func Decode(vr *ValueReader, d interface{}) error { *v = decodeText(vr) case *[]AclItem: *v = decodeAclItemArray(vr) - case *[]string: - *v = decodeTextArray(vr) case *[][]byte: *v = decodeByteaArray(vr) case *[]interface{}: @@ -2569,41 +2565,6 @@ func encodeFloat64Slice(w *WriteBuf, oid OID, slice []float64) error { return nil } -func decodeTextArray(vr *ValueReader) []string { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != TextArrayOID && vr.Type().DataType != VarcharArrayOID { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []string", vr.Type().DataType))) - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - numElems, err := decode1dArrayHeader(vr) - if err != nil { - vr.Fatal(err) - return nil - } - - a := make([]string, int(numElems)) - for i := 0; i < len(a); i++ { - elSize := vr.ReadInt32() - if elSize == -1 { - vr.Fatal(ProtocolError("Cannot decode null element")) - return nil - } - - a[i] = vr.ReadString(elSize) - } - - return a -} - // escapeAclItem escapes an AclItem before it is added to // its aclitem[] string representation. The PostgreSQL aclitem // datatype itself can need escapes because it follows the @@ -2808,39 +2769,6 @@ func decodeAclItemArray(vr *ValueReader) []AclItem { return aclItems } -func encodeStringSlice(w *WriteBuf, oid OID, slice []string) error { - var elOID OID - switch oid { - case VarcharArrayOID: - elOID = VarcharOID - case TextArrayOID: - elOID = TextOID - default: - return fmt.Errorf("cannot encode Go %s into oid %d", "[]string", oid) - } - - var totalStringSize int - for _, v := range slice { - totalStringSize += len(v) - } - - size := 20 + len(slice)*4 + totalStringSize - w.WriteInt32(int32(size)) - - w.WriteInt32(1) // number of dimensions - w.WriteInt32(0) // no nulls - w.WriteInt32(int32(elOID)) // type of elements - w.WriteInt32(int32(len(slice))) // number of elements - w.WriteInt32(1) // index of first element - - for _, v := range slice { - w.WriteInt32(int32(len(v))) - w.WriteBytes([]byte(v)) - } - - return nil -} - func decodeTimestampArray(vr *ValueReader) []time.Time { if vr.Len() == -1 { return nil From 12ac0c33b826d29ea5b3375d70771e42ac758538 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Mar 2017 21:23:57 -0600 Subject: [PATCH 082/264] Remove unused array code from pgx --- values.go | 555 ------------------------------------------------------ 1 file changed, 555 deletions(-) diff --git a/values.go b/values.go index c011c8cf..f050726e 100644 --- a/values.go +++ b/values.go @@ -2021,65 +2021,6 @@ func decode1dArrayHeader(vr *ValueReader) (length int32, err error) { return length, nil } -func decodeBoolArray(vr *ValueReader) []bool { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != BoolArrayOID { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []bool", vr.Type().DataType))) - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - numElems, err := decode1dArrayHeader(vr) - if err != nil { - vr.Fatal(err) - return nil - } - - a := make([]bool, int(numElems)) - for i := 0; i < len(a); i++ { - elSize := vr.ReadInt32() - switch elSize { - case 1: - if vr.ReadByte() == 1 { - a[i] = true - } - case -1: - vr.Fatal(ProtocolError("Cannot decode null element")) - return nil - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an bool element: %d", elSize))) - return nil - } - } - - return a -} - -func encodeBoolSlice(w *WriteBuf, oid OID, slice []bool) error { - if oid != BoolArrayOID { - return fmt.Errorf("cannot encode Go %s into oid %d", "[]bool", oid) - } - - encodeArrayHeader(w, BoolOID, len(slice), 5) - for _, v := range slice { - w.WriteInt32(1) - var b byte - if v { - b = 1 - } - w.WriteByte(b) - } - - return nil -} - func decodeByteaArray(vr *ValueReader) [][]byte { if vr.Len() == -1 { return nil @@ -2141,430 +2082,6 @@ func encodeByteSliceSlice(w *WriteBuf, oid OID, value [][]byte) error { return nil } -func decodeInt2Array(vr *ValueReader) []int16 { - if vr.Type().DataType != Int2ArrayOID { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int16", vr.Type().DataType))) - return nil - } - - vr.err = errRewoundLen - - var a pgtype.Int2Array - var err error - switch vr.Type().FormatCode { - case TextFormatCode: - err = a.DecodeText(&valueReader2{vr}) - case BinaryFormatCode: - err = a.DecodeBinary(&valueReader2{vr}) - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - if err != nil { - vr.Fatal(err) - return nil - } - - if a.Status == pgtype.Null { - return nil - } - - rawArray := make([]int16, len(a.Elements)) - for i := range a.Elements { - if a.Elements[i].Status == pgtype.Present { - rawArray[i] = a.Elements[i].Int - } else { - vr.Fatal(ProtocolError("Cannot decode null element")) - return nil - } - } - - return rawArray -} - -func decodeInt2ArrayToUInt(vr *ValueReader) []uint16 { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != Int2ArrayOID { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []uint16", vr.Type().DataType))) - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - numElems, err := decode1dArrayHeader(vr) - if err != nil { - vr.Fatal(err) - return nil - } - - a := make([]uint16, int(numElems)) - for i := 0; i < len(a); i++ { - elSize := vr.ReadInt32() - switch elSize { - case 2: - tmp := vr.ReadInt16() - if tmp < 0 { - vr.Fatal(ProtocolError(fmt.Sprintf("%d is less than zero for uint16", tmp))) - return nil - } - a[i] = uint16(tmp) - case -1: - vr.Fatal(ProtocolError("Cannot decode null element")) - return nil - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int2 element: %d", elSize))) - return nil - } - } - - return a -} - -func decodeInt4Array(vr *ValueReader) []int32 { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != Int4ArrayOID { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int32", vr.Type().DataType))) - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - numElems, err := decode1dArrayHeader(vr) - if err != nil { - vr.Fatal(err) - return nil - } - - a := make([]int32, int(numElems)) - for i := 0; i < len(a); i++ { - elSize := vr.ReadInt32() - switch elSize { - case 4: - a[i] = vr.ReadInt32() - case -1: - vr.Fatal(ProtocolError("Cannot decode null element")) - return nil - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4 element: %d", elSize))) - return nil - } - } - - return a -} - -func decodeInt4ArrayToUInt(vr *ValueReader) []uint32 { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != Int4ArrayOID { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []uint32", vr.Type().DataType))) - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - numElems, err := decode1dArrayHeader(vr) - if err != nil { - vr.Fatal(err) - return nil - } - - a := make([]uint32, int(numElems)) - for i := 0; i < len(a); i++ { - elSize := vr.ReadInt32() - switch elSize { - case 4: - tmp := vr.ReadInt32() - if tmp < 0 { - vr.Fatal(ProtocolError(fmt.Sprintf("%d is less than zero for uint32", tmp))) - return nil - } - a[i] = uint32(tmp) - case -1: - vr.Fatal(ProtocolError("Cannot decode null element")) - return nil - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int4 element: %d", elSize))) - return nil - } - } - - return a -} - -func encodeInt32Slice(w *WriteBuf, oid OID, slice []int32) error { - if oid != Int4ArrayOID { - return fmt.Errorf("cannot encode Go %s into oid %d", "[]int32", oid) - } - - encodeArrayHeader(w, Int4OID, len(slice), 8) - for _, v := range slice { - w.WriteInt32(4) - w.WriteInt32(v) - } - - return nil -} - -func encodeUInt32Slice(w *WriteBuf, oid OID, slice []uint32) error { - if oid != Int4ArrayOID { - return fmt.Errorf("cannot encode Go %s into oid %d", "[]uint32", oid) - } - - encodeArrayHeader(w, Int4OID, len(slice), 8) - for _, v := range slice { - if v <= math.MaxInt32 { - w.WriteInt32(4) - w.WriteInt32(int32(v)) - } else { - return fmt.Errorf("%d is greater than max integer %d", v, math.MaxInt32) - } - } - - return nil -} - -func decodeInt8Array(vr *ValueReader) []int64 { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != Int8ArrayOID { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []int64", vr.Type().DataType))) - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - numElems, err := decode1dArrayHeader(vr) - if err != nil { - vr.Fatal(err) - return nil - } - - a := make([]int64, int(numElems)) - for i := 0; i < len(a); i++ { - elSize := vr.ReadInt32() - switch elSize { - case 8: - a[i] = vr.ReadInt64() - case -1: - vr.Fatal(ProtocolError("Cannot decode null element")) - return nil - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8 element: %d", elSize))) - return nil - } - } - - return a -} - -func decodeInt8ArrayToUInt(vr *ValueReader) []uint64 { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != Int8ArrayOID { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []uint64", vr.Type().DataType))) - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - numElems, err := decode1dArrayHeader(vr) - if err != nil { - vr.Fatal(err) - return nil - } - - a := make([]uint64, int(numElems)) - for i := 0; i < len(a); i++ { - elSize := vr.ReadInt32() - switch elSize { - case 8: - tmp := vr.ReadInt64() - if tmp < 0 { - vr.Fatal(ProtocolError(fmt.Sprintf("%d is less than zero for uint64", tmp))) - return nil - } - a[i] = uint64(tmp) - case -1: - vr.Fatal(ProtocolError("Cannot decode null element")) - return nil - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an int8 element: %d", elSize))) - return nil - } - } - - return a -} - -func encodeInt64Slice(w *WriteBuf, oid OID, slice []int64) error { - if oid != Int8ArrayOID { - return fmt.Errorf("cannot encode Go %s into oid %d", "[]int64", oid) - } - - encodeArrayHeader(w, Int8OID, len(slice), 12) - for _, v := range slice { - w.WriteInt32(8) - w.WriteInt64(v) - } - - return nil -} - -func encodeUInt64Slice(w *WriteBuf, oid OID, slice []uint64) error { - if oid != Int8ArrayOID { - return fmt.Errorf("cannot encode Go %s into oid %d", "[]uint64", oid) - } - - encodeArrayHeader(w, Int8OID, len(slice), 12) - for _, v := range slice { - if v <= math.MaxInt64 { - w.WriteInt32(8) - w.WriteInt64(int64(v)) - } else { - return fmt.Errorf("%d is greater than max bigint %d", v, int64(math.MaxInt64)) - } - } - - return nil -} - -func decodeFloat4Array(vr *ValueReader) []float32 { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != Float4ArrayOID { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []float32", vr.Type().DataType))) - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - numElems, err := decode1dArrayHeader(vr) - if err != nil { - vr.Fatal(err) - return nil - } - - a := make([]float32, int(numElems)) - for i := 0; i < len(a); i++ { - elSize := vr.ReadInt32() - switch elSize { - case 4: - n := vr.ReadInt32() - a[i] = math.Float32frombits(uint32(n)) - case -1: - vr.Fatal(ProtocolError("Cannot decode null element")) - return nil - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4 element: %d", elSize))) - return nil - } - } - - return a -} - -func encodeFloat32Slice(w *WriteBuf, oid OID, slice []float32) error { - if oid != Float4ArrayOID { - return fmt.Errorf("cannot encode Go %s into oid %d", "[]float32", oid) - } - - encodeArrayHeader(w, Float4OID, len(slice), 8) - for _, v := range slice { - w.WriteInt32(4) - w.WriteInt32(int32(math.Float32bits(v))) - } - - return nil -} - -func decodeFloat8Array(vr *ValueReader) []float64 { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != Float8ArrayOID { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []float64", vr.Type().DataType))) - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - numElems, err := decode1dArrayHeader(vr) - if err != nil { - vr.Fatal(err) - return nil - } - - a := make([]float64, int(numElems)) - for i := 0; i < len(a); i++ { - elSize := vr.ReadInt32() - switch elSize { - case 8: - n := vr.ReadInt64() - a[i] = math.Float64frombits(uint64(n)) - case -1: - vr.Fatal(ProtocolError("Cannot decode null element")) - return nil - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4 element: %d", elSize))) - return nil - } - } - - return a -} - -func encodeFloat64Slice(w *WriteBuf, oid OID, slice []float64) error { - if oid != Float8ArrayOID { - return fmt.Errorf("cannot encode Go %s into oid %d", "[]float64", oid) - } - - encodeArrayHeader(w, Float8OID, len(slice), 12) - for _, v := range slice { - w.WriteInt32(8) - w.WriteInt64(int64(math.Float64bits(v))) - } - - return nil -} - // escapeAclItem escapes an AclItem before it is added to // its aclitem[] string representation. The PostgreSQL aclitem // datatype itself can need escapes because it follows the @@ -2768,75 +2285,3 @@ func decodeAclItemArray(vr *ValueReader) []AclItem { } return aclItems } - -func decodeTimestampArray(vr *ValueReader) []time.Time { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != TimestampArrayOID && vr.Type().DataType != TimestampTzArrayOID { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []time.Time", vr.Type().DataType))) - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - numElems, err := decode1dArrayHeader(vr) - if err != nil { - vr.Fatal(err) - return nil - } - - a := make([]time.Time, int(numElems)) - for i := 0; i < len(a); i++ { - elSize := vr.ReadInt32() - switch elSize { - case 8: - microsecSinceY2K := vr.ReadInt64() - microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K - a[i] = time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) - case -1: - vr.Fatal(ProtocolError("Cannot decode null element")) - return nil - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an time.Time element: %d", elSize))) - return nil - } - } - - return a -} - -func encodeTimeSlice(w *WriteBuf, oid OID, slice []time.Time) error { - var elOID OID - switch oid { - case TimestampArrayOID: - elOID = TimestampOID - case TimestampTzArrayOID: - elOID = TimestampTzOID - default: - return fmt.Errorf("cannot encode Go %s into oid %d", "[]time.Time", oid) - } - - encodeArrayHeader(w, int(elOID), len(slice), 12) - for _, t := range slice { - w.WriteInt32(8) - microsecSinceUnixEpoch := t.Unix()*1000000 + int64(t.Nanosecond())/1000 - microsecSinceY2K := microsecSinceUnixEpoch - microsecFromUnixEpochToY2K - w.WriteInt64(microsecSinceY2K) - } - - return nil -} - -func encodeArrayHeader(w *WriteBuf, oid, length, sizePerItem int) { - w.WriteInt32(int32(20 + length*sizePerItem)) - w.WriteInt32(1) // number of dimensions - w.WriteInt32(0) // no nulls - w.WriteInt32(int32(oid)) // type of elements - w.WriteInt32(int32(length)) // number of elements - w.WriteInt32(1) // index of first element -} From 575574cf98f0ded4d84372d3fb262b1161ae0d67 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Mar 2017 22:12:03 -0600 Subject: [PATCH 083/264] Move cid to pgtype --- conn.go | 1 + pgtype/cid.go | 141 +++++++++++++++++++++++++++++++++++++++++++++ pgtype/cid_test.go | 94 ++++++++++++++++++++++++++++++ pgtype/pgtype.go | 2 +- values.go | 101 +------------------------------- values_test.go | 4 -- 6 files changed, 238 insertions(+), 105 deletions(-) create mode 100644 pgtype/cid.go create mode 100644 pgtype/cid_test.go diff --git a/conn.go b/conn.go index d97942aa..7bb26677 100644 --- a/conn.go +++ b/conn.go @@ -270,6 +270,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.oidPgtypeValues = map[OID]pgtype.Value{ BoolArrayOID: &pgtype.BoolArray{}, BoolOID: &pgtype.Bool{}, + CIDOID: &pgtype.CID{}, CidrArrayOID: &pgtype.CidrArray{}, CidrOID: &pgtype.Inet{}, DateArrayOID: &pgtype.DateArray{}, diff --git a/pgtype/cid.go b/pgtype/cid.go new file mode 100644 index 00000000..9f8c87d8 --- /dev/null +++ b/pgtype/cid.go @@ -0,0 +1,141 @@ +package pgtype + +import ( + "fmt" + "io" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +// CID is PostgreSQL's Command Identifier type. +// +// When one does +// +// select cmin, cmax, * from some_table; +// +// it is the data type of the cmin and cmax hidden system columns. +// +// It is currently implemented as an unsigned four byte integer. +// Its definition can be found in src/include/c.h as CommandId +// in the PostgreSQL sources. +type CID struct { + Uint uint32 + Status Status +} + +// ConvertFrom converts from src to dst. Note that as CID is not a general +// number type ConvertFrom does not do automatic type conversion as other number +// types do. +func (dst *CID) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case CID: + *dst = value + case uint32: + *dst = CID{Uint: value, Status: Present} + default: + return fmt.Errorf("cannot convert %v to CID", value) + } + + return nil +} + +// AssignTo assigns from src to dst. Note that as CID is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *CID) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *uint32: + if src.Status == Present { + *v = src.Uint + } else { + return fmt.Errorf("cannot assign %v into %T", src, dst) + } + case **uint32: + if src.Status == Present { + n := src.Uint + *v = &n + } else { + *v = nil + } + } + + return nil +} + +func (dst *CID) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = CID{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + n, err := strconv.ParseUint(string(buf), 10, 32) + if err != nil { + return err + } + + *dst = CID{Uint: uint32(n), Status: Present} + return nil +} + +func (dst *CID) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = CID{Status: Null} + return nil + } + + if size != 4 { + return fmt.Errorf("invalid length for cid: %v", size) + } + + n, err := pgio.ReadUint32(r) + if err != nil { + return err + } + + *dst = CID{Uint: n, Status: Present} + return nil +} + +func (src CID) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + s := strconv.FormatUint(uint64(src.Uint), 10) + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + _, err = w.Write([]byte(s)) + return err +} + +func (src CID) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 4) + if err != nil { + return err + } + + _, err = pgio.WriteUint32(w, src.Uint) + return err +} diff --git a/pgtype/cid_test.go b/pgtype/cid_test.go new file mode 100644 index 00000000..72f5dfea --- /dev/null +++ b/pgtype/cid_test.go @@ -0,0 +1,94 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestCIDTranscode(t *testing.T) { + testSuccessfulTranscode(t, "cid", []interface{}{ + pgtype.CID{Uint: 42, Status: pgtype.Present}, + pgtype.CID{Status: pgtype.Null}, + }) +} + +func TestCIDConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.CID + }{ + {source: uint32(1), result: pgtype.CID{Uint: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.CID + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestCIDAssignTo(t *testing.T) { + var ui32 uint32 + var pui32 *uint32 + + simpleTests := []struct { + src pgtype.CID + dst interface{} + expected interface{} + }{ + {src: pgtype.CID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.CID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.CID + dst interface{} + expected interface{} + }{ + {src: pgtype.CID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.CID + dst interface{} + }{ + {src: pgtype.CID{Status: pgtype.Null}, dst: &ui32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 5722c8ab..1200bf12 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -20,7 +20,7 @@ const ( OIDOID = 26 TidOID = 27 XidOID = 28 - CidOID = 29 + CIDOID = 29 JSONOID = 114 CidrOID = 650 CidrArrayOID = 651 diff --git a/values.go b/values.go index f050726e..b143ac1a 100644 --- a/values.go +++ b/values.go @@ -29,7 +29,7 @@ const ( OIDOID = 26 TidOID = 27 XidOID = 28 - CidOID = 29 + CIDOID = 29 JSONOID = 114 CidrOID = 650 CidrArrayOID = 651 @@ -645,58 +645,6 @@ func (n NullXid) Encode(w *WriteBuf, oid OID) error { return encodeXid(w, oid, n.Xid) } -// Cid is PostgreSQL's Command Identifier type. -// -// When one does -// -// select cmin, cmax, * from some_table; -// -// it is the data type of the cmin and cmax hidden system columns. -// -// It is currently implemented as an unsigned four byte integer. -// Its definition can be found in src/include/c.h as CommandId -// in the PostgreSQL sources. -type Cid uint32 - -// NullCid represents a Command Identifier (Cid) that may be null. NullCid implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. -// -// If Valid is false then the value is NULL. -type NullCid struct { - Cid Cid - Valid bool // Valid is true if Cid is not NULL -} - -func (n *NullCid) Scan(vr *ValueReader) error { - if vr.Type().DataType != CidOID { - return SerializationError(fmt.Sprintf("NullCid.Scan cannot decode OID %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Cid, n.Valid = 0, false - return nil - } - n.Valid = true - n.Cid = decodeCid(vr) - return vr.Err() -} - -func (n NullCid) FormatCode() int16 { return BinaryFormatCode } - -func (n NullCid) Encode(w *WriteBuf, oid OID) error { - if oid != CidOID { - return SerializationError(fmt.Sprintf("NullCid.Encode cannot encode into OID %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeCid(w, oid, n.Cid) -} - // Tid is PostgreSQL's Tuple Identifier type. // // When one does @@ -1087,8 +1035,6 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { return encodeOID(wbuf, oid, arg) case Xid: return encodeXid(wbuf, oid, arg) - case Cid: - return encodeCid(wbuf, oid, arg) default: if strippedArg, ok := stripNamedType(&refVal); ok { return Encode(wbuf, oid, strippedArg) @@ -1170,8 +1116,6 @@ func Decode(vr *ValueReader, d interface{}) error { *v = decodeXid(vr) case *Tid: *v = decodeTid(vr) - case *Cid: - *v = decodeCid(vr) case *string: *v = decodeText(vr) case *[]AclItem: @@ -1493,49 +1437,6 @@ func encodeXid(w *WriteBuf, oid OID, value Xid) error { return nil } -func decodeCid(vr *ValueReader) Cid { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into Cid")) - return Cid(0) - } - - if vr.Type().DataType != CidOID { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Cid", vr.Type().DataType))) - return Cid(0) - } - - // Unlikely Cid will ever go over the wire as text format, but who knows? - switch vr.Type().FormatCode { - case TextFormatCode: - s := vr.ReadString(vr.Len()) - n, err := strconv.ParseUint(s, 10, 32) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid OID: %v", s))) - } - return Cid(n) - case BinaryFormatCode: - if vr.Len() != 4 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an OID: %d", vr.Len()))) - return Cid(0) - } - return Cid(vr.ReadUint32()) - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return Cid(0) - } -} - -func encodeCid(w *WriteBuf, oid OID, value Cid) error { - if oid != CidOID { - return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Cid", oid) - } - - w.WriteInt32(4) - w.WriteUint32(uint32(value)) - - return nil -} - // Note that we do not match negative numbers, because neither the // BlockNumber nor OffsetNumber of a Tid can be negative. var tidRegexp *regexp.Regexp = regexp.MustCompile(`^\((\d*),(\d*)\)$`) diff --git a/values_test.go b/values_test.go index d6ce705a..ae3ecc84 100644 --- a/values_test.go +++ b/values_test.go @@ -573,7 +573,6 @@ func TestNullX(t *testing.T) { n pgx.NullName oid pgx.NullOID xid pgx.NullXid - cid pgx.NullCid tid pgx.NullTid i64 pgx.NullInt64 f32 pgx.NullFloat32 @@ -611,9 +610,6 @@ func TestNullX(t *testing.T) { {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: false}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "", Valid: false}}}, // A tricky (and valid) aclitem can still be used, especially with Go's useful backticks {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}}}, - {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 1, Valid: true}}}, - {"select $1::cid", []interface{}{pgx.NullCid{Cid: 1, Valid: false}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 0, Valid: false}}}, - {"select $1::cid", []interface{}{pgx.NullCid{Cid: 4294967295, Valid: true}}, []interface{}{&actual.cid}, allTypes{cid: pgx.NullCid{Cid: 4294967295, Valid: true}}}, {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: true}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: true}}}, {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: false}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 0, OffsetNumber: 0}, Valid: false}}}, {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 4294967295, OffsetNumber: 65535}, Valid: true}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 4294967295, OffsetNumber: 65535}, Valid: true}}}, From 0f68bdcd52524612d8f2cc0d7c04d1dc4a47fd2d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 5 Mar 2017 08:59:26 -0600 Subject: [PATCH 084/264] Generalize array template --- pgtype/typed_array.go.erb | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb index 647ed7c0..8c18073b 100644 --- a/pgtype/typed_array.go.erb +++ b/pgtype/typed_array.go.erb @@ -211,9 +211,16 @@ func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) error { } textElementWriter.Reset() - err = elem.EncodeText(textElementWriter) - if err != nil { - return err + if elem.String == "" && elem.Status == Present { + _, err := io.WriteString(buf, `""`) + if err != nil { + return err + } + } else { + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } } for _, dec := range dimElemCounts { @@ -236,6 +243,10 @@ func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) error { } func (src *<%= pgtype_array_type %>) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, <%= element_oid %>) +} + +func (src *<%= pgtype_array_type %>) encodeBinary(w io.Writer, elementOID int32) error { if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -256,7 +267,7 @@ func (src *<%= pgtype_array_type %>) EncodeBinary(w io.Writer) error { } } - arrayHeader.ElementOID = <%= element_oid %> + arrayHeader.ElementOID = elementOID arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - From cb1c05476f327901db8a13bc00ff9f1192d44c1e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 5 Mar 2017 09:07:07 -0600 Subject: [PATCH 085/264] Move XID to pgypte --- conn.go | 1 + pgtype/pgtype.go | 2 +- pgtype/xid.go | 45 ++++++++++++++++++++ pgtype/xid_test.go | 94 ++++++++++++++++++++++++++++++++++++++++ values.go | 104 +-------------------------------------------- values_test.go | 4 -- 6 files changed, 142 insertions(+), 108 deletions(-) create mode 100644 pgtype/xid.go create mode 100644 pgtype/xid_test.go diff --git a/conn.go b/conn.go index 7bb26677..2b826dad 100644 --- a/conn.go +++ b/conn.go @@ -295,6 +295,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl TimestampTzOID: &pgtype.Timestamptz{}, VarcharArrayOID: &pgtype.VarcharArray{}, VarcharOID: &pgtype.Text{}, + XIDOID: &pgtype.XID{}, } if tlsConfig != nil { diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 1200bf12..15c0cc76 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -19,7 +19,7 @@ const ( TextOID = 25 OIDOID = 26 TidOID = 27 - XidOID = 28 + XIDOID = 28 CIDOID = 29 JSONOID = 114 CidrOID = 650 diff --git a/pgtype/xid.go b/pgtype/xid.go new file mode 100644 index 00000000..f4d087a5 --- /dev/null +++ b/pgtype/xid.go @@ -0,0 +1,45 @@ +package pgtype + +import ( + "io" +) + +// Xid is PostgreSQL's Transaction ID type. +// +// In later versions of PostgreSQL, it is the type used for the backend_xid +// and backend_xmin columns of the pg_stat_activity system view. +// +// Also, when one does +// +// select xmin, xmax, * from some_table; +// +// it is the data type of the xmin and xmax hidden system columns. +// +// It is currently implemented as an unsigned four byte integer. +// Its definition can be found in src/include/postgres_ext.h as TransactionId +// in the PostgreSQL sources. +type XID CID + +func (dst *XID) ConvertFrom(src interface{}) error { + return (*CID)(dst).ConvertFrom(src) +} + +func (src *XID) AssignTo(dst interface{}) error { + return (*CID)(src).AssignTo(dst) +} + +func (dst *XID) DecodeText(r io.Reader) error { + return (*CID)(dst).DecodeText(r) +} + +func (dst *XID) DecodeBinary(r io.Reader) error { + return (*CID)(dst).DecodeBinary(r) +} + +func (src XID) EncodeText(w io.Writer) error { + return (CID)(src).EncodeText(w) +} + +func (src XID) EncodeBinary(w io.Writer) error { + return (CID)(src).EncodeBinary(w) +} diff --git a/pgtype/xid_test.go b/pgtype/xid_test.go new file mode 100644 index 00000000..664920bc --- /dev/null +++ b/pgtype/xid_test.go @@ -0,0 +1,94 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestXIDTranscode(t *testing.T) { + testSuccessfulTranscode(t, "xid", []interface{}{ + pgtype.XID{Uint: 42, Status: pgtype.Present}, + pgtype.XID{Status: pgtype.Null}, + }) +} + +func TestXIDConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.XID + }{ + {source: uint32(1), result: pgtype.XID{Uint: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.XID + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestXIDAssignTo(t *testing.T) { + var ui32 uint32 + var pui32 *uint32 + + simpleTests := []struct { + src pgtype.XID + dst interface{} + expected interface{} + }{ + {src: pgtype.XID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.XID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.XID + dst interface{} + expected interface{} + }{ + {src: pgtype.XID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.XID + dst interface{} + }{ + {src: pgtype.XID{Status: pgtype.Null}, dst: &ui32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/values.go b/values.go index b143ac1a..b6848cf5 100644 --- a/values.go +++ b/values.go @@ -28,7 +28,7 @@ const ( TextOID = 25 OIDOID = 26 TidOID = 27 - XidOID = 28 + XIDOID = 28 CIDOID = 29 JSONOID = 114 CidrOID = 650 @@ -590,61 +590,6 @@ func (n NullOID) Encode(w *WriteBuf, oid OID) error { return encodeOID(w, oid, n.OID) } -// Xid is PostgreSQL's Transaction ID type. -// -// In later versions of PostgreSQL, it is the type used for the backend_xid -// and backend_xmin columns of the pg_stat_activity system view. -// -// Also, when one does -// -// select xmin, xmax, * from some_table; -// -// it is the data type of the xmin and xmax hidden system columns. -// -// It is currently implemented as an unsigned four byte integer. -// Its definition can be found in src/include/postgres_ext.h as TransactionId -// in the PostgreSQL sources. -type Xid uint32 - -// NullXid represents a Transaction ID (Xid) that may be null. NullXid implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. -// -// If Valid is false then the value is NULL. -type NullXid struct { - Xid Xid - Valid bool // Valid is true if Xid is not NULL -} - -func (n *NullXid) Scan(vr *ValueReader) error { - if vr.Type().DataType != XidOID { - return SerializationError(fmt.Sprintf("NullXid.Scan cannot decode OID %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Xid, n.Valid = 0, false - return nil - } - n.Valid = true - n.Xid = decodeXid(vr) - return vr.Err() -} - -func (n NullXid) FormatCode() int16 { return BinaryFormatCode } - -func (n NullXid) Encode(w *WriteBuf, oid OID) error { - if oid != XidOID { - return SerializationError(fmt.Sprintf("NullXid.Encode cannot encode into OID %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeXid(w, oid, n.Xid) -} - // Tid is PostgreSQL's Tuple Identifier type. // // When one does @@ -1033,8 +978,6 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { return encodeString(wbuf, oid, string(arg)) case OID: return encodeOID(wbuf, oid, arg) - case Xid: - return encodeXid(wbuf, oid, arg) default: if strippedArg, ok := stripNamedType(&refVal); ok { return Encode(wbuf, oid, strippedArg) @@ -1112,8 +1055,6 @@ func Decode(vr *ValueReader, d interface{}) error { *v = Name(decodeText(vr)) case *OID: *v = decodeOID(vr) - case *Xid: - *v = decodeXid(vr) case *Tid: *v = decodeTid(vr) case *string: @@ -1394,49 +1335,6 @@ func encodeOID(w *WriteBuf, oid OID, value OID) error { return nil } -func decodeXid(vr *ValueReader) Xid { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into Xid")) - return Xid(0) - } - - if vr.Type().DataType != XidOID { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Xid", vr.Type().DataType))) - return Xid(0) - } - - // Unlikely Xid will ever go over the wire as text format, but who knows? - switch vr.Type().FormatCode { - case TextFormatCode: - s := vr.ReadString(vr.Len()) - n, err := strconv.ParseUint(s, 10, 32) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid OID: %v", s))) - } - return Xid(n) - case BinaryFormatCode: - if vr.Len() != 4 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an OID: %d", vr.Len()))) - return Xid(0) - } - return Xid(vr.ReadUint32()) - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return Xid(0) - } -} - -func encodeXid(w *WriteBuf, oid OID, value Xid) error { - if oid != XidOID { - return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Xid", oid) - } - - w.WriteInt32(4) - w.WriteUint32(uint32(value)) - - return nil -} - // Note that we do not match negative numbers, because neither the // BlockNumber nor OffsetNumber of a Tid can be negative. var tidRegexp *regexp.Regexp = regexp.MustCompile(`^\((\d*),(\d*)\)$`) diff --git a/values_test.go b/values_test.go index ae3ecc84..0283f17d 100644 --- a/values_test.go +++ b/values_test.go @@ -572,7 +572,6 @@ func TestNullX(t *testing.T) { a pgx.NullAclItem n pgx.NullName oid pgx.NullOID - xid pgx.NullXid tid pgx.NullTid i64 pgx.NullInt64 f32 pgx.NullFloat32 @@ -598,9 +597,6 @@ func TestNullX(t *testing.T) { {"select $1::oid", []interface{}{pgx.NullOID{OID: 1, Valid: true}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOID{OID: 1, Valid: true}}}, {"select $1::oid", []interface{}{pgx.NullOID{OID: 1, Valid: false}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOID{OID: 0, Valid: false}}}, {"select $1::oid", []interface{}{pgx.NullOID{OID: 4294967295, Valid: true}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOID{OID: 4294967295, Valid: true}}}, - {"select $1::xid", []interface{}{pgx.NullXid{Xid: 1, Valid: true}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 1, Valid: true}}}, - {"select $1::xid", []interface{}{pgx.NullXid{Xid: 1, Valid: false}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 0, Valid: false}}}, - {"select $1::xid", []interface{}{pgx.NullXid{Xid: 4294967295, Valid: true}}, []interface{}{&actual.xid}, allTypes{xid: pgx.NullXid{Xid: 4294967295, Valid: true}}}, {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 1, Valid: true}}}, {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: false}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 0, Valid: false}}}, {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 255, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 255, Valid: true}}}, From 164bf9eebe76f0be3320ae3089d6b63ce838f69e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 5 Mar 2017 09:13:25 -0600 Subject: [PATCH 086/264] Extract pguint32 --- pgtype/cid.go | 108 +++---------------------------------- pgtype/pguint32.go | 130 +++++++++++++++++++++++++++++++++++++++++++++ pgtype/xid.go | 19 ++++--- 3 files changed, 149 insertions(+), 108 deletions(-) create mode 100644 pgtype/pguint32.go diff --git a/pgtype/cid.go b/pgtype/cid.go index 9f8c87d8..21d6fb80 100644 --- a/pgtype/cid.go +++ b/pgtype/cid.go @@ -1,11 +1,7 @@ package pgtype import ( - "fmt" "io" - "strconv" - - "github.com/jackc/pgx/pgio" ) // CID is PostgreSQL's Command Identifier type. @@ -19,123 +15,33 @@ import ( // It is currently implemented as an unsigned four byte integer. // Its definition can be found in src/include/c.h as CommandId // in the PostgreSQL sources. -type CID struct { - Uint uint32 - Status Status -} +type CID pguint32 // ConvertFrom converts from src to dst. Note that as CID is not a general // number type ConvertFrom does not do automatic type conversion as other number // types do. func (dst *CID) ConvertFrom(src interface{}) error { - switch value := src.(type) { - case CID: - *dst = value - case uint32: - *dst = CID{Uint: value, Status: Present} - default: - return fmt.Errorf("cannot convert %v to CID", value) - } - - return nil + return (*pguint32)(dst).ConvertFrom(src) } // AssignTo assigns from src to dst. Note that as CID is not a general number // type AssignTo does not do automatic type conversion as other number types do. func (src *CID) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *uint32: - if src.Status == Present { - *v = src.Uint - } else { - return fmt.Errorf("cannot assign %v into %T", src, dst) - } - case **uint32: - if src.Status == Present { - n := src.Uint - *v = &n - } else { - *v = nil - } - } - - return nil + return (*pguint32)(src).AssignTo(dst) } func (dst *CID) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { - *dst = CID{Status: Null} - return nil - } - - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - n, err := strconv.ParseUint(string(buf), 10, 32) - if err != nil { - return err - } - - *dst = CID{Uint: uint32(n), Status: Present} - return nil + return (*pguint32)(dst).DecodeText(r) } func (dst *CID) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { - *dst = CID{Status: Null} - return nil - } - - if size != 4 { - return fmt.Errorf("invalid length for cid: %v", size) - } - - n, err := pgio.ReadUint32(r) - if err != nil { - return err - } - - *dst = CID{Uint: n, Status: Present} - return nil + return (*pguint32)(dst).DecodeBinary(r) } func (src CID) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err - } - - s := strconv.FormatUint(uint64(src.Uint), 10) - _, err := pgio.WriteInt32(w, int32(len(s))) - if err != nil { - return nil - } - _, err = w.Write([]byte(s)) - return err + return (pguint32)(src).EncodeText(w) } func (src CID) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err - } - - _, err := pgio.WriteInt32(w, 4) - if err != nil { - return err - } - - _, err = pgio.WriteUint32(w, src.Uint) - return err + return (pguint32)(src).EncodeBinary(w) } diff --git a/pgtype/pguint32.go b/pgtype/pguint32.go new file mode 100644 index 00000000..66b385fb --- /dev/null +++ b/pgtype/pguint32.go @@ -0,0 +1,130 @@ +package pgtype + +import ( + "fmt" + "io" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +// pguint32 is the core type that is used to implement PostgreSQL types such as +// CID and XID. +type pguint32 struct { + Uint uint32 + Status Status +} + +// ConvertFrom converts from src to dst. Note that as pguint32 is not a general +// number type ConvertFrom does not do automatic type conversion as other number +// types do. +func (dst *pguint32) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case uint32: + *dst = pguint32{Uint: value, Status: Present} + default: + return fmt.Errorf("cannot convert %v to pguint32", value) + } + + return nil +} + +// AssignTo assigns from src to dst. Note that as pguint32 is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *pguint32) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *uint32: + if src.Status == Present { + *v = src.Uint + } else { + return fmt.Errorf("cannot assign %v into %T", src, dst) + } + case **uint32: + if src.Status == Present { + n := src.Uint + *v = &n + } else { + *v = nil + } + } + + return nil +} + +func (dst *pguint32) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = pguint32{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + n, err := strconv.ParseUint(string(buf), 10, 32) + if err != nil { + return err + } + + *dst = pguint32{Uint: uint32(n), Status: Present} + return nil +} + +func (dst *pguint32) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = pguint32{Status: Null} + return nil + } + + if size != 4 { + return fmt.Errorf("invalid length for cid: %v", size) + } + + n, err := pgio.ReadUint32(r) + if err != nil { + return err + } + + *dst = pguint32{Uint: n, Status: Present} + return nil +} + +func (src pguint32) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + s := strconv.FormatUint(uint64(src.Uint), 10) + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + _, err = w.Write([]byte(s)) + return err +} + +func (src pguint32) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 4) + if err != nil { + return err + } + + _, err = pgio.WriteUint32(w, src.Uint) + return err +} diff --git a/pgtype/xid.go b/pgtype/xid.go index f4d087a5..b311cbfb 100644 --- a/pgtype/xid.go +++ b/pgtype/xid.go @@ -18,28 +18,33 @@ import ( // It is currently implemented as an unsigned four byte integer. // Its definition can be found in src/include/postgres_ext.h as TransactionId // in the PostgreSQL sources. -type XID CID +type XID pguint32 +// ConvertFrom converts from src to dst. Note that as XID is not a general +// number type ConvertFrom does not do automatic type conversion as other number +// types do. func (dst *XID) ConvertFrom(src interface{}) error { - return (*CID)(dst).ConvertFrom(src) + return (*pguint32)(dst).ConvertFrom(src) } +// AssignTo assigns from src to dst. Note that as XID is not a general number +// type AssignTo does not do automatic type conversion as other number types do. func (src *XID) AssignTo(dst interface{}) error { - return (*CID)(src).AssignTo(dst) + return (*pguint32)(src).AssignTo(dst) } func (dst *XID) DecodeText(r io.Reader) error { - return (*CID)(dst).DecodeText(r) + return (*pguint32)(dst).DecodeText(r) } func (dst *XID) DecodeBinary(r io.Reader) error { - return (*CID)(dst).DecodeBinary(r) + return (*pguint32)(dst).DecodeBinary(r) } func (src XID) EncodeText(w io.Writer) error { - return (CID)(src).EncodeText(w) + return (pguint32)(src).EncodeText(w) } func (src XID) EncodeBinary(w io.Writer) error { - return (CID)(src).EncodeBinary(w) + return (pguint32)(src).EncodeBinary(w) } From f66b80c387d39c45596516f1b547563d2a6763f7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 5 Mar 2017 09:18:50 -0600 Subject: [PATCH 087/264] Fix comment on XID --- pgtype/xid.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgtype/xid.go b/pgtype/xid.go index b311cbfb..d4003b5d 100644 --- a/pgtype/xid.go +++ b/pgtype/xid.go @@ -4,7 +4,7 @@ import ( "io" ) -// Xid is PostgreSQL's Transaction ID type. +// XID is PostgreSQL's Transaction ID type. // // In later versions of PostgreSQL, it is the type used for the backend_xid // and backend_xmin columns of the pg_stat_activity system view. From af8519991e37719ae5288d4eba78026cdd814910 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 5 Mar 2017 13:05:49 -0600 Subject: [PATCH 088/264] Move OID to pgtype --- conn.go | 5 +- pgtype/oid.go | 41 ++++++++++++++ pgtype/oid_test.go | 94 +++++++++++++++++++++++++++++++ pgtype/pguint32.go | 2 +- query.go | 6 -- query_test.go | 7 +-- values.go | 135 +++++++++++++++++++-------------------------- values_test.go | 4 -- 8 files changed, 199 insertions(+), 95 deletions(-) create mode 100644 pgtype/oid.go create mode 100644 pgtype/oid_test.go diff --git a/conn.go b/conn.go index 2b826dad..c55d5618 100644 --- a/conn.go +++ b/conn.go @@ -287,6 +287,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl Int4OID: &pgtype.Int4{}, Int8ArrayOID: &pgtype.Int8Array{}, Int8OID: &pgtype.Int8{}, + OIDOID: &pgtype.OID{}, TextArrayOID: &pgtype.TextArray{}, TextOID: &pgtype.Text{}, TimestampArrayOID: &pgtype.TimestampArray{}, @@ -392,7 +393,7 @@ where ( c.PgTypes = make(map[OID]PgType, 128) for rows.Next() { - var oid OID + var oid uint32 var t PgType rows.Scan(&oid, &t.Name) @@ -400,7 +401,7 @@ where ( // The zero value is text format so we ignore any types without a default type format t.DefaultFormat, _ = DefaultTypeFormats[t.Name] - c.PgTypes[oid] = t + c.PgTypes[OID(oid)] = t } return rows.Err() diff --git a/pgtype/oid.go b/pgtype/oid.go new file mode 100644 index 00000000..d137f352 --- /dev/null +++ b/pgtype/oid.go @@ -0,0 +1,41 @@ +package pgtype + +import ( + "io" +) + +// OID (Object Identifier Type) is, according to +// https://www.postgresql.org/docs/current/static/datatype-oid.html, used +// internally by PostgreSQL as a primary key for various system tables. It is +// currently implemented as an unsigned four-byte integer. Its definition can be +// found in src/include/postgres_ext.h in the PostgreSQL sources. +type OID pguint32 + +// ConvertFrom converts from src to dst. Note that as OID is not a general +// number type ConvertFrom does not do automatic type conversion as other number +// types do. +func (dst *OID) ConvertFrom(src interface{}) error { + return (*pguint32)(dst).ConvertFrom(src) +} + +// AssignTo assigns from src to dst. Note that as OID is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *OID) AssignTo(dst interface{}) error { + return (*pguint32)(src).AssignTo(dst) +} + +func (dst *OID) DecodeText(r io.Reader) error { + return (*pguint32)(dst).DecodeText(r) +} + +func (dst *OID) DecodeBinary(r io.Reader) error { + return (*pguint32)(dst).DecodeBinary(r) +} + +func (src OID) EncodeText(w io.Writer) error { + return (pguint32)(src).EncodeText(w) +} + +func (src OID) EncodeBinary(w io.Writer) error { + return (pguint32)(src).EncodeBinary(w) +} diff --git a/pgtype/oid_test.go b/pgtype/oid_test.go new file mode 100644 index 00000000..c8e0b2d6 --- /dev/null +++ b/pgtype/oid_test.go @@ -0,0 +1,94 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestOIDTranscode(t *testing.T) { + testSuccessfulTranscode(t, "oid", []interface{}{ + pgtype.OID{Uint: 42, Status: pgtype.Present}, + pgtype.OID{Status: pgtype.Null}, + }) +} + +func TestOIDConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.OID + }{ + {source: uint32(1), result: pgtype.OID{Uint: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.OID + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestOIDAssignTo(t *testing.T) { + var ui32 uint32 + var pui32 *uint32 + + simpleTests := []struct { + src pgtype.OID + dst interface{} + expected interface{} + }{ + {src: pgtype.OID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.OID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.OID + dst interface{} + expected interface{} + }{ + {src: pgtype.OID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.OID + dst interface{} + }{ + {src: pgtype.OID{Status: pgtype.Null}, dst: &ui32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/pguint32.go b/pgtype/pguint32.go index 66b385fb..9c1ccd6c 100644 --- a/pgtype/pguint32.go +++ b/pgtype/pguint32.go @@ -89,7 +89,7 @@ func (dst *pguint32) DecodeBinary(r io.Reader) error { } if size != 4 { - return fmt.Errorf("invalid length for cid: %v", size) + return fmt.Errorf("invalid length: %v", size) } n, err := pgio.ReadUint32(r) diff --git a/query.go b/query.go index ffe51ecc..965f3913 100644 --- a/query.go +++ b/query.go @@ -256,8 +256,6 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { val = int64(decodeInt4(vr)) case TextOID, VarcharOID: val = decodeText(vr) - case OIDOID: - val = int64(decodeOID(vr)) case Float4OID: val = float64(decodeFloat4(vr)) case Float8OID: @@ -382,8 +380,6 @@ func (rows *Rows) Values() ([]interface{}, error) { values = append(values, decodeInt2(vr)) case Int4OID: values = append(values, decodeInt4(vr)) - case OIDOID: - values = append(values, decodeOID(vr)) case Float4OID: values = append(values, decodeFloat4(vr)) case Float8OID: @@ -457,8 +453,6 @@ func (rows *Rows) ValuesForStdlib() ([]interface{}, error) { values = append(values, decodeInt2(vr)) case Int4OID: values = append(values, decodeInt4(vr)) - case OIDOID: - values = append(values, decodeOID(vr)) case Float4OID: values = append(values, decodeFloat4(vr)) case Float8OID: diff --git a/query_test.go b/query_test.go index 801ba851..bbd7871e 100644 --- a/query_test.go +++ b/query_test.go @@ -53,7 +53,7 @@ func TestConnQueryValues(t *testing.T) { var rowCount int32 - rows, err := conn.Query("select 'foo'::text, 'bar'::varchar, n, null, n::oid from generate_series(1,$1) n", 10) + rows, err := conn.Query("select 'foo'::text, 'bar'::varchar, n, null, n from generate_series(1,$1) n", 10) if err != nil { t.Fatalf("conn.Query failed: %v", err) } @@ -84,7 +84,7 @@ func TestConnQueryValues(t *testing.T) { t.Errorf(`Expected values[3] to be %v, but it was %d`, nil, values[3]) } - if values[4] != pgx.OID(rowCount) { + if values[4] != rowCount { t.Errorf(`Expected values[4] to be %d, but it was %d`, rowCount, values[4]) } } @@ -478,9 +478,6 @@ func TestQueryRowCoreTypes(t *testing.T) { if err == nil { t.Errorf("%d. Expected null to cause error, but it didn't (sql -> %v)", i, tt.sql) } - if err != nil && !strings.Contains(err.Error(), "Cannot decode null") && !strings.Contains(err.Error(), "cannot assign") { - t.Errorf(`%d. Expected null to cause error "Cannot decode null..." but it was %v (sql -> %v)`, i, err, tt.sql) - } ensureConnValid(t, conn) } diff --git a/values.go b/values.go index b6848cf5..59d6f3c4 100644 --- a/values.go +++ b/values.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgtype" ) @@ -548,46 +549,75 @@ func (n NullInt32) Encode(w *WriteBuf, oid OID) error { // OID (Object Identifier Type) is, according to https://www.postgresql.org/docs/current/static/datatype-oid.html, // used internally by PostgreSQL as a primary key for various system tables. It is currently implemented // as an unsigned four-byte integer. Its definition can be found in src/include/postgres_ext.h -// in the PostgreSQL sources. +// in the PostgreSQL sources. OID cannot be NULL. To allow for NULL OIDs use pgtype.OID. type OID uint32 -// NullOID represents a Command Identifier (OID) that may be null. NullOID implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. -// -// If Valid is false then the value is NULL. -type NullOID struct { - OID OID - Valid bool // Valid is true if OID is not NULL +func (dst *OID) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + return fmt.Errorf("cannot decode nil into OID") + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + n, err := strconv.ParseUint(string(buf), 10, 32) + if err != nil { + return err + } + + *dst = OID(n) + return nil } -func (n *NullOID) Scan(vr *ValueReader) error { - if vr.Type().DataType != OIDOID { - return SerializationError(fmt.Sprintf("NullOID.Scan cannot decode OID %d", vr.Type().DataType)) +func (dst *OID) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err } - if vr.Len() == -1 { - n.OID, n.Valid = 0, false - return nil + if size == -1 { + return fmt.Errorf("cannot decode nil into OID") } - n.Valid = true - n.OID = decodeOID(vr) - return vr.Err() + + if size != 4 { + return fmt.Errorf("invalid length for OID: %v", size) + } + + n, err := pgio.ReadUint32(r) + if err != nil { + return err + } + + *dst = OID(n) + return nil } -func (n NullOID) FormatCode() int16 { return BinaryFormatCode } - -func (n NullOID) Encode(w *WriteBuf, oid OID) error { - if oid != OIDOID { - return SerializationError(fmt.Sprintf("NullOID.Encode cannot encode into OID %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) +func (src OID) EncodeText(w io.Writer) error { + s := strconv.FormatUint(uint64(src), 10) + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { return nil } + _, err = w.Write([]byte(s)) + return err +} - return encodeOID(w, oid, n.OID) +func (src OID) EncodeBinary(w io.Writer) error { + _, err := pgio.WriteInt32(w, 4) + if err != nil { + return err + } + + _, err = pgio.WriteUint32(w, uint32(src)) + return err } // Tid is PostgreSQL's Tuple Identifier type. @@ -976,8 +1006,6 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { // The name data type goes over the wire using the same format as string, // so just cast to string and use encodeString return encodeString(wbuf, oid, string(arg)) - case OID: - return encodeOID(wbuf, oid, arg) default: if strippedArg, ok := stripNamedType(&refVal); ok { return Encode(wbuf, oid, strippedArg) @@ -1053,8 +1081,6 @@ func Decode(vr *ValueReader, d interface{}) error { case *Name: // name goes over the wire just like text *v = Name(decodeText(vr)) - case *OID: - *v = decodeOID(vr) case *Tid: *v = decodeTid(vr) case *string: @@ -1292,49 +1318,6 @@ func decodeInt4(vr *ValueReader) int32 { return n.Int } -func decodeOID(vr *ValueReader) OID { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into OID")) - return OID(0) - } - - if vr.Type().DataType != OIDOID { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.OID", vr.Type().DataType))) - return OID(0) - } - - // OID needs to decode text format because it is used in loadPgTypes - switch vr.Type().FormatCode { - case TextFormatCode: - s := vr.ReadString(vr.Len()) - n, err := strconv.ParseUint(s, 10, 32) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid OID: %v", s))) - } - return OID(n) - case BinaryFormatCode: - if vr.Len() != 4 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an OID: %d", vr.Len()))) - return OID(0) - } - return OID(vr.ReadInt32()) - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return OID(0) - } -} - -func encodeOID(w *WriteBuf, oid OID, value OID) error { - if oid != OIDOID { - return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.OID", oid) - } - - w.WriteInt32(4) - w.WriteUint32(uint32(value)) - - return nil -} - // Note that we do not match negative numbers, because neither the // BlockNumber nor OffsetNumber of a Tid can be negative. var tidRegexp *regexp.Regexp = regexp.MustCompile(`^\((\d*),(\d*)\)$`) @@ -1764,8 +1747,6 @@ func decodeRecord(vr *ValueReader) []interface{} { record = append(record, decodeInt2(&fieldVR)) case Int4OID: record = append(record, decodeInt4(&fieldVR)) - case OIDOID: - record = append(record, decodeOID(&fieldVR)) case Float4OID: record = append(record, decodeFloat4(&fieldVR)) case Float8OID: diff --git a/values_test.go b/values_test.go index 0283f17d..65811959 100644 --- a/values_test.go +++ b/values_test.go @@ -571,7 +571,6 @@ func TestNullX(t *testing.T) { c pgx.NullChar a pgx.NullAclItem n pgx.NullName - oid pgx.NullOID tid pgx.NullTid i64 pgx.NullInt64 f32 pgx.NullFloat32 @@ -594,9 +593,6 @@ func TestNullX(t *testing.T) { {"select $1::int2", []interface{}{pgx.NullInt16{Int16: 1, Valid: false}}, []interface{}{&actual.i16}, allTypes{i16: pgx.NullInt16{Int16: 0, Valid: false}}}, {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 1, Valid: true}}}, {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: false}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 0, Valid: false}}}, - {"select $1::oid", []interface{}{pgx.NullOID{OID: 1, Valid: true}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOID{OID: 1, Valid: true}}}, - {"select $1::oid", []interface{}{pgx.NullOID{OID: 1, Valid: false}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOID{OID: 0, Valid: false}}}, - {"select $1::oid", []interface{}{pgx.NullOID{OID: 4294967295, Valid: true}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOID{OID: 4294967295, Valid: true}}}, {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 1, Valid: true}}}, {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: false}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 0, Valid: false}}}, {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 255, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 255, Valid: true}}}, From 071f4cc2ad1e18ea8aa6766eef8b4e85473da0a5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 5 Mar 2017 13:47:28 -0600 Subject: [PATCH 089/264] Conn.Close waits for server to close connection --- conn.go | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/conn.go b/conn.go index c55d5618..78f323b1 100644 --- a/conn.go +++ b/conn.go @@ -425,18 +425,38 @@ func (c *Conn) Close() (err error) { } } + defer func() { + c.conn.Close() + c.die(errors.New("Closed")) + if c.shouldLog(LogLevelInfo) { + c.log(LogLevelInfo, "Closed connection") + } + }() + + err = c.conn.SetDeadline(time.Time{}) + if err != nil && c.shouldLog(LogLevelWarn) { + c.log(LogLevelWarn, "Failed to clear deadlines to send close message", "err", err) + return err + } + _, err = c.conn.Write([]byte{'X', 0, 0, 0, 4}) if err != nil && c.shouldLog(LogLevelWarn) { c.log(LogLevelWarn, "Failed to send terminate message", "err", err) + return err } - err = c.conn.Close() - - c.die(errors.New("Closed")) - if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "Closed connection") + err = c.conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + if err != nil && c.shouldLog(LogLevelWarn) { + c.log(LogLevelWarn, "Failed to set read deadline to finish closing", "err", err) + return err } - return err + + _, err = c.conn.Read(make([]byte, 1)) + if err != io.EOF { + return err + } + + return nil } // ParseURI parses a database URI into ConnConfig From 5702f34407be35622dc6c6ea95b6f328795b77e4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 5 Mar 2017 14:00:38 -0600 Subject: [PATCH 090/264] Fix replication with context The normal connection context timeout cancels the current query. That isn't appropriate for a replication connection. --- replication.go | 41 ++++++++++++++++++++++++++++++++++------- replication_test.go | 12 +++++++++--- 2 files changed, 43 insertions(+), 10 deletions(-) diff --git a/replication.go b/replication.go index 9bc4a1a4..a251172d 100644 --- a/replication.go +++ b/replication.go @@ -270,16 +270,43 @@ func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err // // This returns the context error when there is no replication message before // the context is canceled. -func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (r *ReplicationMessage, err error) { - err = rc.c.initContext(ctx) - if err != nil { - return nil, err +func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*ReplicationMessage, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: } - defer func() { - err = rc.c.termContext(err) + + go func() { + select { + case <-ctx.Done(): + if err := rc.c.conn.SetDeadline(time.Now()); err != nil { + rc.Close() // Close connection if unable to set deadline + return + } + rc.c.closedChan <- ctx.Err() + case <-rc.c.doneChan: + } }() - return rc.readReplicationMessage() + r, opErr := rc.readReplicationMessage() + + var err error + select { + case err = <-rc.c.closedChan: + if err := rc.c.conn.SetDeadline(time.Time{}); err != nil { + rc.Close() // Close connection if unable to disable deadline + return nil, err + } + + if opErr == nil { + err = nil + } + case rc.c.doneChan <- struct{}{}: + err = opErr + } + + return r, err } func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) { diff --git a/replication_test.go b/replication_test.go index 43793f3c..1a8063e5 100644 --- a/replication_test.go +++ b/replication_test.go @@ -3,12 +3,13 @@ package pgx_test import ( "context" "fmt" - "github.com/jackc/pgx" "reflect" "strconv" "strings" "testing" "time" + + "github.com/jackc/pgx" ) // This function uses a postgresql 9.6 specific column @@ -47,14 +48,19 @@ func TestSimpleReplicationConnection(t *testing.T) { } conn := mustConnect(t, *replicationConnConfig) - defer closeConn(t, conn) + defer func() { + // Ensure replication slot is destroyed, but don't check for errors as it + // should have already been destroyed. + conn.Exec("select pg_drop_replication_slot('pgx_test')") + closeConn(t, conn) + }() replicationConn := mustReplicationConnect(t, *replicationConnConfig) defer closeReplicationConn(t, replicationConn) err = replicationConn.CreateReplicationSlot("pgx_test", "test_decoding") if err != nil { - t.Logf("replication slot create failed: %v", err) + t.Fatalf("replication slot create failed: %v", err) } // Do a simple change so we can get some wal data From 7b1dbd8558f806e01ccec31f99760393167f7169 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 6 Mar 2017 17:55:20 -0600 Subject: [PATCH 091/264] Move Name to pgtype --- conn.go | 1 + pgtype/name.go | 44 ++++++++++++++++++++ pgtype/name_test.go | 97 +++++++++++++++++++++++++++++++++++++++++++++ values.go | 58 --------------------------- values_test.go | 3 -- 5 files changed, 142 insertions(+), 61 deletions(-) create mode 100644 pgtype/name.go create mode 100644 pgtype/name_test.go diff --git a/conn.go b/conn.go index 78f323b1..023b9d97 100644 --- a/conn.go +++ b/conn.go @@ -287,6 +287,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl Int4OID: &pgtype.Int4{}, Int8ArrayOID: &pgtype.Int8Array{}, Int8OID: &pgtype.Int8{}, + NameOID: &pgtype.Name{}, OIDOID: &pgtype.OID{}, TextArrayOID: &pgtype.TextArray{}, TextOID: &pgtype.Text{}, diff --git a/pgtype/name.go b/pgtype/name.go new file mode 100644 index 00000000..3ff81f12 --- /dev/null +++ b/pgtype/name.go @@ -0,0 +1,44 @@ +package pgtype + +import ( + "io" +) + +// Name is a type used for PostgreSQL's special 63-byte +// name data type, used for identifiers like table names. +// The pg_class.relname column is a good example of where the +// name data type is used. +// +// Note that the underlying Go data type of pgx.Name is string, +// so there is no way to enforce the 63-byte length. Inputting +// a longer name into PostgreSQL will result in silent truncation +// to 63 bytes. +// +// Also, if you have custom-compiled PostgreSQL and set +// NAMEDATALEN to a different value, obviously that number of +// bytes applies, rather than the default 63. +type Name Text + +func (dst *Name) ConvertFrom(src interface{}) error { + return (*Text)(dst).ConvertFrom(src) +} + +func (src *Name) AssignTo(dst interface{}) error { + return (*Text)(src).AssignTo(dst) +} + +func (dst *Name) DecodeText(r io.Reader) error { + return (*Text)(dst).DecodeText(r) +} + +func (dst *Name) DecodeBinary(r io.Reader) error { + return (*Text)(dst).DecodeBinary(r) +} + +func (src Name) EncodeText(w io.Writer) error { + return (Text)(src).EncodeText(w) +} + +func (src Name) EncodeBinary(w io.Writer) error { + return (Text)(src).EncodeBinary(w) +} diff --git a/pgtype/name_test.go b/pgtype/name_test.go new file mode 100644 index 00000000..c5f7de17 --- /dev/null +++ b/pgtype/name_test.go @@ -0,0 +1,97 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestNameTranscode(t *testing.T) { + testSuccessfulTranscode(t, "name", []interface{}{ + pgtype.Name{String: "", Status: pgtype.Present}, + pgtype.Name{String: "foo", Status: pgtype.Present}, + pgtype.Name{Status: pgtype.Null}, + }) +} + +func TestNameConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Name + }{ + {source: "foo", result: pgtype.Name{String: "foo", Status: pgtype.Present}}, + {source: _string("bar"), result: pgtype.Name{String: "bar", Status: pgtype.Present}}, + {source: (*string)(nil), result: pgtype.Name{Status: pgtype.Null}}, + } + + for i, tt := range successfulTests { + var d pgtype.Name + err := d.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if d != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestNameAssignTo(t *testing.T) { + var s string + var ps *string + + simpleTests := []struct { + src pgtype.Name + dst interface{} + expected interface{} + }{ + {src: pgtype.Name{String: "foo", Status: pgtype.Present}, dst: &s, expected: "foo"}, + {src: pgtype.Name{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Name + dst interface{} + expected interface{} + }{ + {src: pgtype.Name{String: "foo", Status: pgtype.Present}, dst: &ps, expected: "foo"}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Name + dst interface{} + }{ + {src: pgtype.Name{Status: pgtype.Null}, dst: &s}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/values.go b/values.go index 59d6f3c4..8e7ef4ac 100644 --- a/values.go +++ b/values.go @@ -371,57 +371,6 @@ func (n NullAclItem) Encode(w *WriteBuf, oid OID) error { return encodeString(w, oid, string(n.AclItem)) } -// Name is a type used for PostgreSQL's special 63-byte -// name data type, used for identifiers like table names. -// The pg_class.relname column is a good example of where the -// name data type is used. -// -// Note that the underlying Go data type of pgx.Name is string, -// so there is no way to enforce the 63-byte length. Inputting -// a longer name into PostgreSQL will result in silent truncation -// to 63 bytes. -// -// Also, if you have custom-compiled PostgreSQL and set -// NAMEDATALEN to a different value, obviously that number of -// bytes applies, rather than the default 63. -type Name string - -// NullName represents a pgx.Name that may be null. NullName implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan for prepared and unprepared queries. -// -// If Valid is false then the value is NULL. -type NullName struct { - Name Name - Valid bool // Valid is true if Name is not NULL -} - -func (n *NullName) Scan(vr *ValueReader) error { - if vr.Type().DataType != NameOID { - return SerializationError(fmt.Sprintf("NullName.Scan cannot decode OID %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Name, n.Valid = "", false - return nil - } - - n.Valid = true - n.Name = Name(decodeText(vr)) - return vr.Err() -} - -func (n NullName) FormatCode() int16 { return TextFormatCode } - -func (n NullName) Encode(w *WriteBuf, oid OID) error { - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeString(w, oid, string(n.Name)) -} - // The pgx.Char type is for PostgreSQL's special 8-bit-only // "char" type more akin to the C language's char type, or Go's byte type. // (Note that the name in PostgreSQL itself is "char", in double-quotes, @@ -1002,10 +951,6 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { // The aclitem data type goes over the wire using the same format as string, // so just cast to string and use encodeString return encodeString(wbuf, oid, string(arg)) - case Name: - // The name data type goes over the wire using the same format as string, - // so just cast to string and use encodeString - return encodeString(wbuf, oid, string(arg)) default: if strippedArg, ok := stripNamedType(&refVal); ok { return Encode(wbuf, oid, strippedArg) @@ -1078,9 +1023,6 @@ func Decode(vr *ValueReader, d interface{}) error { case *AclItem: // aclitem goes over the wire just like text *v = AclItem(decodeText(vr)) - case *Name: - // name goes over the wire just like text - *v = Name(decodeText(vr)) case *Tid: *v = decodeTid(vr) case *string: diff --git a/values_test.go b/values_test.go index 65811959..0e51effe 100644 --- a/values_test.go +++ b/values_test.go @@ -570,7 +570,6 @@ func TestNullX(t *testing.T) { i32 pgx.NullInt32 c pgx.NullChar a pgx.NullAclItem - n pgx.NullName tid pgx.NullTid i64 pgx.NullInt64 f32 pgx.NullFloat32 @@ -596,8 +595,6 @@ func TestNullX(t *testing.T) { {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 1, Valid: true}}}, {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: false}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 0, Valid: false}}}, {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 255, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 255, Valid: true}}}, - {"select $1::name", []interface{}{pgx.NullName{Name: "foo", Valid: true}}, []interface{}{&actual.n}, allTypes{n: pgx.NullName{Name: "foo", Valid: true}}}, - {"select $1::name", []interface{}{pgx.NullName{Name: "foo", Valid: false}}, []interface{}{&actual.n}, allTypes{n: pgx.NullName{Name: "", Valid: false}}}, {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: true}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: true}}}, {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: false}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "", Valid: false}}}, // A tricky (and valid) aclitem can still be used, especially with Go's useful backticks From fa36ad91967c7a90f61cfd8f14231d7b8cfe8785 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 7 Mar 2017 19:39:57 -0600 Subject: [PATCH 092/264] Move "char" to pgtype --- conn.go | 1 + pgtype/pgtype_test.go | 22 +++++-- pgtype/qchar.go | 144 ++++++++++++++++++++++++++++++++++++++++++ pgtype/qchar_test.go | 140 ++++++++++++++++++++++++++++++++++++++++ values.go | 80 ----------------------- values_test.go | 4 -- 6 files changed, 300 insertions(+), 91 deletions(-) create mode 100644 pgtype/qchar.go create mode 100644 pgtype/qchar_test.go diff --git a/conn.go b/conn.go index 023b9d97..f9f94c43 100644 --- a/conn.go +++ b/conn.go @@ -270,6 +270,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.oidPgtypeValues = map[OID]pgtype.Value{ BoolArrayOID: &pgtype.BoolArray{}, BoolOID: &pgtype.Bool{}, + CharOID: &pgtype.QChar{}, CIDOID: &pgtype.CID{}, CidrArrayOID: &pgtype.CidrArray{}, CidrOID: &pgtype.Inet{}, diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 304fd0ea..c1dba383 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -74,12 +74,15 @@ func (f forceBinaryEncoder) EncodeBinary(w io.Writer) error { func forceEncoder(e interface{}, formatCode int16) interface{} { switch formatCode { case pgx.TextFormatCode: - return forceTextEncoder{e: e.(pgtype.TextEncoder)} + if e, ok := e.(pgtype.TextEncoder); ok { + return forceTextEncoder{e: e} + } case pgx.BinaryFormatCode: - return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)} - default: - panic("bad encoder") + if e, ok := e.(pgtype.BinaryEncoder); ok { + return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)} + } } + return nil } func testSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface{}) { @@ -105,9 +108,14 @@ func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []int {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, } - for _, fc := range formats { - ps.FieldDescriptions[0].FormatCode = fc.formatCode - for i, v := range values { + for i, v := range values { + for _, fc := range formats { + ps.FieldDescriptions[0].FormatCode = fc.formatCode + vEncoder := forceEncoder(v, fc.formatCode) + if vEncoder == nil { + t.Logf("%v does not implement %v", fc.name) + continue + } // Derefence value if it is a pointer derefV := v refVal := reflect.ValueOf(v) diff --git a/pgtype/qchar.go b/pgtype/qchar.go new file mode 100644 index 00000000..6dd14625 --- /dev/null +++ b/pgtype/qchar.go @@ -0,0 +1,144 @@ +package pgtype + +import ( + "fmt" + "io" + "math" + "strconv" + + "github.com/jackc/pgx/pgio" +) + +// QChar is for PostgreSQL's special 8-bit-only "char" type more akin to the C +// language's char type, or Go's byte type. (Note that the name in PostgreSQL +// itself is "char", in double-quotes, and not char.) It gets used a lot in +// PostgreSQL's system tables to hold a single ASCII character value (eg +// pg_class.relkind). It is named Qchar for quoted char to disambiguate from SQL +// standard type char. +// +// Not all possible values of QChar are representable in the text format. +// Therefore, QChar does not implement TextEncoder and TextDecoder. +type QChar struct { + Int int8 + Status Status +} + +func (dst *QChar) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case QChar: + *dst = value + case int8: + *dst = QChar{Int: value, Status: Present} + case uint8: + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case int16: + if value < math.MinInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case uint16: + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case int32: + if value < math.MinInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case uint32: + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case int64: + if value < math.MinInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case uint64: + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case int: + if value < math.MinInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case uint: + if value > math.MaxInt8 { + return fmt.Errorf("%d is greater than maximum value for QChar", value) + } + *dst = QChar{Int: int8(value), Status: Present} + case string: + num, err := strconv.ParseInt(value, 10, 8) + if err != nil { + return err + } + *dst = QChar{Int: int8(num), Status: Present} + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to QChar", value) + } + + return nil +} + +func (src *QChar) AssignTo(dst interface{}) error { + return int64AssignTo(int64(src.Int), src.Status, dst) +} + +func (dst *QChar) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = QChar{Status: Null} + return nil + } + + if size != 1 { + return fmt.Errorf(`invalid length for "char": %v`, size) + } + + byt, err := pgio.ReadByte(r) + if err != nil { + return err + } + + *dst = QChar{Int: int8(byt), Status: Present} + return nil +} + +func (src QChar) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, 1) + if err != nil { + return nil + } + + return pgio.WriteByte(w, byte(src.Int)) +} diff --git a/pgtype/qchar_test.go b/pgtype/qchar_test.go new file mode 100644 index 00000000..ea7b56a8 --- /dev/null +++ b/pgtype/qchar_test.go @@ -0,0 +1,140 @@ +package pgtype_test + +import ( + "math" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestQCharTranscode(t *testing.T) { + testSuccessfulTranscode(t, `"char"`, []interface{}{ + pgtype.QChar{Int: math.MinInt8, Status: pgtype.Present}, + pgtype.QChar{Int: -1, Status: pgtype.Present}, + pgtype.QChar{Int: 0, Status: pgtype.Present}, + pgtype.QChar{Int: 1, Status: pgtype.Present}, + pgtype.QChar{Int: math.MaxInt8, Status: pgtype.Present}, + pgtype.QChar{Int: 0, Status: pgtype.Null}, + }) +} + +func TestQCharConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.QChar + }{ + {source: int8(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: int16(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: int32(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: int64(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: int8(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, + {source: int16(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, + {source: int32(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, + {source: int64(-1), result: pgtype.QChar{Int: -1, Status: pgtype.Present}}, + {source: uint8(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: uint16(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: uint64(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: "1", result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + {source: _int8(1), result: pgtype.QChar{Int: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.QChar + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestQCharAssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + + simpleTests := []struct { + src pgtype.QChar + dst interface{} + expected interface{} + }{ + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i8, expected: int8(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: pgtype.QChar{Int: 0, Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: pgtype.QChar{Int: 0, Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.QChar + dst interface{} + expected interface{} + }{ + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &pi8, expected: int8(42)}, + {src: pgtype.QChar{Int: 42, Status: pgtype.Present}, dst: &_pi8, expected: _int8(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.QChar + dst interface{} + }{ + {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui8}, + {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui16}, + {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui32}, + {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui64}, + {src: pgtype.QChar{Int: -1, Status: pgtype.Present}, dst: &ui}, + {src: pgtype.QChar{Int: 0, Status: pgtype.Null}, dst: &i16}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/values.go b/values.go index 8e7ef4ac..c724aa39 100644 --- a/values.go +++ b/values.go @@ -371,52 +371,6 @@ func (n NullAclItem) Encode(w *WriteBuf, oid OID) error { return encodeString(w, oid, string(n.AclItem)) } -// The pgx.Char type is for PostgreSQL's special 8-bit-only -// "char" type more akin to the C language's char type, or Go's byte type. -// (Note that the name in PostgreSQL itself is "char", in double-quotes, -// and not char.) It gets used a lot in PostgreSQL's system tables to hold -// a single ASCII character value (eg pg_class.relkind). -type Char byte - -// NullChar represents a pgx.Char that may be null. NullChar implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan for prepared and unprepared queries. -// -// If Valid is false then the value is NULL. -type NullChar struct { - Char Char - Valid bool // Valid is true if Char is not NULL -} - -func (n *NullChar) Scan(vr *ValueReader) error { - if vr.Type().DataType != CharOID { - return SerializationError(fmt.Sprintf("NullChar.Scan cannot decode OID %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Char, n.Valid = 0, false - return nil - } - n.Valid = true - n.Char = decodeChar(vr) - return vr.Err() -} - -func (n NullChar) FormatCode() int16 { return BinaryFormatCode } - -func (n NullChar) Encode(w *WriteBuf, oid OID) error { - if oid != CharOID { - return SerializationError(fmt.Sprintf("NullChar.Encode cannot encode into OID %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeChar(w, oid, n.Char) -} - // NullInt16 represents a smallint that may be null. NullInt16 implements the // Scanner and Encoder interfaces so it may be used both as an argument to // Query[Row] and a destination for Scan for prepared and unprepared queries. @@ -945,8 +899,6 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { } switch arg := arg.(type) { - case Char: - return encodeChar(wbuf, oid, arg) case AclItem: // The aclitem data type goes over the wire using the same format as string, // so just cast to string and use encodeString @@ -1018,8 +970,6 @@ func decodeByOID(vr *ValueReader) (interface{}, error) { // decoding to the built-in functionality. func Decode(vr *ValueReader, d interface{}) error { switch v := d.(type) { - case *Char: - *v = decodeChar(vr) case *AclItem: // aclitem goes over the wire just like text *v = AclItem(decodeText(vr)) @@ -1158,30 +1108,6 @@ func decodeInt8(vr *ValueReader) int64 { return n.Int } -func decodeChar(vr *ValueReader) Char { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into char")) - return Char(0) - } - - if vr.Type().DataType != CharOID { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into char", vr.Type().DataType))) - return Char(0) - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return Char(0) - } - - if vr.Len() != 1 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for a char: %d", vr.Len()))) - return Char(0) - } - - return Char(vr.ReadByte()) -} - func decodeInt2(vr *ValueReader) int16 { if vr.Type().DataType != Int2OID { @@ -1216,12 +1142,6 @@ func decodeInt2(vr *ValueReader) int16 { return n.Int } -func encodeChar(w *WriteBuf, oid OID, value Char) error { - w.WriteInt32(1) - w.WriteByte(byte(value)) - return nil -} - func decodeInt4(vr *ValueReader) int32 { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into int32")) diff --git a/values_test.go b/values_test.go index 0e51effe..4c02ac0a 100644 --- a/values_test.go +++ b/values_test.go @@ -568,7 +568,6 @@ func TestNullX(t *testing.T) { s pgx.NullString i16 pgx.NullInt16 i32 pgx.NullInt32 - c pgx.NullChar a pgx.NullAclItem tid pgx.NullTid i64 pgx.NullInt64 @@ -592,9 +591,6 @@ func TestNullX(t *testing.T) { {"select $1::int2", []interface{}{pgx.NullInt16{Int16: 1, Valid: false}}, []interface{}{&actual.i16}, allTypes{i16: pgx.NullInt16{Int16: 0, Valid: false}}}, {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 1, Valid: true}}}, {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: false}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 0, Valid: false}}}, - {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 1, Valid: true}}}, - {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: false}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 0, Valid: false}}}, - {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 255, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 255, Valid: true}}}, {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: true}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: true}}}, {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: false}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "", Valid: false}}}, // A tricky (and valid) aclitem can still be used, especially with Go's useful backticks From bac4af13bb9df8725ea56e4aa709f8ad17bd7a0d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 9 Mar 2017 21:07:40 -0600 Subject: [PATCH 093/264] Add bytea --- pgtype/bytea.go | 160 ++++++++++++++++++++++++++++++++++++++++++ pgtype/bytea_test.go | 73 +++++++++++++++++++ pgtype/convert.go | 21 ++++++ pgtype/pgtype_test.go | 1 + 4 files changed, 255 insertions(+) create mode 100644 pgtype/bytea.go create mode 100644 pgtype/bytea_test.go diff --git a/pgtype/bytea.go b/pgtype/bytea.go new file mode 100644 index 00000000..2532182f --- /dev/null +++ b/pgtype/bytea.go @@ -0,0 +1,160 @@ +package pgtype + +import ( + "encoding/hex" + "fmt" + "io" + "reflect" + + "github.com/jackc/pgx/pgio" +) + +type Bytea struct { + Bytes []byte + Status Status +} + +func (dst *Bytea) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Bytea: + *dst = value + case []byte: + if value != nil { + *dst = Bytea{Bytes: value, Status: Present} + } else { + *dst = Bytea{Status: Null} + } + default: + if originalSrc, ok := underlyingBytesType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Bytea", value) + } + + return nil +} + +func (src *Bytea) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *[]byte: + if src.Status == Present { + *v = src.Bytes + } else { + *v = nil + } + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if src.Status == Null { + el.Set(reflect.Zero(el.Type())) + return nil + } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return src.AssignTo(el.Interface()) + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + } + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +// DecodeText only supports the hex format. This has been the default since +// PostgreSQL 9.0. +func (dst *Bytea) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Bytea{Status: Null} + return nil + } + + sbuf := make([]byte, int(size)) + _, err = io.ReadFull(r, sbuf) + if err != nil { + return err + } + + if len(sbuf) < 2 || sbuf[0] != '\\' || sbuf[1] != 'x' { + return fmt.Errorf("invalid hex format") + } + + buf := make([]byte, (len(sbuf)-2)/2) + _, err = hex.Decode(buf, sbuf[2:]) + if err != nil { + return err + } + + *dst = Bytea{Bytes: buf, Status: Present} + return nil +} + +func (dst *Bytea) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Bytea{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + *dst = Bytea{Bytes: buf, Status: Present} + return nil +} + +func (src Bytea) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + str := hex.EncodeToString(src.Bytes) + + _, err := pgio.WriteInt32(w, int32(len(str)+2)) + if err != nil { + return nil + } + + _, err = io.WriteString(w, `\x`) + if err != nil { + return nil + } + + _, err = io.WriteString(w, str) + return err +} + +func (src Bytea) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, int32(len(src.Bytes))) + if err != nil { + return nil + } + + _, err = w.Write(src.Bytes) + return err +} diff --git a/pgtype/bytea_test.go b/pgtype/bytea_test.go new file mode 100644 index 00000000..51941387 --- /dev/null +++ b/pgtype/bytea_test.go @@ -0,0 +1,73 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestByteaTranscode(t *testing.T) { + testSuccessfulTranscode(t, "bytea", []interface{}{ + pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: nil, Status: pgtype.Null}, + }) +} + +func TestByteaConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Bytea + }{ + {source: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Null}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Null}}, + {source: []byte{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, + {source: []byte{}, result: pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}}, + {source: []byte(nil), result: pgtype.Bytea{Status: pgtype.Null}}, + {source: _byteSlice{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, + {source: _byteSlice(nil), result: pgtype.Bytea{Status: pgtype.Null}}, + } + + for i, tt := range successfulTests { + var r pgtype.Bytea + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestByteaAssignTo(t *testing.T) { + var buf []byte + var _buf _byteSlice + var pbuf *[]byte + var _pbuf *_byteSlice + + simpleTests := []struct { + src pgtype.Bytea + dst interface{} + expected interface{} + }{ + {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &buf, expected: []byte{1, 2, 3}}, + {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &_buf, expected: _byteSlice{1, 2, 3}}, + {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &pbuf, expected: &[]byte{1, 2, 3}}, + {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &_pbuf, expected: &_byteSlice{1, 2, 3}}, + {src: pgtype.Bytea{Status: pgtype.Null}, dst: &pbuf, expected: ((*[]byte)(nil))}, + {src: pgtype.Bytea{Status: pgtype.Null}, dst: &_pbuf, expected: ((*_byteSlice)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/convert.go b/pgtype/convert.go index 31bbf060..648209f5 100644 --- a/pgtype/convert.go +++ b/pgtype/convert.go @@ -85,6 +85,27 @@ func underlyingBoolType(val interface{}) (interface{}, bool) { return nil, false } +// underlyingBytesType gets the underlying type that can be converted to []byte +func underlyingBytesType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Slice: + if refVal.Type().Elem().Kind() == reflect.Uint8 { + convVal := refVal.Bytes() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + } + + return nil, false +} + // underlyingStringType gets the underlying type that can be converted to String func underlyingStringType(val interface{}) (interface{}, bool) { refVal := reflect.ValueOf(val) diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index c1dba383..6e173cbe 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -22,6 +22,7 @@ type _int32Slice []int32 type _int64Slice []int64 type _float32Slice []float32 type _float64Slice []float64 +type _byteSlice []byte func mustConnectPgx(t testing.TB) *pgx.Conn { config, err := pgx.ParseURI(os.Getenv("DATABASE_URL")) From ac9228a1a39d47a68d910b67b271e624625028b7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 9 Mar 2017 21:09:36 -0600 Subject: [PATCH 094/264] Fix typed_array_gen.sh typo --- pgtype/typed_array_gen.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh index f984e12e..1e2dce64 100644 --- a/pgtype/typed_array_gen.sh +++ b/pgtype/typed_array_gen.sh @@ -1,6 +1,6 @@ erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_oid=Int2OID typed_array.go.erb > int2array.go erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32 element_oid=Int4OID typed_array.go.erb > int4array.go -erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_oid=Int8OID typed_array.go.erb > int2array.go +erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_oid=Int8OID typed_array.go.erb > int8array.go erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_oid=BoolOID typed_array.go.erb > boolarray.go erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_oid=DateOID typed_array.go.erb > datearray.go erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_oid=TimestamptzOID typed_array.go.erb > timestamptzarray.go From 81626342596a5655b2a2239817c45ee29357ebce Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 10 Mar 2017 16:08:47 -0600 Subject: [PATCH 095/264] Decode(Text|Binary) now accepts []byte instead of io.Reader --- pgtype/array.go | 46 ++++++++++++------------ pgtype/bool.go | 40 ++++++--------------- pgtype/boolarray.go | 52 +++++++++++++-------------- pgtype/bytea.go | 39 +++++--------------- pgtype/cid.go | 8 ++--- pgtype/cidrarray.go | 8 ++--- pgtype/date.go | 36 +++++-------------- pgtype/datearray.go | 52 +++++++++++++-------------- pgtype/float4.go | 36 +++++-------------- pgtype/float4array.go | 52 +++++++++++++-------------- pgtype/float8.go | 36 +++++-------------- pgtype/float8array.go | 52 +++++++++++++-------------- pgtype/inet.go | 60 +++++++------------------------ pgtype/inetarray.go | 52 +++++++++++++-------------- pgtype/int2.go | 39 ++++++-------------- pgtype/int2array.go | 52 +++++++++++++-------------- pgtype/int4.go | 37 +++++-------------- pgtype/int4array.go | 52 +++++++++++++-------------- pgtype/int8.go | 36 +++++-------------- pgtype/int8array.go | 52 +++++++++++++-------------- pgtype/name.go | 8 ++--- pgtype/oid.go | 8 ++--- pgtype/pgtype.go | 4 +-- pgtype/pguint32.go | 37 +++++-------------- pgtype/qchar.go | 20 +++-------- pgtype/text.go | 21 +++-------- pgtype/textarray.go | 53 ++++++++++++++------------- pgtype/timestamp.go | 36 +++++-------------- pgtype/timestamparray.go | 52 +++++++++++++-------------- pgtype/timestamptz.go | 36 +++++-------------- pgtype/timestamptzarray.go | 52 +++++++++++++-------------- pgtype/to-consider.txt | 9 +++++ pgtype/typed_array.go.erb | 58 ++++++++++++------------------ pgtype/varchararray.go | 8 ++--- pgtype/xid.go | 8 ++--- query.go | 12 +++---- value_reader.go | 29 +++------------ values.go | 73 +++++++++++--------------------------- 38 files changed, 506 insertions(+), 855 deletions(-) create mode 100644 pgtype/to-consider.txt diff --git a/pgtype/array.go b/pgtype/array.go index 76492c61..6b705103 100644 --- a/pgtype/array.go +++ b/pgtype/array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" "strconv" @@ -25,40 +26,37 @@ type ArrayDimension struct { LowerBound int32 } -func (dst *ArrayHeader) DecodeBinary(r io.Reader) error { - numDims, err := pgio.ReadInt32(r) - if err != nil { - return err +func (dst *ArrayHeader) DecodeBinary(src []byte) (int, error) { + if len(src) < 12 { + return 0, fmt.Errorf("array header too short: %d", len(src)) } + rp := 0 + + numDims := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + dst.ContainsNull = binary.BigEndian.Uint32(src[rp:]) == 1 + rp += 4 + + dst.ElementOID = int32(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + if numDims > 0 { dst.Dimensions = make([]ArrayDimension, numDims) } - - containsNull, err := pgio.ReadInt32(r) - if err != nil { - return err + if len(src) < 12+numDims*8 { + return 0, fmt.Errorf("array header too short for %d dimensions: %d", numDims, len(src)) } - dst.ContainsNull = containsNull == 1 - - dst.ElementOID, err = pgio.ReadInt32(r) - if err != nil { - return err - } - for i := range dst.Dimensions { - dst.Dimensions[i].Length, err = pgio.ReadInt32(r) - if err != nil { - return err - } + dst.Dimensions[i].Length = int32(binary.BigEndian.Uint32(src[rp:])) + rp += 4 - dst.Dimensions[i].LowerBound, err = pgio.ReadInt32(r) - if err != nil { - return err - } + dst.Dimensions[i].LowerBound = int32(binary.BigEndian.Uint32(src[rp:])) + rp += 4 } - return nil + return rp, nil } func (src *ArrayHeader) EncodeBinary(w io.Writer) error { diff --git a/pgtype/bool.go b/pgtype/bool.go index 076403f9..b7bc14d0 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -72,51 +72,31 @@ func (src *Bool) AssignTo(dst interface{}) error { return nil } -func (dst *Bool) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Bool) DecodeText(src []byte) error { + if src == nil { *dst = Bool{Status: Null} return nil } - if size != 1 { - return fmt.Errorf("invalid length for bool: %v", size) + if len(src) != 1 { + return fmt.Errorf("invalid length for bool: %v", len(src)) } - byt, err := pgio.ReadByte(r) - if err != nil { - return err - } - - *dst = Bool{Bool: byt == 't', Status: Present} + *dst = Bool{Bool: src[0] == 't', Status: Present} return nil } -func (dst *Bool) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Bool) DecodeBinary(src []byte) error { + if src == nil { *dst = Bool{Status: Null} return nil } - if size != 1 { - return fmt.Errorf("invalid length for bool: %v", size) + if len(src) != 1 { + return fmt.Errorf("invalid length for bool: %v", len(src)) } - byt, err := pgio.ReadByte(r) - if err != nil { - return err - } - - *dst = Bool{Bool: byt == 1, Status: Present} + *dst = Bool{Bool: src[0] == 1, Status: Present} return nil } diff --git a/pgtype/boolarray.go b/pgtype/boolarray.go index b6b5db02..a9b8bf50 100644 --- a/pgtype/boolarray.go +++ b/pgtype/boolarray.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" @@ -73,29 +74,17 @@ func (src *BoolArray) AssignTo(dst interface{}) error { return nil } -func (dst *BoolArray) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *BoolArray) DecodeText(src []byte) error { + if src == nil { *dst = BoolArray{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Bool if len(uta.Elements) > 0 { @@ -103,8 +92,11 @@ func (dst *BoolArray) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Bool - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -118,19 +110,14 @@ func (dst *BoolArray) DecodeText(r io.Reader) error { return nil } -func (dst *BoolArray) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *BoolArray) DecodeBinary(src []byte) error { + if src == nil { *dst = BoolArray{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -148,7 +135,14 @@ func (dst *BoolArray) DecodeBinary(r io.Reader) error { elements := make([]Bool, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -236,6 +230,10 @@ func (src *BoolArray) EncodeText(w io.Writer) error { } func (src *BoolArray) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, BoolOID) +} + +func (src *BoolArray) encodeBinary(w io.Writer, elementOID int32) error { if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -256,7 +254,7 @@ func (src *BoolArray) EncodeBinary(w io.Writer) error { } } - arrayHeader.ElementOID = BoolOID + arrayHeader.ElementOID = elementOID arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - diff --git a/pgtype/bytea.go b/pgtype/bytea.go index 2532182f..db20482f 100644 --- a/pgtype/bytea.go +++ b/pgtype/bytea.go @@ -71,29 +71,18 @@ func (src *Bytea) AssignTo(dst interface{}) error { // DecodeText only supports the hex format. This has been the default since // PostgreSQL 9.0. -func (dst *Bytea) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Bytea) DecodeText(src []byte) error { + if src == nil { *dst = Bytea{Status: Null} return nil } - sbuf := make([]byte, int(size)) - _, err = io.ReadFull(r, sbuf) - if err != nil { - return err - } - - if len(sbuf) < 2 || sbuf[0] != '\\' || sbuf[1] != 'x' { + if len(src) < 2 || src[0] != '\\' || src[1] != 'x' { return fmt.Errorf("invalid hex format") } - buf := make([]byte, (len(sbuf)-2)/2) - _, err = hex.Decode(buf, sbuf[2:]) + buf := make([]byte, (len(src)-2)/2) + _, err := hex.Decode(buf, src[2:]) if err != nil { return err } @@ -102,25 +91,13 @@ func (dst *Bytea) DecodeText(r io.Reader) error { return nil } -func (dst *Bytea) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Bytea) DecodeBinary(src []byte) error { + if src == nil { *dst = Bytea{Status: Null} return nil } - buf := make([]byte, int(size)) - - _, err = io.ReadFull(r, buf) - if err != nil { - return err - } - - *dst = Bytea{Bytes: buf, Status: Present} + *dst = Bytea{Bytes: src, Status: Present} return nil } diff --git a/pgtype/cid.go b/pgtype/cid.go index 21d6fb80..f8d706d0 100644 --- a/pgtype/cid.go +++ b/pgtype/cid.go @@ -30,12 +30,12 @@ func (src *CID) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *CID) DecodeText(r io.Reader) error { - return (*pguint32)(dst).DecodeText(r) +func (dst *CID) DecodeText(src []byte) error { + return (*pguint32)(dst).DecodeText(src) } -func (dst *CID) DecodeBinary(r io.Reader) error { - return (*pguint32)(dst).DecodeBinary(r) +func (dst *CID) DecodeBinary(src []byte) error { + return (*pguint32)(dst).DecodeBinary(src) } func (src CID) EncodeText(w io.Writer) error { diff --git a/pgtype/cidrarray.go b/pgtype/cidrarray.go index 66dd20d0..d95eef4a 100644 --- a/pgtype/cidrarray.go +++ b/pgtype/cidrarray.go @@ -14,12 +14,12 @@ func (src *CidrArray) AssignTo(dst interface{}) error { return (*InetArray)(src).AssignTo(dst) } -func (dst *CidrArray) DecodeText(r io.Reader) error { - return (*InetArray)(dst).DecodeText(r) +func (dst *CidrArray) DecodeText(src []byte) error { + return (*InetArray)(dst).DecodeText(src) } -func (dst *CidrArray) DecodeBinary(r io.Reader) error { - return (*InetArray)(dst).DecodeBinary(r) +func (dst *CidrArray) DecodeBinary(src []byte) error { + return (*InetArray)(dst).DecodeBinary(src) } func (src *CidrArray) EncodeText(w io.Writer) error { diff --git a/pgtype/date.go b/pgtype/date.go index 307f1e59..1bb81d35 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -1,6 +1,7 @@ package pgtype import ( + "encoding/binary" "fmt" "io" "reflect" @@ -66,24 +67,13 @@ func (src *Date) AssignTo(dst interface{}) error { return nil } -func (dst *Date) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Date) DecodeText(src []byte) error { + if src == nil { *dst = Date{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - sbuf := string(buf) + sbuf := string(src) switch sbuf { case "infinity": *dst = Date{Status: Present, InfinityModifier: Infinity} @@ -101,25 +91,17 @@ func (dst *Date) DecodeText(r io.Reader) error { return nil } -func (dst *Date) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Date) DecodeBinary(src []byte) error { + if src == nil { *dst = Date{Status: Null} return nil } - if size != 4 { - return fmt.Errorf("invalid length for date: %v", size) + if len(src) != 4 { + return fmt.Errorf("invalid length for date: %v", len(src)) } - dayOffset, err := pgio.ReadInt32(r) - if err != nil { - return err - } + dayOffset := int32(binary.BigEndian.Uint32(src)) switch dayOffset { case infinityDayOffset: diff --git a/pgtype/datearray.go b/pgtype/datearray.go index 5e93501e..e9ad1f62 100644 --- a/pgtype/datearray.go +++ b/pgtype/datearray.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" "time" @@ -74,29 +75,17 @@ func (src *DateArray) AssignTo(dst interface{}) error { return nil } -func (dst *DateArray) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *DateArray) DecodeText(src []byte) error { + if src == nil { *dst = DateArray{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Date if len(uta.Elements) > 0 { @@ -104,8 +93,11 @@ func (dst *DateArray) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Date - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -119,19 +111,14 @@ func (dst *DateArray) DecodeText(r io.Reader) error { return nil } -func (dst *DateArray) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *DateArray) DecodeBinary(src []byte) error { + if src == nil { *dst = DateArray{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -149,7 +136,14 @@ func (dst *DateArray) DecodeBinary(r io.Reader) error { elements := make([]Date, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -237,6 +231,10 @@ func (src *DateArray) EncodeText(w io.Writer) error { } func (src *DateArray) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, DateOID) +} + +func (src *DateArray) encodeBinary(w io.Writer, elementOID int32) error { if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -257,7 +255,7 @@ func (src *DateArray) EncodeBinary(w io.Writer) error { } } - arrayHeader.ElementOID = DateOID + arrayHeader.ElementOID = elementOID arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - diff --git a/pgtype/float4.go b/pgtype/float4.go index a1e5aa18..fb0415e5 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -1,6 +1,7 @@ package pgtype import ( + "encoding/binary" "fmt" "io" "math" @@ -92,24 +93,13 @@ func (src *Float4) AssignTo(dst interface{}) error { return float64AssignTo(float64(src.Float), src.Status, dst) } -func (dst *Float4) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Float4) DecodeText(src []byte) error { + if src == nil { *dst = Float4{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - n, err := strconv.ParseFloat(string(buf), 32) + n, err := strconv.ParseFloat(string(src), 32) if err != nil { return err } @@ -118,25 +108,17 @@ func (dst *Float4) DecodeText(r io.Reader) error { return nil } -func (dst *Float4) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Float4) DecodeBinary(src []byte) error { + if src == nil { *dst = Float4{Status: Null} return nil } - if size != 4 { - return fmt.Errorf("invalid length for float4: %v", size) + if len(src) != 4 { + return fmt.Errorf("invalid length for float4: %v", len(src)) } - n, err := pgio.ReadInt32(r) - if err != nil { - return err - } + n := int32(binary.BigEndian.Uint32(src)) *dst = Float4{Float: math.Float32frombits(uint32(n)), Status: Present} return nil diff --git a/pgtype/float4array.go b/pgtype/float4array.go index 8834d213..a4a72146 100644 --- a/pgtype/float4array.go +++ b/pgtype/float4array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" @@ -73,29 +74,17 @@ func (src *Float4Array) AssignTo(dst interface{}) error { return nil } -func (dst *Float4Array) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Float4Array) DecodeText(src []byte) error { + if src == nil { *dst = Float4Array{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Float4 if len(uta.Elements) > 0 { @@ -103,8 +92,11 @@ func (dst *Float4Array) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Float4 - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -118,19 +110,14 @@ func (dst *Float4Array) DecodeText(r io.Reader) error { return nil } -func (dst *Float4Array) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Float4Array) DecodeBinary(src []byte) error { + if src == nil { *dst = Float4Array{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -148,7 +135,14 @@ func (dst *Float4Array) DecodeBinary(r io.Reader) error { elements := make([]Float4, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -236,6 +230,10 @@ func (src *Float4Array) EncodeText(w io.Writer) error { } func (src *Float4Array) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, Float4OID) +} + +func (src *Float4Array) encodeBinary(w io.Writer, elementOID int32) error { if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -256,7 +254,7 @@ func (src *Float4Array) EncodeBinary(w io.Writer) error { } } - arrayHeader.ElementOID = Float4OID + arrayHeader.ElementOID = elementOID arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - diff --git a/pgtype/float8.go b/pgtype/float8.go index c1347cb2..a53de5e3 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -1,6 +1,7 @@ package pgtype import ( + "encoding/binary" "fmt" "io" "math" @@ -82,24 +83,13 @@ func (src *Float8) AssignTo(dst interface{}) error { return float64AssignTo(src.Float, src.Status, dst) } -func (dst *Float8) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Float8) DecodeText(src []byte) error { + if src == nil { *dst = Float8{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - n, err := strconv.ParseFloat(string(buf), 64) + n, err := strconv.ParseFloat(string(src), 64) if err != nil { return err } @@ -108,25 +98,17 @@ func (dst *Float8) DecodeText(r io.Reader) error { return nil } -func (dst *Float8) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Float8) DecodeBinary(src []byte) error { + if src == nil { *dst = Float8{Status: Null} return nil } - if size != 8 { - return fmt.Errorf("invalid length for float4: %v", size) + if len(src) != 8 { + return fmt.Errorf("invalid length for float4: %v", len(src)) } - n, err := pgio.ReadInt64(r) - if err != nil { - return err - } + n := int64(binary.BigEndian.Uint64(src)) *dst = Float8{Float: math.Float64frombits(uint64(n)), Status: Present} return nil diff --git a/pgtype/float8array.go b/pgtype/float8array.go index bad9ed9f..082e817d 100644 --- a/pgtype/float8array.go +++ b/pgtype/float8array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" @@ -73,29 +74,17 @@ func (src *Float8Array) AssignTo(dst interface{}) error { return nil } -func (dst *Float8Array) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Float8Array) DecodeText(src []byte) error { + if src == nil { *dst = Float8Array{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Float8 if len(uta.Elements) > 0 { @@ -103,8 +92,11 @@ func (dst *Float8Array) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Float8 - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -118,19 +110,14 @@ func (dst *Float8Array) DecodeText(r io.Reader) error { return nil } -func (dst *Float8Array) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Float8Array) DecodeBinary(src []byte) error { + if src == nil { *dst = Float8Array{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -148,7 +135,14 @@ func (dst *Float8Array) DecodeBinary(r io.Reader) error { elements := make([]Float8, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -236,6 +230,10 @@ func (src *Float8Array) EncodeText(w io.Writer) error { } func (src *Float8Array) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, Float8OID) +} + +func (src *Float8Array) encodeBinary(w io.Writer, elementOID int32) error { if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -256,7 +254,7 @@ func (src *Float8Array) EncodeBinary(w io.Writer) error { } } - arrayHeader.ElementOID = Float8OID + arrayHeader.ElementOID = elementOID arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - diff --git a/pgtype/inet.go b/pgtype/inet.go index e47c64b0..132a876a 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -91,26 +91,16 @@ func (src *Inet) AssignTo(dst interface{}) error { return nil } -func (dst *Inet) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Inet) DecodeText(src []byte) error { + if src == nil { *dst = Inet{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) - if err != nil { - return err - } - var ipnet *net.IPNet + var err error - if ip := net.ParseIP(string(buf)); ip != nil { + if ip := net.ParseIP(string(src)); ip != nil { ipv4 := ip.To4() if ipv4 != nil { ip = ipv4 @@ -119,7 +109,7 @@ func (dst *Inet) DecodeText(r io.Reader) error { mask := net.CIDRMask(bitCount, bitCount) ipnet = &net.IPNet{Mask: mask, IP: ip} } else { - _, ipnet, err = net.ParseCIDR(string(buf)) + _, ipnet, err = net.ParseCIDR(string(src)) if err != nil { return err } @@ -129,50 +119,24 @@ func (dst *Inet) DecodeText(r io.Reader) error { return nil } -func (dst *Inet) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Inet) DecodeBinary(src []byte) error { + if src == nil { *dst = Inet{Status: Null} return nil } - if size != 8 && size != 20 { - return fmt.Errorf("Received an invalid size for a inet: %d", size) + if len(src) != 8 && len(src) != 20 { + return fmt.Errorf("Received an invalid size for a inet: %d", len(src)) } // ignore family - _, err = pgio.ReadByte(r) - if err != nil { - return err - } - - bits, err := pgio.ReadByte(r) - if err != nil { - return err - } - + bits := src[1] // ignore is_cidr - _, err = pgio.ReadByte(r) - if err != nil { - return err - } - - addressLength, err := pgio.ReadByte(r) - if err != nil { - return err - } + addressLength := src[3] var ipnet net.IPNet ipnet.IP = make(net.IP, int(addressLength)) - _, err = r.Read(ipnet.IP) - if err != nil { - return err - } - + copy(ipnet.IP, src[4:]) ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8) *dst = Inet{IPNet: &ipnet, Status: Present} diff --git a/pgtype/inetarray.go b/pgtype/inetarray.go index cd12e917..28de736f 100644 --- a/pgtype/inetarray.go +++ b/pgtype/inetarray.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" "net" @@ -19,8 +20,7 @@ func (dst *InetArray) ConvertFrom(src interface{}) error { switch value := src.(type) { case InetArray: *dst = value - case CidrArray: - *dst = InetArray(value) + case []*net.IPNet: if value == nil { *dst = InetArray{Status: Null} @@ -39,6 +39,7 @@ func (dst *InetArray) ConvertFrom(src interface{}) error { Status: Present, } } + case []net.IP: if value == nil { *dst = InetArray{Status: Null} @@ -57,6 +58,7 @@ func (dst *InetArray) ConvertFrom(src interface{}) error { Status: Present, } } + default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.ConvertFrom(originalSrc) @@ -81,6 +83,7 @@ func (src *InetArray) AssignTo(dst interface{}) error { } else { *v = nil } + case *[]net.IP: if src.Status == Present { *v = make([]net.IP, len(src.Elements)) @@ -103,29 +106,17 @@ func (src *InetArray) AssignTo(dst interface{}) error { return nil } -func (dst *InetArray) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *InetArray) DecodeText(src []byte) error { + if src == nil { *dst = InetArray{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Inet if len(uta.Elements) > 0 { @@ -133,8 +124,11 @@ func (dst *InetArray) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Inet - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -148,19 +142,14 @@ func (dst *InetArray) DecodeText(r io.Reader) error { return nil } -func (dst *InetArray) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *InetArray) DecodeBinary(src []byte) error { + if src == nil { *dst = InetArray{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -178,7 +167,14 @@ func (dst *InetArray) DecodeBinary(r io.Reader) error { elements := make([]Inet, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } diff --git a/pgtype/int2.go b/pgtype/int2.go index 8057550b..51346a43 100644 --- a/pgtype/int2.go +++ b/pgtype/int2.go @@ -1,6 +1,7 @@ package pgtype import ( + "encoding/binary" "fmt" "io" "math" @@ -88,24 +89,13 @@ func (src *Int2) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *Int2) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int2) DecodeText(src []byte) error { + if src == nil { *dst = Int2{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - n, err := strconv.ParseInt(string(buf), 10, 16) + n, err := strconv.ParseInt(string(src), 10, 16) if err != nil { return err } @@ -114,27 +104,18 @@ func (dst *Int2) DecodeText(r io.Reader) error { return nil } -func (dst *Int2) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int2) DecodeBinary(src []byte) error { + if src == nil { *dst = Int2{Status: Null} return nil } - if size != 2 { - return fmt.Errorf("invalid length for int2: %v", size) + if len(src) != 2 { + return fmt.Errorf("invalid length for int2: %v", len(src)) } - n, err := pgio.ReadInt16(r) - if err != nil { - return err - } - - *dst = Int2{Int: int16(n), Status: Present} + n := int16(binary.BigEndian.Uint16(src)) + *dst = Int2{Int: n, Status: Present} return nil } diff --git a/pgtype/int2array.go b/pgtype/int2array.go index a989347d..71760e1e 100644 --- a/pgtype/int2array.go +++ b/pgtype/int2array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" @@ -104,29 +105,17 @@ func (src *Int2Array) AssignTo(dst interface{}) error { return nil } -func (dst *Int2Array) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int2Array) DecodeText(src []byte) error { + if src == nil { *dst = Int2Array{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Int2 if len(uta.Elements) > 0 { @@ -134,8 +123,11 @@ func (dst *Int2Array) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Int2 - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -149,19 +141,14 @@ func (dst *Int2Array) DecodeText(r io.Reader) error { return nil } -func (dst *Int2Array) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int2Array) DecodeBinary(src []byte) error { + if src == nil { *dst = Int2Array{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -179,7 +166,14 @@ func (dst *Int2Array) DecodeBinary(r io.Reader) error { elements := make([]Int2, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -267,6 +261,10 @@ func (src *Int2Array) EncodeText(w io.Writer) error { } func (src *Int2Array) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, Int2OID) +} + +func (src *Int2Array) encodeBinary(w io.Writer, elementOID int32) error { if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -287,7 +285,7 @@ func (src *Int2Array) EncodeBinary(w io.Writer) error { } } - arrayHeader.ElementOID = Int2OID + arrayHeader.ElementOID = elementOID arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - diff --git a/pgtype/int4.go b/pgtype/int4.go index 43691bb6..8a53d454 100644 --- a/pgtype/int4.go +++ b/pgtype/int4.go @@ -1,6 +1,7 @@ package pgtype import ( + "encoding/binary" "fmt" "io" "math" @@ -79,24 +80,13 @@ func (src *Int4) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *Int4) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int4) DecodeText(src []byte) error { + if src == nil { *dst = Int4{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - n, err := strconv.ParseInt(string(buf), 10, 32) + n, err := strconv.ParseInt(string(src), 10, 32) if err != nil { return err } @@ -105,26 +95,17 @@ func (dst *Int4) DecodeText(r io.Reader) error { return nil } -func (dst *Int4) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int4) DecodeBinary(src []byte) error { + if src == nil { *dst = Int4{Status: Null} return nil } - if size != 4 { - return fmt.Errorf("invalid length for int4: %v", size) - } - - n, err := pgio.ReadInt32(r) - if err != nil { - return err + if len(src) != 4 { + return fmt.Errorf("invalid length for int4: %v", len(src)) } + n := int32(binary.BigEndian.Uint32(src)) *dst = Int4{Int: n, Status: Present} return nil } diff --git a/pgtype/int4array.go b/pgtype/int4array.go index 89caf263..6a202b08 100644 --- a/pgtype/int4array.go +++ b/pgtype/int4array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" @@ -104,29 +105,17 @@ func (src *Int4Array) AssignTo(dst interface{}) error { return nil } -func (dst *Int4Array) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int4Array) DecodeText(src []byte) error { + if src == nil { *dst = Int4Array{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Int4 if len(uta.Elements) > 0 { @@ -134,8 +123,11 @@ func (dst *Int4Array) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Int4 - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -149,19 +141,14 @@ func (dst *Int4Array) DecodeText(r io.Reader) error { return nil } -func (dst *Int4Array) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int4Array) DecodeBinary(src []byte) error { + if src == nil { *dst = Int4Array{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -179,7 +166,14 @@ func (dst *Int4Array) DecodeBinary(r io.Reader) error { elements := make([]Int4, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -267,6 +261,10 @@ func (src *Int4Array) EncodeText(w io.Writer) error { } func (src *Int4Array) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, Int4OID) +} + +func (src *Int4Array) encodeBinary(w io.Writer, elementOID int32) error { if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -287,7 +285,7 @@ func (src *Int4Array) EncodeBinary(w io.Writer) error { } } - arrayHeader.ElementOID = Int4OID + arrayHeader.ElementOID = elementOID arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - diff --git a/pgtype/int8.go b/pgtype/int8.go index b87bb85a..c6bedaa6 100644 --- a/pgtype/int8.go +++ b/pgtype/int8.go @@ -1,6 +1,7 @@ package pgtype import ( + "encoding/binary" "fmt" "io" "math" @@ -70,24 +71,13 @@ func (src *Int8) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *Int8) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int8) DecodeText(src []byte) error { + if src == nil { *dst = Int8{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - n, err := strconv.ParseInt(string(buf), 10, 64) + n, err := strconv.ParseInt(string(src), 10, 64) if err != nil { return err } @@ -96,25 +86,17 @@ func (dst *Int8) DecodeText(r io.Reader) error { return nil } -func (dst *Int8) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int8) DecodeBinary(src []byte) error { + if src == nil { *dst = Int8{Status: Null} return nil } - if size != 8 { - return fmt.Errorf("invalid length for int8: %v", size) + if len(src) != 8 { + return fmt.Errorf("invalid length for int8: %v", len(src)) } - n, err := pgio.ReadInt64(r) - if err != nil { - return err - } + n := int64(binary.BigEndian.Uint64(src)) *dst = Int8{Int: n, Status: Present} return nil diff --git a/pgtype/int8array.go b/pgtype/int8array.go index 003ed055..f621618e 100644 --- a/pgtype/int8array.go +++ b/pgtype/int8array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" @@ -104,29 +105,17 @@ func (src *Int8Array) AssignTo(dst interface{}) error { return nil } -func (dst *Int8Array) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int8Array) DecodeText(src []byte) error { + if src == nil { *dst = Int8Array{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Int8 if len(uta.Elements) > 0 { @@ -134,8 +123,11 @@ func (dst *Int8Array) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Int8 - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -149,19 +141,14 @@ func (dst *Int8Array) DecodeText(r io.Reader) error { return nil } -func (dst *Int8Array) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Int8Array) DecodeBinary(src []byte) error { + if src == nil { *dst = Int8Array{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -179,7 +166,14 @@ func (dst *Int8Array) DecodeBinary(r io.Reader) error { elements := make([]Int8, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -267,6 +261,10 @@ func (src *Int8Array) EncodeText(w io.Writer) error { } func (src *Int8Array) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, Int8OID) +} + +func (src *Int8Array) encodeBinary(w io.Writer, elementOID int32) error { if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -287,7 +285,7 @@ func (src *Int8Array) EncodeBinary(w io.Writer) error { } } - arrayHeader.ElementOID = Int8OID + arrayHeader.ElementOID = elementOID arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - diff --git a/pgtype/name.go b/pgtype/name.go index 3ff81f12..4bbc43c1 100644 --- a/pgtype/name.go +++ b/pgtype/name.go @@ -27,12 +27,12 @@ func (src *Name) AssignTo(dst interface{}) error { return (*Text)(src).AssignTo(dst) } -func (dst *Name) DecodeText(r io.Reader) error { - return (*Text)(dst).DecodeText(r) +func (dst *Name) DecodeText(src []byte) error { + return (*Text)(dst).DecodeText(src) } -func (dst *Name) DecodeBinary(r io.Reader) error { - return (*Text)(dst).DecodeBinary(r) +func (dst *Name) DecodeBinary(src []byte) error { + return (*Text)(dst).DecodeBinary(src) } func (src Name) EncodeText(w io.Writer) error { diff --git a/pgtype/oid.go b/pgtype/oid.go index d137f352..2ea9c2d1 100644 --- a/pgtype/oid.go +++ b/pgtype/oid.go @@ -24,12 +24,12 @@ func (src *OID) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *OID) DecodeText(r io.Reader) error { - return (*pguint32)(dst).DecodeText(r) +func (dst *OID) DecodeText(src []byte) error { + return (*pguint32)(dst).DecodeText(src) } -func (dst *OID) DecodeBinary(r io.Reader) error { - return (*pguint32)(dst).DecodeBinary(r) +func (dst *OID) DecodeBinary(src []byte) error { + return (*pguint32)(dst).DecodeBinary(src) } func (src OID) EncodeText(w io.Writer) error { diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 15c0cc76..7928e1cc 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -74,11 +74,11 @@ type Value interface { } type BinaryDecoder interface { - DecodeBinary(r io.Reader) error + DecodeBinary(src []byte) error } type TextDecoder interface { - DecodeText(r io.Reader) error + DecodeText(src []byte) error } type BinaryEncoder interface { diff --git a/pgtype/pguint32.go b/pgtype/pguint32.go index 9c1ccd6c..9bf1eef6 100644 --- a/pgtype/pguint32.go +++ b/pgtype/pguint32.go @@ -1,6 +1,7 @@ package pgtype import ( + "encoding/binary" "fmt" "io" "strconv" @@ -51,24 +52,13 @@ func (src *pguint32) AssignTo(dst interface{}) error { return nil } -func (dst *pguint32) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *pguint32) DecodeText(src []byte) error { + if src == nil { *dst = pguint32{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - n, err := strconv.ParseUint(string(buf), 10, 32) + n, err := strconv.ParseUint(string(src), 10, 32) if err != nil { return err } @@ -77,26 +67,17 @@ func (dst *pguint32) DecodeText(r io.Reader) error { return nil } -func (dst *pguint32) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *pguint32) DecodeBinary(src []byte) error { + if src == nil { *dst = pguint32{Status: Null} return nil } - if size != 4 { - return fmt.Errorf("invalid length: %v", size) - } - - n, err := pgio.ReadUint32(r) - if err != nil { - return err + if len(src) != 4 { + return fmt.Errorf("invalid length: %v", len(src)) } + n := binary.BigEndian.Uint32(src) *dst = pguint32{Uint: n, Status: Present} return nil } diff --git a/pgtype/qchar.go b/pgtype/qchar.go index 6dd14625..8abec935 100644 --- a/pgtype/qchar.go +++ b/pgtype/qchar.go @@ -106,27 +106,17 @@ func (src *QChar) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *QChar) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *QChar) DecodeBinary(src []byte) error { + if src == nil { *dst = QChar{Status: Null} return nil } - if size != 1 { - return fmt.Errorf(`invalid length for "char": %v`, size) + if len(src) != 1 { + return fmt.Errorf(`invalid length for "char": %v`, len(src)) } - byt, err := pgio.ReadByte(r) - if err != nil { - return err - } - - *dst = QChar{Int: int8(byt), Status: Present} + *dst = QChar{Int: int8(src[0]), Status: Present} return nil } diff --git a/pgtype/text.go b/pgtype/text.go index c9054468..2951b5ad 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -71,29 +71,18 @@ func (src *Text) AssignTo(dst interface{}) error { return nil } -func (dst *Text) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Text) DecodeText(src []byte) error { + if src == nil { *dst = Text{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - *dst = Text{String: string(buf), Status: Present} + *dst = Text{String: string(src), Status: Present} return nil } -func (dst *Text) DecodeBinary(r io.Reader) error { - return dst.DecodeText(r) +func (dst *Text) DecodeBinary(src []byte) error { + return dst.DecodeText(src) } func (src Text) EncodeText(w io.Writer) error { diff --git a/pgtype/textarray.go b/pgtype/textarray.go index c420e5c9..e7ca3578 100644 --- a/pgtype/textarray.go +++ b/pgtype/textarray.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" @@ -73,29 +74,17 @@ func (src *TextArray) AssignTo(dst interface{}) error { return nil } -func (dst *TextArray) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *TextArray) DecodeText(src []byte) error { + if src == nil { *dst = TextArray{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Text if len(uta.Elements) > 0 { @@ -103,8 +92,11 @@ func (dst *TextArray) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Text - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -118,19 +110,14 @@ func (dst *TextArray) DecodeText(r io.Reader) error { return nil } -func (dst *TextArray) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *TextArray) DecodeBinary(src []byte) error { + if src == nil { *dst = TextArray{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -148,7 +135,14 @@ func (dst *TextArray) DecodeBinary(r io.Reader) error { elements := make([]Text, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -211,7 +205,12 @@ func (src *TextArray) EncodeText(w io.Writer) error { } textElementWriter.Reset() - if elem.String == "" && elem.Status == Present { + if elem.Status == Null { + _, err := io.WriteString(buf, `"NULL"`) + if err != nil { + return err + } + } else if elem.String == "" { _, err := io.WriteString(buf, `""`) if err != nil { return err diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index c6933988..ca5eb738 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -1,6 +1,7 @@ package pgtype import ( + "encoding/binary" "fmt" "io" "reflect" @@ -72,24 +73,13 @@ func (src *Timestamp) AssignTo(dst interface{}) error { // DecodeText decodes from src into dst. The decoded time is considered to // be in UTC. -func (dst *Timestamp) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Timestamp) DecodeText(src []byte) error { + if src == nil { *dst = Timestamp{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - sbuf := string(buf) + sbuf := string(src) switch sbuf { case "infinity": *dst = Timestamp{Status: Present, InfinityModifier: Infinity} @@ -109,25 +99,17 @@ func (dst *Timestamp) DecodeText(r io.Reader) error { // DecodeBinary decodes from src into dst. The decoded time is considered to // be in UTC. -func (dst *Timestamp) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Timestamp) DecodeBinary(src []byte) error { + if src == nil { *dst = Timestamp{Status: Null} return nil } - if size != 8 { - return fmt.Errorf("invalid length for timestamp: %v", size) + if len(src) != 8 { + return fmt.Errorf("invalid length for timestamp: %v", len(src)) } - microsecSinceY2K, err := pgio.ReadInt64(r) - if err != nil { - return err - } + microsecSinceY2K := int64(binary.BigEndian.Uint64(src)) switch microsecSinceY2K { case infinityMicrosecondOffset: diff --git a/pgtype/timestamparray.go b/pgtype/timestamparray.go index 3acbb35f..695559ac 100644 --- a/pgtype/timestamparray.go +++ b/pgtype/timestamparray.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" "time" @@ -74,29 +75,17 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { return nil } -func (dst *TimestampArray) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *TimestampArray) DecodeText(src []byte) error { + if src == nil { *dst = TimestampArray{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Timestamp if len(uta.Elements) > 0 { @@ -104,8 +93,11 @@ func (dst *TimestampArray) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Timestamp - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -119,19 +111,14 @@ func (dst *TimestampArray) DecodeText(r io.Reader) error { return nil } -func (dst *TimestampArray) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *TimestampArray) DecodeBinary(src []byte) error { + if src == nil { *dst = TimestampArray{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -149,7 +136,14 @@ func (dst *TimestampArray) DecodeBinary(r io.Reader) error { elements := make([]Timestamp, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -237,6 +231,10 @@ func (src *TimestampArray) EncodeText(w io.Writer) error { } func (src *TimestampArray) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, TimestampOID) +} + +func (src *TimestampArray) encodeBinary(w io.Writer, elementOID int32) error { if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -257,7 +255,7 @@ func (src *TimestampArray) EncodeBinary(w io.Writer) error { } } - arrayHeader.ElementOID = TimestampOID + arrayHeader.ElementOID = elementOID arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index 721c8084..7255bb06 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -1,6 +1,7 @@ package pgtype import ( + "encoding/binary" "fmt" "io" "reflect" @@ -71,24 +72,13 @@ func (src *Timestamptz) AssignTo(dst interface{}) error { return nil } -func (dst *Timestamptz) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Timestamptz) DecodeText(src []byte) error { + if src == nil { *dst = Timestamptz{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - sbuf := string(buf) + sbuf := string(src) switch sbuf { case "infinity": *dst = Timestamptz{Status: Present, InfinityModifier: Infinity} @@ -115,25 +105,17 @@ func (dst *Timestamptz) DecodeText(r io.Reader) error { return nil } -func (dst *Timestamptz) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *Timestamptz) DecodeBinary(src []byte) error { + if src == nil { *dst = Timestamptz{Status: Null} return nil } - if size != 8 { - return fmt.Errorf("invalid length for timestamptz: %v", size) + if len(src) != 8 { + return fmt.Errorf("invalid length for timestamptz: %v", len(src)) } - microsecSinceY2K, err := pgio.ReadInt64(r) - if err != nil { - return err - } + microsecSinceY2K := int64(binary.BigEndian.Uint64(src)) switch microsecSinceY2K { case infinityMicrosecondOffset: diff --git a/pgtype/timestamptzarray.go b/pgtype/timestamptzarray.go index 9df746e6..ca416c97 100644 --- a/pgtype/timestamptzarray.go +++ b/pgtype/timestamptzarray.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "encoding/binary" "fmt" "io" "time" @@ -74,29 +75,17 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { return nil } -func (dst *TimestamptzArray) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *TimestamptzArray) DecodeText(src []byte) error { + if src == nil { *dst = TimestamptzArray{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []Timestamptz if len(uta.Elements) > 0 { @@ -104,8 +93,11 @@ func (dst *TimestamptzArray) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem Timestamptz - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -119,19 +111,14 @@ func (dst *TimestamptzArray) DecodeText(r io.Reader) error { return nil } -func (dst *TimestamptzArray) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *TimestamptzArray) DecodeBinary(src []byte) error { + if src == nil { *dst = TimestamptzArray{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -149,7 +136,14 @@ func (dst *TimestamptzArray) DecodeBinary(r io.Reader) error { elements := make([]Timestamptz, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -237,6 +231,10 @@ func (src *TimestamptzArray) EncodeText(w io.Writer) error { } func (src *TimestamptzArray) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, TimestamptzOID) +} + +func (src *TimestamptzArray) encodeBinary(w io.Writer, elementOID int32) error { if done, err := encodeNotPresent(w, src.Status); done { return err } @@ -257,7 +255,7 @@ func (src *TimestamptzArray) EncodeBinary(w io.Writer) error { } } - arrayHeader.ElementOID = TimestamptzOID + arrayHeader.ElementOID = elementOID arrayHeader.Dimensions = src.Dimensions // TODO - consider how to avoid having to buffer array before writing length - diff --git a/pgtype/to-consider.txt b/pgtype/to-consider.txt new file mode 100644 index 00000000..ba4f3511 --- /dev/null +++ b/pgtype/to-consider.txt @@ -0,0 +1,9 @@ +DecodeText and DecodeBinary take []byte instead of io.Reader +EncodeText and EncodeBinary do not write size +Add Nullable interface with IsNull() and SetNull() + +The above would keep types from needing to worry about writing their own size. Could make EncodeText and DecodeText easier to use with sql.Scanner and driver.Valuer. SetNull() could be removed as DecodeText and DecodeBinary could interpret a nil slice as null. + +EncodeText and EncodeBinary could return (null bool, err error). That would finish removing Nullable interface. + +Also, consider whether arrays and ranges could be represented as generic data types or more common code could be extracted instead of using code generation. diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb index 8c18073b..316439ef 100644 --- a/pgtype/typed_array.go.erb +++ b/pgtype/typed_array.go.erb @@ -73,29 +73,17 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { return nil } -func (dst *<%= pgtype_array_type %>) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *<%= pgtype_array_type %>) DecodeText(src []byte) error { + if src == nil { *dst = <%= pgtype_array_type %>{Status: Null} return nil } - buf := make([]byte, int(size)) - _, err = io.ReadFull(r, buf) + uta, err := ParseUntypedTextArray(string(src)) if err != nil { return err } - uta, err := ParseUntypedTextArray(string(buf)) - if err != nil { - return err - } - - textElementReader := NewTextElementReader(r) var elements []<%= pgtype_element_type %> if len(uta.Elements) > 0 { @@ -103,8 +91,11 @@ func (dst *<%= pgtype_array_type %>) DecodeText(r io.Reader) error { for i, s := range uta.Elements { var elem <%= pgtype_element_type %> - textElementReader.Reset(s) - err = elem.DecodeText(textElementReader) + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) if err != nil { return err } @@ -118,19 +109,14 @@ func (dst *<%= pgtype_array_type %>) DecodeText(r io.Reader) error { return nil } -func (dst *<%= pgtype_array_type %>) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *<%= pgtype_array_type %>) DecodeBinary(src []byte) error { + if src == nil { *dst = <%= pgtype_array_type %>{Status: Null} return nil } var arrayHeader ArrayHeader - err = arrayHeader.DecodeBinary(r) + rp, err := arrayHeader.DecodeBinary(src) if err != nil { return err } @@ -148,7 +134,14 @@ func (dst *<%= pgtype_array_type %>) DecodeBinary(r io.Reader) error { elements := make([]<%= pgtype_element_type %>, elementCount) for i := range elements { - err = elements[i].DecodeBinary(r) + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp:rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) if err != nil { return err } @@ -211,16 +204,9 @@ func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) error { } textElementWriter.Reset() - if elem.String == "" && elem.Status == Present { - _, err := io.WriteString(buf, `""`) - if err != nil { - return err - } - } else { - err = elem.EncodeText(textElementWriter) - if err != nil { - return err - } + err = elem.EncodeText(textElementWriter) + if err != nil { + return err } for _, dec := range dimElemCounts { diff --git a/pgtype/varchararray.go b/pgtype/varchararray.go index 13d94bc0..3a5d8536 100644 --- a/pgtype/varchararray.go +++ b/pgtype/varchararray.go @@ -14,12 +14,12 @@ func (src *VarcharArray) AssignTo(dst interface{}) error { return (*TextArray)(src).AssignTo(dst) } -func (dst *VarcharArray) DecodeText(r io.Reader) error { - return (*TextArray)(dst).DecodeText(r) +func (dst *VarcharArray) DecodeText(src []byte) error { + return (*TextArray)(dst).DecodeText(src) } -func (dst *VarcharArray) DecodeBinary(r io.Reader) error { - return (*TextArray)(dst).DecodeBinary(r) +func (dst *VarcharArray) DecodeBinary(src []byte) error { + return (*TextArray)(dst).DecodeBinary(src) } func (src *VarcharArray) EncodeText(w io.Writer) error { diff --git a/pgtype/xid.go b/pgtype/xid.go index d4003b5d..389f93bc 100644 --- a/pgtype/xid.go +++ b/pgtype/xid.go @@ -33,12 +33,12 @@ func (src *XID) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *XID) DecodeText(r io.Reader) error { - return (*pguint32)(dst).DecodeText(r) +func (dst *XID) DecodeText(src []byte) error { + return (*pguint32)(dst).DecodeText(src) } -func (dst *XID) DecodeBinary(r io.Reader) error { - return (*pguint32)(dst).DecodeBinary(r) +func (dst *XID) DecodeBinary(src []byte) error { + return (*pguint32)(dst).DecodeBinary(src) } func (src XID) EncodeText(w io.Writer) error { diff --git a/query.go b/query.go index 965f3913..71d1ba9e 100644 --- a/query.go +++ b/query.go @@ -231,14 +231,12 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { rows.Fatal(scanArgError{col: i, err: err}) } } else if s, ok := d.(pgtype.BinaryDecoder); ok && vr.Type().FormatCode == BinaryFormatCode { - vr.err = errRewoundLen - err = s.DecodeBinary(&valueReader2{vr}) + err = s.DecodeBinary(vr.bytes()) if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } } else if s, ok := d.(pgtype.TextDecoder); ok && vr.Type().FormatCode == TextFormatCode { - vr.err = errRewoundLen - err = s.DecodeText(&valueReader2{vr}) + err = s.DecodeText(vr.bytes()) if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } @@ -290,8 +288,7 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { switch vr.Type().FormatCode { case TextFormatCode: if textDecoder, ok := pgVal.(pgtype.TextDecoder); ok { - vr.err = errRewoundLen - err = textDecoder.DecodeText(&valueReader2{vr}) + err = textDecoder.DecodeText(vr.bytes()) if err != nil { vr.Fatal(err) } @@ -300,8 +297,7 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { } case BinaryFormatCode: if binaryDecoder, ok := pgVal.(pgtype.BinaryDecoder); ok { - vr.err = errRewoundLen - err = binaryDecoder.DecodeBinary(&valueReader2{vr}) + err = binaryDecoder.DecodeBinary(vr.bytes()) if err != nil { vr.Fatal(err) } diff --git a/value_reader.go b/value_reader.go index c91a21af..85932a7d 100644 --- a/value_reader.go +++ b/value_reader.go @@ -4,8 +4,6 @@ import ( "errors" ) -var errRewoundLen = errors.New("len was rewound") - // ValueReader is used by the Scanner interface to decode values. type ValueReader struct { mr *msgReader @@ -157,27 +155,10 @@ func (r *ValueReader) ReadBytes(count int32) []byte { return r.mr.readBytes(count) } -type valueReader2 struct { - *ValueReader -} - -func (r *valueReader2) Read(dst []byte) (int, error) { - if r.err != nil { - return 0, r.err +// bytes is a compatibility function for pgtype.TextDecoder and pgtype.BinaryDecoder +func (r *ValueReader) bytes() []byte { + if r.Len() >= 0 { + return r.ReadBytes(r.Len()) } - - src := r.ReadBytes(int32(len(dst))) - - copy(dst, src) - - return len(dst), nil -} - -func (r *valueReader2) ReadUint32() (uint32, error) { - if r.err == errRewoundLen { - r.err = nil - return uint32(r.Len()), nil - } - - return r.ValueReader.ReadUint32(), nil + return nil } diff --git a/values.go b/values.go index c724aa39..796f2f3d 100644 --- a/values.go +++ b/values.go @@ -3,6 +3,7 @@ package pgx import ( "bytes" "database/sql/driver" + "encoding/binary" "encoding/json" "fmt" "io" @@ -455,23 +456,12 @@ func (n NullInt32) Encode(w *WriteBuf, oid OID) error { // in the PostgreSQL sources. OID cannot be NULL. To allow for NULL OIDs use pgtype.OID. type OID uint32 -func (dst *OID) DecodeText(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *OID) DecodeText(src []byte) error { + if src == nil { return fmt.Errorf("cannot decode nil into OID") } - buf := make([]byte, int(size)) - _, err = r.Read(buf) - if err != nil { - return err - } - - n, err := strconv.ParseUint(string(buf), 10, 32) + n, err := strconv.ParseUint(string(src), 10, 32) if err != nil { return err } @@ -480,25 +470,16 @@ func (dst *OID) DecodeText(r io.Reader) error { return nil } -func (dst *OID) DecodeBinary(r io.Reader) error { - size, err := pgio.ReadInt32(r) - if err != nil { - return err - } - - if size == -1 { +func (dst *OID) DecodeBinary(src []byte) error { + if src == nil { return fmt.Errorf("cannot decode nil into OID") } - if size != 4 { - return fmt.Errorf("invalid length for OID: %v", size) - } - - n, err := pgio.ReadUint32(r) - if err != nil { - return err + if len(src) != 4 { + return fmt.Errorf("invalid length: %v", len(src)) } + n := binary.BigEndian.Uint32(src) *dst = OID(n) return nil } @@ -1020,15 +1001,13 @@ func decodeBool(vr *ValueReader) bool { return false } - vr.err = errRewoundLen - var b pgtype.Bool var err error switch vr.Type().FormatCode { case TextFormatCode: - err = b.DecodeText(&valueReader2{vr}) + err = b.DecodeText(vr.bytes()) case BinaryFormatCode: - err = b.DecodeBinary(&valueReader2{vr}) + err = b.DecodeBinary(vr.bytes()) default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return false @@ -1081,15 +1060,13 @@ func decodeInt8(vr *ValueReader) int64 { return 0 } - vr.err = errRewoundLen - var n pgtype.Int8 var err error switch vr.Type().FormatCode { case TextFormatCode: - err = n.DecodeText(&valueReader2{vr}) + err = n.DecodeText(vr.bytes()) case BinaryFormatCode: - err = n.DecodeBinary(&valueReader2{vr}) + err = n.DecodeBinary(vr.bytes()) default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0 @@ -1115,15 +1092,13 @@ func decodeInt2(vr *ValueReader) int16 { return 0 } - vr.err = errRewoundLen - var n pgtype.Int2 var err error switch vr.Type().FormatCode { case TextFormatCode: - err = n.DecodeText(&valueReader2{vr}) + err = n.DecodeText(vr.bytes()) case BinaryFormatCode: - err = n.DecodeBinary(&valueReader2{vr}) + err = n.DecodeBinary(vr.bytes()) default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0 @@ -1153,15 +1128,13 @@ func decodeInt4(vr *ValueReader) int32 { return 0 } - vr.err = errRewoundLen - var n pgtype.Int4 var err error switch vr.Type().FormatCode { case TextFormatCode: - err = n.DecodeText(&valueReader2{vr}) + err = n.DecodeText(vr.bytes()) case BinaryFormatCode: - err = n.DecodeBinary(&valueReader2{vr}) + err = n.DecodeBinary(vr.bytes()) default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return 0 @@ -1455,15 +1428,13 @@ func decodeDate(vr *ValueReader) time.Time { return time.Time{} } - vr.err = errRewoundLen - var d pgtype.Date var err error switch vr.Type().FormatCode { case TextFormatCode: - err = d.DecodeText(&valueReader2{vr}) + err = d.DecodeText(vr.bytes()) case BinaryFormatCode: - err = d.DecodeBinary(&valueReader2{vr}) + err = d.DecodeBinary(vr.bytes()) default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return time.Time{} @@ -1518,15 +1489,13 @@ func decodeTimestampTz(vr *ValueReader) time.Time { return zeroTime } - vr.err = errRewoundLen - var t pgtype.Timestamptz var err error switch vr.Type().FormatCode { case TextFormatCode: - err = t.DecodeText(&valueReader2{vr}) + err = t.DecodeText(vr.bytes()) case BinaryFormatCode: - err = t.DecodeBinary(&valueReader2{vr}) + err = t.DecodeBinary(vr.bytes()) default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) return time.Time{} From 6c26c3a4a305b2415969214094039235e5c31e46 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Mar 2017 11:17:51 -0600 Subject: [PATCH 096/264] Improve replication test reliability It was failing intermittently when run concurrently. --- README.md | 2 + conn_config_test.go.example | 2 + replication_test.go | 78 ++++++++++++++++--------------------- 3 files changed, 38 insertions(+), 44 deletions(-) diff --git a/README.md b/README.md index ea2038a8..b85f9c0f 100644 --- a/README.md +++ b/README.md @@ -122,6 +122,8 @@ Change the following settings in your postgresql.conf: max_wal_senders=5 max_replication_slots=5 +Set `replicationConnConfig` appropriately in `conn_config_test.go`. + ## Version Policy pgx follows semantic versioning for the documented public API on stable releases. Branch `v2` is the latest stable release. `master` can contain new features or behavior that will change or be removed before being merged to the stable `v2` branch (in practice, this occurs very rarely). diff --git a/conn_config_test.go.example b/conn_config_test.go.example index cac798b7..4f6a5e5a 100644 --- a/conn_config_test.go.example +++ b/conn_config_test.go.example @@ -23,3 +23,5 @@ var replicationConnConfig *pgx.ConnConfig = nil // var invalidUserConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"} // var tlsConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}} // var customDialerConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} +// var replicationConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_replication", Password: "secret", Database: "pgx_test"} + diff --git a/replication_test.go b/replication_test.go index 1a8063e5..54ef4b66 100644 --- a/replication_test.go +++ b/replication_test.go @@ -39,8 +39,6 @@ func getConfirmedFlushLsnFor(t *testing.T, conn *pgx.Conn, slot string) string { // - Checks the wal position of the slot on the server to make sure // the update succeeded func TestSimpleReplicationConnection(t *testing.T) { - t.Parallel() - var err error if replicationConnConfig == nil { @@ -74,71 +72,63 @@ func TestSimpleReplicationConnection(t *testing.T) { t.Fatalf("Failed to start replication: %v", err) } - var i int32 var insertedTimes []int64 - for i < 5 { + currentTime := time.Now().Unix() + + for i := 0; i < 5; i++ { var ct pgx.CommandTag - currentTime := time.Now().Unix() insertedTimes = append(insertedTimes, currentTime) ct, err = conn.Exec("insert into replication_test(a) values($1)", currentTime) if err != nil { t.Fatalf("Insert failed: %v", err) } t.Logf("Inserted %d rows", ct.RowsAffected()) - i++ + currentTime++ } - i = 0 var foundTimes []int64 var foundCount int var maxWal uint64 + + ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelFn() + for { var message *pgx.ReplicationMessage - ctx, cancelFn := context.WithTimeout(context.Background(), time.Second) - defer cancelFn() message, err = replicationConn.WaitForReplicationMessage(ctx) - if err != nil && err != context.DeadlineExceeded { + if err != nil { t.Fatalf("Replication failed: %v %s", err, reflect.TypeOf(err)) } - if message != nil { - if message.WalMessage != nil { - // The waldata payload with the test_decoding plugin looks like: - // public.replication_test: INSERT: a[integer]:2 - // What we wanna do here is check that once we find one of our inserted times, - // that they occur in the wal stream in the order we executed them. - walString := string(message.WalMessage.WalData) - if strings.Contains(walString, "public.replication_test: INSERT") { - stringParts := strings.Split(walString, ":") - offset, err := strconv.ParseInt(stringParts[len(stringParts)-1], 10, 64) - if err != nil { - t.Fatalf("Failed to parse walString %s", walString) - } - if foundCount > 0 || offset == insertedTimes[0] { - foundTimes = append(foundTimes, offset) - foundCount++ - } - } - if message.WalMessage.WalStart > maxWal { - maxWal = message.WalMessage.WalStart - } + if message.WalMessage != nil { + // The waldata payload with the test_decoding plugin looks like: + // public.replication_test: INSERT: a[integer]:2 + // What we wanna do here is check that once we find one of our inserted times, + // that they occur in the wal stream in the order we executed them. + walString := string(message.WalMessage.WalData) + if strings.Contains(walString, "public.replication_test: INSERT") { + stringParts := strings.Split(walString, ":") + offset, err := strconv.ParseInt(stringParts[len(stringParts)-1], 10, 64) + if err != nil { + t.Fatalf("Failed to parse walString %s", walString) + } + if foundCount > 0 || offset == insertedTimes[0] { + foundTimes = append(foundTimes, offset) + foundCount++ + } + if foundCount == len(insertedTimes) { + break + } } - if message.ServerHeartbeat != nil { - t.Logf("Got heartbeat: %s", message.ServerHeartbeat) + if message.WalMessage.WalStart > maxWal { + maxWal = message.WalMessage.WalStart } - } else { - t.Log("Timed out waiting for wal message") - i++ - } - if i > 3 { - t.Log("Actual timeout") - break - } - } - if foundCount != len(insertedTimes) { - t.Fatalf("Failed to find all inserted time values in WAL stream (found %d expected %d)", foundCount, len(insertedTimes)) + } + if message.ServerHeartbeat != nil { + t.Logf("Got heartbeat: %s", message.ServerHeartbeat) + } } for i := range insertedTimes { From 1f3e484ca15f6072173f866abb7b46b5d3ecd821 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Mar 2017 12:32:33 -0600 Subject: [PATCH 097/264] pgtype.Encode(Binary|Text) do not write length To aid in composability, these methods no longer write their own length. This is especially useful for text formatted arrays and may be useful for future database/sql compatibility. It also makes the code a little simpler as the types no longer have to compute their own size. Along with this, these methods cannot encode NULL. They now return a boolean if they are NULL. This also benefits text array encoding as numeric arrays require NULL to be exactly `NULL` while string arrays require NULL to be `"NULL"`. --- pgtype/bool.go | 38 ++++------ pgtype/boolarray.go | 148 +++++++++++++++++++------------------ pgtype/bytea.go | 44 +++++------ pgtype/cid.go | 4 +- pgtype/cidrarray.go | 4 +- pgtype/date.go | 36 ++++----- pgtype/datearray.go | 148 +++++++++++++++++++------------------ pgtype/float4.go | 36 ++++----- pgtype/float4array.go | 148 +++++++++++++++++++------------------ pgtype/float8.go | 36 ++++----- pgtype/float8array.go | 148 +++++++++++++++++++------------------ pgtype/inet.go | 46 +++++------- pgtype/inetarray.go | 148 +++++++++++++++++++------------------ pgtype/int2.go | 36 ++++----- pgtype/int2array.go | 148 +++++++++++++++++++------------------ pgtype/int4.go | 36 ++++----- pgtype/int4array.go | 148 +++++++++++++++++++------------------ pgtype/int8.go | 36 ++++----- pgtype/int8array.go | 148 +++++++++++++++++++------------------ pgtype/name.go | 4 +- pgtype/oid.go | 4 +- pgtype/pgtype.go | 29 ++++---- pgtype/pgtype_test.go | 4 +- pgtype/pguint32.go | 36 ++++----- pgtype/qchar.go | 16 ++-- pgtype/text.go | 22 +++--- pgtype/text_element.go | 112 ---------------------------- pgtype/textarray.go | 148 ++++++++++++++++++------------------- pgtype/timestamp.go | 40 +++++----- pgtype/timestamparray.go | 148 +++++++++++++++++++------------------ pgtype/timestamptz.go | 36 ++++----- pgtype/timestamptzarray.go | 148 +++++++++++++++++++------------------ pgtype/to-consider.txt | 9 --- pgtype/typed_array.go.erb | 148 +++++++++++++++++++------------------ pgtype/typed_array_gen.sh | 22 +++--- pgtype/varchararray.go | 4 +- pgtype/xid.go | 4 +- values.go | 120 +++++++++++++++++++++--------- 38 files changed, 1271 insertions(+), 1319 deletions(-) delete mode 100644 pgtype/text_element.go delete mode 100644 pgtype/to-consider.txt diff --git a/pgtype/bool.go b/pgtype/bool.go index b7bc14d0..9764fafe 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -5,8 +5,6 @@ import ( "io" "reflect" "strconv" - - "github.com/jackc/pgx/pgio" ) type Bool struct { @@ -100,14 +98,12 @@ func (dst *Bool) DecodeBinary(src []byte) error { return nil } -func (src Bool) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err - } - - _, err := pgio.WriteInt32(w, 1) - if err != nil { - return nil +func (src Bool) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } var buf []byte @@ -117,18 +113,16 @@ func (src Bool) EncodeText(w io.Writer) error { buf = []byte{'f'} } - _, err = w.Write(buf) - return err + _, err := w.Write(buf) + return false, err } -func (src Bool) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err - } - - _, err := pgio.WriteInt32(w, 1) - if err != nil { - return nil +func (src Bool) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } var buf []byte @@ -138,6 +132,6 @@ func (src Bool) EncodeBinary(w io.Writer) error { buf = []byte{0} } - _, err = w.Write(buf) - return err + _, err := w.Write(buf) + return false, err } diff --git a/pgtype/boolarray.go b/pgtype/boolarray.go index a9b8bf50..f7323281 100644 --- a/pgtype/boolarray.go +++ b/pgtype/boolarray.go @@ -152,26 +152,22 @@ func (dst *BoolArray) DecodeBinary(src []byte) error { return nil } -func (src *BoolArray) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *BoolArray) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } if len(src.Dimensions) == 0 { - _, err := pgio.WriteInt32(w, 2) - if err != nil { - return err - } - - _, err = w.Write([]byte("{}")) - return err + _, err := io.WriteString(w, "{}") + return false, err } - buf := &bytes.Buffer{} - - err := EncodeTextArrayDimensions(buf, src.Dimensions) + err := EncodeTextArrayDimensions(w, src.Dimensions) if err != nil { - return err + return false, err } // dimElemCounts is the multiples of elements that each array lies on. For @@ -185,100 +181,112 @@ func (src *BoolArray) EncodeText(w io.Writer) error { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } - textElementWriter := NewTextElementWriter(buf) - for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(buf, ',') + err = pgio.WriteByte(w, ',') if err != nil { - return err + return false, err } } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(buf, '{') + err = pgio.WriteByte(w, '{') if err != nil { - return err + return false, err } } } - textElementWriter.Reset() - err = elem.EncodeText(textElementWriter) + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) if err != nil { - return err + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else if elemBuf.Len() == 0 { + _, err = io.WriteString(w, `""`) + if err != nil { + return false, err + } + } else { + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(buf, '}') + err = pgio.WriteByte(w, '}') if err != nil { - return err + return false, err } } } } - _, err = pgio.WriteInt32(w, int32(buf.Len())) - if err != nil { - return err - } - - _, err = buf.WriteTo(w) - return err + return false, nil } -func (src *BoolArray) EncodeBinary(w io.Writer) error { +func (src *BoolArray) EncodeBinary(w io.Writer) (bool, error) { return src.encodeBinary(w, BoolOID) } -func (src *BoolArray) encodeBinary(w io.Writer, elementOID int32) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *BoolArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - var arrayHeader ArrayHeader + arrayHeader := ArrayHeader{ + ElementOID: elementOID, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(w) + if err != nil { + return false, err + } - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. elemBuf := &bytes.Buffer{} for i := range src.Elements { - err := src.Elements[i].EncodeBinary(elemBuf) + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(elemBuf) if err != nil { - return err + return false, err } - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } } - arrayHeader.ElementOID = elementOID - arrayHeader.Dimensions = src.Dimensions - - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. - headerBuf := &bytes.Buffer{} - err := arrayHeader.EncodeBinary(headerBuf) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) - if err != nil { - return err - } - - _, err = headerBuf.WriteTo(w) - if err != nil { - return err - } - - _, err = elemBuf.WriteTo(w) - if err != nil { - return err - } - - return err + return false, err } diff --git a/pgtype/bytea.go b/pgtype/bytea.go index db20482f..709499d2 100644 --- a/pgtype/bytea.go +++ b/pgtype/bytea.go @@ -5,8 +5,6 @@ import ( "fmt" "io" "reflect" - - "github.com/jackc/pgx/pgio" ) type Bytea struct { @@ -101,37 +99,31 @@ func (dst *Bytea) DecodeBinary(src []byte) error { return nil } -func (src Bytea) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Bytea) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - str := hex.EncodeToString(src.Bytes) - - _, err := pgio.WriteInt32(w, int32(len(str)+2)) + _, err := io.WriteString(w, `\x`) if err != nil { - return nil + return false, err } - _, err = io.WriteString(w, `\x`) - if err != nil { - return nil - } - - _, err = io.WriteString(w, str) - return err + _, err = io.WriteString(w, hex.EncodeToString(src.Bytes)) + return false, err } -func (src Bytea) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Bytea) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - _, err := pgio.WriteInt32(w, int32(len(src.Bytes))) - if err != nil { - return nil - } - - _, err = w.Write(src.Bytes) - return err + _, err := w.Write(src.Bytes) + return false, err } diff --git a/pgtype/cid.go b/pgtype/cid.go index f8d706d0..41b817bb 100644 --- a/pgtype/cid.go +++ b/pgtype/cid.go @@ -38,10 +38,10 @@ func (dst *CID) DecodeBinary(src []byte) error { return (*pguint32)(dst).DecodeBinary(src) } -func (src CID) EncodeText(w io.Writer) error { +func (src CID) EncodeText(w io.Writer) (bool, error) { return (pguint32)(src).EncodeText(w) } -func (src CID) EncodeBinary(w io.Writer) error { +func (src CID) EncodeBinary(w io.Writer) (bool, error) { return (pguint32)(src).EncodeBinary(w) } diff --git a/pgtype/cidrarray.go b/pgtype/cidrarray.go index d95eef4a..cb81d2b9 100644 --- a/pgtype/cidrarray.go +++ b/pgtype/cidrarray.go @@ -22,10 +22,10 @@ func (dst *CidrArray) DecodeBinary(src []byte) error { return (*InetArray)(dst).DecodeBinary(src) } -func (src *CidrArray) EncodeText(w io.Writer) error { +func (src *CidrArray) EncodeText(w io.Writer) (bool, error) { return (*InetArray)(src).EncodeText(w) } -func (src *CidrArray) EncodeBinary(w io.Writer) error { +func (src *CidrArray) EncodeBinary(w io.Writer) (bool, error) { return (*InetArray)(src).encodeBinary(w, CidrOID) } diff --git a/pgtype/date.go b/pgtype/date.go index 1bb81d35..b0d16e64 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -116,9 +116,12 @@ func (dst *Date) DecodeBinary(src []byte) error { return nil } -func (src Date) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Date) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } var s string @@ -132,23 +135,16 @@ func (src Date) EncodeText(w io.Writer) error { s = "-infinity" } - _, err := pgio.WriteInt32(w, int32(len(s))) - if err != nil { - return nil - } - - _, err = w.Write([]byte(s)) - return err + _, err := io.WriteString(w, s) + return false, err } -func (src Date) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err - } - - _, err := pgio.WriteInt32(w, 4) - if err != nil { - return err +func (src Date) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } var daysSinceDateEpoch int32 @@ -165,6 +161,6 @@ func (src Date) EncodeBinary(w io.Writer) error { daysSinceDateEpoch = negativeInfinityDayOffset } - _, err = pgio.WriteInt32(w, daysSinceDateEpoch) - return err + _, err := pgio.WriteInt32(w, daysSinceDateEpoch) + return false, err } diff --git a/pgtype/datearray.go b/pgtype/datearray.go index e9ad1f62..9552739b 100644 --- a/pgtype/datearray.go +++ b/pgtype/datearray.go @@ -153,26 +153,22 @@ func (dst *DateArray) DecodeBinary(src []byte) error { return nil } -func (src *DateArray) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *DateArray) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } if len(src.Dimensions) == 0 { - _, err := pgio.WriteInt32(w, 2) - if err != nil { - return err - } - - _, err = w.Write([]byte("{}")) - return err + _, err := io.WriteString(w, "{}") + return false, err } - buf := &bytes.Buffer{} - - err := EncodeTextArrayDimensions(buf, src.Dimensions) + err := EncodeTextArrayDimensions(w, src.Dimensions) if err != nil { - return err + return false, err } // dimElemCounts is the multiples of elements that each array lies on. For @@ -186,100 +182,112 @@ func (src *DateArray) EncodeText(w io.Writer) error { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } - textElementWriter := NewTextElementWriter(buf) - for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(buf, ',') + err = pgio.WriteByte(w, ',') if err != nil { - return err + return false, err } } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(buf, '{') + err = pgio.WriteByte(w, '{') if err != nil { - return err + return false, err } } } - textElementWriter.Reset() - err = elem.EncodeText(textElementWriter) + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) if err != nil { - return err + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else if elemBuf.Len() == 0 { + _, err = io.WriteString(w, `""`) + if err != nil { + return false, err + } + } else { + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(buf, '}') + err = pgio.WriteByte(w, '}') if err != nil { - return err + return false, err } } } } - _, err = pgio.WriteInt32(w, int32(buf.Len())) - if err != nil { - return err - } - - _, err = buf.WriteTo(w) - return err + return false, nil } -func (src *DateArray) EncodeBinary(w io.Writer) error { +func (src *DateArray) EncodeBinary(w io.Writer) (bool, error) { return src.encodeBinary(w, DateOID) } -func (src *DateArray) encodeBinary(w io.Writer, elementOID int32) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *DateArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - var arrayHeader ArrayHeader + arrayHeader := ArrayHeader{ + ElementOID: elementOID, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(w) + if err != nil { + return false, err + } - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. elemBuf := &bytes.Buffer{} for i := range src.Elements { - err := src.Elements[i].EncodeBinary(elemBuf) + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(elemBuf) if err != nil { - return err + return false, err } - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } } - arrayHeader.ElementOID = elementOID - arrayHeader.Dimensions = src.Dimensions - - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. - headerBuf := &bytes.Buffer{} - err := arrayHeader.EncodeBinary(headerBuf) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) - if err != nil { - return err - } - - _, err = headerBuf.WriteTo(w) - if err != nil { - return err - } - - _, err = elemBuf.WriteTo(w) - if err != nil { - return err - } - - return err + return false, err } diff --git a/pgtype/float4.go b/pgtype/float4.go index fb0415e5..26609ab2 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -124,30 +124,26 @@ func (dst *Float4) DecodeBinary(src []byte) error { return nil } -func (src Float4) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Float4) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - s := strconv.FormatFloat(float64(src.Float), 'f', -1, 32) - _, err := pgio.WriteInt32(w, int32(len(s))) - if err != nil { - return nil - } - _, err = w.Write([]byte(s)) - return err + _, err := io.WriteString(w, strconv.FormatFloat(float64(src.Float), 'f', -1, 32)) + return false, err } -func (src Float4) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Float4) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - _, err := pgio.WriteInt32(w, 4) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, int32(math.Float32bits(src.Float))) - return err + _, err := pgio.WriteInt32(w, int32(math.Float32bits(src.Float))) + return false, err } diff --git a/pgtype/float4array.go b/pgtype/float4array.go index a4a72146..9ab08dcc 100644 --- a/pgtype/float4array.go +++ b/pgtype/float4array.go @@ -152,26 +152,22 @@ func (dst *Float4Array) DecodeBinary(src []byte) error { return nil } -func (src *Float4Array) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *Float4Array) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } if len(src.Dimensions) == 0 { - _, err := pgio.WriteInt32(w, 2) - if err != nil { - return err - } - - _, err = w.Write([]byte("{}")) - return err + _, err := io.WriteString(w, "{}") + return false, err } - buf := &bytes.Buffer{} - - err := EncodeTextArrayDimensions(buf, src.Dimensions) + err := EncodeTextArrayDimensions(w, src.Dimensions) if err != nil { - return err + return false, err } // dimElemCounts is the multiples of elements that each array lies on. For @@ -185,100 +181,112 @@ func (src *Float4Array) EncodeText(w io.Writer) error { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } - textElementWriter := NewTextElementWriter(buf) - for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(buf, ',') + err = pgio.WriteByte(w, ',') if err != nil { - return err + return false, err } } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(buf, '{') + err = pgio.WriteByte(w, '{') if err != nil { - return err + return false, err } } } - textElementWriter.Reset() - err = elem.EncodeText(textElementWriter) + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) if err != nil { - return err + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else if elemBuf.Len() == 0 { + _, err = io.WriteString(w, `""`) + if err != nil { + return false, err + } + } else { + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(buf, '}') + err = pgio.WriteByte(w, '}') if err != nil { - return err + return false, err } } } } - _, err = pgio.WriteInt32(w, int32(buf.Len())) - if err != nil { - return err - } - - _, err = buf.WriteTo(w) - return err + return false, nil } -func (src *Float4Array) EncodeBinary(w io.Writer) error { +func (src *Float4Array) EncodeBinary(w io.Writer) (bool, error) { return src.encodeBinary(w, Float4OID) } -func (src *Float4Array) encodeBinary(w io.Writer, elementOID int32) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *Float4Array) encodeBinary(w io.Writer, elementOID int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - var arrayHeader ArrayHeader + arrayHeader := ArrayHeader{ + ElementOID: elementOID, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(w) + if err != nil { + return false, err + } - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. elemBuf := &bytes.Buffer{} for i := range src.Elements { - err := src.Elements[i].EncodeBinary(elemBuf) + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(elemBuf) if err != nil { - return err + return false, err } - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } } - arrayHeader.ElementOID = elementOID - arrayHeader.Dimensions = src.Dimensions - - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. - headerBuf := &bytes.Buffer{} - err := arrayHeader.EncodeBinary(headerBuf) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) - if err != nil { - return err - } - - _, err = headerBuf.WriteTo(w) - if err != nil { - return err - } - - _, err = elemBuf.WriteTo(w) - if err != nil { - return err - } - - return err + return false, err } diff --git a/pgtype/float8.go b/pgtype/float8.go index a53de5e3..9ec9a665 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -114,30 +114,26 @@ func (dst *Float8) DecodeBinary(src []byte) error { return nil } -func (src Float8) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Float8) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - s := strconv.FormatFloat(float64(src.Float), 'f', -1, 64) - _, err := pgio.WriteInt32(w, int32(len(s))) - if err != nil { - return nil - } - _, err = w.Write([]byte(s)) - return err + _, err := io.WriteString(w, strconv.FormatFloat(float64(src.Float), 'f', -1, 64)) + return false, err } -func (src Float8) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Float8) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - _, err := pgio.WriteInt32(w, 8) - if err != nil { - return err - } - - _, err = pgio.WriteInt64(w, int64(math.Float64bits(src.Float))) - return err + _, err := pgio.WriteInt64(w, int64(math.Float64bits(src.Float))) + return false, err } diff --git a/pgtype/float8array.go b/pgtype/float8array.go index 082e817d..ce7e3b90 100644 --- a/pgtype/float8array.go +++ b/pgtype/float8array.go @@ -152,26 +152,22 @@ func (dst *Float8Array) DecodeBinary(src []byte) error { return nil } -func (src *Float8Array) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *Float8Array) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } if len(src.Dimensions) == 0 { - _, err := pgio.WriteInt32(w, 2) - if err != nil { - return err - } - - _, err = w.Write([]byte("{}")) - return err + _, err := io.WriteString(w, "{}") + return false, err } - buf := &bytes.Buffer{} - - err := EncodeTextArrayDimensions(buf, src.Dimensions) + err := EncodeTextArrayDimensions(w, src.Dimensions) if err != nil { - return err + return false, err } // dimElemCounts is the multiples of elements that each array lies on. For @@ -185,100 +181,112 @@ func (src *Float8Array) EncodeText(w io.Writer) error { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } - textElementWriter := NewTextElementWriter(buf) - for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(buf, ',') + err = pgio.WriteByte(w, ',') if err != nil { - return err + return false, err } } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(buf, '{') + err = pgio.WriteByte(w, '{') if err != nil { - return err + return false, err } } } - textElementWriter.Reset() - err = elem.EncodeText(textElementWriter) + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) if err != nil { - return err + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else if elemBuf.Len() == 0 { + _, err = io.WriteString(w, `""`) + if err != nil { + return false, err + } + } else { + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(buf, '}') + err = pgio.WriteByte(w, '}') if err != nil { - return err + return false, err } } } } - _, err = pgio.WriteInt32(w, int32(buf.Len())) - if err != nil { - return err - } - - _, err = buf.WriteTo(w) - return err + return false, nil } -func (src *Float8Array) EncodeBinary(w io.Writer) error { +func (src *Float8Array) EncodeBinary(w io.Writer) (bool, error) { return src.encodeBinary(w, Float8OID) } -func (src *Float8Array) encodeBinary(w io.Writer, elementOID int32) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *Float8Array) encodeBinary(w io.Writer, elementOID int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - var arrayHeader ArrayHeader + arrayHeader := ArrayHeader{ + ElementOID: elementOID, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(w) + if err != nil { + return false, err + } - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. elemBuf := &bytes.Buffer{} for i := range src.Elements { - err := src.Elements[i].EncodeBinary(elemBuf) + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(elemBuf) if err != nil { - return err + return false, err } - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } } - arrayHeader.ElementOID = elementOID - arrayHeader.Dimensions = src.Dimensions - - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. - headerBuf := &bytes.Buffer{} - err := arrayHeader.EncodeBinary(headerBuf) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) - if err != nil { - return err - } - - _, err = headerBuf.WriteTo(w) - if err != nil { - return err - } - - _, err = elemBuf.WriteTo(w) - if err != nil { - return err - } - - return err + return false, err } diff --git a/pgtype/inet.go b/pgtype/inet.go index 132a876a..f94622f4 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -144,61 +144,55 @@ func (dst *Inet) DecodeBinary(src []byte) error { return nil } -func (src Inet) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Inet) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - s := src.IPNet.String() - _, err := pgio.WriteInt32(w, int32(len(s))) - if err != nil { - return nil - } - _, err = w.Write([]byte(s)) - return err + _, err := io.WriteString(w, src.IPNet.String()) + return false, err } // EncodeBinary encodes src into w. -func (src Inet) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Inet) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - var size int32 var family byte switch len(src.IPNet.IP) { case net.IPv4len: - size = 8 family = defaultAFInet case net.IPv6len: - size = 20 family = defaultAFInet6 default: - return fmt.Errorf("Unexpected IP length: %v", len(src.IPNet.IP)) - } - - if _, err := pgio.WriteInt32(w, size); err != nil { - return err + return false, fmt.Errorf("Unexpected IP length: %v", len(src.IPNet.IP)) } if err := pgio.WriteByte(w, family); err != nil { - return err + return false, err } ones, _ := src.IPNet.Mask.Size() if err := pgio.WriteByte(w, byte(ones)); err != nil { - return err + return false, err } // is_cidr is ignored on server if err := pgio.WriteByte(w, 0); err != nil { - return err + return false, err } if err := pgio.WriteByte(w, byte(len(src.IPNet.IP))); err != nil { - return err + return false, err } _, err := w.Write(src.IPNet.IP) - return err + return false, err } diff --git a/pgtype/inetarray.go b/pgtype/inetarray.go index 28de736f..32cde554 100644 --- a/pgtype/inetarray.go +++ b/pgtype/inetarray.go @@ -184,26 +184,22 @@ func (dst *InetArray) DecodeBinary(src []byte) error { return nil } -func (src *InetArray) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *InetArray) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } if len(src.Dimensions) == 0 { - _, err := pgio.WriteInt32(w, 2) - if err != nil { - return err - } - - _, err = w.Write([]byte("{}")) - return err + _, err := io.WriteString(w, "{}") + return false, err } - buf := &bytes.Buffer{} - - err := EncodeTextArrayDimensions(buf, src.Dimensions) + err := EncodeTextArrayDimensions(w, src.Dimensions) if err != nil { - return err + return false, err } // dimElemCounts is the multiples of elements that each array lies on. For @@ -217,100 +213,112 @@ func (src *InetArray) EncodeText(w io.Writer) error { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } - textElementWriter := NewTextElementWriter(buf) - for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(buf, ',') + err = pgio.WriteByte(w, ',') if err != nil { - return err + return false, err } } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(buf, '{') + err = pgio.WriteByte(w, '{') if err != nil { - return err + return false, err } } } - textElementWriter.Reset() - err = elem.EncodeText(textElementWriter) + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) if err != nil { - return err + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else if elemBuf.Len() == 0 { + _, err = io.WriteString(w, `""`) + if err != nil { + return false, err + } + } else { + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(buf, '}') + err = pgio.WriteByte(w, '}') if err != nil { - return err + return false, err } } } } - _, err = pgio.WriteInt32(w, int32(buf.Len())) - if err != nil { - return err - } - - _, err = buf.WriteTo(w) - return err + return false, nil } -func (src *InetArray) EncodeBinary(w io.Writer) error { +func (src *InetArray) EncodeBinary(w io.Writer) (bool, error) { return src.encodeBinary(w, InetOID) } -func (src *InetArray) encodeBinary(w io.Writer, elementOID int32) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *InetArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - var arrayHeader ArrayHeader + arrayHeader := ArrayHeader{ + ElementOID: elementOID, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(w) + if err != nil { + return false, err + } - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. elemBuf := &bytes.Buffer{} for i := range src.Elements { - err := src.Elements[i].EncodeBinary(elemBuf) + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(elemBuf) if err != nil { - return err + return false, err } - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } } - arrayHeader.ElementOID = elementOID - arrayHeader.Dimensions = src.Dimensions - - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. - headerBuf := &bytes.Buffer{} - err := arrayHeader.EncodeBinary(headerBuf) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) - if err != nil { - return err - } - - _, err = headerBuf.WriteTo(w) - if err != nil { - return err - } - - _, err = elemBuf.WriteTo(w) - if err != nil { - return err - } - - return err + return false, err } diff --git a/pgtype/int2.go b/pgtype/int2.go index 51346a43..7bdbacfe 100644 --- a/pgtype/int2.go +++ b/pgtype/int2.go @@ -119,30 +119,26 @@ func (dst *Int2) DecodeBinary(src []byte) error { return nil } -func (src Int2) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Int2) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - s := strconv.FormatInt(int64(src.Int), 10) - _, err := pgio.WriteInt32(w, int32(len(s))) - if err != nil { - return nil - } - _, err = w.Write([]byte(s)) - return err + _, err := io.WriteString(w, strconv.FormatInt(int64(src.Int), 10)) + return false, err } -func (src Int2) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Int2) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - _, err := pgio.WriteInt32(w, 2) - if err != nil { - return err - } - - _, err = pgio.WriteInt16(w, src.Int) - return err + _, err := pgio.WriteInt16(w, src.Int) + return false, err } diff --git a/pgtype/int2array.go b/pgtype/int2array.go index 71760e1e..f7cc2492 100644 --- a/pgtype/int2array.go +++ b/pgtype/int2array.go @@ -183,26 +183,22 @@ func (dst *Int2Array) DecodeBinary(src []byte) error { return nil } -func (src *Int2Array) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *Int2Array) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } if len(src.Dimensions) == 0 { - _, err := pgio.WriteInt32(w, 2) - if err != nil { - return err - } - - _, err = w.Write([]byte("{}")) - return err + _, err := io.WriteString(w, "{}") + return false, err } - buf := &bytes.Buffer{} - - err := EncodeTextArrayDimensions(buf, src.Dimensions) + err := EncodeTextArrayDimensions(w, src.Dimensions) if err != nil { - return err + return false, err } // dimElemCounts is the multiples of elements that each array lies on. For @@ -216,100 +212,112 @@ func (src *Int2Array) EncodeText(w io.Writer) error { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } - textElementWriter := NewTextElementWriter(buf) - for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(buf, ',') + err = pgio.WriteByte(w, ',') if err != nil { - return err + return false, err } } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(buf, '{') + err = pgio.WriteByte(w, '{') if err != nil { - return err + return false, err } } } - textElementWriter.Reset() - err = elem.EncodeText(textElementWriter) + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) if err != nil { - return err + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else if elemBuf.Len() == 0 { + _, err = io.WriteString(w, `""`) + if err != nil { + return false, err + } + } else { + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(buf, '}') + err = pgio.WriteByte(w, '}') if err != nil { - return err + return false, err } } } } - _, err = pgio.WriteInt32(w, int32(buf.Len())) - if err != nil { - return err - } - - _, err = buf.WriteTo(w) - return err + return false, nil } -func (src *Int2Array) EncodeBinary(w io.Writer) error { +func (src *Int2Array) EncodeBinary(w io.Writer) (bool, error) { return src.encodeBinary(w, Int2OID) } -func (src *Int2Array) encodeBinary(w io.Writer, elementOID int32) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *Int2Array) encodeBinary(w io.Writer, elementOID int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - var arrayHeader ArrayHeader + arrayHeader := ArrayHeader{ + ElementOID: elementOID, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(w) + if err != nil { + return false, err + } - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. elemBuf := &bytes.Buffer{} for i := range src.Elements { - err := src.Elements[i].EncodeBinary(elemBuf) + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(elemBuf) if err != nil { - return err + return false, err } - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } } - arrayHeader.ElementOID = elementOID - arrayHeader.Dimensions = src.Dimensions - - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. - headerBuf := &bytes.Buffer{} - err := arrayHeader.EncodeBinary(headerBuf) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) - if err != nil { - return err - } - - _, err = headerBuf.WriteTo(w) - if err != nil { - return err - } - - _, err = elemBuf.WriteTo(w) - if err != nil { - return err - } - - return err + return false, err } diff --git a/pgtype/int4.go b/pgtype/int4.go index 8a53d454..2d96ea48 100644 --- a/pgtype/int4.go +++ b/pgtype/int4.go @@ -110,30 +110,26 @@ func (dst *Int4) DecodeBinary(src []byte) error { return nil } -func (src Int4) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Int4) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - s := strconv.FormatInt(int64(src.Int), 10) - _, err := pgio.WriteInt32(w, int32(len(s))) - if err != nil { - return nil - } - _, err = w.Write([]byte(s)) - return err + _, err := io.WriteString(w, strconv.FormatInt(int64(src.Int), 10)) + return false, err } -func (src Int4) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Int4) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - _, err := pgio.WriteInt32(w, 4) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, src.Int) - return err + _, err := pgio.WriteInt32(w, src.Int) + return false, err } diff --git a/pgtype/int4array.go b/pgtype/int4array.go index 6a202b08..fa710af7 100644 --- a/pgtype/int4array.go +++ b/pgtype/int4array.go @@ -183,26 +183,22 @@ func (dst *Int4Array) DecodeBinary(src []byte) error { return nil } -func (src *Int4Array) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *Int4Array) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } if len(src.Dimensions) == 0 { - _, err := pgio.WriteInt32(w, 2) - if err != nil { - return err - } - - _, err = w.Write([]byte("{}")) - return err + _, err := io.WriteString(w, "{}") + return false, err } - buf := &bytes.Buffer{} - - err := EncodeTextArrayDimensions(buf, src.Dimensions) + err := EncodeTextArrayDimensions(w, src.Dimensions) if err != nil { - return err + return false, err } // dimElemCounts is the multiples of elements that each array lies on. For @@ -216,100 +212,112 @@ func (src *Int4Array) EncodeText(w io.Writer) error { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } - textElementWriter := NewTextElementWriter(buf) - for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(buf, ',') + err = pgio.WriteByte(w, ',') if err != nil { - return err + return false, err } } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(buf, '{') + err = pgio.WriteByte(w, '{') if err != nil { - return err + return false, err } } } - textElementWriter.Reset() - err = elem.EncodeText(textElementWriter) + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) if err != nil { - return err + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else if elemBuf.Len() == 0 { + _, err = io.WriteString(w, `""`) + if err != nil { + return false, err + } + } else { + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(buf, '}') + err = pgio.WriteByte(w, '}') if err != nil { - return err + return false, err } } } } - _, err = pgio.WriteInt32(w, int32(buf.Len())) - if err != nil { - return err - } - - _, err = buf.WriteTo(w) - return err + return false, nil } -func (src *Int4Array) EncodeBinary(w io.Writer) error { +func (src *Int4Array) EncodeBinary(w io.Writer) (bool, error) { return src.encodeBinary(w, Int4OID) } -func (src *Int4Array) encodeBinary(w io.Writer, elementOID int32) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *Int4Array) encodeBinary(w io.Writer, elementOID int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - var arrayHeader ArrayHeader + arrayHeader := ArrayHeader{ + ElementOID: elementOID, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(w) + if err != nil { + return false, err + } - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. elemBuf := &bytes.Buffer{} for i := range src.Elements { - err := src.Elements[i].EncodeBinary(elemBuf) + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(elemBuf) if err != nil { - return err + return false, err } - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } } - arrayHeader.ElementOID = elementOID - arrayHeader.Dimensions = src.Dimensions - - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. - headerBuf := &bytes.Buffer{} - err := arrayHeader.EncodeBinary(headerBuf) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) - if err != nil { - return err - } - - _, err = headerBuf.WriteTo(w) - if err != nil { - return err - } - - _, err = elemBuf.WriteTo(w) - if err != nil { - return err - } - - return err + return false, err } diff --git a/pgtype/int8.go b/pgtype/int8.go index c6bedaa6..91f5b877 100644 --- a/pgtype/int8.go +++ b/pgtype/int8.go @@ -102,30 +102,26 @@ func (dst *Int8) DecodeBinary(src []byte) error { return nil } -func (src Int8) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Int8) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - s := strconv.FormatInt(src.Int, 10) - _, err := pgio.WriteInt32(w, int32(len(s))) - if err != nil { - return nil - } - _, err = w.Write([]byte(s)) - return err + _, err := io.WriteString(w, strconv.FormatInt(src.Int, 10)) + return false, err } -func (src Int8) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Int8) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - _, err := pgio.WriteInt32(w, 8) - if err != nil { - return err - } - - _, err = pgio.WriteInt64(w, src.Int) - return err + _, err := pgio.WriteInt64(w, src.Int) + return false, err } diff --git a/pgtype/int8array.go b/pgtype/int8array.go index f621618e..65f42477 100644 --- a/pgtype/int8array.go +++ b/pgtype/int8array.go @@ -183,26 +183,22 @@ func (dst *Int8Array) DecodeBinary(src []byte) error { return nil } -func (src *Int8Array) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *Int8Array) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } if len(src.Dimensions) == 0 { - _, err := pgio.WriteInt32(w, 2) - if err != nil { - return err - } - - _, err = w.Write([]byte("{}")) - return err + _, err := io.WriteString(w, "{}") + return false, err } - buf := &bytes.Buffer{} - - err := EncodeTextArrayDimensions(buf, src.Dimensions) + err := EncodeTextArrayDimensions(w, src.Dimensions) if err != nil { - return err + return false, err } // dimElemCounts is the multiples of elements that each array lies on. For @@ -216,100 +212,112 @@ func (src *Int8Array) EncodeText(w io.Writer) error { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } - textElementWriter := NewTextElementWriter(buf) - for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(buf, ',') + err = pgio.WriteByte(w, ',') if err != nil { - return err + return false, err } } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(buf, '{') + err = pgio.WriteByte(w, '{') if err != nil { - return err + return false, err } } } - textElementWriter.Reset() - err = elem.EncodeText(textElementWriter) + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) if err != nil { - return err + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else if elemBuf.Len() == 0 { + _, err = io.WriteString(w, `""`) + if err != nil { + return false, err + } + } else { + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(buf, '}') + err = pgio.WriteByte(w, '}') if err != nil { - return err + return false, err } } } } - _, err = pgio.WriteInt32(w, int32(buf.Len())) - if err != nil { - return err - } - - _, err = buf.WriteTo(w) - return err + return false, nil } -func (src *Int8Array) EncodeBinary(w io.Writer) error { +func (src *Int8Array) EncodeBinary(w io.Writer) (bool, error) { return src.encodeBinary(w, Int8OID) } -func (src *Int8Array) encodeBinary(w io.Writer, elementOID int32) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *Int8Array) encodeBinary(w io.Writer, elementOID int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - var arrayHeader ArrayHeader + arrayHeader := ArrayHeader{ + ElementOID: elementOID, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(w) + if err != nil { + return false, err + } - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. elemBuf := &bytes.Buffer{} for i := range src.Elements { - err := src.Elements[i].EncodeBinary(elemBuf) + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(elemBuf) if err != nil { - return err + return false, err } - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } } - arrayHeader.ElementOID = elementOID - arrayHeader.Dimensions = src.Dimensions - - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. - headerBuf := &bytes.Buffer{} - err := arrayHeader.EncodeBinary(headerBuf) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) - if err != nil { - return err - } - - _, err = headerBuf.WriteTo(w) - if err != nil { - return err - } - - _, err = elemBuf.WriteTo(w) - if err != nil { - return err - } - - return err + return false, err } diff --git a/pgtype/name.go b/pgtype/name.go index 4bbc43c1..513abfc7 100644 --- a/pgtype/name.go +++ b/pgtype/name.go @@ -35,10 +35,10 @@ func (dst *Name) DecodeBinary(src []byte) error { return (*Text)(dst).DecodeBinary(src) } -func (src Name) EncodeText(w io.Writer) error { +func (src Name) EncodeText(w io.Writer) (bool, error) { return (Text)(src).EncodeText(w) } -func (src Name) EncodeBinary(w io.Writer) error { +func (src Name) EncodeBinary(w io.Writer) (bool, error) { return (Text)(src).EncodeBinary(w) } diff --git a/pgtype/oid.go b/pgtype/oid.go index 2ea9c2d1..e1bee4cf 100644 --- a/pgtype/oid.go +++ b/pgtype/oid.go @@ -32,10 +32,10 @@ func (dst *OID) DecodeBinary(src []byte) error { return (*pguint32)(dst).DecodeBinary(src) } -func (src OID) EncodeText(w io.Writer) error { +func (src OID) EncodeText(w io.Writer) (bool, error) { return (pguint32)(src).EncodeText(w) } -func (src OID) EncodeBinary(w io.Writer) error { +func (src OID) EncodeBinary(w io.Writer) (bool, error) { return (pguint32)(src).EncodeBinary(w) } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 7928e1cc..d6cd53c1 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -3,8 +3,6 @@ package pgtype import ( "errors" "io" - - "github.com/jackc/pgx/pgio" ) // PostgreSQL oids for common types @@ -81,23 +79,24 @@ type TextDecoder interface { DecodeText(src []byte) error } +// BinaryEncoder is implemented by types that can encode themselves into the +// PostgreSQL binary wire format. type BinaryEncoder interface { - EncodeBinary(w io.Writer) error + // EncodeBinary should encode the binary format of self to w. If self is the + // SQL value NULL then write nothing and return (true, nil). The caller of + // EncodeBinary is responsible for writing the correct NULL value or the + // length of the data written. + EncodeBinary(w io.Writer) (null bool, err error) } +// TextEncoder is implemented by types that can encode themselves into the +// PostgreSQL text wire format. type TextEncoder interface { - EncodeText(w io.Writer) error + // EncodeText should encode the text format of self to w. If self is the SQL + // value NULL then write nothing and return (true, nil). The caller of + // EncodeText is responsible for writing the correct NULL value or the length + // of the data written. + EncodeText(w io.Writer) (null bool, err error) } var errUndefined = errors.New("cannot encode status undefined") - -func encodeNotPresent(w io.Writer, status Status) (done bool, err error) { - switch status { - case Undefined: - return true, errUndefined - case Null: - _, err = pgio.WriteInt32(w, -1) - return true, err - } - return false, nil -} diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 6e173cbe..07a40160 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -60,7 +60,7 @@ type forceTextEncoder struct { e pgtype.TextEncoder } -func (f forceTextEncoder) EncodeText(w io.Writer) error { +func (f forceTextEncoder) EncodeText(w io.Writer) (bool, error) { return f.e.EncodeText(w) } @@ -68,7 +68,7 @@ type forceBinaryEncoder struct { e pgtype.BinaryEncoder } -func (f forceBinaryEncoder) EncodeBinary(w io.Writer) error { +func (f forceBinaryEncoder) EncodeBinary(w io.Writer) (bool, error) { return f.e.EncodeBinary(w) } diff --git a/pgtype/pguint32.go b/pgtype/pguint32.go index 9bf1eef6..df9e0d36 100644 --- a/pgtype/pguint32.go +++ b/pgtype/pguint32.go @@ -82,30 +82,26 @@ func (dst *pguint32) DecodeBinary(src []byte) error { return nil } -func (src pguint32) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src pguint32) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - s := strconv.FormatUint(uint64(src.Uint), 10) - _, err := pgio.WriteInt32(w, int32(len(s))) - if err != nil { - return nil - } - _, err = w.Write([]byte(s)) - return err + _, err := io.WriteString(w, strconv.FormatUint(uint64(src.Uint), 10)) + return false, err } -func (src pguint32) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src pguint32) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - _, err := pgio.WriteInt32(w, 4) - if err != nil { - return err - } - - _, err = pgio.WriteUint32(w, src.Uint) - return err + _, err := pgio.WriteUint32(w, src.Uint) + return false, err } diff --git a/pgtype/qchar.go b/pgtype/qchar.go index 8abec935..0da1e88b 100644 --- a/pgtype/qchar.go +++ b/pgtype/qchar.go @@ -120,15 +120,13 @@ func (dst *QChar) DecodeBinary(src []byte) error { return nil } -func (src QChar) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src QChar) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - _, err := pgio.WriteInt32(w, 1) - if err != nil { - return nil - } - - return pgio.WriteByte(w, byte(src.Int)) + return false, pgio.WriteByte(w, byte(src.Int)) } diff --git a/pgtype/text.go b/pgtype/text.go index 2951b5ad..baf62d1e 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -4,8 +4,6 @@ import ( "fmt" "io" "reflect" - - "github.com/jackc/pgx/pgio" ) type Text struct { @@ -85,20 +83,18 @@ func (dst *Text) DecodeBinary(src []byte) error { return dst.DecodeText(src) } -func (src Text) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Text) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - _, err := pgio.WriteInt32(w, int32(len(src.String))) - if err != nil { - return nil - } - - _, err = io.WriteString(w, src.String) - return err + _, err := io.WriteString(w, src.String) + return false, err } -func (src Text) EncodeBinary(w io.Writer) error { +func (src Text) EncodeBinary(w io.Writer) (bool, error) { return src.EncodeText(w) } diff --git a/pgtype/text_element.go b/pgtype/text_element.go deleted file mode 100644 index 1a585d08..00000000 --- a/pgtype/text_element.go +++ /dev/null @@ -1,112 +0,0 @@ -package pgtype - -import ( - "bytes" - "errors" - "io" - - "github.com/jackc/pgx/pgio" -) - -// TextElementWriter is a wrapper that makes TextEncoders composable into other -// TextEncoders. TextEncoder first writes the length of the subsequent value. -// This is not necessary when the value is part of another value such as an -// array. TextElementWriter requires one int32 to be written first which it -// ignores. No other integer writes are valid. -type TextElementWriter struct { - w io.Writer - lengthHeaderIgnored bool -} - -func NewTextElementWriter(w io.Writer) *TextElementWriter { - return &TextElementWriter{w: w} -} - -func (w *TextElementWriter) WriteUint16(n uint16) (int, error) { - return 0, errors.New("WriteUint16 should never be called on TextElementWriter") -} - -func (w *TextElementWriter) WriteUint32(n uint32) (int, error) { - if !w.lengthHeaderIgnored { - w.lengthHeaderIgnored = true - - if int32(n) == -1 { - return io.WriteString(w.w, "NULL") - } - - return 4, nil - } - - return 0, errors.New("WriteUint32 should only be called once on TextElementWriter") -} - -func (w *TextElementWriter) WriteUint64(n uint64) (int, error) { - if w.lengthHeaderIgnored { - return pgio.WriteUint64(w.w, n) - } - - return 0, errors.New("WriteUint64 should never be called on TextElementWriter") -} - -func (w *TextElementWriter) Write(buf []byte) (int, error) { - if w.lengthHeaderIgnored { - return w.w.Write(buf) - } - - return 0, errors.New("int32 must be written first") -} - -func (w *TextElementWriter) Reset() { - w.lengthHeaderIgnored = false -} - -// TextElementReader is a wrapper that makes TextDecoders composable into other -// TextDecoders. TextEncoders first read the length of the subsequent value. -// This length value is not present when the value is part of another value such -// as an array. TextElementReader provides a substitute length value from the -// length of the string. No other integer reads are valid. Each time DecodeText -// is called with a TextElementReader as the source the TextElementReader must -// first have Reset called with the new element string data. -type TextElementReader struct { - buf *bytes.Buffer - lengthHeaderIgnored bool -} - -func NewTextElementReader(r io.Reader) *TextElementReader { - return &TextElementReader{buf: &bytes.Buffer{}} -} - -func (r *TextElementReader) ReadUint16() (uint16, error) { - return 0, errors.New("ReadUint16 should never be called on TextElementReader") -} - -func (r *TextElementReader) ReadUint32() (uint32, error) { - if !r.lengthHeaderIgnored { - r.lengthHeaderIgnored = true - if r.buf.String() == "NULL" { - n32 := int32(-1) - return uint32(n32), nil - } - return uint32(r.buf.Len()), nil - } - - return 0, errors.New("ReadUint32 should only be called once on TextElementReader") -} - -func (r *TextElementReader) WriteUint64(n uint64) (int, error) { - return 0, errors.New("ReadUint64 should never be called on TextElementReader") -} - -func (r *TextElementReader) Read(buf []byte) (int, error) { - if r.lengthHeaderIgnored { - return r.buf.Read(buf) - } - - return 0, errors.New("int32 must be read first") -} - -func (r *TextElementReader) Reset(s string) { - r.lengthHeaderIgnored = false - r.buf.Reset() - r.buf.WriteString(s) -} diff --git a/pgtype/textarray.go b/pgtype/textarray.go index e7ca3578..c3e595e0 100644 --- a/pgtype/textarray.go +++ b/pgtype/textarray.go @@ -152,26 +152,22 @@ func (dst *TextArray) DecodeBinary(src []byte) error { return nil } -func (src *TextArray) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *TextArray) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } if len(src.Dimensions) == 0 { - _, err := pgio.WriteInt32(w, 2) - if err != nil { - return err - } - - _, err = w.Write([]byte("{}")) - return err + _, err := io.WriteString(w, "{}") + return false, err } - buf := &bytes.Buffer{} - - err := EncodeTextArrayDimensions(buf, src.Dimensions) + err := EncodeTextArrayDimensions(w, src.Dimensions) if err != nil { - return err + return false, err } // dimElemCounts is the multiples of elements that each array lies on. For @@ -185,112 +181,112 @@ func (src *TextArray) EncodeText(w io.Writer) error { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } - textElementWriter := NewTextElementWriter(buf) - for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(buf, ',') + err = pgio.WriteByte(w, ',') if err != nil { - return err + return false, err } } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(buf, '{') + err = pgio.WriteByte(w, '{') if err != nil { - return err + return false, err } } } - textElementWriter.Reset() - if elem.Status == Null { - _, err := io.WriteString(buf, `"NULL"`) + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) + if err != nil { + return false, err + } + if null { + _, err = io.WriteString(w, `"NULL"`) if err != nil { - return err + return false, err } - } else if elem.String == "" { - _, err := io.WriteString(buf, `""`) + } else if elemBuf.Len() == 0 { + _, err = io.WriteString(w, `""`) if err != nil { - return err + return false, err } } else { - err = elem.EncodeText(textElementWriter) + _, err = elemBuf.WriteTo(w) if err != nil { - return err + return false, err } } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(buf, '}') + err = pgio.WriteByte(w, '}') if err != nil { - return err + return false, err } } } } - _, err = pgio.WriteInt32(w, int32(buf.Len())) - if err != nil { - return err - } - - _, err = buf.WriteTo(w) - return err + return false, nil } -func (src *TextArray) EncodeBinary(w io.Writer) error { +func (src *TextArray) EncodeBinary(w io.Writer) (bool, error) { return src.encodeBinary(w, TextOID) } -func (src *TextArray) encodeBinary(w io.Writer, elementOID int32) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *TextArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - var arrayHeader ArrayHeader + arrayHeader := ArrayHeader{ + ElementOID: elementOID, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(w) + if err != nil { + return false, err + } - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. elemBuf := &bytes.Buffer{} for i := range src.Elements { - err := src.Elements[i].EncodeBinary(elemBuf) + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(elemBuf) if err != nil { - return err + return false, err } - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } } - arrayHeader.ElementOID = elementOID - arrayHeader.Dimensions = src.Dimensions - - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. - headerBuf := &bytes.Buffer{} - err := arrayHeader.EncodeBinary(headerBuf) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) - if err != nil { - return err - } - - _, err = headerBuf.WriteTo(w) - if err != nil { - return err - } - - _, err = elemBuf.WriteTo(w) - if err != nil { - return err - } - - return err + return false, err } diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index ca5eb738..a8b628e9 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -127,12 +127,15 @@ func (dst *Timestamp) DecodeBinary(src []byte) error { // EncodeText writes the text encoding of src into w. If src.Time is not in // the UTC time zone it returns an error. -func (src Timestamp) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Timestamp) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } if src.Time.Location() != time.UTC { - return fmt.Errorf("cannot encode non-UTC time into timestamp") + return false, fmt.Errorf("cannot encode non-UTC time into timestamp") } var s string @@ -146,28 +149,21 @@ func (src Timestamp) EncodeText(w io.Writer) error { s = "-infinity" } - _, err := pgio.WriteInt32(w, int32(len(s))) - if err != nil { - return nil - } - - _, err = w.Write([]byte(s)) - return err + _, err := io.WriteString(w, s) + return false, err } // EncodeBinary writes the binary encoding of src into w. If src.Time is not in // the UTC time zone it returns an error. -func (src Timestamp) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Timestamp) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } if src.Time.Location() != time.UTC { - return fmt.Errorf("cannot encode non-UTC time into timestamp") - } - - _, err := pgio.WriteInt32(w, 8) - if err != nil { - return err + return false, fmt.Errorf("cannot encode non-UTC time into timestamp") } var microsecSinceY2K int64 @@ -181,6 +177,6 @@ func (src Timestamp) EncodeBinary(w io.Writer) error { microsecSinceY2K = negativeInfinityMicrosecondOffset } - _, err = pgio.WriteInt64(w, microsecSinceY2K) - return err + _, err := pgio.WriteInt64(w, microsecSinceY2K) + return false, err } diff --git a/pgtype/timestamparray.go b/pgtype/timestamparray.go index 695559ac..21e4de98 100644 --- a/pgtype/timestamparray.go +++ b/pgtype/timestamparray.go @@ -153,26 +153,22 @@ func (dst *TimestampArray) DecodeBinary(src []byte) error { return nil } -func (src *TimestampArray) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *TimestampArray) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } if len(src.Dimensions) == 0 { - _, err := pgio.WriteInt32(w, 2) - if err != nil { - return err - } - - _, err = w.Write([]byte("{}")) - return err + _, err := io.WriteString(w, "{}") + return false, err } - buf := &bytes.Buffer{} - - err := EncodeTextArrayDimensions(buf, src.Dimensions) + err := EncodeTextArrayDimensions(w, src.Dimensions) if err != nil { - return err + return false, err } // dimElemCounts is the multiples of elements that each array lies on. For @@ -186,100 +182,112 @@ func (src *TimestampArray) EncodeText(w io.Writer) error { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } - textElementWriter := NewTextElementWriter(buf) - for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(buf, ',') + err = pgio.WriteByte(w, ',') if err != nil { - return err + return false, err } } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(buf, '{') + err = pgio.WriteByte(w, '{') if err != nil { - return err + return false, err } } } - textElementWriter.Reset() - err = elem.EncodeText(textElementWriter) + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) if err != nil { - return err + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else if elemBuf.Len() == 0 { + _, err = io.WriteString(w, `""`) + if err != nil { + return false, err + } + } else { + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(buf, '}') + err = pgio.WriteByte(w, '}') if err != nil { - return err + return false, err } } } } - _, err = pgio.WriteInt32(w, int32(buf.Len())) - if err != nil { - return err - } - - _, err = buf.WriteTo(w) - return err + return false, nil } -func (src *TimestampArray) EncodeBinary(w io.Writer) error { +func (src *TimestampArray) EncodeBinary(w io.Writer) (bool, error) { return src.encodeBinary(w, TimestampOID) } -func (src *TimestampArray) encodeBinary(w io.Writer, elementOID int32) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *TimestampArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - var arrayHeader ArrayHeader + arrayHeader := ArrayHeader{ + ElementOID: elementOID, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(w) + if err != nil { + return false, err + } - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. elemBuf := &bytes.Buffer{} for i := range src.Elements { - err := src.Elements[i].EncodeBinary(elemBuf) + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(elemBuf) if err != nil { - return err + return false, err } - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } } - arrayHeader.ElementOID = elementOID - arrayHeader.Dimensions = src.Dimensions - - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. - headerBuf := &bytes.Buffer{} - err := arrayHeader.EncodeBinary(headerBuf) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) - if err != nil { - return err - } - - _, err = headerBuf.WriteTo(w) - if err != nil { - return err - } - - _, err = elemBuf.WriteTo(w) - if err != nil { - return err - } - - return err + return false, err } diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index 7255bb06..f4c67b0b 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -131,9 +131,12 @@ func (dst *Timestamptz) DecodeBinary(src []byte) error { return nil } -func (src Timestamptz) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src Timestamptz) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } var s string @@ -147,23 +150,16 @@ func (src Timestamptz) EncodeText(w io.Writer) error { s = "-infinity" } - _, err := pgio.WriteInt32(w, int32(len(s))) - if err != nil { - return nil - } - - _, err = w.Write([]byte(s)) - return err + _, err := io.WriteString(w, s) + return false, err } -func (src Timestamptz) EncodeBinary(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err - } - - _, err := pgio.WriteInt32(w, 8) - if err != nil { - return err +func (src Timestamptz) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } var microsecSinceY2K int64 @@ -177,6 +173,6 @@ func (src Timestamptz) EncodeBinary(w io.Writer) error { microsecSinceY2K = negativeInfinityMicrosecondOffset } - _, err = pgio.WriteInt64(w, microsecSinceY2K) - return err + _, err := pgio.WriteInt64(w, microsecSinceY2K) + return false, err } diff --git a/pgtype/timestamptzarray.go b/pgtype/timestamptzarray.go index ca416c97..597b1842 100644 --- a/pgtype/timestamptzarray.go +++ b/pgtype/timestamptzarray.go @@ -153,26 +153,22 @@ func (dst *TimestamptzArray) DecodeBinary(src []byte) error { return nil } -func (src *TimestamptzArray) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *TimestamptzArray) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } if len(src.Dimensions) == 0 { - _, err := pgio.WriteInt32(w, 2) - if err != nil { - return err - } - - _, err = w.Write([]byte("{}")) - return err + _, err := io.WriteString(w, "{}") + return false, err } - buf := &bytes.Buffer{} - - err := EncodeTextArrayDimensions(buf, src.Dimensions) + err := EncodeTextArrayDimensions(w, src.Dimensions) if err != nil { - return err + return false, err } // dimElemCounts is the multiples of elements that each array lies on. For @@ -186,100 +182,112 @@ func (src *TimestamptzArray) EncodeText(w io.Writer) error { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } - textElementWriter := NewTextElementWriter(buf) - for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(buf, ',') + err = pgio.WriteByte(w, ',') if err != nil { - return err + return false, err } } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(buf, '{') + err = pgio.WriteByte(w, '{') if err != nil { - return err + return false, err } } } - textElementWriter.Reset() - err = elem.EncodeText(textElementWriter) + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) if err != nil { - return err + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else if elemBuf.Len() == 0 { + _, err = io.WriteString(w, `""`) + if err != nil { + return false, err + } + } else { + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(buf, '}') + err = pgio.WriteByte(w, '}') if err != nil { - return err + return false, err } } } } - _, err = pgio.WriteInt32(w, int32(buf.Len())) - if err != nil { - return err - } - - _, err = buf.WriteTo(w) - return err + return false, nil } -func (src *TimestamptzArray) EncodeBinary(w io.Writer) error { +func (src *TimestamptzArray) EncodeBinary(w io.Writer) (bool, error) { return src.encodeBinary(w, TimestamptzOID) } -func (src *TimestamptzArray) encodeBinary(w io.Writer, elementOID int32) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *TimestamptzArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - var arrayHeader ArrayHeader + arrayHeader := ArrayHeader{ + ElementOID: elementOID, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(w) + if err != nil { + return false, err + } - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. elemBuf := &bytes.Buffer{} for i := range src.Elements { - err := src.Elements[i].EncodeBinary(elemBuf) + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(elemBuf) if err != nil { - return err + return false, err } - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } } - arrayHeader.ElementOID = elementOID - arrayHeader.Dimensions = src.Dimensions - - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. - headerBuf := &bytes.Buffer{} - err := arrayHeader.EncodeBinary(headerBuf) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) - if err != nil { - return err - } - - _, err = headerBuf.WriteTo(w) - if err != nil { - return err - } - - _, err = elemBuf.WriteTo(w) - if err != nil { - return err - } - - return err + return false, err } diff --git a/pgtype/to-consider.txt b/pgtype/to-consider.txt deleted file mode 100644 index ba4f3511..00000000 --- a/pgtype/to-consider.txt +++ /dev/null @@ -1,9 +0,0 @@ -DecodeText and DecodeBinary take []byte instead of io.Reader -EncodeText and EncodeBinary do not write size -Add Nullable interface with IsNull() and SetNull() - -The above would keep types from needing to worry about writing their own size. Could make EncodeText and DecodeText easier to use with sql.Scanner and driver.Valuer. SetNull() could be removed as DecodeText and DecodeBinary could interpret a nil slice as null. - -EncodeText and EncodeBinary could return (null bool, err error). That would finish removing Nullable interface. - -Also, consider whether arrays and ranges could be represented as generic data types or more common code could be extracted instead of using code generation. diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb index 316439ef..2e9b77ea 100644 --- a/pgtype/typed_array.go.erb +++ b/pgtype/typed_array.go.erb @@ -151,26 +151,22 @@ func (dst *<%= pgtype_array_type %>) DecodeBinary(src []byte) error { return nil } -func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } if len(src.Dimensions) == 0 { - _, err := pgio.WriteInt32(w, 2) - if err != nil { - return err - } - - _, err = w.Write([]byte("{}")) - return err + _, err := io.WriteString(w, "{}") + return false, err } - buf := &bytes.Buffer{} - - err := EncodeTextArrayDimensions(buf, src.Dimensions) + err := EncodeTextArrayDimensions(w, src.Dimensions) if err != nil { - return err + return false, err } // dimElemCounts is the multiples of elements that each array lies on. For @@ -184,100 +180,112 @@ func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) error { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } - textElementWriter := NewTextElementWriter(buf) - for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(buf, ',') + err = pgio.WriteByte(w, ',') if err != nil { - return err + return false, err } } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(buf, '{') + err = pgio.WriteByte(w, '{') if err != nil { - return err + return false, err } } } - textElementWriter.Reset() - err = elem.EncodeText(textElementWriter) + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) if err != nil { - return err + return false, err + } + if null { + _, err = io.WriteString(w, `<%= text_null %>`) + if err != nil { + return false, err + } + } else if elemBuf.Len() == 0 { + _, err = io.WriteString(w, `""`) + if err != nil { + return false, err + } + } else { + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(buf, '}') + err = pgio.WriteByte(w, '}') if err != nil { - return err + return false, err } } } } - _, err = pgio.WriteInt32(w, int32(buf.Len())) - if err != nil { - return err - } - - _, err = buf.WriteTo(w) - return err + return false, nil } -func (src *<%= pgtype_array_type %>) EncodeBinary(w io.Writer) error { +func (src *<%= pgtype_array_type %>) EncodeBinary(w io.Writer) (bool, error) { return src.encodeBinary(w, <%= element_oid %>) } -func (src *<%= pgtype_array_type %>) encodeBinary(w io.Writer, elementOID int32) error { - if done, err := encodeNotPresent(w, src.Status); done { - return err +func (src *<%= pgtype_array_type %>) encodeBinary(w io.Writer, elementOID int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - var arrayHeader ArrayHeader + arrayHeader := ArrayHeader{ + ElementOID: elementOID, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(w) + if err != nil { + return false, err + } - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. elemBuf := &bytes.Buffer{} for i := range src.Elements { - err := src.Elements[i].EncodeBinary(elemBuf) + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(elemBuf) if err != nil { - return err + return false, err } - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } } - arrayHeader.ElementOID = elementOID - arrayHeader.Dimensions = src.Dimensions - - // TODO - consider how to avoid having to buffer array before writing length - - // or how not pay allocations for the byte order conversions. - headerBuf := &bytes.Buffer{} - err := arrayHeader.EncodeBinary(headerBuf) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) - if err != nil { - return err - } - - _, err = headerBuf.WriteTo(w) - if err != nil { - return err - } - - _, err = elemBuf.WriteTo(w) - if err != nil { - return err - } - - return err + return false, err } diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh index 1e2dce64..43109700 100644 --- a/pgtype/typed_array_gen.sh +++ b/pgtype/typed_array_gen.sh @@ -1,11 +1,11 @@ -erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_oid=Int2OID typed_array.go.erb > int2array.go -erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32 element_oid=Int4OID typed_array.go.erb > int4array.go -erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_oid=Int8OID typed_array.go.erb > int8array.go -erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_oid=BoolOID typed_array.go.erb > boolarray.go -erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_oid=DateOID typed_array.go.erb > datearray.go -erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_oid=TimestamptzOID typed_array.go.erb > timestamptzarray.go -erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_oid=TimestampOID typed_array.go.erb > timestamparray.go -erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_oid=Float4OID typed_array.go.erb > float4array.go -erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8OID typed_array.go.erb > float8array.go -erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOID typed_array.go.erb > inetarray.go -erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_oid=TextOID typed_array.go.erb > textarray.go +erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_oid=Int2OID text_null=NULL typed_array.go.erb > int2array.go +erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32 element_oid=Int4OID text_null=NULL typed_array.go.erb > int4array.go +erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_oid=Int8OID text_null=NULL typed_array.go.erb > int8array.go +erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_oid=BoolOID text_null=NULL typed_array.go.erb > boolarray.go +erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_oid=DateOID text_null=NULL typed_array.go.erb > datearray.go +erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_oid=TimestamptzOID text_null=NULL typed_array.go.erb > timestamptzarray.go +erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_oid=TimestampOID text_null=NULL typed_array.go.erb > timestamparray.go +erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_oid=Float4OID text_null=NULL typed_array.go.erb > float4array.go +erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8OID text_null=NULL typed_array.go.erb > float8array.go +erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOID text_null=NULL typed_array.go.erb > inetarray.go +erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_oid=TextOID text_null='"NULL"' typed_array.go.erb > textarray.go diff --git a/pgtype/varchararray.go b/pgtype/varchararray.go index 3a5d8536..9c8829d0 100644 --- a/pgtype/varchararray.go +++ b/pgtype/varchararray.go @@ -22,10 +22,10 @@ func (dst *VarcharArray) DecodeBinary(src []byte) error { return (*TextArray)(dst).DecodeBinary(src) } -func (src *VarcharArray) EncodeText(w io.Writer) error { +func (src *VarcharArray) EncodeText(w io.Writer) (bool, error) { return (*TextArray)(src).EncodeText(w) } -func (src *VarcharArray) EncodeBinary(w io.Writer) error { +func (src *VarcharArray) EncodeBinary(w io.Writer) (bool, error) { return (*TextArray)(src).encodeBinary(w, VarcharOID) } diff --git a/pgtype/xid.go b/pgtype/xid.go index 389f93bc..6635b21e 100644 --- a/pgtype/xid.go +++ b/pgtype/xid.go @@ -41,10 +41,10 @@ func (dst *XID) DecodeBinary(src []byte) error { return (*pguint32)(dst).DecodeBinary(src) } -func (src XID) EncodeText(w io.Writer) error { +func (src XID) EncodeText(w io.Writer) (bool, error) { return (pguint32)(src).EncodeText(w) } -func (src XID) EncodeBinary(w io.Writer) error { +func (src XID) EncodeBinary(w io.Writer) (bool, error) { return (pguint32)(src).EncodeBinary(w) } diff --git a/values.go b/values.go index 796f2f3d..88bf13d2 100644 --- a/values.go +++ b/values.go @@ -408,7 +408,10 @@ func (n NullInt16) Encode(w *WriteBuf, oid OID) error { return nil } - return pgtype.Int2{Int: n.Int16, Status: pgtype.Present}.EncodeBinary(w) + w.WriteInt32(2) + + _, err := pgtype.Int2{Int: n.Int16, Status: pgtype.Present}.EncodeBinary(w) + return err } // NullInt32 represents an integer that may be null. NullInt32 implements the @@ -447,7 +450,10 @@ func (n NullInt32) Encode(w *WriteBuf, oid OID) error { return nil } - return pgtype.Int4{Int: n.Int32, Status: pgtype.Present}.EncodeBinary(w) + w.WriteInt32(4) + + _, err := pgtype.Int4{Int: n.Int32, Status: pgtype.Present}.EncodeBinary(w) + return err } // OID (Object Identifier Type) is, according to https://www.postgresql.org/docs/current/static/datatype-oid.html, @@ -484,24 +490,14 @@ func (dst *OID) DecodeBinary(src []byte) error { return nil } -func (src OID) EncodeText(w io.Writer) error { - s := strconv.FormatUint(uint64(src), 10) - _, err := pgio.WriteInt32(w, int32(len(s))) - if err != nil { - return nil - } - _, err = w.Write([]byte(s)) - return err +func (src OID) EncodeText(w io.Writer) (bool, error) { + _, err := io.WriteString(w, strconv.FormatUint(uint64(src), 10)) + return false, err } -func (src OID) EncodeBinary(w io.Writer) error { - _, err := pgio.WriteInt32(w, 4) - if err != nil { - return err - } - - _, err = pgio.WriteUint32(w, uint32(src)) - return err +func (src OID) EncodeBinary(w io.Writer) (bool, error) { + _, err := pgio.WriteUint32(w, uint32(src)) + return false, err } // Tid is PostgreSQL's Tuple Identifier type. @@ -595,7 +591,10 @@ func (n NullInt64) Encode(w *WriteBuf, oid OID) error { return nil } - return pgtype.Int8{Int: n.Int64, Status: pgtype.Present}.EncodeBinary(w) + w.WriteInt32(8) + + _, err := pgtype.Int8{Int: n.Int64, Status: pgtype.Present}.EncodeBinary(w) + return err } // NullBool represents an bool that may be null. NullBool implements the Scanner @@ -634,7 +633,10 @@ func (n NullBool) Encode(w *WriteBuf, oid OID) error { return nil } - return encodeBool(w, oid, n.Bool) + w.WriteInt32(1) + + _, err := pgtype.Bool{Bool: n.Bool, Status: pgtype.Present}.EncodeBinary(w) + return err } // NullTime represents an time.Time that may be null. NullTime implements the @@ -834,9 +836,31 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { case Encoder: return arg.Encode(wbuf, oid) case pgtype.BinaryEncoder: - return arg.EncodeBinary(wbuf) + buf := &bytes.Buffer{} + null, err := arg.EncodeBinary(buf) + if err != nil { + return err + } + if null { + wbuf.WriteInt32(-1) + } else { + wbuf.WriteInt32(int32(buf.Len())) + wbuf.WriteBytes(buf.Bytes()) + } + return nil case pgtype.TextEncoder: - return arg.EncodeText(wbuf) + buf := &bytes.Buffer{} + null, err := arg.EncodeText(buf) + if err != nil { + return err + } + if null { + wbuf.WriteInt32(-1) + } else { + wbuf.WriteInt32(int32(buf.Len())) + wbuf.WriteBytes(buf.Bytes()) + } + return nil case driver.Valuer: v, err := arg.Value() if err != nil { @@ -876,7 +900,19 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { if err != nil { return err } - return value.(pgtype.BinaryEncoder).EncodeBinary(wbuf) + + buf := &bytes.Buffer{} + null, err := value.(pgtype.BinaryEncoder).EncodeBinary(buf) + if err != nil { + return err + } + if null { + wbuf.WriteInt32(-1) + } else { + wbuf.WriteInt32(int32(buf.Len())) + wbuf.WriteBytes(buf.Bytes()) + } + return nil } switch arg := arg.(type) { @@ -1026,15 +1062,6 @@ func decodeBool(vr *ValueReader) bool { return b.Bool } -func encodeBool(w *WriteBuf, oid OID, value bool) error { - if oid != BoolOID { - return fmt.Errorf("cannot encode Go %s into oid %d", "bool", oid) - } - - b := pgtype.Bool{Bool: value, Status: pgtype.Present} - return b.EncodeBinary(w) -} - func decodeInt(vr *ValueReader) int64 { switch vr.Type().DataType { case Int2OID: @@ -1461,14 +1488,39 @@ func encodeTime(w *WriteBuf, oid OID, value time.Time) error { if err != nil { return err } - return d.EncodeBinary(w) + + buf := &bytes.Buffer{} + null, err := d.EncodeBinary(buf) + if err != nil { + return err + } + if null { + w.WriteInt32(-1) + } else { + w.WriteInt32(int32(buf.Len())) + w.WriteBytes(buf.Bytes()) + } + return nil + case TimestampTzOID, TimestampOID: var t pgtype.Timestamptz err := t.ConvertFrom(value) if err != nil { return err } - return t.EncodeBinary(w) + + buf := &bytes.Buffer{} + null, err := t.EncodeBinary(buf) + if err != nil { + return err + } + if null { + w.WriteInt32(-1) + } else { + w.WriteInt32(int32(buf.Len())) + w.WriteBytes(buf.Bytes()) + } + return nil default: return fmt.Errorf("cannot encode %s into oid %v", "time.Time", oid) } From 77c57c780d07dacb63617f06b1aa93ce9bad1f23 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Mar 2017 13:32:32 -0600 Subject: [PATCH 098/264] Add pgtype.ByteaArray Also fix up quoting array elements for text arrays. --- conn.go | 2 + pgtype/array.go | 14 ++ pgtype/boolarray.go | 7 +- pgtype/byteaarray.go | 287 +++++++++++++++++++++++++++++++++++++ pgtype/byteaarray_test.go | 119 +++++++++++++++ pgtype/datearray.go | 7 +- pgtype/float4array.go | 7 +- pgtype/float8array.go | 7 +- pgtype/inetarray.go | 7 +- pgtype/int2array.go | 7 +- pgtype/int4array.go | 7 +- pgtype/int8array.go | 7 +- pgtype/textarray.go | 7 +- pgtype/textarray_test.go | 8 +- pgtype/timestamparray.go | 7 +- pgtype/timestamptzarray.go | 7 +- pgtype/typed_array.go.erb | 7 +- pgtype/typed_array_gen.sh | 1 + values.go | 65 --------- 19 files changed, 439 insertions(+), 141 deletions(-) create mode 100644 pgtype/byteaarray.go create mode 100644 pgtype/byteaarray_test.go diff --git a/conn.go b/conn.go index f9f94c43..e340f1c6 100644 --- a/conn.go +++ b/conn.go @@ -270,6 +270,8 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.oidPgtypeValues = map[OID]pgtype.Value{ BoolArrayOID: &pgtype.BoolArray{}, BoolOID: &pgtype.Bool{}, + ByteaArrayOID: &pgtype.ByteaArray{}, + ByteaOID: &pgtype.Bytea{}, CharOID: &pgtype.QChar{}, CIDOID: &pgtype.CID{}, CidrArrayOID: &pgtype.CidrArray{}, diff --git a/pgtype/array.go b/pgtype/array.go index 6b705103..90092c8d 100644 --- a/pgtype/array.go +++ b/pgtype/array.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "strconv" + "strings" "unicode" "github.com/jackc/pgx/pgio" @@ -371,3 +372,16 @@ func EncodeTextArrayDimensions(w io.Writer, dimensions []ArrayDimension) error { return pgio.WriteByte(w, '=') } + +var quoteArrayReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) + +func quoteArrayElement(src string) string { + return `"` + quoteArrayReplacer.Replace(src) + `"` +} + +func QuoteArrayElementIfNeeded(src string) string { + if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `{},"\`) { + return quoteArrayElement(src) + } + return src +} diff --git a/pgtype/boolarray.go b/pgtype/boolarray.go index f7323281..65a6bc9c 100644 --- a/pgtype/boolarray.go +++ b/pgtype/boolarray.go @@ -208,13 +208,8 @@ func (src *BoolArray) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/pgtype/byteaarray.go b/pgtype/byteaarray.go new file mode 100644 index 00000000..7a4f1601 --- /dev/null +++ b/pgtype/byteaarray.go @@ -0,0 +1,287 @@ +package pgtype + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type ByteaArray struct { + Elements []Bytea + Dimensions []ArrayDimension + Status Status +} + +func (dst *ByteaArray) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case ByteaArray: + *dst = value + + case [][]byte: + if value == nil { + *dst = ByteaArray{Status: Null} + } else if len(value) == 0 { + *dst = ByteaArray{Status: Present} + } else { + elements := make([]Bytea, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = ByteaArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Bytea", value) + } + + return nil +} + +func (src *ByteaArray) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[][]byte: + if src.Status == Present { + *v = make([][]byte, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +func (dst *ByteaArray) DecodeText(src []byte) error { + if src == nil { + *dst = ByteaArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Bytea + + if len(uta.Elements) > 0 { + elements = make([]Bytea, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Bytea + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = ByteaArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *ByteaArray) DecodeBinary(src []byte) error { + if src == nil { + *dst = ByteaArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = ByteaArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Bytea, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) + if err != nil { + return err + } + } + + *dst = ByteaArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *ByteaArray) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if len(src.Dimensions) == 0 { + _, err := io.WriteString(w, "{}") + return false, err + } + + err := EncodeTextArrayDimensions(w, src.Dimensions) + if err != nil { + return false, err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(w, ',') + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(w, '{') + if err != nil { + return false, err + } + } + } + + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) + if err != nil { + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else { + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(w, '}') + if err != nil { + return false, err + } + } + } + } + + return false, nil +} + +func (src *ByteaArray) EncodeBinary(w io.Writer) (bool, error) { + return src.encodeBinary(w, ByteaOID) +} + +func (src *ByteaArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + arrayHeader := ArrayHeader{ + ElementOID: elementOID, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(w) + if err != nil { + return false, err + } + + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return false, err + } + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } + } + } + + return false, err +} diff --git a/pgtype/byteaarray_test.go b/pgtype/byteaarray_test.go new file mode 100644 index 00000000..b39776d9 --- /dev/null +++ b/pgtype/byteaarray_test.go @@ -0,0 +1,119 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestByteaArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "bytea[]", []interface{}{ + &pgtype.ByteaArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + pgtype.Bytea{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.ByteaArray{Status: pgtype.Null}, + &pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + pgtype.Bytea{Status: pgtype.Null}, + pgtype.Bytea{Bytes: []byte{1}, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: []byte{1}, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestByteaArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.ByteaArray + }{ + { + source: [][]byte{{1, 2, 3}}, + result: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([][]byte)(nil)), + result: pgtype.ByteaArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.ByteaArray + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestByteaArrayAssignTo(t *testing.T) { + var byteByteSlice [][]byte + + simpleTests := []struct { + src pgtype.ByteaArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &byteByteSlice, + expected: [][]byte{{1, 2, 3}}, + }, + { + src: pgtype.ByteaArray{Status: pgtype.Null}, + dst: &byteByteSlice, + expected: (([][]byte)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/datearray.go b/pgtype/datearray.go index 9552739b..623ff9b3 100644 --- a/pgtype/datearray.go +++ b/pgtype/datearray.go @@ -209,13 +209,8 @@ func (src *DateArray) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/pgtype/float4array.go b/pgtype/float4array.go index 9ab08dcc..c55f76d0 100644 --- a/pgtype/float4array.go +++ b/pgtype/float4array.go @@ -208,13 +208,8 @@ func (src *Float4Array) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/pgtype/float8array.go b/pgtype/float8array.go index ce7e3b90..d08a5351 100644 --- a/pgtype/float8array.go +++ b/pgtype/float8array.go @@ -208,13 +208,8 @@ func (src *Float8Array) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/pgtype/inetarray.go b/pgtype/inetarray.go index 32cde554..12d9493b 100644 --- a/pgtype/inetarray.go +++ b/pgtype/inetarray.go @@ -240,13 +240,8 @@ func (src *InetArray) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/pgtype/int2array.go b/pgtype/int2array.go index f7cc2492..37ee9926 100644 --- a/pgtype/int2array.go +++ b/pgtype/int2array.go @@ -239,13 +239,8 @@ func (src *Int2Array) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/pgtype/int4array.go b/pgtype/int4array.go index fa710af7..f6f62e4b 100644 --- a/pgtype/int4array.go +++ b/pgtype/int4array.go @@ -239,13 +239,8 @@ func (src *Int4Array) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/pgtype/int8array.go b/pgtype/int8array.go index 65f42477..92d8ec46 100644 --- a/pgtype/int8array.go +++ b/pgtype/int8array.go @@ -239,13 +239,8 @@ func (src *Int8Array) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/pgtype/textarray.go b/pgtype/textarray.go index c3e595e0..182e76f5 100644 --- a/pgtype/textarray.go +++ b/pgtype/textarray.go @@ -208,13 +208,8 @@ func (src *TextArray) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/pgtype/textarray_test.go b/pgtype/textarray_test.go index 29e3a6c7..a22e003d 100644 --- a/pgtype/textarray_test.go +++ b/pgtype/textarray_test.go @@ -25,12 +25,12 @@ func TestTextArrayTranscode(t *testing.T) { &pgtype.TextArray{Status: pgtype.Null}, &pgtype.TextArray{ Elements: []pgtype.Text{ - pgtype.Text{String: "bar", Status: pgtype.Present}, - pgtype.Text{String: "baz", Status: pgtype.Present}, - pgtype.Text{String: "quz", Status: pgtype.Present}, + pgtype.Text{String: "bar ", Status: pgtype.Present}, + pgtype.Text{String: "NuLL", Status: pgtype.Present}, + pgtype.Text{String: `wow"quz\`, Status: pgtype.Present}, pgtype.Text{String: "", Status: pgtype.Present}, pgtype.Text{Status: pgtype.Null}, - pgtype.Text{String: "foo", Status: pgtype.Present}, + pgtype.Text{String: "null", Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, Status: pgtype.Present, diff --git a/pgtype/timestamparray.go b/pgtype/timestamparray.go index 21e4de98..b0fb25fa 100644 --- a/pgtype/timestamparray.go +++ b/pgtype/timestamparray.go @@ -209,13 +209,8 @@ func (src *TimestampArray) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/pgtype/timestamptzarray.go b/pgtype/timestamptzarray.go index 597b1842..25374717 100644 --- a/pgtype/timestamptzarray.go +++ b/pgtype/timestamptzarray.go @@ -209,13 +209,8 @@ func (src *TimestamptzArray) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb index 2e9b77ea..f9dba308 100644 --- a/pgtype/typed_array.go.erb +++ b/pgtype/typed_array.go.erb @@ -207,13 +207,8 @@ func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh index 43109700..c63414c8 100644 --- a/pgtype/typed_array_gen.sh +++ b/pgtype/typed_array_gen.sh @@ -9,3 +9,4 @@ erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]fl erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8OID text_null=NULL typed_array.go.erb > float8array.go erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOID text_null=NULL typed_array.go.erb > inetarray.go erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_oid=TextOID text_null='"NULL"' typed_array.go.erb > textarray.go +erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_oid=ByteaOID text_null=NULL typed_array.go.erb > byteaarray.go diff --git a/values.go b/values.go index 88bf13d2..fc790dfe 100644 --- a/values.go +++ b/values.go @@ -873,8 +873,6 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { return encodeAclItemSlice(wbuf, oid, arg) case []byte: return encodeByteSlice(wbuf, oid, arg) - case [][]byte: - return encodeByteSliceSlice(wbuf, oid, arg) } refVal := reflect.ValueOf(arg) @@ -996,8 +994,6 @@ func Decode(vr *ValueReader, d interface{}) error { *v = decodeText(vr) case *[]AclItem: *v = decodeAclItemArray(vr) - case *[][]byte: - *v = decodeByteaArray(vr) case *[]interface{}: *v = decodeRecord(vr) default: @@ -1684,67 +1680,6 @@ func decode1dArrayHeader(vr *ValueReader) (length int32, err error) { return length, nil } -func decodeByteaArray(vr *ValueReader) [][]byte { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != ByteaArrayOID { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into [][]byte", vr.Type().DataType))) - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - numElems, err := decode1dArrayHeader(vr) - if err != nil { - vr.Fatal(err) - return nil - } - - a := make([][]byte, int(numElems)) - for i := 0; i < len(a); i++ { - elSize := vr.ReadInt32() - switch elSize { - case -1: - vr.Fatal(ProtocolError("Cannot decode null element")) - return nil - default: - a[i] = vr.ReadBytes(elSize) - } - } - - return a -} - -func encodeByteSliceSlice(w *WriteBuf, oid OID, value [][]byte) error { - if oid != ByteaArrayOID { - return fmt.Errorf("cannot encode Go %s into oid %d", "[][]byte", oid) - } - - size := 20 // array header size - for _, el := range value { - size += 4 + len(el) - } - - w.WriteInt32(int32(size)) - - w.WriteInt32(1) // number of dimensions - w.WriteInt32(0) // no nulls - w.WriteInt32(int32(ByteaOID)) // type of elements - w.WriteInt32(int32(len(value))) // number of elements - w.WriteInt32(1) // index of first element - - for _, el := range value { - encodeByteSlice(w, ByteaOID, el) - } - - return nil -} - // escapeAclItem escapes an AclItem before it is added to // its aclitem[] string representation. The PostgreSQL aclitem // datatype itself can need escapes because it follows the From b0cd63bcf0cc880f2ca3051f185d873216594fb8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Mar 2017 15:44:21 -0600 Subject: [PATCH 099/264] Remove unused ScannerV3 --- query.go | 9 --------- values.go | 4 ---- 2 files changed, 13 deletions(-) diff --git a/query.go b/query.go index 71d1ba9e..d1191c7c 100644 --- a/query.go +++ b/query.go @@ -221,15 +221,6 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } - } else if s, ok := d.(ScannerV3); ok { - val, err := decodeByOID(vr) - if err != nil { - rows.Fatal(scanArgError{col: i, err: err}) - } - err = s.ScanPgxV3(nil, val) - if err != nil { - rows.Fatal(scanArgError{col: i, err: err}) - } } else if s, ok := d.(pgtype.BinaryDecoder); ok && vr.Type().FormatCode == BinaryFormatCode { err = s.DecodeBinary(vr.bytes()) if err != nil { diff --git a/values.go b/values.go index fc790dfe..80f4ee52 100644 --- a/values.go +++ b/values.go @@ -204,10 +204,6 @@ type Encoder interface { FormatCode() int16 } -type ScannerV3 interface { - ScanPgxV3(fieldDescription interface{}, src interface{}) error -} - // NullFloat32 represents an float4 that may be null. NullFloat32 implements the // Scanner and Encoder interfaces so it may be used both as an argument to // Query[Row] and a destination for Scan. From fa1c81fec4413a97bb267b85c19293cff10d5841 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Mar 2017 16:13:05 -0600 Subject: [PATCH 100/264] Move ACLItem to pgtype --- aclitem_parse_test.go | 126 ---------------- conn.go | 2 + pgtype/aclitem.go | 104 ++++++++++++++ pgtype/aclitem_test.go | 97 +++++++++++++ pgtype/aclitemarray.go | 186 ++++++++++++++++++++++++ pgtype/aclitemarray_test.go | 151 +++++++++++++++++++ pgtype/pgtype.go | 4 +- pgtype/typed_array_gen.sh | 1 + values.go | 280 +----------------------------------- values_test.go | 51 ------- 10 files changed, 548 insertions(+), 454 deletions(-) delete mode 100644 aclitem_parse_test.go create mode 100644 pgtype/aclitem.go create mode 100644 pgtype/aclitem_test.go create mode 100644 pgtype/aclitemarray.go create mode 100644 pgtype/aclitemarray_test.go diff --git a/aclitem_parse_test.go b/aclitem_parse_test.go deleted file mode 100644 index 5c7c748f..00000000 --- a/aclitem_parse_test.go +++ /dev/null @@ -1,126 +0,0 @@ -package pgx - -import ( - "reflect" - "testing" -) - -func TestEscapeAclItem(t *testing.T) { - tests := []struct { - input string - expected string - }{ - { - "foo", - "foo", - }, - { - `foo, "\}`, - `foo\, \"\\\}`, - }, - } - - for i, tt := range tests { - actual, err := escapeAclItem(tt.input) - - if err != nil { - t.Errorf("%d. Unexpected error %v", i, err) - } - - if actual != tt.expected { - t.Errorf("%d.\nexpected: %s,\nactual: %s", i, tt.expected, actual) - } - } -} - -func TestParseAclItemArray(t *testing.T) { - tests := []struct { - input string - expected []AclItem - errMsg string - }{ - { - "", - []AclItem{}, - "", - }, - { - "one", - []AclItem{"one"}, - "", - }, - { - `"one"`, - []AclItem{"one"}, - "", - }, - { - "one,two,three", - []AclItem{"one", "two", "three"}, - "", - }, - { - `"one","two","three"`, - []AclItem{"one", "two", "three"}, - "", - }, - { - `"one",two,"three"`, - []AclItem{"one", "two", "three"}, - "", - }, - { - `one,two,"three"`, - []AclItem{"one", "two", "three"}, - "", - }, - { - `"one","two",three`, - []AclItem{"one", "two", "three"}, - "", - }, - { - `"one","t w o",three`, - []AclItem{"one", "t w o", "three"}, - "", - }, - { - `"one","t, w o\"\}\\",three`, - []AclItem{"one", `t, w o"}\`, "three"}, - "", - }, - { - `"one","two",three"`, - []AclItem{"one", "two", `three"`}, - "", - }, - { - `"one","two,"three"`, - nil, - "unexpected rune after quoted value", - }, - { - `"one","two","three`, - nil, - "unexpected end of quoted value", - }, - } - - for i, tt := range tests { - actual, err := parseAclItemArray(tt.input) - - if err != nil { - if tt.errMsg == "" { - t.Errorf("%d. Unexpected error %v", i, err) - } else if err.Error() != tt.errMsg { - t.Errorf("%d. Expected error %v did not match actual error %v", i, tt.errMsg, err.Error()) - } - } else if tt.errMsg != "" { - t.Errorf("%d. Expected error not returned: \"%v\"", i, tt.errMsg) - } - - if !reflect.DeepEqual(actual, tt.expected) { - t.Errorf("%d. Expected %v did not match actual %v", i, tt.expected, actual) - } - } -} diff --git a/conn.go b/conn.go index e340f1c6..f55dd82a 100644 --- a/conn.go +++ b/conn.go @@ -268,6 +268,8 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.closedChan = make(chan error) c.oidPgtypeValues = map[OID]pgtype.Value{ + ACLItemOID: &pgtype.ACLItem{}, + ACLItemArrayOID: &pgtype.ACLItemArray{}, BoolArrayOID: &pgtype.BoolArray{}, BoolOID: &pgtype.Bool{}, ByteaArrayOID: &pgtype.ByteaArray{}, diff --git a/pgtype/aclitem.go b/pgtype/aclitem.go new file mode 100644 index 00000000..bd7b7d45 --- /dev/null +++ b/pgtype/aclitem.go @@ -0,0 +1,104 @@ +package pgtype + +import ( + "fmt" + "io" + "reflect" +) + +// ACLItem is used for PostgreSQL's aclitem data type. A sample aclitem +// might look like this: +// +// postgres=arwdDxt/postgres +// +// Note, however, that because the user/role name part of an aclitem is +// an identifier, it follows all the usual formatting rules for SQL +// identifiers: if it contains spaces and other special characters, +// it should appear in double-quotes: +// +// postgres=arwdDxt/"role with spaces" +// +type ACLItem struct { + String string + Status Status +} + +func (dst *ACLItem) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case ACLItem: + *dst = value + case string: + *dst = ACLItem{String: value, Status: Present} + case *string: + if value == nil { + *dst = ACLItem{Status: Null} + } else { + *dst = ACLItem{String: *value, Status: Present} + } + default: + if originalSrc, ok := underlyingStringType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to ACLItem", value) + } + + return nil +} + +func (src *ACLItem) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *string: + if src.Status != Present { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.String + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if src.Status == Null { + el.Set(reflect.Zero(el.Type())) + return nil + } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return src.AssignTo(el.Interface()) + case reflect.String: + if src.Status != Present { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + el.SetString(src.String) + return nil + } + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +func (dst *ACLItem) DecodeText(src []byte) error { + if src == nil { + *dst = ACLItem{Status: Null} + return nil + } + + *dst = ACLItem{String: string(src), Status: Present} + return nil +} + +func (src ACLItem) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := io.WriteString(w, src.String) + return false, err +} diff --git a/pgtype/aclitem_test.go b/pgtype/aclitem_test.go new file mode 100644 index 00000000..0b2b6cfa --- /dev/null +++ b/pgtype/aclitem_test.go @@ -0,0 +1,97 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestACLItemTranscode(t *testing.T) { + testSuccessfulTranscode(t, "aclitem", []interface{}{ + pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, + pgtype.ACLItem{Status: pgtype.Null}, + }) +} + +func TestACLItemConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.ACLItem + }{ + {source: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, result: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + {source: "postgres=arwdDxt/postgres", result: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + {source: (*string)(nil), result: pgtype.ACLItem{Status: pgtype.Null}}, + } + + for i, tt := range successfulTests { + var d pgtype.ACLItem + err := d.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if d != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestACLItemAssignTo(t *testing.T) { + var s string + var ps *string + + simpleTests := []struct { + src pgtype.ACLItem + dst interface{} + expected interface{} + }{ + {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &s, expected: "postgres=arwdDxt/postgres"}, + {src: pgtype.ACLItem{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.ACLItem + dst interface{} + expected interface{} + }{ + {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &ps, expected: "postgres=arwdDxt/postgres"}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.ACLItem + dst interface{} + }{ + {src: pgtype.ACLItem{Status: pgtype.Null}, dst: &s}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/aclitemarray.go b/pgtype/aclitemarray.go new file mode 100644 index 00000000..d69cd83c --- /dev/null +++ b/pgtype/aclitemarray.go @@ -0,0 +1,186 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type ACLItemArray struct { + Elements []ACLItem + Dimensions []ArrayDimension + Status Status +} + +func (dst *ACLItemArray) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case ACLItemArray: + *dst = value + + case []string: + if value == nil { + *dst = ACLItemArray{Status: Null} + } else if len(value) == 0 { + *dst = ACLItemArray{Status: Present} + } else { + elements := make([]ACLItem, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = ACLItemArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to ACLItem", value) + } + + return nil +} + +func (src *ACLItemArray) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[]string: + if src.Status == Present { + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +func (dst *ACLItemArray) DecodeText(src []byte) error { + if src == nil { + *dst = ACLItemArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []ACLItem + + if len(uta.Elements) > 0 { + elements = make([]ACLItem, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem ACLItem + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = ACLItemArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (src *ACLItemArray) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if len(src.Dimensions) == 0 { + _, err := io.WriteString(w, "{}") + return false, err + } + + err := EncodeTextArrayDimensions(w, src.Dimensions) + if err != nil { + return false, err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(w, ',') + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(w, '{') + if err != nil { + return false, err + } + } + } + + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) + if err != nil { + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else { + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(w, '}') + if err != nil { + return false, err + } + } + } + } + + return false, nil +} diff --git a/pgtype/aclitemarray_test.go b/pgtype/aclitemarray_test.go new file mode 100644 index 00000000..8c01ac66 --- /dev/null +++ b/pgtype/aclitemarray_test.go @@ -0,0 +1,151 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestACLItemArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "aclitem[]", []interface{}{ + &pgtype.ACLItemArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.ACLItemArray{Status: pgtype.Null}, + &pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{Status: pgtype.Null}, + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestACLItemArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.ACLItemArray + }{ + { + source: []string{"=r/postgres"}, + result: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]string)(nil)), + result: pgtype.ACLItemArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.ACLItemArray + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestACLItemArrayAssignTo(t *testing.T) { + var stringSlice []string + type _stringSlice []string + var namedStringSlice _stringSlice + + simpleTests := []struct { + src pgtype.ACLItemArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + expected: []string{"=r/postgres"}, + }, + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedStringSlice, + expected: _stringSlice{"=r/postgres"}, + }, + { + src: pgtype.ACLItemArray{Status: pgtype.Null}, + dst: &stringSlice, + expected: (([]string)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.ACLItemArray + dst interface{} + }{ + { + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index d6cd53c1..d72217ac 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -35,8 +35,8 @@ const ( Int8ArrayOID = 1016 Float4ArrayOID = 1021 Float8ArrayOID = 1022 - AclItemOID = 1033 - AclItemArrayOID = 1034 + ACLItemOID = 1033 + ACLItemArrayOID = 1034 InetArrayOID = 1041 VarcharOID = 1043 DateOID = 1082 diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh index c63414c8..876f8a3c 100644 --- a/pgtype/typed_array_gen.sh +++ b/pgtype/typed_array_gen.sh @@ -10,3 +10,4 @@ erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]fl erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOID text_null=NULL typed_array.go.erb > inetarray.go erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_oid=TextOID text_null='"NULL"' typed_array.go.erb > textarray.go erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_oid=ByteaOID text_null=NULL typed_array.go.erb > byteaarray.go +erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string element_oid=ACLItemOID text_null=NULL typed_array.go.erb > aclitemarray.go diff --git a/values.go b/values.go index 80f4ee52..abe12d98 100644 --- a/values.go +++ b/values.go @@ -48,8 +48,8 @@ const ( Int8ArrayOID = 1016 Float4ArrayOID = 1021 Float8ArrayOID = 1022 - AclItemOID = 1033 - AclItemArrayOID = 1034 + ACLItemOID = 1033 + ACLItemArrayOID = 1034 InetArrayOID = 1041 VarcharOID = 1043 DateOID = 1082 @@ -316,58 +316,6 @@ func (s NullString) Encode(w *WriteBuf, oid OID) error { return encodeString(w, oid, s.String) } -// AclItem is used for PostgreSQL's aclitem data type. A sample aclitem -// might look like this: -// -// postgres=arwdDxt/postgres -// -// Note, however, that because the user/role name part of an aclitem is -// an identifier, it follows all the usual formatting rules for SQL -// identifiers: if it contains spaces and other special characters, -// it should appear in double-quotes: -// -// postgres=arwdDxt/"role with spaces" -// -type AclItem string - -// NullAclItem represents a pgx.AclItem that may be null. NullAclItem implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan for prepared and unprepared queries. -// -// If Valid is false then the value is NULL. -type NullAclItem struct { - AclItem AclItem - Valid bool // Valid is true if AclItem is not NULL -} - -func (n *NullAclItem) Scan(vr *ValueReader) error { - if vr.Type().DataType != AclItemOID { - return SerializationError(fmt.Sprintf("NullAclItem.Scan cannot decode OID %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.AclItem, n.Valid = "", false - return nil - } - - n.Valid = true - n.AclItem = AclItem(decodeText(vr)) - return vr.Err() -} - -// Particularly important to return TextFormatCode, seeing as Postgres -// only ever sends aclitem as text, not binary. -func (n NullAclItem) FormatCode() int16 { return TextFormatCode } - -func (n NullAclItem) Encode(w *WriteBuf, oid OID) error { - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeString(w, oid, string(n.AclItem)) -} - // NullInt16 represents a smallint that may be null. NullInt16 implements the // Scanner and Encoder interfaces so it may be used both as an argument to // Query[Row] and a destination for Scan for prepared and unprepared queries. @@ -865,8 +813,6 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { return Encode(wbuf, oid, v) case string: return encodeString(wbuf, oid, arg) - case []AclItem: - return encodeAclItemSlice(wbuf, oid, arg) case []byte: return encodeByteSlice(wbuf, oid, arg) } @@ -909,17 +855,10 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { return nil } - switch arg := arg.(type) { - case AclItem: - // The aclitem data type goes over the wire using the same format as string, - // so just cast to string and use encodeString - return encodeString(wbuf, oid, string(arg)) - default: - if strippedArg, ok := stripNamedType(&refVal); ok { - return Encode(wbuf, oid, strippedArg) - } - return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) + if strippedArg, ok := stripNamedType(&refVal); ok { + return Encode(wbuf, oid, strippedArg) } + return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) } func stripNamedType(val *reflect.Value) (interface{}, bool) { @@ -981,15 +920,10 @@ func decodeByOID(vr *ValueReader) (interface{}, error) { // decoding to the built-in functionality. func Decode(vr *ValueReader, d interface{}) error { switch v := d.(type) { - case *AclItem: - // aclitem goes over the wire just like text - *v = AclItem(decodeText(vr)) case *Tid: *v = decodeTid(vr) case *string: *v = decodeText(vr) - case *[]AclItem: - *v = decodeAclItemArray(vr) case *[]interface{}: *v = decodeRecord(vr) default: @@ -1675,207 +1609,3 @@ func decode1dArrayHeader(vr *ValueReader) (length int32, err error) { return length, nil } - -// escapeAclItem escapes an AclItem before it is added to -// its aclitem[] string representation. The PostgreSQL aclitem -// datatype itself can need escapes because it follows the -// formatting rules of SQL identifiers. Think of this function -// as escaping the escapes, so that PostgreSQL's array parser -// will do the right thing. -func escapeAclItem(acl string) (string, error) { - var escapedAclItem bytes.Buffer - reader := strings.NewReader(acl) - for { - rn, _, err := reader.ReadRune() - if err != nil { - if err == io.EOF { - // Here, EOF is an expected end state, not an error. - return escapedAclItem.String(), nil - } - // This error was not expected - return "", err - } - if needsEscape(rn) { - escapedAclItem.WriteRune('\\') - } - escapedAclItem.WriteRune(rn) - } -} - -// needsEscape determines whether or not a rune needs escaping -// before being placed in the textual representation of an -// aclitem[] array. -func needsEscape(rn rune) bool { - return rn == '\\' || rn == ',' || rn == '"' || rn == '}' -} - -// encodeAclItemSlice encodes a slice of AclItems in -// their textual represention for PostgreSQL. -func encodeAclItemSlice(w *WriteBuf, oid OID, aclitems []AclItem) error { - strs := make([]string, len(aclitems)) - var escapedAclItem string - var err error - for i := range strs { - escapedAclItem, err = escapeAclItem(string(aclitems[i])) - if err != nil { - return err - } - strs[i] = string(escapedAclItem) - } - - var buf bytes.Buffer - buf.WriteRune('{') - buf.WriteString(strings.Join(strs, ",")) - buf.WriteRune('}') - str := buf.String() - w.WriteInt32(int32(len(str))) - w.WriteBytes([]byte(str)) - return nil -} - -// parseAclItemArray parses the textual representation -// of the aclitem[] type. The textual representation is chosen because -// Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin). -// See https://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO -// for formatting notes. -func parseAclItemArray(arr string) ([]AclItem, error) { - reader := strings.NewReader(arr) - // Difficult to guess a performant initial capacity for a slice of - // aclitems, but let's go with 5. - aclItems := make([]AclItem, 0, 5) - // A single value - aclItem := AclItem("") - for { - // Grab the first/next/last rune to see if we are dealing with a - // quoted value, an unquoted value, or the end of the string. - rn, _, err := reader.ReadRune() - if err != nil { - if err == io.EOF { - // Here, EOF is an expected end state, not an error. - return aclItems, nil - } - // This error was not expected - return nil, err - } - - if rn == '"' { - // Discard the opening quote of the quoted value. - aclItem, err = parseQuotedAclItem(reader) - } else { - // We have just read the first rune of an unquoted (bare) value; - // put it back so that ParseBareValue can read it. - err := reader.UnreadRune() - if err != nil { - return nil, err - } - aclItem, err = parseBareAclItem(reader) - } - - if err != nil { - if err == io.EOF { - // Here, EOF is an expected end state, not an error.. - aclItems = append(aclItems, aclItem) - return aclItems, nil - } - // This error was not expected. - return nil, err - } - aclItems = append(aclItems, aclItem) - } -} - -// parseBareAclItem parses a bare (unquoted) aclitem from reader -func parseBareAclItem(reader *strings.Reader) (AclItem, error) { - var aclItem bytes.Buffer - for { - rn, _, err := reader.ReadRune() - if err != nil { - // Return the read value in case the error is a harmless io.EOF. - // (io.EOF marks the end of a bare aclitem at the end of a string) - return AclItem(aclItem.String()), err - } - if rn == ',' { - // A comma marks the end of a bare aclitem. - return AclItem(aclItem.String()), nil - } else { - aclItem.WriteRune(rn) - } - } -} - -// parseQuotedAclItem parses an aclitem which is in double quotes from reader -func parseQuotedAclItem(reader *strings.Reader) (AclItem, error) { - var aclItem bytes.Buffer - for { - rn, escaped, err := readPossiblyEscapedRune(reader) - if err != nil { - if err == io.EOF { - // Even when it is the last value, the final rune of - // a quoted aclitem should be the final closing quote, not io.EOF. - return AclItem(""), fmt.Errorf("unexpected end of quoted value") - } - // Return the read aclitem in case the error is a harmless io.EOF, - // which will be determined by the caller. - return AclItem(aclItem.String()), err - } - if !escaped && rn == '"' { - // An unescaped double quote marks the end of a quoted value. - // The next rune should either be a comma or the end of the string. - rn, _, err := reader.ReadRune() - if err != nil { - // Return the read value in case the error is a harmless io.EOF, - // which will be determined by the caller. - return AclItem(aclItem.String()), err - } - if rn != ',' { - return AclItem(""), fmt.Errorf("unexpected rune after quoted value") - } - return AclItem(aclItem.String()), nil - } - aclItem.WriteRune(rn) - } -} - -// Returns the next rune from r, unless it is a backslash; -// in that case, it returns the rune after the backslash. The second -// return value tells us whether or not the rune was -// preceeded by a backslash (escaped). -func readPossiblyEscapedRune(reader *strings.Reader) (rune, bool, error) { - rn, _, err := reader.ReadRune() - if err != nil { - return 0, false, err - } - if rn == '\\' { - // Discard the backslash and read the next rune. - rn, _, err = reader.ReadRune() - if err != nil { - return 0, false, err - } - return rn, true, nil - } - return rn, false, nil -} - -func decodeAclItemArray(vr *ValueReader) []AclItem { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into []AclItem")) - return nil - } - - str := vr.ReadString(vr.Len()) - - // Short-circuit empty array. - if str == "{}" { - return []AclItem{} - } - - // Remove the '{' at the front and the '}' at the end, - // so that parseAclItemArray doesn't have to deal with them. - str = str[1 : len(str)-1] - aclItems, err := parseAclItemArray(str) - if err != nil { - vr.Fatal(ProtocolError(err.Error())) - return nil - } - return aclItems -} diff --git a/values_test.go b/values_test.go index 4c02ac0a..9cf2b219 100644 --- a/values_test.go +++ b/values_test.go @@ -568,7 +568,6 @@ func TestNullX(t *testing.T) { s pgx.NullString i16 pgx.NullInt16 i32 pgx.NullInt32 - a pgx.NullAclItem tid pgx.NullTid i64 pgx.NullInt64 f32 pgx.NullFloat32 @@ -591,10 +590,6 @@ func TestNullX(t *testing.T) { {"select $1::int2", []interface{}{pgx.NullInt16{Int16: 1, Valid: false}}, []interface{}{&actual.i16}, allTypes{i16: pgx.NullInt16{Int16: 0, Valid: false}}}, {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 1, Valid: true}}}, {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: false}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 0, Valid: false}}}, - {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: true}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: true}}}, - {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: "postgres=arwdDxt/postgres", Valid: false}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: "", Valid: false}}}, - // A tricky (and valid) aclitem can still be used, especially with Go's useful backticks - {"select $1::aclitem", []interface{}{pgx.NullAclItem{AclItem: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}}, []interface{}{&actual.a}, allTypes{a: pgx.NullAclItem{AclItem: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Valid: true}}}, {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: true}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: true}}}, {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: false}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 0, OffsetNumber: 0}, Valid: false}}}, {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 4294967295, OffsetNumber: 65535}, Valid: true}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 4294967295, OffsetNumber: 65535}, Valid: true}}}, @@ -631,52 +626,6 @@ func TestNullX(t *testing.T) { } } -func assertAclItemSlicesEqual(t *testing.T, query, scan []pgx.AclItem) { - if !reflect.DeepEqual(query, scan) { - t.Errorf("failed to encode aclitem[]\n EXPECTED: %d %v\n ACTUAL: %d %v", len(query), query, len(scan), scan) - } -} - -func TestAclArrayDecoding(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - sql := "select $1::aclitem[]" - var scan []pgx.AclItem - - tests := []struct { - query []pgx.AclItem - }{ - { - []pgx.AclItem{}, - }, - { - []pgx.AclItem{"=r/postgres"}, - }, - { - []pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres"}, - }, - { - []pgx.AclItem{"=r/postgres", "postgres=arwdDxt/postgres", `postgres=arwdDxt/" tricky, ' } "" \ test user "`}, - }, - } - for i, tt := range tests { - err := conn.QueryRow(sql, tt.query).Scan(&scan) - if err != nil { - // t.Errorf(`%d. error reading array: %v`, i, err) - t.Errorf(`%d. error reading array: %v query: %s`, i, err, tt.query) - if pgerr, ok := err.(pgx.PgError); ok { - t.Errorf(`%d. error reading array (detail): %s`, i, pgerr.Detail) - } - continue - } - assertAclItemSlicesEqual(t, tt.query, scan) - ensureConnValid(t, conn) - } -} - func TestArrayDecoding(t *testing.T) { t.Parallel() From f10ed4ff5dadf05df4164d4013332e0bcf7ddb67 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Mar 2017 16:14:13 -0600 Subject: [PATCH 101/264] Remove unused function --- values.go | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/values.go b/values.go index abe12d98..f34735d4 100644 --- a/values.go +++ b/values.go @@ -901,20 +901,6 @@ func stripNamedType(val *reflect.Value) (interface{}, bool) { return nil, false } -func decodeByOID(vr *ValueReader) (interface{}, error) { - switch vr.Type().DataType { - case Int2OID, Int4OID, Int8OID: - n := decodeInt(vr) - return n, vr.Err() - case BoolOID: - b := decodeBool(vr) - return b, vr.Err() - default: - buf := vr.ReadBytes(vr.Len()) - return buf, vr.Err() - } -} - // Decode decodes from vr into d. d must be a pointer. This allows // implementations of the Decoder interface to delegate the actual work of // decoding to the built-in functionality. From 6694e0e61876db7827791ef1af197847dee9b2e3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Mar 2017 16:48:37 -0600 Subject: [PATCH 102/264] Move Tid to pgtype --- conn.go | 3 +- pgtype/pgtype.go | 9 +++- pgtype/tid.go | 104 ++++++++++++++++++++++++++++++++++++ pgtype/tid_test.go | 15 ++++++ query.go | 8 ++- values.go | 130 +++------------------------------------------ values_test.go | 4 -- 7 files changed, 142 insertions(+), 131 deletions(-) create mode 100644 pgtype/tid.go create mode 100644 pgtype/tid_test.go diff --git a/conn.go b/conn.go index f55dd82a..c2cc5d3c 100644 --- a/conn.go +++ b/conn.go @@ -268,8 +268,8 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.closedChan = make(chan error) c.oidPgtypeValues = map[OID]pgtype.Value{ - ACLItemOID: &pgtype.ACLItem{}, ACLItemArrayOID: &pgtype.ACLItemArray{}, + ACLItemOID: &pgtype.ACLItem{}, BoolArrayOID: &pgtype.BoolArray{}, BoolOID: &pgtype.Bool{}, ByteaArrayOID: &pgtype.ByteaArray{}, @@ -296,6 +296,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl OIDOID: &pgtype.OID{}, TextArrayOID: &pgtype.TextArray{}, TextOID: &pgtype.Text{}, + TIDOID: &pgtype.TID{}, TimestampArrayOID: &pgtype.TimestampArray{}, TimestampOID: &pgtype.Timestamp{}, TimestampTzArrayOID: &pgtype.TimestamptzArray{}, diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index d72217ac..8c67c630 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -16,7 +16,7 @@ const ( Int4OID = 23 TextOID = 25 OIDOID = 26 - TidOID = 27 + TIDOID = 27 XIDOID = 28 CIDOID = 29 JSONOID = 114 @@ -66,8 +66,13 @@ const ( NegativeInfinity InfinityModifier = -Infinity ) -type Value interface { +type Value interface{} + +type ConverterFrom interface { ConvertFrom(src interface{}) error +} + +type AssignerTo interface { AssignTo(dst interface{}) error } diff --git a/pgtype/tid.go b/pgtype/tid.go new file mode 100644 index 00000000..804cced2 --- /dev/null +++ b/pgtype/tid.go @@ -0,0 +1,104 @@ +package pgtype + +import ( + "encoding/binary" + "fmt" + "io" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" +) + +// TID is PostgreSQL's Tuple Identifier type. +// +// When one does +// +// select ctid, * from some_table; +// +// it is the data type of the ctid hidden system column. +// +// It is currently implemented as a pair unsigned two byte integers. +// Its conversion functions can be found in src/backend/utils/adt/tid.c +// in the PostgreSQL sources. +type TID struct { + BlockNumber uint32 + OffsetNumber uint16 + Status Status +} + +func (dst *TID) DecodeText(src []byte) error { + if src == nil { + *dst = TID{Status: Null} + return nil + } + + if len(src) < 5 { + return fmt.Errorf("invalid length for tid: %v", len(src)) + } + + parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) + if len(parts) < 2 { + return fmt.Errorf("invalid format for tid") + } + + blockNumber, err := strconv.ParseUint(parts[0], 10, 32) + if err != nil { + return err + } + + offsetNumber, err := strconv.ParseUint(parts[1], 10, 16) + if err != nil { + return err + } + + *dst = TID{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber), Status: Present} + return nil +} + +func (dst *TID) DecodeBinary(src []byte) error { + if src == nil { + *dst = TID{Status: Null} + return nil + } + + if len(src) != 6 { + return fmt.Errorf("invalid length for tid: %v", len(src)) + } + + *dst = TID{ + BlockNumber: binary.BigEndian.Uint32(src), + OffsetNumber: binary.BigEndian.Uint16(src[4:]), + Status: Present, + } + return nil +} + +func (src TID) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := io.WriteString(w, fmt.Sprintf(`(%d,%d)`, src.BlockNumber, src.OffsetNumber)) + return false, err +} + +func (src TID) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := pgio.WriteUint32(w, src.BlockNumber) + if err != nil { + return false, err + } + + _, err = pgio.WriteUint16(w, src.OffsetNumber) + return false, err +} diff --git a/pgtype/tid_test.go b/pgtype/tid_test.go new file mode 100644 index 00000000..a5aab8a3 --- /dev/null +++ b/pgtype/tid_test.go @@ -0,0 +1,15 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestTIDTranscode(t *testing.T) { + testSuccessfulTranscode(t, "tid", []interface{}{ + pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, + pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, + pgtype.TID{Status: pgtype.Null}, + }) +} diff --git a/query.go b/query.go index d1191c7c..5730f1c6 100644 --- a/query.go +++ b/query.go @@ -299,8 +299,12 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { vr.Fatal(fmt.Errorf("unknown format code: %v", vr.Type().FormatCode)) } - if err := pgVal.AssignTo(d); err != nil { - vr.Fatal(err) + if assignerTo, ok := pgVal.(pgtype.AssignerTo); ok { + if err := assignerTo.AssignTo(d); err != nil { + vr.Fatal(err) + } + } else { + vr.Fatal(fmt.Errorf("cannot assign %T", pgVal)) } } else { if err := Decode(vr, d); err != nil { diff --git a/values.go b/values.go index f34735d4..72f836bb 100644 --- a/values.go +++ b/values.go @@ -9,7 +9,6 @@ import ( "io" "math" "reflect" - "regexp" "strconv" "strings" "time" @@ -29,7 +28,7 @@ const ( Int4OID = 23 TextOID = 25 OIDOID = 26 - TidOID = 27 + TIDOID = 27 XIDOID = 28 CIDOID = 29 JSONOID = 114 @@ -444,61 +443,6 @@ func (src OID) EncodeBinary(w io.Writer) (bool, error) { return false, err } -// Tid is PostgreSQL's Tuple Identifier type. -// -// When one does -// -// select ctid, * from some_table; -// -// it is the data type of the ctid hidden system column. -// -// It is currently implemented as a pair unsigned two byte integers. -// Its conversion functions can be found in src/backend/utils/adt/tid.c -// in the PostgreSQL sources. -type Tid struct { - BlockNumber uint32 - OffsetNumber uint16 -} - -// NullTid represents a Tuple Identifier (Tid) that may be null. NullTid implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. -// -// If Valid is false then the value is NULL. -type NullTid struct { - Tid Tid - Valid bool // Valid is true if Tid is not NULL -} - -func (n *NullTid) Scan(vr *ValueReader) error { - if vr.Type().DataType != TidOID { - return SerializationError(fmt.Sprintf("NullTid.Scan cannot decode OID %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Tid, n.Valid = Tid{BlockNumber: 0, OffsetNumber: 0}, false - return nil - } - n.Valid = true - n.Tid = decodeTid(vr) - return vr.Err() -} - -func (n NullTid) FormatCode() int16 { return BinaryFormatCode } - -func (n NullTid) Encode(w *WriteBuf, oid OID) error { - if oid != TidOID { - return SerializationError(fmt.Sprintf("NullTid.Encode cannot encode into OID %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeTid(w, oid, n.Tid) -} - // NullInt64 represents an bigint that may be null. NullInt64 implements the // Scanner and Encoder interfaces so it may be used both as an argument to // Query[Row] and a destination for Scan. @@ -836,9 +780,13 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { } if value, ok := wbuf.conn.oidPgtypeValues[oid]; ok { - err := value.ConvertFrom(arg) - if err != nil { - return err + if converterFrom, ok := value.(pgtype.ConverterFrom); ok { + err := converterFrom.ConvertFrom(arg) + if err != nil { + return err + } + } else { + return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) } buf := &bytes.Buffer{} @@ -906,8 +854,6 @@ func stripNamedType(val *reflect.Value) (interface{}, bool) { // decoding to the built-in functionality. func Decode(vr *ValueReader, d interface{}) error { switch v := d.(type) { - case *Tid: - *v = decodeTid(vr) case *string: *v = decodeText(vr) case *[]interface{}: @@ -1092,66 +1038,6 @@ func decodeInt4(vr *ValueReader) int32 { return n.Int } -// Note that we do not match negative numbers, because neither the -// BlockNumber nor OffsetNumber of a Tid can be negative. -var tidRegexp *regexp.Regexp = regexp.MustCompile(`^\((\d*),(\d*)\)$`) - -func decodeTid(vr *ValueReader) Tid { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into Tid")) - return Tid{BlockNumber: 0, OffsetNumber: 0} - } - - if vr.Type().DataType != TidOID { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.Tid", vr.Type().DataType))) - return Tid{BlockNumber: 0, OffsetNumber: 0} - } - - // Unlikely Tid will ever go over the wire as text format, but who knows? - switch vr.Type().FormatCode { - case TextFormatCode: - s := vr.ReadString(vr.Len()) - - match := tidRegexp.FindStringSubmatch(s) - if match == nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid OID: %v", s))) - return Tid{BlockNumber: 0, OffsetNumber: 0} - } - - blockNumber, err := strconv.ParseUint(s, 10, 16) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid BlockNumber part of a Tid: %v", s))) - } - - offsetNumber, err := strconv.ParseUint(s, 10, 16) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid offsetNumber part of a Tid: %v", s))) - } - return Tid{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber)} - case BinaryFormatCode: - if vr.Len() != 6 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an OID: %d", vr.Len()))) - return Tid{BlockNumber: 0, OffsetNumber: 0} - } - return Tid{BlockNumber: vr.ReadUint32(), OffsetNumber: vr.ReadUint16()} - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return Tid{BlockNumber: 0, OffsetNumber: 0} - } -} - -func encodeTid(w *WriteBuf, oid OID, value Tid) error { - if oid != TidOID { - return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.Tid", oid) - } - - w.WriteInt32(6) - w.WriteUint32(value.BlockNumber) - w.WriteUint16(value.OffsetNumber) - - return nil -} - func decodeFloat4(vr *ValueReader) float32 { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into float32")) diff --git a/values_test.go b/values_test.go index 9cf2b219..eb570fe6 100644 --- a/values_test.go +++ b/values_test.go @@ -568,7 +568,6 @@ func TestNullX(t *testing.T) { s pgx.NullString i16 pgx.NullInt16 i32 pgx.NullInt32 - tid pgx.NullTid i64 pgx.NullInt64 f32 pgx.NullFloat32 f64 pgx.NullFloat64 @@ -590,9 +589,6 @@ func TestNullX(t *testing.T) { {"select $1::int2", []interface{}{pgx.NullInt16{Int16: 1, Valid: false}}, []interface{}{&actual.i16}, allTypes{i16: pgx.NullInt16{Int16: 0, Valid: false}}}, {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 1, Valid: true}}}, {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: false}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 0, Valid: false}}}, - {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: true}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: true}}}, - {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 1, OffsetNumber: 1}, Valid: false}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 0, OffsetNumber: 0}, Valid: false}}}, - {"select $1::tid", []interface{}{pgx.NullTid{Tid: pgx.Tid{BlockNumber: 4294967295, OffsetNumber: 65535}, Valid: true}}, []interface{}{&actual.tid}, allTypes{tid: pgx.NullTid{Tid: pgx.Tid{BlockNumber: 4294967295, OffsetNumber: 65535}, Valid: true}}}, {"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: true}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 1, Valid: true}}}, {"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: false}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 0, Valid: false}}}, {"select $1::float4", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: true}}, []interface{}{&actual.f32}, allTypes{f32: pgx.NullFloat32{Float32: 1.23, Valid: true}}}, From 3dc509df948e5cd94ef2a5f68b7bfa30626ae4a8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Mar 2017 16:53:07 -0600 Subject: [PATCH 103/264] Rename array files --- pgtype/{aclitemarray.go => aclitem_array.go} | 0 ...temarray_test.go => aclitem_array_test.go} | 0 pgtype/{boolarray.go => bool_array.go} | 0 .../{boolarray_test.go => bool_array_test.go} | 0 pgtype/{byteaarray.go => bytea_array.go} | 0 ...byteaarray_test.go => bytea_array_test.go} | 0 pgtype/{cidrarray.go => cidr_array.go} | 0 pgtype/{datearray.go => date_array.go} | 0 .../{datearray_test.go => date_array_test.go} | 0 pgtype/{float4array.go => float4_array.go} | 0 ...oat4array_test.go => float4_array_test.go} | 0 pgtype/{float8array.go => float8_array.go} | 0 ...oat8array_test.go => float8_array_test.go} | 0 pgtype/{inetarray.go => inet_array.go} | 0 .../{inetarray_test.go => inet_array_test.go} | 0 pgtype/{int2array.go => int2_array.go} | 0 .../{int2array_test.go => int2_array_test.go} | 0 pgtype/{int4array.go => int4_array.go} | 0 .../{int4array_test.go => int4_array_test.go} | 0 pgtype/{int8array.go => int8_array.go} | 0 .../{int8array_test.go => int8_array_test.go} | 0 pgtype/{textarray.go => text_array.go} | 0 .../{textarray_test.go => text_array_test.go} | 0 .../{timestamparray.go => timestamp_array.go} | 0 ...parray_test.go => timestamp_array_test.go} | 0 ...mestamptzarray.go => timestamptz_array.go} | 0 ...rray_test.go => timestamptz_array_test.go} | 0 pgtype/typed_array_gen.sh | 26 +++++++++---------- pgtype/{varchararray.go => varchar_array.go} | 0 29 files changed, 13 insertions(+), 13 deletions(-) rename pgtype/{aclitemarray.go => aclitem_array.go} (100%) rename pgtype/{aclitemarray_test.go => aclitem_array_test.go} (100%) rename pgtype/{boolarray.go => bool_array.go} (100%) rename pgtype/{boolarray_test.go => bool_array_test.go} (100%) rename pgtype/{byteaarray.go => bytea_array.go} (100%) rename pgtype/{byteaarray_test.go => bytea_array_test.go} (100%) rename pgtype/{cidrarray.go => cidr_array.go} (100%) rename pgtype/{datearray.go => date_array.go} (100%) rename pgtype/{datearray_test.go => date_array_test.go} (100%) rename pgtype/{float4array.go => float4_array.go} (100%) rename pgtype/{float4array_test.go => float4_array_test.go} (100%) rename pgtype/{float8array.go => float8_array.go} (100%) rename pgtype/{float8array_test.go => float8_array_test.go} (100%) rename pgtype/{inetarray.go => inet_array.go} (100%) rename pgtype/{inetarray_test.go => inet_array_test.go} (100%) rename pgtype/{int2array.go => int2_array.go} (100%) rename pgtype/{int2array_test.go => int2_array_test.go} (100%) rename pgtype/{int4array.go => int4_array.go} (100%) rename pgtype/{int4array_test.go => int4_array_test.go} (100%) rename pgtype/{int8array.go => int8_array.go} (100%) rename pgtype/{int8array_test.go => int8_array_test.go} (100%) rename pgtype/{textarray.go => text_array.go} (100%) rename pgtype/{textarray_test.go => text_array_test.go} (100%) rename pgtype/{timestamparray.go => timestamp_array.go} (100%) rename pgtype/{timestamparray_test.go => timestamp_array_test.go} (100%) rename pgtype/{timestamptzarray.go => timestamptz_array.go} (100%) rename pgtype/{timestamptzarray_test.go => timestamptz_array_test.go} (100%) rename pgtype/{varchararray.go => varchar_array.go} (100%) diff --git a/pgtype/aclitemarray.go b/pgtype/aclitem_array.go similarity index 100% rename from pgtype/aclitemarray.go rename to pgtype/aclitem_array.go diff --git a/pgtype/aclitemarray_test.go b/pgtype/aclitem_array_test.go similarity index 100% rename from pgtype/aclitemarray_test.go rename to pgtype/aclitem_array_test.go diff --git a/pgtype/boolarray.go b/pgtype/bool_array.go similarity index 100% rename from pgtype/boolarray.go rename to pgtype/bool_array.go diff --git a/pgtype/boolarray_test.go b/pgtype/bool_array_test.go similarity index 100% rename from pgtype/boolarray_test.go rename to pgtype/bool_array_test.go diff --git a/pgtype/byteaarray.go b/pgtype/bytea_array.go similarity index 100% rename from pgtype/byteaarray.go rename to pgtype/bytea_array.go diff --git a/pgtype/byteaarray_test.go b/pgtype/bytea_array_test.go similarity index 100% rename from pgtype/byteaarray_test.go rename to pgtype/bytea_array_test.go diff --git a/pgtype/cidrarray.go b/pgtype/cidr_array.go similarity index 100% rename from pgtype/cidrarray.go rename to pgtype/cidr_array.go diff --git a/pgtype/datearray.go b/pgtype/date_array.go similarity index 100% rename from pgtype/datearray.go rename to pgtype/date_array.go diff --git a/pgtype/datearray_test.go b/pgtype/date_array_test.go similarity index 100% rename from pgtype/datearray_test.go rename to pgtype/date_array_test.go diff --git a/pgtype/float4array.go b/pgtype/float4_array.go similarity index 100% rename from pgtype/float4array.go rename to pgtype/float4_array.go diff --git a/pgtype/float4array_test.go b/pgtype/float4_array_test.go similarity index 100% rename from pgtype/float4array_test.go rename to pgtype/float4_array_test.go diff --git a/pgtype/float8array.go b/pgtype/float8_array.go similarity index 100% rename from pgtype/float8array.go rename to pgtype/float8_array.go diff --git a/pgtype/float8array_test.go b/pgtype/float8_array_test.go similarity index 100% rename from pgtype/float8array_test.go rename to pgtype/float8_array_test.go diff --git a/pgtype/inetarray.go b/pgtype/inet_array.go similarity index 100% rename from pgtype/inetarray.go rename to pgtype/inet_array.go diff --git a/pgtype/inetarray_test.go b/pgtype/inet_array_test.go similarity index 100% rename from pgtype/inetarray_test.go rename to pgtype/inet_array_test.go diff --git a/pgtype/int2array.go b/pgtype/int2_array.go similarity index 100% rename from pgtype/int2array.go rename to pgtype/int2_array.go diff --git a/pgtype/int2array_test.go b/pgtype/int2_array_test.go similarity index 100% rename from pgtype/int2array_test.go rename to pgtype/int2_array_test.go diff --git a/pgtype/int4array.go b/pgtype/int4_array.go similarity index 100% rename from pgtype/int4array.go rename to pgtype/int4_array.go diff --git a/pgtype/int4array_test.go b/pgtype/int4_array_test.go similarity index 100% rename from pgtype/int4array_test.go rename to pgtype/int4_array_test.go diff --git a/pgtype/int8array.go b/pgtype/int8_array.go similarity index 100% rename from pgtype/int8array.go rename to pgtype/int8_array.go diff --git a/pgtype/int8array_test.go b/pgtype/int8_array_test.go similarity index 100% rename from pgtype/int8array_test.go rename to pgtype/int8_array_test.go diff --git a/pgtype/textarray.go b/pgtype/text_array.go similarity index 100% rename from pgtype/textarray.go rename to pgtype/text_array.go diff --git a/pgtype/textarray_test.go b/pgtype/text_array_test.go similarity index 100% rename from pgtype/textarray_test.go rename to pgtype/text_array_test.go diff --git a/pgtype/timestamparray.go b/pgtype/timestamp_array.go similarity index 100% rename from pgtype/timestamparray.go rename to pgtype/timestamp_array.go diff --git a/pgtype/timestamparray_test.go b/pgtype/timestamp_array_test.go similarity index 100% rename from pgtype/timestamparray_test.go rename to pgtype/timestamp_array_test.go diff --git a/pgtype/timestamptzarray.go b/pgtype/timestamptz_array.go similarity index 100% rename from pgtype/timestamptzarray.go rename to pgtype/timestamptz_array.go diff --git a/pgtype/timestamptzarray_test.go b/pgtype/timestamptz_array_test.go similarity index 100% rename from pgtype/timestamptzarray_test.go rename to pgtype/timestamptz_array_test.go diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh index 876f8a3c..32c298cc 100644 --- a/pgtype/typed_array_gen.sh +++ b/pgtype/typed_array_gen.sh @@ -1,13 +1,13 @@ -erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_oid=Int2OID text_null=NULL typed_array.go.erb > int2array.go -erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32 element_oid=Int4OID text_null=NULL typed_array.go.erb > int4array.go -erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_oid=Int8OID text_null=NULL typed_array.go.erb > int8array.go -erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_oid=BoolOID text_null=NULL typed_array.go.erb > boolarray.go -erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_oid=DateOID text_null=NULL typed_array.go.erb > datearray.go -erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_oid=TimestamptzOID text_null=NULL typed_array.go.erb > timestamptzarray.go -erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_oid=TimestampOID text_null=NULL typed_array.go.erb > timestamparray.go -erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_oid=Float4OID text_null=NULL typed_array.go.erb > float4array.go -erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8OID text_null=NULL typed_array.go.erb > float8array.go -erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOID text_null=NULL typed_array.go.erb > inetarray.go -erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_oid=TextOID text_null='"NULL"' typed_array.go.erb > textarray.go -erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_oid=ByteaOID text_null=NULL typed_array.go.erb > byteaarray.go -erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string element_oid=ACLItemOID text_null=NULL typed_array.go.erb > aclitemarray.go +erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_oid=Int2OID text_null=NULL typed_array.go.erb > int2_array.go +erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32 element_oid=Int4OID text_null=NULL typed_array.go.erb > int4_array.go +erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_oid=Int8OID text_null=NULL typed_array.go.erb > int8_array.go +erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_oid=BoolOID text_null=NULL typed_array.go.erb > bool_array.go +erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_oid=DateOID text_null=NULL typed_array.go.erb > date_array.go +erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_oid=TimestamptzOID text_null=NULL typed_array.go.erb > timestamptz_array.go +erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_oid=TimestampOID text_null=NULL typed_array.go.erb > timestamp_array.go +erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_oid=Float4OID text_null=NULL typed_array.go.erb > float4_array.go +erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8OID text_null=NULL typed_array.go.erb > float8_array.go +erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOID text_null=NULL typed_array.go.erb > inet_array.go +erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_oid=TextOID text_null='"NULL"' typed_array.go.erb > text_array.go +erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_oid=ByteaOID text_null=NULL typed_array.go.erb > bytea_array.go +erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string element_oid=ACLItemOID text_null=NULL typed_array.go.erb > aclitem_array.go diff --git a/pgtype/varchararray.go b/pgtype/varchar_array.go similarity index 100% rename from pgtype/varchararray.go rename to pgtype/varchar_array.go From 743b98b298a35453b2a8abf8e4002f8897cf3d47 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Mar 2017 17:03:23 -0600 Subject: [PATCH 104/264] Name PG types as words Though this doesn't follow Go naming conventions exactly it makes names more consistent with PostgreSQL and it is easier to read. For example, TIDOID becomes TidOid. In addition this is one less breaking change in the move to V3. --- conn.go | 126 ++++++------- conn_pool.go | 4 +- conn_test.go | 2 +- copy_to_test.go | 2 +- doc.go | 6 +- example_custom_type_test.go | 7 +- fastpath.go | 14 +- large_objects.go | 12 +- messages.go | 4 +- pgtype/aclitem.go | 26 +-- pgtype/aclitem_array.go | 34 ++-- pgtype/aclitem_array_test.go | 74 ++++---- pgtype/aclitem_test.go | 36 ++-- pgtype/array.go | 6 +- pgtype/bool_array.go | 6 +- pgtype/bytea_array.go | 6 +- pgtype/cid.go | 20 +- pgtype/cid_test.go | 30 +-- pgtype/cidr_array.go | 2 +- pgtype/date_array.go | 6 +- pgtype/extra-interface.txt | 2 +- pgtype/float4_array.go | 6 +- pgtype/float8_array.go | 6 +- pgtype/inet_array.go | 6 +- pgtype/inet_array_test.go | 36 ++-- pgtype/inet_test.go | 38 ++-- pgtype/int2_array.go | 6 +- pgtype/int4_array.go | 6 +- pgtype/int8_array.go | 6 +- pgtype/oid.go | 20 +- pgtype/oid_test.go | 30 +-- pgtype/pgtype.go | 82 ++++----- pgtype/pgtype_test.go | 2 +- pgtype/pguint32.go | 2 +- pgtype/text_array.go | 6 +- pgtype/tid.go | 20 +- pgtype/tid_test.go | 8 +- pgtype/timestamp_array.go | 6 +- pgtype/timestamptz_array.go | 6 +- pgtype/typed_array.go.erb | 4 +- pgtype/typed_array_gen.sh | 26 +-- pgtype/varchar_array.go | 2 +- pgtype/xid.go | 20 +- pgtype/xid_test.go | 30 +-- query.go | 82 ++++----- query_test.go | 8 +- stdlib/sql.go | 28 +-- v3.md | 6 - value_reader.go | 4 +- values.go | 344 +++++++++++++++++------------------ values_test.go | 112 ++++++------ 51 files changed, 689 insertions(+), 694 deletions(-) diff --git a/conn.go b/conn.go index c2cc5d3c..21bd8f1b 100644 --- a/conn.go +++ b/conn.go @@ -77,7 +77,7 @@ type Conn struct { pid int32 // backend pid secretKey int32 // key to use to send a cancel query message to the server RuntimeParams map[string]string // parameters that have been reported by the server - PgTypes map[OID]PgType // oids to PgTypes + PgTypes map[Oid]PgType // oids to PgTypes config ConnConfig // config used when establishing this connection txStatus byte preparedStatements map[string]*PreparedStatement @@ -102,7 +102,7 @@ type Conn struct { doneChan chan struct{} closedChan chan error - oidPgtypeValues map[OID]pgtype.Value + oidPgtypeValues map[Oid]pgtype.Value } // PreparedStatement is a description of a prepared statement @@ -110,12 +110,12 @@ type PreparedStatement struct { Name string SQL string FieldDescriptions []FieldDescription - ParameterOIDs []OID + ParameterOids []Oid } // PrepareExOptions is an option struct that can be passed to PrepareEx type PrepareExOptions struct { - ParameterOIDs []OID + ParameterOids []Oid } // Notification is a message received from the PostgreSQL LISTEN/NOTIFY system @@ -180,13 +180,13 @@ func Connect(config ConnConfig) (c *Conn, err error) { return connect(config, nil) } -func connect(config ConnConfig, pgTypes map[OID]PgType) (c *Conn, err error) { +func connect(config ConnConfig, pgTypes map[Oid]PgType) (c *Conn, err error) { c = new(Conn) c.config = config if pgTypes != nil { - c.PgTypes = make(map[OID]PgType, len(pgTypes)) + c.PgTypes = make(map[Oid]PgType, len(pgTypes)) for k, v := range pgTypes { c.PgTypes[k] = v } @@ -267,43 +267,43 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.doneChan = make(chan struct{}) c.closedChan = make(chan error) - c.oidPgtypeValues = map[OID]pgtype.Value{ - ACLItemArrayOID: &pgtype.ACLItemArray{}, - ACLItemOID: &pgtype.ACLItem{}, - BoolArrayOID: &pgtype.BoolArray{}, - BoolOID: &pgtype.Bool{}, - ByteaArrayOID: &pgtype.ByteaArray{}, - ByteaOID: &pgtype.Bytea{}, - CharOID: &pgtype.QChar{}, - CIDOID: &pgtype.CID{}, - CidrArrayOID: &pgtype.CidrArray{}, - CidrOID: &pgtype.Inet{}, - DateArrayOID: &pgtype.DateArray{}, - DateOID: &pgtype.Date{}, - Float4ArrayOID: &pgtype.Float4Array{}, - Float4OID: &pgtype.Float4{}, - Float8ArrayOID: &pgtype.Float8Array{}, - Float8OID: &pgtype.Float8{}, - InetArrayOID: &pgtype.InetArray{}, - InetOID: &pgtype.Inet{}, - Int2ArrayOID: &pgtype.Int2Array{}, - Int2OID: &pgtype.Int2{}, - Int4ArrayOID: &pgtype.Int4Array{}, - Int4OID: &pgtype.Int4{}, - Int8ArrayOID: &pgtype.Int8Array{}, - Int8OID: &pgtype.Int8{}, - NameOID: &pgtype.Name{}, - OIDOID: &pgtype.OID{}, - TextArrayOID: &pgtype.TextArray{}, - TextOID: &pgtype.Text{}, - TIDOID: &pgtype.TID{}, - TimestampArrayOID: &pgtype.TimestampArray{}, - TimestampOID: &pgtype.Timestamp{}, - TimestampTzArrayOID: &pgtype.TimestamptzArray{}, - TimestampTzOID: &pgtype.Timestamptz{}, - VarcharArrayOID: &pgtype.VarcharArray{}, - VarcharOID: &pgtype.Text{}, - XIDOID: &pgtype.XID{}, + c.oidPgtypeValues = map[Oid]pgtype.Value{ + AclitemArrayOid: &pgtype.AclitemArray{}, + AclitemOid: &pgtype.Aclitem{}, + BoolArrayOid: &pgtype.BoolArray{}, + BoolOid: &pgtype.Bool{}, + ByteaArrayOid: &pgtype.ByteaArray{}, + ByteaOid: &pgtype.Bytea{}, + CharOid: &pgtype.QChar{}, + CidOid: &pgtype.Cid{}, + CidrArrayOid: &pgtype.CidrArray{}, + CidrOid: &pgtype.Inet{}, + DateArrayOid: &pgtype.DateArray{}, + DateOid: &pgtype.Date{}, + Float4ArrayOid: &pgtype.Float4Array{}, + Float4Oid: &pgtype.Float4{}, + Float8ArrayOid: &pgtype.Float8Array{}, + Float8Oid: &pgtype.Float8{}, + InetArrayOid: &pgtype.InetArray{}, + InetOid: &pgtype.Inet{}, + Int2ArrayOid: &pgtype.Int2Array{}, + Int2Oid: &pgtype.Int2{}, + Int4ArrayOid: &pgtype.Int4Array{}, + Int4Oid: &pgtype.Int4{}, + Int8ArrayOid: &pgtype.Int8Array{}, + Int8Oid: &pgtype.Int8{}, + NameOid: &pgtype.Name{}, + OidOid: &pgtype.Oid{}, + TextArrayOid: &pgtype.TextArray{}, + TextOid: &pgtype.Text{}, + TidOid: &pgtype.Tid{}, + TimestampArrayOid: &pgtype.TimestampArray{}, + TimestampOid: &pgtype.Timestamp{}, + TimestampTzArrayOid: &pgtype.TimestamptzArray{}, + TimestampTzOid: &pgtype.Timestamptz{}, + VarcharArrayOid: &pgtype.VarcharArray{}, + VarcharOid: &pgtype.Text{}, + XidOid: &pgtype.Xid{}, } if tlsConfig != nil { @@ -397,7 +397,7 @@ where ( return err } - c.PgTypes = make(map[OID]PgType, 128) + c.PgTypes = make(map[Oid]PgType, 128) for rows.Next() { var oid uint32 @@ -408,7 +408,7 @@ where ( // The zero value is text format so we ignore any types without a default type format t.DefaultFormat, _ = DefaultTypeFormats[t.Name] - c.PgTypes[OID(oid)] = t + c.PgTypes[Oid(oid)] = t } return rows.Err() @@ -669,7 +669,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { // PrepareEx creates a prepared statement with name and sql. sql can contain placeholders // for bound parameters. These placeholders are referenced positional as $1, $2, etc. -// It defers from Prepare as it allows additional options (such as parameter OIDs) to be passed via struct +// It defers from Prepare as it allows additional options (such as parameter Oids) to be passed via struct // // PrepareEx is idempotent; i.e. it is safe to call PrepareEx multiple times with the same // name and sql arguments. This allows a code path to PrepareEx and Query/Exec without @@ -719,11 +719,11 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared wbuf.WriteCString(sql) if opts != nil { - if len(opts.ParameterOIDs) > 65535 { - return nil, fmt.Errorf("Number of PrepareExOptions ParameterOIDs must be between 0 and 65535, received %d", len(opts.ParameterOIDs)) + if len(opts.ParameterOids) > 65535 { + return nil, fmt.Errorf("Number of PrepareExOptions ParameterOids must be between 0 and 65535, received %d", len(opts.ParameterOids)) } - wbuf.WriteInt16(int16(len(opts.ParameterOIDs))) - for _, oid := range opts.ParameterOIDs { + wbuf.WriteInt16(int16(len(opts.ParameterOids))) + for _, oid := range opts.ParameterOids { wbuf.WriteInt32(int32(oid)) } } else { @@ -760,10 +760,10 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared switch t { case parameterDescription: - ps.ParameterOIDs = c.rxParameterDescription(r) + ps.ParameterOids = c.rxParameterDescription(r) - if len(ps.ParameterOIDs) > 65535 && softErr == nil { - softErr = fmt.Errorf("PostgreSQL supports maximum of 65535 parameters, received %d", len(ps.ParameterOIDs)) + if len(ps.ParameterOids) > 65535 && softErr == nil { + softErr = fmt.Errorf("PostgreSQL supports maximum of 65535 parameters, received %d", len(ps.ParameterOids)) } case rowDescription: ps.FieldDescriptions = c.rxRowDescription(r) @@ -970,8 +970,8 @@ func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error { } func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}) (err error) { - if len(ps.ParameterOIDs) != len(arguments) { - return fmt.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOIDs), len(arguments)) + if len(ps.ParameterOids) != len(arguments) { + return fmt.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOids), len(arguments)) } if err := c.ensureConnectionReadyForQuery(); err != nil { @@ -983,8 +983,8 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteByte(0) wbuf.WriteCString(ps.Name) - wbuf.WriteInt16(int16(len(ps.ParameterOIDs))) - for i, oid := range ps.ParameterOIDs { + wbuf.WriteInt16(int16(len(ps.ParameterOids))) + for i, oid := range ps.ParameterOids { switch arg := arguments[i].(type) { case Encoder: wbuf.WriteInt16(arg.FormatCode()) @@ -1000,7 +1000,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} } wbuf.WriteInt16(int16(len(arguments))) - for i, oid := range ps.ParameterOIDs { + for i, oid := range ps.ParameterOids { if err := Encode(wbuf, oid, arguments[i]); err != nil { return err } @@ -1188,9 +1188,9 @@ func (c *Conn) rxRowDescription(r *msgReader) (fields []FieldDescription) { for i := int16(0); i < fieldCount; i++ { f := &fields[i] f.Name = r.readCString() - f.Table = OID(r.readUint32()) + f.Table = Oid(r.readUint32()) f.AttributeNumber = r.readInt16() - f.DataType = OID(r.readUint32()) + f.DataType = Oid(r.readUint32()) f.DataTypeSize = r.readInt16() f.Modifier = r.readInt32() f.FormatCode = r.readInt16() @@ -1198,7 +1198,7 @@ func (c *Conn) rxRowDescription(r *msgReader) (fields []FieldDescription) { return } -func (c *Conn) rxParameterDescription(r *msgReader) (parameters []OID) { +func (c *Conn) rxParameterDescription(r *msgReader) (parameters []Oid) { // Internally, PostgreSQL supports greater than 64k parameters to a prepared // statement. But the parameter description uses a 16-bit integer for the // count of parameters. If there are more than 64K parameters, this count is @@ -1207,10 +1207,10 @@ func (c *Conn) rxParameterDescription(r *msgReader) (parameters []OID) { r.readInt16() parameterCount := len(r.msgBody[r.rp:]) / 4 - parameters = make([]OID, 0, parameterCount) + parameters = make([]Oid, 0, parameterCount) for i := 0; i < parameterCount; i++ { - parameters = append(parameters, OID(r.readUint32())) + parameters = append(parameters, Oid(r.readUint32())) } return } diff --git a/conn_pool.go b/conn_pool.go index fd632006..3081105c 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -28,7 +28,7 @@ type ConnPool struct { closed bool preparedStatements map[string]*PreparedStatement acquireTimeout time.Duration - pgTypes map[OID]PgType + pgTypes map[Oid]PgType txAfterClose func(tx *Tx) rowsAfterClose func(rows *Rows) } @@ -446,7 +446,7 @@ func (p *ConnPool) Prepare(name, sql string) (*PreparedStatement, error) { // // PrepareEx creates a prepared statement with name and sql. sql can contain placeholders // for bound parameters. These placeholders are referenced positional as $1, $2, etc. -// It defers from Prepare as it allows additional options (such as parameter OIDs) to be passed via struct +// It defers from Prepare as it allows additional options (such as parameter Oids) to be passed via struct // // PrepareEx is idempotent; i.e. it is safe to call PrepareEx multiple times with the same // name and sql arguments. This allows a code path to PrepareEx and Query/Exec/Prepare without diff --git a/conn_test.go b/conn_test.go index b44fd6db..a6034be6 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1042,7 +1042,7 @@ func TestPrepareEx(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - _, err := conn.PrepareEx("test", "select $1", &pgx.PrepareExOptions{ParameterOIDs: []pgx.OID{pgx.TextOID}}) + _, err := conn.PrepareEx("test", "select $1", &pgx.PrepareExOptions{ParameterOids: []pgx.Oid{pgx.TextOid}}) if err != nil { t.Errorf("Unable to prepare statement: %v", err) return diff --git a/copy_to_test.go b/copy_to_test.go index 7d5f2509..ee96054a 100644 --- a/copy_to_test.go +++ b/copy_to_test.go @@ -125,7 +125,7 @@ func TestConnCopyToJSON(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - for _, oid := range []pgx.OID{pgx.JSONOID, pgx.JSONBOID} { + for _, oid := range []pgx.Oid{pgx.JsonOid, pgx.JsonbOid} { if _, ok := conn.PgTypes[oid]; !ok { return // No JSON/JSONB type -- must be running against old PostgreSQL } diff --git a/doc.go b/doc.go index 514d51a0..5f3490ca 100644 --- a/doc.go +++ b/doc.go @@ -169,10 +169,10 @@ there. pgx.DefaultTypeFormats["point"] = pgx.BinaryFormatCode -Note that the type is referred to by name, not by OID. This is because custom -PostgreSQL types like hstore will have different OIDs on different servers. When +Note that the type is referred to by name, not by Oid. This is because custom +PostgreSQL types like hstore will have different Oids on different servers. When pgx establishes a connection it queries the pg_type table for all types. It then -matches the names in DefaultTypeFormats with the returned OIDs and stores it in +matches the names in DefaultTypeFormats with the returned Oids and stores it in Conn.PgTypes. See example_custom_type_test.go for an example of a custom type for the diff --git a/example_custom_type_test.go b/example_custom_type_test.go index 674466f3..74fbab67 100644 --- a/example_custom_type_test.go +++ b/example_custom_type_test.go @@ -3,9 +3,10 @@ package pgx_test import ( "errors" "fmt" - "github.com/jackc/pgx" "regexp" "strconv" + + "github.com/jackc/pgx" ) var pointRegexp *regexp.Regexp = regexp.MustCompile(`^\((.*),(.*)\)$`) @@ -20,7 +21,7 @@ type NullPoint struct { func (p *NullPoint) ScanPgx(vr *pgx.ValueReader) error { if vr.Type().DataTypeName != "point" { - return pgx.SerializationError(fmt.Sprintf("NullPoint.Scan cannot decode %s (OID %d)", vr.Type().DataTypeName, vr.Type().DataType)) + return pgx.SerializationError(fmt.Sprintf("NullPoint.Scan cannot decode %s (Oid %d)", vr.Type().DataTypeName, vr.Type().DataType)) } if vr.Len() == -1 { @@ -57,7 +58,7 @@ func (p *NullPoint) ScanPgx(vr *pgx.ValueReader) error { func (p NullPoint) FormatCode() int16 { return pgx.BinaryFormatCode } -func (p NullPoint) Encode(w *pgx.WriteBuf, oid pgx.OID) error { +func (p NullPoint) Encode(w *pgx.WriteBuf, oid pgx.Oid) error { if !p.Valid { w.WriteInt32(-1) return nil diff --git a/fastpath.go b/fastpath.go index af055e56..d58a7754 100644 --- a/fastpath.go +++ b/fastpath.go @@ -5,26 +5,26 @@ import ( ) func newFastpath(cn *Conn) *fastpath { - return &fastpath{cn: cn, fns: make(map[string]OID)} + return &fastpath{cn: cn, fns: make(map[string]Oid)} } type fastpath struct { cn *Conn - fns map[string]OID + fns map[string]Oid } -func (f *fastpath) functionOID(name string) OID { +func (f *fastpath) functionOid(name string) Oid { return f.fns[name] } -func (f *fastpath) addFunction(name string, oid OID) { +func (f *fastpath) addFunction(name string, oid Oid) { f.fns[name] = oid } func (f *fastpath) addFunctions(rows *Rows) error { for rows.Next() { var name string - var oid OID + var oid Oid if err := rows.Scan(&name, &oid); err != nil { return err } @@ -47,7 +47,7 @@ func fpInt64Arg(n int64) fpArg { return res } -func (f *fastpath) Call(oid OID, args []fpArg) (res []byte, err error) { +func (f *fastpath) Call(oid Oid, args []fpArg) (res []byte, err error) { if err := f.cn.ensureConnectionReadyForQuery(); err != nil { return nil, err } @@ -93,7 +93,7 @@ func (f *fastpath) Call(oid OID, args []fpArg) (res []byte, err error) { } func (f *fastpath) CallFn(fn string, args []fpArg) ([]byte, error) { - return f.Call(f.functionOID(fn), args) + return f.Call(f.functionOid(fn), args) } func fpInt32(data []byte, err error) (int32, error) { diff --git a/large_objects.go b/large_objects.go index 5b3e3a33..960e1e25 100644 --- a/large_objects.go +++ b/large_objects.go @@ -59,20 +59,20 @@ const ( ) // Create creates a new large object. If id is zero, the server assigns an -// unused OID. -func (o *LargeObjects) Create(id OID) (OID, error) { - newOID, err := fpInt32(o.fp.CallFn("lo_create", []fpArg{fpIntArg(int32(id))})) - return OID(newOID), err +// unused Oid. +func (o *LargeObjects) Create(id Oid) (Oid, error) { + newOid, err := fpInt32(o.fp.CallFn("lo_create", []fpArg{fpIntArg(int32(id))})) + return Oid(newOid), err } // Open opens an existing large object with the given mode. -func (o *LargeObjects) Open(oid OID, mode LargeObjectMode) (*LargeObject, error) { +func (o *LargeObjects) Open(oid Oid, mode LargeObjectMode) (*LargeObject, error) { fd, err := fpInt32(o.fp.CallFn("lo_open", []fpArg{fpIntArg(int32(oid)), fpIntArg(int32(mode))})) return &LargeObject{fd: fd, lo: o}, err } // Unlink removes a large object from the database. -func (o *LargeObjects) Unlink(oid OID) error { +func (o *LargeObjects) Unlink(oid Oid) error { _, err := o.fp.CallFn("lo_unlink", []fpArg{fpIntArg(int32(oid))}) return err } diff --git a/messages.go b/messages.go index f6be9ff9..0c14c61d 100644 --- a/messages.go +++ b/messages.go @@ -55,9 +55,9 @@ func (s *startupMessage) Bytes() (buf []byte) { type FieldDescription struct { Name string - Table OID + Table Oid AttributeNumber int16 - DataType OID + DataType Oid DataTypeSize int16 DataTypeName string Modifier int32 diff --git a/pgtype/aclitem.go b/pgtype/aclitem.go index bd7b7d45..821c5001 100644 --- a/pgtype/aclitem.go +++ b/pgtype/aclitem.go @@ -6,7 +6,7 @@ import ( "reflect" ) -// ACLItem is used for PostgreSQL's aclitem data type. A sample aclitem +// Aclitem is used for PostgreSQL's aclitem data type. A sample aclitem // might look like this: // // postgres=arwdDxt/postgres @@ -18,34 +18,34 @@ import ( // // postgres=arwdDxt/"role with spaces" // -type ACLItem struct { +type Aclitem struct { String string Status Status } -func (dst *ACLItem) ConvertFrom(src interface{}) error { +func (dst *Aclitem) ConvertFrom(src interface{}) error { switch value := src.(type) { - case ACLItem: + case Aclitem: *dst = value case string: - *dst = ACLItem{String: value, Status: Present} + *dst = Aclitem{String: value, Status: Present} case *string: if value == nil { - *dst = ACLItem{Status: Null} + *dst = Aclitem{Status: Null} } else { - *dst = ACLItem{String: *value, Status: Present} + *dst = Aclitem{String: *value, Status: Present} } default: if originalSrc, ok := underlyingStringType(src); ok { return dst.ConvertFrom(originalSrc) } - return fmt.Errorf("cannot convert %v to ACLItem", value) + return fmt.Errorf("cannot convert %v to Aclitem", value) } return nil } -func (src *ACLItem) AssignTo(dst interface{}) error { +func (src *Aclitem) AssignTo(dst interface{}) error { switch v := dst.(type) { case *string: if src.Status != Present { @@ -81,17 +81,17 @@ func (src *ACLItem) AssignTo(dst interface{}) error { return nil } -func (dst *ACLItem) DecodeText(src []byte) error { +func (dst *Aclitem) DecodeText(src []byte) error { if src == nil { - *dst = ACLItem{Status: Null} + *dst = Aclitem{Status: Null} return nil } - *dst = ACLItem{String: string(src), Status: Present} + *dst = Aclitem{String: string(src), Status: Present} return nil } -func (src ACLItem) EncodeText(w io.Writer) (bool, error) { +func (src Aclitem) EncodeText(w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/aclitem_array.go b/pgtype/aclitem_array.go index d69cd83c..48f5cd38 100644 --- a/pgtype/aclitem_array.go +++ b/pgtype/aclitem_array.go @@ -8,30 +8,30 @@ import ( "github.com/jackc/pgx/pgio" ) -type ACLItemArray struct { - Elements []ACLItem +type AclitemArray struct { + Elements []Aclitem Dimensions []ArrayDimension Status Status } -func (dst *ACLItemArray) ConvertFrom(src interface{}) error { +func (dst *AclitemArray) ConvertFrom(src interface{}) error { switch value := src.(type) { - case ACLItemArray: + case AclitemArray: *dst = value case []string: if value == nil { - *dst = ACLItemArray{Status: Null} + *dst = AclitemArray{Status: Null} } else if len(value) == 0 { - *dst = ACLItemArray{Status: Present} + *dst = AclitemArray{Status: Present} } else { - elements := make([]ACLItem, len(value)) + elements := make([]Aclitem, len(value)) for i := range value { if err := elements[i].ConvertFrom(value[i]); err != nil { return err } } - *dst = ACLItemArray{ + *dst = AclitemArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, Status: Present, @@ -42,13 +42,13 @@ func (dst *ACLItemArray) ConvertFrom(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.ConvertFrom(originalSrc) } - return fmt.Errorf("cannot convert %v to ACLItem", value) + return fmt.Errorf("cannot convert %v to Aclitem", value) } return nil } -func (src *ACLItemArray) AssignTo(dst interface{}) error { +func (src *AclitemArray) AssignTo(dst interface{}) error { switch v := dst.(type) { case *[]string: @@ -73,9 +73,9 @@ func (src *ACLItemArray) AssignTo(dst interface{}) error { return nil } -func (dst *ACLItemArray) DecodeText(src []byte) error { +func (dst *AclitemArray) DecodeText(src []byte) error { if src == nil { - *dst = ACLItemArray{Status: Null} + *dst = AclitemArray{Status: Null} return nil } @@ -84,13 +84,13 @@ func (dst *ACLItemArray) DecodeText(src []byte) error { return err } - var elements []ACLItem + var elements []Aclitem if len(uta.Elements) > 0 { - elements = make([]ACLItem, len(uta.Elements)) + elements = make([]Aclitem, len(uta.Elements)) for i, s := range uta.Elements { - var elem ACLItem + var elem Aclitem var elemSrc []byte if s != "NULL" { elemSrc = []byte(s) @@ -104,12 +104,12 @@ func (dst *ACLItemArray) DecodeText(src []byte) error { } } - *dst = ACLItemArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = AclitemArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} return nil } -func (src *ACLItemArray) EncodeText(w io.Writer) (bool, error) { +func (src *AclitemArray) EncodeText(w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/aclitem_array_test.go b/pgtype/aclitem_array_test.go index 8c01ac66..e78f14c6 100644 --- a/pgtype/aclitem_array_test.go +++ b/pgtype/aclitem_array_test.go @@ -7,40 +7,40 @@ import ( "github.com/jackc/pgx/pgtype" ) -func TestACLItemArrayTranscode(t *testing.T) { +func TestAclitemArrayTranscode(t *testing.T) { testSuccessfulTranscode(t, "aclitem[]", []interface{}{ - &pgtype.ACLItemArray{ + &pgtype.AclitemArray{ Elements: nil, Dimensions: nil, Status: pgtype.Present, }, - &pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, - pgtype.ACLItem{Status: pgtype.Null}, + &pgtype.AclitemArray{ + Elements: []pgtype.Aclitem{ + pgtype.Aclitem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.Aclitem{Status: pgtype.Null}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, - &pgtype.ACLItemArray{Status: pgtype.Null}, - &pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, - pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, - pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, - pgtype.ACLItem{Status: pgtype.Null}, - pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + &pgtype.AclitemArray{Status: pgtype.Null}, + &pgtype.AclitemArray{ + Elements: []pgtype.Aclitem{ + pgtype.Aclitem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + pgtype.Aclitem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, + pgtype.Aclitem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.Aclitem{Status: pgtype.Null}, + pgtype.Aclitem{String: "=r/postgres", Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, - &pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{ - pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, - pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, - pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + &pgtype.AclitemArray{ + Elements: []pgtype.Aclitem{ + pgtype.Aclitem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + pgtype.Aclitem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, @@ -51,26 +51,26 @@ func TestACLItemArrayTranscode(t *testing.T) { }) } -func TestACLItemArrayConvertFrom(t *testing.T) { +func TestAclitemArrayConvertFrom(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.ACLItemArray + result pgtype.AclitemArray }{ { source: []string{"=r/postgres"}, - result: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, + result: pgtype.AclitemArray{ + Elements: []pgtype.Aclitem{{String: "=r/postgres", Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present}, }, { source: (([]string)(nil)), - result: pgtype.ACLItemArray{Status: pgtype.Null}, + result: pgtype.AclitemArray{Status: pgtype.Null}, }, } for i, tt := range successfulTests { - var r pgtype.ACLItemArray + var r pgtype.AclitemArray err := r.ConvertFrom(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -82,19 +82,19 @@ func TestACLItemArrayConvertFrom(t *testing.T) { } } -func TestACLItemArrayAssignTo(t *testing.T) { +func TestAclitemArrayAssignTo(t *testing.T) { var stringSlice []string type _stringSlice []string var namedStringSlice _stringSlice simpleTests := []struct { - src pgtype.ACLItemArray + src pgtype.AclitemArray dst interface{} expected interface{} }{ { - src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, + src: pgtype.AclitemArray{ + Elements: []pgtype.Aclitem{{String: "=r/postgres", Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, @@ -102,8 +102,8 @@ func TestACLItemArrayAssignTo(t *testing.T) { expected: []string{"=r/postgres"}, }, { - src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, + src: pgtype.AclitemArray{ + Elements: []pgtype.Aclitem{{String: "=r/postgres", Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, @@ -111,7 +111,7 @@ func TestACLItemArrayAssignTo(t *testing.T) { expected: _stringSlice{"=r/postgres"}, }, { - src: pgtype.ACLItemArray{Status: pgtype.Null}, + src: pgtype.AclitemArray{Status: pgtype.Null}, dst: &stringSlice, expected: (([]string)(nil)), }, @@ -129,12 +129,12 @@ func TestACLItemArrayAssignTo(t *testing.T) { } errorTests := []struct { - src pgtype.ACLItemArray + src pgtype.AclitemArray dst interface{} }{ { - src: pgtype.ACLItemArray{ - Elements: []pgtype.ACLItem{{Status: pgtype.Null}}, + src: pgtype.AclitemArray{ + Elements: []pgtype.Aclitem{{Status: pgtype.Null}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, diff --git a/pgtype/aclitem_test.go b/pgtype/aclitem_test.go index 0b2b6cfa..fc429acc 100644 --- a/pgtype/aclitem_test.go +++ b/pgtype/aclitem_test.go @@ -7,26 +7,26 @@ import ( "github.com/jackc/pgx/pgtype" ) -func TestACLItemTranscode(t *testing.T) { +func TestAclitemTranscode(t *testing.T) { testSuccessfulTranscode(t, "aclitem", []interface{}{ - pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, - pgtype.ACLItem{Status: pgtype.Null}, + pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + pgtype.Aclitem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, + pgtype.Aclitem{Status: pgtype.Null}, }) } -func TestACLItemConvertFrom(t *testing.T) { +func TestAclitemConvertFrom(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.ACLItem + result pgtype.Aclitem }{ - {source: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, result: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, - {source: "postgres=arwdDxt/postgres", result: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, - {source: (*string)(nil), result: pgtype.ACLItem{Status: pgtype.Null}}, + {source: pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, result: pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + {source: "postgres=arwdDxt/postgres", result: pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + {source: (*string)(nil), result: pgtype.Aclitem{Status: pgtype.Null}}, } for i, tt := range successfulTests { - var d pgtype.ACLItem + var d pgtype.Aclitem err := d.ConvertFrom(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -38,17 +38,17 @@ func TestACLItemConvertFrom(t *testing.T) { } } -func TestACLItemAssignTo(t *testing.T) { +func TestAclitemAssignTo(t *testing.T) { var s string var ps *string simpleTests := []struct { - src pgtype.ACLItem + src pgtype.Aclitem dst interface{} expected interface{} }{ - {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &s, expected: "postgres=arwdDxt/postgres"}, - {src: pgtype.ACLItem{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + {src: pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &s, expected: "postgres=arwdDxt/postgres"}, + {src: pgtype.Aclitem{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, } for i, tt := range simpleTests { @@ -63,11 +63,11 @@ func TestACLItemAssignTo(t *testing.T) { } pointerAllocTests := []struct { - src pgtype.ACLItem + src pgtype.Aclitem dst interface{} expected interface{} }{ - {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &ps, expected: "postgres=arwdDxt/postgres"}, + {src: pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &ps, expected: "postgres=arwdDxt/postgres"}, } for i, tt := range pointerAllocTests { @@ -82,10 +82,10 @@ func TestACLItemAssignTo(t *testing.T) { } errorTests := []struct { - src pgtype.ACLItem + src pgtype.Aclitem dst interface{} }{ - {src: pgtype.ACLItem{Status: pgtype.Null}, dst: &s}, + {src: pgtype.Aclitem{Status: pgtype.Null}, dst: &s}, } for i, tt := range errorTests { diff --git a/pgtype/array.go b/pgtype/array.go index 90092c8d..dff0fe81 100644 --- a/pgtype/array.go +++ b/pgtype/array.go @@ -18,7 +18,7 @@ import ( type ArrayHeader struct { ContainsNull bool - ElementOID int32 + ElementOid int32 Dimensions []ArrayDimension } @@ -40,7 +40,7 @@ func (dst *ArrayHeader) DecodeBinary(src []byte) (int, error) { dst.ContainsNull = binary.BigEndian.Uint32(src[rp:]) == 1 rp += 4 - dst.ElementOID = int32(binary.BigEndian.Uint32(src[rp:])) + dst.ElementOid = int32(binary.BigEndian.Uint32(src[rp:])) rp += 4 if numDims > 0 { @@ -75,7 +75,7 @@ func (src *ArrayHeader) EncodeBinary(w io.Writer) error { return err } - _, err = pgio.WriteInt32(w, src.ElementOID) + _, err = pgio.WriteInt32(w, src.ElementOid) if err != nil { return err } diff --git a/pgtype/bool_array.go b/pgtype/bool_array.go index 65a6bc9c..a74e9f90 100644 --- a/pgtype/bool_array.go +++ b/pgtype/bool_array.go @@ -229,10 +229,10 @@ func (src *BoolArray) EncodeText(w io.Writer) (bool, error) { } func (src *BoolArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, BoolOID) + return src.encodeBinary(w, BoolOid) } -func (src *BoolArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) { +func (src *BoolArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -241,7 +241,7 @@ func (src *BoolArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) } arrayHeader := ArrayHeader{ - ElementOID: elementOID, + ElementOid: elementOid, Dimensions: src.Dimensions, } diff --git a/pgtype/bytea_array.go b/pgtype/bytea_array.go index 7a4f1601..9003eafd 100644 --- a/pgtype/bytea_array.go +++ b/pgtype/bytea_array.go @@ -229,10 +229,10 @@ func (src *ByteaArray) EncodeText(w io.Writer) (bool, error) { } func (src *ByteaArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, ByteaOID) + return src.encodeBinary(w, ByteaOid) } -func (src *ByteaArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) { +func (src *ByteaArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -241,7 +241,7 @@ func (src *ByteaArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) } arrayHeader := ArrayHeader{ - ElementOID: elementOID, + ElementOid: elementOid, Dimensions: src.Dimensions, } diff --git a/pgtype/cid.go b/pgtype/cid.go index 41b817bb..be93a03e 100644 --- a/pgtype/cid.go +++ b/pgtype/cid.go @@ -4,7 +4,7 @@ import ( "io" ) -// CID is PostgreSQL's Command Identifier type. +// Cid is PostgreSQL's Command Identifier type. // // When one does // @@ -15,33 +15,33 @@ import ( // It is currently implemented as an unsigned four byte integer. // Its definition can be found in src/include/c.h as CommandId // in the PostgreSQL sources. -type CID pguint32 +type Cid pguint32 -// ConvertFrom converts from src to dst. Note that as CID is not a general +// ConvertFrom converts from src to dst. Note that as Cid is not a general // number type ConvertFrom does not do automatic type conversion as other number // types do. -func (dst *CID) ConvertFrom(src interface{}) error { +func (dst *Cid) ConvertFrom(src interface{}) error { return (*pguint32)(dst).ConvertFrom(src) } -// AssignTo assigns from src to dst. Note that as CID is not a general number +// AssignTo assigns from src to dst. Note that as Cid is not a general number // type AssignTo does not do automatic type conversion as other number types do. -func (src *CID) AssignTo(dst interface{}) error { +func (src *Cid) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *CID) DecodeText(src []byte) error { +func (dst *Cid) DecodeText(src []byte) error { return (*pguint32)(dst).DecodeText(src) } -func (dst *CID) DecodeBinary(src []byte) error { +func (dst *Cid) DecodeBinary(src []byte) error { return (*pguint32)(dst).DecodeBinary(src) } -func (src CID) EncodeText(w io.Writer) (bool, error) { +func (src Cid) EncodeText(w io.Writer) (bool, error) { return (pguint32)(src).EncodeText(w) } -func (src CID) EncodeBinary(w io.Writer) (bool, error) { +func (src Cid) EncodeBinary(w io.Writer) (bool, error) { return (pguint32)(src).EncodeBinary(w) } diff --git a/pgtype/cid_test.go b/pgtype/cid_test.go index 72f5dfea..7d9fde34 100644 --- a/pgtype/cid_test.go +++ b/pgtype/cid_test.go @@ -7,23 +7,23 @@ import ( "github.com/jackc/pgx/pgtype" ) -func TestCIDTranscode(t *testing.T) { +func TestCidTranscode(t *testing.T) { testSuccessfulTranscode(t, "cid", []interface{}{ - pgtype.CID{Uint: 42, Status: pgtype.Present}, - pgtype.CID{Status: pgtype.Null}, + pgtype.Cid{Uint: 42, Status: pgtype.Present}, + pgtype.Cid{Status: pgtype.Null}, }) } -func TestCIDConvertFrom(t *testing.T) { +func TestCidConvertFrom(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.CID + result pgtype.Cid }{ - {source: uint32(1), result: pgtype.CID{Uint: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Cid{Uint: 1, Status: pgtype.Present}}, } for i, tt := range successfulTests { - var r pgtype.CID + var r pgtype.Cid err := r.ConvertFrom(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -35,17 +35,17 @@ func TestCIDConvertFrom(t *testing.T) { } } -func TestCIDAssignTo(t *testing.T) { +func TestCidAssignTo(t *testing.T) { var ui32 uint32 var pui32 *uint32 simpleTests := []struct { - src pgtype.CID + src pgtype.Cid dst interface{} expected interface{} }{ - {src: pgtype.CID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.CID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + {src: pgtype.Cid{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Cid{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, } for i, tt := range simpleTests { @@ -60,11 +60,11 @@ func TestCIDAssignTo(t *testing.T) { } pointerAllocTests := []struct { - src pgtype.CID + src pgtype.Cid dst interface{} expected interface{} }{ - {src: pgtype.CID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + {src: pgtype.Cid{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, } for i, tt := range pointerAllocTests { @@ -79,10 +79,10 @@ func TestCIDAssignTo(t *testing.T) { } errorTests := []struct { - src pgtype.CID + src pgtype.Cid dst interface{} }{ - {src: pgtype.CID{Status: pgtype.Null}, dst: &ui32}, + {src: pgtype.Cid{Status: pgtype.Null}, dst: &ui32}, } for i, tt := range errorTests { diff --git a/pgtype/cidr_array.go b/pgtype/cidr_array.go index cb81d2b9..e0219ee5 100644 --- a/pgtype/cidr_array.go +++ b/pgtype/cidr_array.go @@ -27,5 +27,5 @@ func (src *CidrArray) EncodeText(w io.Writer) (bool, error) { } func (src *CidrArray) EncodeBinary(w io.Writer) (bool, error) { - return (*InetArray)(src).encodeBinary(w, CidrOID) + return (*InetArray)(src).encodeBinary(w, CidrOid) } diff --git a/pgtype/date_array.go b/pgtype/date_array.go index 623ff9b3..8f7cba18 100644 --- a/pgtype/date_array.go +++ b/pgtype/date_array.go @@ -230,10 +230,10 @@ func (src *DateArray) EncodeText(w io.Writer) (bool, error) { } func (src *DateArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, DateOID) + return src.encodeBinary(w, DateOid) } -func (src *DateArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) { +func (src *DateArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -242,7 +242,7 @@ func (src *DateArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) } arrayHeader := ArrayHeader{ - ElementOID: elementOID, + ElementOid: elementOid, Dimensions: src.Dimensions, } diff --git a/pgtype/extra-interface.txt b/pgtype/extra-interface.txt index 16453823..f07818bc 100644 --- a/pgtype/extra-interface.txt +++ b/pgtype/extra-interface.txt @@ -1,3 +1,3 @@ Can pass function to get inet data and function to get oid/name mapping as optional interface with io.Reader or io.Writer -Could be useful for arrays of types without defined OIDs like hstore. +Could be useful for arrays of types without defined Oids like hstore. diff --git a/pgtype/float4_array.go b/pgtype/float4_array.go index c55f76d0..632e7e4b 100644 --- a/pgtype/float4_array.go +++ b/pgtype/float4_array.go @@ -229,10 +229,10 @@ func (src *Float4Array) EncodeText(w io.Writer) (bool, error) { } func (src *Float4Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Float4OID) + return src.encodeBinary(w, Float4Oid) } -func (src *Float4Array) encodeBinary(w io.Writer, elementOID int32) (bool, error) { +func (src *Float4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -241,7 +241,7 @@ func (src *Float4Array) encodeBinary(w io.Writer, elementOID int32) (bool, error } arrayHeader := ArrayHeader{ - ElementOID: elementOID, + ElementOid: elementOid, Dimensions: src.Dimensions, } diff --git a/pgtype/float8_array.go b/pgtype/float8_array.go index d08a5351..68cf30f2 100644 --- a/pgtype/float8_array.go +++ b/pgtype/float8_array.go @@ -229,10 +229,10 @@ func (src *Float8Array) EncodeText(w io.Writer) (bool, error) { } func (src *Float8Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Float8OID) + return src.encodeBinary(w, Float8Oid) } -func (src *Float8Array) encodeBinary(w io.Writer, elementOID int32) (bool, error) { +func (src *Float8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -241,7 +241,7 @@ func (src *Float8Array) encodeBinary(w io.Writer, elementOID int32) (bool, error } arrayHeader := ArrayHeader{ - ElementOID: elementOID, + ElementOid: elementOid, Dimensions: src.Dimensions, } diff --git a/pgtype/inet_array.go b/pgtype/inet_array.go index 12d9493b..629cd51f 100644 --- a/pgtype/inet_array.go +++ b/pgtype/inet_array.go @@ -261,10 +261,10 @@ func (src *InetArray) EncodeText(w io.Writer) (bool, error) { } func (src *InetArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, InetOID) + return src.encodeBinary(w, InetOid) } -func (src *InetArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) { +func (src *InetArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -273,7 +273,7 @@ func (src *InetArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) } arrayHeader := ArrayHeader{ - ElementOID: elementOID, + ElementOid: elementOid, Dimensions: src.Dimensions, } diff --git a/pgtype/inet_array_test.go b/pgtype/inet_array_test.go index 8cab5355..523a9f8d 100644 --- a/pgtype/inet_array_test.go +++ b/pgtype/inet_array_test.go @@ -17,7 +17,7 @@ func TestInetArrayTranscode(t *testing.T) { }, &pgtype.InetArray{ Elements: []pgtype.Inet{ - pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, pgtype.Inet{Status: pgtype.Null}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, @@ -26,22 +26,22 @@ func TestInetArrayTranscode(t *testing.T) { &pgtype.InetArray{Status: pgtype.Null}, &pgtype.InetArray{ Elements: []pgtype.Inet{ - pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, pgtype.Inet{Status: pgtype.Null}, - pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "255.0.0.0/8"), Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, &pgtype.InetArray{ Elements: []pgtype.Inet{ - pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, @@ -58,9 +58,9 @@ func TestInetArrayConvertFrom(t *testing.T) { result pgtype.InetArray }{ { - source: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, + source: []*net.IPNet{mustParseCidr(t, "127.0.0.1/32")}, result: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Elements: []pgtype.Inet{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present}, }, @@ -69,9 +69,9 @@ func TestInetArrayConvertFrom(t *testing.T) { result: pgtype.InetArray{Status: pgtype.Null}, }, { - source: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, + source: []net.IP{mustParseCidr(t, "127.0.0.1/32").IP}, result: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Elements: []pgtype.Inet{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present}, }, @@ -105,12 +105,12 @@ func TestInetArrayAssignTo(t *testing.T) { }{ { src: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Elements: []pgtype.Inet{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, dst: &ipnetSlice, - expected: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, + expected: []*net.IPNet{mustParseCidr(t, "127.0.0.1/32")}, }, { src: pgtype.InetArray{ @@ -123,12 +123,12 @@ func TestInetArrayAssignTo(t *testing.T) { }, { src: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Elements: []pgtype.Inet{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, dst: &ipSlice, - expected: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, + expected: []net.IP{mustParseCidr(t, "127.0.0.1/32").IP}, }, { src: pgtype.InetArray{ diff --git a/pgtype/inet_test.go b/pgtype/inet_test.go index 5e86376b..5a326810 100644 --- a/pgtype/inet_test.go +++ b/pgtype/inet_test.go @@ -11,16 +11,16 @@ import ( func TestInetTranscode(t *testing.T) { for _, pgTypeName := range []string{"inet", "cidr"} { testSuccessfulTranscode(t, pgTypeName, []interface{}{ - pgtype.Inet{IPNet: mustParseCIDR(t, "0.0.0.0/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.1.0/24"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "255.255.255.255/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "::/128"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "::/0"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "::1/128"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "0.0.0.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "192.168.1.0/24"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "255.0.0.0/8"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "255.255.255.255/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "::/128"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "::/0"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "::1/128"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, pgtype.Inet{Status: pgtype.Null}, }) } @@ -31,10 +31,10 @@ func TestInetConvertFrom(t *testing.T) { source interface{} result pgtype.Inet }{ - {source: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Null}, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Null}}, - {source: mustParseCIDR(t, "127.0.0.1/32"), result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - {source: mustParseCIDR(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - {source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + {source: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Null}, result: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Null}}, + {source: mustParseCidr(t, "127.0.0.1/32"), result: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + {source: mustParseCidr(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + {source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, } for i, tt := range successfulTests { @@ -61,8 +61,8 @@ func TestInetAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, + {src: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ipnet, expected: *mustParseCidr(t, "127.0.0.1/32")}, + {src: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ip, expected: mustParseCidr(t, "127.0.0.1/32").IP}, {src: pgtype.Inet{Status: pgtype.Null}, dst: &pipnet, expected: ((*net.IPNet)(nil))}, {src: pgtype.Inet{Status: pgtype.Null}, dst: &pip, expected: ((*net.IP)(nil))}, } @@ -83,8 +83,8 @@ func TestInetAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, + {src: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pipnet, expected: *mustParseCidr(t, "127.0.0.1/32")}, + {src: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pip, expected: mustParseCidr(t, "127.0.0.1/32").IP}, } for i, tt := range pointerAllocTests { @@ -102,7 +102,7 @@ func TestInetAssignTo(t *testing.T) { src pgtype.Inet dst interface{} }{ - {src: pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.0/16"), Status: pgtype.Present}, dst: &ip}, + {src: pgtype.Inet{IPNet: mustParseCidr(t, "192.168.0.0/16"), Status: pgtype.Present}, dst: &ip}, {src: pgtype.Inet{Status: pgtype.Null}, dst: &ipnet}, } diff --git a/pgtype/int2_array.go b/pgtype/int2_array.go index 37ee9926..d8268c0a 100644 --- a/pgtype/int2_array.go +++ b/pgtype/int2_array.go @@ -260,10 +260,10 @@ func (src *Int2Array) EncodeText(w io.Writer) (bool, error) { } func (src *Int2Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Int2OID) + return src.encodeBinary(w, Int2Oid) } -func (src *Int2Array) encodeBinary(w io.Writer, elementOID int32) (bool, error) { +func (src *Int2Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -272,7 +272,7 @@ func (src *Int2Array) encodeBinary(w io.Writer, elementOID int32) (bool, error) } arrayHeader := ArrayHeader{ - ElementOID: elementOID, + ElementOid: elementOid, Dimensions: src.Dimensions, } diff --git a/pgtype/int4_array.go b/pgtype/int4_array.go index f6f62e4b..dcdb50c1 100644 --- a/pgtype/int4_array.go +++ b/pgtype/int4_array.go @@ -260,10 +260,10 @@ func (src *Int4Array) EncodeText(w io.Writer) (bool, error) { } func (src *Int4Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Int4OID) + return src.encodeBinary(w, Int4Oid) } -func (src *Int4Array) encodeBinary(w io.Writer, elementOID int32) (bool, error) { +func (src *Int4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -272,7 +272,7 @@ func (src *Int4Array) encodeBinary(w io.Writer, elementOID int32) (bool, error) } arrayHeader := ArrayHeader{ - ElementOID: elementOID, + ElementOid: elementOid, Dimensions: src.Dimensions, } diff --git a/pgtype/int8_array.go b/pgtype/int8_array.go index 92d8ec46..ed82f079 100644 --- a/pgtype/int8_array.go +++ b/pgtype/int8_array.go @@ -260,10 +260,10 @@ func (src *Int8Array) EncodeText(w io.Writer) (bool, error) { } func (src *Int8Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Int8OID) + return src.encodeBinary(w, Int8Oid) } -func (src *Int8Array) encodeBinary(w io.Writer, elementOID int32) (bool, error) { +func (src *Int8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -272,7 +272,7 @@ func (src *Int8Array) encodeBinary(w io.Writer, elementOID int32) (bool, error) } arrayHeader := ArrayHeader{ - ElementOID: elementOID, + ElementOid: elementOid, Dimensions: src.Dimensions, } diff --git a/pgtype/oid.go b/pgtype/oid.go index e1bee4cf..c77f3f10 100644 --- a/pgtype/oid.go +++ b/pgtype/oid.go @@ -4,38 +4,38 @@ import ( "io" ) -// OID (Object Identifier Type) is, according to +// Oid (Object Identifier Type) is, according to // https://www.postgresql.org/docs/current/static/datatype-oid.html, used // internally by PostgreSQL as a primary key for various system tables. It is // currently implemented as an unsigned four-byte integer. Its definition can be // found in src/include/postgres_ext.h in the PostgreSQL sources. -type OID pguint32 +type Oid pguint32 -// ConvertFrom converts from src to dst. Note that as OID is not a general +// ConvertFrom converts from src to dst. Note that as Oid is not a general // number type ConvertFrom does not do automatic type conversion as other number // types do. -func (dst *OID) ConvertFrom(src interface{}) error { +func (dst *Oid) ConvertFrom(src interface{}) error { return (*pguint32)(dst).ConvertFrom(src) } -// AssignTo assigns from src to dst. Note that as OID is not a general number +// AssignTo assigns from src to dst. Note that as Oid is not a general number // type AssignTo does not do automatic type conversion as other number types do. -func (src *OID) AssignTo(dst interface{}) error { +func (src *Oid) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *OID) DecodeText(src []byte) error { +func (dst *Oid) DecodeText(src []byte) error { return (*pguint32)(dst).DecodeText(src) } -func (dst *OID) DecodeBinary(src []byte) error { +func (dst *Oid) DecodeBinary(src []byte) error { return (*pguint32)(dst).DecodeBinary(src) } -func (src OID) EncodeText(w io.Writer) (bool, error) { +func (src Oid) EncodeText(w io.Writer) (bool, error) { return (pguint32)(src).EncodeText(w) } -func (src OID) EncodeBinary(w io.Writer) (bool, error) { +func (src Oid) EncodeBinary(w io.Writer) (bool, error) { return (pguint32)(src).EncodeBinary(w) } diff --git a/pgtype/oid_test.go b/pgtype/oid_test.go index c8e0b2d6..bbab6699 100644 --- a/pgtype/oid_test.go +++ b/pgtype/oid_test.go @@ -7,23 +7,23 @@ import ( "github.com/jackc/pgx/pgtype" ) -func TestOIDTranscode(t *testing.T) { +func TestOidTranscode(t *testing.T) { testSuccessfulTranscode(t, "oid", []interface{}{ - pgtype.OID{Uint: 42, Status: pgtype.Present}, - pgtype.OID{Status: pgtype.Null}, + pgtype.Oid{Uint: 42, Status: pgtype.Present}, + pgtype.Oid{Status: pgtype.Null}, }) } -func TestOIDConvertFrom(t *testing.T) { +func TestOidConvertFrom(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.OID + result pgtype.Oid }{ - {source: uint32(1), result: pgtype.OID{Uint: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Oid{Uint: 1, Status: pgtype.Present}}, } for i, tt := range successfulTests { - var r pgtype.OID + var r pgtype.Oid err := r.ConvertFrom(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -35,17 +35,17 @@ func TestOIDConvertFrom(t *testing.T) { } } -func TestOIDAssignTo(t *testing.T) { +func TestOidAssignTo(t *testing.T) { var ui32 uint32 var pui32 *uint32 simpleTests := []struct { - src pgtype.OID + src pgtype.Oid dst interface{} expected interface{} }{ - {src: pgtype.OID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.OID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + {src: pgtype.Oid{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Oid{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, } for i, tt := range simpleTests { @@ -60,11 +60,11 @@ func TestOIDAssignTo(t *testing.T) { } pointerAllocTests := []struct { - src pgtype.OID + src pgtype.Oid dst interface{} expected interface{} }{ - {src: pgtype.OID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + {src: pgtype.Oid{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, } for i, tt := range pointerAllocTests { @@ -79,10 +79,10 @@ func TestOIDAssignTo(t *testing.T) { } errorTests := []struct { - src pgtype.OID + src pgtype.Oid dst interface{} }{ - {src: pgtype.OID{Status: pgtype.Null}, dst: &ui32}, + {src: pgtype.Oid{Status: pgtype.Null}, dst: &ui32}, } for i, tt := range errorTests { diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 8c67c630..cbcd6bd5 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -7,47 +7,47 @@ import ( // PostgreSQL oids for common types const ( - BoolOID = 16 - ByteaOID = 17 - CharOID = 18 - NameOID = 19 - Int8OID = 20 - Int2OID = 21 - Int4OID = 23 - TextOID = 25 - OIDOID = 26 - TIDOID = 27 - XIDOID = 28 - CIDOID = 29 - JSONOID = 114 - CidrOID = 650 - CidrArrayOID = 651 - Float4OID = 700 - Float8OID = 701 - UnknownOID = 705 - InetOID = 869 - BoolArrayOID = 1000 - Int2ArrayOID = 1005 - Int4ArrayOID = 1007 - TextArrayOID = 1009 - ByteaArrayOID = 1001 - VarcharArrayOID = 1015 - Int8ArrayOID = 1016 - Float4ArrayOID = 1021 - Float8ArrayOID = 1022 - ACLItemOID = 1033 - ACLItemArrayOID = 1034 - InetArrayOID = 1041 - VarcharOID = 1043 - DateOID = 1082 - TimestampOID = 1114 - TimestampArrayOID = 1115 - DateArrayOID = 1182 - TimestamptzOID = 1184 - TimestamptzArrayOID = 1185 - RecordOID = 2249 - UUIDOID = 2950 - JSONBOID = 3802 + BoolOid = 16 + ByteaOid = 17 + CharOid = 18 + NameOid = 19 + Int8Oid = 20 + Int2Oid = 21 + Int4Oid = 23 + TextOid = 25 + OidOid = 26 + TidOid = 27 + XidOid = 28 + CidOid = 29 + JsonOid = 114 + CidrOid = 650 + CidrArrayOid = 651 + Float4Oid = 700 + Float8Oid = 701 + UnknownOid = 705 + InetOid = 869 + BoolArrayOid = 1000 + Int2ArrayOid = 1005 + Int4ArrayOid = 1007 + TextArrayOid = 1009 + ByteaArrayOid = 1001 + VarcharArrayOid = 1015 + Int8ArrayOid = 1016 + Float4ArrayOid = 1021 + Float8ArrayOid = 1022 + AclitemOid = 1033 + AclitemArrayOid = 1034 + InetArrayOid = 1041 + VarcharOid = 1043 + DateOid = 1082 + TimestampOid = 1114 + TimestampArrayOid = 1115 + DateArrayOid = 1182 + TimestamptzOid = 1184 + TimestamptzArrayOid = 1185 + RecordOid = 2249 + UuidOid = 2950 + JsonbOid = 3802 ) type Status byte diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 07a40160..f9b6f56d 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -47,7 +47,7 @@ func mustClose(t testing.TB, conn interface { } } -func mustParseCIDR(t testing.TB, s string) *net.IPNet { +func mustParseCidr(t testing.TB, s string) *net.IPNet { _, ipnet, err := net.ParseCIDR(s) if err != nil { t.Fatal(err) diff --git a/pgtype/pguint32.go b/pgtype/pguint32.go index df9e0d36..c636e1c4 100644 --- a/pgtype/pguint32.go +++ b/pgtype/pguint32.go @@ -10,7 +10,7 @@ import ( ) // pguint32 is the core type that is used to implement PostgreSQL types such as -// CID and XID. +// Cid and Xid. type pguint32 struct { Uint uint32 Status Status diff --git a/pgtype/text_array.go b/pgtype/text_array.go index 182e76f5..06e3c0df 100644 --- a/pgtype/text_array.go +++ b/pgtype/text_array.go @@ -229,10 +229,10 @@ func (src *TextArray) EncodeText(w io.Writer) (bool, error) { } func (src *TextArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, TextOID) + return src.encodeBinary(w, TextOid) } -func (src *TextArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) { +func (src *TextArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -241,7 +241,7 @@ func (src *TextArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) } arrayHeader := ArrayHeader{ - ElementOID: elementOID, + ElementOid: elementOid, Dimensions: src.Dimensions, } diff --git a/pgtype/tid.go b/pgtype/tid.go index 804cced2..b67892ff 100644 --- a/pgtype/tid.go +++ b/pgtype/tid.go @@ -10,7 +10,7 @@ import ( "github.com/jackc/pgx/pgio" ) -// TID is PostgreSQL's Tuple Identifier type. +// Tid is PostgreSQL's Tuple Identifier type. // // When one does // @@ -21,15 +21,15 @@ import ( // It is currently implemented as a pair unsigned two byte integers. // Its conversion functions can be found in src/backend/utils/adt/tid.c // in the PostgreSQL sources. -type TID struct { +type Tid struct { BlockNumber uint32 OffsetNumber uint16 Status Status } -func (dst *TID) DecodeText(src []byte) error { +func (dst *Tid) DecodeText(src []byte) error { if src == nil { - *dst = TID{Status: Null} + *dst = Tid{Status: Null} return nil } @@ -52,13 +52,13 @@ func (dst *TID) DecodeText(src []byte) error { return err } - *dst = TID{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber), Status: Present} + *dst = Tid{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber), Status: Present} return nil } -func (dst *TID) DecodeBinary(src []byte) error { +func (dst *Tid) DecodeBinary(src []byte) error { if src == nil { - *dst = TID{Status: Null} + *dst = Tid{Status: Null} return nil } @@ -66,7 +66,7 @@ func (dst *TID) DecodeBinary(src []byte) error { return fmt.Errorf("invalid length for tid: %v", len(src)) } - *dst = TID{ + *dst = Tid{ BlockNumber: binary.BigEndian.Uint32(src), OffsetNumber: binary.BigEndian.Uint16(src[4:]), Status: Present, @@ -74,7 +74,7 @@ func (dst *TID) DecodeBinary(src []byte) error { return nil } -func (src TID) EncodeText(w io.Writer) (bool, error) { +func (src Tid) EncodeText(w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -86,7 +86,7 @@ func (src TID) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src TID) EncodeBinary(w io.Writer) (bool, error) { +func (src Tid) EncodeBinary(w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/tid_test.go b/pgtype/tid_test.go index a5aab8a3..56595ef4 100644 --- a/pgtype/tid_test.go +++ b/pgtype/tid_test.go @@ -6,10 +6,10 @@ import ( "github.com/jackc/pgx/pgtype" ) -func TestTIDTranscode(t *testing.T) { +func TestTidTranscode(t *testing.T) { testSuccessfulTranscode(t, "tid", []interface{}{ - pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, - pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, - pgtype.TID{Status: pgtype.Null}, + pgtype.Tid{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, + pgtype.Tid{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, + pgtype.Tid{Status: pgtype.Null}, }) } diff --git a/pgtype/timestamp_array.go b/pgtype/timestamp_array.go index b0fb25fa..1ea30ba4 100644 --- a/pgtype/timestamp_array.go +++ b/pgtype/timestamp_array.go @@ -230,10 +230,10 @@ func (src *TimestampArray) EncodeText(w io.Writer) (bool, error) { } func (src *TimestampArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, TimestampOID) + return src.encodeBinary(w, TimestampOid) } -func (src *TimestampArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) { +func (src *TimestampArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -242,7 +242,7 @@ func (src *TimestampArray) encodeBinary(w io.Writer, elementOID int32) (bool, er } arrayHeader := ArrayHeader{ - ElementOID: elementOID, + ElementOid: elementOid, Dimensions: src.Dimensions, } diff --git a/pgtype/timestamptz_array.go b/pgtype/timestamptz_array.go index 25374717..fc3ce08c 100644 --- a/pgtype/timestamptz_array.go +++ b/pgtype/timestamptz_array.go @@ -230,10 +230,10 @@ func (src *TimestamptzArray) EncodeText(w io.Writer) (bool, error) { } func (src *TimestamptzArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, TimestamptzOID) + return src.encodeBinary(w, TimestamptzOid) } -func (src *TimestamptzArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) { +func (src *TimestamptzArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -242,7 +242,7 @@ func (src *TimestamptzArray) encodeBinary(w io.Writer, elementOID int32) (bool, } arrayHeader := ArrayHeader{ - ElementOID: elementOID, + ElementOid: elementOid, Dimensions: src.Dimensions, } diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb index f9dba308..98c8d845 100644 --- a/pgtype/typed_array.go.erb +++ b/pgtype/typed_array.go.erb @@ -231,7 +231,7 @@ func (src *<%= pgtype_array_type %>) EncodeBinary(w io.Writer) (bool, error) { return src.encodeBinary(w, <%= element_oid %>) } -func (src *<%= pgtype_array_type %>) encodeBinary(w io.Writer, elementOID int32) (bool, error) { +func (src *<%= pgtype_array_type %>) encodeBinary(w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -240,7 +240,7 @@ func (src *<%= pgtype_array_type %>) encodeBinary(w io.Writer, elementOID int32) } arrayHeader := ArrayHeader{ - ElementOID: elementOID, + ElementOid: elementOid, Dimensions: src.Dimensions, } diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh index 32c298cc..41c1313f 100644 --- a/pgtype/typed_array_gen.sh +++ b/pgtype/typed_array_gen.sh @@ -1,13 +1,13 @@ -erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_oid=Int2OID text_null=NULL typed_array.go.erb > int2_array.go -erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32 element_oid=Int4OID text_null=NULL typed_array.go.erb > int4_array.go -erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_oid=Int8OID text_null=NULL typed_array.go.erb > int8_array.go -erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_oid=BoolOID text_null=NULL typed_array.go.erb > bool_array.go -erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_oid=DateOID text_null=NULL typed_array.go.erb > date_array.go -erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_oid=TimestamptzOID text_null=NULL typed_array.go.erb > timestamptz_array.go -erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_oid=TimestampOID text_null=NULL typed_array.go.erb > timestamp_array.go -erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_oid=Float4OID text_null=NULL typed_array.go.erb > float4_array.go -erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8OID text_null=NULL typed_array.go.erb > float8_array.go -erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOID text_null=NULL typed_array.go.erb > inet_array.go -erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_oid=TextOID text_null='"NULL"' typed_array.go.erb > text_array.go -erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_oid=ByteaOID text_null=NULL typed_array.go.erb > bytea_array.go -erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string element_oid=ACLItemOID text_null=NULL typed_array.go.erb > aclitem_array.go +erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_oid=Int2Oid text_null=NULL typed_array.go.erb > int2_array.go +erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32 element_oid=Int4Oid text_null=NULL typed_array.go.erb > int4_array.go +erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_oid=Int8Oid text_null=NULL typed_array.go.erb > int8_array.go +erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_oid=BoolOid text_null=NULL typed_array.go.erb > bool_array.go +erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_oid=DateOid text_null=NULL typed_array.go.erb > date_array.go +erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_oid=TimestamptzOid text_null=NULL typed_array.go.erb > timestamptz_array.go +erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_oid=TimestampOid text_null=NULL typed_array.go.erb > timestamp_array.go +erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_oid=Float4Oid text_null=NULL typed_array.go.erb > float4_array.go +erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8Oid text_null=NULL typed_array.go.erb > float8_array.go +erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOid text_null=NULL typed_array.go.erb > inet_array.go +erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_oid=TextOid text_null='"NULL"' typed_array.go.erb > text_array.go +erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_oid=ByteaOid text_null=NULL typed_array.go.erb > bytea_array.go +erb pgtype_array_type=AclitemArray pgtype_element_type=Aclitem go_array_types=[]string element_oid=AclitemOid text_null=NULL typed_array.go.erb > aclitem_array.go diff --git a/pgtype/varchar_array.go b/pgtype/varchar_array.go index 9c8829d0..b9d87b7f 100644 --- a/pgtype/varchar_array.go +++ b/pgtype/varchar_array.go @@ -27,5 +27,5 @@ func (src *VarcharArray) EncodeText(w io.Writer) (bool, error) { } func (src *VarcharArray) EncodeBinary(w io.Writer) (bool, error) { - return (*TextArray)(src).encodeBinary(w, VarcharOID) + return (*TextArray)(src).encodeBinary(w, VarcharOid) } diff --git a/pgtype/xid.go b/pgtype/xid.go index 6635b21e..7deaa4f0 100644 --- a/pgtype/xid.go +++ b/pgtype/xid.go @@ -4,7 +4,7 @@ import ( "io" ) -// XID is PostgreSQL's Transaction ID type. +// Xid is PostgreSQL's Transaction ID type. // // In later versions of PostgreSQL, it is the type used for the backend_xid // and backend_xmin columns of the pg_stat_activity system view. @@ -18,33 +18,33 @@ import ( // It is currently implemented as an unsigned four byte integer. // Its definition can be found in src/include/postgres_ext.h as TransactionId // in the PostgreSQL sources. -type XID pguint32 +type Xid pguint32 -// ConvertFrom converts from src to dst. Note that as XID is not a general +// ConvertFrom converts from src to dst. Note that as Xid is not a general // number type ConvertFrom does not do automatic type conversion as other number // types do. -func (dst *XID) ConvertFrom(src interface{}) error { +func (dst *Xid) ConvertFrom(src interface{}) error { return (*pguint32)(dst).ConvertFrom(src) } -// AssignTo assigns from src to dst. Note that as XID is not a general number +// AssignTo assigns from src to dst. Note that as Xid is not a general number // type AssignTo does not do automatic type conversion as other number types do. -func (src *XID) AssignTo(dst interface{}) error { +func (src *Xid) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *XID) DecodeText(src []byte) error { +func (dst *Xid) DecodeText(src []byte) error { return (*pguint32)(dst).DecodeText(src) } -func (dst *XID) DecodeBinary(src []byte) error { +func (dst *Xid) DecodeBinary(src []byte) error { return (*pguint32)(dst).DecodeBinary(src) } -func (src XID) EncodeText(w io.Writer) (bool, error) { +func (src Xid) EncodeText(w io.Writer) (bool, error) { return (pguint32)(src).EncodeText(w) } -func (src XID) EncodeBinary(w io.Writer) (bool, error) { +func (src Xid) EncodeBinary(w io.Writer) (bool, error) { return (pguint32)(src).EncodeBinary(w) } diff --git a/pgtype/xid_test.go b/pgtype/xid_test.go index 664920bc..a5c5df51 100644 --- a/pgtype/xid_test.go +++ b/pgtype/xid_test.go @@ -7,23 +7,23 @@ import ( "github.com/jackc/pgx/pgtype" ) -func TestXIDTranscode(t *testing.T) { +func TestXidTranscode(t *testing.T) { testSuccessfulTranscode(t, "xid", []interface{}{ - pgtype.XID{Uint: 42, Status: pgtype.Present}, - pgtype.XID{Status: pgtype.Null}, + pgtype.Xid{Uint: 42, Status: pgtype.Present}, + pgtype.Xid{Status: pgtype.Null}, }) } -func TestXIDConvertFrom(t *testing.T) { +func TestXidConvertFrom(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.XID + result pgtype.Xid }{ - {source: uint32(1), result: pgtype.XID{Uint: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.Xid{Uint: 1, Status: pgtype.Present}}, } for i, tt := range successfulTests { - var r pgtype.XID + var r pgtype.Xid err := r.ConvertFrom(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -35,17 +35,17 @@ func TestXIDConvertFrom(t *testing.T) { } } -func TestXIDAssignTo(t *testing.T) { +func TestXidAssignTo(t *testing.T) { var ui32 uint32 var pui32 *uint32 simpleTests := []struct { - src pgtype.XID + src pgtype.Xid dst interface{} expected interface{} }{ - {src: pgtype.XID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.XID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + {src: pgtype.Xid{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.Xid{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, } for i, tt := range simpleTests { @@ -60,11 +60,11 @@ func TestXIDAssignTo(t *testing.T) { } pointerAllocTests := []struct { - src pgtype.XID + src pgtype.Xid dst interface{} expected interface{} }{ - {src: pgtype.XID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + {src: pgtype.Xid{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, } for i, tt := range pointerAllocTests { @@ -79,10 +79,10 @@ func TestXIDAssignTo(t *testing.T) { } errorTests := []struct { - src pgtype.XID + src pgtype.Xid dst interface{} }{ - {src: pgtype.XID{Status: pgtype.Null}, dst: &ui32}, + {src: pgtype.Xid{Status: pgtype.Null}, dst: &ui32}, } for i, tt := range errorTests { diff --git a/query.go b/query.go index 5730f1c6..2a5d88fc 100644 --- a/query.go +++ b/query.go @@ -202,7 +202,7 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { if b, ok := d.(*[]byte); ok { // If it actually is a bytea then pass it through decodeBytea (so it can be decoded if it is in text format) // Otherwise read the bytes directly regardless of what the actual type is. - if vr.Type().DataType == ByteaOID { + if vr.Type().DataType == ByteaOid { *b = decodeBytea(vr) } else { if vr.Len() != -1 { @@ -235,25 +235,25 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { var val interface{} if 0 <= vr.Len() { switch vr.Type().DataType { - case BoolOID: + case BoolOid: val = decodeBool(vr) - case Int8OID: + case Int8Oid: val = int64(decodeInt8(vr)) - case Int2OID: + case Int2Oid: val = int64(decodeInt2(vr)) - case Int4OID: + case Int4Oid: val = int64(decodeInt4(vr)) - case TextOID, VarcharOID: + case TextOid, VarcharOid: val = decodeText(vr) - case Float4OID: + case Float4Oid: val = float64(decodeFloat4(vr)) - case Float8OID: + case Float8Oid: val = decodeFloat8(vr) - case DateOID: + case DateOid: val = decodeDate(vr) - case TimestampOID: + case TimestampOid: val = decodeTimestamp(vr) - case TimestampTzOID: + case TimestampTzOid: val = decodeTimestampTz(vr) default: val = vr.ReadBytes(vr.Len()) @@ -263,14 +263,14 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } - } else if vr.Type().DataType == JSONOID { + } else if vr.Type().DataType == JsonOid { // Because the argument passed to decodeJSON will escape the heap. // This allows d to be stack allocated and only copied to the heap when // we actually are decoding JSON. This saves one memory allocation per // row. d2 := d decodeJSON(vr, &d2) - } else if vr.Type().DataType == JSONBOID { + } else if vr.Type().DataType == JsonbOid { // Same trick as above for getting stack allocation d2 := d decodeJSONB(vr, &d2) @@ -345,11 +345,11 @@ func (rows *Rows) Values() ([]interface{}, error) { // encoding so anything else should be treated as a string case TextFormatCode: switch vr.Type().DataType { - case JSONOID: + case JsonOid: var d interface{} decodeJSON(vr, &d) values = append(values, d) - case JSONBOID: + case JsonbOid: var d interface{} decodeJSONB(vr, &d) values = append(values, d) @@ -359,33 +359,33 @@ func (rows *Rows) Values() ([]interface{}, error) { case BinaryFormatCode: switch vr.Type().DataType { - case TextOID, VarcharOID: + case TextOid, VarcharOid: values = append(values, decodeText(vr)) - case BoolOID: + case BoolOid: values = append(values, decodeBool(vr)) - case ByteaOID: + case ByteaOid: values = append(values, decodeBytea(vr)) - case Int8OID: + case Int8Oid: values = append(values, decodeInt8(vr)) - case Int2OID: + case Int2Oid: values = append(values, decodeInt2(vr)) - case Int4OID: + case Int4Oid: values = append(values, decodeInt4(vr)) - case Float4OID: + case Float4Oid: values = append(values, decodeFloat4(vr)) - case Float8OID: + case Float8Oid: values = append(values, decodeFloat8(vr)) - case DateOID: + case DateOid: values = append(values, decodeDate(vr)) - case TimestampTzOID: + case TimestampTzOid: values = append(values, decodeTimestampTz(vr)) - case TimestampOID: + case TimestampOid: values = append(values, decodeTimestamp(vr)) - case JSONOID: + case JsonOid: var d interface{} decodeJSON(vr, &d) values = append(values, d) - case JSONBOID: + case JsonbOid: var d interface{} decodeJSONB(vr, &d) values = append(values, d) @@ -432,33 +432,33 @@ func (rows *Rows) ValuesForStdlib() ([]interface{}, error) { values = append(values, vr.ReadString(vr.Len())) case BinaryFormatCode: switch vr.Type().DataType { - case TextOID, VarcharOID: + case TextOid, VarcharOid: values = append(values, decodeText(vr)) - case BoolOID: + case BoolOid: values = append(values, decodeBool(vr)) - case ByteaOID: + case ByteaOid: values = append(values, decodeBytea(vr)) - case Int8OID: + case Int8Oid: values = append(values, decodeInt8(vr)) - case Int2OID: + case Int2Oid: values = append(values, decodeInt2(vr)) - case Int4OID: + case Int4Oid: values = append(values, decodeInt4(vr)) - case Float4OID: + case Float4Oid: values = append(values, decodeFloat4(vr)) - case Float8OID: + case Float8Oid: values = append(values, decodeFloat8(vr)) - case DateOID: + case DateOid: values = append(values, decodeDate(vr)) - case TimestampTzOID: + case TimestampTzOid: values = append(values, decodeTimestampTz(vr)) - case TimestampOID: + case TimestampOid: values = append(values, decodeTimestamp(vr)) - case JSONOID: + case JsonOid: var d interface{} decodeJSON(vr, &d) values = append(values, d) - case JSONBOID: + case JsonbOid: var d interface{} decodeJSONB(vr, &d) values = append(values, d) diff --git a/query_test.go b/query_test.go index bbd7871e..46b012cf 100644 --- a/query_test.go +++ b/query_test.go @@ -197,7 +197,7 @@ func TestConnQueryReadWrongTypeError(t *testing.T) { t.Fatal("Expected Rows to have an error after an improper read but it didn't") } - if rows.Err().Error() != "can't scan into dest[0]: Can't convert OID 23 to time.Time" && !strings.Contains(rows.Err().Error(), "cannot assign") { + if rows.Err().Error() != "can't scan into dest[0]: Can't convert Oid 23 to time.Time" && !strings.Contains(rows.Err().Error(), "cannot assign") { t.Fatalf("Expected different Rows.Err(): %v", rows.Err()) } @@ -403,7 +403,7 @@ type coreEncoder struct{} func (n coreEncoder) FormatCode() int16 { return pgx.TextFormatCode } -func (n *coreEncoder) Encode(w *pgx.WriteBuf, oid pgx.OID) error { +func (n *coreEncoder) Encode(w *pgx.WriteBuf, oid pgx.Oid) error { w.WriteInt32(int32(2)) w.WriteBytes([]byte("42")) return nil @@ -438,7 +438,7 @@ func TestQueryRowCoreTypes(t *testing.T) { f64 float64 b bool t time.Time - oid pgx.OID + oid pgx.Oid } var actual, zero allTypes @@ -456,7 +456,7 @@ func TestQueryRowCoreTypes(t *testing.T) { {"select $1::timestamptz", []interface{}{time.Unix(123, 5000)}, []interface{}{&actual.t}, allTypes{t: time.Unix(123, 5000)}}, {"select $1::timestamp", []interface{}{time.Date(2010, 1, 2, 3, 4, 5, 0, time.UTC)}, []interface{}{&actual.t}, allTypes{t: time.Date(2010, 1, 2, 3, 4, 5, 0, time.UTC)}}, {"select $1::date", []interface{}{time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}, []interface{}{&actual.t}, allTypes{t: time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}}, - {"select $1::oid", []interface{}{pgx.OID(42)}, []interface{}{&actual.oid}, allTypes{oid: 42}}, + {"select $1::oid", []interface{}{pgx.Oid(42)}, []interface{}{&actual.oid}, allTypes{oid: 42}}, } for i, tt := range tests { diff --git a/stdlib/sql.go b/stdlib/sql.go index d138ae1e..07cca7e0 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -58,23 +58,23 @@ var openFromConnPoolCount int // oids that map to intrinsic database/sql types. These will be allowed to be // binary, anything else will be forced to text format -var databaseSqlOIDs map[pgx.OID]bool +var databaseSqlOids map[pgx.Oid]bool func init() { d := &Driver{} sql.Register("pgx", d) - databaseSqlOIDs = make(map[pgx.OID]bool) - databaseSqlOIDs[pgx.BoolOID] = true - databaseSqlOIDs[pgx.ByteaOID] = true - databaseSqlOIDs[pgx.Int2OID] = true - databaseSqlOIDs[pgx.Int4OID] = true - databaseSqlOIDs[pgx.Int8OID] = true - databaseSqlOIDs[pgx.Float4OID] = true - databaseSqlOIDs[pgx.Float8OID] = true - databaseSqlOIDs[pgx.DateOID] = true - databaseSqlOIDs[pgx.TimestampTzOID] = true - databaseSqlOIDs[pgx.TimestampOID] = true + databaseSqlOids = make(map[pgx.Oid]bool) + databaseSqlOids[pgx.BoolOid] = true + databaseSqlOids[pgx.ByteaOid] = true + databaseSqlOids[pgx.Int2Oid] = true + databaseSqlOids[pgx.Int4Oid] = true + databaseSqlOids[pgx.Int8Oid] = true + databaseSqlOids[pgx.Float4Oid] = true + databaseSqlOids[pgx.Float8Oid] = true + databaseSqlOids[pgx.DateOid] = true + databaseSqlOids[pgx.TimestampTzOid] = true + databaseSqlOids[pgx.TimestampOid] = true } type Driver struct { @@ -263,7 +263,7 @@ func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []dr // (e.g. []int32) func restrictBinaryToDatabaseSqlTypes(ps *pgx.PreparedStatement) { for i, _ := range ps.FieldDescriptions { - intrinsic, _ := databaseSqlOIDs[ps.FieldDescriptions[i].DataType] + intrinsic, _ := databaseSqlOids[ps.FieldDescriptions[i].DataType] if !intrinsic { ps.FieldDescriptions[i].FormatCode = pgx.TextFormatCode } @@ -280,7 +280,7 @@ func (s *Stmt) Close() error { } func (s *Stmt) NumInput() int { - return len(s.ps.ParameterOIDs) + return len(s.ps.ParameterOids) } func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) { diff --git a/v3.md b/v3.md index 68619d4d..2f1c353c 100644 --- a/v3.md +++ b/v3.md @@ -2,14 +2,8 @@ ## Changes -Rename Oid to OID in accordance with Go naming conventions. - -Rename Json(b) to JSON(B) in accordance with Go naming conventions. - Rename Pid to PID in accordance with Go naming conventions. -Rename Uuid to UUID in accordance with Go naming conventions. - Logger interface reduced to single Log method. Replace BeginIso with BeginEx. BeginEx adds support for read/write mode and deferrable mode. diff --git a/value_reader.go b/value_reader.go index 85932a7d..364581c9 100644 --- a/value_reader.go +++ b/value_reader.go @@ -116,8 +116,8 @@ func (r *ValueReader) ReadInt64() int64 { return r.mr.readInt64() } -func (r *ValueReader) ReadOID() OID { - return OID(r.ReadUint32()) +func (r *ValueReader) ReadOid() Oid { + return Oid(r.ReadUint32()) } // ReadString reads count bytes and returns as string diff --git a/values.go b/values.go index 72f836bb..778284a4 100644 --- a/values.go +++ b/values.go @@ -19,47 +19,47 @@ import ( // PostgreSQL oids for common types const ( - BoolOID = 16 - ByteaOID = 17 - CharOID = 18 - NameOID = 19 - Int8OID = 20 - Int2OID = 21 - Int4OID = 23 - TextOID = 25 - OIDOID = 26 - TIDOID = 27 - XIDOID = 28 - CIDOID = 29 - JSONOID = 114 - CidrOID = 650 - CidrArrayOID = 651 - Float4OID = 700 - Float8OID = 701 - UnknownOID = 705 - InetOID = 869 - BoolArrayOID = 1000 - Int2ArrayOID = 1005 - Int4ArrayOID = 1007 - TextArrayOID = 1009 - ByteaArrayOID = 1001 - VarcharArrayOID = 1015 - Int8ArrayOID = 1016 - Float4ArrayOID = 1021 - Float8ArrayOID = 1022 - ACLItemOID = 1033 - ACLItemArrayOID = 1034 - InetArrayOID = 1041 - VarcharOID = 1043 - DateOID = 1082 - TimestampOID = 1114 - TimestampArrayOID = 1115 - DateArrayOID = 1182 - TimestampTzOID = 1184 - TimestampTzArrayOID = 1185 - RecordOID = 2249 - UUIDOID = 2950 - JSONBOID = 3802 + BoolOid = 16 + ByteaOid = 17 + CharOid = 18 + NameOid = 19 + Int8Oid = 20 + Int2Oid = 21 + Int4Oid = 23 + TextOid = 25 + OidOid = 26 + TidOid = 27 + XidOid = 28 + CidOid = 29 + JsonOid = 114 + CidrOid = 650 + CidrArrayOid = 651 + Float4Oid = 700 + Float8Oid = 701 + UnknownOid = 705 + InetOid = 869 + BoolArrayOid = 1000 + Int2ArrayOid = 1005 + Int4ArrayOid = 1007 + TextArrayOid = 1009 + ByteaArrayOid = 1001 + VarcharArrayOid = 1015 + Int8ArrayOid = 1016 + Float4ArrayOid = 1021 + Float8ArrayOid = 1022 + AclitemOid = 1033 + AclitemArrayOid = 1034 + InetArrayOid = 1041 + VarcharOid = 1043 + DateOid = 1082 + TimestampOid = 1114 + TimestampArrayOid = 1115 + DateArrayOid = 1182 + TimestampTzOid = 1184 + TimestampTzArrayOid = 1185 + RecordOid = 2249 + UuidOid = 2950 + JsonbOid = 3802 ) // PostgreSQL format codes @@ -81,7 +81,7 @@ const minInt = -maxInt - 1 var DefaultTypeFormats map[string]int16 // internalNativeGoTypeFormats lists the encoding type for native Go types (not handled with Encoder interface) -var internalNativeGoTypeFormats map[OID]int16 +var internalNativeGoTypeFormats map[Oid]int16 func init() { DefaultTypeFormats = map[string]int16{ @@ -120,36 +120,36 @@ func init() { "xid": BinaryFormatCode, } - internalNativeGoTypeFormats = map[OID]int16{ - BoolArrayOID: BinaryFormatCode, - BoolOID: BinaryFormatCode, - ByteaArrayOID: BinaryFormatCode, - ByteaOID: BinaryFormatCode, - CidrArrayOID: BinaryFormatCode, - CidrOID: BinaryFormatCode, - DateOID: BinaryFormatCode, - Float4ArrayOID: BinaryFormatCode, - Float4OID: BinaryFormatCode, - Float8ArrayOID: BinaryFormatCode, - Float8OID: BinaryFormatCode, - InetArrayOID: BinaryFormatCode, - InetOID: BinaryFormatCode, - Int2ArrayOID: BinaryFormatCode, - Int2OID: BinaryFormatCode, - Int4ArrayOID: BinaryFormatCode, - Int4OID: BinaryFormatCode, - Int8ArrayOID: BinaryFormatCode, - Int8OID: BinaryFormatCode, - JSONBOID: BinaryFormatCode, - JSONOID: BinaryFormatCode, - OIDOID: BinaryFormatCode, - RecordOID: BinaryFormatCode, - TextArrayOID: BinaryFormatCode, - TimestampArrayOID: BinaryFormatCode, - TimestampOID: BinaryFormatCode, - TimestampTzArrayOID: BinaryFormatCode, - TimestampTzOID: BinaryFormatCode, - VarcharArrayOID: BinaryFormatCode, + internalNativeGoTypeFormats = map[Oid]int16{ + BoolArrayOid: BinaryFormatCode, + BoolOid: BinaryFormatCode, + ByteaArrayOid: BinaryFormatCode, + ByteaOid: BinaryFormatCode, + CidrArrayOid: BinaryFormatCode, + CidrOid: BinaryFormatCode, + DateOid: BinaryFormatCode, + Float4ArrayOid: BinaryFormatCode, + Float4Oid: BinaryFormatCode, + Float8ArrayOid: BinaryFormatCode, + Float8Oid: BinaryFormatCode, + InetArrayOid: BinaryFormatCode, + InetOid: BinaryFormatCode, + Int2ArrayOid: BinaryFormatCode, + Int2Oid: BinaryFormatCode, + Int4ArrayOid: BinaryFormatCode, + Int4Oid: BinaryFormatCode, + Int8ArrayOid: BinaryFormatCode, + Int8Oid: BinaryFormatCode, + JsonbOid: BinaryFormatCode, + JsonOid: BinaryFormatCode, + OidOid: BinaryFormatCode, + RecordOid: BinaryFormatCode, + TextArrayOid: BinaryFormatCode, + TimestampArrayOid: BinaryFormatCode, + TimestampOid: BinaryFormatCode, + TimestampTzArrayOid: BinaryFormatCode, + TimestampTzOid: BinaryFormatCode, + VarcharArrayOid: BinaryFormatCode, } } @@ -164,7 +164,7 @@ func (e SerializationError) Error() string { // server. To allow types to support pgx and database/sql.Scan this interface // has been deprecated in favor of PgxScanner. type Scanner interface { - // Scan MUST check r.Type().DataType (to check by OID) or + // Scan MUST check r.Type().DataType (to check by Oid) or // r.Type().DataTypeName (to check by name) to ensure that it is scanning an // expected column type. It also MUST check r.Type().FormatCode before // decoding. It should not assume that it was called on a data type or format @@ -176,7 +176,7 @@ type Scanner interface { // It is used exactly the same as the Scanner interface. It simply has renamed // the method. type PgxScanner interface { - // ScanPgx MUST check r.Type().DataType (to check by OID) or + // ScanPgx MUST check r.Type().DataType (to check by Oid) or // r.Type().DataTypeName (to check by name) to ensure that it is scanning an // expected column type. It also MUST check r.Type().FormatCode before // decoding. It should not assume that it was called on a data type or format @@ -196,7 +196,7 @@ type Encoder interface { // expected data size or format of the encoded data does not match. But if // the encoded data is a valid representation of the data type PostgreSQL // expects such as date and int4, incorrect data may be stored. - Encode(w *WriteBuf, oid OID) error + Encode(w *WriteBuf, oid Oid) error // FormatCode returns the format that the encoder writes the value. It must be // either pgx.TextFormatCode or pgx.BinaryFormatCode. @@ -214,8 +214,8 @@ type NullFloat32 struct { } func (n *NullFloat32) Scan(vr *ValueReader) error { - if vr.Type().DataType != Float4OID { - return SerializationError(fmt.Sprintf("NullFloat32.Scan cannot decode OID %d", vr.Type().DataType)) + if vr.Type().DataType != Float4Oid { + return SerializationError(fmt.Sprintf("NullFloat32.Scan cannot decode Oid %d", vr.Type().DataType)) } if vr.Len() == -1 { @@ -229,9 +229,9 @@ func (n *NullFloat32) Scan(vr *ValueReader) error { func (n NullFloat32) FormatCode() int16 { return BinaryFormatCode } -func (n NullFloat32) Encode(w *WriteBuf, oid OID) error { - if oid != Float4OID { - return SerializationError(fmt.Sprintf("NullFloat32.Encode cannot encode into OID %d", oid)) +func (n NullFloat32) Encode(w *WriteBuf, oid Oid) error { + if oid != Float4Oid { + return SerializationError(fmt.Sprintf("NullFloat32.Encode cannot encode into Oid %d", oid)) } if !n.Valid { @@ -253,8 +253,8 @@ type NullFloat64 struct { } func (n *NullFloat64) Scan(vr *ValueReader) error { - if vr.Type().DataType != Float8OID { - return SerializationError(fmt.Sprintf("NullFloat64.Scan cannot decode OID %d", vr.Type().DataType)) + if vr.Type().DataType != Float8Oid { + return SerializationError(fmt.Sprintf("NullFloat64.Scan cannot decode Oid %d", vr.Type().DataType)) } if vr.Len() == -1 { @@ -268,9 +268,9 @@ func (n *NullFloat64) Scan(vr *ValueReader) error { func (n NullFloat64) FormatCode() int16 { return BinaryFormatCode } -func (n NullFloat64) Encode(w *WriteBuf, oid OID) error { - if oid != Float8OID { - return SerializationError(fmt.Sprintf("NullFloat64.Encode cannot encode into OID %d", oid)) +func (n NullFloat64) Encode(w *WriteBuf, oid Oid) error { + if oid != Float8Oid { + return SerializationError(fmt.Sprintf("NullFloat64.Encode cannot encode into Oid %d", oid)) } if !n.Valid { @@ -306,7 +306,7 @@ func (n *NullString) Scan(vr *ValueReader) error { func (n NullString) FormatCode() int16 { return TextFormatCode } -func (s NullString) Encode(w *WriteBuf, oid OID) error { +func (s NullString) Encode(w *WriteBuf, oid Oid) error { if !s.Valid { w.WriteInt32(-1) return nil @@ -326,8 +326,8 @@ type NullInt16 struct { } func (n *NullInt16) Scan(vr *ValueReader) error { - if vr.Type().DataType != Int2OID { - return SerializationError(fmt.Sprintf("NullInt16.Scan cannot decode OID %d", vr.Type().DataType)) + if vr.Type().DataType != Int2Oid { + return SerializationError(fmt.Sprintf("NullInt16.Scan cannot decode Oid %d", vr.Type().DataType)) } if vr.Len() == -1 { @@ -341,9 +341,9 @@ func (n *NullInt16) Scan(vr *ValueReader) error { func (n NullInt16) FormatCode() int16 { return BinaryFormatCode } -func (n NullInt16) Encode(w *WriteBuf, oid OID) error { - if oid != Int2OID { - return SerializationError(fmt.Sprintf("NullInt16.Encode cannot encode into OID %d", oid)) +func (n NullInt16) Encode(w *WriteBuf, oid Oid) error { + if oid != Int2Oid { + return SerializationError(fmt.Sprintf("NullInt16.Encode cannot encode into Oid %d", oid)) } if !n.Valid { @@ -368,8 +368,8 @@ type NullInt32 struct { } func (n *NullInt32) Scan(vr *ValueReader) error { - if vr.Type().DataType != Int4OID { - return SerializationError(fmt.Sprintf("NullInt32.Scan cannot decode OID %d", vr.Type().DataType)) + if vr.Type().DataType != Int4Oid { + return SerializationError(fmt.Sprintf("NullInt32.Scan cannot decode Oid %d", vr.Type().DataType)) } if vr.Len() == -1 { @@ -383,9 +383,9 @@ func (n *NullInt32) Scan(vr *ValueReader) error { func (n NullInt32) FormatCode() int16 { return BinaryFormatCode } -func (n NullInt32) Encode(w *WriteBuf, oid OID) error { - if oid != Int4OID { - return SerializationError(fmt.Sprintf("NullInt32.Encode cannot encode into OID %d", oid)) +func (n NullInt32) Encode(w *WriteBuf, oid Oid) error { + if oid != Int4Oid { + return SerializationError(fmt.Sprintf("NullInt32.Encode cannot encode into Oid %d", oid)) } if !n.Valid { @@ -399,15 +399,15 @@ func (n NullInt32) Encode(w *WriteBuf, oid OID) error { return err } -// OID (Object Identifier Type) is, according to https://www.postgresql.org/docs/current/static/datatype-oid.html, +// Oid (Object Identifier Type) is, according to https://www.postgresql.org/docs/current/static/datatype-oid.html, // used internally by PostgreSQL as a primary key for various system tables. It is currently implemented // as an unsigned four-byte integer. Its definition can be found in src/include/postgres_ext.h -// in the PostgreSQL sources. OID cannot be NULL. To allow for NULL OIDs use pgtype.OID. -type OID uint32 +// in the PostgreSQL sources. Oid cannot be NULL. To allow for NULL Oids use pgtype.Oid. +type Oid uint32 -func (dst *OID) DecodeText(src []byte) error { +func (dst *Oid) DecodeText(src []byte) error { if src == nil { - return fmt.Errorf("cannot decode nil into OID") + return fmt.Errorf("cannot decode nil into Oid") } n, err := strconv.ParseUint(string(src), 10, 32) @@ -415,13 +415,13 @@ func (dst *OID) DecodeText(src []byte) error { return err } - *dst = OID(n) + *dst = Oid(n) return nil } -func (dst *OID) DecodeBinary(src []byte) error { +func (dst *Oid) DecodeBinary(src []byte) error { if src == nil { - return fmt.Errorf("cannot decode nil into OID") + return fmt.Errorf("cannot decode nil into Oid") } if len(src) != 4 { @@ -429,16 +429,16 @@ func (dst *OID) DecodeBinary(src []byte) error { } n := binary.BigEndian.Uint32(src) - *dst = OID(n) + *dst = Oid(n) return nil } -func (src OID) EncodeText(w io.Writer) (bool, error) { +func (src Oid) EncodeText(w io.Writer) (bool, error) { _, err := io.WriteString(w, strconv.FormatUint(uint64(src), 10)) return false, err } -func (src OID) EncodeBinary(w io.Writer) (bool, error) { +func (src Oid) EncodeBinary(w io.Writer) (bool, error) { _, err := pgio.WriteUint32(w, uint32(src)) return false, err } @@ -454,8 +454,8 @@ type NullInt64 struct { } func (n *NullInt64) Scan(vr *ValueReader) error { - if vr.Type().DataType != Int8OID { - return SerializationError(fmt.Sprintf("NullInt64.Scan cannot decode OID %d", vr.Type().DataType)) + if vr.Type().DataType != Int8Oid { + return SerializationError(fmt.Sprintf("NullInt64.Scan cannot decode Oid %d", vr.Type().DataType)) } if vr.Len() == -1 { @@ -469,9 +469,9 @@ func (n *NullInt64) Scan(vr *ValueReader) error { func (n NullInt64) FormatCode() int16 { return BinaryFormatCode } -func (n NullInt64) Encode(w *WriteBuf, oid OID) error { - if oid != Int8OID { - return SerializationError(fmt.Sprintf("NullInt64.Encode cannot encode into OID %d", oid)) +func (n NullInt64) Encode(w *WriteBuf, oid Oid) error { + if oid != Int8Oid { + return SerializationError(fmt.Sprintf("NullInt64.Encode cannot encode into Oid %d", oid)) } if !n.Valid { @@ -496,8 +496,8 @@ type NullBool struct { } func (n *NullBool) Scan(vr *ValueReader) error { - if vr.Type().DataType != BoolOID { - return SerializationError(fmt.Sprintf("NullBool.Scan cannot decode OID %d", vr.Type().DataType)) + if vr.Type().DataType != BoolOid { + return SerializationError(fmt.Sprintf("NullBool.Scan cannot decode Oid %d", vr.Type().DataType)) } if vr.Len() == -1 { @@ -511,9 +511,9 @@ func (n *NullBool) Scan(vr *ValueReader) error { func (n NullBool) FormatCode() int16 { return BinaryFormatCode } -func (n NullBool) Encode(w *WriteBuf, oid OID) error { - if oid != BoolOID { - return SerializationError(fmt.Sprintf("NullBool.Encode cannot encode into OID %d", oid)) +func (n NullBool) Encode(w *WriteBuf, oid Oid) error { + if oid != BoolOid { + return SerializationError(fmt.Sprintf("NullBool.Encode cannot encode into Oid %d", oid)) } if !n.Valid { @@ -540,8 +540,8 @@ type NullTime struct { func (n *NullTime) Scan(vr *ValueReader) error { oid := vr.Type().DataType - if oid != TimestampTzOID && oid != TimestampOID && oid != DateOID { - return SerializationError(fmt.Sprintf("NullTime.Scan cannot decode OID %d", vr.Type().DataType)) + if oid != TimestampTzOid && oid != TimestampOid && oid != DateOid { + return SerializationError(fmt.Sprintf("NullTime.Scan cannot decode Oid %d", vr.Type().DataType)) } if vr.Len() == -1 { @@ -551,11 +551,11 @@ func (n *NullTime) Scan(vr *ValueReader) error { n.Valid = true switch oid { - case TimestampTzOID: + case TimestampTzOid: n.Time = decodeTimestampTz(vr) - case TimestampOID: + case TimestampOid: n.Time = decodeTimestamp(vr) - case DateOID: + case DateOid: n.Time = decodeDate(vr) } @@ -564,9 +564,9 @@ func (n *NullTime) Scan(vr *ValueReader) error { func (n NullTime) FormatCode() int16 { return BinaryFormatCode } -func (n NullTime) Encode(w *WriteBuf, oid OID) error { - if oid != TimestampTzOID && oid != TimestampOID && oid != DateOID { - return SerializationError(fmt.Sprintf("NullTime.Encode cannot encode into OID %d", oid)) +func (n NullTime) Encode(w *WriteBuf, oid Oid) error { + if oid != TimestampTzOid && oid != TimestampOid && oid != DateOid { + return SerializationError(fmt.Sprintf("NullTime.Encode cannot encode into Oid %d", oid)) } if !n.Valid { @@ -616,7 +616,7 @@ func (h *Hstore) Scan(vr *ValueReader) error { func (h Hstore) FormatCode() int16 { return TextFormatCode } -func (h Hstore) Encode(w *WriteBuf, oid OID) error { +func (h Hstore) Encode(w *WriteBuf, oid Oid) error { var buf bytes.Buffer i := 0 @@ -682,7 +682,7 @@ func (h *NullHstore) Scan(vr *ValueReader) error { func (h NullHstore) FormatCode() int16 { return TextFormatCode } -func (h NullHstore) Encode(w *WriteBuf, oid OID) error { +func (h NullHstore) Encode(w *WriteBuf, oid Oid) error { var buf bytes.Buffer if !h.Valid { @@ -714,7 +714,7 @@ func (h NullHstore) Encode(w *WriteBuf, oid OID) error { // Encode encodes arg into wbuf as the type oid. This allows implementations // of the Encoder interface to delegate the actual work of encoding to the // built-in functionality. -func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { +func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error { if arg == nil { wbuf.WriteInt32(-1) return nil @@ -772,10 +772,10 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { return Encode(wbuf, oid, arg) } - if oid == JSONOID { + if oid == JsonOid { return encodeJSON(wbuf, oid, arg) } - if oid == JSONBOID { + if oid == JsonbOid { return encodeJSONB(wbuf, oid, arg) } @@ -890,7 +890,7 @@ func Decode(vr *ValueReader, d interface{}) error { } func decodeBool(vr *ValueReader) bool { - if vr.Type().DataType != BoolOID { + if vr.Type().DataType != BoolOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into bool", vr.Type().DataType))) return false } @@ -922,11 +922,11 @@ func decodeBool(vr *ValueReader) bool { func decodeInt(vr *ValueReader) int64 { switch vr.Type().DataType { - case Int2OID: + case Int2Oid: return int64(decodeInt2(vr)) - case Int4OID: + case Int4Oid: return int64(decodeInt4(vr)) - case Int8OID: + case Int8Oid: return int64(decodeInt8(vr)) } @@ -940,7 +940,7 @@ func decodeInt8(vr *ValueReader) int64 { return 0 } - if vr.Type().DataType != Int8OID { + if vr.Type().DataType != Int8Oid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int8", vr.Type().DataType))) return 0 } @@ -972,7 +972,7 @@ func decodeInt8(vr *ValueReader) int64 { func decodeInt2(vr *ValueReader) int16 { - if vr.Type().DataType != Int2OID { + if vr.Type().DataType != Int2Oid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int16", vr.Type().DataType))) return 0 } @@ -1008,7 +1008,7 @@ func decodeInt4(vr *ValueReader) int32 { return 0 } - if vr.Type().DataType != Int4OID { + if vr.Type().DataType != Int4Oid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int32", vr.Type().DataType))) return 0 } @@ -1044,7 +1044,7 @@ func decodeFloat4(vr *ValueReader) float32 { return 0 } - if vr.Type().DataType != Float4OID { + if vr.Type().DataType != Float4Oid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into float32", vr.Type().DataType))) return 0 } @@ -1063,12 +1063,12 @@ func decodeFloat4(vr *ValueReader) float32 { return math.Float32frombits(uint32(i)) } -func encodeFloat32(w *WriteBuf, oid OID, value float32) error { +func encodeFloat32(w *WriteBuf, oid Oid, value float32) error { switch oid { - case Float4OID: + case Float4Oid: w.WriteInt32(4) w.WriteInt32(int32(math.Float32bits(value))) - case Float8OID: + case Float8Oid: w.WriteInt32(8) w.WriteInt64(int64(math.Float64bits(float64(value)))) default: @@ -1084,7 +1084,7 @@ func decodeFloat8(vr *ValueReader) float64 { return 0 } - if vr.Type().DataType != Float8OID { + if vr.Type().DataType != Float8Oid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into float64", vr.Type().DataType))) return 0 } @@ -1103,9 +1103,9 @@ func decodeFloat8(vr *ValueReader) float64 { return math.Float64frombits(uint64(i)) } -func encodeFloat64(w *WriteBuf, oid OID, value float64) error { +func encodeFloat64(w *WriteBuf, oid Oid, value float64) error { switch oid { - case Float8OID: + case Float8Oid: w.WriteInt32(8) w.WriteInt64(int64(math.Float64bits(value))) default: @@ -1138,7 +1138,7 @@ func decodeTextAllowBinary(vr *ValueReader) string { return vr.ReadString(vr.Len()) } -func encodeString(w *WriteBuf, oid OID, value string) error { +func encodeString(w *WriteBuf, oid Oid, value string) error { w.WriteInt32(int32(len(value))) w.WriteBytes([]byte(value)) return nil @@ -1149,7 +1149,7 @@ func decodeBytea(vr *ValueReader) []byte { return nil } - if vr.Type().DataType != ByteaOID { + if vr.Type().DataType != ByteaOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []byte", vr.Type().DataType))) return nil } @@ -1162,7 +1162,7 @@ func decodeBytea(vr *ValueReader) []byte { return vr.ReadBytes(vr.Len()) } -func encodeByteSlice(w *WriteBuf, oid OID, value []byte) error { +func encodeByteSlice(w *WriteBuf, oid Oid, value []byte) error { w.WriteInt32(int32(len(value))) w.WriteBytes(value) @@ -1174,7 +1174,7 @@ func decodeJSON(vr *ValueReader, d interface{}) error { return nil } - if vr.Type().DataType != JSONOID { + if vr.Type().DataType != JsonOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into json", vr.Type().DataType))) } @@ -1186,8 +1186,8 @@ func decodeJSON(vr *ValueReader, d interface{}) error { return err } -func encodeJSON(w *WriteBuf, oid OID, value interface{}) error { - if oid != JSONOID { +func encodeJSON(w *WriteBuf, oid Oid, value interface{}) error { + if oid != JsonOid { return fmt.Errorf("cannot encode JSON into oid %v", oid) } @@ -1207,7 +1207,7 @@ func decodeJSONB(vr *ValueReader, d interface{}) error { return nil } - if vr.Type().DataType != JSONBOID { + if vr.Type().DataType != JsonbOid { err := ProtocolError(fmt.Sprintf("Cannot decode oid %v into jsonb", vr.Type().DataType)) vr.Fatal(err) return err @@ -1230,8 +1230,8 @@ func decodeJSONB(vr *ValueReader, d interface{}) error { return err } -func encodeJSONB(w *WriteBuf, oid OID, value interface{}) error { - if oid != JSONBOID { +func encodeJSONB(w *WriteBuf, oid Oid, value interface{}) error { + if oid != JsonbOid { return fmt.Errorf("cannot encode JSON into oid %v", oid) } @@ -1248,7 +1248,7 @@ func encodeJSONB(w *WriteBuf, oid OID, value interface{}) error { } func decodeDate(vr *ValueReader) time.Time { - if vr.Type().DataType != DateOID { + if vr.Type().DataType != DateOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) return time.Time{} } @@ -1278,9 +1278,9 @@ func decodeDate(vr *ValueReader) time.Time { return d.Time } -func encodeTime(w *WriteBuf, oid OID, value time.Time) error { +func encodeTime(w *WriteBuf, oid Oid, value time.Time) error { switch oid { - case DateOID: + case DateOid: var d pgtype.Date err := d.ConvertFrom(value) if err != nil { @@ -1300,7 +1300,7 @@ func encodeTime(w *WriteBuf, oid OID, value time.Time) error { } return nil - case TimestampTzOID, TimestampOID: + case TimestampTzOid, TimestampOid: var t pgtype.Timestamptz err := t.ConvertFrom(value) if err != nil { @@ -1334,7 +1334,7 @@ func decodeTimestampTz(vr *ValueReader) time.Time { return zeroTime } - if vr.Type().DataType != TimestampTzOID { + if vr.Type().DataType != TimestampTzOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) return zeroTime } @@ -1372,7 +1372,7 @@ func decodeTimestamp(vr *ValueReader) time.Time { return zeroTime } - if vr.Type().DataType != TimestampOID { + if vr.Type().DataType != TimestampOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) return zeroTime } @@ -1402,7 +1402,7 @@ func decodeRecord(vr *ValueReader) []interface{} { return nil } - if vr.Type().DataType != RecordOID { + if vr.Type().DataType != RecordOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []interface{}", vr.Type().DataType))) return nil } @@ -1413,32 +1413,32 @@ func decodeRecord(vr *ValueReader) []interface{} { for i := int32(0); i < valueCount; i++ { fd := FieldDescription{FormatCode: BinaryFormatCode} fieldVR := ValueReader{mr: vr.mr, fd: &fd} - fd.DataType = vr.ReadOID() + fd.DataType = vr.ReadOid() fieldVR.valueBytesRemaining = vr.ReadInt32() vr.valueBytesRemaining -= fieldVR.valueBytesRemaining switch fd.DataType { - case BoolOID: + case BoolOid: record = append(record, decodeBool(&fieldVR)) - case ByteaOID: + case ByteaOid: record = append(record, decodeBytea(&fieldVR)) - case Int8OID: + case Int8Oid: record = append(record, decodeInt8(&fieldVR)) - case Int2OID: + case Int2Oid: record = append(record, decodeInt2(&fieldVR)) - case Int4OID: + case Int4Oid: record = append(record, decodeInt4(&fieldVR)) - case Float4OID: + case Float4Oid: record = append(record, decodeFloat4(&fieldVR)) - case Float8OID: + case Float8Oid: record = append(record, decodeFloat8(&fieldVR)) - case DateOID: + case DateOid: record = append(record, decodeDate(&fieldVR)) - case TimestampTzOID: + case TimestampTzOid: record = append(record, decodeTimestampTz(&fieldVR)) - case TimestampOID: + case TimestampOid: record = append(record, decodeTimestamp(&fieldVR)) - case TextOID, VarcharOID, UnknownOID: + case TextOid, VarcharOid, UnknownOid: record = append(record, decodeTextAllowBinary(&fieldVR)) default: vr.Fatal(fmt.Errorf("decodeRecord cannot decode oid %d", fd.DataType)) diff --git a/values_test.go b/values_test.go index eb570fe6..7b82d456 100644 --- a/values_test.go +++ b/values_test.go @@ -84,7 +84,7 @@ func TestJSONAndJSONBTranscode(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - for _, oid := range []pgx.OID{pgx.JSONOID, pgx.JSONBOID} { + for _, oid := range []pgx.Oid{pgx.JsonOid, pgx.JsonbOid} { if _, ok := conn.PgTypes[oid]; !ok { return // No JSON/JSONB type -- must be running against old PostgreSQL } @@ -232,7 +232,7 @@ func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string, format int16) } } -func mustParseCIDR(t *testing.T, s string) *net.IPNet { +func mustParseCidr(t *testing.T, s string) *net.IPNet { _, ipnet, err := net.ParseCIDR(s) if err != nil { t.Fatal(err) @@ -277,26 +277,26 @@ func TestInetCidrTranscodeIPNet(t *testing.T) { sql string value *net.IPNet }{ - {"select $1::inet", mustParseCIDR(t, "0.0.0.0/32")}, - {"select $1::inet", mustParseCIDR(t, "127.0.0.1/32")}, - {"select $1::inet", mustParseCIDR(t, "12.34.56.0/32")}, - {"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")}, - {"select $1::inet", mustParseCIDR(t, "255.0.0.0/8")}, - {"select $1::inet", mustParseCIDR(t, "255.255.255.255/32")}, - {"select $1::inet", mustParseCIDR(t, "::/128")}, - {"select $1::inet", mustParseCIDR(t, "::/0")}, - {"select $1::inet", mustParseCIDR(t, "::1/128")}, - {"select $1::inet", mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128")}, - {"select $1::cidr", mustParseCIDR(t, "0.0.0.0/32")}, - {"select $1::cidr", mustParseCIDR(t, "127.0.0.1/32")}, - {"select $1::cidr", mustParseCIDR(t, "12.34.56.0/32")}, - {"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")}, - {"select $1::cidr", mustParseCIDR(t, "255.0.0.0/8")}, - {"select $1::cidr", mustParseCIDR(t, "255.255.255.255/32")}, - {"select $1::cidr", mustParseCIDR(t, "::/128")}, - {"select $1::cidr", mustParseCIDR(t, "::/0")}, - {"select $1::cidr", mustParseCIDR(t, "::1/128")}, - {"select $1::cidr", mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128")}, + {"select $1::inet", mustParseCidr(t, "0.0.0.0/32")}, + {"select $1::inet", mustParseCidr(t, "127.0.0.1/32")}, + {"select $1::inet", mustParseCidr(t, "12.34.56.0/32")}, + {"select $1::inet", mustParseCidr(t, "192.168.1.0/24")}, + {"select $1::inet", mustParseCidr(t, "255.0.0.0/8")}, + {"select $1::inet", mustParseCidr(t, "255.255.255.255/32")}, + {"select $1::inet", mustParseCidr(t, "::/128")}, + {"select $1::inet", mustParseCidr(t, "::/0")}, + {"select $1::inet", mustParseCidr(t, "::1/128")}, + {"select $1::inet", mustParseCidr(t, "2607:f8b0:4009:80b::200e/128")}, + {"select $1::cidr", mustParseCidr(t, "0.0.0.0/32")}, + {"select $1::cidr", mustParseCidr(t, "127.0.0.1/32")}, + {"select $1::cidr", mustParseCidr(t, "12.34.56.0/32")}, + {"select $1::cidr", mustParseCidr(t, "192.168.1.0/24")}, + {"select $1::cidr", mustParseCidr(t, "255.0.0.0/8")}, + {"select $1::cidr", mustParseCidr(t, "255.255.255.255/32")}, + {"select $1::cidr", mustParseCidr(t, "::/128")}, + {"select $1::cidr", mustParseCidr(t, "::/0")}, + {"select $1::cidr", mustParseCidr(t, "::1/128")}, + {"select $1::cidr", mustParseCidr(t, "2607:f8b0:4009:80b::200e/128")}, } for i, tt := range tests { @@ -360,8 +360,8 @@ func TestInetCidrTranscodeIP(t *testing.T) { sql string value *net.IPNet }{ - {"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")}, - {"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")}, + {"select $1::inet", mustParseCidr(t, "192.168.1.0/24")}, + {"select $1::cidr", mustParseCidr(t, "192.168.1.0/24")}, } for i, tt := range failTests { var actual net.IP @@ -389,31 +389,31 @@ func TestInetCidrArrayTranscodeIPNet(t *testing.T) { { "select $1::inet[]", []*net.IPNet{ - mustParseCIDR(t, "0.0.0.0/32"), - mustParseCIDR(t, "127.0.0.1/32"), - mustParseCIDR(t, "12.34.56.0/32"), - mustParseCIDR(t, "192.168.1.0/24"), - mustParseCIDR(t, "255.0.0.0/8"), - mustParseCIDR(t, "255.255.255.255/32"), - mustParseCIDR(t, "::/128"), - mustParseCIDR(t, "::/0"), - mustParseCIDR(t, "::1/128"), - mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), + mustParseCidr(t, "0.0.0.0/32"), + mustParseCidr(t, "127.0.0.1/32"), + mustParseCidr(t, "12.34.56.0/32"), + mustParseCidr(t, "192.168.1.0/24"), + mustParseCidr(t, "255.0.0.0/8"), + mustParseCidr(t, "255.255.255.255/32"), + mustParseCidr(t, "::/128"), + mustParseCidr(t, "::/0"), + mustParseCidr(t, "::1/128"), + mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), }, }, { "select $1::cidr[]", []*net.IPNet{ - mustParseCIDR(t, "0.0.0.0/32"), - mustParseCIDR(t, "127.0.0.1/32"), - mustParseCIDR(t, "12.34.56.0/32"), - mustParseCIDR(t, "192.168.1.0/24"), - mustParseCIDR(t, "255.0.0.0/8"), - mustParseCIDR(t, "255.255.255.255/32"), - mustParseCIDR(t, "::/128"), - mustParseCIDR(t, "::/0"), - mustParseCIDR(t, "::1/128"), - mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), + mustParseCidr(t, "0.0.0.0/32"), + mustParseCidr(t, "127.0.0.1/32"), + mustParseCidr(t, "12.34.56.0/32"), + mustParseCidr(t, "192.168.1.0/24"), + mustParseCidr(t, "255.0.0.0/8"), + mustParseCidr(t, "255.255.255.255/32"), + mustParseCidr(t, "::/128"), + mustParseCidr(t, "::/0"), + mustParseCidr(t, "::1/128"), + mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), }, }, } @@ -490,15 +490,15 @@ func TestInetCidrArrayTranscodeIP(t *testing.T) { { "select $1::inet[]", []*net.IPNet{ - mustParseCIDR(t, "12.34.56.0/32"), - mustParseCIDR(t, "192.168.1.0/24"), + mustParseCidr(t, "12.34.56.0/32"), + mustParseCidr(t, "192.168.1.0/24"), }, }, { "select $1::cidr[]", []*net.IPNet{ - mustParseCIDR(t, "12.34.56.0/32"), - mustParseCIDR(t, "192.168.1.0/24"), + mustParseCidr(t, "12.34.56.0/32"), + mustParseCidr(t, "192.168.1.0/24"), }, }, } @@ -541,7 +541,7 @@ func TestInetCidrTranscodeWithJustIP(t *testing.T) { } for i, tt := range tests { - expected := mustParseCIDR(t, tt.value) + expected := mustParseCidr(t, tt.value) var actual net.IPNet err := conn.QueryRow(tt.sql, expected.IP).Scan(&actual) @@ -840,13 +840,13 @@ func TestNullXMismatch(t *testing.T) { err string }{ {"select $1::date", []interface{}{pgx.NullString{String: "foo", Valid: true}}, []interface{}{&actual.s}, "invalid input syntax for type date"}, - {"select $1::date", []interface{}{pgx.NullInt16{Int16: 1, Valid: true}}, []interface{}{&actual.i16}, "cannot encode into OID 1082"}, - {"select $1::date", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, "cannot encode into OID 1082"}, - {"select $1::date", []interface{}{pgx.NullInt64{Int64: 1, Valid: true}}, []interface{}{&actual.i64}, "cannot encode into OID 1082"}, - {"select $1::date", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: true}}, []interface{}{&actual.f32}, "cannot encode into OID 1082"}, - {"select $1::date", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: true}}, []interface{}{&actual.f64}, "cannot encode into OID 1082"}, - {"select $1::date", []interface{}{pgx.NullBool{Bool: true, Valid: true}}, []interface{}{&actual.b}, "cannot encode into OID 1082"}, - {"select $1::int4", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, "cannot encode into OID 23"}, + {"select $1::date", []interface{}{pgx.NullInt16{Int16: 1, Valid: true}}, []interface{}{&actual.i16}, "cannot encode into Oid 1082"}, + {"select $1::date", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, "cannot encode into Oid 1082"}, + {"select $1::date", []interface{}{pgx.NullInt64{Int64: 1, Valid: true}}, []interface{}{&actual.i64}, "cannot encode into Oid 1082"}, + {"select $1::date", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: true}}, []interface{}{&actual.f32}, "cannot encode into Oid 1082"}, + {"select $1::date", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: true}}, []interface{}{&actual.f64}, "cannot encode into Oid 1082"}, + {"select $1::date", []interface{}{pgx.NullBool{Bool: true, Valid: true}}, []interface{}{&actual.b}, "cannot encode into Oid 1082"}, + {"select $1::int4", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, "cannot encode into Oid 23"}, } for i, tt := range tests { From 542eac08c6bd02b174a9ec769262f54835b78f06 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Mar 2017 18:46:51 -0600 Subject: [PATCH 105/264] Add json/jsonb to pgtype --- conn.go | 2 + pgtype/json.go | 102 ++++++++++++++++++++++++++++++++ pgtype/json_test.go | 135 +++++++++++++++++++++++++++++++++++++++++++ pgtype/jsonb.go | 64 ++++++++++++++++++++ pgtype/jsonb_test.go | 135 +++++++++++++++++++++++++++++++++++++++++++ query.go | 11 ---- values.go | 7 --- 7 files changed, 438 insertions(+), 18 deletions(-) create mode 100644 pgtype/json.go create mode 100644 pgtype/json_test.go create mode 100644 pgtype/jsonb.go create mode 100644 pgtype/jsonb_test.go diff --git a/conn.go b/conn.go index 21bd8f1b..4085722c 100644 --- a/conn.go +++ b/conn.go @@ -292,6 +292,8 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl Int4Oid: &pgtype.Int4{}, Int8ArrayOid: &pgtype.Int8Array{}, Int8Oid: &pgtype.Int8{}, + JsonbOid: &pgtype.Jsonb{}, + JsonOid: &pgtype.Json{}, NameOid: &pgtype.Name{}, OidOid: &pgtype.Oid{}, TextArrayOid: &pgtype.TextArray{}, diff --git a/pgtype/json.go b/pgtype/json.go new file mode 100644 index 00000000..8a258ea4 --- /dev/null +++ b/pgtype/json.go @@ -0,0 +1,102 @@ +package pgtype + +import ( + "encoding/json" + "io" +) + +type Json struct { + Bytes []byte + Status Status +} + +func (dst *Json) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case string: + *dst = Json{Bytes: []byte(value), Status: Present} + case *string: + if value == nil { + *dst = Json{Status: Null} + } else { + *dst = Json{Bytes: []byte(*value), Status: Present} + } + case []byte: + if value == nil { + *dst = Json{Status: Null} + } else { + *dst = Json{Bytes: value, Status: Present} + } + default: + buf, err := json.Marshal(value) + if err != nil { + return err + } + *dst = Json{Bytes: buf, Status: Present} + } + + return nil +} + +func (src *Json) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *string: + if src.Status != Present { + v = nil + } else { + *v = string(src.Bytes) + } + case **string: + *v = new(string) + return src.AssignTo(*v) + case *[]byte: + if src.Status != Present { + *v = nil + } else { + buf := make([]byte, len(src.Bytes)) + copy(buf, src.Bytes) + *v = buf + } + default: + data := src.Bytes + if data == nil || src.Status != Present { + data = []byte("null") + } + + return json.Unmarshal(data, dst) + } + + return nil +} + +func (dst *Json) DecodeText(src []byte) error { + if src == nil { + *dst = Json{Status: Null} + return nil + } + + buf := make([]byte, len(src)) + copy(buf, src) + + *dst = Json{Bytes: buf, Status: Present} + return nil +} + +func (dst *Json) DecodeBinary(src []byte) error { + return dst.DecodeText(src) +} + +func (src Json) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := w.Write(src.Bytes) + return false, err +} + +func (src Json) EncodeBinary(w io.Writer) (bool, error) { + return src.EncodeText(w) +} diff --git a/pgtype/json_test.go b/pgtype/json_test.go new file mode 100644 index 00000000..87770f31 --- /dev/null +++ b/pgtype/json_test.go @@ -0,0 +1,135 @@ +package pgtype_test + +import ( + "bytes" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestJsonTranscode(t *testing.T) { + testSuccessfulTranscode(t, "json", []interface{}{ + pgtype.Json{Bytes: []byte("{}"), Status: pgtype.Present}, + pgtype.Json{Bytes: []byte("null"), Status: pgtype.Present}, + pgtype.Json{Bytes: []byte("42"), Status: pgtype.Present}, + pgtype.Json{Bytes: []byte(`"hello"`), Status: pgtype.Present}, + pgtype.Json{Status: pgtype.Null}, + }) +} + +func TestJsonConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Json + }{ + {source: "{}", result: pgtype.Json{Bytes: []byte("{}"), Status: pgtype.Present}}, + {source: []byte("{}"), result: pgtype.Json{Bytes: []byte("{}"), Status: pgtype.Present}}, + {source: ([]byte)(nil), result: pgtype.Json{Status: pgtype.Null}}, + {source: (*string)(nil), result: pgtype.Json{Status: pgtype.Null}}, + {source: []int{1, 2, 3}, result: pgtype.Json{Bytes: []byte("[1,2,3]"), Status: pgtype.Present}}, + {source: map[string]interface{}{"foo": "bar"}, result: pgtype.Json{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var d pgtype.Json + err := d.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(d, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestJsonAssignTo(t *testing.T) { + var s string + var ps *string + var b []byte + + rawStringTests := []struct { + src pgtype.Json + dst *string + expected string + }{ + {src: pgtype.Json{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &s, expected: "{}"}, + } + + for i, tt := range rawStringTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if *tt.dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } + + rawBytesTests := []struct { + src pgtype.Json + dst *[]byte + expected []byte + }{ + {src: pgtype.Json{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &b, expected: []byte("{}")}, + {src: pgtype.Json{Status: pgtype.Null}, dst: &b, expected: (([]byte)(nil))}, + } + + for i, tt := range rawBytesTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if bytes.Compare(tt.expected, *tt.dst) != 0 { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } + + var mapDst map[string]interface{} + type structDst struct { + Name string `json:"name"` + Age int `json:"age"` + } + var strDst structDst + + unmarshalTests := []struct { + src pgtype.Json + dst interface{} + expected interface{} + }{ + {src: pgtype.Json{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, + {src: pgtype.Json{Bytes: []byte(`{"name":"John","age":42}`), Status: pgtype.Present}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, + } + for i, tt := range unmarshalTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Json + dst **string + expected *string + }{ + {src: pgtype.Json{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if *tt.dst == tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } +} diff --git a/pgtype/jsonb.go b/pgtype/jsonb.go new file mode 100644 index 00000000..0739a468 --- /dev/null +++ b/pgtype/jsonb.go @@ -0,0 +1,64 @@ +package pgtype + +import ( + "fmt" + "io" +) + +type Jsonb Json + +func (dst *Jsonb) ConvertFrom(src interface{}) error { + return (*Json)(dst).ConvertFrom(src) +} + +func (src *Jsonb) AssignTo(dst interface{}) error { + return (*Json)(src).AssignTo(dst) +} + +func (dst *Jsonb) DecodeText(src []byte) error { + return (*Json)(dst).DecodeText(src) +} + +func (dst *Jsonb) DecodeBinary(src []byte) error { + if src == nil { + *dst = Jsonb{Status: Null} + return nil + } + + if len(src) == 0 { + return fmt.Errorf("jsonb too short") + } + + if src[0] != 1 { + return fmt.Errorf("unknown jsonb version number %d", src[0]) + } + src = src[1:] + + buf := make([]byte, len(src)) + copy(buf, src) + + *dst = Jsonb{Bytes: buf, Status: Present} + return nil + +} + +func (src Jsonb) EncodeText(w io.Writer) (bool, error) { + return (Json)(src).EncodeText(w) +} + +func (src Jsonb) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := w.Write([]byte{1}) + if err != nil { + return false, err + } + + _, err = w.Write(src.Bytes) + return false, err +} diff --git a/pgtype/jsonb_test.go b/pgtype/jsonb_test.go new file mode 100644 index 00000000..e42931d5 --- /dev/null +++ b/pgtype/jsonb_test.go @@ -0,0 +1,135 @@ +package pgtype_test + +import ( + "bytes" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestJsonbTranscode(t *testing.T) { + testSuccessfulTranscode(t, "jsonb", []interface{}{ + pgtype.Jsonb{Bytes: []byte("{}"), Status: pgtype.Present}, + pgtype.Jsonb{Bytes: []byte("null"), Status: pgtype.Present}, + pgtype.Jsonb{Bytes: []byte("42"), Status: pgtype.Present}, + pgtype.Jsonb{Bytes: []byte(`"hello"`), Status: pgtype.Present}, + pgtype.Jsonb{Status: pgtype.Null}, + }) +} + +func TestJsonbConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Jsonb + }{ + {source: "{}", result: pgtype.Jsonb{Bytes: []byte("{}"), Status: pgtype.Present}}, + {source: []byte("{}"), result: pgtype.Jsonb{Bytes: []byte("{}"), Status: pgtype.Present}}, + {source: ([]byte)(nil), result: pgtype.Jsonb{Status: pgtype.Null}}, + {source: (*string)(nil), result: pgtype.Jsonb{Status: pgtype.Null}}, + {source: []int{1, 2, 3}, result: pgtype.Jsonb{Bytes: []byte("[1,2,3]"), Status: pgtype.Present}}, + {source: map[string]interface{}{"foo": "bar"}, result: pgtype.Jsonb{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var d pgtype.Jsonb + err := d.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(d, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestJsonbAssignTo(t *testing.T) { + var s string + var ps *string + var b []byte + + rawStringTests := []struct { + src pgtype.Jsonb + dst *string + expected string + }{ + {src: pgtype.Jsonb{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &s, expected: "{}"}, + } + + for i, tt := range rawStringTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if *tt.dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } + + rawBytesTests := []struct { + src pgtype.Jsonb + dst *[]byte + expected []byte + }{ + {src: pgtype.Jsonb{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &b, expected: []byte("{}")}, + {src: pgtype.Jsonb{Status: pgtype.Null}, dst: &b, expected: (([]byte)(nil))}, + } + + for i, tt := range rawBytesTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if bytes.Compare(tt.expected, *tt.dst) != 0 { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } + + var mapDst map[string]interface{} + type structDst struct { + Name string `json:"name"` + Age int `json:"age"` + } + var strDst structDst + + unmarshalTests := []struct { + src pgtype.Jsonb + dst interface{} + expected interface{} + }{ + {src: pgtype.Jsonb{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, + {src: pgtype.Jsonb{Bytes: []byte(`{"name":"John","age":42}`), Status: pgtype.Present}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, + } + for i, tt := range unmarshalTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Jsonb + dst **string + expected *string + }{ + {src: pgtype.Jsonb{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if *tt.dst == tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } +} diff --git a/query.go b/query.go index 2a5d88fc..bc7aeda4 100644 --- a/query.go +++ b/query.go @@ -263,17 +263,6 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } - } else if vr.Type().DataType == JsonOid { - // Because the argument passed to decodeJSON will escape the heap. - // This allows d to be stack allocated and only copied to the heap when - // we actually are decoding JSON. This saves one memory allocation per - // row. - d2 := d - decodeJSON(vr, &d2) - } else if vr.Type().DataType == JsonbOid { - // Same trick as above for getting stack allocation - d2 := d - decodeJSONB(vr, &d2) } else { if pgVal, present := rows.conn.oidPgtypeValues[vr.Type().DataType]; present { switch vr.Type().FormatCode { diff --git a/values.go b/values.go index 778284a4..bc9e5c64 100644 --- a/values.go +++ b/values.go @@ -772,13 +772,6 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error { return Encode(wbuf, oid, arg) } - if oid == JsonOid { - return encodeJSON(wbuf, oid, arg) - } - if oid == JsonbOid { - return encodeJSONB(wbuf, oid, arg) - } - if value, ok := wbuf.conn.oidPgtypeValues[oid]; ok { if converterFrom, ok := value.(pgtype.ConverterFrom); ok { err := converterFrom.ConvertFrom(arg) From 57494a6a0fd5a2e7acd26fdcca7b2f340eac23a8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Mar 2017 19:53:02 -0600 Subject: [PATCH 106/264] Expand pgtype.Value interface - Include and rename ConvertFrom to Set - Add Get - Include AssignTo --- pgtype/aclitem.go | 15 +++++++++++++-- pgtype/aclitem_array.go | 17 ++++++++++++++--- pgtype/aclitem_array_test.go | 4 ++-- pgtype/aclitem_test.go | 4 ++-- pgtype/bool.go | 15 +++++++++++++-- pgtype/bool_array.go | 17 ++++++++++++++--- pgtype/bool_array_test.go | 4 ++-- pgtype/bool_test.go | 4 ++-- pgtype/bytea.go | 15 +++++++++++++-- pgtype/bytea_array.go | 17 ++++++++++++++--- pgtype/bytea_array_test.go | 4 ++-- pgtype/bytea_test.go | 4 ++-- pgtype/cid.go | 12 ++++++++---- pgtype/cid_test.go | 4 ++-- pgtype/cidr_array.go | 8 ++++++-- pgtype/date.go | 15 +++++++++++++-- pgtype/date_array.go | 17 ++++++++++++++--- pgtype/date_array_test.go | 4 ++-- pgtype/date_test.go | 4 ++-- pgtype/float4.go | 15 +++++++++++++-- pgtype/float4_array.go | 17 ++++++++++++++--- pgtype/float4_array_test.go | 4 ++-- pgtype/float4_test.go | 4 ++-- pgtype/float8.go | 15 +++++++++++++-- pgtype/float8_array.go | 17 ++++++++++++++--- pgtype/float8_array_test.go | 4 ++-- pgtype/float8_test.go | 4 ++-- pgtype/inet.go | 15 +++++++++++++-- pgtype/inet_array.go | 19 +++++++++++++++---- pgtype/inet_array_test.go | 4 ++-- pgtype/inet_test.go | 4 ++-- pgtype/int2.go | 15 +++++++++++++-- pgtype/int2_array.go | 19 +++++++++++++++---- pgtype/int2_array_test.go | 4 ++-- pgtype/int2_test.go | 4 ++-- pgtype/int4.go | 15 +++++++++++++-- pgtype/int4_array.go | 19 +++++++++++++++---- pgtype/int4_array_test.go | 4 ++-- pgtype/int4_test.go | 4 ++-- pgtype/int8.go | 15 +++++++++++++-- pgtype/int8_array.go | 19 +++++++++++++++---- pgtype/int8_array_test.go | 4 ++-- pgtype/int8_test.go | 4 ++-- pgtype/json.go | 18 +++++++++++++++++- pgtype/json_test.go | 4 ++-- pgtype/jsonb.go | 8 ++++++-- pgtype/jsonb_test.go | 4 ++-- pgtype/name.go | 8 ++++++-- pgtype/name_test.go | 4 ++-- pgtype/oid.go | 12 ++++++++---- pgtype/oid_test.go | 4 ++-- pgtype/pgtype.go | 13 ++++++++----- pgtype/pguint32.go | 17 ++++++++++++++--- pgtype/qchar.go | 15 +++++++++++++-- pgtype/qchar_test.go | 4 ++-- pgtype/text.go | 15 +++++++++++++-- pgtype/text_array.go | 17 ++++++++++++++--- pgtype/text_array_test.go | 4 ++-- pgtype/text_test.go | 4 ++-- pgtype/tid.go | 19 +++++++++++++++++++ pgtype/timestamp.go | 20 +++++++++++++++++--- pgtype/timestamp_array.go | 17 ++++++++++++++--- pgtype/timestamp_array_test.go | 4 ++-- pgtype/timestamp_test.go | 4 ++-- pgtype/timestamptz.go | 18 ++++++++++++++++-- pgtype/timestamptz_array.go | 17 ++++++++++++++--- pgtype/timestamptz_array_test.go | 4 ++-- pgtype/timestamptz_test.go | 4 ++-- pgtype/typed_array.go.erb | 17 ++++++++++++++--- pgtype/varchar_array.go | 8 ++++++-- pgtype/xid.go | 12 ++++++++---- pgtype/xid_test.go | 4 ++-- query.go | 8 ++------ values.go | 14 +++++--------- 74 files changed, 568 insertions(+), 185 deletions(-) diff --git a/pgtype/aclitem.go b/pgtype/aclitem.go index 821c5001..36cf3bbf 100644 --- a/pgtype/aclitem.go +++ b/pgtype/aclitem.go @@ -23,7 +23,7 @@ type Aclitem struct { Status Status } -func (dst *Aclitem) ConvertFrom(src interface{}) error { +func (dst *Aclitem) Set(src interface{}) error { switch value := src.(type) { case Aclitem: *dst = value @@ -37,7 +37,7 @@ func (dst *Aclitem) ConvertFrom(src interface{}) error { } default: if originalSrc, ok := underlyingStringType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Aclitem", value) } @@ -45,6 +45,17 @@ func (dst *Aclitem) ConvertFrom(src interface{}) error { return nil } +func (dst *Aclitem) Get() interface{} { + switch dst.Status { + case Present: + return dst.String + case Null: + return nil + default: + return dst.Status + } +} + func (src *Aclitem) AssignTo(dst interface{}) error { switch v := dst.(type) { case *string: diff --git a/pgtype/aclitem_array.go b/pgtype/aclitem_array.go index 48f5cd38..13952e5c 100644 --- a/pgtype/aclitem_array.go +++ b/pgtype/aclitem_array.go @@ -14,7 +14,7 @@ type AclitemArray struct { Status Status } -func (dst *AclitemArray) ConvertFrom(src interface{}) error { +func (dst *AclitemArray) Set(src interface{}) error { switch value := src.(type) { case AclitemArray: *dst = value @@ -27,7 +27,7 @@ func (dst *AclitemArray) ConvertFrom(src interface{}) error { } else { elements := make([]Aclitem, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -40,7 +40,7 @@ func (dst *AclitemArray) ConvertFrom(src interface{}) error { default: if originalSrc, ok := underlyingSliceType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Aclitem", value) } @@ -48,6 +48,17 @@ func (dst *AclitemArray) ConvertFrom(src interface{}) error { return nil } +func (dst *AclitemArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + func (src *AclitemArray) AssignTo(dst interface{}) error { switch v := dst.(type) { diff --git a/pgtype/aclitem_array_test.go b/pgtype/aclitem_array_test.go index e78f14c6..75c672bd 100644 --- a/pgtype/aclitem_array_test.go +++ b/pgtype/aclitem_array_test.go @@ -51,7 +51,7 @@ func TestAclitemArrayTranscode(t *testing.T) { }) } -func TestAclitemArrayConvertFrom(t *testing.T) { +func TestAclitemArraySet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.AclitemArray @@ -71,7 +71,7 @@ func TestAclitemArrayConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.AclitemArray - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/aclitem_test.go b/pgtype/aclitem_test.go index fc429acc..47e6fa84 100644 --- a/pgtype/aclitem_test.go +++ b/pgtype/aclitem_test.go @@ -15,7 +15,7 @@ func TestAclitemTranscode(t *testing.T) { }) } -func TestAclitemConvertFrom(t *testing.T) { +func TestAclitemSet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Aclitem @@ -27,7 +27,7 @@ func TestAclitemConvertFrom(t *testing.T) { for i, tt := range successfulTests { var d pgtype.Aclitem - err := d.ConvertFrom(tt.source) + err := d.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/bool.go b/pgtype/bool.go index 9764fafe..04a261c2 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -12,7 +12,7 @@ type Bool struct { Status Status } -func (dst *Bool) ConvertFrom(src interface{}) error { +func (dst *Bool) Set(src interface{}) error { switch value := src.(type) { case Bool: *dst = value @@ -26,7 +26,7 @@ func (dst *Bool) ConvertFrom(src interface{}) error { *dst = Bool{Bool: bb, Status: Present} default: if originalSrc, ok := underlyingBoolType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Bool", value) } @@ -34,6 +34,17 @@ func (dst *Bool) ConvertFrom(src interface{}) error { return nil } +func (dst *Bool) Get() interface{} { + switch dst.Status { + case Present: + return dst.Bool + case Null: + return nil + default: + return dst.Status + } +} + func (src *Bool) AssignTo(dst interface{}) error { switch v := dst.(type) { case *bool: diff --git a/pgtype/bool_array.go b/pgtype/bool_array.go index a74e9f90..fdcbf7a0 100644 --- a/pgtype/bool_array.go +++ b/pgtype/bool_array.go @@ -15,7 +15,7 @@ type BoolArray struct { Status Status } -func (dst *BoolArray) ConvertFrom(src interface{}) error { +func (dst *BoolArray) Set(src interface{}) error { switch value := src.(type) { case BoolArray: *dst = value @@ -28,7 +28,7 @@ func (dst *BoolArray) ConvertFrom(src interface{}) error { } else { elements := make([]Bool, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -41,7 +41,7 @@ func (dst *BoolArray) ConvertFrom(src interface{}) error { default: if originalSrc, ok := underlyingSliceType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Bool", value) } @@ -49,6 +49,17 @@ func (dst *BoolArray) ConvertFrom(src interface{}) error { return nil } +func (dst *BoolArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + func (src *BoolArray) AssignTo(dst interface{}) error { switch v := dst.(type) { diff --git a/pgtype/bool_array_test.go b/pgtype/bool_array_test.go index c5f15f97..a526d892 100644 --- a/pgtype/bool_array_test.go +++ b/pgtype/bool_array_test.go @@ -51,7 +51,7 @@ func TestBoolArrayTranscode(t *testing.T) { }) } -func TestBoolArrayConvertFrom(t *testing.T) { +func TestBoolArraySet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.BoolArray @@ -71,7 +71,7 @@ func TestBoolArrayConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.BoolArray - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/bool_test.go b/pgtype/bool_test.go index 374f07da..773bd99b 100644 --- a/pgtype/bool_test.go +++ b/pgtype/bool_test.go @@ -15,7 +15,7 @@ func TestBoolTranscode(t *testing.T) { }) } -func TestBoolConvertFrom(t *testing.T) { +func TestBoolSet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Bool @@ -33,7 +33,7 @@ func TestBoolConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Bool - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/bytea.go b/pgtype/bytea.go index 709499d2..9d2e20f3 100644 --- a/pgtype/bytea.go +++ b/pgtype/bytea.go @@ -12,7 +12,7 @@ type Bytea struct { Status Status } -func (dst *Bytea) ConvertFrom(src interface{}) error { +func (dst *Bytea) Set(src interface{}) error { switch value := src.(type) { case Bytea: *dst = value @@ -24,7 +24,7 @@ func (dst *Bytea) ConvertFrom(src interface{}) error { } default: if originalSrc, ok := underlyingBytesType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Bytea", value) } @@ -32,6 +32,17 @@ func (dst *Bytea) ConvertFrom(src interface{}) error { return nil } +func (dst *Bytea) Get() interface{} { + switch dst.Status { + case Present: + return dst.Bytes + case Null: + return nil + default: + return dst.Status + } +} + func (src *Bytea) AssignTo(dst interface{}) error { switch v := dst.(type) { case *[]byte: diff --git a/pgtype/bytea_array.go b/pgtype/bytea_array.go index 9003eafd..5362944a 100644 --- a/pgtype/bytea_array.go +++ b/pgtype/bytea_array.go @@ -15,7 +15,7 @@ type ByteaArray struct { Status Status } -func (dst *ByteaArray) ConvertFrom(src interface{}) error { +func (dst *ByteaArray) Set(src interface{}) error { switch value := src.(type) { case ByteaArray: *dst = value @@ -28,7 +28,7 @@ func (dst *ByteaArray) ConvertFrom(src interface{}) error { } else { elements := make([]Bytea, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -41,7 +41,7 @@ func (dst *ByteaArray) ConvertFrom(src interface{}) error { default: if originalSrc, ok := underlyingSliceType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Bytea", value) } @@ -49,6 +49,17 @@ func (dst *ByteaArray) ConvertFrom(src interface{}) error { return nil } +func (dst *ByteaArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + func (src *ByteaArray) AssignTo(dst interface{}) error { switch v := dst.(type) { diff --git a/pgtype/bytea_array_test.go b/pgtype/bytea_array_test.go index b39776d9..22c6478b 100644 --- a/pgtype/bytea_array_test.go +++ b/pgtype/bytea_array_test.go @@ -51,7 +51,7 @@ func TestByteaArrayTranscode(t *testing.T) { }) } -func TestByteaArrayConvertFrom(t *testing.T) { +func TestByteaArraySet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.ByteaArray @@ -71,7 +71,7 @@ func TestByteaArrayConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.ByteaArray - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/bytea_test.go b/pgtype/bytea_test.go index 51941387..4655a1c1 100644 --- a/pgtype/bytea_test.go +++ b/pgtype/bytea_test.go @@ -15,7 +15,7 @@ func TestByteaTranscode(t *testing.T) { }) } -func TestByteaConvertFrom(t *testing.T) { +func TestByteaSet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Bytea @@ -30,7 +30,7 @@ func TestByteaConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Bytea - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/cid.go b/pgtype/cid.go index be93a03e..20957f36 100644 --- a/pgtype/cid.go +++ b/pgtype/cid.go @@ -17,11 +17,15 @@ import ( // in the PostgreSQL sources. type Cid pguint32 -// ConvertFrom converts from src to dst. Note that as Cid is not a general -// number type ConvertFrom does not do automatic type conversion as other number +// Set converts from src to dst. Note that as Cid is not a general +// number type Set does not do automatic type conversion as other number // types do. -func (dst *Cid) ConvertFrom(src interface{}) error { - return (*pguint32)(dst).ConvertFrom(src) +func (dst *Cid) Set(src interface{}) error { + return (*pguint32)(dst).Set(src) +} + +func (dst *Cid) Get() interface{} { + return (*pguint32)(dst).Get() } // AssignTo assigns from src to dst. Note that as Cid is not a general number diff --git a/pgtype/cid_test.go b/pgtype/cid_test.go index 7d9fde34..0d114cda 100644 --- a/pgtype/cid_test.go +++ b/pgtype/cid_test.go @@ -14,7 +14,7 @@ func TestCidTranscode(t *testing.T) { }) } -func TestCidConvertFrom(t *testing.T) { +func TestCidSet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Cid @@ -24,7 +24,7 @@ func TestCidConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Cid - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/cidr_array.go b/pgtype/cidr_array.go index e0219ee5..c30c53d3 100644 --- a/pgtype/cidr_array.go +++ b/pgtype/cidr_array.go @@ -6,8 +6,12 @@ import ( type CidrArray InetArray -func (dst *CidrArray) ConvertFrom(src interface{}) error { - return (*InetArray)(dst).ConvertFrom(src) +func (dst *CidrArray) Set(src interface{}) error { + return (*InetArray)(dst).Set(src) +} + +func (dst *CidrArray) Get() interface{} { + return (*InetArray)(dst).Get() } func (src *CidrArray) AssignTo(dst interface{}) error { diff --git a/pgtype/date.go b/pgtype/date.go index b0d16e64..a3b8d99f 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -21,7 +21,7 @@ const ( infinityDayOffset = 2147483647 ) -func (dst *Date) ConvertFrom(src interface{}) error { +func (dst *Date) Set(src interface{}) error { switch value := src.(type) { case Date: *dst = value @@ -29,7 +29,7 @@ func (dst *Date) ConvertFrom(src interface{}) error { *dst = Date{Time: value, Status: Present} default: if originalSrc, ok := underlyingTimeType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Date", value) } @@ -37,6 +37,17 @@ func (dst *Date) ConvertFrom(src interface{}) error { return nil } +func (dst *Date) Get() interface{} { + switch dst.Status { + case Present: + return dst.Time + case Null: + return nil + default: + return dst.Status + } +} + func (src *Date) AssignTo(dst interface{}) error { switch v := dst.(type) { case *time.Time: diff --git a/pgtype/date_array.go b/pgtype/date_array.go index 8f7cba18..ce28e236 100644 --- a/pgtype/date_array.go +++ b/pgtype/date_array.go @@ -16,7 +16,7 @@ type DateArray struct { Status Status } -func (dst *DateArray) ConvertFrom(src interface{}) error { +func (dst *DateArray) Set(src interface{}) error { switch value := src.(type) { case DateArray: *dst = value @@ -29,7 +29,7 @@ func (dst *DateArray) ConvertFrom(src interface{}) error { } else { elements := make([]Date, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -42,7 +42,7 @@ func (dst *DateArray) ConvertFrom(src interface{}) error { default: if originalSrc, ok := underlyingSliceType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Date", value) } @@ -50,6 +50,17 @@ func (dst *DateArray) ConvertFrom(src interface{}) error { return nil } +func (dst *DateArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + func (src *DateArray) AssignTo(dst interface{}) error { switch v := dst.(type) { diff --git a/pgtype/date_array_test.go b/pgtype/date_array_test.go index 60f15983..a05f4254 100644 --- a/pgtype/date_array_test.go +++ b/pgtype/date_array_test.go @@ -52,7 +52,7 @@ func TestDateArrayTranscode(t *testing.T) { }) } -func TestDateArrayConvertFrom(t *testing.T) { +func TestDateArraySet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.DateArray @@ -72,7 +72,7 @@ func TestDateArrayConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.DateArray - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/date_test.go b/pgtype/date_test.go index 3a473b6a..eff3a521 100644 --- a/pgtype/date_test.go +++ b/pgtype/date_test.go @@ -22,7 +22,7 @@ func TestDateTranscode(t *testing.T) { }) } -func TestDateConvertFrom(t *testing.T) { +func TestDateSet(t *testing.T) { type _time time.Time successfulTests := []struct { @@ -41,7 +41,7 @@ func TestDateConvertFrom(t *testing.T) { for i, tt := range successfulTests { var d pgtype.Date - err := d.ConvertFrom(tt.source) + err := d.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/float4.go b/pgtype/float4.go index 26609ab2..a38d24db 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -15,7 +15,7 @@ type Float4 struct { Status Status } -func (dst *Float4) ConvertFrom(src interface{}) error { +func (dst *Float4) Set(src interface{}) error { switch value := src.(type) { case Float4: *dst = value @@ -81,7 +81,7 @@ func (dst *Float4) ConvertFrom(src interface{}) error { *dst = Float4{Float: float32(num), Status: Present} default: if originalSrc, ok := underlyingNumberType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Float8", value) } @@ -89,6 +89,17 @@ func (dst *Float4) ConvertFrom(src interface{}) error { return nil } +func (dst *Float4) Get() interface{} { + switch dst.Status { + case Present: + return dst.Float + case Null: + return nil + default: + return dst.Status + } +} + func (src *Float4) AssignTo(dst interface{}) error { return float64AssignTo(float64(src.Float), src.Status, dst) } diff --git a/pgtype/float4_array.go b/pgtype/float4_array.go index 632e7e4b..410a8b37 100644 --- a/pgtype/float4_array.go +++ b/pgtype/float4_array.go @@ -15,7 +15,7 @@ type Float4Array struct { Status Status } -func (dst *Float4Array) ConvertFrom(src interface{}) error { +func (dst *Float4Array) Set(src interface{}) error { switch value := src.(type) { case Float4Array: *dst = value @@ -28,7 +28,7 @@ func (dst *Float4Array) ConvertFrom(src interface{}) error { } else { elements := make([]Float4, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -41,7 +41,7 @@ func (dst *Float4Array) ConvertFrom(src interface{}) error { default: if originalSrc, ok := underlyingSliceType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Float4", value) } @@ -49,6 +49,17 @@ func (dst *Float4Array) ConvertFrom(src interface{}) error { return nil } +func (dst *Float4Array) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + func (src *Float4Array) AssignTo(dst interface{}) error { switch v := dst.(type) { diff --git a/pgtype/float4_array_test.go b/pgtype/float4_array_test.go index b22f4fbc..06a1d2e0 100644 --- a/pgtype/float4_array_test.go +++ b/pgtype/float4_array_test.go @@ -51,7 +51,7 @@ func TestFloat4ArrayTranscode(t *testing.T) { }) } -func TestFloat4ArrayConvertFrom(t *testing.T) { +func TestFloat4ArraySet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Float4Array @@ -71,7 +71,7 @@ func TestFloat4ArrayConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Float4Array - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/float4_test.go b/pgtype/float4_test.go index 62420b8d..ea60cd3a 100644 --- a/pgtype/float4_test.go +++ b/pgtype/float4_test.go @@ -18,7 +18,7 @@ func TestFloat4Transcode(t *testing.T) { }) } -func TestFloat4ConvertFrom(t *testing.T) { +func TestFloat4Set(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Float4 @@ -43,7 +43,7 @@ func TestFloat4ConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Float4 - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/float8.go b/pgtype/float8.go index 9ec9a665..9129e8ba 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -15,7 +15,7 @@ type Float8 struct { Status Status } -func (dst *Float8) ConvertFrom(src interface{}) error { +func (dst *Float8) Set(src interface{}) error { switch value := src.(type) { case Float8: *dst = value @@ -71,7 +71,7 @@ func (dst *Float8) ConvertFrom(src interface{}) error { *dst = Float8{Float: float64(num), Status: Present} default: if originalSrc, ok := underlyingNumberType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Float8", value) } @@ -79,6 +79,17 @@ func (dst *Float8) ConvertFrom(src interface{}) error { return nil } +func (dst *Float8) Get() interface{} { + switch dst.Status { + case Present: + return dst.Float + case Null: + return nil + default: + return dst.Status + } +} + func (src *Float8) AssignTo(dst interface{}) error { return float64AssignTo(src.Float, src.Status, dst) } diff --git a/pgtype/float8_array.go b/pgtype/float8_array.go index 68cf30f2..b2f70f51 100644 --- a/pgtype/float8_array.go +++ b/pgtype/float8_array.go @@ -15,7 +15,7 @@ type Float8Array struct { Status Status } -func (dst *Float8Array) ConvertFrom(src interface{}) error { +func (dst *Float8Array) Set(src interface{}) error { switch value := src.(type) { case Float8Array: *dst = value @@ -28,7 +28,7 @@ func (dst *Float8Array) ConvertFrom(src interface{}) error { } else { elements := make([]Float8, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -41,7 +41,7 @@ func (dst *Float8Array) ConvertFrom(src interface{}) error { default: if originalSrc, ok := underlyingSliceType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Float8", value) } @@ -49,6 +49,17 @@ func (dst *Float8Array) ConvertFrom(src interface{}) error { return nil } +func (dst *Float8Array) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + func (src *Float8Array) AssignTo(dst interface{}) error { switch v := dst.(type) { diff --git a/pgtype/float8_array_test.go b/pgtype/float8_array_test.go index d4402281..635e249a 100644 --- a/pgtype/float8_array_test.go +++ b/pgtype/float8_array_test.go @@ -51,7 +51,7 @@ func TestFloat8ArrayTranscode(t *testing.T) { }) } -func TestFloat8ArrayConvertFrom(t *testing.T) { +func TestFloat8ArraySet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Float8Array @@ -71,7 +71,7 @@ func TestFloat8ArrayConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Float8Array - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/float8_test.go b/pgtype/float8_test.go index 748ffd25..724e9350 100644 --- a/pgtype/float8_test.go +++ b/pgtype/float8_test.go @@ -18,7 +18,7 @@ func TestFloat8Transcode(t *testing.T) { }) } -func TestFloat8ConvertFrom(t *testing.T) { +func TestFloat8Set(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Float8 @@ -43,7 +43,7 @@ func TestFloat8ConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Float8 - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/inet.go b/pgtype/inet.go index f94622f4..00bfb30c 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -23,7 +23,7 @@ type Inet struct { Status Status } -func (dst *Inet) ConvertFrom(src interface{}) error { +func (dst *Inet) Set(src interface{}) error { switch value := src.(type) { case Inet: *dst = value @@ -43,7 +43,7 @@ func (dst *Inet) ConvertFrom(src interface{}) error { *dst = Inet{IPNet: ipnet, Status: Present} default: if originalSrc, ok := underlyingPtrType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Inet", value) } @@ -51,6 +51,17 @@ func (dst *Inet) ConvertFrom(src interface{}) error { return nil } +func (dst *Inet) Get() interface{} { + switch dst.Status { + case Present: + return dst.IPNet + case Null: + return nil + default: + return dst.Status + } +} + func (src *Inet) AssignTo(dst interface{}) error { switch v := dst.(type) { case *net.IPNet: diff --git a/pgtype/inet_array.go b/pgtype/inet_array.go index 629cd51f..4d865b4f 100644 --- a/pgtype/inet_array.go +++ b/pgtype/inet_array.go @@ -16,7 +16,7 @@ type InetArray struct { Status Status } -func (dst *InetArray) ConvertFrom(src interface{}) error { +func (dst *InetArray) Set(src interface{}) error { switch value := src.(type) { case InetArray: *dst = value @@ -29,7 +29,7 @@ func (dst *InetArray) ConvertFrom(src interface{}) error { } else { elements := make([]Inet, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -48,7 +48,7 @@ func (dst *InetArray) ConvertFrom(src interface{}) error { } else { elements := make([]Inet, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -61,7 +61,7 @@ func (dst *InetArray) ConvertFrom(src interface{}) error { default: if originalSrc, ok := underlyingSliceType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Inet", value) } @@ -69,6 +69,17 @@ func (dst *InetArray) ConvertFrom(src interface{}) error { return nil } +func (dst *InetArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + func (src *InetArray) AssignTo(dst interface{}) error { switch v := dst.(type) { diff --git a/pgtype/inet_array_test.go b/pgtype/inet_array_test.go index 523a9f8d..fe22285d 100644 --- a/pgtype/inet_array_test.go +++ b/pgtype/inet_array_test.go @@ -52,7 +52,7 @@ func TestInetArrayTranscode(t *testing.T) { }) } -func TestInetArrayConvertFrom(t *testing.T) { +func TestInetArraySet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.InetArray @@ -83,7 +83,7 @@ func TestInetArrayConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.InetArray - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/inet_test.go b/pgtype/inet_test.go index 5a326810..90b0723f 100644 --- a/pgtype/inet_test.go +++ b/pgtype/inet_test.go @@ -26,7 +26,7 @@ func TestInetTranscode(t *testing.T) { } } -func TestInetConvertFrom(t *testing.T) { +func TestInetSet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Inet @@ -39,7 +39,7 @@ func TestInetConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Inet - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/int2.go b/pgtype/int2.go index 7bdbacfe..525427c5 100644 --- a/pgtype/int2.go +++ b/pgtype/int2.go @@ -15,7 +15,7 @@ type Int2 struct { Status Status } -func (dst *Int2) ConvertFrom(src interface{}) error { +func (dst *Int2) Set(src interface{}) error { switch value := src.(type) { case Int2: *dst = value @@ -77,7 +77,7 @@ func (dst *Int2) ConvertFrom(src interface{}) error { *dst = Int2{Int: int16(num), Status: Present} default: if originalSrc, ok := underlyingNumberType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Int2", value) } @@ -85,6 +85,17 @@ func (dst *Int2) ConvertFrom(src interface{}) error { return nil } +func (dst *Int2) Get() interface{} { + switch dst.Status { + case Present: + return dst.Int + case Null: + return nil + default: + return dst.Status + } +} + func (src *Int2) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } diff --git a/pgtype/int2_array.go b/pgtype/int2_array.go index d8268c0a..28792fa5 100644 --- a/pgtype/int2_array.go +++ b/pgtype/int2_array.go @@ -15,7 +15,7 @@ type Int2Array struct { Status Status } -func (dst *Int2Array) ConvertFrom(src interface{}) error { +func (dst *Int2Array) Set(src interface{}) error { switch value := src.(type) { case Int2Array: *dst = value @@ -28,7 +28,7 @@ func (dst *Int2Array) ConvertFrom(src interface{}) error { } else { elements := make([]Int2, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -47,7 +47,7 @@ func (dst *Int2Array) ConvertFrom(src interface{}) error { } else { elements := make([]Int2, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -60,7 +60,7 @@ func (dst *Int2Array) ConvertFrom(src interface{}) error { default: if originalSrc, ok := underlyingSliceType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Int2", value) } @@ -68,6 +68,17 @@ func (dst *Int2Array) ConvertFrom(src interface{}) error { return nil } +func (dst *Int2Array) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + func (src *Int2Array) AssignTo(dst interface{}) error { switch v := dst.(type) { diff --git a/pgtype/int2_array_test.go b/pgtype/int2_array_test.go index ced0eab4..8af4523d 100644 --- a/pgtype/int2_array_test.go +++ b/pgtype/int2_array_test.go @@ -51,7 +51,7 @@ func TestInt2ArrayTranscode(t *testing.T) { }) } -func TestInt2ArrayConvertFrom(t *testing.T) { +func TestInt2ArraySet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Int2Array @@ -78,7 +78,7 @@ func TestInt2ArrayConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Int2Array - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/int2_test.go b/pgtype/int2_test.go index 8601309d..2bd8e016 100644 --- a/pgtype/int2_test.go +++ b/pgtype/int2_test.go @@ -19,7 +19,7 @@ func TestInt2Transcode(t *testing.T) { }) } -func TestInt2ConvertFrom(t *testing.T) { +func TestInt2Set(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Int2 @@ -42,7 +42,7 @@ func TestInt2ConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Int2 - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/int4.go b/pgtype/int4.go index 2d96ea48..b3203a28 100644 --- a/pgtype/int4.go +++ b/pgtype/int4.go @@ -15,7 +15,7 @@ type Int4 struct { Status Status } -func (dst *Int4) ConvertFrom(src interface{}) error { +func (dst *Int4) Set(src interface{}) error { switch value := src.(type) { case Int4: *dst = value @@ -68,7 +68,7 @@ func (dst *Int4) ConvertFrom(src interface{}) error { *dst = Int4{Int: int32(num), Status: Present} default: if originalSrc, ok := underlyingNumberType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Int8", value) } @@ -76,6 +76,17 @@ func (dst *Int4) ConvertFrom(src interface{}) error { return nil } +func (dst *Int4) Get() interface{} { + switch dst.Status { + case Present: + return dst.Int + case Null: + return nil + default: + return dst.Status + } +} + func (src *Int4) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } diff --git a/pgtype/int4_array.go b/pgtype/int4_array.go index dcdb50c1..61cedb2e 100644 --- a/pgtype/int4_array.go +++ b/pgtype/int4_array.go @@ -15,7 +15,7 @@ type Int4Array struct { Status Status } -func (dst *Int4Array) ConvertFrom(src interface{}) error { +func (dst *Int4Array) Set(src interface{}) error { switch value := src.(type) { case Int4Array: *dst = value @@ -28,7 +28,7 @@ func (dst *Int4Array) ConvertFrom(src interface{}) error { } else { elements := make([]Int4, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -47,7 +47,7 @@ func (dst *Int4Array) ConvertFrom(src interface{}) error { } else { elements := make([]Int4, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -60,7 +60,7 @@ func (dst *Int4Array) ConvertFrom(src interface{}) error { default: if originalSrc, ok := underlyingSliceType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Int4", value) } @@ -68,6 +68,17 @@ func (dst *Int4Array) ConvertFrom(src interface{}) error { return nil } +func (dst *Int4Array) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + func (src *Int4Array) AssignTo(dst interface{}) error { switch v := dst.(type) { diff --git a/pgtype/int4_array_test.go b/pgtype/int4_array_test.go index 38ba27cb..111cb56b 100644 --- a/pgtype/int4_array_test.go +++ b/pgtype/int4_array_test.go @@ -51,7 +51,7 @@ func TestInt4ArrayTranscode(t *testing.T) { }) } -func TestInt4ArrayConvertFrom(t *testing.T) { +func TestInt4ArraySet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Int4Array @@ -78,7 +78,7 @@ func TestInt4ArrayConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Int4Array - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/int4_test.go b/pgtype/int4_test.go index 0ac2e5b5..3e000182 100644 --- a/pgtype/int4_test.go +++ b/pgtype/int4_test.go @@ -19,7 +19,7 @@ func TestInt4Transcode(t *testing.T) { }) } -func TestInt4ConvertFrom(t *testing.T) { +func TestInt4Set(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Int4 @@ -42,7 +42,7 @@ func TestInt4ConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Int4 - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/int8.go b/pgtype/int8.go index 91f5b877..15ad6715 100644 --- a/pgtype/int8.go +++ b/pgtype/int8.go @@ -15,7 +15,7 @@ type Int8 struct { Status Status } -func (dst *Int8) ConvertFrom(src interface{}) error { +func (dst *Int8) Set(src interface{}) error { switch value := src.(type) { case Int8: *dst = value @@ -59,7 +59,7 @@ func (dst *Int8) ConvertFrom(src interface{}) error { *dst = Int8{Int: num, Status: Present} default: if originalSrc, ok := underlyingNumberType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Int8", value) } @@ -67,6 +67,17 @@ func (dst *Int8) ConvertFrom(src interface{}) error { return nil } +func (dst *Int8) Get() interface{} { + switch dst.Status { + case Present: + return dst.Int + case Null: + return nil + default: + return dst.Status + } +} + func (src *Int8) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } diff --git a/pgtype/int8_array.go b/pgtype/int8_array.go index ed82f079..9f4373e8 100644 --- a/pgtype/int8_array.go +++ b/pgtype/int8_array.go @@ -15,7 +15,7 @@ type Int8Array struct { Status Status } -func (dst *Int8Array) ConvertFrom(src interface{}) error { +func (dst *Int8Array) Set(src interface{}) error { switch value := src.(type) { case Int8Array: *dst = value @@ -28,7 +28,7 @@ func (dst *Int8Array) ConvertFrom(src interface{}) error { } else { elements := make([]Int8, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -47,7 +47,7 @@ func (dst *Int8Array) ConvertFrom(src interface{}) error { } else { elements := make([]Int8, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -60,7 +60,7 @@ func (dst *Int8Array) ConvertFrom(src interface{}) error { default: if originalSrc, ok := underlyingSliceType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Int8", value) } @@ -68,6 +68,17 @@ func (dst *Int8Array) ConvertFrom(src interface{}) error { return nil } +func (dst *Int8Array) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + func (src *Int8Array) AssignTo(dst interface{}) error { switch v := dst.(type) { diff --git a/pgtype/int8_array_test.go b/pgtype/int8_array_test.go index 137768c6..349a1f7e 100644 --- a/pgtype/int8_array_test.go +++ b/pgtype/int8_array_test.go @@ -51,7 +51,7 @@ func TestInt8ArrayTranscode(t *testing.T) { }) } -func TestInt8ArrayConvertFrom(t *testing.T) { +func TestInt8ArraySet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Int8Array @@ -78,7 +78,7 @@ func TestInt8ArrayConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Int8Array - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/int8_test.go b/pgtype/int8_test.go index 15762a50..e1fe69fb 100644 --- a/pgtype/int8_test.go +++ b/pgtype/int8_test.go @@ -19,7 +19,7 @@ func TestInt8Transcode(t *testing.T) { }) } -func TestInt8ConvertFrom(t *testing.T) { +func TestInt8Set(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Int8 @@ -42,7 +42,7 @@ func TestInt8ConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Int8 - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/json.go b/pgtype/json.go index 8a258ea4..ecdb3dab 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -10,7 +10,7 @@ type Json struct { Status Status } -func (dst *Json) ConvertFrom(src interface{}) error { +func (dst *Json) Set(src interface{}) error { switch value := src.(type) { case string: *dst = Json{Bytes: []byte(value), Status: Present} @@ -37,6 +37,22 @@ func (dst *Json) ConvertFrom(src interface{}) error { return nil } +func (dst *Json) Get() interface{} { + switch dst.Status { + case Present: + var i interface{} + err := json.Unmarshal(dst.Bytes, &i) + if err != nil { + return dst + } + return i + case Null: + return nil + default: + return dst.Status + } +} + func (src *Json) AssignTo(dst interface{}) error { switch v := dst.(type) { case *string: diff --git a/pgtype/json_test.go b/pgtype/json_test.go index 87770f31..b0aa8c9b 100644 --- a/pgtype/json_test.go +++ b/pgtype/json_test.go @@ -18,7 +18,7 @@ func TestJsonTranscode(t *testing.T) { }) } -func TestJsonConvertFrom(t *testing.T) { +func TestJsonSet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Json @@ -33,7 +33,7 @@ func TestJsonConvertFrom(t *testing.T) { for i, tt := range successfulTests { var d pgtype.Json - err := d.ConvertFrom(tt.source) + err := d.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/jsonb.go b/pgtype/jsonb.go index 0739a468..13062e8e 100644 --- a/pgtype/jsonb.go +++ b/pgtype/jsonb.go @@ -7,8 +7,12 @@ import ( type Jsonb Json -func (dst *Jsonb) ConvertFrom(src interface{}) error { - return (*Json)(dst).ConvertFrom(src) +func (dst *Jsonb) Set(src interface{}) error { + return (*Json)(dst).Set(src) +} + +func (dst *Jsonb) Get() interface{} { + return (*Json)(dst).Get() } func (src *Jsonb) AssignTo(dst interface{}) error { diff --git a/pgtype/jsonb_test.go b/pgtype/jsonb_test.go index e42931d5..3978b0d4 100644 --- a/pgtype/jsonb_test.go +++ b/pgtype/jsonb_test.go @@ -18,7 +18,7 @@ func TestJsonbTranscode(t *testing.T) { }) } -func TestJsonbConvertFrom(t *testing.T) { +func TestJsonbSet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Jsonb @@ -33,7 +33,7 @@ func TestJsonbConvertFrom(t *testing.T) { for i, tt := range successfulTests { var d pgtype.Jsonb - err := d.ConvertFrom(tt.source) + err := d.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/name.go b/pgtype/name.go index 513abfc7..9eb12ece 100644 --- a/pgtype/name.go +++ b/pgtype/name.go @@ -19,8 +19,12 @@ import ( // bytes applies, rather than the default 63. type Name Text -func (dst *Name) ConvertFrom(src interface{}) error { - return (*Text)(dst).ConvertFrom(src) +func (dst *Name) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +func (dst *Name) Get() interface{} { + return (*Text)(dst).Get() } func (src *Name) AssignTo(dst interface{}) error { diff --git a/pgtype/name_test.go b/pgtype/name_test.go index c5f7de17..81a766b8 100644 --- a/pgtype/name_test.go +++ b/pgtype/name_test.go @@ -15,7 +15,7 @@ func TestNameTranscode(t *testing.T) { }) } -func TestNameConvertFrom(t *testing.T) { +func TestNameSet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Name @@ -27,7 +27,7 @@ func TestNameConvertFrom(t *testing.T) { for i, tt := range successfulTests { var d pgtype.Name - err := d.ConvertFrom(tt.source) + err := d.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/oid.go b/pgtype/oid.go index c77f3f10..e57bb2e6 100644 --- a/pgtype/oid.go +++ b/pgtype/oid.go @@ -11,11 +11,15 @@ import ( // found in src/include/postgres_ext.h in the PostgreSQL sources. type Oid pguint32 -// ConvertFrom converts from src to dst. Note that as Oid is not a general -// number type ConvertFrom does not do automatic type conversion as other number +// Set converts from src to dst. Note that as Oid is not a general +// number type Set does not do automatic type conversion as other number // types do. -func (dst *Oid) ConvertFrom(src interface{}) error { - return (*pguint32)(dst).ConvertFrom(src) +func (dst *Oid) Set(src interface{}) error { + return (*pguint32)(dst).Set(src) +} + +func (dst *Oid) Get() interface{} { + return (*pguint32)(dst).Get() } // AssignTo assigns from src to dst. Note that as Oid is not a general number diff --git a/pgtype/oid_test.go b/pgtype/oid_test.go index bbab6699..b3b96959 100644 --- a/pgtype/oid_test.go +++ b/pgtype/oid_test.go @@ -14,7 +14,7 @@ func TestOidTranscode(t *testing.T) { }) } -func TestOidConvertFrom(t *testing.T) { +func TestOidSet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Oid @@ -24,7 +24,7 @@ func TestOidConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Oid - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index cbcd6bd5..5a51172e 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -66,13 +66,16 @@ const ( NegativeInfinity InfinityModifier = -Infinity ) -type Value interface{} +type Value interface { + // Set converts and assigns src to itself. + Set(src interface{}) error -type ConverterFrom interface { - ConvertFrom(src interface{}) error -} + // Get returns the simplest representation of Value. If the Value is Null or + // Undefined that is the return value. If no simpler representation is + // possible, then Get() returns Value. + Get() interface{} -type AssignerTo interface { + // AssignTo converts and assigns the Value to dst. AssignTo(dst interface{}) error } diff --git a/pgtype/pguint32.go b/pgtype/pguint32.go index c636e1c4..05c79c0e 100644 --- a/pgtype/pguint32.go +++ b/pgtype/pguint32.go @@ -16,10 +16,10 @@ type pguint32 struct { Status Status } -// ConvertFrom converts from src to dst. Note that as pguint32 is not a general -// number type ConvertFrom does not do automatic type conversion as other number +// Set converts from src to dst. Note that as pguint32 is not a general +// number type Set does not do automatic type conversion as other number // types do. -func (dst *pguint32) ConvertFrom(src interface{}) error { +func (dst *pguint32) Set(src interface{}) error { switch value := src.(type) { case uint32: *dst = pguint32{Uint: value, Status: Present} @@ -30,6 +30,17 @@ func (dst *pguint32) ConvertFrom(src interface{}) error { return nil } +func (dst *pguint32) Get() interface{} { + switch dst.Status { + case Present: + return dst.Uint + case Null: + return nil + default: + return dst.Status + } +} + // AssignTo assigns from src to dst. Note that as pguint32 is not a general number // type AssignTo does not do automatic type conversion as other number types do. func (src *pguint32) AssignTo(dst interface{}) error { diff --git a/pgtype/qchar.go b/pgtype/qchar.go index 0da1e88b..b6392cf9 100644 --- a/pgtype/qchar.go +++ b/pgtype/qchar.go @@ -23,7 +23,7 @@ type QChar struct { Status Status } -func (dst *QChar) ConvertFrom(src interface{}) error { +func (dst *QChar) Set(src interface{}) error { switch value := src.(type) { case QChar: *dst = value @@ -94,7 +94,7 @@ func (dst *QChar) ConvertFrom(src interface{}) error { *dst = QChar{Int: int8(num), Status: Present} default: if originalSrc, ok := underlyingNumberType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to QChar", value) } @@ -102,6 +102,17 @@ func (dst *QChar) ConvertFrom(src interface{}) error { return nil } +func (dst *QChar) Get() interface{} { + switch dst.Status { + case Present: + return dst.Int + case Null: + return nil + default: + return dst.Status + } +} + func (src *QChar) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } diff --git a/pgtype/qchar_test.go b/pgtype/qchar_test.go index ea7b56a8..a1b6d22e 100644 --- a/pgtype/qchar_test.go +++ b/pgtype/qchar_test.go @@ -19,7 +19,7 @@ func TestQCharTranscode(t *testing.T) { }) } -func TestQCharConvertFrom(t *testing.T) { +func TestQCharSet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.QChar @@ -42,7 +42,7 @@ func TestQCharConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.QChar - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/text.go b/pgtype/text.go index baf62d1e..50db2349 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -11,7 +11,7 @@ type Text struct { Status Status } -func (dst *Text) ConvertFrom(src interface{}) error { +func (dst *Text) Set(src interface{}) error { switch value := src.(type) { case Text: *dst = value @@ -25,7 +25,7 @@ func (dst *Text) ConvertFrom(src interface{}) error { } default: if originalSrc, ok := underlyingStringType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Text", value) } @@ -33,6 +33,17 @@ func (dst *Text) ConvertFrom(src interface{}) error { return nil } +func (dst *Text) Get() interface{} { + switch dst.Status { + case Present: + return dst.String + case Null: + return nil + default: + return dst.Status + } +} + func (src *Text) AssignTo(dst interface{}) error { switch v := dst.(type) { case *string: diff --git a/pgtype/text_array.go b/pgtype/text_array.go index 06e3c0df..3a5a64ce 100644 --- a/pgtype/text_array.go +++ b/pgtype/text_array.go @@ -15,7 +15,7 @@ type TextArray struct { Status Status } -func (dst *TextArray) ConvertFrom(src interface{}) error { +func (dst *TextArray) Set(src interface{}) error { switch value := src.(type) { case TextArray: *dst = value @@ -28,7 +28,7 @@ func (dst *TextArray) ConvertFrom(src interface{}) error { } else { elements := make([]Text, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -41,7 +41,7 @@ func (dst *TextArray) ConvertFrom(src interface{}) error { default: if originalSrc, ok := underlyingSliceType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Text", value) } @@ -49,6 +49,17 @@ func (dst *TextArray) ConvertFrom(src interface{}) error { return nil } +func (dst *TextArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + func (src *TextArray) AssignTo(dst interface{}) error { switch v := dst.(type) { diff --git a/pgtype/text_array_test.go b/pgtype/text_array_test.go index a22e003d..5a78d7bc 100644 --- a/pgtype/text_array_test.go +++ b/pgtype/text_array_test.go @@ -51,7 +51,7 @@ func TestTextArrayTranscode(t *testing.T) { }) } -func TestTextArrayConvertFrom(t *testing.T) { +func TestTextArraySet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.TextArray @@ -71,7 +71,7 @@ func TestTextArrayConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.TextArray - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/text_test.go b/pgtype/text_test.go index 6e944857..f5e20055 100644 --- a/pgtype/text_test.go +++ b/pgtype/text_test.go @@ -17,7 +17,7 @@ func TestTextTranscode(t *testing.T) { } } -func TestTextConvertFrom(t *testing.T) { +func TestTextSet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Text @@ -30,7 +30,7 @@ func TestTextConvertFrom(t *testing.T) { for i, tt := range successfulTests { var d pgtype.Text - err := d.ConvertFrom(tt.source) + err := d.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/tid.go b/pgtype/tid.go index b67892ff..20d962df 100644 --- a/pgtype/tid.go +++ b/pgtype/tid.go @@ -27,6 +27,25 @@ type Tid struct { Status Status } +func (dst *Tid) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Tid", src) +} + +func (dst *Tid) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Tid) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + func (dst *Tid) DecodeText(src []byte) error { if src == nil { *dst = Tid{Status: Null} diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index a8b628e9..a84f3881 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -23,9 +23,9 @@ type Timestamp struct { InfinityModifier } -// ConvertFrom converts src into a Timestamp and stores in dst. If src is a +// Set converts src into a Timestamp and stores in dst. If src is a // time.Time in a non-UTC time zone, the time zone is discarded. -func (dst *Timestamp) ConvertFrom(src interface{}) error { +func (dst *Timestamp) Set(src interface{}) error { switch value := src.(type) { case Timestamp: *dst = value @@ -33,7 +33,7 @@ func (dst *Timestamp) ConvertFrom(src interface{}) error { *dst = Timestamp{Time: time.Date(value.Year(), value.Month(), value.Day(), value.Hour(), value.Minute(), value.Second(), value.Nanosecond(), time.UTC), Status: Present} default: if originalSrc, ok := underlyingTimeType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Timestamp", value) } @@ -41,6 +41,20 @@ func (dst *Timestamp) ConvertFrom(src interface{}) error { return nil } +func (dst *Timestamp) Get() interface{} { + switch dst.Status { + case Present: + if dst.InfinityModifier != None { + return dst.InfinityModifier + } + return dst.Time + case Null: + return nil + default: + return dst.Status + } +} + func (src *Timestamp) AssignTo(dst interface{}) error { switch v := dst.(type) { case *time.Time: diff --git a/pgtype/timestamp_array.go b/pgtype/timestamp_array.go index 1ea30ba4..ec0facb2 100644 --- a/pgtype/timestamp_array.go +++ b/pgtype/timestamp_array.go @@ -16,7 +16,7 @@ type TimestampArray struct { Status Status } -func (dst *TimestampArray) ConvertFrom(src interface{}) error { +func (dst *TimestampArray) Set(src interface{}) error { switch value := src.(type) { case TimestampArray: *dst = value @@ -29,7 +29,7 @@ func (dst *TimestampArray) ConvertFrom(src interface{}) error { } else { elements := make([]Timestamp, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -42,7 +42,7 @@ func (dst *TimestampArray) ConvertFrom(src interface{}) error { default: if originalSrc, ok := underlyingSliceType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Timestamp", value) } @@ -50,6 +50,17 @@ func (dst *TimestampArray) ConvertFrom(src interface{}) error { return nil } +func (dst *TimestampArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + func (src *TimestampArray) AssignTo(dst interface{}) error { switch v := dst.(type) { diff --git a/pgtype/timestamp_array_test.go b/pgtype/timestamp_array_test.go index 68189cc7..a15d3696 100644 --- a/pgtype/timestamp_array_test.go +++ b/pgtype/timestamp_array_test.go @@ -68,7 +68,7 @@ func TestTimestampArrayTranscode(t *testing.T) { }) } -func TestTimestampArrayConvertFrom(t *testing.T) { +func TestTimestampArraySet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.TimestampArray @@ -88,7 +88,7 @@ func TestTimestampArrayConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.TimestampArray - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/timestamp_test.go b/pgtype/timestamp_test.go index 6d6e738c..7297ed1f 100644 --- a/pgtype/timestamp_test.go +++ b/pgtype/timestamp_test.go @@ -31,7 +31,7 @@ func TestTimestampTranscode(t *testing.T) { }) } -func TestTimestampConvertFrom(t *testing.T) { +func TestTimestampSet(t *testing.T) { type _time time.Time successfulTests := []struct { @@ -51,7 +51,7 @@ func TestTimestampConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Timestamp - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index f4c67b0b..a6922d5b 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -26,7 +26,7 @@ type Timestamptz struct { InfinityModifier } -func (dst *Timestamptz) ConvertFrom(src interface{}) error { +func (dst *Timestamptz) Set(src interface{}) error { switch value := src.(type) { case Timestamptz: *dst = value @@ -34,7 +34,7 @@ func (dst *Timestamptz) ConvertFrom(src interface{}) error { *dst = Timestamptz{Time: value, Status: Present} default: if originalSrc, ok := underlyingTimeType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Timestamptz", value) } @@ -42,6 +42,20 @@ func (dst *Timestamptz) ConvertFrom(src interface{}) error { return nil } +func (dst *Timestamptz) Get() interface{} { + switch dst.Status { + case Present: + if dst.InfinityModifier != None { + return dst.InfinityModifier + } + return dst.Time + case Null: + return nil + default: + return dst.Status + } +} + func (src *Timestamptz) AssignTo(dst interface{}) error { switch v := dst.(type) { case *time.Time: diff --git a/pgtype/timestamptz_array.go b/pgtype/timestamptz_array.go index fc3ce08c..775ec970 100644 --- a/pgtype/timestamptz_array.go +++ b/pgtype/timestamptz_array.go @@ -16,7 +16,7 @@ type TimestamptzArray struct { Status Status } -func (dst *TimestamptzArray) ConvertFrom(src interface{}) error { +func (dst *TimestamptzArray) Set(src interface{}) error { switch value := src.(type) { case TimestamptzArray: *dst = value @@ -29,7 +29,7 @@ func (dst *TimestamptzArray) ConvertFrom(src interface{}) error { } else { elements := make([]Timestamptz, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -42,7 +42,7 @@ func (dst *TimestamptzArray) ConvertFrom(src interface{}) error { default: if originalSrc, ok := underlyingSliceType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to Timestamptz", value) } @@ -50,6 +50,17 @@ func (dst *TimestamptzArray) ConvertFrom(src interface{}) error { return nil } +func (dst *TimestamptzArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + func (src *TimestamptzArray) AssignTo(dst interface{}) error { switch v := dst.(type) { diff --git a/pgtype/timestamptz_array_test.go b/pgtype/timestamptz_array_test.go index af2c004b..e0017828 100644 --- a/pgtype/timestamptz_array_test.go +++ b/pgtype/timestamptz_array_test.go @@ -68,7 +68,7 @@ func TestTimestamptzArrayTranscode(t *testing.T) { }) } -func TestTimestamptzArrayConvertFrom(t *testing.T) { +func TestTimestamptzArraySet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.TimestamptzArray @@ -88,7 +88,7 @@ func TestTimestamptzArrayConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.TimestamptzArray - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/timestamptz_test.go b/pgtype/timestamptz_test.go index 8f80ca81..242cd05f 100644 --- a/pgtype/timestamptz_test.go +++ b/pgtype/timestamptz_test.go @@ -31,7 +31,7 @@ func TestTimestamptzTranscode(t *testing.T) { }) } -func TestTimestamptzConvertFrom(t *testing.T) { +func TestTimestamptzSet(t *testing.T) { type _time time.Time successfulTests := []struct { @@ -50,7 +50,7 @@ func TestTimestamptzConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Timestamptz - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb index 98c8d845..c62e2896 100644 --- a/pgtype/typed_array.go.erb +++ b/pgtype/typed_array.go.erb @@ -14,7 +14,7 @@ type <%= pgtype_array_type %> struct { Status Status } -func (dst *<%= pgtype_array_type %>) ConvertFrom(src interface{}) error { +func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { switch value := src.(type) { case <%= pgtype_array_type %>: *dst = value @@ -27,7 +27,7 @@ func (dst *<%= pgtype_array_type %>) ConvertFrom(src interface{}) error { } else { elements := make([]<%= pgtype_element_type %>, len(value)) for i := range value { - if err := elements[i].ConvertFrom(value[i]); err != nil { + if err := elements[i].Set(value[i]); err != nil { return err } } @@ -40,7 +40,7 @@ func (dst *<%= pgtype_array_type %>) ConvertFrom(src interface{}) error { <% end %> default: if originalSrc, ok := underlyingSliceType(src); ok { - return dst.ConvertFrom(originalSrc) + return dst.Set(originalSrc) } return fmt.Errorf("cannot convert %v to <%= pgtype_element_type %>", value) } @@ -48,6 +48,17 @@ func (dst *<%= pgtype_array_type %>) ConvertFrom(src interface{}) error { return nil } +func (dst *<%= pgtype_array_type %>) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { switch v := dst.(type) { <% go_array_types.split(",").each do |t| %> diff --git a/pgtype/varchar_array.go b/pgtype/varchar_array.go index b9d87b7f..693b9a61 100644 --- a/pgtype/varchar_array.go +++ b/pgtype/varchar_array.go @@ -6,8 +6,12 @@ import ( type VarcharArray TextArray -func (dst *VarcharArray) ConvertFrom(src interface{}) error { - return (*TextArray)(dst).ConvertFrom(src) +func (dst *VarcharArray) Set(src interface{}) error { + return (*TextArray)(dst).Set(src) +} + +func (dst *VarcharArray) Get() interface{} { + return (*TextArray)(dst).Get() } func (src *VarcharArray) AssignTo(dst interface{}) error { diff --git a/pgtype/xid.go b/pgtype/xid.go index 7deaa4f0..a53120de 100644 --- a/pgtype/xid.go +++ b/pgtype/xid.go @@ -20,11 +20,15 @@ import ( // in the PostgreSQL sources. type Xid pguint32 -// ConvertFrom converts from src to dst. Note that as Xid is not a general -// number type ConvertFrom does not do automatic type conversion as other number +// Set converts from src to dst. Note that as Xid is not a general +// number type Set does not do automatic type conversion as other number // types do. -func (dst *Xid) ConvertFrom(src interface{}) error { - return (*pguint32)(dst).ConvertFrom(src) +func (dst *Xid) Set(src interface{}) error { + return (*pguint32)(dst).Set(src) +} + +func (dst *Xid) Get() interface{} { + return (*pguint32)(dst).Get() } // AssignTo assigns from src to dst. Note that as Xid is not a general number diff --git a/pgtype/xid_test.go b/pgtype/xid_test.go index a5c5df51..fecfb64b 100644 --- a/pgtype/xid_test.go +++ b/pgtype/xid_test.go @@ -14,7 +14,7 @@ func TestXidTranscode(t *testing.T) { }) } -func TestXidConvertFrom(t *testing.T) { +func TestXidSet(t *testing.T) { successfulTests := []struct { source interface{} result pgtype.Xid @@ -24,7 +24,7 @@ func TestXidConvertFrom(t *testing.T) { for i, tt := range successfulTests { var r pgtype.Xid - err := r.ConvertFrom(tt.source) + err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) } diff --git a/query.go b/query.go index bc7aeda4..8adb7d80 100644 --- a/query.go +++ b/query.go @@ -288,12 +288,8 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { vr.Fatal(fmt.Errorf("unknown format code: %v", vr.Type().FormatCode)) } - if assignerTo, ok := pgVal.(pgtype.AssignerTo); ok { - if err := assignerTo.AssignTo(d); err != nil { - vr.Fatal(err) - } - } else { - vr.Fatal(fmt.Errorf("cannot assign %T", pgVal)) + if err := pgVal.AssignTo(d); err != nil { + vr.Fatal(err) } } else { if err := Decode(vr, d); err != nil { diff --git a/values.go b/values.go index bc9e5c64..e976d0d3 100644 --- a/values.go +++ b/values.go @@ -773,13 +773,9 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error { } if value, ok := wbuf.conn.oidPgtypeValues[oid]; ok { - if converterFrom, ok := value.(pgtype.ConverterFrom); ok { - err := converterFrom.ConvertFrom(arg) - if err != nil { - return err - } - } else { - return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) + err := value.Set(arg) + if err != nil { + return err } buf := &bytes.Buffer{} @@ -1275,7 +1271,7 @@ func encodeTime(w *WriteBuf, oid Oid, value time.Time) error { switch oid { case DateOid: var d pgtype.Date - err := d.ConvertFrom(value) + err := d.Set(value) if err != nil { return err } @@ -1295,7 +1291,7 @@ func encodeTime(w *WriteBuf, oid Oid, value time.Time) error { case TimestampTzOid, TimestampOid: var t pgtype.Timestamptz - err := t.ConvertFrom(value) + err := t.Set(value) if err != nil { return err } From 7da69cd3db0dc389d7f0a0cb5e76402093b81e1b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Mar 2017 20:07:31 -0600 Subject: [PATCH 107/264] Restructure *Rows.Values() to use Get() --- query.go | 76 +++++++++++++++------------------------------ replication_test.go | 6 +--- 2 files changed, 26 insertions(+), 56 deletions(-) diff --git a/query.go b/query.go index 8adb7d80..6e191665 100644 --- a/query.go +++ b/query.go @@ -324,59 +324,33 @@ func (rows *Rows) Values() ([]interface{}, error) { values = append(values, nil) continue } - // TODO - consider what are the implications of returning complex types since database/sql uses this method - switch vr.Type().FormatCode { - // All intrinsic types (except string) are encoded with binary - // encoding so anything else should be treated as a string - case TextFormatCode: - switch vr.Type().DataType { - case JsonOid: - var d interface{} - decodeJSON(vr, &d) - values = append(values, d) - case JsonbOid: - var d interface{} - decodeJSONB(vr, &d) - values = append(values, d) - default: - values = append(values, vr.ReadString(vr.Len())) - } - case BinaryFormatCode: - switch vr.Type().DataType { - case TextOid, VarcharOid: - values = append(values, decodeText(vr)) - case BoolOid: - values = append(values, decodeBool(vr)) - case ByteaOid: - values = append(values, decodeBytea(vr)) - case Int8Oid: - values = append(values, decodeInt8(vr)) - case Int2Oid: - values = append(values, decodeInt2(vr)) - case Int4Oid: - values = append(values, decodeInt4(vr)) - case Float4Oid: - values = append(values, decodeFloat4(vr)) - case Float8Oid: - values = append(values, decodeFloat8(vr)) - case DateOid: - values = append(values, decodeDate(vr)) - case TimestampTzOid: - values = append(values, decodeTimestampTz(vr)) - case TimestampOid: - values = append(values, decodeTimestamp(vr)) - case JsonOid: - var d interface{} - decodeJSON(vr, &d) - values = append(values, d) - case JsonbOid: - var d interface{} - decodeJSONB(vr, &d) - values = append(values, d) - default: - rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types")) + pgVal := rows.conn.oidPgtypeValues[vr.Type().DataType].(pgtype.TextDecoder) + if pgVal == nil { + panic("need GenericText or GenericBinary") + } + + switch vr.Type().FormatCode { + case TextFormatCode: + decoder := rows.conn.oidPgtypeValues[vr.Type().DataType].(pgtype.TextDecoder) + if decoder == nil { + panic("need GenericText") } + err := decoder.DecodeText(vr.bytes()) + if err != nil { + rows.Fatal(err) + } + values = append(values, decoder.(pgtype.Value).Get()) + case BinaryFormatCode: + decoder := rows.conn.oidPgtypeValues[vr.Type().DataType].(pgtype.BinaryDecoder) + if decoder == nil { + panic("need GenericBinary") + } + err := decoder.DecodeBinary(vr.bytes()) + if err != nil { + rows.Fatal(err) + } + values = append(values, decoder.(pgtype.Value).Get()) default: rows.Fatal(errors.New("Unknown format code")) } diff --git a/replication_test.go b/replication_test.go index 54ef4b66..d75233c1 100644 --- a/replication_test.go +++ b/replication_test.go @@ -246,11 +246,7 @@ func getCurrentTimeline(t *testing.T, rc *pgx.ReplicationConn) int { if e != nil { t.Error(e) } - timeline, e := strconv.Atoi(values[1].(string)) - if e != nil { - t.Error(e) - } - return timeline + return int(values[1].(int32)) } t.Fatal("Failed to read timeline") return -1 From 5cf4b97681046da2d910808072d3a30772674c19 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Mar 2017 20:12:47 -0600 Subject: [PATCH 108/264] Document that Decode* must not keep src - Also fix Bytea.DecodeBinary to not keep src. --- pgtype/bytea.go | 5 ++++- pgtype/pgtype.go | 6 ++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/pgtype/bytea.go b/pgtype/bytea.go index 9d2e20f3..a8ee55ae 100644 --- a/pgtype/bytea.go +++ b/pgtype/bytea.go @@ -106,7 +106,10 @@ func (dst *Bytea) DecodeBinary(src []byte) error { return nil } - *dst = Bytea{Bytes: src, Status: Present} + buf := make([]byte, len(src)) + copy(buf, src) + + *dst = Bytea{Bytes: buf, Status: Present} return nil } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 5a51172e..7b1470b7 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -80,10 +80,16 @@ type Value interface { } type BinaryDecoder interface { + // DecodeBinary decodes src into BinaryDecoder. If src is nil then the + // original SQL value is NULL. BinaryDecoder MUST not retain a reference to + // src. It MUST make a copy if it needs to retain the raw bytes. DecodeBinary(src []byte) error } type TextDecoder interface { + // DecodeText decodes src into TextDecoder. If src is nil then the original + // SQL value is NULL. TextDecoder MUST not retain a reference to src. It MUST + // make a copy if it needs to retain the raw bytes. DecodeText(src []byte) error } From aac8fd66f20d1e0763230238b75266d890337c02 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Mar 2017 20:18:56 -0600 Subject: [PATCH 109/264] Remove Set self support from pgtype Set having the capability to assign an object of the same type was inconsistently implemented. Some places it was not implemented at all, some places it was a shallow copy, some places a deep copy. Given that it doesn't seem likely to ever be used, and if it is needed it is easy enough to do outside of the library this code has been removed. --- pgtype/aclitem.go | 2 -- pgtype/aclitem_array.go | 2 -- pgtype/aclitem_test.go | 1 - pgtype/bool.go | 2 -- pgtype/bool_array.go | 2 -- pgtype/bool_test.go | 1 - pgtype/bytea.go | 2 -- pgtype/bytea_array.go | 2 -- pgtype/bytea_test.go | 1 - pgtype/date.go | 2 -- pgtype/date_array.go | 2 -- pgtype/date_test.go | 1 - pgtype/float4.go | 2 -- pgtype/float4_array.go | 2 -- pgtype/float8.go | 2 -- pgtype/float8_array.go | 2 -- pgtype/inet.go | 2 -- pgtype/inet_array.go | 2 -- pgtype/inet_test.go | 1 - pgtype/int2.go | 2 -- pgtype/int2_array.go | 2 -- pgtype/int4.go | 2 -- pgtype/int4_array.go | 2 -- pgtype/int8.go | 2 -- pgtype/int8_array.go | 2 -- pgtype/qchar.go | 2 -- pgtype/text.go | 2 -- pgtype/text_array.go | 2 -- pgtype/text_test.go | 1 - pgtype/timestamp.go | 2 -- pgtype/timestamp_array.go | 2 -- pgtype/timestamp_test.go | 1 - pgtype/timestamptz.go | 2 -- pgtype/timestamptz_array.go | 2 -- pgtype/timestamptz_test.go | 1 - pgtype/typed_array.go.erb | 2 -- 36 files changed, 64 deletions(-) diff --git a/pgtype/aclitem.go b/pgtype/aclitem.go index 36cf3bbf..b8a1549e 100644 --- a/pgtype/aclitem.go +++ b/pgtype/aclitem.go @@ -25,8 +25,6 @@ type Aclitem struct { func (dst *Aclitem) Set(src interface{}) error { switch value := src.(type) { - case Aclitem: - *dst = value case string: *dst = Aclitem{String: value, Status: Present} case *string: diff --git a/pgtype/aclitem_array.go b/pgtype/aclitem_array.go index 13952e5c..5e3647b7 100644 --- a/pgtype/aclitem_array.go +++ b/pgtype/aclitem_array.go @@ -16,8 +16,6 @@ type AclitemArray struct { func (dst *AclitemArray) Set(src interface{}) error { switch value := src.(type) { - case AclitemArray: - *dst = value case []string: if value == nil { diff --git a/pgtype/aclitem_test.go b/pgtype/aclitem_test.go index 47e6fa84..1738025a 100644 --- a/pgtype/aclitem_test.go +++ b/pgtype/aclitem_test.go @@ -20,7 +20,6 @@ func TestAclitemSet(t *testing.T) { source interface{} result pgtype.Aclitem }{ - {source: pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, result: pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, {source: "postgres=arwdDxt/postgres", result: pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, {source: (*string)(nil), result: pgtype.Aclitem{Status: pgtype.Null}}, } diff --git a/pgtype/bool.go b/pgtype/bool.go index 04a261c2..a8e9b8e1 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -14,8 +14,6 @@ type Bool struct { func (dst *Bool) Set(src interface{}) error { switch value := src.(type) { - case Bool: - *dst = value case bool: *dst = Bool{Bool: value, Status: Present} case string: diff --git a/pgtype/bool_array.go b/pgtype/bool_array.go index fdcbf7a0..4c5fc563 100644 --- a/pgtype/bool_array.go +++ b/pgtype/bool_array.go @@ -17,8 +17,6 @@ type BoolArray struct { func (dst *BoolArray) Set(src interface{}) error { switch value := src.(type) { - case BoolArray: - *dst = value case []bool: if value == nil { diff --git a/pgtype/bool_test.go b/pgtype/bool_test.go index 773bd99b..412e2fd0 100644 --- a/pgtype/bool_test.go +++ b/pgtype/bool_test.go @@ -20,7 +20,6 @@ func TestBoolSet(t *testing.T) { source interface{} result pgtype.Bool }{ - {source: pgtype.Bool{Bool: false, Status: pgtype.Null}, result: pgtype.Bool{Bool: false, Status: pgtype.Null}}, {source: true, result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, {source: false, result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, {source: "true", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, diff --git a/pgtype/bytea.go b/pgtype/bytea.go index a8ee55ae..5df05360 100644 --- a/pgtype/bytea.go +++ b/pgtype/bytea.go @@ -14,8 +14,6 @@ type Bytea struct { func (dst *Bytea) Set(src interface{}) error { switch value := src.(type) { - case Bytea: - *dst = value case []byte: if value != nil { *dst = Bytea{Bytes: value, Status: Present} diff --git a/pgtype/bytea_array.go b/pgtype/bytea_array.go index 5362944a..c6f676a4 100644 --- a/pgtype/bytea_array.go +++ b/pgtype/bytea_array.go @@ -17,8 +17,6 @@ type ByteaArray struct { func (dst *ByteaArray) Set(src interface{}) error { switch value := src.(type) { - case ByteaArray: - *dst = value case [][]byte: if value == nil { diff --git a/pgtype/bytea_test.go b/pgtype/bytea_test.go index 4655a1c1..e21296c6 100644 --- a/pgtype/bytea_test.go +++ b/pgtype/bytea_test.go @@ -20,7 +20,6 @@ func TestByteaSet(t *testing.T) { source interface{} result pgtype.Bytea }{ - {source: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Null}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Null}}, {source: []byte{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, {source: []byte{}, result: pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}}, {source: []byte(nil), result: pgtype.Bytea{Status: pgtype.Null}}, diff --git a/pgtype/date.go b/pgtype/date.go index a3b8d99f..d0481637 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -23,8 +23,6 @@ const ( func (dst *Date) Set(src interface{}) error { switch value := src.(type) { - case Date: - *dst = value case time.Time: *dst = Date{Time: value, Status: Present} default: diff --git a/pgtype/date_array.go b/pgtype/date_array.go index ce28e236..7f602d83 100644 --- a/pgtype/date_array.go +++ b/pgtype/date_array.go @@ -18,8 +18,6 @@ type DateArray struct { func (dst *DateArray) Set(src interface{}) error { switch value := src.(type) { - case DateArray: - *dst = value case []time.Time: if value == nil { diff --git a/pgtype/date_test.go b/pgtype/date_test.go index eff3a521..cfc3dd70 100644 --- a/pgtype/date_test.go +++ b/pgtype/date_test.go @@ -29,7 +29,6 @@ func TestDateSet(t *testing.T) { source interface{} result pgtype.Date }{ - {source: pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, result: pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, {source: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), result: pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, diff --git a/pgtype/float4.go b/pgtype/float4.go index a38d24db..053af44b 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -17,8 +17,6 @@ type Float4 struct { func (dst *Float4) Set(src interface{}) error { switch value := src.(type) { - case Float4: - *dst = value case float32: *dst = Float4{Float: value, Status: Present} case float64: diff --git a/pgtype/float4_array.go b/pgtype/float4_array.go index 410a8b37..0e815e0b 100644 --- a/pgtype/float4_array.go +++ b/pgtype/float4_array.go @@ -17,8 +17,6 @@ type Float4Array struct { func (dst *Float4Array) Set(src interface{}) error { switch value := src.(type) { - case Float4Array: - *dst = value case []float32: if value == nil { diff --git a/pgtype/float8.go b/pgtype/float8.go index 9129e8ba..635b7a09 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -17,8 +17,6 @@ type Float8 struct { func (dst *Float8) Set(src interface{}) error { switch value := src.(type) { - case Float8: - *dst = value case float32: *dst = Float8{Float: float64(value), Status: Present} case float64: diff --git a/pgtype/float8_array.go b/pgtype/float8_array.go index b2f70f51..811c5a1f 100644 --- a/pgtype/float8_array.go +++ b/pgtype/float8_array.go @@ -17,8 +17,6 @@ type Float8Array struct { func (dst *Float8Array) Set(src interface{}) error { switch value := src.(type) { - case Float8Array: - *dst = value case []float64: if value == nil { diff --git a/pgtype/inet.go b/pgtype/inet.go index 00bfb30c..87d675f9 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -25,8 +25,6 @@ type Inet struct { func (dst *Inet) Set(src interface{}) error { switch value := src.(type) { - case Inet: - *dst = value case net.IPNet: *dst = Inet{IPNet: &value, Status: Present} case *net.IPNet: diff --git a/pgtype/inet_array.go b/pgtype/inet_array.go index 4d865b4f..1d1cf3fd 100644 --- a/pgtype/inet_array.go +++ b/pgtype/inet_array.go @@ -18,8 +18,6 @@ type InetArray struct { func (dst *InetArray) Set(src interface{}) error { switch value := src.(type) { - case InetArray: - *dst = value case []*net.IPNet: if value == nil { diff --git a/pgtype/inet_test.go b/pgtype/inet_test.go index 90b0723f..16035fca 100644 --- a/pgtype/inet_test.go +++ b/pgtype/inet_test.go @@ -31,7 +31,6 @@ func TestInetSet(t *testing.T) { source interface{} result pgtype.Inet }{ - {source: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Null}, result: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Null}}, {source: mustParseCidr(t, "127.0.0.1/32"), result: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, {source: mustParseCidr(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, {source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, diff --git a/pgtype/int2.go b/pgtype/int2.go index 525427c5..62e1bc69 100644 --- a/pgtype/int2.go +++ b/pgtype/int2.go @@ -17,8 +17,6 @@ type Int2 struct { func (dst *Int2) Set(src interface{}) error { switch value := src.(type) { - case Int2: - *dst = value case int8: *dst = Int2{Int: int16(value), Status: Present} case uint8: diff --git a/pgtype/int2_array.go b/pgtype/int2_array.go index 28792fa5..3d06c018 100644 --- a/pgtype/int2_array.go +++ b/pgtype/int2_array.go @@ -17,8 +17,6 @@ type Int2Array struct { func (dst *Int2Array) Set(src interface{}) error { switch value := src.(type) { - case Int2Array: - *dst = value case []int16: if value == nil { diff --git a/pgtype/int4.go b/pgtype/int4.go index b3203a28..8eaf5094 100644 --- a/pgtype/int4.go +++ b/pgtype/int4.go @@ -17,8 +17,6 @@ type Int4 struct { func (dst *Int4) Set(src interface{}) error { switch value := src.(type) { - case Int4: - *dst = value case int8: *dst = Int4{Int: int32(value), Status: Present} case uint8: diff --git a/pgtype/int4_array.go b/pgtype/int4_array.go index 61cedb2e..5cd91c04 100644 --- a/pgtype/int4_array.go +++ b/pgtype/int4_array.go @@ -17,8 +17,6 @@ type Int4Array struct { func (dst *Int4Array) Set(src interface{}) error { switch value := src.(type) { - case Int4Array: - *dst = value case []int32: if value == nil { diff --git a/pgtype/int8.go b/pgtype/int8.go index 15ad6715..2416500d 100644 --- a/pgtype/int8.go +++ b/pgtype/int8.go @@ -17,8 +17,6 @@ type Int8 struct { func (dst *Int8) Set(src interface{}) error { switch value := src.(type) { - case Int8: - *dst = value case int8: *dst = Int8{Int: int64(value), Status: Present} case uint8: diff --git a/pgtype/int8_array.go b/pgtype/int8_array.go index 9f4373e8..5efc0f45 100644 --- a/pgtype/int8_array.go +++ b/pgtype/int8_array.go @@ -17,8 +17,6 @@ type Int8Array struct { func (dst *Int8Array) Set(src interface{}) error { switch value := src.(type) { - case Int8Array: - *dst = value case []int64: if value == nil { diff --git a/pgtype/qchar.go b/pgtype/qchar.go index b6392cf9..d46e716d 100644 --- a/pgtype/qchar.go +++ b/pgtype/qchar.go @@ -25,8 +25,6 @@ type QChar struct { func (dst *QChar) Set(src interface{}) error { switch value := src.(type) { - case QChar: - *dst = value case int8: *dst = QChar{Int: value, Status: Present} case uint8: diff --git a/pgtype/text.go b/pgtype/text.go index 50db2349..3dd082c9 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -13,8 +13,6 @@ type Text struct { func (dst *Text) Set(src interface{}) error { switch value := src.(type) { - case Text: - *dst = value case string: *dst = Text{String: value, Status: Present} case *string: diff --git a/pgtype/text_array.go b/pgtype/text_array.go index 3a5a64ce..1e6677a9 100644 --- a/pgtype/text_array.go +++ b/pgtype/text_array.go @@ -17,8 +17,6 @@ type TextArray struct { func (dst *TextArray) Set(src interface{}) error { switch value := src.(type) { - case TextArray: - *dst = value case []string: if value == nil { diff --git a/pgtype/text_test.go b/pgtype/text_test.go index f5e20055..39348bcc 100644 --- a/pgtype/text_test.go +++ b/pgtype/text_test.go @@ -22,7 +22,6 @@ func TestTextSet(t *testing.T) { source interface{} result pgtype.Text }{ - {source: pgtype.Text{String: "foo", Status: pgtype.Present}, result: pgtype.Text{String: "foo", Status: pgtype.Present}}, {source: "foo", result: pgtype.Text{String: "foo", Status: pgtype.Present}}, {source: _string("bar"), result: pgtype.Text{String: "bar", Status: pgtype.Present}}, {source: (*string)(nil), result: pgtype.Text{Status: pgtype.Null}}, diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index a84f3881..3bb8f080 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -27,8 +27,6 @@ type Timestamp struct { // time.Time in a non-UTC time zone, the time zone is discarded. func (dst *Timestamp) Set(src interface{}) error { switch value := src.(type) { - case Timestamp: - *dst = value case time.Time: *dst = Timestamp{Time: time.Date(value.Year(), value.Month(), value.Day(), value.Hour(), value.Minute(), value.Second(), value.Nanosecond(), time.UTC), Status: Present} default: diff --git a/pgtype/timestamp_array.go b/pgtype/timestamp_array.go index ec0facb2..c955dc42 100644 --- a/pgtype/timestamp_array.go +++ b/pgtype/timestamp_array.go @@ -18,8 +18,6 @@ type TimestampArray struct { func (dst *TimestampArray) Set(src interface{}) error { switch value := src.(type) { - case TimestampArray: - *dst = value case []time.Time: if value == nil { diff --git a/pgtype/timestamp_test.go b/pgtype/timestamp_test.go index 7297ed1f..58828806 100644 --- a/pgtype/timestamp_test.go +++ b/pgtype/timestamp_test.go @@ -38,7 +38,6 @@ func TestTimestampSet(t *testing.T) { source interface{} result pgtype.Timestamp }{ - {source: pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, result: pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, {source: time.Date(1999, 12, 31, 12, 59, 59, 0, time.UTC), result: pgtype.Timestamp{Time: time.Date(1999, 12, 31, 12, 59, 59, 0, time.UTC), Status: pgtype.Present}}, diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index a6922d5b..5b9f5038 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -28,8 +28,6 @@ type Timestamptz struct { func (dst *Timestamptz) Set(src interface{}) error { switch value := src.(type) { - case Timestamptz: - *dst = value case time.Time: *dst = Timestamptz{Time: value, Status: Present} default: diff --git a/pgtype/timestamptz_array.go b/pgtype/timestamptz_array.go index 775ec970..cd63e02e 100644 --- a/pgtype/timestamptz_array.go +++ b/pgtype/timestamptz_array.go @@ -18,8 +18,6 @@ type TimestamptzArray struct { func (dst *TimestamptzArray) Set(src interface{}) error { switch value := src.(type) { - case TimestamptzArray: - *dst = value case []time.Time: if value == nil { diff --git a/pgtype/timestamptz_test.go b/pgtype/timestamptz_test.go index 242cd05f..6ddfc1bc 100644 --- a/pgtype/timestamptz_test.go +++ b/pgtype/timestamptz_test.go @@ -38,7 +38,6 @@ func TestTimestamptzSet(t *testing.T) { source interface{} result pgtype.Timestamptz }{ - {source: pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, result: pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, {source: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}}, {source: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), result: pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 12, 59, 59, 0, time.Local), Status: pgtype.Present}}, diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb index c62e2896..a56097c0 100644 --- a/pgtype/typed_array.go.erb +++ b/pgtype/typed_array.go.erb @@ -16,8 +16,6 @@ type <%= pgtype_array_type %> struct { func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { switch value := src.(type) { - case <%= pgtype_array_type %>: - *dst = value <% go_array_types.split(",").each do |t| %> case <%= t %>: if value == nil { From 3391818847e471469870aaf8f5760a191c44c899 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Mar 2017 20:28:14 -0600 Subject: [PATCH 110/264] Add pgtype GenericText and GenericBinary Rows.Values uses this for unknown types. --- pgtype/generic_binary.go | 29 +++++++++++++++++++++++++++++ pgtype/generic_text.go | 29 +++++++++++++++++++++++++++++ query.go | 9 ++------- 3 files changed, 60 insertions(+), 7 deletions(-) create mode 100644 pgtype/generic_binary.go create mode 100644 pgtype/generic_text.go diff --git a/pgtype/generic_binary.go b/pgtype/generic_binary.go new file mode 100644 index 00000000..ac35ea60 --- /dev/null +++ b/pgtype/generic_binary.go @@ -0,0 +1,29 @@ +package pgtype + +import ( + "io" +) + +// GenericBinary is a placeholder for binary format values that no other type exists +// to handle. +type GenericBinary Bytea + +func (dst *GenericBinary) Set(src interface{}) error { + return (*Bytea)(dst).Set(src) +} + +func (dst *GenericBinary) Get() interface{} { + return (*Bytea)(dst).Get() +} + +func (src *GenericBinary) AssignTo(dst interface{}) error { + return (*Bytea)(src).AssignTo(dst) +} + +func (dst *GenericBinary) DecodeBinary(src []byte) error { + return (*Bytea)(dst).DecodeBinary(src) +} + +func (src GenericBinary) EncodeBinary(w io.Writer) (bool, error) { + return (Bytea)(src).EncodeBinary(w) +} diff --git a/pgtype/generic_text.go b/pgtype/generic_text.go new file mode 100644 index 00000000..19f41059 --- /dev/null +++ b/pgtype/generic_text.go @@ -0,0 +1,29 @@ +package pgtype + +import ( + "io" +) + +// GenericText is a placeholder for text format values that no other type exists +// to handle. +type GenericText Text + +func (dst *GenericText) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +func (dst *GenericText) Get() interface{} { + return (*Text)(dst).Get() +} + +func (src *GenericText) AssignTo(dst interface{}) error { + return (*Text)(src).AssignTo(dst) +} + +func (dst *GenericText) DecodeText(src []byte) error { + return (*Text)(dst).DecodeText(src) +} + +func (src GenericText) EncodeText(w io.Writer) (bool, error) { + return (Text)(src).EncodeText(w) +} diff --git a/query.go b/query.go index 6e191665..d8caa08d 100644 --- a/query.go +++ b/query.go @@ -325,16 +325,11 @@ func (rows *Rows) Values() ([]interface{}, error) { continue } - pgVal := rows.conn.oidPgtypeValues[vr.Type().DataType].(pgtype.TextDecoder) - if pgVal == nil { - panic("need GenericText or GenericBinary") - } - switch vr.Type().FormatCode { case TextFormatCode: decoder := rows.conn.oidPgtypeValues[vr.Type().DataType].(pgtype.TextDecoder) if decoder == nil { - panic("need GenericText") + decoder = &pgtype.GenericText{} } err := decoder.DecodeText(vr.bytes()) if err != nil { @@ -344,7 +339,7 @@ func (rows *Rows) Values() ([]interface{}, error) { case BinaryFormatCode: decoder := rows.conn.oidPgtypeValues[vr.Type().DataType].(pgtype.BinaryDecoder) if decoder == nil { - panic("need GenericBinary") + decoder = &pgtype.GenericBinary{} } err := decoder.DecodeBinary(vr.bytes()) if err != nil { From 7bb1f3677d7288a72aeed5750c022f51c5de5fa3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 12 Mar 2017 17:06:06 -0500 Subject: [PATCH 111/264] Move hstore to pgtype Also implement binary format --- conn.go | 98 ++++++---- hstore.go | 222 --------------------- hstore_test.go | 181 ----------------- pgtype/hstore.go | 438 ++++++++++++++++++++++++++++++++++++++++++ pgtype/hstore_test.go | 108 +++++++++++ values.go | 135 ------------- 6 files changed, 603 insertions(+), 579 deletions(-) delete mode 100644 hstore.go delete mode 100644 hstore_test.go create mode 100644 pgtype/hstore.go create mode 100644 pgtype/hstore_test.go diff --git a/conn.go b/conn.go index 4085722c..d541e942 100644 --- a/conn.go +++ b/conn.go @@ -267,47 +267,6 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.doneChan = make(chan struct{}) c.closedChan = make(chan error) - c.oidPgtypeValues = map[Oid]pgtype.Value{ - AclitemArrayOid: &pgtype.AclitemArray{}, - AclitemOid: &pgtype.Aclitem{}, - BoolArrayOid: &pgtype.BoolArray{}, - BoolOid: &pgtype.Bool{}, - ByteaArrayOid: &pgtype.ByteaArray{}, - ByteaOid: &pgtype.Bytea{}, - CharOid: &pgtype.QChar{}, - CidOid: &pgtype.Cid{}, - CidrArrayOid: &pgtype.CidrArray{}, - CidrOid: &pgtype.Inet{}, - DateArrayOid: &pgtype.DateArray{}, - DateOid: &pgtype.Date{}, - Float4ArrayOid: &pgtype.Float4Array{}, - Float4Oid: &pgtype.Float4{}, - Float8ArrayOid: &pgtype.Float8Array{}, - Float8Oid: &pgtype.Float8{}, - InetArrayOid: &pgtype.InetArray{}, - InetOid: &pgtype.Inet{}, - Int2ArrayOid: &pgtype.Int2Array{}, - Int2Oid: &pgtype.Int2{}, - Int4ArrayOid: &pgtype.Int4Array{}, - Int4Oid: &pgtype.Int4{}, - Int8ArrayOid: &pgtype.Int8Array{}, - Int8Oid: &pgtype.Int8{}, - JsonbOid: &pgtype.Jsonb{}, - JsonOid: &pgtype.Json{}, - NameOid: &pgtype.Name{}, - OidOid: &pgtype.Oid{}, - TextArrayOid: &pgtype.TextArray{}, - TextOid: &pgtype.Text{}, - TidOid: &pgtype.Tid{}, - TimestampArrayOid: &pgtype.TimestampArray{}, - TimestampOid: &pgtype.Timestamp{}, - TimestampTzArrayOid: &pgtype.TimestamptzArray{}, - TimestampTzOid: &pgtype.Timestamptz{}, - VarcharArrayOid: &pgtype.VarcharArray{}, - VarcharOid: &pgtype.Text{}, - XidOid: &pgtype.Xid{}, - } - if tlsConfig != nil { if c.shouldLog(LogLevelDebug) { c.log(LogLevelDebug, "Starting TLS handshake") @@ -317,6 +276,8 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl } } + c.loadStaticOidPgtypeValues() + c.mr.cr = chunkreader.NewChunkReader(c.conn) msg := newStartupMessage() @@ -376,6 +337,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl return err } } + c.loadDynamicOidPgtypeValues() return nil default: @@ -416,6 +378,60 @@ where ( return rows.Err() } +func (c *Conn) loadStaticOidPgtypeValues() { + c.oidPgtypeValues = map[Oid]pgtype.Value{ + AclitemArrayOid: &pgtype.AclitemArray{}, + AclitemOid: &pgtype.Aclitem{}, + BoolArrayOid: &pgtype.BoolArray{}, + BoolOid: &pgtype.Bool{}, + ByteaArrayOid: &pgtype.ByteaArray{}, + ByteaOid: &pgtype.Bytea{}, + CharOid: &pgtype.QChar{}, + CidOid: &pgtype.Cid{}, + CidrArrayOid: &pgtype.CidrArray{}, + CidrOid: &pgtype.Inet{}, + DateArrayOid: &pgtype.DateArray{}, + DateOid: &pgtype.Date{}, + Float4ArrayOid: &pgtype.Float4Array{}, + Float4Oid: &pgtype.Float4{}, + Float8ArrayOid: &pgtype.Float8Array{}, + Float8Oid: &pgtype.Float8{}, + InetArrayOid: &pgtype.InetArray{}, + InetOid: &pgtype.Inet{}, + Int2ArrayOid: &pgtype.Int2Array{}, + Int2Oid: &pgtype.Int2{}, + Int4ArrayOid: &pgtype.Int4Array{}, + Int4Oid: &pgtype.Int4{}, + Int8ArrayOid: &pgtype.Int8Array{}, + Int8Oid: &pgtype.Int8{}, + JsonbOid: &pgtype.Jsonb{}, + JsonOid: &pgtype.Json{}, + NameOid: &pgtype.Name{}, + OidOid: &pgtype.Oid{}, + TextArrayOid: &pgtype.TextArray{}, + TextOid: &pgtype.Text{}, + TidOid: &pgtype.Tid{}, + TimestampArrayOid: &pgtype.TimestampArray{}, + TimestampOid: &pgtype.Timestamp{}, + TimestampTzArrayOid: &pgtype.TimestamptzArray{}, + TimestampTzOid: &pgtype.Timestamptz{}, + VarcharArrayOid: &pgtype.VarcharArray{}, + VarcharOid: &pgtype.Text{}, + XidOid: &pgtype.Xid{}, + } +} + +func (c *Conn) loadDynamicOidPgtypeValues() { + nameOids := make(map[string]Oid, len(c.PgTypes)) + for k, v := range c.PgTypes { + nameOids[v.Name] = k + } + + if oid, ok := nameOids["hstore"]; ok { + c.oidPgtypeValues[oid] = &pgtype.Hstore{} + } +} + // PID returns the backend PID for this connection. func (c *Conn) PID() int32 { return c.pid diff --git a/hstore.go b/hstore.go deleted file mode 100644 index 0ab9f779..00000000 --- a/hstore.go +++ /dev/null @@ -1,222 +0,0 @@ -package pgx - -import ( - "bytes" - "errors" - "fmt" - "unicode" - "unicode/utf8" -) - -const ( - hsPre = iota - hsKey - hsSep - hsVal - hsNul - hsNext -) - -type hstoreParser struct { - str string - pos int -} - -func newHSP(in string) *hstoreParser { - return &hstoreParser{ - pos: 0, - str: in, - } -} - -func (p *hstoreParser) Consume() (r rune, end bool) { - if p.pos >= len(p.str) { - end = true - return - } - r, w := utf8.DecodeRuneInString(p.str[p.pos:]) - p.pos += w - return -} - -func (p *hstoreParser) Peek() (r rune, end bool) { - if p.pos >= len(p.str) { - end = true - return - } - r, _ = utf8.DecodeRuneInString(p.str[p.pos:]) - return -} - -func parseHstoreToMap(s string) (m map[string]string, err error) { - keys, values, err := ParseHstore(s) - if err != nil { - return - } - m = make(map[string]string, len(keys)) - for i, key := range keys { - if !values[i].Valid { - err = fmt.Errorf("key '%s' has NULL value", key) - m = nil - return - } - m[key] = values[i].String - } - return -} - -func parseHstoreToNullHstore(s string) (store map[string]NullString, err error) { - keys, values, err := ParseHstore(s) - if err != nil { - return - } - - store = make(map[string]NullString, len(keys)) - - for i, key := range keys { - store[key] = values[i] - } - return -} - -// ParseHstore parses the string representation of an hstore column (the same -// you would get from an ordinary SELECT) into two slices of keys and values. it -// is used internally in the default parsing of hstores, but is exported for use -// in handling custom data structures backed by an hstore column without the -// overhead of creating a map[string]string -func ParseHstore(s string) (k []string, v []NullString, err error) { - if s == "" { - return - } - - buf := bytes.Buffer{} - keys := []string{} - values := []NullString{} - p := newHSP(s) - - r, end := p.Consume() - state := hsPre - - for !end { - switch state { - case hsPre: - if r == '"' { - state = hsKey - } else { - err = errors.New("String does not begin with \"") - } - case hsKey: - switch r { - case '"': //End of the key - if buf.Len() == 0 { - err = errors.New("Empty Key is invalid") - } else { - keys = append(keys, buf.String()) - buf = bytes.Buffer{} - state = hsSep - } - case '\\': //Potential escaped character - n, end := p.Consume() - switch { - case end: - err = errors.New("Found EOS in key, expecting character or \"") - case n == '"', n == '\\': - buf.WriteRune(n) - default: - buf.WriteRune(r) - buf.WriteRune(n) - } - default: //Any other character - buf.WriteRune(r) - } - case hsSep: - if r == '=' { - r, end = p.Consume() - switch { - case end: - err = errors.New("Found EOS after '=', expecting '>'") - case r == '>': - r, end = p.Consume() - switch { - case end: - err = errors.New("Found EOS after '=>', expecting '\"' or 'NULL'") - case r == '"': - state = hsVal - case r == 'N': - state = hsNul - default: - err = fmt.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r) - } - default: - err = fmt.Errorf("Invalid character after '=', expecting '>'") - } - } else { - err = fmt.Errorf("Invalid character '%c' after value, expecting '='", r) - } - case hsVal: - switch r { - case '"': //End of the value - values = append(values, NullString{String: buf.String(), Valid: true}) - buf = bytes.Buffer{} - state = hsNext - case '\\': //Potential escaped character - n, end := p.Consume() - switch { - case end: - err = errors.New("Found EOS in key, expecting character or \"") - case n == '"', n == '\\': - buf.WriteRune(n) - default: - buf.WriteRune(r) - buf.WriteRune(n) - } - default: //Any other character - buf.WriteRune(r) - } - case hsNul: - nulBuf := make([]rune, 3) - nulBuf[0] = r - for i := 1; i < 3; i++ { - r, end = p.Consume() - if end { - err = errors.New("Found EOS in NULL value") - return - } - nulBuf[i] = r - } - if nulBuf[0] == 'U' && nulBuf[1] == 'L' && nulBuf[2] == 'L' { - values = append(values, NullString{String: "", Valid: false}) - state = hsNext - } else { - err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf)) - } - case hsNext: - if r == ',' { - r, end = p.Consume() - switch { - case end: - err = errors.New("Found EOS after ',', expcting space") - case (unicode.IsSpace(r)): - r, end = p.Consume() - state = hsKey - default: - err = fmt.Errorf("Invalid character '%c' after ', ', expecting \"", r) - } - } else { - err = fmt.Errorf("Invalid character '%c' after value, expecting ','", r) - } - } - - if err != nil { - return - } - r, end = p.Consume() - } - if state != hsNext { - err = errors.New("Improperly formatted hstore") - return - } - k = keys - v = values - return -} diff --git a/hstore_test.go b/hstore_test.go deleted file mode 100644 index c948f0cd..00000000 --- a/hstore_test.go +++ /dev/null @@ -1,181 +0,0 @@ -package pgx_test - -import ( - "github.com/jackc/pgx" - "testing" -) - -func TestHstoreTranscode(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - type test struct { - hstore pgx.Hstore - description string - } - - tests := []test{ - {pgx.Hstore{}, "empty"}, - {pgx.Hstore{"foo": "bar"}, "single key/value"}, - {pgx.Hstore{"foo": "bar", "baz": "quz"}, "multiple key/values"}, - {pgx.Hstore{"NULL": "bar"}, `string "NULL" key`}, - {pgx.Hstore{"foo": "NULL"}, `string "NULL" value`}, - } - - specialStringTests := []struct { - input string - description string - }{ - {`"`, `double quote (")`}, - {`'`, `single quote (')`}, - {`\`, `backslash (\)`}, - {`\\`, `multiple backslashes (\\)`}, - {`=>`, `separator (=>)`}, - {` `, `space`}, - {`\ / / \\ => " ' " '`, `multiple special characters`}, - } - for _, sst := range specialStringTests { - tests = append(tests, test{pgx.Hstore{sst.input + "foo": "bar"}, "key with " + sst.description + " at beginning"}) - tests = append(tests, test{pgx.Hstore{"foo" + sst.input + "foo": "bar"}, "key with " + sst.description + " in middle"}) - tests = append(tests, test{pgx.Hstore{"foo" + sst.input: "bar"}, "key with " + sst.description + " at end"}) - tests = append(tests, test{pgx.Hstore{sst.input: "bar"}, "key is " + sst.description}) - - tests = append(tests, test{pgx.Hstore{"foo": sst.input + "bar"}, "value with " + sst.description + " at beginning"}) - tests = append(tests, test{pgx.Hstore{"foo": "bar" + sst.input + "bar"}, "value with " + sst.description + " in middle"}) - tests = append(tests, test{pgx.Hstore{"foo": "bar" + sst.input}, "value with " + sst.description + " at end"}) - tests = append(tests, test{pgx.Hstore{"foo": sst.input}, "value is " + sst.description}) - } - - for _, tt := range tests { - var result pgx.Hstore - err := conn.QueryRow("select $1::hstore", tt.hstore).Scan(&result) - if err != nil { - t.Errorf(`%s: QueryRow.Scan returned an error: %v`, tt.description, err) - } - - for key, inValue := range tt.hstore { - outValue, ok := result[key] - if ok { - if inValue != outValue { - t.Errorf(`%s: Key %s mismatch - expected %s, received %s`, tt.description, key, inValue, outValue) - } - } else { - t.Errorf(`%s: Missing key %s`, tt.description, key) - } - } - - ensureConnValid(t, conn) - } -} - -func TestNullHstoreTranscode(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - type test struct { - nullHstore pgx.NullHstore - description string - } - - tests := []test{ - {pgx.NullHstore{}, "null"}, - {pgx.NullHstore{Valid: true}, "empty"}, - {pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": {String: "bar", Valid: true}}, - Valid: true}, - "single key/value"}, - {pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": {String: "bar", Valid: true}, "baz": {String: "quz", Valid: true}}, - Valid: true}, - "multiple key/values"}, - {pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"NULL": {String: "bar", Valid: true}}, - Valid: true}, - `string "NULL" key`}, - {pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": {String: "NULL", Valid: true}}, - Valid: true}, - `string "NULL" value`}, - {pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": {String: "", Valid: false}}, - Valid: true}, - `NULL value`}, - } - - specialStringTests := []struct { - input string - description string - }{ - {`"`, `double quote (")`}, - {`'`, `single quote (')`}, - {`\`, `backslash (\)`}, - {`\\`, `multiple backslashes (\\)`}, - {`=>`, `separator (=>)`}, - {` `, `space`}, - {`\ / / \\ => " ' " '`, `multiple special characters`}, - } - for _, sst := range specialStringTests { - tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{sst.input + "foo": {String: "bar", Valid: true}}, - Valid: true}, - "key with " + sst.description + " at beginning"}) - tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo" + sst.input + "foo": {String: "bar", Valid: true}}, - Valid: true}, - "key with " + sst.description + " in middle"}) - tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo" + sst.input: {String: "bar", Valid: true}}, - Valid: true}, - "key with " + sst.description + " at end"}) - tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{sst.input: {String: "bar", Valid: true}}, - Valid: true}, - "key is " + sst.description}) - - tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": {String: sst.input + "bar", Valid: true}}, - Valid: true}, - "value with " + sst.description + " at beginning"}) - tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": {String: "bar" + sst.input + "bar", Valid: true}}, - Valid: true}, - "value with " + sst.description + " in middle"}) - tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": {String: "bar" + sst.input, Valid: true}}, - Valid: true}, - "value with " + sst.description + " at end"}) - tests = append(tests, test{pgx.NullHstore{ - Hstore: map[string]pgx.NullString{"foo": {String: sst.input, Valid: true}}, - Valid: true}, - "value is " + sst.description}) - } - - for _, tt := range tests { - var result pgx.NullHstore - err := conn.QueryRow("select $1::hstore", tt.nullHstore).Scan(&result) - if err != nil { - t.Errorf(`%s: QueryRow.Scan returned an error: %v`, tt.description, err) - } - - if result.Valid != tt.nullHstore.Valid { - t.Errorf(`%s: Valid mismatch - expected %v, received %v`, tt.description, tt.nullHstore.Valid, result.Valid) - } - - for key, inValue := range tt.nullHstore.Hstore { - outValue, ok := result.Hstore[key] - if ok { - if inValue != outValue { - t.Errorf(`%s: Key %s mismatch - expected %v, received %v`, tt.description, key, inValue, outValue) - } - } else { - t.Errorf(`%s: Missing key %s`, tt.description, key) - } - } - - ensureConnValid(t, conn) - } -} diff --git a/pgtype/hstore.go b/pgtype/hstore.go new file mode 100644 index 00000000..11bfb9a7 --- /dev/null +++ b/pgtype/hstore.go @@ -0,0 +1,438 @@ +package pgtype + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "strings" + "unicode" + "unicode/utf8" + + "github.com/jackc/pgx/pgio" +) + +// Hstore represents an hstore column that can be null or have null values +// associated with its keys. +type Hstore struct { + Map map[string]Text + Status Status +} + +func (dst *Hstore) Set(src interface{}) error { + switch value := src.(type) { + case map[string]string: + m := make(map[string]Text, len(value)) + for k, v := range value { + m[k] = Text{String: v, Status: Present} + } + *dst = Hstore{Map: m, Status: Present} + default: + return fmt.Errorf("cannot convert %v to Tid", src) + } + + return nil +} + +func (dst *Hstore) Get() interface{} { + switch dst.Status { + case Present: + return dst.Map + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Hstore) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *map[string]string: + switch src.Status { + case Present: + *v = make(map[string]string, len(src.Map)) + for k, val := range src.Map { + if val.Status != Present { + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + (*v)[k] = val.String + } + case Null: + *v = nil + default: + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + default: + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +func (dst *Hstore) DecodeText(src []byte) error { + if src == nil { + *dst = Hstore{Status: Null} + return nil + } + + keys, values, err := parseHstore(string(src)) + if err != nil { + return err + } + + m := make(map[string]Text, len(keys)) + for i := range keys { + m[keys[i]] = values[i] + } + + *dst = Hstore{Map: m, Status: Present} + return nil +} + +func (dst *Hstore) DecodeBinary(src []byte) error { + if src == nil { + *dst = Hstore{Status: Null} + return nil + } + + rp := 0 + + if len(src[rp:]) < 4 { + return fmt.Errorf("hstore incomplete %v", src) + } + pairCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + m := make(map[string]Text, pairCount) + + for i := 0; i < pairCount; i++ { + if len(src[rp:]) < 4 { + return fmt.Errorf("hstore incomplete %v", src) + } + keyLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + if len(src[rp:]) < keyLen { + return fmt.Errorf("hstore incomplete %v", src) + } + key := string(src[rp : rp+keyLen]) + rp += keyLen + + if len(src[rp:]) < 4 { + return fmt.Errorf("hstore incomplete %v", src) + } + valueLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + var valueBuf []byte + if valueLen >= 0 { + valueBuf = src[rp : rp+valueLen] + } + rp += valueLen + + var value Text + err := value.DecodeBinary(valueBuf) + if err != nil { + return err + } + m[key] = value + } + + *dst = Hstore{Map: m, Status: Present} + + return nil +} + +func (src Hstore) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + firstPair := true + + for k, v := range src.Map { + if firstPair { + firstPair = false + } else { + err := pgio.WriteByte(w, ',') + if err != nil { + return false, err + } + } + + _, err := io.WriteString(w, quoteHstoreElementIfNeeded(k)) + if err != nil { + return false, err + } + + _, err = io.WriteString(w, "=>") + if err != nil { + return false, err + } + + elemBuf := &bytes.Buffer{} + null, err := v.EncodeText(elemBuf) + if err != nil { + return false, err + } + + if null { + _, err = io.WriteString(w, "NULL") + if err != nil { + return false, err + } + } else { + _, err := io.WriteString(w, quoteHstoreElementIfNeeded(elemBuf.String())) + if err != nil { + return false, err + } + } + } + + return false, nil +} + +func (src Hstore) EncodeBinary(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := pgio.WriteInt32(w, int32(len(src.Map))) + if err != nil { + return false, err + } + + elemBuf := &bytes.Buffer{} + for k, v := range src.Map { + _, err := pgio.WriteInt32(w, int32(len(k))) + if err != nil { + return false, err + } + _, err = io.WriteString(w, k) + if err != nil { + return false, err + } + + null, err := v.EncodeText(elemBuf) + if err != nil { + return false, err + } + if null { + _, err := pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err := pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } + } + } + + return false, err +} + +var quoteHstoreReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) + +func quoteHstoreElement(src string) string { + return `"` + quoteArrayReplacer.Replace(src) + `"` +} + +func quoteHstoreElementIfNeeded(src string) string { + if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || strings.ContainsAny(src, ` {},"\=>`) { + return quoteArrayElement(src) + } + return src +} + +const ( + hsPre = iota + hsKey + hsSep + hsVal + hsNul + hsNext +) + +type hstoreParser struct { + str string + pos int +} + +func newHSP(in string) *hstoreParser { + return &hstoreParser{ + pos: 0, + str: in, + } +} + +func (p *hstoreParser) Consume() (r rune, end bool) { + if p.pos >= len(p.str) { + end = true + return + } + r, w := utf8.DecodeRuneInString(p.str[p.pos:]) + p.pos += w + return +} + +func (p *hstoreParser) Peek() (r rune, end bool) { + if p.pos >= len(p.str) { + end = true + return + } + r, _ = utf8.DecodeRuneInString(p.str[p.pos:]) + return +} + +// parseHstore parses the string representation of an hstore column (the same +// you would get from an ordinary SELECT) into two slices of keys and values. it +// is used internally in the default parsing of hstores. +func parseHstore(s string) (k []string, v []Text, err error) { + if s == "" { + return + } + + buf := bytes.Buffer{} + keys := []string{} + values := []Text{} + p := newHSP(s) + + r, end := p.Consume() + state := hsPre + + for !end { + switch state { + case hsPre: + if r == '"' { + state = hsKey + } else { + err = errors.New("String does not begin with \"") + } + case hsKey: + switch r { + case '"': //End of the key + if buf.Len() == 0 { + err = errors.New("Empty Key is invalid") + } else { + keys = append(keys, buf.String()) + buf = bytes.Buffer{} + state = hsSep + } + case '\\': //Potential escaped character + n, end := p.Consume() + switch { + case end: + err = errors.New("Found EOS in key, expecting character or \"") + case n == '"', n == '\\': + buf.WriteRune(n) + default: + buf.WriteRune(r) + buf.WriteRune(n) + } + default: //Any other character + buf.WriteRune(r) + } + case hsSep: + if r == '=' { + r, end = p.Consume() + switch { + case end: + err = errors.New("Found EOS after '=', expecting '>'") + case r == '>': + r, end = p.Consume() + switch { + case end: + err = errors.New("Found EOS after '=>', expecting '\"' or 'NULL'") + case r == '"': + state = hsVal + case r == 'N': + state = hsNul + default: + err = fmt.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r) + } + default: + err = fmt.Errorf("Invalid character after '=', expecting '>'") + } + } else { + err = fmt.Errorf("Invalid character '%c' after value, expecting '='", r) + } + case hsVal: + switch r { + case '"': //End of the value + values = append(values, Text{String: buf.String(), Status: Present}) + buf = bytes.Buffer{} + state = hsNext + case '\\': //Potential escaped character + n, end := p.Consume() + switch { + case end: + err = errors.New("Found EOS in key, expecting character or \"") + case n == '"', n == '\\': + buf.WriteRune(n) + default: + buf.WriteRune(r) + buf.WriteRune(n) + } + default: //Any other character + buf.WriteRune(r) + } + case hsNul: + nulBuf := make([]rune, 3) + nulBuf[0] = r + for i := 1; i < 3; i++ { + r, end = p.Consume() + if end { + err = errors.New("Found EOS in NULL value") + return + } + nulBuf[i] = r + } + if nulBuf[0] == 'U' && nulBuf[1] == 'L' && nulBuf[2] == 'L' { + values = append(values, Text{Status: Null}) + state = hsNext + } else { + err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf)) + } + case hsNext: + if r == ',' { + r, end = p.Consume() + switch { + case end: + err = errors.New("Found EOS after ',', expcting space") + case (unicode.IsSpace(r)): + r, end = p.Consume() + state = hsKey + default: + err = fmt.Errorf("Invalid character '%c' after ', ', expecting \"", r) + } + } else { + err = fmt.Errorf("Invalid character '%c' after value, expecting ','", r) + } + } + + if err != nil { + return + } + r, end = p.Consume() + } + if state != hsNext { + err = errors.New("Improperly formatted hstore") + return + } + k = keys + v = values + return +} diff --git a/pgtype/hstore_test.go b/pgtype/hstore_test.go new file mode 100644 index 00000000..fbe8dee5 --- /dev/null +++ b/pgtype/hstore_test.go @@ -0,0 +1,108 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestHstoreTranscode(t *testing.T) { + text := func(s string) pgtype.Text { + return pgtype.Text{String: s, Status: pgtype.Present} + } + + values := []interface{}{ + pgtype.Hstore{Map: map[string]pgtype.Text{}, Status: pgtype.Present}, + pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, + pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, + pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, + pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, + pgtype.Hstore{Status: pgtype.Null}, + } + + specialStrings := []string{ + `"`, + `'`, + `\`, + `\\`, + `=>`, + ` `, + `\ / / \\ => " ' " '`, + } + for _, s := range specialStrings { + // Special key values + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Status: pgtype.Present}) // at beginning + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Status: pgtype.Present}) // in middle + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Status: pgtype.Present}) // at end + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Status: pgtype.Present}) // is key + + // Special value values + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Status: pgtype.Present}) // at beginning + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Status: pgtype.Present}) // in middle + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Status: pgtype.Present}) // at end + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key + } + + testSuccessfulTranscodeEqFunc(t, "hstore", values, func(ai, bi interface{}) bool { + a := ai.(pgtype.Hstore) + b := bi.(pgtype.Hstore) + + if len(a.Map) != len(b.Map) || a.Status != b.Status { + return false + } + + for k := range a.Map { + if a.Map[k] != b.Map[k] { + return false + } + } + + return true + }) +} + +func TestHstoreSet(t *testing.T) { + successfulTests := []struct { + src map[string]string + result pgtype.Hstore + }{ + {src: map[string]string{"foo": "bar"}, result: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": pgtype.Text{String: "bar", Status: pgtype.Present}}, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var dst pgtype.Hstore + err := dst.Set(tt.src) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(dst, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.src, tt.result, dst) + } + } +} + +func TestHstoreAssignTo(t *testing.T) { + var m map[string]string + + simpleTests := []struct { + src pgtype.Hstore + dst *map[string]string + expected map[string]string + }{ + {src: pgtype.Hstore{Map: map[string]pgtype.Text{"foo": pgtype.Text{String: "bar", Status: pgtype.Present}}, Status: pgtype.Present}, dst: &m, expected: map[string]string{"foo": "bar"}}, + {src: pgtype.Hstore{Status: pgtype.Null}, dst: &m, expected: ((map[string]string)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(*tt.dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } +} diff --git a/values.go b/values.go index e976d0d3..e1c8f731 100644 --- a/values.go +++ b/values.go @@ -10,7 +10,6 @@ import ( "math" "reflect" "strconv" - "strings" "time" "github.com/jackc/pgx/pgio" @@ -577,140 +576,6 @@ func (n NullTime) Encode(w *WriteBuf, oid Oid) error { return encodeTime(w, oid, n.Time) } -// Hstore represents an hstore column. It does not support a null column or null -// key values (use NullHstore for this). Hstore implements the Scanner and -// Encoder interfaces so it may be used both as an argument to Query[Row] and a -// destination for Scan. -type Hstore map[string]string - -func (h *Hstore) Scan(vr *ValueReader) error { - //oid for hstore not standardized, so we check its type name - if vr.Type().DataTypeName != "hstore" { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode type %s into Hstore", vr.Type().DataTypeName))) - return nil - } - - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null column into Hstore")) - return nil - } - - switch vr.Type().FormatCode { - case TextFormatCode: - m, err := parseHstoreToMap(vr.ReadString(vr.Len())) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode hstore column: %v", err))) - return nil - } - hm := Hstore(m) - *h = hm - return nil - case BinaryFormatCode: - vr.Fatal(ProtocolError("Can't decode binary hstore")) - return nil - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } -} - -func (h Hstore) FormatCode() int16 { return TextFormatCode } - -func (h Hstore) Encode(w *WriteBuf, oid Oid) error { - var buf bytes.Buffer - - i := 0 - for k, v := range h { - i++ - ks := strings.Replace(k, `\`, `\\`, -1) - ks = strings.Replace(ks, `"`, `\"`, -1) - vs := strings.Replace(v, `\`, `\\`, -1) - vs = strings.Replace(vs, `"`, `\"`, -1) - buf.WriteString(fmt.Sprintf(`"%s"=>"%s"`, ks, vs)) - if i < len(h) { - buf.WriteString(", ") - } - } - w.WriteInt32(int32(buf.Len())) - w.WriteBytes(buf.Bytes()) - return nil -} - -// NullHstore represents an hstore column that can be null or have null values -// associated with its keys. NullHstore implements the Scanner and Encoder -// interfaces so it may be used both as an argument to Query[Row] and a -// destination for Scan. -// -// If Valid is false, then the value of the entire hstore column is NULL -// If any of the NullString values in Store has Valid set to false, the key -// appears in the hstore column, but its value is explicitly set to NULL. -type NullHstore struct { - Hstore map[string]NullString - Valid bool -} - -func (h *NullHstore) Scan(vr *ValueReader) error { - //oid for hstore not standardized, so we check its type name - if vr.Type().DataTypeName != "hstore" { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode type %s into NullHstore", vr.Type().DataTypeName))) - return nil - } - - if vr.Len() == -1 { - h.Valid = false - return nil - } - - switch vr.Type().FormatCode { - case TextFormatCode: - store, err := parseHstoreToNullHstore(vr.ReadString(vr.Len())) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Can't decode hstore column: %v", err))) - return nil - } - h.Valid = true - h.Hstore = store - return nil - case BinaryFormatCode: - vr.Fatal(ProtocolError("Can't decode binary hstore")) - return nil - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } -} - -func (h NullHstore) FormatCode() int16 { return TextFormatCode } - -func (h NullHstore) Encode(w *WriteBuf, oid Oid) error { - var buf bytes.Buffer - - if !h.Valid { - w.WriteInt32(-1) - return nil - } - - i := 0 - for k, v := range h.Hstore { - i++ - ks := strings.Replace(k, `\`, `\\`, -1) - ks = strings.Replace(ks, `"`, `\"`, -1) - if v.Valid { - vs := strings.Replace(v.String, `\`, `\\`, -1) - vs = strings.Replace(vs, `"`, `\"`, -1) - buf.WriteString(fmt.Sprintf(`"%s"=>"%s"`, ks, vs)) - } else { - buf.WriteString(fmt.Sprintf(`"%s"=>NULL`, ks)) - } - if i < len(h.Hstore) { - buf.WriteString(", ") - } - } - w.WriteInt32(int32(buf.Len())) - w.WriteBytes(buf.Bytes()) - return nil -} - // Encode encodes arg into wbuf as the type oid. This allows implementations // of the Encoder interface to delegate the actual work of encoding to the // built-in functionality. From 26d57356f7dcf5bc3986d4d50a15c401e1eda190 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 12 Mar 2017 17:22:04 -0500 Subject: [PATCH 112/264] Remove old Scanner and Encoder system --- bench_test.go | 127 +----------- conn.go | 4 +- example_custom_type_test.go | 87 ++++----- query.go | 10 - query_test.go | 103 ---------- values.go | 375 ------------------------------------ values_test.go | 142 -------------- 7 files changed, 41 insertions(+), 807 deletions(-) diff --git a/bench_test.go b/bench_test.go index b08c2b4e..348c840c 100644 --- a/bench_test.go +++ b/bench_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" ) func BenchmarkConnPool(b *testing.B) { @@ -49,126 +50,6 @@ func BenchmarkConnPoolQueryRow(b *testing.B) { } } -func BenchmarkNullXWithNullValues(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - - _, err := conn.Prepare("selectNulls", "select 1::int4, 'johnsmith', null::text, null::text, null::text, null::date, null::timestamptz") - if err != nil { - b.Fatal(err) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - var record struct { - id int32 - userName string - email pgx.NullString - name pgx.NullString - sex pgx.NullString - birthDate pgx.NullTime - lastLoginTime pgx.NullTime - } - - err = conn.QueryRow("selectNulls").Scan( - &record.id, - &record.userName, - &record.email, - &record.name, - &record.sex, - &record.birthDate, - &record.lastLoginTime, - ) - if err != nil { - b.Fatal(err) - } - - // These checks both ensure that the correct data was returned - // and provide a benchmark of accessing the returned values. - if record.id != 1 { - b.Fatalf("bad value for id: %v", record.id) - } - if record.userName != "johnsmith" { - b.Fatalf("bad value for userName: %v", record.userName) - } - if record.email.Valid { - b.Fatalf("bad value for email: %v", record.email) - } - if record.name.Valid { - b.Fatalf("bad value for name: %v", record.name) - } - if record.sex.Valid { - b.Fatalf("bad value for sex: %v", record.sex) - } - if record.birthDate.Valid { - b.Fatalf("bad value for birthDate: %v", record.birthDate) - } - if record.lastLoginTime.Valid { - b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime) - } - } -} - -func BenchmarkNullXWithPresentValues(b *testing.B) { - conn := mustConnect(b, *defaultConnConfig) - defer closeConn(b, conn) - - _, err := conn.Prepare("selectNulls", "select 1::int4, 'johnsmith', 'johnsmith@example.com', 'John Smith', 'male', '1970-01-01'::date, '2015-01-01 00:00:00'::timestamptz") - if err != nil { - b.Fatal(err) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - var record struct { - id int32 - userName string - email pgx.NullString - name pgx.NullString - sex pgx.NullString - birthDate pgx.NullTime - lastLoginTime pgx.NullTime - } - - err = conn.QueryRow("selectNulls").Scan( - &record.id, - &record.userName, - &record.email, - &record.name, - &record.sex, - &record.birthDate, - &record.lastLoginTime, - ) - if err != nil { - b.Fatal(err) - } - - // These checks both ensure that the correct data was returned - // and provide a benchmark of accessing the returned values. - if record.id != 1 { - b.Fatalf("bad value for id: %v", record.id) - } - if record.userName != "johnsmith" { - b.Fatalf("bad value for userName: %v", record.userName) - } - if !record.email.Valid || record.email.String != "johnsmith@example.com" { - b.Fatalf("bad value for email: %v", record.email) - } - if !record.name.Valid || record.name.String != "John Smith" { - b.Fatalf("bad value for name: %v", record.name) - } - if !record.sex.Valid || record.sex.String != "male" { - b.Fatalf("bad value for sex: %v", record.sex) - } - if !record.birthDate.Valid || record.birthDate.Time != time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local) { - b.Fatalf("bad value for birthDate: %v", record.birthDate) - } - if !record.lastLoginTime.Valid || record.lastLoginTime.Time != time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local) { - b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime) - } - } -} - func BenchmarkPointerPointerWithNullValues(b *testing.B) { conn := mustConnect(b, *defaultConnConfig) defer closeConn(b, conn) @@ -475,12 +356,12 @@ func newBenchmarkWriteTableCopyToSrc(count int) pgx.CopyToSource { row: []interface{}{ "varchar_1", "varchar_2", - pgx.NullString{}, + pgtype.Text{}, time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), - pgx.NullTime{}, + pgtype.Date{}, 1, 2, - pgx.NullInt32{}, + pgtype.Int4{}, time.Date(2001, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2002, 1, 1, 0, 0, 0, 0, time.Local), true, diff --git a/conn.go b/conn.go index d541e942..ae83fc77 100644 --- a/conn.go +++ b/conn.go @@ -1003,9 +1003,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteInt16(int16(len(ps.ParameterOids))) for i, oid := range ps.ParameterOids { - switch arg := arguments[i].(type) { - case Encoder: - wbuf.WriteInt16(arg.FormatCode()) + switch arguments[i].(type) { case pgtype.BinaryEncoder: wbuf.WriteInt16(BinaryFormatCode) case pgtype.TextEncoder: diff --git a/example_custom_type_test.go b/example_custom_type_test.go index 74fbab67..71110f85 100644 --- a/example_custom_type_test.go +++ b/example_custom_type_test.go @@ -1,78 +1,63 @@ package pgx_test import ( - "errors" "fmt" + "io" "regexp" "strconv" "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" ) var pointRegexp *regexp.Regexp = regexp.MustCompile(`^\((.*),(.*)\)$`) -// NullPoint represents a point that may be null. -// -// If Valid is false then the value is NULL. -type NullPoint struct { - X, Y float64 // Coordinates of point - Valid bool // Valid is true if not NULL +// Point represents a point that may be null. +type Point struct { + X, Y float64 // Coordinates of point + Status pgtype.Status } -func (p *NullPoint) ScanPgx(vr *pgx.ValueReader) error { - if vr.Type().DataTypeName != "point" { - return pgx.SerializationError(fmt.Sprintf("NullPoint.Scan cannot decode %s (Oid %d)", vr.Type().DataTypeName, vr.Type().DataType)) - } - - if vr.Len() == -1 { - p.X, p.Y, p.Valid = 0, 0, false +func (dst *Point) DecodeText(src []byte) error { + if src == nil { + *dst = Point{Status: pgtype.Null} return nil } - switch vr.Type().FormatCode { - case pgx.TextFormatCode: - s := vr.ReadString(vr.Len()) - match := pointRegexp.FindStringSubmatch(s) - if match == nil { - return pgx.SerializationError(fmt.Sprintf("Received invalid point: %v", s)) - } - - var err error - p.X, err = strconv.ParseFloat(match[1], 64) - if err != nil { - return pgx.SerializationError(fmt.Sprintf("Received invalid point: %v", s)) - } - p.Y, err = strconv.ParseFloat(match[2], 64) - if err != nil { - return pgx.SerializationError(fmt.Sprintf("Received invalid point: %v", s)) - } - case pgx.BinaryFormatCode: - return errors.New("binary format not implemented") - default: - return fmt.Errorf("unknown format %v", vr.Type().FormatCode) + s := string(src) + match := pointRegexp.FindStringSubmatch(s) + if match == nil { + return fmt.Errorf("Received invalid point: %v", s) } - p.Valid = true - return vr.Err() -} - -func (p NullPoint) FormatCode() int16 { return pgx.BinaryFormatCode } - -func (p NullPoint) Encode(w *pgx.WriteBuf, oid pgx.Oid) error { - if !p.Valid { - w.WriteInt32(-1) - return nil + x, err := strconv.ParseFloat(match[1], 64) + if err != nil { + return fmt.Errorf("Received invalid point: %v", s) + } + y, err := strconv.ParseFloat(match[2], 64) + if err != nil { + return fmt.Errorf("Received invalid point: %v", s) } - s := fmt.Sprintf("point(%v,%v)", p.X, p.Y) - w.WriteInt32(int32(len(s))) - w.WriteBytes([]byte(s)) + *dst = Point{X: x, Y: y, Status: pgtype.Present} return nil } -func (p NullPoint) String() string { - if p.Valid { +func (src Point) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case pgtype.Null: + return true, nil + case pgtype.Undefined: + return false, fmt.Errorf("undefined") + } + + _, err := io.WriteString(w, fmt.Sprintf("point(%v,%v)", src.X, src.Y)) + return false, err +} + +func (p Point) String() string { + if p.Status == pgtype.Present { return fmt.Sprintf("%v, %v", p.X, p.Y) } return "null point" @@ -85,7 +70,7 @@ func Example_CustomType() { return } - var p NullPoint + var p Point err = conn.QueryRow("select null::point").Scan(&p) if err != nil { fmt.Println(err) diff --git a/query.go b/query.go index d8caa08d..63ce91ed 100644 --- a/query.go +++ b/query.go @@ -211,16 +211,6 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { *b = nil } } - } else if s, ok := d.(Scanner); ok { - err = s.Scan(vr) - if err != nil { - rows.Fatal(scanArgError{col: i, err: err}) - } - } else if s, ok := d.(PgxScanner); ok { - err = s.ScanPgx(vr) - if err != nil { - rows.Fatal(scanArgError{col: i, err: err}) - } } else if s, ok := d.(pgtype.BinaryDecoder); ok && vr.Type().FormatCode == BinaryFormatCode { err = s.DecodeBinary(vr.bytes()) if err != nil { diff --git a/query_test.go b/query_test.go index 46b012cf..8838329c 100644 --- a/query_test.go +++ b/query_test.go @@ -270,44 +270,6 @@ func TestConnQueryScanIgnoreColumn(t *testing.T) { ensureConnValid(t, conn) } -func TestConnQueryScanner(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - rows, err := conn.Query("select null::int8, 1::int8") - if err != nil { - t.Fatalf("conn.Query failed: %v", err) - } - - ok := rows.Next() - if !ok { - t.Fatal("rows.Next terminated early") - } - - var n, m pgx.NullInt64 - err = rows.Scan(&n, &m) - if err != nil { - t.Fatalf("rows.Scan failed: %v", err) - } - rows.Close() - - if n.Valid { - t.Error("Null should not be valid, but it was") - } - - if !m.Valid { - t.Error("1 should be valid, but it wasn't") - } - - if m.Int64 != 1 { - t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64) - } - - ensureConnValid(t, conn) -} - func TestConnQueryErrorWhileReturningRows(t *testing.T) { t.Parallel() @@ -339,42 +301,6 @@ func TestConnQueryErrorWhileReturningRows(t *testing.T) { } -func TestConnQueryEncoder(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - n := pgx.NullInt64{Int64: 1, Valid: true} - - rows, err := conn.Query("select $1::int8", &n) - if err != nil { - t.Fatalf("conn.Query failed: %v", err) - } - - ok := rows.Next() - if !ok { - t.Fatal("rows.Next terminated early") - } - - var m pgx.NullInt64 - err = rows.Scan(&m) - if err != nil { - t.Fatalf("rows.Scan failed: %v", err) - } - rows.Close() - - if !m.Valid { - t.Error("m should be valid, but it wasn't") - } - - if m.Int64 != 1 { - t.Errorf("m.Int64 should have been 1, but it was %v", m.Int64) - } - - ensureConnValid(t, conn) -} - func TestQueryEncodeError(t *testing.T) { t.Parallel() @@ -397,35 +323,6 @@ func TestQueryEncodeError(t *testing.T) { } } -// Ensure that an argument that implements Encoder works when the parameter type -// is a core type. -type coreEncoder struct{} - -func (n coreEncoder) FormatCode() int16 { return pgx.TextFormatCode } - -func (n *coreEncoder) Encode(w *pgx.WriteBuf, oid pgx.Oid) error { - w.WriteInt32(int32(2)) - w.WriteBytes([]byte("42")) - return nil -} - -func TestQueryEncodeCoreTextFormatError(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - var n int32 - err := conn.QueryRow("select $1::integer", &coreEncoder{}).Scan(&n) - if err != nil { - t.Fatalf("Unexpected conn.QueryRow error: %v", err) - } - - if n != 42 { - t.Errorf("Expected 42, got %v", n) - } -} - func TestQueryRowCoreTypes(t *testing.T) { t.Parallel() diff --git a/values.go b/values.go index e1c8f731..e2b30087 100644 --- a/values.go +++ b/values.go @@ -159,245 +159,6 @@ func (e SerializationError) Error() string { return string(e) } -// Deprecated: Scanner is an interface used to decode values from the PostgreSQL -// server. To allow types to support pgx and database/sql.Scan this interface -// has been deprecated in favor of PgxScanner. -type Scanner interface { - // Scan MUST check r.Type().DataType (to check by Oid) or - // r.Type().DataTypeName (to check by name) to ensure that it is scanning an - // expected column type. It also MUST check r.Type().FormatCode before - // decoding. It should not assume that it was called on a data type or format - // that it understands. - Scan(r *ValueReader) error -} - -// PgxScanner is an interface used to decode values from the PostgreSQL server. -// It is used exactly the same as the Scanner interface. It simply has renamed -// the method. -type PgxScanner interface { - // ScanPgx MUST check r.Type().DataType (to check by Oid) or - // r.Type().DataTypeName (to check by name) to ensure that it is scanning an - // expected column type. It also MUST check r.Type().FormatCode before - // decoding. It should not assume that it was called on a data type or format - // that it understands. - ScanPgx(r *ValueReader) error -} - -// Encoder is an interface used to encode values for transmission to the -// PostgreSQL server. -type Encoder interface { - // Encode writes the value to w. - // - // If the value is NULL an int32(-1) should be written. - // - // Encode MUST check oid to see if the parameter data type is compatible. If - // this is not done, the PostgreSQL server may detect the error if the - // expected data size or format of the encoded data does not match. But if - // the encoded data is a valid representation of the data type PostgreSQL - // expects such as date and int4, incorrect data may be stored. - Encode(w *WriteBuf, oid Oid) error - - // FormatCode returns the format that the encoder writes the value. It must be - // either pgx.TextFormatCode or pgx.BinaryFormatCode. - FormatCode() int16 -} - -// NullFloat32 represents an float4 that may be null. NullFloat32 implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. -// -// If Valid is false then the value is NULL. -type NullFloat32 struct { - Float32 float32 - Valid bool // Valid is true if Float32 is not NULL -} - -func (n *NullFloat32) Scan(vr *ValueReader) error { - if vr.Type().DataType != Float4Oid { - return SerializationError(fmt.Sprintf("NullFloat32.Scan cannot decode Oid %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Float32, n.Valid = 0, false - return nil - } - n.Valid = true - n.Float32 = decodeFloat4(vr) - return vr.Err() -} - -func (n NullFloat32) FormatCode() int16 { return BinaryFormatCode } - -func (n NullFloat32) Encode(w *WriteBuf, oid Oid) error { - if oid != Float4Oid { - return SerializationError(fmt.Sprintf("NullFloat32.Encode cannot encode into Oid %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeFloat32(w, oid, n.Float32) -} - -// NullFloat64 represents an float8 that may be null. NullFloat64 implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. -// -// If Valid is false then the value is NULL. -type NullFloat64 struct { - Float64 float64 - Valid bool // Valid is true if Float64 is not NULL -} - -func (n *NullFloat64) Scan(vr *ValueReader) error { - if vr.Type().DataType != Float8Oid { - return SerializationError(fmt.Sprintf("NullFloat64.Scan cannot decode Oid %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Float64, n.Valid = 0, false - return nil - } - n.Valid = true - n.Float64 = decodeFloat8(vr) - return vr.Err() -} - -func (n NullFloat64) FormatCode() int16 { return BinaryFormatCode } - -func (n NullFloat64) Encode(w *WriteBuf, oid Oid) error { - if oid != Float8Oid { - return SerializationError(fmt.Sprintf("NullFloat64.Encode cannot encode into Oid %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeFloat64(w, oid, n.Float64) -} - -// NullString represents an string that may be null. NullString implements the -// Scanner Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. -// -// If Valid is false then the value is NULL. -type NullString struct { - String string - Valid bool // Valid is true if String is not NULL -} - -func (n *NullString) Scan(vr *ValueReader) error { - // Not checking oid as so we can scan anything into into a NullString - may revisit this decision later - - if vr.Len() == -1 { - n.String, n.Valid = "", false - return nil - } - - n.Valid = true - n.String = decodeText(vr) - return vr.Err() -} - -func (n NullString) FormatCode() int16 { return TextFormatCode } - -func (s NullString) Encode(w *WriteBuf, oid Oid) error { - if !s.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeString(w, oid, s.String) -} - -// NullInt16 represents a smallint that may be null. NullInt16 implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan for prepared and unprepared queries. -// -// If Valid is false then the value is NULL. -type NullInt16 struct { - Int16 int16 - Valid bool // Valid is true if Int16 is not NULL -} - -func (n *NullInt16) Scan(vr *ValueReader) error { - if vr.Type().DataType != Int2Oid { - return SerializationError(fmt.Sprintf("NullInt16.Scan cannot decode Oid %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Int16, n.Valid = 0, false - return nil - } - n.Valid = true - n.Int16 = decodeInt2(vr) - return vr.Err() -} - -func (n NullInt16) FormatCode() int16 { return BinaryFormatCode } - -func (n NullInt16) Encode(w *WriteBuf, oid Oid) error { - if oid != Int2Oid { - return SerializationError(fmt.Sprintf("NullInt16.Encode cannot encode into Oid %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - w.WriteInt32(2) - - _, err := pgtype.Int2{Int: n.Int16, Status: pgtype.Present}.EncodeBinary(w) - return err -} - -// NullInt32 represents an integer that may be null. NullInt32 implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. -// -// If Valid is false then the value is NULL. -type NullInt32 struct { - Int32 int32 - Valid bool // Valid is true if Int32 is not NULL -} - -func (n *NullInt32) Scan(vr *ValueReader) error { - if vr.Type().DataType != Int4Oid { - return SerializationError(fmt.Sprintf("NullInt32.Scan cannot decode Oid %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Int32, n.Valid = 0, false - return nil - } - n.Valid = true - n.Int32 = decodeInt4(vr) - return vr.Err() -} - -func (n NullInt32) FormatCode() int16 { return BinaryFormatCode } - -func (n NullInt32) Encode(w *WriteBuf, oid Oid) error { - if oid != Int4Oid { - return SerializationError(fmt.Sprintf("NullInt32.Encode cannot encode into Oid %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - w.WriteInt32(4) - - _, err := pgtype.Int4{Int: n.Int32, Status: pgtype.Present}.EncodeBinary(w) - return err -} - // Oid (Object Identifier Type) is, according to https://www.postgresql.org/docs/current/static/datatype-oid.html, // used internally by PostgreSQL as a primary key for various system tables. It is currently implemented // as an unsigned four-byte integer. Its definition can be found in src/include/postgres_ext.h @@ -442,140 +203,6 @@ func (src Oid) EncodeBinary(w io.Writer) (bool, error) { return false, err } -// NullInt64 represents an bigint that may be null. NullInt64 implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. -// -// If Valid is false then the value is NULL. -type NullInt64 struct { - Int64 int64 - Valid bool // Valid is true if Int64 is not NULL -} - -func (n *NullInt64) Scan(vr *ValueReader) error { - if vr.Type().DataType != Int8Oid { - return SerializationError(fmt.Sprintf("NullInt64.Scan cannot decode Oid %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Int64, n.Valid = 0, false - return nil - } - n.Valid = true - n.Int64 = decodeInt8(vr) - return vr.Err() -} - -func (n NullInt64) FormatCode() int16 { return BinaryFormatCode } - -func (n NullInt64) Encode(w *WriteBuf, oid Oid) error { - if oid != Int8Oid { - return SerializationError(fmt.Sprintf("NullInt64.Encode cannot encode into Oid %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - w.WriteInt32(8) - - _, err := pgtype.Int8{Int: n.Int64, Status: pgtype.Present}.EncodeBinary(w) - return err -} - -// NullBool represents an bool that may be null. NullBool implements the Scanner -// and Encoder interfaces so it may be used both as an argument to Query[Row] -// and a destination for Scan. -// -// If Valid is false then the value is NULL. -type NullBool struct { - Bool bool - Valid bool // Valid is true if Bool is not NULL -} - -func (n *NullBool) Scan(vr *ValueReader) error { - if vr.Type().DataType != BoolOid { - return SerializationError(fmt.Sprintf("NullBool.Scan cannot decode Oid %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Bool, n.Valid = false, false - return nil - } - n.Valid = true - n.Bool = decodeBool(vr) - return vr.Err() -} - -func (n NullBool) FormatCode() int16 { return BinaryFormatCode } - -func (n NullBool) Encode(w *WriteBuf, oid Oid) error { - if oid != BoolOid { - return SerializationError(fmt.Sprintf("NullBool.Encode cannot encode into Oid %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - w.WriteInt32(1) - - _, err := pgtype.Bool{Bool: n.Bool, Status: pgtype.Present}.EncodeBinary(w) - return err -} - -// NullTime represents an time.Time that may be null. NullTime implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. It corresponds with the PostgreSQL -// types timestamptz, timestamp, and date. -// -// If Valid is false then the value is NULL. -type NullTime struct { - Time time.Time - Valid bool // Valid is true if Time is not NULL -} - -func (n *NullTime) Scan(vr *ValueReader) error { - oid := vr.Type().DataType - if oid != TimestampTzOid && oid != TimestampOid && oid != DateOid { - return SerializationError(fmt.Sprintf("NullTime.Scan cannot decode Oid %d", vr.Type().DataType)) - } - - if vr.Len() == -1 { - n.Time, n.Valid = time.Time{}, false - return nil - } - - n.Valid = true - switch oid { - case TimestampTzOid: - n.Time = decodeTimestampTz(vr) - case TimestampOid: - n.Time = decodeTimestamp(vr) - case DateOid: - n.Time = decodeDate(vr) - } - - return vr.Err() -} - -func (n NullTime) FormatCode() int16 { return BinaryFormatCode } - -func (n NullTime) Encode(w *WriteBuf, oid Oid) error { - if oid != TimestampTzOid && oid != TimestampOid && oid != DateOid { - return SerializationError(fmt.Sprintf("NullTime.Encode cannot encode into Oid %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) - return nil - } - - return encodeTime(w, oid, n.Time) -} - // Encode encodes arg into wbuf as the type oid. This allows implementations // of the Encoder interface to delegate the actual work of encoding to the // built-in functionality. @@ -586,8 +213,6 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error { } switch arg := arg.(type) { - case Encoder: - return arg.Encode(wbuf, oid) case pgtype.BinaryEncoder: buf := &bytes.Buffer{} null, err := arg.EncodeBinary(buf) diff --git a/values_test.go b/values_test.go index 7b82d456..69a91d4e 100644 --- a/values_test.go +++ b/values_test.go @@ -4,7 +4,6 @@ import ( "bytes" "net" "reflect" - "strings" "testing" "time" @@ -558,70 +557,6 @@ func TestInetCidrTranscodeWithJustIP(t *testing.T) { } } -func TestNullX(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - type allTypes struct { - s pgx.NullString - i16 pgx.NullInt16 - i32 pgx.NullInt32 - i64 pgx.NullInt64 - f32 pgx.NullFloat32 - f64 pgx.NullFloat64 - b pgx.NullBool - t pgx.NullTime - } - - var actual, zero allTypes - - tests := []struct { - sql string - queryArgs []interface{} - scanArgs []interface{} - expected allTypes - }{ - {"select $1::text", []interface{}{pgx.NullString{String: "foo", Valid: true}}, []interface{}{&actual.s}, allTypes{s: pgx.NullString{String: "foo", Valid: true}}}, - {"select $1::text", []interface{}{pgx.NullString{String: "foo", Valid: false}}, []interface{}{&actual.s}, allTypes{s: pgx.NullString{String: "", Valid: false}}}, - {"select $1::int2", []interface{}{pgx.NullInt16{Int16: 1, Valid: true}}, []interface{}{&actual.i16}, allTypes{i16: pgx.NullInt16{Int16: 1, Valid: true}}}, - {"select $1::int2", []interface{}{pgx.NullInt16{Int16: 1, Valid: false}}, []interface{}{&actual.i16}, allTypes{i16: pgx.NullInt16{Int16: 0, Valid: false}}}, - {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 1, Valid: true}}}, - {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: false}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 0, Valid: false}}}, - {"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: true}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 1, Valid: true}}}, - {"select $1::int8", []interface{}{pgx.NullInt64{Int64: 1, Valid: false}}, []interface{}{&actual.i64}, allTypes{i64: pgx.NullInt64{Int64: 0, Valid: false}}}, - {"select $1::float4", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: true}}, []interface{}{&actual.f32}, allTypes{f32: pgx.NullFloat32{Float32: 1.23, Valid: true}}}, - {"select $1::float4", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: false}}, []interface{}{&actual.f32}, allTypes{f32: pgx.NullFloat32{Float32: 0, Valid: false}}}, - {"select $1::float8", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: true}}, []interface{}{&actual.f64}, allTypes{f64: pgx.NullFloat64{Float64: 1.23, Valid: true}}}, - {"select $1::float8", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: false}}, []interface{}{&actual.f64}, allTypes{f64: pgx.NullFloat64{Float64: 0, Valid: false}}}, - {"select $1::bool", []interface{}{pgx.NullBool{Bool: true, Valid: true}}, []interface{}{&actual.b}, allTypes{b: pgx.NullBool{Bool: true, Valid: true}}}, - {"select $1::bool", []interface{}{pgx.NullBool{Bool: true, Valid: false}}, []interface{}{&actual.b}, allTypes{b: pgx.NullBool{Bool: false, Valid: false}}}, - {"select $1::timestamptz", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}}, - {"select $1::timestamptz", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}}, - {"select $1::timestamp", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}}, - {"select $1::timestamp", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}}, - {"select $1::date", []interface{}{pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC), Valid: true}}}, - {"select $1::date", []interface{}{pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}}, - {"select 42::int4, $1::float8", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: true}}, []interface{}{&actual.i32, &actual.f64}, allTypes{i32: pgx.NullInt32{Int32: 42, Valid: true}, f64: pgx.NullFloat64{Float64: 1.23, Valid: true}}}, - } - - for i, tt := range tests { - actual = zero - - err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...) - if err != nil { - t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, tt.sql, tt.queryArgs) - } - - if actual != tt.expected { - t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArgs -> %v)", i, tt.expected, actual, tt.sql, tt.queryArgs) - } - - ensureConnValid(t, conn) - } -} - func TestArrayDecoding(t *testing.T) { t.Parallel() @@ -736,36 +671,6 @@ func TestArrayDecoding(t *testing.T) { } } -type shortScanner struct{} - -func (*shortScanner) Scan(r *pgx.ValueReader) error { - r.ReadByte() - return nil -} - -func TestShortScanner(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - rows, err := conn.Query("select 'ab', 'cd' union select 'cd', 'ef'") - if err != nil { - t.Error(err) - } - defer rows.Close() - - for rows.Next() { - var s1, s2 shortScanner - err = rows.Scan(&s1, &s2) - if err != nil { - t.Error(err) - } - } - - ensureConnValid(t, conn) -} - func TestEmptyArrayDecoding(t *testing.T) { t.Parallel() @@ -814,53 +719,6 @@ func TestEmptyArrayDecoding(t *testing.T) { ensureConnValid(t, conn) } -func TestNullXMismatch(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - type allTypes struct { - s pgx.NullString - i16 pgx.NullInt16 - i32 pgx.NullInt32 - i64 pgx.NullInt64 - f32 pgx.NullFloat32 - f64 pgx.NullFloat64 - b pgx.NullBool - t pgx.NullTime - } - - var actual, zero allTypes - - tests := []struct { - sql string - queryArgs []interface{} - scanArgs []interface{} - err string - }{ - {"select $1::date", []interface{}{pgx.NullString{String: "foo", Valid: true}}, []interface{}{&actual.s}, "invalid input syntax for type date"}, - {"select $1::date", []interface{}{pgx.NullInt16{Int16: 1, Valid: true}}, []interface{}{&actual.i16}, "cannot encode into Oid 1082"}, - {"select $1::date", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, "cannot encode into Oid 1082"}, - {"select $1::date", []interface{}{pgx.NullInt64{Int64: 1, Valid: true}}, []interface{}{&actual.i64}, "cannot encode into Oid 1082"}, - {"select $1::date", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: true}}, []interface{}{&actual.f32}, "cannot encode into Oid 1082"}, - {"select $1::date", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: true}}, []interface{}{&actual.f64}, "cannot encode into Oid 1082"}, - {"select $1::date", []interface{}{pgx.NullBool{Bool: true, Valid: true}}, []interface{}{&actual.b}, "cannot encode into Oid 1082"}, - {"select $1::int4", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, "cannot encode into Oid 23"}, - } - - for i, tt := range tests { - actual = zero - - err := conn.QueryRow(tt.sql, tt.queryArgs...).Scan(tt.scanArgs...) - if err == nil || !strings.Contains(err.Error(), tt.err) { - t.Errorf(`%d. Expected error to contain "%s", but it didn't: %v`, i, tt.err, err) - } - - ensureConnValid(t, conn) - } -} - func TestPointerPointer(t *testing.T) { t.Parallel() From 9cd561f1a5d598e264d545ab14e66869e64ba28c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 13 Mar 2017 20:14:08 -0500 Subject: [PATCH 113/264] Remove unused code --- values.go | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/values.go b/values.go index e2b30087..0749be92 100644 --- a/values.go +++ b/values.go @@ -937,26 +937,3 @@ func decodeRecord(vr *ValueReader) []interface{} { return record } - -func decode1dArrayHeader(vr *ValueReader) (length int32, err error) { - numDims := vr.ReadInt32() - if numDims > 1 { - return 0, ProtocolError(fmt.Sprintf("Expected array to have 0 or 1 dimension, but it had %v", numDims)) - } - - vr.ReadInt32() // 0 if no nulls / 1 if there is one or more nulls -- but we don't care - vr.ReadInt32() // element oid - - if numDims == 0 { - return 0, nil - } - - length = vr.ReadInt32() - - idxFirstElem := vr.ReadInt32() - if idxFirstElem != 1 { - return 0, ProtocolError(fmt.Sprintf("Expected array's first element to start a index 1, but it is %d", idxFirstElem)) - } - - return length, nil -} From 7ec8d7b3434b3e763f6b6e7eacbc4bdb97a707f6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 13 Mar 2017 20:23:17 -0500 Subject: [PATCH 114/264] Fix error message for hstore --- pgtype/hstore.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgtype/hstore.go b/pgtype/hstore.go index 11bfb9a7..c48ae6da 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -29,7 +29,7 @@ func (dst *Hstore) Set(src interface{}) error { } *dst = Hstore{Map: m, Status: Present} default: - return fmt.Errorf("cannot convert %v to Tid", src) + return fmt.Errorf("cannot convert %v to Hstore", src) } return nil From ba5f97176a3e2bf1ff6f97f2a96f83966a64b7bc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 13 Mar 2017 21:34:38 -0500 Subject: [PATCH 115/264] Move not null Oid to pgtype In preparation to ConnInfo implementation. --- conn.go | 40 ++++++------- conn_pool.go | 4 +- conn_test.go | 3 +- copy_to_test.go | 3 +- fastpath.go | 14 +++-- large_objects.go | 10 ++-- messages.go | 6 +- pgtype/oid.go | 58 +++++++++++-------- pgtype/oid_value.go | 45 +++++++++++++++ pgtype/{oid_test.go => oid_value_test.go} | 30 +++++----- query_test.go | 5 +- stdlib/sql.go | 5 +- value_reader.go | 6 +- values.go | 68 ++++------------------- values_test.go | 3 +- 15 files changed, 162 insertions(+), 138 deletions(-) create mode 100644 pgtype/oid_value.go rename pgtype/{oid_test.go => oid_value_test.go} (66%) diff --git a/conn.go b/conn.go index ae83fc77..cf34d267 100644 --- a/conn.go +++ b/conn.go @@ -74,11 +74,11 @@ type Conn struct { lastActivityTime time.Time // the last time the connection was used wbuf [1024]byte writeBuf WriteBuf - pid int32 // backend pid - secretKey int32 // key to use to send a cancel query message to the server - RuntimeParams map[string]string // parameters that have been reported by the server - PgTypes map[Oid]PgType // oids to PgTypes - config ConnConfig // config used when establishing this connection + pid int32 // backend pid + secretKey int32 // key to use to send a cancel query message to the server + RuntimeParams map[string]string // parameters that have been reported by the server + PgTypes map[pgtype.Oid]PgType // oids to PgTypes + config ConnConfig // config used when establishing this connection txStatus byte preparedStatements map[string]*PreparedStatement channels map[string]struct{} @@ -102,7 +102,7 @@ type Conn struct { doneChan chan struct{} closedChan chan error - oidPgtypeValues map[Oid]pgtype.Value + oidPgtypeValues map[pgtype.Oid]pgtype.Value } // PreparedStatement is a description of a prepared statement @@ -110,12 +110,12 @@ type PreparedStatement struct { Name string SQL string FieldDescriptions []FieldDescription - ParameterOids []Oid + ParameterOids []pgtype.Oid } // PrepareExOptions is an option struct that can be passed to PrepareEx type PrepareExOptions struct { - ParameterOids []Oid + ParameterOids []pgtype.Oid } // Notification is a message received from the PostgreSQL LISTEN/NOTIFY system @@ -180,13 +180,13 @@ func Connect(config ConnConfig) (c *Conn, err error) { return connect(config, nil) } -func connect(config ConnConfig, pgTypes map[Oid]PgType) (c *Conn, err error) { +func connect(config ConnConfig, pgTypes map[pgtype.Oid]PgType) (c *Conn, err error) { c = new(Conn) c.config = config if pgTypes != nil { - c.PgTypes = make(map[Oid]PgType, len(pgTypes)) + c.PgTypes = make(map[pgtype.Oid]PgType, len(pgTypes)) for k, v := range pgTypes { c.PgTypes[k] = v } @@ -361,7 +361,7 @@ where ( return err } - c.PgTypes = make(map[Oid]PgType, 128) + c.PgTypes = make(map[pgtype.Oid]PgType, 128) for rows.Next() { var oid uint32 @@ -372,14 +372,14 @@ where ( // The zero value is text format so we ignore any types without a default type format t.DefaultFormat, _ = DefaultTypeFormats[t.Name] - c.PgTypes[Oid(oid)] = t + c.PgTypes[pgtype.Oid(oid)] = t } return rows.Err() } func (c *Conn) loadStaticOidPgtypeValues() { - c.oidPgtypeValues = map[Oid]pgtype.Value{ + c.oidPgtypeValues = map[pgtype.Oid]pgtype.Value{ AclitemArrayOid: &pgtype.AclitemArray{}, AclitemOid: &pgtype.Aclitem{}, BoolArrayOid: &pgtype.BoolArray{}, @@ -407,7 +407,7 @@ func (c *Conn) loadStaticOidPgtypeValues() { JsonbOid: &pgtype.Jsonb{}, JsonOid: &pgtype.Json{}, NameOid: &pgtype.Name{}, - OidOid: &pgtype.Oid{}, + OidOid: &pgtype.OidValue{}, TextArrayOid: &pgtype.TextArray{}, TextOid: &pgtype.Text{}, TidOid: &pgtype.Tid{}, @@ -422,7 +422,7 @@ func (c *Conn) loadStaticOidPgtypeValues() { } func (c *Conn) loadDynamicOidPgtypeValues() { - nameOids := make(map[string]Oid, len(c.PgTypes)) + nameOids := make(map[string]pgtype.Oid, len(c.PgTypes)) for k, v := range c.PgTypes { nameOids[v.Name] = k } @@ -1204,9 +1204,9 @@ func (c *Conn) rxRowDescription(r *msgReader) (fields []FieldDescription) { for i := int16(0); i < fieldCount; i++ { f := &fields[i] f.Name = r.readCString() - f.Table = Oid(r.readUint32()) + f.Table = pgtype.Oid(r.readUint32()) f.AttributeNumber = r.readInt16() - f.DataType = Oid(r.readUint32()) + f.DataType = pgtype.Oid(r.readUint32()) f.DataTypeSize = r.readInt16() f.Modifier = r.readInt32() f.FormatCode = r.readInt16() @@ -1214,7 +1214,7 @@ func (c *Conn) rxRowDescription(r *msgReader) (fields []FieldDescription) { return } -func (c *Conn) rxParameterDescription(r *msgReader) (parameters []Oid) { +func (c *Conn) rxParameterDescription(r *msgReader) (parameters []pgtype.Oid) { // Internally, PostgreSQL supports greater than 64k parameters to a prepared // statement. But the parameter description uses a 16-bit integer for the // count of parameters. If there are more than 64K parameters, this count is @@ -1223,10 +1223,10 @@ func (c *Conn) rxParameterDescription(r *msgReader) (parameters []Oid) { r.readInt16() parameterCount := len(r.msgBody[r.rp:]) / 4 - parameters = make([]Oid, 0, parameterCount) + parameters = make([]pgtype.Oid, 0, parameterCount) for i := 0; i < parameterCount; i++ { - parameters = append(parameters, Oid(r.readUint32())) + parameters = append(parameters, pgtype.Oid(r.readUint32())) } return } diff --git a/conn_pool.go b/conn_pool.go index 3081105c..469f638b 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -5,6 +5,8 @@ import ( "errors" "sync" "time" + + "github.com/jackc/pgx/pgtype" ) type ConnPoolConfig struct { @@ -28,7 +30,7 @@ type ConnPool struct { closed bool preparedStatements map[string]*PreparedStatement acquireTimeout time.Duration - pgTypes map[Oid]PgType + pgTypes map[pgtype.Oid]PgType txAfterClose func(tx *Tx) rowsAfterClose func(rows *Rows) } diff --git a/conn_test.go b/conn_test.go index a6034be6..d863999c 100644 --- a/conn_test.go +++ b/conn_test.go @@ -14,6 +14,7 @@ import ( "time" "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" ) func TestConnect(t *testing.T) { @@ -1042,7 +1043,7 @@ func TestPrepareEx(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - _, err := conn.PrepareEx("test", "select $1", &pgx.PrepareExOptions{ParameterOids: []pgx.Oid{pgx.TextOid}}) + _, err := conn.PrepareEx("test", "select $1", &pgx.PrepareExOptions{ParameterOids: []pgtype.Oid{pgx.TextOid}}) if err != nil { t.Errorf("Unable to prepare statement: %v", err) return diff --git a/copy_to_test.go b/copy_to_test.go index ee96054a..b65ea0f9 100644 --- a/copy_to_test.go +++ b/copy_to_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" ) func TestConnCopyToSmall(t *testing.T) { @@ -125,7 +126,7 @@ func TestConnCopyToJSON(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - for _, oid := range []pgx.Oid{pgx.JsonOid, pgx.JsonbOid} { + for _, oid := range []pgtype.Oid{pgx.JsonOid, pgx.JsonbOid} { if _, ok := conn.PgTypes[oid]; !ok { return // No JSON/JSONB type -- must be running against old PostgreSQL } diff --git a/fastpath.go b/fastpath.go index d58a7754..0caba9d3 100644 --- a/fastpath.go +++ b/fastpath.go @@ -2,29 +2,31 @@ package pgx import ( "encoding/binary" + + "github.com/jackc/pgx/pgtype" ) func newFastpath(cn *Conn) *fastpath { - return &fastpath{cn: cn, fns: make(map[string]Oid)} + return &fastpath{cn: cn, fns: make(map[string]pgtype.Oid)} } type fastpath struct { cn *Conn - fns map[string]Oid + fns map[string]pgtype.Oid } -func (f *fastpath) functionOid(name string) Oid { +func (f *fastpath) functionOid(name string) pgtype.Oid { return f.fns[name] } -func (f *fastpath) addFunction(name string, oid Oid) { +func (f *fastpath) addFunction(name string, oid pgtype.Oid) { f.fns[name] = oid } func (f *fastpath) addFunctions(rows *Rows) error { for rows.Next() { var name string - var oid Oid + var oid pgtype.Oid if err := rows.Scan(&name, &oid); err != nil { return err } @@ -47,7 +49,7 @@ func fpInt64Arg(n int64) fpArg { return res } -func (f *fastpath) Call(oid Oid, args []fpArg) (res []byte, err error) { +func (f *fastpath) Call(oid pgtype.Oid, args []fpArg) (res []byte, err error) { if err := f.cn.ensureConnectionReadyForQuery(); err != nil { return nil, err } diff --git a/large_objects.go b/large_objects.go index 960e1e25..bb65e623 100644 --- a/large_objects.go +++ b/large_objects.go @@ -2,6 +2,8 @@ package pgx import ( "io" + + "github.com/jackc/pgx/pgtype" ) // LargeObjects is a structure used to access the large objects API. It is only @@ -60,19 +62,19 @@ const ( // Create creates a new large object. If id is zero, the server assigns an // unused Oid. -func (o *LargeObjects) Create(id Oid) (Oid, error) { +func (o *LargeObjects) Create(id pgtype.Oid) (pgtype.Oid, error) { newOid, err := fpInt32(o.fp.CallFn("lo_create", []fpArg{fpIntArg(int32(id))})) - return Oid(newOid), err + return pgtype.Oid(newOid), err } // Open opens an existing large object with the given mode. -func (o *LargeObjects) Open(oid Oid, mode LargeObjectMode) (*LargeObject, error) { +func (o *LargeObjects) Open(oid pgtype.Oid, mode LargeObjectMode) (*LargeObject, error) { fd, err := fpInt32(o.fp.CallFn("lo_open", []fpArg{fpIntArg(int32(oid)), fpIntArg(int32(mode))})) return &LargeObject{fd: fd, lo: o}, err } // Unlink removes a large object from the database. -func (o *LargeObjects) Unlink(oid Oid) error { +func (o *LargeObjects) Unlink(oid pgtype.Oid) error { _, err := o.fp.CallFn("lo_unlink", []fpArg{fpIntArg(int32(oid))}) return err } diff --git a/messages.go b/messages.go index 0c14c61d..68faf14c 100644 --- a/messages.go +++ b/messages.go @@ -2,6 +2,8 @@ package pgx import ( "encoding/binary" + + "github.com/jackc/pgx/pgtype" ) const ( @@ -55,9 +57,9 @@ func (s *startupMessage) Bytes() (buf []byte) { type FieldDescription struct { Name string - Table Oid + Table pgtype.Oid AttributeNumber int16 - DataType Oid + DataType pgtype.Oid DataTypeSize int16 DataTypeName string Modifier int32 diff --git a/pgtype/oid.go b/pgtype/oid.go index e57bb2e6..eab1fbcb 100644 --- a/pgtype/oid.go +++ b/pgtype/oid.go @@ -1,45 +1,57 @@ package pgtype import ( + "encoding/binary" + "fmt" "io" + "strconv" + + "github.com/jackc/pgx/pgio" ) // Oid (Object Identifier Type) is, according to // https://www.postgresql.org/docs/current/static/datatype-oid.html, used // internally by PostgreSQL as a primary key for various system tables. It is // currently implemented as an unsigned four-byte integer. Its definition can be -// found in src/include/postgres_ext.h in the PostgreSQL sources. -type Oid pguint32 - -// Set converts from src to dst. Note that as Oid is not a general -// number type Set does not do automatic type conversion as other number -// types do. -func (dst *Oid) Set(src interface{}) error { - return (*pguint32)(dst).Set(src) -} - -func (dst *Oid) Get() interface{} { - return (*pguint32)(dst).Get() -} - -// AssignTo assigns from src to dst. Note that as Oid is not a general number -// type AssignTo does not do automatic type conversion as other number types do. -func (src *Oid) AssignTo(dst interface{}) error { - return (*pguint32)(src).AssignTo(dst) -} +// found in src/include/postgres_ext.h in the PostgreSQL sources. Because it is +// so frequently required to be in a NOT NULL condition Oid cannot be NULL. To +// allow for NULL Oids use OidValue. +type Oid uint32 func (dst *Oid) DecodeText(src []byte) error { - return (*pguint32)(dst).DecodeText(src) + if src == nil { + return fmt.Errorf("cannot decode nil into Oid") + } + + n, err := strconv.ParseUint(string(src), 10, 32) + if err != nil { + return err + } + + *dst = Oid(n) + return nil } func (dst *Oid) DecodeBinary(src []byte) error { - return (*pguint32)(dst).DecodeBinary(src) + if src == nil { + return fmt.Errorf("cannot decode nil into Oid") + } + + if len(src) != 4 { + return fmt.Errorf("invalid length: %v", len(src)) + } + + n := binary.BigEndian.Uint32(src) + *dst = Oid(n) + return nil } func (src Oid) EncodeText(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeText(w) + _, err := io.WriteString(w, strconv.FormatUint(uint64(src), 10)) + return false, err } func (src Oid) EncodeBinary(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeBinary(w) + _, err := pgio.WriteUint32(w, uint32(src)) + return false, err } diff --git a/pgtype/oid_value.go b/pgtype/oid_value.go new file mode 100644 index 00000000..a2b2dcbe --- /dev/null +++ b/pgtype/oid_value.go @@ -0,0 +1,45 @@ +package pgtype + +import ( + "io" +) + +// OidValue (Object Identifier Type) is, according to +// https://www.postgresql.org/docs/current/static/datatype-OidValue.html, used +// internally by PostgreSQL as a primary key for various system tables. It is +// currently implemented as an unsigned four-byte integer. Its definition can be +// found in src/include/postgres_ext.h in the PostgreSQL sources. +type OidValue pguint32 + +// Set converts from src to dst. Note that as OidValue is not a general +// number type Set does not do automatic type conversion as other number +// types do. +func (dst *OidValue) Set(src interface{}) error { + return (*pguint32)(dst).Set(src) +} + +func (dst *OidValue) Get() interface{} { + return (*pguint32)(dst).Get() +} + +// AssignTo assigns from src to dst. Note that as OidValue is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *OidValue) AssignTo(dst interface{}) error { + return (*pguint32)(src).AssignTo(dst) +} + +func (dst *OidValue) DecodeText(src []byte) error { + return (*pguint32)(dst).DecodeText(src) +} + +func (dst *OidValue) DecodeBinary(src []byte) error { + return (*pguint32)(dst).DecodeBinary(src) +} + +func (src OidValue) EncodeText(w io.Writer) (bool, error) { + return (pguint32)(src).EncodeText(w) +} + +func (src OidValue) EncodeBinary(w io.Writer) (bool, error) { + return (pguint32)(src).EncodeBinary(w) +} diff --git a/pgtype/oid_test.go b/pgtype/oid_value_test.go similarity index 66% rename from pgtype/oid_test.go rename to pgtype/oid_value_test.go index b3b96959..21dd6f9d 100644 --- a/pgtype/oid_test.go +++ b/pgtype/oid_value_test.go @@ -7,23 +7,23 @@ import ( "github.com/jackc/pgx/pgtype" ) -func TestOidTranscode(t *testing.T) { +func TestOidValueTranscode(t *testing.T) { testSuccessfulTranscode(t, "oid", []interface{}{ - pgtype.Oid{Uint: 42, Status: pgtype.Present}, - pgtype.Oid{Status: pgtype.Null}, + pgtype.OidValue{Uint: 42, Status: pgtype.Present}, + pgtype.OidValue{Status: pgtype.Null}, }) } -func TestOidSet(t *testing.T) { +func TestOidValueSet(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.Oid + result pgtype.OidValue }{ - {source: uint32(1), result: pgtype.Oid{Uint: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.OidValue{Uint: 1, Status: pgtype.Present}}, } for i, tt := range successfulTests { - var r pgtype.Oid + var r pgtype.OidValue err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -35,17 +35,17 @@ func TestOidSet(t *testing.T) { } } -func TestOidAssignTo(t *testing.T) { +func TestOidValueAssignTo(t *testing.T) { var ui32 uint32 var pui32 *uint32 simpleTests := []struct { - src pgtype.Oid + src pgtype.OidValue dst interface{} expected interface{} }{ - {src: pgtype.Oid{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.Oid{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + {src: pgtype.OidValue{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.OidValue{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, } for i, tt := range simpleTests { @@ -60,11 +60,11 @@ func TestOidAssignTo(t *testing.T) { } pointerAllocTests := []struct { - src pgtype.Oid + src pgtype.OidValue dst interface{} expected interface{} }{ - {src: pgtype.Oid{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + {src: pgtype.OidValue{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, } for i, tt := range pointerAllocTests { @@ -79,10 +79,10 @@ func TestOidAssignTo(t *testing.T) { } errorTests := []struct { - src pgtype.Oid + src pgtype.OidValue dst interface{} }{ - {src: pgtype.Oid{Status: pgtype.Null}, dst: &ui32}, + {src: pgtype.OidValue{Status: pgtype.Null}, dst: &ui32}, } for i, tt := range errorTests { diff --git a/query_test.go b/query_test.go index 8838329c..01889444 100644 --- a/query_test.go +++ b/query_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" "github.com/shopspring/decimal" ) @@ -335,7 +336,7 @@ func TestQueryRowCoreTypes(t *testing.T) { f64 float64 b bool t time.Time - oid pgx.Oid + oid pgtype.Oid } var actual, zero allTypes @@ -353,7 +354,7 @@ func TestQueryRowCoreTypes(t *testing.T) { {"select $1::timestamptz", []interface{}{time.Unix(123, 5000)}, []interface{}{&actual.t}, allTypes{t: time.Unix(123, 5000)}}, {"select $1::timestamp", []interface{}{time.Date(2010, 1, 2, 3, 4, 5, 0, time.UTC)}, []interface{}{&actual.t}, allTypes{t: time.Date(2010, 1, 2, 3, 4, 5, 0, time.UTC)}}, {"select $1::date", []interface{}{time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}, []interface{}{&actual.t}, allTypes{t: time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}}, - {"select $1::oid", []interface{}{pgx.Oid(42)}, []interface{}{&actual.oid}, allTypes{oid: 42}}, + {"select $1::oid", []interface{}{pgtype.Oid(42)}, []interface{}{&actual.oid}, allTypes{oid: 42}}, } for i, tt := range tests { diff --git a/stdlib/sql.go b/stdlib/sql.go index 07cca7e0..7ab4cdbe 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -52,19 +52,20 @@ import ( "io" "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" ) var openFromConnPoolCount int // oids that map to intrinsic database/sql types. These will be allowed to be // binary, anything else will be forced to text format -var databaseSqlOids map[pgx.Oid]bool +var databaseSqlOids map[pgtype.Oid]bool func init() { d := &Driver{} sql.Register("pgx", d) - databaseSqlOids = make(map[pgx.Oid]bool) + databaseSqlOids = make(map[pgtype.Oid]bool) databaseSqlOids[pgx.BoolOid] = true databaseSqlOids[pgx.ByteaOid] = true databaseSqlOids[pgx.Int2Oid] = true diff --git a/value_reader.go b/value_reader.go index 364581c9..fea21d49 100644 --- a/value_reader.go +++ b/value_reader.go @@ -2,6 +2,8 @@ package pgx import ( "errors" + + "github.com/jackc/pgx/pgtype" ) // ValueReader is used by the Scanner interface to decode values. @@ -116,8 +118,8 @@ func (r *ValueReader) ReadInt64() int64 { return r.mr.readInt64() } -func (r *ValueReader) ReadOid() Oid { - return Oid(r.ReadUint32()) +func (r *ValueReader) ReadOid() pgtype.Oid { + return pgtype.Oid(r.ReadUint32()) } // ReadString reads count bytes and returns as string diff --git a/values.go b/values.go index 0749be92..d90c363b 100644 --- a/values.go +++ b/values.go @@ -3,16 +3,12 @@ package pgx import ( "bytes" "database/sql/driver" - "encoding/binary" "encoding/json" "fmt" - "io" "math" "reflect" - "strconv" "time" - "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgtype" ) @@ -80,7 +76,7 @@ const minInt = -maxInt - 1 var DefaultTypeFormats map[string]int16 // internalNativeGoTypeFormats lists the encoding type for native Go types (not handled with Encoder interface) -var internalNativeGoTypeFormats map[Oid]int16 +var internalNativeGoTypeFormats map[pgtype.Oid]int16 func init() { DefaultTypeFormats = map[string]int16{ @@ -119,7 +115,7 @@ func init() { "xid": BinaryFormatCode, } - internalNativeGoTypeFormats = map[Oid]int16{ + internalNativeGoTypeFormats = map[pgtype.Oid]int16{ BoolArrayOid: BinaryFormatCode, BoolOid: BinaryFormatCode, ByteaArrayOid: BinaryFormatCode, @@ -159,54 +155,10 @@ func (e SerializationError) Error() string { return string(e) } -// Oid (Object Identifier Type) is, according to https://www.postgresql.org/docs/current/static/datatype-oid.html, -// used internally by PostgreSQL as a primary key for various system tables. It is currently implemented -// as an unsigned four-byte integer. Its definition can be found in src/include/postgres_ext.h -// in the PostgreSQL sources. Oid cannot be NULL. To allow for NULL Oids use pgtype.Oid. -type Oid uint32 - -func (dst *Oid) DecodeText(src []byte) error { - if src == nil { - return fmt.Errorf("cannot decode nil into Oid") - } - - n, err := strconv.ParseUint(string(src), 10, 32) - if err != nil { - return err - } - - *dst = Oid(n) - return nil -} - -func (dst *Oid) DecodeBinary(src []byte) error { - if src == nil { - return fmt.Errorf("cannot decode nil into Oid") - } - - if len(src) != 4 { - return fmt.Errorf("invalid length: %v", len(src)) - } - - n := binary.BigEndian.Uint32(src) - *dst = Oid(n) - return nil -} - -func (src Oid) EncodeText(w io.Writer) (bool, error) { - _, err := io.WriteString(w, strconv.FormatUint(uint64(src), 10)) - return false, err -} - -func (src Oid) EncodeBinary(w io.Writer) (bool, error) { - _, err := pgio.WriteUint32(w, uint32(src)) - return false, err -} - // Encode encodes arg into wbuf as the type oid. This allows implementations // of the Encoder interface to delegate the actual work of encoding to the // built-in functionality. -func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error { +func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error { if arg == nil { wbuf.WriteInt32(-1) return nil @@ -542,7 +494,7 @@ func decodeFloat4(vr *ValueReader) float32 { return math.Float32frombits(uint32(i)) } -func encodeFloat32(w *WriteBuf, oid Oid, value float32) error { +func encodeFloat32(w *WriteBuf, oid pgtype.Oid, value float32) error { switch oid { case Float4Oid: w.WriteInt32(4) @@ -582,7 +534,7 @@ func decodeFloat8(vr *ValueReader) float64 { return math.Float64frombits(uint64(i)) } -func encodeFloat64(w *WriteBuf, oid Oid, value float64) error { +func encodeFloat64(w *WriteBuf, oid pgtype.Oid, value float64) error { switch oid { case Float8Oid: w.WriteInt32(8) @@ -617,7 +569,7 @@ func decodeTextAllowBinary(vr *ValueReader) string { return vr.ReadString(vr.Len()) } -func encodeString(w *WriteBuf, oid Oid, value string) error { +func encodeString(w *WriteBuf, oid pgtype.Oid, value string) error { w.WriteInt32(int32(len(value))) w.WriteBytes([]byte(value)) return nil @@ -641,7 +593,7 @@ func decodeBytea(vr *ValueReader) []byte { return vr.ReadBytes(vr.Len()) } -func encodeByteSlice(w *WriteBuf, oid Oid, value []byte) error { +func encodeByteSlice(w *WriteBuf, oid pgtype.Oid, value []byte) error { w.WriteInt32(int32(len(value))) w.WriteBytes(value) @@ -665,7 +617,7 @@ func decodeJSON(vr *ValueReader, d interface{}) error { return err } -func encodeJSON(w *WriteBuf, oid Oid, value interface{}) error { +func encodeJSON(w *WriteBuf, oid pgtype.Oid, value interface{}) error { if oid != JsonOid { return fmt.Errorf("cannot encode JSON into oid %v", oid) } @@ -709,7 +661,7 @@ func decodeJSONB(vr *ValueReader, d interface{}) error { return err } -func encodeJSONB(w *WriteBuf, oid Oid, value interface{}) error { +func encodeJSONB(w *WriteBuf, oid pgtype.Oid, value interface{}) error { if oid != JsonbOid { return fmt.Errorf("cannot encode JSON into oid %v", oid) } @@ -757,7 +709,7 @@ func decodeDate(vr *ValueReader) time.Time { return d.Time } -func encodeTime(w *WriteBuf, oid Oid, value time.Time) error { +func encodeTime(w *WriteBuf, oid pgtype.Oid, value time.Time) error { switch oid { case DateOid: var d pgtype.Date diff --git a/values_test.go b/values_test.go index 69a91d4e..e7ae7e1d 100644 --- a/values_test.go +++ b/values_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" ) func TestDateTranscode(t *testing.T) { @@ -83,7 +84,7 @@ func TestJSONAndJSONBTranscode(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - for _, oid := range []pgx.Oid{pgx.JsonOid, pgx.JsonbOid} { + for _, oid := range []pgtype.Oid{pgx.JsonOid, pgx.JsonbOid} { if _, ok := conn.PgTypes[oid]; !ok { return // No JSON/JSONB type -- must be running against old PostgreSQL } From 5eb19bc66a0a3cba288ca510062d251072b0b9fd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 10 Mar 2017 14:20:54 -0600 Subject: [PATCH 116/264] Add *Conn.CopyFrom This replaces *Conn.CopyTo. CopyTo was named incorrectly. In PostgreSQL COPY FROM is the command that copies from the client to the server. In addition, CopyTo does not accept a schema qualified table name. This commit introduces the Identifier type which handles multi-part names and correctly quotes/sanitizes them. The new CopyFrom method uses this Identifier type. Conn.CopyTo is deprecated. refs #243 and #190 --- conn.go | 13 ++ conn_test.go | 37 ++++ copy_from.go | 241 ++++++++++++++++++++++++++ copy_from_test.go | 428 ++++++++++++++++++++++++++++++++++++++++++++++ copy_to.go | 33 +--- copy_to_test.go | 61 ------- 6 files changed, 726 insertions(+), 87 deletions(-) create mode 100644 copy_from.go create mode 100644 copy_from_test.go diff --git a/conn.go b/conn.go index cf34d267..1007811e 100644 --- a/conn.go +++ b/conn.go @@ -146,6 +146,19 @@ func (ct CommandTag) RowsAffected() int64 { return n } +// Identifier a PostgreSQL identifier or name. Identifiers can be composed of +// multiple parts such as ["schema", "table"] or ["table", "column"]. +type Identifier []string + +// Sanitize returns a sanitized string safe for SQL interpolation. +func (ident Identifier) Sanitize() string { + parts := make([]string, len(ident)) + for i := range ident { + parts[i] = `"` + strings.Replace(ident[i], `"`, `""`, -1) + `"` + } + return strings.Join(parts, ".") +} + // ErrNoRows occurs when rows are expected but none are returned. var ErrNoRows = errors.New("no rows in result set") diff --git a/conn_test.go b/conn_test.go index d863999c..e1c780b8 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1541,3 +1541,40 @@ func TestSetLogLevel(t *testing.T) { t.Fatal("Expected logger to be called, but it wasn't") } } + +func TestIdentifierSanitize(t *testing.T) { + t.Parallel() + + tests := []struct { + ident pgx.Identifier + expected string + }{ + { + ident: pgx.Identifier{`foo`}, + expected: `"foo"`, + }, + { + ident: pgx.Identifier{`select`}, + expected: `"select"`, + }, + { + ident: pgx.Identifier{`foo`, `bar`}, + expected: `"foo"."bar"`, + }, + { + ident: pgx.Identifier{`you should " not do this`}, + expected: `"you should "" not do this"`, + }, + { + ident: pgx.Identifier{`you should " not do this`, `please don't`}, + expected: `"you should "" not do this"."please don't"`, + }, + } + + for i, tt := range tests { + qval := tt.ident.Sanitize() + if qval != tt.expected { + t.Errorf("%d. Expected Sanitize %v to return %v but it was %v", i, tt.ident, tt.expected, qval) + } + } +} diff --git a/copy_from.go b/copy_from.go new file mode 100644 index 00000000..1f8a2306 --- /dev/null +++ b/copy_from.go @@ -0,0 +1,241 @@ +package pgx + +import ( + "bytes" + "fmt" +) + +// CopyFromRows returns a CopyFromSource interface over the provided rows slice +// making it usable by *Conn.CopyFrom. +func CopyFromRows(rows [][]interface{}) CopyFromSource { + return ©FromRows{rows: rows, idx: -1} +} + +type copyFromRows struct { + rows [][]interface{} + idx int +} + +func (ctr *copyFromRows) Next() bool { + ctr.idx++ + return ctr.idx < len(ctr.rows) +} + +func (ctr *copyFromRows) Values() ([]interface{}, error) { + return ctr.rows[ctr.idx], nil +} + +func (ctr *copyFromRows) Err() error { + return nil +} + +// CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data. +type CopyFromSource interface { + // Next returns true if there is another row and makes the next row data + // available to Values(). When there are no more rows available or an error + // has occurred it returns false. + Next() bool + + // Values returns the values for the current row. + Values() ([]interface{}, error) + + // Err returns any error that has been encountered by the CopyFromSource. If + // this is not nil *Conn.CopyFrom will abort the copy. + Err() error +} + +type copyFrom struct { + conn *Conn + tableName Identifier + columnNames []string + rowSrc CopyFromSource + readerErrChan chan error +} + +func (ct *copyFrom) readUntilReadyForQuery() { + for { + t, r, err := ct.conn.rxMsg() + if err != nil { + ct.readerErrChan <- err + close(ct.readerErrChan) + return + } + + switch t { + case readyForQuery: + ct.conn.rxReadyForQuery(r) + close(ct.readerErrChan) + return + case commandComplete: + case errorResponse: + ct.readerErrChan <- ct.conn.rxErrorResponse(r) + default: + err = ct.conn.processContextFreeMsg(t, r) + if err != nil { + ct.readerErrChan <- ct.conn.processContextFreeMsg(t, r) + } + } + } +} + +func (ct *copyFrom) waitForReaderDone() error { + var err error + for err = range ct.readerErrChan { + } + return err +} + +func (ct *copyFrom) run() (int, error) { + quotedTableName := ct.tableName.Sanitize() + buf := &bytes.Buffer{} + for i, cn := range ct.columnNames { + if i != 0 { + buf.WriteString(", ") + } + buf.WriteString(quoteIdentifier(cn)) + } + quotedColumnNames := buf.String() + + ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName)) + if err != nil { + return 0, err + } + + err = ct.conn.sendSimpleQuery(fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames)) + if err != nil { + return 0, err + } + + err = ct.conn.readUntilCopyInResponse() + if err != nil { + return 0, err + } + + go ct.readUntilReadyForQuery() + defer ct.waitForReaderDone() + + wbuf := newWriteBuf(ct.conn, copyData) + + wbuf.WriteBytes([]byte("PGCOPY\n\377\r\n\000")) + wbuf.WriteInt32(0) + wbuf.WriteInt32(0) + + var sentCount int + + for ct.rowSrc.Next() { + select { + case err = <-ct.readerErrChan: + return 0, err + default: + } + + if len(wbuf.buf) > 65536 { + wbuf.closeMsg() + _, err = ct.conn.conn.Write(wbuf.buf) + if err != nil { + ct.conn.die(err) + return 0, err + } + + // Directly manipulate wbuf to reset to reuse the same buffer + wbuf.buf = wbuf.buf[0:5] + wbuf.buf[0] = copyData + wbuf.sizeIdx = 1 + } + + sentCount++ + + values, err := ct.rowSrc.Values() + if err != nil { + ct.cancelCopyIn() + return 0, err + } + if len(values) != len(ct.columnNames) { + ct.cancelCopyIn() + return 0, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) + } + + wbuf.WriteInt16(int16(len(ct.columnNames))) + for i, val := range values { + err = Encode(wbuf, ps.FieldDescriptions[i].DataType, val) + if err != nil { + ct.cancelCopyIn() + return 0, err + } + + } + } + + if ct.rowSrc.Err() != nil { + ct.cancelCopyIn() + return 0, ct.rowSrc.Err() + } + + wbuf.WriteInt16(-1) // terminate the copy stream + + wbuf.startMsg(copyDone) + wbuf.closeMsg() + _, err = ct.conn.conn.Write(wbuf.buf) + if err != nil { + ct.conn.die(err) + return 0, err + } + + err = ct.waitForReaderDone() + if err != nil { + return 0, err + } + return sentCount, nil +} + +func (c *Conn) readUntilCopyInResponse() error { + for { + var t byte + var r *msgReader + t, r, err := c.rxMsg() + if err != nil { + return err + } + + switch t { + case copyInResponse: + return nil + default: + err = c.processContextFreeMsg(t, r) + if err != nil { + return err + } + } + } +} + +func (ct *copyFrom) cancelCopyIn() error { + wbuf := newWriteBuf(ct.conn, copyFail) + wbuf.WriteCString("client error: abort") + wbuf.closeMsg() + _, err := ct.conn.conn.Write(wbuf.buf) + if err != nil { + ct.conn.die(err) + return err + } + + return nil +} + +// CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion. +// It returns the number of rows copied and an error. +// +// CopyFrom requires all values use the binary format. Almost all types +// implemented by pgx use the binary format by default. Types implementing +// Encoder can only be used if they encode to the binary format. +func (c *Conn) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) { + ct := ©From{ + conn: c, + tableName: tableName, + columnNames: columnNames, + rowSrc: rowSrc, + readerErrChan: make(chan error), + } + + return ct.run() +} diff --git a/copy_from_test.go b/copy_from_test.go new file mode 100644 index 00000000..54da6989 --- /dev/null +++ b/copy_from_test.go @@ -0,0 +1,428 @@ +package pgx_test + +import ( + "fmt" + "reflect" + "testing" + "time" + + "github.com/jackc/pgx" +) + +func TestConnCopyFromSmall(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g timestamptz + )`) + + inputRows := [][]interface{}{ + {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)}, + {nil, nil, nil, nil, nil, nil, nil}, + } + + copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows)) + if err != nil { + t.Errorf("Unexpected error for CopyFrom: %v", err) + } + if copyCount != len(inputRows) { + t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) + } + + ensureConnValid(t, conn) +} + +func TestConnCopyFromLarge(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g timestamptz, + h bytea + )`) + + inputRows := [][]interface{}{} + + for i := 0; i < 10000; i++ { + inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}}) + } + + copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows)) + if err != nil { + t.Errorf("Unexpected error for CopyFrom: %v", err) + } + if copyCount != len(inputRows) { + t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal") + } + + ensureConnValid(t, conn) +} + +func TestConnCopyFromJSON(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + for _, oid := range []pgx.Oid{pgx.JsonOid, pgx.JsonbOid} { + if _, ok := conn.PgTypes[oid]; !ok { + return // No JSON/JSONB type -- must be running against old PostgreSQL + } + } + + mustExec(t, conn, `create temporary table foo( + a json, + b jsonb + )`) + + inputRows := [][]interface{}{ + {map[string]interface{}{"foo": "bar"}, map[string]interface{}{"bar": "quz"}}, + {nil, nil}, + } + + copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows)) + if err != nil { + t.Errorf("Unexpected error for CopyFrom: %v", err) + } + if copyCount != len(inputRows) { + t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if !reflect.DeepEqual(inputRows, outputRows) { + t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) + } + + ensureConnValid(t, conn) +} + +func TestConnCopyFromFailServerSideMidway(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int4, + b varchar not null + )`) + + inputRows := [][]interface{}{ + {int32(1), "abc"}, + {int32(2), nil}, // this row should trigger a failure + {int32(3), "def"}, + } + + copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows)) + if err == nil { + t.Errorf("Expected CopyFrom return error, but it did not") + } + if _, ok := err.(pgx.PgError); !ok { + t.Errorf("Expected CopyFrom return pgx.PgError, but instead it returned: %v", err) + } + if copyCount != 0 { + t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if len(outputRows) != 0 { + t.Errorf("Expected 0 rows, but got %v", outputRows) + } + + ensureConnValid(t, conn) +} + +type failSource struct { + count int +} + +func (fs *failSource) Next() bool { + time.Sleep(time.Millisecond * 100) + fs.count++ + return fs.count < 100 +} + +func (fs *failSource) Values() ([]interface{}, error) { + if fs.count == 3 { + return []interface{}{nil}, nil + } + return []interface{}{make([]byte, 100000)}, nil +} + +func (fs *failSource) Err() error { + return nil +} + +func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a bytea not null + )`) + + startTime := time.Now() + + copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a"}, &failSource{}) + if err == nil { + t.Errorf("Expected CopyFrom return error, but it did not") + } + if _, ok := err.(pgx.PgError); !ok { + t.Errorf("Expected CopyFrom return pgx.PgError, but instead it returned: %v", err) + } + if copyCount != 0 { + t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) + } + + endTime := time.Now() + copyTime := endTime.Sub(startTime) + if copyTime > time.Second { + t.Errorf("Failing CopyFrom shouldn't have taken so long: %v", copyTime) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if len(outputRows) != 0 { + t.Errorf("Expected 0 rows, but got %v", outputRows) + } + + ensureConnValid(t, conn) +} + +type clientFailSource struct { + count int + err error +} + +func (cfs *clientFailSource) Next() bool { + cfs.count++ + return cfs.count < 100 +} + +func (cfs *clientFailSource) Values() ([]interface{}, error) { + if cfs.count == 3 { + cfs.err = fmt.Errorf("client error") + return nil, cfs.err + } + return []interface{}{make([]byte, 100000)}, nil +} + +func (cfs *clientFailSource) Err() error { + return cfs.err +} + +func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a bytea not null + )`) + + copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a"}, &clientFailSource{}) + if err == nil { + t.Errorf("Expected CopyFrom return error, but it did not") + } + if copyCount != 0 { + t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if len(outputRows) != 0 { + t.Errorf("Expected 0 rows, but got %v", outputRows) + } + + ensureConnValid(t, conn) +} + +type clientFinalErrSource struct { + count int +} + +func (cfs *clientFinalErrSource) Next() bool { + cfs.count++ + return cfs.count < 5 +} + +func (cfs *clientFinalErrSource) Values() ([]interface{}, error) { + return []interface{}{make([]byte, 100000)}, nil +} + +func (cfs *clientFinalErrSource) Err() error { + return fmt.Errorf("final error") +} + +func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a bytea not null + )`) + + copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a"}, &clientFinalErrSource{}) + if err == nil { + t.Errorf("Expected CopyFrom return error, but it did not") + } + if copyCount != 0 { + t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) + } + + rows, err := conn.Query("select * from foo") + if err != nil { + t.Errorf("Unexpected error for Query: %v", err) + } + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + if rows.Err() != nil { + t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) + } + + if len(outputRows) != 0 { + t.Errorf("Expected 0 rows, but got %v", outputRows) + } + + ensureConnValid(t, conn) +} diff --git a/copy_to.go b/copy_to.go index dd70ada3..b6cf16c8 100644 --- a/copy_to.go +++ b/copy_to.go @@ -5,8 +5,8 @@ import ( "fmt" ) -// CopyToRows returns a CopyToSource interface over the provided rows slice -// making it usable by *Conn.CopyTo. +// Deprecated. Use CopyFromRows instead. CopyToRows returns a CopyToSource +// interface over the provided rows slice making it usable by *Conn.CopyTo. func CopyToRows(rows [][]interface{}) CopyToSource { return ©ToRows{rows: rows, idx: -1} } @@ -29,7 +29,8 @@ func (ctr *copyToRows) Err() error { return nil } -// CopyToSource is the interface used by *Conn.CopyTo as the source for copy data. +// Deprecated. Use CopyFromSource instead. CopyToSource is the interface used by +// *Conn.CopyTo as the source for copy data. type CopyToSource interface { // Next returns true if there is another row and makes the next row data // available to Values(). When there are no more rows available or an error @@ -187,27 +188,6 @@ func (ct *copyTo) run() (int, error) { return sentCount, nil } -func (c *Conn) readUntilCopyInResponse() error { - for { - var t byte - var r *msgReader - t, r, err := c.rxMsg() - if err != nil { - return err - } - - switch t { - case copyInResponse: - return nil - default: - err = c.processContextFreeMsg(t, r) - if err != nil { - return err - } - } - } -} - func (ct *copyTo) cancelCopyIn() error { wbuf := newWriteBuf(ct.conn, copyFail) wbuf.WriteCString("client error: abort") @@ -221,8 +201,9 @@ func (ct *copyTo) cancelCopyIn() error { return nil } -// CopyTo uses the PostgreSQL copy protocol to perform bulk data insertion. -// It returns the number of rows copied and an error. +// Deprecated. Use CopyFrom instead. CopyTo uses the PostgreSQL copy protocol to +// perform bulk data insertion. It returns the number of rows copied and an +// error. // // CopyTo requires all values use the binary format. Almost all types // implemented by pgx use the binary format by default. Types implementing diff --git a/copy_to_test.go b/copy_to_test.go index b65ea0f9..afe22ca2 100644 --- a/copy_to_test.go +++ b/copy_to_test.go @@ -1,7 +1,6 @@ package pgx_test import ( - "fmt" "reflect" "testing" "time" @@ -228,27 +227,6 @@ func TestConnCopyToFailServerSideMidway(t *testing.T) { ensureConnValid(t, conn) } -type failSource struct { - count int -} - -func (fs *failSource) Next() bool { - time.Sleep(time.Millisecond * 100) - fs.count++ - return fs.count < 100 -} - -func (fs *failSource) Values() ([]interface{}, error) { - if fs.count == 3 { - return []interface{}{nil}, nil - } - return []interface{}{make([]byte, 100000)}, nil -} - -func (fs *failSource) Err() error { - return nil -} - func TestConnCopyToFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) { t.Parallel() @@ -303,28 +281,6 @@ func TestConnCopyToFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) { ensureConnValid(t, conn) } -type clientFailSource struct { - count int - err error -} - -func (cfs *clientFailSource) Next() bool { - cfs.count++ - return cfs.count < 100 -} - -func (cfs *clientFailSource) Values() ([]interface{}, error) { - if cfs.count == 3 { - cfs.err = fmt.Errorf("client error") - return nil, cfs.err - } - return []interface{}{make([]byte, 100000)}, nil -} - -func (cfs *clientFailSource) Err() error { - return cfs.err -} - func TestConnCopyToCopyToSourceErrorMidway(t *testing.T) { t.Parallel() @@ -368,23 +324,6 @@ func TestConnCopyToCopyToSourceErrorMidway(t *testing.T) { ensureConnValid(t, conn) } -type clientFinalErrSource struct { - count int -} - -func (cfs *clientFinalErrSource) Next() bool { - cfs.count++ - return cfs.count < 5 -} - -func (cfs *clientFinalErrSource) Values() ([]interface{}, error) { - return []interface{}{make([]byte, 100000)}, nil -} - -func (cfs *clientFinalErrSource) Err() error { - return fmt.Errorf("final error") -} - func TestConnCopyToCopyToSourceErrorEnd(t *testing.T) { t.Parallel() From 94749e580f315f3dbe7f5e01ac11e138899b0ff3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 17 Mar 2017 14:17:53 -0500 Subject: [PATCH 117/264] Remove CopyTo --- bench_test.go | 22 +-- conn_pool.go | 6 +- copy_from_test.go | 51 +++---- copy_to.go | 221 ---------------------------- copy_to_test.go | 368 ---------------------------------------------- tx.go | 6 +- v3.md | 2 + 7 files changed, 45 insertions(+), 631 deletions(-) delete mode 100644 copy_to.go delete mode 100644 copy_to_test.go diff --git a/bench_test.go b/bench_test.go index 348c840c..69d17c39 100644 --- a/bench_test.go +++ b/bench_test.go @@ -331,27 +331,27 @@ const benchmarkWriteTableInsertSQL = `insert into t( $13::bool )` -type benchmarkWriteTableCopyToSrc struct { +type benchmarkWriteTableCopyFromSrc struct { count int idx int row []interface{} } -func (s *benchmarkWriteTableCopyToSrc) Next() bool { +func (s *benchmarkWriteTableCopyFromSrc) Next() bool { s.idx++ return s.idx < s.count } -func (s *benchmarkWriteTableCopyToSrc) Values() ([]interface{}, error) { +func (s *benchmarkWriteTableCopyFromSrc) Values() ([]interface{}, error) { return s.row, nil } -func (s *benchmarkWriteTableCopyToSrc) Err() error { +func (s *benchmarkWriteTableCopyFromSrc) Err() error { return nil } -func newBenchmarkWriteTableCopyToSrc(count int) pgx.CopyToSource { - return &benchmarkWriteTableCopyToSrc{ +func newBenchmarkWriteTableCopyFromSrc(count int) pgx.CopyFromSource { + return &benchmarkWriteTableCopyFromSrc{ count: count, row: []interface{}{ "varchar_1", @@ -384,7 +384,7 @@ func benchmarkWriteNRowsViaInsert(b *testing.B, n int) { b.ResetTimer() for i := 0; i < b.N; i++ { - src := newBenchmarkWriteTableCopyToSrc(n) + src := newBenchmarkWriteTableCopyFromSrc(n) tx, err := conn.Begin() if err != nil { @@ -407,7 +407,7 @@ func benchmarkWriteNRowsViaInsert(b *testing.B, n int) { // note this function is only used for benchmarks -- it doesn't escape tableName // or columnNames -func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc pgx.CopyToSource) (int, error) { +func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc pgx.CopyFromSource) (int, error) { maxRowsPerInsert := 65535 / len(columnNames) rowsThisInsert := 0 rowCount := 0 @@ -495,7 +495,7 @@ func benchmarkWriteNRowsViaMultiInsert(b *testing.B, n int) { b.ResetTimer() for i := 0; i < b.N; i++ { - src := newBenchmarkWriteTableCopyToSrc(n) + src := newBenchmarkWriteTableCopyFromSrc(n) _, err := multiInsert(conn, "t", []string{"varchar_1", @@ -527,9 +527,9 @@ func benchmarkWriteNRowsViaCopy(b *testing.B, n int) { b.ResetTimer() for i := 0; i < b.N; i++ { - src := newBenchmarkWriteTableCopyToSrc(n) + src := newBenchmarkWriteTableCopyFromSrc(n) - _, err := conn.CopyTo("t", + _, err := conn.CopyFrom(pgx.Identifier{"t"}, []string{"varchar_1", "varchar_2", "varchar_null_1", diff --git a/conn_pool.go b/conn_pool.go index 469f638b..653ed0ba 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -540,13 +540,13 @@ func (p *ConnPool) BeginEx(txOptions *TxOptions) (*Tx, error) { } } -// CopyTo acquires a connection, delegates the call to that connection, and releases the connection -func (p *ConnPool) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) { +// CopyFrom acquires a connection, delegates the call to that connection, and releases the connection +func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) { c, err := p.Acquire() if err != nil { return 0, err } defer p.Release(c) - return c.CopyTo(tableName, columnNames, rowSrc) + return c.CopyFrom(tableName, columnNames, rowSrc) } diff --git a/copy_from_test.go b/copy_from_test.go index 54da6989..e17575de 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" ) func TestConnCopyFromSmall(t *testing.T) { @@ -26,7 +27,7 @@ func TestConnCopyFromSmall(t *testing.T) { )`) inputRows := [][]interface{}{ - {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)}, + {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)}, {nil, nil, nil, nil, nil, nil, nil}, } @@ -83,7 +84,7 @@ func TestConnCopyFromLarge(t *testing.T) { inputRows := [][]interface{}{} for i := 0; i < 10000; i++ { - inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}}) + inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}}) } copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows)) @@ -125,7 +126,7 @@ func TestConnCopyFromJSON(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - for _, oid := range []pgx.Oid{pgx.JsonOid, pgx.JsonbOid} { + for _, oid := range []pgtype.Oid{pgx.JsonOid, pgx.JsonbOid} { if _, ok := conn.PgTypes[oid]; !ok { return // No JSON/JSONB type -- must be running against old PostgreSQL } @@ -174,6 +175,28 @@ func TestConnCopyFromJSON(t *testing.T) { ensureConnValid(t, conn) } +type clientFailSource struct { + count int + err error +} + +func (cfs *clientFailSource) Next() bool { + cfs.count++ + return cfs.count < 100 +} + +func (cfs *clientFailSource) Values() ([]interface{}, error) { + if cfs.count == 3 { + cfs.err = fmt.Errorf("client error") + return nil, cfs.err + } + return []interface{}{make([]byte, 100000)}, nil +} + +func (cfs *clientFailSource) Err() error { + return cfs.err +} + func TestConnCopyFromFailServerSideMidway(t *testing.T) { t.Parallel() @@ -302,28 +325,6 @@ func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) { ensureConnValid(t, conn) } -type clientFailSource struct { - count int - err error -} - -func (cfs *clientFailSource) Next() bool { - cfs.count++ - return cfs.count < 100 -} - -func (cfs *clientFailSource) Values() ([]interface{}, error) { - if cfs.count == 3 { - cfs.err = fmt.Errorf("client error") - return nil, cfs.err - } - return []interface{}{make([]byte, 100000)}, nil -} - -func (cfs *clientFailSource) Err() error { - return cfs.err -} - func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) { t.Parallel() diff --git a/copy_to.go b/copy_to.go deleted file mode 100644 index b6cf16c8..00000000 --- a/copy_to.go +++ /dev/null @@ -1,221 +0,0 @@ -package pgx - -import ( - "bytes" - "fmt" -) - -// Deprecated. Use CopyFromRows instead. CopyToRows returns a CopyToSource -// interface over the provided rows slice making it usable by *Conn.CopyTo. -func CopyToRows(rows [][]interface{}) CopyToSource { - return ©ToRows{rows: rows, idx: -1} -} - -type copyToRows struct { - rows [][]interface{} - idx int -} - -func (ctr *copyToRows) Next() bool { - ctr.idx++ - return ctr.idx < len(ctr.rows) -} - -func (ctr *copyToRows) Values() ([]interface{}, error) { - return ctr.rows[ctr.idx], nil -} - -func (ctr *copyToRows) Err() error { - return nil -} - -// Deprecated. Use CopyFromSource instead. CopyToSource is the interface used by -// *Conn.CopyTo as the source for copy data. -type CopyToSource interface { - // Next returns true if there is another row and makes the next row data - // available to Values(). When there are no more rows available or an error - // has occurred it returns false. - Next() bool - - // Values returns the values for the current row. - Values() ([]interface{}, error) - - // Err returns any error that has been encountered by the CopyToSource. If - // this is not nil *Conn.CopyTo will abort the copy. - Err() error -} - -type copyTo struct { - conn *Conn - tableName string - columnNames []string - rowSrc CopyToSource - readerErrChan chan error -} - -func (ct *copyTo) readUntilReadyForQuery() { - for { - t, r, err := ct.conn.rxMsg() - if err != nil { - ct.readerErrChan <- err - close(ct.readerErrChan) - return - } - - switch t { - case readyForQuery: - ct.conn.rxReadyForQuery(r) - close(ct.readerErrChan) - return - case errorResponse: - ct.readerErrChan <- ct.conn.rxErrorResponse(r) - default: - err = ct.conn.processContextFreeMsg(t, r) - if err != nil { - ct.readerErrChan <- ct.conn.processContextFreeMsg(t, r) - } - } - } -} - -func (ct *copyTo) waitForReaderDone() error { - var err error - for err = range ct.readerErrChan { - } - return err -} - -func (ct *copyTo) run() (int, error) { - quotedTableName := quoteIdentifier(ct.tableName) - buf := &bytes.Buffer{} - for i, cn := range ct.columnNames { - if i != 0 { - buf.WriteString(", ") - } - buf.WriteString(quoteIdentifier(cn)) - } - quotedColumnNames := buf.String() - - ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName)) - if err != nil { - return 0, err - } - - err = ct.conn.sendSimpleQuery(fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames)) - if err != nil { - return 0, err - } - - err = ct.conn.readUntilCopyInResponse() - if err != nil { - return 0, err - } - - go ct.readUntilReadyForQuery() - defer ct.waitForReaderDone() - - wbuf := newWriteBuf(ct.conn, copyData) - - wbuf.WriteBytes([]byte("PGCOPY\n\377\r\n\000")) - wbuf.WriteInt32(0) - wbuf.WriteInt32(0) - - var sentCount int - - for ct.rowSrc.Next() { - select { - case err = <-ct.readerErrChan: - return 0, err - default: - } - - if len(wbuf.buf) > 65536 { - wbuf.closeMsg() - _, err = ct.conn.conn.Write(wbuf.buf) - if err != nil { - ct.conn.die(err) - return 0, err - } - - // Directly manipulate wbuf to reset to reuse the same buffer - wbuf.buf = wbuf.buf[0:5] - wbuf.buf[0] = copyData - wbuf.sizeIdx = 1 - } - - sentCount++ - - values, err := ct.rowSrc.Values() - if err != nil { - ct.cancelCopyIn() - return 0, err - } - if len(values) != len(ct.columnNames) { - ct.cancelCopyIn() - return 0, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) - } - - wbuf.WriteInt16(int16(len(ct.columnNames))) - for i, val := range values { - err = Encode(wbuf, ps.FieldDescriptions[i].DataType, val) - if err != nil { - ct.cancelCopyIn() - return 0, err - } - - } - } - - if ct.rowSrc.Err() != nil { - ct.cancelCopyIn() - return 0, ct.rowSrc.Err() - } - - wbuf.WriteInt16(-1) // terminate the copy stream - - wbuf.startMsg(copyDone) - wbuf.closeMsg() - _, err = ct.conn.conn.Write(wbuf.buf) - if err != nil { - ct.conn.die(err) - return 0, err - } - - err = ct.waitForReaderDone() - if err != nil { - return 0, err - } - return sentCount, nil -} - -func (ct *copyTo) cancelCopyIn() error { - wbuf := newWriteBuf(ct.conn, copyFail) - wbuf.WriteCString("client error: abort") - wbuf.closeMsg() - _, err := ct.conn.conn.Write(wbuf.buf) - if err != nil { - ct.conn.die(err) - return err - } - - return nil -} - -// Deprecated. Use CopyFrom instead. CopyTo uses the PostgreSQL copy protocol to -// perform bulk data insertion. It returns the number of rows copied and an -// error. -// -// CopyTo requires all values use the binary format. Almost all types -// implemented by pgx use the binary format by default. Types implementing -// Encoder can only be used if they encode to the binary format. -func (c *Conn) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) { - ct := ©To{ - conn: c, - tableName: tableName, - columnNames: columnNames, - rowSrc: rowSrc, - readerErrChan: make(chan error), - } - - return ct.run() -} diff --git a/copy_to_test.go b/copy_to_test.go deleted file mode 100644 index afe22ca2..00000000 --- a/copy_to_test.go +++ /dev/null @@ -1,368 +0,0 @@ -package pgx_test - -import ( - "reflect" - "testing" - "time" - - "github.com/jackc/pgx" - "github.com/jackc/pgx/pgtype" -) - -func TestConnCopyToSmall(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - mustExec(t, conn, `create temporary table foo( - a int2, - b int4, - c int8, - d varchar, - e text, - f date, - g timestamptz - )`) - - inputRows := [][]interface{}{ - {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)}, - {nil, nil, nil, nil, nil, nil, nil}, - } - - copyCount, err := conn.CopyTo("foo", []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyToRows(inputRows)) - if err != nil { - t.Errorf("Unexpected error for CopyTo: %v", err) - } - if copyCount != len(inputRows) { - t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount) - } - - rows, err := conn.Query("select * from foo") - if err != nil { - t.Errorf("Unexpected error for Query: %v", err) - } - - var outputRows [][]interface{} - for rows.Next() { - row, err := rows.Values() - if err != nil { - t.Errorf("Unexpected error for rows.Values(): %v", err) - } - outputRows = append(outputRows, row) - } - - if rows.Err() != nil { - t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) - } - - if !reflect.DeepEqual(inputRows, outputRows) { - t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) - } - - ensureConnValid(t, conn) -} - -func TestConnCopyToLarge(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - mustExec(t, conn, `create temporary table foo( - a int2, - b int4, - c int8, - d varchar, - e text, - f date, - g timestamptz, - h bytea - )`) - - inputRows := [][]interface{}{} - - for i := 0; i < 10000; i++ { - inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local), []byte{111, 111, 111, 111}}) - } - - copyCount, err := conn.CopyTo("foo", []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyToRows(inputRows)) - if err != nil { - t.Errorf("Unexpected error for CopyTo: %v", err) - } - if copyCount != len(inputRows) { - t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount) - } - - rows, err := conn.Query("select * from foo") - if err != nil { - t.Errorf("Unexpected error for Query: %v", err) - } - - var outputRows [][]interface{} - for rows.Next() { - row, err := rows.Values() - if err != nil { - t.Errorf("Unexpected error for rows.Values(): %v", err) - } - outputRows = append(outputRows, row) - } - - if rows.Err() != nil { - t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) - } - - if !reflect.DeepEqual(inputRows, outputRows) { - t.Errorf("Input rows and output rows do not equal") - } - - ensureConnValid(t, conn) -} - -func TestConnCopyToJSON(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - for _, oid := range []pgtype.Oid{pgx.JsonOid, pgx.JsonbOid} { - if _, ok := conn.PgTypes[oid]; !ok { - return // No JSON/JSONB type -- must be running against old PostgreSQL - } - } - - mustExec(t, conn, `create temporary table foo( - a json, - b jsonb - )`) - - inputRows := [][]interface{}{ - {map[string]interface{}{"foo": "bar"}, map[string]interface{}{"bar": "quz"}}, - {nil, nil}, - } - - copyCount, err := conn.CopyTo("foo", []string{"a", "b"}, pgx.CopyToRows(inputRows)) - if err != nil { - t.Errorf("Unexpected error for CopyTo: %v", err) - } - if copyCount != len(inputRows) { - t.Errorf("Expected CopyTo to return %d copied rows, but got %d", len(inputRows), copyCount) - } - - rows, err := conn.Query("select * from foo") - if err != nil { - t.Errorf("Unexpected error for Query: %v", err) - } - - var outputRows [][]interface{} - for rows.Next() { - row, err := rows.Values() - if err != nil { - t.Errorf("Unexpected error for rows.Values(): %v", err) - } - outputRows = append(outputRows, row) - } - - if rows.Err() != nil { - t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) - } - - if !reflect.DeepEqual(inputRows, outputRows) { - t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) - } - - ensureConnValid(t, conn) -} - -func TestConnCopyToFailServerSideMidway(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - mustExec(t, conn, `create temporary table foo( - a int4, - b varchar not null - )`) - - inputRows := [][]interface{}{ - {int32(1), "abc"}, - {int32(2), nil}, // this row should trigger a failure - {int32(3), "def"}, - } - - copyCount, err := conn.CopyTo("foo", []string{"a", "b"}, pgx.CopyToRows(inputRows)) - if err == nil { - t.Errorf("Expected CopyTo return error, but it did not") - } - if _, ok := err.(pgx.PgError); !ok { - t.Errorf("Expected CopyTo return pgx.PgError, but instead it returned: %v", err) - } - if copyCount != 0 { - t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount) - } - - rows, err := conn.Query("select * from foo") - if err != nil { - t.Errorf("Unexpected error for Query: %v", err) - } - - var outputRows [][]interface{} - for rows.Next() { - row, err := rows.Values() - if err != nil { - t.Errorf("Unexpected error for rows.Values(): %v", err) - } - outputRows = append(outputRows, row) - } - - if rows.Err() != nil { - t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) - } - - if len(outputRows) != 0 { - t.Errorf("Expected 0 rows, but got %v", outputRows) - } - - ensureConnValid(t, conn) -} - -func TestConnCopyToFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - mustExec(t, conn, `create temporary table foo( - a bytea not null - )`) - - startTime := time.Now() - - copyCount, err := conn.CopyTo("foo", []string{"a"}, &failSource{}) - if err == nil { - t.Errorf("Expected CopyTo return error, but it did not") - } - if _, ok := err.(pgx.PgError); !ok { - t.Errorf("Expected CopyTo return pgx.PgError, but instead it returned: %v", err) - } - if copyCount != 0 { - t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount) - } - - endTime := time.Now() - copyTime := endTime.Sub(startTime) - if copyTime > time.Second { - t.Errorf("Failing CopyTo shouldn't have taken so long: %v", copyTime) - } - - rows, err := conn.Query("select * from foo") - if err != nil { - t.Errorf("Unexpected error for Query: %v", err) - } - - var outputRows [][]interface{} - for rows.Next() { - row, err := rows.Values() - if err != nil { - t.Errorf("Unexpected error for rows.Values(): %v", err) - } - outputRows = append(outputRows, row) - } - - if rows.Err() != nil { - t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) - } - - if len(outputRows) != 0 { - t.Errorf("Expected 0 rows, but got %v", outputRows) - } - - ensureConnValid(t, conn) -} - -func TestConnCopyToCopyToSourceErrorMidway(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - mustExec(t, conn, `create temporary table foo( - a bytea not null - )`) - - copyCount, err := conn.CopyTo("foo", []string{"a"}, &clientFailSource{}) - if err == nil { - t.Errorf("Expected CopyTo return error, but it did not") - } - if copyCount != 0 { - t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount) - } - - rows, err := conn.Query("select * from foo") - if err != nil { - t.Errorf("Unexpected error for Query: %v", err) - } - - var outputRows [][]interface{} - for rows.Next() { - row, err := rows.Values() - if err != nil { - t.Errorf("Unexpected error for rows.Values(): %v", err) - } - outputRows = append(outputRows, row) - } - - if rows.Err() != nil { - t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) - } - - if len(outputRows) != 0 { - t.Errorf("Expected 0 rows, but got %v", outputRows) - } - - ensureConnValid(t, conn) -} - -func TestConnCopyToCopyToSourceErrorEnd(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - mustExec(t, conn, `create temporary table foo( - a bytea not null - )`) - - copyCount, err := conn.CopyTo("foo", []string{"a"}, &clientFinalErrSource{}) - if err == nil { - t.Errorf("Expected CopyTo return error, but it did not") - } - if copyCount != 0 { - t.Errorf("Expected CopyTo to return 0 copied rows, but got %d", copyCount) - } - - rows, err := conn.Query("select * from foo") - if err != nil { - t.Errorf("Unexpected error for Query: %v", err) - } - - var outputRows [][]interface{} - for rows.Next() { - row, err := rows.Values() - if err != nil { - t.Errorf("Unexpected error for rows.Values(): %v", err) - } - outputRows = append(outputRows, row) - } - - if rows.Err() != nil { - t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) - } - - if len(outputRows) != 0 { - t.Errorf("Expected 0 rows, but got %v", outputRows) - } - - ensureConnValid(t, conn) -} diff --git a/tx.go b/tx.go index a636b364..099ef180 100644 --- a/tx.go +++ b/tx.go @@ -185,13 +185,13 @@ func (tx *Tx) QueryRow(sql string, args ...interface{}) *Row { return (*Row)(rows) } -// CopyTo delegates to the underlying *Conn -func (tx *Tx) CopyTo(tableName string, columnNames []string, rowSrc CopyToSource) (int, error) { +// CopyFrom delegates to the underlying *Conn +func (tx *Tx) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) { if tx.status != TxStatusInProgress { return 0, ErrTxClosed } - return tx.conn.CopyTo(tableName, columnNames, rowSrc) + return tx.conn.CopyFrom(tableName, columnNames, rowSrc) } // Conn returns the *Conn this transaction is using. diff --git a/v3.md b/v3.md index 2f1c353c..6f5fd412 100644 --- a/v3.md +++ b/v3.md @@ -26,6 +26,8 @@ ReplicationConn.WaitForReplicationMessage now takes context.Context instead of t Reject scanning binary format values into a string (e.g. binary encoded timestamptz to string). See https://github.com/jackc/pgx/issues/219 and https://github.com/jackc/pgx/issues/228 +Remove CopyTo + ## TODO / Possible / Investigate Organize errors better From 94d56e8556ead24df412fcbce5184834b56dd370 Mon Sep 17 00:00:00 2001 From: j7b Date: Fri, 17 Mar 2017 16:59:10 +0000 Subject: [PATCH 118/264] Support pgpass --- conn.go | 12 +++++-- pgpass.go | 85 ++++++++++++++++++++++++++++++++++++++++++++++++++ pgpass_test.go | 57 +++++++++++++++++++++++++++++++++ 3 files changed, 151 insertions(+), 3 deletions(-) create mode 100644 pgpass.go create mode 100644 pgpass_test.go diff --git a/conn.go b/conn.go index 1007811e..0c86d169 100644 --- a/conn.go +++ b/conn.go @@ -542,7 +542,9 @@ func ParseURI(uri string) (ConnConfig, error) { cp.RuntimeParams[k] = v[0] } - + if cp.Password == "" { + pgpass(&cp) + } return cp, nil } @@ -595,7 +597,9 @@ func ParseDSN(s string) (ConnConfig, error) { if err != nil { return cp, err } - + if cp.Password == "" { + pgpass(&cp) + } return cp, nil } @@ -658,7 +662,9 @@ func ParseEnvLibpq() (ConnConfig, error) { if appname := os.Getenv("PGAPPNAME"); appname != "" { cc.RuntimeParams["application_name"] = appname } - + if cc.Password == "" { + pgpass(&cc) + } return cc, nil } diff --git a/pgpass.go b/pgpass.go new file mode 100644 index 00000000..b6f028d2 --- /dev/null +++ b/pgpass.go @@ -0,0 +1,85 @@ +package pgx + +import ( + "bufio" + "fmt" + "os" + "os/user" + "path/filepath" + "strings" +) + +func parsepgpass(cfg *ConnConfig, line string) *string { + const ( + backslash = "\r" + colon = "\n" + ) + const ( + host int = iota + port + database + username + pw + ) + line = strings.Replace(line, `\:`, colon, -1) + line = strings.Replace(line, `\\`, backslash, -1) + parts := strings.Split(line, `:`) + if len(parts) != 5 { + return nil + } + for i := range parts { + if parts[i] == `*` { + continue + } + parts[i] = strings.Replace(strings.Replace(parts[i], backslash, `\`, -1), colon, `:`, -1) + switch i { + case host: + if parts[i] != cfg.Host { + return nil + } + case port: + portstr := fmt.Sprintf(`%v`, cfg.Port) + if portstr == "0" { + portstr = "5432" + } + if parts[i] != portstr { + return nil + } + case database: + if parts[i] != cfg.Database { + return nil + } + case username: + if parts[i] != cfg.User { + return nil + } + } + } + return &parts[4] +} + +func pgpass(cfg *ConnConfig) (found bool) { + passfile := os.Getenv("PGPASSFILE") + if passfile == "" { + u, err := user.Current() + if err != nil { + return + } + passfile = filepath.Join(u.HomeDir, ".pgpass") + } + f, err := os.Open(passfile) + if err != nil { + return + } + defer f.Close() + scanner := bufio.NewScanner(f) + var pw *string + for scanner.Scan() { + pw = parsepgpass(cfg, scanner.Text()) + if pw != nil { + cfg.Password = *pw + return true + } + } + return false +} diff --git a/pgpass_test.go b/pgpass_test.go new file mode 100644 index 00000000..f6094c82 --- /dev/null +++ b/pgpass_test.go @@ -0,0 +1,57 @@ +package pgx + +import ( + "fmt" + "io/ioutil" + "os" + "strings" + "testing" +) + +func unescape(s string) string { + s = strings.Replace(s, `\:`, `:`, -1) + s = strings.Replace(s, `\\`, `\`, -1) + return s +} + +var passfile = [][]string{ + []string{"test1", "5432", "larrydb", "larry", "whatstheidea"}, + []string{"test1", "5432", "moedb", "moe", "imbecile"}, + []string{"test1", "5432", "curlydb", "curly", "nyuknyuknyuk"}, + []string{"test2", "5432", "*", "shemp", "heymoe"}, + []string{"test2", "5432", "*", "*", `test\\ing\:`}, +} + +func TestPGPass(t *testing.T) { + tf, err := ioutil.TempFile("", "") + if err != nil { + t.Fatal(err) + } + defer tf.Close() + defer os.Remove(tf.Name()) + os.Setenv("PGPASSFILE", tf.Name()) + for _, l := range passfile { + _, err := fmt.Fprintln(tf, strings.Join(l, `:`)) + if err != nil { + t.Fatal(err) + } + } + if err = tf.Close(); err != nil { + t.Fatal(err) + } + for i, l := range passfile { + cfg := ConnConfig{Host: l[0], Database: l[2], User: l[3]} + found := pgpass(&cfg) + if !found { + t.Fatalf("Entry %v not found", i) + } + if cfg.Password != unescape(l[4]) { + t.Fatalf(`Password mismatch entry %v want %s got %s`, i, unescape(l[4]), cfg.Password) + } + } + cfg := ConnConfig{Host: "derp", Database: "herp", User: "joe"} + found := pgpass(&cfg) + if found { + t.Fatal("bad found") + } +} From 19c668975218ca857f07e0506cdbcaa83f68fb24 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 18 Mar 2017 12:01:16 -0500 Subject: [PATCH 119/264] Add pgtype.Record and prerequisite restructuring Because reading a record type requires the decoder to be able to look up oid to type mapping and types such as hstore have types that are not fixed between different PostgreSQL servers it was necessary to restructure the pgtype system so all encoders and decodes take a *ConnInfo that includes oid/name/type information. --- conn.go | 148 ++++-------- conn_pool.go | 8 +- copy_from_test.go | 5 +- example_custom_type_test.go | 4 +- pgtype/aclitem.go | 4 +- pgtype/aclitem_array.go | 8 +- pgtype/array.go | 4 +- pgtype/bool.go | 8 +- pgtype/bool_array.go | 24 +- pgtype/bytea.go | 8 +- pgtype/bytea_array.go | 24 +- pgtype/cid.go | 16 +- pgtype/cidr.go | 35 +++ pgtype/cidr_array.go | 317 +++++++++++++++++++++++- pgtype/cidr_array_test.go | 164 +++++++++++++ pgtype/database_sql.go | 66 +++++ pgtype/date.go | 11 +- pgtype/date_array.go | 24 +- pgtype/float4.go | 8 +- pgtype/float4_array.go | 24 +- pgtype/float8.go | 8 +- pgtype/float8_array.go | 24 +- pgtype/generic_binary.go | 8 +- pgtype/generic_text.go | 8 +- pgtype/hstore.go | 14 +- pgtype/inet.go | 8 +- pgtype/inet_array.go | 24 +- pgtype/int2.go | 8 +- pgtype/int2_array.go | 24 +- pgtype/int4.go | 8 +- pgtype/int4_array.go | 24 +- pgtype/int8.go | 8 +- pgtype/int8_array.go | 24 +- pgtype/json.go | 12 +- pgtype/jsonb.go | 12 +- pgtype/name.go | 16 +- pgtype/oid.go | 8 +- pgtype/oid_value.go | 16 +- pgtype/pgtype.go | 129 +++++++++- pgtype/pgtype_test.go | 10 +- pgtype/pguint32.go | 8 +- pgtype/qchar.go | 4 +- pgtype/record.go | 123 ++++++++++ pgtype/record_test.go | 150 ++++++++++++ pgtype/text.go | 12 +- pgtype/text_array.go | 24 +- pgtype/tid.go | 8 +- pgtype/timestamp.go | 8 +- pgtype/timestamp_array.go | 24 +- pgtype/timestamptz.go | 8 +- pgtype/timestamptz_array.go | 24 +- pgtype/typed_array.go.erb | 24 +- pgtype/typed_array_gen.sh | 2 + pgtype/unknown.go | 32 +++ pgtype/varchar.go | 40 ++++ pgtype/varchar_array.go | 285 +++++++++++++++++++++- pgtype/varchar_array_test.go | 151 ++++++++++++ pgtype/xid.go | 16 +- query.go | 208 ++++++++-------- query_test.go | 2 +- values.go | 451 +---------------------------------- values_test.go | 265 ++++++++++---------- 62 files changed, 2067 insertions(+), 1105 deletions(-) create mode 100644 pgtype/cidr.go create mode 100644 pgtype/cidr_array_test.go create mode 100644 pgtype/database_sql.go create mode 100644 pgtype/record.go create mode 100644 pgtype/record_test.go create mode 100644 pgtype/unknown.go create mode 100644 pgtype/varchar.go create mode 100644 pgtype/varchar_array_test.go diff --git a/conn.go b/conn.go index 0c86d169..3414d7cf 100644 --- a/conn.go +++ b/conn.go @@ -31,6 +31,20 @@ const ( connStatusBusy ) +// minimalConnInfo has just enough static type information to establish the +// connection and retrieve the type data. +var minimalConnInfo *pgtype.ConnInfo + +func init() { + minimalConnInfo = pgtype.NewConnInfo() + minimalConnInfo.InitializeDataTypes(map[string]pgtype.Oid{ + "int4": Int4Oid, + "name": NameOid, + "oid": OidOid, + "text": TextOid, + }) +} + // DialFunc is a function that can be used to connect to a PostgreSQL server type DialFunc func(network, addr string) (net.Conn, error) @@ -74,11 +88,10 @@ type Conn struct { lastActivityTime time.Time // the last time the connection was used wbuf [1024]byte writeBuf WriteBuf - pid int32 // backend pid - secretKey int32 // key to use to send a cancel query message to the server - RuntimeParams map[string]string // parameters that have been reported by the server - PgTypes map[pgtype.Oid]PgType // oids to PgTypes - config ConnConfig // config used when establishing this connection + pid int32 // backend pid + secretKey int32 // key to use to send a cancel query message to the server + RuntimeParams map[string]string // parameters that have been reported by the server + config ConnConfig // config used when establishing this connection txStatus byte preparedStatements map[string]*PreparedStatement channels map[string]struct{} @@ -102,7 +115,7 @@ type Conn struct { doneChan chan struct{} closedChan chan error - oidPgtypeValues map[pgtype.Oid]pgtype.Value + ConnInfo *pgtype.ConnInfo } // PreparedStatement is a description of a prepared statement @@ -125,12 +138,6 @@ type Notification struct { Payload string } -// PgType is information about PostgreSQL type and how to encode and decode it -type PgType struct { - Name string // name of type e.g. int4, text, date - DefaultFormat int16 // default format (text or binary) this type will be requested in -} - // CommandTag is the result of an Exec function type CommandTag string @@ -190,20 +197,14 @@ func (e ProtocolError) Error() string { // config.Host must be specified. config.User will default to the OS user name. // Other config fields are optional. func Connect(config ConnConfig) (c *Conn, err error) { - return connect(config, nil) + return connect(config, minimalConnInfo) } -func connect(config ConnConfig, pgTypes map[pgtype.Oid]PgType) (c *Conn, err error) { +func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) { c = new(Conn) c.config = config - - if pgTypes != nil { - c.PgTypes = make(map[pgtype.Oid]PgType, len(pgTypes)) - for k, v := range pgTypes { - c.PgTypes[k] = v - } - } + c.ConnInfo = connInfo if c.config.LogLevel != 0 { c.logLevel = c.config.LogLevel @@ -289,8 +290,6 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl } } - c.loadStaticOidPgtypeValues() - c.mr.cr = chunkreader.NewChunkReader(c.conn) msg := newStartupMessage() @@ -344,13 +343,12 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl return nil } - if c.PgTypes == nil { - err = c.loadPgTypes() + if c.ConnInfo == minimalConnInfo { + err = c.initConnInfo() if err != nil { return err } } - c.loadDynamicOidPgtypeValues() return nil default: @@ -361,88 +359,37 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl } } -func (c *Conn) loadPgTypes() error { +func (c *Conn) initConnInfo() error { + nameOids := make(map[string]pgtype.Oid, 256) + rows, err := c.Query(`select t.oid, t.typname from pg_type t left join pg_type base_type on t.typelem=base_type.oid where ( - t.typtype='b' - and (base_type.oid is null or base_type.typtype='b') - ) - or t.typname in('record');`) + t.typtype in('b', 'p') + and (base_type.oid is null or base_type.typtype in('b', 'p')) + )`) if err != nil { return err } - c.PgTypes = make(map[pgtype.Oid]PgType, 128) - for rows.Next() { - var oid uint32 - var t PgType + var oid pgtype.Oid + var name pgtype.Text + if err := rows.Scan(&oid, &name); err != nil { + return err + } - rows.Scan(&oid, &t.Name) - - // The zero value is text format so we ignore any types without a default type format - t.DefaultFormat, _ = DefaultTypeFormats[t.Name] - - c.PgTypes[pgtype.Oid(oid)] = t + nameOids[name.String] = oid } - return rows.Err() -} - -func (c *Conn) loadStaticOidPgtypeValues() { - c.oidPgtypeValues = map[pgtype.Oid]pgtype.Value{ - AclitemArrayOid: &pgtype.AclitemArray{}, - AclitemOid: &pgtype.Aclitem{}, - BoolArrayOid: &pgtype.BoolArray{}, - BoolOid: &pgtype.Bool{}, - ByteaArrayOid: &pgtype.ByteaArray{}, - ByteaOid: &pgtype.Bytea{}, - CharOid: &pgtype.QChar{}, - CidOid: &pgtype.Cid{}, - CidrArrayOid: &pgtype.CidrArray{}, - CidrOid: &pgtype.Inet{}, - DateArrayOid: &pgtype.DateArray{}, - DateOid: &pgtype.Date{}, - Float4ArrayOid: &pgtype.Float4Array{}, - Float4Oid: &pgtype.Float4{}, - Float8ArrayOid: &pgtype.Float8Array{}, - Float8Oid: &pgtype.Float8{}, - InetArrayOid: &pgtype.InetArray{}, - InetOid: &pgtype.Inet{}, - Int2ArrayOid: &pgtype.Int2Array{}, - Int2Oid: &pgtype.Int2{}, - Int4ArrayOid: &pgtype.Int4Array{}, - Int4Oid: &pgtype.Int4{}, - Int8ArrayOid: &pgtype.Int8Array{}, - Int8Oid: &pgtype.Int8{}, - JsonbOid: &pgtype.Jsonb{}, - JsonOid: &pgtype.Json{}, - NameOid: &pgtype.Name{}, - OidOid: &pgtype.OidValue{}, - TextArrayOid: &pgtype.TextArray{}, - TextOid: &pgtype.Text{}, - TidOid: &pgtype.Tid{}, - TimestampArrayOid: &pgtype.TimestampArray{}, - TimestampOid: &pgtype.Timestamp{}, - TimestampTzArrayOid: &pgtype.TimestamptzArray{}, - TimestampTzOid: &pgtype.Timestamptz{}, - VarcharArrayOid: &pgtype.VarcharArray{}, - VarcharOid: &pgtype.Text{}, - XidOid: &pgtype.Xid{}, - } -} - -func (c *Conn) loadDynamicOidPgtypeValues() { - nameOids := make(map[string]pgtype.Oid, len(c.PgTypes)) - for k, v := range c.PgTypes { - nameOids[v.Name] = k + if rows.Err() != nil { + return rows.Err() } - if oid, ok := nameOids["hstore"]; ok { - c.oidPgtypeValues[oid] = &pgtype.Hstore{} - } + c.ConnInfo = pgtype.NewConnInfo() + c.ConnInfo.InitializeDataTypes(nameOids) + return nil } // PID returns the backend PID for this connection. @@ -805,9 +752,16 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared case rowDescription: ps.FieldDescriptions = c.rxRowDescription(r) for i := range ps.FieldDescriptions { - t, _ := c.PgTypes[ps.FieldDescriptions[i].DataType] - ps.FieldDescriptions[i].DataTypeName = t.Name - ps.FieldDescriptions[i].FormatCode = t.DefaultFormat + if dt, ok := c.ConnInfo.DataTypeForOid(ps.FieldDescriptions[i].DataType); ok { + ps.FieldDescriptions[i].DataTypeName = dt.Name + if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { + ps.FieldDescriptions[i].FormatCode = BinaryFormatCode + } else { + ps.FieldDescriptions[i].FormatCode = TextFormatCode + } + } else { + return nil, fmt.Errorf("unknown oid: %d", ps.FieldDescriptions[i].DataType) + } } case readyForQuery: c.rxReadyForQuery(r) diff --git a/conn_pool.go b/conn_pool.go index 653ed0ba..44559ea8 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -30,7 +30,7 @@ type ConnPool struct { closed bool preparedStatements map[string]*PreparedStatement acquireTimeout time.Duration - pgTypes map[pgtype.Oid]PgType + connInfo *pgtype.ConnInfo txAfterClose func(tx *Tx) rowsAfterClose func(rows *Rows) } @@ -49,6 +49,7 @@ var ErrAcquireTimeout = errors.New("timeout acquiring connection from pool") func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) { p = new(ConnPool) p.config = config.ConnConfig + p.connInfo = minimalConnInfo p.maxConnections = config.MaxConnections if p.maxConnections == 0 { p.maxConnections = 5 @@ -95,6 +96,7 @@ func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) { } p.allConnections = append(p.allConnections, c) p.availableConnections = append(p.availableConnections, c) + p.connInfo = c.ConnInfo.DeepCopy() return } @@ -294,7 +296,7 @@ func (p *ConnPool) Stat() (s ConnPoolStat) { } func (p *ConnPool) createConnection() (*Conn, error) { - c, err := connect(p.config, p.pgTypes) + c, err := connect(p.config, p.connInfo) if err != nil { return nil, err } @@ -329,8 +331,6 @@ func (p *ConnPool) createConnectionUnlocked() (*Conn, error) { // afterConnectionCreated executes (if it is) afterConnect() callback and prepares // all the known statements for the new connection. func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) { - p.pgTypes = c.PgTypes - if p.afterConnect != nil { err := p.afterConnect(c) if err != nil { diff --git a/copy_from_test.go b/copy_from_test.go index e17575de..6df4ebb1 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -7,7 +7,6 @@ import ( "time" "github.com/jackc/pgx" - "github.com/jackc/pgx/pgtype" ) func TestConnCopyFromSmall(t *testing.T) { @@ -126,8 +125,8 @@ func TestConnCopyFromJSON(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - for _, oid := range []pgtype.Oid{pgx.JsonOid, pgx.JsonbOid} { - if _, ok := conn.PgTypes[oid]; !ok { + for _, typeName := range []string{"json", "jsonb"} { + if _, ok := conn.ConnInfo.DataTypeForName(typeName); !ok { return // No JSON/JSONB type -- must be running against old PostgreSQL } } diff --git a/example_custom_type_test.go b/example_custom_type_test.go index 71110f85..1c21c7e6 100644 --- a/example_custom_type_test.go +++ b/example_custom_type_test.go @@ -18,7 +18,7 @@ type Point struct { Status pgtype.Status } -func (dst *Point) DecodeText(src []byte) error { +func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error { if src == nil { *dst = Point{Status: pgtype.Null} return nil @@ -44,7 +44,7 @@ func (dst *Point) DecodeText(src []byte) error { return nil } -func (src Point) EncodeText(w io.Writer) (bool, error) { +func (src Point) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { switch src.Status { case pgtype.Null: return true, nil diff --git a/pgtype/aclitem.go b/pgtype/aclitem.go index b8a1549e..f9faab20 100644 --- a/pgtype/aclitem.go +++ b/pgtype/aclitem.go @@ -90,7 +90,7 @@ func (src *Aclitem) AssignTo(dst interface{}) error { return nil } -func (dst *Aclitem) DecodeText(src []byte) error { +func (dst *Aclitem) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Aclitem{Status: Null} return nil @@ -100,7 +100,7 @@ func (dst *Aclitem) DecodeText(src []byte) error { return nil } -func (src Aclitem) EncodeText(w io.Writer) (bool, error) { +func (src Aclitem) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/aclitem_array.go b/pgtype/aclitem_array.go index 5e3647b7..f02d339e 100644 --- a/pgtype/aclitem_array.go +++ b/pgtype/aclitem_array.go @@ -82,7 +82,7 @@ func (src *AclitemArray) AssignTo(dst interface{}) error { return nil } -func (dst *AclitemArray) DecodeText(src []byte) error { +func (dst *AclitemArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = AclitemArray{Status: Null} return nil @@ -104,7 +104,7 @@ func (dst *AclitemArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -118,7 +118,7 @@ func (dst *AclitemArray) DecodeText(src []byte) error { return nil } -func (src *AclitemArray) EncodeText(w io.Writer) (bool, error) { +func (src *AclitemArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -165,7 +165,7 @@ func (src *AclitemArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/array.go b/pgtype/array.go index dff0fe81..9561afe5 100644 --- a/pgtype/array.go +++ b/pgtype/array.go @@ -27,7 +27,7 @@ type ArrayDimension struct { LowerBound int32 } -func (dst *ArrayHeader) DecodeBinary(src []byte) (int, error) { +func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) { if len(src) < 12 { return 0, fmt.Errorf("array header too short: %d", len(src)) } @@ -60,7 +60,7 @@ func (dst *ArrayHeader) DecodeBinary(src []byte) (int, error) { return rp, nil } -func (src *ArrayHeader) EncodeBinary(w io.Writer) error { +func (src *ArrayHeader) EncodeBinary(ci *ConnInfo, w io.Writer) error { _, err := pgio.WriteInt32(w, int32(len(src.Dimensions))) if err != nil { return err diff --git a/pgtype/bool.go b/pgtype/bool.go index a8e9b8e1..87316381 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -79,7 +79,7 @@ func (src *Bool) AssignTo(dst interface{}) error { return nil } -func (dst *Bool) DecodeText(src []byte) error { +func (dst *Bool) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Bool{Status: Null} return nil @@ -93,7 +93,7 @@ func (dst *Bool) DecodeText(src []byte) error { return nil } -func (dst *Bool) DecodeBinary(src []byte) error { +func (dst *Bool) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Bool{Status: Null} return nil @@ -107,7 +107,7 @@ func (dst *Bool) DecodeBinary(src []byte) error { return nil } -func (src Bool) EncodeText(w io.Writer) (bool, error) { +func (src Bool) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -126,7 +126,7 @@ func (src Bool) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Bool) EncodeBinary(w io.Writer) (bool, error) { +func (src Bool) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/bool_array.go b/pgtype/bool_array.go index 4c5fc563..1cb46cf6 100644 --- a/pgtype/bool_array.go +++ b/pgtype/bool_array.go @@ -83,7 +83,7 @@ func (src *BoolArray) AssignTo(dst interface{}) error { return nil } -func (dst *BoolArray) DecodeText(src []byte) error { +func (dst *BoolArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = BoolArray{Status: Null} return nil @@ -105,7 +105,7 @@ func (dst *BoolArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -119,14 +119,14 @@ func (dst *BoolArray) DecodeText(src []byte) error { return nil } -func (dst *BoolArray) DecodeBinary(src []byte) error { +func (dst *BoolArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = BoolArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -151,7 +151,7 @@ func (dst *BoolArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -161,7 +161,7 @@ func (dst *BoolArray) DecodeBinary(src []byte) error { return nil } -func (src *BoolArray) EncodeText(w io.Writer) (bool, error) { +func (src *BoolArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -208,7 +208,7 @@ func (src *BoolArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -237,11 +237,11 @@ func (src *BoolArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *BoolArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, BoolOid) +func (src *BoolArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, BoolOid) } -func (src *BoolArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *BoolArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -261,7 +261,7 @@ func (src *BoolArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -271,7 +271,7 @@ func (src *BoolArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/bytea.go b/pgtype/bytea.go index 5df05360..dc1e9c07 100644 --- a/pgtype/bytea.go +++ b/pgtype/bytea.go @@ -78,7 +78,7 @@ func (src *Bytea) AssignTo(dst interface{}) error { // DecodeText only supports the hex format. This has been the default since // PostgreSQL 9.0. -func (dst *Bytea) DecodeText(src []byte) error { +func (dst *Bytea) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Bytea{Status: Null} return nil @@ -98,7 +98,7 @@ func (dst *Bytea) DecodeText(src []byte) error { return nil } -func (dst *Bytea) DecodeBinary(src []byte) error { +func (dst *Bytea) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Bytea{Status: Null} return nil @@ -111,7 +111,7 @@ func (dst *Bytea) DecodeBinary(src []byte) error { return nil } -func (src Bytea) EncodeText(w io.Writer) (bool, error) { +func (src Bytea) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -128,7 +128,7 @@ func (src Bytea) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Bytea) EncodeBinary(w io.Writer) (bool, error) { +func (src Bytea) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/bytea_array.go b/pgtype/bytea_array.go index c6f676a4..30405509 100644 --- a/pgtype/bytea_array.go +++ b/pgtype/bytea_array.go @@ -83,7 +83,7 @@ func (src *ByteaArray) AssignTo(dst interface{}) error { return nil } -func (dst *ByteaArray) DecodeText(src []byte) error { +func (dst *ByteaArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = ByteaArray{Status: Null} return nil @@ -105,7 +105,7 @@ func (dst *ByteaArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -119,14 +119,14 @@ func (dst *ByteaArray) DecodeText(src []byte) error { return nil } -func (dst *ByteaArray) DecodeBinary(src []byte) error { +func (dst *ByteaArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = ByteaArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -151,7 +151,7 @@ func (dst *ByteaArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -161,7 +161,7 @@ func (dst *ByteaArray) DecodeBinary(src []byte) error { return nil } -func (src *ByteaArray) EncodeText(w io.Writer) (bool, error) { +func (src *ByteaArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -208,7 +208,7 @@ func (src *ByteaArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -237,11 +237,11 @@ func (src *ByteaArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *ByteaArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, ByteaOid) +func (src *ByteaArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, ByteaOid) } -func (src *ByteaArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *ByteaArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -261,7 +261,7 @@ func (src *ByteaArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -271,7 +271,7 @@ func (src *ByteaArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/cid.go b/pgtype/cid.go index 20957f36..d86e8063 100644 --- a/pgtype/cid.go +++ b/pgtype/cid.go @@ -34,18 +34,18 @@ func (src *Cid) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *Cid) DecodeText(src []byte) error { - return (*pguint32)(dst).DecodeText(src) +func (dst *Cid) DecodeText(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeText(ci, src) } -func (dst *Cid) DecodeBinary(src []byte) error { - return (*pguint32)(dst).DecodeBinary(src) +func (dst *Cid) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src Cid) EncodeText(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeText(w) +func (src Cid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeText(ci, w) } -func (src Cid) EncodeBinary(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeBinary(w) +func (src Cid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeBinary(ci, w) } diff --git a/pgtype/cidr.go b/pgtype/cidr.go new file mode 100644 index 00000000..463b279d --- /dev/null +++ b/pgtype/cidr.go @@ -0,0 +1,35 @@ +package pgtype + +import ( + "io" +) + +type Cidr Inet + +func (dst *Cidr) Set(src interface{}) error { + return (*Inet)(dst).Set(src) +} + +func (dst *Cidr) Get() interface{} { + return (*Inet)(dst).Get() +} + +func (src *Cidr) AssignTo(dst interface{}) error { + return (*Inet)(src).AssignTo(dst) +} + +func (dst *Cidr) DecodeText(ci *ConnInfo, src []byte) error { + return (*Inet)(dst).DecodeText(ci, src) +} + +func (dst *Cidr) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Inet)(dst).DecodeBinary(ci, src) +} + +func (src Cidr) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (Inet)(src).EncodeText(ci, w) +} + +func (src Cidr) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (Inet)(src).EncodeBinary(ci, w) +} diff --git a/pgtype/cidr_array.go b/pgtype/cidr_array.go index c30c53d3..32d2e7bf 100644 --- a/pgtype/cidr_array.go +++ b/pgtype/cidr_array.go @@ -1,35 +1,328 @@ package pgtype import ( + "bytes" + "encoding/binary" + "fmt" "io" + "net" + + "github.com/jackc/pgx/pgio" ) -type CidrArray InetArray +type CidrArray struct { + Elements []Cidr + Dimensions []ArrayDimension + Status Status +} func (dst *CidrArray) Set(src interface{}) error { - return (*InetArray)(dst).Set(src) + switch value := src.(type) { + + case []*net.IPNet: + if value == nil { + *dst = CidrArray{Status: Null} + } else if len(value) == 0 { + *dst = CidrArray{Status: Present} + } else { + elements := make([]Cidr, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = CidrArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []net.IP: + if value == nil { + *dst = CidrArray{Status: Null} + } else if len(value) == 0 { + *dst = CidrArray{Status: Present} + } else { + elements := make([]Cidr, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = CidrArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Cidr", value) + } + + return nil } func (dst *CidrArray) Get() interface{} { - return (*InetArray)(dst).Get() + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } } func (src *CidrArray) AssignTo(dst interface{}) error { - return (*InetArray)(src).AssignTo(dst) + switch v := dst.(type) { + + case *[]*net.IPNet: + if src.Status == Present { + *v = make([]*net.IPNet, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + case *[]net.IP: + if src.Status == Present { + *v = make([]net.IP, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil } -func (dst *CidrArray) DecodeText(src []byte) error { - return (*InetArray)(dst).DecodeText(src) +func (dst *CidrArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = CidrArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Cidr + + if len(uta.Elements) > 0 { + elements = make([]Cidr, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Cidr + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = CidrArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil } -func (dst *CidrArray) DecodeBinary(src []byte) error { - return (*InetArray)(dst).DecodeBinary(src) +func (dst *CidrArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = CidrArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = CidrArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Cidr, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = CidrArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil } -func (src *CidrArray) EncodeText(w io.Writer) (bool, error) { - return (*InetArray)(src).EncodeText(w) +func (src *CidrArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if len(src.Dimensions) == 0 { + _, err := io.WriteString(w, "{}") + return false, err + } + + err := EncodeTextArrayDimensions(w, src.Dimensions) + if err != nil { + return false, err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(w, ',') + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(w, '{') + if err != nil { + return false, err + } + } + } + + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else { + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(w, '}') + if err != nil { + return false, err + } + } + } + } + + return false, nil } -func (src *CidrArray) EncodeBinary(w io.Writer) (bool, error) { - return (*InetArray)(src).encodeBinary(w, CidrOid) +func (src *CidrArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, CidrOid) +} + +func (src *CidrArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + arrayHeader := ArrayHeader{ + ElementOid: elementOid, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(ci, w) + if err != nil { + return false, err + } + + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } + } + } + + return false, err } diff --git a/pgtype/cidr_array_test.go b/pgtype/cidr_array_test.go new file mode 100644 index 00000000..ec105914 --- /dev/null +++ b/pgtype/cidr_array_test.go @@ -0,0 +1,164 @@ +package pgtype_test + +import ( + "net" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestCidrArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "cidr[]", []interface{}{ + &pgtype.CidrArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.CidrArray{ + Elements: []pgtype.Cidr{ + pgtype.Cidr{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Cidr{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.CidrArray{Status: pgtype.Null}, + &pgtype.CidrArray{ + Elements: []pgtype.Cidr{ + pgtype.Cidr{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + pgtype.Cidr{Status: pgtype.Null}, + pgtype.Cidr{IPNet: mustParseCidr(t, "255.0.0.0/8"), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.CidrArray{ + Elements: []pgtype.Cidr{ + pgtype.Cidr{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestCidrArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.CidrArray + }{ + { + source: []*net.IPNet{mustParseCidr(t, "127.0.0.1/32")}, + result: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]*net.IPNet)(nil)), + result: pgtype.CidrArray{Status: pgtype.Null}, + }, + { + source: []net.IP{mustParseCidr(t, "127.0.0.1/32").IP}, + result: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]net.IP)(nil)), + result: pgtype.CidrArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.CidrArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestCidrArrayAssignTo(t *testing.T) { + var ipnetSlice []*net.IPNet + var ipSlice []net.IP + + simpleTests := []struct { + src pgtype.CidrArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipnetSlice, + expected: []*net.IPNet{mustParseCidr(t, "127.0.0.1/32")}, + }, + { + src: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipnetSlice, + expected: []*net.IPNet{nil}, + }, + { + src: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipSlice, + expected: []net.IP{mustParseCidr(t, "127.0.0.1/32").IP}, + }, + { + src: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipSlice, + expected: []net.IP{nil}, + }, + { + src: pgtype.CidrArray{Status: pgtype.Null}, + dst: &ipnetSlice, + expected: (([]*net.IPNet)(nil)), + }, + { + src: pgtype.CidrArray{Status: pgtype.Null}, + dst: &ipSlice, + expected: (([]net.IP)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/database_sql.go b/pgtype/database_sql.go new file mode 100644 index 00000000..969d6542 --- /dev/null +++ b/pgtype/database_sql.go @@ -0,0 +1,66 @@ +package pgtype + +import ( + "bytes" + "errors" +) + +func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) { + switch src := src.(type) { + case *Bool: + return src.Bool, nil + case *Bytea: + return src.Bytes, nil + case *Date: + if src.InfinityModifier == None { + return src.Time, nil + } + case *Float4: + return float64(src.Float), nil + case *Float8: + return src.Float, nil + case *GenericBinary: + return src.Bytes, nil + case *GenericText: + return src.String, nil + case *Int2: + return int64(src.Int), nil + case *Int4: + return int64(src.Int), nil + case *Int8: + return int64(src.Int), nil + case *Text: + return src.String, nil + case *Timestamp: + if src.InfinityModifier == None { + return src.Time, nil + } + case *Timestamptz: + if src.InfinityModifier == None { + return src.Time, nil + } + case *Unknown: + return src.String, nil + case *Varchar: + return src.String, nil + } + + buf := &bytes.Buffer{} + if textEncoder, ok := src.(TextEncoder); ok { + _, err := textEncoder.EncodeText(ci, buf) + if err != nil { + return nil, err + } + return buf.String(), nil + } + + if binaryEncoder, ok := src.(BinaryEncoder); ok { + _, err := binaryEncoder.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + return buf.Bytes(), nil + } + + return nil, errors.New("cannot convert to database/sql compatible value") +} diff --git a/pgtype/date.go b/pgtype/date.go index d0481637..b6cc8329 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -38,6 +38,9 @@ func (dst *Date) Set(src interface{}) error { func (dst *Date) Get() interface{} { switch dst.Status { case Present: + if dst.InfinityModifier != None { + return dst.InfinityModifier + } return dst.Time case Null: return nil @@ -76,7 +79,7 @@ func (src *Date) AssignTo(dst interface{}) error { return nil } -func (dst *Date) DecodeText(src []byte) error { +func (dst *Date) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Date{Status: Null} return nil @@ -100,7 +103,7 @@ func (dst *Date) DecodeText(src []byte) error { return nil } -func (dst *Date) DecodeBinary(src []byte) error { +func (dst *Date) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Date{Status: Null} return nil @@ -125,7 +128,7 @@ func (dst *Date) DecodeBinary(src []byte) error { return nil } -func (src Date) EncodeText(w io.Writer) (bool, error) { +func (src Date) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -148,7 +151,7 @@ func (src Date) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Date) EncodeBinary(w io.Writer) (bool, error) { +func (src Date) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/date_array.go b/pgtype/date_array.go index 7f602d83..ba68d561 100644 --- a/pgtype/date_array.go +++ b/pgtype/date_array.go @@ -84,7 +84,7 @@ func (src *DateArray) AssignTo(dst interface{}) error { return nil } -func (dst *DateArray) DecodeText(src []byte) error { +func (dst *DateArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = DateArray{Status: Null} return nil @@ -106,7 +106,7 @@ func (dst *DateArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -120,14 +120,14 @@ func (dst *DateArray) DecodeText(src []byte) error { return nil } -func (dst *DateArray) DecodeBinary(src []byte) error { +func (dst *DateArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = DateArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -152,7 +152,7 @@ func (dst *DateArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -162,7 +162,7 @@ func (dst *DateArray) DecodeBinary(src []byte) error { return nil } -func (src *DateArray) EncodeText(w io.Writer) (bool, error) { +func (src *DateArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -209,7 +209,7 @@ func (src *DateArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -238,11 +238,11 @@ func (src *DateArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *DateArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, DateOid) +func (src *DateArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, DateOid) } -func (src *DateArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *DateArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -262,7 +262,7 @@ func (src *DateArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -272,7 +272,7 @@ func (src *DateArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/float4.go b/pgtype/float4.go index 053af44b..94b7b7a1 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -102,7 +102,7 @@ func (src *Float4) AssignTo(dst interface{}) error { return float64AssignTo(float64(src.Float), src.Status, dst) } -func (dst *Float4) DecodeText(src []byte) error { +func (dst *Float4) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float4{Status: Null} return nil @@ -117,7 +117,7 @@ func (dst *Float4) DecodeText(src []byte) error { return nil } -func (dst *Float4) DecodeBinary(src []byte) error { +func (dst *Float4) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float4{Status: Null} return nil @@ -133,7 +133,7 @@ func (dst *Float4) DecodeBinary(src []byte) error { return nil } -func (src Float4) EncodeText(w io.Writer) (bool, error) { +func (src Float4) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -145,7 +145,7 @@ func (src Float4) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Float4) EncodeBinary(w io.Writer) (bool, error) { +func (src Float4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/float4_array.go b/pgtype/float4_array.go index 0e815e0b..40152bcf 100644 --- a/pgtype/float4_array.go +++ b/pgtype/float4_array.go @@ -83,7 +83,7 @@ func (src *Float4Array) AssignTo(dst interface{}) error { return nil } -func (dst *Float4Array) DecodeText(src []byte) error { +func (dst *Float4Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float4Array{Status: Null} return nil @@ -105,7 +105,7 @@ func (dst *Float4Array) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -119,14 +119,14 @@ func (dst *Float4Array) DecodeText(src []byte) error { return nil } -func (dst *Float4Array) DecodeBinary(src []byte) error { +func (dst *Float4Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float4Array{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -151,7 +151,7 @@ func (dst *Float4Array) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -161,7 +161,7 @@ func (dst *Float4Array) DecodeBinary(src []byte) error { return nil } -func (src *Float4Array) EncodeText(w io.Writer) (bool, error) { +func (src *Float4Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -208,7 +208,7 @@ func (src *Float4Array) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -237,11 +237,11 @@ func (src *Float4Array) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *Float4Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Float4Oid) +func (src *Float4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, Float4Oid) } -func (src *Float4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *Float4Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -261,7 +261,7 @@ func (src *Float4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -271,7 +271,7 @@ func (src *Float4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/float8.go b/pgtype/float8.go index 635b7a09..dd2d592d 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -92,7 +92,7 @@ func (src *Float8) AssignTo(dst interface{}) error { return float64AssignTo(src.Float, src.Status, dst) } -func (dst *Float8) DecodeText(src []byte) error { +func (dst *Float8) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float8{Status: Null} return nil @@ -107,7 +107,7 @@ func (dst *Float8) DecodeText(src []byte) error { return nil } -func (dst *Float8) DecodeBinary(src []byte) error { +func (dst *Float8) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float8{Status: Null} return nil @@ -123,7 +123,7 @@ func (dst *Float8) DecodeBinary(src []byte) error { return nil } -func (src Float8) EncodeText(w io.Writer) (bool, error) { +func (src Float8) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -135,7 +135,7 @@ func (src Float8) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Float8) EncodeBinary(w io.Writer) (bool, error) { +func (src Float8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/float8_array.go b/pgtype/float8_array.go index 811c5a1f..d0ee0d70 100644 --- a/pgtype/float8_array.go +++ b/pgtype/float8_array.go @@ -83,7 +83,7 @@ func (src *Float8Array) AssignTo(dst interface{}) error { return nil } -func (dst *Float8Array) DecodeText(src []byte) error { +func (dst *Float8Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float8Array{Status: Null} return nil @@ -105,7 +105,7 @@ func (dst *Float8Array) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -119,14 +119,14 @@ func (dst *Float8Array) DecodeText(src []byte) error { return nil } -func (dst *Float8Array) DecodeBinary(src []byte) error { +func (dst *Float8Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float8Array{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -151,7 +151,7 @@ func (dst *Float8Array) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -161,7 +161,7 @@ func (dst *Float8Array) DecodeBinary(src []byte) error { return nil } -func (src *Float8Array) EncodeText(w io.Writer) (bool, error) { +func (src *Float8Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -208,7 +208,7 @@ func (src *Float8Array) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -237,11 +237,11 @@ func (src *Float8Array) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *Float8Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Float8Oid) +func (src *Float8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, Float8Oid) } -func (src *Float8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *Float8Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -261,7 +261,7 @@ func (src *Float8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -271,7 +271,7 @@ func (src *Float8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/generic_binary.go b/pgtype/generic_binary.go index ac35ea60..aa28bb62 100644 --- a/pgtype/generic_binary.go +++ b/pgtype/generic_binary.go @@ -20,10 +20,10 @@ func (src *GenericBinary) AssignTo(dst interface{}) error { return (*Bytea)(src).AssignTo(dst) } -func (dst *GenericBinary) DecodeBinary(src []byte) error { - return (*Bytea)(dst).DecodeBinary(src) +func (dst *GenericBinary) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Bytea)(dst).DecodeBinary(ci, src) } -func (src GenericBinary) EncodeBinary(w io.Writer) (bool, error) { - return (Bytea)(src).EncodeBinary(w) +func (src GenericBinary) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (Bytea)(src).EncodeBinary(ci, w) } diff --git a/pgtype/generic_text.go b/pgtype/generic_text.go index 19f41059..bd75e0d0 100644 --- a/pgtype/generic_text.go +++ b/pgtype/generic_text.go @@ -20,10 +20,10 @@ func (src *GenericText) AssignTo(dst interface{}) error { return (*Text)(src).AssignTo(dst) } -func (dst *GenericText) DecodeText(src []byte) error { - return (*Text)(dst).DecodeText(src) +func (dst *GenericText) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) } -func (src GenericText) EncodeText(w io.Writer) (bool, error) { - return (Text)(src).EncodeText(w) +func (src GenericText) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (Text)(src).EncodeText(ci, w) } diff --git a/pgtype/hstore.go b/pgtype/hstore.go index c48ae6da..d771d6e6 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -70,7 +70,7 @@ func (src *Hstore) AssignTo(dst interface{}) error { return nil } -func (dst *Hstore) DecodeText(src []byte) error { +func (dst *Hstore) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Hstore{Status: Null} return nil @@ -90,7 +90,7 @@ func (dst *Hstore) DecodeText(src []byte) error { return nil } -func (dst *Hstore) DecodeBinary(src []byte) error { +func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Hstore{Status: Null} return nil @@ -132,7 +132,7 @@ func (dst *Hstore) DecodeBinary(src []byte) error { rp += valueLen var value Text - err := value.DecodeBinary(valueBuf) + err := value.DecodeBinary(ci, valueBuf) if err != nil { return err } @@ -144,7 +144,7 @@ func (dst *Hstore) DecodeBinary(src []byte) error { return nil } -func (src Hstore) EncodeText(w io.Writer) (bool, error) { +func (src Hstore) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -175,7 +175,7 @@ func (src Hstore) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := v.EncodeText(elemBuf) + null, err := v.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -196,7 +196,7 @@ func (src Hstore) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src Hstore) EncodeBinary(w io.Writer) (bool, error) { +func (src Hstore) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -220,7 +220,7 @@ func (src Hstore) EncodeBinary(w io.Writer) (bool, error) { return false, err } - null, err := v.EncodeText(elemBuf) + null, err := v.EncodeText(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/inet.go b/pgtype/inet.go index 87d675f9..b83bd1c9 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -100,7 +100,7 @@ func (src *Inet) AssignTo(dst interface{}) error { return nil } -func (dst *Inet) DecodeText(src []byte) error { +func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Inet{Status: Null} return nil @@ -128,7 +128,7 @@ func (dst *Inet) DecodeText(src []byte) error { return nil } -func (dst *Inet) DecodeBinary(src []byte) error { +func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Inet{Status: Null} return nil @@ -153,7 +153,7 @@ func (dst *Inet) DecodeBinary(src []byte) error { return nil } -func (src Inet) EncodeText(w io.Writer) (bool, error) { +func (src Inet) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src Inet) EncodeText(w io.Writer) (bool, error) { } // EncodeBinary encodes src into w. -func (src Inet) EncodeBinary(w io.Writer) (bool, error) { +func (src Inet) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/inet_array.go b/pgtype/inet_array.go index 1d1cf3fd..6cad82e7 100644 --- a/pgtype/inet_array.go +++ b/pgtype/inet_array.go @@ -115,7 +115,7 @@ func (src *InetArray) AssignTo(dst interface{}) error { return nil } -func (dst *InetArray) DecodeText(src []byte) error { +func (dst *InetArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = InetArray{Status: Null} return nil @@ -137,7 +137,7 @@ func (dst *InetArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -151,14 +151,14 @@ func (dst *InetArray) DecodeText(src []byte) error { return nil } -func (dst *InetArray) DecodeBinary(src []byte) error { +func (dst *InetArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = InetArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -183,7 +183,7 @@ func (dst *InetArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -193,7 +193,7 @@ func (dst *InetArray) DecodeBinary(src []byte) error { return nil } -func (src *InetArray) EncodeText(w io.Writer) (bool, error) { +func (src *InetArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -240,7 +240,7 @@ func (src *InetArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -269,11 +269,11 @@ func (src *InetArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *InetArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, InetOid) +func (src *InetArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, InetOid) } -func (src *InetArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *InetArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -293,7 +293,7 @@ func (src *InetArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -303,7 +303,7 @@ func (src *InetArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/int2.go b/pgtype/int2.go index 62e1bc69..6996cd4f 100644 --- a/pgtype/int2.go +++ b/pgtype/int2.go @@ -98,7 +98,7 @@ func (src *Int2) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *Int2) DecodeText(src []byte) error { +func (dst *Int2) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int2{Status: Null} return nil @@ -113,7 +113,7 @@ func (dst *Int2) DecodeText(src []byte) error { return nil } -func (dst *Int2) DecodeBinary(src []byte) error { +func (dst *Int2) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int2{Status: Null} return nil @@ -128,7 +128,7 @@ func (dst *Int2) DecodeBinary(src []byte) error { return nil } -func (src Int2) EncodeText(w io.Writer) (bool, error) { +func (src Int2) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -140,7 +140,7 @@ func (src Int2) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Int2) EncodeBinary(w io.Writer) (bool, error) { +func (src Int2) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/int2_array.go b/pgtype/int2_array.go index 3d06c018..2bf1c237 100644 --- a/pgtype/int2_array.go +++ b/pgtype/int2_array.go @@ -114,7 +114,7 @@ func (src *Int2Array) AssignTo(dst interface{}) error { return nil } -func (dst *Int2Array) DecodeText(src []byte) error { +func (dst *Int2Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int2Array{Status: Null} return nil @@ -136,7 +136,7 @@ func (dst *Int2Array) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -150,14 +150,14 @@ func (dst *Int2Array) DecodeText(src []byte) error { return nil } -func (dst *Int2Array) DecodeBinary(src []byte) error { +func (dst *Int2Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int2Array{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -182,7 +182,7 @@ func (dst *Int2Array) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -192,7 +192,7 @@ func (dst *Int2Array) DecodeBinary(src []byte) error { return nil } -func (src *Int2Array) EncodeText(w io.Writer) (bool, error) { +func (src *Int2Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -239,7 +239,7 @@ func (src *Int2Array) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -268,11 +268,11 @@ func (src *Int2Array) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *Int2Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Int2Oid) +func (src *Int2Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, Int2Oid) } -func (src *Int2Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *Int2Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -292,7 +292,7 @@ func (src *Int2Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -302,7 +302,7 @@ func (src *Int2Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/int4.go b/pgtype/int4.go index 8eaf5094..62ee366f 100644 --- a/pgtype/int4.go +++ b/pgtype/int4.go @@ -89,7 +89,7 @@ func (src *Int4) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *Int4) DecodeText(src []byte) error { +func (dst *Int4) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int4{Status: Null} return nil @@ -104,7 +104,7 @@ func (dst *Int4) DecodeText(src []byte) error { return nil } -func (dst *Int4) DecodeBinary(src []byte) error { +func (dst *Int4) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int4{Status: Null} return nil @@ -119,7 +119,7 @@ func (dst *Int4) DecodeBinary(src []byte) error { return nil } -func (src Int4) EncodeText(w io.Writer) (bool, error) { +func (src Int4) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -131,7 +131,7 @@ func (src Int4) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Int4) EncodeBinary(w io.Writer) (bool, error) { +func (src Int4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/int4_array.go b/pgtype/int4_array.go index 5cd91c04..dda88eaf 100644 --- a/pgtype/int4_array.go +++ b/pgtype/int4_array.go @@ -114,7 +114,7 @@ func (src *Int4Array) AssignTo(dst interface{}) error { return nil } -func (dst *Int4Array) DecodeText(src []byte) error { +func (dst *Int4Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int4Array{Status: Null} return nil @@ -136,7 +136,7 @@ func (dst *Int4Array) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -150,14 +150,14 @@ func (dst *Int4Array) DecodeText(src []byte) error { return nil } -func (dst *Int4Array) DecodeBinary(src []byte) error { +func (dst *Int4Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int4Array{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -182,7 +182,7 @@ func (dst *Int4Array) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -192,7 +192,7 @@ func (dst *Int4Array) DecodeBinary(src []byte) error { return nil } -func (src *Int4Array) EncodeText(w io.Writer) (bool, error) { +func (src *Int4Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -239,7 +239,7 @@ func (src *Int4Array) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -268,11 +268,11 @@ func (src *Int4Array) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *Int4Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Int4Oid) +func (src *Int4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, Int4Oid) } -func (src *Int4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *Int4Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -292,7 +292,7 @@ func (src *Int4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -302,7 +302,7 @@ func (src *Int4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/int8.go b/pgtype/int8.go index 2416500d..7ed54f8e 100644 --- a/pgtype/int8.go +++ b/pgtype/int8.go @@ -80,7 +80,7 @@ func (src *Int8) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *Int8) DecodeText(src []byte) error { +func (dst *Int8) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int8{Status: Null} return nil @@ -95,7 +95,7 @@ func (dst *Int8) DecodeText(src []byte) error { return nil } -func (dst *Int8) DecodeBinary(src []byte) error { +func (dst *Int8) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int8{Status: Null} return nil @@ -111,7 +111,7 @@ func (dst *Int8) DecodeBinary(src []byte) error { return nil } -func (src Int8) EncodeText(w io.Writer) (bool, error) { +func (src Int8) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -123,7 +123,7 @@ func (src Int8) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Int8) EncodeBinary(w io.Writer) (bool, error) { +func (src Int8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/int8_array.go b/pgtype/int8_array.go index 5efc0f45..468c126b 100644 --- a/pgtype/int8_array.go +++ b/pgtype/int8_array.go @@ -114,7 +114,7 @@ func (src *Int8Array) AssignTo(dst interface{}) error { return nil } -func (dst *Int8Array) DecodeText(src []byte) error { +func (dst *Int8Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int8Array{Status: Null} return nil @@ -136,7 +136,7 @@ func (dst *Int8Array) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -150,14 +150,14 @@ func (dst *Int8Array) DecodeText(src []byte) error { return nil } -func (dst *Int8Array) DecodeBinary(src []byte) error { +func (dst *Int8Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int8Array{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -182,7 +182,7 @@ func (dst *Int8Array) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -192,7 +192,7 @@ func (dst *Int8Array) DecodeBinary(src []byte) error { return nil } -func (src *Int8Array) EncodeText(w io.Writer) (bool, error) { +func (src *Int8Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -239,7 +239,7 @@ func (src *Int8Array) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -268,11 +268,11 @@ func (src *Int8Array) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *Int8Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Int8Oid) +func (src *Int8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, Int8Oid) } -func (src *Int8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *Int8Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -292,7 +292,7 @@ func (src *Int8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -302,7 +302,7 @@ func (src *Int8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/json.go b/pgtype/json.go index ecdb3dab..bfffae14 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -84,7 +84,7 @@ func (src *Json) AssignTo(dst interface{}) error { return nil } -func (dst *Json) DecodeText(src []byte) error { +func (dst *Json) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Json{Status: Null} return nil @@ -97,11 +97,11 @@ func (dst *Json) DecodeText(src []byte) error { return nil } -func (dst *Json) DecodeBinary(src []byte) error { - return dst.DecodeText(src) +func (dst *Json) DecodeBinary(ci *ConnInfo, src []byte) error { + return dst.DecodeText(ci, src) } -func (src Json) EncodeText(w io.Writer) (bool, error) { +func (src Json) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -113,6 +113,6 @@ func (src Json) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Json) EncodeBinary(w io.Writer) (bool, error) { - return src.EncodeText(w) +func (src Json) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.EncodeText(ci, w) } diff --git a/pgtype/jsonb.go b/pgtype/jsonb.go index 13062e8e..e44f3c41 100644 --- a/pgtype/jsonb.go +++ b/pgtype/jsonb.go @@ -19,11 +19,11 @@ func (src *Jsonb) AssignTo(dst interface{}) error { return (*Json)(src).AssignTo(dst) } -func (dst *Jsonb) DecodeText(src []byte) error { - return (*Json)(dst).DecodeText(src) +func (dst *Jsonb) DecodeText(ci *ConnInfo, src []byte) error { + return (*Json)(dst).DecodeText(ci, src) } -func (dst *Jsonb) DecodeBinary(src []byte) error { +func (dst *Jsonb) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Jsonb{Status: Null} return nil @@ -46,11 +46,11 @@ func (dst *Jsonb) DecodeBinary(src []byte) error { } -func (src Jsonb) EncodeText(w io.Writer) (bool, error) { - return (Json)(src).EncodeText(w) +func (src Jsonb) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (Json)(src).EncodeText(ci, w) } -func (src Jsonb) EncodeBinary(w io.Writer) (bool, error) { +func (src Jsonb) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/name.go b/pgtype/name.go index 9eb12ece..9ebf63d3 100644 --- a/pgtype/name.go +++ b/pgtype/name.go @@ -31,18 +31,18 @@ func (src *Name) AssignTo(dst interface{}) error { return (*Text)(src).AssignTo(dst) } -func (dst *Name) DecodeText(src []byte) error { - return (*Text)(dst).DecodeText(src) +func (dst *Name) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) } -func (dst *Name) DecodeBinary(src []byte) error { - return (*Text)(dst).DecodeBinary(src) +func (dst *Name) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeBinary(ci, src) } -func (src Name) EncodeText(w io.Writer) (bool, error) { - return (Text)(src).EncodeText(w) +func (src Name) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (Text)(src).EncodeText(ci, w) } -func (src Name) EncodeBinary(w io.Writer) (bool, error) { - return (Text)(src).EncodeBinary(w) +func (src Name) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (Text)(src).EncodeBinary(ci, w) } diff --git a/pgtype/oid.go b/pgtype/oid.go index eab1fbcb..3edd7f3c 100644 --- a/pgtype/oid.go +++ b/pgtype/oid.go @@ -18,7 +18,7 @@ import ( // allow for NULL Oids use OidValue. type Oid uint32 -func (dst *Oid) DecodeText(src []byte) error { +func (dst *Oid) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { return fmt.Errorf("cannot decode nil into Oid") } @@ -32,7 +32,7 @@ func (dst *Oid) DecodeText(src []byte) error { return nil } -func (dst *Oid) DecodeBinary(src []byte) error { +func (dst *Oid) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { return fmt.Errorf("cannot decode nil into Oid") } @@ -46,12 +46,12 @@ func (dst *Oid) DecodeBinary(src []byte) error { return nil } -func (src Oid) EncodeText(w io.Writer) (bool, error) { +func (src Oid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { _, err := io.WriteString(w, strconv.FormatUint(uint64(src), 10)) return false, err } -func (src Oid) EncodeBinary(w io.Writer) (bool, error) { +func (src Oid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteUint32(w, uint32(src)) return false, err } diff --git a/pgtype/oid_value.go b/pgtype/oid_value.go index a2b2dcbe..1bce6e11 100644 --- a/pgtype/oid_value.go +++ b/pgtype/oid_value.go @@ -28,18 +28,18 @@ func (src *OidValue) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *OidValue) DecodeText(src []byte) error { - return (*pguint32)(dst).DecodeText(src) +func (dst *OidValue) DecodeText(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeText(ci, src) } -func (dst *OidValue) DecodeBinary(src []byte) error { - return (*pguint32)(dst).DecodeBinary(src) +func (dst *OidValue) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src OidValue) EncodeText(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeText(w) +func (src OidValue) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeText(ci, w) } -func (src OidValue) EncodeBinary(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeBinary(w) +func (src OidValue) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeBinary(ci, w) } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 7b1470b7..674c0db7 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -3,6 +3,7 @@ package pgtype import ( "errors" "io" + "reflect" ) // PostgreSQL oids for common types @@ -83,14 +84,14 @@ type BinaryDecoder interface { // DecodeBinary decodes src into BinaryDecoder. If src is nil then the // original SQL value is NULL. BinaryDecoder MUST not retain a reference to // src. It MUST make a copy if it needs to retain the raw bytes. - DecodeBinary(src []byte) error + DecodeBinary(ci *ConnInfo, src []byte) error } type TextDecoder interface { // DecodeText decodes src into TextDecoder. If src is nil then the original // SQL value is NULL. TextDecoder MUST not retain a reference to src. It MUST // make a copy if it needs to retain the raw bytes. - DecodeText(src []byte) error + DecodeText(ci *ConnInfo, src []byte) error } // BinaryEncoder is implemented by types that can encode themselves into the @@ -100,7 +101,7 @@ type BinaryEncoder interface { // SQL value NULL then write nothing and return (true, nil). The caller of // EncodeBinary is responsible for writing the correct NULL value or the // length of the data written. - EncodeBinary(w io.Writer) (null bool, err error) + EncodeBinary(ci *ConnInfo, w io.Writer) (null bool, err error) } // TextEncoder is implemented by types that can encode themselves into the @@ -110,7 +111,127 @@ type TextEncoder interface { // value NULL then write nothing and return (true, nil). The caller of // EncodeText is responsible for writing the correct NULL value or the length // of the data written. - EncodeText(w io.Writer) (null bool, err error) + EncodeText(ci *ConnInfo, w io.Writer) (null bool, err error) } var errUndefined = errors.New("cannot encode status undefined") + +type DataType struct { + Value Value + Name string + Oid Oid +} + +type ConnInfo struct { + oidToDataType map[Oid]*DataType + nameToDataType map[string]*DataType + reflectTypeToDataType map[reflect.Type]*DataType +} + +func NewConnInfo() *ConnInfo { + return &ConnInfo{ + oidToDataType: make(map[Oid]*DataType, 256), + nameToDataType: make(map[string]*DataType, 256), + reflectTypeToDataType: make(map[reflect.Type]*DataType, 256), + } +} + +func (ci *ConnInfo) InitializeDataTypes(nameOids map[string]Oid) { + for name, oid := range nameOids { + var value Value + if t, ok := nameValues[name]; ok { + value = reflect.New(reflect.ValueOf(t).Elem().Type()).Interface().(Value) + } else { + value = &GenericText{} + } + ci.RegisterDataType(DataType{Value: value, Name: name, Oid: oid}) + } +} + +func (ci *ConnInfo) RegisterDataType(t DataType) { + ci.oidToDataType[t.Oid] = &t + ci.nameToDataType[t.Name] = &t + ci.reflectTypeToDataType[reflect.ValueOf(t.Value).Type()] = &t +} + +func (ci *ConnInfo) DataTypeForOid(oid Oid) (*DataType, bool) { + dt, ok := ci.oidToDataType[oid] + return dt, ok +} + +func (ci *ConnInfo) DataTypeForName(name string) (*DataType, bool) { + dt, ok := ci.nameToDataType[name] + return dt, ok +} + +func (ci *ConnInfo) DataTypeForValue(v Value) (*DataType, bool) { + dt, ok := ci.reflectTypeToDataType[reflect.ValueOf(v).Type()] + return dt, ok +} + +// DeepCopy makes a deep copy of the ConnInfo. +func (ci *ConnInfo) DeepCopy() *ConnInfo { + ci2 := &ConnInfo{ + oidToDataType: make(map[Oid]*DataType, len(ci.oidToDataType)), + nameToDataType: make(map[string]*DataType, len(ci.nameToDataType)), + reflectTypeToDataType: make(map[reflect.Type]*DataType, len(ci.reflectTypeToDataType)), + } + + for _, dt := range ci.oidToDataType { + ci2.RegisterDataType(DataType{ + Value: reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(Value), + Name: dt.Name, + Oid: dt.Oid, + }) + } + + return ci2 +} + +var nameValues map[string]Value + +func init() { + nameValues = map[string]Value{ + "_aclitem": &AclitemArray{}, + "_bool": &BoolArray{}, + "_bytea": &ByteaArray{}, + "_cidr": &CidrArray{}, + "_date": &DateArray{}, + "_float4": &Float4Array{}, + "_float8": &Float8Array{}, + "_inet": &InetArray{}, + "_int2": &Int2Array{}, + "_int4": &Int4Array{}, + "_int8": &Int8Array{}, + "_text": &TextArray{}, + "_timestamp": &TimestampArray{}, + "_timestamptz": &TimestamptzArray{}, + "_varchar": &VarcharArray{}, + "aclitem": &Aclitem{}, + "bool": &Bool{}, + "bytea": &Bytea{}, + "char": &QChar{}, + "cid": &Cid{}, + "cidr": &Cidr{}, + "date": &Date{}, + "float4": &Float4{}, + "float8": &Float8{}, + "hstore": &Hstore{}, + "inet": &Inet{}, + "int2": &Int2{}, + "int4": &Int4{}, + "int8": &Int8{}, + "json": &Json{}, + "jsonb": &Jsonb{}, + "name": &Name{}, + "oid": &OidValue{}, + "record": &Record{}, + "text": &Text{}, + "tid": &Tid{}, + "timestamp": &Timestamp{}, + "timestamptz": &Timestamptz{}, + "unknown": &Unknown{}, + "varchar": &Varchar{}, + "xid": &Xid{}, + } +} diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index f9b6f56d..391fed57 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -60,16 +60,16 @@ type forceTextEncoder struct { e pgtype.TextEncoder } -func (f forceTextEncoder) EncodeText(w io.Writer) (bool, error) { - return f.e.EncodeText(w) +func (f forceTextEncoder) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { + return f.e.EncodeText(ci, w) } type forceBinaryEncoder struct { e pgtype.BinaryEncoder } -func (f forceBinaryEncoder) EncodeBinary(w io.Writer) (bool, error) { - return f.e.EncodeBinary(w) +func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { + return f.e.EncodeBinary(ci, w) } func forceEncoder(e interface{}, formatCode int16) interface{} { @@ -114,7 +114,7 @@ func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []int ps.FieldDescriptions[0].FormatCode = fc.formatCode vEncoder := forceEncoder(v, fc.formatCode) if vEncoder == nil { - t.Logf("%v does not implement %v", fc.name) + t.Logf("%#v does not implement %v", v, fc.name) continue } // Derefence value if it is a pointer diff --git a/pgtype/pguint32.go b/pgtype/pguint32.go index 05c79c0e..3f9e7bf7 100644 --- a/pgtype/pguint32.go +++ b/pgtype/pguint32.go @@ -63,7 +63,7 @@ func (src *pguint32) AssignTo(dst interface{}) error { return nil } -func (dst *pguint32) DecodeText(src []byte) error { +func (dst *pguint32) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = pguint32{Status: Null} return nil @@ -78,7 +78,7 @@ func (dst *pguint32) DecodeText(src []byte) error { return nil } -func (dst *pguint32) DecodeBinary(src []byte) error { +func (dst *pguint32) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = pguint32{Status: Null} return nil @@ -93,7 +93,7 @@ func (dst *pguint32) DecodeBinary(src []byte) error { return nil } -func (src pguint32) EncodeText(w io.Writer) (bool, error) { +func (src pguint32) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -105,7 +105,7 @@ func (src pguint32) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src pguint32) EncodeBinary(w io.Writer) (bool, error) { +func (src pguint32) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/qchar.go b/pgtype/qchar.go index d46e716d..4b32ee4a 100644 --- a/pgtype/qchar.go +++ b/pgtype/qchar.go @@ -115,7 +115,7 @@ func (src *QChar) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *QChar) DecodeBinary(src []byte) error { +func (dst *QChar) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = QChar{Status: Null} return nil @@ -129,7 +129,7 @@ func (dst *QChar) DecodeBinary(src []byte) error { return nil } -func (src QChar) EncodeBinary(w io.Writer) (bool, error) { +func (src QChar) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/record.go b/pgtype/record.go new file mode 100644 index 00000000..1bfd05b9 --- /dev/null +++ b/pgtype/record.go @@ -0,0 +1,123 @@ +package pgtype + +import ( + "encoding/binary" + "fmt" +) + +// Record is the generic PostgreSQL record type such as is created with the +// "row" function. Record only implements BinaryEncoder and Value. The text +// format output format from PostgreSQL does not include type information and is +// therefore impossible to decode. No encoders are implemented because +// PostgreSQL does not support input of generic records. +type Record struct { + Fields []Value + Status Status +} + +func (dst *Record) Set(src interface{}) error { + switch value := src.(type) { + case []Value: + *dst = Record{Fields: value, Status: Present} + default: + return fmt.Errorf("cannot convert %v to Record", src) + } + + return nil +} + +func (dst *Record) Get() interface{} { + switch dst.Status { + case Present: + return dst.Fields + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Record) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *[]Value: + switch src.Status { + case Present: + *v = make([]Value, len(src.Fields)) + copy(*v, src.Fields) + case Null: + *v = nil + default: + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + case *[]interface{}: + switch src.Status { + case Present: + *v = make([]interface{}, len(src.Fields)) + for i := range *v { + (*v)[i] = src.Fields[i].Get() + } + case Null: + *v = nil + default: + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + default: + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Record{Status: Null} + return nil + } + + rp := 0 + + if len(src[rp:]) < 4 { + return fmt.Errorf("Record incomplete %v", src) + } + fieldCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + fields := make([]Value, fieldCount) + + for i := 0; i < fieldCount; i++ { + if len(src[rp:]) < 8 { + return fmt.Errorf("Record incomplete %v", src) + } + fieldOid := Oid(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + fieldLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + var binaryDecoder BinaryDecoder + if dt, ok := ci.DataTypeForOid(fieldOid); ok { + if binaryDecoder, ok = dt.Value.(BinaryDecoder); !ok { + return fmt.Errorf("unknown oid while decoding record: %v", fieldOid) + } + } + + var fieldBytes []byte + if fieldLen >= 0 { + if len(src[rp:]) < fieldLen { + return fmt.Errorf("Record incomplete %v", src) + } + fieldBytes = src[rp : rp+fieldLen] + rp += fieldLen + } + + if err := binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil { + return err + } + + fields[i] = binaryDecoder.(Value) + } + + *dst = Record{Fields: fields, Status: Present} + + return nil +} diff --git a/pgtype/record_test.go b/pgtype/record_test.go new file mode 100644 index 00000000..bc6e5893 --- /dev/null +++ b/pgtype/record_test.go @@ -0,0 +1,150 @@ +package pgtype_test + +import ( + "fmt" + "reflect" + "testing" + + "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" +) + +func TestRecordTranscode(t *testing.T) { + conn := mustConnectPgx(t) + defer mustClose(t, conn) + + tests := []struct { + sql string + expected pgtype.Record + }{ + { + sql: `select row()`, + expected: pgtype.Record{ + Fields: []pgtype.Value{}, + Status: pgtype.Present, + }, + }, + { + sql: `select row('foo'::text, 42::int4)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4Array{ + Elements: []pgtype.Int4{ + pgtype.Int4{Int: 1, Status: pgtype.Present}, + pgtype.Int4{Int: 2, Status: pgtype.Present}, + pgtype.Int4{Status: pgtype.Null}, + pgtype.Int4{Int: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 4, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select row(null)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Unknown{Status: pgtype.Null}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select null::record`, + expected: pgtype.Record{ + Status: pgtype.Null, + }, + }, + } + + for i, tt := range tests { + psName := fmt.Sprintf("test%d", i) + ps, err := conn.Prepare(psName, tt.sql) + if err != nil { + t.Fatal(err) + } + ps.FieldDescriptions[0].FormatCode = pgx.BinaryFormatCode + + var result pgtype.Record + if err := conn.QueryRow(psName).Scan(&result); err != nil { + t.Errorf("%d: %v", i, err) + continue + } + + if !reflect.DeepEqual(tt.expected, result) { + t.Errorf("%d: expected %v, got %v", i, tt.expected, result) + } + } +} + +func TestRecordAssignTo(t *testing.T) { + var valueSlice []pgtype.Value + var interfaceSlice []interface{} + + simpleTests := []struct { + src pgtype.Record + dst interface{} + expected interface{} + }{ + { + src: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + dst: &valueSlice, + expected: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + }, + { + src: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + dst: &interfaceSlice, + expected: []interface{}{"foo", int32(42)}, + }, + { + src: pgtype.Record{Status: pgtype.Null}, + dst: &valueSlice, + expected: (([]pgtype.Value)(nil)), + }, + { + src: pgtype.Record{Status: pgtype.Null}, + dst: &interfaceSlice, + expected: (([]interface{})(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/text.go b/pgtype/text.go index 3dd082c9..f1a76b6e 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -78,7 +78,7 @@ func (src *Text) AssignTo(dst interface{}) error { return nil } -func (dst *Text) DecodeText(src []byte) error { +func (dst *Text) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Text{Status: Null} return nil @@ -88,11 +88,11 @@ func (dst *Text) DecodeText(src []byte) error { return nil } -func (dst *Text) DecodeBinary(src []byte) error { - return dst.DecodeText(src) +func (dst *Text) DecodeBinary(ci *ConnInfo, src []byte) error { + return dst.DecodeText(ci, src) } -func (src Text) EncodeText(w io.Writer) (bool, error) { +func (src Text) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -104,6 +104,6 @@ func (src Text) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Text) EncodeBinary(w io.Writer) (bool, error) { - return src.EncodeText(w) +func (src Text) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.EncodeText(ci, w) } diff --git a/pgtype/text_array.go b/pgtype/text_array.go index 1e6677a9..6e89708f 100644 --- a/pgtype/text_array.go +++ b/pgtype/text_array.go @@ -83,7 +83,7 @@ func (src *TextArray) AssignTo(dst interface{}) error { return nil } -func (dst *TextArray) DecodeText(src []byte) error { +func (dst *TextArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = TextArray{Status: Null} return nil @@ -105,7 +105,7 @@ func (dst *TextArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -119,14 +119,14 @@ func (dst *TextArray) DecodeText(src []byte) error { return nil } -func (dst *TextArray) DecodeBinary(src []byte) error { +func (dst *TextArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = TextArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -151,7 +151,7 @@ func (dst *TextArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -161,7 +161,7 @@ func (dst *TextArray) DecodeBinary(src []byte) error { return nil } -func (src *TextArray) EncodeText(w io.Writer) (bool, error) { +func (src *TextArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -208,7 +208,7 @@ func (src *TextArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -237,11 +237,11 @@ func (src *TextArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *TextArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, TextOid) +func (src *TextArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, TextOid) } -func (src *TextArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *TextArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -261,7 +261,7 @@ func (src *TextArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -271,7 +271,7 @@ func (src *TextArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/tid.go b/pgtype/tid.go index 20d962df..b91711d3 100644 --- a/pgtype/tid.go +++ b/pgtype/tid.go @@ -46,7 +46,7 @@ func (src *Tid) AssignTo(dst interface{}) error { return fmt.Errorf("cannot assign %v to %T", src, dst) } -func (dst *Tid) DecodeText(src []byte) error { +func (dst *Tid) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Tid{Status: Null} return nil @@ -75,7 +75,7 @@ func (dst *Tid) DecodeText(src []byte) error { return nil } -func (dst *Tid) DecodeBinary(src []byte) error { +func (dst *Tid) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Tid{Status: Null} return nil @@ -93,7 +93,7 @@ func (dst *Tid) DecodeBinary(src []byte) error { return nil } -func (src Tid) EncodeText(w io.Writer) (bool, error) { +func (src Tid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -105,7 +105,7 @@ func (src Tid) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Tid) EncodeBinary(w io.Writer) (bool, error) { +func (src Tid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index 3bb8f080..9a9e74ea 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -85,7 +85,7 @@ func (src *Timestamp) AssignTo(dst interface{}) error { // DecodeText decodes from src into dst. The decoded time is considered to // be in UTC. -func (dst *Timestamp) DecodeText(src []byte) error { +func (dst *Timestamp) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Timestamp{Status: Null} return nil @@ -111,7 +111,7 @@ func (dst *Timestamp) DecodeText(src []byte) error { // DecodeBinary decodes from src into dst. The decoded time is considered to // be in UTC. -func (dst *Timestamp) DecodeBinary(src []byte) error { +func (dst *Timestamp) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Timestamp{Status: Null} return nil @@ -139,7 +139,7 @@ func (dst *Timestamp) DecodeBinary(src []byte) error { // EncodeText writes the text encoding of src into w. If src.Time is not in // the UTC time zone it returns an error. -func (src Timestamp) EncodeText(w io.Writer) (bool, error) { +func (src Timestamp) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -167,7 +167,7 @@ func (src Timestamp) EncodeText(w io.Writer) (bool, error) { // EncodeBinary writes the binary encoding of src into w. If src.Time is not in // the UTC time zone it returns an error. -func (src Timestamp) EncodeBinary(w io.Writer) (bool, error) { +func (src Timestamp) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/timestamp_array.go b/pgtype/timestamp_array.go index c955dc42..064ad483 100644 --- a/pgtype/timestamp_array.go +++ b/pgtype/timestamp_array.go @@ -84,7 +84,7 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { return nil } -func (dst *TimestampArray) DecodeText(src []byte) error { +func (dst *TimestampArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = TimestampArray{Status: Null} return nil @@ -106,7 +106,7 @@ func (dst *TimestampArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -120,14 +120,14 @@ func (dst *TimestampArray) DecodeText(src []byte) error { return nil } -func (dst *TimestampArray) DecodeBinary(src []byte) error { +func (dst *TimestampArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = TimestampArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -152,7 +152,7 @@ func (dst *TimestampArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -162,7 +162,7 @@ func (dst *TimestampArray) DecodeBinary(src []byte) error { return nil } -func (src *TimestampArray) EncodeText(w io.Writer) (bool, error) { +func (src *TimestampArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -209,7 +209,7 @@ func (src *TimestampArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -238,11 +238,11 @@ func (src *TimestampArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *TimestampArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, TimestampOid) +func (src *TimestampArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, TimestampOid) } -func (src *TimestampArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *TimestampArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -262,7 +262,7 @@ func (src *TimestampArray) encodeBinary(w io.Writer, elementOid int32) (bool, er } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -272,7 +272,7 @@ func (src *TimestampArray) encodeBinary(w io.Writer, elementOid int32) (bool, er for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index 5b9f5038..7f57f4b7 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -84,7 +84,7 @@ func (src *Timestamptz) AssignTo(dst interface{}) error { return nil } -func (dst *Timestamptz) DecodeText(src []byte) error { +func (dst *Timestamptz) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Timestamptz{Status: Null} return nil @@ -117,7 +117,7 @@ func (dst *Timestamptz) DecodeText(src []byte) error { return nil } -func (dst *Timestamptz) DecodeBinary(src []byte) error { +func (dst *Timestamptz) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Timestamptz{Status: Null} return nil @@ -143,7 +143,7 @@ func (dst *Timestamptz) DecodeBinary(src []byte) error { return nil } -func (src Timestamptz) EncodeText(w io.Writer) (bool, error) { +func (src Timestamptz) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src Timestamptz) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Timestamptz) EncodeBinary(w io.Writer) (bool, error) { +func (src Timestamptz) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/timestamptz_array.go b/pgtype/timestamptz_array.go index cd63e02e..4af1460b 100644 --- a/pgtype/timestamptz_array.go +++ b/pgtype/timestamptz_array.go @@ -84,7 +84,7 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { return nil } -func (dst *TimestamptzArray) DecodeText(src []byte) error { +func (dst *TimestamptzArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = TimestamptzArray{Status: Null} return nil @@ -106,7 +106,7 @@ func (dst *TimestamptzArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -120,14 +120,14 @@ func (dst *TimestamptzArray) DecodeText(src []byte) error { return nil } -func (dst *TimestamptzArray) DecodeBinary(src []byte) error { +func (dst *TimestamptzArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = TimestamptzArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -152,7 +152,7 @@ func (dst *TimestamptzArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -162,7 +162,7 @@ func (dst *TimestamptzArray) DecodeBinary(src []byte) error { return nil } -func (src *TimestamptzArray) EncodeText(w io.Writer) (bool, error) { +func (src *TimestamptzArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -209,7 +209,7 @@ func (src *TimestamptzArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -238,11 +238,11 @@ func (src *TimestamptzArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *TimestamptzArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, TimestamptzOid) +func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, TimestamptzOid) } -func (src *TimestamptzArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *TimestamptzArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -262,7 +262,7 @@ func (src *TimestamptzArray) encodeBinary(w io.Writer, elementOid int32) (bool, } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -272,7 +272,7 @@ func (src *TimestamptzArray) encodeBinary(w io.Writer, elementOid int32) (bool, for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb index a56097c0..2a46a658 100644 --- a/pgtype/typed_array.go.erb +++ b/pgtype/typed_array.go.erb @@ -82,7 +82,7 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { return nil } -func (dst *<%= pgtype_array_type %>) DecodeText(src []byte) error { +func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = <%= pgtype_array_type %>{Status: Null} return nil @@ -104,7 +104,7 @@ func (dst *<%= pgtype_array_type %>) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -118,14 +118,14 @@ func (dst *<%= pgtype_array_type %>) DecodeText(src []byte) error { return nil } -func (dst *<%= pgtype_array_type %>) DecodeBinary(src []byte) error { +func (dst *<%= pgtype_array_type %>) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = <%= pgtype_array_type %>{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -150,7 +150,7 @@ func (dst *<%= pgtype_array_type %>) DecodeBinary(src []byte) error { elemSrc = src[rp:rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -160,7 +160,7 @@ func (dst *<%= pgtype_array_type %>) DecodeBinary(src []byte) error { return nil } -func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) (bool, error) { +func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -207,7 +207,7 @@ func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -236,11 +236,11 @@ func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *<%= pgtype_array_type %>) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, <%= element_oid %>) +func (src *<%= pgtype_array_type %>) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, <%= element_oid %>) } -func (src *<%= pgtype_array_type %>) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *<%= pgtype_array_type %>) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -260,7 +260,7 @@ func (src *<%= pgtype_array_type %>) encodeBinary(w io.Writer, elementOid int32) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -270,7 +270,7 @@ func (src *<%= pgtype_array_type %>) encodeBinary(w io.Writer, elementOid int32) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh index 41c1313f..5fde32aa 100644 --- a/pgtype/typed_array_gen.sh +++ b/pgtype/typed_array_gen.sh @@ -8,6 +8,8 @@ erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_type erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_oid=Float4Oid text_null=NULL typed_array.go.erb > float4_array.go erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8Oid text_null=NULL typed_array.go.erb > float8_array.go erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOid text_null=NULL typed_array.go.erb > inet_array.go +erb pgtype_array_type=CidrArray pgtype_element_type=Cidr go_array_types=[]*net.IPNet,[]net.IP element_oid=CidrOid text_null=NULL typed_array.go.erb > cidr_array.go erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_oid=TextOid text_null='"NULL"' typed_array.go.erb > text_array.go +erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string element_oid=VarcharOid text_null='"NULL"' typed_array.go.erb > varchar_array.go erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_oid=ByteaOid text_null=NULL typed_array.go.erb > bytea_array.go erb pgtype_array_type=AclitemArray pgtype_element_type=Aclitem go_array_types=[]string element_oid=AclitemOid text_null=NULL typed_array.go.erb > aclitem_array.go diff --git a/pgtype/unknown.go b/pgtype/unknown.go new file mode 100644 index 00000000..b951ad99 --- /dev/null +++ b/pgtype/unknown.go @@ -0,0 +1,32 @@ +package pgtype + +// Unknown represents the PostgreSQL unknown type. It is either a string literal +// or NULL. It is used when PostgreSQL does not know the type of a value. In +// general, this will only be used in pgx when selecting a null value without +// type information. e.g. SELECT NULL; +type Unknown struct { + String string + Status Status +} + +func (dst *Unknown) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +func (dst *Unknown) Get() interface{} { + return (*Text)(dst).Get() +} + +// AssignTo assigns from src to dst. Note that as Unknown is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *Unknown) AssignTo(dst interface{}) error { + return (*Text)(src).AssignTo(dst) +} + +func (dst *Unknown) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) +} + +func (dst *Unknown) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeBinary(ci, src) +} diff --git a/pgtype/varchar.go b/pgtype/varchar.go new file mode 100644 index 00000000..adda6c49 --- /dev/null +++ b/pgtype/varchar.go @@ -0,0 +1,40 @@ +package pgtype + +import ( + "io" +) + +type Varchar Text + +// Set converts from src to dst. Note that as Varchar is not a general +// number type Set does not do automatic type conversion as other number +// types do. +func (dst *Varchar) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +func (dst *Varchar) Get() interface{} { + return (*Text)(dst).Get() +} + +// AssignTo assigns from src to dst. Note that as Varchar is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *Varchar) AssignTo(dst interface{}) error { + return (*Text)(src).AssignTo(dst) +} + +func (dst *Varchar) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) +} + +func (dst *Varchar) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeBinary(ci, src) +} + +func (src Varchar) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (Text)(src).EncodeText(ci, w) +} + +func (src Varchar) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (Text)(src).EncodeBinary(ci, w) +} diff --git a/pgtype/varchar_array.go b/pgtype/varchar_array.go index 693b9a61..21e9ccff 100644 --- a/pgtype/varchar_array.go +++ b/pgtype/varchar_array.go @@ -1,35 +1,296 @@ package pgtype import ( + "bytes" + "encoding/binary" + "fmt" "io" + + "github.com/jackc/pgx/pgio" ) -type VarcharArray TextArray +type VarcharArray struct { + Elements []Varchar + Dimensions []ArrayDimension + Status Status +} func (dst *VarcharArray) Set(src interface{}) error { - return (*TextArray)(dst).Set(src) + switch value := src.(type) { + + case []string: + if value == nil { + *dst = VarcharArray{Status: Null} + } else if len(value) == 0 { + *dst = VarcharArray{Status: Present} + } else { + elements := make([]Varchar, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = VarcharArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Varchar", value) + } + + return nil } func (dst *VarcharArray) Get() interface{} { - return (*TextArray)(dst).Get() + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } } func (src *VarcharArray) AssignTo(dst interface{}) error { - return (*TextArray)(src).AssignTo(dst) + switch v := dst.(type) { + + case *[]string: + if src.Status == Present { + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil } -func (dst *VarcharArray) DecodeText(src []byte) error { - return (*TextArray)(dst).DecodeText(src) +func (dst *VarcharArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = VarcharArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Varchar + + if len(uta.Elements) > 0 { + elements = make([]Varchar, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Varchar + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = VarcharArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil } -func (dst *VarcharArray) DecodeBinary(src []byte) error { - return (*TextArray)(dst).DecodeBinary(src) +func (dst *VarcharArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = VarcharArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = VarcharArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Varchar, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = VarcharArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil } -func (src *VarcharArray) EncodeText(w io.Writer) (bool, error) { - return (*TextArray)(src).EncodeText(w) +func (src *VarcharArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if len(src.Dimensions) == 0 { + _, err := io.WriteString(w, "{}") + return false, err + } + + err := EncodeTextArrayDimensions(w, src.Dimensions) + if err != nil { + return false, err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(w, ',') + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(w, '{') + if err != nil { + return false, err + } + } + } + + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = io.WriteString(w, `"NULL"`) + if err != nil { + return false, err + } + } else { + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(w, '}') + if err != nil { + return false, err + } + } + } + } + + return false, nil } -func (src *VarcharArray) EncodeBinary(w io.Writer) (bool, error) { - return (*TextArray)(src).encodeBinary(w, VarcharOid) +func (src *VarcharArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, VarcharOid) +} + +func (src *VarcharArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + arrayHeader := ArrayHeader{ + ElementOid: elementOid, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(ci, w) + if err != nil { + return false, err + } + + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } + } + } + + return false, err } diff --git a/pgtype/varchar_array_test.go b/pgtype/varchar_array_test.go new file mode 100644 index 00000000..4a8b09b8 --- /dev/null +++ b/pgtype/varchar_array_test.go @@ -0,0 +1,151 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestVarcharArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "varchar[]", []interface{}{ + &pgtype.VarcharArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + pgtype.Varchar{String: "foo", Status: pgtype.Present}, + pgtype.Varchar{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.VarcharArray{Status: pgtype.Null}, + &pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + pgtype.Varchar{String: "bar ", Status: pgtype.Present}, + pgtype.Varchar{String: "NuLL", Status: pgtype.Present}, + pgtype.Varchar{String: `wow"quz\`, Status: pgtype.Present}, + pgtype.Varchar{String: "", Status: pgtype.Present}, + pgtype.Varchar{Status: pgtype.Null}, + pgtype.Varchar{String: "null", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + pgtype.Varchar{String: "bar", Status: pgtype.Present}, + pgtype.Varchar{String: "baz", Status: pgtype.Present}, + pgtype.Varchar{String: "quz", Status: pgtype.Present}, + pgtype.Varchar{String: "foo", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestVarcharArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.VarcharArray + }{ + { + source: []string{"foo"}, + result: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]string)(nil)), + result: pgtype.VarcharArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.VarcharArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestVarcharArrayAssignTo(t *testing.T) { + var stringSlice []string + type _stringSlice []string + var namedStringSlice _stringSlice + + simpleTests := []struct { + src pgtype.VarcharArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + expected: []string{"foo"}, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedStringSlice, + expected: _stringSlice{"bar"}, + }, + { + src: pgtype.VarcharArray{Status: pgtype.Null}, + dst: &stringSlice, + expected: (([]string)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.VarcharArray + dst interface{} + }{ + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/xid.go b/pgtype/xid.go index a53120de..c76548a4 100644 --- a/pgtype/xid.go +++ b/pgtype/xid.go @@ -37,18 +37,18 @@ func (src *Xid) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *Xid) DecodeText(src []byte) error { - return (*pguint32)(dst).DecodeText(src) +func (dst *Xid) DecodeText(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeText(ci, src) } -func (dst *Xid) DecodeBinary(src []byte) error { - return (*pguint32)(dst).DecodeBinary(src) +func (dst *Xid) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src Xid) EncodeText(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeText(w) +func (src Xid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeText(ci, w) } -func (src Xid) EncodeBinary(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeBinary(w) +func (src Xid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeBinary(ci, w) } diff --git a/query.go b/query.go index 63ce91ed..48a657f9 100644 --- a/query.go +++ b/query.go @@ -212,74 +212,86 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { } } } else if s, ok := d.(pgtype.BinaryDecoder); ok && vr.Type().FormatCode == BinaryFormatCode { - err = s.DecodeBinary(vr.bytes()) + err = s.DecodeBinary(rows.conn.ConnInfo, vr.bytes()) if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } } else if s, ok := d.(pgtype.TextDecoder); ok && vr.Type().FormatCode == TextFormatCode { - err = s.DecodeText(vr.bytes()) + err = s.DecodeText(rows.conn.ConnInfo, vr.bytes()) if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } } else if s, ok := d.(sql.Scanner); ok { - var val interface{} + var sqlSrc interface{} if 0 <= vr.Len() { - switch vr.Type().DataType { - case BoolOid: - val = decodeBool(vr) - case Int8Oid: - val = int64(decodeInt8(vr)) - case Int2Oid: - val = int64(decodeInt2(vr)) - case Int4Oid: - val = int64(decodeInt4(vr)) - case TextOid, VarcharOid: - val = decodeText(vr) - case Float4Oid: - val = float64(decodeFloat4(vr)) - case Float8Oid: - val = decodeFloat8(vr) - case DateOid: - val = decodeDate(vr) - case TimestampOid: - val = decodeTimestamp(vr) - case TimestampTzOid: - val = decodeTimestampTz(vr) - default: - val = vr.ReadBytes(vr.Len()) + if dt, ok := rows.conn.ConnInfo.DataTypeForOid(vr.Type().DataType); ok { + value := dt.Value + + switch vr.Type().FormatCode { + case TextFormatCode: + decoder := value.(pgtype.TextDecoder) + if decoder == nil { + decoder = &pgtype.GenericText{} + } + err := decoder.DecodeText(rows.conn.ConnInfo, vr.bytes()) + if err != nil { + rows.Fatal(err) + } + case BinaryFormatCode: + decoder := value.(pgtype.BinaryDecoder) + if decoder == nil { + decoder = &pgtype.GenericBinary{} + } + err := decoder.DecodeBinary(rows.conn.ConnInfo, vr.bytes()) + if err != nil { + rows.Fatal(err) + } + default: + rows.Fatal(errors.New("Unknown format code")) + } + + sqlSrc, err = pgtype.DatabaseSQLValue(rows.conn.ConnInfo, value) + if err != nil { + rows.Fatal(err) + } + } else { + rows.Fatal(errors.New("Unknown type")) } } - err = s.Scan(val) + err = s.Scan(sqlSrc) if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } } else { - if pgVal, present := rows.conn.oidPgtypeValues[vr.Type().DataType]; present { + if dt, ok := rows.conn.ConnInfo.DataTypeForOid(vr.Type().DataType); ok { + value := dt.Value switch vr.Type().FormatCode { case TextFormatCode: - if textDecoder, ok := pgVal.(pgtype.TextDecoder); ok { - err = textDecoder.DecodeText(vr.bytes()) + if textDecoder, ok := value.(pgtype.TextDecoder); ok { + err = textDecoder.DecodeText(rows.conn.ConnInfo, vr.bytes()) if err != nil { vr.Fatal(err) } } else { - vr.Fatal(fmt.Errorf("%T is not a pgtype.TextDecoder", pgVal)) + vr.Fatal(fmt.Errorf("%T is not a pgtype.TextDecoder", value)) } case BinaryFormatCode: - if binaryDecoder, ok := pgVal.(pgtype.BinaryDecoder); ok { - err = binaryDecoder.DecodeBinary(vr.bytes()) + if binaryDecoder, ok := value.(pgtype.BinaryDecoder); ok { + err = binaryDecoder.DecodeBinary(rows.conn.ConnInfo, vr.bytes()) if err != nil { vr.Fatal(err) } } else { - vr.Fatal(fmt.Errorf("%T is not a pgtype.BinaryDecoder", pgVal)) + vr.Fatal(fmt.Errorf("%T is not a pgtype.BinaryDecoder", value)) } default: vr.Fatal(fmt.Errorf("unknown format code: %v", vr.Type().FormatCode)) } - if err := pgVal.AssignTo(d); err != nil { - vr.Fatal(err) + if vr.Err() == nil { + if err := value.AssignTo(d); err != nil { + vr.Fatal(err) + } } } else { if err := Decode(vr, d); err != nil { @@ -315,29 +327,35 @@ func (rows *Rows) Values() ([]interface{}, error) { continue } - switch vr.Type().FormatCode { - case TextFormatCode: - decoder := rows.conn.oidPgtypeValues[vr.Type().DataType].(pgtype.TextDecoder) - if decoder == nil { - decoder = &pgtype.GenericText{} + if dt, ok := rows.conn.ConnInfo.DataTypeForOid(vr.Type().DataType); ok { + value := dt.Value + + switch vr.Type().FormatCode { + case TextFormatCode: + decoder := value.(pgtype.TextDecoder) + if decoder == nil { + decoder = &pgtype.GenericText{} + } + err := decoder.DecodeText(rows.conn.ConnInfo, vr.bytes()) + if err != nil { + rows.Fatal(err) + } + values = append(values, decoder.(pgtype.Value).Get()) + case BinaryFormatCode: + decoder := value.(pgtype.BinaryDecoder) + if decoder == nil { + decoder = &pgtype.GenericBinary{} + } + err := decoder.DecodeBinary(rows.conn.ConnInfo, vr.bytes()) + if err != nil { + rows.Fatal(err) + } + values = append(values, value.Get()) + default: + rows.Fatal(errors.New("Unknown format code")) } - err := decoder.DecodeText(vr.bytes()) - if err != nil { - rows.Fatal(err) - } - values = append(values, decoder.(pgtype.Value).Get()) - case BinaryFormatCode: - decoder := rows.conn.oidPgtypeValues[vr.Type().DataType].(pgtype.BinaryDecoder) - if decoder == nil { - decoder = &pgtype.GenericBinary{} - } - err := decoder.DecodeBinary(vr.bytes()) - if err != nil { - rows.Fatal(err) - } - values = append(values, decoder.(pgtype.Value).Get()) - default: - rows.Fatal(errors.New("Unknown format code")) + } else { + rows.Fatal(errors.New("Unknown type")) } if vr.Err() != nil { @@ -368,49 +386,41 @@ func (rows *Rows) ValuesForStdlib() ([]interface{}, error) { values = append(values, nil) continue } - // TODO - consider what are the implications of returning complex types since database/sql uses this method - switch vr.Type().FormatCode { - // All intrinsic types (except string) are encoded with binary - // encoding so anything else should be treated as a string - case TextFormatCode: - values = append(values, vr.ReadString(vr.Len())) - case BinaryFormatCode: - switch vr.Type().DataType { - case TextOid, VarcharOid: - values = append(values, decodeText(vr)) - case BoolOid: - values = append(values, decodeBool(vr)) - case ByteaOid: - values = append(values, decodeBytea(vr)) - case Int8Oid: - values = append(values, decodeInt8(vr)) - case Int2Oid: - values = append(values, decodeInt2(vr)) - case Int4Oid: - values = append(values, decodeInt4(vr)) - case Float4Oid: - values = append(values, decodeFloat4(vr)) - case Float8Oid: - values = append(values, decodeFloat8(vr)) - case DateOid: - values = append(values, decodeDate(vr)) - case TimestampTzOid: - values = append(values, decodeTimestampTz(vr)) - case TimestampOid: - values = append(values, decodeTimestamp(vr)) - case JsonOid: - var d interface{} - decodeJSON(vr, &d) - values = append(values, d) - case JsonbOid: - var d interface{} - decodeJSONB(vr, &d) - values = append(values, d) + + if dt, ok := rows.conn.ConnInfo.DataTypeForOid(vr.Type().DataType); ok { + value := dt.Value + + switch vr.Type().FormatCode { + case TextFormatCode: + decoder := value.(pgtype.TextDecoder) + if decoder == nil { + decoder = &pgtype.GenericText{} + } + err := decoder.DecodeText(rows.conn.ConnInfo, vr.bytes()) + if err != nil { + rows.Fatal(err) + } + case BinaryFormatCode: + decoder := value.(pgtype.BinaryDecoder) + if decoder == nil { + decoder = &pgtype.GenericBinary{} + } + err := decoder.DecodeBinary(rows.conn.ConnInfo, vr.bytes()) + if err != nil { + rows.Fatal(err) + } default: - rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types")) + rows.Fatal(errors.New("Unknown format code")) } - default: - rows.Fatal(errors.New("Unknown format code")) + + sqlSrc, err := pgtype.DatabaseSQLValue(rows.conn.ConnInfo, value) + if err != nil { + rows.Fatal(err) + } + + values = append(values, sqlSrc) + } else { + rows.Fatal(errors.New("Unknown type")) } if vr.Err() != nil { diff --git a/query_test.go b/query_test.go index 01889444..480959e8 100644 --- a/query_test.go +++ b/query_test.go @@ -776,7 +776,7 @@ func TestQueryRowErrors(t *testing.T) { {"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`}, {"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"}, {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "cannot decode"}, - {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "Cannot encode int into oid 600"}, + {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "cannot convert 705 to Text"}, } for i, tt := range tests { diff --git a/values.go b/values.go index d90c363b..4eb24eef 100644 --- a/values.go +++ b/values.go @@ -5,9 +5,7 @@ import ( "database/sql/driver" "encoding/json" "fmt" - "math" "reflect" - "time" "github.com/jackc/pgx/pgtype" ) @@ -167,7 +165,7 @@ func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error { switch arg := arg.(type) { case pgtype.BinaryEncoder: buf := &bytes.Buffer{} - null, err := arg.EncodeBinary(buf) + null, err := arg.EncodeBinary(wbuf.conn.ConnInfo, buf) if err != nil { return err } @@ -180,7 +178,7 @@ func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error { return nil case pgtype.TextEncoder: buf := &bytes.Buffer{} - null, err := arg.EncodeText(buf) + null, err := arg.EncodeText(wbuf.conn.ConnInfo, buf) if err != nil { return err } @@ -214,14 +212,15 @@ func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error { return Encode(wbuf, oid, arg) } - if value, ok := wbuf.conn.oidPgtypeValues[oid]; ok { + if dt, ok := wbuf.conn.ConnInfo.DataTypeForOid(oid); ok { + value := dt.Value err := value.Set(arg) if err != nil { return err } buf := &bytes.Buffer{} - null, err := value.(pgtype.BinaryEncoder).EncodeBinary(buf) + null, err := value.(pgtype.BinaryEncoder).EncodeBinary(wbuf.conn.ConnInfo, buf) if err != nil { return err } @@ -287,8 +286,6 @@ func Decode(vr *ValueReader, d interface{}) error { switch v := d.(type) { case *string: *v = decodeText(vr) - case *[]interface{}: - *v = decodeRecord(vr) default: if v := reflect.ValueOf(d); v.Kind() == reflect.Ptr { el := v.Elem() @@ -320,232 +317,6 @@ func Decode(vr *ValueReader, d interface{}) error { return nil } -func decodeBool(vr *ValueReader) bool { - if vr.Type().DataType != BoolOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into bool", vr.Type().DataType))) - return false - } - - var b pgtype.Bool - var err error - switch vr.Type().FormatCode { - case TextFormatCode: - err = b.DecodeText(vr.bytes()) - case BinaryFormatCode: - err = b.DecodeBinary(vr.bytes()) - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return false - } - - if err != nil { - vr.Fatal(err) - return false - } - - if b.Status != pgtype.Present { - vr.Fatal(fmt.Errorf("Cannot decode null into bool")) - return false - } - - return b.Bool -} - -func decodeInt(vr *ValueReader) int64 { - switch vr.Type().DataType { - case Int2Oid: - return int64(decodeInt2(vr)) - case Int4Oid: - return int64(decodeInt4(vr)) - case Int8Oid: - return int64(decodeInt8(vr)) - } - - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into any integer type", vr.Type().DataType))) - return 0 -} - -func decodeInt8(vr *ValueReader) int64 { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into int64")) - return 0 - } - - if vr.Type().DataType != Int8Oid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int8", vr.Type().DataType))) - return 0 - } - - var n pgtype.Int8 - var err error - switch vr.Type().FormatCode { - case TextFormatCode: - err = n.DecodeText(vr.bytes()) - case BinaryFormatCode: - err = n.DecodeBinary(vr.bytes()) - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return 0 - } - - if err != nil { - vr.Fatal(err) - return 0 - } - - if n.Status == pgtype.Null { - vr.Fatal(ProtocolError("Cannot decode null into int16")) - return 0 - } - - return n.Int -} - -func decodeInt2(vr *ValueReader) int16 { - - if vr.Type().DataType != Int2Oid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int16", vr.Type().DataType))) - return 0 - } - - var n pgtype.Int2 - var err error - switch vr.Type().FormatCode { - case TextFormatCode: - err = n.DecodeText(vr.bytes()) - case BinaryFormatCode: - err = n.DecodeBinary(vr.bytes()) - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return 0 - } - - if err != nil { - vr.Fatal(err) - return 0 - } - - if n.Status == pgtype.Null { - vr.Fatal(ProtocolError("Cannot decode null into int16")) - return 0 - } - - return n.Int -} - -func decodeInt4(vr *ValueReader) int32 { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into int32")) - return 0 - } - - if vr.Type().DataType != Int4Oid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int32", vr.Type().DataType))) - return 0 - } - - var n pgtype.Int4 - var err error - switch vr.Type().FormatCode { - case TextFormatCode: - err = n.DecodeText(vr.bytes()) - case BinaryFormatCode: - err = n.DecodeBinary(vr.bytes()) - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return 0 - } - - if err != nil { - vr.Fatal(err) - return 0 - } - - if n.Status == pgtype.Null { - vr.Fatal(ProtocolError("Cannot decode null into int16")) - return 0 - } - - return n.Int -} - -func decodeFloat4(vr *ValueReader) float32 { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into float32")) - return 0 - } - - if vr.Type().DataType != Float4Oid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into float32", vr.Type().DataType))) - return 0 - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return 0 - } - - if vr.Len() != 4 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", vr.Len()))) - return 0 - } - - i := vr.ReadInt32() - return math.Float32frombits(uint32(i)) -} - -func encodeFloat32(w *WriteBuf, oid pgtype.Oid, value float32) error { - switch oid { - case Float4Oid: - w.WriteInt32(4) - w.WriteInt32(int32(math.Float32bits(value))) - case Float8Oid: - w.WriteInt32(8) - w.WriteInt64(int64(math.Float64bits(float64(value)))) - default: - return fmt.Errorf("cannot encode %s into oid %v", "float32", oid) - } - - return nil -} - -func decodeFloat8(vr *ValueReader) float64 { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into float64")) - return 0 - } - - if vr.Type().DataType != Float8Oid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into float64", vr.Type().DataType))) - return 0 - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return 0 - } - - if vr.Len() != 8 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float8: %d", vr.Len()))) - return 0 - } - - i := vr.ReadInt64() - return math.Float64frombits(uint64(i)) -} - -func encodeFloat64(w *WriteBuf, oid pgtype.Oid, value float64) error { - switch oid { - case Float8Oid: - w.WriteInt32(8) - w.WriteInt64(int64(math.Float64bits(value))) - default: - return fmt.Errorf("cannot encode %s into oid %v", "float64", oid) - } - - return nil -} - func decodeText(vr *ValueReader) string { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into string")) @@ -677,215 +448,3 @@ func encodeJSONB(w *WriteBuf, oid pgtype.Oid, value interface{}) error { return nil } - -func decodeDate(vr *ValueReader) time.Time { - if vr.Type().DataType != DateOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) - return time.Time{} - } - - var d pgtype.Date - var err error - switch vr.Type().FormatCode { - case TextFormatCode: - err = d.DecodeText(vr.bytes()) - case BinaryFormatCode: - err = d.DecodeBinary(vr.bytes()) - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return time.Time{} - } - - if err != nil { - vr.Fatal(err) - return time.Time{} - } - - if d.Status == pgtype.Null { - vr.Fatal(ProtocolError("Cannot decode null into int16")) - return time.Time{} - } - - return d.Time -} - -func encodeTime(w *WriteBuf, oid pgtype.Oid, value time.Time) error { - switch oid { - case DateOid: - var d pgtype.Date - err := d.Set(value) - if err != nil { - return err - } - - buf := &bytes.Buffer{} - null, err := d.EncodeBinary(buf) - if err != nil { - return err - } - if null { - w.WriteInt32(-1) - } else { - w.WriteInt32(int32(buf.Len())) - w.WriteBytes(buf.Bytes()) - } - return nil - - case TimestampTzOid, TimestampOid: - var t pgtype.Timestamptz - err := t.Set(value) - if err != nil { - return err - } - - buf := &bytes.Buffer{} - null, err := t.EncodeBinary(buf) - if err != nil { - return err - } - if null { - w.WriteInt32(-1) - } else { - w.WriteInt32(int32(buf.Len())) - w.WriteBytes(buf.Bytes()) - } - return nil - default: - return fmt.Errorf("cannot encode %s into oid %v", "time.Time", oid) - } -} - -const microsecFromUnixEpochToY2K = 946684800 * 1000000 - -func decodeTimestampTz(vr *ValueReader) time.Time { - var zeroTime time.Time - - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into time.Time")) - return zeroTime - } - - if vr.Type().DataType != TimestampTzOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) - return zeroTime - } - - var t pgtype.Timestamptz - var err error - switch vr.Type().FormatCode { - case TextFormatCode: - err = t.DecodeText(vr.bytes()) - case BinaryFormatCode: - err = t.DecodeBinary(vr.bytes()) - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return time.Time{} - } - - if err != nil { - vr.Fatal(err) - return time.Time{} - } - - if t.Status == pgtype.Null { - vr.Fatal(ProtocolError("Cannot decode null into time.Time")) - return time.Time{} - } - - return t.Time -} - -func decodeTimestamp(vr *ValueReader) time.Time { - var zeroTime time.Time - - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into timestamp")) - return zeroTime - } - - if vr.Type().DataType != TimestampOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) - return zeroTime - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return zeroTime - } - - if vr.Len() != 8 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an timestamp: %d", vr.Len()))) - return zeroTime - } - - microsecSinceY2K := vr.ReadInt64() - microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K - return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) -} - -func decodeRecord(vr *ValueReader) []interface{} { - if vr.Len() == -1 { - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - if vr.Type().DataType != RecordOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []interface{}", vr.Type().DataType))) - return nil - } - - valueCount := vr.ReadInt32() - record := make([]interface{}, 0, int(valueCount)) - - for i := int32(0); i < valueCount; i++ { - fd := FieldDescription{FormatCode: BinaryFormatCode} - fieldVR := ValueReader{mr: vr.mr, fd: &fd} - fd.DataType = vr.ReadOid() - fieldVR.valueBytesRemaining = vr.ReadInt32() - vr.valueBytesRemaining -= fieldVR.valueBytesRemaining - - switch fd.DataType { - case BoolOid: - record = append(record, decodeBool(&fieldVR)) - case ByteaOid: - record = append(record, decodeBytea(&fieldVR)) - case Int8Oid: - record = append(record, decodeInt8(&fieldVR)) - case Int2Oid: - record = append(record, decodeInt2(&fieldVR)) - case Int4Oid: - record = append(record, decodeInt4(&fieldVR)) - case Float4Oid: - record = append(record, decodeFloat4(&fieldVR)) - case Float8Oid: - record = append(record, decodeFloat8(&fieldVR)) - case DateOid: - record = append(record, decodeDate(&fieldVR)) - case TimestampTzOid: - record = append(record, decodeTimestampTz(&fieldVR)) - case TimestampOid: - record = append(record, decodeTimestamp(&fieldVR)) - case TextOid, VarcharOid, UnknownOid: - record = append(record, decodeTextAllowBinary(&fieldVR)) - default: - vr.Fatal(fmt.Errorf("decodeRecord cannot decode oid %d", fd.DataType)) - return nil - } - - // Consume any remaining data - if fieldVR.Len() > 0 { - fieldVR.ReadBytes(fieldVR.Len()) - } - - if fieldVR.Err() != nil { - vr.Fatal(fieldVR.Err()) - return nil - } - } - - return record -} diff --git a/values_test.go b/values_test.go index e7ae7e1d..1d09eb18 100644 --- a/values_test.go +++ b/values_test.go @@ -6,9 +6,6 @@ import ( "reflect" "testing" "time" - - "github.com/jackc/pgx" - "github.com/jackc/pgx/pgtype" ) func TestDateTranscode(t *testing.T) { @@ -78,159 +75,161 @@ func TestTimestampTzTranscode(t *testing.T) { } } -func TestJSONAndJSONBTranscode(t *testing.T) { - t.Parallel() +// TODO - move these tests to pgtype - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) +// func TestJSONAndJSONBTranscode(t *testing.T) { +// t.Parallel() - for _, oid := range []pgtype.Oid{pgx.JsonOid, pgx.JsonbOid} { - if _, ok := conn.PgTypes[oid]; !ok { - return // No JSON/JSONB type -- must be running against old PostgreSQL - } +// conn := mustConnect(t, *defaultConnConfig) +// defer closeConn(t, conn) - for _, format := range []int16{pgx.TextFormatCode, pgx.BinaryFormatCode} { - pgtype := conn.PgTypes[oid] - pgtype.DefaultFormat = format - conn.PgTypes[oid] = pgtype +// for _, oid := range []pgtype.Oid{pgx.JsonOid, pgx.JsonbOid} { +// if _, ok := conn.ConnInfo.DataTypeForOid(oid); !ok { +// return // No JSON/JSONB type -- must be running against old PostgreSQL +// } - typename := conn.PgTypes[oid].Name +// for _, format := range []int16{pgx.TextFormatCode, pgx.BinaryFormatCode} { +// pgtype := conn.PgTypes[oid] +// pgtype.DefaultFormat = format +// conn.PgTypes[oid] = pgtype - testJSONString(t, conn, typename, format) - testJSONStringPointer(t, conn, typename, format) - testJSONSingleLevelStringMap(t, conn, typename, format) - testJSONNestedMap(t, conn, typename, format) - testJSONStringArray(t, conn, typename, format) - testJSONInt64Array(t, conn, typename, format) - testJSONInt16ArrayFailureDueToOverflow(t, conn, typename, format) - testJSONStruct(t, conn, typename, format) - } - } -} +// typename := conn.PgTypes[oid].Name -func testJSONString(t *testing.T, conn *pgx.Conn, typename string, format int16) { - input := `{"key": "value"}` - expectedOutput := map[string]string{"key": "value"} - var output map[string]string - err := conn.QueryRow("select $1::"+typename, input).Scan(&output) - if err != nil { - t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) - return - } +// testJSONString(t, conn, typename, format) +// testJSONStringPointer(t, conn, typename, format) +// testJSONSingleLevelStringMap(t, conn, typename, format) +// testJSONNestedMap(t, conn, typename, format) +// testJSONStringArray(t, conn, typename, format) +// testJSONInt64Array(t, conn, typename, format) +// testJSONInt16ArrayFailureDueToOverflow(t, conn, typename, format) +// testJSONStruct(t, conn, typename, format) +// } +// } +// } - if !reflect.DeepEqual(expectedOutput, output) { - t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, expectedOutput, output) - return - } -} +// func testJSONString(t *testing.T, conn *pgx.Conn, typename string, format int16) { +// input := `{"key": "value"}` +// expectedOutput := map[string]string{"key": "value"} +// var output map[string]string +// err := conn.QueryRow("select $1::"+typename, input).Scan(&output) +// if err != nil { +// t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) +// return +// } -func testJSONStringPointer(t *testing.T, conn *pgx.Conn, typename string, format int16) { - input := `{"key": "value"}` - expectedOutput := map[string]string{"key": "value"} - var output map[string]string - err := conn.QueryRow("select $1::"+typename, &input).Scan(&output) - if err != nil { - t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) - return - } +// if !reflect.DeepEqual(expectedOutput, output) { +// t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, expectedOutput, output) +// return +// } +// } - if !reflect.DeepEqual(expectedOutput, output) { - t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, expectedOutput, output) - return - } -} +// func testJSONStringPointer(t *testing.T, conn *pgx.Conn, typename string, format int16) { +// input := `{"key": "value"}` +// expectedOutput := map[string]string{"key": "value"} +// var output map[string]string +// err := conn.QueryRow("select $1::"+typename, &input).Scan(&output) +// if err != nil { +// t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) +// return +// } -func testJSONSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string, format int16) { - input := map[string]string{"key": "value"} - var output map[string]string - err := conn.QueryRow("select $1::"+typename, input).Scan(&output) - if err != nil { - t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) - return - } +// if !reflect.DeepEqual(expectedOutput, output) { +// t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, expectedOutput, output) +// return +// } +// } - if !reflect.DeepEqual(input, output) { - t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, input, output) - return - } -} +// func testJSONSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string, format int16) { +// input := map[string]string{"key": "value"} +// var output map[string]string +// err := conn.QueryRow("select $1::"+typename, input).Scan(&output) +// if err != nil { +// t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) +// return +// } -func testJSONNestedMap(t *testing.T, conn *pgx.Conn, typename string, format int16) { - input := map[string]interface{}{ - "name": "Uncanny", - "stats": map[string]interface{}{"hp": float64(107), "maxhp": float64(150)}, - "inventory": []interface{}{"phone", "key"}, - } - var output map[string]interface{} - err := conn.QueryRow("select $1::"+typename, input).Scan(&output) - if err != nil { - t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) - return - } +// if !reflect.DeepEqual(input, output) { +// t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, input, output) +// return +// } +// } - if !reflect.DeepEqual(input, output) { - t.Errorf("%s %d: Did not transcode map[string]interface{} successfully: %v is not %v", typename, format, input, output) - return - } -} +// func testJSONNestedMap(t *testing.T, conn *pgx.Conn, typename string, format int16) { +// input := map[string]interface{}{ +// "name": "Uncanny", +// "stats": map[string]interface{}{"hp": float64(107), "maxhp": float64(150)}, +// "inventory": []interface{}{"phone", "key"}, +// } +// var output map[string]interface{} +// err := conn.QueryRow("select $1::"+typename, input).Scan(&output) +// if err != nil { +// t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) +// return +// } -func testJSONStringArray(t *testing.T, conn *pgx.Conn, typename string, format int16) { - input := []string{"foo", "bar", "baz"} - var output []string - err := conn.QueryRow("select $1::"+typename, input).Scan(&output) - if err != nil { - t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) - } +// if !reflect.DeepEqual(input, output) { +// t.Errorf("%s %d: Did not transcode map[string]interface{} successfully: %v is not %v", typename, format, input, output) +// return +// } +// } - if !reflect.DeepEqual(input, output) { - t.Errorf("%s %d: Did not transcode []string successfully: %v is not %v", typename, format, input, output) - } -} +// func testJSONStringArray(t *testing.T, conn *pgx.Conn, typename string, format int16) { +// input := []string{"foo", "bar", "baz"} +// var output []string +// err := conn.QueryRow("select $1::"+typename, input).Scan(&output) +// if err != nil { +// t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) +// } -func testJSONInt64Array(t *testing.T, conn *pgx.Conn, typename string, format int16) { - input := []int64{1, 2, 234432} - var output []int64 - err := conn.QueryRow("select $1::"+typename, input).Scan(&output) - if err != nil { - t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) - } +// if !reflect.DeepEqual(input, output) { +// t.Errorf("%s %d: Did not transcode []string successfully: %v is not %v", typename, format, input, output) +// } +// } - if !reflect.DeepEqual(input, output) { - t.Errorf("%s %d: Did not transcode []int64 successfully: %v is not %v", typename, format, input, output) - } -} +// func testJSONInt64Array(t *testing.T, conn *pgx.Conn, typename string, format int16) { +// input := []int64{1, 2, 234432} +// var output []int64 +// err := conn.QueryRow("select $1::"+typename, input).Scan(&output) +// if err != nil { +// t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) +// } -func testJSONInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string, format int16) { - input := []int{1, 2, 234432} - var output []int16 - err := conn.QueryRow("select $1::"+typename, input).Scan(&output) - if err == nil || err.Error() != "can't scan into dest[0]: json: cannot unmarshal number 234432 into Go value of type int16" { - t.Errorf("%s %d: Expected *json.UnmarkalTypeError, but got %v", typename, format, err) - } -} +// if !reflect.DeepEqual(input, output) { +// t.Errorf("%s %d: Did not transcode []int64 successfully: %v is not %v", typename, format, input, output) +// } +// } -func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string, format int16) { - type person struct { - Name string `json:"name"` - Age int `json:"age"` - } +// func testJSONInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string, format int16) { +// input := []int{1, 2, 234432} +// var output []int16 +// err := conn.QueryRow("select $1::"+typename, input).Scan(&output) +// if err == nil || err.Error() != "can't scan into dest[0]: json: cannot unmarshal number 234432 into Go value of type int16" { +// t.Errorf("%s %d: Expected *json.UnmarkalTypeError, but got %v", typename, format, err) +// } +// } - input := person{ - Name: "John", - Age: 42, - } +// func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string, format int16) { +// type person struct { +// Name string `json:"name"` +// Age int `json:"age"` +// } - var output person +// input := person{ +// Name: "John", +// Age: 42, +// } - err := conn.QueryRow("select $1::"+typename, input).Scan(&output) - if err != nil { - t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) - } +// var output person - if !reflect.DeepEqual(input, output) { - t.Errorf("%s %d: Did not transcode struct successfully: %v is not %v", typename, format, input, output) - } -} +// err := conn.QueryRow("select $1::"+typename, input).Scan(&output) +// if err != nil { +// t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) +// } + +// if !reflect.DeepEqual(input, output) { +// t.Errorf("%s %d: Did not transcode struct successfully: %v is not %v", typename, format, input, output) +// } +// } func mustParseCidr(t *testing.T, s string) *net.IPNet { _, ipnet, err := net.ParseCIDR(s) From cf70e6b9f491fc8bc51ae3dbf8eda2f4470e9511 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 18 Mar 2017 12:40:54 -0500 Subject: [PATCH 120/264] Add pgtype.HstoreArray This required restructuring array types to lookup oid of element instead of hard-coding it due to hstore having a variable oid. --- pgtype/bool_array.go | 11 +- pgtype/bytea_array.go | 11 +- pgtype/cidr_array.go | 11 +- pgtype/date_array.go | 11 +- pgtype/float4_array.go | 11 +- pgtype/float8_array.go | 11 +- pgtype/hstore_array.go | 297 ++++++++++++++++++++++++++++++++++++ pgtype/hstore_array_test.go | 183 ++++++++++++++++++++++ pgtype/inet_array.go | 11 +- pgtype/int2_array.go | 11 +- pgtype/int4_array.go | 11 +- pgtype/int8_array.go | 11 +- pgtype/text_array.go | 11 +- pgtype/timestamp_array.go | 11 +- pgtype/timestamptz_array.go | 11 +- pgtype/typed_array.go.erb | 11 +- pgtype/typed_array_gen.sh | 31 ++-- pgtype/varchar_array.go | 11 +- 18 files changed, 586 insertions(+), 90 deletions(-) create mode 100644 pgtype/hstore_array.go create mode 100644 pgtype/hstore_array_test.go diff --git a/pgtype/bool_array.go b/pgtype/bool_array.go index 1cb46cf6..6adfbb00 100644 --- a/pgtype/bool_array.go +++ b/pgtype/bool_array.go @@ -238,10 +238,6 @@ func (src *BoolArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *BoolArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, BoolOid) -} - -func (src *BoolArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -250,10 +246,15 @@ func (src *BoolArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("bool"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "bool") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/pgtype/bytea_array.go b/pgtype/bytea_array.go index 30405509..d318fa3b 100644 --- a/pgtype/bytea_array.go +++ b/pgtype/bytea_array.go @@ -238,10 +238,6 @@ func (src *ByteaArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *ByteaArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, ByteaOid) -} - -func (src *ByteaArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -250,10 +246,15 @@ func (src *ByteaArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("bytea"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "bytea") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/pgtype/cidr_array.go b/pgtype/cidr_array.go index 32d2e7bf..3ab83ecd 100644 --- a/pgtype/cidr_array.go +++ b/pgtype/cidr_array.go @@ -270,10 +270,6 @@ func (src *CidrArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *CidrArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, CidrOid) -} - -func (src *CidrArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -282,10 +278,15 @@ func (src *CidrArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("cidr"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "cidr") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/pgtype/date_array.go b/pgtype/date_array.go index ba68d561..8bc8ff72 100644 --- a/pgtype/date_array.go +++ b/pgtype/date_array.go @@ -239,10 +239,6 @@ func (src *DateArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *DateArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, DateOid) -} - -func (src *DateArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -251,10 +247,15 @@ func (src *DateArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("date"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "date") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/pgtype/float4_array.go b/pgtype/float4_array.go index 40152bcf..6abc1a31 100644 --- a/pgtype/float4_array.go +++ b/pgtype/float4_array.go @@ -238,10 +238,6 @@ func (src *Float4Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *Float4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, Float4Oid) -} - -func (src *Float4Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -250,10 +246,15 @@ func (src *Float4Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32 } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("float4"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "float4") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/pgtype/float8_array.go b/pgtype/float8_array.go index d0ee0d70..050efa3f 100644 --- a/pgtype/float8_array.go +++ b/pgtype/float8_array.go @@ -238,10 +238,6 @@ func (src *Float8Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *Float8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, Float8Oid) -} - -func (src *Float8Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -250,10 +246,15 @@ func (src *Float8Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32 } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("float8"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "float8") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/pgtype/hstore_array.go b/pgtype/hstore_array.go new file mode 100644 index 00000000..ba192462 --- /dev/null +++ b/pgtype/hstore_array.go @@ -0,0 +1,297 @@ +package pgtype + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type HstoreArray struct { + Elements []Hstore + Dimensions []ArrayDimension + Status Status +} + +func (dst *HstoreArray) Set(src interface{}) error { + switch value := src.(type) { + + case []map[string]string: + if value == nil { + *dst = HstoreArray{Status: Null} + } else if len(value) == 0 { + *dst = HstoreArray{Status: Present} + } else { + elements := make([]Hstore, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = HstoreArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Hstore", value) + } + + return nil +} + +func (dst *HstoreArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *HstoreArray) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[]map[string]string: + if src.Status == Present { + *v = make([]map[string]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +func (dst *HstoreArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = HstoreArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Hstore + + if len(uta.Elements) > 0 { + elements = make([]Hstore, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Hstore + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = HstoreArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *HstoreArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = HstoreArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = HstoreArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Hstore, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = HstoreArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *HstoreArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if len(src.Dimensions) == 0 { + _, err := io.WriteString(w, "{}") + return false, err + } + + err := EncodeTextArrayDimensions(w, src.Dimensions) + if err != nil { + return false, err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(w, ',') + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(w, '{') + if err != nil { + return false, err + } + } + } + + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else { + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(w, '}') + if err != nil { + return false, err + } + } + } + } + + return false, nil +} + +func (src *HstoreArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("hstore"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "hstore") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(ci, w) + if err != nil { + return false, err + } + + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } + } + } + + return false, err +} diff --git a/pgtype/hstore_array_test.go b/pgtype/hstore_array_test.go new file mode 100644 index 00000000..e23c7b3b --- /dev/null +++ b/pgtype/hstore_array_test.go @@ -0,0 +1,183 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" +) + +func TestHstoreArrayTranscode(t *testing.T) { + conn := mustConnectPgx(t) + defer mustClose(t, conn) + + text := func(s string) pgtype.Text { + return pgtype.Text{String: s, Status: pgtype.Present} + } + + values := []pgtype.Hstore{ + pgtype.Hstore{Map: map[string]pgtype.Text{}, Status: pgtype.Present}, + pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, + pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, + pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, + pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, + pgtype.Hstore{Status: pgtype.Null}, + } + + specialStrings := []string{ + `"`, + `'`, + `\`, + `\\`, + `=>`, + ` `, + `\ / / \\ => " ' " '`, + } + for _, s := range specialStrings { + // Special key values + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Status: pgtype.Present}) // at beginning + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Status: pgtype.Present}) // in middle + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Status: pgtype.Present}) // at end + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Status: pgtype.Present}) // is key + + // Special value values + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Status: pgtype.Present}) // at beginning + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Status: pgtype.Present}) // in middle + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Status: pgtype.Present}) // at end + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key + } + + src := pgtype.HstoreArray{ + Elements: values, + Dimensions: []pgtype.ArrayDimension{{Length: int32(len(values)), LowerBound: 1}}, + Status: pgtype.Present, + } + + ps, err := conn.Prepare("test", "select $1::hstore[]") + if err != nil { + t.Fatal(err) + } + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for _, fc := range formats { + ps.FieldDescriptions[0].FormatCode = fc.formatCode + vEncoder := forceEncoder(src, fc.formatCode) + if vEncoder == nil { + t.Logf("%#v does not implement %v", src, fc.name) + continue + } + + var result pgtype.HstoreArray + err := conn.QueryRow("test", vEncoder).Scan(&result) + if err != nil { + t.Errorf("%v: %v", fc.name, err) + continue + } + + if result.Status != src.Status { + t.Errorf("%v: expected Status %v, got %v", fc.formatCode, src.Status, result.Status) + continue + } + + if len(result.Elements) != len(src.Elements) { + t.Errorf("%v: expected %v elements, got %v", fc.formatCode, len(src.Elements), len(result.Elements)) + continue + } + + for i := range result.Elements { + a := src.Elements[i] + b := result.Elements[i] + + if a.Status != b.Status { + t.Errorf("%v element idx %d: expected status %v, got %v", fc.formatCode, i, a.Status, b.Status) + } + + if len(a.Map) != len(b.Map) { + t.Errorf("%v element idx %d: expected %v pairs, got %v", fc.formatCode, i, len(a.Map), len(b.Map)) + } + + for k := range a.Map { + if a.Map[k] != b.Map[k] { + t.Errorf("%v element idx %d: expected key %v to be %v, got %v", fc.formatCode, i, k, a.Map[k], b.Map[k]) + } + } + } + } +} + +func TestHstoreArraySet(t *testing.T) { + successfulTests := []struct { + src []map[string]string + result pgtype.HstoreArray + }{ + { + src: []map[string]string{map[string]string{"foo": "bar"}}, + result: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": pgtype.Text{String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + }, + } + + for i, tt := range successfulTests { + var dst pgtype.HstoreArray + err := dst.Set(tt.src) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(dst, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.src, tt.result, dst) + } + } +} + +func TestHstoreArrayAssignTo(t *testing.T) { + var m []map[string]string + + simpleTests := []struct { + src pgtype.HstoreArray + dst *[]map[string]string + expected []map[string]string + }{ + { + src: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": pgtype.Text{String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &m, + expected: []map[string]string{{"foo": "bar"}}}, + {src: pgtype.HstoreArray{Status: pgtype.Null}, dst: &m, expected: (([]map[string]string)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(*tt.dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } +} diff --git a/pgtype/inet_array.go b/pgtype/inet_array.go index 6cad82e7..d893a724 100644 --- a/pgtype/inet_array.go +++ b/pgtype/inet_array.go @@ -270,10 +270,6 @@ func (src *InetArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *InetArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, InetOid) -} - -func (src *InetArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -282,10 +278,15 @@ func (src *InetArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("inet"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "inet") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/pgtype/int2_array.go b/pgtype/int2_array.go index 2bf1c237..b93a4fa3 100644 --- a/pgtype/int2_array.go +++ b/pgtype/int2_array.go @@ -269,10 +269,6 @@ func (src *Int2Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *Int2Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, Int2Oid) -} - -func (src *Int2Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -281,10 +277,15 @@ func (src *Int2Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("int2"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "int2") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/pgtype/int4_array.go b/pgtype/int4_array.go index dda88eaf..0b96b7a4 100644 --- a/pgtype/int4_array.go +++ b/pgtype/int4_array.go @@ -269,10 +269,6 @@ func (src *Int4Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *Int4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, Int4Oid) -} - -func (src *Int4Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -281,10 +277,15 @@ func (src *Int4Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("int4"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "int4") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/pgtype/int8_array.go b/pgtype/int8_array.go index 468c126b..02a240f4 100644 --- a/pgtype/int8_array.go +++ b/pgtype/int8_array.go @@ -269,10 +269,6 @@ func (src *Int8Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *Int8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, Int8Oid) -} - -func (src *Int8Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -281,10 +277,15 @@ func (src *Int8Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("int8"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "int8") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/pgtype/text_array.go b/pgtype/text_array.go index 6e89708f..9f25727e 100644 --- a/pgtype/text_array.go +++ b/pgtype/text_array.go @@ -238,10 +238,6 @@ func (src *TextArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *TextArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, TextOid) -} - -func (src *TextArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -250,10 +246,15 @@ func (src *TextArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("text"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "text") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/pgtype/timestamp_array.go b/pgtype/timestamp_array.go index 064ad483..bb19e502 100644 --- a/pgtype/timestamp_array.go +++ b/pgtype/timestamp_array.go @@ -239,10 +239,6 @@ func (src *TimestampArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *TimestampArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, TimestampOid) -} - -func (src *TimestampArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -251,10 +247,15 @@ func (src *TimestampArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid in } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("timestamp"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "timestamp") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/pgtype/timestamptz_array.go b/pgtype/timestamptz_array.go index 4af1460b..6a85cefa 100644 --- a/pgtype/timestamptz_array.go +++ b/pgtype/timestamptz_array.go @@ -239,10 +239,6 @@ func (src *TimestamptzArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) } func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, TimestamptzOid) -} - -func (src *TimestamptzArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -251,10 +247,15 @@ func (src *TimestamptzArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("timestamptz"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "timestamptz") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb index 2a46a658..2b81666e 100644 --- a/pgtype/typed_array.go.erb +++ b/pgtype/typed_array.go.erb @@ -237,10 +237,6 @@ func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool } func (src *<%= pgtype_array_type %>) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, <%= element_oid %>) -} - -func (src *<%= pgtype_array_type %>) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -249,10 +245,15 @@ func (src *<%= pgtype_array_type %>) encodeBinary(ci *ConnInfo, w io.Writer, ele } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("<%= element_type_name %>"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "<%= element_type_name %>") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh index 5fde32aa..166f8802 100644 --- a/pgtype/typed_array_gen.sh +++ b/pgtype/typed_array_gen.sh @@ -1,15 +1,16 @@ -erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_oid=Int2Oid text_null=NULL typed_array.go.erb > int2_array.go -erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32 element_oid=Int4Oid text_null=NULL typed_array.go.erb > int4_array.go -erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_oid=Int8Oid text_null=NULL typed_array.go.erb > int8_array.go -erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_oid=BoolOid text_null=NULL typed_array.go.erb > bool_array.go -erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_oid=DateOid text_null=NULL typed_array.go.erb > date_array.go -erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_oid=TimestamptzOid text_null=NULL typed_array.go.erb > timestamptz_array.go -erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_oid=TimestampOid text_null=NULL typed_array.go.erb > timestamp_array.go -erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_oid=Float4Oid text_null=NULL typed_array.go.erb > float4_array.go -erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8Oid text_null=NULL typed_array.go.erb > float8_array.go -erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOid text_null=NULL typed_array.go.erb > inet_array.go -erb pgtype_array_type=CidrArray pgtype_element_type=Cidr go_array_types=[]*net.IPNet,[]net.IP element_oid=CidrOid text_null=NULL typed_array.go.erb > cidr_array.go -erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_oid=TextOid text_null='"NULL"' typed_array.go.erb > text_array.go -erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string element_oid=VarcharOid text_null='"NULL"' typed_array.go.erb > varchar_array.go -erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_oid=ByteaOid text_null=NULL typed_array.go.erb > bytea_array.go -erb pgtype_array_type=AclitemArray pgtype_element_type=Aclitem go_array_types=[]string element_oid=AclitemOid text_null=NULL typed_array.go.erb > aclitem_array.go +erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_type_name=int2 text_null=NULL typed_array.go.erb > int2_array.go +erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32 element_type_name=int4 text_null=NULL typed_array.go.erb > int4_array.go +erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_type_name=int8 text_null=NULL typed_array.go.erb > int8_array.go +erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_type_name=bool text_null=NULL typed_array.go.erb > bool_array.go +erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_type_name=date text_null=NULL typed_array.go.erb > date_array.go +erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_type_name=timestamptz text_null=NULL typed_array.go.erb > timestamptz_array.go +erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_type_name=timestamp text_null=NULL typed_array.go.erb > timestamp_array.go +erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_type_name=float4 text_null=NULL typed_array.go.erb > float4_array.go +erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_type_name=float8 text_null=NULL typed_array.go.erb > float8_array.go +erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_type_name=inet text_null=NULL typed_array.go.erb > inet_array.go +erb pgtype_array_type=CidrArray pgtype_element_type=Cidr go_array_types=[]*net.IPNet,[]net.IP element_type_name=cidr text_null=NULL typed_array.go.erb > cidr_array.go +erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_type_name=text text_null='"NULL"' typed_array.go.erb > text_array.go +erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string element_type_name=varchar text_null='"NULL"' typed_array.go.erb > varchar_array.go +erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL typed_array.go.erb > bytea_array.go +erb pgtype_array_type=AclitemArray pgtype_element_type=Aclitem go_array_types=[]string element_type_name=aclitem text_null=NULL typed_array.go.erb > aclitem_array.go +erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL typed_array.go.erb > hstore_array.go diff --git a/pgtype/varchar_array.go b/pgtype/varchar_array.go index 21e9ccff..158ece94 100644 --- a/pgtype/varchar_array.go +++ b/pgtype/varchar_array.go @@ -238,10 +238,6 @@ func (src *VarcharArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *VarcharArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, VarcharOid) -} - -func (src *VarcharArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -250,10 +246,15 @@ func (src *VarcharArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int3 } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("varchar"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "varchar") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true From b9e2f0e8141e6713887216adbdc78807830525b5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 18 Mar 2017 13:54:42 -0500 Subject: [PATCH 121/264] Remove a lot of unused code --- query.go | 4 +- values.go | 126 ------------------------------------------------------ 2 files changed, 1 insertion(+), 129 deletions(-) diff --git a/query.go b/query.go index 48a657f9..a76a99bc 100644 --- a/query.go +++ b/query.go @@ -294,9 +294,7 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { } } } else { - if err := Decode(vr, d); err != nil { - rows.Fatal(scanArgError{col: i, err: err}) - } + rows.Fatal(scanArgError{col: i, err: fmt.Errorf("unknown oid: %v", vr.Type().DataType)}) } } if vr.Err() != nil { diff --git a/values.go b/values.go index 4eb24eef..682245eb 100644 --- a/values.go +++ b/values.go @@ -3,7 +3,6 @@ package pgx import ( "bytes" "database/sql/driver" - "encoding/json" "fmt" "reflect" @@ -279,44 +278,6 @@ func stripNamedType(val *reflect.Value) (interface{}, bool) { return nil, false } -// Decode decodes from vr into d. d must be a pointer. This allows -// implementations of the Decoder interface to delegate the actual work of -// decoding to the built-in functionality. -func Decode(vr *ValueReader, d interface{}) error { - switch v := d.(type) { - case *string: - *v = decodeText(vr) - default: - if v := reflect.ValueOf(d); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if d is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - // -1 is a null value - if vr.Len() == -1 { - if !el.IsNil() { - // if the destination pointer is not nil, nil it out - el.Set(reflect.Zero(el.Type())) - } - return nil - } - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - d = el.Interface() - return Decode(vr, d) - case reflect.String: - el.SetString(decodeText(vr)) - return nil - } - } - return fmt.Errorf("Scan cannot decode into %T", d) - } - - return nil -} - func decodeText(vr *ValueReader) string { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into string")) @@ -331,15 +292,6 @@ func decodeText(vr *ValueReader) string { return vr.ReadString(vr.Len()) } -func decodeTextAllowBinary(vr *ValueReader) string { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into string")) - return "" - } - - return vr.ReadString(vr.Len()) -} - func encodeString(w *WriteBuf, oid pgtype.Oid, value string) error { w.WriteInt32(int32(len(value))) w.WriteBytes([]byte(value)) @@ -370,81 +322,3 @@ func encodeByteSlice(w *WriteBuf, oid pgtype.Oid, value []byte) error { return nil } - -func decodeJSON(vr *ValueReader, d interface{}) error { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != JsonOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into json", vr.Type().DataType))) - } - - bytes := vr.ReadBytes(vr.Len()) - err := json.Unmarshal(bytes, d) - if err != nil { - vr.Fatal(err) - } - return err -} - -func encodeJSON(w *WriteBuf, oid pgtype.Oid, value interface{}) error { - if oid != JsonOid { - return fmt.Errorf("cannot encode JSON into oid %v", oid) - } - - s, err := json.Marshal(value) - if err != nil { - return fmt.Errorf("Failed to encode json from type: %T", value) - } - - w.WriteInt32(int32(len(s))) - w.WriteBytes(s) - - return nil -} - -func decodeJSONB(vr *ValueReader, d interface{}) error { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != JsonbOid { - err := ProtocolError(fmt.Sprintf("Cannot decode oid %v into jsonb", vr.Type().DataType)) - vr.Fatal(err) - return err - } - - bytes := vr.ReadBytes(vr.Len()) - if vr.Type().FormatCode == BinaryFormatCode { - if bytes[0] != 1 { - err := ProtocolError(fmt.Sprintf("Unknown jsonb format byte: %x", bytes[0])) - vr.Fatal(err) - return err - } - bytes = bytes[1:] - } - - err := json.Unmarshal(bytes, d) - if err != nil { - vr.Fatal(err) - } - return err -} - -func encodeJSONB(w *WriteBuf, oid pgtype.Oid, value interface{}) error { - if oid != JsonbOid { - return fmt.Errorf("cannot encode JSON into oid %v", oid) - } - - s, err := json.Marshal(value) - if err != nil { - return fmt.Errorf("Failed to encode json from type: %T", value) - } - - w.WriteInt32(int32(len(s) + 1)) - w.WriteByte(1) // JSONB format header - w.WriteBytes(s) - - return nil -} From ad2ce2ce3c0ed6d12b96fd63fba69966ac628c81 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 18 Mar 2017 14:02:55 -0500 Subject: [PATCH 122/264] Remove internalNativeGoTypeFormats --- conn.go | 13 ++++++++++++- values.go | 35 ----------------------------------- 2 files changed, 12 insertions(+), 36 deletions(-) diff --git a/conn.go b/conn.go index 3414d7cf..bdb229a9 100644 --- a/conn.go +++ b/conn.go @@ -984,7 +984,18 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} case string, *string: wbuf.WriteInt16(TextFormatCode) default: - wbuf.WriteInt16(internalNativeGoTypeFormats[oid]) + if dt, ok := c.ConnInfo.DataTypeForOid(oid); ok { + switch dt.Value.(type) { + case pgtype.BinaryEncoder: + wbuf.WriteInt16(BinaryFormatCode) + case pgtype.TextEncoder: + wbuf.WriteInt16(TextFormatCode) + default: + return fmt.Errorf("value for oid %v does not implement pgtype.BinaryEncoder or pgtype.TextEncoder", oid) + } + } else { + return fmt.Errorf("unknown type for oid %v", oid) + } } } diff --git a/values.go b/values.go index 682245eb..1df63945 100644 --- a/values.go +++ b/values.go @@ -72,9 +72,6 @@ const minInt = -maxInt - 1 // set here. var DefaultTypeFormats map[string]int16 -// internalNativeGoTypeFormats lists the encoding type for native Go types (not handled with Encoder interface) -var internalNativeGoTypeFormats map[pgtype.Oid]int16 - func init() { DefaultTypeFormats = map[string]int16{ "_aclitem": TextFormatCode, // Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin) @@ -111,38 +108,6 @@ func init() { "timestamptz": BinaryFormatCode, "xid": BinaryFormatCode, } - - internalNativeGoTypeFormats = map[pgtype.Oid]int16{ - BoolArrayOid: BinaryFormatCode, - BoolOid: BinaryFormatCode, - ByteaArrayOid: BinaryFormatCode, - ByteaOid: BinaryFormatCode, - CidrArrayOid: BinaryFormatCode, - CidrOid: BinaryFormatCode, - DateOid: BinaryFormatCode, - Float4ArrayOid: BinaryFormatCode, - Float4Oid: BinaryFormatCode, - Float8ArrayOid: BinaryFormatCode, - Float8Oid: BinaryFormatCode, - InetArrayOid: BinaryFormatCode, - InetOid: BinaryFormatCode, - Int2ArrayOid: BinaryFormatCode, - Int2Oid: BinaryFormatCode, - Int4ArrayOid: BinaryFormatCode, - Int4Oid: BinaryFormatCode, - Int8ArrayOid: BinaryFormatCode, - Int8Oid: BinaryFormatCode, - JsonbOid: BinaryFormatCode, - JsonOid: BinaryFormatCode, - OidOid: BinaryFormatCode, - RecordOid: BinaryFormatCode, - TextArrayOid: BinaryFormatCode, - TimestampArrayOid: BinaryFormatCode, - TimestampOid: BinaryFormatCode, - TimestampTzArrayOid: BinaryFormatCode, - TimestampTzOid: BinaryFormatCode, - VarcharArrayOid: BinaryFormatCode, - } } // SerializationError occurs on failure to encode or decode a value From 9e289cb1863f6a62c5d31cd726891aee8bcf67cf Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 18 Mar 2017 14:09:55 -0500 Subject: [PATCH 123/264] Remove unused DefaultTypeFormats --- doc.go | 23 ++++------------------- values.go | 46 ---------------------------------------------- 2 files changed, 4 insertions(+), 65 deletions(-) diff --git a/doc.go b/doc.go index 5f3490ca..2d782c5e 100644 --- a/doc.go +++ b/doc.go @@ -155,25 +155,10 @@ netmask for IPv4 and a /128 for IPv6. Custom Type Support pgx includes support for the common data types like integers, floats, strings, -dates, and times that have direct mappings between Go and SQL. Support can be -added for additional types like point, hstore, numeric, etc. that do not have -direct mappings in Go by the types implementing ScannerPgx and Encoder. - -Custom types can support text or binary formats. Binary format can provide a -large performance increase. The natural place for deciding the format for a -value would be in ScannerPgx as it is responsible for decoding the returned -data. However, that is impossible as the query has already been sent by the time -the ScannerPgx is invoked. The solution to this is the global -DefaultTypeFormats. If a custom type prefers binary format it should register it -there. - - pgx.DefaultTypeFormats["point"] = pgx.BinaryFormatCode - -Note that the type is referred to by name, not by Oid. This is because custom -PostgreSQL types like hstore will have different Oids on different servers. When -pgx establishes a connection it queries the pg_type table for all types. It then -matches the names in DefaultTypeFormats with the returned Oids and stores it in -Conn.PgTypes. +dates, and times that have direct mappings between Go and SQL. In addition, +pgx uses the github.com/jackc/pgx/pgtype library to support more types. See +documention for that library for instructions on how to implement custom +types. See example_custom_type_test.go for an example of a custom type for the PostgreSQL point type. diff --git a/values.go b/values.go index 1df63945..aec3cda7 100644 --- a/values.go +++ b/values.go @@ -64,52 +64,6 @@ const maxUint = ^uint(0) const maxInt = int(maxUint >> 1) const minInt = -maxInt - 1 -// DefaultTypeFormats maps type names to their default requested format (text -// or binary). In theory the Scanner interface should be the one to determine -// the format of the returned values. However, the query has already been -// executed by the time Scan is called so it has no chance to set the format. -// So for types that should always be returned in binary the format should be -// set here. -var DefaultTypeFormats map[string]int16 - -func init() { - DefaultTypeFormats = map[string]int16{ - "_aclitem": TextFormatCode, // Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin) - "_bool": BinaryFormatCode, - "_bytea": BinaryFormatCode, - "_cidr": BinaryFormatCode, - "_float4": BinaryFormatCode, - "_float8": BinaryFormatCode, - "_inet": BinaryFormatCode, - "_int2": BinaryFormatCode, - "_int4": BinaryFormatCode, - "_int8": BinaryFormatCode, - "_text": BinaryFormatCode, - "_timestamp": BinaryFormatCode, - "_timestamptz": BinaryFormatCode, - "_varchar": BinaryFormatCode, - "aclitem": TextFormatCode, // Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin) - "bool": BinaryFormatCode, - "bytea": BinaryFormatCode, - "char": BinaryFormatCode, - "cid": BinaryFormatCode, - "cidr": BinaryFormatCode, - "date": BinaryFormatCode, - "float4": BinaryFormatCode, - "float8": BinaryFormatCode, - "inet": BinaryFormatCode, - "int2": BinaryFormatCode, - "int4": BinaryFormatCode, - "int8": BinaryFormatCode, - "oid": BinaryFormatCode, - "record": BinaryFormatCode, - "tid": BinaryFormatCode, - "timestamp": BinaryFormatCode, - "timestamptz": BinaryFormatCode, - "xid": BinaryFormatCode, - } -} - // SerializationError occurs on failure to encode or decode a value type SerializationError string From a636ef31a4433fc275b0de485d07f401f9bbfca0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 18 Mar 2017 14:23:04 -0500 Subject: [PATCH 124/264] Refactor encoding parameters for prepared statements --- conn.go | 24 ++---------------------- copy_from.go | 2 +- values.go | 31 ++++++++++++++++++++++++------- 3 files changed, 27 insertions(+), 30 deletions(-) diff --git a/conn.go b/conn.go index bdb229a9..6c6998b5 100644 --- a/conn.go +++ b/conn.go @@ -976,32 +976,12 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteInt16(int16(len(ps.ParameterOids))) for i, oid := range ps.ParameterOids { - switch arguments[i].(type) { - case pgtype.BinaryEncoder: - wbuf.WriteInt16(BinaryFormatCode) - case pgtype.TextEncoder: - wbuf.WriteInt16(TextFormatCode) - case string, *string: - wbuf.WriteInt16(TextFormatCode) - default: - if dt, ok := c.ConnInfo.DataTypeForOid(oid); ok { - switch dt.Value.(type) { - case pgtype.BinaryEncoder: - wbuf.WriteInt16(BinaryFormatCode) - case pgtype.TextEncoder: - wbuf.WriteInt16(TextFormatCode) - default: - return fmt.Errorf("value for oid %v does not implement pgtype.BinaryEncoder or pgtype.TextEncoder", oid) - } - } else { - return fmt.Errorf("unknown type for oid %v", oid) - } - } + wbuf.WriteInt16(chooseParameterFormatCode(c.ConnInfo, oid, arguments[i])) } wbuf.WriteInt16(int16(len(arguments))) for i, oid := range ps.ParameterOids { - if err := Encode(wbuf, oid, arguments[i]); err != nil { + if err := encodePreparedStatementArgument(wbuf, oid, arguments[i]); err != nil { return err } } diff --git a/copy_from.go b/copy_from.go index 1f8a2306..9fc76a7b 100644 --- a/copy_from.go +++ b/copy_from.go @@ -157,7 +157,7 @@ func (ct *copyFrom) run() (int, error) { wbuf.WriteInt16(int16(len(ct.columnNames))) for i, val := range values { - err = Encode(wbuf, ps.FieldDescriptions[i].DataType, val) + err = encodePreparedStatementArgument(wbuf, ps.FieldDescriptions[i].DataType, val) if err != nil { ct.cancelCopyIn() return 0, err diff --git a/values.go b/values.go index aec3cda7..49df5d89 100644 --- a/values.go +++ b/values.go @@ -71,10 +71,7 @@ func (e SerializationError) Error() string { return string(e) } -// Encode encodes arg into wbuf as the type oid. This allows implementations -// of the Encoder interface to delegate the actual work of encoding to the -// built-in functionality. -func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error { +func encodePreparedStatementArgument(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error { if arg == nil { wbuf.WriteInt32(-1) return nil @@ -112,7 +109,7 @@ func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error { if err != nil { return err } - return Encode(wbuf, oid, v) + return encodePreparedStatementArgument(wbuf, oid, v) case string: return encodeString(wbuf, oid, arg) case []byte: @@ -127,7 +124,7 @@ func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error { return nil } arg = refVal.Elem().Interface() - return Encode(wbuf, oid, arg) + return encodePreparedStatementArgument(wbuf, oid, arg) } if dt, ok := wbuf.conn.ConnInfo.DataTypeForOid(oid); ok { @@ -152,11 +149,31 @@ func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error { } if strippedArg, ok := stripNamedType(&refVal); ok { - return Encode(wbuf, oid, strippedArg) + return encodePreparedStatementArgument(wbuf, oid, strippedArg) } return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) } +// chooseParameterFormatCode determines the correct format code for an +// argument to a prepared statement. It defaults to TextFormatCode if no +// determination can be made. +func chooseParameterFormatCode(ci *pgtype.ConnInfo, oid pgtype.Oid, arg interface{}) int16 { + switch arg.(type) { + case pgtype.BinaryEncoder: + return BinaryFormatCode + case string, *string, pgtype.TextEncoder: + return TextFormatCode + } + + if dt, ok := ci.DataTypeForOid(oid); ok { + if _, ok := dt.Value.(pgtype.BinaryEncoder); ok { + return BinaryFormatCode + } + } + + return TextFormatCode +} + func stripNamedType(val *reflect.Value) (interface{}, bool) { switch val.Kind() { case reflect.Int: From 015108be9adc34d565613c1b6ae68f743657e899 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 18 Mar 2017 14:28:06 -0500 Subject: [PATCH 125/264] Remove unused code --- values.go | 35 ++++++----------------------------- 1 file changed, 6 insertions(+), 29 deletions(-) diff --git a/values.go b/values.go index 49df5d89..8bfcb2ec 100644 --- a/values.go +++ b/values.go @@ -111,9 +111,13 @@ func encodePreparedStatementArgument(wbuf *WriteBuf, oid pgtype.Oid, arg interfa } return encodePreparedStatementArgument(wbuf, oid, v) case string: - return encodeString(wbuf, oid, arg) + wbuf.WriteInt32(int32(len(arg))) + wbuf.WriteBytes([]byte(arg)) + return nil case []byte: - return encodeByteSlice(wbuf, oid, arg) + wbuf.WriteInt32(int32(len(arg))) + wbuf.WriteBytes(arg) + return nil } refVal := reflect.ValueOf(arg) @@ -214,26 +218,6 @@ func stripNamedType(val *reflect.Value) (interface{}, bool) { return nil, false } -func decodeText(vr *ValueReader) string { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into string")) - return "" - } - - if vr.Type().FormatCode == BinaryFormatCode { - vr.Fatal(ProtocolError("cannot decode binary value into string")) - return "" - } - - return vr.ReadString(vr.Len()) -} - -func encodeString(w *WriteBuf, oid pgtype.Oid, value string) error { - w.WriteInt32(int32(len(value))) - w.WriteBytes([]byte(value)) - return nil -} - func decodeBytea(vr *ValueReader) []byte { if vr.Len() == -1 { return nil @@ -251,10 +235,3 @@ func decodeBytea(vr *ValueReader) []byte { return vr.ReadBytes(vr.Len()) } - -func encodeByteSlice(w *WriteBuf, oid pgtype.Oid, value []byte) error { - w.WriteInt32(int32(len(value))) - w.WriteBytes(value) - - return nil -} From 92cff1d96155b53b88a02bd0524bdbc20c3afe35 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 18 Mar 2017 14:42:36 -0500 Subject: [PATCH 126/264] Simplify []byte scanning --- doc.go | 5 +---- pgtype/text.go | 10 ++++++++++ pgtype/text_test.go | 27 +++++++++++++++++++++++++-- query.go | 15 +-------------- query_test.go | 1 - v3.md | 2 ++ values.go | 18 ------------------ 7 files changed, 39 insertions(+), 39 deletions(-) diff --git a/doc.go b/doc.go index 2d782c5e..0921242a 100644 --- a/doc.go +++ b/doc.go @@ -169,10 +169,7 @@ and database/sql/driver.Valuer interfaces. Raw Bytes Mapping []byte passed as arguments to Query, QueryRow, and Exec are passed unmodified -to PostgreSQL. In like manner, a *[]byte passed to Scan will be filled with -the raw bytes returned by PostgreSQL. This can be especially useful for reading -varchar, text, json, and jsonb values directly into a []byte and avoiding the -type conversion from string. +to PostgreSQL. Transactions diff --git a/pgtype/text.go b/pgtype/text.go index f1a76b6e..af7f16fc 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -49,6 +49,16 @@ func (src *Text) AssignTo(dst interface{}) error { return fmt.Errorf("cannot assign %v to %T", src, dst) } *v = src.String + case *[]byte: + switch src.Status { + case Present: + *v = make([]byte, len(src.String)) + copy(*v, src.String) + case Null: + *v = nil + default: + return fmt.Errorf("unknown status") + } default: if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { el := v.Elem() diff --git a/pgtype/text_test.go b/pgtype/text_test.go index 39348bcc..34b6a784 100644 --- a/pgtype/text_test.go +++ b/pgtype/text_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "bytes" "reflect" "testing" @@ -44,7 +45,7 @@ func TestTextAssignTo(t *testing.T) { var s string var ps *string - simpleTests := []struct { + stringTests := []struct { src pgtype.Text dst interface{} expected interface{} @@ -53,7 +54,7 @@ func TestTextAssignTo(t *testing.T) { {src: pgtype.Text{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, } - for i, tt := range simpleTests { + for i, tt := range stringTests { err := tt.src.AssignTo(tt.dst) if err != nil { t.Errorf("%d: %v", i, err) @@ -64,6 +65,28 @@ func TestTextAssignTo(t *testing.T) { } } + var buf []byte + + bytesTests := []struct { + src pgtype.Text + dst *[]byte + expected []byte + }{ + {src: pgtype.Text{String: "foo", Status: pgtype.Present}, dst: &buf, expected: []byte("foo")}, + {src: pgtype.Text{Status: pgtype.Null}, dst: &buf, expected: nil}, + } + + for i, tt := range bytesTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if bytes.Compare(*tt.dst, tt.expected) != 0 { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, tt.dst) + } + } + pointerAllocTests := []struct { src pgtype.Text dst interface{} diff --git a/query.go b/query.go index a76a99bc..0b5cc911 100644 --- a/query.go +++ b/query.go @@ -198,20 +198,7 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { continue } - // Check for []byte first as we allow sidestepping the decoding process and retrieving the raw bytes - if b, ok := d.(*[]byte); ok { - // If it actually is a bytea then pass it through decodeBytea (so it can be decoded if it is in text format) - // Otherwise read the bytes directly regardless of what the actual type is. - if vr.Type().DataType == ByteaOid { - *b = decodeBytea(vr) - } else { - if vr.Len() != -1 { - *b = vr.ReadBytes(vr.Len()) - } else { - *b = nil - } - } - } else if s, ok := d.(pgtype.BinaryDecoder); ok && vr.Type().FormatCode == BinaryFormatCode { + if s, ok := d.(pgtype.BinaryDecoder); ok && vr.Type().FormatCode == BinaryFormatCode { err = s.DecodeBinary(rows.conn.ConnInfo, vr.bytes()) if err != nil { rows.Fatal(scanArgError{col: i, err: err}) diff --git a/query_test.go b/query_test.go index 480959e8..b053e26d 100644 --- a/query_test.go +++ b/query_test.go @@ -684,7 +684,6 @@ func TestQueryRowCoreByteSlice(t *testing.T) { }{ {"select $1::text", "Jack", []byte("Jack")}, {"select $1::text", []byte("Jack"), []byte("Jack")}, - {"select $1::int4", int32(239023409), []byte{14, 63, 53, 49}}, {"select $1::varchar", []byte("Jack"), []byte("Jack")}, {"select $1::bytea", []byte{0, 15, 255, 17}, []byte{0, 15, 255, 17}}, } diff --git a/v3.md b/v3.md index 6f5fd412..8fe30bf4 100644 --- a/v3.md +++ b/v3.md @@ -28,6 +28,8 @@ Reject scanning binary format values into a string (e.g. binary encoded timestam Remove CopyTo +No longer can read raw bytes of any value into a []byte. Use pgtype.GenericBinary if this functionality is needed. + ## TODO / Possible / Investigate Organize errors better diff --git a/values.go b/values.go index 8bfcb2ec..734e1fa5 100644 --- a/values.go +++ b/values.go @@ -217,21 +217,3 @@ func stripNamedType(val *reflect.Value) (interface{}, bool) { return nil, false } - -func decodeBytea(vr *ValueReader) []byte { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != ByteaOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []byte", vr.Type().DataType))) - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - return vr.ReadBytes(vr.Len()) -} From 6f0ec4c470012f42f222949e2d89e7acb0753c80 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 18 Mar 2017 14:53:51 -0500 Subject: [PATCH 127/264] Renable json tests --- values_test.go | 258 ++++++++++++++++++++++++------------------------- 1 file changed, 126 insertions(+), 132 deletions(-) diff --git a/values_test.go b/values_test.go index 1d09eb18..37bf91cc 100644 --- a/values_test.go +++ b/values_test.go @@ -6,6 +6,8 @@ import ( "reflect" "testing" "time" + + "github.com/jackc/pgx" ) func TestDateTranscode(t *testing.T) { @@ -77,159 +79,151 @@ func TestTimestampTzTranscode(t *testing.T) { // TODO - move these tests to pgtype -// func TestJSONAndJSONBTranscode(t *testing.T) { -// t.Parallel() +func TestJSONAndJSONBTranscode(t *testing.T) { + t.Parallel() -// conn := mustConnect(t, *defaultConnConfig) -// defer closeConn(t, conn) + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) -// for _, oid := range []pgtype.Oid{pgx.JsonOid, pgx.JsonbOid} { -// if _, ok := conn.ConnInfo.DataTypeForOid(oid); !ok { -// return // No JSON/JSONB type -- must be running against old PostgreSQL -// } + for _, typename := range []string{"json", "jsonb"} { + if _, ok := conn.ConnInfo.DataTypeForName(typename); !ok { + continue // No JSON/JSONB type -- must be running against old PostgreSQL + } -// for _, format := range []int16{pgx.TextFormatCode, pgx.BinaryFormatCode} { -// pgtype := conn.PgTypes[oid] -// pgtype.DefaultFormat = format -// conn.PgTypes[oid] = pgtype + testJSONString(t, conn, typename) + testJSONStringPointer(t, conn, typename) + testJSONSingleLevelStringMap(t, conn, typename) + testJSONNestedMap(t, conn, typename) + testJSONStringArray(t, conn, typename) + testJSONInt64Array(t, conn, typename) + testJSONInt16ArrayFailureDueToOverflow(t, conn, typename) + testJSONStruct(t, conn, typename) + } +} -// typename := conn.PgTypes[oid].Name +func testJSONString(t *testing.T, conn *pgx.Conn, typename string) { + input := `{"key": "value"}` + expectedOutput := map[string]string{"key": "value"} + var output map[string]string + err := conn.QueryRow("select $1::"+typename, input).Scan(&output) + if err != nil { + t.Errorf("%s: QueryRow Scan failed: %v", typename, err) + return + } -// testJSONString(t, conn, typename, format) -// testJSONStringPointer(t, conn, typename, format) -// testJSONSingleLevelStringMap(t, conn, typename, format) -// testJSONNestedMap(t, conn, typename, format) -// testJSONStringArray(t, conn, typename, format) -// testJSONInt64Array(t, conn, typename, format) -// testJSONInt16ArrayFailureDueToOverflow(t, conn, typename, format) -// testJSONStruct(t, conn, typename, format) -// } -// } -// } + if !reflect.DeepEqual(expectedOutput, output) { + t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, expectedOutput, output) + return + } +} -// func testJSONString(t *testing.T, conn *pgx.Conn, typename string, format int16) { -// input := `{"key": "value"}` -// expectedOutput := map[string]string{"key": "value"} -// var output map[string]string -// err := conn.QueryRow("select $1::"+typename, input).Scan(&output) -// if err != nil { -// t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) -// return -// } +func testJSONStringPointer(t *testing.T, conn *pgx.Conn, typename string) { + input := `{"key": "value"}` + expectedOutput := map[string]string{"key": "value"} + var output map[string]string + err := conn.QueryRow("select $1::"+typename, &input).Scan(&output) + if err != nil { + t.Errorf("%s: QueryRow Scan failed: %v", typename, err) + return + } -// if !reflect.DeepEqual(expectedOutput, output) { -// t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, expectedOutput, output) -// return -// } -// } + if !reflect.DeepEqual(expectedOutput, output) { + t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, expectedOutput, output) + return + } +} -// func testJSONStringPointer(t *testing.T, conn *pgx.Conn, typename string, format int16) { -// input := `{"key": "value"}` -// expectedOutput := map[string]string{"key": "value"} -// var output map[string]string -// err := conn.QueryRow("select $1::"+typename, &input).Scan(&output) -// if err != nil { -// t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) -// return -// } +func testJSONSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string) { + input := map[string]string{"key": "value"} + var output map[string]string + err := conn.QueryRow("select $1::"+typename, input).Scan(&output) + if err != nil { + t.Errorf("%s: QueryRow Scan failed: %v", typename, err) + return + } -// if !reflect.DeepEqual(expectedOutput, output) { -// t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, expectedOutput, output) -// return -// } -// } + if !reflect.DeepEqual(input, output) { + t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, input, output) + return + } +} -// func testJSONSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string, format int16) { -// input := map[string]string{"key": "value"} -// var output map[string]string -// err := conn.QueryRow("select $1::"+typename, input).Scan(&output) -// if err != nil { -// t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) -// return -// } +func testJSONNestedMap(t *testing.T, conn *pgx.Conn, typename string) { + input := map[string]interface{}{ + "name": "Uncanny", + "stats": map[string]interface{}{"hp": float64(107), "maxhp": float64(150)}, + "inventory": []interface{}{"phone", "key"}, + } + var output map[string]interface{} + err := conn.QueryRow("select $1::"+typename, input).Scan(&output) + if err != nil { + t.Errorf("%s: QueryRow Scan failed: %v", typename, err) + return + } -// if !reflect.DeepEqual(input, output) { -// t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, input, output) -// return -// } -// } + if !reflect.DeepEqual(input, output) { + t.Errorf("%s: Did not transcode map[string]interface{} successfully: %v is not %v", typename, input, output) + return + } +} -// func testJSONNestedMap(t *testing.T, conn *pgx.Conn, typename string, format int16) { -// input := map[string]interface{}{ -// "name": "Uncanny", -// "stats": map[string]interface{}{"hp": float64(107), "maxhp": float64(150)}, -// "inventory": []interface{}{"phone", "key"}, -// } -// var output map[string]interface{} -// err := conn.QueryRow("select $1::"+typename, input).Scan(&output) -// if err != nil { -// t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) -// return -// } +func testJSONStringArray(t *testing.T, conn *pgx.Conn, typename string) { + input := []string{"foo", "bar", "baz"} + var output []string + err := conn.QueryRow("select $1::"+typename, input).Scan(&output) + if err != nil { + t.Errorf("%s: QueryRow Scan failed: %v", typename, err) + } -// if !reflect.DeepEqual(input, output) { -// t.Errorf("%s %d: Did not transcode map[string]interface{} successfully: %v is not %v", typename, format, input, output) -// return -// } -// } + if !reflect.DeepEqual(input, output) { + t.Errorf("%s: Did not transcode []string successfully: %v is not %v", typename, input, output) + } +} -// func testJSONStringArray(t *testing.T, conn *pgx.Conn, typename string, format int16) { -// input := []string{"foo", "bar", "baz"} -// var output []string -// err := conn.QueryRow("select $1::"+typename, input).Scan(&output) -// if err != nil { -// t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) -// } +func testJSONInt64Array(t *testing.T, conn *pgx.Conn, typename string) { + input := []int64{1, 2, 234432} + var output []int64 + err := conn.QueryRow("select $1::"+typename, input).Scan(&output) + if err != nil { + t.Errorf("%s: QueryRow Scan failed: %v", typename, err) + } -// if !reflect.DeepEqual(input, output) { -// t.Errorf("%s %d: Did not transcode []string successfully: %v is not %v", typename, format, input, output) -// } -// } + if !reflect.DeepEqual(input, output) { + t.Errorf("%s: Did not transcode []int64 successfully: %v is not %v", typename, input, output) + } +} -// func testJSONInt64Array(t *testing.T, conn *pgx.Conn, typename string, format int16) { -// input := []int64{1, 2, 234432} -// var output []int64 -// err := conn.QueryRow("select $1::"+typename, input).Scan(&output) -// if err != nil { -// t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) -// } +func testJSONInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string) { + input := []int{1, 2, 234432} + var output []int16 + err := conn.QueryRow("select $1::"+typename, input).Scan(&output) + if err == nil || err.Error() != "can't scan into dest[0]: json: cannot unmarshal number 234432 into Go value of type int16" { + t.Errorf("%s: Expected *json.UnmarkalTypeError, but got %v", typename, err) + } +} -// if !reflect.DeepEqual(input, output) { -// t.Errorf("%s %d: Did not transcode []int64 successfully: %v is not %v", typename, format, input, output) -// } -// } +func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string) { + type person struct { + Name string `json:"name"` + Age int `json:"age"` + } -// func testJSONInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string, format int16) { -// input := []int{1, 2, 234432} -// var output []int16 -// err := conn.QueryRow("select $1::"+typename, input).Scan(&output) -// if err == nil || err.Error() != "can't scan into dest[0]: json: cannot unmarshal number 234432 into Go value of type int16" { -// t.Errorf("%s %d: Expected *json.UnmarkalTypeError, but got %v", typename, format, err) -// } -// } + input := person{ + Name: "John", + Age: 42, + } -// func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string, format int16) { -// type person struct { -// Name string `json:"name"` -// Age int `json:"age"` -// } + var output person -// input := person{ -// Name: "John", -// Age: 42, -// } + err := conn.QueryRow("select $1::"+typename, input).Scan(&output) + if err != nil { + t.Errorf("%s: QueryRow Scan failed: %v", typename, err) + } -// var output person - -// err := conn.QueryRow("select $1::"+typename, input).Scan(&output) -// if err != nil { -// t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) -// } - -// if !reflect.DeepEqual(input, output) { -// t.Errorf("%s %d: Did not transcode struct successfully: %v is not %v", typename, format, input, output) -// } -// } + if !reflect.DeepEqual(input, output) { + t.Errorf("%s: Did not transcode struct successfully: %v is not %v", typename, input, output) + } +} func mustParseCidr(t *testing.T, s string) *net.IPNet { _, ipnet, err := net.ParseCIDR(s) From 9ab59a74a95b0f9a65f875f58712f3e7ca501c71 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 18 Mar 2017 14:59:14 -0500 Subject: [PATCH 128/264] Remove oid constants from pgx --- conn.go | 8 ++++---- conn_test.go | 2 +- stdlib/sql.go | 20 ++++++++++---------- v3.md | 2 ++ values.go | 45 --------------------------------------------- 5 files changed, 17 insertions(+), 60 deletions(-) diff --git a/conn.go b/conn.go index 6c6998b5..509e9d8e 100644 --- a/conn.go +++ b/conn.go @@ -38,10 +38,10 @@ var minimalConnInfo *pgtype.ConnInfo func init() { minimalConnInfo = pgtype.NewConnInfo() minimalConnInfo.InitializeDataTypes(map[string]pgtype.Oid{ - "int4": Int4Oid, - "name": NameOid, - "oid": OidOid, - "text": TextOid, + "int4": pgtype.Int4Oid, + "name": pgtype.NameOid, + "oid": pgtype.OidOid, + "text": pgtype.TextOid, }) } diff --git a/conn_test.go b/conn_test.go index e1c780b8..13367c6a 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1043,7 +1043,7 @@ func TestPrepareEx(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - _, err := conn.PrepareEx("test", "select $1", &pgx.PrepareExOptions{ParameterOids: []pgtype.Oid{pgx.TextOid}}) + _, err := conn.PrepareEx("test", "select $1", &pgx.PrepareExOptions{ParameterOids: []pgtype.Oid{pgtype.TextOid}}) if err != nil { t.Errorf("Unable to prepare statement: %v", err) return diff --git a/stdlib/sql.go b/stdlib/sql.go index 7ab4cdbe..6889a2b6 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -66,16 +66,16 @@ func init() { sql.Register("pgx", d) databaseSqlOids = make(map[pgtype.Oid]bool) - databaseSqlOids[pgx.BoolOid] = true - databaseSqlOids[pgx.ByteaOid] = true - databaseSqlOids[pgx.Int2Oid] = true - databaseSqlOids[pgx.Int4Oid] = true - databaseSqlOids[pgx.Int8Oid] = true - databaseSqlOids[pgx.Float4Oid] = true - databaseSqlOids[pgx.Float8Oid] = true - databaseSqlOids[pgx.DateOid] = true - databaseSqlOids[pgx.TimestampTzOid] = true - databaseSqlOids[pgx.TimestampOid] = true + databaseSqlOids[pgtype.BoolOid] = true + databaseSqlOids[pgtype.ByteaOid] = true + databaseSqlOids[pgtype.Int2Oid] = true + databaseSqlOids[pgtype.Int4Oid] = true + databaseSqlOids[pgtype.Int8Oid] = true + databaseSqlOids[pgtype.Float4Oid] = true + databaseSqlOids[pgtype.Float8Oid] = true + databaseSqlOids[pgtype.DateOid] = true + databaseSqlOids[pgtype.TimestamptzOid] = true + databaseSqlOids[pgtype.TimestampOid] = true } type Driver struct { diff --git a/v3.md b/v3.md index 8fe30bf4..3e0aae82 100644 --- a/v3.md +++ b/v3.md @@ -30,6 +30,8 @@ Remove CopyTo No longer can read raw bytes of any value into a []byte. Use pgtype.GenericBinary if this functionality is needed. +OID constants moved from pgx to pgtype package + ## TODO / Possible / Investigate Organize errors better diff --git a/values.go b/values.go index 734e1fa5..3491efed 100644 --- a/values.go +++ b/values.go @@ -9,51 +9,6 @@ import ( "github.com/jackc/pgx/pgtype" ) -// PostgreSQL oids for common types -const ( - BoolOid = 16 - ByteaOid = 17 - CharOid = 18 - NameOid = 19 - Int8Oid = 20 - Int2Oid = 21 - Int4Oid = 23 - TextOid = 25 - OidOid = 26 - TidOid = 27 - XidOid = 28 - CidOid = 29 - JsonOid = 114 - CidrOid = 650 - CidrArrayOid = 651 - Float4Oid = 700 - Float8Oid = 701 - UnknownOid = 705 - InetOid = 869 - BoolArrayOid = 1000 - Int2ArrayOid = 1005 - Int4ArrayOid = 1007 - TextArrayOid = 1009 - ByteaArrayOid = 1001 - VarcharArrayOid = 1015 - Int8ArrayOid = 1016 - Float4ArrayOid = 1021 - Float8ArrayOid = 1022 - AclitemOid = 1033 - AclitemArrayOid = 1034 - InetArrayOid = 1041 - VarcharOid = 1043 - DateOid = 1082 - TimestampOid = 1114 - TimestampArrayOid = 1115 - DateArrayOid = 1182 - TimestampTzOid = 1184 - TimestampTzArrayOid = 1185 - RecordOid = 2249 - UuidOid = 2950 - JsonbOid = 3802 -) - // PostgreSQL format codes const ( TextFormatCode = 0 From 1bea9d3f7e1419e5f9f3c52a92088b84620e728b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 18 Mar 2017 15:00:41 -0500 Subject: [PATCH 129/264] Remove int bound constants --- values.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/values.go b/values.go index 3491efed..c399b42c 100644 --- a/values.go +++ b/values.go @@ -15,10 +15,6 @@ const ( BinaryFormatCode = 1 ) -const maxUint = ^uint(0) -const maxInt = int(maxUint >> 1) -const minInt = -maxInt - 1 - // SerializationError occurs on failure to encode or decode a value type SerializationError string From 264823e6abd8032711ab99cd539bcaab3a9bfb28 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 18 Mar 2017 15:51:16 -0500 Subject: [PATCH 130/264] Remove unneeded idea file --- pgtype/extra-interface.txt | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 pgtype/extra-interface.txt diff --git a/pgtype/extra-interface.txt b/pgtype/extra-interface.txt deleted file mode 100644 index f07818bc..00000000 --- a/pgtype/extra-interface.txt +++ /dev/null @@ -1,3 +0,0 @@ -Can pass function to get inet data and function to get oid/name mapping as optional interface with io.Reader or io.Writer - -Could be useful for arrays of types without defined Oids like hstore. From 4d9c44fc014bcd2e4d0efca99a9b355036c84899 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 18 Mar 2017 16:54:08 -0500 Subject: [PATCH 131/264] Factor out duplication in AssignTo --- pgtype/aclitem.go | 42 +++++---------- pgtype/aclitem_array.go | 23 ++++---- pgtype/bool.go | 42 +++++---------- pgtype/bool_array.go | 23 ++++---- pgtype/bytea.go | 43 +++++---------- pgtype/bytea_array.go | 23 ++++---- pgtype/cidr_array.go | 30 +++++------ pgtype/convert.go | 102 ++++++++++++++++++++++++++++-------- pgtype/date.go | 39 ++++++-------- pgtype/date_array.go | 23 ++++---- pgtype/float4_array.go | 23 ++++---- pgtype/float8_array.go | 23 ++++---- pgtype/hstore.go | 21 ++++---- pgtype/hstore_array.go | 23 ++++---- pgtype/inet.go | 44 +++++----------- pgtype/inet_array.go | 30 +++++------ pgtype/int2_array.go | 30 +++++------ pgtype/int4_array.go | 30 +++++------ pgtype/int8_array.go | 30 +++++------ pgtype/record.go | 31 +++++------ pgtype/text.go | 50 +++++------------- pgtype/text_array.go | 23 ++++---- pgtype/timestamp.go | 39 ++++++-------- pgtype/timestamp_array.go | 23 ++++---- pgtype/timestamptz.go | 39 ++++++-------- pgtype/timestamptz_array.go | 23 ++++---- pgtype/typed_array.go.erb | 27 +++++----- pgtype/varchar_array.go | 23 ++++---- 28 files changed, 430 insertions(+), 492 deletions(-) diff --git a/pgtype/aclitem.go b/pgtype/aclitem.go index f9faab20..e8386ae7 100644 --- a/pgtype/aclitem.go +++ b/pgtype/aclitem.go @@ -3,7 +3,6 @@ package pgtype import ( "fmt" "io" - "reflect" ) // Aclitem is used for PostgreSQL's aclitem data type. A sample aclitem @@ -55,39 +54,22 @@ func (dst *Aclitem) Get() interface{} { } func (src *Aclitem) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *string: - if src.Status != Present { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = src.String - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if src.Status == Null { - el.Set(reflect.Zero(el.Type())) - return nil - } - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return src.AssignTo(el.Interface()) - case reflect.String: - if src.Status != Present { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - el.SetString(src.String) - return nil + switch src.Status { + case Present: + switch v := dst.(type) { + case *string: + *v = src.String + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Aclitem) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/aclitem_array.go b/pgtype/aclitem_array.go index f02d339e..1c97e74f 100644 --- a/pgtype/aclitem_array.go +++ b/pgtype/aclitem_array.go @@ -58,28 +58,29 @@ func (dst *AclitemArray) Get() interface{} { } func (src *AclitemArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]string: - if src.Status == Present { + case *[]string: *v = make([]string, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *AclitemArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/bool.go b/pgtype/bool.go index 87316381..608a6f95 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -3,7 +3,6 @@ package pgtype import ( "fmt" "io" - "reflect" "strconv" ) @@ -44,39 +43,22 @@ func (dst *Bool) Get() interface{} { } func (src *Bool) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *bool: - if src.Status != Present { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = src.Bool - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if src.Status == Null { - el.Set(reflect.Zero(el.Type())) - return nil - } - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return src.AssignTo(el.Interface()) - case reflect.Bool: - if src.Status != Present { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - el.SetBool(src.Bool) - return nil + switch src.Status { + case Present: + switch v := dst.(type) { + case *bool: + *v = src.Bool + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Bool) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/bool_array.go b/pgtype/bool_array.go index 6adfbb00..cdfe9685 100644 --- a/pgtype/bool_array.go +++ b/pgtype/bool_array.go @@ -59,28 +59,29 @@ func (dst *BoolArray) Get() interface{} { } func (src *BoolArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]bool: - if src.Status == Present { + case *[]bool: *v = make([]bool, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *BoolArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/bytea.go b/pgtype/bytea.go index dc1e9c07..00bed8e8 100644 --- a/pgtype/bytea.go +++ b/pgtype/bytea.go @@ -4,7 +4,6 @@ import ( "encoding/hex" "fmt" "io" - "reflect" ) type Bytea struct { @@ -42,38 +41,24 @@ func (dst *Bytea) Get() interface{} { } func (src *Bytea) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *[]byte: - if src.Status == Present { - *v = src.Bytes - } else { - *v = nil - } - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if src.Status == Null { - el.Set(reflect.Zero(el.Type())) - return nil - } - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return src.AssignTo(el.Interface()) - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) - } + switch src.Status { + case Present: + switch v := dst.(type) { + case *[]byte: + buf := make([]byte, len(src.Bytes)) + copy(buf, src.Bytes) + *v = buf + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } // DecodeText only supports the hex format. This has been the default since diff --git a/pgtype/bytea_array.go b/pgtype/bytea_array.go index d318fa3b..175ca2f6 100644 --- a/pgtype/bytea_array.go +++ b/pgtype/bytea_array.go @@ -59,28 +59,29 @@ func (dst *ByteaArray) Get() interface{} { } func (src *ByteaArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[][]byte: - if src.Status == Present { + case *[][]byte: *v = make([][]byte, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *ByteaArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/cidr_array.go b/pgtype/cidr_array.go index 3ab83ecd..49a2728b 100644 --- a/pgtype/cidr_array.go +++ b/pgtype/cidr_array.go @@ -79,40 +79,38 @@ func (dst *CidrArray) Get() interface{} { } func (src *CidrArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]*net.IPNet: - if src.Status == Present { + case *[]*net.IPNet: *v = make([]*net.IPNet, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - case *[]net.IP: - if src.Status == Present { + case *[]net.IP: *v = make([]net.IP, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *CidrArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/convert.go b/pgtype/convert.go index 648209f5..4fba8430 100644 --- a/pgtype/convert.go +++ b/pgtype/convert.go @@ -184,28 +184,6 @@ func underlyingSliceType(val interface{}) (interface{}, bool) { return nil, false } -func underlyingPtrSliceType(val interface{}) (interface{}, bool) { - refVal := reflect.ValueOf(val) - - if refVal.Kind() != reflect.Ptr { - return nil, false - } - if refVal.IsNil() { - return nil, false - } - - sliceVal := refVal.Elem().Interface() - baseSliceType := reflect.SliceOf(reflect.TypeOf(sliceVal).Elem()) - ptrBaseSliceType := reflect.PtrTo(baseSliceType) - - if refVal.Type().ConvertibleTo(ptrBaseSliceType) { - convVal := refVal.Convert(ptrBaseSliceType) - return convVal.Interface(), reflect.TypeOf(convVal.Interface()) != refVal.Type() - } - - return nil, false -} - func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { if srcStatus == Present { switch v := dst.(type) { @@ -363,3 +341,83 @@ func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error { return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) } + +func nullAssignTo(dst interface{}) error { + dstPtr := reflect.ValueOf(dst) + + // AssignTo dst must always be a pointer + if dstPtr.Kind() != reflect.Ptr { + return fmt.Errorf("cannot assign NULL to %T", dst) + } + + dstVal := dstPtr.Elem() + + switch dstVal.Kind() { + case reflect.Ptr, reflect.Slice, reflect.Map: + dstVal.Set(reflect.Zero(dstVal.Type())) + return nil + } + + return fmt.Errorf("cannot assign NULL to %T", dst) +} + +var kindTypes map[reflect.Kind]reflect.Type + +// GetAssignToDstType attempts to convert dst to something AssignTo can assign +// to. If dst is a pointer to pointer it allocates a value and returns the +// dereferences pointer. If dst is a named type such as *Foo where Foo is type +// Foo int16, it converts dst to *int16. +// +// GetAssignToDstType returns the converted dst and a bool representing if any +// change was made. +func GetAssignToDstType(dst interface{}) (interface{}, bool) { + dstPtr := reflect.ValueOf(dst) + + // AssignTo dst must always be a pointer + if dstPtr.Kind() != reflect.Ptr { + return nil, false + } + + dstVal := dstPtr.Elem() + + // if dst is a pointer to pointer, allocate space try again with the dereferenced pointer + if dstVal.Kind() == reflect.Ptr { + dstVal.Set(reflect.New(dstVal.Type().Elem())) + return dstVal.Interface(), true + } + + // if dst is pointer to a base type that has been renamed + if baseValType, ok := kindTypes[dstVal.Kind()]; ok { + nextDst := dstPtr.Convert(reflect.PtrTo(baseValType)) + return nextDst.Interface(), dstPtr.Type() != nextDst.Type() + } + + if dstVal.Kind() == reflect.Slice { + if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok { + baseSliceType := reflect.PtrTo(reflect.SliceOf(baseElemType)) + nextDst := dstPtr.Convert(baseSliceType) + return nextDst.Interface(), dstPtr.Type() != nextDst.Type() + } + } + + return nil, false +} + +func init() { + kindTypes = map[reflect.Kind]reflect.Type{ + reflect.Bool: reflect.TypeOf(false), + reflect.Float32: reflect.TypeOf(float32(0)), + reflect.Float64: reflect.TypeOf(float64(0)), + reflect.Int: reflect.TypeOf(int(0)), + reflect.Int8: reflect.TypeOf(int8(0)), + reflect.Int16: reflect.TypeOf(int16(0)), + reflect.Int32: reflect.TypeOf(int32(0)), + reflect.Int64: reflect.TypeOf(int64(0)), + reflect.Uint: reflect.TypeOf(uint(0)), + reflect.Uint8: reflect.TypeOf(uint8(0)), + reflect.Uint16: reflect.TypeOf(uint16(0)), + reflect.Uint32: reflect.TypeOf(uint32(0)), + reflect.Uint64: reflect.TypeOf(uint64(0)), + reflect.String: reflect.TypeOf(""), + } +} diff --git a/pgtype/date.go b/pgtype/date.go index b6cc8329..ab854eb2 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -4,7 +4,6 @@ import ( "encoding/binary" "fmt" "io" - "reflect" "time" "github.com/jackc/pgx/pgio" @@ -50,33 +49,25 @@ func (dst *Date) Get() interface{} { } func (src *Date) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *time.Time: - if src.Status != Present || src.InfinityModifier != None { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = src.Time - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if src.Status == Null { - el.Set(reflect.Zero(el.Type())) - return nil - } - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return src.AssignTo(el.Interface()) + switch src.Status { + case Present: + switch v := dst.(type) { + case *time.Time: + if src.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.Time + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Date) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/date_array.go b/pgtype/date_array.go index 8bc8ff72..bf791677 100644 --- a/pgtype/date_array.go +++ b/pgtype/date_array.go @@ -60,28 +60,29 @@ func (dst *DateArray) Get() interface{} { } func (src *DateArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]time.Time: - if src.Status == Present { + case *[]time.Time: *v = make([]time.Time, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *DateArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/float4_array.go b/pgtype/float4_array.go index 6abc1a31..b4d05c55 100644 --- a/pgtype/float4_array.go +++ b/pgtype/float4_array.go @@ -59,28 +59,29 @@ func (dst *Float4Array) Get() interface{} { } func (src *Float4Array) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]float32: - if src.Status == Present { + case *[]float32: *v = make([]float32, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Float4Array) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/float8_array.go b/pgtype/float8_array.go index 050efa3f..e000807e 100644 --- a/pgtype/float8_array.go +++ b/pgtype/float8_array.go @@ -59,28 +59,29 @@ func (dst *Float8Array) Get() interface{} { } func (src *Float8Array) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]float64: - if src.Status == Present { + case *[]float64: *v = make([]float64, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Float8Array) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/hstore.go b/pgtype/hstore.go index d771d6e6..8dc5b4d8 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -47,10 +47,10 @@ func (dst *Hstore) Get() interface{} { } func (src *Hstore) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *map[string]string: - switch src.Status { - case Present: + switch src.Status { + case Present: + switch v := dst.(type) { + case *map[string]string: *v = make(map[string]string, len(src.Map)) for k, val := range src.Map { if val.Status != Present { @@ -58,16 +58,17 @@ func (src *Hstore) AssignTo(dst interface{}) error { } (*v)[k] = val.String } - case Null: - *v = nil + return nil default: - return fmt.Errorf("cannot decode %v into %T", src, dst) + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - default: - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Hstore) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/hstore_array.go b/pgtype/hstore_array.go index ba192462..9bd0ed3b 100644 --- a/pgtype/hstore_array.go +++ b/pgtype/hstore_array.go @@ -59,28 +59,29 @@ func (dst *HstoreArray) Get() interface{} { } func (src *HstoreArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]map[string]string: - if src.Status == Present { + case *[]map[string]string: *v = make([]map[string]string, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *HstoreArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/inet.go b/pgtype/inet.go index b83bd1c9..13764814 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -4,7 +4,6 @@ import ( "fmt" "io" "net" - "reflect" "github.com/jackc/pgx/pgio" ) @@ -61,43 +60,28 @@ func (dst *Inet) Get() interface{} { } func (src *Inet) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *net.IPNet: - if src.Status != Present { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = *src.IPNet - case *net.IP: - if src.Status == Present { - + switch src.Status { + case Present: + switch v := dst.(type) { + case *net.IPNet: + *v = *src.IPNet + return nil + case *net.IP: if oneCount, bitCount := src.IPNet.Mask.Size(); oneCount != bitCount { return fmt.Errorf("cannot assign %v to %T", src, dst) } *v = src.IPNet.IP - } else { - *v = nil - } - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if src.Status == Null { - el.Set(reflect.Zero(el.Type())) - return nil - } - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return src.AssignTo(el.Interface()) + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/inet_array.go b/pgtype/inet_array.go index d893a724..1988a145 100644 --- a/pgtype/inet_array.go +++ b/pgtype/inet_array.go @@ -79,40 +79,38 @@ func (dst *InetArray) Get() interface{} { } func (src *InetArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]*net.IPNet: - if src.Status == Present { + case *[]*net.IPNet: *v = make([]*net.IPNet, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - case *[]net.IP: - if src.Status == Present { + case *[]net.IP: *v = make([]net.IP, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *InetArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/int2_array.go b/pgtype/int2_array.go index b93a4fa3..531e7dd6 100644 --- a/pgtype/int2_array.go +++ b/pgtype/int2_array.go @@ -78,40 +78,38 @@ func (dst *Int2Array) Get() interface{} { } func (src *Int2Array) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]int16: - if src.Status == Present { + case *[]int16: *v = make([]int16, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - case *[]uint16: - if src.Status == Present { + case *[]uint16: *v = make([]uint16, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Int2Array) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/int4_array.go b/pgtype/int4_array.go index 0b96b7a4..3617050f 100644 --- a/pgtype/int4_array.go +++ b/pgtype/int4_array.go @@ -78,40 +78,38 @@ func (dst *Int4Array) Get() interface{} { } func (src *Int4Array) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]int32: - if src.Status == Present { + case *[]int32: *v = make([]int32, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - case *[]uint32: - if src.Status == Present { + case *[]uint32: *v = make([]uint32, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Int4Array) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/int8_array.go b/pgtype/int8_array.go index 02a240f4..4f04b660 100644 --- a/pgtype/int8_array.go +++ b/pgtype/int8_array.go @@ -78,40 +78,38 @@ func (dst *Int8Array) Get() interface{} { } func (src *Int8Array) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]int64: - if src.Status == Present { + case *[]int64: *v = make([]int64, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - case *[]uint64: - if src.Status == Present { + case *[]uint64: *v = make([]uint64, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Int8Array) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/record.go b/pgtype/record.go index 1bfd05b9..89e081ca 100644 --- a/pgtype/record.go +++ b/pgtype/record.go @@ -38,34 +38,29 @@ func (dst *Record) Get() interface{} { } func (src *Record) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *[]Value: - switch src.Status { - case Present: + switch src.Status { + case Present: + switch v := dst.(type) { + case *[]Value: *v = make([]Value, len(src.Fields)) copy(*v, src.Fields) - case Null: - *v = nil - default: - return fmt.Errorf("cannot decode %v into %T", src, dst) - } - case *[]interface{}: - switch src.Status { - case Present: + return nil + case *[]interface{}: *v = make([]interface{}, len(src.Fields)) for i := range *v { (*v)[i] = src.Fields[i].Get() } - case Null: - *v = nil + return nil default: - return fmt.Errorf("cannot decode %v into %T", src, dst) + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - default: - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { diff --git a/pgtype/text.go b/pgtype/text.go index af7f16fc..dbc9362b 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -3,7 +3,6 @@ package pgtype import ( "fmt" "io" - "reflect" ) type Text struct { @@ -43,49 +42,26 @@ func (dst *Text) Get() interface{} { } func (src *Text) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *string: - if src.Status != Present { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = src.String - case *[]byte: - switch src.Status { - case Present: + switch src.Status { + case Present: + switch v := dst.(type) { + case *string: + *v = src.String + return nil + case *[]byte: *v = make([]byte, len(src.String)) copy(*v, src.String) - case Null: - *v = nil + return nil default: - return fmt.Errorf("unknown status") - } - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if src.Status == Null { - el.Set(reflect.Zero(el.Type())) - return nil - } - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return src.AssignTo(el.Interface()) - case reflect.String: - if src.Status != Present { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - el.SetString(src.String) - return nil + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Text) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/text_array.go b/pgtype/text_array.go index 9f25727e..6e8ead26 100644 --- a/pgtype/text_array.go +++ b/pgtype/text_array.go @@ -59,28 +59,29 @@ func (dst *TextArray) Get() interface{} { } func (src *TextArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]string: - if src.Status == Present { + case *[]string: *v = make([]string, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *TextArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index 9a9e74ea..4b42f3cf 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -4,7 +4,6 @@ import ( "encoding/binary" "fmt" "io" - "reflect" "time" "github.com/jackc/pgx/pgio" @@ -54,33 +53,25 @@ func (dst *Timestamp) Get() interface{} { } func (src *Timestamp) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *time.Time: - if src.Status != Present || src.InfinityModifier != None { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = src.Time - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if src.Status == Null { - el.Set(reflect.Zero(el.Type())) - return nil - } - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return src.AssignTo(el.Interface()) + switch src.Status { + case Present: + switch v := dst.(type) { + case *time.Time: + if src.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.Time + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) } } - return fmt.Errorf("cannot assign %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } // DecodeText decodes from src into dst. The decoded time is considered to diff --git a/pgtype/timestamp_array.go b/pgtype/timestamp_array.go index bb19e502..6a6950c7 100644 --- a/pgtype/timestamp_array.go +++ b/pgtype/timestamp_array.go @@ -60,28 +60,29 @@ func (dst *TimestampArray) Get() interface{} { } func (src *TimestampArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]time.Time: - if src.Status == Present { + case *[]time.Time: *v = make([]time.Time, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *TimestampArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index 7f57f4b7..ba849ac8 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -4,7 +4,6 @@ import ( "encoding/binary" "fmt" "io" - "reflect" "time" "github.com/jackc/pgx/pgio" @@ -55,33 +54,25 @@ func (dst *Timestamptz) Get() interface{} { } func (src *Timestamptz) AssignTo(dst interface{}) error { - switch v := dst.(type) { - case *time.Time: - if src.Status != Present || src.InfinityModifier != None { - return fmt.Errorf("cannot assign %v to %T", src, dst) - } - *v = src.Time - default: - if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { - el := v.Elem() - switch el.Kind() { - // if dst is a pointer to pointer, strip the pointer and try again - case reflect.Ptr: - if src.Status == Null { - el.Set(reflect.Zero(el.Type())) - return nil - } - if el.IsNil() { - // allocate destination - el.Set(reflect.New(el.Type().Elem())) - } - return src.AssignTo(el.Interface()) + switch src.Status { + case Present: + switch v := dst.(type) { + case *time.Time: + if src.InfinityModifier != None { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.Time + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) } } - return fmt.Errorf("cannot assign %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *Timestamptz) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/timestamptz_array.go b/pgtype/timestamptz_array.go index 6a85cefa..347d0b8b 100644 --- a/pgtype/timestamptz_array.go +++ b/pgtype/timestamptz_array.go @@ -60,28 +60,29 @@ func (dst *TimestamptzArray) Get() interface{} { } func (src *TimestamptzArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]time.Time: - if src.Status == Present { + case *[]time.Time: *v = make([]time.Time, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *TimestamptzArray) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb index 2b81666e..26c4671c 100644 --- a/pgtype/typed_array.go.erb +++ b/pgtype/typed_array.go.erb @@ -58,28 +58,29 @@ func (dst *<%= pgtype_array_type %>) Get() interface{} { } func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { - switch v := dst.(type) { - <% go_array_types.split(",").each do |t| %> - case *<%= t %>: - if src.Status == Present { + switch src.Status { + case Present: + switch v := dst.(type) { + <% go_array_types.split(",").each do |t| %> + case *<%= t %>: *v = make(<%= t %>, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil + return nil + <% end %> + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - <% end %> - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) - } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error { diff --git a/pgtype/varchar_array.go b/pgtype/varchar_array.go index 158ece94..e1dd3910 100644 --- a/pgtype/varchar_array.go +++ b/pgtype/varchar_array.go @@ -59,28 +59,29 @@ func (dst *VarcharArray) Get() interface{} { } func (src *VarcharArray) AssignTo(dst interface{}) error { - switch v := dst.(type) { + switch src.Status { + case Present: + switch v := dst.(type) { - case *[]string: - if src.Status == Present { + case *[]string: *v = make([]string, len(src.Elements)) for i := range src.Elements { if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { return err } } - } else { - *v = nil - } + return nil - default: - if originalDst, ok := underlyingPtrSliceType(dst); ok { - return src.AssignTo(originalDst) + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } } - return fmt.Errorf("cannot decode %v into %T", src, dst) + case Null: + return nullAssignTo(dst) } - return nil + return fmt.Errorf("cannot decode %v into %T", src, dst) } func (dst *VarcharArray) DecodeText(ci *ConnInfo, src []byte) error { From 5572c002dc954b8a99fa129f0fad73d2cac92c7e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 18 Mar 2017 17:38:52 -0500 Subject: [PATCH 132/264] Optionally generate binary array format --- pgtype/typed_array.go.erb | 92 ++++++++++++++++++++------------------- pgtype/typed_array_gen.sh | 32 +++++++------- 2 files changed, 64 insertions(+), 60 deletions(-) diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb index 26c4671c..0e5725ce 100644 --- a/pgtype/typed_array.go.erb +++ b/pgtype/typed_array.go.erb @@ -119,6 +119,7 @@ func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error return nil } +<% if binary_format == "true" %> func (dst *<%= pgtype_array_type %>) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = <%= pgtype_array_type %>{Status: Null} @@ -160,6 +161,7 @@ func (dst *<%= pgtype_array_type %>) DecodeBinary(ci *ConnInfo, src []byte) erro *dst = <%= pgtype_array_type %>{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} return nil } +<% end %> func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { @@ -237,61 +239,63 @@ func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool return false, nil } -func (src *<%= pgtype_array_type %>) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - switch src.Status { - case Null: - return true, nil - case Undefined: - return false, errUndefined - } - - arrayHeader := ArrayHeader{ - Dimensions: src.Dimensions, - } - - if dt, ok := ci.DataTypeForName("<%= element_type_name %>"); ok { - arrayHeader.ElementOid = int32(dt.Oid) - } else { - return false, fmt.Errorf("unable to find oid for type name %v", "<%= element_type_name %>") - } - - for i := range src.Elements { - if src.Elements[i].Status == Null { - arrayHeader.ContainsNull = true - break +<% if binary_format == "true" %> + func (src *<%= pgtype_array_type %>) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined } - } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } - elemBuf := &bytes.Buffer{} + if dt, ok := ci.DataTypeForName("<%= element_type_name %>"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "<%= element_type_name %>") + } - for i := range src.Elements { - elemBuf.Reset() + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } - if null { - _, err = pgio.WriteInt32(w, -1) + + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } } } - } - return false, err -} + return false, err + } +<% end %> diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh index 166f8802..d77c8ca3 100644 --- a/pgtype/typed_array_gen.sh +++ b/pgtype/typed_array_gen.sh @@ -1,16 +1,16 @@ -erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_type_name=int2 text_null=NULL typed_array.go.erb > int2_array.go -erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32 element_type_name=int4 text_null=NULL typed_array.go.erb > int4_array.go -erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_type_name=int8 text_null=NULL typed_array.go.erb > int8_array.go -erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_type_name=bool text_null=NULL typed_array.go.erb > bool_array.go -erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_type_name=date text_null=NULL typed_array.go.erb > date_array.go -erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_type_name=timestamptz text_null=NULL typed_array.go.erb > timestamptz_array.go -erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_type_name=timestamp text_null=NULL typed_array.go.erb > timestamp_array.go -erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_type_name=float4 text_null=NULL typed_array.go.erb > float4_array.go -erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_type_name=float8 text_null=NULL typed_array.go.erb > float8_array.go -erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_type_name=inet text_null=NULL typed_array.go.erb > inet_array.go -erb pgtype_array_type=CidrArray pgtype_element_type=Cidr go_array_types=[]*net.IPNet,[]net.IP element_type_name=cidr text_null=NULL typed_array.go.erb > cidr_array.go -erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_type_name=text text_null='"NULL"' typed_array.go.erb > text_array.go -erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string element_type_name=varchar text_null='"NULL"' typed_array.go.erb > varchar_array.go -erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL typed_array.go.erb > bytea_array.go -erb pgtype_array_type=AclitemArray pgtype_element_type=Aclitem go_array_types=[]string element_type_name=aclitem text_null=NULL typed_array.go.erb > aclitem_array.go -erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL typed_array.go.erb > hstore_array.go +erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_type_name=int2 text_null=NULL binary_format=true typed_array.go.erb > int2_array.go +erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32 element_type_name=int4 text_null=NULL binary_format=true typed_array.go.erb > int4_array.go +erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_type_name=int8 text_null=NULL binary_format=true typed_array.go.erb > int8_array.go +erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_type_name=bool text_null=NULL binary_format=true typed_array.go.erb > bool_array.go +erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_type_name=date text_null=NULL binary_format=true typed_array.go.erb > date_array.go +erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_type_name=timestamptz text_null=NULL binary_format=true typed_array.go.erb > timestamptz_array.go +erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_type_name=timestamp text_null=NULL binary_format=true typed_array.go.erb > timestamp_array.go +erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_type_name=float4 text_null=NULL binary_format=true typed_array.go.erb > float4_array.go +erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_type_name=float8 text_null=NULL binary_format=true typed_array.go.erb > float8_array.go +erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_type_name=inet text_null=NULL binary_format=true typed_array.go.erb > inet_array.go +erb pgtype_array_type=CidrArray pgtype_element_type=Cidr go_array_types=[]*net.IPNet,[]net.IP element_type_name=cidr text_null=NULL binary_format=true typed_array.go.erb > cidr_array.go +erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_type_name=text text_null='"NULL"' binary_format=true typed_array.go.erb > text_array.go +erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string element_type_name=varchar text_null='"NULL"' binary_format=true typed_array.go.erb > varchar_array.go +erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go +erb pgtype_array_type=AclitemArray pgtype_element_type=Aclitem go_array_types=[]string element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go +erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go From bec9bd261b6a0e11ecde620f2f7fd8a7ea07e529 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 18 Mar 2017 21:11:43 -0500 Subject: [PATCH 133/264] Add database/sql support to pgtype --- pgtype/aclitem.go | 30 +++++++++ pgtype/aclitem_array.go | 31 ++++++++++ pgtype/bool.go | 33 ++++++++++ pgtype/bool_array.go | 31 ++++++++++ pgtype/bytea.go | 38 ++++++++++++ pgtype/bytea_array.go | 31 ++++++++++ pgtype/cid.go | 11 ++++ pgtype/cidr_array.go | 31 ++++++++++ pgtype/database_sql.go | 52 +++++----------- pgtype/date.go | 47 ++++++++++++++- pgtype/date_array.go | 31 ++++++++++ pgtype/date_test.go | 7 ++- pgtype/float4.go | 38 ++++++++++++ pgtype/float4_array.go | 31 ++++++++++ pgtype/float8.go | 38 ++++++++++++ pgtype/float8_array.go | 31 ++++++++++ pgtype/generic_binary.go | 11 ++++ pgtype/generic_text.go | 11 ++++ pgtype/hstore.go | 28 +++++++++ pgtype/hstore_array.go | 31 ++++++++++ pgtype/inet.go | 28 +++++++++ pgtype/inet_array.go | 31 ++++++++++ pgtype/int2.go | 44 ++++++++++++++ pgtype/int2_array.go | 31 ++++++++++ pgtype/int4.go | 46 +++++++++++++- pgtype/int4_array.go | 31 ++++++++++ pgtype/int8.go | 38 ++++++++++++ pgtype/int8_array.go | 31 ++++++++++ pgtype/json.go | 36 +++++++++++ pgtype/jsonb.go | 11 ++++ pgtype/name.go | 11 ++++ pgtype/oid.go | 25 ++++++++ pgtype/oid_value.go | 11 ++++ pgtype/pgtype.go | 13 ++++ pgtype/pgtype_test.go | 61 ++++++++++++++++++- pgtype/pguint32.go | 45 ++++++++++++++ pgtype/qchar.go | 9 ++- pgtype/qchar_test.go | 4 +- pgtype/record.go | 5 ++ pgtype/text.go | 41 +++++++++++++ pgtype/text_array.go | 31 ++++++++++ pgtype/tid.go | 23 +++++++ pgtype/timestamp.go | 47 ++++++++++++++- pgtype/timestamp_array.go | 31 ++++++++++ pgtype/timestamptz.go | 47 ++++++++++++++- pgtype/timestamptz_array.go | 31 ++++++++++ pgtype/typed_array.go.erb | 30 +++++++++ pgtype/unknown.go | 12 ++++ pgtype/varchar.go | 11 ++++ pgtype/varchar_array.go | 31 ++++++++++ pgtype/xid.go | 11 ++++ query.go | 117 +++--------------------------------- query_test.go | 24 -------- stdlib/sql.go | 66 +++++++++++++++----- values.go | 4 -- 55 files changed, 1459 insertions(+), 201 deletions(-) diff --git a/pgtype/aclitem.go b/pgtype/aclitem.go index e8386ae7..77e385e6 100644 --- a/pgtype/aclitem.go +++ b/pgtype/aclitem.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "fmt" "io" ) @@ -93,3 +94,32 @@ func (src Aclitem) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { _, err := io.WriteString(w, src.String) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Aclitem) Scan(src interface{}) error { + if src == nil { + *dst = Aclitem{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Aclitem) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.String, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/aclitem_array.go b/pgtype/aclitem_array.go index 1c97e74f..20a7636a 100644 --- a/pgtype/aclitem_array.go +++ b/pgtype/aclitem_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "fmt" "io" @@ -194,3 +195,33 @@ func (src *AclitemArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } + +// Scan implements the database/sql Scanner interface. +func (dst *AclitemArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *AclitemArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/bool.go b/pgtype/bool.go index 608a6f95..736d19cf 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "fmt" "io" "strconv" @@ -126,3 +127,35 @@ func (src Bool) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := w.Write(buf) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Bool) Scan(src interface{}) error { + if src == nil { + *dst = Bool{Status: Null} + return nil + } + + switch src := src.(type) { + case bool: + *dst = Bool{Bool: src, Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Bool) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.Bool, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/bool_array.go b/pgtype/bool_array.go index cdfe9685..4705d734 100644 --- a/pgtype/bool_array.go +++ b/pgtype/bool_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -296,3 +297,33 @@ func (src *BoolArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *BoolArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *BoolArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/bytea.go b/pgtype/bytea.go index 00bed8e8..9f0266e7 100644 --- a/pgtype/bytea.go +++ b/pgtype/bytea.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/hex" "fmt" "io" @@ -12,6 +13,11 @@ type Bytea struct { } func (dst *Bytea) Set(src interface{}) error { + if src == nil { + *dst = Bytea{Status: Null} + return nil + } + switch value := src.(type) { case []byte: if value != nil { @@ -124,3 +130,35 @@ func (src Bytea) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := w.Write(src.Bytes) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Bytea) Scan(src interface{}) error { + if src == nil { + *dst = Bytea{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + buf := make([]byte, len(src)) + copy(buf, src) + *dst = Bytea{Bytes: buf, Status: Present} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Bytea) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.Bytes, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/bytea_array.go b/pgtype/bytea_array.go index 175ca2f6..268364c1 100644 --- a/pgtype/bytea_array.go +++ b/pgtype/bytea_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -296,3 +297,33 @@ func (src *ByteaArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *ByteaArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *ByteaArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/cid.go b/pgtype/cid.go index d86e8063..63ba6a2f 100644 --- a/pgtype/cid.go +++ b/pgtype/cid.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "io" ) @@ -49,3 +50,13 @@ func (src Cid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { func (src Cid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return (pguint32)(src).EncodeBinary(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *Cid) Scan(src interface{}) error { + return (*pguint32)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Cid) Value() (driver.Value, error) { + return (pguint32)(src).Value() +} diff --git a/pgtype/cidr_array.go b/pgtype/cidr_array.go index 49a2728b..6643bb47 100644 --- a/pgtype/cidr_array.go +++ b/pgtype/cidr_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -325,3 +326,33 @@ func (src *CidrArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *CidrArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *CidrArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/database_sql.go b/pgtype/database_sql.go index 969d6542..2ddd842d 100644 --- a/pgtype/database_sql.go +++ b/pgtype/database_sql.go @@ -2,47 +2,13 @@ package pgtype import ( "bytes" + "database/sql/driver" "errors" ) func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) { - switch src := src.(type) { - case *Bool: - return src.Bool, nil - case *Bytea: - return src.Bytes, nil - case *Date: - if src.InfinityModifier == None { - return src.Time, nil - } - case *Float4: - return float64(src.Float), nil - case *Float8: - return src.Float, nil - case *GenericBinary: - return src.Bytes, nil - case *GenericText: - return src.String, nil - case *Int2: - return int64(src.Int), nil - case *Int4: - return int64(src.Int), nil - case *Int8: - return int64(src.Int), nil - case *Text: - return src.String, nil - case *Timestamp: - if src.InfinityModifier == None { - return src.Time, nil - } - case *Timestamptz: - if src.InfinityModifier == None { - return src.Time, nil - } - case *Unknown: - return src.String, nil - case *Varchar: - return src.String, nil + if valuer, ok := src.(driver.Valuer); ok { + return valuer.Value() } buf := &bytes.Buffer{} @@ -64,3 +30,15 @@ func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) { return nil, errors.New("cannot convert to database/sql compatible value") } + +func encodeValueText(src TextEncoder) (interface{}, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + return buf.String(), err +} diff --git a/pgtype/date.go b/pgtype/date.go index ab854eb2..7dd2c4f0 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -10,9 +11,9 @@ import ( ) type Date struct { - Time time.Time - Status Status - InfinityModifier + Time time.Time + Status Status + InfinityModifier InfinityModifier } const ( @@ -21,6 +22,11 @@ const ( ) func (dst *Date) Set(src interface{}) error { + if src == nil { + *dst = Date{Status: Null} + return nil + } + switch value := src.(type) { case time.Time: *dst = Date{Time: value, Status: Present} @@ -167,3 +173,38 @@ func (src Date) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt32(w, daysSinceDateEpoch) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Date) Scan(src interface{}) error { + if src == nil { + *dst = Date{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + case time.Time: + *dst = Date{Time: src, Status: Present} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Date) Value() (driver.Value, error) { + switch src.Status { + case Present: + if src.InfinityModifier != None { + return src.InfinityModifier.String(), nil + } + return src.Time, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/date_array.go b/pgtype/date_array.go index bf791677..f58de011 100644 --- a/pgtype/date_array.go +++ b/pgtype/date_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -297,3 +298,33 @@ func (src *DateArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *DateArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *DateArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/date_test.go b/pgtype/date_test.go index cfc3dd70..1832b5b4 100644 --- a/pgtype/date_test.go +++ b/pgtype/date_test.go @@ -9,7 +9,7 @@ import ( ) func TestDateTranscode(t *testing.T) { - testSuccessfulTranscode(t, "date", []interface{}{ + testSuccessfulTranscodeEqFunc(t, "date", []interface{}{ pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, @@ -19,6 +19,11 @@ func TestDateTranscode(t *testing.T) { pgtype.Date{Status: pgtype.Null}, pgtype.Date{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, pgtype.Date{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, + }, func(a, b interface{}) bool { + at := a.(pgtype.Date) + bt := b.(pgtype.Date) + + return at.Time.Equal(bt.Time) && at.Status == bt.Status && at.InfinityModifier == bt.InfinityModifier }) } diff --git a/pgtype/float4.go b/pgtype/float4.go index 94b7b7a1..e92149a6 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -16,6 +17,11 @@ type Float4 struct { } func (dst *Float4) Set(src interface{}) error { + if src == nil { + *dst = Float4{Status: Null} + return nil + } + switch value := src.(type) { case float32: *dst = Float4{Float: value, Status: Present} @@ -156,3 +162,35 @@ func (src Float4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt32(w, int32(math.Float32bits(src.Float))) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Float4) Scan(src interface{}) error { + if src == nil { + *dst = Float4{Status: Null} + return nil + } + + switch src := src.(type) { + case float64: + *dst = Float4{Float: float32(src), Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Float4) Value() (driver.Value, error) { + switch src.Status { + case Present: + return float64(src.Float), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/float4_array.go b/pgtype/float4_array.go index b4d05c55..b9ee4b9e 100644 --- a/pgtype/float4_array.go +++ b/pgtype/float4_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -296,3 +297,33 @@ func (src *Float4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Float4Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Float4Array) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/float8.go b/pgtype/float8.go index dd2d592d..4d094757 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -16,6 +17,11 @@ type Float8 struct { } func (dst *Float8) Set(src interface{}) error { + if src == nil { + *dst = Float8{Status: Null} + return nil + } + switch value := src.(type) { case float32: *dst = Float8{Float: float64(value), Status: Present} @@ -146,3 +152,35 @@ func (src Float8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt64(w, int64(math.Float64bits(src.Float))) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Float8) Scan(src interface{}) error { + if src == nil { + *dst = Float8{Status: Null} + return nil + } + + switch src := src.(type) { + case float64: + *dst = Float8{Float: src, Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Float8) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.Float, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/float8_array.go b/pgtype/float8_array.go index e000807e..d49f18a7 100644 --- a/pgtype/float8_array.go +++ b/pgtype/float8_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -296,3 +297,33 @@ func (src *Float8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Float8Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Float8Array) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/generic_binary.go b/pgtype/generic_binary.go index aa28bb62..f834bfb2 100644 --- a/pgtype/generic_binary.go +++ b/pgtype/generic_binary.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "io" ) @@ -27,3 +28,13 @@ func (dst *GenericBinary) DecodeBinary(ci *ConnInfo, src []byte) error { func (src GenericBinary) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return (Bytea)(src).EncodeBinary(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *GenericBinary) Scan(src interface{}) error { + return (*Bytea)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src GenericBinary) Value() (driver.Value, error) { + return (Bytea)(src).Value() +} diff --git a/pgtype/generic_text.go b/pgtype/generic_text.go index bd75e0d0..053ec504 100644 --- a/pgtype/generic_text.go +++ b/pgtype/generic_text.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "io" ) @@ -27,3 +28,13 @@ func (dst *GenericText) DecodeText(ci *ConnInfo, src []byte) error { func (src GenericText) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return (Text)(src).EncodeText(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *GenericText) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src GenericText) Value() (driver.Value, error) { + return (Text)(src).Value() +} diff --git a/pgtype/hstore.go b/pgtype/hstore.go index 8dc5b4d8..b8b0c6f3 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "errors" "fmt" @@ -21,6 +22,11 @@ type Hstore struct { } func (dst *Hstore) Set(src interface{}) error { + if src == nil { + *dst = Hstore{Status: Null} + return nil + } + switch value := src.(type) { case map[string]string: m := make(map[string]Text, len(value)) @@ -437,3 +443,25 @@ func parseHstore(s string) (k []string, v []Text, err error) { v = values return } + +// Scan implements the database/sql Scanner interface. +func (dst *Hstore) Scan(src interface{}) error { + if src == nil { + *dst = Hstore{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Hstore) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/pgtype/hstore_array.go b/pgtype/hstore_array.go index 9bd0ed3b..097fec7b 100644 --- a/pgtype/hstore_array.go +++ b/pgtype/hstore_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -296,3 +297,33 @@ func (src *HstoreArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *HstoreArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *HstoreArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/inet.go b/pgtype/inet.go index 13764814..0ca3ee7a 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "fmt" "io" "net" @@ -23,6 +24,11 @@ type Inet struct { } func (dst *Inet) Set(src interface{}) error { + if src == nil { + *dst = Inet{Status: Null} + return nil + } + switch value := src.(type) { case net.IPNet: *dst = Inet{IPNet: &value, Status: Present} @@ -189,3 +195,25 @@ func (src Inet) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := w.Write(src.IPNet.IP) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Inet) Scan(src interface{}) error { + if src == nil { + *dst = Inet{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Inet) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/pgtype/inet_array.go b/pgtype/inet_array.go index 1988a145..a108d75b 100644 --- a/pgtype/inet_array.go +++ b/pgtype/inet_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -325,3 +326,33 @@ func (src *InetArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *InetArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *InetArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/int2.go b/pgtype/int2.go index 6996cd4f..3bcac63c 100644 --- a/pgtype/int2.go +++ b/pgtype/int2.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -16,6 +17,11 @@ type Int2 struct { } func (dst *Int2) Set(src interface{}) error { + if src == nil { + *dst = Int2{Status: Null} + return nil + } + switch value := src.(type) { case int8: *dst = Int2{Int: int16(value), Status: Present} @@ -151,3 +157,41 @@ func (src Int2) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt16(w, src.Int) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Int2) Scan(src interface{}) error { + if src == nil { + *dst = Int2{Status: Null} + return nil + } + + switch src := src.(type) { + case int64: + if src < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", src) + } + if src > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", src) + } + *dst = Int2{Int: int16(src), Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int2) Value() (driver.Value, error) { + switch src.Status { + case Present: + return int64(src.Int), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/int2_array.go b/pgtype/int2_array.go index 531e7dd6..bddb5ac2 100644 --- a/pgtype/int2_array.go +++ b/pgtype/int2_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -324,3 +325,33 @@ func (src *Int2Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Int2Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Int2Array) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/int4.go b/pgtype/int4.go index 62ee366f..5069dab4 100644 --- a/pgtype/int4.go +++ b/pgtype/int4.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -16,6 +17,11 @@ type Int4 struct { } func (dst *Int4) Set(src interface{}) error { + if src == nil { + *dst = Int4{Status: Null} + return nil + } + switch value := src.(type) { case int8: *dst = Int4{Int: int32(value), Status: Present} @@ -68,7 +74,7 @@ func (dst *Int4) Set(src interface{}) error { if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Int8", value) + return fmt.Errorf("cannot convert %v to Int4", value) } return nil @@ -142,3 +148,41 @@ func (src Int4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt32(w, src.Int) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Int4) Scan(src interface{}) error { + if src == nil { + *dst = Int4{Status: Null} + return nil + } + + switch src := src.(type) { + case int64: + if src < math.MinInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", src) + } + if src > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", src) + } + *dst = Int4{Int: int32(src), Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int4) Value() (driver.Value, error) { + switch src.Status { + case Present: + return int64(src.Int), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/int4_array.go b/pgtype/int4_array.go index 3617050f..d5c8f911 100644 --- a/pgtype/int4_array.go +++ b/pgtype/int4_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -324,3 +325,33 @@ func (src *Int4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Int4Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Int4Array) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/int8.go b/pgtype/int8.go index 7ed54f8e..cf701dc6 100644 --- a/pgtype/int8.go +++ b/pgtype/int8.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -16,6 +17,11 @@ type Int8 struct { } func (dst *Int8) Set(src interface{}) error { + if src == nil { + *dst = Int8{Status: Null} + return nil + } + switch value := src.(type) { case int8: *dst = Int8{Int: int64(value), Status: Present} @@ -134,3 +140,35 @@ func (src Int8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt64(w, src.Int) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Int8) Scan(src interface{}) error { + if src == nil { + *dst = Int8{Status: Null} + return nil + } + + switch src := src.(type) { + case int64: + *dst = Int8{Int: src, Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int8) Value() (driver.Value, error) { + switch src.Status { + case Present: + return int64(src.Int), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/int8_array.go b/pgtype/int8_array.go index 4f04b660..ae2521fa 100644 --- a/pgtype/int8_array.go +++ b/pgtype/int8_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -324,3 +325,33 @@ func (src *Int8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Int8Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Int8Array) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/json.go b/pgtype/json.go index bfffae14..05d965ca 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -1,7 +1,9 @@ package pgtype import ( + "database/sql/driver" "encoding/json" + "fmt" "io" ) @@ -11,6 +13,11 @@ type Json struct { } func (dst *Json) Set(src interface{}) error { + if src == nil { + *dst = Json{Status: Null} + return nil + } + switch value := src.(type) { case string: *dst = Json{Bytes: []byte(value), Status: Present} @@ -116,3 +123,32 @@ func (src Json) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { func (src Json) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return src.EncodeText(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *Json) Scan(src interface{}) error { + if src == nil { + *dst = Json{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Json) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.Bytes, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/jsonb.go b/pgtype/jsonb.go index e44f3c41..f47476d6 100644 --- a/pgtype/jsonb.go +++ b/pgtype/jsonb.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "fmt" "io" ) @@ -66,3 +67,13 @@ func (src Jsonb) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err = w.Write(src.Bytes) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Jsonb) Scan(src interface{}) error { + return (*Json)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Jsonb) Value() (driver.Value, error) { + return (Json)(src).Value() +} diff --git a/pgtype/name.go b/pgtype/name.go index 9ebf63d3..cc4ae23b 100644 --- a/pgtype/name.go +++ b/pgtype/name.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "io" ) @@ -46,3 +47,13 @@ func (src Name) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { func (src Name) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return (Text)(src).EncodeBinary(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *Name) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Name) Value() (driver.Value, error) { + return (Text)(src).Value() +} diff --git a/pgtype/oid.go b/pgtype/oid.go index 3edd7f3c..339dee0f 100644 --- a/pgtype/oid.go +++ b/pgtype/oid.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -55,3 +56,27 @@ func (src Oid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteUint32(w, uint32(src)) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Oid) Scan(src interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", src) + } + + switch src := src.(type) { + case int64: + *dst = Oid(src) + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Oid) Value() (driver.Value, error) { + return int64(src), nil +} diff --git a/pgtype/oid_value.go b/pgtype/oid_value.go index 1bce6e11..cb03802e 100644 --- a/pgtype/oid_value.go +++ b/pgtype/oid_value.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "io" ) @@ -43,3 +44,13 @@ func (src OidValue) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { func (src OidValue) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return (pguint32)(src).EncodeBinary(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *OidValue) Scan(src interface{}) error { + return (*pguint32)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src OidValue) Value() (driver.Value, error) { + return (pguint32)(src).Value() +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 674c0db7..7e6633d9 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -67,6 +67,19 @@ const ( NegativeInfinity InfinityModifier = -Infinity ) +func (im InfinityModifier) String() string { + switch im { + case None: + return "none" + case Infinity: + return "infinity" + case NegativeInfinity: + return "-infinity" + default: + return "invalid" + } +} + type Value interface { // Set converts and assigns src to itself. Set(src interface{}) error diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 391fed57..16cabfd1 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "database/sql" "fmt" "io" "net" @@ -10,6 +11,8 @@ import ( "github.com/jackc/pgx" "github.com/jackc/pgx/pgtype" + _ "github.com/jackc/pgx/stdlib" + _ "github.com/lib/pq" ) // Test for renamed types @@ -24,6 +27,25 @@ type _float32Slice []float32 type _float64Slice []float64 type _byteSlice []byte +func mustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { + var sqlDriverName string + switch driverName { + case "github.com/lib/pq": + sqlDriverName = "postgres" + case "github.com/jackc/pgx/stdlib": + sqlDriverName = "pgx" + default: + t.Fatalf("Unknown driver %v", driverName) + } + + db, err := sql.Open(sqlDriverName, os.Getenv("DATABASE_URL")) + if err != nil { + t.Fatal(err) + } + + return db +} + func mustConnectPgx(t testing.TB) *pgx.Conn { config, err := pgx.ParseURI(os.Getenv("DATABASE_URL")) if err != nil { @@ -93,6 +115,13 @@ func testSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface } func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + testPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + testDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) + } +} + +func testPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { conn := mustConnectPgx(t) defer mustClose(t, conn) @@ -114,7 +143,7 @@ func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []int ps.FieldDescriptions[0].FormatCode = fc.formatCode vEncoder := forceEncoder(v, fc.formatCode) if vEncoder == nil { - t.Logf("%#v does not implement %v", v, fc.name) + t.Logf("Skipping: %#v does not implement %v", v, fc.name) continue } // Derefence value if it is a pointer @@ -136,3 +165,33 @@ func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []int } } } + +func testDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + conn := mustConnectDatabaseSQL(t, driverName) + defer mustClose(t, conn) + + ps, err := conn.Prepare(fmt.Sprintf("select $1::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + for i, v := range values { + // Derefence value if it is a pointer + derefV := v + refVal := reflect.ValueOf(v) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err := ps.QueryRow(v).Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", driverName, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) + } + } + +} diff --git a/pgtype/pguint32.go b/pgtype/pguint32.go index 3f9e7bf7..7138a409 100644 --- a/pgtype/pguint32.go +++ b/pgtype/pguint32.go @@ -1,9 +1,11 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" + "math" "strconv" "github.com/jackc/pgx/pgio" @@ -21,6 +23,14 @@ type pguint32 struct { // types do. func (dst *pguint32) Set(src interface{}) error { switch value := src.(type) { + case int64: + if value < 0 { + return fmt.Errorf("%d is less than minimum value for pguint32", value) + } + if value > math.MaxUint32 { + return fmt.Errorf("%d is greater than maximum value for pguint32", value) + } + *dst = pguint32{Uint: uint32(value), Status: Present} case uint32: *dst = pguint32{Uint: value, Status: Present} default: @@ -116,3 +126,38 @@ func (src pguint32) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteUint32(w, src.Uint) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *pguint32) Scan(src interface{}) error { + if src == nil { + *dst = pguint32{Status: Null} + return nil + } + + switch src := src.(type) { + case uint32: + *dst = pguint32{Uint: src, Status: Present} + return nil + case int64: + *dst = pguint32{Uint: uint32(src), Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src pguint32) Value() (driver.Value, error) { + switch src.Status { + case Present: + return int64(src.Uint), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/qchar.go b/pgtype/qchar.go index 4b32ee4a..49475bd3 100644 --- a/pgtype/qchar.go +++ b/pgtype/qchar.go @@ -17,13 +17,20 @@ import ( // standard type char. // // Not all possible values of QChar are representable in the text format. -// Therefore, QChar does not implement TextEncoder and TextDecoder. +// Therefore, QChar does not implement TextEncoder and TextDecoder. In +// addition, database/sql Scanner and database/sql/driver Value are not +// implemented. type QChar struct { Int int8 Status Status } func (dst *QChar) Set(src interface{}) error { + if src == nil { + *dst = QChar{Status: Null} + return nil + } + switch value := src.(type) { case int8: *dst = QChar{Int: value, Status: Present} diff --git a/pgtype/qchar_test.go b/pgtype/qchar_test.go index a1b6d22e..afac5016 100644 --- a/pgtype/qchar_test.go +++ b/pgtype/qchar_test.go @@ -9,13 +9,15 @@ import ( ) func TestQCharTranscode(t *testing.T) { - testSuccessfulTranscode(t, `"char"`, []interface{}{ + testPgxSuccessfulTranscodeEqFunc(t, `"char"`, []interface{}{ pgtype.QChar{Int: math.MinInt8, Status: pgtype.Present}, pgtype.QChar{Int: -1, Status: pgtype.Present}, pgtype.QChar{Int: 0, Status: pgtype.Present}, pgtype.QChar{Int: 1, Status: pgtype.Present}, pgtype.QChar{Int: math.MaxInt8, Status: pgtype.Present}, pgtype.QChar{Int: 0, Status: pgtype.Null}, + }, func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) }) } diff --git a/pgtype/record.go b/pgtype/record.go index 89e081ca..9c42c907 100644 --- a/pgtype/record.go +++ b/pgtype/record.go @@ -16,6 +16,11 @@ type Record struct { } func (dst *Record) Set(src interface{}) error { + if src == nil { + *dst = Record{Status: Null} + return nil + } + switch value := src.(type) { case []Value: *dst = Record{Fields: value, Status: Present} diff --git a/pgtype/text.go b/pgtype/text.go index dbc9362b..482c9023 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "fmt" "io" ) @@ -11,6 +12,11 @@ type Text struct { } func (dst *Text) Set(src interface{}) error { + if src == nil { + *dst = Text{Status: Null} + return nil + } + switch value := src.(type) { case string: *dst = Text{String: value, Status: Present} @@ -20,6 +26,12 @@ func (dst *Text) Set(src interface{}) error { } else { *dst = Text{String: *value, Status: Present} } + case []byte: + if value == nil { + *dst = Text{Status: Null} + } else { + *dst = Text{String: string(value), Status: Present} + } default: if originalSrc, ok := underlyingStringType(src); ok { return dst.Set(originalSrc) @@ -93,3 +105,32 @@ func (src Text) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { func (src Text) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return src.EncodeText(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *Text) Scan(src interface{}) error { + if src == nil { + *dst = Text{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Text) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.String, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/text_array.go b/pgtype/text_array.go index 6e8ead26..64728048 100644 --- a/pgtype/text_array.go +++ b/pgtype/text_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -296,3 +297,33 @@ func (src *TextArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *TextArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *TextArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/tid.go b/pgtype/tid.go index b91711d3..b363c1f9 100644 --- a/pgtype/tid.go +++ b/pgtype/tid.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -121,3 +122,25 @@ func (src Tid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err = pgio.WriteUint16(w, src.OffsetNumber) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Tid) Scan(src interface{}) error { + if src == nil { + *dst = Tid{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Tid) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index 4b42f3cf..78c6355e 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -17,14 +18,19 @@ const pgTimestampFormat = "2006-01-02 15:04:05.999999999" // recommended to use timestamptz whenever possible. Timestamp methods either // convert to UTC or return an error on non-UTC times. type Timestamp struct { - Time time.Time // Time must always be in UTC. - Status Status - InfinityModifier + Time time.Time // Time must always be in UTC. + Status Status + InfinityModifier InfinityModifier } // Set converts src into a Timestamp and stores in dst. If src is a // time.Time in a non-UTC time zone, the time zone is discarded. func (dst *Timestamp) Set(src interface{}) error { + if src == nil { + *dst = Timestamp{Status: Null} + return nil + } + switch value := src.(type) { case time.Time: *dst = Timestamp{Time: time.Date(value.Year(), value.Month(), value.Day(), value.Hour(), value.Minute(), value.Second(), value.Nanosecond(), time.UTC), Status: Present} @@ -183,3 +189,38 @@ func (src Timestamp) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt64(w, microsecSinceY2K) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Timestamp) Scan(src interface{}) error { + if src == nil { + *dst = Timestamp{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + case time.Time: + *dst = Timestamp{Time: src, Status: Present} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Timestamp) Value() (driver.Value, error) { + switch src.Status { + case Present: + if src.InfinityModifier != None { + return src.InfinityModifier.String(), nil + } + return src.Time, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/timestamp_array.go b/pgtype/timestamp_array.go index 6a6950c7..5d08f9cc 100644 --- a/pgtype/timestamp_array.go +++ b/pgtype/timestamp_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -297,3 +298,33 @@ func (src *TimestampArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *TimestampArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *TimestampArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index ba849ac8..50370335 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -20,12 +21,17 @@ const ( ) type Timestamptz struct { - Time time.Time - Status Status - InfinityModifier + Time time.Time + Status Status + InfinityModifier InfinityModifier } func (dst *Timestamptz) Set(src interface{}) error { + if src == nil { + *dst = Timestamptz{Status: Null} + return nil + } + switch value := src.(type) { case time.Time: *dst = Timestamptz{Time: value, Status: Present} @@ -179,3 +185,38 @@ func (src Timestamptz) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt64(w, microsecSinceY2K) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Timestamptz) Scan(src interface{}) error { + if src == nil { + *dst = Timestamptz{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + case time.Time: + *dst = Timestamptz{Time: src, Status: Present} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Timestamptz) Value() (driver.Value, error) { + switch src.Status { + case Present: + if src.InfinityModifier != None { + return src.InfinityModifier.String(), nil + } + return src.Time, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/timestamptz_array.go b/pgtype/timestamptz_array.go index 347d0b8b..107be06a 100644 --- a/pgtype/timestamptz_array.go +++ b/pgtype/timestamptz_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -297,3 +298,33 @@ func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, erro return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *TimestamptzArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *TimestamptzArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb index 0e5725ce..4b8f1a28 100644 --- a/pgtype/typed_array.go.erb +++ b/pgtype/typed_array.go.erb @@ -299,3 +299,33 @@ func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool return false, err } <% end %> + +// Scan implements the database/sql Scanner interface. +func (dst *<%= pgtype_array_type %>) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *<%= pgtype_array_type %>) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/unknown.go b/pgtype/unknown.go index b951ad99..2dca0f87 100644 --- a/pgtype/unknown.go +++ b/pgtype/unknown.go @@ -1,5 +1,7 @@ package pgtype +import "database/sql/driver" + // Unknown represents the PostgreSQL unknown type. It is either a string literal // or NULL. It is used when PostgreSQL does not know the type of a value. In // general, this will only be used in pgx when selecting a null value without @@ -30,3 +32,13 @@ func (dst *Unknown) DecodeText(ci *ConnInfo, src []byte) error { func (dst *Unknown) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Text)(dst).DecodeBinary(ci, src) } + +// Scan implements the database/sql Scanner interface. +func (dst *Unknown) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Unknown) Value() (driver.Value, error) { + return (Text)(src).Value() +} diff --git a/pgtype/varchar.go b/pgtype/varchar.go index adda6c49..f25ada5d 100644 --- a/pgtype/varchar.go +++ b/pgtype/varchar.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "io" ) @@ -38,3 +39,13 @@ func (src Varchar) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { func (src Varchar) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return (Text)(src).EncodeBinary(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *Varchar) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Varchar) Value() (driver.Value, error) { + return (Text)(src).Value() +} diff --git a/pgtype/varchar_array.go b/pgtype/varchar_array.go index e1dd3910..2712b4d2 100644 --- a/pgtype/varchar_array.go +++ b/pgtype/varchar_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -296,3 +297,33 @@ func (src *VarcharArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *VarcharArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *VarcharArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/xid.go b/pgtype/xid.go index c76548a4..0a7fc7d9 100644 --- a/pgtype/xid.go +++ b/pgtype/xid.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "io" ) @@ -52,3 +53,13 @@ func (src Xid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { func (src Xid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return (pguint32)(src).EncodeBinary(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *Xid) Scan(src interface{}) error { + return (*pguint32)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Xid) Value() (driver.Value, error) { + return (pguint32)(src).Value() +} diff --git a/query.go b/query.go index 0b5cc911..e820fabc 100644 --- a/query.go +++ b/query.go @@ -208,47 +208,6 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } - } else if s, ok := d.(sql.Scanner); ok { - var sqlSrc interface{} - if 0 <= vr.Len() { - if dt, ok := rows.conn.ConnInfo.DataTypeForOid(vr.Type().DataType); ok { - value := dt.Value - - switch vr.Type().FormatCode { - case TextFormatCode: - decoder := value.(pgtype.TextDecoder) - if decoder == nil { - decoder = &pgtype.GenericText{} - } - err := decoder.DecodeText(rows.conn.ConnInfo, vr.bytes()) - if err != nil { - rows.Fatal(err) - } - case BinaryFormatCode: - decoder := value.(pgtype.BinaryDecoder) - if decoder == nil { - decoder = &pgtype.GenericBinary{} - } - err := decoder.DecodeBinary(rows.conn.ConnInfo, vr.bytes()) - if err != nil { - rows.Fatal(err) - } - default: - rows.Fatal(errors.New("Unknown format code")) - } - - sqlSrc, err = pgtype.DatabaseSQLValue(rows.conn.ConnInfo, value) - if err != nil { - rows.Fatal(err) - } - } else { - rows.Fatal(errors.New("Unknown type")) - } - } - err = s.Scan(sqlSrc) - if err != nil { - rows.Fatal(scanArgError{col: i, err: err}) - } } else { if dt, ok := rows.conn.ConnInfo.DataTypeForOid(vr.Type().DataType); ok { value := dt.Value @@ -276,7 +235,16 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { } if vr.Err() == nil { - if err := value.AssignTo(d); err != nil { + if scanner, ok := d.(sql.Scanner); ok { + sqlSrc, err := pgtype.DatabaseSQLValue(rows.conn.ConnInfo, value) + if err != nil { + rows.Fatal(err) + } + err = scanner.Scan(sqlSrc) + if err != nil { + rows.Fatal(scanArgError{col: i, err: err}) + } + } else if err := value.AssignTo(d); err != nil { vr.Fatal(err) } } @@ -355,71 +323,6 @@ func (rows *Rows) Values() ([]interface{}, error) { return values, rows.Err() } -// ValuesForStdlib is a temporary function to keep all systems operational -// while refactoring. Do not use. -func (rows *Rows) ValuesForStdlib() ([]interface{}, error) { - if rows.closed { - return nil, errors.New("rows is closed") - } - - values := make([]interface{}, 0, len(rows.fields)) - - for range rows.fields { - vr, _ := rows.nextColumn() - - if vr.Len() == -1 { - values = append(values, nil) - continue - } - - if dt, ok := rows.conn.ConnInfo.DataTypeForOid(vr.Type().DataType); ok { - value := dt.Value - - switch vr.Type().FormatCode { - case TextFormatCode: - decoder := value.(pgtype.TextDecoder) - if decoder == nil { - decoder = &pgtype.GenericText{} - } - err := decoder.DecodeText(rows.conn.ConnInfo, vr.bytes()) - if err != nil { - rows.Fatal(err) - } - case BinaryFormatCode: - decoder := value.(pgtype.BinaryDecoder) - if decoder == nil { - decoder = &pgtype.GenericBinary{} - } - err := decoder.DecodeBinary(rows.conn.ConnInfo, vr.bytes()) - if err != nil { - rows.Fatal(err) - } - default: - rows.Fatal(errors.New("Unknown format code")) - } - - sqlSrc, err := pgtype.DatabaseSQLValue(rows.conn.ConnInfo, value) - if err != nil { - rows.Fatal(err) - } - - values = append(values, sqlSrc) - } else { - rows.Fatal(errors.New("Unknown type")) - } - - if vr.Err() != nil { - rows.Fatal(vr.Err()) - } - - if rows.Err() != nil { - return nil, rows.Err() - } - } - - return values, rows.Err() -} - // AfterClose adds f to a LILO queue of functions that will be called when // rows is closed. func (rows *Rows) AfterClose(f func(*Rows)) { diff --git a/query_test.go b/query_test.go index b053e26d..25347ec5 100644 --- a/query_test.go +++ b/query_test.go @@ -704,30 +704,6 @@ func TestQueryRowCoreByteSlice(t *testing.T) { } } -func TestQueryRowByteSliceArgument(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - sql := "select $1::int4" - queryArg := []byte{14, 63, 53, 49} - expected := int32(239023409) - - var actual int32 - - err := conn.QueryRow(sql, queryArg).Scan(&actual) - if err != nil { - t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql) - } - - if expected != actual { - t.Errorf("Expected %v, got %v (sql -> %v)", expected, actual, sql) - } - - ensureConnValid(t, conn) -} - func TestQueryRowUnknownType(t *testing.T) { t.Parallel() diff --git a/stdlib/sql.go b/stdlib/sql.go index 6889a2b6..affa93b6 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -68,14 +68,17 @@ func init() { databaseSqlOids = make(map[pgtype.Oid]bool) databaseSqlOids[pgtype.BoolOid] = true databaseSqlOids[pgtype.ByteaOid] = true + databaseSqlOids[pgtype.CidOid] = true + databaseSqlOids[pgtype.DateOid] = true + databaseSqlOids[pgtype.Float4Oid] = true + databaseSqlOids[pgtype.Float8Oid] = true databaseSqlOids[pgtype.Int2Oid] = true databaseSqlOids[pgtype.Int4Oid] = true databaseSqlOids[pgtype.Int8Oid] = true - databaseSqlOids[pgtype.Float4Oid] = true - databaseSqlOids[pgtype.Float8Oid] = true - databaseSqlOids[pgtype.DateOid] = true - databaseSqlOids[pgtype.TimestamptzOid] = true + databaseSqlOids[pgtype.OidOid] = true databaseSqlOids[pgtype.TimestampOid] = true + databaseSqlOids[pgtype.TimestamptzOid] = true + databaseSqlOids[pgtype.XidOid] = true } type Driver struct { @@ -292,9 +295,9 @@ func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) { return s.conn.queryPrepared(s.ps.Name, argsV) } -// TODO - rename to avoid alloc type Rows struct { - rows *pgx.Rows + rows *pgx.Rows + values []interface{} } func (r *Rows) Columns() []string { @@ -312,6 +315,42 @@ func (r *Rows) Close() error { } func (r *Rows) Next(dest []driver.Value) error { + if r.values == nil { + r.values = make([]interface{}, len(r.rows.FieldDescriptions())) + for i, fd := range r.rows.FieldDescriptions() { + switch fd.DataType { + case pgtype.BoolOid: + r.values[i] = &pgtype.Bool{} + case pgtype.ByteaOid: + r.values[i] = &pgtype.Bytea{} + case pgtype.CidOid: + r.values[i] = &pgtype.Cid{} + case pgtype.DateOid: + r.values[i] = &pgtype.Date{} + case pgtype.Float4Oid: + r.values[i] = &pgtype.Float4{} + case pgtype.Float8Oid: + r.values[i] = &pgtype.Float8{} + case pgtype.Int2Oid: + r.values[i] = &pgtype.Int2{} + case pgtype.Int4Oid: + r.values[i] = &pgtype.Int4{} + case pgtype.Int8Oid: + r.values[i] = &pgtype.Int8{} + case pgtype.OidOid: + r.values[i] = &pgtype.OidValue{} + case pgtype.TimestampOid: + r.values[i] = &pgtype.Timestamp{} + case pgtype.TimestamptzOid: + r.values[i] = &pgtype.Timestamptz{} + case pgtype.XidOid: + r.values[i] = &pgtype.Xid{} + default: + r.values[i] = &pgtype.GenericText{} + } + } + } + more := r.rows.Next() if !more { if r.rows.Err() == nil { @@ -321,19 +360,16 @@ func (r *Rows) Next(dest []driver.Value) error { } } - values, err := r.rows.ValuesForStdlib() + err := r.rows.Scan(r.values...) if err != nil { return err } - if len(dest) < len(values) { - fmt.Printf("%d: %#v\n", len(dest), dest) - fmt.Printf("%d: %#v\n", len(values), values) - return errors.New("expected more values than were received") - } - - for i, v := range values { - dest[i] = driver.Value(v) + for i, v := range r.values { + dest[i], err = v.(driver.Valuer).Value() + if err != nil { + return err + } } return nil diff --git a/values.go b/values.go index c399b42c..5370bf47 100644 --- a/values.go +++ b/values.go @@ -65,10 +65,6 @@ func encodePreparedStatementArgument(wbuf *WriteBuf, oid pgtype.Oid, arg interfa wbuf.WriteInt32(int32(len(arg))) wbuf.WriteBytes([]byte(arg)) return nil - case []byte: - wbuf.WriteInt32(int32(len(arg))) - wbuf.WriteBytes(arg) - return nil } refVal := reflect.ValueOf(arg) From db6c5daa70caee11804c2988e82e87ed5f54dd63 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 20 Mar 2017 08:00:43 -0500 Subject: [PATCH 134/264] Run goimports as part of array gen script --- pgtype/typed_array_gen.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh index d77c8ca3..52612466 100644 --- a/pgtype/typed_array_gen.sh +++ b/pgtype/typed_array_gen.sh @@ -14,3 +14,4 @@ erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[] erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go erb pgtype_array_type=AclitemArray pgtype_element_type=Aclitem go_array_types=[]string element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go +goimports -w *_array.go From ed8bfa4f42b482b7898bb44a21441f75816434f7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 20 Mar 2017 08:38:06 -0500 Subject: [PATCH 135/264] pgtype tests now require pq --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index cd9ab572..069cfcb6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -53,6 +53,7 @@ install: - go get -u gopkg.in/inconshreveable/log15.v2 - go get -u github.com/jackc/fake - go get -u github.com/jackc/pgmock/pgmsg + - go get -u github.com/lib/pq script: - go test -v -race ./... From 120da8df8fc1bc9f82454ff18029598d4d6dfca7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 20 Mar 2017 08:58:28 -0500 Subject: [PATCH 136/264] Skip jsonb test if no jsonb type --- pgtype/jsonb_test.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pgtype/jsonb_test.go b/pgtype/jsonb_test.go index 3978b0d4..91637eb8 100644 --- a/pgtype/jsonb_test.go +++ b/pgtype/jsonb_test.go @@ -9,6 +9,12 @@ import ( ) func TestJsonbTranscode(t *testing.T) { + conn := mustConnectPgx(t) + defer mustClose(t, conn) + if _, ok := conn.ConnInfo.DataTypeForName("jsonb"); !ok { + t.Skip("Skipping due to no jsonb type") + } + testSuccessfulTranscode(t, "jsonb", []interface{}{ pgtype.Jsonb{Bytes: []byte("{}"), Status: pgtype.Present}, pgtype.Jsonb{Bytes: []byte("null"), Status: pgtype.Present}, From 1a99c0e5c478b5c14472cfef7924caff4947e2d0 Mon Sep 17 00:00:00 2001 From: Terin Stock Date: Mon, 20 Mar 2017 13:24:44 -0700 Subject: [PATCH 137/264] fix(stdlib): lock openFromConnPoolCount while using Locks the `openFromConnPoolCount` counter while formatting the driver name and incrementing to avoid a data race of multiple goroutines modifying the counter and registering the same name. `sql.Register` panics if a driver name has already been registered. --- stdlib/sql.go | 10 +++++++++- stdlib/sql_test.go | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/stdlib/sql.go b/stdlib/sql.go index affa93b6..e3d46cab 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -50,12 +50,16 @@ import ( "errors" "fmt" "io" + "sync" "github.com/jackc/pgx" "github.com/jackc/pgx/pgtype" ) -var openFromConnPoolCount int +var ( + openFromConnPoolCountMu sync.Mutex + openFromConnPoolCount int +) // oids that map to intrinsic database/sql types. These will be allowed to be // binary, anything else will be forced to text format @@ -120,8 +124,12 @@ func (d *Driver) Open(name string) (driver.Conn, error) { // pool connection size must be at least 2. func OpenFromConnPool(pool *pgx.ConnPool) (*sql.DB, error) { d := &Driver{Pool: pool} + + openFromConnPoolCountMu.Lock() name := fmt.Sprintf("pgx-%d", openFromConnPoolCount) openFromConnPoolCount++ + openFromConnPoolCountMu.Unlock() + sql.Register(name, d) db, err := sql.Open(name, "") if err != nil { diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index c8062c61..641ba9fe 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -3,6 +3,7 @@ package stdlib_test import ( "bytes" "database/sql" + "sync" "testing" "github.com/jackc/pgx" @@ -164,6 +165,43 @@ func TestOpenFromConnPool(t *testing.T) { } } +func TestOpenFromConnPoolRace(t *testing.T) { + wg := &sync.WaitGroup{} + connConfig := pgx.ConnConfig{ + Host: "127.0.0.1", + User: "pgx_md5", + Password: "secret", + Database: "pgx_test", + } + + config := pgx.ConnPoolConfig{ConnConfig: connConfig} + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + wg.Add(10) + for i := 0; i < 10; i++ { + go func() { + defer wg.Done() + db, err := stdlib.OpenFromConnPool(pool) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer closeDB(t, db) + + // Can get pgx.ConnPool from driver + driver := db.Driver().(*stdlib.Driver) + if driver.Pool == nil { + t.Fatal("Expected driver opened through OpenFromConnPool to have Pool, but it did not") + } + }() + } + + wg.Wait() +} + func TestStmtExec(t *testing.T) { db := openDB(t) defer closeDB(t, db) From 7eae904eba5210b3e3903b4cda1963ad985dd228 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 23 Mar 2017 18:41:52 -0500 Subject: [PATCH 138/264] Add int4range --- conn.go | 4 +- pgtype/int4range.go | 268 ++++++++++++++++++++++++++++++++++++++ pgtype/int4range_test.go | 25 ++++ pgtype/pgtype.go | 1 + pgtype/pgtype_test.go | 93 +++++++++++++ pgtype/range.go | 273 +++++++++++++++++++++++++++++++++++++++ pgtype/range_test.go | 177 +++++++++++++++++++++++++ 7 files changed, 839 insertions(+), 2 deletions(-) create mode 100644 pgtype/int4range.go create mode 100644 pgtype/int4range_test.go create mode 100644 pgtype/range.go create mode 100644 pgtype/range_test.go diff --git a/conn.go b/conn.go index 509e9d8e..d79a4e97 100644 --- a/conn.go +++ b/conn.go @@ -366,8 +366,8 @@ func (c *Conn) initConnInfo() error { from pg_type t left join pg_type base_type on t.typelem=base_type.oid where ( - t.typtype in('b', 'p') - and (base_type.oid is null or base_type.typtype in('b', 'p')) + t.typtype in('b', 'p', 'r') + and (base_type.oid is null or base_type.typtype in('b', 'p', 'r')) )`) if err != nil { return err diff --git a/pgtype/int4range.go b/pgtype/int4range.go new file mode 100644 index 00000000..cac4484c --- /dev/null +++ b/pgtype/int4range.go @@ -0,0 +1,268 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type Int4range struct { + Lower Int4 + Upper Int4 + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *Int4range) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Int4range", src) +} + +func (dst *Int4range) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Int4range) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Int4range) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int4range{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Int4range{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Int4range) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int4range{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Int4range{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Int4range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, '('); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, '['); err != nil { + return false, err + } + case Empty: + _, err := io.WriteString(w, "empty") + return false, err + default: + return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + if src.LowerType != Unbounded { + if null, err := src.Lower.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + if err := pgio.WriteByte(w, ','); err != nil { + return false, err + } + + if src.UpperType != Unbounded { + if null, err := src.Upper.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, ')'); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, ']'); err != nil { + return false, err + } + default: + return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return false, nil +} + +func (src Int4range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + err := pgio.WriteByte(w, emptyMask) + return false, err + default: + return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + if err := pgio.WriteByte(w, rangeType); err != nil { + return false, err + } + + valBuf := &bytes.Buffer{} + + if src.LowerType != Unbounded { + null, err := src.Lower.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + if src.UpperType != Unbounded { + null, err := src.Upper.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + return false, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int4range) Scan(src interface{}) error { + if src == nil { + *dst = Int4range{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int4range) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/pgtype/int4range_test.go b/pgtype/int4range_test.go new file mode 100644 index 00000000..c96fe9cd --- /dev/null +++ b/pgtype/int4range_test.go @@ -0,0 +1,25 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInt4rangeTranscode(t *testing.T) { + testSuccessfulTranscode(t, "int4range", []interface{}{ + pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + pgtype.Int4range{Lower: pgtype.Int4{Int: 1, Status: pgtype.Present}, Upper: pgtype.Int4{Int: 10, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + pgtype.Int4range{Lower: pgtype.Int4{Int: -42, Status: pgtype.Present}, Upper: pgtype.Int4{Int: -5, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + pgtype.Int4range{Status: pgtype.Null}, + }) +} + +func TestInt4rangeNormalize(t *testing.T) { + testSuccessfulNormalize(t, []normalizeTest{ + { + sql: "select int4range(1, 10, '(]')", + value: pgtype.Int4range{Lower: pgtype.Int4{Int: 2, Status: pgtype.Present}, Upper: pgtype.Int4{Int: 11, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + }, + }) +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 7e6633d9..7a95994c 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -233,6 +233,7 @@ func init() { "inet": &Inet{}, "int2": &Int2{}, "int4": &Int4{}, + "int4range": &Int4range{}, "int8": &Int8{}, "json": &Json{}, "jsonb": &Jsonb{}, diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 16cabfd1..298cff64 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -189,6 +189,99 @@ func testDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeNa t.Errorf("%v %d: %v", driverName, i, err) } + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) + } + } +} + +type normalizeTest struct { + sql string + value interface{} +} + +func testSuccessfulNormalize(t testing.TB, tests []normalizeTest) { + testSuccessfulNormalizeEqFunc(t, tests, func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + }) +} + +func testSuccessfulNormalizeEqFunc(t testing.TB, tests []normalizeTest, eqFunc func(a, b interface{}) bool) { + testPgxSuccessfulNormalizeEqFunc(t, tests, eqFunc) + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + testDatabaseSQLSuccessfulNormalizeEqFunc(t, driverName, tests, eqFunc) + } +} + +func testPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []normalizeTest, eqFunc func(a, b interface{}) bool) { + conn := mustConnectPgx(t) + defer mustClose(t, conn) + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for i, tt := range tests { + for _, fc := range formats { + psName := fmt.Sprintf("test%d", i) + ps, err := conn.Prepare(psName, tt.sql) + if err != nil { + t.Fatal(err) + } + + ps.FieldDescriptions[0].FormatCode = fc.formatCode + if forceEncoder(tt.value, fc.formatCode) == nil { + t.Logf("Skipping: %#v does not implement %v", tt.value, fc.name) + continue + } + // Derefence value if it is a pointer + derefV := tt.value + refVal := reflect.ValueOf(tt.value) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err = conn.QueryRow(psName).Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", fc.name, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface()) + } + } + } +} + +func testDatabaseSQLSuccessfulNormalizeEqFunc(t testing.TB, driverName string, tests []normalizeTest, eqFunc func(a, b interface{}) bool) { + conn := mustConnectDatabaseSQL(t, driverName) + defer mustClose(t, conn) + + for i, tt := range tests { + ps, err := conn.Prepare(tt.sql) + if err != nil { + t.Errorf("%d. %v", i, err) + continue + } + + // Derefence value if it is a pointer + derefV := tt.value + refVal := reflect.ValueOf(tt.value) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err = ps.QueryRow().Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", driverName, i, err) + } + if !eqFunc(result.Elem().Interface(), derefV) { t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) } diff --git a/pgtype/range.go b/pgtype/range.go new file mode 100644 index 00000000..76daf8cc --- /dev/null +++ b/pgtype/range.go @@ -0,0 +1,273 @@ +package pgtype + +import ( + "bytes" + "encoding/binary" + "fmt" +) + +type BoundType byte + +const ( + Inclusive = BoundType('i') + Exclusive = BoundType('e') + Unbounded = BoundType('U') + Empty = BoundType('E') +) + +type UntypedTextRange struct { + Lower string + Upper string + LowerType BoundType + UpperType BoundType +} + +func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { + utr := &UntypedTextRange{} + if src == "empty" { + utr.LowerType = 'E' + utr.UpperType = 'E' + return utr, nil + } + + buf := bytes.NewBufferString(src) + + skipWhitespace(buf) + + r, _, err := buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid lower bound: %v", err) + } + switch r { + case '(': + utr.LowerType = Exclusive + case '[': + utr.LowerType = Inclusive + default: + return nil, fmt.Errorf("missing lower bound, instead got: %v", string(r)) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid lower value: %v", err) + } + buf.UnreadRune() + + if r == ',' { + utr.LowerType = Unbounded + } else { + utr.Lower, err = rangeParseValue(buf) + if err != nil { + return nil, fmt.Errorf("invalid lower value: %v", err) + } + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("missing range separator: %v", err) + } + if r != ',' { + return nil, fmt.Errorf("missing range separator: %v", r) + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("invalid upper value: %v", err) + } + buf.UnreadRune() + + if r == ')' || r == ']' { + utr.UpperType = Unbounded + } else { + utr.Upper, err = rangeParseValue(buf) + if err != nil { + return nil, fmt.Errorf("invalid upper value: %v", err) + } + } + + r, _, err = buf.ReadRune() + if err != nil { + return nil, fmt.Errorf("missing upper bound: %v", err) + } + switch r { + case ')': + utr.UpperType = Exclusive + case ']': + utr.UpperType = Inclusive + default: + return nil, fmt.Errorf("missing upper bound, instead got: %v", string(r)) + } + + skipWhitespace(buf) + + if buf.Len() > 0 { + return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) + } + + return utr, nil +} + +func rangeParseValue(buf *bytes.Buffer) (string, error) { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + if r == '"' { + return rangeParseQuotedValue(buf) + } + buf.UnreadRune() + + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case '\\': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + case ',', '[', ']', '(', ')': + buf.UnreadRune() + return s.String(), nil + } + + s.WriteRune(r) + } +} + +func rangeParseQuotedValue(buf *bytes.Buffer) (string, error) { + s := &bytes.Buffer{} + + for { + r, _, err := buf.ReadRune() + if err != nil { + return "", err + } + + switch r { + case '\\': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + case '"': + r, _, err = buf.ReadRune() + if err != nil { + return "", err + } + if r != '"' { + buf.UnreadRune() + return s.String(), nil + } + } + s.WriteRune(r) + } +} + +type UntypedBinaryRange struct { + Lower []byte + Upper []byte + LowerType BoundType + UpperType BoundType +} + +// 0 = () = 00000 +// 1 = empty = 00001 +// 2 = [) = 00010 +// 4 = (] = 00100 +// 6 = [] = 00110 +// 8 = ) = 01000 +// 12 = ] = 01100 +// 16 = ( = 10000 +// 18 = [ = 10010 +// 24 = = 11000 + +const emptyMask = 1 +const lowerInclusiveMask = 2 +const upperInclusiveMask = 4 +const lowerUnboundedMask = 8 +const upperUnboundedMask = 16 + +func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { + ubr := &UntypedBinaryRange{} + + if len(src) == 0 { + return nil, fmt.Errorf("range too short: %v", len(src)) + } + + rangeType := src[0] + rp := 1 + + if rangeType&emptyMask > 0 { + if len(src[rp:]) > 0 { + return nil, fmt.Errorf("unexpected trailing bytes parsing empty range: %v", len(src[rp:])) + } + ubr.LowerType = Empty + ubr.UpperType = Empty + return ubr, nil + } + + if rangeType&lowerInclusiveMask > 0 { + ubr.LowerType = Inclusive + } else if rangeType&lowerUnboundedMask > 0 { + ubr.LowerType = Unbounded + } else { + ubr.LowerType = Exclusive + } + + if rangeType&upperInclusiveMask > 0 { + ubr.UpperType = Inclusive + } else if rangeType&upperUnboundedMask > 0 { + ubr.UpperType = Unbounded + } else { + ubr.UpperType = Exclusive + } + + if ubr.LowerType == Unbounded && ubr.UpperType == Unbounded { + if len(src[rp:]) > 0 { + return nil, fmt.Errorf("unexpected trailing bytes parsing unbounded range: %v", len(src[rp:])) + } + return ubr, nil + } + + if len(src[rp:]) < 4 { + return nil, fmt.Errorf("too few bytes for size: %v", src[rp:]) + } + valueLen := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + val := src[rp : rp+valueLen] + rp += valueLen + + if ubr.LowerType != Unbounded { + ubr.Lower = val + } else { + ubr.Upper = val + if len(src[rp:]) > 0 { + return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) + } + return ubr, nil + } + + if ubr.UpperType != Unbounded { + if len(src[rp:]) < 4 { + return nil, fmt.Errorf("too few bytes for size: %v", src[rp:]) + } + valueLen := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + ubr.Upper = src[rp : rp+valueLen] + rp += valueLen + } + + if len(src[rp:]) > 0 { + return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) + } + + return ubr, nil + +} diff --git a/pgtype/range_test.go b/pgtype/range_test.go new file mode 100644 index 00000000..9e16df59 --- /dev/null +++ b/pgtype/range_test.go @@ -0,0 +1,177 @@ +package pgtype + +import ( + "bytes" + "testing" +) + +func TestParseUntypedTextRange(t *testing.T) { + tests := []struct { + src string + result UntypedTextRange + err error + }{ + { + src: `[1,2)`, + result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `[1,2]`, + result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Inclusive}, + err: nil, + }, + { + src: `(1,3)`, + result: UntypedTextRange{Lower: "1", Upper: "3", LowerType: Exclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: ` [1,2) `, + result: UntypedTextRange{Lower: "1", Upper: "2", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `[ foo , bar )`, + result: UntypedTextRange{Lower: " foo ", Upper: " bar ", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["foo","bar")`, + result: UntypedTextRange{Lower: "foo", Upper: "bar", LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["f""oo","b""ar")`, + result: UntypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["f""oo","b""ar")`, + result: UntypedTextRange{Lower: `f"oo`, Upper: `b"ar`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `["","bar")`, + result: UntypedTextRange{Lower: ``, Upper: `bar`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `[f\"oo\,,b\\ar\))`, + result: UntypedTextRange{Lower: `f"oo,`, Upper: `b\ar)`, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: `empty`, + result: UntypedTextRange{Lower: "", Upper: "", LowerType: Empty, UpperType: Empty}, + err: nil, + }, + } + + for i, tt := range tests { + r, err := ParseUntypedTextRange(tt.src) + if err != tt.err { + t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err) + continue + } + + if r.LowerType != tt.result.LowerType { + t.Errorf("%d. `%v`: expected result lower type %v, got %v", i, tt.src, string(tt.result.LowerType), string(r.LowerType)) + } + + if r.UpperType != tt.result.UpperType { + t.Errorf("%d. `%v`: expected result upper type %v, got %v", i, tt.src, string(tt.result.UpperType), string(r.UpperType)) + } + + if r.Lower != tt.result.Lower { + t.Errorf("%d. `%v`: expected result lower %v, got %v", i, tt.src, tt.result.Lower, r.Lower) + } + + if r.Upper != tt.result.Upper { + t.Errorf("%d. `%v`: expected result upper %v, got %v", i, tt.src, tt.result.Upper, r.Upper) + } + } +} + +func TestParseUntypedBinaryRange(t *testing.T) { + tests := []struct { + src []byte + result UntypedBinaryRange + err error + }{ + { + src: []byte{0, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: []byte{1}, + result: UntypedBinaryRange{Lower: nil, Upper: nil, LowerType: Empty, UpperType: Empty}, + err: nil, + }, + { + src: []byte{2, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Exclusive}, + err: nil, + }, + { + src: []byte{4, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Exclusive, UpperType: Inclusive}, + err: nil, + }, + { + src: []byte{6, 0, 0, 0, 2, 0, 4, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: []byte{0, 5}, LowerType: Inclusive, UpperType: Inclusive}, + err: nil, + }, + { + src: []byte{8, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Exclusive}, + err: nil, + }, + { + src: []byte{12, 0, 0, 0, 2, 0, 5}, + result: UntypedBinaryRange{Lower: nil, Upper: []byte{0, 5}, LowerType: Unbounded, UpperType: Inclusive}, + err: nil, + }, + { + src: []byte{16, 0, 0, 0, 2, 0, 4}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Exclusive, UpperType: Unbounded}, + err: nil, + }, + { + src: []byte{18, 0, 0, 0, 2, 0, 4}, + result: UntypedBinaryRange{Lower: []byte{0, 4}, Upper: nil, LowerType: Inclusive, UpperType: Unbounded}, + err: nil, + }, + { + src: []byte{24}, + result: UntypedBinaryRange{Lower: nil, Upper: nil, LowerType: Unbounded, UpperType: Unbounded}, + err: nil, + }, + } + + for i, tt := range tests { + r, err := ParseUntypedBinaryRange(tt.src) + if err != tt.err { + t.Errorf("%d. `%v`: expected err %v, got %v", i, tt.src, tt.err, err) + continue + } + + if r.LowerType != tt.result.LowerType { + t.Errorf("%d. `%v`: expected result lower type %v, got %v", i, tt.src, string(tt.result.LowerType), string(r.LowerType)) + } + + if r.UpperType != tt.result.UpperType { + t.Errorf("%d. `%v`: expected result upper type %v, got %v", i, tt.src, string(tt.result.UpperType), string(r.UpperType)) + } + + if bytes.Compare(r.Lower, tt.result.Lower) != 0 { + t.Errorf("%d. `%v`: expected result lower %v, got %v", i, tt.src, tt.result.Lower, r.Lower) + } + + if bytes.Compare(r.Upper, tt.result.Upper) != 0 { + t.Errorf("%d. `%v`: expected result upper %v, got %v", i, tt.src, tt.result.Upper, r.Upper) + } + } +} From d7973d87ddad5713b22191c75163870fd784d852 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 24 Mar 2017 13:27:04 -0500 Subject: [PATCH 139/264] Fix TestParseEnvLibpq when PGSSLMODE is set --- conn_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conn_test.go b/conn_test.go index 13367c6a..4550e63a 100644 --- a/conn_test.go +++ b/conn_test.go @@ -531,7 +531,7 @@ func TestParseDSN(t *testing.T) { } func TestParseEnvLibpq(t *testing.T) { - pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME"} + pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE"} savedEnv := make(map[string]string) for _, n := range pgEnvvars { From 7312fb20e8702393e5da6038dc4d87e41921a6be Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 24 Mar 2017 13:36:10 -0500 Subject: [PATCH 140/264] Add Int8range Add code generation for ranges --- pgtype/int8range.go | 268 ++++++++++++++++++++++++++++++++++++++ pgtype/int8range_test.go | 25 ++++ pgtype/typed_range.go.erb | 268 ++++++++++++++++++++++++++++++++++++++ pgtype/typed_range_gen.sh | 3 + 4 files changed, 564 insertions(+) create mode 100644 pgtype/int8range.go create mode 100644 pgtype/int8range_test.go create mode 100644 pgtype/typed_range.go.erb create mode 100644 pgtype/typed_range_gen.sh diff --git a/pgtype/int8range.go b/pgtype/int8range.go new file mode 100644 index 00000000..44946be9 --- /dev/null +++ b/pgtype/int8range.go @@ -0,0 +1,268 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type Int8range struct { + Lower Int8 + Upper Int8 + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *Int8range) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Int8range", src) +} + +func (dst *Int8range) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Int8range) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Int8range) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int8range{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Int8range{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Int8range) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Int8range{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Int8range{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Int8range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, '('); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, '['); err != nil { + return false, err + } + case Empty: + _, err := io.WriteString(w, "empty") + return false, err + default: + return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + if src.LowerType != Unbounded { + if null, err := src.Lower.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + if err := pgio.WriteByte(w, ','); err != nil { + return false, err + } + + if src.UpperType != Unbounded { + if null, err := src.Upper.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, ')'); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, ']'); err != nil { + return false, err + } + default: + return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return false, nil +} + +func (src Int8range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + err := pgio.WriteByte(w, emptyMask) + return false, err + default: + return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + if err := pgio.WriteByte(w, rangeType); err != nil { + return false, err + } + + valBuf := &bytes.Buffer{} + + if src.LowerType != Unbounded { + null, err := src.Lower.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + if src.UpperType != Unbounded { + null, err := src.Upper.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + return false, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Int8range) Scan(src interface{}) error { + if src == nil { + *dst = Int8range{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int8range) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/pgtype/int8range_test.go b/pgtype/int8range_test.go new file mode 100644 index 00000000..1b3e594c --- /dev/null +++ b/pgtype/int8range_test.go @@ -0,0 +1,25 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestInt8rangeTranscode(t *testing.T) { + testSuccessfulTranscode(t, "Int8range", []interface{}{ + pgtype.Int8range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + pgtype.Int8range{Lower: pgtype.Int8{Int: 1, Status: pgtype.Present}, Upper: pgtype.Int8{Int: 10, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + pgtype.Int8range{Lower: pgtype.Int8{Int: -42, Status: pgtype.Present}, Upper: pgtype.Int8{Int: -5, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + pgtype.Int8range{Status: pgtype.Null}, + }) +} + +func TestInt8rangeNormalize(t *testing.T) { + testSuccessfulNormalize(t, []normalizeTest{ + { + sql: "select Int8range(1, 10, '(]')", + value: pgtype.Int8range{Lower: pgtype.Int8{Int: 2, Status: pgtype.Present}, Upper: pgtype.Int8{Int: 11, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + }, + }) +} diff --git a/pgtype/typed_range.go.erb b/pgtype/typed_range.go.erb new file mode 100644 index 00000000..922b98b4 --- /dev/null +++ b/pgtype/typed_range.go.erb @@ -0,0 +1,268 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type <%= range_type %> struct { + Lower <%= element_type %> + Upper <%= element_type %> + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *<%= range_type %>) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to <%= range_type %>", src) +} + +func (dst *<%= range_type %>) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *<%= range_type %>) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *<%= range_type %>) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = <%= range_type %>{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = <%= range_type %>{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *<%= range_type %>) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = <%= range_type %>{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = <%= range_type %>{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src <%= range_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, '('); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, '['); err != nil { + return false, err + } + case Empty: + _, err := io.WriteString(w, "empty") + return false, err + default: + return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + if src.LowerType != Unbounded { + if null, err := src.Lower.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + if err := pgio.WriteByte(w, ','); err != nil { + return false, err + } + + if src.UpperType != Unbounded { + if null, err := src.Upper.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, ')'); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, ']'); err != nil { + return false, err + } + default: + return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return false, nil +} + +func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + err := pgio.WriteByte(w, emptyMask) + return false, err + default: + return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + if err := pgio.WriteByte(w, rangeType); err != nil { + return false, err + } + + valBuf := &bytes.Buffer{} + + if src.LowerType != Unbounded { + null, err := src.Lower.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + if src.UpperType != Unbounded { + null, err := src.Upper.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + return false, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *<%= range_type %>) Scan(src interface{}) error { + if src == nil { + *dst = <%= range_type %>{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src <%= range_type %>) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/pgtype/typed_range_gen.sh b/pgtype/typed_range_gen.sh new file mode 100644 index 00000000..af3e2cd1 --- /dev/null +++ b/pgtype/typed_range_gen.sh @@ -0,0 +1,3 @@ +erb range_type=Int4range element_type=Int4 typed_range.go.erb > int4range.go +erb range_type=Int8range element_type=Int8 typed_range.go.erb > int8range.go +goimports -w *range.go From fffeb1d5dc8694b8e89436ce53d1b2f02fdd501c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 24 Mar 2017 14:17:49 -0500 Subject: [PATCH 141/264] Add daterange, tsrange, and tstzrange --- pgtype/daterange.go | 268 ++++++++++++++++++++++++++++++++++++++ pgtype/daterange_test.go | 66 ++++++++++ pgtype/pgtype.go | 4 + pgtype/tsrange.go | 268 ++++++++++++++++++++++++++++++++++++++ pgtype/tsrange_test.go | 40 ++++++ pgtype/tstzrange.go | 268 ++++++++++++++++++++++++++++++++++++++ pgtype/tstzrange_test.go | 40 ++++++ pgtype/typed_range_gen.sh | 3 + 8 files changed, 957 insertions(+) create mode 100644 pgtype/daterange.go create mode 100644 pgtype/daterange_test.go create mode 100644 pgtype/tsrange.go create mode 100644 pgtype/tsrange_test.go create mode 100644 pgtype/tstzrange.go create mode 100644 pgtype/tstzrange_test.go diff --git a/pgtype/daterange.go b/pgtype/daterange.go new file mode 100644 index 00000000..fbf51980 --- /dev/null +++ b/pgtype/daterange.go @@ -0,0 +1,268 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type Daterange struct { + Lower Date + Upper Date + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *Daterange) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Daterange", src) +} + +func (dst *Daterange) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Daterange) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Daterange) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Daterange{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Daterange{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Daterange) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Daterange{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Daterange{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Daterange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, '('); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, '['); err != nil { + return false, err + } + case Empty: + _, err := io.WriteString(w, "empty") + return false, err + default: + return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + if src.LowerType != Unbounded { + if null, err := src.Lower.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + if err := pgio.WriteByte(w, ','); err != nil { + return false, err + } + + if src.UpperType != Unbounded { + if null, err := src.Upper.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, ')'); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, ']'); err != nil { + return false, err + } + default: + return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return false, nil +} + +func (src Daterange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + err := pgio.WriteByte(w, emptyMask) + return false, err + default: + return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + if err := pgio.WriteByte(w, rangeType); err != nil { + return false, err + } + + valBuf := &bytes.Buffer{} + + if src.LowerType != Unbounded { + null, err := src.Lower.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + if src.UpperType != Unbounded { + null, err := src.Upper.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + return false, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Daterange) Scan(src interface{}) error { + if src == nil { + *dst = Daterange{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Daterange) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/pgtype/daterange_test.go b/pgtype/daterange_test.go new file mode 100644 index 00000000..8501cc7e --- /dev/null +++ b/pgtype/daterange_test.go @@ -0,0 +1,66 @@ +package pgtype_test + +import ( + "testing" + "time" + + "github.com/jackc/pgx/pgtype" +) + +func TestDaterangeTranscode(t *testing.T) { + testSuccessfulTranscodeEqFunc(t, "daterange", []interface{}{ + pgtype.Daterange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + pgtype.Daterange{Status: pgtype.Null}, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.Daterange) + b := bb.(pgtype.Daterange) + + return a.Status == b.Status && + a.Lower.Time.Equal(b.Lower.Time) && + a.Lower.Status == b.Lower.Status && + a.Lower.InfinityModifier == b.Lower.InfinityModifier && + a.Upper.Time.Equal(b.Upper.Time) && + a.Upper.Status == b.Upper.Status && + a.Upper.InfinityModifier == b.Upper.InfinityModifier + }) +} + +func TestDaterangeNormalize(t *testing.T) { + testSuccessfulNormalizeEqFunc(t, []normalizeTest{ + { + sql: "select daterange('2010-01-01', '2010-01-11', '(]')", + value: pgtype.Daterange{ + Lower: pgtype.Date{Time: time.Date(2010, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Date{Time: time.Date(2010, 1, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + }, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.Daterange) + b := bb.(pgtype.Daterange) + + return a.Status == b.Status && + a.Lower.Time.Equal(b.Lower.Time) && + a.Lower.Status == b.Lower.Status && + a.Lower.InfinityModifier == b.Lower.InfinityModifier && + a.Upper.Time.Equal(b.Upper.Time) && + a.Upper.Status == b.Upper.Status && + a.Upper.InfinityModifier == b.Upper.InfinityModifier + }) +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 7a95994c..3d691044 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -227,6 +227,7 @@ func init() { "cid": &Cid{}, "cidr": &Cidr{}, "date": &Date{}, + "daterange": &Daterange{}, "float4": &Float4{}, "float8": &Float8{}, "hstore": &Hstore{}, @@ -235,6 +236,7 @@ func init() { "int4": &Int4{}, "int4range": &Int4range{}, "int8": &Int8{}, + "int8range": &Int8range{}, "json": &Json{}, "jsonb": &Jsonb{}, "name": &Name{}, @@ -244,6 +246,8 @@ func init() { "tid": &Tid{}, "timestamp": &Timestamp{}, "timestamptz": &Timestamptz{}, + "tsrange": &Tsrange{}, + "tstzrange": &Tstzrange{}, "unknown": &Unknown{}, "varchar": &Varchar{}, "xid": &Xid{}, diff --git a/pgtype/tsrange.go b/pgtype/tsrange.go new file mode 100644 index 00000000..48992829 --- /dev/null +++ b/pgtype/tsrange.go @@ -0,0 +1,268 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type Tsrange struct { + Lower Timestamp + Upper Timestamp + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *Tsrange) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Tsrange", src) +} + +func (dst *Tsrange) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Tsrange) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Tsrange) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Tsrange{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Tsrange{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Tsrange) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Tsrange{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Tsrange{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Tsrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, '('); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, '['); err != nil { + return false, err + } + case Empty: + _, err := io.WriteString(w, "empty") + return false, err + default: + return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + if src.LowerType != Unbounded { + if null, err := src.Lower.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + if err := pgio.WriteByte(w, ','); err != nil { + return false, err + } + + if src.UpperType != Unbounded { + if null, err := src.Upper.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, ')'); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, ']'); err != nil { + return false, err + } + default: + return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return false, nil +} + +func (src Tsrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + err := pgio.WriteByte(w, emptyMask) + return false, err + default: + return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + if err := pgio.WriteByte(w, rangeType); err != nil { + return false, err + } + + valBuf := &bytes.Buffer{} + + if src.LowerType != Unbounded { + null, err := src.Lower.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + if src.UpperType != Unbounded { + null, err := src.Upper.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + return false, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Tsrange) Scan(src interface{}) error { + if src == nil { + *dst = Tsrange{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Tsrange) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/pgtype/tsrange_test.go b/pgtype/tsrange_test.go new file mode 100644 index 00000000..448cb92f --- /dev/null +++ b/pgtype/tsrange_test.go @@ -0,0 +1,40 @@ +package pgtype_test + +import ( + "testing" + "time" + + "github.com/jackc/pgx/pgtype" +) + +func TestTsrangeTranscode(t *testing.T) { + testSuccessfulTranscodeEqFunc(t, "tsrange", []interface{}{ + pgtype.Tsrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + pgtype.Tsrange{ + Lower: pgtype.Timestamp{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Timestamp{Time: time.Date(2028, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + pgtype.Tsrange{ + Lower: pgtype.Timestamp{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + pgtype.Tsrange{Status: pgtype.Null}, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.Tsrange) + b := bb.(pgtype.Tsrange) + + return a.Status == b.Status && + a.Lower.Time.Equal(b.Lower.Time) && + a.Lower.Status == b.Lower.Status && + a.Lower.InfinityModifier == b.Lower.InfinityModifier && + a.Upper.Time.Equal(b.Upper.Time) && + a.Upper.Status == b.Upper.Status && + a.Upper.InfinityModifier == b.Upper.InfinityModifier + }) +} diff --git a/pgtype/tstzrange.go b/pgtype/tstzrange.go new file mode 100644 index 00000000..61e94ab4 --- /dev/null +++ b/pgtype/tstzrange.go @@ -0,0 +1,268 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type Tstzrange struct { + Lower Timestamptz + Upper Timestamptz + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *Tstzrange) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Tstzrange", src) +} + +func (dst *Tstzrange) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Tstzrange) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Tstzrange) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Tstzrange{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Tstzrange{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Tstzrange) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Tstzrange{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Tstzrange{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Tstzrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, '('); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, '['); err != nil { + return false, err + } + case Empty: + _, err := io.WriteString(w, "empty") + return false, err + default: + return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + if src.LowerType != Unbounded { + if null, err := src.Lower.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + if err := pgio.WriteByte(w, ','); err != nil { + return false, err + } + + if src.UpperType != Unbounded { + if null, err := src.Upper.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, ')'); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, ']'); err != nil { + return false, err + } + default: + return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return false, nil +} + +func (src Tstzrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + err := pgio.WriteByte(w, emptyMask) + return false, err + default: + return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + if err := pgio.WriteByte(w, rangeType); err != nil { + return false, err + } + + valBuf := &bytes.Buffer{} + + if src.LowerType != Unbounded { + null, err := src.Lower.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + if src.UpperType != Unbounded { + null, err := src.Upper.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + return false, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Tstzrange) Scan(src interface{}) error { + if src == nil { + *dst = Tstzrange{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Tstzrange) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/pgtype/tstzrange_test.go b/pgtype/tstzrange_test.go new file mode 100644 index 00000000..197aabbc --- /dev/null +++ b/pgtype/tstzrange_test.go @@ -0,0 +1,40 @@ +package pgtype_test + +import ( + "testing" + "time" + + "github.com/jackc/pgx/pgtype" +) + +func TestTstzrangeTranscode(t *testing.T) { + testSuccessfulTranscodeEqFunc(t, "tstzrange", []interface{}{ + pgtype.Tstzrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + pgtype.Tstzrange{ + Lower: pgtype.Timestamptz{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Timestamptz{Time: time.Date(2028, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + pgtype.Tstzrange{ + Lower: pgtype.Timestamptz{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + Upper: pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + pgtype.Tstzrange{Status: pgtype.Null}, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.Tstzrange) + b := bb.(pgtype.Tstzrange) + + return a.Status == b.Status && + a.Lower.Time.Equal(b.Lower.Time) && + a.Lower.Status == b.Lower.Status && + a.Lower.InfinityModifier == b.Lower.InfinityModifier && + a.Upper.Time.Equal(b.Upper.Time) && + a.Upper.Status == b.Upper.Status && + a.Upper.InfinityModifier == b.Upper.InfinityModifier + }) +} diff --git a/pgtype/typed_range_gen.sh b/pgtype/typed_range_gen.sh index af3e2cd1..b4220f09 100644 --- a/pgtype/typed_range_gen.sh +++ b/pgtype/typed_range_gen.sh @@ -1,3 +1,6 @@ erb range_type=Int4range element_type=Int4 typed_range.go.erb > int4range.go erb range_type=Int8range element_type=Int8 typed_range.go.erb > int8range.go +erb range_type=Tsrange element_type=Timestamp typed_range.go.erb > tsrange.go +erb range_type=Tstzrange element_type=Timestamptz typed_range.go.erb > tstzrange.go +erb range_type=Daterange element_type=Date typed_range.go.erb > daterange.go goimports -w *range.go From 09078d2470d48838df4ea09c531e895bff68cb01 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 31 Mar 2017 20:11:18 -0500 Subject: [PATCH 142/264] Add interval type --- pgtype/interval.go | 271 ++++++++++++++++++++++++++++++++++++++++ pgtype/interval_test.go | 62 +++++++++ 2 files changed, 333 insertions(+) create mode 100644 pgtype/interval.go create mode 100644 pgtype/interval_test.go diff --git a/pgtype/interval.go b/pgtype/interval.go new file mode 100644 index 00000000..7eddb10f --- /dev/null +++ b/pgtype/interval.go @@ -0,0 +1,271 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + "strconv" + "strings" + "time" + + "github.com/jackc/pgx/pgio" +) + +const ( + microsecondsPerSecond = 1000000 + microsecondsPerMinute = 60 * microsecondsPerSecond + microsecondsPerHour = 60 * microsecondsPerMinute +) + +type Interval struct { + Microseconds int64 + Days int32 + Months int32 + Status Status +} + +func (dst *Interval) Set(src interface{}) error { + if src == nil { + *dst = Interval{Status: Null} + return nil + } + + switch value := src.(type) { + case time.Duration: + *dst = Interval{Microseconds: int64(value) / 1000, Status: Present} + default: + if originalSrc, ok := underlyingPtrType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Interval", value) + } + + return nil +} + +func (dst *Interval) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Interval) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *time.Duration: + if src.Days > 0 || src.Months > 0 { + return fmt.Errorf("interval with months or days cannot be decoded into %T", dst) + } + *v = time.Duration(src.Microseconds) * time.Microsecond + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return nullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Interval{Status: Null} + return nil + } + + var microseconds int64 + var days int32 + var months int32 + + parts := strings.Split(string(src), " ") + + for i := 0; i < len(parts)-1; i += 2 { + scalar, err := strconv.ParseInt(parts[i], 10, 64) + if err != nil { + return fmt.Errorf("bad interval format") + } + + switch parts[i+1] { + case "year", "years": + months += int32(scalar * 12) + case "mon", "mons": + months += int32(scalar) + case "day", "days": + days = int32(scalar) + } + } + + if len(parts)%2 == 1 { + timeParts := strings.SplitN(parts[len(parts)-1], ":", 3) + if len(timeParts) != 3 { + return fmt.Errorf("bad interval format") + } + + var negative bool + if timeParts[0][0] == '-' { + negative = true + timeParts[0] = timeParts[0][1:] + } + + hours, err := strconv.ParseInt(timeParts[0], 10, 64) + if err != nil { + return fmt.Errorf("bad interval hour format: %s", hours) + } + + minutes, err := strconv.ParseInt(timeParts[1], 10, 64) + if err != nil { + return fmt.Errorf("bad interval minute format: %s", minutes) + } + + secondParts := strings.SplitN(timeParts[2], ".", 2) + + seconds, err := strconv.ParseInt(secondParts[0], 10, 64) + if err != nil { + return fmt.Errorf("bad interval second format: %s", seconds) + } + + var uSeconds int64 + if len(secondParts) == 2 { + uSeconds, err = strconv.ParseInt(secondParts[1], 10, 64) + if err != nil { + return fmt.Errorf("bad interval decimal format: %s", seconds) + } + + for i := 0; i < 6-len(secondParts[1]); i++ { + uSeconds *= 10 + } + } + + microseconds = hours * microsecondsPerHour + microseconds += minutes * microsecondsPerMinute + microseconds += seconds * microsecondsPerSecond + microseconds += uSeconds + + if negative { + microseconds = -microseconds + } + } + + *dst = Interval{Months: months, Days: days, Microseconds: microseconds, Status: Present} + return nil +} + +func (dst *Interval) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Interval{Status: Null} + return nil + } + + if len(src) != 16 { + return fmt.Errorf("Received an invalid size for a interval: %d", len(src)) + } + + microseconds := int64(binary.BigEndian.Uint64(src)) + days := int32(binary.BigEndian.Uint32(src[8:])) + months := int32(binary.BigEndian.Uint32(src[12:])) + + *dst = Interval{Microseconds: microseconds, Days: days, Months: months, Status: Present} + return nil +} + +func (src Interval) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if src.Months != 0 { + if _, err := io.WriteString(w, strconv.FormatInt(int64(src.Months), 10)); err != nil { + return false, err + } + + if _, err := io.WriteString(w, " mon "); err != nil { + return false, err + } + } + + if src.Days != 0 { + if _, err := io.WriteString(w, strconv.FormatInt(int64(src.Days), 10)); err != nil { + return false, err + } + + if _, err := io.WriteString(w, " day "); err != nil { + return false, err + } + } + + absMicroseconds := src.Microseconds + if absMicroseconds < 0 { + absMicroseconds = -absMicroseconds + + if err := pgio.WriteByte(w, '-'); err != nil { + return false, err + } + } + + hours := absMicroseconds / microsecondsPerHour + minutes := (absMicroseconds % microsecondsPerHour) / microsecondsPerMinute + seconds := (absMicroseconds % microsecondsPerMinute) / microsecondsPerSecond + microseconds := absMicroseconds % microsecondsPerSecond + + timeStr := fmt.Sprintf("%02d:%02d:%02d.%06d", hours, minutes, seconds, microseconds) + + _, err := io.WriteString(w, timeStr) + return false, err +} + +// EncodeBinary encodes src into w. +func (src Interval) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if _, err := pgio.WriteInt64(w, src.Microseconds); err != nil { + return false, err + } + if _, err := pgio.WriteInt32(w, src.Days); err != nil { + return false, err + } + if _, err := pgio.WriteInt32(w, src.Months); err != nil { + return false, err + } + + return false, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Interval) Scan(src interface{}) error { + if src == nil { + *dst = Interval{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Interval) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/pgtype/interval_test.go b/pgtype/interval_test.go new file mode 100644 index 00000000..db9614ef --- /dev/null +++ b/pgtype/interval_test.go @@ -0,0 +1,62 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestIntervalTranscode(t *testing.T) { + testSuccessfulTranscode(t, "interval", []interface{}{ + pgtype.Interval{Microseconds: 1, Status: pgtype.Present}, + pgtype.Interval{Microseconds: 1000000, Status: pgtype.Present}, + pgtype.Interval{Microseconds: 1000001, Status: pgtype.Present}, + pgtype.Interval{Microseconds: 123202800000000, Status: pgtype.Present}, + pgtype.Interval{Days: 1, Status: pgtype.Present}, + pgtype.Interval{Months: 1, Status: pgtype.Present}, + pgtype.Interval{Months: 12, Status: pgtype.Present}, + pgtype.Interval{Months: 13, Days: 15, Microseconds: 1000001, Status: pgtype.Present}, + pgtype.Interval{Microseconds: -1, Status: pgtype.Present}, + pgtype.Interval{Microseconds: -1000000, Status: pgtype.Present}, + pgtype.Interval{Microseconds: -1000001, Status: pgtype.Present}, + pgtype.Interval{Microseconds: -123202800000000, Status: pgtype.Present}, + pgtype.Interval{Days: -1, Status: pgtype.Present}, + pgtype.Interval{Months: -1, Status: pgtype.Present}, + pgtype.Interval{Months: -12, Status: pgtype.Present}, + pgtype.Interval{Months: -13, Days: -15, Microseconds: -1000001, Status: pgtype.Present}, + pgtype.Interval{Status: pgtype.Null}, + }) +} + +func TestIntervalNormalize(t *testing.T) { + testSuccessfulNormalize(t, []normalizeTest{ + { + sql: "select '1 second'::interval", + value: pgtype.Interval{Microseconds: 1000000, Status: pgtype.Present}, + }, + { + sql: "select '1.000001 second'::interval", + value: pgtype.Interval{Microseconds: 1000001, Status: pgtype.Present}, + }, + { + sql: "select '34223 hours'::interval", + value: pgtype.Interval{Microseconds: 123202800000000, Status: pgtype.Present}, + }, + { + sql: "select '1 day'::interval", + value: pgtype.Interval{Days: 1, Status: pgtype.Present}, + }, + { + sql: "select '1 month'::interval", + value: pgtype.Interval{Months: 1, Status: pgtype.Present}, + }, + { + sql: "select '1 year'::interval", + value: pgtype.Interval{Months: 12, Status: pgtype.Present}, + }, + { + sql: "select '-13 mon'::interval", + value: pgtype.Interval{Months: -13, Status: pgtype.Present}, + }, + }) +} From c5d247830cc6409e7e7c898caa1ed6441d3da236 Mon Sep 17 00:00:00 2001 From: James Lawrence Date: Fri, 31 Mar 2017 20:02:17 -0400 Subject: [PATCH 143/264] enable sql.Open to support both DSN and URI based connection strings --- conn.go | 9 +++ conn_test.go | 219 ++++++++++++++++++++++++++++++++++++++++++++++++++ stdlib/sql.go | 9 ++- 3 files changed, 236 insertions(+), 1 deletion(-) diff --git a/conn.go b/conn.go index d79a4e97..6078cca2 100644 --- a/conn.go +++ b/conn.go @@ -550,6 +550,15 @@ func ParseDSN(s string) (ConnConfig, error) { return cp, nil } +// ParseConnectionString parses either a URI or a DSN connection string. +// see ParseURI and ParseDSN for details. +func ParseConnectionString(s string) (ConnConfig, error) { + if strings.HasPrefix(s, "postgres://") || strings.HasPrefix(s, "postgresql://") { + return ParseURI(s) + } + return ParseDSN(s) +} + // ParseEnvLibpq parses the environment like libpq does into a ConnConfig // // See http://www.postgresql.org/docs/9.4/static/libpq-envars.html for details diff --git a/conn_test.go b/conn_test.go index 4550e63a..50ea68f6 100644 --- a/conn_test.go +++ b/conn_test.go @@ -530,6 +530,225 @@ func TestParseDSN(t *testing.T) { } } +func TestParseConnectionString(t *testing.T) { + t.Parallel() + + tests := []struct { + url string + connParams pgx.ConnConfig + }{ + { + url: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + UseFallbackTLS: false, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "postgres://jack:secret@localhost:5432/mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "postgresql://jack:secret@localhost:5432/mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "postgres://jack@localhost:5432/mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "postgres://jack@localhost/mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "postgres://jack@localhost/mydb?application_name=pgxtest&search_path=myschema", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + }, + }, + }, + { + url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=prefer", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user=jack password=secret host=localhost port=5432 dbname=mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user=jack host=localhost port=5432 dbname=mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user=jack host=localhost dbname=mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user=jack host=localhost dbname=mydb application_name=pgxtest search_path=myschema", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + }, + }, + }, + } + + for i, tt := range tests { + connParams, err := pgx.ParseConnectionString(tt.url) + if err != nil { + t.Errorf("%d. Unexpected error from pgx.ParseDSN(%q) => %v", i, tt.url, err) + continue + } + + if !reflect.DeepEqual(connParams, tt.connParams) { + t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams) + } + } +} + func TestParseEnvLibpq(t *testing.T) { pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE"} diff --git a/stdlib/sql.go b/stdlib/sql.go index e3d46cab..000f0fbf 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -7,6 +7,13 @@ // return err // } // +// Or from a DSN string. +// +// db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5432 database=pgx_test sslmode=disable") +// if err != nil { +// return err +// } +// // Or a normal pgx connection pool can be established and the database/sql // connection can be created through stdlib.OpenFromConnPool(). This allows // more control over the connection process (such as TLS), more control @@ -99,7 +106,7 @@ func (d *Driver) Open(name string) (driver.Conn, error) { return &Conn{conn: conn, pool: d.Pool}, nil } - connConfig, err := pgx.ParseURI(name) + connConfig, err := pgx.ParseConnectionString(name) if err != nil { return nil, err } From 5ad2c4e2b9ce871a9577469dd49668c7b01a4e2b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 1 Apr 2017 23:33:04 -0500 Subject: [PATCH 144/264] Add pgtype.Numeric --- pgtype/decimal.go | 35 +++ pgtype/numeric.go | 602 +++++++++++++++++++++++++++++++++++++++++ pgtype/numeric_test.go | 315 +++++++++++++++++++++ pgtype/pgtype.go | 2 + values.go | 10 + 5 files changed, 964 insertions(+) create mode 100644 pgtype/decimal.go create mode 100644 pgtype/numeric.go create mode 100644 pgtype/numeric_test.go diff --git a/pgtype/decimal.go b/pgtype/decimal.go new file mode 100644 index 00000000..728c748e --- /dev/null +++ b/pgtype/decimal.go @@ -0,0 +1,35 @@ +package pgtype + +import ( + "io" +) + +type Decimal Numeric + +func (dst *Decimal) Set(src interface{}) error { + return (*Numeric)(dst).Set(src) +} + +func (dst *Decimal) Get() interface{} { + return (*Numeric)(dst).Get() +} + +func (src *Decimal) AssignTo(dst interface{}) error { + return (*Numeric)(src).AssignTo(dst) +} + +func (dst *Decimal) DecodeText(ci *ConnInfo, src []byte) error { + return (*Numeric)(dst).DecodeText(ci, src) +} + +func (dst *Decimal) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Numeric)(dst).DecodeBinary(ci, src) +} + +func (src *Decimal) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (*Numeric)(src).EncodeText(ci, w) +} + +func (src *Decimal) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (*Numeric)(src).EncodeBinary(ci, w) +} diff --git a/pgtype/numeric.go b/pgtype/numeric.go new file mode 100644 index 00000000..0f3f6529 --- /dev/null +++ b/pgtype/numeric.go @@ -0,0 +1,602 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + "math" + "math/big" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" +) + +// PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000 +const nbase = 10000 + +var big0 *big.Int = big.NewInt(0) +var big10 *big.Int = big.NewInt(10) +var big100 *big.Int = big.NewInt(100) +var big1000 *big.Int = big.NewInt(1000) + +var bigMaxInt8 *big.Int = big.NewInt(math.MaxInt8) +var bigMinInt8 *big.Int = big.NewInt(math.MinInt8) +var bigMaxInt16 *big.Int = big.NewInt(math.MaxInt16) +var bigMinInt16 *big.Int = big.NewInt(math.MinInt16) +var bigMaxInt32 *big.Int = big.NewInt(math.MaxInt32) +var bigMinInt32 *big.Int = big.NewInt(math.MinInt32) +var bigMaxInt64 *big.Int = big.NewInt(math.MaxInt64) +var bigMinInt64 *big.Int = big.NewInt(math.MinInt64) +var bigMaxInt *big.Int = big.NewInt(int64(maxInt)) +var bigMinInt *big.Int = big.NewInt(int64(minInt)) + +var bigMaxUint8 *big.Int = big.NewInt(math.MaxUint8) +var bigMaxUint16 *big.Int = big.NewInt(math.MaxUint16) +var bigMaxUint32 *big.Int = big.NewInt(math.MaxUint32) +var bigMaxUint64 *big.Int = (&big.Int{}).SetUint64(uint64(math.MaxUint64)) +var bigMaxUint *big.Int = (&big.Int{}).SetUint64(uint64(maxUint)) + +var bigNBase *big.Int = big.NewInt(nbase) +var bigNBaseX2 *big.Int = big.NewInt(nbase * nbase) +var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase) +var bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase) + +type Numeric struct { + Int *big.Int + Exp int32 + Status Status +} + +func (dst *Numeric) Set(src interface{}) error { + if src == nil { + *dst = Numeric{Status: Null} + return nil + } + + switch value := src.(type) { + case float32: + num, exp, err := parseNumericString(strconv.FormatFloat(float64(value), 'f', -1, 64)) + if err != nil { + return err + } + *dst = Numeric{Int: num, Exp: exp, Status: Present} + case float64: + num, exp, err := parseNumericString(strconv.FormatFloat(value, 'f', -1, 64)) + if err != nil { + return err + } + *dst = Numeric{Int: num, Exp: exp, Status: Present} + case int8: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case uint8: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case int16: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case uint16: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case int32: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case uint32: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case int64: + *dst = Numeric{Int: big.NewInt(value), Status: Present} + case uint64: + *dst = Numeric{Int: (&big.Int{}).SetUint64(value), Status: Present} + case int: + *dst = Numeric{Int: big.NewInt(int64(value)), Status: Present} + case uint: + *dst = Numeric{Int: (&big.Int{}).SetUint64(uint64(value)), Status: Present} + case string: + num, exp, err := parseNumericString(value) + if err != nil { + return err + } + *dst = Numeric{Int: num, Exp: exp, Status: Present} + default: + if originalSrc, ok := underlyingNumberType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Numeric", value) + } + + return nil +} + +func (dst *Numeric) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Numeric) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *float32: + f, err := strconv.ParseFloat(src.Int.String(), 64) + if err != nil { + return err + } + return float64AssignTo(f, src.Status, dst) + case *float64: + f, err := strconv.ParseFloat(src.Int.String(), 64) + if err != nil { + return err + } + return float64AssignTo(f, src.Status, dst) + case *int: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = int(normalizedInt.Int64()) + case *int8: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt8) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt8) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = int8(normalizedInt.Int64()) + case *int16: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt16) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt16) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = int16(normalizedInt.Int64()) + case *int32: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt32) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt32) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = int32(normalizedInt.Int64()) + case *int64: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(bigMaxInt64) > 0 { + return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + } + if normalizedInt.Cmp(bigMinInt64) < 0 { + return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + } + *v = normalizedInt.Int64() + case *uint: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = uint(normalizedInt.Uint64()) + case *uint8: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint8) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = uint8(normalizedInt.Uint64()) + case *uint16: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint16) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = uint16(normalizedInt.Uint64()) + case *uint32: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint32) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = uint32(normalizedInt.Uint64()) + case *uint64: + normalizedInt, err := src.toBigInt() + if err != nil { + return err + } + if normalizedInt.Cmp(big0) < 0 { + return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + } else if normalizedInt.Cmp(bigMaxUint64) > 0 { + return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + } + *v = normalizedInt.Uint64() + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return nullAssignTo(dst) + } + + return nil +} + +func (dst *Numeric) toBigInt() (*big.Int, error) { + if dst.Exp == 0 { + return dst.Int, nil + } + + num := &big.Int{} + num.Set(dst.Int) + if dst.Exp > 0 { + mul := &big.Int{} + mul.Exp(big10, big.NewInt(int64(dst.Exp)), nil) + num.Mul(num, mul) + return num, nil + } + + div := &big.Int{} + div.Exp(big10, big.NewInt(int64(-dst.Exp)), nil) + remainder := &big.Int{} + num.DivMod(num, div, remainder) + if remainder.Cmp(big0) != 0 { + return nil, fmt.Errorf("cannot convert %v to integer", dst) + } + return num, nil +} + +func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Numeric{Status: Null} + return nil + } + + num, exp, err := parseNumericString(string(src)) + if err != nil { + return err + } + + *dst = Numeric{Int: num, Exp: exp, Status: Present} + return nil +} + +func parseNumericString(str string) (n *big.Int, exp int32, err error) { + parts := strings.SplitN(str, ".", 2) + digits := strings.Join(parts, "") + + if len(parts) > 1 { + exp = int32(-len(parts[1])) + } else { + for len(digits) > 1 && digits[len(digits)-1] == '0' { + digits = digits[:len(digits)-1] + exp++ + } + } + + accum := &big.Int{} + if _, ok := accum.SetString(digits, 10); !ok { + return nil, 0, fmt.Errorf("%s is not a number", str) + } + + return accum, exp, nil +} + +func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Numeric{Status: Null} + return nil + } + + if len(src) < 8 { + return fmt.Errorf("numeric incomplete %v", src) + } + + rp := 0 + ndigits := int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + if ndigits == 0 { + *dst = Numeric{Int: big.NewInt(0), Status: Present} + return nil + } + + weight := int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + sign := int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + dscale := int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + if len(src[rp:]) < int(ndigits)*2 { + return fmt.Errorf("numeric incomplete %v", src) + } + + accum := &big.Int{} + + for i := 0; i < int(ndigits+3)/4; i++ { + int64accum, bytesRead, digitsRead := nbaseDigitsToInt64(src[rp:]) + rp += bytesRead + + if i > 0 { + var mul *big.Int + switch digitsRead { + case 1: + mul = bigNBase + case 2: + mul = bigNBaseX2 + case 3: + mul = bigNBaseX3 + case 4: + mul = bigNBaseX4 + default: + return fmt.Errorf("invalid digitsRead: %d (this can't happen)", digitsRead) + } + accum.Mul(accum, mul) + } + + accum.Add(accum, big.NewInt(int64accum)) + } + + exp := (int32(weight) - int32(ndigits) + 1) * 4 + + if dscale > 0 { + fracNBaseDigits := ndigits - weight - 1 + fracDecimalDigits := fracNBaseDigits * 4 + + if dscale > fracDecimalDigits { + multCount := int(dscale - fracDecimalDigits) + for i := 0; i < multCount; i++ { + accum.Mul(accum, big10) + exp-- + } + } else if dscale < fracDecimalDigits { + divCount := int(fracDecimalDigits - dscale) + for i := 0; i < divCount; i++ { + accum.Div(accum, big10) + exp++ + } + } + } + + reduced := &big.Int{} + remainder := &big.Int{} + if exp >= 0 { + for { + reduced.DivMod(accum, big10, remainder) + if remainder.Cmp(big0) != 0 { + break + } + accum.Set(reduced) + exp++ + } + } + + if sign != 0 { + accum.Neg(accum) + } + + *dst = Numeric{Int: accum, Exp: exp, Status: Present} + + return nil + +} + +func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) { + digits := len(src) / 2 + if digits > 4 { + digits = 4 + } + + rp := 0 + + for i := 0; i < digits; i++ { + if i > 0 { + accum *= nbase + } + accum += int64(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + } + + return accum, rp, digits +} + +func (src *Numeric) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if _, err := io.WriteString(w, src.Int.String()); err != nil { + return false, err + } + + if err := pgio.WriteByte(w, 'e'); err != nil { + return false, err + } + + if _, err := io.WriteString(w, strconv.FormatInt(int64(src.Exp), 10)); err != nil { + return false, err + } + + return false, nil + +} + +func (src *Numeric) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + var sign int16 + if src.Int.Cmp(big0) < 0 { + sign = 16384 + } + + absInt := &big.Int{} + wholePart := &big.Int{} + fracPart := &big.Int{} + remainder := &big.Int{} + absInt.Abs(src.Int) + + // Normalize absInt and exp to where exp is always a multiple of 4. This makes + // converting to 16-bit base 10,000 digits easier. + var exp int32 + switch src.Exp % 4 { + case 1, -3: + exp = src.Exp - 1 + absInt.Mul(absInt, big10) + case 2, -2: + exp = src.Exp - 2 + absInt.Mul(absInt, big100) + case 3, -1: + exp = src.Exp - 3 + absInt.Mul(absInt, big1000) + default: + exp = src.Exp + } + + if exp < 0 { + divisor := &big.Int{} + divisor.Exp(big10, big.NewInt(int64(-exp)), nil) + wholePart.DivMod(absInt, divisor, fracPart) + } else { + wholePart = absInt + } + + var wholeDigits, fracDigits []int16 + + for wholePart.Cmp(big0) != 0 { + wholePart.DivMod(wholePart, bigNBase, remainder) + wholeDigits = append(wholeDigits, int16(remainder.Int64())) + } + + for fracPart.Cmp(big0) != 0 { + fracPart.DivMod(fracPart, bigNBase, remainder) + fracDigits = append(fracDigits, int16(remainder.Int64())) + } + + if _, err := pgio.WriteInt16(w, int16(len(wholeDigits)+len(fracDigits))); err != nil { + return false, err + } + + var weight int16 + if len(wholeDigits) > 0 { + weight = int16(len(wholeDigits) - 1) + if exp > 0 { + weight += int16(exp / 4) + } + } else { + weight = int16(exp/4) - 1 + int16(len(fracDigits)) + } + if _, err := pgio.WriteInt16(w, weight); err != nil { + return false, err + } + + if _, err := pgio.WriteInt16(w, sign); err != nil { + return false, err + } + + var dscale int16 + if src.Exp < 0 { + dscale = int16(-src.Exp) + } + if _, err := pgio.WriteInt16(w, dscale); err != nil { + return false, err + } + + for i := len(wholeDigits) - 1; i >= 0; i-- { + if _, err := pgio.WriteInt16(w, wholeDigits[i]); err != nil { + return false, err + } + } + + for i := len(fracDigits) - 1; i >= 0; i-- { + if _, err := pgio.WriteInt16(w, fracDigits[i]); err != nil { + return false, err + } + } + + return false, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Numeric) Scan(src interface{}) error { + if src == nil { + *dst = Numeric{Status: Null} + return nil + } + + switch src := src.(type) { + case float64: + // TODO + // *dst = Numeric{Float: src, Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Numeric) Value() (driver.Value, error) { + switch src.Status { + case Present: + buf := &bytes.Buffer{} + _, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + + return buf.String(), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go new file mode 100644 index 00000000..64dea847 --- /dev/null +++ b/pgtype/numeric_test.go @@ -0,0 +1,315 @@ +package pgtype_test + +import ( + "math/big" + "math/rand" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +// For test purposes only. Note that it does not normalize values. e.g. (Int: 1, Exp: 3) will not equal (Int: 1000, Exp: 0) +func numericEqual(left, right *pgtype.Numeric) bool { + return left.Status == right.Status && + left.Exp == right.Exp && + ((left.Int == nil && right.Int == nil) || (left.Int != nil && right.Int != nil && left.Int.Cmp(right.Int) == 0)) +} + +// For test purposes only. +func numericNormalizedEqual(left, right *pgtype.Numeric) bool { + if left.Status != right.Status { + return false + } + + normLeft := &pgtype.Numeric{Int: (&big.Int{}).Set(left.Int), Status: left.Status} + normRight := &pgtype.Numeric{Int: (&big.Int{}).Set(right.Int), Status: right.Status} + + if left.Exp < right.Exp { + mul := (&big.Int{}).Exp(big.NewInt(10), big.NewInt(int64(right.Exp-left.Exp)), nil) + normRight.Int.Mul(normRight.Int, mul) + } else if left.Exp > right.Exp { + mul := (&big.Int{}).Exp(big.NewInt(10), big.NewInt(int64(left.Exp-right.Exp)), nil) + normLeft.Int.Mul(normLeft.Int, mul) + } + + return normLeft.Int.Cmp(normRight.Int) == 0 +} + +func mustParseBigInt(t *testing.T, src string) *big.Int { + i := &big.Int{} + if _, ok := i.SetString(src, 10); !ok { + t.Fatalf("could not parse big.Int: %s", src) + } + return i +} + +func TestNumericNormalize(t *testing.T) { + testSuccessfulNormalize(t, []normalizeTest{ + { + sql: "select '0'::numeric", + value: pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present}, + }, + { + sql: "select '1'::numeric", + value: pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present}, + }, + { + sql: "select '10.00'::numeric", + value: pgtype.Numeric{Int: big.NewInt(1000), Exp: -2, Status: pgtype.Present}, + }, + { + sql: "select '1e-3'::numeric", + value: pgtype.Numeric{Int: big.NewInt(1), Exp: -3, Status: pgtype.Present}, + }, + { + sql: "select '-1'::numeric", + value: pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Status: pgtype.Present}, + }, + { + sql: "select '10000'::numeric", + value: pgtype.Numeric{Int: big.NewInt(1), Exp: 4, Status: pgtype.Present}, + }, + { + sql: "select '3.14'::numeric", + value: pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Status: pgtype.Present}, + }, + { + sql: "select '1.1'::numeric", + value: pgtype.Numeric{Int: big.NewInt(11), Exp: -1, Status: pgtype.Present}, + }, + { + sql: "select '100010001'::numeric", + value: pgtype.Numeric{Int: big.NewInt(100010001), Exp: 0, Status: pgtype.Present}, + }, + { + sql: "select '100010001.0001'::numeric", + value: pgtype.Numeric{Int: big.NewInt(1000100010001), Exp: -4, Status: pgtype.Present}, + }, + { + sql: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric", + value: pgtype.Numeric{ + Int: mustParseBigInt(t, "423723478923478928934789237432487213832189417894318904389012483210893443219085471578891547854892438945012347981"), + Exp: -41, + Status: pgtype.Present, + }, + }, + { + sql: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric", + value: pgtype.Numeric{ + Int: mustParseBigInt(t, "8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), + Exp: -196, + Status: pgtype.Present, + }, + }, + { + sql: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric", + value: pgtype.Numeric{ + Int: mustParseBigInt(t, "123"), + Exp: -186, + Status: pgtype.Present, + }, + }, + }) +} + +func TestNumericTranscode(t *testing.T) { + testSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ + &pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(1), Exp: 6, Status: pgtype.Present}, + + // preserves significant zeroes + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -1, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -2, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -3, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -4, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -5, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(10000000), Exp: -6, Status: pgtype.Present}, + + &pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(123), Exp: -7, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(123), Exp: -8, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(123), Exp: -9, Status: pgtype.Present}, + &pgtype.Numeric{Int: big.NewInt(123), Exp: -1500, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "2437"), Exp: 23790, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "243723409723490243842378942378901237502734019231380123"), Exp: 23790, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 80, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "3723409723490243842378942378901237502734019231380123"), Exp: 81, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "723409723490243842378942378901237502734019231380123"), Exp: 82, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "23409723490243842378942378901237502734019231380123"), Exp: 83, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "3409723490243842378942378901237502734019231380123"), Exp: 84, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "913423409823409243892349028349023482934092340892390101"), Exp: -14021, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -90, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "3423409823409243892349028349023482934092340892390101"), Exp: -91, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "423409823409243892349028349023482934092340892390101"), Exp: -92, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "23409823409243892349028349023482934092340892390101"), Exp: -93, Status: pgtype.Present}, + &pgtype.Numeric{Int: mustParseBigInt(t, "3409823409243892349028349023482934092340892390101"), Exp: -94, Status: pgtype.Present}, + &pgtype.Numeric{Status: pgtype.Null}, + }, func(aa, bb interface{}) bool { + a := aa.(pgtype.Numeric) + b := bb.(pgtype.Numeric) + + return numericEqual(&a, &b) + }) + +} + +func TestNumericTranscodeFuzz(t *testing.T) { + r := rand.New(rand.NewSource(0)) + max := &big.Int{} + max.SetString("9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10) + + values := make([]interface{}, 0, 2000) + for i := 0; i < 10; i++ { + for j := -50; j < 50; j++ { + num := (&big.Int{}).Rand(r, max) + negNum := &big.Int{} + negNum.Neg(num) + values = append(values, &pgtype.Numeric{Int: num, Exp: int32(j), Status: pgtype.Present}) + values = append(values, &pgtype.Numeric{Int: negNum, Exp: int32(j), Status: pgtype.Present}) + } + } + + testSuccessfulTranscodeEqFunc(t, "numeric", values, + func(aa, bb interface{}) bool { + a := aa.(pgtype.Numeric) + b := bb.(pgtype.Numeric) + + return numericNormalizedEqual(&a, &b) + }) +} + +func TestNumericSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result *pgtype.Numeric + }{ + {source: float32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: float64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: int8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: int16(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: int32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: int64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: int8(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, + {source: int16(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, + {source: int32(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, + {source: int64(-1), result: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}}, + {source: uint8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: uint16(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: uint32(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: uint64(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: "1", result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: _int8(1), result: &pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}}, + {source: float64(1000), result: &pgtype.Numeric{Int: big.NewInt(1), Exp: 3, Status: pgtype.Present}}, + {source: float64(1234), result: &pgtype.Numeric{Int: big.NewInt(1234), Exp: 0, Status: pgtype.Present}}, + {source: float64(12345678900), result: &pgtype.Numeric{Int: big.NewInt(123456789), Exp: 2, Status: pgtype.Present}}, + {source: float64(12345.678901), result: &pgtype.Numeric{Int: big.NewInt(12345678901), Exp: -6, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + r := &pgtype.Numeric{} + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !numericEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestNumericAssignTo(t *testing.T) { + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + var f32 float32 + var f64 float64 + var pf32 *float32 + var pf64 *float64 + + simpleTests := []struct { + src *pgtype.Numeric + dst interface{} + expected interface{} + }{ + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &f32, expected: float32(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &f64, expected: float64(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src *pgtype.Numeric + dst interface{} + expected interface{} + }{ + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src *pgtype.Numeric + dst interface{} + }{ + {src: &pgtype.Numeric{Int: big.NewInt(150), Status: pgtype.Present}, dst: &i8}, + {src: &pgtype.Numeric{Int: big.NewInt(40000), Status: pgtype.Present}, dst: &i16}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui8}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui16}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui32}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui64}, + {src: &pgtype.Numeric{Int: big.NewInt(-1), Status: pgtype.Present}, dst: &ui}, + {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &i32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 3d691044..84939b58 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -228,6 +228,7 @@ func init() { "cidr": &Cidr{}, "date": &Date{}, "daterange": &Daterange{}, + "decimal": &Decimal{}, "float4": &Float4{}, "float8": &Float8{}, "hstore": &Hstore{}, @@ -240,6 +241,7 @@ func init() { "json": &Json{}, "jsonb": &Jsonb{}, "name": &Name{}, + "numeric": &Numeric{}, "oid": &OidValue{}, "record": &Record{}, "text": &Text{}, diff --git a/values.go b/values.go index 5370bf47..71c4cc5c 100644 --- a/values.go +++ b/values.go @@ -118,6 +118,16 @@ func chooseParameterFormatCode(ci *pgtype.ConnInfo, oid pgtype.Oid, arg interfac if dt, ok := ci.DataTypeForOid(oid); ok { if _, ok := dt.Value.(pgtype.BinaryEncoder); ok { + if arg, ok := arg.(driver.Valuer); ok { + if err := dt.Value.Set(arg); err != nil { + if value, err := arg.Value(); err == nil { + if _, ok := value.(string); ok { + return TextFormatCode + } + } + } + } + return BinaryFormatCode } } From 9e5e02cc837c2be57041231d6c2ce9d19eed6f55 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 1 Apr 2017 23:44:03 -0500 Subject: [PATCH 145/264] Add pgtype TODO notes --- v3.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/v3.md b/v3.md index 3e0aae82..a2384ace 100644 --- a/v3.md +++ b/v3.md @@ -62,3 +62,21 @@ Further clean up logging interface -- still some pre-loglevel code in place Possibly integrate internal logging support with context. Possibly add method that adds arbitrary pgx log data to context. Or add ability to configure what key(s) pgx looks at for additional log context. Consider whether to switch to logrus style or stick with log15 style logs Keep ability to change logging while running + +consider test to ensure that AssignTo makes copy of reference types +something like: +select array[1,2,3], array[4,5,6,7] + +pgtype TODO: +numrange +numeric[] +point +line +lseg +box +path +path +polygon +circle +macaddr +varbit From 6ca1c1e41e9d0748665393751ba9163837255a1e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 3 Apr 2017 07:35:19 -0500 Subject: [PATCH 146/264] Add pgtype.Numrange --- pgtype/numrange.go | 268 ++++++++++++++++++++++++++++++++++++++ pgtype/numrange_test.go | 33 +++++ pgtype/pgtype.go | 1 + pgtype/typed_range_gen.sh | 1 + v3.md | 1 - 5 files changed, 303 insertions(+), 1 deletion(-) create mode 100644 pgtype/numrange.go create mode 100644 pgtype/numrange_test.go diff --git a/pgtype/numrange.go b/pgtype/numrange.go new file mode 100644 index 00000000..cf42dcbd --- /dev/null +++ b/pgtype/numrange.go @@ -0,0 +1,268 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type Numrange struct { + Lower Numeric + Upper Numeric + LowerType BoundType + UpperType BoundType + Status Status +} + +func (dst *Numrange) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Numrange", src) +} + +func (dst *Numrange) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Numrange) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Numrange) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Numrange{Status: Null} + return nil + } + + utr, err := ParseUntypedTextRange(string(src)) + if err != nil { + return err + } + + *dst = Numrange{Status: Present} + + dst.LowerType = utr.LowerType + dst.UpperType = utr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeText(ci, []byte(utr.Lower)); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeText(ci, []byte(utr.Upper)); err != nil { + return err + } + } + + return nil +} + +func (dst *Numrange) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Numrange{Status: Null} + return nil + } + + ubr, err := ParseUntypedBinaryRange(src) + if err != nil { + return err + } + + *dst = Numrange{Status: Present} + + dst.LowerType = ubr.LowerType + dst.UpperType = ubr.UpperType + + if dst.LowerType == Empty { + return nil + } + + if dst.LowerType == Inclusive || dst.LowerType == Exclusive { + if err := dst.Lower.DecodeBinary(ci, ubr.Lower); err != nil { + return err + } + } + + if dst.UpperType == Inclusive || dst.UpperType == Exclusive { + if err := dst.Upper.DecodeBinary(ci, ubr.Upper); err != nil { + return err + } + } + + return nil +} + +func (src Numrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + switch src.LowerType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, '('); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, '['); err != nil { + return false, err + } + case Empty: + _, err := io.WriteString(w, "empty") + return false, err + default: + return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + } + + if src.LowerType != Unbounded { + if null, err := src.Lower.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + } + + if err := pgio.WriteByte(w, ','); err != nil { + return false, err + } + + if src.UpperType != Unbounded { + if null, err := src.Upper.EncodeText(ci, w); err != nil { + return false, err + } else if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + } + + switch src.UpperType { + case Exclusive, Unbounded: + if err := pgio.WriteByte(w, ')'); err != nil { + return false, err + } + case Inclusive: + if err := pgio.WriteByte(w, ']'); err != nil { + return false, err + } + default: + return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + } + + return false, nil +} + +func (src Numrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + var rangeType byte + switch src.LowerType { + case Inclusive: + rangeType |= lowerInclusiveMask + case Unbounded: + rangeType |= lowerUnboundedMask + case Exclusive: + case Empty: + err := pgio.WriteByte(w, emptyMask) + return false, err + default: + return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + } + + switch src.UpperType { + case Inclusive: + rangeType |= upperInclusiveMask + case Unbounded: + rangeType |= upperUnboundedMask + case Exclusive: + default: + return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + } + + if err := pgio.WriteByte(w, rangeType); err != nil { + return false, err + } + + valBuf := &bytes.Buffer{} + + if src.LowerType != Unbounded { + null, err := src.Lower.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + if src.UpperType != Unbounded { + null, err := src.Upper.EncodeBinary(ci, valBuf) + if err != nil { + return false, err + } + if null { + return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + } + + _, err = pgio.WriteInt32(w, int32(valBuf.Len())) + if err != nil { + return false, err + } + _, err = valBuf.WriteTo(w) + if err != nil { + return false, err + } + } + + return false, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Numrange) Scan(src interface{}) error { + if src == nil { + *dst = Numrange{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Numrange) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/pgtype/numrange_test.go b/pgtype/numrange_test.go new file mode 100644 index 00000000..81202362 --- /dev/null +++ b/pgtype/numrange_test.go @@ -0,0 +1,33 @@ +package pgtype_test + +import ( + "math/big" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestNumrangeTranscode(t *testing.T) { + testSuccessfulTranscode(t, "numrange", []interface{}{ + pgtype.Numrange{ + LowerType: pgtype.Empty, + UpperType: pgtype.Empty, + Status: pgtype.Present, + }, + pgtype.Numrange{ + Lower: pgtype.Numeric{Int: big.NewInt(-543), Exp: 3, Status: pgtype.Present}, + Upper: pgtype.Numeric{Int: big.NewInt(342), Exp: 1, Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + pgtype.Numrange{ + Lower: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Status: pgtype.Present}, + Upper: pgtype.Numeric{Int: big.NewInt(-5), Exp: 0, Status: pgtype.Present}, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + pgtype.Numrange{Status: pgtype.Null}, + }) +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 84939b58..d7e28641 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -242,6 +242,7 @@ func init() { "jsonb": &Jsonb{}, "name": &Name{}, "numeric": &Numeric{}, + "numrange": &Numrange{}, "oid": &OidValue{}, "record": &Record{}, "text": &Text{}, diff --git a/pgtype/typed_range_gen.sh b/pgtype/typed_range_gen.sh index b4220f09..bedda292 100644 --- a/pgtype/typed_range_gen.sh +++ b/pgtype/typed_range_gen.sh @@ -3,4 +3,5 @@ erb range_type=Int8range element_type=Int8 typed_range.go.erb > int8range.go erb range_type=Tsrange element_type=Timestamp typed_range.go.erb > tsrange.go erb range_type=Tstzrange element_type=Timestamptz typed_range.go.erb > tstzrange.go erb range_type=Daterange element_type=Date typed_range.go.erb > daterange.go +erb range_type=Numrange element_type=Numeric typed_range.go.erb > numrange.go goimports -w *range.go diff --git a/v3.md b/v3.md index a2384ace..b79ce9cd 100644 --- a/v3.md +++ b/v3.md @@ -68,7 +68,6 @@ something like: select array[1,2,3], array[4,5,6,7] pgtype TODO: -numrange numeric[] point line From c09c356b1902fc779f0514cd2ba0320143018b7f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 3 Apr 2017 07:46:45 -0500 Subject: [PATCH 147/264] Add pgtype.NumericArray --- pgtype/numeric_array.go | 357 +++++++++++++++++++++++++++++++++++ pgtype/numeric_array_test.go | 159 ++++++++++++++++ pgtype/pgtype.go | 1 + pgtype/typed_array_gen.sh | 1 + v3.md | 1 - 5 files changed, 518 insertions(+), 1 deletion(-) create mode 100644 pgtype/numeric_array.go create mode 100644 pgtype/numeric_array_test.go diff --git a/pgtype/numeric_array.go b/pgtype/numeric_array.go new file mode 100644 index 00000000..b147e6a2 --- /dev/null +++ b/pgtype/numeric_array.go @@ -0,0 +1,357 @@ +package pgtype + +import ( + "bytes" + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type NumericArray struct { + Elements []Numeric + Dimensions []ArrayDimension + Status Status +} + +func (dst *NumericArray) Set(src interface{}) error { + switch value := src.(type) { + + case []float32: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []float64: + if value == nil { + *dst = NumericArray{Status: Null} + } else if len(value) == 0 { + *dst = NumericArray{Status: Present} + } else { + elements := make([]Numeric, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = NumericArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Numeric", value) + } + + return nil +} + +func (dst *NumericArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *NumericArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + + case *[]float32: + *v = make([]float32, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]float64: + *v = make([]float64, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return nullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *NumericArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = NumericArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Numeric + + if len(uta.Elements) > 0 { + elements = make([]Numeric, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Numeric + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = NumericArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *NumericArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = NumericArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = NumericArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Numeric, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = NumericArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *NumericArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if len(src.Dimensions) == 0 { + _, err := io.WriteString(w, "{}") + return false, err + } + + err := EncodeTextArrayDimensions(w, src.Dimensions) + if err != nil { + return false, err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(w, ',') + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(w, '{') + if err != nil { + return false, err + } + } + } + + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else { + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(w, '}') + if err != nil { + return false, err + } + } + } + } + + return false, nil +} + +func (src *NumericArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("numeric"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "numeric") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(ci, w) + if err != nil { + return false, err + } + + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } + } + } + + return false, err +} + +// Scan implements the database/sql Scanner interface. +func (dst *NumericArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *NumericArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/numeric_array_test.go b/pgtype/numeric_array_test.go new file mode 100644 index 00000000..af2e8e51 --- /dev/null +++ b/pgtype/numeric_array_test.go @@ -0,0 +1,159 @@ +package pgtype_test + +import ( + "math/big" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestNumericArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "numeric[]", []interface{}{ + &pgtype.NumericArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}, + pgtype.Numeric{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.NumericArray{Status: pgtype.Null}, + &pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}, + pgtype.Numeric{Int: big.NewInt(2), Status: pgtype.Present}, + pgtype.Numeric{Int: big.NewInt(3), Status: pgtype.Present}, + pgtype.Numeric{Int: big.NewInt(4), Status: pgtype.Present}, + pgtype.Numeric{Status: pgtype.Null}, + pgtype.Numeric{Int: big.NewInt(6), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.NumericArray{ + Elements: []pgtype.Numeric{ + pgtype.Numeric{Int: big.NewInt(1), Status: pgtype.Present}, + pgtype.Numeric{Int: big.NewInt(2), Status: pgtype.Present}, + pgtype.Numeric{Int: big.NewInt(3), Status: pgtype.Present}, + pgtype.Numeric{Int: big.NewInt(4), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestNumericArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.NumericArray + }{ + { + source: []float32{1}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []float64{1}, + result: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]float32)(nil)), + result: pgtype.NumericArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.NumericArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestNumericArrayAssignTo(t *testing.T) { + var float32Slice []float32 + var float64Slice []float64 + + simpleTests := []struct { + src pgtype.NumericArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float32Slice, + expected: []float32{1}, + }, + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Int: big.NewInt(1), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float64Slice, + expected: []float64{1}, + }, + { + src: pgtype.NumericArray{Status: pgtype.Null}, + dst: &float32Slice, + expected: (([]float32)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.NumericArray + dst interface{} + }{ + { + src: pgtype.NumericArray{ + Elements: []pgtype.Numeric{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &float32Slice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } + +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index d7e28641..208b1f00 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -216,6 +216,7 @@ func init() { "_int2": &Int2Array{}, "_int4": &Int4Array{}, "_int8": &Int8Array{}, + "_numeric": &NumericArray{}, "_text": &TextArray{}, "_timestamp": &TimestampArray{}, "_timestamptz": &TimestamptzArray{}, diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh index 52612466..2e36b8b3 100644 --- a/pgtype/typed_array_gen.sh +++ b/pgtype/typed_array_gen.sh @@ -14,4 +14,5 @@ erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[] erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go erb pgtype_array_type=AclitemArray pgtype_element_type=Aclitem go_array_types=[]string element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go +erb pgtype_array_type=NumericArray pgtype_element_type=Numeric go_array_types=[]float32,[]float64 element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go goimports -w *_array.go diff --git a/v3.md b/v3.md index b79ce9cd..f1ec1990 100644 --- a/v3.md +++ b/v3.md @@ -68,7 +68,6 @@ something like: select array[1,2,3], array[4,5,6,7] pgtype TODO: -numeric[] point line lseg From 5a2feadf1128e1a3217691013b7d98ad0eb324d7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 3 Apr 2017 17:53:32 -0500 Subject: [PATCH 148/264] Add pgtype.Point --- example_custom_type_test.go | 52 +++++++++----- pgtype/pgtype.go | 1 + pgtype/point.go | 139 ++++++++++++++++++++++++++++++++++++ pgtype/point_test.go | 15 ++++ query_test.go | 15 +++- v3.md | 1 - 6 files changed, 202 insertions(+), 21 deletions(-) create mode 100644 pgtype/point.go create mode 100644 pgtype/point_test.go diff --git a/example_custom_type_test.go b/example_custom_type_test.go index 1c21c7e6..647b97e6 100644 --- a/example_custom_type_test.go +++ b/example_custom_type_test.go @@ -2,7 +2,6 @@ package pgx_test import ( "fmt" - "io" "regexp" "strconv" @@ -18,6 +17,25 @@ type Point struct { Status pgtype.Status } +func (dst *Point) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Point", src) +} + +func (dst *Point) Get() interface{} { + switch dst.Status { + case pgtype.Present: + return dst + case pgtype.Null: + return nil + default: + return dst.Status + } +} + +func (src *Point) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error { if src == nil { *dst = Point{Status: pgtype.Null} @@ -44,23 +62,12 @@ func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error { return nil } -func (src Point) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { - switch src.Status { - case pgtype.Null: - return true, nil - case pgtype.Undefined: - return false, fmt.Errorf("undefined") +func (src *Point) String() string { + if src.Status == pgtype.Null { + return "null point" } - _, err := io.WriteString(w, fmt.Sprintf("point(%v,%v)", src.X, src.Y)) - return false, err -} - -func (p Point) String() string { - if p.Status == pgtype.Present { - return fmt.Sprintf("%v, %v", p.X, p.Y) - } - return "null point" + return fmt.Sprintf("%.1f, %.1f", src.X, src.Y) } func Example_CustomType() { @@ -70,15 +77,22 @@ func Example_CustomType() { return } - var p Point - err = conn.QueryRow("select null::point").Scan(&p) + // Override registered handler for point + conn.ConnInfo.RegisterDataType(pgtype.DataType{ + Value: &Point{}, + Name: "point", + Oid: 600, + }) + + p := &Point{} + err = conn.QueryRow("select null::point").Scan(p) if err != nil { fmt.Println(err) return } fmt.Println(p) - err = conn.QueryRow("select point(1.5,2.5)").Scan(&p) + err = conn.QueryRow("select point(1.5,2.5)").Scan(p) if err != nil { fmt.Println(err) return diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 208b1f00..911ab70e 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -245,6 +245,7 @@ func init() { "numeric": &Numeric{}, "numrange": &Numrange{}, "oid": &OidValue{}, + "point": &Point{}, "record": &Record{}, "text": &Text{}, "tid": &Tid{}, diff --git a/pgtype/point.go b/pgtype/point.go new file mode 100644 index 00000000..1b40bc44 --- /dev/null +++ b/pgtype/point.go @@ -0,0 +1,139 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" +) + +type Point struct { + X float64 + Y float64 + Status Status +} + +func (dst *Point) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Point", src) +} + +func (dst *Point) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Point) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Point) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Point{Status: Null} + return nil + } + + if len(src) < 5 { + return fmt.Errorf("invalid length for point: %v", len(src)) + } + + parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) + if len(parts) < 2 { + return fmt.Errorf("invalid format for point") + } + + x, err := strconv.ParseFloat(parts[0], 64) + if err != nil { + return err + } + + y, err := strconv.ParseFloat(parts[1], 64) + if err != nil { + return err + } + + *dst = Point{X: x, Y: y, Status: Present} + return nil +} + +func (dst *Point) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Point{Status: Null} + return nil + } + + if len(src) != 16 { + return fmt.Errorf("invalid length for point: %v", len(src)) + } + + x := binary.BigEndian.Uint64(src) + y := binary.BigEndian.Uint64(src[8:]) + + *dst = Point{ + X: math.Float64frombits(x), + Y: math.Float64frombits(y), + Status: Present, + } + return nil +} + +func (src *Point) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f)`, src.X, src.Y)) + return false, err +} + +func (src *Point) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := pgio.WriteUint64(w, math.Float64bits(src.X)) + if err != nil { + return false, err + } + + _, err = pgio.WriteUint64(w, math.Float64bits(src.Y)) + return false, err +} + +// Scan implements the database/sql Scanner interface. +func (dst *Point) Scan(src interface{}) error { + if src == nil { + *dst = Point{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Point) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/pgtype/point_test.go b/pgtype/point_test.go new file mode 100644 index 00000000..4ddb8009 --- /dev/null +++ b/pgtype/point_test.go @@ -0,0 +1,15 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestPointTranscode(t *testing.T) { + testSuccessfulTranscode(t, "point", []interface{}{ + &pgtype.Point{X: 1.234, Y: 5.6789, Status: pgtype.Present}, + &pgtype.Point{X: -1.234, Y: -5.6789, Status: pgtype.Present}, + &pgtype.Point{Status: pgtype.Null}, + }) +} diff --git a/query_test.go b/query_test.go index 25347ec5..d0fcb706 100644 --- a/query_test.go +++ b/query_test.go @@ -710,6 +710,19 @@ func TestQueryRowUnknownType(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) + // Clear existing type mappings + conn.ConnInfo = pgtype.NewConnInfo() + conn.ConnInfo.RegisterDataType(pgtype.DataType{ + Value: &pgtype.GenericText{}, + Name: "point", + Oid: 600, + }) + conn.ConnInfo.RegisterDataType(pgtype.DataType{ + Value: &pgtype.Int4{}, + Name: "int4", + Oid: pgtype.Int4Oid, + }) + sql := "select $1::point" expected := "(1,0)" var actual string @@ -751,7 +764,7 @@ func TestQueryRowErrors(t *testing.T) { {"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`}, {"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"}, {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "cannot decode"}, - {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "cannot convert 705 to Text"}, + {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "cannot convert 705 to Point"}, } for i, tt := range tests { diff --git a/v3.md b/v3.md index f1ec1990..70a378ad 100644 --- a/v3.md +++ b/v3.md @@ -68,7 +68,6 @@ something like: select array[1,2,3], array[4,5,6,7] pgtype TODO: -point line lseg box From 06822bebe0b83b2ddb82350b79fa596ae8452a5c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 3 Apr 2017 19:47:36 -0500 Subject: [PATCH 149/264] Add pgtype.Box --- pgtype/box.go | 168 +++++++++++++++++++++++++++++++++++++++++++ pgtype/box_test.go | 33 +++++++++ pgtype/pgtype.go | 1 + pgtype/point.go | 13 ++-- pgtype/point_test.go | 4 +- v3.md | 1 - 6 files changed, 212 insertions(+), 8 deletions(-) create mode 100644 pgtype/box.go create mode 100644 pgtype/box_test.go diff --git a/pgtype/box.go b/pgtype/box.go new file mode 100644 index 00000000..eaaddbff --- /dev/null +++ b/pgtype/box.go @@ -0,0 +1,168 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" +) + +type Box struct { + Corners [2]Vec2 + Status Status +} + +func (dst *Box) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Box", src) +} + +func (dst *Box) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Box) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Box) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Box{Status: Null} + return nil + } + + if len(src) < 11 { + return fmt.Errorf("invalid length for Box: %v", len(src)) + } + + str := string(src[1:]) + + var end int + end = strings.IndexByte(str, ',') + + x1, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y1, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+3:] + end = strings.IndexByte(str, ',') + + x2, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1 : len(str)-1] + + y2, err := strconv.ParseFloat(str, 64) + if err != nil { + return err + } + + *dst = Box{Corners: [2]Vec2{{x1, y1}, {x2, y2}}, Status: Present} + return nil +} + +func (dst *Box) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Box{Status: Null} + return nil + } + + if len(src) != 32 { + return fmt.Errorf("invalid length for Box: %v", len(src)) + } + + x1 := binary.BigEndian.Uint64(src) + y1 := binary.BigEndian.Uint64(src[8:]) + x2 := binary.BigEndian.Uint64(src[16:]) + y2 := binary.BigEndian.Uint64(src[24:]) + + *dst = Box{ + Corners: [2]Vec2{ + {math.Float64frombits(x1), math.Float64frombits(y1)}, + {math.Float64frombits(x2), math.Float64frombits(y2)}, + }, + Status: Present, + } + return nil +} + +func (src *Box) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f),(%f,%f)`, + src.Corners[0].X, src.Corners[0].Y, src.Corners[1].X, src.Corners[1].Y)) + return false, err +} + +func (src *Box) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if _, err := pgio.WriteUint64(w, math.Float64bits(src.Corners[0].X)); err != nil { + return false, err + } + + if _, err := pgio.WriteUint64(w, math.Float64bits(src.Corners[0].Y)); err != nil { + return false, err + } + + if _, err := pgio.WriteUint64(w, math.Float64bits(src.Corners[1].X)); err != nil { + return false, err + } + + _, err := pgio.WriteUint64(w, math.Float64bits(src.Corners[1].Y)) + return false, err +} + +// Scan implements the database/sql Scanner interface. +func (dst *Box) Scan(src interface{}) error { + if src == nil { + *dst = Box{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Box) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/pgtype/box_test.go b/pgtype/box_test.go new file mode 100644 index 00000000..21446dc3 --- /dev/null +++ b/pgtype/box_test.go @@ -0,0 +1,33 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestBoxTranscode(t *testing.T) { + testSuccessfulTranscode(t, "box", []interface{}{ + &pgtype.Box{ + Corners: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}}, + Status: pgtype.Present, + }, + &pgtype.Box{ + Corners: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Status: pgtype.Present, + }, + &pgtype.Box{Status: pgtype.Null}, + }) +} + +func TestBoxNormalize(t *testing.T) { + testSuccessfulNormalize(t, []normalizeTest{ + { + sql: "select '3.14, 1.678, 7.1, 5.234'::box", + value: &pgtype.Box{ + Corners: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}}, + Status: pgtype.Present, + }, + }, + }) +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 911ab70e..b29bc90c 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -223,6 +223,7 @@ func init() { "_varchar": &VarcharArray{}, "aclitem": &Aclitem{}, "bool": &Bool{}, + "box": &Box{}, "bytea": &Bytea{}, "char": &QChar{}, "cid": &Cid{}, diff --git a/pgtype/point.go b/pgtype/point.go index 1b40bc44..94f753e3 100644 --- a/pgtype/point.go +++ b/pgtype/point.go @@ -12,9 +12,13 @@ import ( "github.com/jackc/pgx/pgio" ) +type Vec2 struct { + X float64 + Y float64 +} + type Point struct { - X float64 - Y float64 + Vec2 Status Status } @@ -62,7 +66,7 @@ func (dst *Point) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Point{X: x, Y: y, Status: Present} + *dst = Point{Vec2: Vec2{x, y}, Status: Present} return nil } @@ -80,8 +84,7 @@ func (dst *Point) DecodeBinary(ci *ConnInfo, src []byte) error { y := binary.BigEndian.Uint64(src[8:]) *dst = Point{ - X: math.Float64frombits(x), - Y: math.Float64frombits(y), + Vec2: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, Status: Present, } return nil diff --git a/pgtype/point_test.go b/pgtype/point_test.go index 4ddb8009..723dfa60 100644 --- a/pgtype/point_test.go +++ b/pgtype/point_test.go @@ -8,8 +8,8 @@ import ( func TestPointTranscode(t *testing.T) { testSuccessfulTranscode(t, "point", []interface{}{ - &pgtype.Point{X: 1.234, Y: 5.6789, Status: pgtype.Present}, - &pgtype.Point{X: -1.234, Y: -5.6789, Status: pgtype.Present}, + &pgtype.Point{Vec2: pgtype.Vec2{1.234, 5.6789}, Status: pgtype.Present}, + &pgtype.Point{Vec2: pgtype.Vec2{-1.234, -5.6789}, Status: pgtype.Present}, &pgtype.Point{Status: pgtype.Null}, }) } diff --git a/v3.md b/v3.md index 70a378ad..1930508f 100644 --- a/v3.md +++ b/v3.md @@ -70,7 +70,6 @@ select array[1,2,3], array[4,5,6,7] pgtype TODO: line lseg -box path path polygon From 5394aa9a2b7a0a1a4f6e340f68592096924425d9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 4 Apr 2017 08:04:40 -0500 Subject: [PATCH 150/264] Add pgtype.Line --- pgtype/line.go | 148 ++++++++++++++++++++++++++++++++++++++++++++ pgtype/line_test.go | 21 +++++++ pgtype/pgtype.go | 1 + v3.md | 1 - 4 files changed, 170 insertions(+), 1 deletion(-) create mode 100644 pgtype/line.go create mode 100644 pgtype/line_test.go diff --git a/pgtype/line.go b/pgtype/line.go new file mode 100644 index 00000000..08a74e84 --- /dev/null +++ b/pgtype/line.go @@ -0,0 +1,148 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" +) + +type Line struct { + A, B, C float64 + Status Status +} + +func (dst *Line) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Line", src) +} + +func (dst *Line) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Line) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Line) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Line{Status: Null} + return nil + } + + if len(src) < 7 { + return fmt.Errorf("invalid length for Line: %v", len(src)) + } + + parts := strings.SplitN(string(src[1:len(src)-1]), ",", 3) + if len(parts) < 3 { + return fmt.Errorf("invalid format for line") + } + + a, err := strconv.ParseFloat(parts[0], 64) + if err != nil { + return err + } + + b, err := strconv.ParseFloat(parts[1], 64) + if err != nil { + return err + } + + c, err := strconv.ParseFloat(parts[2], 64) + if err != nil { + return err + } + + *dst = Line{A: a, B: b, C: c, Status: Present} + return nil +} + +func (dst *Line) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Line{Status: Null} + return nil + } + + if len(src) != 24 { + return fmt.Errorf("invalid length for Line: %v", len(src)) + } + + a := binary.BigEndian.Uint64(src) + b := binary.BigEndian.Uint64(src[8:]) + c := binary.BigEndian.Uint64(src[16:]) + + *dst = Line{ + A: math.Float64frombits(a), + B: math.Float64frombits(b), + C: math.Float64frombits(c), + Status: Present, + } + return nil +} + +func (src *Line) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := io.WriteString(w, fmt.Sprintf(`{%f,%f,%f}`, src.A, src.B, src.C)) + return false, err +} + +func (src *Line) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if _, err := pgio.WriteUint64(w, math.Float64bits(src.A)); err != nil { + return false, err + } + + if _, err := pgio.WriteUint64(w, math.Float64bits(src.B)); err != nil { + return false, err + } + + _, err := pgio.WriteUint64(w, math.Float64bits(src.C)) + return false, err +} + +// Scan implements the database/sql Scanner interface. +func (dst *Line) Scan(src interface{}) error { + if src == nil { + *dst = Line{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Line) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/pgtype/line_test.go b/pgtype/line_test.go new file mode 100644 index 00000000..6d3b02e1 --- /dev/null +++ b/pgtype/line_test.go @@ -0,0 +1,21 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestLineTranscode(t *testing.T) { + testSuccessfulTranscode(t, "line", []interface{}{ + &pgtype.Line{ + A: 1.23, B: 4.56, C: 7.89, + Status: pgtype.Present, + }, + &pgtype.Line{ + A: -1.23, B: -4.56, C: -7.89, + Status: pgtype.Present, + }, + &pgtype.Line{Status: pgtype.Null}, + }) +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index b29bc90c..c92dfccf 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -242,6 +242,7 @@ func init() { "int8range": &Int8range{}, "json": &Json{}, "jsonb": &Jsonb{}, + "line": &Line{}, "name": &Name{}, "numeric": &Numeric{}, "numrange": &Numrange{}, diff --git a/v3.md b/v3.md index 1930508f..38cf18cf 100644 --- a/v3.md +++ b/v3.md @@ -68,7 +68,6 @@ something like: select array[1,2,3], array[4,5,6,7] pgtype TODO: -line lseg path path From 365005d2072dc477c039cb416b696da703a2a913 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 4 Apr 2017 08:16:02 -0500 Subject: [PATCH 151/264] Add pgtype.Lseg --- pgtype/box.go | 18 ++--- pgtype/box_test.go | 12 ++-- pgtype/lseg.go | 168 ++++++++++++++++++++++++++++++++++++++++++++ pgtype/lseg_test.go | 21 ++++++ pgtype/pgtype.go | 1 + v3.md | 1 - 6 files changed, 205 insertions(+), 16 deletions(-) create mode 100644 pgtype/lseg.go create mode 100644 pgtype/lseg_test.go diff --git a/pgtype/box.go b/pgtype/box.go index eaaddbff..138953a5 100644 --- a/pgtype/box.go +++ b/pgtype/box.go @@ -13,8 +13,8 @@ import ( ) type Box struct { - Corners [2]Vec2 - Status Status + P [2]Vec2 + Status Status } func (dst *Box) Set(src interface{}) error { @@ -79,7 +79,7 @@ func (dst *Box) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Box{Corners: [2]Vec2{{x1, y1}, {x2, y2}}, Status: Present} + *dst = Box{P: [2]Vec2{{x1, y1}, {x2, y2}}, Status: Present} return nil } @@ -99,7 +99,7 @@ func (dst *Box) DecodeBinary(ci *ConnInfo, src []byte) error { y2 := binary.BigEndian.Uint64(src[24:]) *dst = Box{ - Corners: [2]Vec2{ + P: [2]Vec2{ {math.Float64frombits(x1), math.Float64frombits(y1)}, {math.Float64frombits(x2), math.Float64frombits(y2)}, }, @@ -117,7 +117,7 @@ func (src *Box) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f),(%f,%f)`, - src.Corners[0].X, src.Corners[0].Y, src.Corners[1].X, src.Corners[1].Y)) + src.P[0].X, src.P[0].Y, src.P[1].X, src.P[1].Y)) return false, err } @@ -129,19 +129,19 @@ func (src *Box) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, errUndefined } - if _, err := pgio.WriteUint64(w, math.Float64bits(src.Corners[0].X)); err != nil { + if _, err := pgio.WriteUint64(w, math.Float64bits(src.P[0].X)); err != nil { return false, err } - if _, err := pgio.WriteUint64(w, math.Float64bits(src.Corners[0].Y)); err != nil { + if _, err := pgio.WriteUint64(w, math.Float64bits(src.P[0].Y)); err != nil { return false, err } - if _, err := pgio.WriteUint64(w, math.Float64bits(src.Corners[1].X)); err != nil { + if _, err := pgio.WriteUint64(w, math.Float64bits(src.P[1].X)); err != nil { return false, err } - _, err := pgio.WriteUint64(w, math.Float64bits(src.Corners[1].Y)) + _, err := pgio.WriteUint64(w, math.Float64bits(src.P[1].Y)) return false, err } diff --git a/pgtype/box_test.go b/pgtype/box_test.go index 21446dc3..00732973 100644 --- a/pgtype/box_test.go +++ b/pgtype/box_test.go @@ -9,12 +9,12 @@ import ( func TestBoxTranscode(t *testing.T) { testSuccessfulTranscode(t, "box", []interface{}{ &pgtype.Box{ - Corners: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}}, - Status: pgtype.Present, + P: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}}, + Status: pgtype.Present, }, &pgtype.Box{ - Corners: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, - Status: pgtype.Present, + P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Status: pgtype.Present, }, &pgtype.Box{Status: pgtype.Null}, }) @@ -25,8 +25,8 @@ func TestBoxNormalize(t *testing.T) { { sql: "select '3.14, 1.678, 7.1, 5.234'::box", value: &pgtype.Box{ - Corners: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}}, - Status: pgtype.Present, + P: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}}, + Status: pgtype.Present, }, }, }) diff --git a/pgtype/lseg.go b/pgtype/lseg.go new file mode 100644 index 00000000..b86256e0 --- /dev/null +++ b/pgtype/lseg.go @@ -0,0 +1,168 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" +) + +type Lseg struct { + P [2]Vec2 + Status Status +} + +func (dst *Lseg) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Lseg", src) +} + +func (dst *Lseg) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Lseg) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Lseg) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Lseg{Status: Null} + return nil + } + + if len(src) < 11 { + return fmt.Errorf("invalid length for Lseg: %v", len(src)) + } + + str := string(src[2:]) + + var end int + end = strings.IndexByte(str, ',') + + x1, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y1, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+3:] + end = strings.IndexByte(str, ',') + + x2, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1 : len(str)-2] + + y2, err := strconv.ParseFloat(str, 64) + if err != nil { + return err + } + + *dst = Lseg{P: [2]Vec2{{x1, y1}, {x2, y2}}, Status: Present} + return nil +} + +func (dst *Lseg) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Lseg{Status: Null} + return nil + } + + if len(src) != 32 { + return fmt.Errorf("invalid length for Lseg: %v", len(src)) + } + + x1 := binary.BigEndian.Uint64(src) + y1 := binary.BigEndian.Uint64(src[8:]) + x2 := binary.BigEndian.Uint64(src[16:]) + y2 := binary.BigEndian.Uint64(src[24:]) + + *dst = Lseg{ + P: [2]Vec2{ + {math.Float64frombits(x1), math.Float64frombits(y1)}, + {math.Float64frombits(x2), math.Float64frombits(y2)}, + }, + Status: Present, + } + return nil +} + +func (src *Lseg) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f),(%f,%f)`, + src.P[0].X, src.P[0].Y, src.P[1].X, src.P[1].Y)) + return false, err +} + +func (src *Lseg) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if _, err := pgio.WriteUint64(w, math.Float64bits(src.P[0].X)); err != nil { + return false, err + } + + if _, err := pgio.WriteUint64(w, math.Float64bits(src.P[0].Y)); err != nil { + return false, err + } + + if _, err := pgio.WriteUint64(w, math.Float64bits(src.P[1].X)); err != nil { + return false, err + } + + _, err := pgio.WriteUint64(w, math.Float64bits(src.P[1].Y)) + return false, err +} + +// Scan implements the database/sql Scanner interface. +func (dst *Lseg) Scan(src interface{}) error { + if src == nil { + *dst = Lseg{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Lseg) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/pgtype/lseg_test.go b/pgtype/lseg_test.go new file mode 100644 index 00000000..5f041263 --- /dev/null +++ b/pgtype/lseg_test.go @@ -0,0 +1,21 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestLsegTranscode(t *testing.T) { + testSuccessfulTranscode(t, "lseg", []interface{}{ + &pgtype.Lseg{ + P: [2]pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}}, + Status: pgtype.Present, + }, + &pgtype.Lseg{ + P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Status: pgtype.Present, + }, + &pgtype.Lseg{Status: pgtype.Null}, + }) +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index c92dfccf..6d1f49af 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -243,6 +243,7 @@ func init() { "json": &Json{}, "jsonb": &Jsonb{}, "line": &Line{}, + "lseg": &Lseg{}, "name": &Name{}, "numeric": &Numeric{}, "numrange": &Numrange{}, diff --git a/v3.md b/v3.md index 38cf18cf..26c86f12 100644 --- a/v3.md +++ b/v3.md @@ -68,7 +68,6 @@ something like: select array[1,2,3], array[4,5,6,7] pgtype TODO: -lseg path path polygon From d14de1d1fca7afc2ac5d522aa8a171abccc8c744 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 4 Apr 2017 08:40:41 -0500 Subject: [PATCH 152/264] Add path --- pgtype/path.go | 207 ++++++++++++++++++++++++++++++++++++++++++++ pgtype/path_test.go | 28 ++++++ pgtype/pgtype.go | 1 + v3.md | 2 - 4 files changed, 236 insertions(+), 2 deletions(-) create mode 100644 pgtype/path.go create mode 100644 pgtype/path_test.go diff --git a/pgtype/path.go b/pgtype/path.go new file mode 100644 index 00000000..fb4193d9 --- /dev/null +++ b/pgtype/path.go @@ -0,0 +1,207 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" +) + +type Path struct { + P []Vec2 + Closed bool + Status Status +} + +func (dst *Path) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Path", src) +} + +func (dst *Path) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Path) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Path) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Path{Status: Null} + return nil + } + + if len(src) < 7 { + return fmt.Errorf("invalid length for Path: %v", len(src)) + } + + closed := src[0] == '(' + points := make([]Vec2, 0) + + str := string(src[2:]) + + for { + end := strings.IndexByte(str, ',') + x, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + points = append(points, Vec2{x, y}) + + if end+3 < len(str) { + str = str[end+3:] + } else { + break + } + } + + *dst = Path{P: points, Closed: closed, Status: Present} + return nil +} + +func (dst *Path) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Path{Status: Null} + return nil + } + + if len(src) < 5 { + return fmt.Errorf("invalid length for Path: %v", len(src)) + } + + closed := src[0] == 1 + pointCount := int(binary.BigEndian.Uint32(src[1:])) + + rp := 5 + + if 5+pointCount*16 != len(src) { + return fmt.Errorf("invalid length for Path with %d points: %v", pointCount, len(src)) + } + + points := make([]Vec2, pointCount) + for i := 0; i < len(points); i++ { + x := binary.BigEndian.Uint64(src[rp:]) + rp += 8 + y := binary.BigEndian.Uint64(src[rp:]) + rp += 8 + points[i] = Vec2{math.Float64frombits(x), math.Float64frombits(y)} + } + + *dst = Path{ + P: points, + Closed: closed, + Status: Present, + } + return nil +} + +func (src *Path) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + var startByte, endByte byte + if src.Closed { + startByte = '(' + endByte = ')' + } else { + startByte = '[' + endByte = ']' + } + if err := pgio.WriteByte(w, startByte); err != nil { + return false, err + } + + for i, p := range src.P { + if i > 0 { + if err := pgio.WriteByte(w, ','); err != nil { + return false, err + } + } + if _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f)`, p.X, p.Y)); err != nil { + return false, err + } + } + + err := pgio.WriteByte(w, endByte) + return false, err +} + +func (src *Path) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + var closeByte byte + if src.Closed { + closeByte = 1 + } + if err := pgio.WriteByte(w, closeByte); err != nil { + return false, err + } + + if _, err := pgio.WriteInt32(w, int32(len(src.P))); err != nil { + return false, err + } + + for _, p := range src.P { + if _, err := pgio.WriteUint64(w, math.Float64bits(p.X)); err != nil { + return false, err + } + + if _, err := pgio.WriteUint64(w, math.Float64bits(p.Y)); err != nil { + return false, err + } + } + + return false, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Path) Scan(src interface{}) error { + if src == nil { + *dst = Path{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Path) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/pgtype/path_test.go b/pgtype/path_test.go new file mode 100644 index 00000000..4e5f7f62 --- /dev/null +++ b/pgtype/path_test.go @@ -0,0 +1,28 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestPathTranscode(t *testing.T) { + testSuccessfulTranscode(t, "path", []interface{}{ + &pgtype.Path{ + P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}}, + Closed: false, + Status: pgtype.Present, + }, + &pgtype.Path{ + P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}, {23.1, 9.34}}, + Closed: true, + Status: pgtype.Present, + }, + &pgtype.Path{ + P: []pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Closed: true, + Status: pgtype.Present, + }, + &pgtype.Path{Status: pgtype.Null}, + }) +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 6d1f49af..18d21e20 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -248,6 +248,7 @@ func init() { "numeric": &Numeric{}, "numrange": &Numrange{}, "oid": &OidValue{}, + "path": &Path{}, "point": &Point{}, "record": &Record{}, "text": &Text{}, diff --git a/v3.md b/v3.md index 26c86f12..412d759d 100644 --- a/v3.md +++ b/v3.md @@ -68,8 +68,6 @@ something like: select array[1,2,3], array[4,5,6,7] pgtype TODO: -path -path polygon circle macaddr From 26e92b12c23c5af6a433b255fba0435574129e89 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 4 Apr 2017 20:24:01 -0500 Subject: [PATCH 153/264] Add pgtype.Uuid --- pgtype/pgtype.go | 1 + pgtype/uuid.go | 173 ++++++++++++++++++++++++++++++++++++++++++++ pgtype/uuid_test.go | 95 ++++++++++++++++++++++++ 3 files changed, 269 insertions(+) create mode 100644 pgtype/uuid.go create mode 100644 pgtype/uuid_test.go diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 18d21e20..5c8adb6e 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -258,6 +258,7 @@ func init() { "tsrange": &Tsrange{}, "tstzrange": &Tstzrange{}, "unknown": &Unknown{}, + "uuid": &Uuid{}, "varchar": &Varchar{}, "xid": &Xid{}, } diff --git a/pgtype/uuid.go b/pgtype/uuid.go new file mode 100644 index 00000000..111bed35 --- /dev/null +++ b/pgtype/uuid.go @@ -0,0 +1,173 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/hex" + "fmt" + "io" +) + +type Uuid struct { + Bytes [16]byte + Status Status +} + +func (dst *Uuid) Set(src interface{}) error { + switch value := src.(type) { + case [16]byte: + *dst = Uuid{Bytes: value, Status: Present} + case []byte: + if len(value) != 16 { + return fmt.Errorf("[]byte must be 16 bytes to convert to Uuid: %d", len(value)) + } + *dst = Uuid{Status: Present} + copy(dst.Bytes[:], value) + case string: + uuid, err := parseUuid(value) + if err != nil { + return err + } + *dst = Uuid{Bytes: uuid, Status: Present} + default: + if originalSrc, ok := underlyingPtrType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Uuid", value) + } + + return nil +} + +func (dst *Uuid) Get() interface{} { + switch dst.Status { + case Present: + return dst.Bytes + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Uuid) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *[16]byte: + *v = src.Bytes + return nil + case *[]byte: + *v = make([]byte, 16) + copy(*v, src.Bytes[:]) + return nil + case *string: + *v = encodeUuid(src.Bytes) + return nil + default: + if nextDst, retry := GetAssignToDstType(v); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return nullAssignTo(dst) + } + + return fmt.Errorf("cannot assign %v into %T", src, dst) +} + +// parseUuid converts a string UUID in standard form to a byte array. +func parseUuid(src string) (dst [16]byte, err error) { + src = src[0:8] + src[9:13] + src[14:18] + src[19:23] + src[24:] + buf, err := hex.DecodeString(src) + if err != nil { + return dst, err + } + + copy(dst[:], buf) + return dst, err +} + +// encodeUuid converts a uuid byte array to UUID standard string form. +func encodeUuid(src [16]byte) string { + return fmt.Sprintf("%x-%x-%x-%x-%x", src[0:4], src[4:6], src[6:8], src[8:10], src[10:16]) +} + +func (dst *Uuid) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Uuid{Status: Null} + return nil + } + + if len(src) != 36 { + return fmt.Errorf("invalid length for Uuid: %v", len(src)) + } + + buf, err := parseUuid(string(src)) + if err != nil { + return err + } + + *dst = Uuid{Bytes: buf, Status: Present} + return nil +} + +func (dst *Uuid) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Uuid{Status: Null} + return nil + } + + if len(src) != 16 { + return fmt.Errorf("invalid length for Uuid: %v", len(src)) + } + + *dst = Uuid{Status: Present} + copy(dst.Bytes[:], src) + return nil +} + +func (src Uuid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := io.WriteString(w, encodeUuid(src.Bytes)) + return false, err +} + +func (src Uuid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := w.Write(src.Bytes[:]) + return false, err +} + +// Scan implements the database/sql Scanner interface. +func (dst *Uuid) Scan(src interface{}) error { + if src == nil { + *dst = Uuid{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Uuid) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/pgtype/uuid_test.go b/pgtype/uuid_test.go new file mode 100644 index 00000000..1eba7e90 --- /dev/null +++ b/pgtype/uuid_test.go @@ -0,0 +1,95 @@ +package pgtype_test + +import ( + "bytes" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestUuidTranscode(t *testing.T) { + testSuccessfulTranscode(t, "uuid", []interface{}{ + pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + pgtype.Uuid{Status: pgtype.Null}, + }) +} + +func TestUuidSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Uuid + }{ + { + source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + { + source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + { + source: "00010203-0405-0607-0809-0a0b0c0d0e0f", + result: pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Uuid + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestUuidAssignTo(t *testing.T) { + { + src := pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst [16]byte + expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst []byte + expected := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if bytes.Compare(dst, expected) != 0 { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst string + expected := "00010203-0405-0607-0809-0a0b0c0d0e0f" + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + +} From dc71bedebfa0a5c8709811d88748c4c9e805a847 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 4 Apr 2017 20:30:04 -0500 Subject: [PATCH 154/264] Add pgtype.Polygon --- pgtype/pgtype.go | 1 + pgtype/polygon.go | 186 +++++++++++++++++++++++++++++++++++++++++ pgtype/polygon_test.go | 21 +++++ v3.md | 1 - 4 files changed, 208 insertions(+), 1 deletion(-) create mode 100644 pgtype/polygon.go create mode 100644 pgtype/polygon_test.go diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 5c8adb6e..cb0cec2c 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -250,6 +250,7 @@ func init() { "oid": &OidValue{}, "path": &Path{}, "point": &Point{}, + "polygon": &Polygon{}, "record": &Record{}, "text": &Text{}, "tid": &Tid{}, diff --git a/pgtype/polygon.go b/pgtype/polygon.go new file mode 100644 index 00000000..1e2df011 --- /dev/null +++ b/pgtype/polygon.go @@ -0,0 +1,186 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" +) + +type Polygon struct { + P []Vec2 + Status Status +} + +func (dst *Polygon) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Polygon", src) +} + +func (dst *Polygon) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Polygon) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Polygon) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Polygon{Status: Null} + return nil + } + + if len(src) < 7 { + return fmt.Errorf("invalid length for Polygon: %v", len(src)) + } + + points := make([]Vec2, 0) + + str := string(src[2:]) + + for { + end := strings.IndexByte(str, ',') + x, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + points = append(points, Vec2{x, y}) + + if end+3 < len(str) { + str = str[end+3:] + } else { + break + } + } + + *dst = Polygon{P: points, Status: Present} + return nil +} + +func (dst *Polygon) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Polygon{Status: Null} + return nil + } + + if len(src) < 5 { + return fmt.Errorf("invalid length for Polygon: %v", len(src)) + } + + pointCount := int(binary.BigEndian.Uint32(src)) + rp := 4 + + if 4+pointCount*16 != len(src) { + return fmt.Errorf("invalid length for Polygon with %d points: %v", pointCount, len(src)) + } + + points := make([]Vec2, pointCount) + for i := 0; i < len(points); i++ { + x := binary.BigEndian.Uint64(src[rp:]) + rp += 8 + y := binary.BigEndian.Uint64(src[rp:]) + rp += 8 + points[i] = Vec2{math.Float64frombits(x), math.Float64frombits(y)} + } + + *dst = Polygon{ + P: points, + Status: Present, + } + return nil +} + +func (src *Polygon) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if err := pgio.WriteByte(w, '('); err != nil { + return false, err + } + + for i, p := range src.P { + if i > 0 { + if err := pgio.WriteByte(w, ','); err != nil { + return false, err + } + } + if _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f)`, p.X, p.Y)); err != nil { + return false, err + } + } + + err := pgio.WriteByte(w, ')') + return false, err +} + +func (src *Polygon) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if _, err := pgio.WriteInt32(w, int32(len(src.P))); err != nil { + return false, err + } + + for _, p := range src.P { + if _, err := pgio.WriteUint64(w, math.Float64bits(p.X)); err != nil { + return false, err + } + + if _, err := pgio.WriteUint64(w, math.Float64bits(p.Y)); err != nil { + return false, err + } + } + + return false, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *Polygon) Scan(src interface{}) error { + if src == nil { + *dst = Polygon{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Polygon) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/pgtype/polygon_test.go b/pgtype/polygon_test.go new file mode 100644 index 00000000..3a7e1431 --- /dev/null +++ b/pgtype/polygon_test.go @@ -0,0 +1,21 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestPolygonTranscode(t *testing.T) { + testSuccessfulTranscode(t, "polygon", []interface{}{ + &pgtype.Polygon{ + P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}, {5.0, 3.234}}, + Status: pgtype.Present, + }, + &pgtype.Polygon{ + P: []pgtype.Vec2{{3.14, -1.678}, {7.1, -5.234}, {23.1, 9.34}}, + Status: pgtype.Present, + }, + &pgtype.Polygon{Status: pgtype.Null}, + }) +} diff --git a/v3.md b/v3.md index 412d759d..a879e384 100644 --- a/v3.md +++ b/v3.md @@ -68,7 +68,6 @@ something like: select array[1,2,3], array[4,5,6,7] pgtype TODO: -polygon circle macaddr varbit From 5be6819a8cae0a9ff30f77b35b65b2738683655d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 4 Apr 2017 20:39:48 -0500 Subject: [PATCH 155/264] Add pgtype.Circle Also rename Point.Vec2 to Point.P to conform to rest of geometric types. --- pgtype/circle.go | 150 ++++++++++++++++++++++++++++++++++++++++++ pgtype/circle_test.go | 15 +++++ pgtype/pgtype.go | 1 + pgtype/point.go | 12 ++-- pgtype/point_test.go | 4 +- v3.md | 1 - 6 files changed, 174 insertions(+), 9 deletions(-) create mode 100644 pgtype/circle.go create mode 100644 pgtype/circle_test.go diff --git a/pgtype/circle.go b/pgtype/circle.go new file mode 100644 index 00000000..62e2e8b3 --- /dev/null +++ b/pgtype/circle.go @@ -0,0 +1,150 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" +) + +type Circle struct { + P Vec2 + R float64 + Status Status +} + +func (dst *Circle) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Circle", src) +} + +func (dst *Circle) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Circle) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Circle) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Circle{Status: Null} + return nil + } + + if len(src) < 9 { + return fmt.Errorf("invalid length for Circle: %v", len(src)) + } + + str := string(src[2:]) + end := strings.IndexByte(str, ',') + x, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+1:] + end = strings.IndexByte(str, ')') + + y, err := strconv.ParseFloat(str[:end], 64) + if err != nil { + return err + } + + str = str[end+2 : len(str)-1] + + r, err := strconv.ParseFloat(str, 64) + if err != nil { + return err + } + + *dst = Circle{P: Vec2{x, y}, R: r, Status: Present} + return nil +} + +func (dst *Circle) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Circle{Status: Null} + return nil + } + + if len(src) != 24 { + return fmt.Errorf("invalid length for Circle: %v", len(src)) + } + + x := binary.BigEndian.Uint64(src) + y := binary.BigEndian.Uint64(src[8:]) + r := binary.BigEndian.Uint64(src[16:]) + + *dst = Circle{ + P: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, + R: math.Float64frombits(r), + Status: Present, + } + return nil +} + +func (src *Circle) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := io.WriteString(w, fmt.Sprintf(`<(%f,%f),%f>`, src.P.X, src.P.Y, src.R)) + return false, err +} + +func (src *Circle) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if _, err := pgio.WriteUint64(w, math.Float64bits(src.P.X)); err != nil { + return false, err + } + + if _, err := pgio.WriteUint64(w, math.Float64bits(src.P.Y)); err != nil { + return false, err + } + + _, err := pgio.WriteUint64(w, math.Float64bits(src.R)) + return false, err +} + +// Scan implements the database/sql Scanner interface. +func (dst *Circle) Scan(src interface{}) error { + if src == nil { + *dst = Circle{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Circle) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/pgtype/circle_test.go b/pgtype/circle_test.go new file mode 100644 index 00000000..9746dd74 --- /dev/null +++ b/pgtype/circle_test.go @@ -0,0 +1,15 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestCircleTranscode(t *testing.T) { + testSuccessfulTranscode(t, "circle", []interface{}{ + &pgtype.Circle{P: pgtype.Vec2{1.234, 5.6789}, R: 3.5, Status: pgtype.Present}, + &pgtype.Circle{P: pgtype.Vec2{-1.234, -5.6789}, R: 12.9, Status: pgtype.Present}, + &pgtype.Circle{Status: pgtype.Null}, + }) +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index cb0cec2c..52cad561 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -228,6 +228,7 @@ func init() { "char": &QChar{}, "cid": &Cid{}, "cidr": &Cidr{}, + "circle": &Circle{}, "date": &Date{}, "daterange": &Daterange{}, "decimal": &Decimal{}, diff --git a/pgtype/point.go b/pgtype/point.go index 94f753e3..788a76c9 100644 --- a/pgtype/point.go +++ b/pgtype/point.go @@ -18,7 +18,7 @@ type Vec2 struct { } type Point struct { - Vec2 + P Vec2 Status Status } @@ -66,7 +66,7 @@ func (dst *Point) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Point{Vec2: Vec2{x, y}, Status: Present} + *dst = Point{P: Vec2{x, y}, Status: Present} return nil } @@ -84,7 +84,7 @@ func (dst *Point) DecodeBinary(ci *ConnInfo, src []byte) error { y := binary.BigEndian.Uint64(src[8:]) *dst = Point{ - Vec2: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, + P: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, Status: Present, } return nil @@ -98,7 +98,7 @@ func (src *Point) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, errUndefined } - _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f)`, src.X, src.Y)) + _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f)`, src.P.X, src.P.Y)) return false, err } @@ -110,12 +110,12 @@ func (src *Point) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, errUndefined } - _, err := pgio.WriteUint64(w, math.Float64bits(src.X)) + _, err := pgio.WriteUint64(w, math.Float64bits(src.P.X)) if err != nil { return false, err } - _, err = pgio.WriteUint64(w, math.Float64bits(src.Y)) + _, err = pgio.WriteUint64(w, math.Float64bits(src.P.Y)) return false, err } diff --git a/pgtype/point_test.go b/pgtype/point_test.go index 723dfa60..c921f794 100644 --- a/pgtype/point_test.go +++ b/pgtype/point_test.go @@ -8,8 +8,8 @@ import ( func TestPointTranscode(t *testing.T) { testSuccessfulTranscode(t, "point", []interface{}{ - &pgtype.Point{Vec2: pgtype.Vec2{1.234, 5.6789}, Status: pgtype.Present}, - &pgtype.Point{Vec2: pgtype.Vec2{-1.234, -5.6789}, Status: pgtype.Present}, + &pgtype.Point{P: pgtype.Vec2{1.234, 5.6789}, Status: pgtype.Present}, + &pgtype.Point{P: pgtype.Vec2{-1.234, -5.6789}, Status: pgtype.Present}, &pgtype.Point{Status: pgtype.Null}, }) } diff --git a/v3.md b/v3.md index a879e384..9a69a2f2 100644 --- a/v3.md +++ b/v3.md @@ -68,6 +68,5 @@ something like: select array[1,2,3], array[4,5,6,7] pgtype TODO: -circle macaddr varbit From e5c48b17f27248648e0f6df9081399c3414d293c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 4 Apr 2017 21:07:27 -0500 Subject: [PATCH 156/264] Add pgtype.Macaddr --- pgtype/macaddr.go | 154 +++++++++++++++++++++++++++++++++++++++++ pgtype/macaddr_test.go | 77 +++++++++++++++++++++ pgtype/pgtype.go | 1 + pgtype/pgtype_test.go | 9 +++ v3.md | 1 - 5 files changed, 241 insertions(+), 1 deletion(-) create mode 100644 pgtype/macaddr.go create mode 100644 pgtype/macaddr_test.go diff --git a/pgtype/macaddr.go b/pgtype/macaddr.go new file mode 100644 index 00000000..2d09ff8c --- /dev/null +++ b/pgtype/macaddr.go @@ -0,0 +1,154 @@ +package pgtype + +import ( + "database/sql/driver" + "fmt" + "io" + "net" +) + +type Macaddr struct { + Addr net.HardwareAddr + Status Status +} + +func (dst *Macaddr) Set(src interface{}) error { + if src == nil { + *dst = Macaddr{Status: Null} + return nil + } + + switch value := src.(type) { + case net.HardwareAddr: + addr := make(net.HardwareAddr, len(value)) + copy(addr, value) + *dst = Macaddr{Addr: addr, Status: Present} + case string: + addr, err := net.ParseMAC(value) + if err != nil { + return err + } + *dst = Macaddr{Addr: addr, Status: Present} + default: + if originalSrc, ok := underlyingPtrType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Macaddr", value) + } + + return nil +} + +func (dst *Macaddr) Get() interface{} { + switch dst.Status { + case Present: + return dst.Addr + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Macaddr) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *net.HardwareAddr: + *v = make(net.HardwareAddr, len(src.Addr)) + copy(*v, src.Addr) + return nil + case *string: + *v = src.Addr.String() + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return nullAssignTo(dst) + } + + return fmt.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *Macaddr) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Macaddr{Status: Null} + return nil + } + + addr, err := net.ParseMAC(string(src)) + if err != nil { + return err + } + + *dst = Macaddr{Addr: addr, Status: Present} + return nil +} + +func (dst *Macaddr) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Macaddr{Status: Null} + return nil + } + + if len(src) != 6 { + return fmt.Errorf("Received an invalid size for a macaddr: %d", len(src)) + } + + addr := make(net.HardwareAddr, 6) + copy(addr, src) + + *dst = Macaddr{Addr: addr, Status: Present} + + return nil +} + +func (src Macaddr) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := io.WriteString(w, src.Addr.String()) + return false, err +} + +// EncodeBinary encodes src into w. +func (src Macaddr) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := w.Write([]byte(src.Addr)) + return false, err +} + +// Scan implements the database/sql Scanner interface. +func (dst *Macaddr) Scan(src interface{}) error { + if src == nil { + *dst = Macaddr{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Macaddr) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/pgtype/macaddr_test.go b/pgtype/macaddr_test.go new file mode 100644 index 00000000..6c7b8b89 --- /dev/null +++ b/pgtype/macaddr_test.go @@ -0,0 +1,77 @@ +package pgtype_test + +import ( + "bytes" + "net" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestMacaddrTranscode(t *testing.T) { + testSuccessfulTranscode(t, "macaddr", []interface{}{ + pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + pgtype.Macaddr{Status: pgtype.Null}, + }) +} + +func TestMacaddrSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Macaddr + }{ + { + source: mustParseMacaddr(t, "01:23:45:67:89:ab"), + result: pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + }, + { + source: "01:23:45:67:89:ab", + result: pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.Macaddr + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestMacaddrAssignTo(t *testing.T) { + { + src := pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present} + var dst net.HardwareAddr + expected := mustParseMacaddr(t, "01:23:45:67:89:ab") + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if bytes.Compare([]byte(dst), []byte(expected)) != 0 { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present} + var dst string + expected := "01:23:45:67:89:ab" + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 52cad561..6b06539b 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -245,6 +245,7 @@ func init() { "jsonb": &Jsonb{}, "line": &Line{}, "lseg": &Lseg{}, + "macaddr": &Macaddr{}, "name": &Name{}, "numeric": &Numeric{}, "numrange": &Numrange{}, diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 298cff64..0b1ffc54 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -78,6 +78,15 @@ func mustParseCidr(t testing.TB, s string) *net.IPNet { return ipnet } +func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr { + addr, err := net.ParseMAC(s) + if err != nil { + t.Fatal(err) + } + + return addr +} + type forceTextEncoder struct { e pgtype.TextEncoder } diff --git a/v3.md b/v3.md index 9a69a2f2..f522825d 100644 --- a/v3.md +++ b/v3.md @@ -68,5 +68,4 @@ something like: select array[1,2,3], array[4,5,6,7] pgtype TODO: -macaddr varbit From 52b58b88a68854fede7bd1693f953cdbfaead423 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 4 Apr 2017 21:13:00 -0500 Subject: [PATCH 157/264] Fix pgtype.Inet.AssignTo assigning reference AssignTo should always assign copy. Added documentation for AssignTo interface. --- pgtype/inet.go | 10 ++++++++-- pgtype/pgtype.go | 3 ++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/pgtype/inet.go b/pgtype/inet.go index 0ca3ee7a..3e00e2fa 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -70,13 +70,19 @@ func (src *Inet) AssignTo(dst interface{}) error { case Present: switch v := dst.(type) { case *net.IPNet: - *v = *src.IPNet + *v = net.IPNet{ + IP: make(net.IP, len(src.IPNet.IP)), + Mask: make(net.IPMask, len(src.IPNet.Mask)), + } + copy(v.IP, src.IPNet.IP) + copy(v.Mask, src.IPNet.Mask) return nil case *net.IP: if oneCount, bitCount := src.IPNet.Mask.Size(); oneCount != bitCount { return fmt.Errorf("cannot assign %v to %T", src, dst) } - *v = src.IPNet.IP + *v = make(net.IP, len(src.IPNet.IP)) + copy(*v, src.IPNet.IP) return nil default: if nextDst, retry := GetAssignToDstType(dst); retry { diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 6b06539b..5de07b7d 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -89,7 +89,8 @@ type Value interface { // possible, then Get() returns Value. Get() interface{} - // AssignTo converts and assigns the Value to dst. + // AssignTo converts and assigns the Value to dst. It MUST make a deep copy of + // any reference types. AssignTo(dst interface{}) error } From 54d9cbc743e9294fe1720f67533dfbfc229d1956 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 5 Apr 2017 07:54:41 -0500 Subject: [PATCH 158/264] Add pgtype.Varbit --- pgtype/pgtype.go | 1 + pgtype/varbit.go | 141 ++++++++++++++++++++++++++++++++++++++++++ pgtype/varbit_test.go | 25 ++++++++ v3.md | 7 --- 4 files changed, 167 insertions(+), 7 deletions(-) create mode 100644 pgtype/varbit.go create mode 100644 pgtype/varbit_test.go diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 5de07b7d..338afc9b 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -263,6 +263,7 @@ func init() { "tstzrange": &Tstzrange{}, "unknown": &Unknown{}, "uuid": &Uuid{}, + "varbit": &Varbit{}, "varchar": &Varchar{}, "xid": &Xid{}, } diff --git a/pgtype/varbit.go b/pgtype/varbit.go new file mode 100644 index 00000000..d28e95cd --- /dev/null +++ b/pgtype/varbit.go @@ -0,0 +1,141 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type Varbit struct { + Bytes []byte + Len int32 // Number of bits + Status Status +} + +func (dst *Varbit) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Varbit", src) +} + +func (dst *Varbit) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Varbit) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Varbit) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Varbit{Status: Null} + return nil + } + + bitLen := len(src) + byteLen := bitLen / 8 + if bitLen%8 > 0 { + byteLen++ + } + buf := make([]byte, byteLen) + + for i, b := range src { + if b == '1' { + byteIdx := i / 8 + bitIdx := uint(i % 8) + buf[byteIdx] = buf[byteIdx] | (128 >> bitIdx) + } + } + + *dst = Varbit{Bytes: buf, Len: int32(bitLen), Status: Present} + return nil +} + +func (dst *Varbit) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Varbit{Status: Null} + return nil + } + + if len(src) < 4 { + return fmt.Errorf("invalid length for varbit: %v", len(src)) + } + + bitLen := int32(binary.BigEndian.Uint32(src)) + rp := 4 + + buf := make([]byte, len(src[rp:])) + copy(buf, src[rp:]) + + *dst = Varbit{Bytes: buf, Len: bitLen, Status: Present} + return nil +} + +func (src *Varbit) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + buf := make([]byte, int(src.Len)) + for i, _ := range buf { + byteIdx := i / 8 + bitMask := byte(128 >> byte(i%8)) + char := byte('0') + if src.Bytes[byteIdx]&bitMask > 0 { + char = '1' + } + buf[i] = char + } + + _, err := w.Write(buf) + return false, err +} + +func (src *Varbit) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if _, err := pgio.WriteInt32(w, src.Len); err != nil { + return false, err + } + + _, err := w.Write(src.Bytes) + return false, err +} + +// Scan implements the database/sql Scanner interface. +func (dst *Varbit) Scan(src interface{}) error { + if src == nil { + *dst = Varbit{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Varbit) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/pgtype/varbit_test.go b/pgtype/varbit_test.go new file mode 100644 index 00000000..cd146d26 --- /dev/null +++ b/pgtype/varbit_test.go @@ -0,0 +1,25 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestVarbitTranscode(t *testing.T) { + testSuccessfulTranscode(t, "varbit", []interface{}{ + &pgtype.Varbit{Bytes: []byte{}, Len: 0, Status: pgtype.Present}, + &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Status: pgtype.Present}, + &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 128}, Len: 33, Status: pgtype.Present}, + &pgtype.Varbit{Status: pgtype.Null}, + }) +} + +func TestVarbitNormalize(t *testing.T) { + testSuccessfulNormalize(t, []normalizeTest{ + { + sql: "select B'111111111'", + value: &pgtype.Varbit{Bytes: []byte{255, 128}, Len: 9, Status: pgtype.Present}, + }, + }) +} diff --git a/v3.md b/v3.md index f522825d..20038938 100644 --- a/v3.md +++ b/v3.md @@ -42,14 +42,10 @@ Or maybe double-down on conn/pool coupling and improve connpool Add auto-idle pinging to conns in pool -Extract types Null* and Hstore to separate package - Remove names from prepared statements - use database/sql style objects Better way of handling text/binary protocol choice than pgx.DefaultTypeFormats or manually editing a PreparedStatement. Possibly an optional part of preparing a statement is specifying the format and/or a decoder. Or maybe it is part of a QueryEx call... Could be very interesting to make encoding and decoding possible without being a method of the type. This could drastically clean up those huge type switches. -Also maybe support binary and text for everything possible - dValueReader / msgReader cleanup Make easier / possible to mock Conn or ConnPool (https://github.com/jackc/pgx/pull/162) @@ -66,6 +62,3 @@ Keep ability to change logging while running consider test to ensure that AssignTo makes copy of reference types something like: select array[1,2,3], array[4,5,6,7] - -pgtype TODO: -varbit From 7b1f461ec302c7dd144f7524fe89db4e1e26b6c1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 10 Apr 2017 08:58:51 -0500 Subject: [PATCH 159/264] Add simple protocol suuport with (Query|Exec)Ex --- conn.go | 66 +++- conn_pool.go | 12 +- conn_test.go | 80 ++++- internal/sanitize/sanitize.go | 236 +++++++++++++ internal/sanitize/sanitize_test.go | 175 ++++++++++ pgtype/cid_test.go | 17 +- pgtype/json.go | 2 +- pgtype/numeric.go | 21 +- pgtype/numeric_test.go | 3 + pgtype/pgtype_test.go | 31 ++ pgtype/xid_test.go | 17 +- query.go | 65 +++- query_test.go | 510 +++++++++++++---------------- stdlib/sql.go | 2 +- stress_test.go | 12 +- values.go | 76 +++++ 16 files changed, 999 insertions(+), 326 deletions(-) create mode 100644 internal/sanitize/sanitize.go create mode 100644 internal/sanitize/sanitize_test.go diff --git a/conn.go b/conn.go index 6078cca2..c2cb408f 100644 --- a/conn.go +++ b/conn.go @@ -1021,7 +1021,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} // Exec executes sql. sql can be either a prepared statement name or an SQL string. // arguments should be referenced positionally from the sql string as $1, $2, etc. func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) { - return c.ExecContext(context.Background(), sql, arguments...) + return c.ExecEx(context.Background(), sql, nil, arguments...) } // Processes messages that are not exclusive to one context such as @@ -1364,24 +1364,16 @@ func (c *Conn) Ping() error { } func (c *Conn) PingContext(ctx context.Context) error { - _, err := c.ExecContext(ctx, ";") + _, err := c.ExecEx(ctx, ";", nil) return err } -func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) { +func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) { err = c.waitForPreviousCancelQuery(ctx) if err != nil { return "", err } - err = c.initContext(ctx) - if err != nil { - return "", err - } - defer func() { - err = c.termContext(err) - }() - if err = c.lock(); err != nil { return commandTag, err } @@ -1406,8 +1398,56 @@ func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interfa } }() - if err = c.sendQuery(sql, arguments...); err != nil { - return + if options != nil && options.SimpleProtocol { + err = c.initContext(ctx) + if err != nil { + return "", err + } + defer func() { + err = c.termContext(err) + }() + + err = c.sanitizeAndSendSimpleQuery(sql, arguments...) + if err != nil { + return "", err + + } + } else { + if len(arguments) > 0 { + ps, ok := c.preparedStatements[sql] + if !ok { + var err error + ps, err = c.PrepareExContext(ctx, "", sql, nil) + if err != nil { + return "", err + } + } + + err = c.initContext(ctx) + if err != nil { + return "", err + } + defer func() { + err = c.termContext(err) + }() + + err = c.sendPreparedQuery(ps, arguments...) + if err != nil { + return "", err + } + } else { + err = c.initContext(ctx) + if err != nil { + return "", err + } + defer func() { + err = c.termContext(err) + }() + + if err = c.sendQuery(sql, arguments...); err != nil { + return + } + } } var softErr error diff --git a/conn_pool.go b/conn_pool.go index 44559ea8..8703d7fa 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -360,14 +360,14 @@ func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag Comman return c.Exec(sql, arguments...) } -func (p *ConnPool) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) { +func (p *ConnPool) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) { var c *Conn if c, err = p.Acquire(); err != nil { return } defer p.Release(c) - return c.ExecContext(ctx, sql, arguments...) + return c.ExecEx(ctx, sql, options, arguments...) } // Query acquires a connection and delegates the call to that connection. When @@ -390,14 +390,14 @@ func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) { return rows, nil } -func (p *ConnPool) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) { +func (p *ConnPool) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (*Rows, error) { c, err := p.Acquire() if err != nil { // Because checking for errors can be deferred to the *Rows, build one with the error return &Rows{closed: true, err: err}, err } - rows, err := c.QueryContext(ctx, sql, args...) + rows, err := c.QueryEx(ctx, sql, options, args...) if err != nil { p.Release(c) return rows, err @@ -416,8 +416,8 @@ func (p *ConnPool) QueryRow(sql string, args ...interface{}) *Row { return (*Row)(rows) } -func (p *ConnPool) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *Row { - rows, _ := p.QueryContext(ctx, sql, args...) +func (p *ConnPool) QueryRowEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) *Row { + rows, _ := p.QueryEx(ctx, sql, options, args...) return (*Row)(rows) } diff --git a/conn_test.go b/conn_test.go index 50ea68f6..d4ca593f 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1033,7 +1033,7 @@ func TestExecFailure(t *testing.T) { } } -func TestExecContextWithoutCancelation(t *testing.T) { +func TestExecExContextWithoutCancelation(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -1042,16 +1042,16 @@ func TestExecContextWithoutCancelation(t *testing.T) { ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() - commandTag, err := conn.ExecContext(ctx, "create temporary table foo(id integer primary key);") + commandTag, err := conn.ExecEx(ctx, "create temporary table foo(id integer primary key);", nil) if err != nil { t.Fatal(err) } if commandTag != "CREATE TABLE" { - t.Fatalf("Unexpected results from ExecContext: %v", commandTag) + t.Fatalf("Unexpected results from ExecEx: %v", commandTag) } } -func TestExecContextFailureWithoutCancelation(t *testing.T) { +func TestExecExContextFailureWithoutCancelation(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -1060,18 +1060,18 @@ func TestExecContextFailureWithoutCancelation(t *testing.T) { ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() - if _, err := conn.ExecContext(ctx, "selct;"); err == nil { + if _, err := conn.ExecEx(ctx, "selct;", nil); err == nil { t.Fatal("Expected SQL syntax error") } rows, _ := conn.Query("select 1") rows.Close() if rows.Err() != nil { - t.Fatalf("ExecContext failure appears to have broken connection: %v", rows.Err()) + t.Fatalf("ExecEx failure appears to have broken connection: %v", rows.Err()) } } -func TestExecContextCancelationCancelsQuery(t *testing.T) { +func TestExecExContextCancelationCancelsQuery(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -1083,7 +1083,7 @@ func TestExecContextCancelationCancelsQuery(t *testing.T) { cancelFunc() }() - _, err := conn.ExecContext(ctx, "select pg_sleep(60)") + _, err := conn.ExecEx(ctx, "select pg_sleep(60)", nil) if err != context.Canceled { t.Fatalf("Expected context.Canceled err, got %v", err) } @@ -1091,6 +1091,70 @@ func TestExecContextCancelationCancelsQuery(t *testing.T) { ensureConnValid(t, conn) } +func TestExecExExtendedProtocol(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + commandTag, err := conn.ExecEx(ctx, "create temporary table foo(name varchar primary key);", nil) + if err != nil { + t.Fatal(err) + } + if commandTag != "CREATE TABLE" { + t.Fatalf("Unexpected results from ExecEx: %v", commandTag) + } + + commandTag, err = conn.ExecEx( + ctx, + "insert into foo(name) values($1);", + nil, + "bar", + ) + if err != nil { + t.Fatal(err) + } + if commandTag != "INSERT 0 1" { + t.Fatalf("Unexpected results from ExecEx: %v", commandTag) + } + + ensureConnValid(t, conn) +} + +func TestExecExSimpleProtocol(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + commandTag, err := conn.ExecEx(ctx, "create temporary table foo(name varchar primary key);", nil) + if err != nil { + t.Fatal(err) + } + if commandTag != "CREATE TABLE" { + t.Fatalf("Unexpected results from ExecEx: %v", commandTag) + } + + commandTag, err = conn.ExecEx( + ctx, + "insert into foo(name) values($1);", + &pgx.QueryExOptions{SimpleProtocol: true}, + "bar'; drop table foo;--", + ) + if err != nil { + t.Fatal(err) + } + if commandTag != "INSERT 0 1" { + t.Fatalf("Unexpected results from ExecEx: %v", commandTag) + } +} + func TestPrepare(t *testing.T) { t.Parallel() diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go new file mode 100644 index 00000000..92b892b9 --- /dev/null +++ b/internal/sanitize/sanitize.go @@ -0,0 +1,236 @@ +package sanitize + +import ( + "bytes" + "encoding/hex" + "fmt" + "strconv" + "strings" + "time" + "unicode/utf8" +) + +// Part is either a string or an int. A string is raw SQL. An int is a +// argument placeholder. +type Part interface{} + +type Query struct { + Parts []Part +} + +func (q *Query) Sanitize(args ...interface{}) (string, error) { + argUse := make([]bool, len(args)) + buf := &bytes.Buffer{} + + for _, part := range q.Parts { + var str string + switch part := part.(type) { + case string: + str = part + case int: + argIdx := part - 1 + if argIdx >= len(args) { + return "", fmt.Errorf("insufficient arguments") + } + arg := args[argIdx] + switch arg := arg.(type) { + case nil: + str = "null" + case int64: + str = strconv.FormatInt(arg, 10) + case float64: + str = strconv.FormatFloat(arg, 'f', -1, 64) + case bool: + str = strconv.FormatBool(arg) + case []byte: + str = QuoteBytes(arg) + case string: + str = QuoteString(arg) + case time.Time: + str = arg.Format("'2006-01-02 15:04:05.999999999Z07:00:00'") + default: + return "", fmt.Errorf("invalid arg type: %T", arg) + } + argUse[argIdx] = true + default: + return "", fmt.Errorf("invalid Part type: %T", part) + } + buf.WriteString(str) + } + + for i, used := range argUse { + if !used { + return "", fmt.Errorf("unused argument: %d", i) + } + } + return buf.String(), nil +} + +func NewQuery(sql string) (*Query, error) { + l := &sqlLexer{ + src: sql, + stateFn: rawState, + } + + for l.stateFn != nil { + l.stateFn = l.stateFn(l) + } + + query := &Query{Parts: l.parts} + + return query, nil +} + +func QuoteString(str string) string { + return "'" + strings.Replace(str, "'", "''", -1) + "'" +} + +func QuoteBytes(buf []byte) string { + return `'\x` + hex.EncodeToString(buf) + "'" +} + +type sqlLexer struct { + src string + start int + pos int + stateFn stateFn + parts []Part +} + +type stateFn func(*sqlLexer) stateFn + +func rawState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case 'e', 'E': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '\'' { + l.pos += width + return escapeStringState + } + case '\'': + return singleQuoteState + case '"': + return doubleQuoteState + case '$': + nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:]) + if '0' <= nextRune && nextRune <= '9' { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos-width]) + } + l.start = l.pos + return placeholderState + } + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func singleQuoteState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\'': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '\'' { + return rawState + } + l.pos += width + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +func doubleQuoteState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '"': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '"' { + return rawState + } + l.pos += width + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +// placeholderState consumes a placeholder value. The $ must have already has +// already been consumed. The first rune must be a digit. +func placeholderState(l *sqlLexer) stateFn { + num := 0 + + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + if '0' <= r && r <= '9' { + num *= 10 + num += int(r - '0') + } else { + l.parts = append(l.parts, num) + l.pos -= width + l.start = l.pos + return rawState + } + } +} + +func escapeStringState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\\': + _, width = utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + case '\'': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '\'' { + return rawState + } + l.pos += width + case utf8.RuneError: + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } +} + +// SanitizeSQL replaces placeholder values with args. It quotes and escapes args +// as necessary. This function is only safe when standard_conforming_strings is +// on. +func SanitizeSQL(sql string, args ...interface{}) (string, error) { + query, err := NewQuery(sql) + if err != nil { + return "", err + } + return query.Sanitize(args...) +} diff --git a/internal/sanitize/sanitize_test.go b/internal/sanitize/sanitize_test.go new file mode 100644 index 00000000..9597840e --- /dev/null +++ b/internal/sanitize/sanitize_test.go @@ -0,0 +1,175 @@ +package sanitize_test + +import ( + "testing" + + "github.com/jackc/pgx/internal/sanitize" +) + +func TestNewQuery(t *testing.T) { + successTests := []struct { + sql string + expected sanitize.Query + }{ + { + sql: "select 42", + expected: sanitize.Query{Parts: []sanitize.Part{"select 42"}}, + }, + { + sql: "select $1", + expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + }, + { + sql: "select 'quoted $42', $1", + expected: sanitize.Query{Parts: []sanitize.Part{"select 'quoted $42', ", 1}}, + }, + { + sql: `select "doubled quoted $42", $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select "doubled quoted $42", `, 1}}, + }, + { + sql: "select 'foo''bar', $1", + expected: sanitize.Query{Parts: []sanitize.Part{"select 'foo''bar', ", 1}}, + }, + { + sql: `select "foo""bar", $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select "foo""bar", `, 1}}, + }, + { + sql: "select '''', $1", + expected: sanitize.Query{Parts: []sanitize.Part{"select '''', ", 1}}, + }, + { + sql: `select """", $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select """", `, 1}}, + }, + { + sql: "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11", + expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2, ", ", 3, ", ", 4, ", ", 5, ", ", 6, ", ", 7, ", ", 8, ", ", 9, ", ", 10, ", ", 11}}, + }, + { + sql: `select "adsf""$1""adsf", $1, 'foo''$$12bar', $2, '$3'`, + expected: sanitize.Query{Parts: []sanitize.Part{`select "adsf""$1""adsf", `, 1, `, 'foo''$$12bar', `, 2, `, '$3'`}}, + }, + { + sql: `select E'escape string\' $42', $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select E'escape string\' $42', `, 1}}, + }, + { + sql: `select e'escape string\' $42', $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select e'escape string\' $42', `, 1}}, + }, + } + + for i, tt := range successTests { + query, err := sanitize.NewQuery(tt.sql) + if err != nil { + t.Errorf("%d. %v", i, err) + } + + if len(query.Parts) == len(tt.expected.Parts) { + for j := range query.Parts { + if query.Parts[j] != tt.expected.Parts[j] { + t.Errorf("%d. expected part %d to be %v but it was %v", i, j, tt.expected.Parts[j], query.Parts[j]) + } + } + } else { + t.Errorf("%d. expected query parts to be %v but it was %v", i, tt.expected.Parts, query.Parts) + } + } +} + +func TestQuerySanitize(t *testing.T) { + successfulTests := []struct { + query sanitize.Query + args []interface{} + expected string + }{ + { + query: sanitize.Query{Parts: []sanitize.Part{"select 42"}}, + args: []interface{}{}, + expected: `select 42`, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []interface{}{int64(42)}, + expected: `select 42`, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []interface{}{float64(1.23)}, + expected: `select 1.23`, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []interface{}{true}, + expected: `select true`, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []interface{}{[]byte{0, 1, 2, 3, 255}}, + expected: `select '\x00010203ff'`, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []interface{}{nil}, + expected: `select null`, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []interface{}{"foobar"}, + expected: `select 'foobar'`, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []interface{}{"foo'bar"}, + expected: `select 'foo''bar'`, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []interface{}{`foo\'bar`}, + expected: `select 'foo\''bar'`, + }, + } + + for i, tt := range successfulTests { + actual, err := tt.query.Sanitize(tt.args...) + if err != nil { + t.Errorf("%d. %v", i, err) + continue + } + + if tt.expected != actual { + t.Errorf("%d. expected %s, but got %s", i, tt.expected, actual) + } + } + + errorTests := []struct { + query sanitize.Query + args []interface{} + expected string + }{ + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2}}, + args: []interface{}{int64(42)}, + expected: `insufficient arguments`, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select 'foo'"}}, + args: []interface{}{int64(42)}, + expected: `unused argument: 0`, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []interface{}{42}, + expected: `invalid arg type: int`, + }, + } + + for i, tt := range errorTests { + _, err := tt.query.Sanitize(tt.args...) + if err == nil || err.Error() != tt.expected { + t.Errorf("%d. expected error %v, got %v", i, tt.expected, err) + } + } +} diff --git a/pgtype/cid_test.go b/pgtype/cid_test.go index 0d114cda..210573f6 100644 --- a/pgtype/cid_test.go +++ b/pgtype/cid_test.go @@ -8,10 +8,23 @@ import ( ) func TestCidTranscode(t *testing.T) { - testSuccessfulTranscode(t, "cid", []interface{}{ + pgTypeName := "cid" + values := []interface{}{ pgtype.Cid{Uint: 42, Status: pgtype.Present}, pgtype.Cid{Status: pgtype.Null}, - }) + } + eqFunc := func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + } + + testPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + + // No direct conversion from int to cid, convert through text + testPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc) + + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + testDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) + } } func TestCidSet(t *testing.T) { diff --git a/pgtype/json.go b/pgtype/json.go index 05d965ca..b1c061f9 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -145,7 +145,7 @@ func (dst *Json) Scan(src interface{}) error { func (src Json) Value() (driver.Value, error) { switch src.Status { case Present: - return src.Bytes, nil + return string(src.Bytes), nil case Null: return nil, nil default: diff --git a/pgtype/numeric.go b/pgtype/numeric.go index 0f3f6529..a26e8c89 100644 --- a/pgtype/numeric.go +++ b/pgtype/numeric.go @@ -121,13 +121,13 @@ func (src *Numeric) AssignTo(dst interface{}) error { case Present: switch v := dst.(type) { case *float32: - f, err := strconv.ParseFloat(src.Int.String(), 64) + f, err := src.toFloat64() if err != nil { return err } return float64AssignTo(f, src.Status, dst) case *float64: - f, err := strconv.ParseFloat(src.Int.String(), 64) + f, err := src.toFloat64() if err != nil { return err } @@ -283,6 +283,23 @@ func (dst *Numeric) toBigInt() (*big.Int, error) { return num, nil } +func (src *Numeric) toFloat64() (float64, error) { + f, err := strconv.ParseFloat(src.Int.String(), 64) + if err != nil { + return 0, err + } + if src.Exp > 0 { + for i := 0; i < int(src.Exp); i++ { + f *= 10 + } + } else if src.Exp < 0 { + for i := 0; i > int(src.Exp); i-- { + f /= 10 + } + } + return f, nil +} + func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Numeric{Status: Null} diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go index 64dea847..93aa8866 100644 --- a/pgtype/numeric_test.go +++ b/pgtype/numeric_test.go @@ -247,9 +247,12 @@ func TestNumericAssignTo(t *testing.T) { }{ {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &f32, expected: float32(42)}, {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &f64, expected: float64(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: -1, Status: pgtype.Present}, dst: &f32, expected: float32(4.2)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: -1, Status: pgtype.Present}, dst: &f64, expected: float64(4.2)}, {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i16, expected: int16(42)}, {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i32, expected: int32(42)}, {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: &pgtype.Numeric{Int: big.NewInt(42), Exp: 3, Status: pgtype.Present}, dst: &i64, expected: int64(42000)}, {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i, expected: int(42)}, {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, {src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 0b1ffc54..f486f077 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "context" "database/sql" "fmt" "io" @@ -125,6 +126,7 @@ func testSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { testPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + testPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { testDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) } @@ -175,6 +177,35 @@ func testPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values [] } } +func testPgxSimpleProtocolSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + conn := mustConnectPgx(t) + defer mustClose(t, conn) + + for i, v := range values { + // Derefence value if it is a pointer + derefV := v + refVal := reflect.ValueOf(v) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err := conn.QueryRowEx( + context.Background(), + fmt.Sprintf("select ($1)::%s", pgTypeName), + &pgx.QueryExOptions{SimpleProtocol: true}, + v, + ).Scan(result.Interface()) + if err != nil { + t.Errorf("Simple protocol %d: %v", i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("Simple protocol %d: expected %v, got %v", i, derefV, result.Elem().Interface()) + } + } +} + func testDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { conn := mustConnectDatabaseSQL(t, driverName) defer mustClose(t, conn) diff --git a/pgtype/xid_test.go b/pgtype/xid_test.go index fecfb64b..11dd0615 100644 --- a/pgtype/xid_test.go +++ b/pgtype/xid_test.go @@ -8,10 +8,23 @@ import ( ) func TestXidTranscode(t *testing.T) { - testSuccessfulTranscode(t, "xid", []interface{}{ + pgTypeName := "xid" + values := []interface{}{ pgtype.Xid{Uint: 42, Status: pgtype.Present}, pgtype.Xid{Status: pgtype.Null}, - }) + } + eqFunc := func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + } + + testPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + + // No direct conversion from int to xid, convert through text + testPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc) + + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + testDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) + } } func TestXidSet(t *testing.T) { diff --git a/query.go b/query.go index e820fabc..72f987b4 100644 --- a/query.go +++ b/query.go @@ -7,6 +7,7 @@ import ( "fmt" "time" + "github.com/jackc/pgx/internal/sanitize" "github.com/jackc/pgx/pgtype" ) @@ -123,6 +124,17 @@ func (rows *Rows) Next() bool { } switch t { + case rowDescription: + rows.fields = rows.conn.rxRowDescription(r) + for i := range rows.fields { + if dt, ok := rows.conn.ConnInfo.DataTypeForOid(rows.fields[i].DataType); ok { + rows.fields[i].DataTypeName = dt.Name + rows.fields[i].FormatCode = TextFormatCode + } else { + rows.Fatal(fmt.Errorf("unknown oid: %d", rows.fields[i].DataType)) + return false + } + } case dataRow: fieldCount := r.readInt16() if int(fieldCount) != len(rows.fields) { @@ -341,7 +353,7 @@ func (rows *Rows) AfterClose(f func(*Rows)) { // be returned in an error state. So it is allowed to ignore the error returned // from Query and handle it in *Rows. func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) { - return c.QueryContext(context.Background(), sql, args...) + return c.QueryEx(context.Background(), sql, nil, args...) } func (c *Conn) getRows(sql string, args []interface{}) *Rows { @@ -368,7 +380,11 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { return (*Row)(rows) } -func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (rows *Rows, err error) { +type QueryExOptions struct { + SimpleProtocol bool +} + +func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (rows *Rows, err error) { err = c.waitForPreviousCancelQuery(ctx) if err != nil { return nil, err @@ -384,6 +400,22 @@ func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{} } rows.unlockConn = true + if options != nil && options.SimpleProtocol { + err = c.initContext(ctx) + if err != nil { + rows.Fatal(err) + return rows, err + } + + err = c.sanitizeAndSendSimpleQuery(sql, args...) + if err != nil { + rows.Fatal(err) + return rows, err + } + + return rows, nil + } + ps, ok := c.preparedStatements[sql] if !ok { var err error @@ -411,7 +443,32 @@ func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{} return rows, err } -func (c *Conn) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *Row { - rows, _ := c.QueryContext(ctx, sql, args...) +func (c *Conn) sanitizeAndSendSimpleQuery(sql string, args ...interface{}) (err error) { + if c.RuntimeParams["standard_conforming_strings"] != "on" { + return errors.New("simple protocol queries must be run with standard_conforming_strings=on") + } + + if c.RuntimeParams["client_encoding"] != "UTF8" { + return errors.New("simple protocol queries must be run with client_encoding=UTF8") + } + + valueArgs := make([]interface{}, len(args)) + for i, a := range args { + valueArgs[i], err = convertSimpleArgument(c.ConnInfo, a) + if err != nil { + return err + } + } + + sql, err = sanitize.SanitizeSQL(sql, valueArgs...) + if err != nil { + return err + } + + return c.sendSimpleQuery(sql) +} + +func (c *Conn) QueryRowEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) *Row { + rows, _ := c.QueryEx(ctx, sql, options, args...) return (*Row)(rows) } diff --git a/query_test.go b/query_test.go index d0fcb706..66660ba1 100644 --- a/query_test.go +++ b/query_test.go @@ -797,275 +797,6 @@ func TestQueryRowNoResults(t *testing.T) { ensureConnValid(t, conn) } -func TestQueryRowCoreInt16Slice(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - var actual []int16 - - tests := []struct { - sql string - expected []int16 - }{ - {"select $1::int2[]", []int16{1, 2, 3, 4, 5}}, - {"select $1::int2[]", []int16{}}, - } - - for i, tt := range tests { - err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual) - if err != nil { - t.Errorf("%d. Unexpected failure: %v", i, err) - } - - if len(actual) != len(tt.expected) { - t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual) - } - - for j := 0; j < len(actual); j++ { - if actual[j] != tt.expected[j] { - t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j]) - } - } - - ensureConnValid(t, conn) - } - - // Check that Scan errors when an array with a null is scanned into a core slice type - err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int2[];").Scan(&actual) - if err == nil { - t.Error("Expected null to cause error when scanned into slice, but it didn't") - } - if err != nil && !(strings.Contains(err.Error(), "Cannot decode null") || strings.Contains(err.Error(), "cannot assign")) { - t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err) - } - - ensureConnValid(t, conn) -} - -func TestQueryRowCoreInt32Slice(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - var actual []int32 - - tests := []struct { - sql string - expected []int32 - }{ - {"select $1::int4[]", []int32{1, 2, 3, 4, 5}}, - {"select $1::int4[]", []int32{}}, - } - - for i, tt := range tests { - err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual) - if err != nil { - t.Errorf("%d. Unexpected failure: %v", i, err) - } - - if len(actual) != len(tt.expected) { - t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual) - } - - for j := 0; j < len(actual); j++ { - if actual[j] != tt.expected[j] { - t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j]) - } - } - - ensureConnValid(t, conn) - } - - // Check that Scan errors when an array with a null is scanned into a core slice type - err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int4[];").Scan(&actual) - if err == nil { - t.Error("Expected null to cause error when scanned into slice, but it didn't") - } - - ensureConnValid(t, conn) -} - -func TestQueryRowCoreInt64Slice(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - var actual []int64 - - tests := []struct { - sql string - expected []int64 - }{ - {"select $1::int8[]", []int64{1, 2, 3, 4, 5}}, - {"select $1::int8[]", []int64{}}, - } - - for i, tt := range tests { - err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual) - if err != nil { - t.Errorf("%d. Unexpected failure: %v", i, err) - } - - if len(actual) != len(tt.expected) { - t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual) - } - - for j := 0; j < len(actual); j++ { - if actual[j] != tt.expected[j] { - t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j]) - } - } - - ensureConnValid(t, conn) - } - - // Check that Scan errors when an array with a null is scanned into a core slice type - err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int8[];").Scan(&actual) - if err == nil { - t.Error("Expected null to cause error when scanned into slice, but it didn't") - } - - ensureConnValid(t, conn) -} - -func TestQueryRowCoreFloat32Slice(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - var actual []float32 - - tests := []struct { - sql string - expected []float32 - }{ - {"select $1::float4[]", []float32{1.5, 2.0, 3.5}}, - {"select $1::float4[]", []float32{}}, - } - - for i, tt := range tests { - err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual) - if err != nil { - t.Errorf("%d. Unexpected failure: %v", i, err) - } - - if len(actual) != len(tt.expected) { - t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual) - } - - for j := 0; j < len(actual); j++ { - if actual[j] != tt.expected[j] { - t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j]) - } - } - - ensureConnValid(t, conn) - } - - // Check that Scan errors when an array with a null is scanned into a core slice type - err := conn.QueryRow("select '{1.5, 2.0, 3.5, null}'::float4[];").Scan(&actual) - if err == nil { - t.Error("Expected null to cause error when scanned into slice, but it didn't") - } - - ensureConnValid(t, conn) -} - -func TestQueryRowCoreFloat64Slice(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - var actual []float64 - - tests := []struct { - sql string - expected []float64 - }{ - {"select $1::float8[]", []float64{1.5, 2.0, 3.5}}, - {"select $1::float8[]", []float64{}}, - } - - for i, tt := range tests { - err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual) - if err != nil { - t.Errorf("%d. Unexpected failure: %v", i, err) - } - - if len(actual) != len(tt.expected) { - t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual) - } - - for j := 0; j < len(actual); j++ { - if actual[j] != tt.expected[j] { - t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j]) - } - } - - ensureConnValid(t, conn) - } - - // Check that Scan errors when an array with a null is scanned into a core slice type - err := conn.QueryRow("select '{1.5, 2.0, 3.5, null}'::float8[];").Scan(&actual) - if err == nil { - t.Error("Expected null to cause error when scanned into slice, but it didn't") - } - - ensureConnValid(t, conn) -} - -func TestQueryRowCoreStringSlice(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - var actual []string - - tests := []struct { - sql string - expected []string - }{ - {"select $1::text[]", []string{"Adam", "Eve", "UTF-8 Characters ƅ Ɔ Ƌ Ļæ"}}, - {"select $1::text[]", []string{}}, - {"select $1::varchar[]", []string{"Adam", "Eve", "UTF-8 Characters ƅ Ɔ Ƌ Ļæ"}}, - {"select $1::varchar[]", []string{}}, - } - - for i, tt := range tests { - err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual) - if err != nil { - t.Errorf("%d. Unexpected failure: %v", i, err) - } - - if len(actual) != len(tt.expected) { - t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual) - } - - for j := 0; j < len(actual); j++ { - if actual[j] != tt.expected[j] { - t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j]) - } - } - - ensureConnValid(t, conn) - } - - // Check that Scan errors when an array with a null is scanned into a core slice type - err := conn.QueryRow("select '{Adam,Eve,NULL}'::text[];").Scan(&actual) - if err == nil { - t.Error("Expected null to cause error when scanned into slice, but it didn't") - } - - ensureConnValid(t, conn) -} - func TestReadingValueAfterEmptyArray(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) @@ -1236,7 +967,7 @@ func TestConnQueryDatabaseSQLNullX(t *testing.T) { ensureConnValid(t, conn) } -func TestQueryContextSuccess(t *testing.T) { +func TestQueryExContextSuccess(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -1245,7 +976,7 @@ func TestQueryContextSuccess(t *testing.T) { ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() - rows, err := conn.QueryContext(ctx, "select 42::integer") + rows, err := conn.QueryEx(ctx, "select 42::integer", nil) if err != nil { t.Fatal(err) } @@ -1273,7 +1004,7 @@ func TestQueryContextSuccess(t *testing.T) { ensureConnValid(t, conn) } -func TestQueryContextErrorWhileReceivingRows(t *testing.T) { +func TestQueryExContextErrorWhileReceivingRows(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -1282,7 +1013,7 @@ func TestQueryContextErrorWhileReceivingRows(t *testing.T) { ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() - rows, err := conn.QueryContext(ctx, "select 10/(10-n) from generate_series(1, 100) n") + rows, err := conn.QueryEx(ctx, "select 10/(10-n) from generate_series(1, 100) n", nil) if err != nil { t.Fatal(err) } @@ -1310,7 +1041,7 @@ func TestQueryContextErrorWhileReceivingRows(t *testing.T) { ensureConnValid(t, conn) } -func TestQueryContextCancelationCancelsQuery(t *testing.T) { +func TestQueryExContextCancelationCancelsQuery(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -1322,7 +1053,7 @@ func TestQueryContextCancelationCancelsQuery(t *testing.T) { cancelFunc() }() - rows, err := conn.QueryContext(ctx, "select pg_sleep(5)") + rows, err := conn.QueryEx(ctx, "select pg_sleep(5)", nil) if err != nil { t.Fatal(err) } @@ -1338,7 +1069,7 @@ func TestQueryContextCancelationCancelsQuery(t *testing.T) { ensureConnValid(t, conn) } -func TestQueryRowContextSuccess(t *testing.T) { +func TestQueryRowExContextSuccess(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -1348,7 +1079,7 @@ func TestQueryRowContextSuccess(t *testing.T) { defer cancelFunc() var result int - err := conn.QueryRowContext(ctx, "select 42::integer").Scan(&result) + err := conn.QueryRowEx(ctx, "select 42::integer", nil).Scan(&result) if err != nil { t.Fatal(err) } @@ -1359,7 +1090,7 @@ func TestQueryRowContextSuccess(t *testing.T) { ensureConnValid(t, conn) } -func TestQueryRowContextErrorWhileReceivingRow(t *testing.T) { +func TestQueryRowExContextErrorWhileReceivingRow(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -1369,7 +1100,7 @@ func TestQueryRowContextErrorWhileReceivingRow(t *testing.T) { defer cancelFunc() var result int - err := conn.QueryRowContext(ctx, "select 10/0").Scan(&result) + err := conn.QueryRowEx(ctx, "select 10/0", nil).Scan(&result) if err == nil || err.Error() != "ERROR: division by zero (SQLSTATE 22012)" { t.Fatalf("Expected division by zero error, but got %v", err) } @@ -1377,7 +1108,7 @@ func TestQueryRowContextErrorWhileReceivingRow(t *testing.T) { ensureConnValid(t, conn) } -func TestQueryRowContextCancelationCancelsQuery(t *testing.T) { +func TestQueryRowExContextCancelationCancelsQuery(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -1390,10 +1121,227 @@ func TestQueryRowContextCancelationCancelsQuery(t *testing.T) { }() var result []byte - err := conn.QueryRowContext(ctx, "select pg_sleep(5)").Scan(&result) + err := conn.QueryRowEx(ctx, "select pg_sleep(5)", nil).Scan(&result) if err != context.Canceled { t.Fatalf("Expected context.Canceled error, got %v", err) } ensureConnValid(t, conn) } + +func TestConnSimpleProtocol(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + // Test all supported low-level types + + { + expected := int64(42) + var actual int64 + err := conn.QueryRowEx( + context.Background(), + "select $1::int8", + &pgx.QueryExOptions{SimpleProtocol: true}, + expected, + ).Scan(&actual) + if err != nil { + t.Error(err) + } + if expected != actual { + t.Errorf("expected %v got %v", expected, actual) + } + } + + { + expected := float64(1.23) + var actual float64 + err := conn.QueryRowEx( + context.Background(), + "select $1::float8", + &pgx.QueryExOptions{SimpleProtocol: true}, + expected, + ).Scan(&actual) + if err != nil { + t.Error(err) + } + if expected != actual { + t.Errorf("expected %v got %v", expected, actual) + } + } + + { + expected := true + var actual bool + err := conn.QueryRowEx( + context.Background(), + "select $1", + &pgx.QueryExOptions{SimpleProtocol: true}, + expected, + ).Scan(&actual) + if err != nil { + t.Error(err) + } + if expected != actual { + t.Errorf("expected %v got %v", expected, actual) + } + } + + { + expected := []byte{0, 1, 20, 35, 64, 80, 120, 3, 255, 240, 128, 95} + var actual []byte + err := conn.QueryRowEx( + context.Background(), + "select $1::bytea", + &pgx.QueryExOptions{SimpleProtocol: true}, + expected, + ).Scan(&actual) + if err != nil { + t.Error(err) + } + if bytes.Compare(actual, expected) != 0 { + t.Errorf("expected %v got %v", expected, actual) + } + } + + { + expected := "test" + var actual string + err := conn.QueryRowEx( + context.Background(), + "select $1::text", + &pgx.QueryExOptions{SimpleProtocol: true}, + expected, + ).Scan(&actual) + if err != nil { + t.Error(err) + } + if expected != actual { + t.Errorf("expected %v got %v", expected, actual) + } + } + + // Test high-level type + + { + expected := pgtype.Line{A: 1, B: 2, C: 1.5, Status: pgtype.Present} + actual := expected + err := conn.QueryRowEx( + context.Background(), + "select $1::line", + &pgx.QueryExOptions{SimpleProtocol: true}, + &expected, + ).Scan(&actual) + if err != nil { + t.Error(err) + } + if expected != actual { + t.Errorf("expected %v got %v", expected, actual) + } + } + + // Test multiple args in single query + + { + expectedInt64 := int64(234423) + expectedFloat64 := float64(-0.2312) + expectedBool := true + expectedBytes := []byte{255, 0, 23, 16, 87, 45, 9, 23, 45, 223} + expectedString := "test" + var actualInt64 int64 + var actualFloat64 float64 + var actualBool bool + var actualBytes []byte + var actualString string + err := conn.QueryRowEx( + context.Background(), + "select $1::int8, $2::float8, $3, $4::bytea, $5::text", + &pgx.QueryExOptions{SimpleProtocol: true}, + expectedInt64, expectedFloat64, expectedBool, expectedBytes, expectedString, + ).Scan(&actualInt64, &actualFloat64, &actualBool, &actualBytes, &actualString) + if err != nil { + t.Error(err) + } + if expectedInt64 != actualInt64 { + t.Errorf("expected %v got %v", expectedInt64, actualInt64) + } + if expectedFloat64 != actualFloat64 { + t.Errorf("expected %v got %v", expectedFloat64, actualFloat64) + } + if expectedBool != actualBool { + t.Errorf("expected %v got %v", expectedBool, actualBool) + } + if bytes.Compare(expectedBytes, actualBytes) != 0 { + t.Errorf("expected %v got %v", expectedBytes, actualBytes) + } + if expectedString != actualString { + t.Errorf("expected %v got %v", expectedString, actualString) + } + } + + // Test dangerous cases + + { + expected := "foo';drop table users;" + var actual string + err := conn.QueryRowEx( + context.Background(), + "select $1", + &pgx.QueryExOptions{SimpleProtocol: true}, + expected, + ).Scan(&actual) + if err != nil { + t.Error(err) + } + if expected != actual { + t.Errorf("expected %v got %v", expected, actual) + } + } + + ensureConnValid(t, conn) +} + +func TestConnSimpleProtocolRefusesNonUTF8ClientEncoding(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, "set client_encoding to 'SQL_ASCII'") + + var expected string + err := conn.QueryRowEx( + context.Background(), + "select $1", + &pgx.QueryExOptions{SimpleProtocol: true}, + "test", + ).Scan(&expected) + if err == nil { + t.Error("expected error when client_encoding not UTF8, but no error occurred") + } + + ensureConnValid(t, conn) +} + +func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, "set standard_conforming_strings to off") + + var expected string + err := conn.QueryRowEx( + context.Background(), + "select $1", + &pgx.QueryExOptions{SimpleProtocol: true}, + `\'; drop table users; --`, + ).Scan(&expected) + if err == nil { + t.Error("expected error when standard_conforming_strings is off, but no error occurred") + } + + ensureConnValid(t, conn) +} diff --git a/stdlib/sql.go b/stdlib/sql.go index 000f0fbf..80a559af 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -268,7 +268,7 @@ func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []dr args := namedValueToInterface(argsV) - rows, err := c.conn.QueryContext(ctx, name, args...) + rows, err := c.conn.QueryEx(ctx, name, nil, args...) if err != nil { fmt.Println(err) return nil, err diff --git a/stress_test.go b/stress_test.go index 47a3f4d6..93752c29 100644 --- a/stress_test.go +++ b/stress_test.go @@ -49,8 +49,8 @@ func TestStressConnPool(t *testing.T) { {"listenAndPoolUnlistens", listenAndPoolUnlistens}, {"reset", func(p *pgx.ConnPool, n int) error { p.Reset(); return nil }}, {"poolPrepareUseAndDeallocate", poolPrepareUseAndDeallocate}, - {"canceledQueryContext", canceledQueryContext}, - {"canceledExecContext", canceledExecContext}, + {"canceledQueryExContext", canceledQueryExContext}, + {"canceledExecExContext", canceledExecExContext}, } actionCount := 1000 @@ -317,14 +317,14 @@ func txMultipleQueries(pool *pgx.ConnPool, actionNum int) error { return tx.Commit() } -func canceledQueryContext(pool *pgx.ConnPool, actionNum int) error { +func canceledQueryExContext(pool *pgx.ConnPool, actionNum int) error { ctx, cancelFunc := context.WithCancel(context.Background()) go func() { time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond) cancelFunc() }() - rows, err := pool.QueryContext(ctx, "select pg_sleep(2)") + rows, err := pool.QueryEx(ctx, "select pg_sleep(2)", nil) if err == context.Canceled { return nil } else if err != nil { @@ -342,14 +342,14 @@ func canceledQueryContext(pool *pgx.ConnPool, actionNum int) error { return nil } -func canceledExecContext(pool *pgx.ConnPool, actionNum int) error { +func canceledExecExContext(pool *pgx.ConnPool, actionNum int) error { ctx, cancelFunc := context.WithCancel(context.Background()) go func() { time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond) cancelFunc() }() - _, err := pool.ExecContext(ctx, "select pg_sleep(2)") + _, err := pool.ExecEx(ctx, "select pg_sleep(2)", nil) if err != context.Canceled { return fmt.Errorf("Expected context.Canceled error, got %v", err) } diff --git a/values.go b/values.go index 71c4cc5c..3565df34 100644 --- a/values.go +++ b/values.go @@ -4,7 +4,9 @@ import ( "bytes" "database/sql/driver" "fmt" + "math" "reflect" + "time" "github.com/jackc/pgx/pgtype" ) @@ -22,6 +24,80 @@ func (e SerializationError) Error() string { return string(e) } +func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, error) { + if arg == nil { + return nil, nil + } + + switch arg := arg.(type) { + case driver.Valuer: + return arg.Value() + case pgtype.TextEncoder: + buf := &bytes.Buffer{} + null, err := arg.EncodeText(ci, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + return buf.String(), nil + case int64: + return arg, nil + case float64: + return arg, nil + case bool: + return arg, nil + case time.Time: + return arg, nil + case string: + return arg, nil + case []byte: + return arg, nil + case int8: + return int64(arg), nil + case int16: + return int64(arg), nil + case int32: + return int64(arg), nil + case int: + return int64(arg), nil + case uint8: + return int64(arg), nil + case uint16: + return int64(arg), nil + case uint32: + return int64(arg), nil + case uint64: + if arg > math.MaxInt64 { + return nil, fmt.Errorf("arg too big for int64: %v", arg) + } + return int64(arg), nil + case uint: + if arg > math.MaxInt64 { + return nil, fmt.Errorf("arg too big for int64: %v", arg) + } + return int64(arg), nil + case float32: + return float64(arg), nil + } + + refVal := reflect.ValueOf(arg) + + if refVal.Kind() == reflect.Ptr { + if refVal.IsNil() { + return nil, nil + } + arg = refVal.Elem().Interface() + return convertSimpleArgument(ci, arg) + } + + if strippedArg, ok := stripNamedType(&refVal); ok { + return convertSimpleArgument(ci, strippedArg) + } + return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg)) +} + func encodePreparedStatementArgument(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error { if arg == nil { wbuf.WriteInt32(-1) From 76c0b9ee9084fbcfa4875579afcd4d488814a169 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 11 Apr 2017 20:16:41 -0500 Subject: [PATCH 160/264] Skip line tests on when server version < PG 9.4 --- .travis.yml | 1 + pgtype/line_test.go | 14 ++++++++++++++ 2 files changed, 15 insertions(+) diff --git a/.travis.yml b/.travis.yml index 069cfcb6..a60a324e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -54,6 +54,7 @@ install: - go get -u github.com/jackc/fake - go get -u github.com/jackc/pgmock/pgmsg - go get -u github.com/lib/pq + - go get -u github.com/hashicorp/go-version script: - go test -v -race ./... diff --git a/pgtype/line_test.go b/pgtype/line_test.go index 6d3b02e1..995eaad5 100644 --- a/pgtype/line_test.go +++ b/pgtype/line_test.go @@ -3,10 +3,24 @@ package pgtype_test import ( "testing" + version "github.com/hashicorp/go-version" "github.com/jackc/pgx/pgtype" ) func TestLineTranscode(t *testing.T) { + conn := mustConnectPgx(t) + serverVersion, err := version.NewVersion(conn.RuntimeParams["server_version"]) + if err != nil { + t.Fatalf("cannot get server version: %v", err) + } + mustClose(t, conn) + + minVersion := version.Must(version.NewVersion("9.4")) + + if serverVersion.LessThan(minVersion) { + t.Skipf("Skipping line test for server version %v", serverVersion) + } + testSuccessfulTranscode(t, "line", []interface{}{ &pgtype.Line{ A: 1.23, B: 4.56, C: 7.89, From ccfff83d1a2c066e88ccaf90accb5b81d2095341 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 11 Apr 2017 20:38:18 -0500 Subject: [PATCH 161/264] Use circle in simple protocol test line is not supported PG 9.3 and below. --- query_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/query_test.go b/query_test.go index 66660ba1..c1ca480a 100644 --- a/query_test.go +++ b/query_test.go @@ -1225,11 +1225,11 @@ func TestConnSimpleProtocol(t *testing.T) { // Test high-level type { - expected := pgtype.Line{A: 1, B: 2, C: 1.5, Status: pgtype.Present} + expected := pgtype.Circle{P: pgtype.Vec2{1, 2}, R: 1.5, Status: pgtype.Present} actual := expected err := conn.QueryRowEx( context.Background(), - "select $1::line", + "select $1::circle", &pgx.QueryExOptions{SimpleProtocol: true}, &expected, ).Scan(&actual) From 932ab58cf7eb6c665d40ebbaeae91973b95fc870 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 12 Apr 2017 07:46:25 -0500 Subject: [PATCH 162/264] Remove ValueReader --- query.go | 72 +++++++++------------ v3.md | 6 +- value_reader.go | 166 ------------------------------------------------ 3 files changed, 36 insertions(+), 208 deletions(-) delete mode 100644 value_reader.go diff --git a/query.go b/query.go index 72f987b4..f7d8ed19 100644 --- a/query.go +++ b/query.go @@ -43,7 +43,6 @@ type Rows struct { conn *Conn mr *msgReader fields []FieldDescription - vr ValueReader rowCount int columnIdx int err error @@ -114,7 +113,6 @@ func (rows *Rows) Next() bool { rows.rowCount++ rows.columnIdx = 0 - rows.vr = ValueReader{} for { t, r, err := rows.conn.rxMsg() @@ -163,24 +161,23 @@ func (rows *Rows) Conn() *Conn { return rows.conn } -func (rows *Rows) nextColumn() (*ValueReader, bool) { +func (rows *Rows) nextColumn() ([]byte, *FieldDescription, bool) { if rows.closed { - return nil, false + return nil, nil, false } if len(rows.fields) <= rows.columnIdx { rows.Fatal(ProtocolError("No next column available")) - return nil, false - } - - if rows.vr.Len() > 0 { - rows.mr.readBytes(rows.vr.Len()) + return nil, nil, false } fd := &rows.fields[rows.columnIdx] rows.columnIdx++ size := rows.mr.readInt32() - rows.vr = ValueReader{mr: rows.mr, fd: fd, valueBytesRemaining: size} - return &rows.vr, true + var buf []byte + if size >= 0 { + buf = rows.mr.readBytes(size) + } + return buf, fd, true } type scanArgError struct { @@ -204,49 +201,49 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { } for i, d := range dest { - vr, _ := rows.nextColumn() + buf, fd, _ := rows.nextColumn() if d == nil { continue } - if s, ok := d.(pgtype.BinaryDecoder); ok && vr.Type().FormatCode == BinaryFormatCode { - err = s.DecodeBinary(rows.conn.ConnInfo, vr.bytes()) + if s, ok := d.(pgtype.BinaryDecoder); ok && fd.FormatCode == BinaryFormatCode { + err = s.DecodeBinary(rows.conn.ConnInfo, buf) if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } - } else if s, ok := d.(pgtype.TextDecoder); ok && vr.Type().FormatCode == TextFormatCode { - err = s.DecodeText(rows.conn.ConnInfo, vr.bytes()) + } else if s, ok := d.(pgtype.TextDecoder); ok && fd.FormatCode == TextFormatCode { + err = s.DecodeText(rows.conn.ConnInfo, buf) if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } } else { - if dt, ok := rows.conn.ConnInfo.DataTypeForOid(vr.Type().DataType); ok { + if dt, ok := rows.conn.ConnInfo.DataTypeForOid(fd.DataType); ok { value := dt.Value - switch vr.Type().FormatCode { + switch fd.FormatCode { case TextFormatCode: if textDecoder, ok := value.(pgtype.TextDecoder); ok { - err = textDecoder.DecodeText(rows.conn.ConnInfo, vr.bytes()) + err = textDecoder.DecodeText(rows.conn.ConnInfo, buf) if err != nil { - vr.Fatal(err) + rows.Fatal(scanArgError{col: i, err: err}) } } else { - vr.Fatal(fmt.Errorf("%T is not a pgtype.TextDecoder", value)) + rows.Fatal(scanArgError{col: i, err: fmt.Errorf("%T is not a pgtype.TextDecoder", value)}) } case BinaryFormatCode: if binaryDecoder, ok := value.(pgtype.BinaryDecoder); ok { - err = binaryDecoder.DecodeBinary(rows.conn.ConnInfo, vr.bytes()) + err = binaryDecoder.DecodeBinary(rows.conn.ConnInfo, buf) if err != nil { - vr.Fatal(err) + rows.Fatal(scanArgError{col: i, err: err}) } } else { - vr.Fatal(fmt.Errorf("%T is not a pgtype.BinaryDecoder", value)) + rows.Fatal(scanArgError{col: i, err: fmt.Errorf("%T is not a pgtype.BinaryDecoder", value)}) } default: - vr.Fatal(fmt.Errorf("unknown format code: %v", vr.Type().FormatCode)) + rows.Fatal(scanArgError{col: i, err: fmt.Errorf("unknown format code: %v", fd.FormatCode)}) } - if vr.Err() == nil { + if rows.Err() == nil { if scanner, ok := d.(sql.Scanner); ok { sqlSrc, err := pgtype.DatabaseSQLValue(rows.conn.ConnInfo, value) if err != nil { @@ -257,16 +254,13 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { rows.Fatal(scanArgError{col: i, err: err}) } } else if err := value.AssignTo(d); err != nil { - vr.Fatal(err) + rows.Fatal(scanArgError{col: i, err: err}) } } } else { - rows.Fatal(scanArgError{col: i, err: fmt.Errorf("unknown oid: %v", vr.Type().DataType)}) + rows.Fatal(scanArgError{col: i, err: fmt.Errorf("unknown oid: %v", fd.DataType)}) } } - if vr.Err() != nil { - rows.Fatal(scanArgError{col: i, err: vr.Err()}) - } if rows.Err() != nil { return rows.Err() @@ -285,23 +279,23 @@ func (rows *Rows) Values() ([]interface{}, error) { values := make([]interface{}, 0, len(rows.fields)) for range rows.fields { - vr, _ := rows.nextColumn() + buf, fd, _ := rows.nextColumn() - if vr.Len() == -1 { + if buf == nil { values = append(values, nil) continue } - if dt, ok := rows.conn.ConnInfo.DataTypeForOid(vr.Type().DataType); ok { + if dt, ok := rows.conn.ConnInfo.DataTypeForOid(fd.DataType); ok { value := dt.Value - switch vr.Type().FormatCode { + switch fd.FormatCode { case TextFormatCode: decoder := value.(pgtype.TextDecoder) if decoder == nil { decoder = &pgtype.GenericText{} } - err := decoder.DecodeText(rows.conn.ConnInfo, vr.bytes()) + err := decoder.DecodeText(rows.conn.ConnInfo, buf) if err != nil { rows.Fatal(err) } @@ -311,7 +305,7 @@ func (rows *Rows) Values() ([]interface{}, error) { if decoder == nil { decoder = &pgtype.GenericBinary{} } - err := decoder.DecodeBinary(rows.conn.ConnInfo, vr.bytes()) + err := decoder.DecodeBinary(rows.conn.ConnInfo, buf) if err != nil { rows.Fatal(err) } @@ -323,10 +317,6 @@ func (rows *Rows) Values() ([]interface{}, error) { rows.Fatal(errors.New("Unknown type")) } - if vr.Err() != nil { - rows.Fatal(vr.Err()) - } - if rows.Err() != nil { return nil, rows.Err() } diff --git a/v3.md b/v3.md index 20038938..d9017890 100644 --- a/v3.md +++ b/v3.md @@ -32,6 +32,10 @@ No longer can read raw bytes of any value into a []byte. Use pgtype.GenericBinar OID constants moved from pgx to pgtype package +Removed ValueReader + +Replaced Scanner, Encoder, and PgxScanner interfaces with pgtype system + ## TODO / Possible / Investigate Organize errors better @@ -46,7 +50,7 @@ Remove names from prepared statements - use database/sql style objects Better way of handling text/binary protocol choice than pgx.DefaultTypeFormats or manually editing a PreparedStatement. Possibly an optional part of preparing a statement is specifying the format and/or a decoder. Or maybe it is part of a QueryEx call... Could be very interesting to make encoding and decoding possible without being a method of the type. This could drastically clean up those huge type switches. -dValueReader / msgReader cleanup +msgReader cleanup Make easier / possible to mock Conn or ConnPool (https://github.com/jackc/pgx/pull/162) diff --git a/value_reader.go b/value_reader.go deleted file mode 100644 index fea21d49..00000000 --- a/value_reader.go +++ /dev/null @@ -1,166 +0,0 @@ -package pgx - -import ( - "errors" - - "github.com/jackc/pgx/pgtype" -) - -// ValueReader is used by the Scanner interface to decode values. -type ValueReader struct { - mr *msgReader - fd *FieldDescription - valueBytesRemaining int32 - err error -} - -// Err returns any error that the ValueReader has experienced -func (r *ValueReader) Err() error { - return r.err -} - -// Fatal tells r that a Fatal error has occurred -func (r *ValueReader) Fatal(err error) { - r.err = err -} - -// Len returns the number of unread bytes -func (r *ValueReader) Len() int32 { - return r.valueBytesRemaining -} - -// Type returns the *FieldDescription of the value -func (r *ValueReader) Type() *FieldDescription { - return r.fd -} - -func (r *ValueReader) ReadByte() byte { - if r.err != nil { - return 0 - } - - r.valueBytesRemaining-- - if r.valueBytesRemaining < 0 { - r.Fatal(errors.New("read past end of value")) - return 0 - } - - return r.mr.readByte() -} - -func (r *ValueReader) ReadInt16() int16 { - if r.err != nil { - return 0 - } - - r.valueBytesRemaining -= 2 - if r.valueBytesRemaining < 0 { - r.Fatal(errors.New("read past end of value")) - return 0 - } - - return r.mr.readInt16() -} - -func (r *ValueReader) ReadUint16() uint16 { - if r.err != nil { - return 0 - } - - r.valueBytesRemaining -= 2 - if r.valueBytesRemaining < 0 { - r.Fatal(errors.New("read past end of value")) - return 0 - } - - return r.mr.readUint16() -} - -func (r *ValueReader) ReadInt32() int32 { - if r.err != nil { - return 0 - } - - r.valueBytesRemaining -= 4 - if r.valueBytesRemaining < 0 { - r.Fatal(errors.New("read past end of value")) - return 0 - } - - return r.mr.readInt32() -} - -func (r *ValueReader) ReadUint32() uint32 { - if r.err != nil { - return 0 - } - - r.valueBytesRemaining -= 4 - if r.valueBytesRemaining < 0 { - r.Fatal(errors.New("read past end of value")) - return 0 - } - - return r.mr.readUint32() -} - -func (r *ValueReader) ReadInt64() int64 { - if r.err != nil { - return 0 - } - - r.valueBytesRemaining -= 8 - if r.valueBytesRemaining < 0 { - r.Fatal(errors.New("read past end of value")) - return 0 - } - - return r.mr.readInt64() -} - -func (r *ValueReader) ReadOid() pgtype.Oid { - return pgtype.Oid(r.ReadUint32()) -} - -// ReadString reads count bytes and returns as string -func (r *ValueReader) ReadString(count int32) string { - if r.err != nil { - return "" - } - - r.valueBytesRemaining -= count - if r.valueBytesRemaining < 0 { - r.Fatal(errors.New("read past end of value")) - return "" - } - - return r.mr.readString(count) -} - -// ReadBytes reads count bytes and returns as []byte -func (r *ValueReader) ReadBytes(count int32) []byte { - if r.err != nil { - return nil - } - - if count < 0 { - r.Fatal(errors.New("count must not be negative")) - return nil - } - - r.valueBytesRemaining -= count - if r.valueBytesRemaining < 0 { - r.Fatal(errors.New("read past end of value")) - return nil - } - - return r.mr.readBytes(count) -} - -// bytes is a compatibility function for pgtype.TextDecoder and pgtype.BinaryDecoder -func (r *ValueReader) bytes() []byte { - if r.Len() >= 0 { - return r.ReadBytes(r.Len()) - } - return nil -} From adb54d06ce15980e31951713ee44865f9a2304bb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 12 Apr 2017 18:03:43 -0500 Subject: [PATCH 163/264] Tweak timing sensitive test --- conn_pool_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conn_pool_test.go b/conn_pool_test.go index 9d03fad3..825638b6 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -297,7 +297,7 @@ func TestPoolWithoutAcquireTimeoutSet(t *testing.T) { // ... then try to consume 1 more. It should hang forever. // To unblock it we release the previously taken connection in a goroutine. stopDeadWaitTimeout := 5 * time.Second - timer := time.AfterFunc(stopDeadWaitTimeout, func() { + timer := time.AfterFunc(stopDeadWaitTimeout+100*time.Millisecond, func() { releaseAllConnections(pool, allConnections) }) defer timer.Stop() From fe7d9d34622c6c3f0aaee4293ecef0f0971e219d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 13 Apr 2017 21:54:04 -0500 Subject: [PATCH 164/264] Add MarshalJSON to a few types --- pgtype/int2.go | 13 +++++++++++++ pgtype/int4.go | 13 +++++++++++++ pgtype/int8.go | 13 +++++++++++++ pgtype/pgtype.go | 1 + pgtype/text.go | 14 ++++++++++++++ pgtype/varchar.go | 4 ++++ 6 files changed, 58 insertions(+) diff --git a/pgtype/int2.go b/pgtype/int2.go index 3bcac63c..0cb6ef82 100644 --- a/pgtype/int2.go +++ b/pgtype/int2.go @@ -195,3 +195,16 @@ func (src Int2) Value() (driver.Value, error) { return nil, errUndefined } } + +func (src Int2) MarshalJSON() ([]byte, error) { + switch src.Status { + case Present: + return []byte(strconv.FormatInt(int64(src.Int), 10)), nil + case Null: + return []byte("null"), nil + case Undefined: + return []byte("undefined"), nil + } + + return nil, errBadStatus +} diff --git a/pgtype/int4.go b/pgtype/int4.go index 5069dab4..4a5bca51 100644 --- a/pgtype/int4.go +++ b/pgtype/int4.go @@ -186,3 +186,16 @@ func (src Int4) Value() (driver.Value, error) { return nil, errUndefined } } + +func (src Int4) MarshalJSON() ([]byte, error) { + switch src.Status { + case Present: + return []byte(strconv.FormatInt(int64(src.Int), 10)), nil + case Null: + return []byte("null"), nil + case Undefined: + return []byte("undefined"), nil + } + + return nil, errBadStatus +} diff --git a/pgtype/int8.go b/pgtype/int8.go index cf701dc6..0cc3545d 100644 --- a/pgtype/int8.go +++ b/pgtype/int8.go @@ -172,3 +172,16 @@ func (src Int8) Value() (driver.Value, error) { return nil, errUndefined } } + +func (src Int8) MarshalJSON() ([]byte, error) { + switch src.Status { + case Present: + return []byte(strconv.FormatInt(src.Int, 10)), nil + case Null: + return []byte("null"), nil + case Undefined: + return []byte("undefined"), nil + } + + return nil, errBadStatus +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 338afc9b..27a1a091 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -129,6 +129,7 @@ type TextEncoder interface { } var errUndefined = errors.New("cannot encode status undefined") +var errBadStatus = errors.New("invalid status") type DataType struct { Value Value diff --git a/pgtype/text.go b/pgtype/text.go index 482c9023..62158b09 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -2,6 +2,7 @@ package pgtype import ( "database/sql/driver" + "encoding/json" "fmt" "io" ) @@ -134,3 +135,16 @@ func (src Text) Value() (driver.Value, error) { return nil, errUndefined } } + +func (src Text) MarshalJSON() ([]byte, error) { + switch src.Status { + case Present: + return json.Marshal(src.String) + case Null: + return []byte("null"), nil + case Undefined: + return []byte("undefined"), nil + } + + return nil, errBadStatus +} diff --git a/pgtype/varchar.go b/pgtype/varchar.go index f25ada5d..6c137b9a 100644 --- a/pgtype/varchar.go +++ b/pgtype/varchar.go @@ -49,3 +49,7 @@ func (dst *Varchar) Scan(src interface{}) error { func (src Varchar) Value() (driver.Value, error) { return (Text)(src).Value() } + +func (src Varchar) MarshalJSON() ([]byte, error) { + return (Text)(src).MarshalJSON() +} From e4451b47b257cd89bda942c349e5ad70855a13f3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 14 Apr 2017 12:18:49 -0500 Subject: [PATCH 165/264] Add shopspring.Numeric This adds PostgreSQL numeric mapping to and from github.com/shopspring/decimal. Makes pgtype.NullAssignTo public as external types need this functionality. Begin extraction of pgtype testing functionality so it can easily be used by external types. --- pgtype/aclitem.go | 2 +- pgtype/aclitem_array.go | 2 +- pgtype/bool.go | 2 +- pgtype/bool_array.go | 2 +- pgtype/bytea.go | 2 +- pgtype/bytea_array.go | 2 +- pgtype/cidr_array.go | 2 +- pgtype/convert.go | 2 +- pgtype/date.go | 2 +- pgtype/date_array.go | 2 +- pgtype/ext/shopspring-numeric/decimal.go | 320 ++++++++++++++++++ pgtype/ext/shopspring-numeric/decimal_test.go | 281 +++++++++++++++ pgtype/float4_array.go | 2 +- pgtype/float8_array.go | 2 +- pgtype/hstore.go | 2 +- pgtype/hstore_array.go | 2 +- pgtype/inet.go | 2 +- pgtype/inet_array.go | 2 +- pgtype/int2_array.go | 2 +- pgtype/int4_array.go | 2 +- pgtype/int8_array.go | 2 +- pgtype/interval.go | 2 +- pgtype/macaddr.go | 2 +- pgtype/numeric.go | 2 +- pgtype/numeric_array.go | 2 +- pgtype/record.go | 2 +- pgtype/testutil/testutil.go | 298 ++++++++++++++++ pgtype/text.go | 2 +- pgtype/text_array.go | 2 +- pgtype/timestamp.go | 2 +- pgtype/timestamp_array.go | 2 +- pgtype/timestamptz.go | 2 +- pgtype/timestamptz_array.go | 2 +- pgtype/typed_array.go.erb | 2 +- pgtype/uuid.go | 2 +- pgtype/varchar_array.go | 2 +- v3.md | 2 + 37 files changed, 934 insertions(+), 33 deletions(-) create mode 100644 pgtype/ext/shopspring-numeric/decimal.go create mode 100644 pgtype/ext/shopspring-numeric/decimal_test.go create mode 100644 pgtype/testutil/testutil.go diff --git a/pgtype/aclitem.go b/pgtype/aclitem.go index 77e385e6..3ccf8318 100644 --- a/pgtype/aclitem.go +++ b/pgtype/aclitem.go @@ -67,7 +67,7 @@ func (src *Aclitem) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/aclitem_array.go b/pgtype/aclitem_array.go index 20a7636a..7ef76573 100644 --- a/pgtype/aclitem_array.go +++ b/pgtype/aclitem_array.go @@ -78,7 +78,7 @@ func (src *AclitemArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/bool.go b/pgtype/bool.go index 736d19cf..1ebf590b 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -56,7 +56,7 @@ func (src *Bool) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/bool_array.go b/pgtype/bool_array.go index 4705d734..468f6816 100644 --- a/pgtype/bool_array.go +++ b/pgtype/bool_array.go @@ -79,7 +79,7 @@ func (src *BoolArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/bytea.go b/pgtype/bytea.go index 9f0266e7..8bf5de2b 100644 --- a/pgtype/bytea.go +++ b/pgtype/bytea.go @@ -61,7 +61,7 @@ func (src *Bytea) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/bytea_array.go b/pgtype/bytea_array.go index 268364c1..4aa2b862 100644 --- a/pgtype/bytea_array.go +++ b/pgtype/bytea_array.go @@ -79,7 +79,7 @@ func (src *ByteaArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/cidr_array.go b/pgtype/cidr_array.go index 6643bb47..96d912ae 100644 --- a/pgtype/cidr_array.go +++ b/pgtype/cidr_array.go @@ -108,7 +108,7 @@ func (src *CidrArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/convert.go b/pgtype/convert.go index 4fba8430..2b406426 100644 --- a/pgtype/convert.go +++ b/pgtype/convert.go @@ -342,7 +342,7 @@ func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error { return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) } -func nullAssignTo(dst interface{}) error { +func NullAssignTo(dst interface{}) error { dstPtr := reflect.ValueOf(dst) // AssignTo dst must always be a pointer diff --git a/pgtype/date.go b/pgtype/date.go index 7dd2c4f0..34753f05 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -70,7 +70,7 @@ func (src *Date) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/date_array.go b/pgtype/date_array.go index f58de011..f24bf6b9 100644 --- a/pgtype/date_array.go +++ b/pgtype/date_array.go @@ -80,7 +80,7 @@ func (src *DateArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/ext/shopspring-numeric/decimal.go b/pgtype/ext/shopspring-numeric/decimal.go new file mode 100644 index 00000000..9c7e316b --- /dev/null +++ b/pgtype/ext/shopspring-numeric/decimal.go @@ -0,0 +1,320 @@ +package numeric + +import ( + "bytes" + "database/sql/driver" + "errors" + "fmt" + "io" + "strconv" + + "github.com/jackc/pgx/pgtype" + "github.com/shopspring/decimal" +) + +var errUndefined = errors.New("cannot encode status undefined") + +type Numeric struct { + Decimal decimal.Decimal + Status pgtype.Status +} + +func (dst *Numeric) Set(src interface{}) error { + if src == nil { + *dst = Numeric{Status: pgtype.Null} + return nil + } + + switch value := src.(type) { + case decimal.Decimal: + *dst = Numeric{Decimal: value, Status: pgtype.Present} + case float32: + *dst = Numeric{Decimal: decimal.NewFromFloat(float64(value)), Status: pgtype.Present} + case float64: + *dst = Numeric{Decimal: decimal.NewFromFloat(value), Status: pgtype.Present} + case int8: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case uint8: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case int16: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case uint16: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case int32: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case uint32: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case int64: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case uint64: + // uint64 could be greater than int64 so convert to string then to decimal + dec, err := decimal.NewFromString(strconv.FormatUint(value, 10)) + if err != nil { + return err + } + *dst = Numeric{Decimal: dec, Status: pgtype.Present} + case int: + *dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present} + case uint: + // uint could be greater than int64 so convert to string then to decimal + dec, err := decimal.NewFromString(strconv.FormatUint(uint64(value), 10)) + if err != nil { + return err + } + *dst = Numeric{Decimal: dec, Status: pgtype.Present} + case string: + dec, err := decimal.NewFromString(value) + if err != nil { + return err + } + *dst = Numeric{Decimal: dec, Status: pgtype.Present} + default: + // If all else fails see if pgtype.Numeric can handle it. If so, translate through that. + num := &pgtype.Numeric{} + if err := num.Set(value); err != nil { + return fmt.Errorf("cannot convert %v to Numeric", value) + } + + buf := &bytes.Buffer{} + if _, err := num.EncodeText(nil, buf); err != nil { + return fmt.Errorf("cannot convert %v to Numeric", value) + } + + dec, err := decimal.NewFromString(buf.String()) + if err != nil { + return fmt.Errorf("cannot convert %v to Numeric", value) + } + *dst = Numeric{Decimal: dec, Status: pgtype.Present} + } + + return nil +} + +func (dst *Numeric) Get() interface{} { + switch dst.Status { + case pgtype.Present: + return dst.Decimal + case pgtype.Null: + return nil + default: + return dst.Status + } +} + +func (src *Numeric) AssignTo(dst interface{}) error { + switch src.Status { + case pgtype.Present: + switch v := dst.(type) { + case *decimal.Decimal: + *v = src.Decimal + case *float32: + f, _ := src.Decimal.Float64() + *v = float32(f) + case *float64: + f, _ := src.Decimal.Float64() + *v = f + case *int: + if src.Decimal.Exponent() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseInt(src.Decimal.String(), 10, strconv.IntSize) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = int(n) + case *int8: + if src.Decimal.Exponent() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseInt(src.Decimal.String(), 10, 8) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = int8(n) + case *int16: + if src.Decimal.Exponent() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseInt(src.Decimal.String(), 10, 16) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = int16(n) + case *int32: + if src.Decimal.Exponent() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseInt(src.Decimal.String(), 10, 32) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = int32(n) + case *int64: + if src.Decimal.Exponent() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseInt(src.Decimal.String(), 10, 64) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = int64(n) + case *uint: + if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseUint(src.Decimal.String(), 10, strconv.IntSize) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = uint(n) + case *uint8: + if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseUint(src.Decimal.String(), 10, 8) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = uint8(n) + case *uint16: + if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseUint(src.Decimal.String(), 10, 16) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = uint16(n) + case *uint32: + if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseUint(src.Decimal.String(), 10, 32) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = uint32(n) + case *uint64: + if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + n, err := strconv.ParseUint(src.Decimal.String(), 10, 64) + if err != nil { + return fmt.Errorf("cannot convert %v to %T", dst, *v) + } + *v = uint64(n) + default: + if nextDst, retry := pgtype.GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case pgtype.Null: + return pgtype.NullAssignTo(dst) + } + + return nil +} + +func (dst *Numeric) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + if src == nil { + *dst = Numeric{Status: pgtype.Null} + return nil + } + + dec, err := decimal.NewFromString(string(src)) + if err != nil { + return err + } + + *dst = Numeric{Decimal: dec, Status: pgtype.Present} + return nil +} + +func (dst *Numeric) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + if src == nil { + *dst = Numeric{Status: pgtype.Null} + return nil + } + + // For now at least, implement this in terms of pgtype.Numeric + + num := &pgtype.Numeric{} + if err := num.DecodeBinary(ci, src); err != nil { + return err + } + + buf := &bytes.Buffer{} + if _, err := num.EncodeText(ci, buf); err != nil { + return err + } + + dec, err := decimal.NewFromString(buf.String()) + if err != nil { + return err + } + + *dst = Numeric{Decimal: dec, Status: pgtype.Present} + + return nil +} + +func (src *Numeric) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case pgtype.Null: + return true, nil + case pgtype.Undefined: + return false, errUndefined + } + + _, err := io.WriteString(w, src.Decimal.String()) + return false, err +} + +func (src *Numeric) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case pgtype.Null: + return true, nil + case pgtype.Undefined: + return false, errUndefined + } + + // For now at least, implement this in terms of pgtype.Numeric + num := &pgtype.Numeric{} + if err := num.DecodeText(ci, []byte(src.Decimal.String())); err != nil { + return false, err + } + + return num.EncodeBinary(ci, w) +} + +// Scan implements the database/sql Scanner interface. +func (dst *Numeric) Scan(src interface{}) error { + if src == nil { + *dst = Numeric{Status: pgtype.Null} + return nil + } + + switch src := src.(type) { + case float64: + *dst = Numeric{Decimal: decimal.NewFromFloat(src), Status: pgtype.Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Numeric) Value() (driver.Value, error) { + switch src.Status { + case pgtype.Present: + return src.Decimal.Value() + case pgtype.Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/ext/shopspring-numeric/decimal_test.go b/pgtype/ext/shopspring-numeric/decimal_test.go new file mode 100644 index 00000000..50c0fb8b --- /dev/null +++ b/pgtype/ext/shopspring-numeric/decimal_test.go @@ -0,0 +1,281 @@ +package numeric_test + +import ( + "fmt" + "math/big" + "math/rand" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + shopspring "github.com/jackc/pgx/pgtype/ext/shopspring-numeric" + "github.com/jackc/pgx/pgtype/testutil" + "github.com/shopspring/decimal" +) + +func mustParseDecimal(t *testing.T, src string) decimal.Decimal { + dec, err := decimal.NewFromString(src) + if err != nil { + t.Fatal(err) + } + return dec +} + +func TestNumericNormalize(t *testing.T) { + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + { + SQL: "select '0'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Status: pgtype.Present}, + }, + { + SQL: "select '1'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}, + }, + { + SQL: "select '10.00'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "10.00"), Status: pgtype.Present}, + }, + { + SQL: "select '1e-3'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Status: pgtype.Present}, + }, + { + SQL: "select '-1'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, + }, + { + SQL: "select '10000'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "10000"), Status: pgtype.Present}, + }, + { + SQL: "select '3.14'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Status: pgtype.Present}, + }, + { + SQL: "select '1.1'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "1.1"), Status: pgtype.Present}, + }, + { + SQL: "select '100010001'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001"), Status: pgtype.Present}, + }, + { + SQL: "select '100010001.0001'::numeric", + Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001.0001"), Status: pgtype.Present}, + }, + { + SQL: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric", + Value: shopspring.Numeric{ + Decimal: mustParseDecimal(t, "4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981"), + Status: pgtype.Present, + }, + }, + { + SQL: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric", + Value: shopspring.Numeric{ + Decimal: mustParseDecimal(t, "0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), + Status: pgtype.Present, + }, + }, + { + SQL: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric", + Value: shopspring.Numeric{ + Decimal: mustParseDecimal(t, "0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123"), + Status: pgtype.Present, + }, + }, + }) +} + +func TestNumericTranscode(t *testing.T) { + testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "100000"), Status: pgtype.Present}, + + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.1"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.01"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.0001"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00001"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.000001"), Status: pgtype.Present}, + + &shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000123"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.000000123"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.0000000123"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000000123"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001234567890123456789"), Status: pgtype.Present}, + &shopspring.Numeric{Decimal: mustParseDecimal(t, "4309132809320932980457137401234890237489238912983572189348951289375283573984571892758234678903467889512893489128589347891272139.8489235871258912789347891235879148795891238915678189467128957812395781238579189025891238901583915890128973578957912385798125789012378905238905471598123758923478294374327894237892234"), Status: pgtype.Present}, + &shopspring.Numeric{Status: pgtype.Null}, + }, func(aa, bb interface{}) bool { + a := aa.(shopspring.Numeric) + b := bb.(shopspring.Numeric) + + return a.Status == b.Status && a.Decimal.Equal(b.Decimal) + }) + +} + +func TestNumericTranscodeFuzz(t *testing.T) { + r := rand.New(rand.NewSource(0)) + max := &big.Int{} + max.SetString("9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10) + + values := make([]interface{}, 0, 2000) + for i := 0; i < 500; i++ { + num := fmt.Sprintf("%s.%s", (&big.Int{}).Rand(r, max).String(), (&big.Int{}).Rand(r, max).String()) + negNum := "-" + num + values = append(values, &shopspring.Numeric{Decimal: mustParseDecimal(t, num), Status: pgtype.Present}) + values = append(values, &shopspring.Numeric{Decimal: mustParseDecimal(t, negNum), Status: pgtype.Present}) + } + + testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", values, + func(aa, bb interface{}) bool { + a := aa.(shopspring.Numeric) + b := bb.(shopspring.Numeric) + + return a.Status == b.Status && a.Decimal.Equal(b.Decimal) + }) +} + +func TestNumericSet(t *testing.T) { + type _int8 int8 + + successfulTests := []struct { + source interface{} + result *shopspring.Numeric + }{ + {source: float32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: float64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: int16(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: int32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: int64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: int8(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, + {source: int16(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, + {source: int32(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, + {source: int64(-1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}}, + {source: uint8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: uint16(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: uint32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: uint64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: "1", result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: _int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: float64(1000), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1000"), Status: pgtype.Present}}, + {source: float64(1234), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1234"), Status: pgtype.Present}}, + {source: float64(12345678900), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "12345678900"), Status: pgtype.Present}}, + {source: float64(12345.678901), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "12345.678901"), Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + r := &shopspring.Numeric{} + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !(r.Status == tt.result.Status && r.Decimal.Equal(tt.result.Decimal)) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestNumericAssignTo(t *testing.T) { + type _int8 int8 + + var i8 int8 + var i16 int16 + var i32 int32 + var i64 int64 + var i int + var ui8 uint8 + var ui16 uint16 + var ui32 uint32 + var ui64 uint64 + var ui uint + var pi8 *int8 + var _i8 _int8 + var _pi8 *_int8 + var f32 float32 + var f64 float64 + var pf32 *float32 + var pf64 *float64 + + simpleTests := []struct { + src *shopspring.Numeric + dst interface{} + expected interface{} + }{ + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &f32, expected: float32(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &f64, expected: float64(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "4.2"), Status: pgtype.Present}, dst: &f32, expected: float32(4.2)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "4.2"), Status: pgtype.Present}, dst: &f64, expected: float64(4.2)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i16, expected: int16(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i32, expected: int32(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i64, expected: int64(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42000"), Status: pgtype.Present}, dst: &i64, expected: int64(42000)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &i, expected: int(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui8, expected: uint8(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui16, expected: uint16(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui64, expected: uint64(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &ui, expected: uint(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, + {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, + {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src *shopspring.Numeric + dst interface{} + expected interface{} + }{ + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &pf32, expected: float32(42)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &pf64, expected: float64(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src *shopspring.Numeric + dst interface{} + }{ + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "150"), Status: pgtype.Present}, dst: &i8}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "40000"), Status: pgtype.Present}, dst: &i16}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui8}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui16}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui32}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui64}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, dst: &ui}, + {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &i32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/float4_array.go b/pgtype/float4_array.go index b9ee4b9e..db1523f0 100644 --- a/pgtype/float4_array.go +++ b/pgtype/float4_array.go @@ -79,7 +79,7 @@ func (src *Float4Array) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/float8_array.go b/pgtype/float8_array.go index d49f18a7..19878bbb 100644 --- a/pgtype/float8_array.go +++ b/pgtype/float8_array.go @@ -79,7 +79,7 @@ func (src *Float8Array) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/hstore.go b/pgtype/hstore.go index b8b0c6f3..5dc78671 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -71,7 +71,7 @@ func (src *Hstore) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/hstore_array.go b/pgtype/hstore_array.go index 097fec7b..e4263f20 100644 --- a/pgtype/hstore_array.go +++ b/pgtype/hstore_array.go @@ -79,7 +79,7 @@ func (src *HstoreArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/inet.go b/pgtype/inet.go index 3e00e2fa..09fce04d 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -90,7 +90,7 @@ func (src *Inet) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/inet_array.go b/pgtype/inet_array.go index a108d75b..4687b145 100644 --- a/pgtype/inet_array.go +++ b/pgtype/inet_array.go @@ -108,7 +108,7 @@ func (src *InetArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/int2_array.go b/pgtype/int2_array.go index bddb5ac2..3506370e 100644 --- a/pgtype/int2_array.go +++ b/pgtype/int2_array.go @@ -107,7 +107,7 @@ func (src *Int2Array) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/int4_array.go b/pgtype/int4_array.go index d5c8f911..e4ec6455 100644 --- a/pgtype/int4_array.go +++ b/pgtype/int4_array.go @@ -107,7 +107,7 @@ func (src *Int4Array) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/int8_array.go b/pgtype/int8_array.go index ae2521fa..6c0dab65 100644 --- a/pgtype/int8_array.go +++ b/pgtype/int8_array.go @@ -107,7 +107,7 @@ func (src *Int8Array) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/interval.go b/pgtype/interval.go index 7eddb10f..20a4a419 100644 --- a/pgtype/interval.go +++ b/pgtype/interval.go @@ -71,7 +71,7 @@ func (src *Interval) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/macaddr.go b/pgtype/macaddr.go index 2d09ff8c..2834d69f 100644 --- a/pgtype/macaddr.go +++ b/pgtype/macaddr.go @@ -67,7 +67,7 @@ func (src *Macaddr) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/numeric.go b/pgtype/numeric.go index a26e8c89..63f99c06 100644 --- a/pgtype/numeric.go +++ b/pgtype/numeric.go @@ -253,7 +253,7 @@ func (src *Numeric) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return nil diff --git a/pgtype/numeric_array.go b/pgtype/numeric_array.go index b147e6a2..3d59a6b0 100644 --- a/pgtype/numeric_array.go +++ b/pgtype/numeric_array.go @@ -107,7 +107,7 @@ func (src *NumericArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/record.go b/pgtype/record.go index 9c42c907..3b315d40 100644 --- a/pgtype/record.go +++ b/pgtype/record.go @@ -62,7 +62,7 @@ func (src *Record) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/testutil/testutil.go b/pgtype/testutil/testutil.go new file mode 100644 index 00000000..610f0710 --- /dev/null +++ b/pgtype/testutil/testutil.go @@ -0,0 +1,298 @@ +package testutil + +import ( + "context" + "database/sql" + "fmt" + "io" + "os" + "reflect" + "testing" + + "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" + _ "github.com/jackc/pgx/stdlib" + _ "github.com/lib/pq" +) + +func mustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { + var sqlDriverName string + switch driverName { + case "github.com/lib/pq": + sqlDriverName = "postgres" + case "github.com/jackc/pgx/stdlib": + sqlDriverName = "pgx" + default: + t.Fatalf("Unknown driver %v", driverName) + } + + db, err := sql.Open(sqlDriverName, os.Getenv("DATABASE_URL")) + if err != nil { + t.Fatal(err) + } + + return db +} + +func mustConnectPgx(t testing.TB) *pgx.Conn { + config, err := pgx.ParseURI(os.Getenv("DATABASE_URL")) + if err != nil { + t.Fatal(err) + } + + conn, err := pgx.Connect(config) + if err != nil { + t.Fatal(err) + } + + return conn +} + +func mustClose(t testing.TB, conn interface { + Close() error +}) { + err := conn.Close() + if err != nil { + t.Fatal(err) + } +} + +type forceTextEncoder struct { + e pgtype.TextEncoder +} + +func (f forceTextEncoder) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { + return f.e.EncodeText(ci, w) +} + +type forceBinaryEncoder struct { + e pgtype.BinaryEncoder +} + +func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { + return f.e.EncodeBinary(ci, w) +} + +func forceEncoder(e interface{}, formatCode int16) interface{} { + switch formatCode { + case pgx.TextFormatCode: + if e, ok := e.(pgtype.TextEncoder); ok { + return forceTextEncoder{e: e} + } + case pgx.BinaryFormatCode: + if e, ok := e.(pgtype.BinaryEncoder); ok { + return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)} + } + } + return nil +} + +func TestSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface{}) { + TestSuccessfulTranscodeEqFunc(t, pgTypeName, values, func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + }) +} + +func TestSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) + } +} + +func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + conn := mustConnectPgx(t) + defer mustClose(t, conn) + + ps, err := conn.Prepare("test", fmt.Sprintf("select $1::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for i, v := range values { + for _, fc := range formats { + ps.FieldDescriptions[0].FormatCode = fc.formatCode + vEncoder := forceEncoder(v, fc.formatCode) + if vEncoder == nil { + t.Logf("Skipping: %#v does not implement %v", v, fc.name) + continue + } + // Derefence value if it is a pointer + derefV := v + refVal := reflect.ValueOf(v) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err := conn.QueryRow("test", forceEncoder(v, fc.formatCode)).Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", fc.name, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface()) + } + } + } +} + +func TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + conn := mustConnectPgx(t) + defer mustClose(t, conn) + + for i, v := range values { + // Derefence value if it is a pointer + derefV := v + refVal := reflect.ValueOf(v) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err := conn.QueryRowEx( + context.Background(), + fmt.Sprintf("select ($1)::%s", pgTypeName), + &pgx.QueryExOptions{SimpleProtocol: true}, + v, + ).Scan(result.Interface()) + if err != nil { + t.Errorf("Simple protocol %d: %v", i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("Simple protocol %d: expected %v, got %v", i, derefV, result.Elem().Interface()) + } + } +} + +func TestDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + conn := mustConnectDatabaseSQL(t, driverName) + defer mustClose(t, conn) + + ps, err := conn.Prepare(fmt.Sprintf("select $1::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + for i, v := range values { + // Derefence value if it is a pointer + derefV := v + refVal := reflect.ValueOf(v) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err := ps.QueryRow(v).Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", driverName, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) + } + } +} + +type NormalizeTest struct { + SQL string + Value interface{} +} + +func TestSuccessfulNormalize(t testing.TB, tests []NormalizeTest) { + TestSuccessfulNormalizeEqFunc(t, tests, func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) + }) +} + +func TestSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { + TestPgxSuccessfulNormalizeEqFunc(t, tests, eqFunc) + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + TestDatabaseSQLSuccessfulNormalizeEqFunc(t, driverName, tests, eqFunc) + } +} + +func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { + conn := mustConnectPgx(t) + defer mustClose(t, conn) + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for i, tt := range tests { + for _, fc := range formats { + psName := fmt.Sprintf("test%d", i) + ps, err := conn.Prepare(psName, tt.SQL) + if err != nil { + t.Fatal(err) + } + + ps.FieldDescriptions[0].FormatCode = fc.formatCode + if forceEncoder(tt.Value, fc.formatCode) == nil { + t.Logf("Skipping: %#v does not implement %v", tt.Value, fc.name) + continue + } + // Derefence value if it is a pointer + derefV := tt.Value + refVal := reflect.ValueOf(tt.Value) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err = conn.QueryRow(psName).Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", fc.name, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface()) + } + } + } +} + +func TestDatabaseSQLSuccessfulNormalizeEqFunc(t testing.TB, driverName string, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { + conn := mustConnectDatabaseSQL(t, driverName) + defer mustClose(t, conn) + + for i, tt := range tests { + ps, err := conn.Prepare(tt.SQL) + if err != nil { + t.Errorf("%d. %v", i, err) + continue + } + + // Derefence value if it is a pointer + derefV := tt.Value + refVal := reflect.ValueOf(tt.Value) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err = ps.QueryRow().Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", driverName, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) + } + } +} diff --git a/pgtype/text.go b/pgtype/text.go index 62158b09..de80dd08 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -71,7 +71,7 @@ func (src *Text) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/text_array.go b/pgtype/text_array.go index 64728048..a6bd4724 100644 --- a/pgtype/text_array.go +++ b/pgtype/text_array.go @@ -79,7 +79,7 @@ func (src *TextArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index 78c6355e..e7bc1c7d 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -74,7 +74,7 @@ func (src *Timestamp) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/timestamp_array.go b/pgtype/timestamp_array.go index 5d08f9cc..2046c387 100644 --- a/pgtype/timestamp_array.go +++ b/pgtype/timestamp_array.go @@ -80,7 +80,7 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index 50370335..ef2d7498 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -75,7 +75,7 @@ func (src *Timestamptz) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/timestamptz_array.go b/pgtype/timestamptz_array.go index 107be06a..fd58d3be 100644 --- a/pgtype/timestamptz_array.go +++ b/pgtype/timestamptz_array.go @@ -80,7 +80,7 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb index 4b8f1a28..2a38ed82 100644 --- a/pgtype/typed_array.go.erb +++ b/pgtype/typed_array.go.erb @@ -77,7 +77,7 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/pgtype/uuid.go b/pgtype/uuid.go index 111bed35..88d2195b 100644 --- a/pgtype/uuid.go +++ b/pgtype/uuid.go @@ -69,7 +69,7 @@ func (src *Uuid) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot assign %v into %T", src, dst) diff --git a/pgtype/varchar_array.go b/pgtype/varchar_array.go index 2712b4d2..9ca16d7e 100644 --- a/pgtype/varchar_array.go +++ b/pgtype/varchar_array.go @@ -79,7 +79,7 @@ func (src *VarcharArray) AssignTo(dst interface{}) error { } } case Null: - return nullAssignTo(dst) + return NullAssignTo(dst) } return fmt.Errorf("cannot decode %v into %T", src, dst) diff --git a/v3.md b/v3.md index d9017890..2946bcf0 100644 --- a/v3.md +++ b/v3.md @@ -66,3 +66,5 @@ Keep ability to change logging while running consider test to ensure that AssignTo makes copy of reference types something like: select array[1,2,3], array[4,5,6,7] + +Reconsider synonym types like varchar/text and numeric/decimal. From f418255c24395c6f8ba5bf0edb1119d3c6e2cbdd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 14 Apr 2017 12:38:33 -0500 Subject: [PATCH 166/264] Finish extraction of pgtype test helpers --- pgtype/aclitem_array_test.go | 3 +- pgtype/aclitem_test.go | 3 +- pgtype/bool_array_test.go | 3 +- pgtype/bool_test.go | 3 +- pgtype/box_test.go | 9 +- pgtype/bytea_array_test.go | 3 +- pgtype/bytea_test.go | 3 +- pgtype/cid_test.go | 7 +- pgtype/cidr_array_test.go | 3 +- pgtype/circle_test.go | 3 +- pgtype/date_array_test.go | 3 +- pgtype/date_test.go | 3 +- pgtype/daterange_test.go | 9 +- pgtype/float4_array_test.go | 3 +- pgtype/float4_test.go | 3 +- pgtype/float8_array_test.go | 3 +- pgtype/float8_test.go | 3 +- pgtype/hstore_array_test.go | 7 +- pgtype/hstore_test.go | 3 +- pgtype/inet_array_test.go | 3 +- pgtype/inet_test.go | 3 +- pgtype/int2_array_test.go | 3 +- pgtype/int2_test.go | 3 +- pgtype/int4_array_test.go | 3 +- pgtype/int4_test.go | 3 +- pgtype/int4range_test.go | 9 +- pgtype/int8_array_test.go | 3 +- pgtype/int8_test.go | 3 +- pgtype/int8range_test.go | 9 +- pgtype/interval_test.go | 33 ++-- pgtype/json_test.go | 3 +- pgtype/jsonb_test.go | 7 +- pgtype/line_test.go | 7 +- pgtype/lseg_test.go | 3 +- pgtype/macaddr_test.go | 3 +- pgtype/name_test.go | 3 +- pgtype/numeric_array_test.go | 3 +- pgtype/numeric_test.go | 59 ++++--- pgtype/numrange_test.go | 3 +- pgtype/oid_value_test.go | 3 +- pgtype/path_test.go | 3 +- pgtype/pgtype_test.go | 291 ------------------------------- pgtype/point_test.go | 3 +- pgtype/polygon_test.go | 3 +- pgtype/qchar_test.go | 3 +- pgtype/record_test.go | 5 +- pgtype/testutil/testutil.go | 34 ++-- pgtype/text_array_test.go | 3 +- pgtype/text_test.go | 3 +- pgtype/tid_test.go | 3 +- pgtype/timestamp_array_test.go | 3 +- pgtype/timestamp_test.go | 3 +- pgtype/timestamptz_array_test.go | 3 +- pgtype/timestamptz_test.go | 3 +- pgtype/tsrange_test.go | 3 +- pgtype/tstzrange_test.go | 3 +- pgtype/uuid_test.go | 3 +- pgtype/varbit_test.go | 9 +- pgtype/varchar_array_test.go | 3 +- pgtype/xid_test.go | 7 +- 60 files changed, 202 insertions(+), 435 deletions(-) diff --git a/pgtype/aclitem_array_test.go b/pgtype/aclitem_array_test.go index 75c672bd..951e7847 100644 --- a/pgtype/aclitem_array_test.go +++ b/pgtype/aclitem_array_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestAclitemArrayTranscode(t *testing.T) { - testSuccessfulTranscode(t, "aclitem[]", []interface{}{ + testutil.TestSuccessfulTranscode(t, "aclitem[]", []interface{}{ &pgtype.AclitemArray{ Elements: nil, Dimensions: nil, diff --git a/pgtype/aclitem_test.go b/pgtype/aclitem_test.go index 1738025a..5389eab2 100644 --- a/pgtype/aclitem_test.go +++ b/pgtype/aclitem_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestAclitemTranscode(t *testing.T) { - testSuccessfulTranscode(t, "aclitem", []interface{}{ + testutil.TestSuccessfulTranscode(t, "aclitem", []interface{}{ pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, pgtype.Aclitem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, pgtype.Aclitem{Status: pgtype.Null}, diff --git a/pgtype/bool_array_test.go b/pgtype/bool_array_test.go index a526d892..87886da6 100644 --- a/pgtype/bool_array_test.go +++ b/pgtype/bool_array_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestBoolArrayTranscode(t *testing.T) { - testSuccessfulTranscode(t, "bool[]", []interface{}{ + testutil.TestSuccessfulTranscode(t, "bool[]", []interface{}{ &pgtype.BoolArray{ Elements: nil, Dimensions: nil, diff --git a/pgtype/bool_test.go b/pgtype/bool_test.go index 412e2fd0..31f3d528 100644 --- a/pgtype/bool_test.go +++ b/pgtype/bool_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestBoolTranscode(t *testing.T) { - testSuccessfulTranscode(t, "bool", []interface{}{ + testutil.TestSuccessfulTranscode(t, "bool", []interface{}{ pgtype.Bool{Bool: false, Status: pgtype.Present}, pgtype.Bool{Bool: true, Status: pgtype.Present}, pgtype.Bool{Bool: false, Status: pgtype.Null}, diff --git a/pgtype/box_test.go b/pgtype/box_test.go index 00732973..f26cda68 100644 --- a/pgtype/box_test.go +++ b/pgtype/box_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestBoxTranscode(t *testing.T) { - testSuccessfulTranscode(t, "box", []interface{}{ + testutil.TestSuccessfulTranscode(t, "box", []interface{}{ &pgtype.Box{ P: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}}, Status: pgtype.Present, @@ -21,10 +22,10 @@ func TestBoxTranscode(t *testing.T) { } func TestBoxNormalize(t *testing.T) { - testSuccessfulNormalize(t, []normalizeTest{ + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ { - sql: "select '3.14, 1.678, 7.1, 5.234'::box", - value: &pgtype.Box{ + SQL: "select '3.14, 1.678, 7.1, 5.234'::box", + Value: &pgtype.Box{ P: [2]pgtype.Vec2{{7.1, 5.234}, {3.14, 1.678}}, Status: pgtype.Present, }, diff --git a/pgtype/bytea_array_test.go b/pgtype/bytea_array_test.go index 22c6478b..451c2461 100644 --- a/pgtype/bytea_array_test.go +++ b/pgtype/bytea_array_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestByteaArrayTranscode(t *testing.T) { - testSuccessfulTranscode(t, "bytea[]", []interface{}{ + testutil.TestSuccessfulTranscode(t, "bytea[]", []interface{}{ &pgtype.ByteaArray{ Elements: nil, Dimensions: nil, diff --git a/pgtype/bytea_test.go b/pgtype/bytea_test.go index e21296c6..7d32e294 100644 --- a/pgtype/bytea_test.go +++ b/pgtype/bytea_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestByteaTranscode(t *testing.T) { - testSuccessfulTranscode(t, "bytea", []interface{}{ + testutil.TestSuccessfulTranscode(t, "bytea", []interface{}{ pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}, pgtype.Bytea{Bytes: nil, Status: pgtype.Null}, diff --git a/pgtype/cid_test.go b/pgtype/cid_test.go index 210573f6..385b8cac 100644 --- a/pgtype/cid_test.go +++ b/pgtype/cid_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestCidTranscode(t *testing.T) { @@ -17,13 +18,13 @@ func TestCidTranscode(t *testing.T) { return reflect.DeepEqual(a, b) } - testPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) // No direct conversion from int to cid, convert through text - testPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc) + testutil.TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc) for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { - testDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) + testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) } } diff --git a/pgtype/cidr_array_test.go b/pgtype/cidr_array_test.go index ec105914..1ebe5195 100644 --- a/pgtype/cidr_array_test.go +++ b/pgtype/cidr_array_test.go @@ -6,10 +6,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestCidrArrayTranscode(t *testing.T) { - testSuccessfulTranscode(t, "cidr[]", []interface{}{ + testutil.TestSuccessfulTranscode(t, "cidr[]", []interface{}{ &pgtype.CidrArray{ Elements: nil, Dimensions: nil, diff --git a/pgtype/circle_test.go b/pgtype/circle_test.go index 9746dd74..2747d4f5 100644 --- a/pgtype/circle_test.go +++ b/pgtype/circle_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestCircleTranscode(t *testing.T) { - testSuccessfulTranscode(t, "circle", []interface{}{ + testutil.TestSuccessfulTranscode(t, "circle", []interface{}{ &pgtype.Circle{P: pgtype.Vec2{1.234, 5.6789}, R: 3.5, Status: pgtype.Present}, &pgtype.Circle{P: pgtype.Vec2{-1.234, -5.6789}, R: 12.9, Status: pgtype.Present}, &pgtype.Circle{Status: pgtype.Null}, diff --git a/pgtype/date_array_test.go b/pgtype/date_array_test.go index a05f4254..74ebfbbe 100644 --- a/pgtype/date_array_test.go +++ b/pgtype/date_array_test.go @@ -6,10 +6,11 @@ import ( "time" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestDateArrayTranscode(t *testing.T) { - testSuccessfulTranscode(t, "date[]", []interface{}{ + testutil.TestSuccessfulTranscode(t, "date[]", []interface{}{ &pgtype.DateArray{ Elements: nil, Dimensions: nil, diff --git a/pgtype/date_test.go b/pgtype/date_test.go index 1832b5b4..d1493f5e 100644 --- a/pgtype/date_test.go +++ b/pgtype/date_test.go @@ -6,10 +6,11 @@ import ( "time" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestDateTranscode(t *testing.T) { - testSuccessfulTranscodeEqFunc(t, "date", []interface{}{ + testutil.TestSuccessfulTranscodeEqFunc(t, "date", []interface{}{ pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, diff --git a/pgtype/daterange_test.go b/pgtype/daterange_test.go index 8501cc7e..7dfae0f4 100644 --- a/pgtype/daterange_test.go +++ b/pgtype/daterange_test.go @@ -5,10 +5,11 @@ import ( "time" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestDaterangeTranscode(t *testing.T) { - testSuccessfulTranscodeEqFunc(t, "daterange", []interface{}{ + testutil.TestSuccessfulTranscodeEqFunc(t, "daterange", []interface{}{ pgtype.Daterange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, pgtype.Daterange{ Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, @@ -40,10 +41,10 @@ func TestDaterangeTranscode(t *testing.T) { } func TestDaterangeNormalize(t *testing.T) { - testSuccessfulNormalizeEqFunc(t, []normalizeTest{ + testutil.TestSuccessfulNormalizeEqFunc(t, []testutil.NormalizeTest{ { - sql: "select daterange('2010-01-01', '2010-01-11', '(]')", - value: pgtype.Daterange{ + SQL: "select daterange('2010-01-01', '2010-01-11', '(]')", + Value: pgtype.Daterange{ Lower: pgtype.Date{Time: time.Date(2010, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, Upper: pgtype.Date{Time: time.Date(2010, 1, 12, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, LowerType: pgtype.Inclusive, diff --git a/pgtype/float4_array_test.go b/pgtype/float4_array_test.go index 06a1d2e0..6d6a4f30 100644 --- a/pgtype/float4_array_test.go +++ b/pgtype/float4_array_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestFloat4ArrayTranscode(t *testing.T) { - testSuccessfulTranscode(t, "float4[]", []interface{}{ + testutil.TestSuccessfulTranscode(t, "float4[]", []interface{}{ &pgtype.Float4Array{ Elements: nil, Dimensions: nil, diff --git a/pgtype/float4_test.go b/pgtype/float4_test.go index ea60cd3a..57f4bc34 100644 --- a/pgtype/float4_test.go +++ b/pgtype/float4_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestFloat4Transcode(t *testing.T) { - testSuccessfulTranscode(t, "float4", []interface{}{ + testutil.TestSuccessfulTranscode(t, "float4", []interface{}{ pgtype.Float4{Float: -1, Status: pgtype.Present}, pgtype.Float4{Float: 0, Status: pgtype.Present}, pgtype.Float4{Float: 0.00001, Status: pgtype.Present}, diff --git a/pgtype/float8_array_test.go b/pgtype/float8_array_test.go index 635e249a..56801e80 100644 --- a/pgtype/float8_array_test.go +++ b/pgtype/float8_array_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestFloat8ArrayTranscode(t *testing.T) { - testSuccessfulTranscode(t, "float8[]", []interface{}{ + testutil.TestSuccessfulTranscode(t, "float8[]", []interface{}{ &pgtype.Float8Array{ Elements: nil, Dimensions: nil, diff --git a/pgtype/float8_test.go b/pgtype/float8_test.go index 724e9350..b7527b86 100644 --- a/pgtype/float8_test.go +++ b/pgtype/float8_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestFloat8Transcode(t *testing.T) { - testSuccessfulTranscode(t, "float8", []interface{}{ + testutil.TestSuccessfulTranscode(t, "float8", []interface{}{ pgtype.Float8{Float: -1, Status: pgtype.Present}, pgtype.Float8{Float: 0, Status: pgtype.Present}, pgtype.Float8{Float: 0.00001, Status: pgtype.Present}, diff --git a/pgtype/hstore_array_test.go b/pgtype/hstore_array_test.go index e23c7b3b..d26497b1 100644 --- a/pgtype/hstore_array_test.go +++ b/pgtype/hstore_array_test.go @@ -6,11 +6,12 @@ import ( "github.com/jackc/pgx" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestHstoreArrayTranscode(t *testing.T) { - conn := mustConnectPgx(t) - defer mustClose(t, conn) + conn := testutil.MustConnectPgx(t) + defer testutil.MustClose(t, conn) text := func(s string) pgtype.Text { return pgtype.Text{String: s, Status: pgtype.Present} @@ -69,7 +70,7 @@ func TestHstoreArrayTranscode(t *testing.T) { for _, fc := range formats { ps.FieldDescriptions[0].FormatCode = fc.formatCode - vEncoder := forceEncoder(src, fc.formatCode) + vEncoder := testutil.ForceEncoder(src, fc.formatCode) if vEncoder == nil { t.Logf("%#v does not implement %v", src, fc.name) continue diff --git a/pgtype/hstore_test.go b/pgtype/hstore_test.go index fbe8dee5..502a8df0 100644 --- a/pgtype/hstore_test.go +++ b/pgtype/hstore_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestHstoreTranscode(t *testing.T) { @@ -44,7 +45,7 @@ func TestHstoreTranscode(t *testing.T) { values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key } - testSuccessfulTranscodeEqFunc(t, "hstore", values, func(ai, bi interface{}) bool { + testutil.TestSuccessfulTranscodeEqFunc(t, "hstore", values, func(ai, bi interface{}) bool { a := ai.(pgtype.Hstore) b := bi.(pgtype.Hstore) diff --git a/pgtype/inet_array_test.go b/pgtype/inet_array_test.go index fe22285d..c0465922 100644 --- a/pgtype/inet_array_test.go +++ b/pgtype/inet_array_test.go @@ -6,10 +6,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestInetArrayTranscode(t *testing.T) { - testSuccessfulTranscode(t, "inet[]", []interface{}{ + testutil.TestSuccessfulTranscode(t, "inet[]", []interface{}{ &pgtype.InetArray{ Elements: nil, Dimensions: nil, diff --git a/pgtype/inet_test.go b/pgtype/inet_test.go index 16035fca..532e9abe 100644 --- a/pgtype/inet_test.go +++ b/pgtype/inet_test.go @@ -6,11 +6,12 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestInetTranscode(t *testing.T) { for _, pgTypeName := range []string{"inet", "cidr"} { - testSuccessfulTranscode(t, pgTypeName, []interface{}{ + testutil.TestSuccessfulTranscode(t, pgTypeName, []interface{}{ pgtype.Inet{IPNet: mustParseCidr(t, "0.0.0.0/32"), Status: pgtype.Present}, pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, pgtype.Inet{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, diff --git a/pgtype/int2_array_test.go b/pgtype/int2_array_test.go index 8af4523d..0adc1aef 100644 --- a/pgtype/int2_array_test.go +++ b/pgtype/int2_array_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestInt2ArrayTranscode(t *testing.T) { - testSuccessfulTranscode(t, "int2[]", []interface{}{ + testutil.TestSuccessfulTranscode(t, "int2[]", []interface{}{ &pgtype.Int2Array{ Elements: nil, Dimensions: nil, diff --git a/pgtype/int2_test.go b/pgtype/int2_test.go index 2bd8e016..d81405a6 100644 --- a/pgtype/int2_test.go +++ b/pgtype/int2_test.go @@ -6,10 +6,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestInt2Transcode(t *testing.T) { - testSuccessfulTranscode(t, "int2", []interface{}{ + testutil.TestSuccessfulTranscode(t, "int2", []interface{}{ pgtype.Int2{Int: math.MinInt16, Status: pgtype.Present}, pgtype.Int2{Int: -1, Status: pgtype.Present}, pgtype.Int2{Int: 0, Status: pgtype.Present}, diff --git a/pgtype/int4_array_test.go b/pgtype/int4_array_test.go index 111cb56b..6fad18bb 100644 --- a/pgtype/int4_array_test.go +++ b/pgtype/int4_array_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestInt4ArrayTranscode(t *testing.T) { - testSuccessfulTranscode(t, "int4[]", []interface{}{ + testutil.TestSuccessfulTranscode(t, "int4[]", []interface{}{ &pgtype.Int4Array{ Elements: nil, Dimensions: nil, diff --git a/pgtype/int4_test.go b/pgtype/int4_test.go index 3e000182..1354b47a 100644 --- a/pgtype/int4_test.go +++ b/pgtype/int4_test.go @@ -6,10 +6,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestInt4Transcode(t *testing.T) { - testSuccessfulTranscode(t, "int4", []interface{}{ + testutil.TestSuccessfulTranscode(t, "int4", []interface{}{ pgtype.Int4{Int: math.MinInt32, Status: pgtype.Present}, pgtype.Int4{Int: -1, Status: pgtype.Present}, pgtype.Int4{Int: 0, Status: pgtype.Present}, diff --git a/pgtype/int4range_test.go b/pgtype/int4range_test.go index c96fe9cd..74a91e59 100644 --- a/pgtype/int4range_test.go +++ b/pgtype/int4range_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestInt4rangeTranscode(t *testing.T) { - testSuccessfulTranscode(t, "int4range", []interface{}{ + testutil.TestSuccessfulTranscode(t, "int4range", []interface{}{ pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, pgtype.Int4range{Lower: pgtype.Int4{Int: 1, Status: pgtype.Present}, Upper: pgtype.Int4{Int: 10, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, pgtype.Int4range{Lower: pgtype.Int4{Int: -42, Status: pgtype.Present}, Upper: pgtype.Int4{Int: -5, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, @@ -16,10 +17,10 @@ func TestInt4rangeTranscode(t *testing.T) { } func TestInt4rangeNormalize(t *testing.T) { - testSuccessfulNormalize(t, []normalizeTest{ + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ { - sql: "select int4range(1, 10, '(]')", - value: pgtype.Int4range{Lower: pgtype.Int4{Int: 2, Status: pgtype.Present}, Upper: pgtype.Int4{Int: 11, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + SQL: "select int4range(1, 10, '(]')", + Value: pgtype.Int4range{Lower: pgtype.Int4{Int: 2, Status: pgtype.Present}, Upper: pgtype.Int4{Int: 11, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, }, }) } diff --git a/pgtype/int8_array_test.go b/pgtype/int8_array_test.go index 349a1f7e..4f5c4f9a 100644 --- a/pgtype/int8_array_test.go +++ b/pgtype/int8_array_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestInt8ArrayTranscode(t *testing.T) { - testSuccessfulTranscode(t, "int8[]", []interface{}{ + testutil.TestSuccessfulTranscode(t, "int8[]", []interface{}{ &pgtype.Int8Array{ Elements: nil, Dimensions: nil, diff --git a/pgtype/int8_test.go b/pgtype/int8_test.go index e1fe69fb..d6752205 100644 --- a/pgtype/int8_test.go +++ b/pgtype/int8_test.go @@ -6,10 +6,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestInt8Transcode(t *testing.T) { - testSuccessfulTranscode(t, "int8", []interface{}{ + testutil.TestSuccessfulTranscode(t, "int8", []interface{}{ pgtype.Int8{Int: math.MinInt64, Status: pgtype.Present}, pgtype.Int8{Int: -1, Status: pgtype.Present}, pgtype.Int8{Int: 0, Status: pgtype.Present}, diff --git a/pgtype/int8range_test.go b/pgtype/int8range_test.go index 1b3e594c..703f476e 100644 --- a/pgtype/int8range_test.go +++ b/pgtype/int8range_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestInt8rangeTranscode(t *testing.T) { - testSuccessfulTranscode(t, "Int8range", []interface{}{ + testutil.TestSuccessfulTranscode(t, "Int8range", []interface{}{ pgtype.Int8range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, pgtype.Int8range{Lower: pgtype.Int8{Int: 1, Status: pgtype.Present}, Upper: pgtype.Int8{Int: 10, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, pgtype.Int8range{Lower: pgtype.Int8{Int: -42, Status: pgtype.Present}, Upper: pgtype.Int8{Int: -5, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, @@ -16,10 +17,10 @@ func TestInt8rangeTranscode(t *testing.T) { } func TestInt8rangeNormalize(t *testing.T) { - testSuccessfulNormalize(t, []normalizeTest{ + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ { - sql: "select Int8range(1, 10, '(]')", - value: pgtype.Int8range{Lower: pgtype.Int8{Int: 2, Status: pgtype.Present}, Upper: pgtype.Int8{Int: 11, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + SQL: "select Int8range(1, 10, '(]')", + Value: pgtype.Int8range{Lower: pgtype.Int8{Int: 2, Status: pgtype.Present}, Upper: pgtype.Int8{Int: 11, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, }, }) } diff --git a/pgtype/interval_test.go b/pgtype/interval_test.go index db9614ef..28e77e0a 100644 --- a/pgtype/interval_test.go +++ b/pgtype/interval_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestIntervalTranscode(t *testing.T) { - testSuccessfulTranscode(t, "interval", []interface{}{ + testutil.TestSuccessfulTranscode(t, "interval", []interface{}{ pgtype.Interval{Microseconds: 1, Status: pgtype.Present}, pgtype.Interval{Microseconds: 1000000, Status: pgtype.Present}, pgtype.Interval{Microseconds: 1000001, Status: pgtype.Present}, @@ -29,34 +30,34 @@ func TestIntervalTranscode(t *testing.T) { } func TestIntervalNormalize(t *testing.T) { - testSuccessfulNormalize(t, []normalizeTest{ + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ { - sql: "select '1 second'::interval", - value: pgtype.Interval{Microseconds: 1000000, Status: pgtype.Present}, + SQL: "select '1 second'::interval", + Value: pgtype.Interval{Microseconds: 1000000, Status: pgtype.Present}, }, { - sql: "select '1.000001 second'::interval", - value: pgtype.Interval{Microseconds: 1000001, Status: pgtype.Present}, + SQL: "select '1.000001 second'::interval", + Value: pgtype.Interval{Microseconds: 1000001, Status: pgtype.Present}, }, { - sql: "select '34223 hours'::interval", - value: pgtype.Interval{Microseconds: 123202800000000, Status: pgtype.Present}, + SQL: "select '34223 hours'::interval", + Value: pgtype.Interval{Microseconds: 123202800000000, Status: pgtype.Present}, }, { - sql: "select '1 day'::interval", - value: pgtype.Interval{Days: 1, Status: pgtype.Present}, + SQL: "select '1 day'::interval", + Value: pgtype.Interval{Days: 1, Status: pgtype.Present}, }, { - sql: "select '1 month'::interval", - value: pgtype.Interval{Months: 1, Status: pgtype.Present}, + SQL: "select '1 month'::interval", + Value: pgtype.Interval{Months: 1, Status: pgtype.Present}, }, { - sql: "select '1 year'::interval", - value: pgtype.Interval{Months: 12, Status: pgtype.Present}, + SQL: "select '1 year'::interval", + Value: pgtype.Interval{Months: 12, Status: pgtype.Present}, }, { - sql: "select '-13 mon'::interval", - value: pgtype.Interval{Months: -13, Status: pgtype.Present}, + SQL: "select '-13 mon'::interval", + Value: pgtype.Interval{Months: -13, Status: pgtype.Present}, }, }) } diff --git a/pgtype/json_test.go b/pgtype/json_test.go index b0aa8c9b..6d7cccfd 100644 --- a/pgtype/json_test.go +++ b/pgtype/json_test.go @@ -6,10 +6,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestJsonTranscode(t *testing.T) { - testSuccessfulTranscode(t, "json", []interface{}{ + testutil.TestSuccessfulTranscode(t, "json", []interface{}{ pgtype.Json{Bytes: []byte("{}"), Status: pgtype.Present}, pgtype.Json{Bytes: []byte("null"), Status: pgtype.Present}, pgtype.Json{Bytes: []byte("42"), Status: pgtype.Present}, diff --git a/pgtype/jsonb_test.go b/pgtype/jsonb_test.go index 91637eb8..37c11858 100644 --- a/pgtype/jsonb_test.go +++ b/pgtype/jsonb_test.go @@ -6,16 +6,17 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestJsonbTranscode(t *testing.T) { - conn := mustConnectPgx(t) - defer mustClose(t, conn) + conn := testutil.MustConnectPgx(t) + defer testutil.MustClose(t, conn) if _, ok := conn.ConnInfo.DataTypeForName("jsonb"); !ok { t.Skip("Skipping due to no jsonb type") } - testSuccessfulTranscode(t, "jsonb", []interface{}{ + testutil.TestSuccessfulTranscode(t, "jsonb", []interface{}{ pgtype.Jsonb{Bytes: []byte("{}"), Status: pgtype.Present}, pgtype.Jsonb{Bytes: []byte("null"), Status: pgtype.Present}, pgtype.Jsonb{Bytes: []byte("42"), Status: pgtype.Present}, diff --git a/pgtype/line_test.go b/pgtype/line_test.go index 995eaad5..09e48019 100644 --- a/pgtype/line_test.go +++ b/pgtype/line_test.go @@ -5,15 +5,16 @@ import ( version "github.com/hashicorp/go-version" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestLineTranscode(t *testing.T) { - conn := mustConnectPgx(t) + conn := testutil.MustConnectPgx(t) serverVersion, err := version.NewVersion(conn.RuntimeParams["server_version"]) if err != nil { t.Fatalf("cannot get server version: %v", err) } - mustClose(t, conn) + testutil.MustClose(t, conn) minVersion := version.Must(version.NewVersion("9.4")) @@ -21,7 +22,7 @@ func TestLineTranscode(t *testing.T) { t.Skipf("Skipping line test for server version %v", serverVersion) } - testSuccessfulTranscode(t, "line", []interface{}{ + testutil.TestSuccessfulTranscode(t, "line", []interface{}{ &pgtype.Line{ A: 1.23, B: 4.56, C: 7.89, Status: pgtype.Present, diff --git a/pgtype/lseg_test.go b/pgtype/lseg_test.go index 5f041263..bd394e3c 100644 --- a/pgtype/lseg_test.go +++ b/pgtype/lseg_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestLsegTranscode(t *testing.T) { - testSuccessfulTranscode(t, "lseg", []interface{}{ + testutil.TestSuccessfulTranscode(t, "lseg", []interface{}{ &pgtype.Lseg{ P: [2]pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}}, Status: pgtype.Present, diff --git a/pgtype/macaddr_test.go b/pgtype/macaddr_test.go index 6c7b8b89..c2542da3 100644 --- a/pgtype/macaddr_test.go +++ b/pgtype/macaddr_test.go @@ -7,10 +7,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestMacaddrTranscode(t *testing.T) { - testSuccessfulTranscode(t, "macaddr", []interface{}{ + testutil.TestSuccessfulTranscode(t, "macaddr", []interface{}{ pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, pgtype.Macaddr{Status: pgtype.Null}, }) diff --git a/pgtype/name_test.go b/pgtype/name_test.go index 81a766b8..348f8d39 100644 --- a/pgtype/name_test.go +++ b/pgtype/name_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestNameTranscode(t *testing.T) { - testSuccessfulTranscode(t, "name", []interface{}{ + testutil.TestSuccessfulTranscode(t, "name", []interface{}{ pgtype.Name{String: "", Status: pgtype.Present}, pgtype.Name{String: "foo", Status: pgtype.Present}, pgtype.Name{Status: pgtype.Null}, diff --git a/pgtype/numeric_array_test.go b/pgtype/numeric_array_test.go index af2e8e51..25531840 100644 --- a/pgtype/numeric_array_test.go +++ b/pgtype/numeric_array_test.go @@ -6,10 +6,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestNumericArrayTranscode(t *testing.T) { - testSuccessfulTranscode(t, "numeric[]", []interface{}{ + testutil.TestSuccessfulTranscode(t, "numeric[]", []interface{}{ &pgtype.NumericArray{ Elements: nil, Dimensions: nil, diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go index 93aa8866..d68a9347 100644 --- a/pgtype/numeric_test.go +++ b/pgtype/numeric_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) // For test purposes only. Note that it does not normalize values. e.g. (Int: 1, Exp: 3) will not equal (Int: 1000, Exp: 0) @@ -45,66 +46,66 @@ func mustParseBigInt(t *testing.T, src string) *big.Int { } func TestNumericNormalize(t *testing.T) { - testSuccessfulNormalize(t, []normalizeTest{ + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ { - sql: "select '0'::numeric", - value: pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present}, + SQL: "select '0'::numeric", + Value: pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present}, }, { - sql: "select '1'::numeric", - value: pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present}, + SQL: "select '1'::numeric", + Value: pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present}, }, { - sql: "select '10.00'::numeric", - value: pgtype.Numeric{Int: big.NewInt(1000), Exp: -2, Status: pgtype.Present}, + SQL: "select '10.00'::numeric", + Value: pgtype.Numeric{Int: big.NewInt(1000), Exp: -2, Status: pgtype.Present}, }, { - sql: "select '1e-3'::numeric", - value: pgtype.Numeric{Int: big.NewInt(1), Exp: -3, Status: pgtype.Present}, + SQL: "select '1e-3'::numeric", + Value: pgtype.Numeric{Int: big.NewInt(1), Exp: -3, Status: pgtype.Present}, }, { - sql: "select '-1'::numeric", - value: pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Status: pgtype.Present}, + SQL: "select '-1'::numeric", + Value: pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Status: pgtype.Present}, }, { - sql: "select '10000'::numeric", - value: pgtype.Numeric{Int: big.NewInt(1), Exp: 4, Status: pgtype.Present}, + SQL: "select '10000'::numeric", + Value: pgtype.Numeric{Int: big.NewInt(1), Exp: 4, Status: pgtype.Present}, }, { - sql: "select '3.14'::numeric", - value: pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Status: pgtype.Present}, + SQL: "select '3.14'::numeric", + Value: pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Status: pgtype.Present}, }, { - sql: "select '1.1'::numeric", - value: pgtype.Numeric{Int: big.NewInt(11), Exp: -1, Status: pgtype.Present}, + SQL: "select '1.1'::numeric", + Value: pgtype.Numeric{Int: big.NewInt(11), Exp: -1, Status: pgtype.Present}, }, { - sql: "select '100010001'::numeric", - value: pgtype.Numeric{Int: big.NewInt(100010001), Exp: 0, Status: pgtype.Present}, + SQL: "select '100010001'::numeric", + Value: pgtype.Numeric{Int: big.NewInt(100010001), Exp: 0, Status: pgtype.Present}, }, { - sql: "select '100010001.0001'::numeric", - value: pgtype.Numeric{Int: big.NewInt(1000100010001), Exp: -4, Status: pgtype.Present}, + SQL: "select '100010001.0001'::numeric", + Value: pgtype.Numeric{Int: big.NewInt(1000100010001), Exp: -4, Status: pgtype.Present}, }, { - sql: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric", - value: pgtype.Numeric{ + SQL: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric", + Value: pgtype.Numeric{ Int: mustParseBigInt(t, "423723478923478928934789237432487213832189417894318904389012483210893443219085471578891547854892438945012347981"), Exp: -41, Status: pgtype.Present, }, }, { - sql: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric", - value: pgtype.Numeric{ + SQL: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric", + Value: pgtype.Numeric{ Int: mustParseBigInt(t, "8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), Exp: -196, Status: pgtype.Present, }, }, { - sql: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric", - value: pgtype.Numeric{ + SQL: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric", + Value: pgtype.Numeric{ Int: mustParseBigInt(t, "123"), Exp: -186, Status: pgtype.Present, @@ -114,7 +115,7 @@ func TestNumericNormalize(t *testing.T) { } func TestNumericTranscode(t *testing.T) { - testSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ + testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", []interface{}{ &pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present}, &pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present}, &pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Status: pgtype.Present}, @@ -172,7 +173,7 @@ func TestNumericTranscodeFuzz(t *testing.T) { } } - testSuccessfulTranscodeEqFunc(t, "numeric", values, + testutil.TestSuccessfulTranscodeEqFunc(t, "numeric", values, func(aa, bb interface{}) bool { a := aa.(pgtype.Numeric) b := bb.(pgtype.Numeric) diff --git a/pgtype/numrange_test.go b/pgtype/numrange_test.go index 81202362..81e73c38 100644 --- a/pgtype/numrange_test.go +++ b/pgtype/numrange_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestNumrangeTranscode(t *testing.T) { - testSuccessfulTranscode(t, "numrange", []interface{}{ + testutil.TestSuccessfulTranscode(t, "numrange", []interface{}{ pgtype.Numrange{ LowerType: pgtype.Empty, UpperType: pgtype.Empty, diff --git a/pgtype/oid_value_test.go b/pgtype/oid_value_test.go index 21dd6f9d..d3412159 100644 --- a/pgtype/oid_value_test.go +++ b/pgtype/oid_value_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestOidValueTranscode(t *testing.T) { - testSuccessfulTranscode(t, "oid", []interface{}{ + testutil.TestSuccessfulTranscode(t, "oid", []interface{}{ pgtype.OidValue{Uint: 42, Status: pgtype.Present}, pgtype.OidValue{Status: pgtype.Null}, }) diff --git a/pgtype/path_test.go b/pgtype/path_test.go index 4e5f7f62..d213a1b4 100644 --- a/pgtype/path_test.go +++ b/pgtype/path_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestPathTranscode(t *testing.T) { - testSuccessfulTranscode(t, "path", []interface{}{ + testutil.TestSuccessfulTranscode(t, "path", []interface{}{ &pgtype.Path{ P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}}, Closed: false, diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index f486f077..716e063d 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -1,17 +1,9 @@ package pgtype_test import ( - "context" - "database/sql" - "fmt" - "io" "net" - "os" - "reflect" "testing" - "github.com/jackc/pgx" - "github.com/jackc/pgx/pgtype" _ "github.com/jackc/pgx/stdlib" _ "github.com/lib/pq" ) @@ -28,48 +20,6 @@ type _float32Slice []float32 type _float64Slice []float64 type _byteSlice []byte -func mustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { - var sqlDriverName string - switch driverName { - case "github.com/lib/pq": - sqlDriverName = "postgres" - case "github.com/jackc/pgx/stdlib": - sqlDriverName = "pgx" - default: - t.Fatalf("Unknown driver %v", driverName) - } - - db, err := sql.Open(sqlDriverName, os.Getenv("DATABASE_URL")) - if err != nil { - t.Fatal(err) - } - - return db -} - -func mustConnectPgx(t testing.TB) *pgx.Conn { - config, err := pgx.ParseURI(os.Getenv("DATABASE_URL")) - if err != nil { - t.Fatal(err) - } - - conn, err := pgx.Connect(config) - if err != nil { - t.Fatal(err) - } - - return conn -} - -func mustClose(t testing.TB, conn interface { - Close() error -}) { - err := conn.Close() - if err != nil { - t.Fatal(err) - } -} - func mustParseCidr(t testing.TB, s string) *net.IPNet { _, ipnet, err := net.ParseCIDR(s) if err != nil { @@ -87,244 +37,3 @@ func mustParseMacaddr(t testing.TB, s string) net.HardwareAddr { return addr } - -type forceTextEncoder struct { - e pgtype.TextEncoder -} - -func (f forceTextEncoder) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { - return f.e.EncodeText(ci, w) -} - -type forceBinaryEncoder struct { - e pgtype.BinaryEncoder -} - -func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { - return f.e.EncodeBinary(ci, w) -} - -func forceEncoder(e interface{}, formatCode int16) interface{} { - switch formatCode { - case pgx.TextFormatCode: - if e, ok := e.(pgtype.TextEncoder); ok { - return forceTextEncoder{e: e} - } - case pgx.BinaryFormatCode: - if e, ok := e.(pgtype.BinaryEncoder); ok { - return forceBinaryEncoder{e: e.(pgtype.BinaryEncoder)} - } - } - return nil -} - -func testSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface{}) { - testSuccessfulTranscodeEqFunc(t, pgTypeName, values, func(a, b interface{}) bool { - return reflect.DeepEqual(a, b) - }) -} - -func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { - testPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) - testPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) - for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { - testDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) - } -} - -func testPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { - conn := mustConnectPgx(t) - defer mustClose(t, conn) - - ps, err := conn.Prepare("test", fmt.Sprintf("select $1::%s", pgTypeName)) - if err != nil { - t.Fatal(err) - } - - formats := []struct { - name string - formatCode int16 - }{ - {name: "TextFormat", formatCode: pgx.TextFormatCode}, - {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, - } - - for i, v := range values { - for _, fc := range formats { - ps.FieldDescriptions[0].FormatCode = fc.formatCode - vEncoder := forceEncoder(v, fc.formatCode) - if vEncoder == nil { - t.Logf("Skipping: %#v does not implement %v", v, fc.name) - continue - } - // Derefence value if it is a pointer - derefV := v - refVal := reflect.ValueOf(v) - if refVal.Kind() == reflect.Ptr { - derefV = refVal.Elem().Interface() - } - - result := reflect.New(reflect.TypeOf(derefV)) - err := conn.QueryRow("test", forceEncoder(v, fc.formatCode)).Scan(result.Interface()) - if err != nil { - t.Errorf("%v %d: %v", fc.name, i, err) - } - - if !eqFunc(result.Elem().Interface(), derefV) { - t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface()) - } - } - } -} - -func testPgxSimpleProtocolSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { - conn := mustConnectPgx(t) - defer mustClose(t, conn) - - for i, v := range values { - // Derefence value if it is a pointer - derefV := v - refVal := reflect.ValueOf(v) - if refVal.Kind() == reflect.Ptr { - derefV = refVal.Elem().Interface() - } - - result := reflect.New(reflect.TypeOf(derefV)) - err := conn.QueryRowEx( - context.Background(), - fmt.Sprintf("select ($1)::%s", pgTypeName), - &pgx.QueryExOptions{SimpleProtocol: true}, - v, - ).Scan(result.Interface()) - if err != nil { - t.Errorf("Simple protocol %d: %v", i, err) - } - - if !eqFunc(result.Elem().Interface(), derefV) { - t.Errorf("Simple protocol %d: expected %v, got %v", i, derefV, result.Elem().Interface()) - } - } -} - -func testDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { - conn := mustConnectDatabaseSQL(t, driverName) - defer mustClose(t, conn) - - ps, err := conn.Prepare(fmt.Sprintf("select $1::%s", pgTypeName)) - if err != nil { - t.Fatal(err) - } - - for i, v := range values { - // Derefence value if it is a pointer - derefV := v - refVal := reflect.ValueOf(v) - if refVal.Kind() == reflect.Ptr { - derefV = refVal.Elem().Interface() - } - - result := reflect.New(reflect.TypeOf(derefV)) - err := ps.QueryRow(v).Scan(result.Interface()) - if err != nil { - t.Errorf("%v %d: %v", driverName, i, err) - } - - if !eqFunc(result.Elem().Interface(), derefV) { - t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) - } - } -} - -type normalizeTest struct { - sql string - value interface{} -} - -func testSuccessfulNormalize(t testing.TB, tests []normalizeTest) { - testSuccessfulNormalizeEqFunc(t, tests, func(a, b interface{}) bool { - return reflect.DeepEqual(a, b) - }) -} - -func testSuccessfulNormalizeEqFunc(t testing.TB, tests []normalizeTest, eqFunc func(a, b interface{}) bool) { - testPgxSuccessfulNormalizeEqFunc(t, tests, eqFunc) - for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { - testDatabaseSQLSuccessfulNormalizeEqFunc(t, driverName, tests, eqFunc) - } -} - -func testPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []normalizeTest, eqFunc func(a, b interface{}) bool) { - conn := mustConnectPgx(t) - defer mustClose(t, conn) - - formats := []struct { - name string - formatCode int16 - }{ - {name: "TextFormat", formatCode: pgx.TextFormatCode}, - {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, - } - - for i, tt := range tests { - for _, fc := range formats { - psName := fmt.Sprintf("test%d", i) - ps, err := conn.Prepare(psName, tt.sql) - if err != nil { - t.Fatal(err) - } - - ps.FieldDescriptions[0].FormatCode = fc.formatCode - if forceEncoder(tt.value, fc.formatCode) == nil { - t.Logf("Skipping: %#v does not implement %v", tt.value, fc.name) - continue - } - // Derefence value if it is a pointer - derefV := tt.value - refVal := reflect.ValueOf(tt.value) - if refVal.Kind() == reflect.Ptr { - derefV = refVal.Elem().Interface() - } - - result := reflect.New(reflect.TypeOf(derefV)) - err = conn.QueryRow(psName).Scan(result.Interface()) - if err != nil { - t.Errorf("%v %d: %v", fc.name, i, err) - } - - if !eqFunc(result.Elem().Interface(), derefV) { - t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface()) - } - } - } -} - -func testDatabaseSQLSuccessfulNormalizeEqFunc(t testing.TB, driverName string, tests []normalizeTest, eqFunc func(a, b interface{}) bool) { - conn := mustConnectDatabaseSQL(t, driverName) - defer mustClose(t, conn) - - for i, tt := range tests { - ps, err := conn.Prepare(tt.sql) - if err != nil { - t.Errorf("%d. %v", i, err) - continue - } - - // Derefence value if it is a pointer - derefV := tt.value - refVal := reflect.ValueOf(tt.value) - if refVal.Kind() == reflect.Ptr { - derefV = refVal.Elem().Interface() - } - - result := reflect.New(reflect.TypeOf(derefV)) - err = ps.QueryRow().Scan(result.Interface()) - if err != nil { - t.Errorf("%v %d: %v", driverName, i, err) - } - - if !eqFunc(result.Elem().Interface(), derefV) { - t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) - } - } - -} diff --git a/pgtype/point_test.go b/pgtype/point_test.go index c921f794..f46b342d 100644 --- a/pgtype/point_test.go +++ b/pgtype/point_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestPointTranscode(t *testing.T) { - testSuccessfulTranscode(t, "point", []interface{}{ + testutil.TestSuccessfulTranscode(t, "point", []interface{}{ &pgtype.Point{P: pgtype.Vec2{1.234, 5.6789}, Status: pgtype.Present}, &pgtype.Point{P: pgtype.Vec2{-1.234, -5.6789}, Status: pgtype.Present}, &pgtype.Point{Status: pgtype.Null}, diff --git a/pgtype/polygon_test.go b/pgtype/polygon_test.go index 3a7e1431..48481dc5 100644 --- a/pgtype/polygon_test.go +++ b/pgtype/polygon_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestPolygonTranscode(t *testing.T) { - testSuccessfulTranscode(t, "polygon", []interface{}{ + testutil.TestSuccessfulTranscode(t, "polygon", []interface{}{ &pgtype.Polygon{ P: []pgtype.Vec2{{3.14, 1.678}, {7.1, 5.234}, {5.0, 3.234}}, Status: pgtype.Present, diff --git a/pgtype/qchar_test.go b/pgtype/qchar_test.go index afac5016..b810b89c 100644 --- a/pgtype/qchar_test.go +++ b/pgtype/qchar_test.go @@ -6,10 +6,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestQCharTranscode(t *testing.T) { - testPgxSuccessfulTranscodeEqFunc(t, `"char"`, []interface{}{ + testutil.TestPgxSuccessfulTranscodeEqFunc(t, `"char"`, []interface{}{ pgtype.QChar{Int: math.MinInt8, Status: pgtype.Present}, pgtype.QChar{Int: -1, Status: pgtype.Present}, pgtype.QChar{Int: 0, Status: pgtype.Present}, diff --git a/pgtype/record_test.go b/pgtype/record_test.go index bc6e5893..df17501f 100644 --- a/pgtype/record_test.go +++ b/pgtype/record_test.go @@ -7,11 +7,12 @@ import ( "github.com/jackc/pgx" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestRecordTranscode(t *testing.T) { - conn := mustConnectPgx(t) - defer mustClose(t, conn) + conn := testutil.MustConnectPgx(t) + defer testutil.MustClose(t, conn) tests := []struct { sql string diff --git a/pgtype/testutil/testutil.go b/pgtype/testutil/testutil.go index 610f0710..d9aaa5c4 100644 --- a/pgtype/testutil/testutil.go +++ b/pgtype/testutil/testutil.go @@ -15,7 +15,7 @@ import ( _ "github.com/lib/pq" ) -func mustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { +func MustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { var sqlDriverName string switch driverName { case "github.com/lib/pq": @@ -34,7 +34,7 @@ func mustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { return db } -func mustConnectPgx(t testing.TB) *pgx.Conn { +func MustConnectPgx(t testing.TB) *pgx.Conn { config, err := pgx.ParseURI(os.Getenv("DATABASE_URL")) if err != nil { t.Fatal(err) @@ -48,7 +48,7 @@ func mustConnectPgx(t testing.TB) *pgx.Conn { return conn } -func mustClose(t testing.TB, conn interface { +func MustClose(t testing.TB, conn interface { Close() error }) { err := conn.Close() @@ -73,7 +73,7 @@ func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool return f.e.EncodeBinary(ci, w) } -func forceEncoder(e interface{}, formatCode int16) interface{} { +func ForceEncoder(e interface{}, formatCode int16) interface{} { switch formatCode { case pgx.TextFormatCode: if e, ok := e.(pgtype.TextEncoder); ok { @@ -102,8 +102,8 @@ func TestSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []int } func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { - conn := mustConnectPgx(t) - defer mustClose(t, conn) + conn := MustConnectPgx(t) + defer MustClose(t, conn) ps, err := conn.Prepare("test", fmt.Sprintf("select $1::%s", pgTypeName)) if err != nil { @@ -121,7 +121,7 @@ func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values [] for i, v := range values { for _, fc := range formats { ps.FieldDescriptions[0].FormatCode = fc.formatCode - vEncoder := forceEncoder(v, fc.formatCode) + vEncoder := ForceEncoder(v, fc.formatCode) if vEncoder == nil { t.Logf("Skipping: %#v does not implement %v", v, fc.name) continue @@ -134,7 +134,7 @@ func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values [] } result := reflect.New(reflect.TypeOf(derefV)) - err := conn.QueryRow("test", forceEncoder(v, fc.formatCode)).Scan(result.Interface()) + err := conn.QueryRow("test", ForceEncoder(v, fc.formatCode)).Scan(result.Interface()) if err != nil { t.Errorf("%v %d: %v", fc.name, i, err) } @@ -147,8 +147,8 @@ func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values [] } func TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { - conn := mustConnectPgx(t) - defer mustClose(t, conn) + conn := MustConnectPgx(t) + defer MustClose(t, conn) for i, v := range values { // Derefence value if it is a pointer @@ -176,8 +176,8 @@ func TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName str } func TestDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { - conn := mustConnectDatabaseSQL(t, driverName) - defer mustClose(t, conn) + conn := MustConnectDatabaseSQL(t, driverName) + defer MustClose(t, conn) ps, err := conn.Prepare(fmt.Sprintf("select $1::%s", pgTypeName)) if err != nil { @@ -223,8 +223,8 @@ func TestSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc f } func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { - conn := mustConnectPgx(t) - defer mustClose(t, conn) + conn := MustConnectPgx(t) + defer MustClose(t, conn) formats := []struct { name string @@ -243,7 +243,7 @@ func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFun } ps.FieldDescriptions[0].FormatCode = fc.formatCode - if forceEncoder(tt.Value, fc.formatCode) == nil { + if ForceEncoder(tt.Value, fc.formatCode) == nil { t.Logf("Skipping: %#v does not implement %v", tt.Value, fc.name) continue } @@ -268,8 +268,8 @@ func TestPgxSuccessfulNormalizeEqFunc(t testing.TB, tests []NormalizeTest, eqFun } func TestDatabaseSQLSuccessfulNormalizeEqFunc(t testing.TB, driverName string, tests []NormalizeTest, eqFunc func(a, b interface{}) bool) { - conn := mustConnectDatabaseSQL(t, driverName) - defer mustClose(t, conn) + conn := MustConnectDatabaseSQL(t, driverName) + defer MustClose(t, conn) for i, tt := range tests { ps, err := conn.Prepare(tt.SQL) diff --git a/pgtype/text_array_test.go b/pgtype/text_array_test.go index 5a78d7bc..35ebef96 100644 --- a/pgtype/text_array_test.go +++ b/pgtype/text_array_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestTextArrayTranscode(t *testing.T) { - testSuccessfulTranscode(t, "text[]", []interface{}{ + testutil.TestSuccessfulTranscode(t, "text[]", []interface{}{ &pgtype.TextArray{ Elements: nil, Dimensions: nil, diff --git a/pgtype/text_test.go b/pgtype/text_test.go index 34b6a784..e4c1dbd8 100644 --- a/pgtype/text_test.go +++ b/pgtype/text_test.go @@ -6,11 +6,12 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestTextTranscode(t *testing.T) { for _, pgTypeName := range []string{"text", "varchar"} { - testSuccessfulTranscode(t, pgTypeName, []interface{}{ + testutil.TestSuccessfulTranscode(t, pgTypeName, []interface{}{ pgtype.Text{String: "", Status: pgtype.Present}, pgtype.Text{String: "foo", Status: pgtype.Present}, pgtype.Text{Status: pgtype.Null}, diff --git a/pgtype/tid_test.go b/pgtype/tid_test.go index 56595ef4..7eb7773a 100644 --- a/pgtype/tid_test.go +++ b/pgtype/tid_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestTidTranscode(t *testing.T) { - testSuccessfulTranscode(t, "tid", []interface{}{ + testutil.TestSuccessfulTranscode(t, "tid", []interface{}{ pgtype.Tid{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, pgtype.Tid{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, pgtype.Tid{Status: pgtype.Null}, diff --git a/pgtype/timestamp_array_test.go b/pgtype/timestamp_array_test.go index a15d3696..c75d101f 100644 --- a/pgtype/timestamp_array_test.go +++ b/pgtype/timestamp_array_test.go @@ -6,10 +6,11 @@ import ( "time" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestTimestampArrayTranscode(t *testing.T) { - testSuccessfulTranscodeEqFunc(t, "timestamp[]", []interface{}{ + testutil.TestSuccessfulTranscodeEqFunc(t, "timestamp[]", []interface{}{ &pgtype.TimestampArray{ Elements: nil, Dimensions: nil, diff --git a/pgtype/timestamp_test.go b/pgtype/timestamp_test.go index 58828806..c0427a5c 100644 --- a/pgtype/timestamp_test.go +++ b/pgtype/timestamp_test.go @@ -6,10 +6,11 @@ import ( "time" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestTimestampTranscode(t *testing.T) { - testSuccessfulTranscodeEqFunc(t, "timestamp", []interface{}{ + testutil.TestSuccessfulTranscodeEqFunc(t, "timestamp", []interface{}{ pgtype.Timestamp{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, pgtype.Timestamp{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, diff --git a/pgtype/timestamptz_array_test.go b/pgtype/timestamptz_array_test.go index e0017828..50ee65d0 100644 --- a/pgtype/timestamptz_array_test.go +++ b/pgtype/timestamptz_array_test.go @@ -6,10 +6,11 @@ import ( "time" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestTimestamptzArrayTranscode(t *testing.T) { - testSuccessfulTranscodeEqFunc(t, "timestamptz[]", []interface{}{ + testutil.TestSuccessfulTranscodeEqFunc(t, "timestamptz[]", []interface{}{ &pgtype.TimestamptzArray{ Elements: nil, Dimensions: nil, diff --git a/pgtype/timestamptz_test.go b/pgtype/timestamptz_test.go index 6ddfc1bc..bbc001e5 100644 --- a/pgtype/timestamptz_test.go +++ b/pgtype/timestamptz_test.go @@ -6,10 +6,11 @@ import ( "time" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestTimestamptzTranscode(t *testing.T) { - testSuccessfulTranscodeEqFunc(t, "timestamptz", []interface{}{ + testutil.TestSuccessfulTranscodeEqFunc(t, "timestamptz", []interface{}{ pgtype.Timestamptz{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, pgtype.Timestamptz{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, diff --git a/pgtype/tsrange_test.go b/pgtype/tsrange_test.go index 448cb92f..865233c2 100644 --- a/pgtype/tsrange_test.go +++ b/pgtype/tsrange_test.go @@ -5,10 +5,11 @@ import ( "time" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestTsrangeTranscode(t *testing.T) { - testSuccessfulTranscodeEqFunc(t, "tsrange", []interface{}{ + testutil.TestSuccessfulTranscodeEqFunc(t, "tsrange", []interface{}{ pgtype.Tsrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, pgtype.Tsrange{ Lower: pgtype.Timestamp{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, diff --git a/pgtype/tstzrange_test.go b/pgtype/tstzrange_test.go index 197aabbc..8eb00ab9 100644 --- a/pgtype/tstzrange_test.go +++ b/pgtype/tstzrange_test.go @@ -5,10 +5,11 @@ import ( "time" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestTstzrangeTranscode(t *testing.T) { - testSuccessfulTranscodeEqFunc(t, "tstzrange", []interface{}{ + testutil.TestSuccessfulTranscodeEqFunc(t, "tstzrange", []interface{}{ pgtype.Tstzrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, pgtype.Tstzrange{ Lower: pgtype.Timestamptz{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, diff --git a/pgtype/uuid_test.go b/pgtype/uuid_test.go index 1eba7e90..b745d542 100644 --- a/pgtype/uuid_test.go +++ b/pgtype/uuid_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestUuidTranscode(t *testing.T) { - testSuccessfulTranscode(t, "uuid", []interface{}{ + testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, pgtype.Uuid{Status: pgtype.Null}, }) diff --git a/pgtype/varbit_test.go b/pgtype/varbit_test.go index cd146d26..6c813aae 100644 --- a/pgtype/varbit_test.go +++ b/pgtype/varbit_test.go @@ -4,10 +4,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestVarbitTranscode(t *testing.T) { - testSuccessfulTranscode(t, "varbit", []interface{}{ + testutil.TestSuccessfulTranscode(t, "varbit", []interface{}{ &pgtype.Varbit{Bytes: []byte{}, Len: 0, Status: pgtype.Present}, &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 255}, Len: 40, Status: pgtype.Present}, &pgtype.Varbit{Bytes: []byte{0, 1, 128, 254, 128}, Len: 33, Status: pgtype.Present}, @@ -16,10 +17,10 @@ func TestVarbitTranscode(t *testing.T) { } func TestVarbitNormalize(t *testing.T) { - testSuccessfulNormalize(t, []normalizeTest{ + testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ { - sql: "select B'111111111'", - value: &pgtype.Varbit{Bytes: []byte{255, 128}, Len: 9, Status: pgtype.Present}, + SQL: "select B'111111111'", + Value: &pgtype.Varbit{Bytes: []byte{255, 128}, Len: 9, Status: pgtype.Present}, }, }) } diff --git a/pgtype/varchar_array_test.go b/pgtype/varchar_array_test.go index 4a8b09b8..7d6fb39b 100644 --- a/pgtype/varchar_array_test.go +++ b/pgtype/varchar_array_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestVarcharArrayTranscode(t *testing.T) { - testSuccessfulTranscode(t, "varchar[]", []interface{}{ + testutil.TestSuccessfulTranscode(t, "varchar[]", []interface{}{ &pgtype.VarcharArray{ Elements: nil, Dimensions: nil, diff --git a/pgtype/xid_test.go b/pgtype/xid_test.go index 11dd0615..868c101e 100644 --- a/pgtype/xid_test.go +++ b/pgtype/xid_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" ) func TestXidTranscode(t *testing.T) { @@ -17,13 +18,13 @@ func TestXidTranscode(t *testing.T) { return reflect.DeepEqual(a, b) } - testPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) // No direct conversion from int to xid, convert through text - testPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc) + testutil.TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc) for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { - testDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) + testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) } } From a8c350c77d143ff4f55142efb91dc9836d4e59ce Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 14 Apr 2017 13:08:05 -0500 Subject: [PATCH 167/264] Use pointer methods for all struct pgtypes Now no need to no whether certain interfaces are implemented by struct or pointer to struct. --- pgtype/aclitem.go | 4 ++-- pgtype/aclitem_test.go | 6 +++--- pgtype/bool.go | 6 +++--- pgtype/bool_test.go | 6 +++--- pgtype/bytea.go | 6 +++--- pgtype/bytea_test.go | 6 +++--- pgtype/cid.go | 12 ++++++------ pgtype/cid_test.go | 4 ++-- pgtype/cidr.go | 8 ++++---- pgtype/date.go | 6 +++--- pgtype/date_test.go | 18 +++++++++--------- pgtype/daterange.go | 6 +++--- pgtype/daterange_test.go | 8 ++++---- pgtype/float4.go | 6 +++--- pgtype/float4_test.go | 12 ++++++------ pgtype/float8.go | 6 +++--- pgtype/float8_test.go | 12 ++++++------ pgtype/generic_binary.go | 8 ++++---- pgtype/generic_text.go | 8 ++++---- pgtype/hstore.go | 6 +++--- pgtype/hstore_test.go | 28 ++++++++++++++-------------- pgtype/inet.go | 6 +++--- pgtype/inet_test.go | 22 +++++++++++----------- pgtype/int2.go | 8 ++++---- pgtype/int2_test.go | 12 ++++++------ pgtype/int4.go | 8 ++++---- pgtype/int4_test.go | 12 ++++++------ pgtype/int4range.go | 6 +++--- pgtype/int4range_test.go | 8 ++++---- pgtype/int8.go | 8 ++++---- pgtype/int8_test.go | 12 ++++++------ pgtype/int8range.go | 6 +++--- pgtype/int8range_test.go | 8 ++++---- pgtype/interval.go | 6 +++--- pgtype/interval_test.go | 34 +++++++++++++++++----------------- pgtype/json.go | 6 +++--- pgtype/json_test.go | 10 +++++----- pgtype/jsonb.go | 10 +++++----- pgtype/jsonb_test.go | 10 +++++----- pgtype/macaddr.go | 6 +++--- pgtype/macaddr_test.go | 4 ++-- pgtype/name.go | 12 ++++++------ pgtype/name_test.go | 6 +++--- pgtype/numrange.go | 6 +++--- pgtype/numrange_test.go | 8 ++++---- pgtype/oid_value.go | 12 ++++++------ pgtype/oid_value_test.go | 4 ++-- pgtype/pguint32.go | 6 +++--- pgtype/qchar.go | 2 +- pgtype/text.go | 8 ++++---- pgtype/text_test.go | 6 +++--- pgtype/tid.go | 6 +++--- pgtype/tid_test.go | 6 +++--- pgtype/timestamp.go | 6 +++--- pgtype/timestamp_test.go | 26 +++++++++++++------------- pgtype/timestamptz.go | 6 +++--- pgtype/timestamptz_test.go | 26 +++++++++++++------------- pgtype/tsrange.go | 6 +++--- pgtype/tsrange_test.go | 8 ++++---- pgtype/tstzrange.go | 6 +++--- pgtype/tstzrange_test.go | 8 ++++---- pgtype/unknown.go | 4 ++-- pgtype/uuid.go | 6 +++--- pgtype/uuid_test.go | 4 ++-- pgtype/varchar.go | 16 ++++++++-------- pgtype/xid.go | 12 ++++++------ pgtype/xid_test.go | 4 ++-- 67 files changed, 302 insertions(+), 302 deletions(-) diff --git a/pgtype/aclitem.go b/pgtype/aclitem.go index 3ccf8318..ebfcc3e7 100644 --- a/pgtype/aclitem.go +++ b/pgtype/aclitem.go @@ -83,7 +83,7 @@ func (dst *Aclitem) DecodeText(ci *ConnInfo, src []byte) error { return nil } -func (src Aclitem) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Aclitem) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -113,7 +113,7 @@ func (dst *Aclitem) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Aclitem) Value() (driver.Value, error) { +func (src *Aclitem) Value() (driver.Value, error) { switch src.Status { case Present: return src.String, nil diff --git a/pgtype/aclitem_test.go b/pgtype/aclitem_test.go index 5389eab2..13c63395 100644 --- a/pgtype/aclitem_test.go +++ b/pgtype/aclitem_test.go @@ -10,9 +10,9 @@ import ( func TestAclitemTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "aclitem", []interface{}{ - pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - pgtype.Aclitem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, - pgtype.Aclitem{Status: pgtype.Null}, + &pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + &pgtype.Aclitem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, + &pgtype.Aclitem{Status: pgtype.Null}, }) } diff --git a/pgtype/bool.go b/pgtype/bool.go index 1ebf590b..9d309f0c 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -90,7 +90,7 @@ func (dst *Bool) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Bool) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Bool) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -109,7 +109,7 @@ func (src Bool) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } -func (src Bool) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Bool) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -149,7 +149,7 @@ func (dst *Bool) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Bool) Value() (driver.Value, error) { +func (src *Bool) Value() (driver.Value, error) { switch src.Status { case Present: return src.Bool, nil diff --git a/pgtype/bool_test.go b/pgtype/bool_test.go index 31f3d528..2712e3b0 100644 --- a/pgtype/bool_test.go +++ b/pgtype/bool_test.go @@ -10,9 +10,9 @@ import ( func TestBoolTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "bool", []interface{}{ - pgtype.Bool{Bool: false, Status: pgtype.Present}, - pgtype.Bool{Bool: true, Status: pgtype.Present}, - pgtype.Bool{Bool: false, Status: pgtype.Null}, + &pgtype.Bool{Bool: false, Status: pgtype.Present}, + &pgtype.Bool{Bool: true, Status: pgtype.Present}, + &pgtype.Bool{Bool: false, Status: pgtype.Null}, }) } diff --git a/pgtype/bytea.go b/pgtype/bytea.go index 8bf5de2b..3e2661db 100644 --- a/pgtype/bytea.go +++ b/pgtype/bytea.go @@ -102,7 +102,7 @@ func (dst *Bytea) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Bytea) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Bytea) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -119,7 +119,7 @@ func (src Bytea) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } -func (src Bytea) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Bytea) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -152,7 +152,7 @@ func (dst *Bytea) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Bytea) Value() (driver.Value, error) { +func (src *Bytea) Value() (driver.Value, error) { switch src.Status { case Present: return src.Bytes, nil diff --git a/pgtype/bytea_test.go b/pgtype/bytea_test.go index 7d32e294..fd5a0dec 100644 --- a/pgtype/bytea_test.go +++ b/pgtype/bytea_test.go @@ -10,9 +10,9 @@ import ( func TestByteaTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "bytea", []interface{}{ - pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, - pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}, - pgtype.Bytea{Bytes: nil, Status: pgtype.Null}, + &pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + &pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}, + &pgtype.Bytea{Bytes: nil, Status: pgtype.Null}, }) } diff --git a/pgtype/cid.go b/pgtype/cid.go index 63ba6a2f..c2b3073b 100644 --- a/pgtype/cid.go +++ b/pgtype/cid.go @@ -43,12 +43,12 @@ func (dst *Cid) DecodeBinary(ci *ConnInfo, src []byte) error { return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src Cid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (pguint32)(src).EncodeText(ci, w) +func (src *Cid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (*pguint32)(src).EncodeText(ci, w) } -func (src Cid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (pguint32)(src).EncodeBinary(ci, w) +func (src *Cid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (*pguint32)(src).EncodeBinary(ci, w) } // Scan implements the database/sql Scanner interface. @@ -57,6 +57,6 @@ func (dst *Cid) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Cid) Value() (driver.Value, error) { - return (pguint32)(src).Value() +func (src *Cid) Value() (driver.Value, error) { + return (*pguint32)(src).Value() } diff --git a/pgtype/cid_test.go b/pgtype/cid_test.go index 385b8cac..c3bf3132 100644 --- a/pgtype/cid_test.go +++ b/pgtype/cid_test.go @@ -11,8 +11,8 @@ import ( func TestCidTranscode(t *testing.T) { pgTypeName := "cid" values := []interface{}{ - pgtype.Cid{Uint: 42, Status: pgtype.Present}, - pgtype.Cid{Status: pgtype.Null}, + &pgtype.Cid{Uint: 42, Status: pgtype.Present}, + &pgtype.Cid{Status: pgtype.Null}, } eqFunc := func(a, b interface{}) bool { return reflect.DeepEqual(a, b) diff --git a/pgtype/cidr.go b/pgtype/cidr.go index 463b279d..39a87a26 100644 --- a/pgtype/cidr.go +++ b/pgtype/cidr.go @@ -26,10 +26,10 @@ func (dst *Cidr) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Inet)(dst).DecodeBinary(ci, src) } -func (src Cidr) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (Inet)(src).EncodeText(ci, w) +func (src *Cidr) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (*Inet)(src).EncodeText(ci, w) } -func (src Cidr) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (Inet)(src).EncodeBinary(ci, w) +func (src *Cidr) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (*Inet)(src).EncodeBinary(ci, w) } diff --git a/pgtype/date.go b/pgtype/date.go index 34753f05..993a04c5 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -125,7 +125,7 @@ func (dst *Date) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Date) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Date) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -148,7 +148,7 @@ func (src Date) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } -func (src Date) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Date) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -195,7 +195,7 @@ func (dst *Date) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Date) Value() (driver.Value, error) { +func (src *Date) Value() (driver.Value, error) { switch src.Status { case Present: if src.InfinityModifier != None { diff --git a/pgtype/date_test.go b/pgtype/date_test.go index d1493f5e..d98e1652 100644 --- a/pgtype/date_test.go +++ b/pgtype/date_test.go @@ -11,15 +11,15 @@ import ( func TestDateTranscode(t *testing.T) { testutil.TestSuccessfulTranscodeEqFunc(t, "date", []interface{}{ - pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Date{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Date{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Date{Status: pgtype.Null}, - pgtype.Date{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, - pgtype.Date{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, + &pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Date{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Date{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Date{Status: pgtype.Null}, + &pgtype.Date{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, + &pgtype.Date{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, }, func(a, b interface{}) bool { at := a.(pgtype.Date) bt := b.(pgtype.Date) diff --git a/pgtype/daterange.go b/pgtype/daterange.go index fbf51980..d78c4803 100644 --- a/pgtype/daterange.go +++ b/pgtype/daterange.go @@ -106,7 +106,7 @@ func (dst *Daterange) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Daterange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Daterange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src Daterange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } -func (src Daterange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Daterange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -263,6 +263,6 @@ func (dst *Daterange) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Daterange) Value() (driver.Value, error) { +func (src *Daterange) Value() (driver.Value, error) { return encodeValueText(src) } diff --git a/pgtype/daterange_test.go b/pgtype/daterange_test.go index 7dfae0f4..d2af5986 100644 --- a/pgtype/daterange_test.go +++ b/pgtype/daterange_test.go @@ -10,22 +10,22 @@ import ( func TestDaterangeTranscode(t *testing.T) { testutil.TestSuccessfulTranscodeEqFunc(t, "daterange", []interface{}{ - pgtype.Daterange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, - pgtype.Daterange{ + &pgtype.Daterange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + &pgtype.Daterange{ Lower: pgtype.Date{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, Upper: pgtype.Date{Time: time.Date(2028, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present, }, - pgtype.Daterange{ + &pgtype.Daterange{ Lower: pgtype.Date{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, Upper: pgtype.Date{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present, }, - pgtype.Daterange{Status: pgtype.Null}, + &pgtype.Daterange{Status: pgtype.Null}, }, func(aa, bb interface{}) bool { a := aa.(pgtype.Daterange) b := bb.(pgtype.Daterange) diff --git a/pgtype/float4.go b/pgtype/float4.go index e92149a6..76be4203 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -139,7 +139,7 @@ func (dst *Float4) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Float4) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Float4) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -151,7 +151,7 @@ func (src Float4) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } -func (src Float4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Float4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -184,7 +184,7 @@ func (dst *Float4) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Float4) Value() (driver.Value, error) { +func (src *Float4) Value() (driver.Value, error) { switch src.Status { case Present: return float64(src.Float), nil diff --git a/pgtype/float4_test.go b/pgtype/float4_test.go index 57f4bc34..2ed8d05d 100644 --- a/pgtype/float4_test.go +++ b/pgtype/float4_test.go @@ -10,12 +10,12 @@ import ( func TestFloat4Transcode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "float4", []interface{}{ - pgtype.Float4{Float: -1, Status: pgtype.Present}, - pgtype.Float4{Float: 0, Status: pgtype.Present}, - pgtype.Float4{Float: 0.00001, Status: pgtype.Present}, - pgtype.Float4{Float: 1, Status: pgtype.Present}, - pgtype.Float4{Float: 9999.99, Status: pgtype.Present}, - pgtype.Float4{Float: 0, Status: pgtype.Null}, + &pgtype.Float4{Float: -1, Status: pgtype.Present}, + &pgtype.Float4{Float: 0, Status: pgtype.Present}, + &pgtype.Float4{Float: 0.00001, Status: pgtype.Present}, + &pgtype.Float4{Float: 1, Status: pgtype.Present}, + &pgtype.Float4{Float: 9999.99, Status: pgtype.Present}, + &pgtype.Float4{Float: 0, Status: pgtype.Null}, }) } diff --git a/pgtype/float8.go b/pgtype/float8.go index 4d094757..8cfc53c5 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -129,7 +129,7 @@ func (dst *Float8) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Float8) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Float8) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -141,7 +141,7 @@ func (src Float8) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } -func (src Float8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Float8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -174,7 +174,7 @@ func (dst *Float8) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Float8) Value() (driver.Value, error) { +func (src *Float8) Value() (driver.Value, error) { switch src.Status { case Present: return src.Float, nil diff --git a/pgtype/float8_test.go b/pgtype/float8_test.go index b7527b86..46fc8d5d 100644 --- a/pgtype/float8_test.go +++ b/pgtype/float8_test.go @@ -10,12 +10,12 @@ import ( func TestFloat8Transcode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "float8", []interface{}{ - pgtype.Float8{Float: -1, Status: pgtype.Present}, - pgtype.Float8{Float: 0, Status: pgtype.Present}, - pgtype.Float8{Float: 0.00001, Status: pgtype.Present}, - pgtype.Float8{Float: 1, Status: pgtype.Present}, - pgtype.Float8{Float: 9999.99, Status: pgtype.Present}, - pgtype.Float8{Float: 0, Status: pgtype.Null}, + &pgtype.Float8{Float: -1, Status: pgtype.Present}, + &pgtype.Float8{Float: 0, Status: pgtype.Present}, + &pgtype.Float8{Float: 0.00001, Status: pgtype.Present}, + &pgtype.Float8{Float: 1, Status: pgtype.Present}, + &pgtype.Float8{Float: 9999.99, Status: pgtype.Present}, + &pgtype.Float8{Float: 0, Status: pgtype.Null}, }) } diff --git a/pgtype/generic_binary.go b/pgtype/generic_binary.go index f834bfb2..094bd64e 100644 --- a/pgtype/generic_binary.go +++ b/pgtype/generic_binary.go @@ -25,8 +25,8 @@ func (dst *GenericBinary) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Bytea)(dst).DecodeBinary(ci, src) } -func (src GenericBinary) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (Bytea)(src).EncodeBinary(ci, w) +func (src *GenericBinary) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (*Bytea)(src).EncodeBinary(ci, w) } // Scan implements the database/sql Scanner interface. @@ -35,6 +35,6 @@ func (dst *GenericBinary) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src GenericBinary) Value() (driver.Value, error) { - return (Bytea)(src).Value() +func (src *GenericBinary) Value() (driver.Value, error) { + return (*Bytea)(src).Value() } diff --git a/pgtype/generic_text.go b/pgtype/generic_text.go index 053ec504..5d0d83be 100644 --- a/pgtype/generic_text.go +++ b/pgtype/generic_text.go @@ -25,8 +25,8 @@ func (dst *GenericText) DecodeText(ci *ConnInfo, src []byte) error { return (*Text)(dst).DecodeText(ci, src) } -func (src GenericText) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (Text)(src).EncodeText(ci, w) +func (src *GenericText) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (*Text)(src).EncodeText(ci, w) } // Scan implements the database/sql Scanner interface. @@ -35,6 +35,6 @@ func (dst *GenericText) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src GenericText) Value() (driver.Value, error) { - return (Text)(src).Value() +func (src *GenericText) Value() (driver.Value, error) { + return (*Text)(src).Value() } diff --git a/pgtype/hstore.go b/pgtype/hstore.go index 5dc78671..3d55f783 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -151,7 +151,7 @@ func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Hstore) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Hstore) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -203,7 +203,7 @@ func (src Hstore) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } -func (src Hstore) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Hstore) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -462,6 +462,6 @@ func (dst *Hstore) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Hstore) Value() (driver.Value, error) { +func (src *Hstore) Value() (driver.Value, error) { return encodeValueText(src) } diff --git a/pgtype/hstore_test.go b/pgtype/hstore_test.go index 502a8df0..dc2439fc 100644 --- a/pgtype/hstore_test.go +++ b/pgtype/hstore_test.go @@ -14,12 +14,12 @@ func TestHstoreTranscode(t *testing.T) { } values := []interface{}{ - pgtype.Hstore{Map: map[string]pgtype.Text{}, Status: pgtype.Present}, - pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, - pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, - pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, - pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, - pgtype.Hstore{Status: pgtype.Null}, + &pgtype.Hstore{Map: map[string]pgtype.Text{}, Status: pgtype.Present}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, + &pgtype.Hstore{Status: pgtype.Null}, } specialStrings := []string{ @@ -33,16 +33,16 @@ func TestHstoreTranscode(t *testing.T) { } for _, s := range specialStrings { // Special key values - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Status: pgtype.Present}) // at beginning - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Status: pgtype.Present}) // in middle - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Status: pgtype.Present}) // at end - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Status: pgtype.Present}) // is key + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Status: pgtype.Present}) // at beginning + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Status: pgtype.Present}) // in middle + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Status: pgtype.Present}) // at end + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Status: pgtype.Present}) // is key // Special value values - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Status: pgtype.Present}) // at beginning - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Status: pgtype.Present}) // in middle - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Status: pgtype.Present}) // at end - values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Status: pgtype.Present}) // at beginning + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Status: pgtype.Present}) // in middle + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Status: pgtype.Present}) // at end + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key } testutil.TestSuccessfulTranscodeEqFunc(t, "hstore", values, func(ai, bi interface{}) bool { diff --git a/pgtype/inet.go b/pgtype/inet.go index 09fce04d..62734088 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -149,7 +149,7 @@ func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Inet) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Inet) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -162,7 +162,7 @@ func (src Inet) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } // EncodeBinary encodes src into w. -func (src Inet) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Inet) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -220,6 +220,6 @@ func (dst *Inet) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Inet) Value() (driver.Value, error) { +func (src *Inet) Value() (driver.Value, error) { return encodeValueText(src) } diff --git a/pgtype/inet_test.go b/pgtype/inet_test.go index 532e9abe..b883df8e 100644 --- a/pgtype/inet_test.go +++ b/pgtype/inet_test.go @@ -12,17 +12,17 @@ import ( func TestInetTranscode(t *testing.T) { for _, pgTypeName := range []string{"inet", "cidr"} { testutil.TestSuccessfulTranscode(t, pgTypeName, []interface{}{ - pgtype.Inet{IPNet: mustParseCidr(t, "0.0.0.0/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "192.168.1.0/24"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "255.0.0.0/8"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "255.255.255.255/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "::/128"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "::/0"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "::1/128"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, - pgtype.Inet{Status: pgtype.Null}, + &pgtype.Inet{IPNet: mustParseCidr(t, "0.0.0.0/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCidr(t, "192.168.1.0/24"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCidr(t, "255.0.0.0/8"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCidr(t, "255.255.255.255/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCidr(t, "::/128"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCidr(t, "::/0"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCidr(t, "::1/128"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + &pgtype.Inet{Status: pgtype.Null}, }) } } diff --git a/pgtype/int2.go b/pgtype/int2.go index 0cb6ef82..4a3beb22 100644 --- a/pgtype/int2.go +++ b/pgtype/int2.go @@ -134,7 +134,7 @@ func (dst *Int2) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Int2) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int2) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -146,7 +146,7 @@ func (src Int2) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } -func (src Int2) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int2) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -185,7 +185,7 @@ func (dst *Int2) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Int2) Value() (driver.Value, error) { +func (src *Int2) Value() (driver.Value, error) { switch src.Status { case Present: return int64(src.Int), nil @@ -196,7 +196,7 @@ func (src Int2) Value() (driver.Value, error) { } } -func (src Int2) MarshalJSON() ([]byte, error) { +func (src *Int2) MarshalJSON() ([]byte, error) { switch src.Status { case Present: return []byte(strconv.FormatInt(int64(src.Int), 10)), nil diff --git a/pgtype/int2_test.go b/pgtype/int2_test.go index d81405a6..d20bf0ed 100644 --- a/pgtype/int2_test.go +++ b/pgtype/int2_test.go @@ -11,12 +11,12 @@ import ( func TestInt2Transcode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "int2", []interface{}{ - pgtype.Int2{Int: math.MinInt16, Status: pgtype.Present}, - pgtype.Int2{Int: -1, Status: pgtype.Present}, - pgtype.Int2{Int: 0, Status: pgtype.Present}, - pgtype.Int2{Int: 1, Status: pgtype.Present}, - pgtype.Int2{Int: math.MaxInt16, Status: pgtype.Present}, - pgtype.Int2{Int: 0, Status: pgtype.Null}, + &pgtype.Int2{Int: math.MinInt16, Status: pgtype.Present}, + &pgtype.Int2{Int: -1, Status: pgtype.Present}, + &pgtype.Int2{Int: 0, Status: pgtype.Present}, + &pgtype.Int2{Int: 1, Status: pgtype.Present}, + &pgtype.Int2{Int: math.MaxInt16, Status: pgtype.Present}, + &pgtype.Int2{Int: 0, Status: pgtype.Null}, }) } diff --git a/pgtype/int4.go b/pgtype/int4.go index 4a5bca51..f429d887 100644 --- a/pgtype/int4.go +++ b/pgtype/int4.go @@ -125,7 +125,7 @@ func (dst *Int4) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Int4) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int4) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -137,7 +137,7 @@ func (src Int4) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } -func (src Int4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -176,7 +176,7 @@ func (dst *Int4) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Int4) Value() (driver.Value, error) { +func (src *Int4) Value() (driver.Value, error) { switch src.Status { case Present: return int64(src.Int), nil @@ -187,7 +187,7 @@ func (src Int4) Value() (driver.Value, error) { } } -func (src Int4) MarshalJSON() ([]byte, error) { +func (src *Int4) MarshalJSON() ([]byte, error) { switch src.Status { case Present: return []byte(strconv.FormatInt(int64(src.Int), 10)), nil diff --git a/pgtype/int4_test.go b/pgtype/int4_test.go index 1354b47a..02f5409f 100644 --- a/pgtype/int4_test.go +++ b/pgtype/int4_test.go @@ -11,12 +11,12 @@ import ( func TestInt4Transcode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "int4", []interface{}{ - pgtype.Int4{Int: math.MinInt32, Status: pgtype.Present}, - pgtype.Int4{Int: -1, Status: pgtype.Present}, - pgtype.Int4{Int: 0, Status: pgtype.Present}, - pgtype.Int4{Int: 1, Status: pgtype.Present}, - pgtype.Int4{Int: math.MaxInt32, Status: pgtype.Present}, - pgtype.Int4{Int: 0, Status: pgtype.Null}, + &pgtype.Int4{Int: math.MinInt32, Status: pgtype.Present}, + &pgtype.Int4{Int: -1, Status: pgtype.Present}, + &pgtype.Int4{Int: 0, Status: pgtype.Present}, + &pgtype.Int4{Int: 1, Status: pgtype.Present}, + &pgtype.Int4{Int: math.MaxInt32, Status: pgtype.Present}, + &pgtype.Int4{Int: 0, Status: pgtype.Null}, }) } diff --git a/pgtype/int4range.go b/pgtype/int4range.go index cac4484c..8b04cf3c 100644 --- a/pgtype/int4range.go +++ b/pgtype/int4range.go @@ -106,7 +106,7 @@ func (dst *Int4range) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Int4range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int4range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src Int4range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } -func (src Int4range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int4range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -263,6 +263,6 @@ func (dst *Int4range) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Int4range) Value() (driver.Value, error) { +func (src *Int4range) Value() (driver.Value, error) { return encodeValueText(src) } diff --git a/pgtype/int4range_test.go b/pgtype/int4range_test.go index 74a91e59..088097d8 100644 --- a/pgtype/int4range_test.go +++ b/pgtype/int4range_test.go @@ -9,10 +9,10 @@ import ( func TestInt4rangeTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "int4range", []interface{}{ - pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, - pgtype.Int4range{Lower: pgtype.Int4{Int: 1, Status: pgtype.Present}, Upper: pgtype.Int4{Int: 10, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - pgtype.Int4range{Lower: pgtype.Int4{Int: -42, Status: pgtype.Present}, Upper: pgtype.Int4{Int: -5, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - pgtype.Int4range{Status: pgtype.Null}, + &pgtype.Int4range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + &pgtype.Int4range{Lower: pgtype.Int4{Int: 1, Status: pgtype.Present}, Upper: pgtype.Int4{Int: 10, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + &pgtype.Int4range{Lower: pgtype.Int4{Int: -42, Status: pgtype.Present}, Upper: pgtype.Int4{Int: -5, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + &pgtype.Int4range{Status: pgtype.Null}, }) } diff --git a/pgtype/int8.go b/pgtype/int8.go index 0cc3545d..97db8393 100644 --- a/pgtype/int8.go +++ b/pgtype/int8.go @@ -117,7 +117,7 @@ func (dst *Int8) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Int8) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int8) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -129,7 +129,7 @@ func (src Int8) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } -func (src Int8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -162,7 +162,7 @@ func (dst *Int8) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Int8) Value() (driver.Value, error) { +func (src *Int8) Value() (driver.Value, error) { switch src.Status { case Present: return int64(src.Int), nil @@ -173,7 +173,7 @@ func (src Int8) Value() (driver.Value, error) { } } -func (src Int8) MarshalJSON() ([]byte, error) { +func (src *Int8) MarshalJSON() ([]byte, error) { switch src.Status { case Present: return []byte(strconv.FormatInt(src.Int, 10)), nil diff --git a/pgtype/int8_test.go b/pgtype/int8_test.go index d6752205..0b3bb3eb 100644 --- a/pgtype/int8_test.go +++ b/pgtype/int8_test.go @@ -11,12 +11,12 @@ import ( func TestInt8Transcode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "int8", []interface{}{ - pgtype.Int8{Int: math.MinInt64, Status: pgtype.Present}, - pgtype.Int8{Int: -1, Status: pgtype.Present}, - pgtype.Int8{Int: 0, Status: pgtype.Present}, - pgtype.Int8{Int: 1, Status: pgtype.Present}, - pgtype.Int8{Int: math.MaxInt64, Status: pgtype.Present}, - pgtype.Int8{Int: 0, Status: pgtype.Null}, + &pgtype.Int8{Int: math.MinInt64, Status: pgtype.Present}, + &pgtype.Int8{Int: -1, Status: pgtype.Present}, + &pgtype.Int8{Int: 0, Status: pgtype.Present}, + &pgtype.Int8{Int: 1, Status: pgtype.Present}, + &pgtype.Int8{Int: math.MaxInt64, Status: pgtype.Present}, + &pgtype.Int8{Int: 0, Status: pgtype.Null}, }) } diff --git a/pgtype/int8range.go b/pgtype/int8range.go index 44946be9..f8e056cb 100644 --- a/pgtype/int8range.go +++ b/pgtype/int8range.go @@ -106,7 +106,7 @@ func (dst *Int8range) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Int8range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int8range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src Int8range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } -func (src Int8range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int8range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -263,6 +263,6 @@ func (dst *Int8range) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Int8range) Value() (driver.Value, error) { +func (src *Int8range) Value() (driver.Value, error) { return encodeValueText(src) } diff --git a/pgtype/int8range_test.go b/pgtype/int8range_test.go index 703f476e..c039ec65 100644 --- a/pgtype/int8range_test.go +++ b/pgtype/int8range_test.go @@ -9,10 +9,10 @@ import ( func TestInt8rangeTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "Int8range", []interface{}{ - pgtype.Int8range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, - pgtype.Int8range{Lower: pgtype.Int8{Int: 1, Status: pgtype.Present}, Upper: pgtype.Int8{Int: 10, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - pgtype.Int8range{Lower: pgtype.Int8{Int: -42, Status: pgtype.Present}, Upper: pgtype.Int8{Int: -5, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, - pgtype.Int8range{Status: pgtype.Null}, + &pgtype.Int8range{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + &pgtype.Int8range{Lower: pgtype.Int8{Int: 1, Status: pgtype.Present}, Upper: pgtype.Int8{Int: 10, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + &pgtype.Int8range{Lower: pgtype.Int8{Int: -42, Status: pgtype.Present}, Upper: pgtype.Int8{Int: -5, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present}, + &pgtype.Int8range{Status: pgtype.Null}, }) } diff --git a/pgtype/interval.go b/pgtype/interval.go index 20a4a419..1cbdffc3 100644 --- a/pgtype/interval.go +++ b/pgtype/interval.go @@ -178,7 +178,7 @@ func (dst *Interval) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Interval) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Interval) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -227,7 +227,7 @@ func (src Interval) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } // EncodeBinary encodes src into w. -func (src Interval) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Interval) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -266,6 +266,6 @@ func (dst *Interval) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Interval) Value() (driver.Value, error) { +func (src *Interval) Value() (driver.Value, error) { return encodeValueText(src) } diff --git a/pgtype/interval_test.go b/pgtype/interval_test.go index 28e77e0a..18e21ddd 100644 --- a/pgtype/interval_test.go +++ b/pgtype/interval_test.go @@ -9,23 +9,23 @@ import ( func TestIntervalTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "interval", []interface{}{ - pgtype.Interval{Microseconds: 1, Status: pgtype.Present}, - pgtype.Interval{Microseconds: 1000000, Status: pgtype.Present}, - pgtype.Interval{Microseconds: 1000001, Status: pgtype.Present}, - pgtype.Interval{Microseconds: 123202800000000, Status: pgtype.Present}, - pgtype.Interval{Days: 1, Status: pgtype.Present}, - pgtype.Interval{Months: 1, Status: pgtype.Present}, - pgtype.Interval{Months: 12, Status: pgtype.Present}, - pgtype.Interval{Months: 13, Days: 15, Microseconds: 1000001, Status: pgtype.Present}, - pgtype.Interval{Microseconds: -1, Status: pgtype.Present}, - pgtype.Interval{Microseconds: -1000000, Status: pgtype.Present}, - pgtype.Interval{Microseconds: -1000001, Status: pgtype.Present}, - pgtype.Interval{Microseconds: -123202800000000, Status: pgtype.Present}, - pgtype.Interval{Days: -1, Status: pgtype.Present}, - pgtype.Interval{Months: -1, Status: pgtype.Present}, - pgtype.Interval{Months: -12, Status: pgtype.Present}, - pgtype.Interval{Months: -13, Days: -15, Microseconds: -1000001, Status: pgtype.Present}, - pgtype.Interval{Status: pgtype.Null}, + &pgtype.Interval{Microseconds: 1, Status: pgtype.Present}, + &pgtype.Interval{Microseconds: 1000000, Status: pgtype.Present}, + &pgtype.Interval{Microseconds: 1000001, Status: pgtype.Present}, + &pgtype.Interval{Microseconds: 123202800000000, Status: pgtype.Present}, + &pgtype.Interval{Days: 1, Status: pgtype.Present}, + &pgtype.Interval{Months: 1, Status: pgtype.Present}, + &pgtype.Interval{Months: 12, Status: pgtype.Present}, + &pgtype.Interval{Months: 13, Days: 15, Microseconds: 1000001, Status: pgtype.Present}, + &pgtype.Interval{Microseconds: -1, Status: pgtype.Present}, + &pgtype.Interval{Microseconds: -1000000, Status: pgtype.Present}, + &pgtype.Interval{Microseconds: -1000001, Status: pgtype.Present}, + &pgtype.Interval{Microseconds: -123202800000000, Status: pgtype.Present}, + &pgtype.Interval{Days: -1, Status: pgtype.Present}, + &pgtype.Interval{Months: -1, Status: pgtype.Present}, + &pgtype.Interval{Months: -12, Status: pgtype.Present}, + &pgtype.Interval{Months: -13, Days: -15, Microseconds: -1000001, Status: pgtype.Present}, + &pgtype.Interval{Status: pgtype.Null}, }) } diff --git a/pgtype/json.go b/pgtype/json.go index b1c061f9..a027a91c 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -108,7 +108,7 @@ func (dst *Json) DecodeBinary(ci *ConnInfo, src []byte) error { return dst.DecodeText(ci, src) } -func (src Json) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Json) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -120,7 +120,7 @@ func (src Json) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } -func (src Json) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Json) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return src.EncodeText(ci, w) } @@ -142,7 +142,7 @@ func (dst *Json) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Json) Value() (driver.Value, error) { +func (src *Json) Value() (driver.Value, error) { switch src.Status { case Present: return string(src.Bytes), nil diff --git a/pgtype/json_test.go b/pgtype/json_test.go index 6d7cccfd..3d8d2a68 100644 --- a/pgtype/json_test.go +++ b/pgtype/json_test.go @@ -11,11 +11,11 @@ import ( func TestJsonTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "json", []interface{}{ - pgtype.Json{Bytes: []byte("{}"), Status: pgtype.Present}, - pgtype.Json{Bytes: []byte("null"), Status: pgtype.Present}, - pgtype.Json{Bytes: []byte("42"), Status: pgtype.Present}, - pgtype.Json{Bytes: []byte(`"hello"`), Status: pgtype.Present}, - pgtype.Json{Status: pgtype.Null}, + &pgtype.Json{Bytes: []byte("{}"), Status: pgtype.Present}, + &pgtype.Json{Bytes: []byte("null"), Status: pgtype.Present}, + &pgtype.Json{Bytes: []byte("42"), Status: pgtype.Present}, + &pgtype.Json{Bytes: []byte(`"hello"`), Status: pgtype.Present}, + &pgtype.Json{Status: pgtype.Null}, }) } diff --git a/pgtype/jsonb.go b/pgtype/jsonb.go index f47476d6..82cbb21f 100644 --- a/pgtype/jsonb.go +++ b/pgtype/jsonb.go @@ -47,11 +47,11 @@ func (dst *Jsonb) DecodeBinary(ci *ConnInfo, src []byte) error { } -func (src Jsonb) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (Json)(src).EncodeText(ci, w) +func (src *Jsonb) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (*Json)(src).EncodeText(ci, w) } -func (src Jsonb) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Jsonb) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -74,6 +74,6 @@ func (dst *Jsonb) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Jsonb) Value() (driver.Value, error) { - return (Json)(src).Value() +func (src *Jsonb) Value() (driver.Value, error) { + return (*Json)(src).Value() } diff --git a/pgtype/jsonb_test.go b/pgtype/jsonb_test.go index 37c11858..86c8a12c 100644 --- a/pgtype/jsonb_test.go +++ b/pgtype/jsonb_test.go @@ -17,11 +17,11 @@ func TestJsonbTranscode(t *testing.T) { } testutil.TestSuccessfulTranscode(t, "jsonb", []interface{}{ - pgtype.Jsonb{Bytes: []byte("{}"), Status: pgtype.Present}, - pgtype.Jsonb{Bytes: []byte("null"), Status: pgtype.Present}, - pgtype.Jsonb{Bytes: []byte("42"), Status: pgtype.Present}, - pgtype.Jsonb{Bytes: []byte(`"hello"`), Status: pgtype.Present}, - pgtype.Jsonb{Status: pgtype.Null}, + &pgtype.Jsonb{Bytes: []byte("{}"), Status: pgtype.Present}, + &pgtype.Jsonb{Bytes: []byte("null"), Status: pgtype.Present}, + &pgtype.Jsonb{Bytes: []byte("42"), Status: pgtype.Present}, + &pgtype.Jsonb{Bytes: []byte(`"hello"`), Status: pgtype.Present}, + &pgtype.Jsonb{Status: pgtype.Null}, }) } diff --git a/pgtype/macaddr.go b/pgtype/macaddr.go index 2834d69f..cfbb513d 100644 --- a/pgtype/macaddr.go +++ b/pgtype/macaddr.go @@ -106,7 +106,7 @@ func (dst *Macaddr) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Macaddr) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Macaddr) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -119,7 +119,7 @@ func (src Macaddr) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } // EncodeBinary encodes src into w. -func (src Macaddr) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Macaddr) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -149,6 +149,6 @@ func (dst *Macaddr) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Macaddr) Value() (driver.Value, error) { +func (src *Macaddr) Value() (driver.Value, error) { return encodeValueText(src) } diff --git a/pgtype/macaddr_test.go b/pgtype/macaddr_test.go index c2542da3..5d329249 100644 --- a/pgtype/macaddr_test.go +++ b/pgtype/macaddr_test.go @@ -12,8 +12,8 @@ import ( func TestMacaddrTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "macaddr", []interface{}{ - pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, - pgtype.Macaddr{Status: pgtype.Null}, + &pgtype.Macaddr{Addr: mustParseMacaddr(t, "01:23:45:67:89:ab"), Status: pgtype.Present}, + &pgtype.Macaddr{Status: pgtype.Null}, }) } diff --git a/pgtype/name.go b/pgtype/name.go index cc4ae23b..05e92563 100644 --- a/pgtype/name.go +++ b/pgtype/name.go @@ -40,12 +40,12 @@ func (dst *Name) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Text)(dst).DecodeBinary(ci, src) } -func (src Name) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (Text)(src).EncodeText(ci, w) +func (src *Name) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (*Text)(src).EncodeText(ci, w) } -func (src Name) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (Text)(src).EncodeBinary(ci, w) +func (src *Name) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (*Text)(src).EncodeBinary(ci, w) } // Scan implements the database/sql Scanner interface. @@ -54,6 +54,6 @@ func (dst *Name) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Name) Value() (driver.Value, error) { - return (Text)(src).Value() +func (src *Name) Value() (driver.Value, error) { + return (*Text)(src).Value() } diff --git a/pgtype/name_test.go b/pgtype/name_test.go index 348f8d39..ec0820c4 100644 --- a/pgtype/name_test.go +++ b/pgtype/name_test.go @@ -10,9 +10,9 @@ import ( func TestNameTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "name", []interface{}{ - pgtype.Name{String: "", Status: pgtype.Present}, - pgtype.Name{String: "foo", Status: pgtype.Present}, - pgtype.Name{Status: pgtype.Null}, + &pgtype.Name{String: "", Status: pgtype.Present}, + &pgtype.Name{String: "foo", Status: pgtype.Present}, + &pgtype.Name{Status: pgtype.Null}, }) } diff --git a/pgtype/numrange.go b/pgtype/numrange.go index cf42dcbd..a1b5b184 100644 --- a/pgtype/numrange.go +++ b/pgtype/numrange.go @@ -106,7 +106,7 @@ func (dst *Numrange) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Numrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Numrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src Numrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } -func (src Numrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Numrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -263,6 +263,6 @@ func (dst *Numrange) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Numrange) Value() (driver.Value, error) { +func (src *Numrange) Value() (driver.Value, error) { return encodeValueText(src) } diff --git a/pgtype/numrange_test.go b/pgtype/numrange_test.go index 81e73c38..32267c86 100644 --- a/pgtype/numrange_test.go +++ b/pgtype/numrange_test.go @@ -10,25 +10,25 @@ import ( func TestNumrangeTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "numrange", []interface{}{ - pgtype.Numrange{ + &pgtype.Numrange{ LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present, }, - pgtype.Numrange{ + &pgtype.Numrange{ Lower: pgtype.Numeric{Int: big.NewInt(-543), Exp: 3, Status: pgtype.Present}, Upper: pgtype.Numeric{Int: big.NewInt(342), Exp: 1, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present, }, - pgtype.Numrange{ + &pgtype.Numrange{ Lower: pgtype.Numeric{Int: big.NewInt(-42), Exp: 1, Status: pgtype.Present}, Upper: pgtype.Numeric{Int: big.NewInt(-5), Exp: 0, Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present, }, - pgtype.Numrange{Status: pgtype.Null}, + &pgtype.Numrange{Status: pgtype.Null}, }) } diff --git a/pgtype/oid_value.go b/pgtype/oid_value.go index cb03802e..4a7de921 100644 --- a/pgtype/oid_value.go +++ b/pgtype/oid_value.go @@ -37,12 +37,12 @@ func (dst *OidValue) DecodeBinary(ci *ConnInfo, src []byte) error { return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src OidValue) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (pguint32)(src).EncodeText(ci, w) +func (src *OidValue) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (*pguint32)(src).EncodeText(ci, w) } -func (src OidValue) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (pguint32)(src).EncodeBinary(ci, w) +func (src *OidValue) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (*pguint32)(src).EncodeBinary(ci, w) } // Scan implements the database/sql Scanner interface. @@ -51,6 +51,6 @@ func (dst *OidValue) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src OidValue) Value() (driver.Value, error) { - return (pguint32)(src).Value() +func (src *OidValue) Value() (driver.Value, error) { + return (*pguint32)(src).Value() } diff --git a/pgtype/oid_value_test.go b/pgtype/oid_value_test.go index d3412159..52ce4064 100644 --- a/pgtype/oid_value_test.go +++ b/pgtype/oid_value_test.go @@ -10,8 +10,8 @@ import ( func TestOidValueTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "oid", []interface{}{ - pgtype.OidValue{Uint: 42, Status: pgtype.Present}, - pgtype.OidValue{Status: pgtype.Null}, + &pgtype.OidValue{Uint: 42, Status: pgtype.Present}, + &pgtype.OidValue{Status: pgtype.Null}, }) } diff --git a/pgtype/pguint32.go b/pgtype/pguint32.go index 7138a409..0caa0cba 100644 --- a/pgtype/pguint32.go +++ b/pgtype/pguint32.go @@ -103,7 +103,7 @@ func (dst *pguint32) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src pguint32) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *pguint32) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -115,7 +115,7 @@ func (src pguint32) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } -func (src pguint32) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *pguint32) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -151,7 +151,7 @@ func (dst *pguint32) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src pguint32) Value() (driver.Value, error) { +func (src *pguint32) Value() (driver.Value, error) { switch src.Status { case Present: return int64(src.Uint), nil diff --git a/pgtype/qchar.go b/pgtype/qchar.go index 49475bd3..10b56534 100644 --- a/pgtype/qchar.go +++ b/pgtype/qchar.go @@ -136,7 +136,7 @@ func (dst *QChar) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src QChar) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *QChar) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/text.go b/pgtype/text.go index de80dd08..8e42a756 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -91,7 +91,7 @@ func (dst *Text) DecodeBinary(ci *ConnInfo, src []byte) error { return dst.DecodeText(ci, src) } -func (src Text) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Text) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -103,7 +103,7 @@ func (src Text) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } -func (src Text) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Text) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return src.EncodeText(ci, w) } @@ -125,7 +125,7 @@ func (dst *Text) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Text) Value() (driver.Value, error) { +func (src *Text) Value() (driver.Value, error) { switch src.Status { case Present: return src.String, nil @@ -136,7 +136,7 @@ func (src Text) Value() (driver.Value, error) { } } -func (src Text) MarshalJSON() ([]byte, error) { +func (src *Text) MarshalJSON() ([]byte, error) { switch src.Status { case Present: return json.Marshal(src.String) diff --git a/pgtype/text_test.go b/pgtype/text_test.go index e4c1dbd8..bd971807 100644 --- a/pgtype/text_test.go +++ b/pgtype/text_test.go @@ -12,9 +12,9 @@ import ( func TestTextTranscode(t *testing.T) { for _, pgTypeName := range []string{"text", "varchar"} { testutil.TestSuccessfulTranscode(t, pgTypeName, []interface{}{ - pgtype.Text{String: "", Status: pgtype.Present}, - pgtype.Text{String: "foo", Status: pgtype.Present}, - pgtype.Text{Status: pgtype.Null}, + &pgtype.Text{String: "", Status: pgtype.Present}, + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Text{Status: pgtype.Null}, }) } } diff --git a/pgtype/tid.go b/pgtype/tid.go index b363c1f9..f24c6244 100644 --- a/pgtype/tid.go +++ b/pgtype/tid.go @@ -94,7 +94,7 @@ func (dst *Tid) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Tid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Tid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -106,7 +106,7 @@ func (src Tid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } -func (src Tid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Tid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -141,6 +141,6 @@ func (dst *Tid) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Tid) Value() (driver.Value, error) { +func (src *Tid) Value() (driver.Value, error) { return encodeValueText(src) } diff --git a/pgtype/tid_test.go b/pgtype/tid_test.go index 7eb7773a..a5430d11 100644 --- a/pgtype/tid_test.go +++ b/pgtype/tid_test.go @@ -9,8 +9,8 @@ import ( func TestTidTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "tid", []interface{}{ - pgtype.Tid{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, - pgtype.Tid{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, - pgtype.Tid{Status: pgtype.Null}, + &pgtype.Tid{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, + &pgtype.Tid{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, + &pgtype.Tid{Status: pgtype.Null}, }) } diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index e7bc1c7d..694b63c0 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -136,7 +136,7 @@ func (dst *Timestamp) DecodeBinary(ci *ConnInfo, src []byte) error { // EncodeText writes the text encoding of src into w. If src.Time is not in // the UTC time zone it returns an error. -func (src Timestamp) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Timestamp) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -164,7 +164,7 @@ func (src Timestamp) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { // EncodeBinary writes the binary encoding of src into w. If src.Time is not in // the UTC time zone it returns an error. -func (src Timestamp) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Timestamp) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -211,7 +211,7 @@ func (dst *Timestamp) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Timestamp) Value() (driver.Value, error) { +func (src *Timestamp) Value() (driver.Value, error) { switch src.Status { case Present: if src.InfinityModifier != None { diff --git a/pgtype/timestamp_test.go b/pgtype/timestamp_test.go index c0427a5c..267f1a7e 100644 --- a/pgtype/timestamp_test.go +++ b/pgtype/timestamp_test.go @@ -11,19 +11,19 @@ import ( func TestTimestampTranscode(t *testing.T) { testutil.TestSuccessfulTranscodeEqFunc(t, "timestamp", []interface{}{ - pgtype.Timestamp{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamp{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamp{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamp{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamp{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamp{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, - pgtype.Timestamp{Status: pgtype.Null}, - pgtype.Timestamp{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, - pgtype.Timestamp{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, + &pgtype.Timestamp{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, + &pgtype.Timestamp{Status: pgtype.Null}, + &pgtype.Timestamp{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, + &pgtype.Timestamp{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, }, func(a, b interface{}) bool { at := a.(pgtype.Timestamp) bt := b.(pgtype.Timestamp) diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index ef2d7498..3c76ec03 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -140,7 +140,7 @@ func (dst *Timestamptz) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Timestamptz) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Timestamptz) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -163,7 +163,7 @@ func (src Timestamptz) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } -func (src Timestamptz) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Timestamptz) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -207,7 +207,7 @@ func (dst *Timestamptz) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Timestamptz) Value() (driver.Value, error) { +func (src *Timestamptz) Value() (driver.Value, error) { switch src.Status { case Present: if src.InfinityModifier != None { diff --git a/pgtype/timestamptz_test.go b/pgtype/timestamptz_test.go index bbc001e5..c326802d 100644 --- a/pgtype/timestamptz_test.go +++ b/pgtype/timestamptz_test.go @@ -11,19 +11,19 @@ import ( func TestTimestamptzTranscode(t *testing.T) { testutil.TestSuccessfulTranscodeEqFunc(t, "timestamptz", []interface{}{ - pgtype.Timestamptz{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - pgtype.Timestamptz{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - pgtype.Timestamptz{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - pgtype.Timestamptz{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - pgtype.Timestamptz{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, - pgtype.Timestamptz{Status: pgtype.Null}, - pgtype.Timestamptz{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, - pgtype.Timestamptz{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, + &pgtype.Timestamptz{Time: time.Date(1800, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(1905, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(1940, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(1960, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(2000, 1, 2, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 0, 0, 0, time.Local), Status: pgtype.Present}, + &pgtype.Timestamptz{Status: pgtype.Null}, + &pgtype.Timestamptz{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, + &pgtype.Timestamptz{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, }, func(a, b interface{}) bool { at := a.(pgtype.Timestamptz) bt := b.(pgtype.Timestamptz) diff --git a/pgtype/tsrange.go b/pgtype/tsrange.go index 48992829..3bf5f5ca 100644 --- a/pgtype/tsrange.go +++ b/pgtype/tsrange.go @@ -106,7 +106,7 @@ func (dst *Tsrange) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Tsrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Tsrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src Tsrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } -func (src Tsrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Tsrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -263,6 +263,6 @@ func (dst *Tsrange) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Tsrange) Value() (driver.Value, error) { +func (src *Tsrange) Value() (driver.Value, error) { return encodeValueText(src) } diff --git a/pgtype/tsrange_test.go b/pgtype/tsrange_test.go index 865233c2..78eb1cd3 100644 --- a/pgtype/tsrange_test.go +++ b/pgtype/tsrange_test.go @@ -10,22 +10,22 @@ import ( func TestTsrangeTranscode(t *testing.T) { testutil.TestSuccessfulTranscodeEqFunc(t, "tsrange", []interface{}{ - pgtype.Tsrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, - pgtype.Tsrange{ + &pgtype.Tsrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + &pgtype.Tsrange{ Lower: pgtype.Timestamp{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, Upper: pgtype.Timestamp{Time: time.Date(2028, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present, }, - pgtype.Tsrange{ + &pgtype.Tsrange{ Lower: pgtype.Timestamp{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, Upper: pgtype.Timestamp{Time: time.Date(2200, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present, }, - pgtype.Tsrange{Status: pgtype.Null}, + &pgtype.Tsrange{Status: pgtype.Null}, }, func(aa, bb interface{}) bool { a := aa.(pgtype.Tsrange) b := bb.(pgtype.Tsrange) diff --git a/pgtype/tstzrange.go b/pgtype/tstzrange.go index 61e94ab4..8e80a8f9 100644 --- a/pgtype/tstzrange.go +++ b/pgtype/tstzrange.go @@ -106,7 +106,7 @@ func (dst *Tstzrange) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Tstzrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Tstzrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src Tstzrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } -func (src Tstzrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Tstzrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -263,6 +263,6 @@ func (dst *Tstzrange) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Tstzrange) Value() (driver.Value, error) { +func (src *Tstzrange) Value() (driver.Value, error) { return encodeValueText(src) } diff --git a/pgtype/tstzrange_test.go b/pgtype/tstzrange_test.go index 8eb00ab9..a27ddd3a 100644 --- a/pgtype/tstzrange_test.go +++ b/pgtype/tstzrange_test.go @@ -10,22 +10,22 @@ import ( func TestTstzrangeTranscode(t *testing.T) { testutil.TestSuccessfulTranscodeEqFunc(t, "tstzrange", []interface{}{ - pgtype.Tstzrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, - pgtype.Tstzrange{ + &pgtype.Tstzrange{LowerType: pgtype.Empty, UpperType: pgtype.Empty, Status: pgtype.Present}, + &pgtype.Tstzrange{ Lower: pgtype.Timestamptz{Time: time.Date(1990, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, Upper: pgtype.Timestamptz{Time: time.Date(2028, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present, }, - pgtype.Tstzrange{ + &pgtype.Tstzrange{ Lower: pgtype.Timestamptz{Time: time.Date(1800, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, Upper: pgtype.Timestamptz{Time: time.Date(2200, 1, 1, 0, 23, 12, 0, time.UTC), Status: pgtype.Present}, LowerType: pgtype.Inclusive, UpperType: pgtype.Exclusive, Status: pgtype.Present, }, - pgtype.Tstzrange{Status: pgtype.Null}, + &pgtype.Tstzrange{Status: pgtype.Null}, }, func(aa, bb interface{}) bool { a := aa.(pgtype.Tstzrange) b := bb.(pgtype.Tstzrange) diff --git a/pgtype/unknown.go b/pgtype/unknown.go index 2dca0f87..567831d7 100644 --- a/pgtype/unknown.go +++ b/pgtype/unknown.go @@ -39,6 +39,6 @@ func (dst *Unknown) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Unknown) Value() (driver.Value, error) { - return (Text)(src).Value() +func (src *Unknown) Value() (driver.Value, error) { + return (*Text)(src).Value() } diff --git a/pgtype/uuid.go b/pgtype/uuid.go index 88d2195b..03029ffd 100644 --- a/pgtype/uuid.go +++ b/pgtype/uuid.go @@ -126,7 +126,7 @@ func (dst *Uuid) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Uuid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Uuid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -138,7 +138,7 @@ func (src Uuid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } -func (src Uuid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Uuid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -168,6 +168,6 @@ func (dst *Uuid) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Uuid) Value() (driver.Value, error) { +func (src *Uuid) Value() (driver.Value, error) { return encodeValueText(src) } diff --git a/pgtype/uuid_test.go b/pgtype/uuid_test.go index b745d542..4c6ad2cd 100644 --- a/pgtype/uuid_test.go +++ b/pgtype/uuid_test.go @@ -10,8 +10,8 @@ import ( func TestUuidTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ - pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - pgtype.Uuid{Status: pgtype.Null}, + &pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + &pgtype.Uuid{Status: pgtype.Null}, }) } diff --git a/pgtype/varchar.go b/pgtype/varchar.go index 6c137b9a..80673fa8 100644 --- a/pgtype/varchar.go +++ b/pgtype/varchar.go @@ -32,12 +32,12 @@ func (dst *Varchar) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Text)(dst).DecodeBinary(ci, src) } -func (src Varchar) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (Text)(src).EncodeText(ci, w) +func (src *Varchar) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (*Text)(src).EncodeText(ci, w) } -func (src Varchar) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (Text)(src).EncodeBinary(ci, w) +func (src *Varchar) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (*Text)(src).EncodeBinary(ci, w) } // Scan implements the database/sql Scanner interface. @@ -46,10 +46,10 @@ func (dst *Varchar) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Varchar) Value() (driver.Value, error) { - return (Text)(src).Value() +func (src *Varchar) Value() (driver.Value, error) { + return (*Text)(src).Value() } -func (src Varchar) MarshalJSON() ([]byte, error) { - return (Text)(src).MarshalJSON() +func (src *Varchar) MarshalJSON() ([]byte, error) { + return (*Text)(src).MarshalJSON() } diff --git a/pgtype/xid.go b/pgtype/xid.go index 0a7fc7d9..90a8d691 100644 --- a/pgtype/xid.go +++ b/pgtype/xid.go @@ -46,12 +46,12 @@ func (dst *Xid) DecodeBinary(ci *ConnInfo, src []byte) error { return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src Xid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (pguint32)(src).EncodeText(ci, w) +func (src *Xid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (*pguint32)(src).EncodeText(ci, w) } -func (src Xid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (pguint32)(src).EncodeBinary(ci, w) +func (src *Xid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (*pguint32)(src).EncodeBinary(ci, w) } // Scan implements the database/sql Scanner interface. @@ -60,6 +60,6 @@ func (dst *Xid) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Xid) Value() (driver.Value, error) { - return (pguint32)(src).Value() +func (src *Xid) Value() (driver.Value, error) { + return (*pguint32)(src).Value() } diff --git a/pgtype/xid_test.go b/pgtype/xid_test.go index 868c101e..c4a1bec3 100644 --- a/pgtype/xid_test.go +++ b/pgtype/xid_test.go @@ -11,8 +11,8 @@ import ( func TestXidTranscode(t *testing.T) { pgTypeName := "xid" values := []interface{}{ - pgtype.Xid{Uint: 42, Status: pgtype.Present}, - pgtype.Xid{Status: pgtype.Null}, + &pgtype.Xid{Uint: 42, Status: pgtype.Present}, + &pgtype.Xid{Status: pgtype.Null}, } eqFunc := func(a, b interface{}) bool { return reflect.DeepEqual(a, b) From cab445ddd2102f6d05aa4c8dcf7e6e304faaa772 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 14 Apr 2017 16:46:39 -0500 Subject: [PATCH 168/264] Add satori-uuid type Make pgtype.EncodeValueText public --- pgtype/box.go | 2 +- pgtype/circle.go | 2 +- pgtype/database_sql.go | 2 +- pgtype/daterange.go | 2 +- pgtype/ext/satori-uuid/uuid.go | 164 ++++++++++++++++++++++++++++ pgtype/ext/satori-uuid/uuid_test.go | 97 ++++++++++++++++ pgtype/hstore.go | 2 +- pgtype/inet.go | 2 +- pgtype/int4range.go | 2 +- pgtype/int8range.go | 2 +- pgtype/interval.go | 2 +- pgtype/line.go | 2 +- pgtype/lseg.go | 2 +- pgtype/macaddr.go | 2 +- pgtype/numrange.go | 2 +- pgtype/path.go | 2 +- pgtype/point.go | 2 +- pgtype/polygon.go | 2 +- pgtype/tid.go | 2 +- pgtype/tsrange.go | 2 +- pgtype/tstzrange.go | 2 +- pgtype/typed_range.go.erb | 2 +- pgtype/uuid.go | 2 +- pgtype/varbit.go | 2 +- 24 files changed, 283 insertions(+), 22 deletions(-) create mode 100644 pgtype/ext/satori-uuid/uuid.go create mode 100644 pgtype/ext/satori-uuid/uuid_test.go diff --git a/pgtype/box.go b/pgtype/box.go index 138953a5..2e4f39ee 100644 --- a/pgtype/box.go +++ b/pgtype/box.go @@ -164,5 +164,5 @@ func (dst *Box) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Box) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/circle.go b/pgtype/circle.go index 62e2e8b3..8c8f4693 100644 --- a/pgtype/circle.go +++ b/pgtype/circle.go @@ -146,5 +146,5 @@ func (dst *Circle) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Circle) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/database_sql.go b/pgtype/database_sql.go index 2ddd842d..e255b646 100644 --- a/pgtype/database_sql.go +++ b/pgtype/database_sql.go @@ -31,7 +31,7 @@ func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) { return nil, errors.New("cannot convert to database/sql compatible value") } -func encodeValueText(src TextEncoder) (interface{}, error) { +func EncodeValueText(src TextEncoder) (interface{}, error) { buf := &bytes.Buffer{} null, err := src.EncodeText(nil, buf) if err != nil { diff --git a/pgtype/daterange.go b/pgtype/daterange.go index d78c4803..5cecca20 100644 --- a/pgtype/daterange.go +++ b/pgtype/daterange.go @@ -264,5 +264,5 @@ func (dst *Daterange) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Daterange) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/ext/satori-uuid/uuid.go b/pgtype/ext/satori-uuid/uuid.go new file mode 100644 index 00000000..1b65f48a --- /dev/null +++ b/pgtype/ext/satori-uuid/uuid.go @@ -0,0 +1,164 @@ +package uuid + +import ( + "database/sql/driver" + "errors" + "fmt" + "io" + + "github.com/jackc/pgx/pgtype" + uuid "github.com/satori/go.uuid" +) + +var errUndefined = errors.New("cannot encode status undefined") + +type Uuid struct { + UUID uuid.UUID + Status pgtype.Status +} + +func (dst *Uuid) Set(src interface{}) error { + switch value := src.(type) { + case uuid.UUID: + *dst = Uuid{UUID: value, Status: pgtype.Present} + case [16]byte: + *dst = Uuid{UUID: uuid.UUID(value), Status: pgtype.Present} + case []byte: + if len(value) != 16 { + return fmt.Errorf("[]byte must be 16 bytes to convert to Uuid: %d", len(value)) + } + *dst = Uuid{Status: pgtype.Present} + copy(dst.UUID[:], value) + case string: + uuid, err := uuid.FromString(value) + if err != nil { + return err + } + *dst = Uuid{UUID: uuid, Status: pgtype.Present} + default: + // If all else fails see if pgtype.Uuid can handle it. If so, translate through that. + pgUuid := &pgtype.Uuid{} + if err := pgUuid.Set(value); err != nil { + return fmt.Errorf("cannot convert %v to Uuid", value) + } + + *dst = Uuid{UUID: uuid.UUID(pgUuid.Bytes), Status: pgUuid.Status} + } + + return nil +} + +func (dst *Uuid) Get() interface{} { + switch dst.Status { + case pgtype.Present: + return dst.UUID + case pgtype.Null: + return nil + default: + return dst.Status + } +} + +func (src *Uuid) AssignTo(dst interface{}) error { + switch src.Status { + case pgtype.Present: + switch v := dst.(type) { + case *uuid.UUID: + *v = src.UUID + case *[16]byte: + *v = [16]byte(src.UUID) + return nil + case *[]byte: + *v = make([]byte, 16) + copy(*v, src.UUID[:]) + return nil + case *string: + *v = src.UUID.String() + return nil + default: + if nextDst, retry := pgtype.GetAssignToDstType(v); retry { + return src.AssignTo(nextDst) + } + } + case pgtype.Null: + return pgtype.NullAssignTo(dst) + } + + return fmt.Errorf("cannot assign %v into %T", src, dst) +} + +func (dst *Uuid) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + if src == nil { + *dst = Uuid{Status: pgtype.Null} + return nil + } + + u, err := uuid.FromString(string(src)) + if err != nil { + return err + } + + *dst = Uuid{UUID: u, Status: pgtype.Present} + return nil +} + +func (dst *Uuid) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + if src == nil { + *dst = Uuid{Status: pgtype.Null} + return nil + } + + if len(src) != 16 { + return fmt.Errorf("invalid length for Uuid: %v", len(src)) + } + + *dst = Uuid{Status: pgtype.Present} + copy(dst.UUID[:], src) + return nil +} + +func (src *Uuid) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case pgtype.Null: + return true, nil + case pgtype.Undefined: + return false, errUndefined + } + + _, err := io.WriteString(w, src.UUID.String()) + return false, err +} + +func (src *Uuid) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case pgtype.Null: + return true, nil + case pgtype.Undefined: + return false, errUndefined + } + + _, err := w.Write(src.UUID[:]) + return false, err +} + +// Scan implements the database/sql Scanner interface. +func (dst *Uuid) Scan(src interface{}) error { + if src == nil { + *dst = Uuid{Status: pgtype.Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Uuid) Value() (driver.Value, error) { + return pgtype.EncodeValueText(src) +} diff --git a/pgtype/ext/satori-uuid/uuid_test.go b/pgtype/ext/satori-uuid/uuid_test.go new file mode 100644 index 00000000..993fb837 --- /dev/null +++ b/pgtype/ext/satori-uuid/uuid_test.go @@ -0,0 +1,97 @@ +package uuid_test + +import ( + "bytes" + "testing" + + "github.com/jackc/pgx/pgtype" + satori "github.com/jackc/pgx/pgtype/ext/satori-uuid" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestUuidTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ + &satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + &satori.Uuid{Status: pgtype.Null}, + }) +} + +func TestUuidSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result satori.Uuid + }{ + { + source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + { + source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + { + source: "00010203-0405-0607-0809-0a0b0c0d0e0f", + result: satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r satori.Uuid + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestUuidAssignTo(t *testing.T) { + { + src := satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst [16]byte + expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst []byte + expected := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if bytes.Compare(dst, expected) != 0 { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst string + expected := "00010203-0405-0607-0809-0a0b0c0d0e0f" + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + +} diff --git a/pgtype/hstore.go b/pgtype/hstore.go index 3d55f783..04df2acc 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -463,5 +463,5 @@ func (dst *Hstore) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Hstore) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/inet.go b/pgtype/inet.go index 62734088..e3a7ec88 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -221,5 +221,5 @@ func (dst *Inet) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Inet) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/int4range.go b/pgtype/int4range.go index 8b04cf3c..12a48dab 100644 --- a/pgtype/int4range.go +++ b/pgtype/int4range.go @@ -264,5 +264,5 @@ func (dst *Int4range) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Int4range) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/int8range.go b/pgtype/int8range.go index f8e056cb..3541dbe2 100644 --- a/pgtype/int8range.go +++ b/pgtype/int8range.go @@ -264,5 +264,5 @@ func (dst *Int8range) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Int8range) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/interval.go b/pgtype/interval.go index 1cbdffc3..050d5610 100644 --- a/pgtype/interval.go +++ b/pgtype/interval.go @@ -267,5 +267,5 @@ func (dst *Interval) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Interval) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/line.go b/pgtype/line.go index 08a74e84..06f01f21 100644 --- a/pgtype/line.go +++ b/pgtype/line.go @@ -144,5 +144,5 @@ func (dst *Line) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Line) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/lseg.go b/pgtype/lseg.go index b86256e0..986724cc 100644 --- a/pgtype/lseg.go +++ b/pgtype/lseg.go @@ -164,5 +164,5 @@ func (dst *Lseg) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Lseg) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/macaddr.go b/pgtype/macaddr.go index cfbb513d..0fe092e4 100644 --- a/pgtype/macaddr.go +++ b/pgtype/macaddr.go @@ -150,5 +150,5 @@ func (dst *Macaddr) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Macaddr) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/numrange.go b/pgtype/numrange.go index a1b5b184..b0baec9a 100644 --- a/pgtype/numrange.go +++ b/pgtype/numrange.go @@ -264,5 +264,5 @@ func (dst *Numrange) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Numrange) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/path.go b/pgtype/path.go index fb4193d9..2fd6cfc7 100644 --- a/pgtype/path.go +++ b/pgtype/path.go @@ -203,5 +203,5 @@ func (dst *Path) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Path) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/point.go b/pgtype/point.go index 788a76c9..3d51766e 100644 --- a/pgtype/point.go +++ b/pgtype/point.go @@ -138,5 +138,5 @@ func (dst *Point) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Point) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/polygon.go b/pgtype/polygon.go index 1e2df011..af99ee3d 100644 --- a/pgtype/polygon.go +++ b/pgtype/polygon.go @@ -182,5 +182,5 @@ func (dst *Polygon) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Polygon) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/tid.go b/pgtype/tid.go index f24c6244..7976afde 100644 --- a/pgtype/tid.go +++ b/pgtype/tid.go @@ -142,5 +142,5 @@ func (dst *Tid) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Tid) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/tsrange.go b/pgtype/tsrange.go index 3bf5f5ca..78a94af2 100644 --- a/pgtype/tsrange.go +++ b/pgtype/tsrange.go @@ -264,5 +264,5 @@ func (dst *Tsrange) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Tsrange) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/tstzrange.go b/pgtype/tstzrange.go index 8e80a8f9..d1fc7326 100644 --- a/pgtype/tstzrange.go +++ b/pgtype/tstzrange.go @@ -264,5 +264,5 @@ func (dst *Tstzrange) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Tstzrange) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/typed_range.go.erb b/pgtype/typed_range.go.erb index 922b98b4..e46f71c7 100644 --- a/pgtype/typed_range.go.erb +++ b/pgtype/typed_range.go.erb @@ -264,5 +264,5 @@ func (dst *<%= range_type %>) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src <%= range_type %>) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/uuid.go b/pgtype/uuid.go index 03029ffd..c830c086 100644 --- a/pgtype/uuid.go +++ b/pgtype/uuid.go @@ -169,5 +169,5 @@ func (dst *Uuid) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Uuid) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/varbit.go b/pgtype/varbit.go index d28e95cd..00c34e10 100644 --- a/pgtype/varbit.go +++ b/pgtype/varbit.go @@ -137,5 +137,5 @@ func (dst *Varbit) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Varbit) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } From f7d3c4e151796f5e6e078093cd3c4fe796ffd55b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 14 Apr 2017 17:11:39 -0500 Subject: [PATCH 169/264] Replace DATABASE_URL with PGX_TEST_DATABASE PGX_TEST_DATABASE is much less likely to collide with another environment variable. This is especially valuable when using direnv to automatically set environment variables. --- .travis.yml | 2 +- pgtype/testutil/testutil.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.travis.yml b/.travis.yml index a60a324e..66815bb8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -28,7 +28,7 @@ before_install: env: global: - - DATABASE_URL=postgres://pgx_md5:secret@127.0.0.1/pgx_test + - PGX_TEST_DATABASE=postgres://pgx_md5:secret@127.0.0.1/pgx_test matrix: - PGVERSION=9.6 - PGVERSION=9.5 diff --git a/pgtype/testutil/testutil.go b/pgtype/testutil/testutil.go index d9aaa5c4..6bf9f878 100644 --- a/pgtype/testutil/testutil.go +++ b/pgtype/testutil/testutil.go @@ -26,7 +26,7 @@ func MustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { t.Fatalf("Unknown driver %v", driverName) } - db, err := sql.Open(sqlDriverName, os.Getenv("DATABASE_URL")) + db, err := sql.Open(sqlDriverName, os.Getenv("PGX_TEST_DATABASE")) if err != nil { t.Fatal(err) } @@ -35,7 +35,7 @@ func MustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { } func MustConnectPgx(t testing.TB) *pgx.Conn { - config, err := pgx.ParseURI(os.Getenv("DATABASE_URL")) + config, err := pgx.ParseURI(os.Getenv("PGX_TEST_DATABASE")) if err != nil { t.Fatal(err) } From 73471ea3fe22e9f510af48dd819ca8632c1d4abc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 14 Apr 2017 17:21:32 -0500 Subject: [PATCH 170/264] Use pgx.ParseConnectionString in test helper This allows using URI or DSN for database connection information. DSN allows using unix domain sockets. --- pgtype/testutil/testutil.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgtype/testutil/testutil.go b/pgtype/testutil/testutil.go index 6bf9f878..5dd2fbe1 100644 --- a/pgtype/testutil/testutil.go +++ b/pgtype/testutil/testutil.go @@ -35,7 +35,7 @@ func MustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { } func MustConnectPgx(t testing.TB) *pgx.Conn { - config, err := pgx.ParseURI(os.Getenv("PGX_TEST_DATABASE")) + config, err := pgx.ParseConnectionString(os.Getenv("PGX_TEST_DATABASE")) if err != nil { t.Fatal(err) } From e305ece4105beb94580f9da8d9bd9104deaed5bd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 14 Apr 2017 17:37:01 -0500 Subject: [PATCH 171/264] Fix travis --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index 66815bb8..0045cf5a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -55,6 +55,7 @@ install: - go get -u github.com/jackc/pgmock/pgmsg - go get -u github.com/lib/pq - go get -u github.com/hashicorp/go-version + - go get -u github.com/satori/go.uuid script: - go test -v -race ./... From f04c58338b58927573d7e664c19b325220d848a4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 29 Apr 2017 10:02:38 -0500 Subject: [PATCH 172/264] Introduce pgproto3 package pgproto3 will wrap the message encoding and decoding for the PostgreSQL frontend/backend protocol version 3. --- .gitignore | 1 + .travis.yml | 2 +- conn.go | 304 ++++++++++++----------------- conn_pool_test.go | 2 +- copy_from.go | 30 +-- fastpath.go | 22 +-- messages.go | 4 +- pgproto3/authentication.go | 54 +++++ pgproto3/backend_key_data.go | 47 +++++ pgproto3/big_endian.go | 37 ++++ pgproto3/bind_complete.go | 29 +++ pgproto3/close_complete.go | 29 +++ pgproto3/command_complete.go | 47 +++++ pgproto3/copy_both_response.go | 64 ++++++ pgproto3/copy_data.go | 41 ++++ pgproto3/copy_in_response.go | 64 ++++++ pgproto3/copy_out_response.go | 64 ++++++ pgproto3/data_row.go | 103 ++++++++++ pgproto3/empty_query_response.go | 29 +++ pgproto3/error_response.go | 197 +++++++++++++++++++ pgproto3/frontend.go | 70 +++++++ pgproto3/function_call_response.go | 73 +++++++ pgproto3/no_data.go | 29 +++ pgproto3/notice_response.go | 13 ++ pgproto3/notification_response.go | 65 ++++++ pgproto3/parameter_description.go | 60 ++++++ pgproto3/parameter_status.go | 62 ++++++ pgproto3/parse_complete.go | 29 +++ pgproto3/pgproto3.go | 88 +++++++++ pgproto3/query.go | 43 ++++ pgproto3/ready_for_query.go | 35 ++++ pgproto3/row_description.go | 101 ++++++++++ query.go | 30 ++- replication.go | 70 ++++--- 34 files changed, 1676 insertions(+), 262 deletions(-) create mode 100644 pgproto3/authentication.go create mode 100644 pgproto3/backend_key_data.go create mode 100644 pgproto3/big_endian.go create mode 100644 pgproto3/bind_complete.go create mode 100644 pgproto3/close_complete.go create mode 100644 pgproto3/command_complete.go create mode 100644 pgproto3/copy_both_response.go create mode 100644 pgproto3/copy_data.go create mode 100644 pgproto3/copy_in_response.go create mode 100644 pgproto3/copy_out_response.go create mode 100644 pgproto3/data_row.go create mode 100644 pgproto3/empty_query_response.go create mode 100644 pgproto3/error_response.go create mode 100644 pgproto3/frontend.go create mode 100644 pgproto3/function_call_response.go create mode 100644 pgproto3/no_data.go create mode 100644 pgproto3/notice_response.go create mode 100644 pgproto3/notification_response.go create mode 100644 pgproto3/parameter_description.go create mode 100644 pgproto3/parameter_status.go create mode 100644 pgproto3/parse_complete.go create mode 100644 pgproto3/pgproto3.go create mode 100644 pgproto3/query.go create mode 100644 pgproto3/ready_for_query.go create mode 100644 pgproto3/row_description.go diff --git a/.gitignore b/.gitignore index cb0cd901..0ff00800 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,4 @@ _testmain.go *.exe conn_config_test.go +.envrc diff --git a/.travis.yml b/.travis.yml index 0045cf5a..edacab39 100644 --- a/.travis.yml +++ b/.travis.yml @@ -52,7 +52,7 @@ install: - go get -u github.com/shopspring/decimal - go get -u gopkg.in/inconshreveable/log15.v2 - go get -u github.com/jackc/fake - - go get -u github.com/jackc/pgmock/pgmsg + - go get -u github.com/jackc/pgmock/pgproto3 - go get -u github.com/lib/pq - go get -u github.com/hashicorp/go-version - go get -u github.com/satori/go.uuid diff --git a/conn.go b/conn.go index c2cb408f..7487b8ad 100644 --- a/conn.go +++ b/conn.go @@ -20,7 +20,7 @@ import ( "sync/atomic" "time" - "github.com/jackc/pgx/chunkreader" + "github.com/jackc/pgx/pgproto3" "github.com/jackc/pgx/pgtype" ) @@ -88,8 +88,8 @@ type Conn struct { lastActivityTime time.Time // the last time the connection was used wbuf [1024]byte writeBuf WriteBuf - pid int32 // backend pid - secretKey int32 // key to use to send a cancel query message to the server + pid uint32 // backend pid + secretKey uint32 // key to use to send a cancel query message to the server RuntimeParams map[string]string // parameters that have been reported by the server config ConnConfig // config used when establishing this connection txStatus byte @@ -98,7 +98,6 @@ type Conn struct { notifications []*Notification logger Logger logLevel int - mr msgReader fp *fastpath poolResetCount int preallocatedRows []Rows @@ -116,6 +115,8 @@ type Conn struct { closedChan chan error ConnInfo *pgtype.ConnInfo + + frontend *pgproto3.Frontend } // PreparedStatement is a description of a prepared statement @@ -133,7 +134,7 @@ type PrepareExOptions struct { // Notification is a message received from the PostgreSQL LISTEN/NOTIFY system type Notification struct { - PID int32 // backend pid that sent the notification + PID uint32 // backend pid that sent the notification Channel string // channel from which notification was received Payload string } @@ -213,8 +214,6 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) c.logLevel = LogLevelDebug } c.logger = c.config.Logger - c.mr.log = c.log - c.mr.shouldLog = c.shouldLog if c.config.User == "" { user, err := user.Current() @@ -290,7 +289,10 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl } } - c.mr.cr = chunkreader.NewChunkReader(c.conn) + c.frontend, err = pgproto3.NewFrontend(c.conn, c.conn) + if err != nil { + return err + } msg := newStartupMessage() @@ -317,29 +319,27 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl } for { - var t byte - var r *msgReader - t, r, err = c.rxMsg() + msg, err := c.rxMsg() if err != nil { return err } - switch t { - case backendKeyData: - c.rxBackendKeyData(r) - case authenticationX: - if err = c.rxAuthenticationX(r); err != nil { + switch msg := msg.(type) { + case *pgproto3.BackendKeyData: + c.rxBackendKeyData(msg) + case *pgproto3.Authentication: + if err = c.rxAuthenticationX(msg); err != nil { return err } - case readyForQuery: - c.rxReadyForQuery(r) + case *pgproto3.ReadyForQuery: + c.rxReadyForQuery(msg) if c.shouldLog(LogLevelInfo) { c.log(LogLevelInfo, "Connection established") } // Replication connections can't execute the queries to // populate the c.PgTypes and c.pgsqlAfInet - if _, ok := msg.options["replication"]; ok { + if _, ok := config.RuntimeParams["replication"]; ok { return nil } @@ -352,7 +352,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl return nil default: - if err = c.processContextFreeMsg(t, r); err != nil { + if err = c.processContextFreeMsg(msg); err != nil { return err } } @@ -393,7 +393,7 @@ where ( } // PID returns the backend PID for this connection. -func (c *Conn) PID() int32 { +func (c *Conn) PID() uint32 { return c.pid } @@ -744,22 +744,20 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared var softErr error for { - var t byte - var r *msgReader - t, r, err := c.rxMsg() + msg, err := c.rxMsg() if err != nil { return nil, err } - switch t { - case parameterDescription: - ps.ParameterOids = c.rxParameterDescription(r) + switch msg := msg.(type) { + case *pgproto3.ParameterDescription: + ps.ParameterOids = c.rxParameterDescription(msg) if len(ps.ParameterOids) > 65535 && softErr == nil { softErr = fmt.Errorf("PostgreSQL supports maximum of 65535 parameters, received %d", len(ps.ParameterOids)) } - case rowDescription: - ps.FieldDescriptions = c.rxRowDescription(r) + case *pgproto3.RowDescription: + ps.FieldDescriptions = c.rxRowDescription(msg) for i := range ps.FieldDescriptions { if dt, ok := c.ConnInfo.DataTypeForOid(ps.FieldDescriptions[i].DataType); ok { ps.FieldDescriptions[i].DataTypeName = dt.Name @@ -772,8 +770,8 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared return nil, fmt.Errorf("unknown oid: %d", ps.FieldDescriptions[i].DataType) } } - case readyForQuery: - c.rxReadyForQuery(r) + case *pgproto3.ReadyForQuery: + c.rxReadyForQuery(msg) if softErr == nil { c.preparedStatements[name] = ps @@ -781,7 +779,7 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared return ps, softErr default: - if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil { + if e := c.processContextFreeMsg(msg); e != nil && softErr == nil { softErr = e } } @@ -830,18 +828,16 @@ func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) { } for { - var t byte - var r *msgReader - t, r, err := c.rxMsg() + msg, err := c.rxMsg() if err != nil { return err } - switch t { - case closeComplete: + switch msg.(type) { + case *pgproto3.CloseComplete: return nil default: - err = c.processContextFreeMsg(t, r) + err = c.processContextFreeMsg(msg) if err != nil { return err } @@ -908,12 +904,12 @@ func (c *Conn) WaitForNotification(ctx context.Context) (notification *Notificat } for { - t, r, err := c.rxMsg() + msg, err := c.rxMsg() if err != nil { return nil, err } - err = c.processContextFreeMsg(t, r) + err = c.processContextFreeMsg(msg) if err != nil { return nil, err } @@ -1030,62 +1026,48 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag // meaningful in a given context. These messages can occur due to a context // deadline interrupting message processing. For example, an interrupted query // may have left DataRow messages on the wire. -func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) { - switch t { - case bindComplete: - case commandComplete: - case dataRow: - case emptyQueryResponse: - case errorResponse: - return c.rxErrorResponse(r) - case noData: - case noticeResponse: - case notificationResponse: - c.rxNotificationResponse(r) - case parameterDescription: - case parseComplete: - case readyForQuery: - c.rxReadyForQuery(r) - case rowDescription: - case 'S': - c.rxParameterStatus(r) - - default: - return fmt.Errorf("Received unknown message type: %c", t) +func (c *Conn) processContextFreeMsg(msg pgproto3.BackendMessage) (err error) { + switch msg := msg.(type) { + case *pgproto3.ErrorResponse: + return c.rxErrorResponse(msg) + case *pgproto3.NotificationResponse: + c.rxNotificationResponse(msg) + case *pgproto3.ReadyForQuery: + c.rxReadyForQuery(msg) + case *pgproto3.ParameterStatus: + c.rxParameterStatus(msg) } return nil } -func (c *Conn) rxMsg() (t byte, r *msgReader, err error) { +func (c *Conn) rxMsg() (pgproto3.BackendMessage, error) { if atomic.LoadInt32(&c.status) < connStatusIdle { - return 0, nil, ErrDeadConn + return nil, ErrDeadConn } - t, err = c.mr.rxMsg() + msg, err := c.frontend.Receive() if err != nil { if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { c.die(err) } + return nil, err } c.lastActivityTime = time.Now() - if c.shouldLog(LogLevelTrace) { - c.log(LogLevelTrace, "rxMsg", "type", string(t), "msgBodyLen", len(c.mr.msgBody)) - } + // fmt.Printf("rxMsg: %#v\n", msg) - return t, &c.mr, err + return msg, nil } -func (c *Conn) rxAuthenticationX(r *msgReader) (err error) { - switch r.readInt32() { - case 0: // AuthenticationOk - case 3: // AuthenticationCleartextPassword +func (c *Conn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { + switch msg.Type { + case pgproto3.AuthTypeOk: + case pgproto3.AuthTypeCleartextPassword: err = c.txPasswordMessage(c.config.Password) - case 5: // AuthenticationMD5Password - salt := r.readString(4) - digestedPassword := "md5" + hexMD5(hexMD5(c.config.Password+c.config.User)+salt) + case pgproto3.AuthTypeMD5Password: + digestedPassword := "md5" + hexMD5(hexMD5(c.config.Password+c.config.User)+string(msg.Salt[:])) err = c.txPasswordMessage(digestedPassword) default: err = errors.New("Received unknown authentication message") @@ -1100,115 +1082,75 @@ func hexMD5(s string) string { return hex.EncodeToString(hash.Sum(nil)) } -func (c *Conn) rxParameterStatus(r *msgReader) { - key := r.readCString() - value := r.readCString() - c.RuntimeParams[key] = value +func (c *Conn) rxParameterStatus(msg *pgproto3.ParameterStatus) { + c.RuntimeParams[msg.Name] = msg.Value } -func (c *Conn) rxErrorResponse(r *msgReader) (err PgError) { - for { - switch r.readByte() { - case 'S': - err.Severity = r.readCString() - case 'C': - err.Code = r.readCString() - case 'M': - err.Message = r.readCString() - case 'D': - err.Detail = r.readCString() - case 'H': - err.Hint = r.readCString() - case 'P': - s := r.readCString() - n, _ := strconv.ParseInt(s, 10, 32) - err.Position = int32(n) - case 'p': - s := r.readCString() - n, _ := strconv.ParseInt(s, 10, 32) - err.InternalPosition = int32(n) - case 'q': - err.InternalQuery = r.readCString() - case 'W': - err.Where = r.readCString() - case 's': - err.SchemaName = r.readCString() - case 't': - err.TableName = r.readCString() - case 'c': - err.ColumnName = r.readCString() - case 'd': - err.DataTypeName = r.readCString() - case 'n': - err.ConstraintName = r.readCString() - case 'F': - err.File = r.readCString() - case 'L': - s := r.readCString() - n, _ := strconv.ParseInt(s, 10, 32) - err.Line = int32(n) - case 'R': - err.Routine = r.readCString() - - case 0: // End of error message - if err.Severity == "FATAL" { - c.die(err) - } - return - default: // Ignore other error fields - r.readCString() - } +func (c *Conn) rxErrorResponse(msg *pgproto3.ErrorResponse) PgError { + err := PgError{ + Severity: msg.Severity, + Code: msg.Code, + Message: msg.Message, + Detail: msg.Detail, + Hint: msg.Hint, + Position: msg.Position, + InternalPosition: msg.InternalPosition, + InternalQuery: msg.InternalQuery, + Where: msg.Where, + SchemaName: msg.SchemaName, + TableName: msg.TableName, + ColumnName: msg.ColumnName, + DataTypeName: msg.DataTypeName, + ConstraintName: msg.ConstraintName, + File: msg.File, + Line: msg.Line, + Routine: msg.Routine, } + + if err.Severity == "FATAL" { + c.die(err) + } + + return err } -func (c *Conn) rxBackendKeyData(r *msgReader) { - c.pid = r.readInt32() - c.secretKey = r.readInt32() +func (c *Conn) rxBackendKeyData(msg *pgproto3.BackendKeyData) { + c.pid = msg.ProcessID + c.secretKey = msg.SecretKey } -func (c *Conn) rxReadyForQuery(r *msgReader) { +func (c *Conn) rxReadyForQuery(msg *pgproto3.ReadyForQuery) { c.readyForQuery = true - c.txStatus = r.readByte() + c.txStatus = msg.TxStatus } -func (c *Conn) rxRowDescription(r *msgReader) (fields []FieldDescription) { - fieldCount := r.readInt16() - fields = make([]FieldDescription, fieldCount) - for i := int16(0); i < fieldCount; i++ { - f := &fields[i] - f.Name = r.readCString() - f.Table = pgtype.Oid(r.readUint32()) - f.AttributeNumber = r.readInt16() - f.DataType = pgtype.Oid(r.readUint32()) - f.DataTypeSize = r.readInt16() - f.Modifier = r.readInt32() - f.FormatCode = r.readInt16() +func (c *Conn) rxRowDescription(msg *pgproto3.RowDescription) []FieldDescription { + fields := make([]FieldDescription, len(msg.Fields)) + for i := 0; i < len(fields); i++ { + fields[i].Name = msg.Fields[i].Name + fields[i].Table = pgtype.Oid(msg.Fields[i].TableOID) + fields[i].AttributeNumber = msg.Fields[i].TableAttributeNumber + fields[i].DataType = pgtype.Oid(msg.Fields[i].DataTypeOID) + fields[i].DataTypeSize = msg.Fields[i].DataTypeSize + fields[i].Modifier = msg.Fields[i].TypeModifier + fields[i].FormatCode = msg.Fields[i].Format } - return + return fields } -func (c *Conn) rxParameterDescription(r *msgReader) (parameters []pgtype.Oid) { - // Internally, PostgreSQL supports greater than 64k parameters to a prepared - // statement. But the parameter description uses a 16-bit integer for the - // count of parameters. If there are more than 64K parameters, this count is - // wrong. So read the count, ignore it, and compute the proper value from - // the size of the message. - r.readInt16() - parameterCount := len(r.msgBody[r.rp:]) / 4 - - parameters = make([]pgtype.Oid, 0, parameterCount) - - for i := 0; i < parameterCount; i++ { - parameters = append(parameters, pgtype.Oid(r.readUint32())) +func (c *Conn) rxParameterDescription(msg *pgproto3.ParameterDescription) []pgtype.Oid { + parameters := make([]pgtype.Oid, len(msg.ParameterOIDs)) + for i := 0; i < len(parameters); i++ { + parameters[i] = pgtype.Oid(msg.ParameterOIDs[i]) } - return + return parameters } -func (c *Conn) rxNotificationResponse(r *msgReader) { +func (c *Conn) rxNotificationResponse(msg *pgproto3.NotificationResponse) { n := new(Notification) - n.PID = r.readInt32() - n.Channel = r.readCString() - n.Payload = r.readCString() + n.PID = msg.PID + n.Channel = msg.Channel + n.Payload = msg.Payload c.notifications = append(c.notifications, n) } @@ -1453,21 +1395,19 @@ func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, var softErr error for { - var t byte - var r *msgReader - t, r, err = c.rxMsg() + msg, err := c.rxMsg() if err != nil { return commandTag, err } - switch t { - case readyForQuery: - c.rxReadyForQuery(r) + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + c.rxReadyForQuery(msg) return commandTag, softErr - case commandComplete: - commandTag = CommandTag(r.readCString()) + case *pgproto3.CommandComplete: + commandTag = CommandTag(msg.CommandTag) default: - if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil { + if e := c.processContextFreeMsg(msg); e != nil && softErr == nil { softErr = e } } @@ -1545,19 +1485,19 @@ func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error { func (c *Conn) ensureConnectionReadyForQuery() error { for !c.readyForQuery { - t, r, err := c.rxMsg() + msg, err := c.rxMsg() if err != nil { return err } - switch t { - case errorResponse: - pgErr := c.rxErrorResponse(r) + switch msg := msg.(type) { + case *pgproto3.ErrorResponse: + pgErr := c.rxErrorResponse(msg) if pgErr.Severity == "FATAL" { return pgErr } default: - err = c.processContextFreeMsg(t, r) + err = c.processContextFreeMsg(msg) if err != nil { return err } diff --git a/conn_pool_test.go b/conn_pool_test.go index 825638b6..42f37eb1 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -686,7 +686,7 @@ func TestConnPoolBeginRetry(t *testing.T) { } defer tx.Rollback() - var txPID int32 + var txPID uint32 err = tx.QueryRow("select pg_backend_pid()").Scan(&txPID) if err != nil { t.Fatalf("tx.QueryRow Scan failed: %v", err) diff --git a/copy_from.go b/copy_from.go index 9fc76a7b..7d8dead1 100644 --- a/copy_from.go +++ b/copy_from.go @@ -3,6 +3,8 @@ package pgx import ( "bytes" "fmt" + + "github.com/jackc/pgx/pgproto3" ) // CopyFromRows returns a CopyFromSource interface over the provided rows slice @@ -54,25 +56,25 @@ type copyFrom struct { func (ct *copyFrom) readUntilReadyForQuery() { for { - t, r, err := ct.conn.rxMsg() + msg, err := ct.conn.rxMsg() if err != nil { ct.readerErrChan <- err close(ct.readerErrChan) return } - switch t { - case readyForQuery: - ct.conn.rxReadyForQuery(r) + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + ct.conn.rxReadyForQuery(msg) close(ct.readerErrChan) return - case commandComplete: - case errorResponse: - ct.readerErrChan <- ct.conn.rxErrorResponse(r) + case *pgproto3.CommandComplete: + case *pgproto3.ErrorResponse: + ct.readerErrChan <- ct.conn.rxErrorResponse(msg) default: - err = ct.conn.processContextFreeMsg(t, r) + err = ct.conn.processContextFreeMsg(msg) if err != nil { - ct.readerErrChan <- ct.conn.processContextFreeMsg(t, r) + ct.readerErrChan <- ct.conn.processContextFreeMsg(msg) } } } @@ -190,18 +192,16 @@ func (ct *copyFrom) run() (int, error) { func (c *Conn) readUntilCopyInResponse() error { for { - var t byte - var r *msgReader - t, r, err := c.rxMsg() + msg, err := c.rxMsg() if err != nil { return err } - switch t { - case copyInResponse: + switch msg := msg.(type) { + case *pgproto3.CopyInResponse: return nil default: - err = c.processContextFreeMsg(t, r) + err = c.processContextFreeMsg(msg) if err != nil { return err } diff --git a/fastpath.go b/fastpath.go index 0caba9d3..75681c9c 100644 --- a/fastpath.go +++ b/fastpath.go @@ -3,6 +3,7 @@ package pgx import ( "encoding/binary" + "github.com/jackc/pgx/pgproto3" "github.com/jackc/pgx/pgtype" ) @@ -71,23 +72,20 @@ func (f *fastpath) Call(oid pgtype.Oid, args []fpArg) (res []byte, err error) { } for { - var t byte - var r *msgReader - t, r, err = f.cn.rxMsg() + msg, err := f.cn.rxMsg() if err != nil { return nil, err } - switch t { - case 'V': // FunctionCallResponse - data := r.readBytes(r.readInt32()) - res = make([]byte, len(data)) - copy(res, data) - case 'Z': // Ready for query - f.cn.rxReadyForQuery(r) + switch msg := msg.(type) { + case *pgproto3.FunctionCallResponse: + res = make([]byte, len(msg.Result)) + copy(res, msg.Result) + case *pgproto3.ReadyForQuery: + f.cn.rxReadyForQuery(msg) // done - return + return res, err default: - if err := f.cn.processContextFreeMsg(t, r); err != nil { + if err := f.cn.processContextFreeMsg(msg); err != nil { return nil, err } } diff --git a/messages.go b/messages.go index 68faf14c..e229367a 100644 --- a/messages.go +++ b/messages.go @@ -58,11 +58,11 @@ func (s *startupMessage) Bytes() (buf []byte) { type FieldDescription struct { Name string Table pgtype.Oid - AttributeNumber int16 + AttributeNumber uint16 DataType pgtype.Oid DataTypeSize int16 DataTypeName string - Modifier int32 + Modifier uint32 FormatCode int16 } diff --git a/pgproto3/authentication.go b/pgproto3/authentication.go new file mode 100644 index 00000000..e265a247 --- /dev/null +++ b/pgproto3/authentication.go @@ -0,0 +1,54 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "fmt" +) + +const ( + AuthTypeOk = 0 + AuthTypeCleartextPassword = 3 + AuthTypeMD5Password = 5 +) + +type Authentication struct { + Type uint32 + + // MD5Password fields + Salt [4]byte +} + +func (*Authentication) Backend() {} + +func (dst *Authentication) UnmarshalBinary(src []byte) error { + *dst = Authentication{Type: binary.BigEndian.Uint32(src[:4])} + + switch dst.Type { + case AuthTypeOk: + case AuthTypeCleartextPassword: + case AuthTypeMD5Password: + copy(dst.Salt[:], src[4:8]) + default: + return fmt.Errorf("unknown authentication type: %d", dst.Type) + } + + return nil +} + +func (src *Authentication) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + buf.WriteByte('R') + buf.Write(bigEndian.Uint32(0)) + buf.Write(bigEndian.Uint32(src.Type)) + + switch src.Type { + case AuthTypeMD5Password: + buf.Write(src.Salt[:]) + } + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} diff --git a/pgproto3/backend_key_data.go b/pgproto3/backend_key_data.go new file mode 100644 index 00000000..5d8eb496 --- /dev/null +++ b/pgproto3/backend_key_data.go @@ -0,0 +1,47 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type BackendKeyData struct { + ProcessID uint32 + SecretKey uint32 +} + +func (*BackendKeyData) Backend() {} + +func (dst *BackendKeyData) UnmarshalBinary(src []byte) error { + if len(src) != 8 { + return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)} + } + + dst.ProcessID = binary.BigEndian.Uint32(src[:4]) + dst.SecretKey = binary.BigEndian.Uint32(src[4:]) + + return nil +} + +func (src *BackendKeyData) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + buf.WriteByte('K') + buf.Write(bigEndian.Uint32(12)) + buf.Write(bigEndian.Uint32(src.ProcessID)) + buf.Write(bigEndian.Uint32(src.SecretKey)) + return buf.Bytes(), nil +} + +func (src *BackendKeyData) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ProcessID uint32 + SecretKey uint32 + }{ + Type: "BackendKeyData", + ProcessID: src.ProcessID, + SecretKey: src.SecretKey, + }) +} diff --git a/pgproto3/big_endian.go b/pgproto3/big_endian.go new file mode 100644 index 00000000..f7bdb97e --- /dev/null +++ b/pgproto3/big_endian.go @@ -0,0 +1,37 @@ +package pgproto3 + +import ( + "encoding/binary" +) + +type BigEndianBuf [8]byte + +func (b BigEndianBuf) Int16(n int16) []byte { + buf := b[0:2] + binary.BigEndian.PutUint16(buf, uint16(n)) + return buf +} + +func (b BigEndianBuf) Uint16(n uint16) []byte { + buf := b[0:2] + binary.BigEndian.PutUint16(buf, n) + return buf +} + +func (b BigEndianBuf) Int32(n int32) []byte { + buf := b[0:4] + binary.BigEndian.PutUint32(buf, uint32(n)) + return buf +} + +func (b BigEndianBuf) Uint32(n uint32) []byte { + buf := b[0:4] + binary.BigEndian.PutUint32(buf, n) + return buf +} + +func (b BigEndianBuf) Int64(n int64) []byte { + buf := b[0:8] + binary.BigEndian.PutUint64(buf, uint64(n)) + return buf +} diff --git a/pgproto3/bind_complete.go b/pgproto3/bind_complete.go new file mode 100644 index 00000000..756a30e6 --- /dev/null +++ b/pgproto3/bind_complete.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type BindComplete struct{} + +func (*BindComplete) Backend() {} + +func (dst *BindComplete) UnmarshalBinary(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *BindComplete) MarshalBinary() ([]byte, error) { + return []byte{'2', 0, 0, 0, 4}, nil +} + +func (src *BindComplete) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "BindComplete", + }) +} diff --git a/pgproto3/close_complete.go b/pgproto3/close_complete.go new file mode 100644 index 00000000..fd6ff180 --- /dev/null +++ b/pgproto3/close_complete.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type CloseComplete struct{} + +func (*CloseComplete) Backend() {} + +func (dst *CloseComplete) UnmarshalBinary(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *CloseComplete) MarshalBinary() ([]byte, error) { + return []byte{'3', 0, 0, 0, 4}, nil +} + +func (src *CloseComplete) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "CloseComplete", + }) +} diff --git a/pgproto3/command_complete.go b/pgproto3/command_complete.go new file mode 100644 index 00000000..ac60153e --- /dev/null +++ b/pgproto3/command_complete.go @@ -0,0 +1,47 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" +) + +type CommandComplete struct { + CommandTag string +} + +func (*CommandComplete) Backend() {} + +func (dst *CommandComplete) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + dst.CommandTag = string(b[:len(b)-1]) + + return nil +} + +func (src *CommandComplete) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('C') + buf.Write(bigEndian.Uint32(uint32(4 + len(src.CommandTag) + 1))) + + buf.WriteString(src.CommandTag) + buf.WriteByte(0) + + return buf.Bytes(), nil +} + +func (src *CommandComplete) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + CommandTag string + }{ + Type: "CommandComplete", + CommandTag: src.CommandTag, + }) +} diff --git a/pgproto3/copy_both_response.go b/pgproto3/copy_both_response.go new file mode 100644 index 00000000..2a4c58af --- /dev/null +++ b/pgproto3/copy_both_response.go @@ -0,0 +1,64 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type CopyBothResponse struct { + OverallFormat byte + ColumnFormatCodes []uint16 +} + +func (*CopyBothResponse) Backend() {} + +func (dst *CopyBothResponse) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 3 { + return &invalidMessageFormatErr{messageType: "CopyBothResponse"} + } + + overallFormat := buf.Next(1)[0] + + columnCount := int(binary.BigEndian.Uint16(buf.Next(2))) + if buf.Len() != columnCount*2 { + return &invalidMessageFormatErr{messageType: "CopyBothResponse"} + } + + columnFormatCodes := make([]uint16, columnCount) + for i := 0; i < columnCount; i++ { + columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) + } + + *dst = CopyBothResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes} + + return nil +} + +func (src *CopyBothResponse) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('W') + buf.Write(bigEndian.Uint32(uint32(4 + 1 + 2 + 2*len(src.ColumnFormatCodes)))) + + buf.Write(bigEndian.Uint16(uint16(len(src.ColumnFormatCodes)))) + + for _, fc := range src.ColumnFormatCodes { + buf.Write(bigEndian.Uint16(fc)) + } + + return buf.Bytes(), nil +} + +func (src *CopyBothResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ColumnFormatCodes []uint16 + }{ + Type: "CopyBothResponse", + ColumnFormatCodes: src.ColumnFormatCodes, + }) +} diff --git a/pgproto3/copy_data.go b/pgproto3/copy_data.go new file mode 100644 index 00000000..b9ea6272 --- /dev/null +++ b/pgproto3/copy_data.go @@ -0,0 +1,41 @@ +package pgproto3 + +import ( + "bytes" + "encoding/hex" + "encoding/json" +) + +type CopyData struct { + Data []byte +} + +func (*CopyData) Backend() {} +func (*CopyData) Frontend() {} + +func (dst *CopyData) UnmarshalBinary(src []byte) error { + dst.Data = make([]byte, len(src)) + copy(dst.Data, src) + return nil +} + +func (src *CopyData) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('d') + buf.Write(bigEndian.Uint32(uint32(4 + len(src.Data)))) + buf.Write(src.Data) + + return buf.Bytes(), nil +} + +func (src *CopyData) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data string + }{ + Type: "CopyData", + Data: hex.EncodeToString(src.Data), + }) +} diff --git a/pgproto3/copy_in_response.go b/pgproto3/copy_in_response.go new file mode 100644 index 00000000..63868c7a --- /dev/null +++ b/pgproto3/copy_in_response.go @@ -0,0 +1,64 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type CopyInResponse struct { + OverallFormat byte + ColumnFormatCodes []uint16 +} + +func (*CopyInResponse) Backend() {} + +func (dst *CopyInResponse) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 3 { + return &invalidMessageFormatErr{messageType: "CopyInResponse"} + } + + overallFormat := buf.Next(1)[0] + + columnCount := int(binary.BigEndian.Uint16(buf.Next(2))) + if buf.Len() != columnCount*2 { + return &invalidMessageFormatErr{messageType: "CopyInResponse"} + } + + columnFormatCodes := make([]uint16, columnCount) + for i := 0; i < columnCount; i++ { + columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) + } + + *dst = CopyInResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes} + + return nil +} + +func (src *CopyInResponse) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('G') + buf.Write(bigEndian.Uint32(uint32(4 + 1 + 2 + 2*len(src.ColumnFormatCodes)))) + + buf.Write(bigEndian.Uint16(uint16(len(src.ColumnFormatCodes)))) + + for _, fc := range src.ColumnFormatCodes { + buf.Write(bigEndian.Uint16(fc)) + } + + return buf.Bytes(), nil +} + +func (src *CopyInResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ColumnFormatCodes []uint16 + }{ + Type: "CopyInResponse", + ColumnFormatCodes: src.ColumnFormatCodes, + }) +} diff --git a/pgproto3/copy_out_response.go b/pgproto3/copy_out_response.go new file mode 100644 index 00000000..e46d9e8f --- /dev/null +++ b/pgproto3/copy_out_response.go @@ -0,0 +1,64 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type CopyOutResponse struct { + OverallFormat byte + ColumnFormatCodes []uint16 +} + +func (*CopyOutResponse) Backend() {} + +func (dst *CopyOutResponse) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 3 { + return &invalidMessageFormatErr{messageType: "CopyOutResponse"} + } + + overallFormat := buf.Next(1)[0] + + columnCount := int(binary.BigEndian.Uint16(buf.Next(2))) + if buf.Len() != columnCount*2 { + return &invalidMessageFormatErr{messageType: "CopyOutResponse"} + } + + columnFormatCodes := make([]uint16, columnCount) + for i := 0; i < columnCount; i++ { + columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) + } + + *dst = CopyOutResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes} + + return nil +} + +func (src *CopyOutResponse) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('H') + buf.Write(bigEndian.Uint32(uint32(4 + 1 + 2 + 2*len(src.ColumnFormatCodes)))) + + buf.Write(bigEndian.Uint16(uint16(len(src.ColumnFormatCodes)))) + + for _, fc := range src.ColumnFormatCodes { + buf.Write(bigEndian.Uint16(fc)) + } + + return buf.Bytes(), nil +} + +func (src *CopyOutResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ColumnFormatCodes []uint16 + }{ + Type: "CopyOutResponse", + ColumnFormatCodes: src.ColumnFormatCodes, + }) +} diff --git a/pgproto3/data_row.go b/pgproto3/data_row.go new file mode 100644 index 00000000..c95861b9 --- /dev/null +++ b/pgproto3/data_row.go @@ -0,0 +1,103 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "encoding/json" +) + +type DataRow struct { + Values [][]byte +} + +func (*DataRow) Backend() {} + +func (dst *DataRow) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 2 { + return &invalidMessageFormatErr{messageType: "DataRow"} + } + fieldCount := int(binary.BigEndian.Uint16(buf.Next(2))) + + dst.Values = make([][]byte, fieldCount) + + for i := 0; i < fieldCount; i++ { + if buf.Len() < 4 { + return &invalidMessageFormatErr{messageType: "DataRow"} + } + + msgSize := int(int32(binary.BigEndian.Uint32(buf.Next(4)))) + + // null + if msgSize == -1 { + continue + } + + value := make([]byte, msgSize) + _, err := buf.Read(value) + if err != nil { + return err + } + + dst.Values[i] = value + } + + return nil +} + +func (src *DataRow) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('D') + buf.Write(bigEndian.Uint32(0)) + + buf.Write(bigEndian.Uint16(uint16(len(src.Values)))) + + for _, v := range src.Values { + if v == nil { + buf.Write(bigEndian.Int32(-1)) + continue + } + + buf.Write(bigEndian.Int32(int32(len(v)))) + buf.Write(v) + } + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} + +func (src *DataRow) MarshalJSON() ([]byte, error) { + formattedValues := make([]map[string]string, len(src.Values)) + for i, v := range src.Values { + if v == nil { + continue + } + + var hasNonPrintable bool + for _, b := range v { + if b < 32 { + hasNonPrintable = true + break + } + } + + if hasNonPrintable { + formattedValues[i] = map[string]string{"binary": hex.EncodeToString(v)} + } else { + formattedValues[i] = map[string]string{"text": string(v)} + } + } + + return json.Marshal(struct { + Type string + Values []map[string]string + }{ + Type: "DataRow", + Values: formattedValues, + }) +} diff --git a/pgproto3/empty_query_response.go b/pgproto3/empty_query_response.go new file mode 100644 index 00000000..de6e6272 --- /dev/null +++ b/pgproto3/empty_query_response.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type EmptyQueryResponse struct{} + +func (*EmptyQueryResponse) Backend() {} + +func (dst *EmptyQueryResponse) UnmarshalBinary(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "EmptyQueryResponse", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *EmptyQueryResponse) MarshalBinary() ([]byte, error) { + return []byte{'I', 0, 0, 0, 4}, nil +} + +func (src *EmptyQueryResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "EmptyQueryResponse", + }) +} diff --git a/pgproto3/error_response.go b/pgproto3/error_response.go new file mode 100644 index 00000000..82e408d7 --- /dev/null +++ b/pgproto3/error_response.go @@ -0,0 +1,197 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "strconv" +) + +type ErrorResponse struct { + Severity string + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string + + UnknownFields map[byte]string +} + +func (*ErrorResponse) Backend() {} + +func (dst *ErrorResponse) UnmarshalBinary(src []byte) error { + *dst = ErrorResponse{} + + buf := bytes.NewBuffer(src) + + for { + k, err := buf.ReadByte() + if err != nil { + return err + } + if k == 0 { + break + } + + vb, err := buf.ReadBytes(0) + if err != nil { + return err + } + v := string(vb[:len(vb)-1]) + + switch k { + case 'S': + dst.Severity = v + case 'C': + dst.Code = v + case 'M': + dst.Message = v + case 'D': + dst.Detail = v + case 'H': + dst.Hint = v + case 'P': + s := v + n, _ := strconv.ParseInt(s, 10, 32) + dst.Position = int32(n) + case 'p': + s := v + n, _ := strconv.ParseInt(s, 10, 32) + dst.InternalPosition = int32(n) + case 'q': + dst.InternalQuery = v + case 'W': + dst.Where = v + case 's': + dst.SchemaName = v + case 't': + dst.TableName = v + case 'c': + dst.ColumnName = v + case 'd': + dst.DataTypeName = v + case 'n': + dst.ConstraintName = v + case 'F': + dst.File = v + case 'L': + s := v + n, _ := strconv.ParseInt(s, 10, 32) + dst.Line = int32(n) + case 'R': + dst.Routine = v + + default: + if dst.UnknownFields == nil { + dst.UnknownFields = make(map[byte]string) + } + dst.UnknownFields[k] = v + } + } + + return nil +} + +func (src *ErrorResponse) MarshalBinary() ([]byte, error) { + return src.marshalBinary('E') +} + +func (src *ErrorResponse) marshalBinary(typeByte byte) ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte(typeByte) + buf.Write(bigEndian.Uint32(0)) + + if src.Severity != "" { + buf.WriteString(src.Severity) + buf.WriteByte(0) + } + if src.Code != "" { + buf.WriteString(src.Code) + buf.WriteByte(0) + } + if src.Message != "" { + buf.WriteString(src.Message) + buf.WriteByte(0) + } + if src.Detail != "" { + buf.WriteString(src.Detail) + buf.WriteByte(0) + } + if src.Hint != "" { + buf.WriteString(src.Hint) + buf.WriteByte(0) + } + if src.Position != 0 { + buf.WriteString(strconv.Itoa(int(src.Position))) + buf.WriteByte(0) + } + if src.InternalPosition != 0 { + buf.WriteString(strconv.Itoa(int(src.InternalPosition))) + buf.WriteByte(0) + } + if src.InternalQuery != "" { + buf.WriteString(src.InternalQuery) + buf.WriteByte(0) + } + if src.Where != "" { + buf.WriteString(src.Where) + buf.WriteByte(0) + } + if src.SchemaName != "" { + buf.WriteString(src.SchemaName) + buf.WriteByte(0) + } + if src.TableName != "" { + buf.WriteString(src.TableName) + buf.WriteByte(0) + } + if src.ColumnName != "" { + buf.WriteString(src.ColumnName) + buf.WriteByte(0) + } + if src.DataTypeName != "" { + buf.WriteString(src.DataTypeName) + buf.WriteByte(0) + } + if src.ConstraintName != "" { + buf.WriteString(src.ConstraintName) + buf.WriteByte(0) + } + if src.File != "" { + buf.WriteString(src.File) + buf.WriteByte(0) + } + if src.Line != 0 { + buf.WriteString(strconv.Itoa(int(src.Line))) + buf.WriteByte(0) + } + if src.Routine != "" { + buf.WriteString(src.Routine) + buf.WriteByte(0) + } + + for k, v := range src.UnknownFields { + buf.WriteByte(k) + buf.WriteByte(0) + buf.WriteString(v) + buf.WriteByte(0) + } + buf.WriteByte(0) + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go new file mode 100644 index 00000000..c1dec461 --- /dev/null +++ b/pgproto3/frontend.go @@ -0,0 +1,70 @@ +package pgproto3 + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/jackc/pgx/chunkreader" +) + +type Frontend struct { + cr *chunkreader.ChunkReader + w io.Writer +} + +func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) { + cr := chunkreader.NewChunkReader(r) + return &Frontend{cr: cr, w: w}, nil +} + +func (b *Frontend) Send(msg FrontendMessage) error { + return errors.New("not implemented") +} + +func (b *Frontend) Receive() (BackendMessage, error) { + backendMessages := map[byte]BackendMessage{ + '1': &ParseComplete{}, + '2': &BindComplete{}, + '3': &CloseComplete{}, + 'A': &NotificationResponse{}, + 'C': &CommandComplete{}, + 'd': &CopyData{}, + 'D': &DataRow{}, + 'E': &ErrorResponse{}, + 'G': &CopyInResponse{}, + 'H': &CopyOutResponse{}, + 'I': &EmptyQueryResponse{}, + 'K': &BackendKeyData{}, + 'n': &NoData{}, + 'N': &NoticeResponse{}, + 'R': &Authentication{}, + 'S': &ParameterStatus{}, + 't': &ParameterDescription{}, + 'T': &RowDescription{}, + 'V': &FunctionCallResponse{}, + 'W': &CopyBothResponse{}, + 'Z': &ReadyForQuery{}, + } + + header, err := b.cr.Next(5) + if err != nil { + return nil, err + } + + msgType := header[0] + bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4 + + msgBody, err := b.cr.Next(bodyLen) + if err != nil { + return nil, err + } + + if msg, ok := backendMessages[msgType]; ok { + err = msg.UnmarshalBinary(msgBody) + return msg, err + } + + return nil, fmt.Errorf("unknown message type: %c", msgType) +} diff --git a/pgproto3/function_call_response.go b/pgproto3/function_call_response.go new file mode 100644 index 00000000..5c692b36 --- /dev/null +++ b/pgproto3/function_call_response.go @@ -0,0 +1,73 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "encoding/json" +) + +type FunctionCallResponse struct { + Result []byte +} + +func (*FunctionCallResponse) Backend() {} + +func (dst *FunctionCallResponse) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 4 { + return &invalidMessageFormatErr{messageType: "FunctionCallResponse"} + } + resultSize := int(binary.BigEndian.Uint32(buf.Next(4))) + if buf.Len() != resultSize { + return &invalidMessageFormatErr{messageType: "FunctionCallResponse"} + } + + dst.Result = make([]byte, resultSize) + copy(dst.Result, buf.Bytes()) + + return nil +} + +func (src *FunctionCallResponse) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('V') + buf.Write(bigEndian.Uint32(uint32(4 + 4 + len(src.Result)))) + + if src.Result == nil { + buf.Write(bigEndian.Int32(-1)) + } else { + buf.Write(bigEndian.Int32(int32(len(src.Result)))) + buf.Write(src.Result) + } + + return buf.Bytes(), nil +} + +func (src *FunctionCallResponse) MarshalJSON() ([]byte, error) { + var formattedValue map[string]string + var hasNonPrintable bool + for _, b := range src.Result { + if b < 32 { + hasNonPrintable = true + break + } + } + + if hasNonPrintable { + formattedValue = map[string]string{"binary": hex.EncodeToString(src.Result)} + } else { + formattedValue = map[string]string{"text": string(src.Result)} + } + + return json.Marshal(struct { + Type string + Result map[string]string + }{ + Type: "FunctionCallResponse", + Result: formattedValue, + }) +} diff --git a/pgproto3/no_data.go b/pgproto3/no_data.go new file mode 100644 index 00000000..47ebf28e --- /dev/null +++ b/pgproto3/no_data.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type NoData struct{} + +func (*NoData) Backend() {} + +func (dst *NoData) UnmarshalBinary(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "NoData", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *NoData) MarshalBinary() ([]byte, error) { + return []byte{'n', 0, 0, 0, 4}, nil +} + +func (src *NoData) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "NoData", + }) +} diff --git a/pgproto3/notice_response.go b/pgproto3/notice_response.go new file mode 100644 index 00000000..767c9a67 --- /dev/null +++ b/pgproto3/notice_response.go @@ -0,0 +1,13 @@ +package pgproto3 + +type NoticeResponse ErrorResponse + +func (*NoticeResponse) Backend() {} + +func (dst *NoticeResponse) UnmarshalBinary(src []byte) error { + return (*ErrorResponse)(dst).UnmarshalBinary(src) +} + +func (src *NoticeResponse) MarshalBinary() ([]byte, error) { + return (*ErrorResponse)(src).marshalBinary('N') +} diff --git a/pgproto3/notification_response.go b/pgproto3/notification_response.go new file mode 100644 index 00000000..4ae8bab3 --- /dev/null +++ b/pgproto3/notification_response.go @@ -0,0 +1,65 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type NotificationResponse struct { + PID uint32 + Channel string + Payload string +} + +func (*NotificationResponse) Backend() {} + +func (dst *NotificationResponse) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + pid := binary.BigEndian.Uint32(buf.Next(4)) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + channel := string(b[:len(b)-1]) + + b, err = buf.ReadBytes(0) + if err != nil { + return err + } + payload := string(b[:len(b)-1]) + + *dst = NotificationResponse{PID: pid, Channel: channel, Payload: payload} + return nil +} + +func (src *NotificationResponse) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('A') + buf.Write(bigEndian.Uint32(uint32(4 + 4 + len(src.Channel) + len(src.Payload)))) + + buf.WriteString(src.Channel) + buf.WriteByte(0) + buf.WriteString(src.Payload) + buf.WriteByte(0) + + return buf.Bytes(), nil +} + +func (src *NotificationResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + PID uint32 + Channel string + Payload string + }{ + Type: "NotificationResponse", + PID: src.PID, + Channel: src.Channel, + Payload: src.Payload, + }) +} diff --git a/pgproto3/parameter_description.go b/pgproto3/parameter_description.go new file mode 100644 index 00000000..40d92c50 --- /dev/null +++ b/pgproto3/parameter_description.go @@ -0,0 +1,60 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type ParameterDescription struct { + ParameterOIDs []uint32 +} + +func (*ParameterDescription) Backend() {} + +func (dst *ParameterDescription) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 2 { + return &invalidMessageFormatErr{messageType: "ParameterDescription"} + } + + // Reported parameter count will be incorrect when number of args is greater than uint16 + buf.Next(2) + // Instead infer parameter count by remaining size of message + parameterCount := buf.Len() / 4 + + *dst = ParameterDescription{ParameterOIDs: make([]uint32, parameterCount)} + + for i := 0; i < parameterCount; i++ { + dst.ParameterOIDs[i] = binary.BigEndian.Uint32(buf.Next(4)) + } + + return nil +} + +func (src *ParameterDescription) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('t') + buf.Write(bigEndian.Uint32(uint32(4 + 2 + 4*len(src.ParameterOIDs)))) + + buf.Write(bigEndian.Uint16(uint16(len(src.ParameterOIDs)))) + + for _, oid := range src.ParameterOIDs { + buf.Write(bigEndian.Uint32(oid)) + } + + return buf.Bytes(), nil +} + +func (src *ParameterDescription) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ParameterOIDs []uint32 + }{ + Type: "ParameterDescription", + ParameterOIDs: src.ParameterOIDs, + }) +} diff --git a/pgproto3/parameter_status.go b/pgproto3/parameter_status.go new file mode 100644 index 00000000..b8ce7f8d --- /dev/null +++ b/pgproto3/parameter_status.go @@ -0,0 +1,62 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type ParameterStatus struct { + Name string + Value string +} + +func (*ParameterStatus) Backend() {} + +func (dst *ParameterStatus) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + name := string(b[:len(b)-1]) + + b, err = buf.ReadBytes(0) + if err != nil { + return err + } + value := string(b[:len(b)-1]) + + *dst = ParameterStatus{Name: name, Value: value} + return nil +} + +func (src *ParameterStatus) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('S') + buf.Write(bigEndian.Uint32(0)) + + buf.WriteString(src.Name) + buf.WriteByte(0) + buf.WriteString(src.Value) + buf.WriteByte(0) + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} + +func (ps *ParameterStatus) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Name string + Value string + }{ + Type: "ParameterStatus", + Name: ps.Name, + Value: ps.Value, + }) +} diff --git a/pgproto3/parse_complete.go b/pgproto3/parse_complete.go new file mode 100644 index 00000000..24951e3d --- /dev/null +++ b/pgproto3/parse_complete.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type ParseComplete struct{} + +func (*ParseComplete) Backend() {} + +func (dst *ParseComplete) UnmarshalBinary(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "ParseComplete", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *ParseComplete) MarshalBinary() ([]byte, error) { + return []byte{'1', 0, 0, 0, 4}, nil +} + +func (src *ParseComplete) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "ParseComplete", + }) +} diff --git a/pgproto3/pgproto3.go b/pgproto3/pgproto3.go new file mode 100644 index 00000000..a9221239 --- /dev/null +++ b/pgproto3/pgproto3.go @@ -0,0 +1,88 @@ +package pgproto3 + +import "fmt" + +type Message interface { + UnmarshalBinary(data []byte) error + MarshalBinary() (data []byte, err error) +} + +type FrontendMessage interface { + Message + Frontend() // no-op method to distinguish frontend from backend methods +} + +type BackendMessage interface { + Message + Backend() // no-op method to distinguish frontend from backend methods +} + +// func ParseBackend(typeByte byte, body []byte) (BackendMessage, error) { +// switch typeByte { +// case '1': +// return ParseParseComplete(body) +// case '2': +// return ParseBindComplete(body) +// case 'C': +// return ParseCommandComplete(body) +// case 'D': +// return ParseDataRow(body) +// case 'E': +// return ParseErrorResponse(body) +// case 'K': +// return ParseBackendKeyData(body) +// case 'R': +// return ParseAuthentication(body) +// case 'S': +// return ParseParameterStatus(body) +// case 'T': +// return ParseRowDescription(body) +// case 't': +// return ParseParameterDescription(body) +// case 'Z': +// return ParseReadyForQuery(body) +// default: +// return ParseUnknownMessage(typeByte, body) +// } +// } + +// func ParseFrontend(typeByte byte, body []byte) (FrontendMessage, error) { +// switch typeByte { +// case 'B': +// return ParseBind(body) +// case 'D': +// return ParseDescribe(body) +// case 'E': +// return ParseExecute(body) +// case 'P': +// return ParseParse(body) +// case 'p': +// return ParsePasswordMessage(body) +// case 'Q': +// return ParseQuery(body) +// case 'S': +// return ParseSync(body) +// case 'X': +// return ParseTerminate(body) +// default: +// return ParseUnknownMessage(typeByte, body) +// } +// } + +type invalidMessageLenErr struct { + messageType string + expectedLen int + actualLen int +} + +func (e *invalidMessageLenErr) Error() string { + return fmt.Sprintf("%s body must have length of %d, but it is %d", e.messageType, e.expectedLen, e.actualLen) +} + +type invalidMessageFormatErr struct { + messageType string +} + +func (e *invalidMessageFormatErr) Error() string { + return fmt.Sprintf("%s body is invalid", e.messageType) +} diff --git a/pgproto3/query.go b/pgproto3/query.go new file mode 100644 index 00000000..a3fc32eb --- /dev/null +++ b/pgproto3/query.go @@ -0,0 +1,43 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" +) + +type Query struct { + String string +} + +func (*Query) Frontend() {} + +func (dst *Query) UnmarshalBinary(src []byte) error { + i := bytes.IndexByte(src, 0) + if i != len(src)-1 { + return &invalidMessageFormatErr{messageType: "Query"} + } + + dst.String = string(src[:i]) + + return nil +} + +func (src *Query) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + buf.WriteByte('Q') + buf.Write(bigEndian.Uint32(uint32(4 + len(src.String) + 1))) + buf.WriteString(src.String) + buf.WriteByte(0) + return buf.Bytes(), nil +} + +func (src *Query) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + String string + }{ + Type: "Query", + String: src.String, + }) +} diff --git a/pgproto3/ready_for_query.go b/pgproto3/ready_for_query.go new file mode 100644 index 00000000..09005d00 --- /dev/null +++ b/pgproto3/ready_for_query.go @@ -0,0 +1,35 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type ReadyForQuery struct { + TxStatus byte +} + +func (*ReadyForQuery) Backend() {} + +func (dst *ReadyForQuery) UnmarshalBinary(src []byte) error { + if len(src) != 1 { + return &invalidMessageLenErr{messageType: "ReadyForQuery", expectedLen: 1, actualLen: len(src)} + } + + dst.TxStatus = src[0] + + return nil +} + +func (src *ReadyForQuery) MarshalBinary() ([]byte, error) { + return []byte{'Z', 0, 0, 0, 5, src.TxStatus}, nil +} + +func (src *ReadyForQuery) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + TxStatus string + }{ + Type: "ReadyForQuery", + TxStatus: string(src.TxStatus), + }) +} diff --git a/pgproto3/row_description.go b/pgproto3/row_description.go new file mode 100644 index 00000000..294a6aa9 --- /dev/null +++ b/pgproto3/row_description.go @@ -0,0 +1,101 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +const ( + TextFormat = 0 + BinaryFormat = 1 +) + +type FieldDescription struct { + Name string + TableOID uint32 + TableAttributeNumber uint16 + DataTypeOID uint32 + DataTypeSize int16 + TypeModifier uint32 + Format int16 +} + +type RowDescription struct { + Fields []FieldDescription +} + +func (*RowDescription) Backend() {} + +func (dst *RowDescription) UnmarshalBinary(src []byte) error { + buf := bytes.NewBuffer(src) + + if buf.Len() < 2 { + return &invalidMessageFormatErr{messageType: "RowDescription"} + } + fieldCount := int(binary.BigEndian.Uint16(buf.Next(2))) + + *dst = RowDescription{Fields: make([]FieldDescription, fieldCount)} + + for i := 0; i < fieldCount; i++ { + var fd FieldDescription + bName, err := buf.ReadBytes(0) + if err != nil { + return err + } + fd.Name = string(bName[:len(bName)-1]) + + // Since buf.Next() doesn't return an error if we hit the end of the buffer + // check Len ahead of time + if buf.Len() < 18 { + return &invalidMessageFormatErr{messageType: "RowDescription"} + } + + fd.TableOID = binary.BigEndian.Uint32(buf.Next(4)) + fd.TableAttributeNumber = binary.BigEndian.Uint16(buf.Next(2)) + fd.DataTypeOID = binary.BigEndian.Uint32(buf.Next(4)) + fd.DataTypeSize = int16(binary.BigEndian.Uint16(buf.Next(2))) + fd.TypeModifier = binary.BigEndian.Uint32(buf.Next(4)) + fd.Format = int16(binary.BigEndian.Uint16(buf.Next(2))) + + dst.Fields[i] = fd + } + + return nil +} + +func (src *RowDescription) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('T') + buf.Write(bigEndian.Uint32(0)) + + buf.Write(bigEndian.Uint16(uint16(len(src.Fields)))) + + for _, fd := range src.Fields { + buf.WriteString(fd.Name) + buf.WriteByte(0) + + buf.Write(bigEndian.Uint32(fd.TableOID)) + buf.Write(bigEndian.Uint16(fd.TableAttributeNumber)) + buf.Write(bigEndian.Uint32(fd.DataTypeOID)) + buf.Write(bigEndian.Uint16(uint16(fd.DataTypeSize))) + buf.Write(bigEndian.Uint32(fd.TypeModifier)) + buf.Write(bigEndian.Uint16(uint16(fd.Format))) + } + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} + +func (src *RowDescription) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Fields []FieldDescription + }{ + Type: "RowDescription", + Fields: src.Fields, + }) +} diff --git a/query.go b/query.go index f7d8ed19..04a87043 100644 --- a/query.go +++ b/query.go @@ -8,6 +8,7 @@ import ( "time" "github.com/jackc/pgx/internal/sanitize" + "github.com/jackc/pgx/pgproto3" "github.com/jackc/pgx/pgtype" ) @@ -41,7 +42,7 @@ func (r *Row) Scan(dest ...interface{}) (err error) { // calling Next() until it returns false, or when a fatal error occurs. type Rows struct { conn *Conn - mr *msgReader + values [][]byte fields []FieldDescription rowCount int columnIdx int @@ -115,15 +116,15 @@ func (rows *Rows) Next() bool { rows.columnIdx = 0 for { - t, r, err := rows.conn.rxMsg() + msg, err := rows.conn.rxMsg() if err != nil { rows.Fatal(err) return false } - switch t { - case rowDescription: - rows.fields = rows.conn.rxRowDescription(r) + switch msg := msg.(type) { + case *pgproto3.RowDescription: + rows.fields = rows.conn.rxRowDescription(msg) for i := range rows.fields { if dt, ok := rows.conn.ConnInfo.DataTypeForOid(rows.fields[i].DataType); ok { rows.fields[i].DataTypeName = dt.Name @@ -133,21 +134,20 @@ func (rows *Rows) Next() bool { return false } } - case dataRow: - fieldCount := r.readInt16() - if int(fieldCount) != len(rows.fields) { - rows.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), fieldCount))) + case *pgproto3.DataRow: + if len(msg.Values) != len(rows.fields) { + rows.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), len(msg.Values)))) return false } - rows.mr = r + rows.values = msg.Values return true - case commandComplete: + case *pgproto3.CommandComplete: rows.Close() return false default: - err = rows.conn.processContextFreeMsg(t, r) + err = rows.conn.processContextFreeMsg(msg) if err != nil { rows.Fatal(err) return false @@ -170,13 +170,9 @@ func (rows *Rows) nextColumn() ([]byte, *FieldDescription, bool) { return nil, nil, false } + buf := rows.values[rows.columnIdx] fd := &rows.fields[rows.columnIdx] rows.columnIdx++ - size := rows.mr.readInt32() - var buf []byte - if size >= 0 { - buf = rows.mr.readBytes(size) - } return buf, fd, true } diff --git a/replication.go b/replication.go index a251172d..ea768961 100644 --- a/replication.go +++ b/replication.go @@ -2,9 +2,12 @@ package pgx import ( "context" + "encoding/binary" "errors" "fmt" "time" + + "github.com/jackc/pgx/pgproto3" ) const ( @@ -203,59 +206,64 @@ func (rc *ReplicationConn) CauseOfDeath() error { } func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err error) { - var t byte - var reader *msgReader - t, reader, err = rc.c.rxMsg() + msg, err := rc.c.rxMsg() if err != nil { return } - switch t { - case noticeResponse: - pgError := rc.c.rxErrorResponse(reader) + switch msg := msg.(type) { + case *pgproto3.NoticeResponse: + pgError := rc.c.rxErrorResponse((*pgproto3.ErrorResponse)(msg)) if rc.c.shouldLog(LogLevelInfo) { rc.c.log(LogLevelInfo, pgError.Error()) } - case errorResponse: - err = rc.c.rxErrorResponse(reader) + case *pgproto3.ErrorResponse: + err = rc.c.rxErrorResponse(msg) if rc.c.shouldLog(LogLevelError) { rc.c.log(LogLevelError, err.Error()) } return - case copyBothResponse: + case *pgproto3.CopyBothResponse: // This is the tail end of the replication process start, // and can be safely ignored return - case copyData: - var msgType byte - msgType = reader.readByte() + case *pgproto3.CopyData: + msgType := msg.Data[0] + rp := 1 + switch msgType { case walData: - walStart := reader.readInt64() - serverWalEnd := reader.readInt64() - serverTime := reader.readInt64() - walData := reader.readBytes(int32(len(reader.msgBody) - reader.rp)) - walMessage := WalMessage{WalStart: uint64(walStart), - ServerWalEnd: uint64(serverWalEnd), - ServerTime: uint64(serverTime), + walStart := binary.BigEndian.Uint64(msg.Data[rp:]) + rp += 8 + serverWalEnd := binary.BigEndian.Uint64(msg.Data[rp:]) + rp += 8 + serverTime := binary.BigEndian.Uint64(msg.Data[rp:]) + rp += 8 + walData := msg.Data[rp:] + walMessage := WalMessage{WalStart: walStart, + ServerWalEnd: serverWalEnd, + ServerTime: serverTime, WalData: walData, } return &ReplicationMessage{WalMessage: &walMessage}, nil case senderKeepalive: - serverWalEnd := reader.readInt64() - serverTime := reader.readInt64() - replyNow := reader.readByte() - h := &ServerHeartbeat{ServerWalEnd: uint64(serverWalEnd), ServerTime: uint64(serverTime), ReplyRequested: replyNow} + serverWalEnd := binary.BigEndian.Uint64(msg.Data[rp:]) + rp += 8 + serverTime := binary.BigEndian.Uint64(msg.Data[rp:]) + rp += 8 + replyNow := msg.Data[rp] + rp += 1 + h := &ServerHeartbeat{ServerWalEnd: serverWalEnd, ServerTime: serverTime, ReplyRequested: replyNow} return &ReplicationMessage{ServerHeartbeat: h}, nil default: if rc.c.shouldLog(LogLevelError) { - rc.c.log(LogLevelError, "Unexpected data playload message type %v", t) + rc.c.log(LogLevelError, "Unexpected data playload message type %v", msgType) } } default: if rc.c.shouldLog(LogLevelError) { - rc.c.log(LogLevelError, "Unexpected replication message type %v", t) + rc.c.log(LogLevelError, "Unexpected replication message type %T", msg) } } return @@ -325,21 +333,19 @@ func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) { rows.Fatal(err) } - var t byte - var r *msgReader - t, r, err = rc.c.rxMsg() + msg, err := rc.c.rxMsg() if err != nil { return nil, err } - switch t { - case rowDescription: - rows.fields = rc.c.rxRowDescription(r) + switch msg := msg.(type) { + case *pgproto3.RowDescription: + rows.fields = rc.c.rxRowDescription(msg) // We don't have c.PgTypes here because we're a replication // connection. This means the field descriptions will have // only Oids. Not much we can do about this. default: - if e := rc.c.processContextFreeMsg(t, r); e != nil { + if e := rc.c.processContextFreeMsg(msg); e != nil { rows.Fatal(e) return rows, e } From 3c7235c68cb72e6d01a90d87cc0a2b2854c6d31f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 29 Apr 2017 10:46:25 -0500 Subject: [PATCH 173/264] Remove unused msg_reader.go --- msg_reader.go | 249 -------------------------------------------------- 1 file changed, 249 deletions(-) delete mode 100644 msg_reader.go diff --git a/msg_reader.go b/msg_reader.go deleted file mode 100644 index 1858037a..00000000 --- a/msg_reader.go +++ /dev/null @@ -1,249 +0,0 @@ -package pgx - -import ( - "bytes" - "encoding/binary" - "errors" - "net" - - "github.com/jackc/pgx/chunkreader" -) - -// msgReader is a helper that reads values from a PostgreSQL message. -type msgReader struct { - cr *chunkreader.ChunkReader - msgType byte - msgBody []byte - rp int // read position - err error - log func(lvl int, msg string, ctx ...interface{}) - shouldLog func(lvl int) bool -} - -// fatal tells rc that a Fatal error has occurred -func (r *msgReader) fatal(err error) { - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgType", r.msgType, "msgBody", r.msgBody, "rp", r.rp) - } - r.err = err -} - -// rxMsg reads the type and size of the next message. -func (r *msgReader) rxMsg() (byte, error) { - if r.err != nil { - return 0, r.err - } - - header, err := r.cr.Next(5) - if err != nil { - if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { - r.fatal(err) - } - return 0, err - } - - r.msgType = header[0] - bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4 - - r.msgBody, err = r.cr.Next(bodyLen) - if err != nil { - if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { - r.fatal(err) - } - return 0, err - } - - r.rp = 0 - - return r.msgType, nil -} - -func (r *msgReader) readByte() byte { - if r.err != nil { - return 0 - } - - if len(r.msgBody)-r.rp < 1 { - r.fatal(errors.New("read past end of message")) - return 0 - } - - b := r.msgBody[r.rp] - r.rp++ - - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readByte", "value", b, "byteAsString", string(b), "msgType", r.msgType, "rp", r.rp) - } - - return b -} - -func (r *msgReader) readInt16() int16 { - if r.err != nil { - return 0 - } - - if len(r.msgBody)-r.rp < 2 { - r.fatal(errors.New("read past end of message")) - return 0 - } - - n := int16(binary.BigEndian.Uint16(r.msgBody[r.rp:])) - r.rp += 2 - - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgType", r.msgType, "rp", r.rp) - } - - return n -} - -func (r *msgReader) readInt32() int32 { - if r.err != nil { - return 0 - } - - if len(r.msgBody)-r.rp < 4 { - r.fatal(errors.New("read past end of message")) - return 0 - } - - n := int32(binary.BigEndian.Uint32(r.msgBody[r.rp:])) - r.rp += 4 - - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgType", r.msgType, "rp", r.rp) - } - - return n -} - -func (r *msgReader) readUint16() uint16 { - if r.err != nil { - return 0 - } - - if len(r.msgBody)-r.rp < 2 { - r.fatal(errors.New("read past end of message")) - return 0 - } - - n := binary.BigEndian.Uint16(r.msgBody[r.rp:]) - r.rp += 2 - - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgType", r.msgType, "rp", r.rp) - } - - return n -} - -func (r *msgReader) readUint32() uint32 { - if r.err != nil { - return 0 - } - - if len(r.msgBody)-r.rp < 4 { - r.fatal(errors.New("read past end of message")) - return 0 - } - - n := binary.BigEndian.Uint32(r.msgBody[r.rp:]) - r.rp += 4 - - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readUint32", "value", n, "msgType", r.msgType, "rp", r.rp) - } - - return n -} - -func (r *msgReader) readInt64() int64 { - if r.err != nil { - return 0 - } - - if len(r.msgBody)-r.rp < 8 { - r.fatal(errors.New("read past end of message")) - return 0 - } - - n := int64(binary.BigEndian.Uint64(r.msgBody[r.rp:])) - r.rp += 8 - - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgType", r.msgType, "rp", r.rp) - } - - return n -} - -// readCString reads a null terminated string -func (r *msgReader) readCString() string { - if r.err != nil { - return "" - } - - nullIdx := bytes.IndexByte(r.msgBody[r.rp:], 0) - if nullIdx == -1 { - r.fatal(errors.New("null terminated string not found")) - return "" - } - - s := string(r.msgBody[r.rp : r.rp+nullIdx]) - r.rp += nullIdx + 1 - - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readCString", "value", s, "msgType", r.msgType, "rp", r.rp) - } - - return s -} - -// readString reads count bytes and returns as string -func (r *msgReader) readString(countI32 int32) string { - if r.err != nil { - return "" - } - - count := int(countI32) - - if len(r.msgBody)-r.rp < count { - r.fatal(errors.New("read past end of message")) - return "" - } - - s := string(r.msgBody[r.rp : r.rp+count]) - r.rp += count - - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgType", r.msgType, "rp", r.rp) - } - - return s -} - -// readBytes reads count bytes and returns as []byte -func (r *msgReader) readBytes(countI32 int32) []byte { - if r.err != nil { - return nil - } - - count := int(countI32) - - if len(r.msgBody)-r.rp < count { - r.fatal(errors.New("read past end of message")) - return nil - } - - b := r.msgBody[r.rp : r.rp+count] - r.rp += count - - r.cr.KeepLast() - - if r.shouldLog(LogLevelTrace) { - r.log(LogLevelTrace, "msgReader.readBytes", "value", b, r.msgType, "rp", r.rp) - } - - return b -} From 70b7c9a300fbb28809af167042bb907c9229046a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 29 Apr 2017 11:01:54 -0500 Subject: [PATCH 174/264] Use flyweight pattern for pgproto3 messages --- pgproto3/frontend.go | 103 ++++++++++++++++++++++++++++++------------- 1 file changed, 73 insertions(+), 30 deletions(-) diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index c1dec461..df67b718 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -12,6 +12,29 @@ import ( type Frontend struct { cr *chunkreader.ChunkReader w io.Writer + + // Backend message flyweights + authentication Authentication + backendKeyData BackendKeyData + bindComplete BindComplete + closeComplete CloseComplete + commandComplete CommandComplete + copyBothResponse CopyBothResponse + copyData CopyData + copyInResponse CopyInResponse + copyOutResponse CopyOutResponse + dataRow DataRow + emptyQueryResponse EmptyQueryResponse + errorResponse ErrorResponse + functionCallResponse FunctionCallResponse + noData NoData + noticeResponse NoticeResponse + notificationResponse NotificationResponse + parameterDescription ParameterDescription + parameterStatus ParameterStatus + parseComplete ParseComplete + readyForQuery ReadyForQuery + rowDescription RowDescription } func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) { @@ -24,30 +47,6 @@ func (b *Frontend) Send(msg FrontendMessage) error { } func (b *Frontend) Receive() (BackendMessage, error) { - backendMessages := map[byte]BackendMessage{ - '1': &ParseComplete{}, - '2': &BindComplete{}, - '3': &CloseComplete{}, - 'A': &NotificationResponse{}, - 'C': &CommandComplete{}, - 'd': &CopyData{}, - 'D': &DataRow{}, - 'E': &ErrorResponse{}, - 'G': &CopyInResponse{}, - 'H': &CopyOutResponse{}, - 'I': &EmptyQueryResponse{}, - 'K': &BackendKeyData{}, - 'n': &NoData{}, - 'N': &NoticeResponse{}, - 'R': &Authentication{}, - 'S': &ParameterStatus{}, - 't': &ParameterDescription{}, - 'T': &RowDescription{}, - 'V': &FunctionCallResponse{}, - 'W': &CopyBothResponse{}, - 'Z': &ReadyForQuery{}, - } - header, err := b.cr.Next(5) if err != nil { return nil, err @@ -56,15 +55,59 @@ func (b *Frontend) Receive() (BackendMessage, error) { msgType := header[0] bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4 + var msg BackendMessage + switch msgType { + case '1': + msg = &b.parseComplete + case '2': + msg = &b.bindComplete + case '3': + msg = &b.closeComplete + case 'A': + msg = &b.notificationResponse + case 'C': + msg = &b.commandComplete + case 'd': + msg = &b.copyData + case 'D': + msg = &b.dataRow + case 'E': + msg = &b.errorResponse + case 'G': + msg = &b.copyInResponse + case 'H': + msg = &b.copyOutResponse + case 'I': + msg = &b.emptyQueryResponse + case 'K': + msg = &b.backendKeyData + case 'n': + msg = &b.noData + case 'N': + msg = &b.noticeResponse + case 'R': + msg = &b.authentication + case 'S': + msg = &b.parameterStatus + case 't': + msg = &b.parameterDescription + case 'T': + msg = &b.rowDescription + case 'V': + msg = &b.functionCallResponse + case 'W': + msg = &b.copyBothResponse + case 'Z': + msg = &b.readyForQuery + default: + return nil, fmt.Errorf("unknown message type: %c", msgType) + } + msgBody, err := b.cr.Next(bodyLen) if err != nil { return nil, err } - if msg, ok := backendMessages[msgType]; ok { - err = msg.UnmarshalBinary(msgBody) - return msg, err - } - - return nil, fmt.Errorf("unknown message type: %c", msgType) + err = msg.UnmarshalBinary(msgBody) + return msg, err } From e8eaad520bd6445b2ecdf46cff0ea86f5bf86054 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 29 Apr 2017 11:55:14 -0500 Subject: [PATCH 175/264] Reduce allocations and copies in pgproto3 Altered chunkreader to never reuse memory. Altered pgproto3 to to copy memory when decoding. Renamed UnmarshalBinary to Decode because of changed semantics. --- chunkreader/chunkreader.go | 25 ++----------- chunkreader/chunkreader_test.go | 44 +++------------------- pgproto3/authentication.go | 2 +- pgproto3/backend_key_data.go | 2 +- pgproto3/bind_complete.go | 2 +- pgproto3/close_complete.go | 2 +- pgproto3/command_complete.go | 13 +++---- pgproto3/copy_both_response.go | 2 +- pgproto3/copy_data.go | 5 +-- pgproto3/copy_in_response.go | 2 +- pgproto3/copy_out_response.go | 2 +- pgproto3/data_row.go | 39 ++++++++++++-------- pgproto3/empty_query_response.go | 2 +- pgproto3/error_response.go | 2 +- pgproto3/frontend.go | 2 +- pgproto3/function_call_response.go | 22 ++++++----- pgproto3/no_data.go | 2 +- pgproto3/notice_response.go | 4 +- pgproto3/notification_response.go | 2 +- pgproto3/parameter_description.go | 2 +- pgproto3/parameter_status.go | 2 +- pgproto3/parse_complete.go | 2 +- pgproto3/pgproto3.go | 59 +++--------------------------- pgproto3/query.go | 2 +- pgproto3/ready_for_query.go | 2 +- pgproto3/row_description.go | 2 +- 26 files changed, 80 insertions(+), 167 deletions(-) diff --git a/chunkreader/chunkreader.go b/chunkreader/chunkreader.go index f9d6555c..f8d437b2 100644 --- a/chunkreader/chunkreader.go +++ b/chunkreader/chunkreader.go @@ -9,14 +9,12 @@ type ChunkReader struct { buf []byte rp, wp int // buf read position and write position - taken bool options Options } type Options struct { MinBufLen int // Minimum buffer length - BlockLen int // Increments to expand buffer (e.g. a 8000 byte request with a BlockLen of 1024 would yield a buffer len of 8192) } func NewChunkReader(r io.Reader) *ChunkReader { @@ -32,9 +30,6 @@ func NewChunkReaderEx(r io.Reader, options Options) (*ChunkReader, error) { if options.MinBufLen == 0 { options.MinBufLen = 4096 } - if options.BlockLen == 0 { - options.BlockLen = 512 - } return &ChunkReader{ r: r, @@ -43,8 +38,8 @@ func NewChunkReaderEx(r io.Reader, options Options) (*ChunkReader, error) { }, nil } -// Next returns buf filled with the next n bytes. buf is only valid until the -// next call to Next. If an error occurs, buf will be nil. +// Next returns buf filled with the next n bytes. If an error occurs, buf will +// be nil. func (r *ChunkReader) Next(n int) (buf []byte, err error) { // n bytes already in buf if (r.wp - r.rp) >= n { @@ -56,17 +51,12 @@ func (r *ChunkReader) Next(n int) (buf []byte, err error) { // available space in buf is less than n if len(r.buf) < n { r.copyBufContents(r.newBuf(n)) - r.taken = false } // buf is large enough, but need to shift filled area to start to make enough contiguous space minReadCount := n - (r.wp - r.rp) if (len(r.buf) - r.wp) < minReadCount { - newBuf := r.buf - if r.taken { - newBuf = r.newBuf(n) - r.taken = false - } + newBuf := r.newBuf(n) r.copyBufContents(newBuf) } @@ -79,20 +69,13 @@ func (r *ChunkReader) Next(n int) (buf []byte, err error) { return buf, nil } -// KeepLast prevents the last data retrieved by Next from being reused by the -// ChunkReader. -func (r *ChunkReader) KeepLast() { - r.taken = true -} - func (r *ChunkReader) appendAtLeast(fillLen int) error { n, err := io.ReadAtLeast(r.r, r.buf[r.wp:], fillLen) r.wp += n return err } -func (r *ChunkReader) newBuf(min int) []byte { - size := ((min / r.options.BlockLen) + 1) * r.options.BlockLen +func (r *ChunkReader) newBuf(size int) []byte { if size < r.options.MinBufLen { size = r.options.MinBufLen } diff --git a/chunkreader/chunkreader_test.go b/chunkreader/chunkreader_test.go index 9c19ff4a..3be07e3c 100644 --- a/chunkreader/chunkreader_test.go +++ b/chunkreader/chunkreader_test.go @@ -7,7 +7,7 @@ import ( func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) { server := &bytes.Buffer{} - r, err := NewChunkReaderEx(server, Options{MinBufLen: 4, BlockLen: 2}) + r, err := NewChunkReaderEx(server, Options{MinBufLen: 4}) if err != nil { t.Fatal(err) } @@ -44,7 +44,7 @@ func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) { func TestChunkReaderNextExpandsBufAsNeeded(t *testing.T) { server := &bytes.Buffer{} - r, err := NewChunkReaderEx(server, Options{MinBufLen: 4, BlockLen: 2}) + r, err := NewChunkReaderEx(server, Options{MinBufLen: 4}) if err != nil { t.Fatal(err) } @@ -59,14 +59,14 @@ func TestChunkReaderNextExpandsBufAsNeeded(t *testing.T) { if bytes.Compare(n1, src[0:5]) != 0 { t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:5], n1) } - if len(r.buf) != 6 { - t.Fatalf("Expected len(r.buf) to be %v, but it was %v", 6, len(r.buf)) + if len(r.buf) != 5 { + t.Fatalf("Expected len(r.buf) to be %v, but it was %v", 5, len(r.buf)) } } -func TestChunkReaderNextReusesBuf(t *testing.T) { +func TestChunkReaderDoesNotReuseBuf(t *testing.T) { server := &bytes.Buffer{} - r, err := NewChunkReaderEx(server, Options{MinBufLen: 4, BlockLen: 1}) + r, err := NewChunkReaderEx(server, Options{MinBufLen: 4}) if err != nil { t.Fatal(err) } @@ -90,38 +90,6 @@ func TestChunkReaderNextReusesBuf(t *testing.T) { t.Fatalf("Expected read bytes to be %v, but they were %v", src[4:8], n2) } - if bytes.Compare(n1, src[4:8]) != 0 { - t.Fatalf("Expected Next to have reused buf, %v found instead of %v", src[4:8], n1) - } -} - -func TestChunkReaderKeepLastPreventsBufReuse(t *testing.T) { - server := &bytes.Buffer{} - r, err := NewChunkReaderEx(server, Options{MinBufLen: 4, BlockLen: 1}) - if err != nil { - t.Fatal(err) - } - - src := []byte{1, 2, 3, 4, 5, 6, 7, 8} - server.Write(src) - - n1, err := r.Next(4) - if err != nil { - t.Fatal(err) - } - if bytes.Compare(n1, src[0:4]) != 0 { - t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:4], n1) - } - r.KeepLast() - - n2, err := r.Next(4) - if err != nil { - t.Fatal(err) - } - if bytes.Compare(n2, src[4:8]) != 0 { - t.Fatalf("Expected read bytes to be %v, but they were %v", src[4:8], n2) - } - if bytes.Compare(n1, src[0:4]) != 0 { t.Fatalf("Expected KeepLast to prevent Next from overwriting buf, expected %v but it was %v", src[0:4], n1) } diff --git a/pgproto3/authentication.go b/pgproto3/authentication.go index e265a247..54f4978f 100644 --- a/pgproto3/authentication.go +++ b/pgproto3/authentication.go @@ -21,7 +21,7 @@ type Authentication struct { func (*Authentication) Backend() {} -func (dst *Authentication) UnmarshalBinary(src []byte) error { +func (dst *Authentication) Decode(src []byte) error { *dst = Authentication{Type: binary.BigEndian.Uint32(src[:4])} switch dst.Type { diff --git a/pgproto3/backend_key_data.go b/pgproto3/backend_key_data.go index 5d8eb496..04f31aec 100644 --- a/pgproto3/backend_key_data.go +++ b/pgproto3/backend_key_data.go @@ -13,7 +13,7 @@ type BackendKeyData struct { func (*BackendKeyData) Backend() {} -func (dst *BackendKeyData) UnmarshalBinary(src []byte) error { +func (dst *BackendKeyData) Decode(src []byte) error { if len(src) != 8 { return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)} } diff --git a/pgproto3/bind_complete.go b/pgproto3/bind_complete.go index 756a30e6..4f1c44b8 100644 --- a/pgproto3/bind_complete.go +++ b/pgproto3/bind_complete.go @@ -8,7 +8,7 @@ type BindComplete struct{} func (*BindComplete) Backend() {} -func (dst *BindComplete) UnmarshalBinary(src []byte) error { +func (dst *BindComplete) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)} } diff --git a/pgproto3/close_complete.go b/pgproto3/close_complete.go index fd6ff180..9bab3e8c 100644 --- a/pgproto3/close_complete.go +++ b/pgproto3/close_complete.go @@ -8,7 +8,7 @@ type CloseComplete struct{} func (*CloseComplete) Backend() {} -func (dst *CloseComplete) UnmarshalBinary(src []byte) error { +func (dst *CloseComplete) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)} } diff --git a/pgproto3/command_complete.go b/pgproto3/command_complete.go index ac60153e..86653804 100644 --- a/pgproto3/command_complete.go +++ b/pgproto3/command_complete.go @@ -11,14 +11,13 @@ type CommandComplete struct { func (*CommandComplete) Backend() {} -func (dst *CommandComplete) UnmarshalBinary(src []byte) error { - buf := bytes.NewBuffer(src) - - b, err := buf.ReadBytes(0) - if err != nil { - return err +func (dst *CommandComplete) Decode(src []byte) error { + idx := bytes.IndexByte(src, 0) + if idx != len(src)-1 { + return &invalidMessageFormatErr{messageType: "CommandComplete"} } - dst.CommandTag = string(b[:len(b)-1]) + + dst.CommandTag = string(src[:idx]) return nil } diff --git a/pgproto3/copy_both_response.go b/pgproto3/copy_both_response.go index 2a4c58af..3857c187 100644 --- a/pgproto3/copy_both_response.go +++ b/pgproto3/copy_both_response.go @@ -13,7 +13,7 @@ type CopyBothResponse struct { func (*CopyBothResponse) Backend() {} -func (dst *CopyBothResponse) UnmarshalBinary(src []byte) error { +func (dst *CopyBothResponse) Decode(src []byte) error { buf := bytes.NewBuffer(src) if buf.Len() < 3 { diff --git a/pgproto3/copy_data.go b/pgproto3/copy_data.go index b9ea6272..de7ab4ff 100644 --- a/pgproto3/copy_data.go +++ b/pgproto3/copy_data.go @@ -13,9 +13,8 @@ type CopyData struct { func (*CopyData) Backend() {} func (*CopyData) Frontend() {} -func (dst *CopyData) UnmarshalBinary(src []byte) error { - dst.Data = make([]byte, len(src)) - copy(dst.Data, src) +func (dst *CopyData) Decode(src []byte) error { + dst.Data = src return nil } diff --git a/pgproto3/copy_in_response.go b/pgproto3/copy_in_response.go index 63868c7a..9854d665 100644 --- a/pgproto3/copy_in_response.go +++ b/pgproto3/copy_in_response.go @@ -13,7 +13,7 @@ type CopyInResponse struct { func (*CopyInResponse) Backend() {} -func (dst *CopyInResponse) UnmarshalBinary(src []byte) error { +func (dst *CopyInResponse) Decode(src []byte) error { buf := bytes.NewBuffer(src) if buf.Len() < 3 { diff --git a/pgproto3/copy_out_response.go b/pgproto3/copy_out_response.go index e46d9e8f..5ef6e4c1 100644 --- a/pgproto3/copy_out_response.go +++ b/pgproto3/copy_out_response.go @@ -13,7 +13,7 @@ type CopyOutResponse struct { func (*CopyOutResponse) Backend() {} -func (dst *CopyOutResponse) UnmarshalBinary(src []byte) error { +func (dst *CopyOutResponse) Decode(src []byte) error { buf := bytes.NewBuffer(src) if buf.Len() < 3 { diff --git a/pgproto3/data_row.go b/pgproto3/data_row.go index c95861b9..6b27f728 100644 --- a/pgproto3/data_row.go +++ b/pgproto3/data_row.go @@ -13,35 +13,42 @@ type DataRow struct { func (*DataRow) Backend() {} -func (dst *DataRow) UnmarshalBinary(src []byte) error { - buf := bytes.NewBuffer(src) - - if buf.Len() < 2 { +func (dst *DataRow) Decode(src []byte) error { + if len(src) < 2 { return &invalidMessageFormatErr{messageType: "DataRow"} } - fieldCount := int(binary.BigEndian.Uint16(buf.Next(2))) + rp := 0 + fieldCount := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 - dst.Values = make([][]byte, fieldCount) + // If the capacity of the values slice is too small OR substantially too + // large reallocate. This is too avoid one row with many columns from + // permanently allocating memory. + if cap(dst.Values) < fieldCount || cap(dst.Values)-fieldCount > 32 { + dst.Values = make([][]byte, fieldCount, 32) + } else { + dst.Values = dst.Values[:fieldCount] + } for i := 0; i < fieldCount; i++ { - if buf.Len() < 4 { + if len(src[rp:]) < 4 { return &invalidMessageFormatErr{messageType: "DataRow"} } - msgSize := int(int32(binary.BigEndian.Uint32(buf.Next(4)))) + msgSize := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 // null if msgSize == -1 { - continue - } + dst.Values[i] = nil + } else { + if len(src[rp:]) < msgSize { + return &invalidMessageFormatErr{messageType: "DataRow"} + } - value := make([]byte, msgSize) - _, err := buf.Read(value) - if err != nil { - return err + dst.Values[i] = src[rp : rp+msgSize] + rp += msgSize } - - dst.Values[i] = value } return nil diff --git a/pgproto3/empty_query_response.go b/pgproto3/empty_query_response.go index de6e6272..13ed1886 100644 --- a/pgproto3/empty_query_response.go +++ b/pgproto3/empty_query_response.go @@ -8,7 +8,7 @@ type EmptyQueryResponse struct{} func (*EmptyQueryResponse) Backend() {} -func (dst *EmptyQueryResponse) UnmarshalBinary(src []byte) error { +func (dst *EmptyQueryResponse) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "EmptyQueryResponse", expectedLen: 0, actualLen: len(src)} } diff --git a/pgproto3/error_response.go b/pgproto3/error_response.go index 82e408d7..602dd2a1 100644 --- a/pgproto3/error_response.go +++ b/pgproto3/error_response.go @@ -30,7 +30,7 @@ type ErrorResponse struct { func (*ErrorResponse) Backend() {} -func (dst *ErrorResponse) UnmarshalBinary(src []byte) error { +func (dst *ErrorResponse) Decode(src []byte) error { *dst = ErrorResponse{} buf := bytes.NewBuffer(src) diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index df67b718..50835836 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -108,6 +108,6 @@ func (b *Frontend) Receive() (BackendMessage, error) { return nil, err } - err = msg.UnmarshalBinary(msgBody) + err = msg.Decode(msgBody) return msg, err } diff --git a/pgproto3/function_call_response.go b/pgproto3/function_call_response.go index 5c692b36..1e0f16af 100644 --- a/pgproto3/function_call_response.go +++ b/pgproto3/function_call_response.go @@ -13,20 +13,24 @@ type FunctionCallResponse struct { func (*FunctionCallResponse) Backend() {} -func (dst *FunctionCallResponse) UnmarshalBinary(src []byte) error { - buf := bytes.NewBuffer(src) - - if buf.Len() < 4 { +func (dst *FunctionCallResponse) Decode(src []byte) error { + if len(src) < 4 { return &invalidMessageFormatErr{messageType: "FunctionCallResponse"} } - resultSize := int(binary.BigEndian.Uint32(buf.Next(4))) - if buf.Len() != resultSize { + rp := 0 + resultSize := int(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + if resultSize == -1 { + dst.Result = nil + return nil + } + + if len(src[rp:]) != resultSize { return &invalidMessageFormatErr{messageType: "FunctionCallResponse"} } - dst.Result = make([]byte, resultSize) - copy(dst.Result, buf.Bytes()) - + dst.Result = src[rp:] return nil } diff --git a/pgproto3/no_data.go b/pgproto3/no_data.go index 47ebf28e..3adec4ad 100644 --- a/pgproto3/no_data.go +++ b/pgproto3/no_data.go @@ -8,7 +8,7 @@ type NoData struct{} func (*NoData) Backend() {} -func (dst *NoData) UnmarshalBinary(src []byte) error { +func (dst *NoData) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "NoData", expectedLen: 0, actualLen: len(src)} } diff --git a/pgproto3/notice_response.go b/pgproto3/notice_response.go index 767c9a67..8af55baf 100644 --- a/pgproto3/notice_response.go +++ b/pgproto3/notice_response.go @@ -4,8 +4,8 @@ type NoticeResponse ErrorResponse func (*NoticeResponse) Backend() {} -func (dst *NoticeResponse) UnmarshalBinary(src []byte) error { - return (*ErrorResponse)(dst).UnmarshalBinary(src) +func (dst *NoticeResponse) Decode(src []byte) error { + return (*ErrorResponse)(dst).Decode(src) } func (src *NoticeResponse) MarshalBinary() ([]byte, error) { diff --git a/pgproto3/notification_response.go b/pgproto3/notification_response.go index 4ae8bab3..7262844e 100644 --- a/pgproto3/notification_response.go +++ b/pgproto3/notification_response.go @@ -14,7 +14,7 @@ type NotificationResponse struct { func (*NotificationResponse) Backend() {} -func (dst *NotificationResponse) UnmarshalBinary(src []byte) error { +func (dst *NotificationResponse) Decode(src []byte) error { buf := bytes.NewBuffer(src) pid := binary.BigEndian.Uint32(buf.Next(4)) diff --git a/pgproto3/parameter_description.go b/pgproto3/parameter_description.go index 40d92c50..32b6e1c1 100644 --- a/pgproto3/parameter_description.go +++ b/pgproto3/parameter_description.go @@ -12,7 +12,7 @@ type ParameterDescription struct { func (*ParameterDescription) Backend() {} -func (dst *ParameterDescription) UnmarshalBinary(src []byte) error { +func (dst *ParameterDescription) Decode(src []byte) error { buf := bytes.NewBuffer(src) if buf.Len() < 2 { diff --git a/pgproto3/parameter_status.go b/pgproto3/parameter_status.go index b8ce7f8d..9b10824c 100644 --- a/pgproto3/parameter_status.go +++ b/pgproto3/parameter_status.go @@ -13,7 +13,7 @@ type ParameterStatus struct { func (*ParameterStatus) Backend() {} -func (dst *ParameterStatus) UnmarshalBinary(src []byte) error { +func (dst *ParameterStatus) Decode(src []byte) error { buf := bytes.NewBuffer(src) b, err := buf.ReadBytes(0) diff --git a/pgproto3/parse_complete.go b/pgproto3/parse_complete.go index 24951e3d..e949c14c 100644 --- a/pgproto3/parse_complete.go +++ b/pgproto3/parse_complete.go @@ -8,7 +8,7 @@ type ParseComplete struct{} func (*ParseComplete) Backend() {} -func (dst *ParseComplete) UnmarshalBinary(src []byte) error { +func (dst *ParseComplete) Decode(src []byte) error { if len(src) != 0 { return &invalidMessageLenErr{messageType: "ParseComplete", expectedLen: 0, actualLen: len(src)} } diff --git a/pgproto3/pgproto3.go b/pgproto3/pgproto3.go index a9221239..3fe8fc93 100644 --- a/pgproto3/pgproto3.go +++ b/pgproto3/pgproto3.go @@ -2,8 +2,13 @@ package pgproto3 import "fmt" +// Message is the interface implemented by an object that can decode and encode +// a particular PostgreSQL message. +// +// Decode is allowed and expected to retain a reference to data after +// returning (unlike encoding.BinaryUnmarshaler). type Message interface { - UnmarshalBinary(data []byte) error + Decode(data []byte) error MarshalBinary() (data []byte, err error) } @@ -17,58 +22,6 @@ type BackendMessage interface { Backend() // no-op method to distinguish frontend from backend methods } -// func ParseBackend(typeByte byte, body []byte) (BackendMessage, error) { -// switch typeByte { -// case '1': -// return ParseParseComplete(body) -// case '2': -// return ParseBindComplete(body) -// case 'C': -// return ParseCommandComplete(body) -// case 'D': -// return ParseDataRow(body) -// case 'E': -// return ParseErrorResponse(body) -// case 'K': -// return ParseBackendKeyData(body) -// case 'R': -// return ParseAuthentication(body) -// case 'S': -// return ParseParameterStatus(body) -// case 'T': -// return ParseRowDescription(body) -// case 't': -// return ParseParameterDescription(body) -// case 'Z': -// return ParseReadyForQuery(body) -// default: -// return ParseUnknownMessage(typeByte, body) -// } -// } - -// func ParseFrontend(typeByte byte, body []byte) (FrontendMessage, error) { -// switch typeByte { -// case 'B': -// return ParseBind(body) -// case 'D': -// return ParseDescribe(body) -// case 'E': -// return ParseExecute(body) -// case 'P': -// return ParseParse(body) -// case 'p': -// return ParsePasswordMessage(body) -// case 'Q': -// return ParseQuery(body) -// case 'S': -// return ParseSync(body) -// case 'X': -// return ParseTerminate(body) -// default: -// return ParseUnknownMessage(typeByte, body) -// } -// } - type invalidMessageLenErr struct { messageType string expectedLen int diff --git a/pgproto3/query.go b/pgproto3/query.go index a3fc32eb..b5fc2dbc 100644 --- a/pgproto3/query.go +++ b/pgproto3/query.go @@ -11,7 +11,7 @@ type Query struct { func (*Query) Frontend() {} -func (dst *Query) UnmarshalBinary(src []byte) error { +func (dst *Query) Decode(src []byte) error { i := bytes.IndexByte(src, 0) if i != len(src)-1 { return &invalidMessageFormatErr{messageType: "Query"} diff --git a/pgproto3/ready_for_query.go b/pgproto3/ready_for_query.go index 09005d00..e0e4707a 100644 --- a/pgproto3/ready_for_query.go +++ b/pgproto3/ready_for_query.go @@ -10,7 +10,7 @@ type ReadyForQuery struct { func (*ReadyForQuery) Backend() {} -func (dst *ReadyForQuery) UnmarshalBinary(src []byte) error { +func (dst *ReadyForQuery) Decode(src []byte) error { if len(src) != 1 { return &invalidMessageLenErr{messageType: "ReadyForQuery", expectedLen: 1, actualLen: len(src)} } diff --git a/pgproto3/row_description.go b/pgproto3/row_description.go index 294a6aa9..b1110290 100644 --- a/pgproto3/row_description.go +++ b/pgproto3/row_description.go @@ -27,7 +27,7 @@ type RowDescription struct { func (*RowDescription) Backend() {} -func (dst *RowDescription) UnmarshalBinary(src []byte) error { +func (dst *RowDescription) Decode(src []byte) error { buf := bytes.NewBuffer(src) if buf.Len() < 2 { From 932caef6000e8c3cd5077a5d4777955b22d2353d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 29 Apr 2017 12:23:51 -0500 Subject: [PATCH 176/264] pgtype DecodeText and DecodeBinary do not copy They now take ownership of the src argument. Needed to change Scan to make a copy of []byte arguments as lib/pq apparently gives Scan a shared memory buffer. --- pgtype/aclitem.go | 4 +++- pgtype/aclitem_array.go | 4 +++- pgtype/bool.go | 4 +++- pgtype/bool_array.go | 4 +++- pgtype/box.go | 4 +++- pgtype/bytea.go | 5 +---- pgtype/bytea_array.go | 4 +++- pgtype/cidr_array.go | 4 +++- pgtype/circle.go | 4 +++- pgtype/date.go | 4 +++- pgtype/date_array.go | 4 +++- pgtype/daterange.go | 10 ++++++---- pgtype/float4.go | 4 +++- pgtype/float4_array.go | 4 +++- pgtype/float8.go | 4 +++- pgtype/float8_array.go | 4 +++- pgtype/hstore.go | 4 +++- pgtype/hstore_array.go | 4 +++- pgtype/inet.go | 4 +++- pgtype/inet_array.go | 4 +++- pgtype/int2.go | 4 +++- pgtype/int2_array.go | 4 +++- pgtype/int4.go | 4 +++- pgtype/int4_array.go | 4 +++- pgtype/int4range.go | 10 ++++++---- pgtype/int8.go | 4 +++- pgtype/int8_array.go | 4 +++- pgtype/int8range.go | 10 ++++++---- pgtype/interval.go | 4 +++- pgtype/json.go | 9 ++++----- pgtype/jsonb.go | 6 +----- pgtype/line.go | 4 +++- pgtype/lseg.go | 4 +++- pgtype/macaddr.go | 4 +++- pgtype/numeric.go | 4 +++- pgtype/numeric_array.go | 4 +++- pgtype/numrange.go | 10 ++++++---- pgtype/oid.go | 4 +++- pgtype/path.go | 4 +++- pgtype/pgtype.go | 8 ++++---- pgtype/pguint32.go | 4 +++- pgtype/point.go | 4 +++- pgtype/polygon.go | 4 +++- pgtype/text.go | 4 +++- pgtype/text_array.go | 4 +++- pgtype/tid.go | 4 +++- pgtype/timestamp.go | 4 +++- pgtype/timestamp_array.go | 4 +++- pgtype/timestamptz.go | 4 +++- pgtype/timestamptz_array.go | 4 +++- pgtype/tsrange.go | 10 ++++++---- pgtype/tstzrange.go | 10 ++++++---- pgtype/typed_array.go.erb | 4 +++- pgtype/typed_range.go.erb | 4 +++- pgtype/uuid.go | 4 +++- pgtype/varbit.go | 9 ++++----- pgtype/varchar_array.go | 4 +++- 57 files changed, 188 insertions(+), 93 deletions(-) diff --git a/pgtype/aclitem.go b/pgtype/aclitem.go index ebfcc3e7..31065764 100644 --- a/pgtype/aclitem.go +++ b/pgtype/aclitem.go @@ -106,7 +106,9 @@ func (dst *Aclitem) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/aclitem_array.go b/pgtype/aclitem_array.go index 7ef76573..480b5bba 100644 --- a/pgtype/aclitem_array.go +++ b/pgtype/aclitem_array.go @@ -206,7 +206,9 @@ func (dst *AclitemArray) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/bool.go b/pgtype/bool.go index 9d309f0c..ba876c91 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -142,7 +142,9 @@ func (dst *Bool) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/bool_array.go b/pgtype/bool_array.go index 468f6816..4e92a616 100644 --- a/pgtype/bool_array.go +++ b/pgtype/bool_array.go @@ -308,7 +308,9 @@ func (dst *BoolArray) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/box.go b/pgtype/box.go index 2e4f39ee..e25af854 100644 --- a/pgtype/box.go +++ b/pgtype/box.go @@ -156,7 +156,9 @@ func (dst *Box) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/bytea.go b/pgtype/bytea.go index 3e2661db..bf774476 100644 --- a/pgtype/bytea.go +++ b/pgtype/bytea.go @@ -95,10 +95,7 @@ func (dst *Bytea) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } - buf := make([]byte, len(src)) - copy(buf, src) - - *dst = Bytea{Bytes: buf, Status: Present} + *dst = Bytea{Bytes: src, Status: Present} return nil } diff --git a/pgtype/bytea_array.go b/pgtype/bytea_array.go index 4aa2b862..dd79b991 100644 --- a/pgtype/bytea_array.go +++ b/pgtype/bytea_array.go @@ -308,7 +308,9 @@ func (dst *ByteaArray) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/cidr_array.go b/pgtype/cidr_array.go index 96d912ae..0aa289e7 100644 --- a/pgtype/cidr_array.go +++ b/pgtype/cidr_array.go @@ -337,7 +337,9 @@ func (dst *CidrArray) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/circle.go b/pgtype/circle.go index 8c8f4693..e9268a06 100644 --- a/pgtype/circle.go +++ b/pgtype/circle.go @@ -138,7 +138,9 @@ func (dst *Circle) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/date.go b/pgtype/date.go index 993a04c5..a7e4762a 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -185,7 +185,9 @@ func (dst *Date) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) case time.Time: *dst = Date{Time: src, Status: Present} return nil diff --git a/pgtype/date_array.go b/pgtype/date_array.go index f24bf6b9..91e2ee62 100644 --- a/pgtype/date_array.go +++ b/pgtype/date_array.go @@ -309,7 +309,9 @@ func (dst *DateArray) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/daterange.go b/pgtype/daterange.go index 5cecca20..a5cd5d95 100644 --- a/pgtype/daterange.go +++ b/pgtype/daterange.go @@ -106,7 +106,7 @@ func (dst *Daterange) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Daterange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Daterange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src *Daterange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } -func (src *Daterange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Daterange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -256,13 +256,15 @@ func (dst *Daterange) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. -func (src *Daterange) Value() (driver.Value, error) { +func (src Daterange) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/pgtype/float4.go b/pgtype/float4.go index 76be4203..77bc4878 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -177,7 +177,9 @@ func (dst *Float4) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/float4_array.go b/pgtype/float4_array.go index db1523f0..38508a52 100644 --- a/pgtype/float4_array.go +++ b/pgtype/float4_array.go @@ -308,7 +308,9 @@ func (dst *Float4Array) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/float8.go b/pgtype/float8.go index 8cfc53c5..5322e251 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -167,7 +167,9 @@ func (dst *Float8) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/float8_array.go b/pgtype/float8_array.go index 19878bbb..2f310bbd 100644 --- a/pgtype/float8_array.go +++ b/pgtype/float8_array.go @@ -308,7 +308,9 @@ func (dst *Float8Array) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/hstore.go b/pgtype/hstore.go index 04df2acc..69a35b17 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -455,7 +455,9 @@ func (dst *Hstore) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/hstore_array.go b/pgtype/hstore_array.go index e4263f20..9f773af2 100644 --- a/pgtype/hstore_array.go +++ b/pgtype/hstore_array.go @@ -308,7 +308,9 @@ func (dst *HstoreArray) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/inet.go b/pgtype/inet.go index e3a7ec88..7c09a549 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -213,7 +213,9 @@ func (dst *Inet) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/inet_array.go b/pgtype/inet_array.go index 4687b145..ed9f5d1c 100644 --- a/pgtype/inet_array.go +++ b/pgtype/inet_array.go @@ -337,7 +337,9 @@ func (dst *InetArray) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/int2.go b/pgtype/int2.go index 4a3beb22..028cdfcf 100644 --- a/pgtype/int2.go +++ b/pgtype/int2.go @@ -178,7 +178,9 @@ func (dst *Int2) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/int2_array.go b/pgtype/int2_array.go index 3506370e..cdfcde48 100644 --- a/pgtype/int2_array.go +++ b/pgtype/int2_array.go @@ -336,7 +336,9 @@ func (dst *Int2Array) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/int4.go b/pgtype/int4.go index f429d887..cae0d32a 100644 --- a/pgtype/int4.go +++ b/pgtype/int4.go @@ -169,7 +169,9 @@ func (dst *Int4) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/int4_array.go b/pgtype/int4_array.go index e4ec6455..9ca0b067 100644 --- a/pgtype/int4_array.go +++ b/pgtype/int4_array.go @@ -336,7 +336,9 @@ func (dst *Int4Array) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/int4range.go b/pgtype/int4range.go index 12a48dab..29b8371e 100644 --- a/pgtype/int4range.go +++ b/pgtype/int4range.go @@ -106,7 +106,7 @@ func (dst *Int4range) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Int4range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Int4range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src *Int4range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } -func (src *Int4range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Int4range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -256,13 +256,15 @@ func (dst *Int4range) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. -func (src *Int4range) Value() (driver.Value, error) { +func (src Int4range) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/pgtype/int8.go b/pgtype/int8.go index 97db8393..a4ec4e62 100644 --- a/pgtype/int8.go +++ b/pgtype/int8.go @@ -155,7 +155,9 @@ func (dst *Int8) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/int8_array.go b/pgtype/int8_array.go index 6c0dab65..c5026f83 100644 --- a/pgtype/int8_array.go +++ b/pgtype/int8_array.go @@ -336,7 +336,9 @@ func (dst *Int8Array) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/int8range.go b/pgtype/int8range.go index 3541dbe2..e3e0486f 100644 --- a/pgtype/int8range.go +++ b/pgtype/int8range.go @@ -106,7 +106,7 @@ func (dst *Int8range) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Int8range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Int8range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src *Int8range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } -func (src *Int8range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Int8range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -256,13 +256,15 @@ func (dst *Int8range) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. -func (src *Int8range) Value() (driver.Value, error) { +func (src Int8range) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/pgtype/interval.go b/pgtype/interval.go index 050d5610..8ce345a3 100644 --- a/pgtype/interval.go +++ b/pgtype/interval.go @@ -259,7 +259,9 @@ func (dst *Interval) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/json.go b/pgtype/json.go index a027a91c..44880863 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -97,10 +97,7 @@ func (dst *Json) DecodeText(ci *ConnInfo, src []byte) error { return nil } - buf := make([]byte, len(src)) - copy(buf, src) - - *dst = Json{Bytes: buf, Status: Present} + *dst = Json{Bytes: src, Status: Present} return nil } @@ -135,7 +132,9 @@ func (dst *Json) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/jsonb.go b/pgtype/jsonb.go index 82cbb21f..5533b4b4 100644 --- a/pgtype/jsonb.go +++ b/pgtype/jsonb.go @@ -37,12 +37,8 @@ func (dst *Jsonb) DecodeBinary(ci *ConnInfo, src []byte) error { if src[0] != 1 { return fmt.Errorf("unknown jsonb version number %d", src[0]) } - src = src[1:] - buf := make([]byte, len(src)) - copy(buf, src) - - *dst = Jsonb{Bytes: buf, Status: Present} + *dst = Jsonb{Bytes: src[1:], Status: Present} return nil } diff --git a/pgtype/line.go b/pgtype/line.go index 06f01f21..75fdf207 100644 --- a/pgtype/line.go +++ b/pgtype/line.go @@ -136,7 +136,9 @@ func (dst *Line) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/lseg.go b/pgtype/lseg.go index 986724cc..823c2c09 100644 --- a/pgtype/lseg.go +++ b/pgtype/lseg.go @@ -156,7 +156,9 @@ func (dst *Lseg) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/macaddr.go b/pgtype/macaddr.go index 0fe092e4..785148a2 100644 --- a/pgtype/macaddr.go +++ b/pgtype/macaddr.go @@ -142,7 +142,9 @@ func (dst *Macaddr) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/numeric.go b/pgtype/numeric.go index 63f99c06..8dbc0251 100644 --- a/pgtype/numeric.go +++ b/pgtype/numeric.go @@ -594,7 +594,9 @@ func (dst *Numeric) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/numeric_array.go b/pgtype/numeric_array.go index 3d59a6b0..2fc844eb 100644 --- a/pgtype/numeric_array.go +++ b/pgtype/numeric_array.go @@ -336,7 +336,9 @@ func (dst *NumericArray) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/numrange.go b/pgtype/numrange.go index b0baec9a..bac6fc4b 100644 --- a/pgtype/numrange.go +++ b/pgtype/numrange.go @@ -106,7 +106,7 @@ func (dst *Numrange) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Numrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Numrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src *Numrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } -func (src *Numrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Numrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -256,13 +256,15 @@ func (dst *Numrange) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. -func (src *Numrange) Value() (driver.Value, error) { +func (src Numrange) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/pgtype/oid.go b/pgtype/oid.go index 339dee0f..58a7b0f5 100644 --- a/pgtype/oid.go +++ b/pgtype/oid.go @@ -70,7 +70,9 @@ func (dst *Oid) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/path.go b/pgtype/path.go index 2fd6cfc7..c1aa76bc 100644 --- a/pgtype/path.go +++ b/pgtype/path.go @@ -195,7 +195,9 @@ func (dst *Path) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 27a1a091..3a6b7471 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -96,15 +96,15 @@ type Value interface { type BinaryDecoder interface { // DecodeBinary decodes src into BinaryDecoder. If src is nil then the - // original SQL value is NULL. BinaryDecoder MUST not retain a reference to - // src. It MUST make a copy if it needs to retain the raw bytes. + // original SQL value is NULL. BinaryDecoder takes ownership of src. The + // caller MUST not use it again. DecodeBinary(ci *ConnInfo, src []byte) error } type TextDecoder interface { // DecodeText decodes src into TextDecoder. If src is nil then the original - // SQL value is NULL. TextDecoder MUST not retain a reference to src. It MUST - // make a copy if it needs to retain the raw bytes. + // SQL value is NULL. TextDecoder takes ownership of src. The caller MUST not + // use it again. DecodeText(ci *ConnInfo, src []byte) error } diff --git a/pgtype/pguint32.go b/pgtype/pguint32.go index 0caa0cba..a13c1fcd 100644 --- a/pgtype/pguint32.go +++ b/pgtype/pguint32.go @@ -144,7 +144,9 @@ func (dst *pguint32) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/point.go b/pgtype/point.go index 3d51766e..62901340 100644 --- a/pgtype/point.go +++ b/pgtype/point.go @@ -130,7 +130,9 @@ func (dst *Point) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/polygon.go b/pgtype/polygon.go index af99ee3d..c4383765 100644 --- a/pgtype/polygon.go +++ b/pgtype/polygon.go @@ -174,7 +174,9 @@ func (dst *Polygon) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/text.go b/pgtype/text.go index 8e42a756..54e2d774 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -118,7 +118,9 @@ func (dst *Text) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/text_array.go b/pgtype/text_array.go index a6bd4724..8a573d83 100644 --- a/pgtype/text_array.go +++ b/pgtype/text_array.go @@ -308,7 +308,9 @@ func (dst *TextArray) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/tid.go b/pgtype/tid.go index 7976afde..7456b155 100644 --- a/pgtype/tid.go +++ b/pgtype/tid.go @@ -134,7 +134,9 @@ func (dst *Tid) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index 694b63c0..4fb10abc 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -201,7 +201,9 @@ func (dst *Timestamp) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) case time.Time: *dst = Timestamp{Time: src, Status: Present} return nil diff --git a/pgtype/timestamp_array.go b/pgtype/timestamp_array.go index 2046c387..49815dae 100644 --- a/pgtype/timestamp_array.go +++ b/pgtype/timestamp_array.go @@ -309,7 +309,9 @@ func (dst *TimestampArray) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index 3c76ec03..8606b2f2 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -197,7 +197,9 @@ func (dst *Timestamptz) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) case time.Time: *dst = Timestamptz{Time: src, Status: Present} return nil diff --git a/pgtype/timestamptz_array.go b/pgtype/timestamptz_array.go index fd58d3be..bf983b6b 100644 --- a/pgtype/timestamptz_array.go +++ b/pgtype/timestamptz_array.go @@ -309,7 +309,9 @@ func (dst *TimestamptzArray) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/tsrange.go b/pgtype/tsrange.go index 78a94af2..429a5cbe 100644 --- a/pgtype/tsrange.go +++ b/pgtype/tsrange.go @@ -106,7 +106,7 @@ func (dst *Tsrange) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Tsrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Tsrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src *Tsrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } -func (src *Tsrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Tsrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -256,13 +256,15 @@ func (dst *Tsrange) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. -func (src *Tsrange) Value() (driver.Value, error) { +func (src Tsrange) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/pgtype/tstzrange.go b/pgtype/tstzrange.go index d1fc7326..f03a9f65 100644 --- a/pgtype/tstzrange.go +++ b/pgtype/tstzrange.go @@ -106,7 +106,7 @@ func (dst *Tstzrange) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Tstzrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Tstzrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src *Tstzrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } -func (src *Tstzrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Tstzrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -256,13 +256,15 @@ func (dst *Tstzrange) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. -func (src *Tstzrange) Value() (driver.Value, error) { +func (src Tstzrange) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb index 2a38ed82..6752bd5b 100644 --- a/pgtype/typed_array.go.erb +++ b/pgtype/typed_array.go.erb @@ -310,7 +310,9 @@ func (dst *<%= pgtype_array_type %>) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/typed_range.go.erb b/pgtype/typed_range.go.erb index e46f71c7..49db1b1d 100644 --- a/pgtype/typed_range.go.erb +++ b/pgtype/typed_range.go.erb @@ -256,7 +256,9 @@ func (dst *<%= range_type %>) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/uuid.go b/pgtype/uuid.go index c830c086..a4a93ab3 100644 --- a/pgtype/uuid.go +++ b/pgtype/uuid.go @@ -161,7 +161,9 @@ func (dst *Uuid) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/varbit.go b/pgtype/varbit.go index 00c34e10..b986f02a 100644 --- a/pgtype/varbit.go +++ b/pgtype/varbit.go @@ -72,10 +72,7 @@ func (dst *Varbit) DecodeBinary(ci *ConnInfo, src []byte) error { bitLen := int32(binary.BigEndian.Uint32(src)) rp := 4 - buf := make([]byte, len(src[rp:])) - copy(buf, src[rp:]) - - *dst = Varbit{Bytes: buf, Len: bitLen, Status: Present} + *dst = Varbit{Bytes: src[rp:], Len: bitLen, Status: Present} return nil } @@ -129,7 +126,9 @@ func (dst *Varbit) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) diff --git a/pgtype/varchar_array.go b/pgtype/varchar_array.go index 9ca16d7e..d84fac02 100644 --- a/pgtype/varchar_array.go +++ b/pgtype/varchar_array.go @@ -308,7 +308,9 @@ func (dst *VarcharArray) Scan(src interface{}) error { case string: return dst.DecodeText(nil, []byte(src)) case []byte: - return dst.DecodeText(nil, src) + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) } return fmt.Errorf("cannot scan %T", src) From a5f702c41dd243a33e769b5211a2eda0a9d36124 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 29 Apr 2017 13:21:25 -0500 Subject: [PATCH 177/264] Reduce allocs and copies --- messages.go | 14 ++++++++++++++ values.go | 27 ++++++++++++--------------- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/messages.go b/messages.go index e229367a..0f17a6d2 100644 --- a/messages.go +++ b/messages.go @@ -118,6 +118,20 @@ func (wb *WriteBuf) closeMsg() { binary.BigEndian.PutUint32(wb.buf[wb.sizeIdx:wb.sizeIdx+4], uint32(len(wb.buf)-wb.sizeIdx)) } +func (wb *WriteBuf) reserveSize() int { + sizePosition := len(wb.buf) + wb.buf = append(wb.buf, 0, 0, 0, 0) + return sizePosition +} + +func (wb *WriteBuf) setComputedSize(sizePosition int) { + binary.BigEndian.PutUint32(wb.buf[sizePosition:], uint32(len(wb.buf)-sizePosition-4)) +} + +func (wb *WriteBuf) setSize(sizePosition int, size int32) { + binary.BigEndian.PutUint32(wb.buf[sizePosition:], uint32(size)) +} + func (wb *WriteBuf) WriteByte(b byte) { wb.buf = append(wb.buf, b) } diff --git a/values.go b/values.go index 3565df34..da12952a 100644 --- a/values.go +++ b/values.go @@ -106,29 +106,27 @@ func encodePreparedStatementArgument(wbuf *WriteBuf, oid pgtype.Oid, arg interfa switch arg := arg.(type) { case pgtype.BinaryEncoder: - buf := &bytes.Buffer{} - null, err := arg.EncodeBinary(wbuf.conn.ConnInfo, buf) + sp := wbuf.reserveSize() + null, err := arg.EncodeBinary(wbuf.conn.ConnInfo, wbuf) if err != nil { return err } if null { - wbuf.WriteInt32(-1) + wbuf.setSize(sp, -1) } else { - wbuf.WriteInt32(int32(buf.Len())) - wbuf.WriteBytes(buf.Bytes()) + wbuf.setComputedSize(sp) } return nil case pgtype.TextEncoder: - buf := &bytes.Buffer{} - null, err := arg.EncodeText(wbuf.conn.ConnInfo, buf) + sp := wbuf.reserveSize() + null, err := arg.EncodeText(wbuf.conn.ConnInfo, wbuf) if err != nil { return err } if null { - wbuf.WriteInt32(-1) + wbuf.setSize(sp, -1) } else { - wbuf.WriteInt32(int32(buf.Len())) - wbuf.WriteBytes(buf.Bytes()) + wbuf.setComputedSize(sp) } return nil case driver.Valuer: @@ -161,16 +159,15 @@ func encodePreparedStatementArgument(wbuf *WriteBuf, oid pgtype.Oid, arg interfa return err } - buf := &bytes.Buffer{} - null, err := value.(pgtype.BinaryEncoder).EncodeBinary(wbuf.conn.ConnInfo, buf) + sp := wbuf.reserveSize() + null, err := value.(pgtype.BinaryEncoder).EncodeBinary(wbuf.conn.ConnInfo, wbuf) if err != nil { return err } if null { - wbuf.WriteInt32(-1) + wbuf.setSize(sp, -1) } else { - wbuf.WriteInt32(int32(buf.Len())) - wbuf.WriteBytes(buf.Bytes()) + wbuf.setComputedSize(sp) } return nil } From 353ca7c5c786dda76d880f3bcc64b123a789eb7b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 29 Apr 2017 13:38:56 -0500 Subject: [PATCH 178/264] Fix travis --- .travis.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index edacab39..76311d4c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -52,7 +52,6 @@ install: - go get -u github.com/shopspring/decimal - go get -u gopkg.in/inconshreveable/log15.v2 - go get -u github.com/jackc/fake - - go get -u github.com/jackc/pgmock/pgproto3 - go get -u github.com/lib/pq - go get -u github.com/hashicorp/go-version - go get -u github.com/satori/go.uuid From 855b735eaeb351c41875241bc2bc8fe31bebffc0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 29 Apr 2017 20:33:52 -0500 Subject: [PATCH 179/264] Add log adapters for testing and log15 Make LogLevel a type for Stringer interface. --- bench_test.go | 2 +- conn.go | 2 +- conn_test.go | 4 ++-- examples/url_shortener/main.go | 27 ++++------------------ log/log15adapter/adapter.go | 42 ++++++++++++++++++++++++++++++++++ log/testingadapter/adapter.go | 25 ++++++++++++++++++++ logger.go | 27 ++++++++++++++++++++-- stdlib/sql_test.go | 4 ++-- 8 files changed, 103 insertions(+), 30 deletions(-) create mode 100644 log/log15adapter/adapter.go create mode 100644 log/testingadapter/adapter.go diff --git a/bench_test.go b/bench_test.go index 69d17c39..91c73293 100644 --- a/bench_test.go +++ b/bench_test.go @@ -179,7 +179,7 @@ func BenchmarkSelectWithoutLogging(b *testing.B) { type discardLogger struct{} -func (dl discardLogger) Log(level int, msg string, ctx ...interface{}) {} +func (dl discardLogger) Log(level pgx.LogLevel, msg string, ctx ...interface{}) {} func BenchmarkSelectWithLoggingTraceDiscard(b *testing.B) { conn := mustConnect(b, *defaultConnConfig) diff --git a/conn.go b/conn.go index 7487b8ad..81a38df9 100644 --- a/conn.go +++ b/conn.go @@ -1213,7 +1213,7 @@ func (c *Conn) shouldLog(lvl int) bool { return c.logger != nil && c.logLevel >= lvl } -func (c *Conn) log(lvl int, msg string, ctx ...interface{}) { +func (c *Conn) log(lvl LogLevel, msg string, ctx ...interface{}) { if c.pid != 0 { ctx = append(ctx, "pid", c.PID) } diff --git a/conn_test.go b/conn_test.go index d4ca593f..8f47d995 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1739,7 +1739,7 @@ func TestCatchSimultaneousConnectionQueryAndExec(t *testing.T) { } type testLog struct { - lvl int + lvl pgx.LogLevel msg string ctx []interface{} } @@ -1748,7 +1748,7 @@ type testLogger struct { logs []testLog } -func (l *testLogger) Log(level int, msg string, ctx ...interface{}) { +func (l *testLogger) Log(level pgx.LogLevel, msg string, ctx ...interface{}) { l.logs = append(l.logs, testLog{lvl: level, msg: msg, ctx: ctx}) } diff --git a/examples/url_shortener/main.go b/examples/url_shortener/main.go index 25e4cb90..8380ef3f 100644 --- a/examples/url_shortener/main.go +++ b/examples/url_shortener/main.go @@ -1,11 +1,13 @@ package main import ( - "github.com/jackc/pgx" - log "gopkg.in/inconshreveable/log15.v2" "io/ioutil" "net/http" "os" + + "github.com/jackc/pgx" + "github.com/jackc/pgx/log/log15adapter" + log "gopkg.in/inconshreveable/log15.v2" ) var pool *pgx.ConnPool @@ -98,27 +100,8 @@ func urlHandler(w http.ResponseWriter, req *http.Request) { } } -type log15Adapter struct { - logger log.Logger -} - -func (a *log15Adapter) Log(level int, msg string, ctx ...interface{}) { - switch level { - case pgx.LogLevelTrace, pgx.LogLevelDebug: - a.logger.Debug(msg, ctx...) - case pgx.LogLevelInfo: - a.logger.Info(msg, ctx...) - case pgx.LogLevelWarn: - a.logger.Warn(msg, ctx...) - case pgx.LogLevelError: - a.logger.Error(msg, ctx...) - default: - panic("invalid log level") - } -} - func main() { - logger := &log15Adapter{logger: log.New("module", "pgx")} + logger := log15adapter.NewLogger(log.New("module", "pgx")) var err error connPoolConfig := pgx.ConnPoolConfig{ diff --git a/log/log15adapter/adapter.go b/log/log15adapter/adapter.go new file mode 100644 index 00000000..55d1b79f --- /dev/null +++ b/log/log15adapter/adapter.go @@ -0,0 +1,42 @@ +// Package log15adapter provides a logger that writes to a github.com/inconshreveable/log15.Logger +// log. +package log15adapter + +import ( + "github.com/jackc/pgx" +) + +// Log15Logger interface defines the subset of +// github.com/inconshreveable/log15.Logger that this adapter uses. +type Log15Logger interface { + Debug(msg string, ctx ...interface{}) + Info(msg string, ctx ...interface{}) + Warn(msg string, ctx ...interface{}) + Error(msg string, ctx ...interface{}) + Crit(msg string, ctx ...interface{}) +} + +type Logger struct { + l Log15Logger +} + +func NewLogger(l Log15Logger) *Logger { + return &Logger{l: l} +} + +func (l *Logger) Log(level pgx.LogLevel, msg string, ctx ...interface{}) { + switch level { + case pgx.LogLevelTrace: + l.l.Debug(msg, append(ctx, "PGX_LOG_LEVEL", level)...) + case pgx.LogLevelDebug: + l.l.Debug(msg, ctx...) + case pgx.LogLevelInfo: + l.l.Info(msg, ctx...) + case pgx.LogLevelWarn: + l.l.Warn(msg, ctx...) + case pgx.LogLevelError: + l.l.Error(msg, ctx...) + default: + l.l.Error(msg, append(ctx, "INVALID_PGX_LOG_LEVEL", level)...) + } +} diff --git a/log/testingadapter/adapter.go b/log/testingadapter/adapter.go new file mode 100644 index 00000000..f042c4f1 --- /dev/null +++ b/log/testingadapter/adapter.go @@ -0,0 +1,25 @@ +// Package testingadapter provides a logger that writes to a test or benchmark +// log. +package testingadapter + +import ( + "github.com/jackc/pgx" +) + +// TestingLogger interface defines the subset of testing.TB methods used by this +// adapter. +type TestingLogger interface { + Log(args ...interface{}) +} + +type Logger struct { + l TestingLogger +} + +func NewLogger(l TestingLogger) *Logger { + return &Logger{l: l} +} + +func (l *Logger) Log(level pgx.LogLevel, msg string, ctx ...interface{}) { + l.l.Log(level, msg, ctx) +} diff --git a/logger.go b/logger.go index e4d28fed..f1b85322 100644 --- a/logger.go +++ b/logger.go @@ -17,10 +17,33 @@ const ( LogLevelNone = 1 ) +// LogLevel represents the pgx logging level. See LogLevel* constants for +// possible values. +type LogLevel int + +func (ll LogLevel) String() string { + switch ll { + case LogLevelTrace: + return "trace" + case LogLevelDebug: + return "debug" + case LogLevelInfo: + return "info" + case LogLevelWarn: + return "warn" + case LogLevelError: + return "error" + case LogLevelNone: + return "none" + default: + return fmt.Sprintf("invalid level %d", ll) + } +} + // Logger is the interface used to get logging from pgx internals. type Logger interface { // Log a message at the given level with context key/value pairs - Log(level int, msg string, ctx ...interface{}) + Log(level LogLevel, msg string, ctx ...interface{}) } // LogLevelFromString converts log level string to constant @@ -32,7 +55,7 @@ type Logger interface { // warn // error // none -func LogLevelFromString(s string) (int, error) { +func LogLevelFromString(s string) (LogLevel, error) { switch s { case "trace": return LogLevelTrace, nil diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 641ba9fe..dadafd41 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -365,7 +365,7 @@ func TestConnQuery(t *testing.T) { } type testLog struct { - lvl int + lvl pgx.LogLevel msg string ctx []interface{} } @@ -374,7 +374,7 @@ type testLogger struct { logs []testLog } -func (l *testLogger) Log(lvl int, msg string, ctx ...interface{}) { +func (l *testLogger) Log(lvl pgx.LogLevel, msg string, ctx ...interface{}) { l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, ctx: ctx}) } From 280bce7078eb67348152eae8dc65ba642299cea1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 29 Apr 2017 21:28:38 -0500 Subject: [PATCH 180/264] Added log adapter for logrus Also changed standard logger interface to take a map instead of varargs for extra data. --- bench_test.go | 2 +- conn.go | 37 +++++++++++++++++--------------- conn_pool.go | 2 +- conn_test.go | 10 ++++----- doc.go | 5 +++-- log/log15adapter/adapter.go | 19 +++++++++++------ log/logrusadapter/adapter.go | 40 +++++++++++++++++++++++++++++++++++ log/testingadapter/adapter.go | 11 ++++++++-- logger.go | 4 ++-- query.go | 4 ++-- replication.go | 10 ++++----- stdlib/sql_test.go | 12 +++++------ v3.md | 4 ++-- 13 files changed, 108 insertions(+), 52 deletions(-) create mode 100644 log/logrusadapter/adapter.go diff --git a/bench_test.go b/bench_test.go index 91c73293..d3525df5 100644 --- a/bench_test.go +++ b/bench_test.go @@ -179,7 +179,7 @@ func BenchmarkSelectWithoutLogging(b *testing.B) { type discardLogger struct{} -func (dl discardLogger) Log(level pgx.LogLevel, msg string, ctx ...interface{}) {} +func (dl discardLogger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {} func BenchmarkSelectWithLoggingTraceDiscard(b *testing.B) { conn := mustConnect(b, *defaultConnConfig) diff --git a/conn.go b/conn.go index 81a38df9..bca9f6d8 100644 --- a/conn.go +++ b/conn.go @@ -222,14 +222,14 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) } c.config.User = user.Username if c.shouldLog(LogLevelDebug) { - c.log(LogLevelDebug, "Using default connection config", "User", c.config.User) + c.log(LogLevelDebug, "Using default connection config", map[string]interface{}{"User": c.config.User}) } } if c.config.Port == 0 { c.config.Port = 5432 if c.shouldLog(LogLevelDebug) { - c.log(LogLevelDebug, "Using default connection config", "Port", c.config.Port) + c.log(LogLevelDebug, "Using default connection config", map[string]interface{}{"Port": c.config.Port}) } } @@ -239,19 +239,19 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) } if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address)) + c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"network": network, "address": address}) } err = c.connect(config, network, address, config.TLSConfig) if err != nil && config.UseFallbackTLS { if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, fmt.Sprintf("Connect with TLSConfig failed, trying FallbackTLSConfig: %v", err)) + c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err}) } err = c.connect(config, network, address, config.FallbackTLSConfig) } if err != nil { if c.shouldLog(LogLevelError) { - c.log(LogLevelError, fmt.Sprintf("Connect failed: %v", err)) + c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err}) } return nil, err } @@ -282,7 +282,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl if tlsConfig != nil { if c.shouldLog(LogLevelDebug) { - c.log(LogLevelDebug, "Starting TLS handshake") + c.log(LogLevelDebug, "starting TLS handshake", nil) } if err := c.startTLS(tlsConfig); err != nil { return err @@ -334,7 +334,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl case *pgproto3.ReadyForQuery: c.rxReadyForQuery(msg) if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "Connection established") + c.log(LogLevelInfo, "connection established", nil) } // Replication connections can't execute the queries to @@ -414,25 +414,25 @@ func (c *Conn) Close() (err error) { c.conn.Close() c.die(errors.New("Closed")) if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "Closed connection") + c.log(LogLevelInfo, "closed connection", nil) } }() err = c.conn.SetDeadline(time.Time{}) if err != nil && c.shouldLog(LogLevelWarn) { - c.log(LogLevelWarn, "Failed to clear deadlines to send close message", "err", err) + c.log(LogLevelWarn, "failed to clear deadlines to send close message", map[string]interface{}{"err": err}) return err } _, err = c.conn.Write([]byte{'X', 0, 0, 0, 4}) if err != nil && c.shouldLog(LogLevelWarn) { - c.log(LogLevelWarn, "Failed to send terminate message", "err", err) + c.log(LogLevelWarn, "failed to send terminate message", map[string]interface{}{"err": err}) return err } err = c.conn.SetReadDeadline(time.Now().Add(5 * time.Second)) if err != nil && c.shouldLog(LogLevelWarn) { - c.log(LogLevelWarn, "Failed to set read deadline to finish closing", "err", err) + c.log(LogLevelWarn, "failed to set read deadline to finish closing", map[string]interface{}{"err": err}) return err } @@ -701,7 +701,7 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared if c.shouldLog(LogLevelError) { defer func() { if err != nil { - c.log(LogLevelError, fmt.Sprintf("Prepare `%s` as `%s` failed: %v", name, sql, err)) + c.log(LogLevelError, "prepareEx failed", map[string]interface{}{"err": err, "name": name, "sql": sql}) } }() } @@ -1213,12 +1213,15 @@ func (c *Conn) shouldLog(lvl int) bool { return c.logger != nil && c.logLevel >= lvl } -func (c *Conn) log(lvl LogLevel, msg string, ctx ...interface{}) { +func (c *Conn) log(lvl LogLevel, msg string, data map[string]interface{}) { + if data == nil { + data = map[string]interface{}{} + } if c.pid != 0 { - ctx = append(ctx, "pid", c.PID) + data["pid"] = c.PID } - c.logger.Log(lvl, msg, ctx...) + c.logger.Log(lvl, msg, data) } // SetLogger replaces the current logger and returns the previous logger. @@ -1327,11 +1330,11 @@ func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, if err == nil { if c.shouldLog(LogLevelInfo) { endTime := time.Now() - c.log(LogLevelInfo, "Exec", "sql", sql, "args", logQueryArgs(arguments), "time", endTime.Sub(startTime), "commandTag", commandTag) + c.log(LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag}) } } else { if c.shouldLog(LogLevelError) { - c.log(LogLevelError, "Exec", "sql", sql, "args", logQueryArgs(arguments), "error", err) + c.log(LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err}) } } diff --git a/conn_pool.go b/conn_pool.go index 8703d7fa..7bc022d0 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -164,7 +164,7 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) { } // All connections are in use and we cannot create more if p.logLevel >= LogLevelWarn { - p.logger.Log(LogLevelWarn, "All connections in pool are busy - waiting...") + p.logger.Log(LogLevelWarn, "waiting for available connection", nil) } // Wait until there is an available connection OR room to create a new connection diff --git a/conn_test.go b/conn_test.go index 8f47d995..f887e030 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1739,17 +1739,17 @@ func TestCatchSimultaneousConnectionQueryAndExec(t *testing.T) { } type testLog struct { - lvl pgx.LogLevel - msg string - ctx []interface{} + lvl pgx.LogLevel + msg string + data map[string]interface{} } type testLogger struct { logs []testLog } -func (l *testLogger) Log(level pgx.LogLevel, msg string, ctx ...interface{}) { - l.logs = append(l.logs, testLog{lvl: level, msg: msg, ctx: ctx}) +func (l *testLogger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) { + l.logs = append(l.logs, testLog{lvl: level, msg: msg, data: data}) } func TestSetLogger(t *testing.T) { diff --git a/doc.go b/doc.go index 0921242a..a0f0bd72 100644 --- a/doc.go +++ b/doc.go @@ -239,7 +239,8 @@ connection. Logging pgx defines a simple logger interface. Connections optionally accept a logger -that satisfies this interface. Set LogLevel to control logging -verbosity. +that satisfies this interface. Set LogLevel to control logging verbosity. +Adapters for github.com/inconshreveable/log15, github.com/Sirupsen/logrus, and +the testing log are provided in the log directory. */ package pgx diff --git a/log/log15adapter/adapter.go b/log/log15adapter/adapter.go index 55d1b79f..8623a380 100644 --- a/log/log15adapter/adapter.go +++ b/log/log15adapter/adapter.go @@ -24,19 +24,24 @@ func NewLogger(l Log15Logger) *Logger { return &Logger{l: l} } -func (l *Logger) Log(level pgx.LogLevel, msg string, ctx ...interface{}) { +func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) { + logArgs := make([]interface{}, 0, len(data)) + for k, v := range data { + logArgs = append(logArgs, k, v) + } + switch level { case pgx.LogLevelTrace: - l.l.Debug(msg, append(ctx, "PGX_LOG_LEVEL", level)...) + l.l.Debug(msg, append(logArgs, "PGX_LOG_LEVEL", level)...) case pgx.LogLevelDebug: - l.l.Debug(msg, ctx...) + l.l.Debug(msg, logArgs...) case pgx.LogLevelInfo: - l.l.Info(msg, ctx...) + l.l.Info(msg, logArgs...) case pgx.LogLevelWarn: - l.l.Warn(msg, ctx...) + l.l.Warn(msg, logArgs...) case pgx.LogLevelError: - l.l.Error(msg, ctx...) + l.l.Error(msg, logArgs...) default: - l.l.Error(msg, append(ctx, "INVALID_PGX_LOG_LEVEL", level)...) + l.l.Error(msg, append(logArgs, "INVALID_PGX_LOG_LEVEL", level)...) } } diff --git a/log/logrusadapter/adapter.go b/log/logrusadapter/adapter.go new file mode 100644 index 00000000..6084c36c --- /dev/null +++ b/log/logrusadapter/adapter.go @@ -0,0 +1,40 @@ +// Package logrusadapter provides a logger that writes to a github.com/Sirupsen/logrus.Logger +// log. +package logrusadapter + +import ( + "github.com/Sirupsen/logrus" + "github.com/jackc/pgx" +) + +type Logger struct { + l *logrus.Logger +} + +func NewLogger(l *logrus.Logger) *Logger { + return &Logger{l: l} +} + +func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) { + var logger logrus.FieldLogger + if data != nil { + logger = l.l.WithFields(data) + } else { + logger = l.l + } + + switch level { + case pgx.LogLevelTrace: + logger.WithField("PGX_LOG_LEVEL", level).Debug(msg) + case pgx.LogLevelDebug: + logger.Debug(msg) + case pgx.LogLevelInfo: + logger.Info(msg) + case pgx.LogLevelWarn: + logger.Warn(msg) + case pgx.LogLevelError: + logger.Error(msg) + default: + logger.WithField("INVALID_PGX_LOG_LEVEL", level).Error(msg) + } +} diff --git a/log/testingadapter/adapter.go b/log/testingadapter/adapter.go index f042c4f1..6c9cde83 100644 --- a/log/testingadapter/adapter.go +++ b/log/testingadapter/adapter.go @@ -3,6 +3,8 @@ package testingadapter import ( + "fmt" + "github.com/jackc/pgx" ) @@ -20,6 +22,11 @@ func NewLogger(l TestingLogger) *Logger { return &Logger{l: l} } -func (l *Logger) Log(level pgx.LogLevel, msg string, ctx ...interface{}) { - l.l.Log(level, msg, ctx) +func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) { + logArgs := make([]interface{}, 0, 2+len(data)) + logArgs = append(logArgs, level, msg) + for k, v := range data { + logArgs = append(logArgs, fmt.Sprintf("%s=%v", k, v)) + } + l.l.Log(logArgs...) } diff --git a/logger.go b/logger.go index f1b85322..c2df1d7d 100644 --- a/logger.go +++ b/logger.go @@ -42,8 +42,8 @@ func (ll LogLevel) String() string { // Logger is the interface used to get logging from pgx internals. type Logger interface { - // Log a message at the given level with context key/value pairs - Log(level LogLevel, msg string, ctx ...interface{}) + // Log a message at the given level with data key/value pairs. data may be nil. + Log(level LogLevel, msg string, data map[string]interface{}) } // LogLevelFromString converts log level string to constant diff --git a/query.go b/query.go index 04a87043..3d081714 100644 --- a/query.go +++ b/query.go @@ -78,10 +78,10 @@ func (rows *Rows) Close() { if rows.err == nil { if rows.conn.shouldLog(LogLevelInfo) { endTime := time.Now() - rows.conn.log(LogLevelInfo, "Query", "sql", rows.sql, "args", logQueryArgs(rows.args), "time", endTime.Sub(rows.startTime), "rowCount", rows.rowCount) + rows.conn.log(LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount}) } } else if rows.conn.shouldLog(LogLevelError) { - rows.conn.log(LogLevelError, "Query", "sql", rows.sql, "args", logQueryArgs(rows.args)) + rows.conn.log(LogLevelError, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args)}) } if rows.afterClose != nil { diff --git a/replication.go b/replication.go index ea768961..594944e0 100644 --- a/replication.go +++ b/replication.go @@ -215,12 +215,12 @@ func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err case *pgproto3.NoticeResponse: pgError := rc.c.rxErrorResponse((*pgproto3.ErrorResponse)(msg)) if rc.c.shouldLog(LogLevelInfo) { - rc.c.log(LogLevelInfo, pgError.Error()) + rc.c.log(LogLevelInfo, pgError.Error(), nil) } case *pgproto3.ErrorResponse: err = rc.c.rxErrorResponse(msg) if rc.c.shouldLog(LogLevelError) { - rc.c.log(LogLevelError, err.Error()) + rc.c.log(LogLevelError, err.Error(), nil) } return case *pgproto3.CopyBothResponse: @@ -258,12 +258,12 @@ func (rc *ReplicationConn) readReplicationMessage() (r *ReplicationMessage, err return &ReplicationMessage{ServerHeartbeat: h}, nil default: if rc.c.shouldLog(LogLevelError) { - rc.c.log(LogLevelError, "Unexpected data playload message type %v", msgType) + rc.c.log(LogLevelError, "Unexpected data playload message type", map[string]interface{}{"type": msgType}) } } default: if rc.c.shouldLog(LogLevelError) { - rc.c.log(LogLevelError, "Unexpected replication message type %T", msg) + rc.c.log(LogLevelError, "Unexpected replication message type", map[string]interface{}{"type": msg}) } } return @@ -421,7 +421,7 @@ func (rc *ReplicationConn) StartReplication(slotName string, startLsn uint64, ti r, err = rc.WaitForReplicationMessage(ctx) if err != nil && r != nil { if rc.c.shouldLog(LogLevelError) { - rc.c.log(LogLevelError, "Unxpected replication message %v", r) + rc.c.log(LogLevelError, "Unexpected replication message", map[string]interface{}{"msg": r, "err": err}) } } diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index dadafd41..ba74560d 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -365,17 +365,17 @@ func TestConnQuery(t *testing.T) { } type testLog struct { - lvl pgx.LogLevel - msg string - ctx []interface{} + lvl pgx.LogLevel + msg string + data map[string]interface{} } type testLogger struct { logs []testLog } -func (l *testLogger) Log(lvl pgx.LogLevel, msg string, ctx ...interface{}) { - l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, ctx: ctx}) +func (l *testLogger) Log(lvl pgx.LogLevel, msg string, data map[string]interface{}) { + l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data}) } func TestConnQueryLog(t *testing.T) { @@ -416,7 +416,7 @@ func TestConnQueryLog(t *testing.T) { t.Errorf("Expected to log Query, but got %v", l) } - if !(l.ctx[0] == "sql" && l.ctx[1] == "select 1") { + if l.data["sql"] != "select 1" { t.Errorf("Expected to log Query with sql 'select 1', but got %v", l) } } diff --git a/v3.md b/v3.md index 2946bcf0..72110888 100644 --- a/v3.md +++ b/v3.md @@ -50,8 +50,6 @@ Remove names from prepared statements - use database/sql style objects Better way of handling text/binary protocol choice than pgx.DefaultTypeFormats or manually editing a PreparedStatement. Possibly an optional part of preparing a statement is specifying the format and/or a decoder. Or maybe it is part of a QueryEx call... Could be very interesting to make encoding and decoding possible without being a method of the type. This could drastically clean up those huge type switches. -msgReader cleanup - Make easier / possible to mock Conn or ConnPool (https://github.com/jackc/pgx/pull/162) Every field that should not be set by user should be replaced by accessor method (only ones left are Conn.RuntimeParams and Conn.PgTypes) @@ -68,3 +66,5 @@ something like: select array[1,2,3], array[4,5,6,7] Reconsider synonym types like varchar/text and numeric/decimal. + +integrate logging and context - should be able to replace logger via context OR inject params into log from context From 4c24c635a9c2b9a787f76e7d9ab1151faed71a79 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 1 May 2017 18:11:55 -0500 Subject: [PATCH 181/264] Add pgproto3.Backend --- pgproto3/backend.go | 74 ++++++++++++++++ pgproto3/bind.go | 167 +++++++++++++++++++++++++++++++++++ pgproto3/describe.go | 60 +++++++++++++ pgproto3/execute.go | 60 +++++++++++++ pgproto3/parse.go | 82 +++++++++++++++++ pgproto3/password_message.go | 44 +++++++++ pgproto3/sync.go | 29 ++++++ pgproto3/terminate.go | 29 ++++++ 8 files changed, 545 insertions(+) create mode 100644 pgproto3/backend.go create mode 100644 pgproto3/bind.go create mode 100644 pgproto3/describe.go create mode 100644 pgproto3/execute.go create mode 100644 pgproto3/parse.go create mode 100644 pgproto3/password_message.go create mode 100644 pgproto3/sync.go create mode 100644 pgproto3/terminate.go diff --git a/pgproto3/backend.go b/pgproto3/backend.go new file mode 100644 index 00000000..c04116a8 --- /dev/null +++ b/pgproto3/backend.go @@ -0,0 +1,74 @@ +package pgproto3 + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/jackc/pgx/chunkreader" +) + +type Backend struct { + cr *chunkreader.ChunkReader + w io.Writer + + // Frontend message flyweights + bind Bind + describe Describe + execute Execute + parse Parse + passwordMessage PasswordMessage + query Query + sync Sync + terminate Terminate +} + +func NewBackend(r io.Reader, w io.Writer) (*Backend, error) { + cr := chunkreader.NewChunkReader(r) + return &Backend{cr: cr, w: w}, nil +} + +func (b *Backend) Send(msg BackendMessage) error { + return errors.New("not implemented") +} + +func (b *Backend) Receive() (FrontendMessage, error) { + header, err := b.cr.Next(5) + if err != nil { + return nil, err + } + + msgType := header[0] + bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4 + + var msg FrontendMessage + switch msgType { + case 'B': + msg = &b.bind + case 'D': + msg = &b.describe + case 'E': + msg = &b.execute + case 'P': + msg = &b.parse + case 'p': + msg = &b.passwordMessage + case 'Q': + msg = &b.query + case 'S': + msg = &b.sync + case 'X': + msg = &b.terminate + default: + return nil, fmt.Errorf("unknown message type: %c", msgType) + } + + msgBody, err := b.cr.Next(bodyLen) + if err != nil { + return nil, err + } + + err = msg.Decode(msgBody) + return msg, err +} diff --git a/pgproto3/bind.go b/pgproto3/bind.go new file mode 100644 index 00000000..6661a775 --- /dev/null +++ b/pgproto3/bind.go @@ -0,0 +1,167 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "encoding/json" +) + +type Bind struct { + DestinationPortal string + PreparedStatement string + ParameterFormatCodes []int16 + Parameters [][]byte + ResultFormatCodes []int16 +} + +func (*Bind) Frontend() {} + +func (dst *Bind) Decode(src []byte) error { + idx := bytes.IndexByte(src, 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + dst.DestinationPortal = string(src[:idx]) + rp := idx + 1 + + idx = bytes.IndexByte(src[rp:], 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + dst.PreparedStatement = string(src[rp : rp+idx]) + rp += idx + 1 + + if len(src[rp:]) < 2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + parameterFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + dst.ParameterFormatCodes = make([]int16, parameterFormatCodeCount) + + if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + for i := 0; i < parameterFormatCodeCount; i++ { + dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + } + + if len(src[rp:]) < 2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + parameterCount := int(binary.BigEndian.Uint16(src[rp:])) + + dst.Parameters = make([][]byte, parameterCount) + + for i := 0; i < parameterCount; i++ { + if len(src[rp:]) < 4 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + + msgSize := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + // null + if msgSize == -1 { + continue + } + + if len(src[rp:]) < msgSize { + return &invalidMessageFormatErr{messageType: "Bind"} + } + + dst.Parameters[i] = src[rp : rp+msgSize] + rp += msgSize + } + + if len(src[rp:]) < 2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + resultFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + + dst.ResultFormatCodes = make([]int16, resultFormatCodeCount) + if len(src[rp:]) < len(dst.ResultFormatCodes)*2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + for i := 0; i < resultFormatCodeCount; i++ { + dst.ResultFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + } + + return nil +} + +func (src *Bind) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('B') + buf.Write(bigEndian.Uint32(0)) + + buf.WriteString(src.DestinationPortal) + buf.WriteByte(0) + buf.WriteString(src.PreparedStatement) + buf.WriteByte(0) + + buf.Write(bigEndian.Uint16(uint16(len(src.ParameterFormatCodes)))) + + for _, fc := range src.ParameterFormatCodes { + buf.Write(bigEndian.Int16(fc)) + } + + buf.Write(bigEndian.Uint16(uint16(len(src.Parameters)))) + + for _, p := range src.Parameters { + if p == nil { + buf.Write(bigEndian.Int32(-1)) + continue + } + + buf.Write(bigEndian.Int32(int32(len(p)))) + buf.Write(p) + } + + buf.Write(bigEndian.Uint16(uint16(len(src.ResultFormatCodes)))) + + for _, fc := range src.ResultFormatCodes { + buf.Write(bigEndian.Int16(fc)) + } + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} + +func (src *Bind) MarshalJSON() ([]byte, error) { + formattedParameters := make([]map[string]string, len(src.Parameters)) + for i, p := range src.Parameters { + if p == nil { + continue + } + + if src.ParameterFormatCodes[i] == 0 { + formattedParameters[i] = map[string]string{"text": string(p)} + } else { + formattedParameters[i] = map[string]string{"binary": hex.EncodeToString(p)} + } + } + + return json.Marshal(struct { + Type string + DestinationPortal string + PreparedStatement string + ParameterFormatCodes []int16 + Parameters []map[string]string + ResultFormatCodes []int16 + }{ + Type: "Bind", + DestinationPortal: src.DestinationPortal, + PreparedStatement: src.PreparedStatement, + ParameterFormatCodes: src.ParameterFormatCodes, + Parameters: formattedParameters, + ResultFormatCodes: src.ResultFormatCodes, + }) +} diff --git a/pgproto3/describe.go b/pgproto3/describe.go new file mode 100644 index 00000000..ea55ed9d --- /dev/null +++ b/pgproto3/describe.go @@ -0,0 +1,60 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type Describe struct { + ObjectType byte // 'S' = prepared statement, 'P' = portal + Name string +} + +func (*Describe) Frontend() {} + +func (dst *Describe) Decode(src []byte) error { + if len(src) < 2 { + return &invalidMessageFormatErr{messageType: "Describe"} + } + + dst.ObjectType = src[0] + rp := 1 + + idx := bytes.IndexByte(src[rp:], 0) + if idx != len(src[rp:])-1 { + return &invalidMessageFormatErr{messageType: "Describe"} + } + + dst.Name = string(src[rp : len(src)-1]) + + return nil +} + +func (src *Describe) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('D') + buf.Write(bigEndian.Uint32(0)) + + buf.WriteByte(src.ObjectType) + buf.WriteString(src.Name) + buf.WriteByte(0) + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} + +func (src *Describe) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ObjectType string + Name string + }{ + Type: "Describe", + ObjectType: string(src.ObjectType), + Name: src.Name, + }) +} diff --git a/pgproto3/execute.go b/pgproto3/execute.go new file mode 100644 index 00000000..4892e7b3 --- /dev/null +++ b/pgproto3/execute.go @@ -0,0 +1,60 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type Execute struct { + Portal string + MaxRows uint32 +} + +func (*Execute) Frontend() {} + +func (dst *Execute) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + dst.Portal = string(b[:len(b)-1]) + + if buf.Len() < 4 { + return &invalidMessageFormatErr{messageType: "Execute"} + } + dst.MaxRows = binary.BigEndian.Uint32(buf.Next(4)) + + return nil +} + +func (src *Execute) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('E') + buf.Write(bigEndian.Uint32(0)) + + buf.WriteString(src.Portal) + buf.WriteByte(0) + + buf.Write(bigEndian.Uint32(src.MaxRows)) + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} + +func (src *Execute) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Portal string + MaxRows uint32 + }{ + Type: "Execute", + Portal: src.Portal, + MaxRows: src.MaxRows, + }) +} diff --git a/pgproto3/parse.go b/pgproto3/parse.go new file mode 100644 index 00000000..5d17ed11 --- /dev/null +++ b/pgproto3/parse.go @@ -0,0 +1,82 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type Parse struct { + Name string + Query string + ParameterOIDs []uint32 +} + +func (*Parse) Frontend() {} + +func (dst *Parse) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + dst.Name = string(b[:len(b)-1]) + + b, err = buf.ReadBytes(0) + if err != nil { + return err + } + dst.Query = string(b[:len(b)-1]) + + if buf.Len() < 2 { + return &invalidMessageFormatErr{messageType: "Parse"} + } + parameterOIDCount := int(binary.BigEndian.Uint16(buf.Next(2))) + + for i := 0; i < parameterOIDCount; i++ { + if buf.Len() < 4 { + return &invalidMessageFormatErr{messageType: "Parse"} + } + dst.ParameterOIDs = append(dst.ParameterOIDs, binary.BigEndian.Uint32(buf.Next(4))) + } + + return nil +} + +func (src *Parse) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('P') + buf.Write(bigEndian.Uint32(0)) + + buf.WriteString(src.Name) + buf.WriteByte(0) + buf.WriteString(src.Query) + buf.WriteByte(0) + + buf.Write(bigEndian.Uint16(uint16(len(src.ParameterOIDs)))) + + for _, v := range src.ParameterOIDs { + buf.Write(bigEndian.Uint32(v)) + } + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} + +func (src *Parse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Name string + Query string + ParameterOIDs []uint32 + }{ + Type: "Parse", + Name: src.Name, + Query: src.Query, + ParameterOIDs: src.ParameterOIDs, + }) +} diff --git a/pgproto3/password_message.go b/pgproto3/password_message.go new file mode 100644 index 00000000..69df6362 --- /dev/null +++ b/pgproto3/password_message.go @@ -0,0 +1,44 @@ +package pgproto3 + +import ( + "bytes" + "encoding/json" +) + +type PasswordMessage struct { + Password string +} + +func (*PasswordMessage) Frontend() {} + +func (dst *PasswordMessage) Decode(src []byte) error { + buf := bytes.NewBuffer(src) + + b, err := buf.ReadBytes(0) + if err != nil { + return err + } + dst.Password = string(b[:len(b)-1]) + + return nil +} + +func (src *PasswordMessage) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + buf.WriteByte('p') + buf.Write(bigEndian.Uint32(uint32(4 + len(src.Password) + 1))) + buf.WriteString(src.Password) + buf.WriteByte(0) + return buf.Bytes(), nil +} + +func (src *PasswordMessage) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Password string + }{ + Type: "PasswordMessage", + Password: src.Password, + }) +} diff --git a/pgproto3/sync.go b/pgproto3/sync.go new file mode 100644 index 00000000..da3fa727 --- /dev/null +++ b/pgproto3/sync.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type Sync struct{} + +func (*Sync) Frontend() {} + +func (dst *Sync) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "Sync", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *Sync) MarshalBinary() ([]byte, error) { + return []byte{'S', 0, 0, 0, 4}, nil +} + +func (src *Sync) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "Sync", + }) +} diff --git a/pgproto3/terminate.go b/pgproto3/terminate.go new file mode 100644 index 00000000..77977f20 --- /dev/null +++ b/pgproto3/terminate.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type Terminate struct{} + +func (*Terminate) Frontend() {} + +func (dst *Terminate) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "Terminate", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *Terminate) MarshalBinary() ([]byte, error) { + return []byte{'X', 0, 0, 0, 4}, nil +} + +func (src *Terminate) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "Terminate", + }) +} From ee0c64864e98bce261691ef3f44cbae2fda5399a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 1 May 2017 19:32:16 -0500 Subject: [PATCH 182/264] Fix Travis --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index 76311d4c..85981e4e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -55,6 +55,7 @@ install: - go get -u github.com/lib/pq - go get -u github.com/hashicorp/go-version - go get -u github.com/satori/go.uuid + - go get -u github.com/Sirupsen/logrus script: - go test -v -race ./... From ee001a7caebc75cc51d0068e9109b99d8a20755d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 1 May 2017 19:46:37 -0500 Subject: [PATCH 183/264] Fix queries with more than 32 columns fixes #270 --- pgproto3/data_row.go | 6 ++++- query_test.go | 53 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/pgproto3/data_row.go b/pgproto3/data_row.go index 6b27f728..3e600e84 100644 --- a/pgproto3/data_row.go +++ b/pgproto3/data_row.go @@ -25,7 +25,11 @@ func (dst *DataRow) Decode(src []byte) error { // large reallocate. This is too avoid one row with many columns from // permanently allocating memory. if cap(dst.Values) < fieldCount || cap(dst.Values)-fieldCount > 32 { - dst.Values = make([][]byte, fieldCount, 32) + newCap := 32 + if newCap < fieldCount { + newCap = fieldCount + } + dst.Values = make([][]byte, fieldCount, newCap) } else { dst.Values = dst.Values[:fieldCount] } diff --git a/query_test.go b/query_test.go index c1ca480a..801b34dd 100644 --- a/query_test.go +++ b/query_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "database/sql" + "fmt" "strings" "testing" "time" @@ -46,6 +47,58 @@ func TestConnQueryScan(t *testing.T) { } } +func TestConnQueryScanWithManyColumns(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + columnCount := 1000 + sql := "select " + for i := 0; i < columnCount; i++ { + if i > 0 { + sql += "," + } + sql += fmt.Sprintf(" %d", i) + } + sql += " from generate_series(1,5)" + + dest := make([]int, columnCount) + + var rowCount int + + rows, err := conn.Query(sql) + if err != nil { + t.Fatalf("conn.Query failed: %v", err) + } + defer rows.Close() + + for rows.Next() { + destPtrs := make([]interface{}, columnCount) + for i := range destPtrs { + destPtrs[i] = &dest[i] + } + if err := rows.Scan(destPtrs...); err != nil { + t.Fatalf("rows.Scan failed: %v", err) + } + rowCount++ + + for i := range dest { + if dest[i] != i { + t.Errorf("dest[%d] => %d, want %d", i, dest[i], i) + } + } + } + + if rows.Err() != nil { + t.Fatalf("conn.Query failed: %v", err) + } + + if rowCount != 5 { + t.Errorf("rowCount => %d, want %d", rowCount, 5) + } +} + func TestConnQueryValues(t *testing.T) { t.Parallel() From 6e64a0c8676c528c3cd02e791a0b2a26d72e33a8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 2 May 2017 20:38:26 -0500 Subject: [PATCH 184/264] Refactor pgio and types to append buffers --- pgio/read.go | 106 ++++-------------- pgio/read_test.go | 57 ++++++++++ pgio/write.go | 105 +++++------------- pgio/write_test.go | 78 ++++++++++++++ pgtype/aclitem.go | 10 +- pgtype/aclitem_array.go | 60 ++++------- pgtype/array.go | 65 +++-------- pgtype/bool.go | 29 +++-- pgtype/bool_array.go | 99 ++++++----------- pgtype/box.go | 37 +++---- pgtype/bytea.go | 26 ++--- pgtype/bytea_array.go | 99 ++++++----------- pgtype/cid.go | 9 +- pgtype/cidr.go | 12 +-- pgtype/cidr_array.go | 99 ++++++----------- pgtype/circle.go | 31 +++--- pgtype/database_sql.go | 17 ++- pgtype/date.go | 19 ++-- pgtype/date_array.go | 99 ++++++----------- pgtype/daterange.go | 118 +++++++++----------- pgtype/decimal.go | 12 +-- pgtype/ext/satori-uuid/uuid.go | 19 ++-- pgtype/ext/shopspring-numeric/decimal.go | 33 +++--- pgtype/float4.go | 21 ++-- pgtype/float4_array.go | 99 ++++++----------- pgtype/float8.go | 21 ++-- pgtype/float8_array.go | 99 ++++++----------- pgtype/generic_binary.go | 5 +- pgtype/generic_text.go | 5 +- pgtype/hstore.go | 93 +++++----------- pgtype/hstore_array.go | 99 ++++++----------- pgtype/hstore_test.go | 58 +++++----- pgtype/inet.go | 39 +++---- pgtype/inet_array.go | 99 ++++++----------- pgtype/int2.go | 19 ++-- pgtype/int2_array.go | 99 ++++++----------- pgtype/int4.go | 19 ++-- pgtype/int4_array.go | 99 ++++++----------- pgtype/int4range.go | 118 +++++++++----------- pgtype/int8.go | 19 ++-- pgtype/int8_array.go | 99 ++++++----------- pgtype/int8range.go | 118 +++++++++----------- pgtype/interval.go | 54 +++------- pgtype/json.go | 14 ++- pgtype/jsonb.go | 20 ++-- pgtype/line.go | 30 ++---- pgtype/lseg.go | 38 +++---- pgtype/macaddr.go | 19 ++-- pgtype/name.go | 9 +- pgtype/numeric.go | 63 ++++------- pgtype/numeric_array.go | 99 ++++++----------- pgtype/numrange.go | 118 +++++++++----------- pgtype/oid.go | 11 +- pgtype/oid_value.go | 9 +- pgtype/path.go | 47 +++----- pgtype/pgtype.go | 17 ++- pgtype/pguint32.go | 19 ++-- pgtype/point.go | 26 ++--- pgtype/polygon.go | 43 +++----- pgtype/qchar.go | 11 +- pgtype/testutil/testutil.go | 9 +- pgtype/text.go | 14 ++- pgtype/text_array.go | 99 ++++++----------- pgtype/tid.go | 27 ++--- pgtype/timestamp.go | 23 ++-- pgtype/timestamp_array.go | 99 ++++++----------- pgtype/timestamptz.go | 19 ++-- pgtype/timestamptz_array.go | 99 ++++++----------- pgtype/tsrange.go | 118 +++++++++----------- pgtype/tstzrange.go | 118 +++++++++----------- pgtype/typed_array.go.erb | 97 ++++++----------- pgtype/typed_range.go.erb | 132 ++++++++++------------- pgtype/uuid.go | 19 ++-- pgtype/varbit.go | 29 ++--- pgtype/varchar.go | 9 +- pgtype/varchar_array.go | 99 ++++++----------- pgtype/xid.go | 9 +- values.go | 45 ++++---- 78 files changed, 1551 insertions(+), 2627 deletions(-) create mode 100644 pgio/read_test.go create mode 100644 pgio/write_test.go diff --git a/pgio/read.go b/pgio/read.go index 7c39162c..7ddad508 100644 --- a/pgio/read.go +++ b/pgio/read.go @@ -2,103 +2,39 @@ package pgio import ( "encoding/binary" - "io" ) -type Uint16Reader interface { - ReadUint16() (n uint16, err error) +func NextByte(buf []byte) ([]byte, byte) { + b := buf[0] + return buf[1:], b } -type Uint32Reader interface { - ReadUint32() (n uint32, err error) +func NextUint16(buf []byte) ([]byte, uint16) { + n := binary.BigEndian.Uint16(buf) + return buf[2:], n } -type Uint64Reader interface { - ReadUint64() (n uint64, err error) +func NextUint32(buf []byte) ([]byte, uint32) { + n := binary.BigEndian.Uint32(buf) + return buf[4:], n } -// ReadByte reads a byte from r. -func ReadByte(r io.Reader) (byte, error) { - if r, ok := r.(io.ByteReader); ok { - return r.ReadByte() - } - - buf := make([]byte, 1) - _, err := r.Read(buf) - return buf[0], err +func NextUint64(buf []byte) ([]byte, uint64) { + n := binary.BigEndian.Uint64(buf) + return buf[8:], n } -// ReadUint16 reads an uint16 from r in PostgreSQL wire format (network byte order). This -// may be more efficient than directly using Read if r provides a ReadUint16 -// method. -func ReadUint16(r io.Reader) (uint16, error) { - if r, ok := r.(Uint16Reader); ok { - return r.ReadUint16() - } - - buf := make([]byte, 2) - _, err := io.ReadFull(r, buf) - if err != nil { - return 0, err - } - - return binary.BigEndian.Uint16(buf), nil +func NextInt16(buf []byte) ([]byte, int16) { + buf, n := NextUint16(buf) + return buf, int16(n) } -// ReadInt16 reads an int16 r in PostgreSQL wire format (network byte order). This -// may be more efficient than directly using Read if r provides a ReadUint16 -// method. -func ReadInt16(r io.Reader) (int16, error) { - n, err := ReadUint16(r) - return int16(n), err +func NextInt32(buf []byte) ([]byte, int32) { + buf, n := NextUint32(buf) + return buf, int32(n) } -// ReadUint32 reads an uint32 r in PostgreSQL wire format (network byte order). This -// may be more efficient than directly using Read if r provides a ReadUint32 -// method. -func ReadUint32(r io.Reader) (uint32, error) { - if r, ok := r.(Uint32Reader); ok { - return r.ReadUint32() - } - - buf := make([]byte, 4) - _, err := io.ReadFull(r, buf) - if err != nil { - return 0, err - } - - return binary.BigEndian.Uint32(buf), nil -} - -// ReadInt32 reads an int32 r in PostgreSQL wire format (network byte order). This -// may be more efficient than directly using Read if r provides a ReadUint32 -// method. -func ReadInt32(r io.Reader) (int32, error) { - n, err := ReadUint32(r) - return int32(n), err -} - -// ReadUint64 reads an uint64 r in PostgreSQL wire format (network byte order). This -// may be more efficient than directly using Read if r provides a ReadUint64 -// method. -func ReadUint64(r io.Reader) (uint64, error) { - if r, ok := r.(Uint64Reader); ok { - return r.ReadUint64() - } - - buf := make([]byte, 8) - _, err := io.ReadFull(r, buf) - if err != nil { - return 0, err - } - - return binary.BigEndian.Uint64(buf), nil -} - -// ReadInt64 reads an int64 r in PostgreSQL wire format (network byte order). This -// may be more efficient than directly using Read if r provides a ReadUint64 -// method. -func ReadInt64(r io.Reader) (int64, error) { - n, err := ReadUint64(r) - return int64(n), err +func NextInt64(buf []byte) ([]byte, int64) { + buf, n := NextUint64(buf) + return buf, int64(n) } diff --git a/pgio/read_test.go b/pgio/read_test.go new file mode 100644 index 00000000..fbe29ae4 --- /dev/null +++ b/pgio/read_test.go @@ -0,0 +1,57 @@ +package pgio + +import ( + "testing" +) + +func TestNextByte(t *testing.T) { + buf := []byte{42, 1} + var b byte + buf, b = NextByte(buf) + if b != 42 { + t.Errorf("NextByte(buf) => %v, want %v", b, 42) + } + buf, b = NextByte(buf) + if b != 1 { + t.Errorf("NextByte(buf) => %v, want %v", b, 1) + } +} + +func TestNextUint16(t *testing.T) { + buf := []byte{0, 42, 0, 1} + var n uint16 + buf, n = NextUint16(buf) + if n != 42 { + t.Errorf("NextUint16(buf) => %v, want %v", n, 42) + } + buf, n = NextUint16(buf) + if n != 1 { + t.Errorf("NextUint16(buf) => %v, want %v", n, 1) + } +} + +func TestNextUint32(t *testing.T) { + buf := []byte{0, 0, 0, 42, 0, 0, 0, 1} + var n uint32 + buf, n = NextUint32(buf) + if n != 42 { + t.Errorf("NextUint32(buf) => %v, want %v", n, 42) + } + buf, n = NextUint32(buf) + if n != 1 { + t.Errorf("NextUint32(buf) => %v, want %v", n, 1) + } +} + +func TestNextUint64(t *testing.T) { + buf := []byte{0, 0, 0, 0, 0, 0, 0, 42, 0, 0, 0, 0, 0, 0, 0, 1} + var n uint64 + buf, n = NextUint64(buf) + if n != 42 { + t.Errorf("NextUint64(buf) => %v, want %v", n, 42) + } + buf, n = NextUint64(buf) + if n != 1 { + t.Errorf("NextUint64(buf) => %v, want %v", n, 1) + } +} diff --git a/pgio/write.go b/pgio/write.go index 823fbd00..96aedf9d 100644 --- a/pgio/write.go +++ b/pgio/write.go @@ -1,97 +1,40 @@ package pgio -import ( - "encoding/binary" - "io" -) +import "encoding/binary" -type Uint16Writer interface { - WriteUint16(uint16) (n int, err error) +func AppendUint16(buf []byte, n uint16) []byte { + wp := len(buf) + buf = append(buf, 0, 0) + binary.BigEndian.PutUint16(buf[wp:], n) + return buf } -type Uint32Writer interface { - WriteUint32(uint32) (n int, err error) +func AppendUint32(buf []byte, n uint32) []byte { + wp := len(buf) + buf = append(buf, 0, 0, 0, 0) + binary.BigEndian.PutUint32(buf[wp:], n) + return buf } -type Uint64Writer interface { - WriteUint64(uint64) (n int, err error) +func AppendUint64(buf []byte, n uint64) []byte { + wp := len(buf) + buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0) + binary.BigEndian.PutUint64(buf[wp:], n) + return buf } -// WriteByte writes b to w. -func WriteByte(w io.Writer, b byte) error { - if w, ok := w.(io.ByteWriter); ok { - return w.WriteByte(b) - } - _, err := w.Write([]byte{b}) - return err +func AppendInt16(buf []byte, n int16) []byte { + return AppendUint16(buf, uint16(n)) } -// WriteUint16 writes n to w in PostgreSQL wire format (network byte order). This -// may be more efficient than directly using Write if w provides a WriteUint16 -// method. -func WriteUint16(w io.Writer, n uint16) (int, error) { - if w, ok := w.(Uint16Writer); ok { - return w.WriteUint16(n) - } - b := make([]byte, 2) - binary.BigEndian.PutUint16(b, n) - return w.Write(b) +func AppendInt32(buf []byte, n int32) []byte { + return AppendUint32(buf, uint32(n)) } -// WriteInt16 writes n to w in PostgreSQL wire format (network byte order). This -// may be more efficient than directly using Write if w provides a WriteUint16 -// method. -func WriteInt16(w io.Writer, n int16) (int, error) { - return WriteUint16(w, uint16(n)) +func AppendInt64(buf []byte, n int64) []byte { + return AppendUint64(buf, uint64(n)) } -// WriteUint32 writes n to w in PostgreSQL wire format (network byte order). This -// may be more efficient than directly using Write if w provides a WriteUint32 -// method. -func WriteUint32(w io.Writer, n uint32) (int, error) { - if w, ok := w.(Uint32Writer); ok { - return w.WriteUint32(n) - } - b := make([]byte, 4) - binary.BigEndian.PutUint32(b, n) - return w.Write(b) -} - -// WriteInt32 writes n to w in PostgreSQL wire format (network byte order). This -// may be more efficient than directly using Write if w provides a WriteUint32 -// method. -func WriteInt32(w io.Writer, n int32) (int, error) { - return WriteUint32(w, uint32(n)) -} - -// WriteUint64 writes n to w in PostgreSQL wire format (network byte order). This -// may be more efficient than directly using Write if w provides a WriteUint64 -// method. -func WriteUint64(w io.Writer, n uint64) (int, error) { - if w, ok := w.(Uint64Writer); ok { - return w.WriteUint64(n) - } - b := make([]byte, 8) - binary.BigEndian.PutUint64(b, n) - return w.Write(b) -} - -// WriteInt64 writes n to w in PostgreSQL wire format (network byte order). This -// may be more efficient than directly using Write if w provides a WriteUint64 -// method. -func WriteInt64(w io.Writer, n int64) (int, error) { - return WriteUint64(w, uint64(n)) -} - -// WriteCString writes s to w followed by a null byte. -func WriteCString(w io.Writer, s string) (int, error) { - n, err := io.WriteString(w, s) - if err != nil { - return n, err - } - err = WriteByte(w, 0) - if err != nil { - return n, err - } - return n + 1, nil +func SetInt32(buf []byte, n int32) { + binary.BigEndian.PutUint32(buf, uint32(n)) } diff --git a/pgio/write_test.go b/pgio/write_test.go new file mode 100644 index 00000000..bd50e71c --- /dev/null +++ b/pgio/write_test.go @@ -0,0 +1,78 @@ +package pgio + +import ( + "reflect" + "testing" +) + +func TestAppendUint16NilBuf(t *testing.T) { + buf := AppendUint16(nil, 1) + if !reflect.DeepEqual(buf, []byte{0, 1}) { + t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1}) + } +} + +func TestAppendUint16EmptyBuf(t *testing.T) { + buf := []byte{} + buf = AppendUint16(buf, 1) + if !reflect.DeepEqual(buf, []byte{0, 1}) { + t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1}) + } +} + +func TestAppendUint16BufWithCapacityDoesNotAllocate(t *testing.T) { + buf := make([]byte, 0, 4) + AppendUint16(buf, 1) + buf = buf[0:2] + if !reflect.DeepEqual(buf, []byte{0, 1}) { + t.Errorf("AppendUint16(nil, 1) => %v, want %v", buf, []byte{0, 1}) + } +} + +func TestAppendUint32NilBuf(t *testing.T) { + buf := AppendUint32(nil, 1) + if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) { + t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1}) + } +} + +func TestAppendUint32EmptyBuf(t *testing.T) { + buf := []byte{} + buf = AppendUint32(buf, 1) + if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) { + t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1}) + } +} + +func TestAppendUint32BufWithCapacityDoesNotAllocate(t *testing.T) { + buf := make([]byte, 0, 4) + AppendUint32(buf, 1) + buf = buf[0:4] + if !reflect.DeepEqual(buf, []byte{0, 0, 0, 1}) { + t.Errorf("AppendUint32(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 1}) + } +} + +func TestAppendUint64NilBuf(t *testing.T) { + buf := AppendUint64(nil, 1) + if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) { + t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) + } +} + +func TestAppendUint64EmptyBuf(t *testing.T) { + buf := []byte{} + buf = AppendUint64(buf, 1) + if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) { + t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) + } +} + +func TestAppendUint64BufWithCapacityDoesNotAllocate(t *testing.T) { + buf := make([]byte, 0, 8) + AppendUint64(buf, 1) + buf = buf[0:8] + if !reflect.DeepEqual(buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) { + t.Errorf("AppendUint64(nil, 1) => %v, want %v", buf, []byte{0, 0, 0, 0, 0, 0, 0, 1}) + } +} diff --git a/pgtype/aclitem.go b/pgtype/aclitem.go index 31065764..27dc15d1 100644 --- a/pgtype/aclitem.go +++ b/pgtype/aclitem.go @@ -3,7 +3,6 @@ package pgtype import ( "database/sql/driver" "fmt" - "io" ) // Aclitem is used for PostgreSQL's aclitem data type. A sample aclitem @@ -83,16 +82,15 @@ func (dst *Aclitem) DecodeText(ci *ConnInfo, src []byte) error { return nil } -func (src *Aclitem) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Aclitem) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, src.String) - return false, err + return append(buf, src.String...), nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/aclitem_array.go b/pgtype/aclitem_array.go index 480b5bba..7df0b503 100644 --- a/pgtype/aclitem_array.go +++ b/pgtype/aclitem_array.go @@ -1,12 +1,8 @@ package pgtype import ( - "bytes" "database/sql/driver" "fmt" - "io" - - "github.com/jackc/pgx/pgio" ) type AclitemArray struct { @@ -120,23 +116,19 @@ func (dst *AclitemArray) DecodeText(ci *ConnInfo, src []byte) error { return nil } -func (src *AclitemArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *AclitemArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -149,51 +141,36 @@ func (src *AclitemArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -216,14 +193,13 @@ func (dst *AclitemArray) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *AclitemArray) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/pgtype/array.go b/pgtype/array.go index 9561afe5..2f9ef66b 100644 --- a/pgtype/array.go +++ b/pgtype/array.go @@ -60,39 +60,23 @@ func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) { return rp, nil } -func (src *ArrayHeader) EncodeBinary(ci *ConnInfo, w io.Writer) error { - _, err := pgio.WriteInt32(w, int32(len(src.Dimensions))) - if err != nil { - return err - } +func (src *ArrayHeader) EncodeBinary(ci *ConnInfo, buf []byte) []byte { + buf = pgio.AppendInt32(buf, int32(len(src.Dimensions))) var containsNull int32 if src.ContainsNull { containsNull = 1 } - _, err = pgio.WriteInt32(w, containsNull) - if err != nil { - return err - } + buf = pgio.AppendInt32(buf, containsNull) - _, err = pgio.WriteInt32(w, src.ElementOid) - if err != nil { - return err - } + buf = pgio.AppendInt32(buf, src.ElementOid) for i := range src.Dimensions { - _, err = pgio.WriteInt32(w, src.Dimensions[i].Length) - if err != nil { - return err - } - - _, err = pgio.WriteInt32(w, src.Dimensions[i].LowerBound) - if err != nil { - return err - } + buf = pgio.AppendInt32(buf, src.Dimensions[i].Length) + buf = pgio.AppendInt32(buf, src.Dimensions[i].LowerBound) } - return nil + return buf } type UntypedTextArray struct { @@ -331,7 +315,7 @@ func arrayParseInteger(buf *bytes.Buffer) (int32, error) { } } -func EncodeTextArrayDimensions(w io.Writer, dimensions []ArrayDimension) error { +func EncodeTextArrayDimensions(buf []byte, dimensions []ArrayDimension) []byte { var customDimensions bool for _, dim := range dimensions { if dim.LowerBound != 1 { @@ -340,37 +324,18 @@ func EncodeTextArrayDimensions(w io.Writer, dimensions []ArrayDimension) error { } if !customDimensions { - return nil + return buf } for _, dim := range dimensions { - err := pgio.WriteByte(w, '[') - if err != nil { - return err - } - - _, err = io.WriteString(w, strconv.FormatInt(int64(dim.LowerBound), 10)) - if err != nil { - return err - } - - err = pgio.WriteByte(w, ':') - if err != nil { - return err - } - - _, err = io.WriteString(w, strconv.FormatInt(int64(dim.LowerBound+dim.Length-1), 10)) - if err != nil { - return err - } - - err = pgio.WriteByte(w, ']') - if err != nil { - return err - } + buf = append(buf, '[') + buf = append(buf, strconv.FormatInt(int64(dim.LowerBound), 10)...) + buf = append(buf, ':') + buf = append(buf, strconv.FormatInt(int64(dim.LowerBound+dim.Length-1), 10)...) + buf = append(buf, ']') } - return pgio.WriteByte(w, '=') + return append(buf, '=') } var quoteArrayReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) diff --git a/pgtype/bool.go b/pgtype/bool.go index ba876c91..7c66a534 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -3,7 +3,6 @@ package pgtype import ( "database/sql/driver" "fmt" - "io" "strconv" ) @@ -90,42 +89,38 @@ func (dst *Bool) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Bool) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Bool) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - var buf []byte if src.Bool { - buf = []byte{'t'} + buf = append(buf, 't') } else { - buf = []byte{'f'} + buf = append(buf, 'f') } - _, err := w.Write(buf) - return false, err + return buf, nil } -func (src *Bool) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Bool) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - var buf []byte if src.Bool { - buf = []byte{1} + buf = append(buf, 1) } else { - buf = []byte{0} + buf = append(buf, 0) } - _, err := w.Write(buf) - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/bool_array.go b/pgtype/bool_array.go index 4e92a616..3c3d4184 100644 --- a/pgtype/bool_array.go +++ b/pgtype/bool_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -163,23 +161,19 @@ func (dst *BoolArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *BoolArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *BoolArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -192,59 +186,44 @@ func (src *BoolArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *BoolArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *BoolArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -254,7 +233,7 @@ func (src *BoolArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { if dt, ok := ci.DataTypeForName("bool"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "bool") + return nil, fmt.Errorf("unable to find oid for type name %v", "bool") } for i := range src.Elements { @@ -264,38 +243,23 @@ func (src *BoolArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -318,14 +282,13 @@ func (dst *BoolArray) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *BoolArray) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/pgtype/box.go b/pgtype/box.go index e25af854..2d098058 100644 --- a/pgtype/box.go +++ b/pgtype/box.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "math" "strconv" "strings" @@ -108,41 +107,33 @@ func (dst *Box) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Box) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Box) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f),(%f,%f)`, - src.P[0].X, src.P[0].Y, src.P[1].X, src.P[1].Y)) - return false, err + buf = append(buf, fmt.Sprintf(`(%f,%f),(%f,%f)`, + src.P[0].X, src.P[0].Y, src.P[1].X, src.P[1].Y)...) + return buf, nil } -func (src *Box) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Box) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - if _, err := pgio.WriteUint64(w, math.Float64bits(src.P[0].X)); err != nil { - return false, err - } + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].Y)) - if _, err := pgio.WriteUint64(w, math.Float64bits(src.P[0].Y)); err != nil { - return false, err - } - - if _, err := pgio.WriteUint64(w, math.Float64bits(src.P[1].X)); err != nil { - return false, err - } - - _, err := pgio.WriteUint64(w, math.Float64bits(src.P[1].Y)) - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/bytea.go b/pgtype/bytea.go index bf774476..2ddac7da 100644 --- a/pgtype/bytea.go +++ b/pgtype/bytea.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/hex" "fmt" - "io" ) type Bytea struct { @@ -99,33 +98,28 @@ func (dst *Bytea) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Bytea) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Bytea) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, `\x`) - if err != nil { - return false, err - } - - _, err = io.WriteString(w, hex.EncodeToString(src.Bytes)) - return false, err + buf = append(buf, `\x`...) + buf = append(buf, hex.EncodeToString(src.Bytes)...) + return buf, nil } -func (src *Bytea) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Bytea) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := w.Write(src.Bytes) - return false, err + return append(buf, src.Bytes...), nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/bytea_array.go b/pgtype/bytea_array.go index dd79b991..67e114f5 100644 --- a/pgtype/bytea_array.go +++ b/pgtype/bytea_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -163,23 +161,19 @@ func (dst *ByteaArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *ByteaArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *ByteaArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -192,59 +186,44 @@ func (src *ByteaArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *ByteaArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *ByteaArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -254,7 +233,7 @@ func (src *ByteaArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { if dt, ok := ci.DataTypeForName("bytea"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "bytea") + return nil, fmt.Errorf("unable to find oid for type name %v", "bytea") } for i := range src.Elements { @@ -264,38 +243,23 @@ func (src *ByteaArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -318,14 +282,13 @@ func (dst *ByteaArray) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *ByteaArray) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/pgtype/cid.go b/pgtype/cid.go index c2b3073b..b7718f88 100644 --- a/pgtype/cid.go +++ b/pgtype/cid.go @@ -2,7 +2,6 @@ package pgtype import ( "database/sql/driver" - "io" ) // Cid is PostgreSQL's Command Identifier type. @@ -43,12 +42,12 @@ func (dst *Cid) DecodeBinary(ci *ConnInfo, src []byte) error { return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src *Cid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (*pguint32)(src).EncodeText(ci, w) +func (src *Cid) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*pguint32)(src).EncodeText(ci, buf) } -func (src *Cid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (*pguint32)(src).EncodeBinary(ci, w) +func (src *Cid) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*pguint32)(src).EncodeBinary(ci, buf) } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/cidr.go b/pgtype/cidr.go index 39a87a26..2b45d2d0 100644 --- a/pgtype/cidr.go +++ b/pgtype/cidr.go @@ -1,9 +1,5 @@ package pgtype -import ( - "io" -) - type Cidr Inet func (dst *Cidr) Set(src interface{}) error { @@ -26,10 +22,10 @@ func (dst *Cidr) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Inet)(dst).DecodeBinary(ci, src) } -func (src *Cidr) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (*Inet)(src).EncodeText(ci, w) +func (src *Cidr) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Inet)(src).EncodeText(ci, buf) } -func (src *Cidr) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (*Inet)(src).EncodeBinary(ci, w) +func (src *Cidr) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Inet)(src).EncodeBinary(ci, buf) } diff --git a/pgtype/cidr_array.go b/pgtype/cidr_array.go index 0aa289e7..01237aa1 100644 --- a/pgtype/cidr_array.go +++ b/pgtype/cidr_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "net" "github.com/jackc/pgx/pgio" @@ -192,23 +190,19 @@ func (dst *CidrArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *CidrArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *CidrArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -221,59 +215,44 @@ func (src *CidrArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *CidrArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *CidrArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -283,7 +262,7 @@ func (src *CidrArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { if dt, ok := ci.DataTypeForName("cidr"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "cidr") + return nil, fmt.Errorf("unable to find oid for type name %v", "cidr") } for i := range src.Elements { @@ -293,38 +272,23 @@ func (src *CidrArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -347,14 +311,13 @@ func (dst *CidrArray) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *CidrArray) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/pgtype/circle.go b/pgtype/circle.go index e9268a06..8626a99d 100644 --- a/pgtype/circle.go +++ b/pgtype/circle.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "math" "strconv" "strings" @@ -95,36 +94,30 @@ func (dst *Circle) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Circle) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Circle) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, fmt.Sprintf(`<(%f,%f),%f>`, src.P.X, src.P.Y, src.R)) - return false, err + buf = append(buf, fmt.Sprintf(`<(%f,%f),%f>`, src.P.X, src.P.Y, src.R)...) + return buf, nil } -func (src *Circle) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Circle) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - if _, err := pgio.WriteUint64(w, math.Float64bits(src.P.X)); err != nil { - return false, err - } - - if _, err := pgio.WriteUint64(w, math.Float64bits(src.P.Y)); err != nil { - return false, err - } - - _, err := pgio.WriteUint64(w, math.Float64bits(src.R)) - return false, err + buf = pgio.AppendUint64(buf, math.Float64bits(src.P.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P.Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.R)) + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/database_sql.go b/pgtype/database_sql.go index e255b646..9d1cf822 100644 --- a/pgtype/database_sql.go +++ b/pgtype/database_sql.go @@ -1,7 +1,6 @@ package pgtype import ( - "bytes" "database/sql/driver" "errors" ) @@ -11,34 +10,32 @@ func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) { return valuer.Value() } - buf := &bytes.Buffer{} if textEncoder, ok := src.(TextEncoder); ok { - _, err := textEncoder.EncodeText(ci, buf) + buf, err := textEncoder.EncodeText(ci, nil) if err != nil { return nil, err } - return buf.String(), nil + return string(buf), nil } if binaryEncoder, ok := src.(BinaryEncoder); ok { - _, err := binaryEncoder.EncodeBinary(ci, buf) + buf, err := binaryEncoder.EncodeBinary(ci, nil) if err != nil { return nil, err } - return buf.Bytes(), nil + return buf, nil } return nil, errors.New("cannot convert to database/sql compatible value") } func EncodeValueText(src TextEncoder) (interface{}, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, make([]byte, 0, 32)) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), err + return string(buf), err } diff --git a/pgtype/date.go b/pgtype/date.go index a7e4762a..8e049254 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "time" "github.com/jackc/pgx/pgio" @@ -125,12 +124,12 @@ func (dst *Date) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Date) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Date) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } var s string @@ -144,16 +143,15 @@ func (src *Date) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { s = "-infinity" } - _, err := io.WriteString(w, s) - return false, err + return append(buf, s...), nil } -func (src *Date) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Date) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } var daysSinceDateEpoch int32 @@ -170,8 +168,7 @@ func (src *Date) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { daysSinceDateEpoch = negativeInfinityDayOffset } - _, err := pgio.WriteInt32(w, daysSinceDateEpoch) - return false, err + return pgio.AppendInt32(buf, daysSinceDateEpoch), nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/date_array.go b/pgtype/date_array.go index 91e2ee62..2175f2aa 100644 --- a/pgtype/date_array.go +++ b/pgtype/date_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "time" "github.com/jackc/pgx/pgio" @@ -164,23 +162,19 @@ func (dst *DateArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *DateArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *DateArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -193,59 +187,44 @@ func (src *DateArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *DateArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *DateArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -255,7 +234,7 @@ func (src *DateArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { if dt, ok := ci.DataTypeForName("date"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "date") + return nil, fmt.Errorf("unable to find oid for type name %v", "date") } for i := range src.Elements { @@ -265,38 +244,23 @@ func (src *DateArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -319,14 +283,13 @@ func (dst *DateArray) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *DateArray) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/pgtype/daterange.go b/pgtype/daterange.go index a5cd5d95..bbe7b17a 100644 --- a/pgtype/daterange.go +++ b/pgtype/daterange.go @@ -1,10 +1,8 @@ package pgtype import ( - "bytes" "database/sql/driver" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -106,72 +104,65 @@ func (dst *Daterange) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Daterange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Daterange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } switch src.LowerType { case Exclusive, Unbounded: - if err := pgio.WriteByte(w, '('); err != nil { - return false, err - } + buf = append(buf, '(') case Inclusive: - if err := pgio.WriteByte(w, '['); err != nil { - return false, err - } + buf = append(buf, '[') case Empty: - _, err := io.WriteString(w, "empty") - return false, err + return append(buf, "empty"...), nil default: - return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) } + var err error + if src.LowerType != Unbounded { - if null, err := src.Lower.EncodeText(ci, w); err != nil { - return false, err - } else if null { - return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } } - if err := pgio.WriteByte(w, ','); err != nil { - return false, err - } + buf = append(buf, ',') if src.UpperType != Unbounded { - if null, err := src.Upper.EncodeText(ci, w); err != nil { - return false, err - } else if null { - return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } } switch src.UpperType { case Exclusive, Unbounded: - if err := pgio.WriteByte(w, ')'); err != nil { - return false, err - } + buf = append(buf, ')') case Inclusive: - if err := pgio.WriteByte(w, ']'); err != nil { - return false, err - } + buf = append(buf, ']') default: - return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) } - return false, nil + return buf, nil } -func (src Daterange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Daterange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } var rangeType byte @@ -182,10 +173,9 @@ func (src Daterange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { rangeType |= lowerUnboundedMask case Exclusive: case Empty: - err := pgio.WriteByte(w, emptyMask) - return false, err + return append(buf, emptyMask), nil default: - return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -195,54 +185,44 @@ func (src Daterange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) } - if err := pgio.WriteByte(w, rangeType); err != nil { - return false, err - } + buf = append(buf, rangeType) - valBuf := &bytes.Buffer{} + var err error if src.LowerType != Unbounded { - null, err := src.Lower.EncodeBinary(ci, valBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } - _, err = pgio.WriteInt32(w, int32(valBuf.Len())) - if err != nil { - return false, err - } - _, err = valBuf.WriteTo(w) - if err != nil { - return false, err - } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } if src.UpperType != Unbounded { - null, err := src.Upper.EncodeBinary(ci, valBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } - _, err = pgio.WriteInt32(w, int32(valBuf.Len())) - if err != nil { - return false, err - } - _, err = valBuf.WriteTo(w) - if err != nil { - return false, err - } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } - return false, nil + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/decimal.go b/pgtype/decimal.go index 728c748e..79653cf3 100644 --- a/pgtype/decimal.go +++ b/pgtype/decimal.go @@ -1,9 +1,5 @@ package pgtype -import ( - "io" -) - type Decimal Numeric func (dst *Decimal) Set(src interface{}) error { @@ -26,10 +22,10 @@ func (dst *Decimal) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Numeric)(dst).DecodeBinary(ci, src) } -func (src *Decimal) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (*Numeric)(src).EncodeText(ci, w) +func (src *Decimal) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Numeric)(src).EncodeText(ci, buf) } -func (src *Decimal) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (*Numeric)(src).EncodeBinary(ci, w) +func (src *Decimal) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Numeric)(src).EncodeBinary(ci, buf) } diff --git a/pgtype/ext/satori-uuid/uuid.go b/pgtype/ext/satori-uuid/uuid.go index 1b65f48a..cff98348 100644 --- a/pgtype/ext/satori-uuid/uuid.go +++ b/pgtype/ext/satori-uuid/uuid.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "errors" "fmt" - "io" "github.com/jackc/pgx/pgtype" uuid "github.com/satori/go.uuid" @@ -117,28 +116,26 @@ func (dst *Uuid) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { return nil } -func (src *Uuid) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { +func (src *Uuid) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case pgtype.Null: - return true, nil + return nil, nil case pgtype.Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, src.UUID.String()) - return false, err + return append(buf, src.UUID.String()...), nil } -func (src *Uuid) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { +func (src *Uuid) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case pgtype.Null: - return true, nil + return nil, nil case pgtype.Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := w.Write(src.UUID[:]) - return false, err + return append(buf, src.UUID[:]...), nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/ext/shopspring-numeric/decimal.go b/pgtype/ext/shopspring-numeric/decimal.go index 9c7e316b..277f3709 100644 --- a/pgtype/ext/shopspring-numeric/decimal.go +++ b/pgtype/ext/shopspring-numeric/decimal.go @@ -1,11 +1,9 @@ package numeric import ( - "bytes" "database/sql/driver" "errors" "fmt" - "io" "strconv" "github.com/jackc/pgx/pgtype" @@ -75,12 +73,12 @@ func (dst *Numeric) Set(src interface{}) error { return fmt.Errorf("cannot convert %v to Numeric", value) } - buf := &bytes.Buffer{} - if _, err := num.EncodeText(nil, buf); err != nil { + buf, err := num.EncodeText(nil, nil) + if err != nil { return fmt.Errorf("cannot convert %v to Numeric", value) } - dec, err := decimal.NewFromString(buf.String()) + dec, err := decimal.NewFromString(string(buf)) if err != nil { return fmt.Errorf("cannot convert %v to Numeric", value) } @@ -243,12 +241,12 @@ func (dst *Numeric) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { return err } - buf := &bytes.Buffer{} - if _, err := num.EncodeText(ci, buf); err != nil { + buf, err := num.EncodeText(ci, nil) + if err != nil { return err } - dec, err := decimal.NewFromString(buf.String()) + dec, err := decimal.NewFromString(string(buf)) if err != nil { return err } @@ -258,33 +256,32 @@ func (dst *Numeric) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { return nil } -func (src *Numeric) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { +func (src *Numeric) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case pgtype.Null: - return true, nil + return nil, nil case pgtype.Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, src.Decimal.String()) - return false, err + return append(buf, src.Decimal.String()...), nil } -func (src *Numeric) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { +func (src *Numeric) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case pgtype.Null: - return true, nil + return nil, nil case pgtype.Undefined: - return false, errUndefined + return nil, errUndefined } // For now at least, implement this in terms of pgtype.Numeric num := &pgtype.Numeric{} if err := num.DecodeText(ci, []byte(src.Decimal.String())); err != nil { - return false, err + return nil, err } - return num.EncodeBinary(ci, w) + return num.EncodeBinary(ci, buf) } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/float4.go b/pgtype/float4.go index 77bc4878..b24654b6 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "math" "strconv" @@ -139,28 +138,28 @@ func (dst *Float4) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Float4) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Float4) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, strconv.FormatFloat(float64(src.Float), 'f', -1, 32)) - return false, err + buf = append(buf, strconv.FormatFloat(float64(src.Float), 'f', -1, 32)...) + return buf, nil } -func (src *Float4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Float4) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := pgio.WriteInt32(w, int32(math.Float32bits(src.Float))) - return false, err + buf = pgio.AppendUint32(buf, math.Float32bits(src.Float)) + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/float4_array.go b/pgtype/float4_array.go index 38508a52..37db8acc 100644 --- a/pgtype/float4_array.go +++ b/pgtype/float4_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -163,23 +161,19 @@ func (dst *Float4Array) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Float4Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Float4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -192,59 +186,44 @@ func (src *Float4Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *Float4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Float4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -254,7 +233,7 @@ func (src *Float4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { if dt, ok := ci.DataTypeForName("float4"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "float4") + return nil, fmt.Errorf("unable to find oid for type name %v", "float4") } for i := range src.Elements { @@ -264,38 +243,23 @@ func (src *Float4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -318,14 +282,13 @@ func (dst *Float4Array) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Float4Array) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/pgtype/float8.go b/pgtype/float8.go index 5322e251..c3ecdcc2 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "math" "strconv" @@ -129,28 +128,28 @@ func (dst *Float8) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Float8) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Float8) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, strconv.FormatFloat(float64(src.Float), 'f', -1, 64)) - return false, err + buf = append(buf, strconv.FormatFloat(float64(src.Float), 'f', -1, 64)...) + return buf, nil } -func (src *Float8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Float8) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := pgio.WriteInt64(w, int64(math.Float64bits(src.Float))) - return false, err + buf = pgio.AppendUint64(buf, math.Float64bits(src.Float)) + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/float8_array.go b/pgtype/float8_array.go index 2f310bbd..dd3fccf1 100644 --- a/pgtype/float8_array.go +++ b/pgtype/float8_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -163,23 +161,19 @@ func (dst *Float8Array) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Float8Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Float8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -192,59 +186,44 @@ func (src *Float8Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *Float8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Float8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -254,7 +233,7 @@ func (src *Float8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { if dt, ok := ci.DataTypeForName("float8"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "float8") + return nil, fmt.Errorf("unable to find oid for type name %v", "float8") } for i := range src.Elements { @@ -264,38 +243,23 @@ func (src *Float8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -318,14 +282,13 @@ func (dst *Float8Array) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Float8Array) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/pgtype/generic_binary.go b/pgtype/generic_binary.go index 094bd64e..2596ecae 100644 --- a/pgtype/generic_binary.go +++ b/pgtype/generic_binary.go @@ -2,7 +2,6 @@ package pgtype import ( "database/sql/driver" - "io" ) // GenericBinary is a placeholder for binary format values that no other type exists @@ -25,8 +24,8 @@ func (dst *GenericBinary) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Bytea)(dst).DecodeBinary(ci, src) } -func (src *GenericBinary) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (*Bytea)(src).EncodeBinary(ci, w) +func (src *GenericBinary) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Bytea)(src).EncodeBinary(ci, buf) } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/generic_text.go b/pgtype/generic_text.go index 5d0d83be..0e3db9de 100644 --- a/pgtype/generic_text.go +++ b/pgtype/generic_text.go @@ -2,7 +2,6 @@ package pgtype import ( "database/sql/driver" - "io" ) // GenericText is a placeholder for text format values that no other type exists @@ -25,8 +24,8 @@ func (dst *GenericText) DecodeText(ci *ConnInfo, src []byte) error { return (*Text)(dst).DecodeText(ci, src) } -func (src *GenericText) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (*Text)(src).EncodeText(ci, w) +func (src *GenericText) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Text)(src).EncodeText(ci, buf) } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/hstore.go b/pgtype/hstore.go index 69a35b17..09506242 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -6,7 +6,6 @@ import ( "encoding/binary" "errors" "fmt" - "io" "strings" "unicode" "unicode/utf8" @@ -151,12 +150,12 @@ func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Hstore) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Hstore) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } firstPair := true @@ -165,90 +164,56 @@ func (src *Hstore) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { if firstPair { firstPair = false } else { - err := pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } - _, err := io.WriteString(w, quoteHstoreElementIfNeeded(k)) + buf = append(buf, quoteHstoreElementIfNeeded(k)...) + buf = append(buf, "=>"...) + + elemBuf, err := v.EncodeText(ci, nil) if err != nil { - return false, err + return nil, err } - _, err = io.WriteString(w, "=>") - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} - null, err := v.EncodeText(ci, elemBuf) - if err != nil { - return false, err - } - - if null { - _, err = io.WriteString(w, "NULL") - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, "NULL"...) } else { - _, err := io.WriteString(w, quoteHstoreElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, quoteHstoreElementIfNeeded(string(elemBuf))...) } } - return false, nil + return buf, nil } -func (src *Hstore) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Hstore) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := pgio.WriteInt32(w, int32(len(src.Map))) - if err != nil { - return false, err - } + buf = pgio.AppendInt32(buf, int32(len(src.Map))) - elemBuf := &bytes.Buffer{} + var err error for k, v := range src.Map { - _, err := pgio.WriteInt32(w, int32(len(k))) - if err != nil { - return false, err - } - _, err = io.WriteString(w, k) - if err != nil { - return false, err - } + buf = pgio.AppendInt32(buf, int32(len(k))) + buf = append(buf, k...) - null, err := v.EncodeText(ci, elemBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := v.EncodeText(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err := pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err := pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, err } var quoteHstoreReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) diff --git a/pgtype/hstore_array.go b/pgtype/hstore_array.go index 9f773af2..2d61fa52 100644 --- a/pgtype/hstore_array.go +++ b/pgtype/hstore_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -163,23 +161,19 @@ func (dst *HstoreArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *HstoreArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *HstoreArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -192,59 +186,44 @@ func (src *HstoreArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *HstoreArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *HstoreArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -254,7 +233,7 @@ func (src *HstoreArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { if dt, ok := ci.DataTypeForName("hstore"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "hstore") + return nil, fmt.Errorf("unable to find oid for type name %v", "hstore") } for i := range src.Elements { @@ -264,38 +243,23 @@ func (src *HstoreArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -318,14 +282,13 @@ func (dst *HstoreArray) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *HstoreArray) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/pgtype/hstore_test.go b/pgtype/hstore_test.go index dc2439fc..8189e4db 100644 --- a/pgtype/hstore_test.go +++ b/pgtype/hstore_test.go @@ -9,41 +9,41 @@ import ( ) func TestHstoreTranscode(t *testing.T) { - text := func(s string) pgtype.Text { - return pgtype.Text{String: s, Status: pgtype.Present} - } + // text := func(s string) pgtype.Text { + // return pgtype.Text{String: s, Status: pgtype.Present} + // } values := []interface{}{ &pgtype.Hstore{Map: map[string]pgtype.Text{}, Status: pgtype.Present}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, - &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, - &pgtype.Hstore{Status: pgtype.Null}, + // &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, + // &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, + // &pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, + // &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, + // &pgtype.Hstore{Status: pgtype.Null}, } - specialStrings := []string{ - `"`, - `'`, - `\`, - `\\`, - `=>`, - ` `, - `\ / / \\ => " ' " '`, - } - for _, s := range specialStrings { - // Special key values - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Status: pgtype.Present}) // at beginning - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Status: pgtype.Present}) // in middle - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Status: pgtype.Present}) // at end - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Status: pgtype.Present}) // is key + // specialStrings := []string{ + // `"`, + // `'`, + // `\`, + // `\\`, + // `=>`, + // ` `, + // `\ / / \\ => " ' " '`, + // } + // for _, s := range specialStrings { + // // Special key values + // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Status: pgtype.Present}) // at beginning + // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Status: pgtype.Present}) // in middle + // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Status: pgtype.Present}) // at end + // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Status: pgtype.Present}) // is key - // Special value values - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Status: pgtype.Present}) // at beginning - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Status: pgtype.Present}) // in middle - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Status: pgtype.Present}) // at end - values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key - } + // // Special value values + // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Status: pgtype.Present}) // at beginning + // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Status: pgtype.Present}) // in middle + // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Status: pgtype.Present}) // at end + // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key + // } testutil.TestSuccessfulTranscodeEqFunc(t, "hstore", values, func(ai, bi interface{}) bool { a := ai.(pgtype.Hstore) diff --git a/pgtype/inet.go b/pgtype/inet.go index 7c09a549..7aa1df95 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -3,10 +3,7 @@ package pgtype import ( "database/sql/driver" "fmt" - "io" "net" - - "github.com/jackc/pgx/pgio" ) // Network address family is dependent on server socket.h value for AF_INET. @@ -149,25 +146,24 @@ func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Inet) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Inet) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, src.IPNet.String()) - return false, err + return append(buf, src.IPNet.String()...), nil } // EncodeBinary encodes src into w. -func (src *Inet) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Inet) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } var family byte @@ -177,29 +173,20 @@ func (src *Inet) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { case net.IPv6len: family = defaultAFInet6 default: - return false, fmt.Errorf("Unexpected IP length: %v", len(src.IPNet.IP)) + return nil, fmt.Errorf("Unexpected IP length: %v", len(src.IPNet.IP)) } - if err := pgio.WriteByte(w, family); err != nil { - return false, err - } + buf = append(buf, family) ones, _ := src.IPNet.Mask.Size() - if err := pgio.WriteByte(w, byte(ones)); err != nil { - return false, err - } + buf = append(buf, byte(ones)) // is_cidr is ignored on server - if err := pgio.WriteByte(w, 0); err != nil { - return false, err - } + buf = append(buf, 0) - if err := pgio.WriteByte(w, byte(len(src.IPNet.IP))); err != nil { - return false, err - } + buf = append(buf, byte(len(src.IPNet.IP))) - _, err := w.Write(src.IPNet.IP) - return false, err + return append(buf, src.IPNet.IP...), nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/inet_array.go b/pgtype/inet_array.go index ed9f5d1c..e448a2ca 100644 --- a/pgtype/inet_array.go +++ b/pgtype/inet_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "net" "github.com/jackc/pgx/pgio" @@ -192,23 +190,19 @@ func (dst *InetArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *InetArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *InetArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -221,59 +215,44 @@ func (src *InetArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *InetArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *InetArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -283,7 +262,7 @@ func (src *InetArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { if dt, ok := ci.DataTypeForName("inet"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "inet") + return nil, fmt.Errorf("unable to find oid for type name %v", "inet") } for i := range src.Elements { @@ -293,38 +272,23 @@ func (src *InetArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -347,14 +311,13 @@ func (dst *InetArray) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *InetArray) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/pgtype/int2.go b/pgtype/int2.go index 028cdfcf..a58c3355 100644 --- a/pgtype/int2.go +++ b/pgtype/int2.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "math" "strconv" @@ -134,28 +133,26 @@ func (dst *Int2) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Int2) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int2) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, strconv.FormatInt(int64(src.Int), 10)) - return false, err + return append(buf, strconv.FormatInt(int64(src.Int), 10)...), nil } -func (src *Int2) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int2) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := pgio.WriteInt16(w, src.Int) - return false, err + return pgio.AppendInt16(buf, src.Int), nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/int2_array.go b/pgtype/int2_array.go index cdfcde48..1d145584 100644 --- a/pgtype/int2_array.go +++ b/pgtype/int2_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -191,23 +189,19 @@ func (dst *Int2Array) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Int2Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int2Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -220,59 +214,44 @@ func (src *Int2Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *Int2Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int2Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -282,7 +261,7 @@ func (src *Int2Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { if dt, ok := ci.DataTypeForName("int2"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "int2") + return nil, fmt.Errorf("unable to find oid for type name %v", "int2") } for i := range src.Elements { @@ -292,38 +271,23 @@ func (src *Int2Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -346,14 +310,13 @@ func (dst *Int2Array) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Int2Array) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/pgtype/int4.go b/pgtype/int4.go index cae0d32a..6f95013b 100644 --- a/pgtype/int4.go +++ b/pgtype/int4.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "math" "strconv" @@ -125,28 +124,26 @@ func (dst *Int4) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Int4) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int4) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, strconv.FormatInt(int64(src.Int), 10)) - return false, err + return append(buf, strconv.FormatInt(int64(src.Int), 10)...), nil } -func (src *Int4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int4) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := pgio.WriteInt32(w, src.Int) - return false, err + return pgio.AppendInt32(buf, src.Int), nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/int4_array.go b/pgtype/int4_array.go index 9ca0b067..1c746503 100644 --- a/pgtype/int4_array.go +++ b/pgtype/int4_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -191,23 +189,19 @@ func (dst *Int4Array) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Int4Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int4Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -220,59 +214,44 @@ func (src *Int4Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *Int4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -282,7 +261,7 @@ func (src *Int4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { if dt, ok := ci.DataTypeForName("int4"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "int4") + return nil, fmt.Errorf("unable to find oid for type name %v", "int4") } for i := range src.Elements { @@ -292,38 +271,23 @@ func (src *Int4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -346,14 +310,13 @@ func (dst *Int4Array) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Int4Array) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/pgtype/int4range.go b/pgtype/int4range.go index 29b8371e..4f27ff0d 100644 --- a/pgtype/int4range.go +++ b/pgtype/int4range.go @@ -1,10 +1,8 @@ package pgtype import ( - "bytes" "database/sql/driver" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -106,72 +104,65 @@ func (dst *Int4range) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Int4range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Int4range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } switch src.LowerType { case Exclusive, Unbounded: - if err := pgio.WriteByte(w, '('); err != nil { - return false, err - } + buf = append(buf, '(') case Inclusive: - if err := pgio.WriteByte(w, '['); err != nil { - return false, err - } + buf = append(buf, '[') case Empty: - _, err := io.WriteString(w, "empty") - return false, err + return append(buf, "empty"...), nil default: - return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) } + var err error + if src.LowerType != Unbounded { - if null, err := src.Lower.EncodeText(ci, w); err != nil { - return false, err - } else if null { - return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } } - if err := pgio.WriteByte(w, ','); err != nil { - return false, err - } + buf = append(buf, ',') if src.UpperType != Unbounded { - if null, err := src.Upper.EncodeText(ci, w); err != nil { - return false, err - } else if null { - return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } } switch src.UpperType { case Exclusive, Unbounded: - if err := pgio.WriteByte(w, ')'); err != nil { - return false, err - } + buf = append(buf, ')') case Inclusive: - if err := pgio.WriteByte(w, ']'); err != nil { - return false, err - } + buf = append(buf, ']') default: - return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) } - return false, nil + return buf, nil } -func (src Int4range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Int4range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } var rangeType byte @@ -182,10 +173,9 @@ func (src Int4range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { rangeType |= lowerUnboundedMask case Exclusive: case Empty: - err := pgio.WriteByte(w, emptyMask) - return false, err + return append(buf, emptyMask), nil default: - return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -195,54 +185,44 @@ func (src Int4range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) } - if err := pgio.WriteByte(w, rangeType); err != nil { - return false, err - } + buf = append(buf, rangeType) - valBuf := &bytes.Buffer{} + var err error if src.LowerType != Unbounded { - null, err := src.Lower.EncodeBinary(ci, valBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } - _, err = pgio.WriteInt32(w, int32(valBuf.Len())) - if err != nil { - return false, err - } - _, err = valBuf.WriteTo(w) - if err != nil { - return false, err - } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } if src.UpperType != Unbounded { - null, err := src.Upper.EncodeBinary(ci, valBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } - _, err = pgio.WriteInt32(w, int32(valBuf.Len())) - if err != nil { - return false, err - } - _, err = valBuf.WriteTo(w) - if err != nil { - return false, err - } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } - return false, nil + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/int8.go b/pgtype/int8.go index a4ec4e62..939c0554 100644 --- a/pgtype/int8.go +++ b/pgtype/int8.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "math" "strconv" @@ -117,28 +116,26 @@ func (dst *Int8) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Int8) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int8) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, strconv.FormatInt(src.Int, 10)) - return false, err + return append(buf, strconv.FormatInt(src.Int, 10)...), nil } -func (src *Int8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int8) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := pgio.WriteInt64(w, src.Int) - return false, err + return pgio.AppendInt64(buf, src.Int), nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/int8_array.go b/pgtype/int8_array.go index c5026f83..56ebcab8 100644 --- a/pgtype/int8_array.go +++ b/pgtype/int8_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -191,23 +189,19 @@ func (dst *Int8Array) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Int8Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int8Array) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -220,59 +214,44 @@ func (src *Int8Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *Int8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Int8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -282,7 +261,7 @@ func (src *Int8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { if dt, ok := ci.DataTypeForName("int8"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "int8") + return nil, fmt.Errorf("unable to find oid for type name %v", "int8") } for i := range src.Elements { @@ -292,38 +271,23 @@ func (src *Int8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -346,14 +310,13 @@ func (dst *Int8Array) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Int8Array) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/pgtype/int8range.go b/pgtype/int8range.go index e3e0486f..128a853f 100644 --- a/pgtype/int8range.go +++ b/pgtype/int8range.go @@ -1,10 +1,8 @@ package pgtype import ( - "bytes" "database/sql/driver" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -106,72 +104,65 @@ func (dst *Int8range) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Int8range) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Int8range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } switch src.LowerType { case Exclusive, Unbounded: - if err := pgio.WriteByte(w, '('); err != nil { - return false, err - } + buf = append(buf, '(') case Inclusive: - if err := pgio.WriteByte(w, '['); err != nil { - return false, err - } + buf = append(buf, '[') case Empty: - _, err := io.WriteString(w, "empty") - return false, err + return append(buf, "empty"...), nil default: - return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) } + var err error + if src.LowerType != Unbounded { - if null, err := src.Lower.EncodeText(ci, w); err != nil { - return false, err - } else if null { - return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } } - if err := pgio.WriteByte(w, ','); err != nil { - return false, err - } + buf = append(buf, ',') if src.UpperType != Unbounded { - if null, err := src.Upper.EncodeText(ci, w); err != nil { - return false, err - } else if null { - return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } } switch src.UpperType { case Exclusive, Unbounded: - if err := pgio.WriteByte(w, ')'); err != nil { - return false, err - } + buf = append(buf, ')') case Inclusive: - if err := pgio.WriteByte(w, ']'); err != nil { - return false, err - } + buf = append(buf, ']') default: - return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) } - return false, nil + return buf, nil } -func (src Int8range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Int8range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } var rangeType byte @@ -182,10 +173,9 @@ func (src Int8range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { rangeType |= lowerUnboundedMask case Exclusive: case Empty: - err := pgio.WriteByte(w, emptyMask) - return false, err + return append(buf, emptyMask), nil default: - return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -195,54 +185,44 @@ func (src Int8range) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) } - if err := pgio.WriteByte(w, rangeType); err != nil { - return false, err - } + buf = append(buf, rangeType) - valBuf := &bytes.Buffer{} + var err error if src.LowerType != Unbounded { - null, err := src.Lower.EncodeBinary(ci, valBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } - _, err = pgio.WriteInt32(w, int32(valBuf.Len())) - if err != nil { - return false, err - } - _, err = valBuf.WriteTo(w) - if err != nil { - return false, err - } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } if src.UpperType != Unbounded { - null, err := src.Upper.EncodeBinary(ci, valBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } - _, err = pgio.WriteInt32(w, int32(valBuf.Len())) - if err != nil { - return false, err - } - _, err = valBuf.WriteTo(w) - if err != nil { - return false, err - } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } - return false, nil + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/interval.go b/pgtype/interval.go index 8ce345a3..ea5c7d3e 100644 --- a/pgtype/interval.go +++ b/pgtype/interval.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "strconv" "strings" "time" @@ -178,41 +177,28 @@ func (dst *Interval) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Interval) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Interval) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if src.Months != 0 { - if _, err := io.WriteString(w, strconv.FormatInt(int64(src.Months), 10)); err != nil { - return false, err - } - - if _, err := io.WriteString(w, " mon "); err != nil { - return false, err - } + buf = append(buf, strconv.FormatInt(int64(src.Months), 10)...) + buf = append(buf, " mon "...) } if src.Days != 0 { - if _, err := io.WriteString(w, strconv.FormatInt(int64(src.Days), 10)); err != nil { - return false, err - } - - if _, err := io.WriteString(w, " day "); err != nil { - return false, err - } + buf = append(buf, strconv.FormatInt(int64(src.Days), 10)...) + buf = append(buf, " day "...) } absMicroseconds := src.Microseconds if absMicroseconds < 0 { absMicroseconds = -absMicroseconds - - if err := pgio.WriteByte(w, '-'); err != nil { - return false, err - } + buf = append(buf, '-') } hours := absMicroseconds / microsecondsPerHour @@ -221,31 +207,21 @@ func (src *Interval) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { microseconds := absMicroseconds % microsecondsPerSecond timeStr := fmt.Sprintf("%02d:%02d:%02d.%06d", hours, minutes, seconds, microseconds) - - _, err := io.WriteString(w, timeStr) - return false, err + return append(buf, timeStr...), nil } // EncodeBinary encodes src into w. -func (src *Interval) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Interval) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - if _, err := pgio.WriteInt64(w, src.Microseconds); err != nil { - return false, err - } - if _, err := pgio.WriteInt32(w, src.Days); err != nil { - return false, err - } - if _, err := pgio.WriteInt32(w, src.Months); err != nil { - return false, err - } - - return false, nil + buf = pgio.AppendInt64(buf, src.Microseconds) + buf = pgio.AppendInt32(buf, src.Days) + return pgio.AppendInt32(buf, src.Months), nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/json.go b/pgtype/json.go index 44880863..91d31129 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/json" "fmt" - "io" ) type Json struct { @@ -105,20 +104,19 @@ func (dst *Json) DecodeBinary(ci *ConnInfo, src []byte) error { return dst.DecodeText(ci, src) } -func (src *Json) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Json) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := w.Write(src.Bytes) - return false, err + return append(buf, src.Bytes...), nil } -func (src *Json) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.EncodeText(ci, w) +func (src *Json) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return src.EncodeText(ci, buf) } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/jsonb.go b/pgtype/jsonb.go index 5533b4b4..f7914202 100644 --- a/pgtype/jsonb.go +++ b/pgtype/jsonb.go @@ -3,7 +3,6 @@ package pgtype import ( "database/sql/driver" "fmt" - "io" ) type Jsonb Json @@ -43,25 +42,20 @@ func (dst *Jsonb) DecodeBinary(ci *ConnInfo, src []byte) error { } -func (src *Jsonb) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (*Json)(src).EncodeText(ci, w) +func (src *Jsonb) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Json)(src).EncodeText(ci, buf) } -func (src *Jsonb) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Jsonb) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := w.Write([]byte{1}) - if err != nil { - return false, err - } - - _, err = w.Write(src.Bytes) - return false, err + buf = append(buf, 1) + return append(buf, src.Bytes...), nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/line.go b/pgtype/line.go index 75fdf207..47f636a5 100644 --- a/pgtype/line.go +++ b/pgtype/line.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "math" "strconv" "strings" @@ -93,36 +92,29 @@ func (dst *Line) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Line) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Line) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, fmt.Sprintf(`{%f,%f,%f}`, src.A, src.B, src.C)) - return false, err + return append(buf, fmt.Sprintf(`{%f,%f,%f}`, src.A, src.B, src.C)...), nil } -func (src *Line) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Line) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - if _, err := pgio.WriteUint64(w, math.Float64bits(src.A)); err != nil { - return false, err - } - - if _, err := pgio.WriteUint64(w, math.Float64bits(src.B)); err != nil { - return false, err - } - - _, err := pgio.WriteUint64(w, math.Float64bits(src.C)) - return false, err + buf = pgio.AppendUint64(buf, math.Float64bits(src.A)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.B)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.C)) + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/lseg.go b/pgtype/lseg.go index 823c2c09..44c2b63c 100644 --- a/pgtype/lseg.go +++ b/pgtype/lseg.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "math" "strconv" "strings" @@ -108,41 +107,32 @@ func (dst *Lseg) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Lseg) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Lseg) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f),(%f,%f)`, - src.P[0].X, src.P[0].Y, src.P[1].X, src.P[1].Y)) - return false, err + buf = append(buf, fmt.Sprintf(`(%f,%f),(%f,%f)`, + src.P[0].X, src.P[0].Y, src.P[1].X, src.P[1].Y)...) + return buf, nil } -func (src *Lseg) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Lseg) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - if _, err := pgio.WriteUint64(w, math.Float64bits(src.P[0].X)); err != nil { - return false, err - } - - if _, err := pgio.WriteUint64(w, math.Float64bits(src.P[0].Y)); err != nil { - return false, err - } - - if _, err := pgio.WriteUint64(w, math.Float64bits(src.P[1].X)); err != nil { - return false, err - } - - _, err := pgio.WriteUint64(w, math.Float64bits(src.P[1].Y)) - return false, err + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].Y)) + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/macaddr.go b/pgtype/macaddr.go index 785148a2..e38701eb 100644 --- a/pgtype/macaddr.go +++ b/pgtype/macaddr.go @@ -3,7 +3,6 @@ package pgtype import ( "database/sql/driver" "fmt" - "io" "net" ) @@ -106,29 +105,27 @@ func (dst *Macaddr) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Macaddr) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Macaddr) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, src.Addr.String()) - return false, err + return append(buf, src.Addr.String()...), nil } // EncodeBinary encodes src into w. -func (src *Macaddr) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Macaddr) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := w.Write([]byte(src.Addr)) - return false, err + return append(buf, src.Addr...), nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/name.go b/pgtype/name.go index 05e92563..af064a82 100644 --- a/pgtype/name.go +++ b/pgtype/name.go @@ -2,7 +2,6 @@ package pgtype import ( "database/sql/driver" - "io" ) // Name is a type used for PostgreSQL's special 63-byte @@ -40,12 +39,12 @@ func (dst *Name) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Text)(dst).DecodeBinary(ci, src) } -func (src *Name) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (*Text)(src).EncodeText(ci, w) +func (src *Name) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Text)(src).EncodeText(ci, buf) } -func (src *Name) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (*Text)(src).EncodeBinary(ci, w) +func (src *Name) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Text)(src).EncodeBinary(ci, buf) } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/numeric.go b/pgtype/numeric.go index 8dbc0251..dffb9963 100644 --- a/pgtype/numeric.go +++ b/pgtype/numeric.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "math" "math/big" "strconv" @@ -455,36 +453,26 @@ func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) { return accum, rp, digits } -func (src *Numeric) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Numeric) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - if _, err := io.WriteString(w, src.Int.String()); err != nil { - return false, err - } - - if err := pgio.WriteByte(w, 'e'); err != nil { - return false, err - } - - if _, err := io.WriteString(w, strconv.FormatInt(int64(src.Exp), 10)); err != nil { - return false, err - } - - return false, nil - + buf = append(buf, src.Int.String()...) + buf = append(buf, 'e') + buf = append(buf, strconv.FormatInt(int64(src.Exp), 10)...) + return buf, nil } -func (src *Numeric) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } var sign int16 @@ -535,9 +523,7 @@ func (src *Numeric) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { fracDigits = append(fracDigits, int16(remainder.Int64())) } - if _, err := pgio.WriteInt16(w, int16(len(wholeDigits)+len(fracDigits))); err != nil { - return false, err - } + buf = pgio.AppendInt16(buf, int16(len(wholeDigits)+len(fracDigits))) var weight int16 if len(wholeDigits) > 0 { @@ -548,35 +534,25 @@ func (src *Numeric) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } else { weight = int16(exp/4) - 1 + int16(len(fracDigits)) } - if _, err := pgio.WriteInt16(w, weight); err != nil { - return false, err - } + buf = pgio.AppendInt16(buf, weight) - if _, err := pgio.WriteInt16(w, sign); err != nil { - return false, err - } + buf = pgio.AppendInt16(buf, sign) var dscale int16 if src.Exp < 0 { dscale = int16(-src.Exp) } - if _, err := pgio.WriteInt16(w, dscale); err != nil { - return false, err - } + buf = pgio.AppendInt16(buf, dscale) for i := len(wholeDigits) - 1; i >= 0; i-- { - if _, err := pgio.WriteInt16(w, wholeDigits[i]); err != nil { - return false, err - } + buf = pgio.AppendInt16(buf, wholeDigits[i]) } for i := len(fracDigits) - 1; i >= 0; i-- { - if _, err := pgio.WriteInt16(w, fracDigits[i]); err != nil { - return false, err - } + buf = pgio.AppendInt16(buf, fracDigits[i]) } - return false, nil + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -606,13 +582,12 @@ func (dst *Numeric) Scan(src interface{}) error { func (src *Numeric) Value() (driver.Value, error) { switch src.Status { case Present: - buf := &bytes.Buffer{} - _, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - return buf.String(), nil + return string(buf), nil case Null: return nil, nil default: diff --git a/pgtype/numeric_array.go b/pgtype/numeric_array.go index 2fc844eb..20f33dff 100644 --- a/pgtype/numeric_array.go +++ b/pgtype/numeric_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -191,23 +189,19 @@ func (dst *NumericArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *NumericArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *NumericArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -220,59 +214,44 @@ func (src *NumericArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *NumericArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *NumericArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -282,7 +261,7 @@ func (src *NumericArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { if dt, ok := ci.DataTypeForName("numeric"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "numeric") + return nil, fmt.Errorf("unable to find oid for type name %v", "numeric") } for i := range src.Elements { @@ -292,38 +271,23 @@ func (src *NumericArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -346,14 +310,13 @@ func (dst *NumericArray) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *NumericArray) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/pgtype/numrange.go b/pgtype/numrange.go index bac6fc4b..00133296 100644 --- a/pgtype/numrange.go +++ b/pgtype/numrange.go @@ -1,10 +1,8 @@ package pgtype import ( - "bytes" "database/sql/driver" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -106,72 +104,65 @@ func (dst *Numrange) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Numrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Numrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } switch src.LowerType { case Exclusive, Unbounded: - if err := pgio.WriteByte(w, '('); err != nil { - return false, err - } + buf = append(buf, '(') case Inclusive: - if err := pgio.WriteByte(w, '['); err != nil { - return false, err - } + buf = append(buf, '[') case Empty: - _, err := io.WriteString(w, "empty") - return false, err + return append(buf, "empty"...), nil default: - return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) } + var err error + if src.LowerType != Unbounded { - if null, err := src.Lower.EncodeText(ci, w); err != nil { - return false, err - } else if null { - return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } } - if err := pgio.WriteByte(w, ','); err != nil { - return false, err - } + buf = append(buf, ',') if src.UpperType != Unbounded { - if null, err := src.Upper.EncodeText(ci, w); err != nil { - return false, err - } else if null { - return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } } switch src.UpperType { case Exclusive, Unbounded: - if err := pgio.WriteByte(w, ')'); err != nil { - return false, err - } + buf = append(buf, ')') case Inclusive: - if err := pgio.WriteByte(w, ']'); err != nil { - return false, err - } + buf = append(buf, ']') default: - return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) } - return false, nil + return buf, nil } -func (src Numrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Numrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } var rangeType byte @@ -182,10 +173,9 @@ func (src Numrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { rangeType |= lowerUnboundedMask case Exclusive: case Empty: - err := pgio.WriteByte(w, emptyMask) - return false, err + return append(buf, emptyMask), nil default: - return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -195,54 +185,44 @@ func (src Numrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) } - if err := pgio.WriteByte(w, rangeType); err != nil { - return false, err - } + buf = append(buf, rangeType) - valBuf := &bytes.Buffer{} + var err error if src.LowerType != Unbounded { - null, err := src.Lower.EncodeBinary(ci, valBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } - _, err = pgio.WriteInt32(w, int32(valBuf.Len())) - if err != nil { - return false, err - } - _, err = valBuf.WriteTo(w) - if err != nil { - return false, err - } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } if src.UpperType != Unbounded { - null, err := src.Upper.EncodeBinary(ci, valBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } - _, err = pgio.WriteInt32(w, int32(valBuf.Len())) - if err != nil { - return false, err - } - _, err = valBuf.WriteTo(w) - if err != nil { - return false, err - } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } - return false, nil + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/oid.go b/pgtype/oid.go index 58a7b0f5..6ceacc73 100644 --- a/pgtype/oid.go +++ b/pgtype/oid.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "strconv" "github.com/jackc/pgx/pgio" @@ -47,14 +46,12 @@ func (dst *Oid) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Oid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - _, err := io.WriteString(w, strconv.FormatUint(uint64(src), 10)) - return false, err +func (src Oid) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return append(buf, strconv.FormatUint(uint64(src), 10)...), nil } -func (src Oid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - _, err := pgio.WriteUint32(w, uint32(src)) - return false, err +func (src Oid) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return pgio.AppendUint32(buf, uint32(src)), nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/oid_value.go b/pgtype/oid_value.go index 4a7de921..882d54fb 100644 --- a/pgtype/oid_value.go +++ b/pgtype/oid_value.go @@ -2,7 +2,6 @@ package pgtype import ( "database/sql/driver" - "io" ) // OidValue (Object Identifier Type) is, according to @@ -37,12 +36,12 @@ func (dst *OidValue) DecodeBinary(ci *ConnInfo, src []byte) error { return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src *OidValue) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (*pguint32)(src).EncodeText(ci, w) +func (src *OidValue) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*pguint32)(src).EncodeText(ci, buf) } -func (src *OidValue) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (*pguint32)(src).EncodeBinary(ci, w) +func (src *OidValue) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*pguint32)(src).EncodeBinary(ci, buf) } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/path.go b/pgtype/path.go index c1aa76bc..3575342d 100644 --- a/pgtype/path.go +++ b/pgtype/path.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "math" "strconv" "strings" @@ -116,12 +115,12 @@ func (dst *Path) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Path) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Path) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } var startByte, endByte byte @@ -132,56 +131,40 @@ func (src *Path) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { startByte = '[' endByte = ']' } - if err := pgio.WriteByte(w, startByte); err != nil { - return false, err - } + buf = append(buf, startByte) for i, p := range src.P { if i > 0 { - if err := pgio.WriteByte(w, ','); err != nil { - return false, err - } - } - if _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f)`, p.X, p.Y)); err != nil { - return false, err + buf = append(buf, ',') } + buf = append(buf, fmt.Sprintf(`(%f,%f)`, p.X, p.Y)...) } - err := pgio.WriteByte(w, endByte) - return false, err + return append(buf, endByte), nil } -func (src *Path) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Path) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } var closeByte byte if src.Closed { closeByte = 1 } - if err := pgio.WriteByte(w, closeByte); err != nil { - return false, err - } + buf = append(buf, closeByte) - if _, err := pgio.WriteInt32(w, int32(len(src.P))); err != nil { - return false, err - } + buf = pgio.AppendInt32(buf, int32(len(src.P))) for _, p := range src.P { - if _, err := pgio.WriteUint64(w, math.Float64bits(p.X)); err != nil { - return false, err - } - - if _, err := pgio.WriteUint64(w, math.Float64bits(p.Y)); err != nil { - return false, err - } + buf = pgio.AppendUint64(buf, math.Float64bits(p.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(p.Y)) } - return false, nil + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 3a6b7471..847fce0f 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -2,7 +2,6 @@ package pgtype import ( "errors" - "io" "reflect" ) @@ -111,21 +110,21 @@ type TextDecoder interface { // BinaryEncoder is implemented by types that can encode themselves into the // PostgreSQL binary wire format. type BinaryEncoder interface { - // EncodeBinary should encode the binary format of self to w. If self is the - // SQL value NULL then write nothing and return (true, nil). The caller of + // EncodeBinary should append the binary format of self to buf. If self is the + // SQL value NULL then append nothing and return (nil, nil). The caller of // EncodeBinary is responsible for writing the correct NULL value or the // length of the data written. - EncodeBinary(ci *ConnInfo, w io.Writer) (null bool, err error) + EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, err error) } // TextEncoder is implemented by types that can encode themselves into the // PostgreSQL text wire format. type TextEncoder interface { - // EncodeText should encode the text format of self to w. If self is the SQL - // value NULL then write nothing and return (true, nil). The caller of - // EncodeText is responsible for writing the correct NULL value or the length - // of the data written. - EncodeText(ci *ConnInfo, w io.Writer) (null bool, err error) + // EncodeText should append the text format of self to buf. If self is the + // SQL value NULL then append nothing and return (nil, nil). The caller of + // EncodeText is responsible for writing the correct NULL value or the + // length of the data written. + EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, err error) } var errUndefined = errors.New("cannot encode status undefined") diff --git a/pgtype/pguint32.go b/pgtype/pguint32.go index a13c1fcd..c15ee6d7 100644 --- a/pgtype/pguint32.go +++ b/pgtype/pguint32.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "math" "strconv" @@ -103,28 +102,26 @@ func (dst *pguint32) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *pguint32) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *pguint32) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, strconv.FormatUint(uint64(src.Uint), 10)) - return false, err + return append(buf, strconv.FormatUint(uint64(src.Uint), 10)...), nil } -func (src *pguint32) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *pguint32) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := pgio.WriteUint32(w, src.Uint) - return false, err + return pgio.AppendUint32(buf, src.Uint), nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/point.go b/pgtype/point.go index 62901340..3d5d4e1a 100644 --- a/pgtype/point.go +++ b/pgtype/point.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "math" "strconv" "strings" @@ -90,33 +89,28 @@ func (dst *Point) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Point) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Point) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f)`, src.P.X, src.P.Y)) - return false, err + return append(buf, fmt.Sprintf(`(%f,%f)`, src.P.X, src.P.Y)...), nil } -func (src *Point) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Point) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := pgio.WriteUint64(w, math.Float64bits(src.P.X)) - if err != nil { - return false, err - } - - _, err = pgio.WriteUint64(w, math.Float64bits(src.P.Y)) - return false, err + buf = pgio.AppendUint64(buf, math.Float64bits(src.P.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(src.P.Y)) + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/polygon.go b/pgtype/polygon.go index c4383765..d0b50061 100644 --- a/pgtype/polygon.go +++ b/pgtype/polygon.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "math" "strconv" "strings" @@ -111,56 +110,42 @@ func (dst *Polygon) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Polygon) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Polygon) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - if err := pgio.WriteByte(w, '('); err != nil { - return false, err - } + buf = append(buf, '(') for i, p := range src.P { if i > 0 { - if err := pgio.WriteByte(w, ','); err != nil { - return false, err - } - } - if _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f)`, p.X, p.Y)); err != nil { - return false, err + buf = append(buf, ',') } + buf = append(buf, fmt.Sprintf(`(%f,%f)`, p.X, p.Y)...) } - err := pgio.WriteByte(w, ')') - return false, err + return append(buf, ')'), nil } -func (src *Polygon) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Polygon) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - if _, err := pgio.WriteInt32(w, int32(len(src.P))); err != nil { - return false, err - } + buf = pgio.AppendInt32(buf, int32(len(src.P))) for _, p := range src.P { - if _, err := pgio.WriteUint64(w, math.Float64bits(p.X)); err != nil { - return false, err - } - - if _, err := pgio.WriteUint64(w, math.Float64bits(p.Y)); err != nil { - return false, err - } + buf = pgio.AppendUint64(buf, math.Float64bits(p.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(p.Y)) } - return false, nil + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/qchar.go b/pgtype/qchar.go index 10b56534..9c40ce18 100644 --- a/pgtype/qchar.go +++ b/pgtype/qchar.go @@ -2,11 +2,8 @@ package pgtype import ( "fmt" - "io" "math" "strconv" - - "github.com/jackc/pgx/pgio" ) // QChar is for PostgreSQL's special 8-bit-only "char" type more akin to the C @@ -136,13 +133,13 @@ func (dst *QChar) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *QChar) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *QChar) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - return false, pgio.WriteByte(w, byte(src.Int)) + return append(buf, byte(src.Int)), nil } diff --git a/pgtype/testutil/testutil.go b/pgtype/testutil/testutil.go index 5dd2fbe1..0effb42d 100644 --- a/pgtype/testutil/testutil.go +++ b/pgtype/testutil/testutil.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "fmt" - "io" "os" "reflect" "testing" @@ -61,16 +60,16 @@ type forceTextEncoder struct { e pgtype.TextEncoder } -func (f forceTextEncoder) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { - return f.e.EncodeText(ci, w) +func (f forceTextEncoder) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + return f.e.EncodeText(ci, buf) } type forceBinaryEncoder struct { e pgtype.BinaryEncoder } -func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { - return f.e.EncodeBinary(ci, w) +func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + return f.e.EncodeBinary(ci, buf) } func ForceEncoder(e interface{}, formatCode int16) interface{} { diff --git a/pgtype/text.go b/pgtype/text.go index 54e2d774..6638c354 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/json" "fmt" - "io" ) type Text struct { @@ -91,20 +90,19 @@ func (dst *Text) DecodeBinary(ci *ConnInfo, src []byte) error { return dst.DecodeText(ci, src) } -func (src *Text) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Text) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, src.String) - return false, err + return append(buf, src.String...), nil } -func (src *Text) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.EncodeText(ci, w) +func (src *Text) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return src.EncodeText(ci, buf) } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/text_array.go b/pgtype/text_array.go index 8a573d83..ed240e12 100644 --- a/pgtype/text_array.go +++ b/pgtype/text_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -163,23 +161,19 @@ func (dst *TextArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *TextArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *TextArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -192,59 +186,44 @@ func (src *TextArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `"NULL"`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `"NULL"`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *TextArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *TextArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -254,7 +233,7 @@ func (src *TextArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { if dt, ok := ci.DataTypeForName("text"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "text") + return nil, fmt.Errorf("unable to find oid for type name %v", "text") } for i := range src.Elements { @@ -264,38 +243,23 @@ func (src *TextArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -318,14 +282,13 @@ func (dst *TextArray) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *TextArray) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/pgtype/tid.go b/pgtype/tid.go index 7456b155..2f4412cb 100644 --- a/pgtype/tid.go +++ b/pgtype/tid.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "strconv" "strings" @@ -94,33 +93,29 @@ func (dst *Tid) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Tid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Tid) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, fmt.Sprintf(`(%d,%d)`, src.BlockNumber, src.OffsetNumber)) - return false, err + buf = append(buf, fmt.Sprintf(`(%d,%d)`, src.BlockNumber, src.OffsetNumber)...) + return buf, nil } -func (src *Tid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Tid) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := pgio.WriteUint32(w, src.BlockNumber) - if err != nil { - return false, err - } - - _, err = pgio.WriteUint16(w, src.OffsetNumber) - return false, err + buf = pgio.AppendUint32(buf, src.BlockNumber) + buf = pgio.AppendUint16(buf, src.OffsetNumber) + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index 4fb10abc..75c6cffa 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "time" "github.com/jackc/pgx/pgio" @@ -136,15 +135,15 @@ func (dst *Timestamp) DecodeBinary(ci *ConnInfo, src []byte) error { // EncodeText writes the text encoding of src into w. If src.Time is not in // the UTC time zone it returns an error. -func (src *Timestamp) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Timestamp) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if src.Time.Location() != time.UTC { - return false, fmt.Errorf("cannot encode non-UTC time into timestamp") + return nil, fmt.Errorf("cannot encode non-UTC time into timestamp") } var s string @@ -158,21 +157,20 @@ func (src *Timestamp) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { s = "-infinity" } - _, err := io.WriteString(w, s) - return false, err + return append(buf, s...), nil } // EncodeBinary writes the binary encoding of src into w. If src.Time is not in // the UTC time zone it returns an error. -func (src *Timestamp) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Timestamp) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if src.Time.Location() != time.UTC { - return false, fmt.Errorf("cannot encode non-UTC time into timestamp") + return nil, fmt.Errorf("cannot encode non-UTC time into timestamp") } var microsecSinceY2K int64 @@ -186,8 +184,7 @@ func (src *Timestamp) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { microsecSinceY2K = negativeInfinityMicrosecondOffset } - _, err := pgio.WriteInt64(w, microsecSinceY2K) - return false, err + return pgio.AppendInt64(buf, microsecSinceY2K), nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/timestamp_array.go b/pgtype/timestamp_array.go index 49815dae..a4f1b9dd 100644 --- a/pgtype/timestamp_array.go +++ b/pgtype/timestamp_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "time" "github.com/jackc/pgx/pgio" @@ -164,23 +162,19 @@ func (dst *TimestampArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *TimestampArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *TimestampArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -193,59 +187,44 @@ func (src *TimestampArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *TimestampArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *TimestampArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -255,7 +234,7 @@ func (src *TimestampArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) if dt, ok := ci.DataTypeForName("timestamp"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "timestamp") + return nil, fmt.Errorf("unable to find oid for type name %v", "timestamp") } for i := range src.Elements { @@ -265,38 +244,23 @@ func (src *TimestampArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -319,14 +283,13 @@ func (dst *TimestampArray) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *TimestampArray) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index 8606b2f2..97b0de2a 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "time" "github.com/jackc/pgx/pgio" @@ -140,12 +139,12 @@ func (dst *Timestamptz) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Timestamptz) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Timestamptz) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } var s string @@ -159,16 +158,15 @@ func (src *Timestamptz) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { s = "-infinity" } - _, err := io.WriteString(w, s) - return false, err + return append(buf, s...), nil } -func (src *Timestamptz) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Timestamptz) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } var microsecSinceY2K int64 @@ -182,8 +180,7 @@ func (src *Timestamptz) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { microsecSinceY2K = negativeInfinityMicrosecondOffset } - _, err := pgio.WriteInt64(w, microsecSinceY2K) - return false, err + return pgio.AppendInt64(buf, microsecSinceY2K), nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/timestamptz_array.go b/pgtype/timestamptz_array.go index bf983b6b..34d4f8a8 100644 --- a/pgtype/timestamptz_array.go +++ b/pgtype/timestamptz_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "time" "github.com/jackc/pgx/pgio" @@ -164,23 +162,19 @@ func (dst *TimestamptzArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *TimestamptzArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *TimestamptzArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -193,59 +187,44 @@ func (src *TimestamptzArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `NULL`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `NULL`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -255,7 +234,7 @@ func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, erro if dt, ok := ci.DataTypeForName("timestamptz"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "timestamptz") + return nil, fmt.Errorf("unable to find oid for type name %v", "timestamptz") } for i := range src.Elements { @@ -265,38 +244,23 @@ func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, erro } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -319,14 +283,13 @@ func (dst *TimestamptzArray) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *TimestamptzArray) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/pgtype/tsrange.go b/pgtype/tsrange.go index 429a5cbe..783fb086 100644 --- a/pgtype/tsrange.go +++ b/pgtype/tsrange.go @@ -1,10 +1,8 @@ package pgtype import ( - "bytes" "database/sql/driver" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -106,72 +104,65 @@ func (dst *Tsrange) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Tsrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Tsrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } switch src.LowerType { case Exclusive, Unbounded: - if err := pgio.WriteByte(w, '('); err != nil { - return false, err - } + buf = append(buf, '(') case Inclusive: - if err := pgio.WriteByte(w, '['); err != nil { - return false, err - } + buf = append(buf, '[') case Empty: - _, err := io.WriteString(w, "empty") - return false, err + return append(buf, "empty"...), nil default: - return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) } + var err error + if src.LowerType != Unbounded { - if null, err := src.Lower.EncodeText(ci, w); err != nil { - return false, err - } else if null { - return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } } - if err := pgio.WriteByte(w, ','); err != nil { - return false, err - } + buf = append(buf, ',') if src.UpperType != Unbounded { - if null, err := src.Upper.EncodeText(ci, w); err != nil { - return false, err - } else if null { - return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } } switch src.UpperType { case Exclusive, Unbounded: - if err := pgio.WriteByte(w, ')'); err != nil { - return false, err - } + buf = append(buf, ')') case Inclusive: - if err := pgio.WriteByte(w, ']'); err != nil { - return false, err - } + buf = append(buf, ']') default: - return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) } - return false, nil + return buf, nil } -func (src Tsrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Tsrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } var rangeType byte @@ -182,10 +173,9 @@ func (src Tsrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { rangeType |= lowerUnboundedMask case Exclusive: case Empty: - err := pgio.WriteByte(w, emptyMask) - return false, err + return append(buf, emptyMask), nil default: - return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -195,54 +185,44 @@ func (src Tsrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) } - if err := pgio.WriteByte(w, rangeType); err != nil { - return false, err - } + buf = append(buf, rangeType) - valBuf := &bytes.Buffer{} + var err error if src.LowerType != Unbounded { - null, err := src.Lower.EncodeBinary(ci, valBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } - _, err = pgio.WriteInt32(w, int32(valBuf.Len())) - if err != nil { - return false, err - } - _, err = valBuf.WriteTo(w) - if err != nil { - return false, err - } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } if src.UpperType != Unbounded { - null, err := src.Upper.EncodeBinary(ci, valBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } - _, err = pgio.WriteInt32(w, int32(valBuf.Len())) - if err != nil { - return false, err - } - _, err = valBuf.WriteTo(w) - if err != nil { - return false, err - } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } - return false, nil + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/tstzrange.go b/pgtype/tstzrange.go index f03a9f65..8fd3fd68 100644 --- a/pgtype/tstzrange.go +++ b/pgtype/tstzrange.go @@ -1,10 +1,8 @@ package pgtype import ( - "bytes" "database/sql/driver" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -106,72 +104,65 @@ func (dst *Tstzrange) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src Tstzrange) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Tstzrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } switch src.LowerType { case Exclusive, Unbounded: - if err := pgio.WriteByte(w, '('); err != nil { - return false, err - } + buf = append(buf, '(') case Inclusive: - if err := pgio.WriteByte(w, '['); err != nil { - return false, err - } + buf = append(buf, '[') case Empty: - _, err := io.WriteString(w, "empty") - return false, err + return append(buf, "empty"...), nil default: - return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) } + var err error + if src.LowerType != Unbounded { - if null, err := src.Lower.EncodeText(ci, w); err != nil { - return false, err - } else if null { - return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } } - if err := pgio.WriteByte(w, ','); err != nil { - return false, err - } + buf = append(buf, ',') if src.UpperType != Unbounded { - if null, err := src.Upper.EncodeText(ci, w); err != nil { - return false, err - } else if null { - return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } } switch src.UpperType { case Exclusive, Unbounded: - if err := pgio.WriteByte(w, ')'); err != nil { - return false, err - } + buf = append(buf, ')') case Inclusive: - if err := pgio.WriteByte(w, ']'); err != nil { - return false, err - } + buf = append(buf, ']') default: - return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) } - return false, nil + return buf, nil } -func (src Tstzrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src Tstzrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } var rangeType byte @@ -182,10 +173,9 @@ func (src Tstzrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { rangeType |= lowerUnboundedMask case Exclusive: case Empty: - err := pgio.WriteByte(w, emptyMask) - return false, err + return append(buf, emptyMask), nil default: - return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -195,54 +185,44 @@ func (src Tstzrange) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) } - if err := pgio.WriteByte(w, rangeType); err != nil { - return false, err - } + buf = append(buf, rangeType) - valBuf := &bytes.Buffer{} + var err error if src.LowerType != Unbounded { - null, err := src.Lower.EncodeBinary(ci, valBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } - _, err = pgio.WriteInt32(w, int32(valBuf.Len())) - if err != nil { - return false, err - } - _, err = valBuf.WriteTo(w) - if err != nil { - return false, err - } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } if src.UpperType != Unbounded { - null, err := src.Upper.EncodeBinary(ci, valBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } - _, err = pgio.WriteInt32(w, int32(valBuf.Len())) - if err != nil { - return false, err - } - _, err = valBuf.WriteTo(w) - if err != nil { - return false, err - } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } - return false, nil + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb index 6752bd5b..0d454ac8 100644 --- a/pgtype/typed_array.go.erb +++ b/pgtype/typed_array.go.erb @@ -163,23 +163,19 @@ func (dst *<%= pgtype_array_type %>) DecodeBinary(ci *ConnInfo, src []byte) erro } <% end %> -func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -192,60 +188,45 @@ func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `<%= text_null %>`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `<%= text_null %>`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } <% if binary_format == "true" %> - func (src *<%= pgtype_array_type %>) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + func (src *<%= pgtype_array_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -255,7 +236,7 @@ func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool if dt, ok := ci.DataTypeForName("<%= element_type_name %>"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "<%= element_type_name %>") + return nil, fmt.Errorf("unable to find oid for type name %v", "<%= element_type_name %>") } for i := range src.Elements { @@ -265,38 +246,23 @@ func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } <% end %> @@ -320,14 +286,13 @@ func (dst *<%= pgtype_array_type %>) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *<%= pgtype_array_type %>) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/pgtype/typed_range.go.erb b/pgtype/typed_range.go.erb index 49db1b1d..90c23991 100644 --- a/pgtype/typed_range.go.erb +++ b/pgtype/typed_range.go.erb @@ -106,73 +106,66 @@ func (dst *<%= range_type %>) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src <%= range_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - switch src.Status { - case Null: - return true, nil - case Undefined: - return false, errUndefined - } +func (src <%= range_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } switch src.LowerType { case Exclusive, Unbounded: - if err := pgio.WriteByte(w, '('); err != nil { - return false, err - } + buf = append(buf, '(') case Inclusive: - if err := pgio.WriteByte(w, '['); err != nil { - return false, err - } + buf = append(buf, '[') case Empty: - _, err := io.WriteString(w, "empty") - return false, err + return append(buf, "empty"...), nil default: - return false, fmt.Errorf("unknown lower bound type %v", src.LowerType) + return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) } + var err error + if src.LowerType != Unbounded { - if null, err := src.Lower.EncodeText(ci, w); err != nil { - return false, err - } else if null { - return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + buf, err = src.Lower.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } } - if err := pgio.WriteByte(w, ','); err != nil { - return false, err - } + buf = append(buf, ',') if src.UpperType != Unbounded { - if null, err := src.Upper.EncodeText(ci, w); err != nil { - return false, err - } else if null { - return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + buf, err = src.Upper.EncodeText(ci, buf) + if err != nil { + return nil, err + } else if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } } switch src.UpperType { case Exclusive, Unbounded: - if err := pgio.WriteByte(w, ')'); err != nil { - return false, err - } + buf = append(buf, ')') case Inclusive: - if err := pgio.WriteByte(w, ']'); err != nil { - return false, err - } + buf = append(buf, ']') default: - return false, fmt.Errorf("unknown upper bound type %v", src.UpperType) + return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) } - return false, nil + return buf, nil } -func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - switch src.Status { - case Null: - return true, nil - case Undefined: - return false, errUndefined - } +func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } var rangeType byte switch src.LowerType { @@ -182,10 +175,9 @@ func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, erro rangeType |= lowerUnboundedMask case Exclusive: case Empty: - err := pgio.WriteByte(w, emptyMask) - return false, err + return append(buf, emptyMask), nil default: - return false, fmt.Errorf("unknown LowerType: %v", src.LowerType) + return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -195,54 +187,44 @@ func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, erro rangeType |= upperUnboundedMask case Exclusive: default: - return false, fmt.Errorf("unknown UpperType: %v", src.UpperType) + return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) } - if err := pgio.WriteByte(w, rangeType); err != nil { - return false, err - } + buf = append(buf, rangeType) - valBuf := &bytes.Buffer{} + var err error if src.LowerType != Unbounded { - null, err := src.Lower.EncodeBinary(ci, valBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Lower.EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - return false, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + if buf == nil { + return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") } - _, err = pgio.WriteInt32(w, int32(valBuf.Len())) - if err != nil { - return false, err - } - _, err = valBuf.WriteTo(w) - if err != nil { - return false, err - } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } if src.UpperType != Unbounded { - null, err := src.Upper.EncodeBinary(ci, valBuf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf, err = src.Upper.EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - return false, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + if buf == nil { + return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") } - _, err = pgio.WriteInt32(w, int32(valBuf.Len())) - if err != nil { - return false, err - } - _, err = valBuf.WriteTo(w) - if err != nil { - return false, err - } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } - return false, nil + return buf, nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/uuid.go b/pgtype/uuid.go index a4a93ab3..c73c501e 100644 --- a/pgtype/uuid.go +++ b/pgtype/uuid.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/hex" "fmt" - "io" ) type Uuid struct { @@ -126,28 +125,26 @@ func (dst *Uuid) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Uuid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Uuid) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := io.WriteString(w, encodeUuid(src.Bytes)) - return false, err + return append(buf, encodeUuid(src.Bytes)...), nil } -func (src *Uuid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Uuid) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - _, err := w.Write(src.Bytes[:]) - return false, err + return append(buf, src.Bytes[:]...), nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/varbit.go b/pgtype/varbit.go index b986f02a..9a9fe1e1 100644 --- a/pgtype/varbit.go +++ b/pgtype/varbit.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/binary" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -76,43 +75,37 @@ func (dst *Varbit) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Varbit) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Varbit) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - buf := make([]byte, int(src.Len)) - for i, _ := range buf { + for i := int32(0); i < src.Len; i++ { byteIdx := i / 8 bitMask := byte(128 >> byte(i%8)) char := byte('0') if src.Bytes[byteIdx]&bitMask > 0 { char = '1' } - buf[i] = char + buf = append(buf, char) } - _, err := w.Write(buf) - return false, err + return buf, nil } -func (src *Varbit) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *Varbit) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } - if _, err := pgio.WriteInt32(w, src.Len); err != nil { - return false, err - } - - _, err := w.Write(src.Bytes) - return false, err + buf = pgio.AppendInt32(buf, src.Len) + return append(buf, src.Bytes...), nil } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/varchar.go b/pgtype/varchar.go index 80673fa8..371efd7e 100644 --- a/pgtype/varchar.go +++ b/pgtype/varchar.go @@ -2,7 +2,6 @@ package pgtype import ( "database/sql/driver" - "io" ) type Varchar Text @@ -32,12 +31,12 @@ func (dst *Varchar) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Text)(dst).DecodeBinary(ci, src) } -func (src *Varchar) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (*Text)(src).EncodeText(ci, w) +func (src *Varchar) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Text)(src).EncodeText(ci, buf) } -func (src *Varchar) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (*Text)(src).EncodeBinary(ci, w) +func (src *Varchar) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*Text)(src).EncodeBinary(ci, buf) } // Scan implements the database/sql Scanner interface. diff --git a/pgtype/varchar_array.go b/pgtype/varchar_array.go index d84fac02..c34ac0b6 100644 --- a/pgtype/varchar_array.go +++ b/pgtype/varchar_array.go @@ -1,11 +1,9 @@ package pgtype import ( - "bytes" "database/sql/driver" "encoding/binary" "fmt" - "io" "github.com/jackc/pgx/pgio" ) @@ -163,23 +161,19 @@ func (dst *VarcharArray) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *VarcharArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *VarcharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } if len(src.Dimensions) == 0 { - _, err := io.WriteString(w, "{}") - return false, err + return append(buf, '{', '}'), nil } - err := EncodeTextArrayDimensions(w, src.Dimensions) - if err != nil { - return false, err - } + buf = EncodeTextArrayDimensions(buf, src.Dimensions) // dimElemCounts is the multiples of elements that each array lies on. For // example, a single dimension array of length 4 would have a dimElemCounts of @@ -192,59 +186,44 @@ func (src *VarcharArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] } + inElemBuf := make([]byte, 0, 32) for i, elem := range src.Elements { if i > 0 { - err = pgio.WriteByte(w, ',') - if err != nil { - return false, err - } + buf = append(buf, ',') } for _, dec := range dimElemCounts { if i%dec == 0 { - err = pgio.WriteByte(w, '{') - if err != nil { - return false, err - } + buf = append(buf, '{') } } - elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(ci, elemBuf) + elemBuf, err := elem.EncodeText(ci, inElemBuf) if err != nil { - return false, err + return nil, err } - if null { - _, err = io.WriteString(w, `"NULL"`) - if err != nil { - return false, err - } + if elemBuf == nil { + buf = append(buf, `"NULL"`...) } else { - _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) - if err != nil { - return false, err - } + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) } for _, dec := range dimElemCounts { if (i+1)%dec == 0 { - err = pgio.WriteByte(w, '}') - if err != nil { - return false, err - } + buf = append(buf, '}') } } } - return false, nil + return buf, nil } -func (src *VarcharArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { +func (src *VarcharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: - return true, nil + return nil, nil case Undefined: - return false, errUndefined + return nil, errUndefined } arrayHeader := ArrayHeader{ @@ -254,7 +233,7 @@ func (src *VarcharArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { if dt, ok := ci.DataTypeForName("varchar"); ok { arrayHeader.ElementOid = int32(dt.Oid) } else { - return false, fmt.Errorf("unable to find oid for type name %v", "varchar") + return nil, fmt.Errorf("unable to find oid for type name %v", "varchar") } for i := range src.Elements { @@ -264,38 +243,23 @@ func (src *VarcharArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { } } - err := arrayHeader.EncodeBinary(ci, w) - if err != nil { - return false, err - } - - elemBuf := &bytes.Buffer{} + buf = arrayHeader.EncodeBinary(ci, buf) for i := range src.Elements { - elemBuf.Reset() + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) if err != nil { - return false, err + return nil, err } - if null { - _, err = pgio.WriteInt32(w, -1) - if err != nil { - return false, err - } - } else { - _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) - if err != nil { - return false, err - } - _, err = elemBuf.WriteTo(w) - if err != nil { - return false, err - } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } } - return false, err + return buf, nil } // Scan implements the database/sql Scanner interface. @@ -318,14 +282,13 @@ func (dst *VarcharArray) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *VarcharArray) Value() (driver.Value, error) { - buf := &bytes.Buffer{} - null, err := src.EncodeText(nil, buf) + buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil } diff --git a/pgtype/xid.go b/pgtype/xid.go index 90a8d691..84acd1b0 100644 --- a/pgtype/xid.go +++ b/pgtype/xid.go @@ -2,7 +2,6 @@ package pgtype import ( "database/sql/driver" - "io" ) // Xid is PostgreSQL's Transaction ID type. @@ -46,12 +45,12 @@ func (dst *Xid) DecodeBinary(ci *ConnInfo, src []byte) error { return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src *Xid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { - return (*pguint32)(src).EncodeText(ci, w) +func (src *Xid) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*pguint32)(src).EncodeText(ci, buf) } -func (src *Xid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return (*pguint32)(src).EncodeBinary(ci, w) +func (src *Xid) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*pguint32)(src).EncodeBinary(ci, buf) } // Scan implements the database/sql Scanner interface. diff --git a/values.go b/values.go index da12952a..b1928b86 100644 --- a/values.go +++ b/values.go @@ -1,13 +1,13 @@ package pgx import ( - "bytes" "database/sql/driver" "fmt" "math" "reflect" "time" + "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgtype" ) @@ -33,15 +33,14 @@ func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, e case driver.Valuer: return arg.Value() case pgtype.TextEncoder: - buf := &bytes.Buffer{} - null, err := arg.EncodeText(ci, buf) + buf, err := arg.EncodeText(ci, nil) if err != nil { return nil, err } - if null { + if buf == nil { return nil, nil } - return buf.String(), nil + return string(buf), nil case int64: return arg, nil case float64: @@ -106,27 +105,27 @@ func encodePreparedStatementArgument(wbuf *WriteBuf, oid pgtype.Oid, arg interfa switch arg := arg.(type) { case pgtype.BinaryEncoder: - sp := wbuf.reserveSize() - null, err := arg.EncodeBinary(wbuf.conn.ConnInfo, wbuf) + sp := len(wbuf.buf) + wbuf.buf = pgio.AppendInt32(wbuf.buf, -1) + argBuf, err := arg.EncodeBinary(wbuf.conn.ConnInfo, wbuf.buf) if err != nil { return err } - if null { - wbuf.setSize(sp, -1) - } else { - wbuf.setComputedSize(sp) + if argBuf != nil { + wbuf.buf = argBuf + pgio.SetInt32(wbuf.buf[sp:], int32(len(wbuf.buf[sp:])-4)) } return nil case pgtype.TextEncoder: - sp := wbuf.reserveSize() - null, err := arg.EncodeText(wbuf.conn.ConnInfo, wbuf) + sp := len(wbuf.buf) + wbuf.buf = pgio.AppendInt32(wbuf.buf, -1) + argBuf, err := arg.EncodeText(wbuf.conn.ConnInfo, wbuf.buf) if err != nil { return err } - if null { - wbuf.setSize(sp, -1) - } else { - wbuf.setComputedSize(sp) + if argBuf != nil { + wbuf.buf = argBuf + pgio.SetInt32(wbuf.buf[sp:], int32(len(wbuf.buf[sp:])-4)) } return nil case driver.Valuer: @@ -159,15 +158,15 @@ func encodePreparedStatementArgument(wbuf *WriteBuf, oid pgtype.Oid, arg interfa return err } - sp := wbuf.reserveSize() - null, err := value.(pgtype.BinaryEncoder).EncodeBinary(wbuf.conn.ConnInfo, wbuf) + sp := len(wbuf.buf) + wbuf.buf = pgio.AppendInt32(wbuf.buf, -1) + argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(wbuf.conn.ConnInfo, wbuf.buf) if err != nil { return err } - if null { - wbuf.setSize(sp, -1) - } else { - wbuf.setComputedSize(sp) + if argBuf != nil { + wbuf.buf = argBuf + pgio.SetInt32(wbuf.buf[sp:], int32(len(wbuf.buf[sp:])-4)) } return nil } From 458dd24a9fd8067f86c5ab765b21a52374ce49b3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 2 May 2017 21:26:45 -0500 Subject: [PATCH 185/264] Remove unneeded WriteBuf --- conn.go | 118 +++++++++++++++++++++++++++++++------------------ copy_from.go | 57 ++++++++++++++---------- fastpath.go | 29 +++++++----- messages.go | 87 ------------------------------------ replication.go | 23 ++++++---- values.go | 70 ++++++++++++++--------------- 6 files changed, 173 insertions(+), 211 deletions(-) diff --git a/conn.go b/conn.go index bca9f6d8..a1781be2 100644 --- a/conn.go +++ b/conn.go @@ -20,6 +20,7 @@ import ( "sync/atomic" "time" + "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgproto3" "github.com/jackc/pgx/pgtype" ) @@ -86,8 +87,7 @@ func (cc *ConnConfig) networkAddress() (network, address string) { type Conn struct { conn net.Conn // the underlying TCP or unix domain socket connection lastActivityTime time.Time // the last time the connection was used - wbuf [1024]byte - writeBuf WriteBuf + wbuf []byte pid uint32 // backend pid secretKey uint32 // key to use to send a cancel query message to the server RuntimeParams map[string]string // parameters that have been reported by the server @@ -279,6 +279,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.cancelQueryCompleted = make(chan struct{}, 1) c.doneChan = make(chan struct{}) c.closedChan = make(chan error) + c.wbuf = make([]byte, 0, 1024) if tlsConfig != nil { if c.shouldLog(LogLevelDebug) { @@ -707,32 +708,42 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared } // parse - wbuf := newWriteBuf(c, 'P') - wbuf.WriteCString(name) - wbuf.WriteCString(sql) + buf := c.wbuf + buf = append(buf, 'P') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, name...) + buf = append(buf, 0) + buf = append(buf, sql...) + buf = append(buf, 0) if opts != nil { if len(opts.ParameterOids) > 65535 { return nil, fmt.Errorf("Number of PrepareExOptions ParameterOids must be between 0 and 65535, received %d", len(opts.ParameterOids)) } - wbuf.WriteInt16(int16(len(opts.ParameterOids))) + buf = pgio.AppendInt16(buf, int16(len(opts.ParameterOids))) for _, oid := range opts.ParameterOids { - wbuf.WriteInt32(int32(oid)) + buf = pgio.AppendInt32(buf, int32(oid)) } } else { - wbuf.WriteInt16(0) + buf = pgio.AppendInt16(buf, 0) } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) // describe - wbuf.startMsg('D') - wbuf.WriteByte('S') - wbuf.WriteCString(name) + buf = append(buf, 'D') + sp = len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, 'S') + buf = append(buf, name...) + buf = append(buf, 0) + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) // sync - wbuf.startMsg('S') - wbuf.closeMsg() + buf = append(buf, 'S') + buf = pgio.AppendInt32(buf, 4) - _, err = c.conn.Write(wbuf.buf) + _, err = c.conn.Write(buf) if err != nil { c.die(err) return nil, err @@ -813,15 +824,20 @@ func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) { delete(c.preparedStatements, name) // close - wbuf := newWriteBuf(c, 'C') - wbuf.WriteByte('S') - wbuf.WriteCString(name) + buf := c.wbuf + buf = append(buf, 'C') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, 'S') + buf = append(buf, name...) + buf = append(buf, 0) + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) // flush - wbuf.startMsg('H') - wbuf.closeMsg() + buf = append(buf, 'H') + buf = pgio.AppendInt32(buf, 4) - _, err = c.conn.Write(wbuf.buf) + _, err = c.conn.Write(buf) if err != nil { c.die(err) return err @@ -943,11 +959,15 @@ func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error { } if len(args) == 0 { - wbuf := newWriteBuf(c, 'Q') - wbuf.WriteCString(sql) - wbuf.closeMsg() + buf := c.wbuf + buf = append(buf, 'Q') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, sql...) + buf = append(buf, 0) + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - _, err := c.conn.Write(wbuf.buf) + _, err := c.conn.Write(buf) if err != nil { c.die(err) return err @@ -975,37 +995,45 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} } // bind - wbuf := newWriteBuf(c, 'B') - wbuf.WriteByte(0) - wbuf.WriteCString(ps.Name) + buf := c.wbuf + buf = append(buf, 'B') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, 0) + buf = append(buf, ps.Name...) + buf = append(buf, 0) - wbuf.WriteInt16(int16(len(ps.ParameterOids))) + buf = pgio.AppendInt16(buf, int16(len(ps.ParameterOids))) for i, oid := range ps.ParameterOids { - wbuf.WriteInt16(chooseParameterFormatCode(c.ConnInfo, oid, arguments[i])) + buf = pgio.AppendInt16(buf, chooseParameterFormatCode(c.ConnInfo, oid, arguments[i])) } - wbuf.WriteInt16(int16(len(arguments))) + buf = pgio.AppendInt16(buf, int16(len(arguments))) for i, oid := range ps.ParameterOids { - if err := encodePreparedStatementArgument(wbuf, oid, arguments[i]); err != nil { + var err error + buf, err = encodePreparedStatementArgument(c.ConnInfo, buf, oid, arguments[i]) + if err != nil { return err } } - wbuf.WriteInt16(int16(len(ps.FieldDescriptions))) + buf = pgio.AppendInt16(buf, int16(len(ps.FieldDescriptions))) for _, fd := range ps.FieldDescriptions { - wbuf.WriteInt16(fd.FormatCode) + buf = pgio.AppendInt16(buf, fd.FormatCode) } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) // execute - wbuf.startMsg('E') - wbuf.WriteByte(0) - wbuf.WriteInt32(0) + buf = append(buf, 'E') + buf = pgio.AppendInt32(buf, 9) + buf = append(buf, 0) + buf = pgio.AppendInt32(buf, 0) // sync - wbuf.startMsg('S') - wbuf.closeMsg() + buf = append(buf, 'S') + buf = pgio.AppendInt32(buf, 4) - _, err = c.conn.Write(wbuf.buf) + _, err = c.conn.Write(buf) if err != nil { c.die(err) } @@ -1180,11 +1208,15 @@ func (c *Conn) txStartupMessage(msg *startupMessage) error { } func (c *Conn) txPasswordMessage(password string) (err error) { - wbuf := newWriteBuf(c, 'p') - wbuf.WriteCString(password) - wbuf.closeMsg() + buf := c.wbuf + buf = append(buf, 'p') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, password...) + buf = append(buf, 0) + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - _, err = c.conn.Write(wbuf.buf) + _, err = c.conn.Write(buf) return err } diff --git a/copy_from.go b/copy_from.go index 7d8dead1..f3c77109 100644 --- a/copy_from.go +++ b/copy_from.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" + "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgproto3" ) @@ -89,14 +90,14 @@ func (ct *copyFrom) waitForReaderDone() error { func (ct *copyFrom) run() (int, error) { quotedTableName := ct.tableName.Sanitize() - buf := &bytes.Buffer{} + cbuf := &bytes.Buffer{} for i, cn := range ct.columnNames { if i != 0 { - buf.WriteString(", ") + cbuf.WriteString(", ") } - buf.WriteString(quoteIdentifier(cn)) + cbuf.WriteString(quoteIdentifier(cn)) } - quotedColumnNames := buf.String() + quotedColumnNames := cbuf.String() ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName)) if err != nil { @@ -116,11 +117,14 @@ func (ct *copyFrom) run() (int, error) { go ct.readUntilReadyForQuery() defer ct.waitForReaderDone() - wbuf := newWriteBuf(ct.conn, copyData) + buf := ct.conn.wbuf + buf = append(buf, copyData) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - wbuf.WriteBytes([]byte("PGCOPY\n\377\r\n\000")) - wbuf.WriteInt32(0) - wbuf.WriteInt32(0) + buf = append(buf, "PGCOPY\n\377\r\n\000"...) + buf = pgio.AppendInt32(buf, 0) + buf = pgio.AppendInt32(buf, 0) var sentCount int @@ -131,18 +135,16 @@ func (ct *copyFrom) run() (int, error) { default: } - if len(wbuf.buf) > 65536 { - wbuf.closeMsg() - _, err = ct.conn.conn.Write(wbuf.buf) + if len(buf) > 65536 { + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + _, err = ct.conn.conn.Write(buf) if err != nil { ct.conn.die(err) return 0, err } // Directly manipulate wbuf to reset to reuse the same buffer - wbuf.buf = wbuf.buf[0:5] - wbuf.buf[0] = copyData - wbuf.sizeIdx = 1 + buf = buf[0:5] } sentCount++ @@ -157,9 +159,9 @@ func (ct *copyFrom) run() (int, error) { return 0, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) } - wbuf.WriteInt16(int16(len(ct.columnNames))) + buf = pgio.AppendInt16(buf, int16(len(ct.columnNames))) for i, val := range values { - err = encodePreparedStatementArgument(wbuf, ps.FieldDescriptions[i].DataType, val) + buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, ps.FieldDescriptions[i].DataType, val) if err != nil { ct.cancelCopyIn() return 0, err @@ -173,11 +175,13 @@ func (ct *copyFrom) run() (int, error) { return 0, ct.rowSrc.Err() } - wbuf.WriteInt16(-1) // terminate the copy stream + buf = pgio.AppendInt16(buf, -1) // terminate the copy stream + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - wbuf.startMsg(copyDone) - wbuf.closeMsg() - _, err = ct.conn.conn.Write(wbuf.buf) + buf = append(buf, copyDone) + buf = pgio.AppendInt32(buf, 4) + + _, err = ct.conn.conn.Write(buf) if err != nil { ct.conn.die(err) return 0, err @@ -210,10 +214,15 @@ func (c *Conn) readUntilCopyInResponse() error { } func (ct *copyFrom) cancelCopyIn() error { - wbuf := newWriteBuf(ct.conn, copyFail) - wbuf.WriteCString("client error: abort") - wbuf.closeMsg() - _, err := ct.conn.conn.Write(wbuf.buf) + buf := ct.conn.wbuf + buf = append(buf, copyFail) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, "client error: abort"...) + buf = append(buf, 0) + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + _, err := ct.conn.conn.Write(buf) if err != nil { ct.conn.die(err) return err diff --git a/fastpath.go b/fastpath.go index 75681c9c..776be177 100644 --- a/fastpath.go +++ b/fastpath.go @@ -3,6 +3,7 @@ package pgx import ( "encoding/binary" + "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgproto3" "github.com/jackc/pgx/pgtype" ) @@ -55,19 +56,23 @@ func (f *fastpath) Call(oid pgtype.Oid, args []fpArg) (res []byte, err error) { return nil, err } - wbuf := newWriteBuf(f.cn, 'F') // function call - wbuf.WriteInt32(int32(oid)) // function object id - wbuf.WriteInt16(1) // # of argument format codes - wbuf.WriteInt16(1) // format code: binary - wbuf.WriteInt16(int16(len(args))) // # of arguments - for _, arg := range args { - wbuf.WriteInt32(int32(len(arg))) // length of argument - wbuf.WriteBytes(arg) // argument value - } - wbuf.WriteInt16(1) // response format code (binary) - wbuf.closeMsg() + buf := f.cn.wbuf + buf = append(buf, 'F') // function call + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - if _, err := f.cn.conn.Write(wbuf.buf); err != nil { + buf = pgio.AppendInt32(buf, int32(oid)) // function object id + buf = pgio.AppendInt16(buf, 1) // # of argument format codes + buf = pgio.AppendInt16(buf, 1) // format code: binary + buf = pgio.AppendInt16(buf, int16(len(args))) // # of arguments + for _, arg := range args { + buf = pgio.AppendInt32(buf, int32(len(arg))) // length of argument + buf = append(buf, arg...) // argument value + } + buf = pgio.AppendInt16(buf, 1) // response format code (binary) + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + if _, err := f.cn.conn.Write(buf); err != nil { return nil, err } diff --git a/messages.go b/messages.go index 0f17a6d2..8e406602 100644 --- a/messages.go +++ b/messages.go @@ -92,90 +92,3 @@ type PgError struct { func (pe PgError) Error() string { return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" } - -func newWriteBuf(c *Conn, t byte) *WriteBuf { - buf := append(c.wbuf[0:0], t, 0, 0, 0, 0) - c.writeBuf = WriteBuf{buf: buf, sizeIdx: 1, conn: c} - return &c.writeBuf -} - -// WriteBuf is used build messages to send to the PostgreSQL server. It is used -// by the Encoder interface when implementing custom encoders. -type WriteBuf struct { - buf []byte - convBuf [8]byte - sizeIdx int - conn *Conn -} - -func (wb *WriteBuf) startMsg(t byte) { - wb.closeMsg() - wb.buf = append(wb.buf, t, 0, 0, 0, 0) - wb.sizeIdx = len(wb.buf) - 4 -} - -func (wb *WriteBuf) closeMsg() { - binary.BigEndian.PutUint32(wb.buf[wb.sizeIdx:wb.sizeIdx+4], uint32(len(wb.buf)-wb.sizeIdx)) -} - -func (wb *WriteBuf) reserveSize() int { - sizePosition := len(wb.buf) - wb.buf = append(wb.buf, 0, 0, 0, 0) - return sizePosition -} - -func (wb *WriteBuf) setComputedSize(sizePosition int) { - binary.BigEndian.PutUint32(wb.buf[sizePosition:], uint32(len(wb.buf)-sizePosition-4)) -} - -func (wb *WriteBuf) setSize(sizePosition int, size int32) { - binary.BigEndian.PutUint32(wb.buf[sizePosition:], uint32(size)) -} - -func (wb *WriteBuf) WriteByte(b byte) { - wb.buf = append(wb.buf, b) -} - -func (wb *WriteBuf) WriteCString(s string) { - wb.buf = append(wb.buf, []byte(s)...) - wb.buf = append(wb.buf, 0) -} - -func (wb *WriteBuf) WriteInt16(n int16) { - wb.WriteUint16(uint16(n)) -} - -func (wb *WriteBuf) WriteUint16(n uint16) (int, error) { - binary.BigEndian.PutUint16(wb.convBuf[:2], n) - wb.buf = append(wb.buf, wb.convBuf[:2]...) - return 2, nil -} - -func (wb *WriteBuf) WriteInt32(n int32) { - wb.WriteUint32(uint32(n)) -} - -func (wb *WriteBuf) WriteUint32(n uint32) (int, error) { - binary.BigEndian.PutUint32(wb.convBuf[:4], n) - wb.buf = append(wb.buf, wb.convBuf[:4]...) - return 4, nil -} - -func (wb *WriteBuf) WriteInt64(n int64) { - wb.WriteUint64(uint64(n)) -} - -func (wb *WriteBuf) WriteUint64(n uint64) (int, error) { - binary.BigEndian.PutUint64(wb.convBuf[:8], n) - wb.buf = append(wb.buf, wb.convBuf[:8]...) - return 8, nil -} - -func (wb *WriteBuf) WriteBytes(b []byte) { - wb.buf = append(wb.buf, b...) -} - -func (wb *WriteBuf) Write(b []byte) (int, error) { - wb.buf = append(wb.buf, b...) - return len(b), nil -} diff --git a/replication.go b/replication.go index 594944e0..1260d3e7 100644 --- a/replication.go +++ b/replication.go @@ -7,6 +7,7 @@ import ( "fmt" "time" + "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgproto3" ) @@ -175,17 +176,21 @@ type ReplicationConn struct { // message to the server, as well as carries the WAL position of the // client, which then updates the server's replication slot position. func (rc *ReplicationConn) SendStandbyStatus(k *StandbyStatus) (err error) { - writeBuf := newWriteBuf(rc.c, copyData) - writeBuf.WriteByte(standbyStatusUpdate) - writeBuf.WriteInt64(int64(k.WalWritePosition)) - writeBuf.WriteInt64(int64(k.WalFlushPosition)) - writeBuf.WriteInt64(int64(k.WalApplyPosition)) - writeBuf.WriteInt64(int64(k.ClientTime)) - writeBuf.WriteByte(k.ReplyRequested) + buf := rc.c.wbuf + buf = append(buf, copyData) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) - writeBuf.closeMsg() + buf = append(buf, standbyStatusUpdate) + buf = pgio.AppendInt64(buf, int64(k.WalWritePosition)) + buf = pgio.AppendInt64(buf, int64(k.WalFlushPosition)) + buf = pgio.AppendInt64(buf, int64(k.WalApplyPosition)) + buf = pgio.AppendInt64(buf, int64(k.ClientTime)) + buf = append(buf, k.ReplyRequested) - _, err = rc.c.conn.Write(writeBuf.buf) + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + _, err = rc.c.conn.Write(buf) if err != nil { rc.c.die(err) } diff --git a/values.go b/values.go index b1928b86..ca5db50b 100644 --- a/values.go +++ b/values.go @@ -97,84 +97,82 @@ func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, e return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg)) } -func encodePreparedStatementArgument(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error { +func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid pgtype.Oid, arg interface{}) ([]byte, error) { if arg == nil { - wbuf.WriteInt32(-1) - return nil + return pgio.AppendInt32(buf, -1), nil } switch arg := arg.(type) { case pgtype.BinaryEncoder: - sp := len(wbuf.buf) - wbuf.buf = pgio.AppendInt32(wbuf.buf, -1) - argBuf, err := arg.EncodeBinary(wbuf.conn.ConnInfo, wbuf.buf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + argBuf, err := arg.EncodeBinary(ci, buf) if err != nil { - return err + return nil, err } if argBuf != nil { - wbuf.buf = argBuf - pgio.SetInt32(wbuf.buf[sp:], int32(len(wbuf.buf[sp:])-4)) + buf = argBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } - return nil + return buf, nil case pgtype.TextEncoder: - sp := len(wbuf.buf) - wbuf.buf = pgio.AppendInt32(wbuf.buf, -1) - argBuf, err := arg.EncodeText(wbuf.conn.ConnInfo, wbuf.buf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + argBuf, err := arg.EncodeText(ci, buf) if err != nil { - return err + return nil, err } if argBuf != nil { - wbuf.buf = argBuf - pgio.SetInt32(wbuf.buf[sp:], int32(len(wbuf.buf[sp:])-4)) + buf = argBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } - return nil + return buf, nil case driver.Valuer: v, err := arg.Value() if err != nil { - return err + return nil, err } - return encodePreparedStatementArgument(wbuf, oid, v) + return encodePreparedStatementArgument(ci, buf, oid, v) case string: - wbuf.WriteInt32(int32(len(arg))) - wbuf.WriteBytes([]byte(arg)) - return nil + buf = pgio.AppendInt32(buf, int32(len(arg))) + buf = append(buf, arg...) + return buf, nil } refVal := reflect.ValueOf(arg) if refVal.Kind() == reflect.Ptr { if refVal.IsNil() { - wbuf.WriteInt32(-1) - return nil + return pgio.AppendInt32(buf, -1), nil } arg = refVal.Elem().Interface() - return encodePreparedStatementArgument(wbuf, oid, arg) + return encodePreparedStatementArgument(ci, buf, oid, arg) } - if dt, ok := wbuf.conn.ConnInfo.DataTypeForOid(oid); ok { + if dt, ok := ci.DataTypeForOid(oid); ok { value := dt.Value err := value.Set(arg) if err != nil { - return err + return nil, err } - sp := len(wbuf.buf) - wbuf.buf = pgio.AppendInt32(wbuf.buf, -1) - argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(wbuf.conn.ConnInfo, wbuf.buf) + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(ci, buf) if err != nil { - return err + return nil, err } if argBuf != nil { - wbuf.buf = argBuf - pgio.SetInt32(wbuf.buf[sp:], int32(len(wbuf.buf[sp:])-4)) + buf = argBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } - return nil + return buf, nil } if strippedArg, ok := stripNamedType(&refVal); ok { - return encodePreparedStatementArgument(wbuf, oid, strippedArg) + return encodePreparedStatementArgument(ci, buf, oid, strippedArg) } - return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) + return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) } // chooseParameterFormatCode determines the correct format code for an From b1489a1eabc8ff45bc2db830d78210f014933e8f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 May 2017 08:48:40 -0500 Subject: [PATCH 186/264] Update pgproto3 to enable pgmock --- pgio/read.go | 11 +++++ pgproto3/backend.go | 30 +++++++++++- pgproto3/frontend.go | 9 +++- pgproto3/startup_message.go | 95 +++++++++++++++++++++++++++++++++++++ 4 files changed, 141 insertions(+), 4 deletions(-) create mode 100644 pgproto3/startup_message.go diff --git a/pgio/read.go b/pgio/read.go index 7ddad508..033bada4 100644 --- a/pgio/read.go +++ b/pgio/read.go @@ -1,6 +1,7 @@ package pgio import ( + "bytes" "encoding/binary" ) @@ -38,3 +39,13 @@ func NextInt64(buf []byte) ([]byte, int64) { buf, n := NextUint64(buf) return buf, int64(n) } + +func NextCString(buf []byte) ([]byte, string, bool) { + idx := bytes.IndexByte(buf, 0) + if idx < 0 { + return buf, "", false + } + cstring := string(buf[:idx]) + buf = buf[:idx+1] + return buf, cstring, true +} diff --git a/pgproto3/backend.go b/pgproto3/backend.go index c04116a8..bd477315 100644 --- a/pgproto3/backend.go +++ b/pgproto3/backend.go @@ -2,7 +2,6 @@ package pgproto3 import ( "encoding/binary" - "errors" "fmt" "io" @@ -20,6 +19,7 @@ type Backend struct { parse Parse passwordMessage PasswordMessage query Query + startupMessage StartupMessage sync Sync terminate Terminate } @@ -30,7 +30,33 @@ func NewBackend(r io.Reader, w io.Writer) (*Backend, error) { } func (b *Backend) Send(msg BackendMessage) error { - return errors.New("not implemented") + buf, err := msg.MarshalBinary() + if err != nil { + return nil + } + + _, err = b.w.Write(buf) + return err +} + +func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) { + buf, err := b.cr.Next(4) + if err != nil { + return nil, err + } + msgSize := int(binary.BigEndian.Uint32(buf) - 4) + + buf, err = b.cr.Next(msgSize) + if err != nil { + return nil, err + } + + err = b.startupMessage.Decode(buf) + if err != nil { + return nil, err + } + + return &b.startupMessage, nil } func (b *Backend) Receive() (FrontendMessage, error) { diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index 50835836..27a9890a 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -2,7 +2,6 @@ package pgproto3 import ( "encoding/binary" - "errors" "fmt" "io" @@ -43,7 +42,13 @@ func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) { } func (b *Frontend) Send(msg FrontendMessage) error { - return errors.New("not implemented") + buf, err := msg.MarshalBinary() + if err != nil { + return nil + } + + _, err = b.w.Write(buf) + return err } func (b *Frontend) Receive() (BackendMessage, error) { diff --git a/pgproto3/startup_message.go b/pgproto3/startup_message.go new file mode 100644 index 00000000..ebb804fe --- /dev/null +++ b/pgproto3/startup_message.go @@ -0,0 +1,95 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "fmt" +) + +const ( + protocolVersionNumber = 196608 // 3.0 + sslRequestNumber = 80877103 +) + +type StartupMessage struct { + ProtocolVersion uint32 + Parameters map[string]string +} + +func (*StartupMessage) Frontend() {} + +func (dst *StartupMessage) Decode(src []byte) error { + if len(src) < 4 { + return fmt.Errorf("startup message too short") + } + + dst.ProtocolVersion = binary.BigEndian.Uint32(src) + rp := 4 + + if dst.ProtocolVersion == sslRequestNumber { + return fmt.Errorf("can't handle ssl connection request") + } + + if dst.ProtocolVersion != protocolVersionNumber { + return fmt.Errorf("Bad startup message version number. Expected %d, got %d", protocolVersionNumber, dst.ProtocolVersion) + } + + dst.Parameters = make(map[string]string) + for { + idx := bytes.IndexByte(src[rp:], 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "StartupMesage"} + } + key := string(src[rp : rp+idx]) + rp += idx + 1 + + idx = bytes.IndexByte(src[rp:], 0) + if idx < 0 { + return &invalidMessageFormatErr{messageType: "StartupMesage"} + } + value := string(src[rp : rp+idx]) + rp += idx + 1 + + dst.Parameters[key] = value + + if len(src[rp:]) == 1 { + if src[rp] != 0 { + return fmt.Errorf("Bad startup message last byte. Expected 0, got %d", src[rp]) + } + break + } + } + + return nil +} + +func (src *StartupMessage) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + buf.Write(bigEndian.Uint32(0)) + buf.Write(bigEndian.Uint32(src.ProtocolVersion)) + for k, v := range src.Parameters { + buf.WriteString(k) + buf.WriteByte(0) + buf.WriteString(v) + buf.WriteByte(0) + } + buf.WriteByte(0) + + binary.BigEndian.PutUint32(buf.Bytes()[0:4], uint32(buf.Len())) + + return buf.Bytes(), nil +} + +func (src *StartupMessage) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ProtocolVersion uint32 + Parameters map[string]string + }{ + Type: "StartupMessage", + ProtocolVersion: src.ProtocolVersion, + Parameters: src.Parameters, + }) +} From 0cda099bb51032f95972fd116f3b2085b00c3698 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 May 2017 08:53:37 -0500 Subject: [PATCH 187/264] Remove read functions from pgio and update docs --- pgio/doc.go | 8 +++---- pgio/read.go | 51 ------------------------------------------ pgio/read_test.go | 57 ----------------------------------------------- 3 files changed, 3 insertions(+), 113 deletions(-) delete mode 100644 pgio/read.go delete mode 100644 pgio/read_test.go diff --git a/pgio/doc.go b/pgio/doc.go index 36233a47..ef2dcc7f 100644 --- a/pgio/doc.go +++ b/pgio/doc.go @@ -1,8 +1,6 @@ -// Package pgio a extremely low-level IO toolkit for the PostgreSQL wire protocol. +// Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol. /* -pgio provides functions for reading and writing integers from io.Reader and -io.Writer while doing byte order conversion. It publishes interfaces which -readers and writers may implement to decode and encode messages with the minimum -of memory allocations. +pgio provides functions for appending integers to a []byte while doing byte +order conversion. */ package pgio diff --git a/pgio/read.go b/pgio/read.go deleted file mode 100644 index 033bada4..00000000 --- a/pgio/read.go +++ /dev/null @@ -1,51 +0,0 @@ -package pgio - -import ( - "bytes" - "encoding/binary" -) - -func NextByte(buf []byte) ([]byte, byte) { - b := buf[0] - return buf[1:], b -} - -func NextUint16(buf []byte) ([]byte, uint16) { - n := binary.BigEndian.Uint16(buf) - return buf[2:], n -} - -func NextUint32(buf []byte) ([]byte, uint32) { - n := binary.BigEndian.Uint32(buf) - return buf[4:], n -} - -func NextUint64(buf []byte) ([]byte, uint64) { - n := binary.BigEndian.Uint64(buf) - return buf[8:], n -} - -func NextInt16(buf []byte) ([]byte, int16) { - buf, n := NextUint16(buf) - return buf, int16(n) -} - -func NextInt32(buf []byte) ([]byte, int32) { - buf, n := NextUint32(buf) - return buf, int32(n) -} - -func NextInt64(buf []byte) ([]byte, int64) { - buf, n := NextUint64(buf) - return buf, int64(n) -} - -func NextCString(buf []byte) ([]byte, string, bool) { - idx := bytes.IndexByte(buf, 0) - if idx < 0 { - return buf, "", false - } - cstring := string(buf[:idx]) - buf = buf[:idx+1] - return buf, cstring, true -} diff --git a/pgio/read_test.go b/pgio/read_test.go deleted file mode 100644 index fbe29ae4..00000000 --- a/pgio/read_test.go +++ /dev/null @@ -1,57 +0,0 @@ -package pgio - -import ( - "testing" -) - -func TestNextByte(t *testing.T) { - buf := []byte{42, 1} - var b byte - buf, b = NextByte(buf) - if b != 42 { - t.Errorf("NextByte(buf) => %v, want %v", b, 42) - } - buf, b = NextByte(buf) - if b != 1 { - t.Errorf("NextByte(buf) => %v, want %v", b, 1) - } -} - -func TestNextUint16(t *testing.T) { - buf := []byte{0, 42, 0, 1} - var n uint16 - buf, n = NextUint16(buf) - if n != 42 { - t.Errorf("NextUint16(buf) => %v, want %v", n, 42) - } - buf, n = NextUint16(buf) - if n != 1 { - t.Errorf("NextUint16(buf) => %v, want %v", n, 1) - } -} - -func TestNextUint32(t *testing.T) { - buf := []byte{0, 0, 0, 42, 0, 0, 0, 1} - var n uint32 - buf, n = NextUint32(buf) - if n != 42 { - t.Errorf("NextUint32(buf) => %v, want %v", n, 42) - } - buf, n = NextUint32(buf) - if n != 1 { - t.Errorf("NextUint32(buf) => %v, want %v", n, 1) - } -} - -func TestNextUint64(t *testing.T) { - buf := []byte{0, 0, 0, 0, 0, 0, 0, 42, 0, 0, 0, 0, 0, 0, 0, 1} - var n uint64 - buf, n = NextUint64(buf) - if n != 42 { - t.Errorf("NextUint64(buf) => %v, want %v", n, 42) - } - buf, n = NextUint64(buf) - if n != 1 { - t.Errorf("NextUint64(buf) => %v, want %v", n, 1) - } -} From 0a67735a8e7a8c6def3b147f1cab5b91c2ab7f85 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 May 2017 09:25:58 -0500 Subject: [PATCH 188/264] ConnPool.Close does not wait for acquired conns --- conn_pool.go | 18 +++++++----------- v3.md | 2 ++ 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/conn_pool.go b/conn_pool.go index 7bc022d0..6a1f37a2 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -230,25 +230,21 @@ func (p *ConnPool) removeFromAllConnections(conn *Conn) bool { return false } -// Close ends the use of a connection pool. It prevents any new connections -// from being acquired, waits until all acquired connections are released, -// then closes all underlying connections. +// Close ends the use of a connection pool. It prevents any new connections from +// being acquired and closes available underlying connections. Any acquired +// connections will be closed when they are released. func (p *ConnPool) Close() { p.cond.L.Lock() defer p.cond.L.Unlock() p.closed = true - // Wait until all connections are released - if len(p.availableConnections) != len(p.allConnections) { - for len(p.availableConnections) != len(p.allConnections) { - p.cond.Wait() - } - } - - for _, c := range p.allConnections { + for _, c := range p.availableConnections { _ = c.Close() } + + // This will cause any checked out connections to be closed on release + p.resetCount++ } // Reset closes all open connections, but leaves the pool open. It is intended diff --git a/v3.md b/v3.md index 72110888..d4dfe1bd 100644 --- a/v3.md +++ b/v3.md @@ -36,6 +36,8 @@ Removed ValueReader Replaced Scanner, Encoder, and PgxScanner interfaces with pgtype system +ConnPool.Close no longer waits for all acquired connections to be released. Instead, it immediately closes all available connections, and closes acquired connections when they are released in the same manner as ConnPool.Reset. + ## TODO / Possible / Investigate Organize errors better From 8322171bd8350632be9666513f7de381bee84c27 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 May 2017 09:37:49 -0500 Subject: [PATCH 189/264] Remove Rows.Fatal --- query.go | 58 +++++++++++++++++++++++++------------------------- replication.go | 6 +++--- v3.md | 2 ++ 3 files changed, 34 insertions(+), 32 deletions(-) diff --git a/query.go b/query.go index 3d081714..2e3957c2 100644 --- a/query.go +++ b/query.go @@ -93,9 +93,9 @@ func (rows *Rows) Err() error { return rows.err } -// Fatal signals an error occurred after the query was sent to the server. It +// fatal signals an error occurred after the query was sent to the server. It // closes the rows automatically. -func (rows *Rows) Fatal(err error) { +func (rows *Rows) fatal(err error) { if rows.err != nil { return } @@ -118,7 +118,7 @@ func (rows *Rows) Next() bool { for { msg, err := rows.conn.rxMsg() if err != nil { - rows.Fatal(err) + rows.fatal(err) return false } @@ -130,13 +130,13 @@ func (rows *Rows) Next() bool { rows.fields[i].DataTypeName = dt.Name rows.fields[i].FormatCode = TextFormatCode } else { - rows.Fatal(fmt.Errorf("unknown oid: %d", rows.fields[i].DataType)) + rows.fatal(fmt.Errorf("unknown oid: %d", rows.fields[i].DataType)) return false } } case *pgproto3.DataRow: if len(msg.Values) != len(rows.fields) { - rows.Fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), len(msg.Values)))) + rows.fatal(ProtocolError(fmt.Sprintf("Row description field count (%v) and data row field count (%v) do not match", len(rows.fields), len(msg.Values)))) return false } @@ -149,7 +149,7 @@ func (rows *Rows) Next() bool { default: err = rows.conn.processContextFreeMsg(msg) if err != nil { - rows.Fatal(err) + rows.fatal(err) return false } } @@ -166,7 +166,7 @@ func (rows *Rows) nextColumn() ([]byte, *FieldDescription, bool) { return nil, nil, false } if len(rows.fields) <= rows.columnIdx { - rows.Fatal(ProtocolError("No next column available")) + rows.fatal(ProtocolError("No next column available")) return nil, nil, false } @@ -192,7 +192,7 @@ func (e scanArgError) Error() string { func (rows *Rows) Scan(dest ...interface{}) (err error) { if len(rows.fields) != len(dest) { err = fmt.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.fields)) - rows.Fatal(err) + rows.fatal(err) return err } @@ -206,12 +206,12 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { if s, ok := d.(pgtype.BinaryDecoder); ok && fd.FormatCode == BinaryFormatCode { err = s.DecodeBinary(rows.conn.ConnInfo, buf) if err != nil { - rows.Fatal(scanArgError{col: i, err: err}) + rows.fatal(scanArgError{col: i, err: err}) } } else if s, ok := d.(pgtype.TextDecoder); ok && fd.FormatCode == TextFormatCode { err = s.DecodeText(rows.conn.ConnInfo, buf) if err != nil { - rows.Fatal(scanArgError{col: i, err: err}) + rows.fatal(scanArgError{col: i, err: err}) } } else { if dt, ok := rows.conn.ConnInfo.DataTypeForOid(fd.DataType); ok { @@ -221,40 +221,40 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { if textDecoder, ok := value.(pgtype.TextDecoder); ok { err = textDecoder.DecodeText(rows.conn.ConnInfo, buf) if err != nil { - rows.Fatal(scanArgError{col: i, err: err}) + rows.fatal(scanArgError{col: i, err: err}) } } else { - rows.Fatal(scanArgError{col: i, err: fmt.Errorf("%T is not a pgtype.TextDecoder", value)}) + rows.fatal(scanArgError{col: i, err: fmt.Errorf("%T is not a pgtype.TextDecoder", value)}) } case BinaryFormatCode: if binaryDecoder, ok := value.(pgtype.BinaryDecoder); ok { err = binaryDecoder.DecodeBinary(rows.conn.ConnInfo, buf) if err != nil { - rows.Fatal(scanArgError{col: i, err: err}) + rows.fatal(scanArgError{col: i, err: err}) } } else { - rows.Fatal(scanArgError{col: i, err: fmt.Errorf("%T is not a pgtype.BinaryDecoder", value)}) + rows.fatal(scanArgError{col: i, err: fmt.Errorf("%T is not a pgtype.BinaryDecoder", value)}) } default: - rows.Fatal(scanArgError{col: i, err: fmt.Errorf("unknown format code: %v", fd.FormatCode)}) + rows.fatal(scanArgError{col: i, err: fmt.Errorf("unknown format code: %v", fd.FormatCode)}) } if rows.Err() == nil { if scanner, ok := d.(sql.Scanner); ok { sqlSrc, err := pgtype.DatabaseSQLValue(rows.conn.ConnInfo, value) if err != nil { - rows.Fatal(err) + rows.fatal(err) } err = scanner.Scan(sqlSrc) if err != nil { - rows.Fatal(scanArgError{col: i, err: err}) + rows.fatal(scanArgError{col: i, err: err}) } } else if err := value.AssignTo(d); err != nil { - rows.Fatal(scanArgError{col: i, err: err}) + rows.fatal(scanArgError{col: i, err: err}) } } } else { - rows.Fatal(scanArgError{col: i, err: fmt.Errorf("unknown oid: %v", fd.DataType)}) + rows.fatal(scanArgError{col: i, err: fmt.Errorf("unknown oid: %v", fd.DataType)}) } } @@ -293,7 +293,7 @@ func (rows *Rows) Values() ([]interface{}, error) { } err := decoder.DecodeText(rows.conn.ConnInfo, buf) if err != nil { - rows.Fatal(err) + rows.fatal(err) } values = append(values, decoder.(pgtype.Value).Get()) case BinaryFormatCode: @@ -303,14 +303,14 @@ func (rows *Rows) Values() ([]interface{}, error) { } err := decoder.DecodeBinary(rows.conn.ConnInfo, buf) if err != nil { - rows.Fatal(err) + rows.fatal(err) } values = append(values, value.Get()) default: - rows.Fatal(errors.New("Unknown format code")) + rows.fatal(errors.New("Unknown format code")) } } else { - rows.Fatal(errors.New("Unknown type")) + rows.fatal(errors.New("Unknown type")) } if rows.Err() != nil { @@ -381,7 +381,7 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, rows = c.getRows(sql, args) if err := c.lock(); err != nil { - rows.Fatal(err) + rows.fatal(err) return rows, err } rows.unlockConn = true @@ -389,13 +389,13 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, if options != nil && options.SimpleProtocol { err = c.initContext(ctx) if err != nil { - rows.Fatal(err) + rows.fatal(err) return rows, err } err = c.sanitizeAndSendSimpleQuery(sql, args...) if err != nil { - rows.Fatal(err) + rows.fatal(err) return rows, err } @@ -407,7 +407,7 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, var err error ps, err = c.PrepareExContext(ctx, "", sql, nil) if err != nil { - rows.Fatal(err) + rows.fatal(err) return rows, rows.err } } @@ -416,13 +416,13 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, err = c.initContext(ctx) if err != nil { - rows.Fatal(err) + rows.fatal(err) return rows, err } err = c.sendPreparedQuery(ps, args...) if err != nil { - rows.Fatal(err) + rows.fatal(err) err = c.termContext(err) } diff --git a/replication.go b/replication.go index 1260d3e7..eacc0c3f 100644 --- a/replication.go +++ b/replication.go @@ -328,14 +328,14 @@ func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) { rows := rc.c.getRows(sql, nil) if err := rc.c.lock(); err != nil { - rows.Fatal(err) + rows.fatal(err) return rows, err } rows.unlockConn = true err := rc.c.sendSimpleQuery(sql) if err != nil { - rows.Fatal(err) + rows.fatal(err) } msg, err := rc.c.rxMsg() @@ -351,7 +351,7 @@ func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) { // only Oids. Not much we can do about this. default: if e := rc.c.processContextFreeMsg(msg); e != nil { - rows.Fatal(e) + rows.fatal(e) return rows, e } } diff --git a/v3.md b/v3.md index d4dfe1bd..baf4b101 100644 --- a/v3.md +++ b/v3.md @@ -38,6 +38,8 @@ Replaced Scanner, Encoder, and PgxScanner interfaces with pgtype system ConnPool.Close no longer waits for all acquired connections to be released. Instead, it immediately closes all available connections, and closes acquired connections when they are released in the same manner as ConnPool.Reset. +Removed Rows.Fatal(error) + ## TODO / Possible / Investigate Organize errors better From 2a4956974761d7cb0abe2acbb217d2a4dc2922af Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 May 2017 10:00:49 -0500 Subject: [PATCH 190/264] Remove AfterClose() and Conn() from Tx and Rows --- conn_pool.go | 16 +++------------- query.go | 25 +++---------------------- tx.go | 37 ++++++++++--------------------------- tx_test.go | 39 ++------------------------------------- v3.md | 8 ++++++++ 5 files changed, 26 insertions(+), 99 deletions(-) diff --git a/conn_pool.go b/conn_pool.go index 6a1f37a2..49de6658 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -31,8 +31,6 @@ type ConnPool struct { preparedStatements map[string]*PreparedStatement acquireTimeout time.Duration connInfo *pgtype.ConnInfo - txAfterClose func(tx *Tx) - rowsAfterClose func(rows *Rows) } type ConnPoolStat struct { @@ -75,14 +73,6 @@ func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) { p.logLevel = LogLevelNone } - p.txAfterClose = func(tx *Tx) { - p.Release(tx.Conn()) - } - - p.rowsAfterClose = func(rows *Rows) { - p.Release(rows.Conn()) - } - p.allConnections = make([]*Conn, 0, p.maxConnections) p.availableConnections = make([]*Conn, 0, p.maxConnections) p.preparedStatements = make(map[string]*PreparedStatement) @@ -381,7 +371,7 @@ func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) { return rows, err } - rows.AfterClose(p.rowsAfterClose) + rows.connPool = p return rows, nil } @@ -399,7 +389,7 @@ func (p *ConnPool) QueryEx(ctx context.Context, sql string, options *QueryExOpti return rows, err } - rows.AfterClose(p.rowsAfterClose) + rows.connPool = p return rows, nil } @@ -531,7 +521,7 @@ func (p *ConnPool) BeginEx(txOptions *TxOptions) (*Tx, error) { continue } - tx.AfterClose(p.txAfterClose) + tx.connPool = p return tx, nil } } diff --git a/query.go b/query.go index 2e3957c2..681c133b 100644 --- a/query.go +++ b/query.go @@ -42,6 +42,7 @@ func (r *Row) Scan(dest ...interface{}) (err error) { // calling Next() until it returns false, or when a fatal error occurs. type Rows struct { conn *Conn + connPool *ConnPool values [][]byte fields []FieldDescription rowCount int @@ -50,7 +51,6 @@ type Rows struct { startTime time.Time sql string args []interface{} - afterClose func(*Rows) unlockConn bool closed bool } @@ -84,8 +84,8 @@ func (rows *Rows) Close() { rows.conn.log(LogLevelError, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args)}) } - if rows.afterClose != nil { - rows.afterClose(rows) + if rows.connPool != nil { + rows.connPool.Release(rows.conn) } } @@ -156,11 +156,6 @@ func (rows *Rows) Next() bool { } } -// Conn returns the *Conn this *Rows is using. -func (rows *Rows) Conn() *Conn { - return rows.conn -} - func (rows *Rows) nextColumn() ([]byte, *FieldDescription, bool) { if rows.closed { return nil, nil, false @@ -321,20 +316,6 @@ func (rows *Rows) Values() ([]interface{}, error) { return values, rows.Err() } -// AfterClose adds f to a LILO queue of functions that will be called when -// rows is closed. -func (rows *Rows) AfterClose(f func(*Rows)) { - if rows.afterClose == nil { - rows.afterClose = f - } else { - prevFn := rows.afterClose - rows.afterClose = func(rows *Rows) { - f(rows) - prevFn(rows) - } - } -} - // Query executes sql with args. If there is an error the returned *Rows will // be returned in an error state. So it is allowed to ignore the error returned // from Query and handle it in *Rows. diff --git a/tx.go b/tx.go index 099ef180..ea804449 100644 --- a/tx.go +++ b/tx.go @@ -94,10 +94,10 @@ func (c *Conn) BeginEx(txOptions *TxOptions) (*Tx, error) { // All Tx methods return ErrTxClosed if Commit or Rollback has already been // called on the Tx. type Tx struct { - conn *Conn - afterClose func(*Tx) - err error - status int8 + conn *Conn + connPool *ConnPool + err error + status int8 } // Commit commits the transaction @@ -117,9 +117,10 @@ func (tx *Tx) Commit() error { tx.err = err } - if tx.afterClose != nil { - tx.afterClose(tx) + if tx.connPool != nil { + tx.connPool.Release(tx.conn) } + return tx.err } @@ -139,9 +140,10 @@ func (tx *Tx) Rollback() error { tx.status = TxStatusRollbackFailure } - if tx.afterClose != nil { - tx.afterClose(tx) + if tx.connPool != nil { + tx.connPool.Release(tx.conn) } + return tx.err } @@ -194,11 +196,6 @@ func (tx *Tx) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFr return tx.conn.CopyFrom(tableName, columnNames, rowSrc) } -// Conn returns the *Conn this transaction is using. -func (tx *Tx) Conn() *Conn { - return tx.conn -} - // Status returns the status of the transaction from the set of // pgx.TxStatus* constants. func (tx *Tx) Status() int8 { @@ -209,17 +206,3 @@ func (tx *Tx) Status() int8 { func (tx *Tx) Err() error { return tx.err } - -// AfterClose adds f to a LILO queue of functions that will be called when -// the transaction is closed (either Commit or Rollback). -func (tx *Tx) AfterClose(f func(*Tx)) { - if tx.afterClose == nil { - tx.afterClose = f - } else { - prevFn := tx.afterClose - tx.afterClose = func(tx *Tx) { - f(tx) - prevFn(tx) - } - } -} diff --git a/tx_test.go b/tx_test.go index 0ba5904b..35abd4eb 100644 --- a/tx_test.go +++ b/tx_test.go @@ -1,9 +1,9 @@ package pgx_test import ( - "github.com/jackc/pgx" "testing" - "time" + + "github.com/jackc/pgx" ) func TestTransactionSuccessfulCommit(t *testing.T) { @@ -226,41 +226,6 @@ func TestBeginExReadOnly(t *testing.T) { } } -func TestTxAfterClose(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - tx, err := conn.Begin() - if err != nil { - t.Fatal(err) - } - - var zeroTime, t1, t2 time.Time - tx.AfterClose(func(tx *pgx.Tx) { - t1 = time.Now() - }) - - tx.AfterClose(func(tx *pgx.Tx) { - t2 = time.Now() - }) - - tx.Rollback() - - if t1 == zeroTime { - t.Error("First Tx.AfterClose callback not called") - } - - if t2 == zeroTime { - t.Error("Second Tx.AfterClose callback not called") - } - - if t1.Before(t2) { - t.Errorf("AfterClose callbacks called out of order: %v, %v", t1, t2) - } -} - func TestTxStatus(t *testing.T) { t.Parallel() diff --git a/v3.md b/v3.md index baf4b101..624c25eb 100644 --- a/v3.md +++ b/v3.md @@ -40,6 +40,14 @@ ConnPool.Close no longer waits for all acquired connections to be released. Inst Removed Rows.Fatal(error) +Removed Rows.AfterClose() + +Removed Rows.Conn() + +Removed Tx.AfterClose() + +Removed Tx.Conn() + ## TODO / Possible / Investigate Organize errors better From 8b6c32d13acab3447d4a024cf06867bf617cf3ee Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 May 2017 15:20:40 -0500 Subject: [PATCH 191/264] Add ConnConfig.Merge --- conn.go | 52 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/conn.go b/conn.go index a1781be2..4a77f769 100644 --- a/conn.go +++ b/conn.go @@ -445,6 +445,58 @@ func (c *Conn) Close() (err error) { return nil } +// Merge returns a new ConnConfig with the attributes of old and other +// combined. When an attribute is set on both, other takes precedence. +// +// As a security precaution, if the other TLSConfig is nil, all old TLS +// attributes will be preserved. +func (old ConnConfig) Merge(other ConnConfig) ConnConfig { + cc := old + + if other.Host != "" { + cc.Host = other.Host + } + if other.Port != 0 { + cc.Port = other.Port + } + if other.Database != "" { + cc.Database = other.Database + } + if other.User != "" { + cc.User = other.User + } + if other.Password != "" { + cc.Password = other.Password + } + + if other.TLSConfig != nil { + cc.TLSConfig = other.TLSConfig + cc.UseFallbackTLS = other.UseFallbackTLS + cc.FallbackTLSConfig = other.FallbackTLSConfig + } + + if other.Logger != nil { + cc.Logger = other.Logger + } + if other.LogLevel != 0 { + cc.LogLevel = other.LogLevel + } + + if other.Dial != nil { + cc.Dial = other.Dial + } + + cc.RuntimeParams = make(map[string]string) + for k, v := range old.RuntimeParams { + cc.RuntimeParams[k] = v + } + for k, v := range other.RuntimeParams { + cc.RuntimeParams[k] = v + } + + return cc +} + // ParseURI parses a database URI into ConnConfig // // Query parameters not used by the connection process are parsed into ConnConfig.RuntimeParams. From 78d344d1abebb939f5ac4cc9a88c4e072f0efddd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 May 2017 15:28:16 -0500 Subject: [PATCH 192/264] Add DriverConfig system to stdlib --- stdlib/sql.go | 80 ++++++++++++++++++++++++++++++++++++++++++++-- stdlib/sql_test.go | 26 +++++++++++++++ 2 files changed, 103 insertions(+), 3 deletions(-) diff --git a/stdlib/sql.go b/stdlib/sql.go index 80a559af..19d96260 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -54,6 +54,7 @@ import ( "context" "database/sql" "database/sql/driver" + "encoding/binary" "errors" "fmt" "io" @@ -72,9 +73,13 @@ var ( // binary, anything else will be forced to text format var databaseSqlOids map[pgtype.Oid]bool +var pgxDriver *Driver + func init() { - d := &Driver{} - sql.Register("pgx", d) + pgxDriver = &Driver{ + configs: make(map[int64]*DriverConfig), + } + sql.Register("pgx", pgxDriver) databaseSqlOids = make(map[pgtype.Oid]bool) databaseSqlOids[pgtype.BoolOid] = true @@ -94,6 +99,10 @@ func init() { type Driver struct { Pool *pgx.ConnPool + + configMutex sync.Mutex + configCount int64 + configs map[int64]*DriverConfig } func (d *Driver) Open(name string) (driver.Conn, error) { @@ -106,20 +115,85 @@ func (d *Driver) Open(name string) (driver.Conn, error) { return &Conn{conn: conn, pool: d.Pool}, nil } - connConfig, err := pgx.ParseConnectionString(name) + var connConfig pgx.ConnConfig + var afterConnect func(*pgx.Conn) error + if len(name) >= 9 && name[0] == 0 { + idBuf := []byte(name)[1:9] + id := int64(binary.BigEndian.Uint64(idBuf)) + connConfig = d.configs[id].ConnConfig + afterConnect = d.configs[id].AfterConnect + name = name[9:] + } + + parsedConfig, err := pgx.ParseConnectionString(name) if err != nil { return nil, err } + connConfig = connConfig.Merge(parsedConfig) conn, err := pgx.Connect(connConfig) if err != nil { return nil, err } + if afterConnect != nil { + err = afterConnect(conn) + if err != nil { + return nil, err + } + } + c := &Conn{conn: conn} return c, nil } +type DriverConfig struct { + pgx.ConnConfig + AfterConnect func(*pgx.Conn) error // function to call on every new connection + driver *Driver + id int64 +} + +// ConnectionString encodes the DriverConfig into the original connection +// string. DriverConfig must be registered before calling ConnectionString. +func (c *DriverConfig) ConnectionString(original string) string { + if c.driver == nil { + panic("DriverConfig must be registered before calling ConnectionString") + } + + buf := make([]byte, 9) + binary.BigEndian.PutUint64(buf[1:], uint64(c.id)) + buf = append(buf, original...) + return string(buf) +} + +func (d *Driver) registerDriverConfig(c *DriverConfig) { + d.configMutex.Lock() + + c.driver = d + c.id = d.configCount + d.configs[d.configCount] = c + d.configCount++ + + d.configMutex.Unlock() +} + +func (d *Driver) unregisterDriverConfig(c *DriverConfig) { + d.configMutex.Lock() + delete(d.configs, c.id) + d.configMutex.Unlock() +} + +// RegisterDriverConfig registers a DriverConfig for use with Open. +func RegisterDriverConfig(c *DriverConfig) { + pgxDriver.registerDriverConfig(c) +} + +// UnregisterDriverConfig removes a DriverConfig registration. +func UnregisterDriverConfig(c *DriverConfig) { + pgxDriver.unregisterDriverConfig(c) +} + // OpenFromConnPool takes the existing *pgx.ConnPool pool and returns a *sql.DB // with pool as the backend. This enables full control over the connection // process and configuration while maintaining compatibility with the diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index ba74560d..e4fbcb0c 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -202,6 +202,32 @@ func TestOpenFromConnPoolRace(t *testing.T) { wg.Wait() } +func TestOpenWithDriverConfigAfterConnect(t *testing.T) { + driverConfig := stdlib.DriverConfig{ + AfterConnect: func(c *pgx.Conn) error { + _, err := c.Exec("create temporary sequence pgx") + return err + }, + } + + stdlib.RegisterDriverConfig(&driverConfig) + defer stdlib.UnregisterDriverConfig(&driverConfig) + + db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test")) + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + + var n int64 + err = db.QueryRow("select nextval('pgx')").Scan(&n) + if err != nil { + t.Fatalf("db.QueryRow unexpectedly failed: %v", err) + } + if n != 1 { + t.Fatalf("n => %d, want %d", n, 1) + } +} + func TestStmtExec(t *testing.T) { db := openDB(t) defer closeDB(t, db) From ffae1b134596a9a111fdb89b1bd09febad7e2dae Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 May 2017 15:39:26 -0500 Subject: [PATCH 193/264] Remove stdlib.OpenFromConnPool --- stdlib/sql.go | 98 -------------------------------------- stdlib/sql_test.go | 115 ++++++--------------------------------------- 2 files changed, 14 insertions(+), 199 deletions(-) diff --git a/stdlib/sql.go b/stdlib/sql.go index 19d96260..439a5262 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -13,41 +13,6 @@ // if err != nil { // return err // } -// -// Or a normal pgx connection pool can be established and the database/sql -// connection can be created through stdlib.OpenFromConnPool(). This allows -// more control over the connection process (such as TLS), more control -// over the connection pool, setting an AfterConnect hook, and using both -// database/sql and pgx interfaces as needed. -// -// connConfig := pgx.ConnConfig{ -// Host: "localhost", -// User: "pgx_md5", -// Password: "secret", -// Database: "pgx_test", -// } -// -// config := pgx.ConnPoolConfig{ConnConfig: connConfig} -// pool, err := pgx.NewConnPool(config) -// if err != nil { -// return err -// } -// -// db, err := stdlib.OpenFromConnPool(pool) -// if err != nil { -// t.Fatalf("Unable to create connection pool: %v", err) -// } -// -// If the database/sql connection is established through -// stdlib.OpenFromConnPool then access to a pgx *ConnPool can be regained -// through db.Driver(). This allows writing a fast path for pgx while -// preserving compatibility with other drivers and database -// -// if driver, ok := db.Driver().(*stdlib.Driver); ok && driver.Pool != nil { -// // fast path with pgx -// } else { -// // normal path for other drivers and databases -// } package stdlib import ( @@ -55,7 +20,6 @@ import ( "database/sql" "database/sql/driver" "encoding/binary" - "errors" "fmt" "io" "sync" @@ -64,11 +28,6 @@ import ( "github.com/jackc/pgx/pgtype" ) -var ( - openFromConnPoolCountMu sync.Mutex - openFromConnPoolCount int -) - // oids that map to intrinsic database/sql types. These will be allowed to be // binary, anything else will be forced to text format var databaseSqlOids map[pgtype.Oid]bool @@ -98,23 +57,12 @@ func init() { } type Driver struct { - Pool *pgx.ConnPool - configMutex sync.Mutex configCount int64 configs map[int64]*DriverConfig } func (d *Driver) Open(name string) (driver.Conn, error) { - if d.Pool != nil { - conn, err := d.Pool.Acquire() - if err != nil { - return nil, err - } - - return &Conn{conn: conn, pool: d.Pool}, nil - } - var connConfig pgx.ConnConfig var afterConnect func(*pgx.Conn) error if len(name) >= 9 && name[0] == 0 { @@ -194,49 +142,8 @@ func UnregisterDriverConfig(c *DriverConfig) { pgxDriver.unregisterDriverConfig(c) } -// OpenFromConnPool takes the existing *pgx.ConnPool pool and returns a *sql.DB -// with pool as the backend. This enables full control over the connection -// process and configuration while maintaining compatibility with the -// database/sql interface. In addition, by calling Driver() on the returned -// *sql.DB and typecasting to *stdlib.Driver a reference to the pgx.ConnPool can -// be reaquired later. This allows fast paths targeting pgx to be used while -// still maintaining compatibility with other databases and drivers. -// -// pool connection size must be at least 2. -func OpenFromConnPool(pool *pgx.ConnPool) (*sql.DB, error) { - d := &Driver{Pool: pool} - - openFromConnPoolCountMu.Lock() - name := fmt.Sprintf("pgx-%d", openFromConnPoolCount) - openFromConnPoolCount++ - openFromConnPoolCountMu.Unlock() - - sql.Register(name, d) - db, err := sql.Open(name, "") - if err != nil { - return nil, err - } - - // Presumably OpenFromConnPool is being used because the user wants to use - // database/sql most of the time, but fast path with pgx some of the time. - // Allow database/sql to use all the connections, but release 2 idle ones. - // Don't have database/sql immediately release all idle connections because - // that would mean that prepared statements would be lost (which kills - // performance if the prepared statements constantly have to be reprepared) - stat := pool.Stat() - - if stat.MaxConnections <= 2 { - return nil, errors.New("pool connection size must be at least 3") - } - db.SetMaxIdleConns(stat.MaxConnections - 2) - db.SetMaxOpenConns(stat.MaxConnections) - - return db, nil -} - type Conn struct { conn *pgx.Conn - pool *pgx.ConnPool psCount int64 // Counter used for creating unique prepared statement names } @@ -259,11 +166,6 @@ func (c *Conn) Prepare(query string) (driver.Stmt, error) { } func (c *Conn) Close() error { - if c.pool != nil { - c.pool.Release(c.conn) - return nil - } - return c.conn.Close() } diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index e4fbcb0c..bdafdd48 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -3,7 +3,6 @@ package stdlib_test import ( "bytes" "database/sql" - "sync" "testing" "github.com/jackc/pgx" @@ -120,88 +119,6 @@ func TestNormalLifeCycle(t *testing.T) { ensureConnValid(t, db) } -func TestSqlOpenDoesNotHavePool(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) - - driver := db.Driver().(*stdlib.Driver) - if driver.Pool != nil { - t.Fatal("Did not expect driver opened through database/sql to have Pool, but it did") - } -} - -func TestOpenFromConnPool(t *testing.T) { - connConfig := pgx.ConnConfig{ - Host: "127.0.0.1", - User: "pgx_md5", - Password: "secret", - Database: "pgx_test", - } - - config := pgx.ConnPoolConfig{ConnConfig: connConfig} - pool, err := pgx.NewConnPool(config) - if err != nil { - t.Fatalf("Unable to create connection pool: %v", err) - } - defer pool.Close() - - db, err := stdlib.OpenFromConnPool(pool) - if err != nil { - t.Fatalf("Unable to create connection pool: %v", err) - } - defer closeDB(t, db) - - // Can get pgx.ConnPool from driver - driver := db.Driver().(*stdlib.Driver) - if driver.Pool == nil { - t.Fatal("Expected driver opened through OpenFromConnPool to have Pool, but it did not") - } - - // Normal sql/database still works - var n int64 - err = db.QueryRow("select 1").Scan(&n) - if err != nil { - t.Fatalf("db.QueryRow unexpectedly failed: %v", err) - } -} - -func TestOpenFromConnPoolRace(t *testing.T) { - wg := &sync.WaitGroup{} - connConfig := pgx.ConnConfig{ - Host: "127.0.0.1", - User: "pgx_md5", - Password: "secret", - Database: "pgx_test", - } - - config := pgx.ConnPoolConfig{ConnConfig: connConfig} - pool, err := pgx.NewConnPool(config) - if err != nil { - t.Fatalf("Unable to create connection pool: %v", err) - } - defer pool.Close() - - wg.Add(10) - for i := 0; i < 10; i++ { - go func() { - defer wg.Done() - db, err := stdlib.OpenFromConnPool(pool) - if err != nil { - t.Fatalf("Unable to create connection pool: %v", err) - } - defer closeDB(t, db) - - // Can get pgx.ConnPool from driver - driver := db.Driver().(*stdlib.Driver) - if driver.Pool == nil { - t.Fatal("Expected driver opened through OpenFromConnPool to have Pool, but it did not") - } - }() - } - - wg.Wait() -} - func TestOpenWithDriverConfigAfterConnect(t *testing.T) { driverConfig := stdlib.DriverConfig{ AfterConnect: func(c *pgx.Conn) error { @@ -217,6 +134,7 @@ func TestOpenWithDriverConfigAfterConnect(t *testing.T) { if err != nil { t.Fatalf("sql.Open failed: %v", err) } + defer closeDB(t, db) var n int64 err = db.QueryRow("select nextval('pgx')").Scan(&n) @@ -407,37 +325,32 @@ func (l *testLogger) Log(lvl pgx.LogLevel, msg string, data map[string]interface func TestConnQueryLog(t *testing.T) { logger := &testLogger{} - connConfig := pgx.ConnConfig{ - Host: "127.0.0.1", - User: "pgx_md5", - Password: "secret", - Database: "pgx_test", - Logger: logger, + driverConfig := stdlib.DriverConfig{ + ConnConfig: pgx.ConnConfig{ + Host: "127.0.0.1", + User: "pgx_md5", + Password: "secret", + Database: "pgx_test", + Logger: logger, + }, } - config := pgx.ConnPoolConfig{ConnConfig: connConfig} - pool, err := pgx.NewConnPool(config) - if err != nil { - t.Fatalf("Unable to create connection pool: %v", err) - } - defer pool.Close() + stdlib.RegisterDriverConfig(&driverConfig) + defer stdlib.UnregisterDriverConfig(&driverConfig) - db, err := stdlib.OpenFromConnPool(pool) + db, err := sql.Open("pgx", driverConfig.ConnectionString("")) if err != nil { - t.Fatalf("Unable to create connection pool: %v", err) + t.Fatalf("sql.Open failed: %v", err) } defer closeDB(t, db) - // clear logs from initial connection - logger.logs = []testLog{} - var n int64 err = db.QueryRow("select 1").Scan(&n) if err != nil { t.Fatalf("db.QueryRow unexpectedly failed: %v", err) } - l := logger.logs[0] + l := logger.logs[len(logger.logs)-1] if l.msg != "Query" { t.Errorf("Expected to log Query, but got %v", l) } From 4cbefbb27ed7c0ea9a2259250bc87276922228c7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 May 2017 16:29:37 -0500 Subject: [PATCH 194/264] Add TxOptions support to stdlib --- stdlib/sql.go | 40 ++++++++++++---------- stdlib/sql_test.go | 82 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 18 deletions(-) diff --git a/stdlib/sql.go b/stdlib/sql.go index 439a5262..7e635324 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -170,16 +170,34 @@ func (c *Conn) Close() error { } func (c *Conn) Begin() (driver.Tx, error) { + return c.BeginTx(context.Background(), driver.TxOptions{}) +} + +func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { if !c.conn.IsAlive() { return nil, driver.ErrBadConn } - _, err := c.conn.Exec("begin") - if err != nil { - return nil, err + var pgxOpts pgx.TxOptions + switch sql.IsolationLevel(opts.Isolation) { + case sql.LevelDefault: + case sql.LevelReadUncommitted: + pgxOpts.IsoLevel = pgx.ReadUncommitted + case sql.LevelReadCommitted: + pgxOpts.IsoLevel = pgx.ReadCommitted + case sql.LevelSnapshot: + pgxOpts.IsoLevel = pgx.RepeatableRead + case sql.LevelSerializable: + pgxOpts.IsoLevel = pgx.Serializable + default: + return nil, fmt.Errorf("unsupported isolation: %v", opts.Isolation) } - return &Tx{conn: c.conn}, nil + if opts.ReadOnly { + pgxOpts.AccessMode = pgx.ReadOnly + } + + return c.conn.BeginEx(&pgxOpts) } func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) { @@ -389,17 +407,3 @@ func namedValueToInterface(argsV []driver.NamedValue) []interface{} { } return args } - -type Tx struct { - conn *pgx.Conn -} - -func (t *Tx) Commit() error { - _, err := t.conn.Exec("commit") - return err -} - -func (t *Tx) Rollback() error { - _, err := t.conn.Exec("rollback") - return err -} diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index bdafdd48..fdc93c0a 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -2,6 +2,7 @@ package stdlib_test import ( "bytes" + "context" "database/sql" "testing" @@ -603,3 +604,84 @@ func TestTransactionLifeCycle(t *testing.T) { ensureConnValid(t, db) } + +func TestConnBeginTxIsolation(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + var defaultIsoLevel string + err := db.QueryRow("show transaction_isolation").Scan(&defaultIsoLevel) + if err != nil { + t.Fatalf("QueryRow failed: %v", err) + } + + supportedTests := []struct { + sqlIso sql.IsolationLevel + pgIso string + }{ + {sqlIso: sql.LevelDefault, pgIso: defaultIsoLevel}, + {sqlIso: sql.LevelReadUncommitted, pgIso: "read uncommitted"}, + {sqlIso: sql.LevelReadCommitted, pgIso: "read committed"}, + {sqlIso: sql.LevelSnapshot, pgIso: "repeatable read"}, + {sqlIso: sql.LevelSerializable, pgIso: "serializable"}, + } + for i, tt := range supportedTests { + func() { + tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso}) + if err != nil { + t.Errorf("%d. BeginTx failed: %v", i, err) + return + } + defer tx.Rollback() + + var pgIso string + err = tx.QueryRow("show transaction_isolation").Scan(&pgIso) + if err != nil { + t.Errorf("%d. QueryRow failed: %v", i, err) + } + + if pgIso != tt.pgIso { + t.Errorf("%d. pgIso => %s, want %s", i, pgIso, tt.pgIso) + } + }() + } + + unsupportedTests := []struct { + sqlIso sql.IsolationLevel + }{ + {sqlIso: sql.LevelWriteCommitted}, + {sqlIso: sql.LevelLinearizable}, + } + for i, tt := range unsupportedTests { + tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso}) + if err == nil { + t.Errorf("%d. BeginTx should have failed", i) + tx.Rollback() + } + } + + ensureConnValid(t, db) +} + +func TestConnBeginTxReadOnly(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}) + if err != nil { + t.Fatalf("BeginTx failed: %v", err) + } + defer tx.Rollback() + + var pgReadOnly string + err = tx.QueryRow("show transaction_read_only").Scan(&pgReadOnly) + if err != nil { + t.Errorf("%d. QueryRow failed: %v", err) + } + + if pgReadOnly != "on" { + t.Errorf("pgReadOnly => %s, want %s", pgReadOnly, "on") + } + + ensureConnValid(t, db) +} From c78d450c19eac535ea7417c6f8cc82b1e76271b0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 May 2017 19:39:40 -0500 Subject: [PATCH 195/264] Add stdlib AcquireConn and ReleaseConn Also add some documentation. --- stdlib/sql.go | 113 ++++++++++++++++++++++++++++++++++++++++++++- stdlib/sql_test.go | 39 ++++++++++++++++ 2 files changed, 150 insertions(+), 2 deletions(-) diff --git a/stdlib/sql.go b/stdlib/sql.go index 7e635324..e70780c1 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -13,6 +13,54 @@ // if err != nil { // return err // } +// +// A DriverConfig can be used to further configure the connection process. This +// allows configuring TLS configuration, setting a custom dialer, logging, and +// setting an AfterConnect hook. +// +// driverConfig := stdlib.DriverConfig{ +// ConnConfig: ConnConfig: pgx.ConnConfig{ +// Logger: logger, +// }, +// AfterConnect: func(c *pgx.Conn) error { +// // Ensure all connections have this temp table available +// _, err := c.Exec("create temporary table foo(...)") +// return err +// }, +// } +// +// stdlib.RegisterDriverConfig(&driverConfig) +// +// db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test")) +// if err != nil { +// return err +// } +// +// AcquireConn and ReleaseConn acquire and release a *pgx.Conn from the standard +// database/sql.DB connection pool. This allows operations that must be +// performed on a single connection, but should not be run in a transaction or +// to use pgx specific functionality. +// +// conn, err := stdlib.AcquireConn(db) +// if err != nil { +// return err +// } +// defer stdlib.ReleaseConn(db, conn) +// +// // do stuff with pgx.Conn +// +// It also can be used to enable a fast path for pgx while preserving +// compatibility with other drivers and database. +// +// conn, err := stdlib.AcquireConn(db) +// if err == nil { +// // fast path with pgx +// // ... +// // release conn when done +// stdlib.ReleaseConn(db, conn) +// } else { +// // normal path for other drivers and databases +// } package stdlib import ( @@ -20,6 +68,7 @@ import ( "database/sql" "database/sql/driver" "encoding/binary" + "errors" "fmt" "io" "sync" @@ -34,9 +83,16 @@ var databaseSqlOids map[pgtype.Oid]bool var pgxDriver *Driver +type ctxKey int + +var ctxKeyFakeTx ctxKey = 0 + +var ErrNotPgx = errors.New("not pgx *sql.DB") + func init() { pgxDriver = &Driver{ - configs: make(map[int64]*DriverConfig), + configs: make(map[int64]*DriverConfig), + fakeTxConns: make(map[*pgx.Conn]*sql.Tx), } sql.Register("pgx", pgxDriver) @@ -60,6 +116,9 @@ type Driver struct { configMutex sync.Mutex configCount int64 configs map[int64]*DriverConfig + + fakeTxMutex sync.Mutex + fakeTxConns map[*pgx.Conn]*sql.Tx } func (d *Driver) Open(name string) (driver.Conn, error) { @@ -91,7 +150,7 @@ func (d *Driver) Open(name string) (driver.Conn, error) { } } - c := &Conn{conn: conn} + c := &Conn{conn: conn, driver: d} return c, nil } @@ -145,6 +204,7 @@ func UnregisterDriverConfig(c *DriverConfig) { type Conn struct { conn *pgx.Conn psCount int64 // Counter used for creating unique prepared statement names + driver *Driver } func (c *Conn) Prepare(query string) (driver.Stmt, error) { @@ -178,6 +238,11 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e return nil, driver.ErrBadConn } + if pconn, ok := ctx.Value(ctxKeyFakeTx).(**pgx.Conn); ok { + *pconn = c.conn + return fakeTx{}, nil + } + var pgxOpts pgx.TxOptions switch sql.IsolationLevel(opts.Isolation) { case sql.LevelDefault: @@ -407,3 +472,47 @@ func namedValueToInterface(argsV []driver.NamedValue) []interface{} { } return args } + +type fakeTx struct{} + +func (fakeTx) Commit() error { return nil } + +func (fakeTx) Rollback() error { return nil } + +func AcquireConn(db *sql.DB) (*pgx.Conn, error) { + driver, ok := db.Driver().(*Driver) + if !ok { + return nil, ErrNotPgx + } + + var conn *pgx.Conn + ctx := context.WithValue(context.Background(), ctxKeyFakeTx, &conn) + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + + driver.fakeTxMutex.Lock() + driver.fakeTxConns[conn] = tx + driver.fakeTxMutex.Unlock() + + return conn, nil +} + +func ReleaseConn(db *sql.DB, conn *pgx.Conn) error { + var tx *sql.Tx + var ok bool + + driver := db.Driver().(*Driver) + driver.fakeTxMutex.Lock() + tx, ok = driver.fakeTxConns[conn] + if ok { + delete(driver.fakeTxConns, conn) + driver.fakeTxMutex.Unlock() + } else { + driver.fakeTxMutex.Unlock() + return fmt.Errorf("can't release conn that is not acquired") + } + + return tx.Rollback() +} diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index fdc93c0a..e9fcd27b 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -685,3 +685,42 @@ func TestConnBeginTxReadOnly(t *testing.T) { ensureConnValid(t, db) } + +func TestAcquireConn(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + var conns []*pgx.Conn + + for i := 1; i < 6; i++ { + conn, err := stdlib.AcquireConn(db) + if err != nil { + t.Errorf("%d. AcquireConn failed: %v", i, err) + continue + } + + var n int32 + err = conn.QueryRow("select 1").Scan(&n) + if err != nil { + t.Errorf("%d. QueryRow failed: %v", i, err) + } + if n != 1 { + t.Errorf("%d. n => %d, want %d", i, n, 1) + } + + stats := db.Stats() + if stats.OpenConnections != i { + t.Errorf("%d. stats.OpenConnections => %d, want %d", i, stats.OpenConnections, i) + } + + conns = append(conns, conn) + } + + for i, conn := range conns { + if err := stdlib.ReleaseConn(db, conn); err != nil { + t.Errorf("%d. stdlib.ReleaseConn failed: %v", i, err) + } + } + + ensureConnValid(t, db) +} From 6a2a5e28fd2b4210ed4c8cf19c0971cdffe24be6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 May 2017 19:48:03 -0500 Subject: [PATCH 196/264] Fix issues identified by go vet --- pgtype/interval.go | 8 ++++---- stdlib/sql_test.go | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pgtype/interval.go b/pgtype/interval.go index ea5c7d3e..85d76d99 100644 --- a/pgtype/interval.go +++ b/pgtype/interval.go @@ -118,26 +118,26 @@ func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { hours, err := strconv.ParseInt(timeParts[0], 10, 64) if err != nil { - return fmt.Errorf("bad interval hour format: %s", hours) + return fmt.Errorf("bad interval hour format: %s", timeParts[0]) } minutes, err := strconv.ParseInt(timeParts[1], 10, 64) if err != nil { - return fmt.Errorf("bad interval minute format: %s", minutes) + return fmt.Errorf("bad interval minute format: %s", timeParts[1]) } secondParts := strings.SplitN(timeParts[2], ".", 2) seconds, err := strconv.ParseInt(secondParts[0], 10, 64) if err != nil { - return fmt.Errorf("bad interval second format: %s", seconds) + return fmt.Errorf("bad interval second format: %s", secondParts[0]) } var uSeconds int64 if len(secondParts) == 2 { uSeconds, err = strconv.ParseInt(secondParts[1], 10, 64) if err != nil { - return fmt.Errorf("bad interval decimal format: %s", seconds) + return fmt.Errorf("bad interval decimal format: %s", secondParts[1]) } for i := 0; i < 6-len(secondParts[1]); i++ { diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index e9fcd27b..416a5a7e 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -676,7 +676,7 @@ func TestConnBeginTxReadOnly(t *testing.T) { var pgReadOnly string err = tx.QueryRow("show transaction_read_only").Scan(&pgReadOnly) if err != nil { - t.Errorf("%d. QueryRow failed: %v", err) + t.Errorf("QueryRow failed: %v", err) } if pgReadOnly != "on" { From d2d99eac651a5022c242c859b6861b606002dd77 Mon Sep 17 00:00:00 2001 From: Steve Atkins Date: Tue, 9 May 2017 08:24:31 -0700 Subject: [PATCH 197/264] Add godoc.org badge to README.md --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index b85f9c0f..877c2a37 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +[![](https://godoc.org/github.com/jackc/pgx?status.svg)](https://godoc.org/github.com/jackc/pgx) + # Pgx ## Experimental Branch From 413871a8977c37f16a6265052abbb1bd62c7e185 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 8 May 2017 18:07:11 -0500 Subject: [PATCH 198/264] Fix Bind Decode to advance rp --- pgproto3/bind.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pgproto3/bind.go b/pgproto3/bind.go index 6661a775..cbd71e13 100644 --- a/pgproto3/bind.go +++ b/pgproto3/bind.go @@ -52,6 +52,7 @@ func (dst *Bind) Decode(src []byte) error { return &invalidMessageFormatErr{messageType: "Bind"} } parameterCount := int(binary.BigEndian.Uint16(src[rp:])) + rp += 2 dst.Parameters = make([][]byte, parameterCount) From 479ebdfa1983a0964896928ce1361e06c2639d7a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 May 2017 17:56:54 -0500 Subject: [PATCH 199/264] Add basic pgmock support Primarily useful for testing pgx itself. Design is still subject to change. --- pgmock/pgmock.go | 478 ++++++++++++++++++++++++++++++++++++ pgproto3/startup_message.go | 6 +- stdlib/sql_test.go | 103 ++++++++ 3 files changed, 584 insertions(+), 3 deletions(-) create mode 100644 pgmock/pgmock.go diff --git a/pgmock/pgmock.go b/pgmock/pgmock.go new file mode 100644 index 00000000..827fa87d --- /dev/null +++ b/pgmock/pgmock.go @@ -0,0 +1,478 @@ +package pgmock + +import ( + "errors" + "fmt" + "net" + "reflect" + + "github.com/jackc/pgx/pgproto3" + "github.com/jackc/pgx/pgtype" +) + +type Server struct { + ln net.Listener + controller Controller +} + +func NewServer(controller Controller) (*Server, error) { + ln, err := net.Listen("tcp", "127.0.0.1:") + if err != nil { + return nil, err + } + + server := &Server{ + ln: ln, + controller: controller, + } + + return server, nil +} + +func (s *Server) Addr() net.Addr { + return s.ln.Addr() +} + +func (s *Server) ServeOne() error { + conn, err := s.ln.Accept() + if err != nil { + return err + } + + backend, err := pgproto3.NewBackend(conn, conn) + if err != nil { + conn.Close() + return err + } + + return s.controller.Serve(backend) +} + +func (s *Server) Close() error { + err := s.ln.Close() + if err != nil { + return err + } + + return nil +} + +type Controller interface { + Serve(backend *pgproto3.Backend) error +} + +type Step interface { + Step(*pgproto3.Backend) error +} + +type Script struct { + Steps []Step +} + +func (s *Script) Run(backend *pgproto3.Backend) error { + for _, step := range s.Steps { + err := step.Step(backend) + if err != nil { + return err + } + } + + return nil +} + +func (s *Script) Serve(backend *pgproto3.Backend) error { + for _, step := range s.Steps { + err := step.Step(backend) + if err != nil { + return err + } + } + + return nil +} + +func (s *Script) Step(backend *pgproto3.Backend) error { + return s.Serve(backend) +} + +type expectMessageStep struct { + want pgproto3.FrontendMessage + any bool +} + +func (e *expectMessageStep) Step(backend *pgproto3.Backend) error { + msg, err := backend.Receive() + if err != nil { + return err + } + + if e.any && reflect.TypeOf(msg) == reflect.TypeOf(e.want) { + return nil + } + + if !reflect.DeepEqual(msg, e.want) { + return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want) + } + + return nil +} + +type expectStartupMessageStep struct { + want *pgproto3.StartupMessage + any bool +} + +func (e *expectStartupMessageStep) Step(backend *pgproto3.Backend) error { + msg, err := backend.ReceiveStartupMessage() + if err != nil { + return err + } + + if e.any { + return nil + } + + if !reflect.DeepEqual(msg, e.want) { + return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want) + } + + return nil +} + +func ExpectMessage(want pgproto3.FrontendMessage) Step { + return expectMessage(want, false) +} + +func ExpectAnyMessage(want pgproto3.FrontendMessage) Step { + return expectMessage(want, true) +} + +func expectMessage(want pgproto3.FrontendMessage, any bool) Step { + if want, ok := want.(*pgproto3.StartupMessage); ok { + return &expectStartupMessageStep{want: want, any: any} + } + + return &expectMessageStep{want: want, any: any} +} + +type sendMessageStep struct { + msg pgproto3.BackendMessage +} + +func (e *sendMessageStep) Step(backend *pgproto3.Backend) error { + return backend.Send(e.msg) +} + +func SendMessage(msg pgproto3.BackendMessage) Step { + return &sendMessageStep{msg: msg} +} + +func AcceptUnauthenticatedConnRequestSteps() []Step { + return []Step{ + ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}), + SendMessage(&pgproto3.Authentication{Type: pgproto3.AuthTypeOk}), + SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}), + SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + } +} + +func PgxInitSteps() []Step { + steps := []Step{ + ExpectMessage(&pgproto3.Parse{ + Query: "select t.oid, t.typname\nfrom pg_type t\nleft join pg_type base_type on t.typelem=base_type.oid\nwhere (\n\t t.typtype in('b', 'p', 'r')\n\t and (base_type.oid is null or base_type.typtype in('b', 'p', 'r'))\n\t)", + }), + ExpectMessage(&pgproto3.Describe{ + ObjectType: 'S', + }), + ExpectMessage(&pgproto3.Sync{}), + SendMessage(&pgproto3.ParseComplete{}), + SendMessage(&pgproto3.ParameterDescription{}), + SendMessage(&pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + {Name: "oid", + TableOID: 1247, + TableAttributeNumber: 65534, + DataTypeOID: 26, + DataTypeSize: 4, + TypeModifier: 4294967295, + Format: 0, + }, + {Name: "typname", + TableOID: 1247, + TableAttributeNumber: 1, + DataTypeOID: 19, + DataTypeSize: 64, + TypeModifier: 4294967295, + Format: 0, + }, + }, + }), + SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + ExpectMessage(&pgproto3.Bind{ + ParameterFormatCodes: []int16{}, + Parameters: [][]byte{}, + ResultFormatCodes: []int16{1, 1}, + }), + ExpectMessage(&pgproto3.Execute{}), + ExpectMessage(&pgproto3.Sync{}), + SendMessage(&pgproto3.BindComplete{}), + } + + rowVals := []struct { + oid pgtype.Oid + name string + }{ + {16, "bool"}, + {17, "bytea"}, + {18, "char"}, + {19, "name"}, + {20, "int8"}, + {21, "int2"}, + {22, "int2vector"}, + {23, "int4"}, + {24, "regproc"}, + {25, "text"}, + {26, "oid"}, + {27, "tid"}, + {28, "xid"}, + {29, "cid"}, + {30, "oidvector"}, + {114, "json"}, + {142, "xml"}, + {143, "_xml"}, + {199, "_json"}, + {194, "pg_node_tree"}, + {32, "pg_ddl_command"}, + {210, "smgr"}, + {600, "point"}, + {601, "lseg"}, + {602, "path"}, + {603, "box"}, + {604, "polygon"}, + {628, "line"}, + {629, "_line"}, + {700, "float4"}, + {701, "float8"}, + {702, "abstime"}, + {703, "reltime"}, + {704, "tinterval"}, + {705, "unknown"}, + {718, "circle"}, + {719, "_circle"}, + {790, "money"}, + {791, "_money"}, + {829, "macaddr"}, + {869, "inet"}, + {650, "cidr"}, + {1000, "_bool"}, + {1001, "_bytea"}, + {1002, "_char"}, + {1003, "_name"}, + {1005, "_int2"}, + {1006, "_int2vector"}, + {1007, "_int4"}, + {1008, "_regproc"}, + {1009, "_text"}, + {1028, "_oid"}, + {1010, "_tid"}, + {1011, "_xid"}, + {1012, "_cid"}, + {1013, "_oidvector"}, + {1014, "_bpchar"}, + {1015, "_varchar"}, + {1016, "_int8"}, + {1017, "_point"}, + {1018, "_lseg"}, + {1019, "_path"}, + {1020, "_box"}, + {1021, "_float4"}, + {1022, "_float8"}, + {1023, "_abstime"}, + {1024, "_reltime"}, + {1025, "_tinterval"}, + {1027, "_polygon"}, + {1033, "aclitem"}, + {1034, "_aclitem"}, + {1040, "_macaddr"}, + {1041, "_inet"}, + {651, "_cidr"}, + {1263, "_cstring"}, + {1042, "bpchar"}, + {1043, "varchar"}, + {1082, "date"}, + {1083, "time"}, + {1114, "timestamp"}, + {1115, "_timestamp"}, + {1182, "_date"}, + {1183, "_time"}, + {1184, "timestamptz"}, + {1185, "_timestamptz"}, + {1186, "interval"}, + {1187, "_interval"}, + {1231, "_numeric"}, + {1266, "timetz"}, + {1270, "_timetz"}, + {1560, "bit"}, + {1561, "_bit"}, + {1562, "varbit"}, + {1563, "_varbit"}, + {1700, "numeric"}, + {1790, "refcursor"}, + {2201, "_refcursor"}, + {2202, "regprocedure"}, + {2203, "regoper"}, + {2204, "regoperator"}, + {2205, "regclass"}, + {2206, "regtype"}, + {4096, "regrole"}, + {4089, "regnamespace"}, + {2207, "_regprocedure"}, + {2208, "_regoper"}, + {2209, "_regoperator"}, + {2210, "_regclass"}, + {2211, "_regtype"}, + {4097, "_regrole"}, + {4090, "_regnamespace"}, + {2950, "uuid"}, + {2951, "_uuid"}, + {3220, "pg_lsn"}, + {3221, "_pg_lsn"}, + {3614, "tsvector"}, + {3642, "gtsvector"}, + {3615, "tsquery"}, + {3734, "regconfig"}, + {3769, "regdictionary"}, + {3643, "_tsvector"}, + {3644, "_gtsvector"}, + {3645, "_tsquery"}, + {3735, "_regconfig"}, + {3770, "_regdictionary"}, + {3802, "jsonb"}, + {3807, "_jsonb"}, + {2970, "txid_snapshot"}, + {2949, "_txid_snapshot"}, + {3904, "int4range"}, + {3905, "_int4range"}, + {3906, "numrange"}, + {3907, "_numrange"}, + {3908, "tsrange"}, + {3909, "_tsrange"}, + {3910, "tstzrange"}, + {3911, "_tstzrange"}, + {3912, "daterange"}, + {3913, "_daterange"}, + {3926, "int8range"}, + {3927, "_int8range"}, + {2249, "record"}, + {2287, "_record"}, + {2275, "cstring"}, + {2276, "any"}, + {2277, "anyarray"}, + {2278, "void"}, + {2279, "trigger"}, + {3838, "event_trigger"}, + {2280, "language_handler"}, + {2281, "internal"}, + {2282, "opaque"}, + {2283, "anyelement"}, + {2776, "anynonarray"}, + {3500, "anyenum"}, + {3115, "fdw_handler"}, + {325, "index_am_handler"}, + {3310, "tsm_handler"}, + {3831, "anyrange"}, + {51367, "gbtreekey4"}, + {51370, "_gbtreekey4"}, + {51371, "gbtreekey8"}, + {51374, "_gbtreekey8"}, + {51375, "gbtreekey16"}, + {51378, "_gbtreekey16"}, + {51379, "gbtreekey32"}, + {51382, "_gbtreekey32"}, + {51383, "gbtreekey_var"}, + {51386, "_gbtreekey_var"}, + {51921, "hstore"}, + {51926, "_hstore"}, + {52005, "ghstore"}, + {52008, "_ghstore"}, + } + + for _, rv := range rowVals { + step := SendMessage(mustBuildDataRow([]interface{}{rv.oid, rv.name}, []int16{pgproto3.BinaryFormat})) + steps = append(steps, step) + } + + steps = append(steps, SendMessage(&pgproto3.CommandComplete{CommandTag: "SELECT 163"})) + steps = append(steps, SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'})) + + return steps +} + +type dataRowValue struct { + Value interface{} + FormatCode int16 +} + +func mustBuildDataRow(values []interface{}, formatCodes []int16) *pgproto3.DataRow { + dr, err := buildDataRow(values, formatCodes) + if err != nil { + panic(err) + } + + return dr +} + +func buildDataRow(values []interface{}, formatCodes []int16) (*pgproto3.DataRow, error) { + dr := &pgproto3.DataRow{ + Values: make([][]byte, len(values)), + } + + if len(formatCodes) == 1 { + for i := 1; i < len(values); i++ { + formatCodes = append(formatCodes, formatCodes[0]) + } + } + + for i := range values { + switch v := values[i].(type) { + case string: + values[i] = &pgtype.Text{String: v, Status: pgtype.Present} + case int16: + values[i] = &pgtype.Int2{Int: v, Status: pgtype.Present} + case int32: + values[i] = &pgtype.Int4{Int: v, Status: pgtype.Present} + case int64: + values[i] = &pgtype.Int8{Int: v, Status: pgtype.Present} + } + } + + for i := range values { + switch formatCodes[i] { + case pgproto3.TextFormat: + if e, ok := values[i].(pgtype.TextEncoder); ok { + buf, err := e.EncodeText(nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to encode values[%d]", i) + } + dr.Values[i] = buf + } else { + return nil, fmt.Errorf("values[%d] does not implement TextExcoder", i) + } + + case pgproto3.BinaryFormat: + if e, ok := values[i].(pgtype.BinaryEncoder); ok { + buf, err := e.EncodeBinary(nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to encode values[%d]", i) + } + dr.Values[i] = buf + } else { + return nil, fmt.Errorf("values[%d] does not implement BinaryEncoder", i) + } + default: + return nil, errors.New("unknown FormatCode") + } + } + + return dr, nil +} diff --git a/pgproto3/startup_message.go b/pgproto3/startup_message.go index ebb804fe..4847d629 100644 --- a/pgproto3/startup_message.go +++ b/pgproto3/startup_message.go @@ -8,7 +8,7 @@ import ( ) const ( - protocolVersionNumber = 196608 // 3.0 + ProtocolVersionNumber = 196608 // 3.0 sslRequestNumber = 80877103 ) @@ -31,8 +31,8 @@ func (dst *StartupMessage) Decode(src []byte) error { return fmt.Errorf("can't handle ssl connection request") } - if dst.ProtocolVersion != protocolVersionNumber { - return fmt.Errorf("Bad startup message version number. Expected %d, got %d", protocolVersionNumber, dst.ProtocolVersion) + if dst.ProtocolVersion != ProtocolVersionNumber { + return fmt.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion) } dst.Parameters = make(map[string]string) diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 416a5a7e..4f2484d8 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -4,9 +4,12 @@ import ( "bytes" "context" "database/sql" + "fmt" "testing" "github.com/jackc/pgx" + "github.com/jackc/pgx/pgmock" + "github.com/jackc/pgx/pgproto3" "github.com/jackc/pgx/stdlib" ) @@ -686,6 +689,106 @@ func TestConnBeginTxReadOnly(t *testing.T) { ensureConnValid(t, db) } +func TestBeginTxContextCancel(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + _, err := db.Exec("drop table if exists t") + if err != nil { + t.Fatalf("db.Exec failed: %v", err) + } + + ctx, cancelFn := context.WithCancel(context.Background()) + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + t.Fatalf("BeginTx failed: %v", err) + } + + _, err = tx.Exec("create table t(id serial)") + if err != nil { + t.Fatalf("tx.Exec failed: %v", err) + } + + cancelFn() + + err = tx.Commit() + if err != context.Canceled { + t.Fatalf("err => %v, want %v", err, context.Canceled) + } + + var n int + err = db.QueryRow("select count(*) from t").Scan(&n) + if pgErr, ok := err.(pgx.PgError); !ok || pgErr.Code != "42P01" { + t.Fatalf(`err => %v, want PgError{Code: "42P01"}`, err) + } + + ensureConnValid(t, db) +} + +func acceptStandardPgxConn(backend *pgproto3.Backend) error { + script := pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + + err := script.Run(backend) + if err != nil { + return err + } + + typeScript := pgmock.Script{ + Steps: pgmock.PgxInitSteps(), + } + + return typeScript.Run(backend) +} + +func TestBeginTxContextCancelWithDeadConn(t *testing.T) { + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) + script.Steps = append(script.Steps, + pgmock.ExpectMessage(&pgproto3.Query{String: "begin"}), + pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: "BEGIN"}), + pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'T'}), + ) + + server, err := pgmock.NewServer(script) + if err != nil { + t.Fatal(err) + } + + errChan := make(chan error) + go func() { + errChan <- server.ServeOne() + }() + + db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + defer closeDB(t, db) + + ctx, cancelFn := context.WithCancel(context.Background()) + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + t.Fatalf("BeginTx failed: %v", err) + } + + cancelFn() + + err = tx.Commit() + if err != context.Canceled { + t.Fatalf("err => %v, want %v", err, context.Canceled) + } + + if err := <-errChan; err != nil { + t.Fatalf("mock server err: %v", err) + } +} + func TestAcquireConn(t *testing.T) { db := openDB(t) defer closeDB(t, db) From e1397613fd8ac3b4cac43a567613e48afcbeac17 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 19 May 2017 08:02:51 -0500 Subject: [PATCH 200/264] Ping only makes sense with a context for timeout --- conn.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/conn.go b/conn.go index 4a77f769..bd098646 100644 --- a/conn.go +++ b/conn.go @@ -1388,11 +1388,7 @@ func (c *Conn) cancelQuery() { }() } -func (c *Conn) Ping() error { - return c.PingContext(context.Background()) -} - -func (c *Conn) PingContext(ctx context.Context) error { +func (c *Conn) Ping(ctx context.Context) error { _, err := c.ExecEx(ctx, ";", nil) return err } From 936cb688667497292ae67dddf013d3a5e067ad9b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 19 May 2017 08:54:08 -0500 Subject: [PATCH 201/264] Add driver.Pinger support to stdlib.Conn --- stdlib/sql.go | 8 ++++++++ stdlib/sql_test.go | 50 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/stdlib/sql.go b/stdlib/sql.go index e70780c1..9e97af90 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -336,6 +336,14 @@ func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []dr return &Rows{rows: rows}, nil } +func (c *Conn) Ping(ctx context.Context) error { + if !c.conn.IsAlive() { + return driver.ErrBadConn + } + + return c.conn.Ping(ctx) +} + // Anything that isn't a database/sql compatible type needs to be forced to // text format so that pgx.Rows.Values doesn't decode it into a native type // (e.g. []int32) diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 4f2484d8..af2b9fe7 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -6,6 +6,7 @@ import ( "database/sql" "fmt" "testing" + "time" "github.com/jackc/pgx" "github.com/jackc/pgx/pgmock" @@ -827,3 +828,52 @@ func TestAcquireConn(t *testing.T) { ensureConnValid(t, db) } + +func TestConnPingContextSuccess(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + if err := db.PingContext(context.Background()); err != nil { + t.Fatalf("db.PingContext failed: %v", err) + } + + ensureConnValid(t, db) +} + +func TestConnPingContextCancel(t *testing.T) { + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) + script.Steps = append(script.Steps, + pgmock.ExpectMessage(&pgproto3.Query{String: ";"}), + ) + + server, err := pgmock.NewServer(script) + if err != nil { + t.Fatal(err) + } + defer server.Close() + + errChan := make(chan error) + go func() { + errChan <- server.ServeOne() + }() + + db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + // defer closeDB(t, db) // mock DB doesn't close correctly yet + + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + + err = db.PingContext(ctx) + if err != context.DeadlineExceeded { + t.Errorf("err => %v, want %v", err, context.DeadlineExceeded) + } + + if err := <-errChan; err != nil { + t.Errorf("mock server err: %v", err) + } +} From f8d7602270dafc831f101998a61c2b0e17b631b4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 19 May 2017 14:31:01 -0500 Subject: [PATCH 202/264] Add driver.ConnPrepareContext support to stdlib.Conn --- stdlib/sql.go | 6 +++++- stdlib/sql_test.go | 53 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/stdlib/sql.go b/stdlib/sql.go index 9e97af90..bc2849c2 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -208,6 +208,10 @@ type Conn struct { } func (c *Conn) Prepare(query string) (driver.Stmt, error) { + return c.PrepareContext(context.Background(), query) +} + +func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { if !c.conn.IsAlive() { return nil, driver.ErrBadConn } @@ -215,7 +219,7 @@ func (c *Conn) Prepare(query string) (driver.Stmt, error) { name := fmt.Sprintf("pgx_%d", c.psCount) c.psCount++ - ps, err := c.conn.Prepare(name, query) + ps, err := c.conn.PrepareExContext(ctx, name, query, nil) if err != nil { return nil, err } diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index af2b9fe7..105de3d8 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -877,3 +877,56 @@ func TestConnPingContextCancel(t *testing.T) { t.Errorf("mock server err: %v", err) } } + +func TestConnPrepareContextSuccess(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + stmt, err := db.PrepareContext(context.Background(), "select now()") + if err != nil { + t.Fatalf("db.PrepareContext failed: %v", err) + } + stmt.Close() + + ensureConnValid(t, db) +} + +func TestConnPrepareContextCancel(t *testing.T) { + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) + script.Steps = append(script.Steps, + pgmock.ExpectMessage(&pgproto3.Parse{Name: "pgx_0", Query: "select now()"}), + pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S', Name: "pgx_0"}), + pgmock.ExpectMessage(&pgproto3.Sync{}), + ) + + server, err := pgmock.NewServer(script) + if err != nil { + t.Fatal(err) + } + defer server.Close() + + errChan := make(chan error) + go func() { + errChan <- server.ServeOne() + }() + + db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + // defer closeDB(t, db) // mock DB doesn't close correctly yet + + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + + _, err = db.PrepareContext(ctx, "select now()") + if err != context.DeadlineExceeded { + t.Errorf("err => %v, want %v", err, context.DeadlineExceeded) + } + + if err := <-errChan; err != nil { + t.Errorf("mock server err: %v", err) + } +} From dbcfa46d8e485185a610cdda1b9c2d7b03250fa0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 19 May 2017 14:57:49 -0500 Subject: [PATCH 203/264] Add driver.ExecerContext support to stdlib.Conn --- stdlib/sql.go | 11 ++++++++++ stdlib/sql_test.go | 50 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/stdlib/sql.go b/stdlib/sql.go index bc2849c2..400f8311 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -279,6 +279,17 @@ func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) { return driver.RowsAffected(commandTag.RowsAffected()), err } +func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Result, error) { + if !c.conn.IsAlive() { + return nil, driver.ErrBadConn + } + + args := namedValueToInterface(argsV) + + commandTag, err := c.conn.ExecEx(ctx, query, nil, args...) + return driver.RowsAffected(commandTag.RowsAffected()), err +} + func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) { if !c.conn.IsAlive() { return nil, driver.ErrBadConn diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 105de3d8..f12c43d5 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -930,3 +930,53 @@ func TestConnPrepareContextCancel(t *testing.T) { t.Errorf("mock server err: %v", err) } } + +func TestConnExecContextSuccess(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + _, err := db.ExecContext(context.Background(), "create temporary table exec_context_test(id serial primary key)") + if err != nil { + t.Fatalf("db.ExecContext failed: %v", err) + } + + ensureConnValid(t, db) +} + +func TestConnExecContextCancel(t *testing.T) { + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) + script.Steps = append(script.Steps, + pgmock.ExpectMessage(&pgproto3.Query{String: "create temporary table exec_context_test(id serial primary key)"}), + ) + + server, err := pgmock.NewServer(script) + if err != nil { + t.Fatal(err) + } + defer server.Close() + + errChan := make(chan error) + go func() { + errChan <- server.ServeOne() + }() + + db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + // defer closeDB(t, db) // mock DB doesn't close correctly yet + + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + + _, err = db.ExecContext(ctx, "create temporary table exec_context_test(id serial primary key)") + if err != context.DeadlineExceeded { + t.Errorf("err => %v, want %v", err, context.DeadlineExceeded) + } + + if err := <-errChan; err != nil { + t.Errorf("mock server err: %v", err) + } +} From 3080d0ee4d5dd4a9608081e3d09b2614ec24dbf7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 19 May 2017 15:50:27 -0500 Subject: [PATCH 204/264] Do not create empty slices in Bind.Decode --- pgmock/pgmock.go | 4 +--- pgproto3/bind.go | 58 ++++++++++++++++++++++++++---------------------- 2 files changed, 33 insertions(+), 29 deletions(-) diff --git a/pgmock/pgmock.go b/pgmock/pgmock.go index 827fa87d..8dccf811 100644 --- a/pgmock/pgmock.go +++ b/pgmock/pgmock.go @@ -209,9 +209,7 @@ func PgxInitSteps() []Step { }), SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), ExpectMessage(&pgproto3.Bind{ - ParameterFormatCodes: []int16{}, - Parameters: [][]byte{}, - ResultFormatCodes: []int16{1, 1}, + ResultFormatCodes: []int16{1, 1}, }), ExpectMessage(&pgproto3.Execute{}), ExpectMessage(&pgproto3.Sync{}), diff --git a/pgproto3/bind.go b/pgproto3/bind.go index cbd71e13..79fb4503 100644 --- a/pgproto3/bind.go +++ b/pgproto3/bind.go @@ -18,6 +18,8 @@ type Bind struct { func (*Bind) Frontend() {} func (dst *Bind) Decode(src []byte) error { + *dst = Bind{} + idx := bytes.IndexByte(src, 0) if idx < 0 { return &invalidMessageFormatErr{messageType: "Bind"} @@ -38,14 +40,16 @@ func (dst *Bind) Decode(src []byte) error { parameterFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:])) rp += 2 - dst.ParameterFormatCodes = make([]int16, parameterFormatCodeCount) + if parameterFormatCodeCount > 0 { + dst.ParameterFormatCodes = make([]int16, parameterFormatCodeCount) - if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 { - return &invalidMessageFormatErr{messageType: "Bind"} - } - for i := 0; i < parameterFormatCodeCount; i++ { - dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:])) - rp += 2 + if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + for i := 0; i < parameterFormatCodeCount; i++ { + dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:])) + rp += 2 + } } if len(src[rp:]) < 2 { @@ -54,27 +58,29 @@ func (dst *Bind) Decode(src []byte) error { parameterCount := int(binary.BigEndian.Uint16(src[rp:])) rp += 2 - dst.Parameters = make([][]byte, parameterCount) + if parameterCount > 0 { + dst.Parameters = make([][]byte, parameterCount) - for i := 0; i < parameterCount; i++ { - if len(src[rp:]) < 4 { - return &invalidMessageFormatErr{messageType: "Bind"} + for i := 0; i < parameterCount; i++ { + if len(src[rp:]) < 4 { + return &invalidMessageFormatErr{messageType: "Bind"} + } + + msgSize := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + // null + if msgSize == -1 { + continue + } + + if len(src[rp:]) < msgSize { + return &invalidMessageFormatErr{messageType: "Bind"} + } + + dst.Parameters[i] = src[rp : rp+msgSize] + rp += msgSize } - - msgSize := int(int32(binary.BigEndian.Uint32(src[rp:]))) - rp += 4 - - // null - if msgSize == -1 { - continue - } - - if len(src[rp:]) < msgSize { - return &invalidMessageFormatErr{messageType: "Bind"} - } - - dst.Parameters[i] = src[rp : rp+msgSize] - rp += msgSize } if len(src[rp:]) < 2 { From 4f31904904101cb4df33473745ad6a13c999359e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 19 May 2017 15:52:18 -0500 Subject: [PATCH 205/264] Remove spurious Println --- stdlib/sql.go | 1 - 1 file changed, 1 deletion(-) diff --git a/stdlib/sql.go b/stdlib/sql.go index 400f8311..5fb185d2 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -344,7 +344,6 @@ func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []dr rows, err := c.conn.QueryEx(ctx, name, nil, args...) if err != nil { - fmt.Println(err) return nil, err } From 5ee76a26c8e2206884a42478c450ba7f5901f300 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 19 May 2017 15:52:36 -0500 Subject: [PATCH 206/264] Add tests for stdlib.Conn.QueryContext --- stdlib/sql_test.go | 93 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index f12c43d5..83f32ea8 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -980,3 +980,96 @@ func TestConnExecContextCancel(t *testing.T) { t.Errorf("mock server err: %v", err) } } + +func TestConnQueryContextSuccess(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + rows, err := db.QueryContext(context.Background(), "select * from generate_series(1,10) n") + if err != nil { + t.Fatalf("db.QueryContext failed: %v", err) + } + + for rows.Next() { + var n int64 + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + } + + if rows.Err() != nil { + t.Error(rows.Err()) + } + + ensureConnValid(t, db) +} + +func TestConnQueryContextCancel(t *testing.T) { + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) + script.Steps = append(script.Steps, + pgmock.ExpectMessage(&pgproto3.Parse{Query: "select * from generate_series(1,10) n"}), + pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S'}), + pgmock.ExpectMessage(&pgproto3.Sync{}), + + pgmock.SendMessage(&pgproto3.ParseComplete{}), + pgmock.SendMessage(&pgproto3.ParameterDescription{}), + pgmock.SendMessage(&pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + { + Name: "n", + DataTypeOID: 23, + DataTypeSize: 4, + TypeModifier: 4294967295, + }, + }, + }), + pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + + pgmock.ExpectMessage(&pgproto3.Bind{ResultFormatCodes: []int16{1}}), + pgmock.ExpectMessage(&pgproto3.Execute{}), + pgmock.ExpectMessage(&pgproto3.Sync{}), + + pgmock.SendMessage(&pgproto3.BindComplete{}), + ) + + server, err := pgmock.NewServer(script) + if err != nil { + t.Fatal(err) + } + defer server.Close() + + errChan := make(chan error) + go func() { + errChan <- server.ServeOne() + }() + + db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + // defer closeDB(t, db) // mock DB doesn't close correctly yet + + ctx, cancelFn := context.WithCancel(context.Background()) + + rows, err := db.QueryContext(ctx, "select * from generate_series(1,10) n") + if err != nil { + t.Fatalf("db.QueryContext failed: %v", err) + } + + cancelFn() + + for rows.Next() { + t.Fatalf("no rows should ever be received") + } + + if rows.Err() != context.Canceled { + t.Errorf("rows.Err() => %v, want %v", rows.Err(), context.Canceled) + } + + if err := <-errChan; err != nil { + t.Errorf("mock server err: %v", err) + } +} From f9cb22e4b86c0e311e3591ab383fa878530e4431 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 19 May 2017 16:05:30 -0500 Subject: [PATCH 207/264] Add driver.RowsColumnTypeDatabaseTypeName support to stdlib.Rows --- stdlib/sql.go | 5 +++++ stdlib/sql_test.go | 27 +++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/stdlib/sql.go b/stdlib/sql.go index 5fb185d2..ce79edb6 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -71,6 +71,7 @@ import ( "errors" "fmt" "io" + "strings" "sync" "github.com/jackc/pgx" @@ -405,6 +406,10 @@ func (r *Rows) Columns() []string { return names } +func (r *Rows) ColumnTypeDatabaseTypeName(index int) string { + return strings.ToUpper(r.rows.FieldDescriptions()[index].DataTypeName) +} + func (r *Rows) Close() error { r.rows.Close() return nil diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 83f32ea8..e7db03c9 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -1073,3 +1073,30 @@ func TestConnQueryContextCancel(t *testing.T) { t.Errorf("mock server err: %v", err) } } + +func TestRowsColumnTypeDatabaseTypeName(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + rows, err := db.Query("select * from generate_series(1,10) n") + if err != nil { + t.Fatalf("db.Query failed: %v", err) + } + + columnTypes, err := rows.ColumnTypes() + if err != nil { + t.Fatalf("rows.ColumnTypes failed: %v", err) + } + + if len(columnTypes) != 1 { + t.Fatalf("len(columnTypes) => %v, want %v", len(columnTypes), 1) + } + + if columnTypes[0].DatabaseTypeName() != "INT4" { + t.Errorf("columnTypes[0].DatabaseTypeName() => %v, want %v", columnTypes[0].DatabaseTypeName(), "INT4") + } + + rows.Close() + + ensureConnValid(t, db) +} From 7f226539a0b8d9d05ea5e697c19530a25f12fa1b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 19 May 2017 16:18:54 -0500 Subject: [PATCH 208/264] Add driver.StmtExecContext support to stdlib.Stmt --- stdlib/sql.go | 4 ++++ stdlib/sql_test.go | 48 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/stdlib/sql.go b/stdlib/sql.go index ce79edb6..408bc62a 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -388,6 +388,10 @@ func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) { return s.conn.Exec(s.ps.Name, argsV) } +func (s *Stmt) ExecContext(ctx context.Context, argsV []driver.NamedValue) (driver.Result, error) { + return s.conn.ExecContext(ctx, s.ps.Name, argsV) +} + func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) { return s.conn.queryPrepared(s.ps.Name, argsV) } diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index e7db03c9..447aa8b6 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -1100,3 +1100,51 @@ func TestRowsColumnTypeDatabaseTypeName(t *testing.T) { ensureConnValid(t, db) } + +func TestStmtExecContextSuccess(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + _, err := db.Exec("create temporary table t(id int primary key)") + if err != nil { + t.Fatalf("db.Exec failed: %v", err) + } + + stmt, err := db.Prepare("insert into t(id) values ($1::int4)") + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + _, err = stmt.ExecContext(context.Background(), 42) + if err != nil { + t.Fatal(err) + } + + ensureConnValid(t, db) +} + +func TestStmtExecContextCancel(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + _, err := db.Exec("create temporary table t(id int primary key)") + if err != nil { + t.Fatalf("db.Exec failed: %v", err) + } + + stmt, err := db.Prepare("insert into t(id) select $1::int4 from pg_sleep(5)") + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + + _, err = stmt.ExecContext(ctx, 42) + if err != context.DeadlineExceeded { + t.Errorf("err => %v, want %v", err, context.DeadlineExceeded) + } + + ensureConnValid(t, db) +} From c6cb362b189885af1a7467b35c116a716c504fbb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 19 May 2017 17:31:22 -0500 Subject: [PATCH 209/264] Add flush and close messages to pgproto3 --- pgproto3/backend.go | 6 +++++ pgproto3/close.go | 60 +++++++++++++++++++++++++++++++++++++++++++++ pgproto3/flush.go | 29 ++++++++++++++++++++++ 3 files changed, 95 insertions(+) create mode 100644 pgproto3/close.go create mode 100644 pgproto3/flush.go diff --git a/pgproto3/backend.go b/pgproto3/backend.go index bd477315..df66a799 100644 --- a/pgproto3/backend.go +++ b/pgproto3/backend.go @@ -14,8 +14,10 @@ type Backend struct { // Frontend message flyweights bind Bind + _close Close describe Describe execute Execute + flush Flush parse Parse passwordMessage PasswordMessage query Query @@ -72,10 +74,14 @@ func (b *Backend) Receive() (FrontendMessage, error) { switch msgType { case 'B': msg = &b.bind + case 'C': + msg = &b._close case 'D': msg = &b.describe case 'E': msg = &b.execute + case 'H': + msg = &b.flush case 'P': msg = &b.parse case 'p': diff --git a/pgproto3/close.go b/pgproto3/close.go new file mode 100644 index 00000000..454ef68e --- /dev/null +++ b/pgproto3/close.go @@ -0,0 +1,60 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "encoding/json" +) + +type Close struct { + ObjectType byte // 'S' = prepared statement, 'P' = portal + Name string +} + +func (*Close) Frontend() {} + +func (dst *Close) Decode(src []byte) error { + if len(src) < 2 { + return &invalidMessageFormatErr{messageType: "Close"} + } + + dst.ObjectType = src[0] + rp := 1 + + idx := bytes.IndexByte(src[rp:], 0) + if idx != len(src[rp:])-1 { + return &invalidMessageFormatErr{messageType: "Close"} + } + + dst.Name = string(src[rp : len(src)-1]) + + return nil +} + +func (src *Close) MarshalBinary() ([]byte, error) { + var bigEndian BigEndianBuf + buf := &bytes.Buffer{} + + buf.WriteByte('C') + buf.Write(bigEndian.Uint32(0)) + + buf.WriteByte(src.ObjectType) + buf.WriteString(src.Name) + buf.WriteByte(0) + + binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + + return buf.Bytes(), nil +} + +func (src *Close) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ObjectType string + Name string + }{ + Type: "Close", + ObjectType: string(src.ObjectType), + Name: src.Name, + }) +} diff --git a/pgproto3/flush.go b/pgproto3/flush.go new file mode 100644 index 00000000..d26f5c0c --- /dev/null +++ b/pgproto3/flush.go @@ -0,0 +1,29 @@ +package pgproto3 + +import ( + "encoding/json" +) + +type Flush struct{} + +func (*Flush) Frontend() {} + +func (dst *Flush) Decode(src []byte) error { + if len(src) != 0 { + return &invalidMessageLenErr{messageType: "Flush", expectedLen: 0, actualLen: len(src)} + } + + return nil +} + +func (src *Flush) MarshalBinary() ([]byte, error) { + return []byte{'H', 0, 0, 0, 4}, nil +} + +func (src *Flush) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + }{ + Type: "Flush", + }) +} From e5820baebe59c24cd088f6bb27d7398ed175963b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 19 May 2017 17:31:56 -0500 Subject: [PATCH 210/264] Add driver.StmtQueryContext support to stdlib.Stmt --- stdlib/sql.go | 4 ++ stdlib/sql_test.go | 110 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 114 insertions(+) diff --git a/stdlib/sql.go b/stdlib/sql.go index 408bc62a..088095ab 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -396,6 +396,10 @@ func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) { return s.conn.queryPrepared(s.ps.Name, argsV) } +func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (driver.Rows, error) { + return s.conn.queryPreparedContext(ctx, s.ps.Name, argsV) +} + type Rows struct { rows *pgx.Rows values []interface{} diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 447aa8b6..b26c815d 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -1148,3 +1148,113 @@ func TestStmtExecContextCancel(t *testing.T) { ensureConnValid(t, db) } + +func TestStmtQueryContextSuccess(t *testing.T) { + // db := openDB(t) + // defer closeDB(t, db) + + db, err := sql.Open("pgx", "postgres://pgx_md5:secret@127.0.0.1:15432/pgx_test?sslmode=disable") + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + + stmt, err := db.Prepare("select * from generate_series(1,$1::int4) n") + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + rows, err := stmt.QueryContext(context.Background(), 5) + if err != nil { + t.Fatalf("stmt.QueryContext failed: %v", err) + } + + for rows.Next() { + var n int64 + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + } + + if rows.Err() != nil { + t.Error(rows.Err()) + } + + ensureConnValid(t, db) +} + +func TestStmtQueryContextCancel(t *testing.T) { + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) + script.Steps = append(script.Steps, + pgmock.ExpectMessage(&pgproto3.Parse{Name: "pgx_0", Query: "select * from generate_series(1, $1::int4) n"}), + pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S', Name: "pgx_0"}), + pgmock.ExpectMessage(&pgproto3.Sync{}), + + pgmock.SendMessage(&pgproto3.ParseComplete{}), + pgmock.SendMessage(&pgproto3.ParameterDescription{ParameterOIDs: []uint32{23}}), + pgmock.SendMessage(&pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + { + Name: "n", + DataTypeOID: 23, + DataTypeSize: 4, + TypeModifier: 4294967295, + }, + }, + }), + pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + + pgmock.ExpectMessage(&pgproto3.Bind{PreparedStatement: "pgx_0", ParameterFormatCodes: []int16{1}, Parameters: [][]uint8{[]uint8{0x0, 0x0, 0x0, 0x2a}}, ResultFormatCodes: []int16{1}}), + pgmock.ExpectMessage(&pgproto3.Execute{}), + pgmock.ExpectMessage(&pgproto3.Sync{}), + + pgmock.SendMessage(&pgproto3.BindComplete{}), + ) + + server, err := pgmock.NewServer(script) + if err != nil { + t.Fatal(err) + } + defer server.Close() + + errChan := make(chan error) + go func() { + errChan <- server.ServeOne() + }() + + db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + // defer closeDB(t, db) // mock DB doesn't close correctly yet + + stmt, err := db.Prepare("select * from generate_series(1, $1::int4) n") + if err != nil { + t.Fatal(err) + } + // defer stmt.Close() + + ctx, cancelFn := context.WithCancel(context.Background()) + + rows, err := stmt.QueryContext(ctx, 42) + if err != nil { + t.Fatalf("stmt.QueryContext failed: %v", err) + } + + cancelFn() + + for rows.Next() { + t.Fatalf("no rows should ever be received") + } + + if rows.Err() != context.Canceled { + t.Errorf("rows.Err() => %v, want %v", rows.Err(), context.Canceled) + } + + if err := <-errChan; err != nil { + t.Errorf("mock server err: %v", err) + } +} From 2a27fb1817e824e8727764f0d9b3c4fd751ddf8f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 May 2017 08:30:47 -0500 Subject: [PATCH 211/264] Remove accidentally committed mock db open --- stdlib/sql_test.go | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index b26c815d..aa3ae3ee 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -1150,13 +1150,8 @@ func TestStmtExecContextCancel(t *testing.T) { } func TestStmtQueryContextSuccess(t *testing.T) { - // db := openDB(t) - // defer closeDB(t, db) - - db, err := sql.Open("pgx", "postgres://pgx_md5:secret@127.0.0.1:15432/pgx_test?sslmode=disable") - if err != nil { - t.Fatalf("sql.Open failed: %v", err) - } + db := openDB(t) + defer closeDB(t, db) stmt, err := db.Prepare("select * from generate_series(1,$1::int4) n") if err != nil { From a904e672c1579ff20897b2fef4011ac2b22e321f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 May 2017 08:34:20 -0500 Subject: [PATCH 212/264] Uncomment Hstore tests --- pgtype/hstore_test.go | 58 +++++++++++++++++++++---------------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/pgtype/hstore_test.go b/pgtype/hstore_test.go index 8189e4db..dc2439fc 100644 --- a/pgtype/hstore_test.go +++ b/pgtype/hstore_test.go @@ -9,41 +9,41 @@ import ( ) func TestHstoreTranscode(t *testing.T) { - // text := func(s string) pgtype.Text { - // return pgtype.Text{String: s, Status: pgtype.Present} - // } + text := func(s string) pgtype.Text { + return pgtype.Text{String: s, Status: pgtype.Present} + } values := []interface{}{ &pgtype.Hstore{Map: map[string]pgtype.Text{}, Status: pgtype.Present}, - // &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, - // &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, - // &pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, - // &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, - // &pgtype.Hstore{Status: pgtype.Null}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, + &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, + &pgtype.Hstore{Status: pgtype.Null}, } - // specialStrings := []string{ - // `"`, - // `'`, - // `\`, - // `\\`, - // `=>`, - // ` `, - // `\ / / \\ => " ' " '`, - // } - // for _, s := range specialStrings { - // // Special key values - // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Status: pgtype.Present}) // at beginning - // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Status: pgtype.Present}) // in middle - // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Status: pgtype.Present}) // at end - // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Status: pgtype.Present}) // is key + specialStrings := []string{ + `"`, + `'`, + `\`, + `\\`, + `=>`, + ` `, + `\ / / \\ => " ' " '`, + } + for _, s := range specialStrings { + // Special key values + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Status: pgtype.Present}) // at beginning + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Status: pgtype.Present}) // in middle + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Status: pgtype.Present}) // at end + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Status: pgtype.Present}) // is key - // // Special value values - // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Status: pgtype.Present}) // at beginning - // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Status: pgtype.Present}) // in middle - // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Status: pgtype.Present}) // at end - // values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key - // } + // Special value values + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Status: pgtype.Present}) // at beginning + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Status: pgtype.Present}) // in middle + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Status: pgtype.Present}) // at end + values = append(values, &pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key + } testutil.TestSuccessfulTranscodeEqFunc(t, "hstore", values, func(ai, bi interface{}) bool { a := ai.(pgtype.Hstore) From ace282df665bd85e4137bd0e5bda312b6e650bd8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 May 2017 08:36:40 -0500 Subject: [PATCH 213/264] Test &pgtype.QChar --- pgtype/qchar_test.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pgtype/qchar_test.go b/pgtype/qchar_test.go index b810b89c..057a557f 100644 --- a/pgtype/qchar_test.go +++ b/pgtype/qchar_test.go @@ -11,12 +11,12 @@ import ( func TestQCharTranscode(t *testing.T) { testutil.TestPgxSuccessfulTranscodeEqFunc(t, `"char"`, []interface{}{ - pgtype.QChar{Int: math.MinInt8, Status: pgtype.Present}, - pgtype.QChar{Int: -1, Status: pgtype.Present}, - pgtype.QChar{Int: 0, Status: pgtype.Present}, - pgtype.QChar{Int: 1, Status: pgtype.Present}, - pgtype.QChar{Int: math.MaxInt8, Status: pgtype.Present}, - pgtype.QChar{Int: 0, Status: pgtype.Null}, + &pgtype.QChar{Int: math.MinInt8, Status: pgtype.Present}, + &pgtype.QChar{Int: -1, Status: pgtype.Present}, + &pgtype.QChar{Int: 0, Status: pgtype.Present}, + &pgtype.QChar{Int: 1, Status: pgtype.Present}, + &pgtype.QChar{Int: math.MaxInt8, Status: pgtype.Present}, + &pgtype.QChar{Int: 0, Status: pgtype.Null}, }, func(a, b interface{}) bool { return reflect.DeepEqual(a, b) }) From 6529b91111521a81de0530aa193c540be73e7c6f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 May 2017 08:38:27 -0500 Subject: [PATCH 214/264] Fix TestNumericNormalize --- pgtype/numeric_test.go | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go index d68a9347..5f3a3416 100644 --- a/pgtype/numeric_test.go +++ b/pgtype/numeric_test.go @@ -49,47 +49,47 @@ func TestNumericNormalize(t *testing.T) { testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ { SQL: "select '0'::numeric", - Value: pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(0), Exp: 0, Status: pgtype.Present}, }, { SQL: "select '1'::numeric", - Value: pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: 0, Status: pgtype.Present}, }, { SQL: "select '10.00'::numeric", - Value: pgtype.Numeric{Int: big.NewInt(1000), Exp: -2, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(1000), Exp: -2, Status: pgtype.Present}, }, { SQL: "select '1e-3'::numeric", - Value: pgtype.Numeric{Int: big.NewInt(1), Exp: -3, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: -3, Status: pgtype.Present}, }, { SQL: "select '-1'::numeric", - Value: pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(-1), Exp: 0, Status: pgtype.Present}, }, { SQL: "select '10000'::numeric", - Value: pgtype.Numeric{Int: big.NewInt(1), Exp: 4, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(1), Exp: 4, Status: pgtype.Present}, }, { SQL: "select '3.14'::numeric", - Value: pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(314), Exp: -2, Status: pgtype.Present}, }, { SQL: "select '1.1'::numeric", - Value: pgtype.Numeric{Int: big.NewInt(11), Exp: -1, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(11), Exp: -1, Status: pgtype.Present}, }, { SQL: "select '100010001'::numeric", - Value: pgtype.Numeric{Int: big.NewInt(100010001), Exp: 0, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(100010001), Exp: 0, Status: pgtype.Present}, }, { SQL: "select '100010001.0001'::numeric", - Value: pgtype.Numeric{Int: big.NewInt(1000100010001), Exp: -4, Status: pgtype.Present}, + Value: &pgtype.Numeric{Int: big.NewInt(1000100010001), Exp: -4, Status: pgtype.Present}, }, { SQL: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric", - Value: pgtype.Numeric{ + Value: &pgtype.Numeric{ Int: mustParseBigInt(t, "423723478923478928934789237432487213832189417894318904389012483210893443219085471578891547854892438945012347981"), Exp: -41, Status: pgtype.Present, @@ -97,7 +97,7 @@ func TestNumericNormalize(t *testing.T) { }, { SQL: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric", - Value: pgtype.Numeric{ + Value: &pgtype.Numeric{ Int: mustParseBigInt(t, "8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), Exp: -196, Status: pgtype.Present, @@ -105,7 +105,7 @@ func TestNumericNormalize(t *testing.T) { }, { SQL: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric", - Value: pgtype.Numeric{ + Value: &pgtype.Numeric{ Int: mustParseBigInt(t, "123"), Exp: -186, Status: pgtype.Present, From aa2bc93e31091c9f82f18fa9160d026b715cfb2a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 May 2017 08:39:53 -0500 Subject: [PATCH 215/264] Fix TestIntervalNormalize --- pgtype/interval_test.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pgtype/interval_test.go b/pgtype/interval_test.go index 18e21ddd..76ea3240 100644 --- a/pgtype/interval_test.go +++ b/pgtype/interval_test.go @@ -33,31 +33,31 @@ func TestIntervalNormalize(t *testing.T) { testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ { SQL: "select '1 second'::interval", - Value: pgtype.Interval{Microseconds: 1000000, Status: pgtype.Present}, + Value: &pgtype.Interval{Microseconds: 1000000, Status: pgtype.Present}, }, { SQL: "select '1.000001 second'::interval", - Value: pgtype.Interval{Microseconds: 1000001, Status: pgtype.Present}, + Value: &pgtype.Interval{Microseconds: 1000001, Status: pgtype.Present}, }, { SQL: "select '34223 hours'::interval", - Value: pgtype.Interval{Microseconds: 123202800000000, Status: pgtype.Present}, + Value: &pgtype.Interval{Microseconds: 123202800000000, Status: pgtype.Present}, }, { SQL: "select '1 day'::interval", - Value: pgtype.Interval{Days: 1, Status: pgtype.Present}, + Value: &pgtype.Interval{Days: 1, Status: pgtype.Present}, }, { SQL: "select '1 month'::interval", - Value: pgtype.Interval{Months: 1, Status: pgtype.Present}, + Value: &pgtype.Interval{Months: 1, Status: pgtype.Present}, }, { SQL: "select '1 year'::interval", - Value: pgtype.Interval{Months: 12, Status: pgtype.Present}, + Value: &pgtype.Interval{Months: 12, Status: pgtype.Present}, }, { SQL: "select '-13 mon'::interval", - Value: pgtype.Interval{Months: -13, Status: pgtype.Present}, + Value: &pgtype.Interval{Months: -13, Status: pgtype.Present}, }, }) } From dc753bf2a3846be56db0516cdb52709404d25467 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 May 2017 08:42:39 -0500 Subject: [PATCH 216/264] Fix TestHstoreArrayTranscode --- pgtype/hstore_array_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgtype/hstore_array_test.go b/pgtype/hstore_array_test.go index d26497b1..fcf08c49 100644 --- a/pgtype/hstore_array_test.go +++ b/pgtype/hstore_array_test.go @@ -49,7 +49,7 @@ func TestHstoreArrayTranscode(t *testing.T) { values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key } - src := pgtype.HstoreArray{ + src := &pgtype.HstoreArray{ Elements: values, Dimensions: []pgtype.ArrayDimension{{Length: int32(len(values)), LowerBound: 1}}, Status: pgtype.Present, From b24ca9fa8ad630a9c4cac9a2556f354c1b8a5b4c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 May 2017 08:45:57 -0500 Subject: [PATCH 217/264] Remove PG 9.0 hstore support from Travis --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 85981e4e..fd3850e4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -41,7 +41,7 @@ env: before_script: - mv conn_config_test.go.travis conn_config_test.go - psql -U postgres -c 'create database pgx_test' - - "[[ \"${PGVERSION}\" = '9.0' ]] && psql -U postgres -f /usr/share/postgresql/9.0/contrib/hstore.sql pgx_test || psql -U postgres pgx_test -c 'create extension hstore'" + - psql -U postgres pgx_test -c 'create extension hstore' - psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'" - psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'" - psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'" From 104192725a88460bb1a73a5dcb3546ef2da94b50 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 May 2017 09:44:15 -0500 Subject: [PATCH 218/264] Ensure shopspring-numeric tests run --- pgtype/ext/shopspring-numeric/decimal_test.go | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/pgtype/ext/shopspring-numeric/decimal_test.go b/pgtype/ext/shopspring-numeric/decimal_test.go index 50c0fb8b..08483dda 100644 --- a/pgtype/ext/shopspring-numeric/decimal_test.go +++ b/pgtype/ext/shopspring-numeric/decimal_test.go @@ -25,61 +25,61 @@ func TestNumericNormalize(t *testing.T) { testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ { SQL: "select '0'::numeric", - Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Status: pgtype.Present}, }, { SQL: "select '1'::numeric", - Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}, }, { SQL: "select '10.00'::numeric", - Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "10.00"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "10.00"), Status: pgtype.Present}, }, { SQL: "select '1e-3'::numeric", - Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.001"), Status: pgtype.Present}, }, { SQL: "select '-1'::numeric", - Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "-1"), Status: pgtype.Present}, }, { SQL: "select '10000'::numeric", - Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "10000"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "10000"), Status: pgtype.Present}, }, { SQL: "select '3.14'::numeric", - Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "3.14"), Status: pgtype.Present}, }, { SQL: "select '1.1'::numeric", - Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "1.1"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1.1"), Status: pgtype.Present}, }, { SQL: "select '100010001'::numeric", - Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001"), Status: pgtype.Present}, }, { SQL: "select '100010001.0001'::numeric", - Value: shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001.0001"), Status: pgtype.Present}, + Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "100010001.0001"), Status: pgtype.Present}, }, { SQL: "select '4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981'::numeric", - Value: shopspring.Numeric{ + Value: &shopspring.Numeric{ Decimal: mustParseDecimal(t, "4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981"), Status: pgtype.Present, }, }, { SQL: "select '0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234'::numeric", - Value: shopspring.Numeric{ + Value: &shopspring.Numeric{ Decimal: mustParseDecimal(t, "0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), Status: pgtype.Present, }, }, { SQL: "select '0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123'::numeric", - Value: shopspring.Numeric{ + Value: &shopspring.Numeric{ Decimal: mustParseDecimal(t, "0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123"), Status: pgtype.Present, }, From b8c043780d38280fccc0e6b4b5aa5b22d2ef1eff Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 May 2017 09:46:06 -0500 Subject: [PATCH 219/264] Fix shopsprint-numeric test --- pgtype/ext/shopspring-numeric/decimal_test.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pgtype/ext/shopspring-numeric/decimal_test.go b/pgtype/ext/shopspring-numeric/decimal_test.go index 08483dda..79121ef3 100644 --- a/pgtype/ext/shopspring-numeric/decimal_test.go +++ b/pgtype/ext/shopspring-numeric/decimal_test.go @@ -22,7 +22,7 @@ func mustParseDecimal(t *testing.T, src string) decimal.Decimal { } func TestNumericNormalize(t *testing.T) { - testutil.TestSuccessfulNormalize(t, []testutil.NormalizeTest{ + testutil.TestSuccessfulNormalizeEqFunc(t, []testutil.NormalizeTest{ { SQL: "select '0'::numeric", Value: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0"), Status: pgtype.Present}, @@ -84,6 +84,11 @@ func TestNumericNormalize(t *testing.T) { Status: pgtype.Present, }, }, + }, func(aa, bb interface{}) bool { + a := aa.(shopspring.Numeric) + b := bb.(shopspring.Numeric) + + return a.Status == b.Status && a.Decimal.Equal(b.Decimal) }) } From 2df4b1406b6e39d830622be140c18d778972cd37 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 May 2017 10:58:44 -0500 Subject: [PATCH 220/264] Do not double call termContext in QueryEx QueryEx was calling termContext and rows.fatal on err of sendPreparedQuery. rows.fatal calls rows.Close which already calls termContext. This sequence of calls was causing underlying io timeout errors to be returned instead of context errors. In addition, added fatalWriteErr helper method to allow recovery of write timeout errors where no bytes were written. This should solve flickering errors on Travis. --- conn.go | 21 +++++++++++++++++---- query.go | 5 ++--- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/conn.go b/conn.go index bd098646..20844e57 100644 --- a/conn.go +++ b/conn.go @@ -795,9 +795,11 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared buf = append(buf, 'S') buf = pgio.AppendInt32(buf, 4) - _, err = c.conn.Write(buf) + n, err := c.conn.Write(buf) if err != nil { - c.die(err) + if fatalWriteErr(n, err) { + c.die(err) + } return nil, err } c.readyForQuery = false @@ -1085,8 +1087,8 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} buf = append(buf, 'S') buf = pgio.AppendInt32(buf, 4) - _, err = c.conn.Write(buf) - if err != nil { + n, err := c.conn.Write(buf) + if err != nil && fatalWriteErr(n, err) { c.die(err) } c.readyForQuery = false @@ -1094,6 +1096,17 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} return err } +// fatalWriteError takes the response of a net.Conn.Write and determines if it is fatal +func fatalWriteErr(bytesWritten int, err error) bool { + // Partial writes break the connection + if bytesWritten > 0 { + return true + } + + netErr, is := err.(net.Error) + return !(is && netErr.Timeout()) +} + // Exec executes sql. sql can be either a prepared statement name or an SQL string. // arguments should be referenced positionally from the sql string as $1, $2, etc. func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) { diff --git a/query.go b/query.go index 681c133b..44bf004a 100644 --- a/query.go +++ b/query.go @@ -398,16 +398,15 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, err = c.initContext(ctx) if err != nil { rows.fatal(err) - return rows, err + return rows, rows.err } err = c.sendPreparedQuery(ps, args...) if err != nil { rows.fatal(err) - err = c.termContext(err) } - return rows, err + return rows, rows.err } func (c *Conn) sanitizeAndSendSimpleQuery(sql string, args ...interface{}) (err error) { From d1fd222ca574df832934c4b4ada8fd9efd47d25d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 May 2017 17:58:19 -0500 Subject: [PATCH 221/264] Add transaction context support --- conn_pool.go | 6 +-- conn_pool_test.go | 3 +- pgmock/pgmock.go | 25 ++++++++++ stdlib/sql.go | 2 +- stdlib/sql_test.go | 11 +++-- tx.go | 28 +++++++++--- tx_test.go | 112 +++++++++++++++++++++++++++++++++++++++++++-- 7 files changed, 168 insertions(+), 19 deletions(-) diff --git a/conn_pool.go b/conn_pool.go index 49de6658..632692de 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -410,7 +410,7 @@ func (p *ConnPool) QueryRowEx(ctx context.Context, sql string, options *QueryExO // Begin acquires a connection and begins a transaction on it. When the // transaction is closed the connection will be automatically released. func (p *ConnPool) Begin() (*Tx, error) { - return p.BeginEx(nil) + return p.BeginEx(context.Background(), nil) } // Prepare creates a prepared statement on a connection in the pool to test the @@ -499,14 +499,14 @@ func (p *ConnPool) Deallocate(name string) (err error) { // BeginEx acquires a connection and starts a transaction with txOptions // determining the transaction mode. When the transaction is closed the // connection will be automatically released. -func (p *ConnPool) BeginEx(txOptions *TxOptions) (*Tx, error) { +func (p *ConnPool) BeginEx(ctx context.Context, txOptions *TxOptions) (*Tx, error) { for { c, err := p.Acquire() if err != nil { return nil, err } - tx, err := c.BeginEx(txOptions) + tx, err := c.BeginEx(ctx, txOptions) if err != nil { alive := c.IsAlive() p.Release(c) diff --git a/conn_pool_test.go b/conn_pool_test.go index 42f37eb1..560ab3ae 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -1,6 +1,7 @@ package pgx_test import ( + "context" "errors" "fmt" "net" @@ -635,7 +636,7 @@ func TestConnPoolTransactionIso(t *testing.T) { pool := createConnPool(t, 2) defer pool.Close() - tx, err := pool.BeginEx(&pgx.TxOptions{IsoLevel: pgx.Serializable}) + tx, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable}) if err != nil { t.Fatalf("pool.BeginEx failed: %v", err) } diff --git a/pgmock/pgmock.go b/pgmock/pgmock.go index 8dccf811..3f1e54f4 100644 --- a/pgmock/pgmock.go +++ b/pgmock/pgmock.go @@ -3,6 +3,7 @@ package pgmock import ( "errors" "fmt" + "io" "net" "reflect" @@ -38,6 +39,9 @@ func (s *Server) ServeOne() error { if err != nil { return err } + defer conn.Close() + + s.Close() backend, err := pgproto3.NewBackend(conn, conn) if err != nil { @@ -167,6 +171,27 @@ func SendMessage(msg pgproto3.BackendMessage) Step { return &sendMessageStep{msg: msg} } +type waitForCloseMessageStep struct{} + +func (e *waitForCloseMessageStep) Step(backend *pgproto3.Backend) error { + for { + msg, err := backend.Receive() + if err == io.EOF { + return nil + } else if err != nil { + return err + } + + if _, ok := msg.(*pgproto3.Terminate); ok { + return nil + } + } +} + +func WaitForClose() Step { + return &waitForCloseMessageStep{} +} + func AcceptUnauthenticatedConnRequestSteps() []Step { return []Step{ ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}), diff --git a/stdlib/sql.go b/stdlib/sql.go index 088095ab..a0aa6975 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -267,7 +267,7 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e pgxOpts.AccessMode = pgx.ReadOnly } - return c.conn.BeginEx(&pgxOpts) + return c.conn.BeginEx(ctx, &pgxOpts) } func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) { diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index aa3ae3ee..415864cd 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -847,6 +847,7 @@ func TestConnPingContextCancel(t *testing.T) { script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Query{String: ";"}), + pgmock.WaitForClose(), ) server, err := pgmock.NewServer(script) @@ -855,7 +856,7 @@ func TestConnPingContextCancel(t *testing.T) { } defer server.Close() - errChan := make(chan error) + errChan := make(chan error, 1) go func() { errChan <- server.ServeOne() }() @@ -864,7 +865,7 @@ func TestConnPingContextCancel(t *testing.T) { if err != nil { t.Fatalf("sql.Open failed: %v", err) } - // defer closeDB(t, db) // mock DB doesn't close correctly yet + defer closeDB(t, db) ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) @@ -900,6 +901,7 @@ func TestConnPrepareContextCancel(t *testing.T) { pgmock.ExpectMessage(&pgproto3.Parse{Name: "pgx_0", Query: "select now()"}), pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S', Name: "pgx_0"}), pgmock.ExpectMessage(&pgproto3.Sync{}), + pgmock.WaitForClose(), ) server, err := pgmock.NewServer(script) @@ -917,7 +919,7 @@ func TestConnPrepareContextCancel(t *testing.T) { if err != nil { t.Fatalf("sql.Open failed: %v", err) } - // defer closeDB(t, db) // mock DB doesn't close correctly yet + defer closeDB(t, db) ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) @@ -950,6 +952,7 @@ func TestConnExecContextCancel(t *testing.T) { script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Query{String: "create temporary table exec_context_test(id serial primary key)"}), + pgmock.WaitForClose(), ) server, err := pgmock.NewServer(script) @@ -967,7 +970,7 @@ func TestConnExecContextCancel(t *testing.T) { if err != nil { t.Fatalf("sql.Open failed: %v", err) } - // defer closeDB(t, db) // mock DB doesn't close correctly yet + defer closeDB(t, db) ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) diff --git a/tx.go b/tx.go index ea804449..f5309468 100644 --- a/tx.go +++ b/tx.go @@ -2,8 +2,10 @@ package pgx import ( "bytes" + "context" "errors" "fmt" + "time" ) type TxIsoLevel string @@ -56,12 +58,13 @@ var ErrTxCommitRollback = errors.New("commit unexpectedly resulted in rollback") // Begin starts a transaction with the default transaction mode for the // current connection. To use a specific transaction mode see BeginEx. func (c *Conn) Begin() (*Tx, error) { - return c.BeginEx(nil) + return c.BeginEx(context.Background(), nil) } // BeginEx starts a transaction with txOptions determining the transaction -// mode. -func (c *Conn) BeginEx(txOptions *TxOptions) (*Tx, error) { +// mode. Unlike database/sql, the context only affects the begin command. i.e. +// there is no auto-rollback on context cancelation. +func (c *Conn) BeginEx(ctx context.Context, txOptions *TxOptions) (*Tx, error) { var beginSQL string if txOptions == nil { beginSQL = "begin" @@ -81,8 +84,11 @@ func (c *Conn) BeginEx(txOptions *TxOptions) (*Tx, error) { beginSQL = buf.String() } - _, err := c.Exec(beginSQL) + _, err := c.ExecEx(ctx, beginSQL, nil) if err != nil { + // begin should never fail unless there is an underlying connection issue or + // a context timeout. In either case, the connection is possibly broken. + c.die(errors.New("failed to begin transaction")) return nil, err } @@ -102,11 +108,16 @@ type Tx struct { // Commit commits the transaction func (tx *Tx) Commit() error { + return tx.CommitEx(context.Background()) +} + +// CommitEx commits the transaction with a context. +func (tx *Tx) CommitEx(ctx context.Context) error { if tx.status != TxStatusInProgress { return ErrTxClosed } - commandTag, err := tx.conn.Exec("commit") + commandTag, err := tx.conn.ExecEx(ctx, "commit", nil) if err == nil && commandTag == "COMMIT" { tx.status = TxStatusCommitSuccess } else if err == nil && commandTag == "ROLLBACK" { @@ -115,6 +126,8 @@ func (tx *Tx) Commit() error { } else { tx.status = TxStatusCommitFailure tx.err = err + // A commit failure leaves the connection in an undefined state + tx.conn.die(errors.New("commit failed")) } if tx.connPool != nil { @@ -133,11 +146,14 @@ func (tx *Tx) Rollback() error { return ErrTxClosed } - _, tx.err = tx.conn.Exec("rollback") + ctx, _ := context.WithTimeout(context.Background(), 15*time.Second) + _, tx.err = tx.conn.ExecEx(ctx, "rollback", nil) if tx.err == nil { tx.status = TxStatusRollbackSuccess } else { tx.status = TxStatusRollbackFailure + // A rollback failure leaves the connection in an undefined state + tx.conn.die(errors.New("rollback failed")) } if tx.connPool != nil { diff --git a/tx_test.go b/tx_test.go index 35abd4eb..b25e1c9f 100644 --- a/tx_test.go +++ b/tx_test.go @@ -1,9 +1,14 @@ package pgx_test import ( + "context" + "fmt" "testing" + "time" "github.com/jackc/pgx" + "github.com/jackc/pgx/pgmock" + "github.com/jackc/pgx/pgproto3" ) func TestTransactionSuccessfulCommit(t *testing.T) { @@ -107,13 +112,13 @@ func TestTxCommitSerializationFailure(t *testing.T) { } defer pool.Exec(`drop table tx_serializable_sums`) - tx1, err := pool.BeginEx(&pgx.TxOptions{IsoLevel: pgx.Serializable}) + tx1, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable}) if err != nil { t.Fatalf("BeginEx failed: %v", err) } defer tx1.Rollback() - tx2, err := pool.BeginEx(&pgx.TxOptions{IsoLevel: pgx.Serializable}) + tx2, err := pool.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable}) if err != nil { t.Fatalf("BeginEx failed: %v", err) } @@ -190,7 +195,7 @@ func TestBeginExIsoLevels(t *testing.T) { isoLevels := []pgx.TxIsoLevel{pgx.Serializable, pgx.RepeatableRead, pgx.ReadCommitted, pgx.ReadUncommitted} for _, iso := range isoLevels { - tx, err := conn.BeginEx(&pgx.TxOptions{IsoLevel: iso}) + tx, err := conn.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: iso}) if err != nil { t.Fatalf("conn.BeginEx failed: %v", err) } @@ -214,7 +219,7 @@ func TestBeginExReadOnly(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - tx, err := conn.BeginEx(&pgx.TxOptions{AccessMode: pgx.ReadOnly}) + tx, err := conn.BeginEx(context.Background(), &pgx.TxOptions{AccessMode: pgx.ReadOnly}) if err != nil { t.Fatalf("conn.BeginEx failed: %v", err) } @@ -226,6 +231,105 @@ func TestBeginExReadOnly(t *testing.T) { } } +func TestConnBeginExContextCancel(t *testing.T) { + t.Parallel() + + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) + script.Steps = append(script.Steps, + pgmock.ExpectMessage(&pgproto3.Query{String: "begin"}), + pgmock.WaitForClose(), + ) + + server, err := pgmock.NewServer(script) + if err != nil { + t.Fatal(err) + } + defer server.Close() + + errChan := make(chan error, 1) + go func() { + errChan <- server.ServeOne() + }() + + mockConfig, err := pgx.ParseURI(fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) + if err != nil { + t.Fatal(err) + } + + conn := mustConnect(t, mockConfig) + + ctx, _ := context.WithTimeout(context.Background(), 50*time.Millisecond) + + _, err = conn.BeginEx(ctx, nil) + if err != context.DeadlineExceeded { + t.Errorf("err => %v, want %v", err, context.DeadlineExceeded) + } + + if conn.IsAlive() { + t.Error("expected conn to be dead after BeginEx failure") + } + + if err := <-errChan; err != nil { + t.Errorf("mock server err: %v", err) + } +} + +func TestTxCommitExCancel(t *testing.T) { + t.Parallel() + + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) + script.Steps = append(script.Steps, + pgmock.ExpectMessage(&pgproto3.Query{String: "begin"}), + pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: "BEGIN"}), + pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'T'}), + pgmock.WaitForClose(), + ) + + server, err := pgmock.NewServer(script) + if err != nil { + t.Fatal(err) + } + defer server.Close() + + errChan := make(chan error, 1) + go func() { + errChan <- server.ServeOne() + }() + + mockConfig, err := pgx.ParseURI(fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) + if err != nil { + t.Fatal(err) + } + + conn := mustConnect(t, mockConfig) + defer conn.Close() + + tx, err := conn.Begin() + if err != nil { + t.Fatal(err) + } + + ctx, _ := context.WithTimeout(context.Background(), 50*time.Millisecond) + err = tx.CommitEx(ctx) + if err != context.DeadlineExceeded { + t.Errorf("err => %v, want %v", err, context.DeadlineExceeded) + } + + if conn.IsAlive() { + t.Error("expected conn to be dead after CommitEx failure") + } + + if err := <-errChan; err != nil { + t.Errorf("mock server err: %v", err) + } +} + func TestTxStatus(t *testing.T) { t.Parallel() From 8a7165dd988dcc83f3ce76babdec23c59e44e502 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 May 2017 18:03:59 -0500 Subject: [PATCH 222/264] Add ctx to PrepareEx Remove PrepareExContext --- conn.go | 10 +++------- conn_pool.go | 8 ++++---- conn_test.go | 2 +- query.go | 2 +- stdlib/sql.go | 4 ++-- tx.go | 6 +++--- v3.md | 2 ++ 7 files changed, 16 insertions(+), 18 deletions(-) diff --git a/conn.go b/conn.go index 20844e57..04299de7 100644 --- a/conn.go +++ b/conn.go @@ -710,7 +710,7 @@ func configSSL(sslmode string, cc *ConnConfig) error { // name and sql arguments. This allows a code path to Prepare and Query/Exec without // concern for if the statement has already been prepared. func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { - return c.PrepareEx(name, sql, nil) + return c.PrepareEx(context.Background(), name, sql, nil) } // PrepareEx creates a prepared statement with name and sql. sql can contain placeholders @@ -720,11 +720,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { // PrepareEx is idempotent; i.e. it is safe to call PrepareEx multiple times with the same // name and sql arguments. This allows a code path to PrepareEx and Query/Exec without // concern for if the statement has already been prepared. -func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { - return c.PrepareExContext(context.Background(), name, sql, opts) -} - -func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { +func (c *Conn) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { err = c.waitForPreviousCancelQuery(ctx) if err != nil { return nil, err @@ -1455,7 +1451,7 @@ func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, ps, ok := c.preparedStatements[sql] if !ok { var err error - ps, err = c.PrepareExContext(ctx, "", sql, nil) + ps, err = c.PrepareEx(ctx, "", sql, nil) if err != nil { return "", err } diff --git a/conn_pool.go b/conn_pool.go index 632692de..42200b85 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -425,7 +425,7 @@ func (p *ConnPool) Begin() (*Tx, error) { // the same name and sql arguments. This allows a code path to Prepare and // Query/Exec/PrepareEx without concern for if the statement has already been prepared. func (p *ConnPool) Prepare(name, sql string) (*PreparedStatement, error) { - return p.PrepareEx(name, sql, nil) + return p.PrepareEx(context.Background(), name, sql, nil) } // PrepareEx creates a prepared statement on a connection in the pool to test the @@ -439,7 +439,7 @@ func (p *ConnPool) Prepare(name, sql string) (*PreparedStatement, error) { // PrepareEx is idempotent; i.e. it is safe to call PrepareEx multiple times with the same // name and sql arguments. This allows a code path to PrepareEx and Query/Exec/Prepare without // concern for if the statement has already been prepared. -func (p *ConnPool) PrepareEx(name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) { +func (p *ConnPool) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) { p.cond.L.Lock() defer p.cond.L.Unlock() @@ -461,13 +461,13 @@ func (p *ConnPool) PrepareEx(name, sql string, opts *PrepareExOptions) (*Prepare return ps, nil } - ps, err := c.PrepareEx(name, sql, opts) + ps, err := c.PrepareEx(ctx, name, sql, opts) if err != nil { return nil, err } for _, c := range p.availableConnections { - _, err := c.PrepareEx(name, sql, opts) + _, err := c.PrepareEx(ctx, name, sql, opts) if err != nil { return nil, err } diff --git a/conn_test.go b/conn_test.go index f887e030..acee1b49 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1326,7 +1326,7 @@ func TestPrepareEx(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - _, err := conn.PrepareEx("test", "select $1", &pgx.PrepareExOptions{ParameterOids: []pgtype.Oid{pgtype.TextOid}}) + _, err := conn.PrepareEx(context.Background(), "test", "select $1", &pgx.PrepareExOptions{ParameterOids: []pgtype.Oid{pgtype.TextOid}}) if err != nil { t.Errorf("Unable to prepare statement: %v", err) return diff --git a/query.go b/query.go index 44bf004a..10eda1bc 100644 --- a/query.go +++ b/query.go @@ -386,7 +386,7 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, ps, ok := c.preparedStatements[sql] if !ok { var err error - ps, err = c.PrepareExContext(ctx, "", sql, nil) + ps, err = c.PrepareEx(ctx, "", sql, nil) if err != nil { rows.fatal(err) return rows, rows.err diff --git a/stdlib/sql.go b/stdlib/sql.go index a0aa6975..aa45dd40 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -220,7 +220,7 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e name := fmt.Sprintf("pgx_%d", c.psCount) c.psCount++ - ps, err := c.conn.PrepareExContext(ctx, name, query, nil) + ps, err := c.conn.PrepareEx(ctx, name, query, nil) if err != nil { return nil, err } @@ -311,7 +311,7 @@ func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.Na return nil, driver.ErrBadConn } - ps, err := c.conn.PrepareExContext(ctx, "", query, nil) + ps, err := c.conn.PrepareEx(ctx, "", query, nil) if err != nil { return nil, err } diff --git a/tx.go b/tx.go index f5309468..07cae4ba 100644 --- a/tx.go +++ b/tx.go @@ -174,16 +174,16 @@ func (tx *Tx) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, // Prepare delegates to the underlying *Conn func (tx *Tx) Prepare(name, sql string) (*PreparedStatement, error) { - return tx.PrepareEx(name, sql, nil) + return tx.PrepareEx(context.Background(), name, sql, nil) } // PrepareEx delegates to the underlying *Conn -func (tx *Tx) PrepareEx(name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) { +func (tx *Tx) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) { if tx.status != TxStatusInProgress { return nil, ErrTxClosed } - return tx.conn.PrepareEx(name, sql, opts) + return tx.conn.PrepareEx(ctx, name, sql, opts) } // Query delegates to the underlying *Conn diff --git a/v3.md b/v3.md index 624c25eb..33a27d2d 100644 --- a/v3.md +++ b/v3.md @@ -48,6 +48,8 @@ Removed Tx.AfterClose() Removed Tx.Conn() +Added ctx parameter to (Conn/Tx/ConnPool).PrepareEx + ## TODO / Possible / Investigate Organize errors better From 749fdfe7d5e0ccca384376a583030b0a2fa74b9d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 21 May 2017 19:35:37 -0500 Subject: [PATCH 223/264] Resolve race on conn.Close/die Use sync.Mutex instead of atomic operations for clarity. --- conn.go | 68 +++++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 47 insertions(+), 21 deletions(-) diff --git a/conn.go b/conn.go index 04299de7..c4c054dd 100644 --- a/conn.go +++ b/conn.go @@ -17,6 +17,7 @@ import ( "regexp" "strconv" "strings" + "sync" "sync/atomic" "time" @@ -102,7 +103,8 @@ type Conn struct { poolResetCount int preallocatedRows []Rows - status int32 // One of connStatus* constants + mux sync.Mutex + status byte // One of connStatus* constants causeOfDeath error readyForQuery bool // connection has received ReadyForQuery message since last query was sent @@ -267,20 +269,25 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl defer func() { if c != nil && err != nil { c.conn.Close() - atomic.StoreInt32(&c.status, connStatusClosed) + c.mux.Lock() + c.status = connStatusClosed + c.mux.Unlock() } }() c.RuntimeParams = make(map[string]string) c.preparedStatements = make(map[string]*PreparedStatement) c.channels = make(map[string]struct{}) - atomic.StoreInt32(&c.status, connStatusIdle) c.lastActivityTime = time.Now() c.cancelQueryCompleted = make(chan struct{}, 1) c.doneChan = make(chan struct{}) c.closedChan = make(chan error) c.wbuf = make([]byte, 0, 1024) + c.mux.Lock() + c.status = connStatusIdle + c.mux.Unlock() + if tlsConfig != nil { if c.shouldLog(LogLevelDebug) { c.log(LogLevelDebug, "starting TLS handshake", nil) @@ -401,19 +408,17 @@ func (c *Conn) PID() uint32 { // Close closes a connection. It is safe to call Close on a already closed // connection. func (c *Conn) Close() (err error) { - for { - status := atomic.LoadInt32(&c.status) - if status < connStatusIdle { - return nil - } - if atomic.CompareAndSwapInt32(&c.status, status, connStatusClosed) { - break - } + c.mux.Lock() + defer c.mux.Unlock() + + if c.status < connStatusIdle { + return nil } + c.status = connStatusClosed defer func() { c.conn.Close() - c.die(errors.New("Closed")) + c.causeOfDeath = errors.New("Closed") if c.shouldLog(LogLevelInfo) { c.log(LogLevelInfo, "closed connection", nil) } @@ -989,10 +994,14 @@ func (c *Conn) WaitForNotification(ctx context.Context) (notification *Notificat } func (c *Conn) IsAlive() bool { - return atomic.LoadInt32(&c.status) >= connStatusIdle + c.mux.Lock() + defer c.mux.Unlock() + return c.status >= connStatusIdle } func (c *Conn) CauseOfDeath() error { + c.mux.Lock() + defer c.mux.Unlock() return c.causeOfDeath } @@ -1131,7 +1140,7 @@ func (c *Conn) processContextFreeMsg(msg pgproto3.BackendMessage) (err error) { } func (c *Conn) rxMsg() (pgproto3.BackendMessage, error) { - if atomic.LoadInt32(&c.status) < connStatusIdle { + if !c.IsAlive() { return nil, ErrDeadConn } @@ -1283,23 +1292,40 @@ func (c *Conn) txPasswordMessage(password string) (err error) { } func (c *Conn) die(err error) { - atomic.StoreInt32(&c.status, connStatusClosed) + c.mux.Lock() + defer c.mux.Unlock() + + if c.status == connStatusClosed { + return + } + + c.status = connStatusClosed c.causeOfDeath = err c.conn.Close() } func (c *Conn) lock() error { - if atomic.CompareAndSwapInt32(&c.status, connStatusIdle, connStatusBusy) { - return nil + c.mux.Lock() + defer c.mux.Unlock() + + if c.status != connStatusIdle { + return ErrConnBusy } - return ErrConnBusy + + c.status = connStatusBusy + return nil } func (c *Conn) unlock() error { - if atomic.CompareAndSwapInt32(&c.status, connStatusBusy, connStatusIdle) { - return nil + c.mux.Lock() + defer c.mux.Unlock() + + if c.status != connStatusBusy { + return errors.New("unlock conn that is not busy") } - return errors.New("unlock conn that is not busy") + + c.status = connStatusIdle + return nil } func (c *Conn) shouldLog(lvl int) bool { From 21d2ed09349333d874bb165c37781595ab0a53b3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 22 May 2017 08:51:23 -0500 Subject: [PATCH 224/264] Add mock close --- stdlib/sql_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 415864cd..bfeb07c6 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -1036,6 +1036,7 @@ func TestConnQueryContextCancel(t *testing.T) { pgmock.ExpectMessage(&pgproto3.Sync{}), pgmock.SendMessage(&pgproto3.BindComplete{}), + pgmock.WaitForClose(), ) server, err := pgmock.NewServer(script) @@ -1053,7 +1054,7 @@ func TestConnQueryContextCancel(t *testing.T) { if err != nil { t.Fatalf("sql.Open failed: %v", err) } - // defer closeDB(t, db) // mock DB doesn't close correctly yet + defer db.Close() ctx, cancelFn := context.WithCancel(context.Background()) From 2e2c2ad778d138c68da0c198304c2dd7595ad83f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 26 May 2017 17:00:44 -0500 Subject: [PATCH 225/264] Replace MarshalBinary with Encode This new approach can avoid allocations. --- pgproto3/authentication.go | 20 +++++++------- pgproto3/backend.go | 7 +---- pgproto3/backend_key_data.go | 17 ++++++------ pgproto3/bind.go | 43 ++++++++++++++---------------- pgproto3/bind_complete.go | 4 +-- pgproto3/close.go | 23 ++++++++-------- pgproto3/close_complete.go | 4 +-- pgproto3/command_complete.go | 18 +++++++------ pgproto3/copy_both_response.go | 21 ++++++++------- pgproto3/copy_data.go | 17 +++++------- pgproto3/copy_in_response.go | 21 ++++++++------- pgproto3/copy_out_response.go | 21 ++++++++------- pgproto3/data_row.go | 26 +++++++++--------- pgproto3/describe.go | 23 ++++++++-------- pgproto3/empty_query_response.go | 4 +-- pgproto3/error_response.go | 8 +++--- pgproto3/execute.go | 22 +++++++-------- pgproto3/flush.go | 4 +-- pgproto3/frontend.go | 7 +---- pgproto3/function_call_response.go | 23 ++++++++-------- pgproto3/no_data.go | 4 +-- pgproto3/notice_response.go | 4 +-- pgproto3/notification_response.go | 22 ++++++++------- pgproto3/parameter_description.go | 21 ++++++++------- pgproto3/parameter_status.go | 25 +++++++++-------- pgproto3/parse.go | 31 +++++++++++---------- pgproto3/parse_complete.go | 4 +-- pgproto3/password_message.go | 18 +++++++------ pgproto3/pgproto3.go | 9 ++++--- pgproto3/query.go | 18 +++++++------ pgproto3/ready_for_query.go | 4 +-- pgproto3/row_description.go | 35 ++++++++++++------------ pgproto3/startup_message.go | 26 +++++++++--------- pgproto3/sync.go | 4 +-- pgproto3/terminate.go | 4 +-- 35 files changed, 277 insertions(+), 285 deletions(-) diff --git a/pgproto3/authentication.go b/pgproto3/authentication.go index 54f4978f..c04ee448 100644 --- a/pgproto3/authentication.go +++ b/pgproto3/authentication.go @@ -1,9 +1,10 @@ package pgproto3 import ( - "bytes" "encoding/binary" "fmt" + + "github.com/jackc/pgx/pgio" ) const ( @@ -36,19 +37,18 @@ func (dst *Authentication) Decode(src []byte) error { return nil } -func (src *Authentication) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - buf.WriteByte('R') - buf.Write(bigEndian.Uint32(0)) - buf.Write(bigEndian.Uint32(src.Type)) +func (src *Authentication) Encode(dst []byte) []byte { + dst = append(dst, 'R') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + dst = pgio.AppendUint32(dst, src.Type) switch src.Type { case AuthTypeMD5Password: - buf.Write(src.Salt[:]) + dst = append(dst, src.Salt[:]...) } - binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - return buf.Bytes(), nil + return dst } diff --git a/pgproto3/backend.go b/pgproto3/backend.go index df66a799..bf96ba95 100644 --- a/pgproto3/backend.go +++ b/pgproto3/backend.go @@ -32,12 +32,7 @@ func NewBackend(r io.Reader, w io.Writer) (*Backend, error) { } func (b *Backend) Send(msg BackendMessage) error { - buf, err := msg.MarshalBinary() - if err != nil { - return nil - } - - _, err = b.w.Write(buf) + _, err := b.w.Write(msg.Encode(nil)) return err } diff --git a/pgproto3/backend_key_data.go b/pgproto3/backend_key_data.go index 04f31aec..5a478f10 100644 --- a/pgproto3/backend_key_data.go +++ b/pgproto3/backend_key_data.go @@ -1,9 +1,10 @@ package pgproto3 import ( - "bytes" "encoding/binary" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type BackendKeyData struct { @@ -24,14 +25,12 @@ func (dst *BackendKeyData) Decode(src []byte) error { return nil } -func (src *BackendKeyData) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - buf.WriteByte('K') - buf.Write(bigEndian.Uint32(12)) - buf.Write(bigEndian.Uint32(src.ProcessID)) - buf.Write(bigEndian.Uint32(src.SecretKey)) - return buf.Bytes(), nil +func (src *BackendKeyData) Encode(dst []byte) []byte { + dst = append(dst, 'K') + dst = pgio.AppendUint32(dst, 12) + dst = pgio.AppendUint32(dst, src.ProcessID) + dst = pgio.AppendUint32(dst, src.SecretKey) + return dst } func (src *BackendKeyData) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/bind.go b/pgproto3/bind.go index 79fb4503..cceee6ab 100644 --- a/pgproto3/bind.go +++ b/pgproto3/bind.go @@ -5,6 +5,8 @@ import ( "encoding/binary" "encoding/hex" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type Bind struct { @@ -101,45 +103,40 @@ func (dst *Bind) Decode(src []byte) error { return nil } -func (src *Bind) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} +func (src *Bind) Encode(dst []byte) []byte { + dst = append(dst, 'B') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) - buf.WriteByte('B') - buf.Write(bigEndian.Uint32(0)) - - buf.WriteString(src.DestinationPortal) - buf.WriteByte(0) - buf.WriteString(src.PreparedStatement) - buf.WriteByte(0) - - buf.Write(bigEndian.Uint16(uint16(len(src.ParameterFormatCodes)))) + dst = append(dst, src.DestinationPortal...) + dst = append(dst, 0) + dst = append(dst, src.PreparedStatement...) + dst = append(dst, 0) + dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes))) for _, fc := range src.ParameterFormatCodes { - buf.Write(bigEndian.Int16(fc)) + dst = pgio.AppendInt16(dst, fc) } - buf.Write(bigEndian.Uint16(uint16(len(src.Parameters)))) - + dst = pgio.AppendUint16(dst, uint16(len(src.Parameters))) for _, p := range src.Parameters { if p == nil { - buf.Write(bigEndian.Int32(-1)) + dst = pgio.AppendInt32(dst, -1) continue } - buf.Write(bigEndian.Int32(int32(len(p)))) - buf.Write(p) + dst = pgio.AppendInt32(dst, int32(len(p))) + dst = append(dst, p...) } - buf.Write(bigEndian.Uint16(uint16(len(src.ResultFormatCodes)))) - + dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes))) for _, fc := range src.ResultFormatCodes { - buf.Write(bigEndian.Int16(fc)) + dst = pgio.AppendInt16(dst, fc) } - binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - return buf.Bytes(), nil + return dst } func (src *Bind) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/bind_complete.go b/pgproto3/bind_complete.go index 4f1c44b8..60360519 100644 --- a/pgproto3/bind_complete.go +++ b/pgproto3/bind_complete.go @@ -16,8 +16,8 @@ func (dst *BindComplete) Decode(src []byte) error { return nil } -func (src *BindComplete) MarshalBinary() ([]byte, error) { - return []byte{'2', 0, 0, 0, 4}, nil +func (src *BindComplete) Encode(dst []byte) []byte { + return append(dst, '2', 0, 0, 0, 4) } func (src *BindComplete) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/close.go b/pgproto3/close.go index 454ef68e..5ff4c886 100644 --- a/pgproto3/close.go +++ b/pgproto3/close.go @@ -2,8 +2,9 @@ package pgproto3 import ( "bytes" - "encoding/binary" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type Close struct { @@ -31,20 +32,18 @@ func (dst *Close) Decode(src []byte) error { return nil } -func (src *Close) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} +func (src *Close) Encode(dst []byte) []byte { + dst = append(dst, 'C') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) - buf.WriteByte('C') - buf.Write(bigEndian.Uint32(0)) + dst = append(dst, src.ObjectType) + dst = append(dst, src.Name...) + dst = append(dst, 0) - buf.WriteByte(src.ObjectType) - buf.WriteString(src.Name) - buf.WriteByte(0) + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) - - return buf.Bytes(), nil + return dst } func (src *Close) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/close_complete.go b/pgproto3/close_complete.go index 9bab3e8c..db793c94 100644 --- a/pgproto3/close_complete.go +++ b/pgproto3/close_complete.go @@ -16,8 +16,8 @@ func (dst *CloseComplete) Decode(src []byte) error { return nil } -func (src *CloseComplete) MarshalBinary() ([]byte, error) { - return []byte{'3', 0, 0, 0, 4}, nil +func (src *CloseComplete) Encode(dst []byte) []byte { + return append(dst, '3', 0, 0, 0, 4) } func (src *CloseComplete) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/command_complete.go b/pgproto3/command_complete.go index 86653804..85848532 100644 --- a/pgproto3/command_complete.go +++ b/pgproto3/command_complete.go @@ -3,6 +3,8 @@ package pgproto3 import ( "bytes" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type CommandComplete struct { @@ -22,17 +24,17 @@ func (dst *CommandComplete) Decode(src []byte) error { return nil } -func (src *CommandComplete) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} +func (src *CommandComplete) Encode(dst []byte) []byte { + dst = append(dst, 'C') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) - buf.WriteByte('C') - buf.Write(bigEndian.Uint32(uint32(4 + len(src.CommandTag) + 1))) + dst = append(dst, src.CommandTag...) + dst = append(dst, 0) - buf.WriteString(src.CommandTag) - buf.WriteByte(0) + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - return buf.Bytes(), nil + return dst } func (src *CommandComplete) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/copy_both_response.go b/pgproto3/copy_both_response.go index 3857c187..2862a34f 100644 --- a/pgproto3/copy_both_response.go +++ b/pgproto3/copy_both_response.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type CopyBothResponse struct { @@ -37,20 +39,19 @@ func (dst *CopyBothResponse) Decode(src []byte) error { return nil } -func (src *CopyBothResponse) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - - buf.WriteByte('W') - buf.Write(bigEndian.Uint32(uint32(4 + 1 + 2 + 2*len(src.ColumnFormatCodes)))) - - buf.Write(bigEndian.Uint16(uint16(len(src.ColumnFormatCodes)))) +func (src *CopyBothResponse) Encode(dst []byte) []byte { + dst = append(dst, 'W') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { - buf.Write(bigEndian.Uint16(fc)) + dst = pgio.AppendUint16(dst, fc) } - return buf.Bytes(), nil + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst } func (src *CopyBothResponse) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/copy_data.go b/pgproto3/copy_data.go index de7ab4ff..fab139e6 100644 --- a/pgproto3/copy_data.go +++ b/pgproto3/copy_data.go @@ -1,9 +1,10 @@ package pgproto3 import ( - "bytes" "encoding/hex" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type CopyData struct { @@ -18,15 +19,11 @@ func (dst *CopyData) Decode(src []byte) error { return nil } -func (src *CopyData) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - - buf.WriteByte('d') - buf.Write(bigEndian.Uint32(uint32(4 + len(src.Data)))) - buf.Write(src.Data) - - return buf.Bytes(), nil +func (src *CopyData) Encode(dst []byte) []byte { + dst = append(dst, 'd') + dst = pgio.AppendInt32(dst, int32(4+len(src.Data))) + dst = append(dst, src.Data...) + return dst } func (src *CopyData) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/copy_in_response.go b/pgproto3/copy_in_response.go index 9854d665..54083cd6 100644 --- a/pgproto3/copy_in_response.go +++ b/pgproto3/copy_in_response.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type CopyInResponse struct { @@ -37,20 +39,19 @@ func (dst *CopyInResponse) Decode(src []byte) error { return nil } -func (src *CopyInResponse) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - - buf.WriteByte('G') - buf.Write(bigEndian.Uint32(uint32(4 + 1 + 2 + 2*len(src.ColumnFormatCodes)))) - - buf.Write(bigEndian.Uint16(uint16(len(src.ColumnFormatCodes)))) +func (src *CopyInResponse) Encode(dst []byte) []byte { + dst = append(dst, 'G') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { - buf.Write(bigEndian.Uint16(fc)) + dst = pgio.AppendUint16(dst, fc) } - return buf.Bytes(), nil + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst } func (src *CopyInResponse) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/copy_out_response.go b/pgproto3/copy_out_response.go index 5ef6e4c1..eaa33b8b 100644 --- a/pgproto3/copy_out_response.go +++ b/pgproto3/copy_out_response.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type CopyOutResponse struct { @@ -37,20 +39,19 @@ func (dst *CopyOutResponse) Decode(src []byte) error { return nil } -func (src *CopyOutResponse) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - - buf.WriteByte('H') - buf.Write(bigEndian.Uint32(uint32(4 + 1 + 2 + 2*len(src.ColumnFormatCodes)))) - - buf.Write(bigEndian.Uint16(uint16(len(src.ColumnFormatCodes)))) +func (src *CopyOutResponse) Encode(dst []byte) []byte { + dst = append(dst, 'H') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { - buf.Write(bigEndian.Uint16(fc)) + dst = pgio.AppendUint16(dst, fc) } - return buf.Bytes(), nil + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst } func (src *CopyOutResponse) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/data_row.go b/pgproto3/data_row.go index 3e600e84..e46d3cc0 100644 --- a/pgproto3/data_row.go +++ b/pgproto3/data_row.go @@ -1,10 +1,11 @@ package pgproto3 import ( - "bytes" "encoding/binary" "encoding/hex" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type DataRow struct { @@ -58,28 +59,25 @@ func (dst *DataRow) Decode(src []byte) error { return nil } -func (src *DataRow) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - - buf.WriteByte('D') - buf.Write(bigEndian.Uint32(0)) - - buf.Write(bigEndian.Uint16(uint16(len(src.Values)))) +func (src *DataRow) Encode(dst []byte) []byte { + dst = append(dst, 'D') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + dst = pgio.AppendUint16(dst, uint16(len(src.Values))) for _, v := range src.Values { if v == nil { - buf.Write(bigEndian.Int32(-1)) + dst = pgio.AppendInt32(dst, -1) continue } - buf.Write(bigEndian.Int32(int32(len(v)))) - buf.Write(v) + dst = pgio.AppendInt32(dst, int32(len(v))) + dst = append(dst, v...) } - binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - return buf.Bytes(), nil + return dst } func (src *DataRow) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/describe.go b/pgproto3/describe.go index ea55ed9d..bb7bc056 100644 --- a/pgproto3/describe.go +++ b/pgproto3/describe.go @@ -2,8 +2,9 @@ package pgproto3 import ( "bytes" - "encoding/binary" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type Describe struct { @@ -31,20 +32,18 @@ func (dst *Describe) Decode(src []byte) error { return nil } -func (src *Describe) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} +func (src *Describe) Encode(dst []byte) []byte { + dst = append(dst, 'D') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) - buf.WriteByte('D') - buf.Write(bigEndian.Uint32(0)) + dst = append(dst, src.ObjectType) + dst = append(dst, src.Name...) + dst = append(dst, 0) - buf.WriteByte(src.ObjectType) - buf.WriteString(src.Name) - buf.WriteByte(0) + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) - - return buf.Bytes(), nil + return dst } func (src *Describe) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/empty_query_response.go b/pgproto3/empty_query_response.go index 13ed1886..d283b06d 100644 --- a/pgproto3/empty_query_response.go +++ b/pgproto3/empty_query_response.go @@ -16,8 +16,8 @@ func (dst *EmptyQueryResponse) Decode(src []byte) error { return nil } -func (src *EmptyQueryResponse) MarshalBinary() ([]byte, error) { - return []byte{'I', 0, 0, 0, 4}, nil +func (src *EmptyQueryResponse) Encode(dst []byte) []byte { + return append(dst, 'I', 0, 0, 0, 4) } func (src *EmptyQueryResponse) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/error_response.go b/pgproto3/error_response.go index 602dd2a1..160234f2 100644 --- a/pgproto3/error_response.go +++ b/pgproto3/error_response.go @@ -103,11 +103,11 @@ func (dst *ErrorResponse) Decode(src []byte) error { return nil } -func (src *ErrorResponse) MarshalBinary() ([]byte, error) { - return src.marshalBinary('E') +func (src *ErrorResponse) Encode(dst []byte) []byte { + return append(dst, src.marshalBinary('E')...) } -func (src *ErrorResponse) marshalBinary(typeByte byte) ([]byte, error) { +func (src *ErrorResponse) marshalBinary(typeByte byte) []byte { var bigEndian BigEndianBuf buf := &bytes.Buffer{} @@ -193,5 +193,5 @@ func (src *ErrorResponse) marshalBinary(typeByte byte) ([]byte, error) { binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) - return buf.Bytes(), nil + return buf.Bytes() } diff --git a/pgproto3/execute.go b/pgproto3/execute.go index 4892e7b3..76da9943 100644 --- a/pgproto3/execute.go +++ b/pgproto3/execute.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type Execute struct { @@ -30,21 +32,19 @@ func (dst *Execute) Decode(src []byte) error { return nil } -func (src *Execute) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} +func (src *Execute) Encode(dst []byte) []byte { + dst = append(dst, 'E') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) - buf.WriteByte('E') - buf.Write(bigEndian.Uint32(0)) + dst = append(dst, src.Portal...) + dst = append(dst, 0) - buf.WriteString(src.Portal) - buf.WriteByte(0) + dst = pgio.AppendUint32(dst, src.MaxRows) - buf.Write(bigEndian.Uint32(src.MaxRows)) + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) - - return buf.Bytes(), nil + return dst } func (src *Execute) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/flush.go b/pgproto3/flush.go index d26f5c0c..7fd5e987 100644 --- a/pgproto3/flush.go +++ b/pgproto3/flush.go @@ -16,8 +16,8 @@ func (dst *Flush) Decode(src []byte) error { return nil } -func (src *Flush) MarshalBinary() ([]byte, error) { - return []byte{'H', 0, 0, 0, 4}, nil +func (src *Flush) Encode(dst []byte) []byte { + return append(dst, 'H', 0, 0, 0, 4) } func (src *Flush) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index 27a9890a..630a5cba 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -42,12 +42,7 @@ func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) { } func (b *Frontend) Send(msg FrontendMessage) error { - buf, err := msg.MarshalBinary() - if err != nil { - return nil - } - - _, err = b.w.Write(buf) + _, err := b.w.Write(msg.Encode(nil)) return err } diff --git a/pgproto3/function_call_response.go b/pgproto3/function_call_response.go index 1e0f16af..bb325b69 100644 --- a/pgproto3/function_call_response.go +++ b/pgproto3/function_call_response.go @@ -1,10 +1,11 @@ package pgproto3 import ( - "bytes" "encoding/binary" "encoding/hex" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type FunctionCallResponse struct { @@ -34,21 +35,21 @@ func (dst *FunctionCallResponse) Decode(src []byte) error { return nil } -func (src *FunctionCallResponse) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - - buf.WriteByte('V') - buf.Write(bigEndian.Uint32(uint32(4 + 4 + len(src.Result)))) +func (src *FunctionCallResponse) Encode(dst []byte) []byte { + dst = append(dst, 'V') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) if src.Result == nil { - buf.Write(bigEndian.Int32(-1)) + dst = pgio.AppendInt32(dst, -1) } else { - buf.Write(bigEndian.Int32(int32(len(src.Result)))) - buf.Write(src.Result) + dst = pgio.AppendInt32(dst, int32(len(src.Result))) + dst = append(dst, src.Result...) } - return buf.Bytes(), nil + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst } func (src *FunctionCallResponse) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/no_data.go b/pgproto3/no_data.go index 3adec4ad..1fb47c2a 100644 --- a/pgproto3/no_data.go +++ b/pgproto3/no_data.go @@ -16,8 +16,8 @@ func (dst *NoData) Decode(src []byte) error { return nil } -func (src *NoData) MarshalBinary() ([]byte, error) { - return []byte{'n', 0, 0, 0, 4}, nil +func (src *NoData) Encode(dst []byte) []byte { + return append(dst, 'n', 0, 0, 0, 4) } func (src *NoData) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/notice_response.go b/pgproto3/notice_response.go index 8af55baf..e4595aa5 100644 --- a/pgproto3/notice_response.go +++ b/pgproto3/notice_response.go @@ -8,6 +8,6 @@ func (dst *NoticeResponse) Decode(src []byte) error { return (*ErrorResponse)(dst).Decode(src) } -func (src *NoticeResponse) MarshalBinary() ([]byte, error) { - return (*ErrorResponse)(src).marshalBinary('N') +func (src *NoticeResponse) Encode(dst []byte) []byte { + return append(dst, (*ErrorResponse)(src).marshalBinary('N')...) } diff --git a/pgproto3/notification_response.go b/pgproto3/notification_response.go index 7262844e..b14007b4 100644 --- a/pgproto3/notification_response.go +++ b/pgproto3/notification_response.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type NotificationResponse struct { @@ -35,19 +37,19 @@ func (dst *NotificationResponse) Decode(src []byte) error { return nil } -func (src *NotificationResponse) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} +func (src *NotificationResponse) Encode(dst []byte) []byte { + dst = append(dst, 'A') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) - buf.WriteByte('A') - buf.Write(bigEndian.Uint32(uint32(4 + 4 + len(src.Channel) + len(src.Payload)))) + dst = append(dst, src.Channel...) + dst = append(dst, 0) + dst = append(dst, src.Payload...) + dst = append(dst, 0) - buf.WriteString(src.Channel) - buf.WriteByte(0) - buf.WriteString(src.Payload) - buf.WriteByte(0) + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - return buf.Bytes(), nil + return dst } func (src *NotificationResponse) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/parameter_description.go b/pgproto3/parameter_description.go index 32b6e1c1..1fa3c927 100644 --- a/pgproto3/parameter_description.go +++ b/pgproto3/parameter_description.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type ParameterDescription struct { @@ -33,20 +35,19 @@ func (dst *ParameterDescription) Decode(src []byte) error { return nil } -func (src *ParameterDescription) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - - buf.WriteByte('t') - buf.Write(bigEndian.Uint32(uint32(4 + 2 + 4*len(src.ParameterOIDs)))) - - buf.Write(bigEndian.Uint16(uint16(len(src.ParameterOIDs)))) +func (src *ParameterDescription) Encode(dst []byte) []byte { + dst = append(dst, 't') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) for _, oid := range src.ParameterOIDs { - buf.Write(bigEndian.Uint32(oid)) + dst = pgio.AppendUint32(dst, oid) } - return buf.Bytes(), nil + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst } func (src *ParameterDescription) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/parameter_status.go b/pgproto3/parameter_status.go index 9b10824c..b3bac33f 100644 --- a/pgproto3/parameter_status.go +++ b/pgproto3/parameter_status.go @@ -2,8 +2,9 @@ package pgproto3 import ( "bytes" - "encoding/binary" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type ParameterStatus struct { @@ -32,21 +33,19 @@ func (dst *ParameterStatus) Decode(src []byte) error { return nil } -func (src *ParameterStatus) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} +func (src *ParameterStatus) Encode(dst []byte) []byte { + dst = append(dst, 'S') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) - buf.WriteByte('S') - buf.Write(bigEndian.Uint32(0)) + dst = append(dst, src.Name...) + dst = append(dst, 0) + dst = append(dst, src.Value...) + dst = append(dst, 0) - buf.WriteString(src.Name) - buf.WriteByte(0) - buf.WriteString(src.Value) - buf.WriteByte(0) + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) - - return buf.Bytes(), nil + return dst } func (ps *ParameterStatus) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/parse.go b/pgproto3/parse.go index 5d17ed11..b8775547 100644 --- a/pgproto3/parse.go +++ b/pgproto3/parse.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type Parse struct { @@ -44,27 +46,24 @@ func (dst *Parse) Decode(src []byte) error { return nil } -func (src *Parse) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} +func (src *Parse) Encode(dst []byte) []byte { + dst = append(dst, 'P') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) - buf.WriteByte('P') - buf.Write(bigEndian.Uint32(0)) + dst = append(dst, src.Name...) + dst = append(dst, 0) + dst = append(dst, src.Query...) + dst = append(dst, 0) - buf.WriteString(src.Name) - buf.WriteByte(0) - buf.WriteString(src.Query) - buf.WriteByte(0) - - buf.Write(bigEndian.Uint16(uint16(len(src.ParameterOIDs)))) - - for _, v := range src.ParameterOIDs { - buf.Write(bigEndian.Uint32(v)) + dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) + for _, oid := range src.ParameterOIDs { + dst = pgio.AppendUint32(dst, oid) } - binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - return buf.Bytes(), nil + return dst } func (src *Parse) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/parse_complete.go b/pgproto3/parse_complete.go index e949c14c..462a89ba 100644 --- a/pgproto3/parse_complete.go +++ b/pgproto3/parse_complete.go @@ -16,8 +16,8 @@ func (dst *ParseComplete) Decode(src []byte) error { return nil } -func (src *ParseComplete) MarshalBinary() ([]byte, error) { - return []byte{'1', 0, 0, 0, 4}, nil +func (src *ParseComplete) Encode(dst []byte) []byte { + return append(dst, '1', 0, 0, 0, 4) } func (src *ParseComplete) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/password_message.go b/pgproto3/password_message.go index 69df6362..2ad3fe4a 100644 --- a/pgproto3/password_message.go +++ b/pgproto3/password_message.go @@ -3,6 +3,8 @@ package pgproto3 import ( "bytes" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type PasswordMessage struct { @@ -23,14 +25,14 @@ func (dst *PasswordMessage) Decode(src []byte) error { return nil } -func (src *PasswordMessage) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - buf.WriteByte('p') - buf.Write(bigEndian.Uint32(uint32(4 + len(src.Password) + 1))) - buf.WriteString(src.Password) - buf.WriteByte(0) - return buf.Bytes(), nil +func (src *PasswordMessage) Encode(dst []byte) []byte { + dst = append(dst, 'p') + dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1)) + + dst = append(dst, src.Password...) + dst = append(dst, 0) + + return dst } func (src *PasswordMessage) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/pgproto3.go b/pgproto3/pgproto3.go index 3fe8fc93..fe7b085b 100644 --- a/pgproto3/pgproto3.go +++ b/pgproto3/pgproto3.go @@ -4,12 +4,13 @@ import "fmt" // Message is the interface implemented by an object that can decode and encode // a particular PostgreSQL message. -// -// Decode is allowed and expected to retain a reference to data after -// returning (unlike encoding.BinaryUnmarshaler). type Message interface { + // Decode is allowed and expected to retain a reference to data after + // returning (unlike encoding.BinaryUnmarshaler). Decode(data []byte) error - MarshalBinary() (data []byte, err error) + + // Encode appends itself to dst and returns the new buffer. + Encode(dst []byte) []byte } type FrontendMessage interface { diff --git a/pgproto3/query.go b/pgproto3/query.go index b5fc2dbc..d80c0fb4 100644 --- a/pgproto3/query.go +++ b/pgproto3/query.go @@ -3,6 +3,8 @@ package pgproto3 import ( "bytes" "encoding/json" + + "github.com/jackc/pgx/pgio" ) type Query struct { @@ -22,14 +24,14 @@ func (dst *Query) Decode(src []byte) error { return nil } -func (src *Query) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - buf.WriteByte('Q') - buf.Write(bigEndian.Uint32(uint32(4 + len(src.String) + 1))) - buf.WriteString(src.String) - buf.WriteByte(0) - return buf.Bytes(), nil +func (src *Query) Encode(dst []byte) []byte { + dst = append(dst, 'Q') + dst = pgio.AppendInt32(dst, int32(4+len(src.String)+1)) + + dst = append(dst, src.String...) + dst = append(dst, 0) + + return dst } func (src *Query) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/ready_for_query.go b/pgproto3/ready_for_query.go index e0e4707a..63b902bd 100644 --- a/pgproto3/ready_for_query.go +++ b/pgproto3/ready_for_query.go @@ -20,8 +20,8 @@ func (dst *ReadyForQuery) Decode(src []byte) error { return nil } -func (src *ReadyForQuery) MarshalBinary() ([]byte, error) { - return []byte{'Z', 0, 0, 0, 5, src.TxStatus}, nil +func (src *ReadyForQuery) Encode(dst []byte) []byte { + return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus) } func (src *ReadyForQuery) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/row_description.go b/pgproto3/row_description.go index b1110290..d0df11b0 100644 --- a/pgproto3/row_description.go +++ b/pgproto3/row_description.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + + "github.com/jackc/pgx/pgio" ) const ( @@ -64,30 +66,27 @@ func (dst *RowDescription) Decode(src []byte) error { return nil } -func (src *RowDescription) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - - buf.WriteByte('T') - buf.Write(bigEndian.Uint32(0)) - - buf.Write(bigEndian.Uint16(uint16(len(src.Fields)))) +func (src *RowDescription) Encode(dst []byte) []byte { + dst = append(dst, 'T') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + dst = pgio.AppendUint16(dst, uint16(len(src.Fields))) for _, fd := range src.Fields { - buf.WriteString(fd.Name) - buf.WriteByte(0) + dst = append(dst, fd.Name...) + dst = append(dst, 0) - buf.Write(bigEndian.Uint32(fd.TableOID)) - buf.Write(bigEndian.Uint16(fd.TableAttributeNumber)) - buf.Write(bigEndian.Uint32(fd.DataTypeOID)) - buf.Write(bigEndian.Uint16(uint16(fd.DataTypeSize))) - buf.Write(bigEndian.Uint32(fd.TypeModifier)) - buf.Write(bigEndian.Uint16(uint16(fd.Format))) + dst = pgio.AppendUint32(dst, fd.TableOID) + dst = pgio.AppendUint16(dst, fd.TableAttributeNumber) + dst = pgio.AppendUint32(dst, fd.DataTypeOID) + dst = pgio.AppendInt16(dst, fd.DataTypeSize) + dst = pgio.AppendUint32(dst, fd.TypeModifier) + dst = pgio.AppendInt16(dst, fd.Format) } - binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - return buf.Bytes(), nil + return dst } func (src *RowDescription) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/startup_message.go b/pgproto3/startup_message.go index 4847d629..4e2df27d 100644 --- a/pgproto3/startup_message.go +++ b/pgproto3/startup_message.go @@ -5,6 +5,8 @@ import ( "encoding/binary" "encoding/json" "fmt" + + "github.com/jackc/pgx/pgio" ) const ( @@ -64,22 +66,22 @@ func (dst *StartupMessage) Decode(src []byte) error { return nil } -func (src *StartupMessage) MarshalBinary() ([]byte, error) { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - buf.Write(bigEndian.Uint32(0)) - buf.Write(bigEndian.Uint32(src.ProtocolVersion)) +func (src *StartupMessage) Encode(dst []byte) []byte { + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = pgio.AppendUint32(dst, src.ProtocolVersion) for k, v := range src.Parameters { - buf.WriteString(k) - buf.WriteByte(0) - buf.WriteString(v) - buf.WriteByte(0) + dst = append(dst, k...) + dst = append(dst, 0) + dst = append(dst, v...) + dst = append(dst, 0) } - buf.WriteByte(0) + dst = append(dst, 0) - binary.BigEndian.PutUint32(buf.Bytes()[0:4], uint32(buf.Len())) + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - return buf.Bytes(), nil + return dst } func (src *StartupMessage) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/sync.go b/pgproto3/sync.go index da3fa727..85f4749a 100644 --- a/pgproto3/sync.go +++ b/pgproto3/sync.go @@ -16,8 +16,8 @@ func (dst *Sync) Decode(src []byte) error { return nil } -func (src *Sync) MarshalBinary() ([]byte, error) { - return []byte{'S', 0, 0, 0, 4}, nil +func (src *Sync) Encode(dst []byte) []byte { + return append(dst, 'S', 0, 0, 0, 4) } func (src *Sync) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/terminate.go b/pgproto3/terminate.go index 77977f20..0a3310da 100644 --- a/pgproto3/terminate.go +++ b/pgproto3/terminate.go @@ -16,8 +16,8 @@ func (dst *Terminate) Decode(src []byte) error { return nil } -func (src *Terminate) MarshalBinary() ([]byte, error) { - return []byte{'X', 0, 0, 0, 4}, nil +func (src *Terminate) Encode(dst []byte) []byte { + return append(dst, 'X', 0, 0, 0, 4) } func (src *Terminate) MarshalJSON() ([]byte, error) { From dd5de3e49e56174a232d7156c745fd3ea3a70d7e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 29 May 2017 09:11:52 -0500 Subject: [PATCH 226/264] Add single round-trip mode for ExecEx --- conn.go | 144 ++++++++++++++++++++++++++++++++++++--------------- conn_test.go | 41 +++++++++++++++ query.go | 1 + 3 files changed, 143 insertions(+), 43 deletions(-) diff --git a/conn.go b/conn.go index c4c054dd..6b0cc0c5 100644 --- a/conn.go +++ b/conn.go @@ -1428,82 +1428,82 @@ func (c *Conn) Ping(ctx context.Context) error { return err } -func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) { - err = c.waitForPreviousCancelQuery(ctx) +func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (CommandTag, error) { + err := c.waitForPreviousCancelQuery(ctx) if err != nil { return "", err } - if err = c.lock(); err != nil { - return commandTag, err + if err := c.lock(); err != nil { + return "", err } + defer c.unlock() startTime := time.Now() c.lastActivityTime = startTime - defer func() { - if err == nil { - if c.shouldLog(LogLevelInfo) { - endTime := time.Now() - c.log(LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag}) - } - } else { - if c.shouldLog(LogLevelError) { - c.log(LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err}) - } + commandTag, err := c.execEx(ctx, sql, options, arguments...) + if err != nil { + if c.shouldLog(LogLevelError) { + c.log(LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err}) } + return commandTag, err + } - if unlockErr := c.unlock(); unlockErr != nil && err == nil { - err = unlockErr - } + if c.shouldLog(LogLevelInfo) { + endTime := time.Now() + c.log(LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag}) + } + + return commandTag, err +} + +func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) { + err = c.initContext(ctx) + if err != nil { + return "", err + } + defer func() { + err = c.termContext(err) }() if options != nil && options.SimpleProtocol { - err = c.initContext(ctx) - if err != nil { - return "", err - } - defer func() { - err = c.termContext(err) - }() - err = c.sanitizeAndSendSimpleQuery(sql, arguments...) if err != nil { return "", err - } + } else if options != nil && len(options.ParameterOids) > 0 { + buf, err := c.buildOneRoundTripExec(c.wbuf, sql, options, arguments) + if err != nil { + return "", err + } + + // sync + buf = append(buf, 'S') + buf = pgio.AppendInt32(buf, 4) + + n, err := c.conn.Write(buf) + if err != nil && fatalWriteErr(n, err) { + c.die(err) + return "", err + } + c.readyForQuery = false } else { if len(arguments) > 0 { ps, ok := c.preparedStatements[sql] if !ok { var err error - ps, err = c.PrepareEx(ctx, "", sql, nil) + ps, err = c.prepareEx("", sql, nil) if err != nil { return "", err } } - err = c.initContext(ctx) - if err != nil { - return "", err - } - defer func() { - err = c.termContext(err) - }() - err = c.sendPreparedQuery(ps, arguments...) if err != nil { return "", err } } else { - err = c.initContext(ctx) - if err != nil { - return "", err - } - defer func() { - err = c.termContext(err) - }() - if err = c.sendQuery(sql, arguments...); err != nil { return } @@ -1532,6 +1532,64 @@ func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, } } +func (c *Conn) buildOneRoundTripExec(buf []byte, sql string, options *QueryExOptions, arguments []interface{}) ([]byte, error) { + if len(arguments) != len(options.ParameterOids) { + return nil, fmt.Errorf("mismatched number of arguments (%d) and options.ParameterOids (%d)", len(arguments), len(options.ParameterOids)) + } + + if len(options.ParameterOids) > 65535 { + return nil, fmt.Errorf("Number of QueryExOptions ParameterOids must be between 0 and 65535, received %d", len(options.ParameterOids)) + } + + // parse + buf = append(buf, 'P') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, 0) + buf = append(buf, sql...) + buf = append(buf, 0) + + buf = pgio.AppendInt16(buf, int16(len(options.ParameterOids))) + for _, oid := range options.ParameterOids { + buf = pgio.AppendUint32(buf, uint32(oid)) + } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + // bind + buf = append(buf, 'B') + sp = len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, 0) + buf = append(buf, 0) + + buf = pgio.AppendInt16(buf, int16(len(options.ParameterOids))) + for i, oid := range options.ParameterOids { + buf = pgio.AppendInt16(buf, chooseParameterFormatCode(c.ConnInfo, oid, arguments[i])) + } + + buf = pgio.AppendInt16(buf, int16(len(arguments))) + for i, oid := range options.ParameterOids { + var err error + buf, err = encodePreparedStatementArgument(c.ConnInfo, buf, oid, arguments[i]) + if err != nil { + return nil, err + } + } + + // No result values for an exec + buf = pgio.AppendInt16(buf, 0) + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + // execute + buf = append(buf, 'E') + buf = pgio.AppendInt32(buf, 9) + buf = append(buf, 0) + buf = pgio.AppendInt32(buf, 0) + + return buf, nil +} + func (c *Conn) initContext(ctx context.Context) error { if c.ctxInProgress { return errors.New("ctx already in progress") diff --git a/conn_test.go b/conn_test.go index acee1b49..4d001da5 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1155,6 +1155,47 @@ func TestExecExSimpleProtocol(t *testing.T) { } } +func TestConnExecExSuppliedCorrectParameterOIDs(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, "create temporary table foo(name varchar primary key);") + + commandTag, err := conn.ExecEx( + context.Background(), + "insert into foo(name) values($1);", + &pgx.QueryExOptions{ParameterOids: []pgtype.Oid{pgtype.VarcharOid}}, + "bar'; drop table foo;--", + ) + if err != nil { + t.Fatal(err) + } + if commandTag != "INSERT 0 1" { + t.Fatalf("Unexpected results from ExecEx: %v", commandTag) + } +} + +func TestConnExecExSuppliedIncorrectParameterOIDs(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, "create temporary table foo(name varchar primary key);") + + _, err := conn.ExecEx( + context.Background(), + "insert into foo(name) values($1);", + &pgx.QueryExOptions{ParameterOids: []pgtype.Oid{pgtype.Int4Oid}}, + "bar'; drop table foo;--", + ) + if err == nil { + t.Fatal("expected error but got none") + } +} + func TestPrepare(t *testing.T) { t.Parallel() diff --git a/query.go b/query.go index 10eda1bc..0962b352 100644 --- a/query.go +++ b/query.go @@ -348,6 +348,7 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { } type QueryExOptions struct { + ParameterOids []pgtype.Oid SimpleProtocol bool } From 4ca7ad120787e13b58d62110f27f6eb151c6638b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 29 May 2017 09:12:56 -0500 Subject: [PATCH 227/264] Remove unused code --- messages.go | 23 +++-------------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/messages.go b/messages.go index 8e406602..b96d25c3 100644 --- a/messages.go +++ b/messages.go @@ -11,26 +11,9 @@ const ( ) const ( - backendKeyData = 'K' - authenticationX = 'R' - readyForQuery = 'Z' - rowDescription = 'T' - dataRow = 'D' - commandComplete = 'C' - errorResponse = 'E' - noticeResponse = 'N' - parseComplete = '1' - parameterDescription = 't' - bindComplete = '2' - notificationResponse = 'A' - emptyQueryResponse = 'I' - noData = 'n' - closeComplete = '3' - flush = 'H' - copyInResponse = 'G' - copyData = 'd' - copyFail = 'f' - copyDone = 'c' + copyData = 'd' + copyFail = 'f' + copyDone = 'c' ) type startupMessage struct { From 4ee21a15de306889a17140af02c060378ee3feee Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 29 May 2017 09:18:41 -0500 Subject: [PATCH 228/264] Use pgproto3 for startup message --- conn.go | 20 +++++++++----------- messages.go | 28 ---------------------------- 2 files changed, 9 insertions(+), 39 deletions(-) diff --git a/conn.go b/conn.go index 6b0cc0c5..2c4f4907 100644 --- a/conn.go +++ b/conn.go @@ -302,27 +302,30 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl return err } - msg := newStartupMessage() + startupMsg := pgproto3.StartupMessage{ + ProtocolVersion: pgproto3.ProtocolVersionNumber, + Parameters: make(map[string]string), + } // Default to disabling TLS renegotiation. // // Go does not support (https://github.com/golang/go/issues/5742) // PostgreSQL recommends disabling (http://www.postgresql.org/docs/9.4/static/runtime-config-connection.html#GUC-SSL-RENEGOTIATION-LIMIT) if tlsConfig != nil { - msg.options["ssl_renegotiation_limit"] = "0" + startupMsg.Parameters["ssl_renegotiation_limit"] = "0" } // Copy default run-time params for k, v := range config.RuntimeParams { - msg.options[k] = v + startupMsg.Parameters[k] = v } - msg.options["user"] = c.config.User + startupMsg.Parameters["user"] = c.config.User if c.config.Database != "" { - msg.options["database"] = c.config.Database + startupMsg.Parameters["database"] = c.config.Database } - if err = c.txStartupMessage(msg); err != nil { + if _, err := c.conn.Write(startupMsg.Encode(nil)); err != nil { return err } @@ -1272,11 +1275,6 @@ func (c *Conn) startTLS(tlsConfig *tls.Config) (err error) { return nil } -func (c *Conn) txStartupMessage(msg *startupMessage) error { - _, err := c.conn.Write(msg.Bytes()) - return err -} - func (c *Conn) txPasswordMessage(password string) (err error) { buf := c.wbuf buf = append(buf, 'p') diff --git a/messages.go b/messages.go index b96d25c3..f06f8b41 100644 --- a/messages.go +++ b/messages.go @@ -1,43 +1,15 @@ package pgx import ( - "encoding/binary" - "github.com/jackc/pgx/pgtype" ) -const ( - protocolVersionNumber = 196608 // 3.0 -) - const ( copyData = 'd' copyFail = 'f' copyDone = 'c' ) -type startupMessage struct { - options map[string]string -} - -func newStartupMessage() *startupMessage { - return &startupMessage{map[string]string{}} -} - -func (s *startupMessage) Bytes() (buf []byte) { - buf = make([]byte, 8, 128) - binary.BigEndian.PutUint32(buf[4:8], uint32(protocolVersionNumber)) - for key, value := range s.options { - buf = append(buf, key...) - buf = append(buf, 0) - buf = append(buf, value...) - buf = append(buf, 0) - } - buf = append(buf, ("\000")...) - binary.BigEndian.PutUint32(buf[0:4], uint32(len(buf))) - return buf -} - type FieldDescription struct { Name string Table pgtype.Oid From 90975ab5c274c25b11861ae873faa99ab176d24d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 29 May 2017 10:01:07 -0500 Subject: [PATCH 229/264] Extract append message functions. In general, pgproto3 types should be used. But these functions may be easier to without incurring additional memory allocations. --- conn.go | 146 +++++++++------------------------------------------- messages.go | 109 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 133 insertions(+), 122 deletions(-) diff --git a/conn.go b/conn.go index 2c4f4907..be64f104 100644 --- a/conn.go +++ b/conn.go @@ -763,41 +763,17 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared }() } - // parse - buf := c.wbuf - buf = append(buf, 'P') - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - buf = append(buf, name...) - buf = append(buf, 0) - buf = append(buf, sql...) - buf = append(buf, 0) - - if opts != nil { - if len(opts.ParameterOids) > 65535 { - return nil, fmt.Errorf("Number of PrepareExOptions ParameterOids must be between 0 and 65535, received %d", len(opts.ParameterOids)) - } - buf = pgio.AppendInt16(buf, int16(len(opts.ParameterOids))) - for _, oid := range opts.ParameterOids { - buf = pgio.AppendInt32(buf, int32(oid)) - } - } else { - buf = pgio.AppendInt16(buf, 0) + if opts == nil { + opts = &PrepareExOptions{} } - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - // describe - buf = append(buf, 'D') - sp = len(buf) - buf = pgio.AppendInt32(buf, -1) - buf = append(buf, 'S') - buf = append(buf, name...) - buf = append(buf, 0) - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + if len(opts.ParameterOids) > 65535 { + return nil, fmt.Errorf("Number of PrepareExOptions ParameterOids must be between 0 and 65535, received %d", len(opts.ParameterOids)) + } - // sync - buf = append(buf, 'S') - buf = pgio.AppendInt32(buf, 4) + buf := appendParse(c.wbuf, name, sql, opts.ParameterOids) + buf = appendDescribe(buf, 'S', name) + buf = appendSync(buf) n, err := c.conn.Write(buf) if err != nil { @@ -1021,13 +997,7 @@ func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error { } if len(args) == 0 { - buf := c.wbuf - buf = append(buf, 'Q') - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - buf = append(buf, sql...) - buf = append(buf, 0) - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + buf := appendQuery(c.wbuf, sql) _, err := c.conn.Write(buf) if err != nil { @@ -1056,44 +1026,17 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} return err } - // bind - buf := c.wbuf - buf = append(buf, 'B') - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - buf = append(buf, 0) - buf = append(buf, ps.Name...) - buf = append(buf, 0) - - buf = pgio.AppendInt16(buf, int16(len(ps.ParameterOids))) - for i, oid := range ps.ParameterOids { - buf = pgio.AppendInt16(buf, chooseParameterFormatCode(c.ConnInfo, oid, arguments[i])) + resultFormatCodes := make([]int16, len(ps.FieldDescriptions)) + for i, fd := range ps.FieldDescriptions { + resultFormatCodes[i] = fd.FormatCode + } + buf, err := appendBind(c.wbuf, "", ps.Name, c.ConnInfo, ps.ParameterOids, arguments, resultFormatCodes) + if err != nil { + return err } - buf = pgio.AppendInt16(buf, int16(len(arguments))) - for i, oid := range ps.ParameterOids { - var err error - buf, err = encodePreparedStatementArgument(c.ConnInfo, buf, oid, arguments[i]) - if err != nil { - return err - } - } - - buf = pgio.AppendInt16(buf, int16(len(ps.FieldDescriptions))) - for _, fd := range ps.FieldDescriptions { - buf = pgio.AppendInt16(buf, fd.FormatCode) - } - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - - // execute - buf = append(buf, 'E') - buf = pgio.AppendInt32(buf, 9) - buf = append(buf, 0) - buf = pgio.AppendInt32(buf, 0) - - // sync - buf = append(buf, 'S') - buf = pgio.AppendInt32(buf, 4) + buf = appendExecute(buf, "", 0) + buf = appendSync(buf) n, err := c.conn.Write(buf) if err != nil && fatalWriteErr(n, err) { @@ -1476,9 +1419,7 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, return "", err } - // sync - buf = append(buf, 'S') - buf = pgio.AppendInt32(buf, 4) + buf = appendSync(buf) n, err := c.conn.Write(buf) if err != nil && fatalWriteErr(n, err) { @@ -1539,51 +1480,12 @@ func (c *Conn) buildOneRoundTripExec(buf []byte, sql string, options *QueryExOpt return nil, fmt.Errorf("Number of QueryExOptions ParameterOids must be between 0 and 65535, received %d", len(options.ParameterOids)) } - // parse - buf = append(buf, 'P') - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - buf = append(buf, 0) - buf = append(buf, sql...) - buf = append(buf, 0) - - buf = pgio.AppendInt16(buf, int16(len(options.ParameterOids))) - for _, oid := range options.ParameterOids { - buf = pgio.AppendUint32(buf, uint32(oid)) + buf = appendParse(buf, "", sql, options.ParameterOids) + buf, err := appendBind(buf, "", "", c.ConnInfo, options.ParameterOids, arguments, nil) + if err != nil { + return nil, err } - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - - // bind - buf = append(buf, 'B') - sp = len(buf) - buf = pgio.AppendInt32(buf, -1) - buf = append(buf, 0) - buf = append(buf, 0) - - buf = pgio.AppendInt16(buf, int16(len(options.ParameterOids))) - for i, oid := range options.ParameterOids { - buf = pgio.AppendInt16(buf, chooseParameterFormatCode(c.ConnInfo, oid, arguments[i])) - } - - buf = pgio.AppendInt16(buf, int16(len(arguments))) - for i, oid := range options.ParameterOids { - var err error - buf, err = encodePreparedStatementArgument(c.ConnInfo, buf, oid, arguments[i]) - if err != nil { - return nil, err - } - } - - // No result values for an exec - buf = pgio.AppendInt16(buf, 0) - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - - // execute - buf = append(buf, 'E') - buf = pgio.AppendInt32(buf, 9) - buf = append(buf, 0) - buf = pgio.AppendInt32(buf, 0) + buf = appendExecute(buf, "", 0) return buf, nil } diff --git a/messages.go b/messages.go index f06f8b41..0bf501b4 100644 --- a/messages.go +++ b/messages.go @@ -1,6 +1,7 @@ package pgx import ( + "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgtype" ) @@ -47,3 +48,111 @@ type PgError struct { func (pe PgError) Error() string { return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" } + +// appendParse appends a PostgreSQL wire protocol parse message to buf and returns it. +func appendParse(buf []byte, name string, query string, parameterOIDs []pgtype.Oid) []byte { + buf = append(buf, 'P') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, name...) + buf = append(buf, 0) + buf = append(buf, query...) + buf = append(buf, 0) + + buf = pgio.AppendInt16(buf, int16(len(parameterOIDs))) + for _, oid := range parameterOIDs { + buf = pgio.AppendUint32(buf, uint32(oid)) + } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + return buf +} + +// appendDescribe appends a PostgreSQL wire protocol describe message to buf and returns it. +func appendDescribe(buf []byte, objectType byte, name string) []byte { + buf = append(buf, 'D') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, objectType) + buf = append(buf, name...) + buf = append(buf, 0) + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + return buf +} + +// appendSync appends a PostgreSQL wire protocol sync message to buf and returns it. +func appendSync(buf []byte) []byte { + buf = append(buf, 'S') + buf = pgio.AppendInt32(buf, 4) + + return buf +} + +// appendBind appends a PostgreSQL wire protocol bind message to buf and returns it. +func appendBind( + buf []byte, + destinationPortal, + preparedStatement string, + connInfo *pgtype.ConnInfo, + parameterOIDs []pgtype.Oid, + arguments []interface{}, + resultFormatCodes []int16, +) ([]byte, error) { + buf = append(buf, 'B') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, destinationPortal...) + buf = append(buf, 0) + buf = append(buf, preparedStatement...) + buf = append(buf, 0) + + buf = pgio.AppendInt16(buf, int16(len(parameterOIDs))) + for i, oid := range parameterOIDs { + buf = pgio.AppendInt16(buf, chooseParameterFormatCode(connInfo, oid, arguments[i])) + } + + buf = pgio.AppendInt16(buf, int16(len(arguments))) + for i, oid := range parameterOIDs { + var err error + buf, err = encodePreparedStatementArgument(connInfo, buf, oid, arguments[i]) + if err != nil { + return nil, err + } + } + + buf = pgio.AppendInt16(buf, int16(len(resultFormatCodes))) + for _, fc := range resultFormatCodes { + buf = pgio.AppendInt16(buf, fc) + } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + return buf, nil +} + +// appendExecute appends a PostgreSQL wire protocol execute message to buf and returns it. +func appendExecute(buf []byte, portal string, maxRows uint32) []byte { + buf = append(buf, 'E') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf = append(buf, portal...) + buf = append(buf, 0) + buf = pgio.AppendUint32(buf, maxRows) + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + return buf +} + +// appendQuery appends a PostgreSQL wire protocol query message to buf and returns it. +func appendQuery(buf []byte, query string) []byte { + buf = append(buf, 'Q') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, query...) + buf = append(buf, 0) + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + return buf +} From 85f30d10d27a0ea2997cd5d3031adccfe3eb2899 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 29 May 2017 11:24:49 -0500 Subject: [PATCH 230/264] Ensure pgproto3.Parse.Decode overwrites itself entirely --- pgproto3/parse.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pgproto3/parse.go b/pgproto3/parse.go index b8775547..ca4834c6 100644 --- a/pgproto3/parse.go +++ b/pgproto3/parse.go @@ -17,6 +17,8 @@ type Parse struct { func (*Parse) Frontend() {} func (dst *Parse) Decode(src []byte) error { + *dst = Parse{} + buf := bytes.NewBuffer(src) b, err := buf.ReadBytes(0) From dd5e6a77dc81414a78548a1b4e04c5b56dcb67f0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 29 May 2017 11:27:44 -0500 Subject: [PATCH 231/264] Add QueryEx single round-trip mode --- query.go | 122 ++++++++++++++++++++++++++++++++++++++++++++------ query_test.go | 26 +++++++++++ 2 files changed, 134 insertions(+), 14 deletions(-) diff --git a/query.go b/query.go index 0962b352..447a55ac 100644 --- a/query.go +++ b/query.go @@ -348,7 +348,12 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { } type QueryExOptions struct { - ParameterOids []pgtype.Oid + // When ParameterOids are present and the query is not a prepared statement, + // then ParameterOids and ResultFormatCodes will be used to avoid an extra + // network round-trip. + ParameterOids []pgtype.Oid + ResultFormatCodes []int16 + SimpleProtocol bool } @@ -358,6 +363,10 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, return nil, err } + if err := c.ensureConnectionReadyForQuery(); err != nil { + return nil, err + } + c.lastActivityTime = time.Now() rows = c.getRows(sql, args) @@ -368,13 +377,13 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, } rows.unlockConn = true - if options != nil && options.SimpleProtocol { - err = c.initContext(ctx) - if err != nil { - rows.fatal(err) - return rows, err - } + err = c.initContext(ctx) + if err != nil { + rows.fatal(err) + return rows, rows.err + } + if options != nil && options.SimpleProtocol { err = c.sanitizeAndSendSimpleQuery(sql, args...) if err != nil { rows.fatal(err) @@ -384,10 +393,54 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, return rows, nil } + if options != nil && len(options.ParameterOids) > 0 { + + buf, err := c.buildOneRoundTripQueryEx(c.wbuf, sql, options, args) + if err != nil { + rows.fatal(err) + return rows, err + } + + buf = appendSync(buf) + + n, err := c.conn.Write(buf) + if err != nil && fatalWriteErr(n, err) { + rows.fatal(err) + c.die(err) + return nil, err + } + c.readyForQuery = false + + fieldDescriptions, err := c.readUntilRowDescription() + if err != nil { + rows.fatal(err) + return nil, err + } + + if len(options.ResultFormatCodes) == 0 { + for i := range fieldDescriptions { + fieldDescriptions[i].FormatCode = TextFormatCode + } + } else if len(options.ResultFormatCodes) == 1 { + fc := options.ResultFormatCodes[0] + for i := range fieldDescriptions { + fieldDescriptions[i].FormatCode = fc + } + } else { + for i := range options.ResultFormatCodes { + fieldDescriptions[i].FormatCode = options.ResultFormatCodes[i] + } + } + + rows.sql = sql + rows.fields = fieldDescriptions + return rows, nil + } + ps, ok := c.preparedStatements[sql] if !ok { var err error - ps, err = c.PrepareEx(ctx, "", sql, nil) + ps, err = c.prepareEx("", sql, nil) if err != nil { rows.fatal(err) return rows, rows.err @@ -396,12 +449,6 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, rows.sql = ps.SQL rows.fields = ps.FieldDescriptions - err = c.initContext(ctx) - if err != nil { - rows.fatal(err) - return rows, rows.err - } - err = c.sendPreparedQuery(ps, args...) if err != nil { rows.fatal(err) @@ -410,6 +457,53 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, return rows, rows.err } +func (c *Conn) buildOneRoundTripQueryEx(buf []byte, sql string, options *QueryExOptions, arguments []interface{}) ([]byte, error) { + if len(arguments) != len(options.ParameterOids) { + return nil, fmt.Errorf("mismatched number of arguments (%d) and options.ParameterOids (%d)", len(arguments), len(options.ParameterOids)) + } + + if len(options.ParameterOids) > 65535 { + return nil, fmt.Errorf("Number of QueryExOptions ParameterOids must be between 0 and 65535, received %d", len(options.ParameterOids)) + } + + buf = appendParse(buf, "", sql, options.ParameterOids) + buf = appendDescribe(buf, 'S', "") + buf, err := appendBind(buf, "", "", c.ConnInfo, options.ParameterOids, arguments, options.ResultFormatCodes) + if err != nil { + return nil, err + } + buf = appendExecute(buf, "", 0) + + return buf, nil +} + +func (c *Conn) readUntilRowDescription() ([]FieldDescription, error) { + for { + msg, err := c.rxMsg() + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *pgproto3.ParameterDescription: + case *pgproto3.RowDescription: + fieldDescriptions := c.rxRowDescription(msg) + for i := range fieldDescriptions { + if dt, ok := c.ConnInfo.DataTypeForOid(fieldDescriptions[i].DataType); ok { + fieldDescriptions[i].DataTypeName = dt.Name + } else { + return nil, fmt.Errorf("unknown oid: %d", fieldDescriptions[i].DataType) + } + } + return fieldDescriptions, nil + default: + if err := c.processContextFreeMsg(msg); err != nil { + return nil, err + } + } + } +} + func (c *Conn) sanitizeAndSendSimpleQuery(sql string, args ...interface{}) (err error) { if c.RuntimeParams["standard_conforming_strings"] != "on" { return errors.New("simple protocol queries must be run with standard_conforming_strings=on") diff --git a/query_test.go b/query_test.go index 801b34dd..4e128fb2 100644 --- a/query_test.go +++ b/query_test.go @@ -1182,6 +1182,32 @@ func TestQueryRowExContextCancelationCancelsQuery(t *testing.T) { ensureConnValid(t, conn) } +func TestConnQueryRowExSingleRoundTrip(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + var result int32 + err := conn.QueryRowEx( + context.Background(), + "select $1 + $2", + &pgx.QueryExOptions{ + ParameterOids: []pgtype.Oid{pgtype.Int4Oid, pgtype.Int4Oid}, + ResultFormatCodes: []int16{pgx.BinaryFormatCode}, + }, + 1, 2, + ).Scan(&result) + if err != nil { + t.Fatal(err) + } + if result != 3 { + t.Fatal("result => %d, want %d", result, 3) + } + + ensureConnValid(t, conn) +} + func TestConnSimpleProtocol(t *testing.T) { t.Parallel() From 07c5b76a24e74f7d685eef7f18864634619496a2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 29 May 2017 11:39:06 -0500 Subject: [PATCH 232/264] Allow for either of 2 possible errors from tx context cancelation --- stdlib/sql_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index bfeb07c6..1aa1f261 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -781,8 +781,8 @@ func TestBeginTxContextCancelWithDeadConn(t *testing.T) { cancelFn() err = tx.Commit() - if err != context.Canceled { - t.Fatalf("err => %v, want %v", err, context.Canceled) + if err != context.Canceled && err != sql.ErrTxDone { + t.Fatalf("err => %v, want %v or %v", err, context.Canceled, sql.ErrTxDone) } if err := <-errChan; err != nil { From e896e8c311ce20acc7babb17f7c57e32905c012f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 29 May 2017 19:15:16 -0500 Subject: [PATCH 233/264] Extract TxOptions beginSQL --- tx.go | 41 +++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/tx.go b/tx.go index 07cae4ba..e144337d 100644 --- a/tx.go +++ b/tx.go @@ -48,6 +48,26 @@ type TxOptions struct { DeferrableMode TxDeferrableMode } +func (txOptions *TxOptions) beginSQL() string { + if txOptions == nil { + return "begin" + } + + buf := &bytes.Buffer{} + buf.WriteString("begin") + if txOptions.IsoLevel != "" { + fmt.Fprintf(buf, " isolation level %s", txOptions.IsoLevel) + } + if txOptions.AccessMode != "" { + fmt.Fprintf(buf, " %s", txOptions.AccessMode) + } + if txOptions.DeferrableMode != "" { + fmt.Fprintf(buf, " %s", txOptions.DeferrableMode) + } + + return buf.String() +} + var ErrTxClosed = errors.New("tx is closed") // ErrTxCommitRollback occurs when an error has occurred in a transaction and @@ -65,26 +85,7 @@ func (c *Conn) Begin() (*Tx, error) { // mode. Unlike database/sql, the context only affects the begin command. i.e. // there is no auto-rollback on context cancelation. func (c *Conn) BeginEx(ctx context.Context, txOptions *TxOptions) (*Tx, error) { - var beginSQL string - if txOptions == nil { - beginSQL = "begin" - } else { - buf := &bytes.Buffer{} - buf.WriteString("begin") - if txOptions.IsoLevel != "" { - fmt.Fprintf(buf, " isolation level %s", txOptions.IsoLevel) - } - if txOptions.AccessMode != "" { - fmt.Fprintf(buf, " %s", txOptions.AccessMode) - } - if txOptions.DeferrableMode != "" { - fmt.Fprintf(buf, " %s", txOptions.DeferrableMode) - } - - beginSQL = buf.String() - } - - _, err := c.ExecEx(ctx, beginSQL, nil) + _, err := c.ExecEx(ctx, txOptions.beginSQL(), nil) if err != nil { // begin should never fail unless there is an underlying connection issue or // a context timeout. In either case, the connection is possibly broken. From 95c11a1fd15165a1404ddf6eed80ea1cf20b9335 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 31 May 2017 07:57:22 -0500 Subject: [PATCH 234/264] Remove bad channel sync causing orphaned goroutine --- conn_test.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/conn_test.go b/conn_test.go index 4d001da5..a7fbbcf1 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1543,13 +1543,9 @@ func TestListenNotifyWhileBusyIsSafe(t *testing.T) { } }() - notifierDone := make(chan bool) go func() { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - defer func() { - notifierDone <- true - }() for i := 0; i < 100000; i++ { mustExec(t, conn, "notify busysafe, 'hello'") From dcf3ee2781c52d5aa7cd0b8dae2b55b1b27fc109 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 31 May 2017 18:33:01 -0500 Subject: [PATCH 235/264] Fix sendPreparedQuery write error hang If the Write call in sendPreparedQuery encountered a non-fatal error - which means it sent no bytes. It still was marking the connection as not ready for query. That caused the next call to hang. --- conn.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/conn.go b/conn.go index be64f104..68312222 100644 --- a/conn.go +++ b/conn.go @@ -1039,12 +1039,15 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} buf = appendSync(buf) n, err := c.conn.Write(buf) - if err != nil && fatalWriteErr(n, err) { - c.die(err) + if err != nil { + if fatalWriteErr(n, err) { + c.die(err) + } + return err } c.readyForQuery = false - return err + return nil } // fatalWriteError takes the response of a net.Conn.Write and determines if it is fatal From dfe250c13b3164cc01430a6aaaaf73e6ab5fd8c8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 2 Jun 2017 08:38:27 -0500 Subject: [PATCH 236/264] Allow either error message --- stdlib/sql_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 1aa1f261..bf99a8bb 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -714,8 +714,8 @@ func TestBeginTxContextCancel(t *testing.T) { cancelFn() err = tx.Commit() - if err != context.Canceled { - t.Fatalf("err => %v, want %v", err, context.Canceled) + if err != context.Canceled && err != sql.ErrTxDone { + t.Fatalf("err => %v, want %v or %v", err, context.Canceled, sql.ErrTxDone) } var n int From fe0af9b35724be39ea7aa894d63b7cf563a9959c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 29 May 2017 19:15:42 -0500 Subject: [PATCH 237/264] Happy-path batch query mode --- batch.go | 211 +++++++++++++++++++++++++++++++++++++++++++++++++ batch_test.go | 150 +++++++++++++++++++++++++++++++++++ conn.go | 20 ++--- helper_test.go | 3 +- query.go | 2 +- 5 files changed, 375 insertions(+), 11 deletions(-) create mode 100644 batch.go create mode 100644 batch_test.go diff --git a/batch.go b/batch.go new file mode 100644 index 00000000..722ce340 --- /dev/null +++ b/batch.go @@ -0,0 +1,211 @@ +package pgx + +import ( + "context" + + "github.com/jackc/pgx/pgproto3" + "github.com/jackc/pgx/pgtype" +) + +type batchItem struct { + query string + arguments []interface{} + parameterOids []pgtype.Oid + resultFormatCodes []int16 +} + +type Batch struct { + conn *Conn + items []*batchItem + resultsRead int + sent bool +} + +// Begin starts a transaction with the default transaction mode for the +// current connection. To use a specific transaction mode see BeginEx. +func (c *Conn) BeginBatch() *Batch { + // TODO - the type stuff below + + // err = c.waitForPreviousCancelQuery(ctx) + // if err != nil { + // return nil, err + // } + + // if err := c.ensureConnectionReadyForQuery(); err != nil { + // return nil, err + // } + + // c.lastActivityTime = time.Now() + + // rows = c.getRows(sql, args) + + // if err := c.lock(); err != nil { + // rows.fatal(err) + // return rows, err + // } + // rows.unlockConn = true + + // err = c.initContext(ctx) + // if err != nil { + // rows.fatal(err) + // return rows, rows.err + // } + + // if options != nil && options.SimpleProtocol { + // err = c.sanitizeAndSendSimpleQuery(sql, args...) + // if err != nil { + // rows.fatal(err) + // return rows, err + // } + + // return rows, nil + // } + + return &Batch{conn: c} +} + +func (b *Batch) Conn() *Conn { + return b.conn +} + +func (b *Batch) Queue(query string, arguments []interface{}, parameterOids []pgtype.Oid, resultFormatCodes []int16) { + b.items = append(b.items, &batchItem{ + query: query, + arguments: arguments, + parameterOids: parameterOids, + resultFormatCodes: resultFormatCodes, + }) +} + +func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error { + buf := appendQuery(b.conn.wbuf, txOptions.beginSQL()) + + for _, bi := range b.items { + // TODO - don't parse if named prepared statement + buf = appendParse(buf, "", bi.query, bi.parameterOids) + + var err error + buf, err = appendBind(buf, "", "", b.conn.ConnInfo, bi.parameterOids, bi.arguments, bi.resultFormatCodes) + if err != nil { + return err + } + + buf = appendDescribe(buf, 'P', "") + buf = appendExecute(buf, "", 0) + } + + buf = appendSync(buf) + buf = appendQuery(buf, "commit") + + n, err := b.conn.conn.Write(buf) + if err != nil { + if fatalWriteErr(n, err) { + b.conn.die(err) + } + return err + } + + // expect ReadyForQuery from sync and from commit + b.conn.pendingReadyForQueryCount = b.conn.pendingReadyForQueryCount + 2 + + b.sent = true + + for { + msg, err := b.conn.rxMsg() + if err != nil { + return err + } + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + return nil + default: + if err := b.conn.processContextFreeMsg(msg); err != nil { + return err + } + } + } + + return nil +} + +func (b *Batch) ExecResults() (CommandTag, error) { + b.resultsRead++ + + for { + msg, err := b.conn.rxMsg() + if err != nil { + return "", err + } + + switch msg := msg.(type) { + case *pgproto3.CommandComplete: + return CommandTag(msg.CommandTag), nil + default: + if err := b.conn.processContextFreeMsg(msg); err != nil { + return "", err + } + } + } +} + +func (b *Batch) QueryResults() (*Rows, error) { + b.resultsRead++ + + rows := b.conn.getRows("batch query", nil) + + fieldDescriptions, err := b.conn.readUntilRowDescription() + if err != nil { + rows.fatal(err) + return nil, err + } + + rows.fields = fieldDescriptions + return rows, nil +} + +func (b *Batch) QueryRowResults() *Row { + rows, _ := b.QueryResults() + return (*Row)(rows) + +} + +func (b *Batch) Finish() error { + for i := b.resultsRead; i < len(b.items); i++ { + _, err := b.ExecResults() + if err != nil { + return err + } + } + + // readyForQueryCount := 0 + + // for { + // msg, err := b.conn.rxMsg() + // if err != nil { + // return "", err + // } + + // switch msg := msg.(type) { + // case *pgproto3.ReadyForQuery: + // c.rxReadyForQuery(msg) + // default: + // if err := b.conn.processContextFreeMsg(msg); err != nil { + // return "", err + // } + // } + // } + + // switch msg := msg.(type) { + // case *pgproto3.ErrorResponse: + // return c.rxErrorResponse(msg) + // case *pgproto3.NotificationResponse: + // c.rxNotificationResponse(msg) + // case *pgproto3.ReadyForQuery: + // c.rxReadyForQuery(msg) + // case *pgproto3.ParameterStatus: + // c.rxParameterStatus(msg) + // } + + return nil +} diff --git a/batch_test.go b/batch_test.go new file mode 100644 index 00000000..aeef52f4 --- /dev/null +++ b/batch_test.go @@ -0,0 +1,150 @@ +package pgx_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" +) + +func TestConnBeginBatch(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + sql := `create temporary table ledger( + id serial primary key, + description varchar not null, + amount int not null +);` + mustExec(t, conn, sql) + + batch := conn.BeginBatch() + batch.Queue("insert into ledger(description, amount) values($1, $2)", + []interface{}{"q1", 1}, + []pgtype.Oid{pgtype.VarcharOid, pgtype.Int4Oid}, + nil, + ) + batch.Queue("insert into ledger(description, amount) values($1, $2)", + []interface{}{"q2", 2}, + []pgtype.Oid{pgtype.VarcharOid, pgtype.Int4Oid}, + nil, + ) + batch.Queue("insert into ledger(description, amount) values($1, $2)", + []interface{}{"q3", 3}, + []pgtype.Oid{pgtype.VarcharOid, pgtype.Int4Oid}, + nil, + ) + batch.Queue("select id, description, amount from ledger order by id", + nil, + nil, + []int16{pgx.BinaryFormatCode, pgx.TextFormatCode, pgx.BinaryFormatCode}, + ) + batch.Queue("select sum(amount) from ledger", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + + err := batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + ct, err := batch.ExecResults() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 1 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) + } + + ct, err = batch.ExecResults() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 1 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) + } + + rows, err := batch.QueryResults() + if err != nil { + t.Error(err) + } + + var id int32 + var description string + var amount int32 + if !rows.Next() { + t.Fatal("expected a row to be available") + } + if err := rows.Scan(&id, &description, &amount); err != nil { + t.Fatal(err) + } + if id != 1 { + t.Errorf("id => %v, want %v", id, 1) + } + if description != "q1" { + t.Errorf("description => %v, want %v", description, "q1") + } + if amount != 1 { + t.Errorf("amount => %v, want %v", amount, 1) + } + + if !rows.Next() { + t.Fatal("expected a row to be available") + } + if err := rows.Scan(&id, &description, &amount); err != nil { + t.Fatal(err) + } + if id != 2 { + t.Errorf("id => %v, want %v", id, 2) + } + if description != "q2" { + t.Errorf("description => %v, want %v", description, "q2") + } + if amount != 2 { + t.Errorf("amount => %v, want %v", amount, 2) + } + + if !rows.Next() { + t.Fatal("expected a row to be available") + } + if err := rows.Scan(&id, &description, &amount); err != nil { + t.Fatal(err) + } + if id != 3 { + t.Errorf("id => %v, want %v", id, 3) + } + if description != "q3" { + t.Errorf("description => %v, want %v", description, "q3") + } + if amount != 3 { + t.Errorf("amount => %v, want %v", amount, 3) + } + + if rows.Next() { + t.Fatal("did not expect a row to be available") + } + + if rows.Err() != nil { + t.Fatal(rows.Err()) + } + + err = batch.QueryRowResults().Scan(&amount) + if err != nil { + t.Error(err) + } + if amount != 6 { + t.Errorf("amount => %v, want %v", amount, 6) + } + + err = batch.Finish() + if err != nil { + t.Fatal(err) + } + + ensureConnValid(t, conn) +} diff --git a/conn.go b/conn.go index 68312222..491f2a9e 100644 --- a/conn.go +++ b/conn.go @@ -107,9 +107,9 @@ type Conn struct { status byte // One of connStatus* constants causeOfDeath error - readyForQuery bool // connection has received ReadyForQuery message since last query was sent - cancelQueryInProgress int32 - cancelQueryCompleted chan struct{} + pendingReadyForQueryCount int // numer of ReadyForQuery messages expected + cancelQueryInProgress int32 + cancelQueryCompleted chan struct{} // context support ctxInProgress bool @@ -329,6 +329,8 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl return err } + c.pendingReadyForQueryCount = 1 + for { msg, err := c.rxMsg() if err != nil { @@ -782,7 +784,7 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared } return nil, err } - c.readyForQuery = false + c.pendingReadyForQueryCount++ ps = &PreparedStatement{Name: name, SQL: sql} @@ -1004,7 +1006,7 @@ func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error { c.die(err) return err } - c.readyForQuery = false + c.pendingReadyForQueryCount++ return nil } @@ -1045,7 +1047,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} } return err } - c.readyForQuery = false + c.pendingReadyForQueryCount++ return nil } @@ -1167,7 +1169,7 @@ func (c *Conn) rxBackendKeyData(msg *pgproto3.BackendKeyData) { } func (c *Conn) rxReadyForQuery(msg *pgproto3.ReadyForQuery) { - c.readyForQuery = true + c.pendingReadyForQueryCount-- c.txStatus = msg.TxStatus } @@ -1429,7 +1431,7 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, c.die(err) return "", err } - c.readyForQuery = false + c.pendingReadyForQueryCount++ } else { if len(arguments) > 0 { ps, ok := c.preparedStatements[sql] @@ -1563,7 +1565,7 @@ func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error { } func (c *Conn) ensureConnectionReadyForQuery() error { - for !c.readyForQuery { + for c.pendingReadyForQueryCount > 0 { msg, err := c.rxMsg() if err != nil { return err diff --git a/helper_test.go b/helper_test.go index 21f86de5..78063107 100644 --- a/helper_test.go +++ b/helper_test.go @@ -1,8 +1,9 @@ package pgx_test import ( - "github.com/jackc/pgx" "testing" + + "github.com/jackc/pgx" ) func mustConnect(t testing.TB, config pgx.ConnConfig) *pgx.Conn { diff --git a/query.go b/query.go index 447a55ac..a3903a22 100644 --- a/query.go +++ b/query.go @@ -409,7 +409,7 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, c.die(err) return nil, err } - c.readyForQuery = false + c.pendingReadyForQueryCount++ fieldDescriptions, err := c.readUntilRowDescription() if err != nil { From 73f496d7de2ee5bc59155e42b145743cebdb2614 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 3 Jun 2017 11:49:27 -0500 Subject: [PATCH 238/264] Finish core batch operations --- batch.go | 180 ++++++++++++++++------------ batch_test.go | 296 +++++++++++++++++++++++++++++++++++++++++++++- bench_test.go | 90 ++++++++++++++ conn_pool.go | 7 ++ conn_pool_test.go | 67 +++++++++++ query.go | 5 + v3.md | 2 + 7 files changed, 572 insertions(+), 75 deletions(-) diff --git a/batch.go b/batch.go index 722ce340..a2e8e042 100644 --- a/batch.go +++ b/batch.go @@ -14,60 +14,31 @@ type batchItem struct { resultFormatCodes []int16 } +// Batch queries are a way of bundling multiple queries together to avoid +// unnecessary network round trips. type Batch struct { conn *Conn + connPool *ConnPool items []*batchItem resultsRead int sent bool + ctx context.Context + err error } -// Begin starts a transaction with the default transaction mode for the -// current connection. To use a specific transaction mode see BeginEx. +// BeginBatch returns a *Batch query for c. func (c *Conn) BeginBatch() *Batch { - // TODO - the type stuff below - - // err = c.waitForPreviousCancelQuery(ctx) - // if err != nil { - // return nil, err - // } - - // if err := c.ensureConnectionReadyForQuery(); err != nil { - // return nil, err - // } - - // c.lastActivityTime = time.Now() - - // rows = c.getRows(sql, args) - - // if err := c.lock(); err != nil { - // rows.fatal(err) - // return rows, err - // } - // rows.unlockConn = true - - // err = c.initContext(ctx) - // if err != nil { - // rows.fatal(err) - // return rows, rows.err - // } - - // if options != nil && options.SimpleProtocol { - // err = c.sanitizeAndSendSimpleQuery(sql, args...) - // if err != nil { - // rows.fatal(err) - // return rows, err - // } - - // return rows, nil - // } - return &Batch{conn: c} } +// Conn returns the underlying connection that b will or was performed on. func (b *Batch) Conn() *Conn { return b.conn } +// Queue queues a query to batch b. parameterOids are required if there are +// parameters and query is not the name of a prepared statement. +// resultFormatCodes are required if there is a result. func (b *Batch) Queue(query string, arguments []interface{}, parameterOids []pgtype.Oid, resultFormatCodes []int16) { b.items = append(b.items, &batchItem{ query: query, @@ -77,15 +48,46 @@ func (b *Batch) Queue(query string, arguments []interface{}, parameterOids []pgt }) } +// Send sends all queued queries to the server at once. All queries are wrapped +// in a transaction. The transaction can optionally be configured with +// txOptions. The context is in effect until the Batch is closed. func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error { + if b.err != nil { + return b.err + } + + b.ctx = ctx + + err := b.conn.waitForPreviousCancelQuery(ctx) + if err != nil { + return err + } + + if err := b.conn.ensureConnectionReadyForQuery(); err != nil { + return err + } + + err = b.conn.initContext(ctx) + if err != nil { + return err + } + buf := appendQuery(b.conn.wbuf, txOptions.beginSQL()) for _, bi := range b.items { - // TODO - don't parse if named prepared statement - buf = appendParse(buf, "", bi.query, bi.parameterOids) + var psName string + var psParameterOids []pgtype.Oid + + if ps, ok := b.conn.preparedStatements[bi.query]; ok { + psName = ps.Name + psParameterOids = ps.ParameterOids + } else { + psParameterOids = bi.parameterOids + buf = appendParse(buf, "", bi.query, psParameterOids) + } var err error - buf, err = appendBind(buf, "", "", b.conn.ConnInfo, bi.parameterOids, bi.arguments, bi.resultFormatCodes) + buf, err = appendBind(buf, "", psName, b.conn.ConnInfo, psParameterOids, bi.arguments, bi.resultFormatCodes) if err != nil { return err } @@ -129,7 +131,20 @@ func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error { return nil } +// ExecResults reads the results from the next query in the batch as if the +// query has been sent with Exec. func (b *Batch) ExecResults() (CommandTag, error) { + if b.err != nil { + return "", b.err + } + + select { + case <-b.ctx.Done(): + b.die(b.ctx.Err()) + return "", b.ctx.Err() + default: + } + b.resultsRead++ for { @@ -149,63 +164,80 @@ func (b *Batch) ExecResults() (CommandTag, error) { } } +// QueryResults reads the results from the next query in the batch as if the +// query has been sent with Query. func (b *Batch) QueryResults() (*Rows, error) { + if b.err != nil { + return nil, b.err + } + + select { + case <-b.ctx.Done(): + b.die(b.ctx.Err()) + return nil, b.ctx.Err() + default: + } + b.resultsRead++ rows := b.conn.getRows("batch query", nil) fieldDescriptions, err := b.conn.readUntilRowDescription() if err != nil { - rows.fatal(err) + b.die(b.ctx.Err()) return nil, err } + rows.batch = b rows.fields = fieldDescriptions return rows, nil } +// QueryRowResults reads the results from the next query in the batch as if the +// query has been sent with QueryRow. func (b *Batch) QueryRowResults() *Row { rows, _ := b.QueryResults() return (*Row)(rows) } -func (b *Batch) Finish() error { +// Close closes the batch operation. Any error that occured during a batch +// operation may have made it impossible to resyncronize the connection with the +// server. In this case the underlying connection will have been closed. +func (b *Batch) Close() (err error) { + if b.err != nil { + return b.err + } + + defer func() { + err = b.conn.termContext(err) + if b.conn != nil && b.connPool != nil { + b.connPool.Release(b.conn) + } + }() + for i := b.resultsRead; i < len(b.items); i++ { - _, err := b.ExecResults() - if err != nil { + if _, err = b.ExecResults(); err != nil { return err } } - // readyForQueryCount := 0 - - // for { - // msg, err := b.conn.rxMsg() - // if err != nil { - // return "", err - // } - - // switch msg := msg.(type) { - // case *pgproto3.ReadyForQuery: - // c.rxReadyForQuery(msg) - // default: - // if err := b.conn.processContextFreeMsg(msg); err != nil { - // return "", err - // } - // } - // } - - // switch msg := msg.(type) { - // case *pgproto3.ErrorResponse: - // return c.rxErrorResponse(msg) - // case *pgproto3.NotificationResponse: - // c.rxNotificationResponse(msg) - // case *pgproto3.ReadyForQuery: - // c.rxReadyForQuery(msg) - // case *pgproto3.ParameterStatus: - // c.rxParameterStatus(msg) - // } + if err = b.conn.ensureConnectionReadyForQuery(); err != nil { + return err + } return nil } + +func (b *Batch) die(err error) { + if b.err != nil { + return + } + + b.err = err + b.conn.die(err) + + if b.conn != nil && b.connPool != nil { + b.connPool.Release(b.conn) + } +} diff --git a/batch_test.go b/batch_test.go index aeef52f4..bccf9a20 100644 --- a/batch_test.go +++ b/batch_test.go @@ -141,10 +141,304 @@ func TestConnBeginBatch(t *testing.T) { t.Errorf("amount => %v, want %v", amount, 6) } - err = batch.Finish() + err = batch.Close() if err != nil { t.Fatal(err) } ensureConnValid(t, conn) } + +func TestConnBeginBatchWithPreparedStatement(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + _, err := conn.Prepare("ps1", "select n from generate_series(0,$1::int) n") + if err != nil { + t.Fatal(err) + } + + batch := conn.BeginBatch() + + queryCount := 3 + for i := 0; i < queryCount; i++ { + batch.Queue("ps1", + []interface{}{5}, + nil, + []int16{pgx.BinaryFormatCode}, + ) + } + + err = batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < queryCount; i++ { + rows, err := batch.QueryResults() + if err != nil { + t.Fatal(err) + } + + for k := 0; rows.Next(); k++ { + var n int + if err := rows.Scan(&n); err != nil { + t.Fatal(err) + } + if n != k { + t.Fatalf("n => %v, want %v", n, k) + } + } + + if rows.Err() != nil { + t.Fatal(rows.Err()) + } + } + + err = batch.Close() + if err != nil { + t.Fatal(err) + } + + ensureConnValid(t, conn) +} + +func TestConnBeginBatchContextCancelBeforeExecResults(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + + sql := `create temporary table ledger( + id serial primary key, + description varchar not null, + amount int not null +);` + mustExec(t, conn, sql) + + batch := conn.BeginBatch() + batch.Queue("insert into ledger(description, amount) values($1, $2)", + []interface{}{"q1", 1}, + []pgtype.Oid{pgtype.VarcharOid, pgtype.Int4Oid}, + nil, + ) + batch.Queue("select pg_sleep(2)", + nil, + nil, + nil, + ) + + ctx, cancelFn := context.WithCancel(context.Background()) + + err := batch.Send(ctx, nil) + if err != nil { + t.Fatal(err) + } + + cancelFn() + + _, err = batch.ExecResults() + if err != context.Canceled { + t.Errorf("err => %v, want %v", err, context.Canceled) + } + + if conn.IsAlive() { + t.Error("conn should be dead, but was alive") + } +} + +func TestConnBeginBatchContextCancelBeforeQueryResults(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + + batch := conn.BeginBatch() + batch.Queue("select pg_sleep(2)", + nil, + nil, + nil, + ) + batch.Queue("select pg_sleep(2)", + nil, + nil, + nil, + ) + + ctx, cancelFn := context.WithCancel(context.Background()) + + err := batch.Send(ctx, nil) + if err != nil { + t.Fatal(err) + } + + cancelFn() + + _, err = batch.QueryResults() + if err != context.Canceled { + t.Errorf("err => %v, want %v", err, context.Canceled) + } + + if conn.IsAlive() { + t.Error("conn should be dead, but was alive") + } +} + +func TestConnBeginBatchContextCancelBeforeFinish(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + + batch := conn.BeginBatch() + batch.Queue("select pg_sleep(2)", + nil, + nil, + nil, + ) + batch.Queue("select pg_sleep(2)", + nil, + nil, + nil, + ) + + ctx, cancelFn := context.WithCancel(context.Background()) + + err := batch.Send(ctx, nil) + if err != nil { + t.Fatal(err) + } + + cancelFn() + + err = batch.Close() + if err != context.Canceled { + t.Errorf("err => %v, want %v", err, context.Canceled) + } + + if conn.IsAlive() { + t.Error("conn should be dead, but was alive") + } +} + +func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + batch := conn.BeginBatch() + batch.Queue("select n from generate_series(0,5) n", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + batch.Queue("select n from generate_series(0,5) n", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + + err := batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + rows, err := batch.QueryResults() + if err != nil { + t.Error(err) + } + + for i := 0; i < 3; i++ { + if !rows.Next() { + t.Error("expected a row to be available") + } + + var n int + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + if n != i { + t.Errorf("n => %v, want %v", n, i) + } + } + + rows.Close() + + rows, err = batch.QueryResults() + if err != nil { + t.Error(err) + } + + for i := 0; rows.Next(); i++ { + var n int + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + if n != i { + t.Errorf("n => %v, want %v", n, i) + } + } + + if rows.Err() != nil { + t.Error(rows.Err()) + } + + err = batch.Close() + if err != nil { + t.Fatal(err) + } + + ensureConnValid(t, conn) +} + +func TestConnBeginBatchQueryError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + batch := conn.BeginBatch() + batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + batch.Queue("select n from generate_series(0,5) n", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + + err := batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + rows, err := batch.QueryResults() + if err != nil { + t.Error(err) + } + + for i := 0; rows.Next(); i++ { + var n int + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + if n != i { + t.Errorf("n => %v, want %v", n, i) + } + } + + if pgErr, ok := rows.Err().(pgx.PgError); !(ok && pgErr.Code == "22012") { + t.Errorf("rows.Err() => %v, want error code %v", rows.Err(), 22012) + } + + err = batch.Close() + if pgErr, ok := err.(pgx.PgError); !(ok && pgErr.Code == "22012") { + t.Errorf("rows.Err() => %v, want error code %v", err, 22012) + } + + if conn.IsAlive() { + t.Error("conn should be dead, but was alive") + } +} diff --git a/bench_test.go b/bench_test.go index d3525df5..7f82891e 100644 --- a/bench_test.go +++ b/bench_test.go @@ -2,6 +2,7 @@ package pgx_test import ( "bytes" + "context" "fmt" "strings" "testing" @@ -609,3 +610,92 @@ func BenchmarkWrite10000RowsViaMultiInsert(b *testing.B) { func BenchmarkWrite10000RowsViaCopy(b *testing.B) { benchmarkWriteNRowsViaCopy(b, 10000) } + +func BenchmarkMultipleQueriesNonBatch(b *testing.B) { + config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: 5} + pool, err := pgx.NewConnPool(config) + if err != nil { + b.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + queryCount := 3 + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for j := 0; j < queryCount; j++ { + rows, err := pool.Query("select n from generate_series(0, 5) n") + if err != nil { + b.Fatal(err) + } + + for k := 0; rows.Next(); k++ { + var n int + if err := rows.Scan(&n); err != nil { + b.Fatal(err) + } + if n != k { + b.Fatalf("n => %v, want %v", n, k) + } + } + + if rows.Err() != nil { + b.Fatal(rows.Err()) + } + } + } +} + +func BenchmarkMultipleQueriesBatch(b *testing.B) { + config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: 5} + pool, err := pgx.NewConnPool(config) + if err != nil { + b.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + queryCount := 3 + + b.ResetTimer() + for i := 0; i < b.N; i++ { + batch := pool.BeginBatch() + for j := 0; j < queryCount; j++ { + batch.Queue("select n from generate_series(0,5) n", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + } + + err := batch.Send(context.Background(), nil) + if err != nil { + b.Fatal(err) + } + + for j := 0; j < queryCount; j++ { + rows, err := batch.QueryResults() + if err != nil { + b.Fatal(err) + } + + for k := 0; rows.Next(); k++ { + var n int + if err := rows.Scan(&n); err != nil { + b.Fatal(err) + } + if n != k { + b.Fatalf("n => %v, want %v", n, k) + } + } + + if rows.Err() != nil { + b.Fatal(rows.Err()) + } + } + + err = batch.Close() + if err != nil { + b.Fatal(err) + } + } +} diff --git a/conn_pool.go b/conn_pool.go index 42200b85..fdfc70f5 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -536,3 +536,10 @@ func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc C return c.CopyFrom(tableName, columnNames, rowSrc) } + +// BeginBatch acquires a connection and begins a batch on that connection. When +// *Batch is finished, the connection is released automatically. +func (p *ConnPool) BeginBatch() *Batch { + c, err := p.Acquire() + return &Batch{conn: c, connPool: p, err: err} +} diff --git a/conn_pool_test.go b/conn_pool_test.go index 560ab3ae..4e0dc199 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -981,3 +981,70 @@ func TestConnPoolPrepareWhenConnIsAlreadyAcquired(t *testing.T) { t.Errorf("Expected error calling deallocated prepared statement, but got: %v", err) } } + +func TestConnPoolBeginBatch(t *testing.T) { + t.Parallel() + + pool := createConnPool(t, 2) + defer pool.Close() + + batch := pool.BeginBatch() + batch.Queue("select n from generate_series(0,5) n", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + batch.Queue("select n from generate_series(0,5) n", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + + err := batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + rows, err := batch.QueryResults() + if err != nil { + t.Error(err) + } + + for i := 0; rows.Next(); i++ { + var n int + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + if n != i { + t.Errorf("n => %v, want %v", n, i) + } + } + + if rows.Err() != nil { + t.Error(rows.Err()) + } + + rows, err = batch.QueryResults() + if err != nil { + t.Error(err) + } + + for i := 0; rows.Next(); i++ { + var n int + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + if n != i { + t.Errorf("n => %v, want %v", n, i) + } + } + + if rows.Err() != nil { + t.Error(rows.Err()) + } + + err = batch.Close() + if err != nil { + t.Fatal(err) + } +} diff --git a/query.go b/query.go index a3903a22..6c9f6ab0 100644 --- a/query.go +++ b/query.go @@ -43,6 +43,7 @@ func (r *Row) Scan(dest ...interface{}) (err error) { type Rows struct { conn *Conn connPool *ConnPool + batch *Batch values [][]byte fields []FieldDescription rowCount int @@ -84,6 +85,10 @@ func (rows *Rows) Close() { rows.conn.log(LogLevelError, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args)}) } + if rows.batch != nil && rows.err != nil { + rows.batch.die(rows.err) + } + if rows.connPool != nil { rows.connPool.Release(rows.conn) } diff --git a/v3.md b/v3.md index 33a27d2d..b369be18 100644 --- a/v3.md +++ b/v3.md @@ -50,6 +50,8 @@ Removed Tx.Conn() Added ctx parameter to (Conn/Tx/ConnPool).PrepareEx +Added batch operations + ## TODO / Possible / Investigate Organize errors better From 27ab289096aa2af3ebe5530c30cef8cd73d8e520 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 3 Jun 2017 11:53:49 -0500 Subject: [PATCH 239/264] Use Go casing convention for OID --- batch.go | 18 +++---- batch_test.go | 8 +-- conn.go | 68 +++++++++++------------ conn_pool.go | 2 +- conn_test.go | 6 +-- example_custom_type_test.go | 2 +- fastpath.go | 14 ++--- large_objects.go | 12 ++--- messages.go | 8 +-- pgmock/pgmock.go | 2 +- pgtype/array.go | 6 +-- pgtype/bool_array.go | 2 +- pgtype/bytea_array.go | 2 +- pgtype/cidr_array.go | 2 +- pgtype/date_array.go | 2 +- pgtype/float4_array.go | 2 +- pgtype/float8_array.go | 2 +- pgtype/hstore_array.go | 2 +- pgtype/inet_array.go | 2 +- pgtype/int2_array.go | 2 +- pgtype/int4_array.go | 2 +- pgtype/int8_array.go | 2 +- pgtype/numeric_array.go | 2 +- pgtype/oid.go | 30 +++++------ pgtype/oid_value.go | 28 +++++----- pgtype/oid_value_test.go | 30 +++++------ pgtype/pgtype.go | 104 ++++++++++++++++++------------------ pgtype/record.go | 6 +-- pgtype/text_array.go | 2 +- pgtype/timestamp_array.go | 2 +- pgtype/timestamptz_array.go | 2 +- pgtype/typed_array.go.erb | 2 +- pgtype/varchar_array.go | 2 +- query.go | 28 +++++----- query_test.go | 12 ++--- replication.go | 6 +-- stdlib/sql.go | 62 ++++++++++----------- values.go | 8 +-- 38 files changed, 247 insertions(+), 247 deletions(-) diff --git a/batch.go b/batch.go index a2e8e042..3c16fd13 100644 --- a/batch.go +++ b/batch.go @@ -10,7 +10,7 @@ import ( type batchItem struct { query string arguments []interface{} - parameterOids []pgtype.Oid + parameterOIDs []pgtype.OID resultFormatCodes []int16 } @@ -36,14 +36,14 @@ func (b *Batch) Conn() *Conn { return b.conn } -// Queue queues a query to batch b. parameterOids are required if there are +// Queue queues a query to batch b. parameterOIDs are required if there are // parameters and query is not the name of a prepared statement. // resultFormatCodes are required if there is a result. -func (b *Batch) Queue(query string, arguments []interface{}, parameterOids []pgtype.Oid, resultFormatCodes []int16) { +func (b *Batch) Queue(query string, arguments []interface{}, parameterOIDs []pgtype.OID, resultFormatCodes []int16) { b.items = append(b.items, &batchItem{ query: query, arguments: arguments, - parameterOids: parameterOids, + parameterOIDs: parameterOIDs, resultFormatCodes: resultFormatCodes, }) } @@ -76,18 +76,18 @@ func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error { for _, bi := range b.items { var psName string - var psParameterOids []pgtype.Oid + var psParameterOIDs []pgtype.OID if ps, ok := b.conn.preparedStatements[bi.query]; ok { psName = ps.Name - psParameterOids = ps.ParameterOids + psParameterOIDs = ps.ParameterOIDs } else { - psParameterOids = bi.parameterOids - buf = appendParse(buf, "", bi.query, psParameterOids) + psParameterOIDs = bi.parameterOIDs + buf = appendParse(buf, "", bi.query, psParameterOIDs) } var err error - buf, err = appendBind(buf, "", psName, b.conn.ConnInfo, psParameterOids, bi.arguments, bi.resultFormatCodes) + buf, err = appendBind(buf, "", psName, b.conn.ConnInfo, psParameterOIDs, bi.arguments, bi.resultFormatCodes) if err != nil { return err } diff --git a/batch_test.go b/batch_test.go index bccf9a20..ffd3cc50 100644 --- a/batch_test.go +++ b/batch_test.go @@ -24,17 +24,17 @@ func TestConnBeginBatch(t *testing.T) { batch := conn.BeginBatch() batch.Queue("insert into ledger(description, amount) values($1, $2)", []interface{}{"q1", 1}, - []pgtype.Oid{pgtype.VarcharOid, pgtype.Int4Oid}, + []pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID}, nil, ) batch.Queue("insert into ledger(description, amount) values($1, $2)", []interface{}{"q2", 2}, - []pgtype.Oid{pgtype.VarcharOid, pgtype.Int4Oid}, + []pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID}, nil, ) batch.Queue("insert into ledger(description, amount) values($1, $2)", []interface{}{"q3", 3}, - []pgtype.Oid{pgtype.VarcharOid, pgtype.Int4Oid}, + []pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID}, nil, ) batch.Queue("select id, description, amount from ledger order by id", @@ -220,7 +220,7 @@ func TestConnBeginBatchContextCancelBeforeExecResults(t *testing.T) { batch := conn.BeginBatch() batch.Queue("insert into ledger(description, amount) values($1, $2)", []interface{}{"q1", 1}, - []pgtype.Oid{pgtype.VarcharOid, pgtype.Int4Oid}, + []pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID}, nil, ) batch.Queue("select pg_sleep(2)", diff --git a/conn.go b/conn.go index 491f2a9e..223808c5 100644 --- a/conn.go +++ b/conn.go @@ -39,11 +39,11 @@ var minimalConnInfo *pgtype.ConnInfo func init() { minimalConnInfo = pgtype.NewConnInfo() - minimalConnInfo.InitializeDataTypes(map[string]pgtype.Oid{ - "int4": pgtype.Int4Oid, - "name": pgtype.NameOid, - "oid": pgtype.OidOid, - "text": pgtype.TextOid, + minimalConnInfo.InitializeDataTypes(map[string]pgtype.OID{ + "int4": pgtype.Int4OID, + "name": pgtype.NameOID, + "oid": pgtype.OIDOID, + "text": pgtype.TextOID, }) } @@ -126,12 +126,12 @@ type PreparedStatement struct { Name string SQL string FieldDescriptions []FieldDescription - ParameterOids []pgtype.Oid + ParameterOIDs []pgtype.OID } // PrepareExOptions is an option struct that can be passed to PrepareEx type PrepareExOptions struct { - ParameterOids []pgtype.Oid + ParameterOIDs []pgtype.OID } // Notification is a message received from the PostgreSQL LISTEN/NOTIFY system @@ -373,7 +373,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl } func (c *Conn) initConnInfo() error { - nameOids := make(map[string]pgtype.Oid, 256) + nameOIDs := make(map[string]pgtype.OID, 256) rows, err := c.Query(`select t.oid, t.typname from pg_type t @@ -387,13 +387,13 @@ where ( } for rows.Next() { - var oid pgtype.Oid + var oid pgtype.OID var name pgtype.Text if err := rows.Scan(&oid, &name); err != nil { return err } - nameOids[name.String] = oid + nameOIDs[name.String] = oid } if rows.Err() != nil { @@ -401,7 +401,7 @@ where ( } c.ConnInfo = pgtype.NewConnInfo() - c.ConnInfo.InitializeDataTypes(nameOids) + c.ConnInfo.InitializeDataTypes(nameOIDs) return nil } @@ -725,7 +725,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { // PrepareEx creates a prepared statement with name and sql. sql can contain placeholders // for bound parameters. These placeholders are referenced positional as $1, $2, etc. -// It defers from Prepare as it allows additional options (such as parameter Oids) to be passed via struct +// It defers from Prepare as it allows additional options (such as parameter OIDs) to be passed via struct // // PrepareEx is idempotent; i.e. it is safe to call PrepareEx multiple times with the same // name and sql arguments. This allows a code path to PrepareEx and Query/Exec without @@ -769,11 +769,11 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared opts = &PrepareExOptions{} } - if len(opts.ParameterOids) > 65535 { - return nil, fmt.Errorf("Number of PrepareExOptions ParameterOids must be between 0 and 65535, received %d", len(opts.ParameterOids)) + if len(opts.ParameterOIDs) > 65535 { + return nil, fmt.Errorf("Number of PrepareExOptions ParameterOIDs must be between 0 and 65535, received %d", len(opts.ParameterOIDs)) } - buf := appendParse(c.wbuf, name, sql, opts.ParameterOids) + buf := appendParse(c.wbuf, name, sql, opts.ParameterOIDs) buf = appendDescribe(buf, 'S', name) buf = appendSync(buf) @@ -798,15 +798,15 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared switch msg := msg.(type) { case *pgproto3.ParameterDescription: - ps.ParameterOids = c.rxParameterDescription(msg) + ps.ParameterOIDs = c.rxParameterDescription(msg) - if len(ps.ParameterOids) > 65535 && softErr == nil { - softErr = fmt.Errorf("PostgreSQL supports maximum of 65535 parameters, received %d", len(ps.ParameterOids)) + if len(ps.ParameterOIDs) > 65535 && softErr == nil { + softErr = fmt.Errorf("PostgreSQL supports maximum of 65535 parameters, received %d", len(ps.ParameterOIDs)) } case *pgproto3.RowDescription: ps.FieldDescriptions = c.rxRowDescription(msg) for i := range ps.FieldDescriptions { - if dt, ok := c.ConnInfo.DataTypeForOid(ps.FieldDescriptions[i].DataType); ok { + if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok { ps.FieldDescriptions[i].DataTypeName = dt.Name if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { ps.FieldDescriptions[i].FormatCode = BinaryFormatCode @@ -1020,8 +1020,8 @@ func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error { } func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}) (err error) { - if len(ps.ParameterOids) != len(arguments) { - return fmt.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOids), len(arguments)) + if len(ps.ParameterOIDs) != len(arguments) { + return fmt.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOIDs), len(arguments)) } if err := c.ensureConnectionReadyForQuery(); err != nil { @@ -1032,7 +1032,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} for i, fd := range ps.FieldDescriptions { resultFormatCodes[i] = fd.FormatCode } - buf, err := appendBind(c.wbuf, "", ps.Name, c.ConnInfo, ps.ParameterOids, arguments, resultFormatCodes) + buf, err := appendBind(c.wbuf, "", ps.Name, c.ConnInfo, ps.ParameterOIDs, arguments, resultFormatCodes) if err != nil { return err } @@ -1177,9 +1177,9 @@ func (c *Conn) rxRowDescription(msg *pgproto3.RowDescription) []FieldDescription fields := make([]FieldDescription, len(msg.Fields)) for i := 0; i < len(fields); i++ { fields[i].Name = msg.Fields[i].Name - fields[i].Table = pgtype.Oid(msg.Fields[i].TableOID) + fields[i].Table = pgtype.OID(msg.Fields[i].TableOID) fields[i].AttributeNumber = msg.Fields[i].TableAttributeNumber - fields[i].DataType = pgtype.Oid(msg.Fields[i].DataTypeOID) + fields[i].DataType = pgtype.OID(msg.Fields[i].DataTypeOID) fields[i].DataTypeSize = msg.Fields[i].DataTypeSize fields[i].Modifier = msg.Fields[i].TypeModifier fields[i].FormatCode = msg.Fields[i].Format @@ -1187,10 +1187,10 @@ func (c *Conn) rxRowDescription(msg *pgproto3.RowDescription) []FieldDescription return fields } -func (c *Conn) rxParameterDescription(msg *pgproto3.ParameterDescription) []pgtype.Oid { - parameters := make([]pgtype.Oid, len(msg.ParameterOIDs)) +func (c *Conn) rxParameterDescription(msg *pgproto3.ParameterDescription) []pgtype.OID { + parameters := make([]pgtype.OID, len(msg.ParameterOIDs)) for i := 0; i < len(parameters); i++ { - parameters[i] = pgtype.Oid(msg.ParameterOIDs[i]) + parameters[i] = pgtype.OID(msg.ParameterOIDs[i]) } return parameters } @@ -1418,7 +1418,7 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, if err != nil { return "", err } - } else if options != nil && len(options.ParameterOids) > 0 { + } else if options != nil && len(options.ParameterOIDs) > 0 { buf, err := c.buildOneRoundTripExec(c.wbuf, sql, options, arguments) if err != nil { return "", err @@ -1477,16 +1477,16 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, } func (c *Conn) buildOneRoundTripExec(buf []byte, sql string, options *QueryExOptions, arguments []interface{}) ([]byte, error) { - if len(arguments) != len(options.ParameterOids) { - return nil, fmt.Errorf("mismatched number of arguments (%d) and options.ParameterOids (%d)", len(arguments), len(options.ParameterOids)) + if len(arguments) != len(options.ParameterOIDs) { + return nil, fmt.Errorf("mismatched number of arguments (%d) and options.ParameterOIDs (%d)", len(arguments), len(options.ParameterOIDs)) } - if len(options.ParameterOids) > 65535 { - return nil, fmt.Errorf("Number of QueryExOptions ParameterOids must be between 0 and 65535, received %d", len(options.ParameterOids)) + if len(options.ParameterOIDs) > 65535 { + return nil, fmt.Errorf("Number of QueryExOptions ParameterOIDs must be between 0 and 65535, received %d", len(options.ParameterOIDs)) } - buf = appendParse(buf, "", sql, options.ParameterOids) - buf, err := appendBind(buf, "", "", c.ConnInfo, options.ParameterOids, arguments, nil) + buf = appendParse(buf, "", sql, options.ParameterOIDs) + buf, err := appendBind(buf, "", "", c.ConnInfo, options.ParameterOIDs, arguments, nil) if err != nil { return nil, err } diff --git a/conn_pool.go b/conn_pool.go index fdfc70f5..40c58f49 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -434,7 +434,7 @@ func (p *ConnPool) Prepare(name, sql string) (*PreparedStatement, error) { // // PrepareEx creates a prepared statement with name and sql. sql can contain placeholders // for bound parameters. These placeholders are referenced positional as $1, $2, etc. -// It defers from Prepare as it allows additional options (such as parameter Oids) to be passed via struct +// It defers from Prepare as it allows additional options (such as parameter OIDs) to be passed via struct // // PrepareEx is idempotent; i.e. it is safe to call PrepareEx multiple times with the same // name and sql arguments. This allows a code path to PrepareEx and Query/Exec/Prepare without diff --git a/conn_test.go b/conn_test.go index a7fbbcf1..8ec3c131 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1166,7 +1166,7 @@ func TestConnExecExSuppliedCorrectParameterOIDs(t *testing.T) { commandTag, err := conn.ExecEx( context.Background(), "insert into foo(name) values($1);", - &pgx.QueryExOptions{ParameterOids: []pgtype.Oid{pgtype.VarcharOid}}, + &pgx.QueryExOptions{ParameterOIDs: []pgtype.OID{pgtype.VarcharOID}}, "bar'; drop table foo;--", ) if err != nil { @@ -1188,7 +1188,7 @@ func TestConnExecExSuppliedIncorrectParameterOIDs(t *testing.T) { _, err := conn.ExecEx( context.Background(), "insert into foo(name) values($1);", - &pgx.QueryExOptions{ParameterOids: []pgtype.Oid{pgtype.Int4Oid}}, + &pgx.QueryExOptions{ParameterOIDs: []pgtype.OID{pgtype.Int4OID}}, "bar'; drop table foo;--", ) if err == nil { @@ -1367,7 +1367,7 @@ func TestPrepareEx(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - _, err := conn.PrepareEx(context.Background(), "test", "select $1", &pgx.PrepareExOptions{ParameterOids: []pgtype.Oid{pgtype.TextOid}}) + _, err := conn.PrepareEx(context.Background(), "test", "select $1", &pgx.PrepareExOptions{ParameterOIDs: []pgtype.OID{pgtype.TextOID}}) if err != nil { t.Errorf("Unable to prepare statement: %v", err) return diff --git a/example_custom_type_test.go b/example_custom_type_test.go index 647b97e6..66ed6c53 100644 --- a/example_custom_type_test.go +++ b/example_custom_type_test.go @@ -81,7 +81,7 @@ func Example_CustomType() { conn.ConnInfo.RegisterDataType(pgtype.DataType{ Value: &Point{}, Name: "point", - Oid: 600, + OID: 600, }) p := &Point{} diff --git a/fastpath.go b/fastpath.go index 776be177..06e1354a 100644 --- a/fastpath.go +++ b/fastpath.go @@ -9,26 +9,26 @@ import ( ) func newFastpath(cn *Conn) *fastpath { - return &fastpath{cn: cn, fns: make(map[string]pgtype.Oid)} + return &fastpath{cn: cn, fns: make(map[string]pgtype.OID)} } type fastpath struct { cn *Conn - fns map[string]pgtype.Oid + fns map[string]pgtype.OID } -func (f *fastpath) functionOid(name string) pgtype.Oid { +func (f *fastpath) functionOID(name string) pgtype.OID { return f.fns[name] } -func (f *fastpath) addFunction(name string, oid pgtype.Oid) { +func (f *fastpath) addFunction(name string, oid pgtype.OID) { f.fns[name] = oid } func (f *fastpath) addFunctions(rows *Rows) error { for rows.Next() { var name string - var oid pgtype.Oid + var oid pgtype.OID if err := rows.Scan(&name, &oid); err != nil { return err } @@ -51,7 +51,7 @@ func fpInt64Arg(n int64) fpArg { return res } -func (f *fastpath) Call(oid pgtype.Oid, args []fpArg) (res []byte, err error) { +func (f *fastpath) Call(oid pgtype.OID, args []fpArg) (res []byte, err error) { if err := f.cn.ensureConnectionReadyForQuery(); err != nil { return nil, err } @@ -98,7 +98,7 @@ func (f *fastpath) Call(oid pgtype.Oid, args []fpArg) (res []byte, err error) { } func (f *fastpath) CallFn(fn string, args []fpArg) ([]byte, error) { - return f.Call(f.functionOid(fn), args) + return f.Call(f.functionOID(fn), args) } func fpInt32(data []byte, err error) (int32, error) { diff --git a/large_objects.go b/large_objects.go index bb65e623..e109bce2 100644 --- a/large_objects.go +++ b/large_objects.go @@ -61,20 +61,20 @@ const ( ) // Create creates a new large object. If id is zero, the server assigns an -// unused Oid. -func (o *LargeObjects) Create(id pgtype.Oid) (pgtype.Oid, error) { - newOid, err := fpInt32(o.fp.CallFn("lo_create", []fpArg{fpIntArg(int32(id))})) - return pgtype.Oid(newOid), err +// unused OID. +func (o *LargeObjects) Create(id pgtype.OID) (pgtype.OID, error) { + newOID, err := fpInt32(o.fp.CallFn("lo_create", []fpArg{fpIntArg(int32(id))})) + return pgtype.OID(newOID), err } // Open opens an existing large object with the given mode. -func (o *LargeObjects) Open(oid pgtype.Oid, mode LargeObjectMode) (*LargeObject, error) { +func (o *LargeObjects) Open(oid pgtype.OID, mode LargeObjectMode) (*LargeObject, error) { fd, err := fpInt32(o.fp.CallFn("lo_open", []fpArg{fpIntArg(int32(oid)), fpIntArg(int32(mode))})) return &LargeObject{fd: fd, lo: o}, err } // Unlink removes a large object from the database. -func (o *LargeObjects) Unlink(oid pgtype.Oid) error { +func (o *LargeObjects) Unlink(oid pgtype.OID) error { _, err := o.fp.CallFn("lo_unlink", []fpArg{fpIntArg(int32(oid))}) return err } diff --git a/messages.go b/messages.go index 0bf501b4..841aa286 100644 --- a/messages.go +++ b/messages.go @@ -13,9 +13,9 @@ const ( type FieldDescription struct { Name string - Table pgtype.Oid + Table pgtype.OID AttributeNumber uint16 - DataType pgtype.Oid + DataType pgtype.OID DataTypeSize int16 DataTypeName string Modifier uint32 @@ -50,7 +50,7 @@ func (pe PgError) Error() string { } // appendParse appends a PostgreSQL wire protocol parse message to buf and returns it. -func appendParse(buf []byte, name string, query string, parameterOIDs []pgtype.Oid) []byte { +func appendParse(buf []byte, name string, query string, parameterOIDs []pgtype.OID) []byte { buf = append(buf, 'P') sp := len(buf) buf = pgio.AppendInt32(buf, -1) @@ -95,7 +95,7 @@ func appendBind( destinationPortal, preparedStatement string, connInfo *pgtype.ConnInfo, - parameterOIDs []pgtype.Oid, + parameterOIDs []pgtype.OID, arguments []interface{}, resultFormatCodes []int16, ) ([]byte, error) { diff --git a/pgmock/pgmock.go b/pgmock/pgmock.go index 3f1e54f4..b3a51729 100644 --- a/pgmock/pgmock.go +++ b/pgmock/pgmock.go @@ -242,7 +242,7 @@ func PgxInitSteps() []Step { } rowVals := []struct { - oid pgtype.Oid + oid pgtype.OID name string }{ {16, "bool"}, diff --git a/pgtype/array.go b/pgtype/array.go index 2f9ef66b..e5504455 100644 --- a/pgtype/array.go +++ b/pgtype/array.go @@ -18,7 +18,7 @@ import ( type ArrayHeader struct { ContainsNull bool - ElementOid int32 + ElementOID int32 Dimensions []ArrayDimension } @@ -40,7 +40,7 @@ func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) { dst.ContainsNull = binary.BigEndian.Uint32(src[rp:]) == 1 rp += 4 - dst.ElementOid = int32(binary.BigEndian.Uint32(src[rp:])) + dst.ElementOID = int32(binary.BigEndian.Uint32(src[rp:])) rp += 4 if numDims > 0 { @@ -69,7 +69,7 @@ func (src *ArrayHeader) EncodeBinary(ci *ConnInfo, buf []byte) []byte { } buf = pgio.AppendInt32(buf, containsNull) - buf = pgio.AppendInt32(buf, src.ElementOid) + buf = pgio.AppendInt32(buf, src.ElementOID) for i := range src.Dimensions { buf = pgio.AppendInt32(buf, src.Dimensions[i].Length) diff --git a/pgtype/bool_array.go b/pgtype/bool_array.go index 3c3d4184..e20a0381 100644 --- a/pgtype/bool_array.go +++ b/pgtype/bool_array.go @@ -231,7 +231,7 @@ func (src *BoolArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } if dt, ok := ci.DataTypeForName("bool"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "bool") } diff --git a/pgtype/bytea_array.go b/pgtype/bytea_array.go index 67e114f5..0d381693 100644 --- a/pgtype/bytea_array.go +++ b/pgtype/bytea_array.go @@ -231,7 +231,7 @@ func (src *ByteaArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } if dt, ok := ci.DataTypeForName("bytea"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "bytea") } diff --git a/pgtype/cidr_array.go b/pgtype/cidr_array.go index 01237aa1..b8a70d63 100644 --- a/pgtype/cidr_array.go +++ b/pgtype/cidr_array.go @@ -260,7 +260,7 @@ func (src *CidrArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } if dt, ok := ci.DataTypeForName("cidr"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "cidr") } diff --git a/pgtype/date_array.go b/pgtype/date_array.go index 2175f2aa..ef91cf3e 100644 --- a/pgtype/date_array.go +++ b/pgtype/date_array.go @@ -232,7 +232,7 @@ func (src *DateArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } if dt, ok := ci.DataTypeForName("date"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "date") } diff --git a/pgtype/float4_array.go b/pgtype/float4_array.go index 37db8acc..a35657b0 100644 --- a/pgtype/float4_array.go +++ b/pgtype/float4_array.go @@ -231,7 +231,7 @@ func (src *Float4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } if dt, ok := ci.DataTypeForName("float4"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "float4") } diff --git a/pgtype/float8_array.go b/pgtype/float8_array.go index dd3fccf1..486e3a4e 100644 --- a/pgtype/float8_array.go +++ b/pgtype/float8_array.go @@ -231,7 +231,7 @@ func (src *Float8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } if dt, ok := ci.DataTypeForName("float8"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "float8") } diff --git a/pgtype/hstore_array.go b/pgtype/hstore_array.go index 2d61fa52..3e5a003f 100644 --- a/pgtype/hstore_array.go +++ b/pgtype/hstore_array.go @@ -231,7 +231,7 @@ func (src *HstoreArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } if dt, ok := ci.DataTypeForName("hstore"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "hstore") } diff --git a/pgtype/inet_array.go b/pgtype/inet_array.go index e448a2ca..57123c1c 100644 --- a/pgtype/inet_array.go +++ b/pgtype/inet_array.go @@ -260,7 +260,7 @@ func (src *InetArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } if dt, ok := ci.DataTypeForName("inet"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "inet") } diff --git a/pgtype/int2_array.go b/pgtype/int2_array.go index 1d145584..e4993104 100644 --- a/pgtype/int2_array.go +++ b/pgtype/int2_array.go @@ -259,7 +259,7 @@ func (src *Int2Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } if dt, ok := ci.DataTypeForName("int2"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "int2") } diff --git a/pgtype/int4_array.go b/pgtype/int4_array.go index 1c746503..6bc06e86 100644 --- a/pgtype/int4_array.go +++ b/pgtype/int4_array.go @@ -259,7 +259,7 @@ func (src *Int4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } if dt, ok := ci.DataTypeForName("int4"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "int4") } diff --git a/pgtype/int8_array.go b/pgtype/int8_array.go index 56ebcab8..4404d22a 100644 --- a/pgtype/int8_array.go +++ b/pgtype/int8_array.go @@ -259,7 +259,7 @@ func (src *Int8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } if dt, ok := ci.DataTypeForName("int8"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "int8") } diff --git a/pgtype/numeric_array.go b/pgtype/numeric_array.go index 20f33dff..f193a2a5 100644 --- a/pgtype/numeric_array.go +++ b/pgtype/numeric_array.go @@ -259,7 +259,7 @@ func (src *NumericArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) } if dt, ok := ci.DataTypeForName("numeric"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "numeric") } diff --git a/pgtype/oid.go b/pgtype/oid.go index 6ceacc73..d37f4e57 100644 --- a/pgtype/oid.go +++ b/pgtype/oid.go @@ -9,18 +9,18 @@ import ( "github.com/jackc/pgx/pgio" ) -// Oid (Object Identifier Type) is, according to +// OID (Object Identifier Type) is, according to // https://www.postgresql.org/docs/current/static/datatype-oid.html, used // internally by PostgreSQL as a primary key for various system tables. It is // currently implemented as an unsigned four-byte integer. Its definition can be // found in src/include/postgres_ext.h in the PostgreSQL sources. Because it is -// so frequently required to be in a NOT NULL condition Oid cannot be NULL. To -// allow for NULL Oids use OidValue. -type Oid uint32 +// so frequently required to be in a NOT NULL condition OID cannot be NULL. To +// allow for NULL OIDs use OIDValue. +type OID uint32 -func (dst *Oid) DecodeText(ci *ConnInfo, src []byte) error { +func (dst *OID) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - return fmt.Errorf("cannot decode nil into Oid") + return fmt.Errorf("cannot decode nil into OID") } n, err := strconv.ParseUint(string(src), 10, 32) @@ -28,13 +28,13 @@ func (dst *Oid) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Oid(n) + *dst = OID(n) return nil } -func (dst *Oid) DecodeBinary(ci *ConnInfo, src []byte) error { +func (dst *OID) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - return fmt.Errorf("cannot decode nil into Oid") + return fmt.Errorf("cannot decode nil into OID") } if len(src) != 4 { @@ -42,27 +42,27 @@ func (dst *Oid) DecodeBinary(ci *ConnInfo, src []byte) error { } n := binary.BigEndian.Uint32(src) - *dst = Oid(n) + *dst = OID(n) return nil } -func (src Oid) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src OID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return append(buf, strconv.FormatUint(uint64(src), 10)...), nil } -func (src Oid) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src OID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return pgio.AppendUint32(buf, uint32(src)), nil } // Scan implements the database/sql Scanner interface. -func (dst *Oid) Scan(src interface{}) error { +func (dst *OID) Scan(src interface{}) error { if src == nil { return fmt.Errorf("cannot scan NULL into %T", src) } switch src := src.(type) { case int64: - *dst = Oid(src) + *dst = OID(src) return nil case string: return dst.DecodeText(nil, []byte(src)) @@ -76,6 +76,6 @@ func (dst *Oid) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src Oid) Value() (driver.Value, error) { +func (src OID) Value() (driver.Value, error) { return int64(src), nil } diff --git a/pgtype/oid_value.go b/pgtype/oid_value.go index 882d54fb..7eae4bf1 100644 --- a/pgtype/oid_value.go +++ b/pgtype/oid_value.go @@ -4,52 +4,52 @@ import ( "database/sql/driver" ) -// OidValue (Object Identifier Type) is, according to -// https://www.postgresql.org/docs/current/static/datatype-OidValue.html, used +// OIDValue (Object Identifier Type) is, according to +// https://www.postgresql.org/docs/current/static/datatype-OIDValue.html, used // internally by PostgreSQL as a primary key for various system tables. It is // currently implemented as an unsigned four-byte integer. Its definition can be // found in src/include/postgres_ext.h in the PostgreSQL sources. -type OidValue pguint32 +type OIDValue pguint32 -// Set converts from src to dst. Note that as OidValue is not a general +// Set converts from src to dst. Note that as OIDValue is not a general // number type Set does not do automatic type conversion as other number // types do. -func (dst *OidValue) Set(src interface{}) error { +func (dst *OIDValue) Set(src interface{}) error { return (*pguint32)(dst).Set(src) } -func (dst *OidValue) Get() interface{} { +func (dst *OIDValue) Get() interface{} { return (*pguint32)(dst).Get() } -// AssignTo assigns from src to dst. Note that as OidValue is not a general number +// AssignTo assigns from src to dst. Note that as OIDValue is not a general number // type AssignTo does not do automatic type conversion as other number types do. -func (src *OidValue) AssignTo(dst interface{}) error { +func (src *OIDValue) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *OidValue) DecodeText(ci *ConnInfo, src []byte) error { +func (dst *OIDValue) DecodeText(ci *ConnInfo, src []byte) error { return (*pguint32)(dst).DecodeText(ci, src) } -func (dst *OidValue) DecodeBinary(ci *ConnInfo, src []byte) error { +func (dst *OIDValue) DecodeBinary(ci *ConnInfo, src []byte) error { return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src *OidValue) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *OIDValue) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return (*pguint32)(src).EncodeText(ci, buf) } -func (src *OidValue) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *OIDValue) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return (*pguint32)(src).EncodeBinary(ci, buf) } // Scan implements the database/sql Scanner interface. -func (dst *OidValue) Scan(src interface{}) error { +func (dst *OIDValue) Scan(src interface{}) error { return (*pguint32)(dst).Scan(src) } // Value implements the database/sql/driver Valuer interface. -func (src *OidValue) Value() (driver.Value, error) { +func (src *OIDValue) Value() (driver.Value, error) { return (*pguint32)(src).Value() } diff --git a/pgtype/oid_value_test.go b/pgtype/oid_value_test.go index 52ce4064..f5ff16cf 100644 --- a/pgtype/oid_value_test.go +++ b/pgtype/oid_value_test.go @@ -8,23 +8,23 @@ import ( "github.com/jackc/pgx/pgtype/testutil" ) -func TestOidValueTranscode(t *testing.T) { +func TestOIDValueTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "oid", []interface{}{ - &pgtype.OidValue{Uint: 42, Status: pgtype.Present}, - &pgtype.OidValue{Status: pgtype.Null}, + &pgtype.OIDValue{Uint: 42, Status: pgtype.Present}, + &pgtype.OIDValue{Status: pgtype.Null}, }) } -func TestOidValueSet(t *testing.T) { +func TestOIDValueSet(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.OidValue + result pgtype.OIDValue }{ - {source: uint32(1), result: pgtype.OidValue{Uint: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.OIDValue{Uint: 1, Status: pgtype.Present}}, } for i, tt := range successfulTests { - var r pgtype.OidValue + var r pgtype.OIDValue err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -36,17 +36,17 @@ func TestOidValueSet(t *testing.T) { } } -func TestOidValueAssignTo(t *testing.T) { +func TestOIDValueAssignTo(t *testing.T) { var ui32 uint32 var pui32 *uint32 simpleTests := []struct { - src pgtype.OidValue + src pgtype.OIDValue dst interface{} expected interface{} }{ - {src: pgtype.OidValue{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.OidValue{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + {src: pgtype.OIDValue{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.OIDValue{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, } for i, tt := range simpleTests { @@ -61,11 +61,11 @@ func TestOidValueAssignTo(t *testing.T) { } pointerAllocTests := []struct { - src pgtype.OidValue + src pgtype.OIDValue dst interface{} expected interface{} }{ - {src: pgtype.OidValue{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + {src: pgtype.OIDValue{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, } for i, tt := range pointerAllocTests { @@ -80,10 +80,10 @@ func TestOidValueAssignTo(t *testing.T) { } errorTests := []struct { - src pgtype.OidValue + src pgtype.OIDValue dst interface{} }{ - {src: pgtype.OidValue{Status: pgtype.Null}, dst: &ui32}, + {src: pgtype.OIDValue{Status: pgtype.Null}, dst: &ui32}, } for i, tt := range errorTests { diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 847fce0f..4c1e86f6 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -7,47 +7,47 @@ import ( // PostgreSQL oids for common types const ( - BoolOid = 16 - ByteaOid = 17 - CharOid = 18 - NameOid = 19 - Int8Oid = 20 - Int2Oid = 21 - Int4Oid = 23 - TextOid = 25 - OidOid = 26 - TidOid = 27 - XidOid = 28 - CidOid = 29 - JsonOid = 114 - CidrOid = 650 - CidrArrayOid = 651 - Float4Oid = 700 - Float8Oid = 701 - UnknownOid = 705 - InetOid = 869 - BoolArrayOid = 1000 - Int2ArrayOid = 1005 - Int4ArrayOid = 1007 - TextArrayOid = 1009 - ByteaArrayOid = 1001 - VarcharArrayOid = 1015 - Int8ArrayOid = 1016 - Float4ArrayOid = 1021 - Float8ArrayOid = 1022 - AclitemOid = 1033 - AclitemArrayOid = 1034 - InetArrayOid = 1041 - VarcharOid = 1043 - DateOid = 1082 - TimestampOid = 1114 - TimestampArrayOid = 1115 - DateArrayOid = 1182 - TimestamptzOid = 1184 - TimestamptzArrayOid = 1185 - RecordOid = 2249 - UuidOid = 2950 - JsonbOid = 3802 + BoolOID = 16 + ByteaOID = 17 + CharOID = 18 + NameOID = 19 + Int8OID = 20 + Int2OID = 21 + Int4OID = 23 + TextOID = 25 + OIDOID = 26 + TidOID = 27 + XidOID = 28 + CidOID = 29 + JsonOID = 114 + CidrOID = 650 + CidrArrayOID = 651 + Float4OID = 700 + Float8OID = 701 + UnknownOID = 705 + InetOID = 869 + BoolArrayOID = 1000 + Int2ArrayOID = 1005 + Int4ArrayOID = 1007 + TextArrayOID = 1009 + ByteaArrayOID = 1001 + VarcharArrayOID = 1015 + Int8ArrayOID = 1016 + Float4ArrayOID = 1021 + Float8ArrayOID = 1022 + AclitemOID = 1033 + AclitemArrayOID = 1034 + InetArrayOID = 1041 + VarcharOID = 1043 + DateOID = 1082 + TimestampOID = 1114 + TimestampArrayOID = 1115 + DateArrayOID = 1182 + TimestamptzOID = 1184 + TimestamptzArrayOID = 1185 + RecordOID = 2249 + UuidOID = 2950 + JsonbOID = 3802 ) type Status byte @@ -133,42 +133,42 @@ var errBadStatus = errors.New("invalid status") type DataType struct { Value Value Name string - Oid Oid + OID OID } type ConnInfo struct { - oidToDataType map[Oid]*DataType + oidToDataType map[OID]*DataType nameToDataType map[string]*DataType reflectTypeToDataType map[reflect.Type]*DataType } func NewConnInfo() *ConnInfo { return &ConnInfo{ - oidToDataType: make(map[Oid]*DataType, 256), + oidToDataType: make(map[OID]*DataType, 256), nameToDataType: make(map[string]*DataType, 256), reflectTypeToDataType: make(map[reflect.Type]*DataType, 256), } } -func (ci *ConnInfo) InitializeDataTypes(nameOids map[string]Oid) { - for name, oid := range nameOids { +func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]OID) { + for name, oid := range nameOIDs { var value Value if t, ok := nameValues[name]; ok { value = reflect.New(reflect.ValueOf(t).Elem().Type()).Interface().(Value) } else { value = &GenericText{} } - ci.RegisterDataType(DataType{Value: value, Name: name, Oid: oid}) + ci.RegisterDataType(DataType{Value: value, Name: name, OID: oid}) } } func (ci *ConnInfo) RegisterDataType(t DataType) { - ci.oidToDataType[t.Oid] = &t + ci.oidToDataType[t.OID] = &t ci.nameToDataType[t.Name] = &t ci.reflectTypeToDataType[reflect.ValueOf(t.Value).Type()] = &t } -func (ci *ConnInfo) DataTypeForOid(oid Oid) (*DataType, bool) { +func (ci *ConnInfo) DataTypeForOID(oid OID) (*DataType, bool) { dt, ok := ci.oidToDataType[oid] return dt, ok } @@ -186,7 +186,7 @@ func (ci *ConnInfo) DataTypeForValue(v Value) (*DataType, bool) { // DeepCopy makes a deep copy of the ConnInfo. func (ci *ConnInfo) DeepCopy() *ConnInfo { ci2 := &ConnInfo{ - oidToDataType: make(map[Oid]*DataType, len(ci.oidToDataType)), + oidToDataType: make(map[OID]*DataType, len(ci.oidToDataType)), nameToDataType: make(map[string]*DataType, len(ci.nameToDataType)), reflectTypeToDataType: make(map[reflect.Type]*DataType, len(ci.reflectTypeToDataType)), } @@ -195,7 +195,7 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo { ci2.RegisterDataType(DataType{ Value: reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(Value), Name: dt.Name, - Oid: dt.Oid, + OID: dt.OID, }) } @@ -250,7 +250,7 @@ func init() { "name": &Name{}, "numeric": &Numeric{}, "numrange": &Numrange{}, - "oid": &OidValue{}, + "oid": &OIDValue{}, "path": &Path{}, "point": &Point{}, "polygon": &Polygon{}, diff --git a/pgtype/record.go b/pgtype/record.go index 3b315d40..7c8736df 100644 --- a/pgtype/record.go +++ b/pgtype/record.go @@ -88,16 +88,16 @@ func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { if len(src[rp:]) < 8 { return fmt.Errorf("Record incomplete %v", src) } - fieldOid := Oid(binary.BigEndian.Uint32(src[rp:])) + fieldOID := OID(binary.BigEndian.Uint32(src[rp:])) rp += 4 fieldLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 var binaryDecoder BinaryDecoder - if dt, ok := ci.DataTypeForOid(fieldOid); ok { + if dt, ok := ci.DataTypeForOID(fieldOID); ok { if binaryDecoder, ok = dt.Value.(BinaryDecoder); !ok { - return fmt.Errorf("unknown oid while decoding record: %v", fieldOid) + return fmt.Errorf("unknown oid while decoding record: %v", fieldOID) } } diff --git a/pgtype/text_array.go b/pgtype/text_array.go index ed240e12..dab7d36e 100644 --- a/pgtype/text_array.go +++ b/pgtype/text_array.go @@ -231,7 +231,7 @@ func (src *TextArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } if dt, ok := ci.DataTypeForName("text"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "text") } diff --git a/pgtype/timestamp_array.go b/pgtype/timestamp_array.go index a4f1b9dd..fca9ad93 100644 --- a/pgtype/timestamp_array.go +++ b/pgtype/timestamp_array.go @@ -232,7 +232,7 @@ func (src *TimestampArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error } if dt, ok := ci.DataTypeForName("timestamp"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "timestamp") } diff --git a/pgtype/timestamptz_array.go b/pgtype/timestamptz_array.go index 34d4f8a8..e0866d69 100644 --- a/pgtype/timestamptz_array.go +++ b/pgtype/timestamptz_array.go @@ -232,7 +232,7 @@ func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, err } if dt, ok := ci.DataTypeForName("timestamptz"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "timestamptz") } diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb index 0d454ac8..01072549 100644 --- a/pgtype/typed_array.go.erb +++ b/pgtype/typed_array.go.erb @@ -234,7 +234,7 @@ func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byt } if dt, ok := ci.DataTypeForName("<%= element_type_name %>"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "<%= element_type_name %>") } diff --git a/pgtype/varchar_array.go b/pgtype/varchar_array.go index c34ac0b6..95b5cfc1 100644 --- a/pgtype/varchar_array.go +++ b/pgtype/varchar_array.go @@ -231,7 +231,7 @@ func (src *VarcharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) } if dt, ok := ci.DataTypeForName("varchar"); ok { - arrayHeader.ElementOid = int32(dt.Oid) + arrayHeader.ElementOID = int32(dt.OID) } else { return nil, fmt.Errorf("unable to find oid for type name %v", "varchar") } diff --git a/query.go b/query.go index 6c9f6ab0..c12d64f0 100644 --- a/query.go +++ b/query.go @@ -131,7 +131,7 @@ func (rows *Rows) Next() bool { case *pgproto3.RowDescription: rows.fields = rows.conn.rxRowDescription(msg) for i := range rows.fields { - if dt, ok := rows.conn.ConnInfo.DataTypeForOid(rows.fields[i].DataType); ok { + if dt, ok := rows.conn.ConnInfo.DataTypeForOID(rows.fields[i].DataType); ok { rows.fields[i].DataTypeName = dt.Name rows.fields[i].FormatCode = TextFormatCode } else { @@ -214,7 +214,7 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { rows.fatal(scanArgError{col: i, err: err}) } } else { - if dt, ok := rows.conn.ConnInfo.DataTypeForOid(fd.DataType); ok { + if dt, ok := rows.conn.ConnInfo.DataTypeForOID(fd.DataType); ok { value := dt.Value switch fd.FormatCode { case TextFormatCode: @@ -282,7 +282,7 @@ func (rows *Rows) Values() ([]interface{}, error) { continue } - if dt, ok := rows.conn.ConnInfo.DataTypeForOid(fd.DataType); ok { + if dt, ok := rows.conn.ConnInfo.DataTypeForOID(fd.DataType); ok { value := dt.Value switch fd.FormatCode { @@ -353,10 +353,10 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { } type QueryExOptions struct { - // When ParameterOids are present and the query is not a prepared statement, - // then ParameterOids and ResultFormatCodes will be used to avoid an extra + // When ParameterOIDs are present and the query is not a prepared statement, + // then ParameterOIDs and ResultFormatCodes will be used to avoid an extra // network round-trip. - ParameterOids []pgtype.Oid + ParameterOIDs []pgtype.OID ResultFormatCodes []int16 SimpleProtocol bool @@ -398,7 +398,7 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, return rows, nil } - if options != nil && len(options.ParameterOids) > 0 { + if options != nil && len(options.ParameterOIDs) > 0 { buf, err := c.buildOneRoundTripQueryEx(c.wbuf, sql, options, args) if err != nil { @@ -463,17 +463,17 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, } func (c *Conn) buildOneRoundTripQueryEx(buf []byte, sql string, options *QueryExOptions, arguments []interface{}) ([]byte, error) { - if len(arguments) != len(options.ParameterOids) { - return nil, fmt.Errorf("mismatched number of arguments (%d) and options.ParameterOids (%d)", len(arguments), len(options.ParameterOids)) + if len(arguments) != len(options.ParameterOIDs) { + return nil, fmt.Errorf("mismatched number of arguments (%d) and options.ParameterOIDs (%d)", len(arguments), len(options.ParameterOIDs)) } - if len(options.ParameterOids) > 65535 { - return nil, fmt.Errorf("Number of QueryExOptions ParameterOids must be between 0 and 65535, received %d", len(options.ParameterOids)) + if len(options.ParameterOIDs) > 65535 { + return nil, fmt.Errorf("Number of QueryExOptions ParameterOIDs must be between 0 and 65535, received %d", len(options.ParameterOIDs)) } - buf = appendParse(buf, "", sql, options.ParameterOids) + buf = appendParse(buf, "", sql, options.ParameterOIDs) buf = appendDescribe(buf, 'S', "") - buf, err := appendBind(buf, "", "", c.ConnInfo, options.ParameterOids, arguments, options.ResultFormatCodes) + buf, err := appendBind(buf, "", "", c.ConnInfo, options.ParameterOIDs, arguments, options.ResultFormatCodes) if err != nil { return nil, err } @@ -494,7 +494,7 @@ func (c *Conn) readUntilRowDescription() ([]FieldDescription, error) { case *pgproto3.RowDescription: fieldDescriptions := c.rxRowDescription(msg) for i := range fieldDescriptions { - if dt, ok := c.ConnInfo.DataTypeForOid(fieldDescriptions[i].DataType); ok { + if dt, ok := c.ConnInfo.DataTypeForOID(fieldDescriptions[i].DataType); ok { fieldDescriptions[i].DataTypeName = dt.Name } else { return nil, fmt.Errorf("unknown oid: %d", fieldDescriptions[i].DataType) diff --git a/query_test.go b/query_test.go index 4e128fb2..9379bd23 100644 --- a/query_test.go +++ b/query_test.go @@ -251,7 +251,7 @@ func TestConnQueryReadWrongTypeError(t *testing.T) { t.Fatal("Expected Rows to have an error after an improper read but it didn't") } - if rows.Err().Error() != "can't scan into dest[0]: Can't convert Oid 23 to time.Time" && !strings.Contains(rows.Err().Error(), "cannot assign") { + if rows.Err().Error() != "can't scan into dest[0]: Can't convert OID 23 to time.Time" && !strings.Contains(rows.Err().Error(), "cannot assign") { t.Fatalf("Expected different Rows.Err(): %v", rows.Err()) } @@ -389,7 +389,7 @@ func TestQueryRowCoreTypes(t *testing.T) { f64 float64 b bool t time.Time - oid pgtype.Oid + oid pgtype.OID } var actual, zero allTypes @@ -407,7 +407,7 @@ func TestQueryRowCoreTypes(t *testing.T) { {"select $1::timestamptz", []interface{}{time.Unix(123, 5000)}, []interface{}{&actual.t}, allTypes{t: time.Unix(123, 5000)}}, {"select $1::timestamp", []interface{}{time.Date(2010, 1, 2, 3, 4, 5, 0, time.UTC)}, []interface{}{&actual.t}, allTypes{t: time.Date(2010, 1, 2, 3, 4, 5, 0, time.UTC)}}, {"select $1::date", []interface{}{time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}, []interface{}{&actual.t}, allTypes{t: time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}}, - {"select $1::oid", []interface{}{pgtype.Oid(42)}, []interface{}{&actual.oid}, allTypes{oid: 42}}, + {"select $1::oid", []interface{}{pgtype.OID(42)}, []interface{}{&actual.oid}, allTypes{oid: 42}}, } for i, tt := range tests { @@ -768,12 +768,12 @@ func TestQueryRowUnknownType(t *testing.T) { conn.ConnInfo.RegisterDataType(pgtype.DataType{ Value: &pgtype.GenericText{}, Name: "point", - Oid: 600, + OID: 600, }) conn.ConnInfo.RegisterDataType(pgtype.DataType{ Value: &pgtype.Int4{}, Name: "int4", - Oid: pgtype.Int4Oid, + OID: pgtype.Int4OID, }) sql := "select $1::point" @@ -1193,7 +1193,7 @@ func TestConnQueryRowExSingleRoundTrip(t *testing.T) { context.Background(), "select $1 + $2", &pgx.QueryExOptions{ - ParameterOids: []pgtype.Oid{pgtype.Int4Oid, pgtype.Int4Oid}, + ParameterOIDs: []pgtype.OID{pgtype.Int4OID, pgtype.Int4OID}, ResultFormatCodes: []int16{pgx.BinaryFormatCode}, }, 1, 2, diff --git a/replication.go b/replication.go index eacc0c3f..1bf69c4e 100644 --- a/replication.go +++ b/replication.go @@ -348,7 +348,7 @@ func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) { rows.fields = rc.c.rxRowDescription(msg) // We don't have c.PgTypes here because we're a replication // connection. This means the field descriptions will have - // only Oids. Not much we can do about this. + // only OIDs. Not much we can do about this. default: if e := rc.c.processContextFreeMsg(msg); e != nil { rows.fatal(e) @@ -368,7 +368,7 @@ func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) { // // NOTE: Because this is a replication mode connection, we don't have // type names, so the field descriptions in the result will have only -// Oids and no DataTypeName values +// OIDs and no DataTypeName values func (rc *ReplicationConn) IdentifySystem() (r *Rows, err error) { return rc.sendReplicationModeQuery("IDENTIFY_SYSTEM") } @@ -383,7 +383,7 @@ func (rc *ReplicationConn) IdentifySystem() (r *Rows, err error) { // // NOTE: Because this is a replication mode connection, we don't have // type names, so the field descriptions in the result will have only -// Oids and no DataTypeName values +// OIDs and no DataTypeName values func (rc *ReplicationConn) TimelineHistory(timeline int) (r *Rows, err error) { return rc.sendReplicationModeQuery(fmt.Sprintf("TIMELINE_HISTORY %d", timeline)) } diff --git a/stdlib/sql.go b/stdlib/sql.go index aa45dd40..00329617 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -80,7 +80,7 @@ import ( // oids that map to intrinsic database/sql types. These will be allowed to be // binary, anything else will be forced to text format -var databaseSqlOids map[pgtype.Oid]bool +var databaseSqlOIDs map[pgtype.OID]bool var pgxDriver *Driver @@ -97,20 +97,20 @@ func init() { } sql.Register("pgx", pgxDriver) - databaseSqlOids = make(map[pgtype.Oid]bool) - databaseSqlOids[pgtype.BoolOid] = true - databaseSqlOids[pgtype.ByteaOid] = true - databaseSqlOids[pgtype.CidOid] = true - databaseSqlOids[pgtype.DateOid] = true - databaseSqlOids[pgtype.Float4Oid] = true - databaseSqlOids[pgtype.Float8Oid] = true - databaseSqlOids[pgtype.Int2Oid] = true - databaseSqlOids[pgtype.Int4Oid] = true - databaseSqlOids[pgtype.Int8Oid] = true - databaseSqlOids[pgtype.OidOid] = true - databaseSqlOids[pgtype.TimestampOid] = true - databaseSqlOids[pgtype.TimestamptzOid] = true - databaseSqlOids[pgtype.XidOid] = true + databaseSqlOIDs = make(map[pgtype.OID]bool) + databaseSqlOIDs[pgtype.BoolOID] = true + databaseSqlOIDs[pgtype.ByteaOID] = true + databaseSqlOIDs[pgtype.CidOID] = true + databaseSqlOIDs[pgtype.DateOID] = true + databaseSqlOIDs[pgtype.Float4OID] = true + databaseSqlOIDs[pgtype.Float8OID] = true + databaseSqlOIDs[pgtype.Int2OID] = true + databaseSqlOIDs[pgtype.Int4OID] = true + databaseSqlOIDs[pgtype.Int8OID] = true + databaseSqlOIDs[pgtype.OIDOID] = true + databaseSqlOIDs[pgtype.TimestampOID] = true + databaseSqlOIDs[pgtype.TimestamptzOID] = true + databaseSqlOIDs[pgtype.XidOID] = true } type Driver struct { @@ -364,7 +364,7 @@ func (c *Conn) Ping(ctx context.Context) error { // (e.g. []int32) func restrictBinaryToDatabaseSqlTypes(ps *pgx.PreparedStatement) { for i, _ := range ps.FieldDescriptions { - intrinsic, _ := databaseSqlOids[ps.FieldDescriptions[i].DataType] + intrinsic, _ := databaseSqlOIDs[ps.FieldDescriptions[i].DataType] if !intrinsic { ps.FieldDescriptions[i].FormatCode = pgx.TextFormatCode } @@ -381,7 +381,7 @@ func (s *Stmt) Close() error { } func (s *Stmt) NumInput() int { - return len(s.ps.ParameterOids) + return len(s.ps.ParameterOIDs) } func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) { @@ -428,31 +428,31 @@ func (r *Rows) Next(dest []driver.Value) error { r.values = make([]interface{}, len(r.rows.FieldDescriptions())) for i, fd := range r.rows.FieldDescriptions() { switch fd.DataType { - case pgtype.BoolOid: + case pgtype.BoolOID: r.values[i] = &pgtype.Bool{} - case pgtype.ByteaOid: + case pgtype.ByteaOID: r.values[i] = &pgtype.Bytea{} - case pgtype.CidOid: + case pgtype.CidOID: r.values[i] = &pgtype.Cid{} - case pgtype.DateOid: + case pgtype.DateOID: r.values[i] = &pgtype.Date{} - case pgtype.Float4Oid: + case pgtype.Float4OID: r.values[i] = &pgtype.Float4{} - case pgtype.Float8Oid: + case pgtype.Float8OID: r.values[i] = &pgtype.Float8{} - case pgtype.Int2Oid: + case pgtype.Int2OID: r.values[i] = &pgtype.Int2{} - case pgtype.Int4Oid: + case pgtype.Int4OID: r.values[i] = &pgtype.Int4{} - case pgtype.Int8Oid: + case pgtype.Int8OID: r.values[i] = &pgtype.Int8{} - case pgtype.OidOid: - r.values[i] = &pgtype.OidValue{} - case pgtype.TimestampOid: + case pgtype.OIDOID: + r.values[i] = &pgtype.OIDValue{} + case pgtype.TimestampOID: r.values[i] = &pgtype.Timestamp{} - case pgtype.TimestamptzOid: + case pgtype.TimestamptzOID: r.values[i] = &pgtype.Timestamptz{} - case pgtype.XidOid: + case pgtype.XidOID: r.values[i] = &pgtype.Xid{} default: r.values[i] = &pgtype.GenericText{} diff --git a/values.go b/values.go index ca5db50b..a6c350f6 100644 --- a/values.go +++ b/values.go @@ -97,7 +97,7 @@ func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, e return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg)) } -func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid pgtype.Oid, arg interface{}) ([]byte, error) { +func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid pgtype.OID, arg interface{}) ([]byte, error) { if arg == nil { return pgio.AppendInt32(buf, -1), nil } @@ -149,7 +149,7 @@ func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid pgtype return encodePreparedStatementArgument(ci, buf, oid, arg) } - if dt, ok := ci.DataTypeForOid(oid); ok { + if dt, ok := ci.DataTypeForOID(oid); ok { value := dt.Value err := value.Set(arg) if err != nil { @@ -178,7 +178,7 @@ func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid pgtype // chooseParameterFormatCode determines the correct format code for an // argument to a prepared statement. It defaults to TextFormatCode if no // determination can be made. -func chooseParameterFormatCode(ci *pgtype.ConnInfo, oid pgtype.Oid, arg interface{}) int16 { +func chooseParameterFormatCode(ci *pgtype.ConnInfo, oid pgtype.OID, arg interface{}) int16 { switch arg.(type) { case pgtype.BinaryEncoder: return BinaryFormatCode @@ -186,7 +186,7 @@ func chooseParameterFormatCode(ci *pgtype.ConnInfo, oid pgtype.Oid, arg interfac return TextFormatCode } - if dt, ok := ci.DataTypeForOid(oid); ok { + if dt, ok := ci.DataTypeForOID(oid); ok { if _, ok := dt.Value.(pgtype.BinaryEncoder); ok { if arg, ok := arg.(driver.Valuer); ok { if err := dt.Value.Set(arg); err != nil { From 3bdc94cee2f445760a98d1edfe0de730a29baea2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 3 Jun 2017 11:54:57 -0500 Subject: [PATCH 240/264] Use Go casing convention for UUID --- pgtype/ext/satori-uuid/uuid.go | 52 ++++++++++++------------- pgtype/ext/satori-uuid/uuid_test.go | 26 ++++++------- pgtype/pgtype.go | 4 +- pgtype/uuid.go | 60 ++++++++++++++--------------- pgtype/uuid_test.go | 26 ++++++------- 5 files changed, 84 insertions(+), 84 deletions(-) diff --git a/pgtype/ext/satori-uuid/uuid.go b/pgtype/ext/satori-uuid/uuid.go index cff98348..b7b776f9 100644 --- a/pgtype/ext/satori-uuid/uuid.go +++ b/pgtype/ext/satori-uuid/uuid.go @@ -11,43 +11,43 @@ import ( var errUndefined = errors.New("cannot encode status undefined") -type Uuid struct { +type UUID struct { UUID uuid.UUID Status pgtype.Status } -func (dst *Uuid) Set(src interface{}) error { +func (dst *UUID) Set(src interface{}) error { switch value := src.(type) { case uuid.UUID: - *dst = Uuid{UUID: value, Status: pgtype.Present} + *dst = UUID{UUID: value, Status: pgtype.Present} case [16]byte: - *dst = Uuid{UUID: uuid.UUID(value), Status: pgtype.Present} + *dst = UUID{UUID: uuid.UUID(value), Status: pgtype.Present} case []byte: if len(value) != 16 { - return fmt.Errorf("[]byte must be 16 bytes to convert to Uuid: %d", len(value)) + return fmt.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) } - *dst = Uuid{Status: pgtype.Present} + *dst = UUID{Status: pgtype.Present} copy(dst.UUID[:], value) case string: uuid, err := uuid.FromString(value) if err != nil { return err } - *dst = Uuid{UUID: uuid, Status: pgtype.Present} + *dst = UUID{UUID: uuid, Status: pgtype.Present} default: - // If all else fails see if pgtype.Uuid can handle it. If so, translate through that. - pgUuid := &pgtype.Uuid{} - if err := pgUuid.Set(value); err != nil { - return fmt.Errorf("cannot convert %v to Uuid", value) + // If all else fails see if pgtype.UUID can handle it. If so, translate through that. + pgUUID := &pgtype.UUID{} + if err := pgUUID.Set(value); err != nil { + return fmt.Errorf("cannot convert %v to UUID", value) } - *dst = Uuid{UUID: uuid.UUID(pgUuid.Bytes), Status: pgUuid.Status} + *dst = UUID{UUID: uuid.UUID(pgUUID.Bytes), Status: pgUUID.Status} } return nil } -func (dst *Uuid) Get() interface{} { +func (dst *UUID) Get() interface{} { switch dst.Status { case pgtype.Present: return dst.UUID @@ -58,7 +58,7 @@ func (dst *Uuid) Get() interface{} { } } -func (src *Uuid) AssignTo(dst interface{}) error { +func (src *UUID) AssignTo(dst interface{}) error { switch src.Status { case pgtype.Present: switch v := dst.(type) { @@ -86,9 +86,9 @@ func (src *Uuid) AssignTo(dst interface{}) error { return fmt.Errorf("cannot assign %v into %T", src, dst) } -func (dst *Uuid) DecodeText(ci *pgtype.ConnInfo, src []byte) error { +func (dst *UUID) DecodeText(ci *pgtype.ConnInfo, src []byte) error { if src == nil { - *dst = Uuid{Status: pgtype.Null} + *dst = UUID{Status: pgtype.Null} return nil } @@ -97,26 +97,26 @@ func (dst *Uuid) DecodeText(ci *pgtype.ConnInfo, src []byte) error { return err } - *dst = Uuid{UUID: u, Status: pgtype.Present} + *dst = UUID{UUID: u, Status: pgtype.Present} return nil } -func (dst *Uuid) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { +func (dst *UUID) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { if src == nil { - *dst = Uuid{Status: pgtype.Null} + *dst = UUID{Status: pgtype.Null} return nil } if len(src) != 16 { - return fmt.Errorf("invalid length for Uuid: %v", len(src)) + return fmt.Errorf("invalid length for UUID: %v", len(src)) } - *dst = Uuid{Status: pgtype.Present} + *dst = UUID{Status: pgtype.Present} copy(dst.UUID[:], src) return nil } -func (src *Uuid) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { +func (src *UUID) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case pgtype.Null: return nil, nil @@ -127,7 +127,7 @@ func (src *Uuid) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { return append(buf, src.UUID.String()...), nil } -func (src *Uuid) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { +func (src *UUID) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case pgtype.Null: return nil, nil @@ -139,9 +139,9 @@ func (src *Uuid) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { } // Scan implements the database/sql Scanner interface. -func (dst *Uuid) Scan(src interface{}) error { +func (dst *UUID) Scan(src interface{}) error { if src == nil { - *dst = Uuid{Status: pgtype.Null} + *dst = UUID{Status: pgtype.Null} return nil } @@ -156,6 +156,6 @@ func (dst *Uuid) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Uuid) Value() (driver.Value, error) { +func (src *UUID) Value() (driver.Value, error) { return pgtype.EncodeValueText(src) } diff --git a/pgtype/ext/satori-uuid/uuid_test.go b/pgtype/ext/satori-uuid/uuid_test.go index 993fb837..02ebb770 100644 --- a/pgtype/ext/satori-uuid/uuid_test.go +++ b/pgtype/ext/satori-uuid/uuid_test.go @@ -9,34 +9,34 @@ import ( "github.com/jackc/pgx/pgtype/testutil" ) -func TestUuidTranscode(t *testing.T) { +func TestUUIDTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ - &satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - &satori.Uuid{Status: pgtype.Null}, + &satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + &satori.UUID{Status: pgtype.Null}, }) } -func TestUuidSet(t *testing.T) { +func TestUUIDSet(t *testing.T) { successfulTests := []struct { source interface{} - result satori.Uuid + result satori.UUID }{ { source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + result: satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, }, { source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + result: satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, }, { source: "00010203-0405-0607-0809-0a0b0c0d0e0f", - result: satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + result: satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, }, } for i, tt := range successfulTests { - var r satori.Uuid + var r satori.UUID err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -48,9 +48,9 @@ func TestUuidSet(t *testing.T) { } } -func TestUuidAssignTo(t *testing.T) { +func TestUUIDAssignTo(t *testing.T) { { - src := satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + src := satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} var dst [16]byte expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} @@ -65,7 +65,7 @@ func TestUuidAssignTo(t *testing.T) { } { - src := satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + src := satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} var dst []byte expected := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} @@ -80,7 +80,7 @@ func TestUuidAssignTo(t *testing.T) { } { - src := satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + src := satori.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} var dst string expected := "00010203-0405-0607-0809-0a0b0c0d0e0f" diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 4c1e86f6..60fab59f 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -46,7 +46,7 @@ const ( TimestamptzOID = 1184 TimestamptzArrayOID = 1185 RecordOID = 2249 - UuidOID = 2950 + UUIDOID = 2950 JsonbOID = 3802 ) @@ -262,7 +262,7 @@ func init() { "tsrange": &Tsrange{}, "tstzrange": &Tstzrange{}, "unknown": &Unknown{}, - "uuid": &Uuid{}, + "uuid": &UUID{}, "varbit": &Varbit{}, "varchar": &Varchar{}, "xid": &Xid{}, diff --git a/pgtype/uuid.go b/pgtype/uuid.go index c73c501e..d1ab1a38 100644 --- a/pgtype/uuid.go +++ b/pgtype/uuid.go @@ -6,38 +6,38 @@ import ( "fmt" ) -type Uuid struct { +type UUID struct { Bytes [16]byte Status Status } -func (dst *Uuid) Set(src interface{}) error { +func (dst *UUID) Set(src interface{}) error { switch value := src.(type) { case [16]byte: - *dst = Uuid{Bytes: value, Status: Present} + *dst = UUID{Bytes: value, Status: Present} case []byte: if len(value) != 16 { - return fmt.Errorf("[]byte must be 16 bytes to convert to Uuid: %d", len(value)) + return fmt.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) } - *dst = Uuid{Status: Present} + *dst = UUID{Status: Present} copy(dst.Bytes[:], value) case string: - uuid, err := parseUuid(value) + uuid, err := parseUUID(value) if err != nil { return err } - *dst = Uuid{Bytes: uuid, Status: Present} + *dst = UUID{Bytes: uuid, Status: Present} default: if originalSrc, ok := underlyingPtrType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Uuid", value) + return fmt.Errorf("cannot convert %v to UUID", value) } return nil } -func (dst *Uuid) Get() interface{} { +func (dst *UUID) Get() interface{} { switch dst.Status { case Present: return dst.Bytes @@ -48,7 +48,7 @@ func (dst *Uuid) Get() interface{} { } } -func (src *Uuid) AssignTo(dst interface{}) error { +func (src *UUID) AssignTo(dst interface{}) error { switch src.Status { case Present: switch v := dst.(type) { @@ -60,7 +60,7 @@ func (src *Uuid) AssignTo(dst interface{}) error { copy(*v, src.Bytes[:]) return nil case *string: - *v = encodeUuid(src.Bytes) + *v = encodeUUID(src.Bytes) return nil default: if nextDst, retry := GetAssignToDstType(v); retry { @@ -74,8 +74,8 @@ func (src *Uuid) AssignTo(dst interface{}) error { return fmt.Errorf("cannot assign %v into %T", src, dst) } -// parseUuid converts a string UUID in standard form to a byte array. -func parseUuid(src string) (dst [16]byte, err error) { +// parseUUID converts a string UUID in standard form to a byte array. +func parseUUID(src string) (dst [16]byte, err error) { src = src[0:8] + src[9:13] + src[14:18] + src[19:23] + src[24:] buf, err := hex.DecodeString(src) if err != nil { @@ -86,46 +86,46 @@ func parseUuid(src string) (dst [16]byte, err error) { return dst, err } -// encodeUuid converts a uuid byte array to UUID standard string form. -func encodeUuid(src [16]byte) string { +// encodeUUID converts a uuid byte array to UUID standard string form. +func encodeUUID(src [16]byte) string { return fmt.Sprintf("%x-%x-%x-%x-%x", src[0:4], src[4:6], src[6:8], src[8:10], src[10:16]) } -func (dst *Uuid) DecodeText(ci *ConnInfo, src []byte) error { +func (dst *UUID) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Uuid{Status: Null} + *dst = UUID{Status: Null} return nil } if len(src) != 36 { - return fmt.Errorf("invalid length for Uuid: %v", len(src)) + return fmt.Errorf("invalid length for UUID: %v", len(src)) } - buf, err := parseUuid(string(src)) + buf, err := parseUUID(string(src)) if err != nil { return err } - *dst = Uuid{Bytes: buf, Status: Present} + *dst = UUID{Bytes: buf, Status: Present} return nil } -func (dst *Uuid) DecodeBinary(ci *ConnInfo, src []byte) error { +func (dst *UUID) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Uuid{Status: Null} + *dst = UUID{Status: Null} return nil } if len(src) != 16 { - return fmt.Errorf("invalid length for Uuid: %v", len(src)) + return fmt.Errorf("invalid length for UUID: %v", len(src)) } - *dst = Uuid{Status: Present} + *dst = UUID{Status: Present} copy(dst.Bytes[:], src) return nil } -func (src *Uuid) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *UUID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -133,10 +133,10 @@ func (src *Uuid) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, errUndefined } - return append(buf, encodeUuid(src.Bytes)...), nil + return append(buf, encodeUUID(src.Bytes)...), nil } -func (src *Uuid) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *UUID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -148,9 +148,9 @@ func (src *Uuid) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } // Scan implements the database/sql Scanner interface. -func (dst *Uuid) Scan(src interface{}) error { +func (dst *UUID) Scan(src interface{}) error { if src == nil { - *dst = Uuid{Status: Null} + *dst = UUID{Status: Null} return nil } @@ -167,6 +167,6 @@ func (dst *Uuid) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Uuid) Value() (driver.Value, error) { +func (src *UUID) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/pgtype/uuid_test.go b/pgtype/uuid_test.go index 4c6ad2cd..5ab52b35 100644 --- a/pgtype/uuid_test.go +++ b/pgtype/uuid_test.go @@ -8,34 +8,34 @@ import ( "github.com/jackc/pgx/pgtype/testutil" ) -func TestUuidTranscode(t *testing.T) { +func TestUUIDTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ - &pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, - &pgtype.Uuid{Status: pgtype.Null}, + &pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + &pgtype.UUID{Status: pgtype.Null}, }) } -func TestUuidSet(t *testing.T) { +func TestUUIDSet(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.Uuid + result pgtype.UUID }{ { source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, }, { source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, }, { source: "00010203-0405-0607-0809-0a0b0c0d0e0f", - result: pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, }, } for i, tt := range successfulTests { - var r pgtype.Uuid + var r pgtype.UUID err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -47,9 +47,9 @@ func TestUuidSet(t *testing.T) { } } -func TestUuidAssignTo(t *testing.T) { +func TestUUIDAssignTo(t *testing.T) { { - src := pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} var dst [16]byte expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} @@ -64,7 +64,7 @@ func TestUuidAssignTo(t *testing.T) { } { - src := pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} var dst []byte expected := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} @@ -79,7 +79,7 @@ func TestUuidAssignTo(t *testing.T) { } { - src := pgtype.Uuid{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} var dst string expected := "00010203-0405-0607-0809-0a0b0c0d0e0f" From 6688466123c2f2ffada0472895af87dd5ba48099 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 3 Jun 2017 11:57:14 -0500 Subject: [PATCH 241/264] Use Go casing convention for JSON(B) --- pgtype/json.go | 40 +++++++++++++++++----------------- pgtype/json_test.go | 52 ++++++++++++++++++++++---------------------- pgtype/jsonb.go | 38 ++++++++++++++++---------------- pgtype/jsonb_test.go | 52 ++++++++++++++++++++++---------------------- pgtype/pgtype.go | 8 +++---- 5 files changed, 95 insertions(+), 95 deletions(-) diff --git a/pgtype/json.go b/pgtype/json.go index 91d31129..ee00e9a4 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -6,44 +6,44 @@ import ( "fmt" ) -type Json struct { +type JSON struct { Bytes []byte Status Status } -func (dst *Json) Set(src interface{}) error { +func (dst *JSON) Set(src interface{}) error { if src == nil { - *dst = Json{Status: Null} + *dst = JSON{Status: Null} return nil } switch value := src.(type) { case string: - *dst = Json{Bytes: []byte(value), Status: Present} + *dst = JSON{Bytes: []byte(value), Status: Present} case *string: if value == nil { - *dst = Json{Status: Null} + *dst = JSON{Status: Null} } else { - *dst = Json{Bytes: []byte(*value), Status: Present} + *dst = JSON{Bytes: []byte(*value), Status: Present} } case []byte: if value == nil { - *dst = Json{Status: Null} + *dst = JSON{Status: Null} } else { - *dst = Json{Bytes: value, Status: Present} + *dst = JSON{Bytes: value, Status: Present} } default: buf, err := json.Marshal(value) if err != nil { return err } - *dst = Json{Bytes: buf, Status: Present} + *dst = JSON{Bytes: buf, Status: Present} } return nil } -func (dst *Json) Get() interface{} { +func (dst *JSON) Get() interface{} { switch dst.Status { case Present: var i interface{} @@ -59,7 +59,7 @@ func (dst *Json) Get() interface{} { } } -func (src *Json) AssignTo(dst interface{}) error { +func (src *JSON) AssignTo(dst interface{}) error { switch v := dst.(type) { case *string: if src.Status != Present { @@ -90,21 +90,21 @@ func (src *Json) AssignTo(dst interface{}) error { return nil } -func (dst *Json) DecodeText(ci *ConnInfo, src []byte) error { +func (dst *JSON) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Json{Status: Null} + *dst = JSON{Status: Null} return nil } - *dst = Json{Bytes: src, Status: Present} + *dst = JSON{Bytes: src, Status: Present} return nil } -func (dst *Json) DecodeBinary(ci *ConnInfo, src []byte) error { +func (dst *JSON) DecodeBinary(ci *ConnInfo, src []byte) error { return dst.DecodeText(ci, src) } -func (src *Json) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *JSON) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -115,14 +115,14 @@ func (src *Json) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return append(buf, src.Bytes...), nil } -func (src *Json) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *JSON) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return src.EncodeText(ci, buf) } // Scan implements the database/sql Scanner interface. -func (dst *Json) Scan(src interface{}) error { +func (dst *JSON) Scan(src interface{}) error { if src == nil { - *dst = Json{Status: Null} + *dst = JSON{Status: Null} return nil } @@ -139,7 +139,7 @@ func (dst *Json) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Json) Value() (driver.Value, error) { +func (src *JSON) Value() (driver.Value, error) { switch src.Status { case Present: return string(src.Bytes), nil diff --git a/pgtype/json_test.go b/pgtype/json_test.go index 3d8d2a68..82c02539 100644 --- a/pgtype/json_test.go +++ b/pgtype/json_test.go @@ -9,31 +9,31 @@ import ( "github.com/jackc/pgx/pgtype/testutil" ) -func TestJsonTranscode(t *testing.T) { +func TestJSONTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "json", []interface{}{ - &pgtype.Json{Bytes: []byte("{}"), Status: pgtype.Present}, - &pgtype.Json{Bytes: []byte("null"), Status: pgtype.Present}, - &pgtype.Json{Bytes: []byte("42"), Status: pgtype.Present}, - &pgtype.Json{Bytes: []byte(`"hello"`), Status: pgtype.Present}, - &pgtype.Json{Status: pgtype.Null}, + &pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}, + &pgtype.JSON{Bytes: []byte("null"), Status: pgtype.Present}, + &pgtype.JSON{Bytes: []byte("42"), Status: pgtype.Present}, + &pgtype.JSON{Bytes: []byte(`"hello"`), Status: pgtype.Present}, + &pgtype.JSON{Status: pgtype.Null}, }) } -func TestJsonSet(t *testing.T) { +func TestJSONSet(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.Json + result pgtype.JSON }{ - {source: "{}", result: pgtype.Json{Bytes: []byte("{}"), Status: pgtype.Present}}, - {source: []byte("{}"), result: pgtype.Json{Bytes: []byte("{}"), Status: pgtype.Present}}, - {source: ([]byte)(nil), result: pgtype.Json{Status: pgtype.Null}}, - {source: (*string)(nil), result: pgtype.Json{Status: pgtype.Null}}, - {source: []int{1, 2, 3}, result: pgtype.Json{Bytes: []byte("[1,2,3]"), Status: pgtype.Present}}, - {source: map[string]interface{}{"foo": "bar"}, result: pgtype.Json{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}}, + {source: "{}", result: pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}}, + {source: []byte("{}"), result: pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}}, + {source: ([]byte)(nil), result: pgtype.JSON{Status: pgtype.Null}}, + {source: (*string)(nil), result: pgtype.JSON{Status: pgtype.Null}}, + {source: []int{1, 2, 3}, result: pgtype.JSON{Bytes: []byte("[1,2,3]"), Status: pgtype.Present}}, + {source: map[string]interface{}{"foo": "bar"}, result: pgtype.JSON{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}}, } for i, tt := range successfulTests { - var d pgtype.Json + var d pgtype.JSON err := d.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -45,17 +45,17 @@ func TestJsonSet(t *testing.T) { } } -func TestJsonAssignTo(t *testing.T) { +func TestJSONAssignTo(t *testing.T) { var s string var ps *string var b []byte rawStringTests := []struct { - src pgtype.Json + src pgtype.JSON dst *string expected string }{ - {src: pgtype.Json{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &s, expected: "{}"}, + {src: pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &s, expected: "{}"}, } for i, tt := range rawStringTests { @@ -70,12 +70,12 @@ func TestJsonAssignTo(t *testing.T) { } rawBytesTests := []struct { - src pgtype.Json + src pgtype.JSON dst *[]byte expected []byte }{ - {src: pgtype.Json{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &b, expected: []byte("{}")}, - {src: pgtype.Json{Status: pgtype.Null}, dst: &b, expected: (([]byte)(nil))}, + {src: pgtype.JSON{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &b, expected: []byte("{}")}, + {src: pgtype.JSON{Status: pgtype.Null}, dst: &b, expected: (([]byte)(nil))}, } for i, tt := range rawBytesTests { @@ -97,12 +97,12 @@ func TestJsonAssignTo(t *testing.T) { var strDst structDst unmarshalTests := []struct { - src pgtype.Json + src pgtype.JSON dst interface{} expected interface{} }{ - {src: pgtype.Json{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, - {src: pgtype.Json{Bytes: []byte(`{"name":"John","age":42}`), Status: pgtype.Present}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, + {src: pgtype.JSON{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, + {src: pgtype.JSON{Bytes: []byte(`{"name":"John","age":42}`), Status: pgtype.Present}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, } for i, tt := range unmarshalTests { err := tt.src.AssignTo(tt.dst) @@ -116,11 +116,11 @@ func TestJsonAssignTo(t *testing.T) { } pointerAllocTests := []struct { - src pgtype.Json + src pgtype.JSON dst **string expected *string }{ - {src: pgtype.Json{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + {src: pgtype.JSON{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, } for i, tt := range pointerAllocTests { diff --git a/pgtype/jsonb.go b/pgtype/jsonb.go index f7914202..9a06c1b4 100644 --- a/pgtype/jsonb.go +++ b/pgtype/jsonb.go @@ -5,27 +5,27 @@ import ( "fmt" ) -type Jsonb Json +type JSONB JSON -func (dst *Jsonb) Set(src interface{}) error { - return (*Json)(dst).Set(src) +func (dst *JSONB) Set(src interface{}) error { + return (*JSON)(dst).Set(src) } -func (dst *Jsonb) Get() interface{} { - return (*Json)(dst).Get() +func (dst *JSONB) Get() interface{} { + return (*JSON)(dst).Get() } -func (src *Jsonb) AssignTo(dst interface{}) error { - return (*Json)(src).AssignTo(dst) +func (src *JSONB) AssignTo(dst interface{}) error { + return (*JSON)(src).AssignTo(dst) } -func (dst *Jsonb) DecodeText(ci *ConnInfo, src []byte) error { - return (*Json)(dst).DecodeText(ci, src) +func (dst *JSONB) DecodeText(ci *ConnInfo, src []byte) error { + return (*JSON)(dst).DecodeText(ci, src) } -func (dst *Jsonb) DecodeBinary(ci *ConnInfo, src []byte) error { +func (dst *JSONB) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Jsonb{Status: Null} + *dst = JSONB{Status: Null} return nil } @@ -37,16 +37,16 @@ func (dst *Jsonb) DecodeBinary(ci *ConnInfo, src []byte) error { return fmt.Errorf("unknown jsonb version number %d", src[0]) } - *dst = Jsonb{Bytes: src[1:], Status: Present} + *dst = JSONB{Bytes: src[1:], Status: Present} return nil } -func (src *Jsonb) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - return (*Json)(src).EncodeText(ci, buf) +func (src *JSONB) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + return (*JSON)(src).EncodeText(ci, buf) } -func (src *Jsonb) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *JSONB) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -59,11 +59,11 @@ func (src *Jsonb) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } // Scan implements the database/sql Scanner interface. -func (dst *Jsonb) Scan(src interface{}) error { - return (*Json)(dst).Scan(src) +func (dst *JSONB) Scan(src interface{}) error { + return (*JSON)(dst).Scan(src) } // Value implements the database/sql/driver Valuer interface. -func (src *Jsonb) Value() (driver.Value, error) { - return (*Json)(src).Value() +func (src *JSONB) Value() (driver.Value, error) { + return (*JSON)(src).Value() } diff --git a/pgtype/jsonb_test.go b/pgtype/jsonb_test.go index 86c8a12c..1a9a3056 100644 --- a/pgtype/jsonb_test.go +++ b/pgtype/jsonb_test.go @@ -9,7 +9,7 @@ import ( "github.com/jackc/pgx/pgtype/testutil" ) -func TestJsonbTranscode(t *testing.T) { +func TestJSONBTranscode(t *testing.T) { conn := testutil.MustConnectPgx(t) defer testutil.MustClose(t, conn) if _, ok := conn.ConnInfo.DataTypeForName("jsonb"); !ok { @@ -17,29 +17,29 @@ func TestJsonbTranscode(t *testing.T) { } testutil.TestSuccessfulTranscode(t, "jsonb", []interface{}{ - &pgtype.Jsonb{Bytes: []byte("{}"), Status: pgtype.Present}, - &pgtype.Jsonb{Bytes: []byte("null"), Status: pgtype.Present}, - &pgtype.Jsonb{Bytes: []byte("42"), Status: pgtype.Present}, - &pgtype.Jsonb{Bytes: []byte(`"hello"`), Status: pgtype.Present}, - &pgtype.Jsonb{Status: pgtype.Null}, + &pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}, + &pgtype.JSONB{Bytes: []byte("null"), Status: pgtype.Present}, + &pgtype.JSONB{Bytes: []byte("42"), Status: pgtype.Present}, + &pgtype.JSONB{Bytes: []byte(`"hello"`), Status: pgtype.Present}, + &pgtype.JSONB{Status: pgtype.Null}, }) } -func TestJsonbSet(t *testing.T) { +func TestJSONBSet(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.Jsonb + result pgtype.JSONB }{ - {source: "{}", result: pgtype.Jsonb{Bytes: []byte("{}"), Status: pgtype.Present}}, - {source: []byte("{}"), result: pgtype.Jsonb{Bytes: []byte("{}"), Status: pgtype.Present}}, - {source: ([]byte)(nil), result: pgtype.Jsonb{Status: pgtype.Null}}, - {source: (*string)(nil), result: pgtype.Jsonb{Status: pgtype.Null}}, - {source: []int{1, 2, 3}, result: pgtype.Jsonb{Bytes: []byte("[1,2,3]"), Status: pgtype.Present}}, - {source: map[string]interface{}{"foo": "bar"}, result: pgtype.Jsonb{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}}, + {source: "{}", result: pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}}, + {source: []byte("{}"), result: pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}}, + {source: ([]byte)(nil), result: pgtype.JSONB{Status: pgtype.Null}}, + {source: (*string)(nil), result: pgtype.JSONB{Status: pgtype.Null}}, + {source: []int{1, 2, 3}, result: pgtype.JSONB{Bytes: []byte("[1,2,3]"), Status: pgtype.Present}}, + {source: map[string]interface{}{"foo": "bar"}, result: pgtype.JSONB{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}}, } for i, tt := range successfulTests { - var d pgtype.Jsonb + var d pgtype.JSONB err := d.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -51,17 +51,17 @@ func TestJsonbSet(t *testing.T) { } } -func TestJsonbAssignTo(t *testing.T) { +func TestJSONBAssignTo(t *testing.T) { var s string var ps *string var b []byte rawStringTests := []struct { - src pgtype.Jsonb + src pgtype.JSONB dst *string expected string }{ - {src: pgtype.Jsonb{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &s, expected: "{}"}, + {src: pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &s, expected: "{}"}, } for i, tt := range rawStringTests { @@ -76,12 +76,12 @@ func TestJsonbAssignTo(t *testing.T) { } rawBytesTests := []struct { - src pgtype.Jsonb + src pgtype.JSONB dst *[]byte expected []byte }{ - {src: pgtype.Jsonb{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &b, expected: []byte("{}")}, - {src: pgtype.Jsonb{Status: pgtype.Null}, dst: &b, expected: (([]byte)(nil))}, + {src: pgtype.JSONB{Bytes: []byte("{}"), Status: pgtype.Present}, dst: &b, expected: []byte("{}")}, + {src: pgtype.JSONB{Status: pgtype.Null}, dst: &b, expected: (([]byte)(nil))}, } for i, tt := range rawBytesTests { @@ -103,12 +103,12 @@ func TestJsonbAssignTo(t *testing.T) { var strDst structDst unmarshalTests := []struct { - src pgtype.Jsonb + src pgtype.JSONB dst interface{} expected interface{} }{ - {src: pgtype.Jsonb{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, - {src: pgtype.Jsonb{Bytes: []byte(`{"name":"John","age":42}`), Status: pgtype.Present}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, + {src: pgtype.JSONB{Bytes: []byte(`{"foo":"bar"}`), Status: pgtype.Present}, dst: &mapDst, expected: map[string]interface{}{"foo": "bar"}}, + {src: pgtype.JSONB{Bytes: []byte(`{"name":"John","age":42}`), Status: pgtype.Present}, dst: &strDst, expected: structDst{Name: "John", Age: 42}}, } for i, tt := range unmarshalTests { err := tt.src.AssignTo(tt.dst) @@ -122,11 +122,11 @@ func TestJsonbAssignTo(t *testing.T) { } pointerAllocTests := []struct { - src pgtype.Jsonb + src pgtype.JSONB dst **string expected *string }{ - {src: pgtype.Jsonb{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + {src: pgtype.JSONB{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, } for i, tt := range pointerAllocTests { diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 60fab59f..2bfc9527 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -19,7 +19,7 @@ const ( TidOID = 27 XidOID = 28 CidOID = 29 - JsonOID = 114 + JSONOID = 114 CidrOID = 650 CidrArrayOID = 651 Float4OID = 700 @@ -47,7 +47,7 @@ const ( TimestamptzArrayOID = 1185 RecordOID = 2249 UUIDOID = 2950 - JsonbOID = 3802 + JSONBOID = 3802 ) type Status byte @@ -242,8 +242,8 @@ func init() { "int4range": &Int4range{}, "int8": &Int8{}, "int8range": &Int8range{}, - "json": &Json{}, - "jsonb": &Jsonb{}, + "json": &JSON{}, + "jsonb": &JSONB{}, "line": &Line{}, "lseg": &Lseg{}, "macaddr": &Macaddr{}, From 24fb04edb59135bd428a3a0492f8b428d993fb6e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 3 Jun 2017 11:58:40 -0500 Subject: [PATCH 242/264] Use Go casing convention for ACLItem --- pgtype/aclitem.go | 32 ++++++++-------- pgtype/aclitem_array.go | 38 +++++++++--------- pgtype/aclitem_array_test.go | 74 ++++++++++++++++++------------------ pgtype/aclitem_test.go | 34 ++++++++--------- pgtype/pgtype.go | 8 ++-- pgtype/typed_array_gen.sh | 2 +- 6 files changed, 94 insertions(+), 94 deletions(-) diff --git a/pgtype/aclitem.go b/pgtype/aclitem.go index 27dc15d1..829eb908 100644 --- a/pgtype/aclitem.go +++ b/pgtype/aclitem.go @@ -5,7 +5,7 @@ import ( "fmt" ) -// Aclitem is used for PostgreSQL's aclitem data type. A sample aclitem +// ACLItem is used for PostgreSQL's aclitem data type. A sample aclitem // might look like this: // // postgres=arwdDxt/postgres @@ -17,32 +17,32 @@ import ( // // postgres=arwdDxt/"role with spaces" // -type Aclitem struct { +type ACLItem struct { String string Status Status } -func (dst *Aclitem) Set(src interface{}) error { +func (dst *ACLItem) Set(src interface{}) error { switch value := src.(type) { case string: - *dst = Aclitem{String: value, Status: Present} + *dst = ACLItem{String: value, Status: Present} case *string: if value == nil { - *dst = Aclitem{Status: Null} + *dst = ACLItem{Status: Null} } else { - *dst = Aclitem{String: *value, Status: Present} + *dst = ACLItem{String: *value, Status: Present} } default: if originalSrc, ok := underlyingStringType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Aclitem", value) + return fmt.Errorf("cannot convert %v to ACLItem", value) } return nil } -func (dst *Aclitem) Get() interface{} { +func (dst *ACLItem) Get() interface{} { switch dst.Status { case Present: return dst.String @@ -53,7 +53,7 @@ func (dst *Aclitem) Get() interface{} { } } -func (src *Aclitem) AssignTo(dst interface{}) error { +func (src *ACLItem) AssignTo(dst interface{}) error { switch src.Status { case Present: switch v := dst.(type) { @@ -72,17 +72,17 @@ func (src *Aclitem) AssignTo(dst interface{}) error { return fmt.Errorf("cannot decode %v into %T", src, dst) } -func (dst *Aclitem) DecodeText(ci *ConnInfo, src []byte) error { +func (dst *ACLItem) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Aclitem{Status: Null} + *dst = ACLItem{Status: Null} return nil } - *dst = Aclitem{String: string(src), Status: Present} + *dst = ACLItem{String: string(src), Status: Present} return nil } -func (src *Aclitem) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *ACLItem) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -94,9 +94,9 @@ func (src *Aclitem) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } // Scan implements the database/sql Scanner interface. -func (dst *Aclitem) Scan(src interface{}) error { +func (dst *ACLItem) Scan(src interface{}) error { if src == nil { - *dst = Aclitem{Status: Null} + *dst = ACLItem{Status: Null} return nil } @@ -113,7 +113,7 @@ func (dst *Aclitem) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Aclitem) Value() (driver.Value, error) { +func (src *ACLItem) Value() (driver.Value, error) { switch src.Status { case Present: return src.String, nil diff --git a/pgtype/aclitem_array.go b/pgtype/aclitem_array.go index 7df0b503..f9215a93 100644 --- a/pgtype/aclitem_array.go +++ b/pgtype/aclitem_array.go @@ -5,28 +5,28 @@ import ( "fmt" ) -type AclitemArray struct { - Elements []Aclitem +type ACLItemArray struct { + Elements []ACLItem Dimensions []ArrayDimension Status Status } -func (dst *AclitemArray) Set(src interface{}) error { +func (dst *ACLItemArray) Set(src interface{}) error { switch value := src.(type) { case []string: if value == nil { - *dst = AclitemArray{Status: Null} + *dst = ACLItemArray{Status: Null} } else if len(value) == 0 { - *dst = AclitemArray{Status: Present} + *dst = ACLItemArray{Status: Present} } else { - elements := make([]Aclitem, len(value)) + elements := make([]ACLItem, len(value)) for i := range value { if err := elements[i].Set(value[i]); err != nil { return err } } - *dst = AclitemArray{ + *dst = ACLItemArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, Status: Present, @@ -37,13 +37,13 @@ func (dst *AclitemArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Aclitem", value) + return fmt.Errorf("cannot convert %v to ACLItem", value) } return nil } -func (dst *AclitemArray) Get() interface{} { +func (dst *ACLItemArray) Get() interface{} { switch dst.Status { case Present: return dst @@ -54,7 +54,7 @@ func (dst *AclitemArray) Get() interface{} { } } -func (src *AclitemArray) AssignTo(dst interface{}) error { +func (src *ACLItemArray) AssignTo(dst interface{}) error { switch src.Status { case Present: switch v := dst.(type) { @@ -80,9 +80,9 @@ func (src *AclitemArray) AssignTo(dst interface{}) error { return fmt.Errorf("cannot decode %v into %T", src, dst) } -func (dst *AclitemArray) DecodeText(ci *ConnInfo, src []byte) error { +func (dst *ACLItemArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = AclitemArray{Status: Null} + *dst = ACLItemArray{Status: Null} return nil } @@ -91,13 +91,13 @@ func (dst *AclitemArray) DecodeText(ci *ConnInfo, src []byte) error { return err } - var elements []Aclitem + var elements []ACLItem if len(uta.Elements) > 0 { - elements = make([]Aclitem, len(uta.Elements)) + elements = make([]ACLItem, len(uta.Elements)) for i, s := range uta.Elements { - var elem Aclitem + var elem ACLItem var elemSrc []byte if s != "NULL" { elemSrc = []byte(s) @@ -111,12 +111,12 @@ func (dst *AclitemArray) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = AclitemArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = ACLItemArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} return nil } -func (src *AclitemArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *ACLItemArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -174,7 +174,7 @@ func (src *AclitemArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { } // Scan implements the database/sql Scanner interface. -func (dst *AclitemArray) Scan(src interface{}) error { +func (dst *ACLItemArray) Scan(src interface{}) error { if src == nil { return dst.DecodeText(nil, nil) } @@ -192,7 +192,7 @@ func (dst *AclitemArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *AclitemArray) Value() (driver.Value, error) { +func (src *ACLItemArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/pgtype/aclitem_array_test.go b/pgtype/aclitem_array_test.go index 951e7847..c01eaa13 100644 --- a/pgtype/aclitem_array_test.go +++ b/pgtype/aclitem_array_test.go @@ -8,40 +8,40 @@ import ( "github.com/jackc/pgx/pgtype/testutil" ) -func TestAclitemArrayTranscode(t *testing.T) { +func TestACLItemArrayTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "aclitem[]", []interface{}{ - &pgtype.AclitemArray{ + &pgtype.ACLItemArray{ Elements: nil, Dimensions: nil, Status: pgtype.Present, }, - &pgtype.AclitemArray{ - Elements: []pgtype.Aclitem{ - pgtype.Aclitem{String: "=r/postgres", Status: pgtype.Present}, - pgtype.Aclitem{Status: pgtype.Null}, + &pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{Status: pgtype.Null}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, - &pgtype.AclitemArray{Status: pgtype.Null}, - &pgtype.AclitemArray{ - Elements: []pgtype.Aclitem{ - pgtype.Aclitem{String: "=r/postgres", Status: pgtype.Present}, - pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - pgtype.Aclitem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, - pgtype.Aclitem{String: "=r/postgres", Status: pgtype.Present}, - pgtype.Aclitem{Status: pgtype.Null}, - pgtype.Aclitem{String: "=r/postgres", Status: pgtype.Present}, + &pgtype.ACLItemArray{Status: pgtype.Null}, + &pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{Status: pgtype.Null}, + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, - &pgtype.AclitemArray{ - Elements: []pgtype.Aclitem{ - pgtype.Aclitem{String: "=r/postgres", Status: pgtype.Present}, - pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - pgtype.Aclitem{String: "=r/postgres", Status: pgtype.Present}, - pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + &pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{ + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: "=r/postgres", Status: pgtype.Present}, + pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, @@ -52,26 +52,26 @@ func TestAclitemArrayTranscode(t *testing.T) { }) } -func TestAclitemArraySet(t *testing.T) { +func TestACLItemArraySet(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.AclitemArray + result pgtype.ACLItemArray }{ { source: []string{"=r/postgres"}, - result: pgtype.AclitemArray{ - Elements: []pgtype.Aclitem{{String: "=r/postgres", Status: pgtype.Present}}, + result: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present}, }, { source: (([]string)(nil)), - result: pgtype.AclitemArray{Status: pgtype.Null}, + result: pgtype.ACLItemArray{Status: pgtype.Null}, }, } for i, tt := range successfulTests { - var r pgtype.AclitemArray + var r pgtype.ACLItemArray err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -83,19 +83,19 @@ func TestAclitemArraySet(t *testing.T) { } } -func TestAclitemArrayAssignTo(t *testing.T) { +func TestACLItemArrayAssignTo(t *testing.T) { var stringSlice []string type _stringSlice []string var namedStringSlice _stringSlice simpleTests := []struct { - src pgtype.AclitemArray + src pgtype.ACLItemArray dst interface{} expected interface{} }{ { - src: pgtype.AclitemArray{ - Elements: []pgtype.Aclitem{{String: "=r/postgres", Status: pgtype.Present}}, + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, @@ -103,8 +103,8 @@ func TestAclitemArrayAssignTo(t *testing.T) { expected: []string{"=r/postgres"}, }, { - src: pgtype.AclitemArray{ - Elements: []pgtype.Aclitem{{String: "=r/postgres", Status: pgtype.Present}}, + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{String: "=r/postgres", Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, @@ -112,7 +112,7 @@ func TestAclitemArrayAssignTo(t *testing.T) { expected: _stringSlice{"=r/postgres"}, }, { - src: pgtype.AclitemArray{Status: pgtype.Null}, + src: pgtype.ACLItemArray{Status: pgtype.Null}, dst: &stringSlice, expected: (([]string)(nil)), }, @@ -130,12 +130,12 @@ func TestAclitemArrayAssignTo(t *testing.T) { } errorTests := []struct { - src pgtype.AclitemArray + src pgtype.ACLItemArray dst interface{} }{ { - src: pgtype.AclitemArray{ - Elements: []pgtype.Aclitem{{Status: pgtype.Null}}, + src: pgtype.ACLItemArray{ + Elements: []pgtype.ACLItem{{Status: pgtype.Null}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, diff --git a/pgtype/aclitem_test.go b/pgtype/aclitem_test.go index 13c63395..65399a30 100644 --- a/pgtype/aclitem_test.go +++ b/pgtype/aclitem_test.go @@ -8,25 +8,25 @@ import ( "github.com/jackc/pgx/pgtype/testutil" ) -func TestAclitemTranscode(t *testing.T) { +func TestACLItemTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "aclitem", []interface{}{ - &pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, - &pgtype.Aclitem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, - &pgtype.Aclitem{Status: pgtype.Null}, + &pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, + &pgtype.ACLItem{String: `postgres=arwdDxt/" tricky, ' } "" \ test user "`, Status: pgtype.Present}, + &pgtype.ACLItem{Status: pgtype.Null}, }) } -func TestAclitemSet(t *testing.T) { +func TestACLItemSet(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.Aclitem + result pgtype.ACLItem }{ - {source: "postgres=arwdDxt/postgres", result: pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, - {source: (*string)(nil), result: pgtype.Aclitem{Status: pgtype.Null}}, + {source: "postgres=arwdDxt/postgres", result: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}}, + {source: (*string)(nil), result: pgtype.ACLItem{Status: pgtype.Null}}, } for i, tt := range successfulTests { - var d pgtype.Aclitem + var d pgtype.ACLItem err := d.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -38,17 +38,17 @@ func TestAclitemSet(t *testing.T) { } } -func TestAclitemAssignTo(t *testing.T) { +func TestACLItemAssignTo(t *testing.T) { var s string var ps *string simpleTests := []struct { - src pgtype.Aclitem + src pgtype.ACLItem dst interface{} expected interface{} }{ - {src: pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &s, expected: "postgres=arwdDxt/postgres"}, - {src: pgtype.Aclitem{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &s, expected: "postgres=arwdDxt/postgres"}, + {src: pgtype.ACLItem{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, } for i, tt := range simpleTests { @@ -63,11 +63,11 @@ func TestAclitemAssignTo(t *testing.T) { } pointerAllocTests := []struct { - src pgtype.Aclitem + src pgtype.ACLItem dst interface{} expected interface{} }{ - {src: pgtype.Aclitem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &ps, expected: "postgres=arwdDxt/postgres"}, + {src: pgtype.ACLItem{String: "postgres=arwdDxt/postgres", Status: pgtype.Present}, dst: &ps, expected: "postgres=arwdDxt/postgres"}, } for i, tt := range pointerAllocTests { @@ -82,10 +82,10 @@ func TestAclitemAssignTo(t *testing.T) { } errorTests := []struct { - src pgtype.Aclitem + src pgtype.ACLItem dst interface{} }{ - {src: pgtype.Aclitem{Status: pgtype.Null}, dst: &s}, + {src: pgtype.ACLItem{Status: pgtype.Null}, dst: &s}, } for i, tt := range errorTests { diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 2bfc9527..4fdcf3c2 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -35,8 +35,8 @@ const ( Int8ArrayOID = 1016 Float4ArrayOID = 1021 Float8ArrayOID = 1022 - AclitemOID = 1033 - AclitemArrayOID = 1034 + ACLItemOID = 1033 + ACLItemArrayOID = 1034 InetArrayOID = 1041 VarcharOID = 1043 DateOID = 1082 @@ -206,7 +206,7 @@ var nameValues map[string]Value func init() { nameValues = map[string]Value{ - "_aclitem": &AclitemArray{}, + "_aclitem": &ACLItemArray{}, "_bool": &BoolArray{}, "_bytea": &ByteaArray{}, "_cidr": &CidrArray{}, @@ -222,7 +222,7 @@ func init() { "_timestamp": &TimestampArray{}, "_timestamptz": &TimestamptzArray{}, "_varchar": &VarcharArray{}, - "aclitem": &Aclitem{}, + "aclitem": &ACLItem{}, "bool": &Bool{}, "box": &Box{}, "bytea": &Bytea{}, diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh index 2e36b8b3..d7abcbcf 100644 --- a/pgtype/typed_array_gen.sh +++ b/pgtype/typed_array_gen.sh @@ -12,7 +12,7 @@ erb pgtype_array_type=CidrArray pgtype_element_type=Cidr go_array_types=[]*net.I erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_type_name=text text_null='"NULL"' binary_format=true typed_array.go.erb > text_array.go erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string element_type_name=varchar text_null='"NULL"' binary_format=true typed_array.go.erb > varchar_array.go erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go -erb pgtype_array_type=AclitemArray pgtype_element_type=Aclitem go_array_types=[]string element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go +erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[]string element_type_name=aclitem text_null=NULL binary_format=false typed_array.go.erb > aclitem_array.go erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go erb pgtype_array_type=NumericArray pgtype_element_type=Numeric go_array_types=[]float32,[]float64 element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go goimports -w *_array.go From 87126272573da01a15aface9efc05e461e2aefef Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 3 Jun 2017 12:01:49 -0500 Subject: [PATCH 243/264] Use Go casing convention for CID/TID/XID/CIDR --- doc.go | 2 +- pgtype/cid.go | 26 +++++----- pgtype/cid_test.go | 30 +++++------ pgtype/cidr.go | 16 +++--- pgtype/cidr_array.go | 58 ++++++++++----------- pgtype/cidr_array_test.go | 92 ++++++++++++++++----------------- pgtype/inet_array_test.go | 36 ++++++------- pgtype/inet_test.go | 36 ++++++------- pgtype/pgtype.go | 20 +++---- pgtype/pgtype_test.go | 2 +- pgtype/pguint32.go | 2 +- pgtype/tid.go | 34 ++++++------ pgtype/tid_test.go | 8 +-- pgtype/typed_array_gen.sh | 2 +- pgtype/xid.go | 26 +++++----- pgtype/xid_test.go | 30 +++++------ stdlib/sql.go | 12 ++--- values_test.go | 106 +++++++++++++++++++------------------- 18 files changed, 269 insertions(+), 269 deletions(-) diff --git a/doc.go b/doc.go index a0f0bd72..a9b9e461 100644 --- a/doc.go +++ b/doc.go @@ -146,7 +146,7 @@ JSON and JSONB Mapping pgx includes built-in support to marshal and unmarshal between Go types and the PostgreSQL JSON and JSONB. -Inet and Cidr Mapping +Inet and CIDR Mapping pgx encodes from net.IPNet to and from inet and cidr PostgreSQL types. In addition, as a convenience pgx will encode from a net.IP; it will assume a /32 diff --git a/pgtype/cid.go b/pgtype/cid.go index b7718f88..0ed54f44 100644 --- a/pgtype/cid.go +++ b/pgtype/cid.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" ) -// Cid is PostgreSQL's Command Identifier type. +// CID is PostgreSQL's Command Identifier type. // // When one does // @@ -15,47 +15,47 @@ import ( // It is currently implemented as an unsigned four byte integer. // Its definition can be found in src/include/c.h as CommandId // in the PostgreSQL sources. -type Cid pguint32 +type CID pguint32 -// Set converts from src to dst. Note that as Cid is not a general +// Set converts from src to dst. Note that as CID is not a general // number type Set does not do automatic type conversion as other number // types do. -func (dst *Cid) Set(src interface{}) error { +func (dst *CID) Set(src interface{}) error { return (*pguint32)(dst).Set(src) } -func (dst *Cid) Get() interface{} { +func (dst *CID) Get() interface{} { return (*pguint32)(dst).Get() } -// AssignTo assigns from src to dst. Note that as Cid is not a general number +// AssignTo assigns from src to dst. Note that as CID is not a general number // type AssignTo does not do automatic type conversion as other number types do. -func (src *Cid) AssignTo(dst interface{}) error { +func (src *CID) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *Cid) DecodeText(ci *ConnInfo, src []byte) error { +func (dst *CID) DecodeText(ci *ConnInfo, src []byte) error { return (*pguint32)(dst).DecodeText(ci, src) } -func (dst *Cid) DecodeBinary(ci *ConnInfo, src []byte) error { +func (dst *CID) DecodeBinary(ci *ConnInfo, src []byte) error { return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src *Cid) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *CID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return (*pguint32)(src).EncodeText(ci, buf) } -func (src *Cid) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *CID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return (*pguint32)(src).EncodeBinary(ci, buf) } // Scan implements the database/sql Scanner interface. -func (dst *Cid) Scan(src interface{}) error { +func (dst *CID) Scan(src interface{}) error { return (*pguint32)(dst).Scan(src) } // Value implements the database/sql/driver Valuer interface. -func (src *Cid) Value() (driver.Value, error) { +func (src *CID) Value() (driver.Value, error) { return (*pguint32)(src).Value() } diff --git a/pgtype/cid_test.go b/pgtype/cid_test.go index c3bf3132..0dfc56d4 100644 --- a/pgtype/cid_test.go +++ b/pgtype/cid_test.go @@ -8,11 +8,11 @@ import ( "github.com/jackc/pgx/pgtype/testutil" ) -func TestCidTranscode(t *testing.T) { +func TestCIDTranscode(t *testing.T) { pgTypeName := "cid" values := []interface{}{ - &pgtype.Cid{Uint: 42, Status: pgtype.Present}, - &pgtype.Cid{Status: pgtype.Null}, + &pgtype.CID{Uint: 42, Status: pgtype.Present}, + &pgtype.CID{Status: pgtype.Null}, } eqFunc := func(a, b interface{}) bool { return reflect.DeepEqual(a, b) @@ -28,16 +28,16 @@ func TestCidTranscode(t *testing.T) { } } -func TestCidSet(t *testing.T) { +func TestCIDSet(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.Cid + result pgtype.CID }{ - {source: uint32(1), result: pgtype.Cid{Uint: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.CID{Uint: 1, Status: pgtype.Present}}, } for i, tt := range successfulTests { - var r pgtype.Cid + var r pgtype.CID err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -49,17 +49,17 @@ func TestCidSet(t *testing.T) { } } -func TestCidAssignTo(t *testing.T) { +func TestCIDAssignTo(t *testing.T) { var ui32 uint32 var pui32 *uint32 simpleTests := []struct { - src pgtype.Cid + src pgtype.CID dst interface{} expected interface{} }{ - {src: pgtype.Cid{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.Cid{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + {src: pgtype.CID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.CID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, } for i, tt := range simpleTests { @@ -74,11 +74,11 @@ func TestCidAssignTo(t *testing.T) { } pointerAllocTests := []struct { - src pgtype.Cid + src pgtype.CID dst interface{} expected interface{} }{ - {src: pgtype.Cid{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + {src: pgtype.CID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, } for i, tt := range pointerAllocTests { @@ -93,10 +93,10 @@ func TestCidAssignTo(t *testing.T) { } errorTests := []struct { - src pgtype.Cid + src pgtype.CID dst interface{} }{ - {src: pgtype.Cid{Status: pgtype.Null}, dst: &ui32}, + {src: pgtype.CID{Status: pgtype.Null}, dst: &ui32}, } for i, tt := range errorTests { diff --git a/pgtype/cidr.go b/pgtype/cidr.go index 2b45d2d0..519b9cae 100644 --- a/pgtype/cidr.go +++ b/pgtype/cidr.go @@ -1,31 +1,31 @@ package pgtype -type Cidr Inet +type CIDR Inet -func (dst *Cidr) Set(src interface{}) error { +func (dst *CIDR) Set(src interface{}) error { return (*Inet)(dst).Set(src) } -func (dst *Cidr) Get() interface{} { +func (dst *CIDR) Get() interface{} { return (*Inet)(dst).Get() } -func (src *Cidr) AssignTo(dst interface{}) error { +func (src *CIDR) AssignTo(dst interface{}) error { return (*Inet)(src).AssignTo(dst) } -func (dst *Cidr) DecodeText(ci *ConnInfo, src []byte) error { +func (dst *CIDR) DecodeText(ci *ConnInfo, src []byte) error { return (*Inet)(dst).DecodeText(ci, src) } -func (dst *Cidr) DecodeBinary(ci *ConnInfo, src []byte) error { +func (dst *CIDR) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Inet)(dst).DecodeBinary(ci, src) } -func (src *Cidr) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *CIDR) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return (*Inet)(src).EncodeText(ci, buf) } -func (src *Cidr) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *CIDR) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return (*Inet)(src).EncodeBinary(ci, buf) } diff --git a/pgtype/cidr_array.go b/pgtype/cidr_array.go index b8a70d63..9b7b50fa 100644 --- a/pgtype/cidr_array.go +++ b/pgtype/cidr_array.go @@ -9,28 +9,28 @@ import ( "github.com/jackc/pgx/pgio" ) -type CidrArray struct { - Elements []Cidr +type CIDRArray struct { + Elements []CIDR Dimensions []ArrayDimension Status Status } -func (dst *CidrArray) Set(src interface{}) error { +func (dst *CIDRArray) Set(src interface{}) error { switch value := src.(type) { case []*net.IPNet: if value == nil { - *dst = CidrArray{Status: Null} + *dst = CIDRArray{Status: Null} } else if len(value) == 0 { - *dst = CidrArray{Status: Present} + *dst = CIDRArray{Status: Present} } else { - elements := make([]Cidr, len(value)) + elements := make([]CIDR, len(value)) for i := range value { if err := elements[i].Set(value[i]); err != nil { return err } } - *dst = CidrArray{ + *dst = CIDRArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, Status: Present, @@ -39,17 +39,17 @@ func (dst *CidrArray) Set(src interface{}) error { case []net.IP: if value == nil { - *dst = CidrArray{Status: Null} + *dst = CIDRArray{Status: Null} } else if len(value) == 0 { - *dst = CidrArray{Status: Present} + *dst = CIDRArray{Status: Present} } else { - elements := make([]Cidr, len(value)) + elements := make([]CIDR, len(value)) for i := range value { if err := elements[i].Set(value[i]); err != nil { return err } } - *dst = CidrArray{ + *dst = CIDRArray{ Elements: elements, Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, Status: Present, @@ -60,13 +60,13 @@ func (dst *CidrArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Cidr", value) + return fmt.Errorf("cannot convert %v to CIDR", value) } return nil } -func (dst *CidrArray) Get() interface{} { +func (dst *CIDRArray) Get() interface{} { switch dst.Status { case Present: return dst @@ -77,7 +77,7 @@ func (dst *CidrArray) Get() interface{} { } } -func (src *CidrArray) AssignTo(dst interface{}) error { +func (src *CIDRArray) AssignTo(dst interface{}) error { switch src.Status { case Present: switch v := dst.(type) { @@ -112,9 +112,9 @@ func (src *CidrArray) AssignTo(dst interface{}) error { return fmt.Errorf("cannot decode %v into %T", src, dst) } -func (dst *CidrArray) DecodeText(ci *ConnInfo, src []byte) error { +func (dst *CIDRArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = CidrArray{Status: Null} + *dst = CIDRArray{Status: Null} return nil } @@ -123,13 +123,13 @@ func (dst *CidrArray) DecodeText(ci *ConnInfo, src []byte) error { return err } - var elements []Cidr + var elements []CIDR if len(uta.Elements) > 0 { - elements = make([]Cidr, len(uta.Elements)) + elements = make([]CIDR, len(uta.Elements)) for i, s := range uta.Elements { - var elem Cidr + var elem CIDR var elemSrc []byte if s != "NULL" { elemSrc = []byte(s) @@ -143,14 +143,14 @@ func (dst *CidrArray) DecodeText(ci *ConnInfo, src []byte) error { } } - *dst = CidrArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + *dst = CIDRArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} return nil } -func (dst *CidrArray) DecodeBinary(ci *ConnInfo, src []byte) error { +func (dst *CIDRArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = CidrArray{Status: Null} + *dst = CIDRArray{Status: Null} return nil } @@ -161,7 +161,7 @@ func (dst *CidrArray) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(arrayHeader.Dimensions) == 0 { - *dst = CidrArray{Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = CIDRArray{Dimensions: arrayHeader.Dimensions, Status: Present} return nil } @@ -170,7 +170,7 @@ func (dst *CidrArray) DecodeBinary(ci *ConnInfo, src []byte) error { elementCount *= d.Length } - elements := make([]Cidr, elementCount) + elements := make([]CIDR, elementCount) for i := range elements { elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) @@ -186,11 +186,11 @@ func (dst *CidrArray) DecodeBinary(ci *ConnInfo, src []byte) error { } } - *dst = CidrArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + *dst = CIDRArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} return nil } -func (src *CidrArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *CIDRArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -247,7 +247,7 @@ func (src *CidrArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *CidrArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *CIDRArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -292,7 +292,7 @@ func (src *CidrArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } // Scan implements the database/sql Scanner interface. -func (dst *CidrArray) Scan(src interface{}) error { +func (dst *CIDRArray) Scan(src interface{}) error { if src == nil { return dst.DecodeText(nil, nil) } @@ -310,7 +310,7 @@ func (dst *CidrArray) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *CidrArray) Value() (driver.Value, error) { +func (src *CIDRArray) Value() (driver.Value, error) { buf, err := src.EncodeText(nil, nil) if err != nil { return nil, err diff --git a/pgtype/cidr_array_test.go b/pgtype/cidr_array_test.go index 1ebe5195..70d3f65b 100644 --- a/pgtype/cidr_array_test.go +++ b/pgtype/cidr_array_test.go @@ -9,40 +9,40 @@ import ( "github.com/jackc/pgx/pgtype/testutil" ) -func TestCidrArrayTranscode(t *testing.T) { +func TestCIDRArrayTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "cidr[]", []interface{}{ - &pgtype.CidrArray{ + &pgtype.CIDRArray{ Elements: nil, Dimensions: nil, Status: pgtype.Present, }, - &pgtype.CidrArray{ - Elements: []pgtype.Cidr{ - pgtype.Cidr{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, - pgtype.Cidr{Status: pgtype.Null}, + &pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + pgtype.CIDR{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.CIDR{Status: pgtype.Null}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, - &pgtype.CidrArray{Status: pgtype.Null}, - &pgtype.CidrArray{ - Elements: []pgtype.Cidr{ - pgtype.Cidr{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, - pgtype.Cidr{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, - pgtype.Cidr{IPNet: mustParseCidr(t, "192.168.0.1/32"), Status: pgtype.Present}, - pgtype.Cidr{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, - pgtype.Cidr{Status: pgtype.Null}, - pgtype.Cidr{IPNet: mustParseCidr(t, "255.0.0.0/8"), Status: pgtype.Present}, + &pgtype.CIDRArray{Status: pgtype.Null}, + &pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + pgtype.CIDR{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.CIDR{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.CIDR{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.CIDR{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + pgtype.CIDR{Status: pgtype.Null}, + pgtype.CIDR{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, - &pgtype.CidrArray{ - Elements: []pgtype.Cidr{ - pgtype.Cidr{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, - pgtype.Cidr{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, - pgtype.Cidr{IPNet: mustParseCidr(t, "192.168.0.1/32"), Status: pgtype.Present}, - pgtype.Cidr{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + &pgtype.CIDRArray{ + Elements: []pgtype.CIDR{ + pgtype.CIDR{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.CIDR{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.CIDR{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.CIDR{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, @@ -53,37 +53,37 @@ func TestCidrArrayTranscode(t *testing.T) { }) } -func TestCidrArraySet(t *testing.T) { +func TestCIDRArraySet(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.CidrArray + result pgtype.CIDRArray }{ { - source: []*net.IPNet{mustParseCidr(t, "127.0.0.1/32")}, - result: pgtype.CidrArray{ - Elements: []pgtype.Cidr{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + source: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present}, }, { source: (([]*net.IPNet)(nil)), - result: pgtype.CidrArray{Status: pgtype.Null}, + result: pgtype.CIDRArray{Status: pgtype.Null}, }, { - source: []net.IP{mustParseCidr(t, "127.0.0.1/32").IP}, - result: pgtype.CidrArray{ - Elements: []pgtype.Cidr{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + source: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, + result: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present}, }, { source: (([]net.IP)(nil)), - result: pgtype.CidrArray{Status: pgtype.Null}, + result: pgtype.CIDRArray{Status: pgtype.Null}, }, } for i, tt := range successfulTests { - var r pgtype.CidrArray + var r pgtype.CIDRArray err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -95,27 +95,27 @@ func TestCidrArraySet(t *testing.T) { } } -func TestCidrArrayAssignTo(t *testing.T) { +func TestCIDRArrayAssignTo(t *testing.T) { var ipnetSlice []*net.IPNet var ipSlice []net.IP simpleTests := []struct { - src pgtype.CidrArray + src pgtype.CIDRArray dst interface{} expected interface{} }{ { - src: pgtype.CidrArray{ - Elements: []pgtype.Cidr{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, dst: &ipnetSlice, - expected: []*net.IPNet{mustParseCidr(t, "127.0.0.1/32")}, + expected: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, }, { - src: pgtype.CidrArray{ - Elements: []pgtype.Cidr{{Status: pgtype.Null}}, + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{Status: pgtype.Null}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, @@ -123,17 +123,17 @@ func TestCidrArrayAssignTo(t *testing.T) { expected: []*net.IPNet{nil}, }, { - src: pgtype.CidrArray{ - Elements: []pgtype.Cidr{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, dst: &ipSlice, - expected: []net.IP{mustParseCidr(t, "127.0.0.1/32").IP}, + expected: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, }, { - src: pgtype.CidrArray{ - Elements: []pgtype.Cidr{{Status: pgtype.Null}}, + src: pgtype.CIDRArray{ + Elements: []pgtype.CIDR{{Status: pgtype.Null}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, @@ -141,12 +141,12 @@ func TestCidrArrayAssignTo(t *testing.T) { expected: []net.IP{nil}, }, { - src: pgtype.CidrArray{Status: pgtype.Null}, + src: pgtype.CIDRArray{Status: pgtype.Null}, dst: &ipnetSlice, expected: (([]*net.IPNet)(nil)), }, { - src: pgtype.CidrArray{Status: pgtype.Null}, + src: pgtype.CIDRArray{Status: pgtype.Null}, dst: &ipSlice, expected: (([]net.IP)(nil)), }, diff --git a/pgtype/inet_array_test.go b/pgtype/inet_array_test.go index c0465922..3e2b6a3c 100644 --- a/pgtype/inet_array_test.go +++ b/pgtype/inet_array_test.go @@ -18,7 +18,7 @@ func TestInetArrayTranscode(t *testing.T) { }, &pgtype.InetArray{ Elements: []pgtype.Inet{ - pgtype.Inet{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, pgtype.Inet{Status: pgtype.Null}, }, Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, @@ -27,22 +27,22 @@ func TestInetArrayTranscode(t *testing.T) { &pgtype.InetArray{Status: pgtype.Null}, &pgtype.InetArray{ Elements: []pgtype.Inet{ - pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "192.168.0.1/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, pgtype.Inet{Status: pgtype.Null}, - pgtype.Inet{IPNet: mustParseCidr(t, "255.0.0.0/8"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, Status: pgtype.Present, }, &pgtype.InetArray{ Elements: []pgtype.Inet{ - pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "192.168.0.1/32"), Status: pgtype.Present}, - pgtype.Inet{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{ {Length: 2, LowerBound: 4}, @@ -59,9 +59,9 @@ func TestInetArraySet(t *testing.T) { result pgtype.InetArray }{ { - source: []*net.IPNet{mustParseCidr(t, "127.0.0.1/32")}, + source: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, result: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present}, }, @@ -70,9 +70,9 @@ func TestInetArraySet(t *testing.T) { result: pgtype.InetArray{Status: pgtype.Null}, }, { - source: []net.IP{mustParseCidr(t, "127.0.0.1/32").IP}, + source: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, result: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present}, }, @@ -106,12 +106,12 @@ func TestInetArrayAssignTo(t *testing.T) { }{ { src: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, dst: &ipnetSlice, - expected: []*net.IPNet{mustParseCidr(t, "127.0.0.1/32")}, + expected: []*net.IPNet{mustParseCIDR(t, "127.0.0.1/32")}, }, { src: pgtype.InetArray{ @@ -124,12 +124,12 @@ func TestInetArrayAssignTo(t *testing.T) { }, { src: pgtype.InetArray{ - Elements: []pgtype.Inet{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Elements: []pgtype.Inet{{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, Status: pgtype.Present, }, dst: &ipSlice, - expected: []net.IP{mustParseCidr(t, "127.0.0.1/32").IP}, + expected: []net.IP{mustParseCIDR(t, "127.0.0.1/32").IP}, }, { src: pgtype.InetArray{ diff --git a/pgtype/inet_test.go b/pgtype/inet_test.go index b883df8e..32d66999 100644 --- a/pgtype/inet_test.go +++ b/pgtype/inet_test.go @@ -12,16 +12,16 @@ import ( func TestInetTranscode(t *testing.T) { for _, pgTypeName := range []string{"inet", "cidr"} { testutil.TestSuccessfulTranscode(t, pgTypeName, []interface{}{ - &pgtype.Inet{IPNet: mustParseCidr(t, "0.0.0.0/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCidr(t, "192.168.1.0/24"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCidr(t, "255.0.0.0/8"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCidr(t, "255.255.255.255/32"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCidr(t, "::/128"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCidr(t, "::/0"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCidr(t, "::1/128"), Status: pgtype.Present}, - &pgtype.Inet{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "0.0.0.0/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "12.34.56.0/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.1.0/24"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "255.0.0.0/8"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "255.255.255.255/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "::/128"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "::/0"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "::1/128"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, &pgtype.Inet{Status: pgtype.Null}, }) } @@ -32,9 +32,9 @@ func TestInetSet(t *testing.T) { source interface{} result pgtype.Inet }{ - {source: mustParseCidr(t, "127.0.0.1/32"), result: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, - {source: mustParseCidr(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, - {source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + {source: mustParseCIDR(t, "127.0.0.1/32"), result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + {source: mustParseCIDR(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, + {source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, } for i, tt := range successfulTests { @@ -61,8 +61,8 @@ func TestInetAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ipnet, expected: *mustParseCidr(t, "127.0.0.1/32")}, - {src: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ip, expected: mustParseCidr(t, "127.0.0.1/32").IP}, + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &ip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, {src: pgtype.Inet{Status: pgtype.Null}, dst: &pipnet, expected: ((*net.IPNet)(nil))}, {src: pgtype.Inet{Status: pgtype.Null}, dst: &pip, expected: ((*net.IP)(nil))}, } @@ -83,8 +83,8 @@ func TestInetAssignTo(t *testing.T) { dst interface{} expected interface{} }{ - {src: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pipnet, expected: *mustParseCidr(t, "127.0.0.1/32")}, - {src: pgtype.Inet{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pip, expected: mustParseCidr(t, "127.0.0.1/32").IP}, + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pipnet, expected: *mustParseCIDR(t, "127.0.0.1/32")}, + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, dst: &pip, expected: mustParseCIDR(t, "127.0.0.1/32").IP}, } for i, tt := range pointerAllocTests { @@ -102,7 +102,7 @@ func TestInetAssignTo(t *testing.T) { src pgtype.Inet dst interface{} }{ - {src: pgtype.Inet{IPNet: mustParseCidr(t, "192.168.0.0/16"), Status: pgtype.Present}, dst: &ip}, + {src: pgtype.Inet{IPNet: mustParseCIDR(t, "192.168.0.0/16"), Status: pgtype.Present}, dst: &ip}, {src: pgtype.Inet{Status: pgtype.Null}, dst: &ipnet}, } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 4fdcf3c2..4302a5fe 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -16,12 +16,12 @@ const ( Int4OID = 23 TextOID = 25 OIDOID = 26 - TidOID = 27 - XidOID = 28 - CidOID = 29 + TIDOID = 27 + XIDOID = 28 + CIDOID = 29 JSONOID = 114 - CidrOID = 650 - CidrArrayOID = 651 + CIDROID = 650 + CIDRArrayOID = 651 Float4OID = 700 Float8OID = 701 UnknownOID = 705 @@ -209,7 +209,7 @@ func init() { "_aclitem": &ACLItemArray{}, "_bool": &BoolArray{}, "_bytea": &ByteaArray{}, - "_cidr": &CidrArray{}, + "_cidr": &CIDRArray{}, "_date": &DateArray{}, "_float4": &Float4Array{}, "_float8": &Float8Array{}, @@ -227,8 +227,8 @@ func init() { "box": &Box{}, "bytea": &Bytea{}, "char": &QChar{}, - "cid": &Cid{}, - "cidr": &Cidr{}, + "cid": &CID{}, + "cidr": &CIDR{}, "circle": &Circle{}, "date": &Date{}, "daterange": &Daterange{}, @@ -256,7 +256,7 @@ func init() { "polygon": &Polygon{}, "record": &Record{}, "text": &Text{}, - "tid": &Tid{}, + "tid": &TID{}, "timestamp": &Timestamp{}, "timestamptz": &Timestamptz{}, "tsrange": &Tsrange{}, @@ -265,6 +265,6 @@ func init() { "uuid": &UUID{}, "varbit": &Varbit{}, "varchar": &Varchar{}, - "xid": &Xid{}, + "xid": &XID{}, } } diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 716e063d..f7e743b2 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -20,7 +20,7 @@ type _float32Slice []float32 type _float64Slice []float64 type _byteSlice []byte -func mustParseCidr(t testing.TB, s string) *net.IPNet { +func mustParseCIDR(t testing.TB, s string) *net.IPNet { _, ipnet, err := net.ParseCIDR(s) if err != nil { t.Fatal(err) diff --git a/pgtype/pguint32.go b/pgtype/pguint32.go index c15ee6d7..15b0f38d 100644 --- a/pgtype/pguint32.go +++ b/pgtype/pguint32.go @@ -11,7 +11,7 @@ import ( ) // pguint32 is the core type that is used to implement PostgreSQL types such as -// Cid and Xid. +// CID and XID. type pguint32 struct { Uint uint32 Status Status diff --git a/pgtype/tid.go b/pgtype/tid.go index 2f4412cb..d44ea3a6 100644 --- a/pgtype/tid.go +++ b/pgtype/tid.go @@ -10,7 +10,7 @@ import ( "github.com/jackc/pgx/pgio" ) -// Tid is PostgreSQL's Tuple Identifier type. +// TID is PostgreSQL's Tuple Identifier type. // // When one does // @@ -21,17 +21,17 @@ import ( // It is currently implemented as a pair unsigned two byte integers. // Its conversion functions can be found in src/backend/utils/adt/tid.c // in the PostgreSQL sources. -type Tid struct { +type TID struct { BlockNumber uint32 OffsetNumber uint16 Status Status } -func (dst *Tid) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Tid", src) +func (dst *TID) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to TID", src) } -func (dst *Tid) Get() interface{} { +func (dst *TID) Get() interface{} { switch dst.Status { case Present: return dst @@ -42,13 +42,13 @@ func (dst *Tid) Get() interface{} { } } -func (src *Tid) AssignTo(dst interface{}) error { +func (src *TID) AssignTo(dst interface{}) error { return fmt.Errorf("cannot assign %v to %T", src, dst) } -func (dst *Tid) DecodeText(ci *ConnInfo, src []byte) error { +func (dst *TID) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Tid{Status: Null} + *dst = TID{Status: Null} return nil } @@ -71,13 +71,13 @@ func (dst *Tid) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Tid{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber), Status: Present} + *dst = TID{BlockNumber: uint32(blockNumber), OffsetNumber: uint16(offsetNumber), Status: Present} return nil } -func (dst *Tid) DecodeBinary(ci *ConnInfo, src []byte) error { +func (dst *TID) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Tid{Status: Null} + *dst = TID{Status: Null} return nil } @@ -85,7 +85,7 @@ func (dst *Tid) DecodeBinary(ci *ConnInfo, src []byte) error { return fmt.Errorf("invalid length for tid: %v", len(src)) } - *dst = Tid{ + *dst = TID{ BlockNumber: binary.BigEndian.Uint32(src), OffsetNumber: binary.BigEndian.Uint16(src[4:]), Status: Present, @@ -93,7 +93,7 @@ func (dst *Tid) DecodeBinary(ci *ConnInfo, src []byte) error { return nil } -func (src *Tid) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *TID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -105,7 +105,7 @@ func (src *Tid) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return buf, nil } -func (src *Tid) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *TID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { switch src.Status { case Null: return nil, nil @@ -119,9 +119,9 @@ func (src *Tid) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { } // Scan implements the database/sql Scanner interface. -func (dst *Tid) Scan(src interface{}) error { +func (dst *TID) Scan(src interface{}) error { if src == nil { - *dst = Tid{Status: Null} + *dst = TID{Status: Null} return nil } @@ -138,6 +138,6 @@ func (dst *Tid) Scan(src interface{}) error { } // Value implements the database/sql/driver Valuer interface. -func (src *Tid) Value() (driver.Value, error) { +func (src *TID) Value() (driver.Value, error) { return EncodeValueText(src) } diff --git a/pgtype/tid_test.go b/pgtype/tid_test.go index a5430d11..9185cb31 100644 --- a/pgtype/tid_test.go +++ b/pgtype/tid_test.go @@ -7,10 +7,10 @@ import ( "github.com/jackc/pgx/pgtype/testutil" ) -func TestTidTranscode(t *testing.T) { +func TestTIDTranscode(t *testing.T) { testutil.TestSuccessfulTranscode(t, "tid", []interface{}{ - &pgtype.Tid{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, - &pgtype.Tid{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, - &pgtype.Tid{Status: pgtype.Null}, + &pgtype.TID{BlockNumber: 42, OffsetNumber: 43, Status: pgtype.Present}, + &pgtype.TID{BlockNumber: 4294967295, OffsetNumber: 65535, Status: pgtype.Present}, + &pgtype.TID{Status: pgtype.Null}, }) } diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh index d7abcbcf..1aa6c354 100644 --- a/pgtype/typed_array_gen.sh +++ b/pgtype/typed_array_gen.sh @@ -8,7 +8,7 @@ erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_type erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_type_name=float4 text_null=NULL binary_format=true typed_array.go.erb > float4_array.go erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_type_name=float8 text_null=NULL binary_format=true typed_array.go.erb > float8_array.go erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_type_name=inet text_null=NULL binary_format=true typed_array.go.erb > inet_array.go -erb pgtype_array_type=CidrArray pgtype_element_type=Cidr go_array_types=[]*net.IPNet,[]net.IP element_type_name=cidr text_null=NULL binary_format=true typed_array.go.erb > cidr_array.go +erb pgtype_array_type=CIDRArray pgtype_element_type=CIDR go_array_types=[]*net.IPNet,[]net.IP element_type_name=cidr text_null=NULL binary_format=true typed_array.go.erb > cidr_array.go erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_type_name=text text_null='"NULL"' binary_format=true typed_array.go.erb > text_array.go erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string element_type_name=varchar text_null='"NULL"' binary_format=true typed_array.go.erb > varchar_array.go erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL binary_format=true typed_array.go.erb > bytea_array.go diff --git a/pgtype/xid.go b/pgtype/xid.go index 84acd1b0..f66f5367 100644 --- a/pgtype/xid.go +++ b/pgtype/xid.go @@ -4,7 +4,7 @@ import ( "database/sql/driver" ) -// Xid is PostgreSQL's Transaction ID type. +// XID is PostgreSQL's Transaction ID type. // // In later versions of PostgreSQL, it is the type used for the backend_xid // and backend_xmin columns of the pg_stat_activity system view. @@ -18,47 +18,47 @@ import ( // It is currently implemented as an unsigned four byte integer. // Its definition can be found in src/include/postgres_ext.h as TransactionId // in the PostgreSQL sources. -type Xid pguint32 +type XID pguint32 -// Set converts from src to dst. Note that as Xid is not a general +// Set converts from src to dst. Note that as XID is not a general // number type Set does not do automatic type conversion as other number // types do. -func (dst *Xid) Set(src interface{}) error { +func (dst *XID) Set(src interface{}) error { return (*pguint32)(dst).Set(src) } -func (dst *Xid) Get() interface{} { +func (dst *XID) Get() interface{} { return (*pguint32)(dst).Get() } -// AssignTo assigns from src to dst. Note that as Xid is not a general number +// AssignTo assigns from src to dst. Note that as XID is not a general number // type AssignTo does not do automatic type conversion as other number types do. -func (src *Xid) AssignTo(dst interface{}) error { +func (src *XID) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *Xid) DecodeText(ci *ConnInfo, src []byte) error { +func (dst *XID) DecodeText(ci *ConnInfo, src []byte) error { return (*pguint32)(dst).DecodeText(ci, src) } -func (dst *Xid) DecodeBinary(ci *ConnInfo, src []byte) error { +func (dst *XID) DecodeBinary(ci *ConnInfo, src []byte) error { return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src *Xid) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *XID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return (*pguint32)(src).EncodeText(ci, buf) } -func (src *Xid) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { +func (src *XID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return (*pguint32)(src).EncodeBinary(ci, buf) } // Scan implements the database/sql Scanner interface. -func (dst *Xid) Scan(src interface{}) error { +func (dst *XID) Scan(src interface{}) error { return (*pguint32)(dst).Scan(src) } // Value implements the database/sql/driver Valuer interface. -func (src *Xid) Value() (driver.Value, error) { +func (src *XID) Value() (driver.Value, error) { return (*pguint32)(src).Value() } diff --git a/pgtype/xid_test.go b/pgtype/xid_test.go index c4a1bec3..d0f3f0ab 100644 --- a/pgtype/xid_test.go +++ b/pgtype/xid_test.go @@ -8,11 +8,11 @@ import ( "github.com/jackc/pgx/pgtype/testutil" ) -func TestXidTranscode(t *testing.T) { +func TestXIDTranscode(t *testing.T) { pgTypeName := "xid" values := []interface{}{ - &pgtype.Xid{Uint: 42, Status: pgtype.Present}, - &pgtype.Xid{Status: pgtype.Null}, + &pgtype.XID{Uint: 42, Status: pgtype.Present}, + &pgtype.XID{Status: pgtype.Null}, } eqFunc := func(a, b interface{}) bool { return reflect.DeepEqual(a, b) @@ -28,16 +28,16 @@ func TestXidTranscode(t *testing.T) { } } -func TestXidSet(t *testing.T) { +func TestXIDSet(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.Xid + result pgtype.XID }{ - {source: uint32(1), result: pgtype.Xid{Uint: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.XID{Uint: 1, Status: pgtype.Present}}, } for i, tt := range successfulTests { - var r pgtype.Xid + var r pgtype.XID err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -49,17 +49,17 @@ func TestXidSet(t *testing.T) { } } -func TestXidAssignTo(t *testing.T) { +func TestXIDAssignTo(t *testing.T) { var ui32 uint32 var pui32 *uint32 simpleTests := []struct { - src pgtype.Xid + src pgtype.XID dst interface{} expected interface{} }{ - {src: pgtype.Xid{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.Xid{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + {src: pgtype.XID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.XID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, } for i, tt := range simpleTests { @@ -74,11 +74,11 @@ func TestXidAssignTo(t *testing.T) { } pointerAllocTests := []struct { - src pgtype.Xid + src pgtype.XID dst interface{} expected interface{} }{ - {src: pgtype.Xid{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + {src: pgtype.XID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, } for i, tt := range pointerAllocTests { @@ -93,10 +93,10 @@ func TestXidAssignTo(t *testing.T) { } errorTests := []struct { - src pgtype.Xid + src pgtype.XID dst interface{} }{ - {src: pgtype.Xid{Status: pgtype.Null}, dst: &ui32}, + {src: pgtype.XID{Status: pgtype.Null}, dst: &ui32}, } for i, tt := range errorTests { diff --git a/stdlib/sql.go b/stdlib/sql.go index 00329617..b9cd3295 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -100,7 +100,7 @@ func init() { databaseSqlOIDs = make(map[pgtype.OID]bool) databaseSqlOIDs[pgtype.BoolOID] = true databaseSqlOIDs[pgtype.ByteaOID] = true - databaseSqlOIDs[pgtype.CidOID] = true + databaseSqlOIDs[pgtype.CIDOID] = true databaseSqlOIDs[pgtype.DateOID] = true databaseSqlOIDs[pgtype.Float4OID] = true databaseSqlOIDs[pgtype.Float8OID] = true @@ -110,7 +110,7 @@ func init() { databaseSqlOIDs[pgtype.OIDOID] = true databaseSqlOIDs[pgtype.TimestampOID] = true databaseSqlOIDs[pgtype.TimestamptzOID] = true - databaseSqlOIDs[pgtype.XidOID] = true + databaseSqlOIDs[pgtype.XIDOID] = true } type Driver struct { @@ -432,8 +432,8 @@ func (r *Rows) Next(dest []driver.Value) error { r.values[i] = &pgtype.Bool{} case pgtype.ByteaOID: r.values[i] = &pgtype.Bytea{} - case pgtype.CidOID: - r.values[i] = &pgtype.Cid{} + case pgtype.CIDOID: + r.values[i] = &pgtype.CID{} case pgtype.DateOID: r.values[i] = &pgtype.Date{} case pgtype.Float4OID: @@ -452,8 +452,8 @@ func (r *Rows) Next(dest []driver.Value) error { r.values[i] = &pgtype.Timestamp{} case pgtype.TimestamptzOID: r.values[i] = &pgtype.Timestamptz{} - case pgtype.XidOID: - r.values[i] = &pgtype.Xid{} + case pgtype.XIDOID: + r.values[i] = &pgtype.XID{} default: r.values[i] = &pgtype.GenericText{} } diff --git a/values_test.go b/values_test.go index 37bf91cc..b8aec46a 100644 --- a/values_test.go +++ b/values_test.go @@ -225,7 +225,7 @@ func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string) { } } -func mustParseCidr(t *testing.T, s string) *net.IPNet { +func mustParseCIDR(t *testing.T, s string) *net.IPNet { _, ipnet, err := net.ParseCIDR(s) if err != nil { t.Fatal(err) @@ -260,7 +260,7 @@ func TestStringToNotTextTypeTranscode(t *testing.T) { } } -func TestInetCidrTranscodeIPNet(t *testing.T) { +func TestInetCIDRTranscodeIPNet(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -270,26 +270,26 @@ func TestInetCidrTranscodeIPNet(t *testing.T) { sql string value *net.IPNet }{ - {"select $1::inet", mustParseCidr(t, "0.0.0.0/32")}, - {"select $1::inet", mustParseCidr(t, "127.0.0.1/32")}, - {"select $1::inet", mustParseCidr(t, "12.34.56.0/32")}, - {"select $1::inet", mustParseCidr(t, "192.168.1.0/24")}, - {"select $1::inet", mustParseCidr(t, "255.0.0.0/8")}, - {"select $1::inet", mustParseCidr(t, "255.255.255.255/32")}, - {"select $1::inet", mustParseCidr(t, "::/128")}, - {"select $1::inet", mustParseCidr(t, "::/0")}, - {"select $1::inet", mustParseCidr(t, "::1/128")}, - {"select $1::inet", mustParseCidr(t, "2607:f8b0:4009:80b::200e/128")}, - {"select $1::cidr", mustParseCidr(t, "0.0.0.0/32")}, - {"select $1::cidr", mustParseCidr(t, "127.0.0.1/32")}, - {"select $1::cidr", mustParseCidr(t, "12.34.56.0/32")}, - {"select $1::cidr", mustParseCidr(t, "192.168.1.0/24")}, - {"select $1::cidr", mustParseCidr(t, "255.0.0.0/8")}, - {"select $1::cidr", mustParseCidr(t, "255.255.255.255/32")}, - {"select $1::cidr", mustParseCidr(t, "::/128")}, - {"select $1::cidr", mustParseCidr(t, "::/0")}, - {"select $1::cidr", mustParseCidr(t, "::1/128")}, - {"select $1::cidr", mustParseCidr(t, "2607:f8b0:4009:80b::200e/128")}, + {"select $1::inet", mustParseCIDR(t, "0.0.0.0/32")}, + {"select $1::inet", mustParseCIDR(t, "127.0.0.1/32")}, + {"select $1::inet", mustParseCIDR(t, "12.34.56.0/32")}, + {"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")}, + {"select $1::inet", mustParseCIDR(t, "255.0.0.0/8")}, + {"select $1::inet", mustParseCIDR(t, "255.255.255.255/32")}, + {"select $1::inet", mustParseCIDR(t, "::/128")}, + {"select $1::inet", mustParseCIDR(t, "::/0")}, + {"select $1::inet", mustParseCIDR(t, "::1/128")}, + {"select $1::inet", mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128")}, + {"select $1::cidr", mustParseCIDR(t, "0.0.0.0/32")}, + {"select $1::cidr", mustParseCIDR(t, "127.0.0.1/32")}, + {"select $1::cidr", mustParseCIDR(t, "12.34.56.0/32")}, + {"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")}, + {"select $1::cidr", mustParseCIDR(t, "255.0.0.0/8")}, + {"select $1::cidr", mustParseCIDR(t, "255.255.255.255/32")}, + {"select $1::cidr", mustParseCIDR(t, "::/128")}, + {"select $1::cidr", mustParseCIDR(t, "::/0")}, + {"select $1::cidr", mustParseCIDR(t, "::1/128")}, + {"select $1::cidr", mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128")}, } for i, tt := range tests { @@ -309,7 +309,7 @@ func TestInetCidrTranscodeIPNet(t *testing.T) { } } -func TestInetCidrTranscodeIP(t *testing.T) { +func TestInetCIDRTranscodeIP(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -353,8 +353,8 @@ func TestInetCidrTranscodeIP(t *testing.T) { sql string value *net.IPNet }{ - {"select $1::inet", mustParseCidr(t, "192.168.1.0/24")}, - {"select $1::cidr", mustParseCidr(t, "192.168.1.0/24")}, + {"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")}, + {"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")}, } for i, tt := range failTests { var actual net.IP @@ -369,7 +369,7 @@ func TestInetCidrTranscodeIP(t *testing.T) { } } -func TestInetCidrArrayTranscodeIPNet(t *testing.T) { +func TestInetCIDRArrayTranscodeIPNet(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -382,31 +382,31 @@ func TestInetCidrArrayTranscodeIPNet(t *testing.T) { { "select $1::inet[]", []*net.IPNet{ - mustParseCidr(t, "0.0.0.0/32"), - mustParseCidr(t, "127.0.0.1/32"), - mustParseCidr(t, "12.34.56.0/32"), - mustParseCidr(t, "192.168.1.0/24"), - mustParseCidr(t, "255.0.0.0/8"), - mustParseCidr(t, "255.255.255.255/32"), - mustParseCidr(t, "::/128"), - mustParseCidr(t, "::/0"), - mustParseCidr(t, "::1/128"), - mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), + mustParseCIDR(t, "0.0.0.0/32"), + mustParseCIDR(t, "127.0.0.1/32"), + mustParseCIDR(t, "12.34.56.0/32"), + mustParseCIDR(t, "192.168.1.0/24"), + mustParseCIDR(t, "255.0.0.0/8"), + mustParseCIDR(t, "255.255.255.255/32"), + mustParseCIDR(t, "::/128"), + mustParseCIDR(t, "::/0"), + mustParseCIDR(t, "::1/128"), + mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), }, }, { "select $1::cidr[]", []*net.IPNet{ - mustParseCidr(t, "0.0.0.0/32"), - mustParseCidr(t, "127.0.0.1/32"), - mustParseCidr(t, "12.34.56.0/32"), - mustParseCidr(t, "192.168.1.0/24"), - mustParseCidr(t, "255.0.0.0/8"), - mustParseCidr(t, "255.255.255.255/32"), - mustParseCidr(t, "::/128"), - mustParseCidr(t, "::/0"), - mustParseCidr(t, "::1/128"), - mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), + mustParseCIDR(t, "0.0.0.0/32"), + mustParseCIDR(t, "127.0.0.1/32"), + mustParseCIDR(t, "12.34.56.0/32"), + mustParseCIDR(t, "192.168.1.0/24"), + mustParseCIDR(t, "255.0.0.0/8"), + mustParseCIDR(t, "255.255.255.255/32"), + mustParseCIDR(t, "::/128"), + mustParseCIDR(t, "::/0"), + mustParseCIDR(t, "::1/128"), + mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), }, }, } @@ -428,7 +428,7 @@ func TestInetCidrArrayTranscodeIPNet(t *testing.T) { } } -func TestInetCidrArrayTranscodeIP(t *testing.T) { +func TestInetCIDRArrayTranscodeIP(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -483,15 +483,15 @@ func TestInetCidrArrayTranscodeIP(t *testing.T) { { "select $1::inet[]", []*net.IPNet{ - mustParseCidr(t, "12.34.56.0/32"), - mustParseCidr(t, "192.168.1.0/24"), + mustParseCIDR(t, "12.34.56.0/32"), + mustParseCIDR(t, "192.168.1.0/24"), }, }, { "select $1::cidr[]", []*net.IPNet{ - mustParseCidr(t, "12.34.56.0/32"), - mustParseCidr(t, "192.168.1.0/24"), + mustParseCIDR(t, "12.34.56.0/32"), + mustParseCIDR(t, "192.168.1.0/24"), }, }, } @@ -509,7 +509,7 @@ func TestInetCidrArrayTranscodeIP(t *testing.T) { } } -func TestInetCidrTranscodeWithJustIP(t *testing.T) { +func TestInetCIDRTranscodeWithJustIP(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -534,7 +534,7 @@ func TestInetCidrTranscodeWithJustIP(t *testing.T) { } for i, tt := range tests { - expected := mustParseCidr(t, tt.value) + expected := mustParseCIDR(t, tt.value) var actual net.IPNet err := conn.QueryRow(tt.sql, expected.IP).Scan(&actual) From 605222780301ef21e80363ab31da5c00327e9f88 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 3 Jun 2017 12:03:19 -0500 Subject: [PATCH 244/264] Update v3 changelog --- v3.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/v3.md b/v3.md index b369be18..a30ca474 100644 --- a/v3.md +++ b/v3.md @@ -52,6 +52,8 @@ Added ctx parameter to (Conn/Tx/ConnPool).PrepareEx Added batch operations +Use Go casing convention for OID, UUID, JSON(B), ACLItem, CID, TID, XID, and CIDR + ## TODO / Possible / Investigate Organize errors better From fb90fb27295133aa4dc9cf597be0e78fa00b2e84 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 4 Jun 2017 21:18:26 -0500 Subject: [PATCH 245/264] Add notification response hook refs #239 --- conn.go | 41 +++++++++++++++++++++++++++++++++++++++++ conn_test.go | 27 +++++++++++++++++++++++++++ messages.go | 4 ++++ v3.md | 2 ++ 4 files changed, 74 insertions(+) diff --git a/conn.go b/conn.go index 223808c5..20a56807 100644 --- a/conn.go +++ b/conn.go @@ -47,6 +47,13 @@ func init() { }) } +// NoticeHandler is a function that can handle notices received from the +// PostgreSQL server. Notices can be received at any time, usually during +// handling of a query response. The *Conn is provided so the handler is aware +// of the origin of the notice, but it must not invoke any query method. Be +// aware that this is distinct from LISTEN/NOTIFY notification. +type NoticeHandler func(*Conn, *Notice) + // DialFunc is a function that can be used to connect to a PostgreSQL server type DialFunc func(network, addr string) (net.Conn, error) @@ -64,6 +71,7 @@ type ConnConfig struct { LogLevel int Dial DialFunc RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) + OnNotice NoticeHandler // Callback function called when a notice response is received. } func (cc *ConnConfig) networkAddress() (network, address string) { @@ -102,6 +110,7 @@ type Conn struct { fp *fastpath poolResetCount int preallocatedRows []Rows + onNotice NoticeHandler mux sync.Mutex status byte // One of connStatus* constants @@ -235,6 +244,8 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) } } + c.onNotice = config.OnNotice + network, address := c.config.networkAddress() if c.config.Dial == nil { c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial @@ -1079,6 +1090,8 @@ func (c *Conn) processContextFreeMsg(msg pgproto3.BackendMessage) (err error) { switch msg := msg.(type) { case *pgproto3.ErrorResponse: return c.rxErrorResponse(msg) + case *pgproto3.NoticeResponse: + c.rxNoticeResponse(msg) case *pgproto3.NotificationResponse: c.rxNotificationResponse(msg) case *pgproto3.ReadyForQuery: @@ -1163,6 +1176,34 @@ func (c *Conn) rxErrorResponse(msg *pgproto3.ErrorResponse) PgError { return err } +func (c *Conn) rxNoticeResponse(msg *pgproto3.NoticeResponse) { + if c.onNotice == nil { + return + } + + notice := &Notice{ + Severity: msg.Severity, + Code: msg.Code, + Message: msg.Message, + Detail: msg.Detail, + Hint: msg.Hint, + Position: msg.Position, + InternalPosition: msg.InternalPosition, + InternalQuery: msg.InternalQuery, + Where: msg.Where, + SchemaName: msg.SchemaName, + TableName: msg.TableName, + ColumnName: msg.ColumnName, + DataTypeName: msg.DataTypeName, + ConstraintName: msg.ConstraintName, + File: msg.File, + Line: msg.Line, + Routine: msg.Routine, + } + + c.onNotice(c, notice) +} + func (c *Conn) rxBackendKeyData(msg *pgproto3.BackendKeyData) { c.pid = msg.ProcessID c.secretKey = msg.SecretKey diff --git a/conn_test.go b/conn_test.go index 8ec3c131..d9369a1a 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1898,3 +1898,30 @@ func TestIdentifierSanitize(t *testing.T) { } } } + +func TestConnOnNotice(t *testing.T) { + t.Parallel() + + var msg string + + connConfig := *defaultConnConfig + connConfig.OnNotice = func(c *pgx.Conn, notice *pgx.Notice) { + msg = notice.Message + } + conn := mustConnect(t, connConfig) + defer closeConn(t, conn) + + _, err := conn.Exec(`do $$ +begin + raise notice 'hello, world'; +end$$;`) + if err != nil { + t.Fatal(err) + } + + if msg != "hello, world" { + t.Errorf("msg => %v, want %v", msg, "hello, world") + } + + ensureConnValid(t, conn) +} diff --git a/messages.go b/messages.go index 841aa286..53a5a67c 100644 --- a/messages.go +++ b/messages.go @@ -49,6 +49,10 @@ func (pe PgError) Error() string { return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" } +// Notice represents a notice response message reported by the PostgreSQL +// server. Be aware that this is distinct from LISTEN/NOTIFY notification. +type Notice PgError + // appendParse appends a PostgreSQL wire protocol parse message to buf and returns it. func appendParse(buf []byte, name string, query string, parameterOIDs []pgtype.OID) []byte { buf = append(buf, 'P') diff --git a/v3.md b/v3.md index a30ca474..993f9e24 100644 --- a/v3.md +++ b/v3.md @@ -54,6 +54,8 @@ Added batch operations Use Go casing convention for OID, UUID, JSON(B), ACLItem, CID, TID, XID, and CIDR +Add OnNotice + ## TODO / Possible / Investigate Organize errors better From 3ea41e6972e7b0d5dd44216388969ff30180ef5a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 4 Jun 2017 21:22:34 -0500 Subject: [PATCH 246/264] Remove unused global error --- conn.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/conn.go b/conn.go index 20a56807..75b8bfb3 100644 --- a/conn.go +++ b/conn.go @@ -181,9 +181,6 @@ func (ident Identifier) Sanitize() string { // ErrNoRows occurs when rows are expected but none are returned. var ErrNoRows = errors.New("no rows in result set") -// ErrNotificationTimeout occurs when WaitForNotification times out. -var ErrNotificationTimeout = errors.New("notification timeout") - // ErrDeadConn occurs on an attempt to use a dead connection var ErrDeadConn = errors.New("conn is dead") From 8f4178b3d3fc8922384e0989a683b112fb5d1066 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 4 Jun 2017 21:30:03 -0500 Subject: [PATCH 247/264] Use github.com/pkg/errors --- .travis.yml | 1 + conn.go | 17 +++---- conn_pool.go | 3 +- conn_pool_test.go | 3 +- copy_from.go | 3 +- copy_from_test.go | 6 +-- example_custom_type_test.go | 11 +++-- internal/sanitize/sanitize.go | 11 +++-- logger.go | 3 +- pgmock/pgmock.go | 16 +++---- pgproto3/authentication.go | 4 +- pgproto3/backend.go | 4 +- pgproto3/frontend.go | 4 +- pgproto3/startup_message.go | 10 ++-- pgtype/aclitem.go | 9 ++-- pgtype/aclitem_array.go | 9 ++-- pgtype/array.go | 36 +++++++-------- pgtype/bool.go | 13 +++--- pgtype/bool_array.go | 10 ++-- pgtype/box.go | 11 +++-- pgtype/bytea.go | 11 +++-- pgtype/bytea_array.go | 10 ++-- pgtype/cidr_array.go | 10 ++-- pgtype/circle.go | 11 +++-- pgtype/convert.go | 59 ++++++++++++------------ pgtype/database_sql.go | 3 +- pgtype/date.go | 12 ++--- pgtype/date_array.go | 10 ++-- pgtype/daterange.go | 24 +++++----- pgtype/ext/satori-uuid/uuid.go | 14 +++--- pgtype/ext/shopspring-numeric/decimal.go | 52 ++++++++++----------- pgtype/float4.go | 20 ++++---- pgtype/float4_array.go | 10 ++-- pgtype/float8.go | 16 +++---- pgtype/float8_array.go | 10 ++-- pgtype/hstore.go | 32 ++++++------- pgtype/hstore_array.go | 10 ++-- pgtype/inet.go | 15 +++--- pgtype/inet_array.go | 10 ++-- pgtype/int2.go | 32 ++++++------- pgtype/int2_array.go | 10 ++-- pgtype/int4.go | 26 +++++------ pgtype/int4_array.go | 10 ++-- pgtype/int4range.go | 24 +++++----- pgtype/int8.go | 16 +++---- pgtype/int8_array.go | 10 ++-- pgtype/int8range.go | 24 +++++----- pgtype/interval.go | 23 ++++----- pgtype/json.go | 5 +- pgtype/jsonb.go | 7 +-- pgtype/line.go | 13 +++--- pgtype/lseg.go | 11 +++-- pgtype/macaddr.go | 11 +++-- pgtype/numeric.go | 56 +++++++++++----------- pgtype/numeric_array.go | 10 ++-- pgtype/numrange.go | 24 +++++----- pgtype/oid.go | 12 ++--- pgtype/path.go | 13 +++--- pgtype/pgtype.go | 3 +- pgtype/pguint32.go | 14 +++--- pgtype/point.go | 13 +++--- pgtype/polygon.go | 13 +++--- pgtype/qchar.go | 33 ++++++------- pgtype/range.go | 39 ++++++++-------- pgtype/record.go | 15 +++--- pgtype/text.go | 9 ++-- pgtype/text_array.go | 10 ++-- pgtype/tid.go | 13 +++--- pgtype/timestamp.go | 16 +++---- pgtype/timestamp_array.go | 10 ++-- pgtype/timestamptz.go | 12 ++--- pgtype/timestamptz_array.go | 10 ++-- pgtype/tsrange.go | 24 +++++----- pgtype/tstzrange.go | 24 +++++----- pgtype/typed_array.go.erb | 8 ++-- pgtype/typed_range.go.erb | 22 ++++----- pgtype/uuid.go | 14 +++--- pgtype/varbit.go | 10 ++-- pgtype/varchar_array.go | 10 ++-- query.go | 21 +++++---- replication.go | 3 +- stdlib/sql.go | 7 +-- stress_test.go | 13 +++--- tx.go | 3 +- v3.md | 2 + values.go | 5 +- 86 files changed, 639 insertions(+), 597 deletions(-) diff --git a/.travis.yml b/.travis.yml index fd3850e4..971b46a9 100644 --- a/.travis.yml +++ b/.travis.yml @@ -56,6 +56,7 @@ install: - go get -u github.com/hashicorp/go-version - go get -u github.com/satori/go.uuid - go get -u github.com/Sirupsen/logrus + - go get -u github.com/pkg/errors script: - go test -v -race ./... diff --git a/conn.go b/conn.go index 75b8bfb3..da8ed655 100644 --- a/conn.go +++ b/conn.go @@ -6,7 +6,6 @@ import ( "crypto/tls" "encoding/binary" "encoding/hex" - "errors" "fmt" "io" "net" @@ -21,6 +20,8 @@ import ( "sync/atomic" "time" + "github.com/pkg/errors" + "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgproto3" "github.com/jackc/pgx/pgtype" @@ -778,7 +779,7 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared } if len(opts.ParameterOIDs) > 65535 { - return nil, fmt.Errorf("Number of PrepareExOptions ParameterOIDs must be between 0 and 65535, received %d", len(opts.ParameterOIDs)) + return nil, errors.Errorf("Number of PrepareExOptions ParameterOIDs must be between 0 and 65535, received %d", len(opts.ParameterOIDs)) } buf := appendParse(c.wbuf, name, sql, opts.ParameterOIDs) @@ -809,7 +810,7 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared ps.ParameterOIDs = c.rxParameterDescription(msg) if len(ps.ParameterOIDs) > 65535 && softErr == nil { - softErr = fmt.Errorf("PostgreSQL supports maximum of 65535 parameters, received %d", len(ps.ParameterOIDs)) + softErr = errors.Errorf("PostgreSQL supports maximum of 65535 parameters, received %d", len(ps.ParameterOIDs)) } case *pgproto3.RowDescription: ps.FieldDescriptions = c.rxRowDescription(msg) @@ -822,7 +823,7 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared ps.FieldDescriptions[i].FormatCode = TextFormatCode } } else { - return nil, fmt.Errorf("unknown oid: %d", ps.FieldDescriptions[i].DataType) + return nil, errors.Errorf("unknown oid: %d", ps.FieldDescriptions[i].DataType) } } case *pgproto3.ReadyForQuery: @@ -1029,7 +1030,7 @@ func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error { func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}) (err error) { if len(ps.ParameterOIDs) != len(arguments) { - return fmt.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOIDs), len(arguments)) + return errors.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOIDs), len(arguments)) } if err := c.ensureConnectionReadyForQuery(); err != nil { @@ -1392,7 +1393,7 @@ func (c *Conn) cancelQuery() { _, err = cancelConn.Read(buf) if err != io.EOF { - return fmt.Errorf("Server failed to close connection after cancel query request: %v %v", err, buf) + return errors.Errorf("Server failed to close connection after cancel query request: %v %v", err, buf) } return nil @@ -1516,11 +1517,11 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, func (c *Conn) buildOneRoundTripExec(buf []byte, sql string, options *QueryExOptions, arguments []interface{}) ([]byte, error) { if len(arguments) != len(options.ParameterOIDs) { - return nil, fmt.Errorf("mismatched number of arguments (%d) and options.ParameterOIDs (%d)", len(arguments), len(options.ParameterOIDs)) + return nil, errors.Errorf("mismatched number of arguments (%d) and options.ParameterOIDs (%d)", len(arguments), len(options.ParameterOIDs)) } if len(options.ParameterOIDs) > 65535 { - return nil, fmt.Errorf("Number of QueryExOptions ParameterOIDs must be between 0 and 65535, received %d", len(options.ParameterOIDs)) + return nil, errors.Errorf("Number of QueryExOptions ParameterOIDs must be between 0 and 65535, received %d", len(options.ParameterOIDs)) } buf = appendParse(buf, "", sql, options.ParameterOIDs) diff --git a/conn_pool.go b/conn_pool.go index 40c58f49..5fa923b7 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -2,10 +2,11 @@ package pgx import ( "context" - "errors" "sync" "time" + "github.com/pkg/errors" + "github.com/jackc/pgx/pgtype" ) diff --git a/conn_pool_test.go b/conn_pool_test.go index 4e0dc199..ccc38ba9 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -2,13 +2,14 @@ package pgx_test import ( "context" - "errors" "fmt" "net" "sync" "testing" "time" + "github.com/pkg/errors" + "github.com/jackc/pgx" ) diff --git a/copy_from.go b/copy_from.go index f3c77109..8b7c3d5b 100644 --- a/copy_from.go +++ b/copy_from.go @@ -6,6 +6,7 @@ import ( "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgproto3" + "github.com/pkg/errors" ) // CopyFromRows returns a CopyFromSource interface over the provided rows slice @@ -156,7 +157,7 @@ func (ct *copyFrom) run() (int, error) { } if len(values) != len(ct.columnNames) { ct.cancelCopyIn() - return 0, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) + return 0, errors.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) } buf = pgio.AppendInt16(buf, int16(len(ct.columnNames))) diff --git a/copy_from_test.go b/copy_from_test.go index 6df4ebb1..ec674855 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -1,12 +1,12 @@ package pgx_test import ( - "fmt" "reflect" "testing" "time" "github.com/jackc/pgx" + "github.com/pkg/errors" ) func TestConnCopyFromSmall(t *testing.T) { @@ -186,7 +186,7 @@ func (cfs *clientFailSource) Next() bool { func (cfs *clientFailSource) Values() ([]interface{}, error) { if cfs.count == 3 { - cfs.err = fmt.Errorf("client error") + cfs.err = errors.Errorf("client error") return nil, cfs.err } return []interface{}{make([]byte, 100000)}, nil @@ -381,7 +381,7 @@ func (cfs *clientFinalErrSource) Values() ([]interface{}, error) { } func (cfs *clientFinalErrSource) Err() error { - return fmt.Errorf("final error") + return errors.Errorf("final error") } func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) { diff --git a/example_custom_type_test.go b/example_custom_type_test.go index 66ed6c53..d3cc9085 100644 --- a/example_custom_type_test.go +++ b/example_custom_type_test.go @@ -7,6 +7,7 @@ import ( "github.com/jackc/pgx" "github.com/jackc/pgx/pgtype" + "github.com/pkg/errors" ) var pointRegexp *regexp.Regexp = regexp.MustCompile(`^\((.*),(.*)\)$`) @@ -18,7 +19,7 @@ type Point struct { } func (dst *Point) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Point", src) + return errors.Errorf("cannot convert %v to Point", src) } func (dst *Point) Get() interface{} { @@ -33,7 +34,7 @@ func (dst *Point) Get() interface{} { } func (src *Point) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error { @@ -45,16 +46,16 @@ func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error { s := string(src) match := pointRegexp.FindStringSubmatch(s) if match == nil { - return fmt.Errorf("Received invalid point: %v", s) + return errors.Errorf("Received invalid point: %v", s) } x, err := strconv.ParseFloat(match[1], 64) if err != nil { - return fmt.Errorf("Received invalid point: %v", s) + return errors.Errorf("Received invalid point: %v", s) } y, err := strconv.ParseFloat(match[2], 64) if err != nil { - return fmt.Errorf("Received invalid point: %v", s) + return errors.Errorf("Received invalid point: %v", s) } *dst = Point{X: x, Y: y, Status: pgtype.Present} diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go index 92b892b9..53543b89 100644 --- a/internal/sanitize/sanitize.go +++ b/internal/sanitize/sanitize.go @@ -3,11 +3,12 @@ package sanitize import ( "bytes" "encoding/hex" - "fmt" "strconv" "strings" "time" "unicode/utf8" + + "github.com/pkg/errors" ) // Part is either a string or an int. A string is raw SQL. An int is a @@ -30,7 +31,7 @@ func (q *Query) Sanitize(args ...interface{}) (string, error) { case int: argIdx := part - 1 if argIdx >= len(args) { - return "", fmt.Errorf("insufficient arguments") + return "", errors.Errorf("insufficient arguments") } arg := args[argIdx] switch arg := arg.(type) { @@ -49,18 +50,18 @@ func (q *Query) Sanitize(args ...interface{}) (string, error) { case time.Time: str = arg.Format("'2006-01-02 15:04:05.999999999Z07:00:00'") default: - return "", fmt.Errorf("invalid arg type: %T", arg) + return "", errors.Errorf("invalid arg type: %T", arg) } argUse[argIdx] = true default: - return "", fmt.Errorf("invalid Part type: %T", part) + return "", errors.Errorf("invalid Part type: %T", part) } buf.WriteString(str) } for i, used := range argUse { if !used { - return "", fmt.Errorf("unused argument: %d", i) + return "", errors.Errorf("unused argument: %d", i) } } return buf.String(), nil diff --git a/logger.go b/logger.go index c2df1d7d..528698b1 100644 --- a/logger.go +++ b/logger.go @@ -2,8 +2,9 @@ package pgx import ( "encoding/hex" - "errors" "fmt" + + "github.com/pkg/errors" ) // The values for log levels are chosen such that the zero value means that no diff --git a/pgmock/pgmock.go b/pgmock/pgmock.go index b3a51729..5e340881 100644 --- a/pgmock/pgmock.go +++ b/pgmock/pgmock.go @@ -1,12 +1,12 @@ package pgmock import ( - "errors" - "fmt" "io" "net" "reflect" + "github.com/pkg/errors" + "github.com/jackc/pgx/pgproto3" "github.com/jackc/pgx/pgtype" ) @@ -115,7 +115,7 @@ func (e *expectMessageStep) Step(backend *pgproto3.Backend) error { } if !reflect.DeepEqual(msg, e.want) { - return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want) + return errors.Errorf("msg => %#v, e.want => %#v", msg, e.want) } return nil @@ -137,7 +137,7 @@ func (e *expectStartupMessageStep) Step(backend *pgproto3.Backend) error { } if !reflect.DeepEqual(msg, e.want) { - return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want) + return errors.Errorf("msg => %#v, e.want => %#v", msg, e.want) } return nil @@ -475,22 +475,22 @@ func buildDataRow(values []interface{}, formatCodes []int16) (*pgproto3.DataRow, if e, ok := values[i].(pgtype.TextEncoder); ok { buf, err := e.EncodeText(nil, nil) if err != nil { - return nil, fmt.Errorf("failed to encode values[%d]", i) + return nil, errors.Errorf("failed to encode values[%d]", i) } dr.Values[i] = buf } else { - return nil, fmt.Errorf("values[%d] does not implement TextExcoder", i) + return nil, errors.Errorf("values[%d] does not implement TextExcoder", i) } case pgproto3.BinaryFormat: if e, ok := values[i].(pgtype.BinaryEncoder); ok { buf, err := e.EncodeBinary(nil, nil) if err != nil { - return nil, fmt.Errorf("failed to encode values[%d]", i) + return nil, errors.Errorf("failed to encode values[%d]", i) } dr.Values[i] = buf } else { - return nil, fmt.Errorf("values[%d] does not implement BinaryEncoder", i) + return nil, errors.Errorf("values[%d] does not implement BinaryEncoder", i) } default: return nil, errors.New("unknown FormatCode") diff --git a/pgproto3/authentication.go b/pgproto3/authentication.go index c04ee448..77750b86 100644 --- a/pgproto3/authentication.go +++ b/pgproto3/authentication.go @@ -2,9 +2,9 @@ package pgproto3 import ( "encoding/binary" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) const ( @@ -31,7 +31,7 @@ func (dst *Authentication) Decode(src []byte) error { case AuthTypeMD5Password: copy(dst.Salt[:], src[4:8]) default: - return fmt.Errorf("unknown authentication type: %d", dst.Type) + return errors.Errorf("unknown authentication type: %d", dst.Type) } return nil diff --git a/pgproto3/backend.go b/pgproto3/backend.go index bf96ba95..9a7ef342 100644 --- a/pgproto3/backend.go +++ b/pgproto3/backend.go @@ -2,10 +2,10 @@ package pgproto3 import ( "encoding/binary" - "fmt" "io" "github.com/jackc/pgx/chunkreader" + "github.com/pkg/errors" ) type Backend struct { @@ -88,7 +88,7 @@ func (b *Backend) Receive() (FrontendMessage, error) { case 'X': msg = &b.terminate default: - return nil, fmt.Errorf("unknown message type: %c", msgType) + return nil, errors.Errorf("unknown message type: %c", msgType) } msgBody, err := b.cr.Next(bodyLen) diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index 630a5cba..c8ab5f15 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -2,10 +2,10 @@ package pgproto3 import ( "encoding/binary" - "fmt" "io" "github.com/jackc/pgx/chunkreader" + "github.com/pkg/errors" ) type Frontend struct { @@ -100,7 +100,7 @@ func (b *Frontend) Receive() (BackendMessage, error) { case 'Z': msg = &b.readyForQuery default: - return nil, fmt.Errorf("unknown message type: %c", msgType) + return nil, errors.Errorf("unknown message type: %c", msgType) } msgBody, err := b.cr.Next(bodyLen) diff --git a/pgproto3/startup_message.go b/pgproto3/startup_message.go index 4e2df27d..6c5d4f99 100644 --- a/pgproto3/startup_message.go +++ b/pgproto3/startup_message.go @@ -4,9 +4,9 @@ import ( "bytes" "encoding/binary" "encoding/json" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) const ( @@ -23,18 +23,18 @@ func (*StartupMessage) Frontend() {} func (dst *StartupMessage) Decode(src []byte) error { if len(src) < 4 { - return fmt.Errorf("startup message too short") + return errors.Errorf("startup message too short") } dst.ProtocolVersion = binary.BigEndian.Uint32(src) rp := 4 if dst.ProtocolVersion == sslRequestNumber { - return fmt.Errorf("can't handle ssl connection request") + return errors.Errorf("can't handle ssl connection request") } if dst.ProtocolVersion != ProtocolVersionNumber { - return fmt.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion) + return errors.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion) } dst.Parameters = make(map[string]string) @@ -57,7 +57,7 @@ func (dst *StartupMessage) Decode(src []byte) error { if len(src[rp:]) == 1 { if src[rp] != 0 { - return fmt.Errorf("Bad startup message last byte. Expected 0, got %d", src[rp]) + return errors.Errorf("Bad startup message last byte. Expected 0, got %d", src[rp]) } break } diff --git a/pgtype/aclitem.go b/pgtype/aclitem.go index 829eb908..35269e91 100644 --- a/pgtype/aclitem.go +++ b/pgtype/aclitem.go @@ -2,7 +2,8 @@ package pgtype import ( "database/sql/driver" - "fmt" + + "github.com/pkg/errors" ) // ACLItem is used for PostgreSQL's aclitem data type. A sample aclitem @@ -36,7 +37,7 @@ func (dst *ACLItem) Set(src interface{}) error { if originalSrc, ok := underlyingStringType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to ACLItem", value) + return errors.Errorf("cannot convert %v to ACLItem", value) } return nil @@ -69,7 +70,7 @@ func (src *ACLItem) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *ACLItem) DecodeText(ci *ConnInfo, src []byte) error { @@ -109,7 +110,7 @@ func (dst *ACLItem) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/aclitem_array.go b/pgtype/aclitem_array.go index f9215a93..fe0af434 100644 --- a/pgtype/aclitem_array.go +++ b/pgtype/aclitem_array.go @@ -2,7 +2,8 @@ package pgtype import ( "database/sql/driver" - "fmt" + + "github.com/pkg/errors" ) type ACLItemArray struct { @@ -37,7 +38,7 @@ func (dst *ACLItemArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to ACLItem", value) + return errors.Errorf("cannot convert %v to ACLItem", value) } return nil @@ -77,7 +78,7 @@ func (src *ACLItemArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *ACLItemArray) DecodeText(ci *ConnInfo, src []byte) error { @@ -188,7 +189,7 @@ func (dst *ACLItemArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/array.go b/pgtype/array.go index e5504455..5b852ed5 100644 --- a/pgtype/array.go +++ b/pgtype/array.go @@ -3,13 +3,13 @@ package pgtype import ( "bytes" "encoding/binary" - "fmt" "io" "strconv" "strings" "unicode" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) // Information on the internals of PostgreSQL arrays can be found in @@ -29,7 +29,7 @@ type ArrayDimension struct { func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) { if len(src) < 12 { - return 0, fmt.Errorf("array header too short: %d", len(src)) + return 0, errors.Errorf("array header too short: %d", len(src)) } rp := 0 @@ -47,7 +47,7 @@ func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) { dst.Dimensions = make([]ArrayDimension, numDims) } if len(src) < 12+numDims*8 { - return 0, fmt.Errorf("array header too short for %d dimensions: %d", numDims, len(src)) + return 0, errors.Errorf("array header too short for %d dimensions: %d", numDims, len(src)) } for i := range dst.Dimensions { dst.Dimensions[i].Length = int32(binary.BigEndian.Uint32(src[rp:])) @@ -93,7 +93,7 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { r, _, err := buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid array: %v", err) + return nil, errors.Errorf("invalid array: %v", err) } var explicitDimensions []ArrayDimension @@ -105,41 +105,41 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { for { r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid array: %v", err) + return nil, errors.Errorf("invalid array: %v", err) } if r == '=' { break } else if r != '[' { - return nil, fmt.Errorf("invalid array, expected '[' or '=' got %v", r) + return nil, errors.Errorf("invalid array, expected '[' or '=' got %v", r) } lower, err := arrayParseInteger(buf) if err != nil { - return nil, fmt.Errorf("invalid array: %v", err) + return nil, errors.Errorf("invalid array: %v", err) } r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid array: %v", err) + return nil, errors.Errorf("invalid array: %v", err) } if r != ':' { - return nil, fmt.Errorf("invalid array, expected ':' got %v", r) + return nil, errors.Errorf("invalid array, expected ':' got %v", r) } upper, err := arrayParseInteger(buf) if err != nil { - return nil, fmt.Errorf("invalid array: %v", err) + return nil, errors.Errorf("invalid array: %v", err) } r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid array: %v", err) + return nil, errors.Errorf("invalid array: %v", err) } if r != ']' { - return nil, fmt.Errorf("invalid array, expected ']' got %v", r) + return nil, errors.Errorf("invalid array, expected ']' got %v", r) } explicitDimensions = append(explicitDimensions, ArrayDimension{LowerBound: lower, Length: upper - lower + 1}) @@ -147,12 +147,12 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid array: %v", err) + return nil, errors.Errorf("invalid array: %v", err) } } if r != '{' { - return nil, fmt.Errorf("invalid array, expected '{': %v", err) + return nil, errors.Errorf("invalid array, expected '{': %v", err) } implicitDimensions := []ArrayDimension{{LowerBound: 1, Length: 0}} @@ -161,7 +161,7 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { for { r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid array: %v", err) + return nil, errors.Errorf("invalid array: %v", err) } if r == '{' { @@ -178,7 +178,7 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { for { r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid array: %v", err) + return nil, errors.Errorf("invalid array: %v", err) } switch r { @@ -197,7 +197,7 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { buf.UnreadRune() value, err := arrayParseValue(buf) if err != nil { - return nil, fmt.Errorf("invalid array value: %v", err) + return nil, errors.Errorf("invalid array value: %v", err) } if currentDim == counterDim { implicitDimensions[currentDim].Length++ @@ -213,7 +213,7 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { skipWhitespace(buf) if buf.Len() > 0 { - return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) + return nil, errors.Errorf("unexpected trailing data: %v", buf.String()) } if len(dst.Elements) == 0 { diff --git a/pgtype/bool.go b/pgtype/bool.go index 7c66a534..3a3eef48 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -2,8 +2,9 @@ package pgtype import ( "database/sql/driver" - "fmt" "strconv" + + "github.com/pkg/errors" ) type Bool struct { @@ -25,7 +26,7 @@ func (dst *Bool) Set(src interface{}) error { if originalSrc, ok := underlyingBoolType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Bool", value) + return errors.Errorf("cannot convert %v to Bool", value) } return nil @@ -58,7 +59,7 @@ func (src *Bool) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *Bool) DecodeText(ci *ConnInfo, src []byte) error { @@ -68,7 +69,7 @@ func (dst *Bool) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) != 1 { - return fmt.Errorf("invalid length for bool: %v", len(src)) + return errors.Errorf("invalid length for bool: %v", len(src)) } *dst = Bool{Bool: src[0] == 't', Status: Present} @@ -82,7 +83,7 @@ func (dst *Bool) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 1 { - return fmt.Errorf("invalid length for bool: %v", len(src)) + return errors.Errorf("invalid length for bool: %v", len(src)) } *dst = Bool{Bool: src[0] == 1, Status: Present} @@ -142,7 +143,7 @@ func (dst *Bool) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/bool_array.go b/pgtype/bool_array.go index e20a0381..e23c27e5 100644 --- a/pgtype/bool_array.go +++ b/pgtype/bool_array.go @@ -3,9 +3,9 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type BoolArray struct { @@ -40,7 +40,7 @@ func (dst *BoolArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Bool", value) + return errors.Errorf("cannot convert %v to Bool", value) } return nil @@ -80,7 +80,7 @@ func (src *BoolArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *BoolArray) DecodeText(ci *ConnInfo, src []byte) error { @@ -233,7 +233,7 @@ func (src *BoolArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("bool"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "bool") + return nil, errors.Errorf("unable to find oid for type name %v", "bool") } for i := range src.Elements { @@ -277,7 +277,7 @@ func (dst *BoolArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/box.go b/pgtype/box.go index 2d098058..83df0499 100644 --- a/pgtype/box.go +++ b/pgtype/box.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Box struct { @@ -17,7 +18,7 @@ type Box struct { } func (dst *Box) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Box", src) + return errors.Errorf("cannot convert %v to Box", src) } func (dst *Box) Get() interface{} { @@ -32,7 +33,7 @@ func (dst *Box) Get() interface{} { } func (src *Box) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *Box) DecodeText(ci *ConnInfo, src []byte) error { @@ -42,7 +43,7 @@ func (dst *Box) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) < 11 { - return fmt.Errorf("invalid length for Box: %v", len(src)) + return errors.Errorf("invalid length for Box: %v", len(src)) } str := string(src[1:]) @@ -89,7 +90,7 @@ func (dst *Box) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 32 { - return fmt.Errorf("invalid length for Box: %v", len(src)) + return errors.Errorf("invalid length for Box: %v", len(src)) } x1 := binary.BigEndian.Uint64(src) @@ -152,7 +153,7 @@ func (dst *Box) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/bytea.go b/pgtype/bytea.go index 2ddac7da..c7117f48 100644 --- a/pgtype/bytea.go +++ b/pgtype/bytea.go @@ -3,7 +3,8 @@ package pgtype import ( "database/sql/driver" "encoding/hex" - "fmt" + + "github.com/pkg/errors" ) type Bytea struct { @@ -28,7 +29,7 @@ func (dst *Bytea) Set(src interface{}) error { if originalSrc, ok := underlyingBytesType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Bytea", value) + return errors.Errorf("cannot convert %v to Bytea", value) } return nil @@ -63,7 +64,7 @@ func (src *Bytea) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } // DecodeText only supports the hex format. This has been the default since @@ -75,7 +76,7 @@ func (dst *Bytea) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) < 2 || src[0] != '\\' || src[1] != 'x' { - return fmt.Errorf("invalid hex format") + return errors.Errorf("invalid hex format") } buf := make([]byte, (len(src)-2)/2) @@ -139,7 +140,7 @@ func (dst *Bytea) Scan(src interface{}) error { return nil } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/bytea_array.go b/pgtype/bytea_array.go index 0d381693..f2842179 100644 --- a/pgtype/bytea_array.go +++ b/pgtype/bytea_array.go @@ -3,9 +3,9 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type ByteaArray struct { @@ -40,7 +40,7 @@ func (dst *ByteaArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Bytea", value) + return errors.Errorf("cannot convert %v to Bytea", value) } return nil @@ -80,7 +80,7 @@ func (src *ByteaArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *ByteaArray) DecodeText(ci *ConnInfo, src []byte) error { @@ -233,7 +233,7 @@ func (src *ByteaArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("bytea"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "bytea") + return nil, errors.Errorf("unable to find oid for type name %v", "bytea") } for i := range src.Elements { @@ -277,7 +277,7 @@ func (dst *ByteaArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/cidr_array.go b/pgtype/cidr_array.go index 9b7b50fa..2373da46 100644 --- a/pgtype/cidr_array.go +++ b/pgtype/cidr_array.go @@ -3,10 +3,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "net" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type CIDRArray struct { @@ -60,7 +60,7 @@ func (dst *CIDRArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to CIDR", value) + return errors.Errorf("cannot convert %v to CIDR", value) } return nil @@ -109,7 +109,7 @@ func (src *CIDRArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *CIDRArray) DecodeText(ci *ConnInfo, src []byte) error { @@ -262,7 +262,7 @@ func (src *CIDRArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("cidr"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "cidr") + return nil, errors.Errorf("unable to find oid for type name %v", "cidr") } for i := range src.Elements { @@ -306,7 +306,7 @@ func (dst *CIDRArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/circle.go b/pgtype/circle.go index 8626a99d..97ecbf31 100644 --- a/pgtype/circle.go +++ b/pgtype/circle.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Circle struct { @@ -18,7 +19,7 @@ type Circle struct { } func (dst *Circle) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Circle", src) + return errors.Errorf("cannot convert %v to Circle", src) } func (dst *Circle) Get() interface{} { @@ -33,7 +34,7 @@ func (dst *Circle) Get() interface{} { } func (src *Circle) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *Circle) DecodeText(ci *ConnInfo, src []byte) error { @@ -43,7 +44,7 @@ func (dst *Circle) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) < 9 { - return fmt.Errorf("invalid length for Circle: %v", len(src)) + return errors.Errorf("invalid length for Circle: %v", len(src)) } str := string(src[2:]) @@ -79,7 +80,7 @@ func (dst *Circle) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 24 { - return fmt.Errorf("invalid length for Circle: %v", len(src)) + return errors.Errorf("invalid length for Circle: %v", len(src)) } x := binary.BigEndian.Uint64(src) @@ -136,7 +137,7 @@ func (dst *Circle) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/convert.go b/pgtype/convert.go index 2b406426..5dfb738e 100644 --- a/pgtype/convert.go +++ b/pgtype/convert.go @@ -1,10 +1,11 @@ package pgtype import ( - "fmt" "math" "reflect" "time" + + "github.com/pkg/errors" ) const maxUint = ^uint(0) @@ -189,70 +190,70 @@ func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { switch v := dst.(type) { case *int: if srcVal < int64(minInt) { - return fmt.Errorf("%d is less than minimum value for int", srcVal) + return errors.Errorf("%d is less than minimum value for int", srcVal) } else if srcVal > int64(maxInt) { - return fmt.Errorf("%d is greater than maximum value for int", srcVal) + return errors.Errorf("%d is greater than maximum value for int", srcVal) } *v = int(srcVal) case *int8: if srcVal < math.MinInt8 { - return fmt.Errorf("%d is less than minimum value for int8", srcVal) + return errors.Errorf("%d is less than minimum value for int8", srcVal) } else if srcVal > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for int8", srcVal) + return errors.Errorf("%d is greater than maximum value for int8", srcVal) } *v = int8(srcVal) case *int16: if srcVal < math.MinInt16 { - return fmt.Errorf("%d is less than minimum value for int16", srcVal) + return errors.Errorf("%d is less than minimum value for int16", srcVal) } else if srcVal > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for int16", srcVal) + return errors.Errorf("%d is greater than maximum value for int16", srcVal) } *v = int16(srcVal) case *int32: if srcVal < math.MinInt32 { - return fmt.Errorf("%d is less than minimum value for int32", srcVal) + return errors.Errorf("%d is less than minimum value for int32", srcVal) } else if srcVal > math.MaxInt32 { - return fmt.Errorf("%d is greater than maximum value for int32", srcVal) + return errors.Errorf("%d is greater than maximum value for int32", srcVal) } *v = int32(srcVal) case *int64: if srcVal < math.MinInt64 { - return fmt.Errorf("%d is less than minimum value for int64", srcVal) + return errors.Errorf("%d is less than minimum value for int64", srcVal) } else if srcVal > math.MaxInt64 { - return fmt.Errorf("%d is greater than maximum value for int64", srcVal) + return errors.Errorf("%d is greater than maximum value for int64", srcVal) } *v = int64(srcVal) case *uint: if srcVal < 0 { - return fmt.Errorf("%d is less than zero for uint", srcVal) + return errors.Errorf("%d is less than zero for uint", srcVal) } else if uint64(srcVal) > uint64(maxUint) { - return fmt.Errorf("%d is greater than maximum value for uint", srcVal) + return errors.Errorf("%d is greater than maximum value for uint", srcVal) } *v = uint(srcVal) case *uint8: if srcVal < 0 { - return fmt.Errorf("%d is less than zero for uint8", srcVal) + return errors.Errorf("%d is less than zero for uint8", srcVal) } else if srcVal > math.MaxUint8 { - return fmt.Errorf("%d is greater than maximum value for uint8", srcVal) + return errors.Errorf("%d is greater than maximum value for uint8", srcVal) } *v = uint8(srcVal) case *uint16: if srcVal < 0 { - return fmt.Errorf("%d is less than zero for uint32", srcVal) + return errors.Errorf("%d is less than zero for uint32", srcVal) } else if srcVal > math.MaxUint16 { - return fmt.Errorf("%d is greater than maximum value for uint16", srcVal) + return errors.Errorf("%d is greater than maximum value for uint16", srcVal) } *v = uint16(srcVal) case *uint32: if srcVal < 0 { - return fmt.Errorf("%d is less than zero for uint32", srcVal) + return errors.Errorf("%d is less than zero for uint32", srcVal) } else if srcVal > math.MaxUint32 { - return fmt.Errorf("%d is greater than maximum value for uint32", srcVal) + return errors.Errorf("%d is greater than maximum value for uint32", srcVal) } *v = uint32(srcVal) case *uint64: if srcVal < 0 { - return fmt.Errorf("%d is less than zero for uint64", srcVal) + return errors.Errorf("%d is less than zero for uint64", srcVal) } *v = uint64(srcVal) default: @@ -268,22 +269,22 @@ func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { return int64AssignTo(srcVal, srcStatus, el.Interface()) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: if el.OverflowInt(int64(srcVal)) { - return fmt.Errorf("cannot put %d into %T", srcVal, dst) + return errors.Errorf("cannot put %d into %T", srcVal, dst) } el.SetInt(int64(srcVal)) return nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: if srcVal < 0 { - return fmt.Errorf("%d is less than zero for %T", srcVal, dst) + return errors.Errorf("%d is less than zero for %T", srcVal, dst) } if el.OverflowUint(uint64(srcVal)) { - return fmt.Errorf("cannot put %d into %T", srcVal, dst) + return errors.Errorf("cannot put %d into %T", srcVal, dst) } el.SetUint(uint64(srcVal)) return nil } } - return fmt.Errorf("cannot assign %v into %T", srcVal, dst) + return errors.Errorf("cannot assign %v into %T", srcVal, dst) } return nil } @@ -297,7 +298,7 @@ func int64AssignTo(srcVal int64, srcStatus Status, dst interface{}) error { } } - return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) + return errors.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) } func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error { @@ -325,7 +326,7 @@ func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error { } } } - return fmt.Errorf("cannot assign %v into %T", srcVal, dst) + return errors.Errorf("cannot assign %v into %T", srcVal, dst) } return nil } @@ -339,7 +340,7 @@ func float64AssignTo(srcVal float64, srcStatus Status, dst interface{}) error { } } - return fmt.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) + return errors.Errorf("cannot assign %v %v into %T", srcVal, srcStatus, dst) } func NullAssignTo(dst interface{}) error { @@ -347,7 +348,7 @@ func NullAssignTo(dst interface{}) error { // AssignTo dst must always be a pointer if dstPtr.Kind() != reflect.Ptr { - return fmt.Errorf("cannot assign NULL to %T", dst) + return errors.Errorf("cannot assign NULL to %T", dst) } dstVal := dstPtr.Elem() @@ -358,7 +359,7 @@ func NullAssignTo(dst interface{}) error { return nil } - return fmt.Errorf("cannot assign NULL to %T", dst) + return errors.Errorf("cannot assign NULL to %T", dst) } var kindTypes map[reflect.Kind]reflect.Type diff --git a/pgtype/database_sql.go b/pgtype/database_sql.go index 9d1cf822..969536dd 100644 --- a/pgtype/database_sql.go +++ b/pgtype/database_sql.go @@ -2,7 +2,8 @@ package pgtype import ( "database/sql/driver" - "errors" + + "github.com/pkg/errors" ) func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) { diff --git a/pgtype/date.go b/pgtype/date.go index 8e049254..f1c0d8bd 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -3,10 +3,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "time" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Date struct { @@ -33,7 +33,7 @@ func (dst *Date) Set(src interface{}) error { if originalSrc, ok := underlyingTimeType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Date", value) + return errors.Errorf("cannot convert %v to Date", value) } return nil @@ -59,7 +59,7 @@ func (src *Date) AssignTo(dst interface{}) error { switch v := dst.(type) { case *time.Time: if src.InfinityModifier != None { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } *v = src.Time return nil @@ -72,7 +72,7 @@ func (src *Date) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *Date) DecodeText(ci *ConnInfo, src []byte) error { @@ -106,7 +106,7 @@ func (dst *Date) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 4 { - return fmt.Errorf("invalid length for date: %v", len(src)) + return errors.Errorf("invalid length for date: %v", len(src)) } dayOffset := int32(binary.BigEndian.Uint32(src)) @@ -190,7 +190,7 @@ func (dst *Date) Scan(src interface{}) error { return nil } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/date_array.go b/pgtype/date_array.go index ef91cf3e..383945e7 100644 --- a/pgtype/date_array.go +++ b/pgtype/date_array.go @@ -3,10 +3,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "time" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type DateArray struct { @@ -41,7 +41,7 @@ func (dst *DateArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Date", value) + return errors.Errorf("cannot convert %v to Date", value) } return nil @@ -81,7 +81,7 @@ func (src *DateArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *DateArray) DecodeText(ci *ConnInfo, src []byte) error { @@ -234,7 +234,7 @@ func (src *DateArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("date"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "date") + return nil, errors.Errorf("unable to find oid for type name %v", "date") } for i := range src.Elements { @@ -278,7 +278,7 @@ func (dst *DateArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/daterange.go b/pgtype/daterange.go index bbe7b17a..47cd7e46 100644 --- a/pgtype/daterange.go +++ b/pgtype/daterange.go @@ -2,9 +2,9 @@ package pgtype import ( "database/sql/driver" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Daterange struct { @@ -16,7 +16,7 @@ type Daterange struct { } func (dst *Daterange) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Daterange", src) + return errors.Errorf("cannot convert %v to Daterange", src) } func (dst *Daterange) Get() interface{} { @@ -31,7 +31,7 @@ func (dst *Daterange) Get() interface{} { } func (src *Daterange) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *Daterange) DecodeText(ci *ConnInfo, src []byte) error { @@ -120,7 +120,7 @@ func (src Daterange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, "empty"...), nil default: - return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) + return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) } var err error @@ -130,7 +130,7 @@ func (src Daterange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") } } @@ -141,7 +141,7 @@ func (src Daterange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") } } @@ -151,7 +151,7 @@ func (src Daterange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Inclusive: buf = append(buf, ']') default: - return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) + return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) } return buf, nil @@ -175,7 +175,7 @@ func (src Daterange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, emptyMask), nil default: - return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) + return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -185,7 +185,7 @@ func (src Daterange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) + return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) } buf = append(buf, rangeType) @@ -201,7 +201,7 @@ func (src Daterange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -216,7 +216,7 @@ func (src Daterange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -241,7 +241,7 @@ func (dst *Daterange) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/ext/satori-uuid/uuid.go b/pgtype/ext/satori-uuid/uuid.go index b7b776f9..78a90035 100644 --- a/pgtype/ext/satori-uuid/uuid.go +++ b/pgtype/ext/satori-uuid/uuid.go @@ -2,8 +2,8 @@ package uuid import ( "database/sql/driver" - "errors" - "fmt" + + "github.com/pkg/errors" "github.com/jackc/pgx/pgtype" uuid "github.com/satori/go.uuid" @@ -24,7 +24,7 @@ func (dst *UUID) Set(src interface{}) error { *dst = UUID{UUID: uuid.UUID(value), Status: pgtype.Present} case []byte: if len(value) != 16 { - return fmt.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) + return errors.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) } *dst = UUID{Status: pgtype.Present} copy(dst.UUID[:], value) @@ -38,7 +38,7 @@ func (dst *UUID) Set(src interface{}) error { // If all else fails see if pgtype.UUID can handle it. If so, translate through that. pgUUID := &pgtype.UUID{} if err := pgUUID.Set(value); err != nil { - return fmt.Errorf("cannot convert %v to UUID", value) + return errors.Errorf("cannot convert %v to UUID", value) } *dst = UUID{UUID: uuid.UUID(pgUUID.Bytes), Status: pgUUID.Status} @@ -83,7 +83,7 @@ func (src *UUID) AssignTo(dst interface{}) error { return pgtype.NullAssignTo(dst) } - return fmt.Errorf("cannot assign %v into %T", src, dst) + return errors.Errorf("cannot assign %v into %T", src, dst) } func (dst *UUID) DecodeText(ci *pgtype.ConnInfo, src []byte) error { @@ -108,7 +108,7 @@ func (dst *UUID) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { } if len(src) != 16 { - return fmt.Errorf("invalid length for UUID: %v", len(src)) + return errors.Errorf("invalid length for UUID: %v", len(src)) } *dst = UUID{Status: pgtype.Present} @@ -152,7 +152,7 @@ func (dst *UUID) Scan(src interface{}) error { return dst.DecodeText(nil, src) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/ext/shopspring-numeric/decimal.go b/pgtype/ext/shopspring-numeric/decimal.go index 277f3709..507a93dc 100644 --- a/pgtype/ext/shopspring-numeric/decimal.go +++ b/pgtype/ext/shopspring-numeric/decimal.go @@ -2,10 +2,10 @@ package numeric import ( "database/sql/driver" - "errors" - "fmt" "strconv" + "github.com/pkg/errors" + "github.com/jackc/pgx/pgtype" "github.com/shopspring/decimal" ) @@ -70,17 +70,17 @@ func (dst *Numeric) Set(src interface{}) error { // If all else fails see if pgtype.Numeric can handle it. If so, translate through that. num := &pgtype.Numeric{} if err := num.Set(value); err != nil { - return fmt.Errorf("cannot convert %v to Numeric", value) + return errors.Errorf("cannot convert %v to Numeric", value) } buf, err := num.EncodeText(nil, nil) if err != nil { - return fmt.Errorf("cannot convert %v to Numeric", value) + return errors.Errorf("cannot convert %v to Numeric", value) } dec, err := decimal.NewFromString(string(buf)) if err != nil { - return fmt.Errorf("cannot convert %v to Numeric", value) + return errors.Errorf("cannot convert %v to Numeric", value) } *dst = Numeric{Decimal: dec, Status: pgtype.Present} } @@ -113,92 +113,92 @@ func (src *Numeric) AssignTo(dst interface{}) error { *v = f case *int: if src.Decimal.Exponent() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseInt(src.Decimal.String(), 10, strconv.IntSize) if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } *v = int(n) case *int8: if src.Decimal.Exponent() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseInt(src.Decimal.String(), 10, 8) if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } *v = int8(n) case *int16: if src.Decimal.Exponent() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseInt(src.Decimal.String(), 10, 16) if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } *v = int16(n) case *int32: if src.Decimal.Exponent() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseInt(src.Decimal.String(), 10, 32) if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } *v = int32(n) case *int64: if src.Decimal.Exponent() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseInt(src.Decimal.String(), 10, 64) if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } *v = int64(n) case *uint: if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseUint(src.Decimal.String(), 10, strconv.IntSize) if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } *v = uint(n) case *uint8: if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseUint(src.Decimal.String(), 10, 8) if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } *v = uint8(n) case *uint16: if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseUint(src.Decimal.String(), 10, 16) if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } *v = uint16(n) case *uint32: if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseUint(src.Decimal.String(), 10, 32) if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } *v = uint32(n) case *uint64: if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } n, err := strconv.ParseUint(src.Decimal.String(), 10, 64) if err != nil { - return fmt.Errorf("cannot convert %v to %T", dst, *v) + return errors.Errorf("cannot convert %v to %T", dst, *v) } *v = uint64(n) default: @@ -301,7 +301,7 @@ func (dst *Numeric) Scan(src interface{}) error { return dst.DecodeText(nil, src) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/float4.go b/pgtype/float4.go index b24654b6..2207594a 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -3,11 +3,11 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "math" "strconv" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Float4 struct { @@ -39,42 +39,42 @@ func (dst *Float4) Set(src interface{}) error { if int32(f32) == value { *dst = Float4{Float: f32, Status: Present} } else { - return fmt.Errorf("%v cannot be exactly represented as float32", value) + return errors.Errorf("%v cannot be exactly represented as float32", value) } case uint32: f32 := float32(value) if uint32(f32) == value { *dst = Float4{Float: f32, Status: Present} } else { - return fmt.Errorf("%v cannot be exactly represented as float32", value) + return errors.Errorf("%v cannot be exactly represented as float32", value) } case int64: f32 := float32(value) if int64(f32) == value { *dst = Float4{Float: f32, Status: Present} } else { - return fmt.Errorf("%v cannot be exactly represented as float32", value) + return errors.Errorf("%v cannot be exactly represented as float32", value) } case uint64: f32 := float32(value) if uint64(f32) == value { *dst = Float4{Float: f32, Status: Present} } else { - return fmt.Errorf("%v cannot be exactly represented as float32", value) + return errors.Errorf("%v cannot be exactly represented as float32", value) } case int: f32 := float32(value) if int(f32) == value { *dst = Float4{Float: f32, Status: Present} } else { - return fmt.Errorf("%v cannot be exactly represented as float32", value) + return errors.Errorf("%v cannot be exactly represented as float32", value) } case uint: f32 := float32(value) if uint(f32) == value { *dst = Float4{Float: f32, Status: Present} } else { - return fmt.Errorf("%v cannot be exactly represented as float32", value) + return errors.Errorf("%v cannot be exactly represented as float32", value) } case string: num, err := strconv.ParseFloat(value, 32) @@ -86,7 +86,7 @@ func (dst *Float4) Set(src interface{}) error { if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Float8", value) + return errors.Errorf("cannot convert %v to Float8", value) } return nil @@ -129,7 +129,7 @@ func (dst *Float4) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 4 { - return fmt.Errorf("invalid length for float4: %v", len(src)) + return errors.Errorf("invalid length for float4: %v", len(src)) } n := int32(binary.BigEndian.Uint32(src)) @@ -181,7 +181,7 @@ func (dst *Float4) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/float4_array.go b/pgtype/float4_array.go index a35657b0..6499064b 100644 --- a/pgtype/float4_array.go +++ b/pgtype/float4_array.go @@ -3,9 +3,9 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Float4Array struct { @@ -40,7 +40,7 @@ func (dst *Float4Array) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Float4", value) + return errors.Errorf("cannot convert %v to Float4", value) } return nil @@ -80,7 +80,7 @@ func (src *Float4Array) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *Float4Array) DecodeText(ci *ConnInfo, src []byte) error { @@ -233,7 +233,7 @@ func (src *Float4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("float4"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "float4") + return nil, errors.Errorf("unable to find oid for type name %v", "float4") } for i := range src.Elements { @@ -277,7 +277,7 @@ func (dst *Float4Array) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/float8.go b/pgtype/float8.go index c3ecdcc2..dd34f541 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -3,11 +3,11 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "math" "strconv" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Float8 struct { @@ -43,28 +43,28 @@ func (dst *Float8) Set(src interface{}) error { if int64(f64) == value { *dst = Float8{Float: f64, Status: Present} } else { - return fmt.Errorf("%v cannot be exactly represented as float64", value) + return errors.Errorf("%v cannot be exactly represented as float64", value) } case uint64: f64 := float64(value) if uint64(f64) == value { *dst = Float8{Float: f64, Status: Present} } else { - return fmt.Errorf("%v cannot be exactly represented as float64", value) + return errors.Errorf("%v cannot be exactly represented as float64", value) } case int: f64 := float64(value) if int(f64) == value { *dst = Float8{Float: f64, Status: Present} } else { - return fmt.Errorf("%v cannot be exactly represented as float64", value) + return errors.Errorf("%v cannot be exactly represented as float64", value) } case uint: f64 := float64(value) if uint(f64) == value { *dst = Float8{Float: f64, Status: Present} } else { - return fmt.Errorf("%v cannot be exactly represented as float64", value) + return errors.Errorf("%v cannot be exactly represented as float64", value) } case string: num, err := strconv.ParseFloat(value, 64) @@ -76,7 +76,7 @@ func (dst *Float8) Set(src interface{}) error { if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Float8", value) + return errors.Errorf("cannot convert %v to Float8", value) } return nil @@ -119,7 +119,7 @@ func (dst *Float8) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 8 { - return fmt.Errorf("invalid length for float4: %v", len(src)) + return errors.Errorf("invalid length for float4: %v", len(src)) } n := int64(binary.BigEndian.Uint64(src)) @@ -171,7 +171,7 @@ func (dst *Float8) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/float8_array.go b/pgtype/float8_array.go index 486e3a4e..27b24836 100644 --- a/pgtype/float8_array.go +++ b/pgtype/float8_array.go @@ -3,9 +3,9 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Float8Array struct { @@ -40,7 +40,7 @@ func (dst *Float8Array) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Float8", value) + return errors.Errorf("cannot convert %v to Float8", value) } return nil @@ -80,7 +80,7 @@ func (src *Float8Array) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *Float8Array) DecodeText(ci *ConnInfo, src []byte) error { @@ -233,7 +233,7 @@ func (src *Float8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("float8"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "float8") + return nil, errors.Errorf("unable to find oid for type name %v", "float8") } for i := range src.Elements { @@ -277,7 +277,7 @@ func (dst *Float8Array) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/hstore.go b/pgtype/hstore.go index 09506242..347446ae 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -4,12 +4,12 @@ import ( "bytes" "database/sql/driver" "encoding/binary" - "errors" - "fmt" "strings" "unicode" "unicode/utf8" + "github.com/pkg/errors" + "github.com/jackc/pgx/pgio" ) @@ -34,7 +34,7 @@ func (dst *Hstore) Set(src interface{}) error { } *dst = Hstore{Map: m, Status: Present} default: - return fmt.Errorf("cannot convert %v to Hstore", src) + return errors.Errorf("cannot convert %v to Hstore", src) } return nil @@ -59,7 +59,7 @@ func (src *Hstore) AssignTo(dst interface{}) error { *v = make(map[string]string, len(src.Map)) for k, val := range src.Map { if val.Status != Present { - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } (*v)[k] = val.String } @@ -73,7 +73,7 @@ func (src *Hstore) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *Hstore) DecodeText(ci *ConnInfo, src []byte) error { @@ -105,7 +105,7 @@ func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { rp := 0 if len(src[rp:]) < 4 { - return fmt.Errorf("hstore incomplete %v", src) + return errors.Errorf("hstore incomplete %v", src) } pairCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 @@ -114,19 +114,19 @@ func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { for i := 0; i < pairCount; i++ { if len(src[rp:]) < 4 { - return fmt.Errorf("hstore incomplete %v", src) + return errors.Errorf("hstore incomplete %v", src) } keyLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 if len(src[rp:]) < keyLen { - return fmt.Errorf("hstore incomplete %v", src) + return errors.Errorf("hstore incomplete %v", src) } key := string(src[rp : rp+keyLen]) rp += keyLen if len(src[rp:]) < 4 { - return fmt.Errorf("hstore incomplete %v", src) + return errors.Errorf("hstore incomplete %v", src) } valueLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 @@ -333,13 +333,13 @@ func parseHstore(s string) (k []string, v []Text, err error) { case r == 'N': state = hsNul default: - err = fmt.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r) + err = errors.Errorf("Invalid character '%c' after '=>', expecting '\"' or 'NULL'", r) } default: - err = fmt.Errorf("Invalid character after '=', expecting '>'") + err = errors.Errorf("Invalid character after '=', expecting '>'") } } else { - err = fmt.Errorf("Invalid character '%c' after value, expecting '='", r) + err = errors.Errorf("Invalid character '%c' after value, expecting '='", r) } case hsVal: switch r { @@ -376,7 +376,7 @@ func parseHstore(s string) (k []string, v []Text, err error) { values = append(values, Text{Status: Null}) state = hsNext } else { - err = fmt.Errorf("Invalid NULL value: 'N%s'", string(nulBuf)) + err = errors.Errorf("Invalid NULL value: 'N%s'", string(nulBuf)) } case hsNext: if r == ',' { @@ -388,10 +388,10 @@ func parseHstore(s string) (k []string, v []Text, err error) { r, end = p.Consume() state = hsKey default: - err = fmt.Errorf("Invalid character '%c' after ', ', expecting \"", r) + err = errors.Errorf("Invalid character '%c' after ', ', expecting \"", r) } } else { - err = fmt.Errorf("Invalid character '%c' after value, expecting ','", r) + err = errors.Errorf("Invalid character '%c' after value, expecting ','", r) } } @@ -425,7 +425,7 @@ func (dst *Hstore) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/hstore_array.go b/pgtype/hstore_array.go index 3e5a003f..38ce457b 100644 --- a/pgtype/hstore_array.go +++ b/pgtype/hstore_array.go @@ -3,9 +3,9 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type HstoreArray struct { @@ -40,7 +40,7 @@ func (dst *HstoreArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Hstore", value) + return errors.Errorf("cannot convert %v to Hstore", value) } return nil @@ -80,7 +80,7 @@ func (src *HstoreArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *HstoreArray) DecodeText(ci *ConnInfo, src []byte) error { @@ -233,7 +233,7 @@ func (src *HstoreArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("hstore"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "hstore") + return nil, errors.Errorf("unable to find oid for type name %v", "hstore") } for i := range src.Elements { @@ -277,7 +277,7 @@ func (dst *HstoreArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/inet.go b/pgtype/inet.go index 7aa1df95..01fc0e5b 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -2,8 +2,9 @@ package pgtype import ( "database/sql/driver" - "fmt" "net" + + "github.com/pkg/errors" ) // Network address family is dependent on server socket.h value for AF_INET. @@ -45,7 +46,7 @@ func (dst *Inet) Set(src interface{}) error { if originalSrc, ok := underlyingPtrType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Inet", value) + return errors.Errorf("cannot convert %v to Inet", value) } return nil @@ -76,7 +77,7 @@ func (src *Inet) AssignTo(dst interface{}) error { return nil case *net.IP: if oneCount, bitCount := src.IPNet.Mask.Size(); oneCount != bitCount { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } *v = make(net.IP, len(src.IPNet.IP)) copy(*v, src.IPNet.IP) @@ -90,7 +91,7 @@ func (src *Inet) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error { @@ -128,7 +129,7 @@ func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 8 && len(src) != 20 { - return fmt.Errorf("Received an invalid size for a inet: %d", len(src)) + return errors.Errorf("Received an invalid size for a inet: %d", len(src)) } // ignore family @@ -173,7 +174,7 @@ func (src *Inet) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { case net.IPv6len: family = defaultAFInet6 default: - return nil, fmt.Errorf("Unexpected IP length: %v", len(src.IPNet.IP)) + return nil, errors.Errorf("Unexpected IP length: %v", len(src.IPNet.IP)) } buf = append(buf, family) @@ -205,7 +206,7 @@ func (dst *Inet) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/inet_array.go b/pgtype/inet_array.go index 57123c1c..3ece23eb 100644 --- a/pgtype/inet_array.go +++ b/pgtype/inet_array.go @@ -3,10 +3,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "net" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type InetArray struct { @@ -60,7 +60,7 @@ func (dst *InetArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Inet", value) + return errors.Errorf("cannot convert %v to Inet", value) } return nil @@ -109,7 +109,7 @@ func (src *InetArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *InetArray) DecodeText(ci *ConnInfo, src []byte) error { @@ -262,7 +262,7 @@ func (src *InetArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("inet"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "inet") + return nil, errors.Errorf("unable to find oid for type name %v", "inet") } for i := range src.Elements { @@ -306,7 +306,7 @@ func (dst *InetArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/int2.go b/pgtype/int2.go index a58c3355..45bce93c 100644 --- a/pgtype/int2.go +++ b/pgtype/int2.go @@ -3,11 +3,11 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "math" "strconv" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Int2 struct { @@ -30,46 +30,46 @@ func (dst *Int2) Set(src interface{}) error { *dst = Int2{Int: int16(value), Status: Present} case uint16: if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) + return errors.Errorf("%d is greater than maximum value for Int2", value) } *dst = Int2{Int: int16(value), Status: Present} case int32: if value < math.MinInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) + return errors.Errorf("%d is greater than maximum value for Int2", value) } if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) + return errors.Errorf("%d is greater than maximum value for Int2", value) } *dst = Int2{Int: int16(value), Status: Present} case uint32: if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) + return errors.Errorf("%d is greater than maximum value for Int2", value) } *dst = Int2{Int: int16(value), Status: Present} case int64: if value < math.MinInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) + return errors.Errorf("%d is greater than maximum value for Int2", value) } if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) + return errors.Errorf("%d is greater than maximum value for Int2", value) } *dst = Int2{Int: int16(value), Status: Present} case uint64: if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) + return errors.Errorf("%d is greater than maximum value for Int2", value) } *dst = Int2{Int: int16(value), Status: Present} case int: if value < math.MinInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) + return errors.Errorf("%d is greater than maximum value for Int2", value) } if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) + return errors.Errorf("%d is greater than maximum value for Int2", value) } *dst = Int2{Int: int16(value), Status: Present} case uint: if value > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", value) + return errors.Errorf("%d is greater than maximum value for Int2", value) } *dst = Int2{Int: int16(value), Status: Present} case string: @@ -82,7 +82,7 @@ func (dst *Int2) Set(src interface{}) error { if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Int2", value) + return errors.Errorf("cannot convert %v to Int2", value) } return nil @@ -125,7 +125,7 @@ func (dst *Int2) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 2 { - return fmt.Errorf("invalid length for int2: %v", len(src)) + return errors.Errorf("invalid length for int2: %v", len(src)) } n := int16(binary.BigEndian.Uint16(src)) @@ -165,10 +165,10 @@ func (dst *Int2) Scan(src interface{}) error { switch src := src.(type) { case int64: if src < math.MinInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", src) + return errors.Errorf("%d is greater than maximum value for Int2", src) } if src > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", src) + return errors.Errorf("%d is greater than maximum value for Int2", src) } *dst = Int2{Int: int16(src), Status: Present} return nil @@ -180,7 +180,7 @@ func (dst *Int2) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/int2_array.go b/pgtype/int2_array.go index e4993104..e939411b 100644 --- a/pgtype/int2_array.go +++ b/pgtype/int2_array.go @@ -3,9 +3,9 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Int2Array struct { @@ -59,7 +59,7 @@ func (dst *Int2Array) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Int2", value) + return errors.Errorf("cannot convert %v to Int2", value) } return nil @@ -108,7 +108,7 @@ func (src *Int2Array) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *Int2Array) DecodeText(ci *ConnInfo, src []byte) error { @@ -261,7 +261,7 @@ func (src *Int2Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("int2"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "int2") + return nil, errors.Errorf("unable to find oid for type name %v", "int2") } for i := range src.Elements { @@ -305,7 +305,7 @@ func (dst *Int2Array) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/int4.go b/pgtype/int4.go index 6f95013b..a3499fef 100644 --- a/pgtype/int4.go +++ b/pgtype/int4.go @@ -3,11 +3,11 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "math" "strconv" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Int4 struct { @@ -34,33 +34,33 @@ func (dst *Int4) Set(src interface{}) error { *dst = Int4{Int: int32(value), Status: Present} case uint32: if value > math.MaxInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", value) + return errors.Errorf("%d is greater than maximum value for Int4", value) } *dst = Int4{Int: int32(value), Status: Present} case int64: if value < math.MinInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", value) + return errors.Errorf("%d is greater than maximum value for Int4", value) } if value > math.MaxInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", value) + return errors.Errorf("%d is greater than maximum value for Int4", value) } *dst = Int4{Int: int32(value), Status: Present} case uint64: if value > math.MaxInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", value) + return errors.Errorf("%d is greater than maximum value for Int4", value) } *dst = Int4{Int: int32(value), Status: Present} case int: if value < math.MinInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", value) + return errors.Errorf("%d is greater than maximum value for Int4", value) } if value > math.MaxInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", value) + return errors.Errorf("%d is greater than maximum value for Int4", value) } *dst = Int4{Int: int32(value), Status: Present} case uint: if value > math.MaxInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", value) + return errors.Errorf("%d is greater than maximum value for Int4", value) } *dst = Int4{Int: int32(value), Status: Present} case string: @@ -73,7 +73,7 @@ func (dst *Int4) Set(src interface{}) error { if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Int4", value) + return errors.Errorf("cannot convert %v to Int4", value) } return nil @@ -116,7 +116,7 @@ func (dst *Int4) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 4 { - return fmt.Errorf("invalid length for int4: %v", len(src)) + return errors.Errorf("invalid length for int4: %v", len(src)) } n := int32(binary.BigEndian.Uint32(src)) @@ -156,10 +156,10 @@ func (dst *Int4) Scan(src interface{}) error { switch src := src.(type) { case int64: if src < math.MinInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", src) + return errors.Errorf("%d is greater than maximum value for Int4", src) } if src > math.MaxInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", src) + return errors.Errorf("%d is greater than maximum value for Int4", src) } *dst = Int4{Int: int32(src), Status: Present} return nil @@ -171,7 +171,7 @@ func (dst *Int4) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/int4_array.go b/pgtype/int4_array.go index 6bc06e86..1a907d2e 100644 --- a/pgtype/int4_array.go +++ b/pgtype/int4_array.go @@ -3,9 +3,9 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Int4Array struct { @@ -59,7 +59,7 @@ func (dst *Int4Array) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Int4", value) + return errors.Errorf("cannot convert %v to Int4", value) } return nil @@ -108,7 +108,7 @@ func (src *Int4Array) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *Int4Array) DecodeText(ci *ConnInfo, src []byte) error { @@ -261,7 +261,7 @@ func (src *Int4Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("int4"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "int4") + return nil, errors.Errorf("unable to find oid for type name %v", "int4") } for i := range src.Elements { @@ -305,7 +305,7 @@ func (dst *Int4Array) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/int4range.go b/pgtype/int4range.go index 4f27ff0d..95ad1521 100644 --- a/pgtype/int4range.go +++ b/pgtype/int4range.go @@ -2,9 +2,9 @@ package pgtype import ( "database/sql/driver" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Int4range struct { @@ -16,7 +16,7 @@ type Int4range struct { } func (dst *Int4range) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Int4range", src) + return errors.Errorf("cannot convert %v to Int4range", src) } func (dst *Int4range) Get() interface{} { @@ -31,7 +31,7 @@ func (dst *Int4range) Get() interface{} { } func (src *Int4range) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *Int4range) DecodeText(ci *ConnInfo, src []byte) error { @@ -120,7 +120,7 @@ func (src Int4range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, "empty"...), nil default: - return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) + return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) } var err error @@ -130,7 +130,7 @@ func (src Int4range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") } } @@ -141,7 +141,7 @@ func (src Int4range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") } } @@ -151,7 +151,7 @@ func (src Int4range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Inclusive: buf = append(buf, ']') default: - return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) + return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) } return buf, nil @@ -175,7 +175,7 @@ func (src Int4range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, emptyMask), nil default: - return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) + return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -185,7 +185,7 @@ func (src Int4range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) + return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) } buf = append(buf, rangeType) @@ -201,7 +201,7 @@ func (src Int4range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -216,7 +216,7 @@ func (src Int4range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -241,7 +241,7 @@ func (dst *Int4range) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/int8.go b/pgtype/int8.go index 939c0554..d671eda7 100644 --- a/pgtype/int8.go +++ b/pgtype/int8.go @@ -3,11 +3,11 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "math" "strconv" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Int8 struct { @@ -38,20 +38,20 @@ func (dst *Int8) Set(src interface{}) error { *dst = Int8{Int: int64(value), Status: Present} case uint64: if value > math.MaxInt64 { - return fmt.Errorf("%d is greater than maximum value for Int8", value) + return errors.Errorf("%d is greater than maximum value for Int8", value) } *dst = Int8{Int: int64(value), Status: Present} case int: if int64(value) < math.MinInt64 { - return fmt.Errorf("%d is greater than maximum value for Int8", value) + return errors.Errorf("%d is greater than maximum value for Int8", value) } if int64(value) > math.MaxInt64 { - return fmt.Errorf("%d is greater than maximum value for Int8", value) + return errors.Errorf("%d is greater than maximum value for Int8", value) } *dst = Int8{Int: int64(value), Status: Present} case uint: if uint64(value) > math.MaxInt64 { - return fmt.Errorf("%d is greater than maximum value for Int8", value) + return errors.Errorf("%d is greater than maximum value for Int8", value) } *dst = Int8{Int: int64(value), Status: Present} case string: @@ -64,7 +64,7 @@ func (dst *Int8) Set(src interface{}) error { if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Int8", value) + return errors.Errorf("cannot convert %v to Int8", value) } return nil @@ -107,7 +107,7 @@ func (dst *Int8) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 8 { - return fmt.Errorf("invalid length for int8: %v", len(src)) + return errors.Errorf("invalid length for int8: %v", len(src)) } n := int64(binary.BigEndian.Uint64(src)) @@ -157,7 +157,7 @@ func (dst *Int8) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/int8_array.go b/pgtype/int8_array.go index 4404d22a..4f3ab4dc 100644 --- a/pgtype/int8_array.go +++ b/pgtype/int8_array.go @@ -3,9 +3,9 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Int8Array struct { @@ -59,7 +59,7 @@ func (dst *Int8Array) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Int8", value) + return errors.Errorf("cannot convert %v to Int8", value) } return nil @@ -108,7 +108,7 @@ func (src *Int8Array) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *Int8Array) DecodeText(ci *ConnInfo, src []byte) error { @@ -261,7 +261,7 @@ func (src *Int8Array) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("int8"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "int8") + return nil, errors.Errorf("unable to find oid for type name %v", "int8") } for i := range src.Elements { @@ -305,7 +305,7 @@ func (dst *Int8Array) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/int8range.go b/pgtype/int8range.go index 128a853f..61d860d3 100644 --- a/pgtype/int8range.go +++ b/pgtype/int8range.go @@ -2,9 +2,9 @@ package pgtype import ( "database/sql/driver" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Int8range struct { @@ -16,7 +16,7 @@ type Int8range struct { } func (dst *Int8range) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Int8range", src) + return errors.Errorf("cannot convert %v to Int8range", src) } func (dst *Int8range) Get() interface{} { @@ -31,7 +31,7 @@ func (dst *Int8range) Get() interface{} { } func (src *Int8range) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *Int8range) DecodeText(ci *ConnInfo, src []byte) error { @@ -120,7 +120,7 @@ func (src Int8range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, "empty"...), nil default: - return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) + return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) } var err error @@ -130,7 +130,7 @@ func (src Int8range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") } } @@ -141,7 +141,7 @@ func (src Int8range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") } } @@ -151,7 +151,7 @@ func (src Int8range) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Inclusive: buf = append(buf, ']') default: - return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) + return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) } return buf, nil @@ -175,7 +175,7 @@ func (src Int8range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, emptyMask), nil default: - return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) + return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -185,7 +185,7 @@ func (src Int8range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) + return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) } buf = append(buf, rangeType) @@ -201,7 +201,7 @@ func (src Int8range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -216,7 +216,7 @@ func (src Int8range) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -241,7 +241,7 @@ func (dst *Int8range) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/interval.go b/pgtype/interval.go index 85d76d99..799ce53a 100644 --- a/pgtype/interval.go +++ b/pgtype/interval.go @@ -9,6 +9,7 @@ import ( "time" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) const ( @@ -37,7 +38,7 @@ func (dst *Interval) Set(src interface{}) error { if originalSrc, ok := underlyingPtrType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Interval", value) + return errors.Errorf("cannot convert %v to Interval", value) } return nil @@ -60,7 +61,7 @@ func (src *Interval) AssignTo(dst interface{}) error { switch v := dst.(type) { case *time.Duration: if src.Days > 0 || src.Months > 0 { - return fmt.Errorf("interval with months or days cannot be decoded into %T", dst) + return errors.Errorf("interval with months or days cannot be decoded into %T", dst) } *v = time.Duration(src.Microseconds) * time.Microsecond return nil @@ -73,7 +74,7 @@ func (src *Interval) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { @@ -91,7 +92,7 @@ func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { for i := 0; i < len(parts)-1; i += 2 { scalar, err := strconv.ParseInt(parts[i], 10, 64) if err != nil { - return fmt.Errorf("bad interval format") + return errors.Errorf("bad interval format") } switch parts[i+1] { @@ -107,7 +108,7 @@ func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { if len(parts)%2 == 1 { timeParts := strings.SplitN(parts[len(parts)-1], ":", 3) if len(timeParts) != 3 { - return fmt.Errorf("bad interval format") + return errors.Errorf("bad interval format") } var negative bool @@ -118,26 +119,26 @@ func (dst *Interval) DecodeText(ci *ConnInfo, src []byte) error { hours, err := strconv.ParseInt(timeParts[0], 10, 64) if err != nil { - return fmt.Errorf("bad interval hour format: %s", timeParts[0]) + return errors.Errorf("bad interval hour format: %s", timeParts[0]) } minutes, err := strconv.ParseInt(timeParts[1], 10, 64) if err != nil { - return fmt.Errorf("bad interval minute format: %s", timeParts[1]) + return errors.Errorf("bad interval minute format: %s", timeParts[1]) } secondParts := strings.SplitN(timeParts[2], ".", 2) seconds, err := strconv.ParseInt(secondParts[0], 10, 64) if err != nil { - return fmt.Errorf("bad interval second format: %s", secondParts[0]) + return errors.Errorf("bad interval second format: %s", secondParts[0]) } var uSeconds int64 if len(secondParts) == 2 { uSeconds, err = strconv.ParseInt(secondParts[1], 10, 64) if err != nil { - return fmt.Errorf("bad interval decimal format: %s", secondParts[1]) + return errors.Errorf("bad interval decimal format: %s", secondParts[1]) } for i := 0; i < 6-len(secondParts[1]); i++ { @@ -166,7 +167,7 @@ func (dst *Interval) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 16 { - return fmt.Errorf("Received an invalid size for a interval: %d", len(src)) + return errors.Errorf("Received an invalid size for a interval: %d", len(src)) } microseconds := int64(binary.BigEndian.Uint64(src)) @@ -240,7 +241,7 @@ func (dst *Interval) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/json.go b/pgtype/json.go index ee00e9a4..562722aa 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -3,7 +3,8 @@ package pgtype import ( "database/sql/driver" "encoding/json" - "fmt" + + "github.com/pkg/errors" ) type JSON struct { @@ -135,7 +136,7 @@ func (dst *JSON) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/jsonb.go b/pgtype/jsonb.go index 9a06c1b4..c315c588 100644 --- a/pgtype/jsonb.go +++ b/pgtype/jsonb.go @@ -2,7 +2,8 @@ package pgtype import ( "database/sql/driver" - "fmt" + + "github.com/pkg/errors" ) type JSONB JSON @@ -30,11 +31,11 @@ func (dst *JSONB) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) == 0 { - return fmt.Errorf("jsonb too short") + return errors.Errorf("jsonb too short") } if src[0] != 1 { - return fmt.Errorf("unknown jsonb version number %d", src[0]) + return errors.Errorf("unknown jsonb version number %d", src[0]) } *dst = JSONB{Bytes: src[1:], Status: Present} diff --git a/pgtype/line.go b/pgtype/line.go index 47f636a5..f6eadf0e 100644 --- a/pgtype/line.go +++ b/pgtype/line.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Line struct { @@ -17,7 +18,7 @@ type Line struct { } func (dst *Line) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Line", src) + return errors.Errorf("cannot convert %v to Line", src) } func (dst *Line) Get() interface{} { @@ -32,7 +33,7 @@ func (dst *Line) Get() interface{} { } func (src *Line) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *Line) DecodeText(ci *ConnInfo, src []byte) error { @@ -42,12 +43,12 @@ func (dst *Line) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) < 7 { - return fmt.Errorf("invalid length for Line: %v", len(src)) + return errors.Errorf("invalid length for Line: %v", len(src)) } parts := strings.SplitN(string(src[1:len(src)-1]), ",", 3) if len(parts) < 3 { - return fmt.Errorf("invalid format for line") + return errors.Errorf("invalid format for line") } a, err := strconv.ParseFloat(parts[0], 64) @@ -76,7 +77,7 @@ func (dst *Line) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 24 { - return fmt.Errorf("invalid length for Line: %v", len(src)) + return errors.Errorf("invalid length for Line: %v", len(src)) } a := binary.BigEndian.Uint64(src) @@ -133,7 +134,7 @@ func (dst *Line) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/lseg.go b/pgtype/lseg.go index 44c2b63c..a9d740cf 100644 --- a/pgtype/lseg.go +++ b/pgtype/lseg.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Lseg struct { @@ -17,7 +18,7 @@ type Lseg struct { } func (dst *Lseg) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Lseg", src) + return errors.Errorf("cannot convert %v to Lseg", src) } func (dst *Lseg) Get() interface{} { @@ -32,7 +33,7 @@ func (dst *Lseg) Get() interface{} { } func (src *Lseg) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *Lseg) DecodeText(ci *ConnInfo, src []byte) error { @@ -42,7 +43,7 @@ func (dst *Lseg) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) < 11 { - return fmt.Errorf("invalid length for Lseg: %v", len(src)) + return errors.Errorf("invalid length for Lseg: %v", len(src)) } str := string(src[2:]) @@ -89,7 +90,7 @@ func (dst *Lseg) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 32 { - return fmt.Errorf("invalid length for Lseg: %v", len(src)) + return errors.Errorf("invalid length for Lseg: %v", len(src)) } x1 := binary.BigEndian.Uint64(src) @@ -151,7 +152,7 @@ func (dst *Lseg) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/macaddr.go b/pgtype/macaddr.go index e38701eb..4c6e2212 100644 --- a/pgtype/macaddr.go +++ b/pgtype/macaddr.go @@ -2,8 +2,9 @@ package pgtype import ( "database/sql/driver" - "fmt" "net" + + "github.com/pkg/errors" ) type Macaddr struct { @@ -32,7 +33,7 @@ func (dst *Macaddr) Set(src interface{}) error { if originalSrc, ok := underlyingPtrType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Macaddr", value) + return errors.Errorf("cannot convert %v to Macaddr", value) } return nil @@ -69,7 +70,7 @@ func (src *Macaddr) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *Macaddr) DecodeText(ci *ConnInfo, src []byte) error { @@ -94,7 +95,7 @@ func (dst *Macaddr) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 6 { - return fmt.Errorf("Received an invalid size for a macaddr: %d", len(src)) + return errors.Errorf("Received an invalid size for a macaddr: %d", len(src)) } addr := make(net.HardwareAddr, 6) @@ -144,7 +145,7 @@ func (dst *Macaddr) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/numeric.go b/pgtype/numeric.go index dffb9963..fded6359 100644 --- a/pgtype/numeric.go +++ b/pgtype/numeric.go @@ -3,13 +3,13 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "math" "math/big" "strconv" "strings" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) // PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000 @@ -97,7 +97,7 @@ func (dst *Numeric) Set(src interface{}) error { if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Numeric", value) + return errors.Errorf("cannot convert %v to Numeric", value) } return nil @@ -136,10 +136,10 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(bigMaxInt) > 0 { - return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) } if normalizedInt.Cmp(bigMinInt) < 0 { - return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) } *v = int(normalizedInt.Int64()) case *int8: @@ -148,10 +148,10 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(bigMaxInt8) > 0 { - return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) } if normalizedInt.Cmp(bigMinInt8) < 0 { - return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) } *v = int8(normalizedInt.Int64()) case *int16: @@ -160,10 +160,10 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(bigMaxInt16) > 0 { - return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) } if normalizedInt.Cmp(bigMinInt16) < 0 { - return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) } *v = int16(normalizedInt.Int64()) case *int32: @@ -172,10 +172,10 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(bigMaxInt32) > 0 { - return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) } if normalizedInt.Cmp(bigMinInt32) < 0 { - return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) } *v = int32(normalizedInt.Int64()) case *int64: @@ -184,10 +184,10 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(bigMaxInt64) > 0 { - return fmt.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) + return errors.Errorf("%v is greater than maximum value for %T", normalizedInt, *v) } if normalizedInt.Cmp(bigMinInt64) < 0 { - return fmt.Errorf("%v is less than minimum value for %T", normalizedInt, *v) + return errors.Errorf("%v is less than minimum value for %T", normalizedInt, *v) } *v = normalizedInt.Int64() case *uint: @@ -196,9 +196,9 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(big0) < 0 { - return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) } else if normalizedInt.Cmp(bigMaxUint) > 0 { - return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) } *v = uint(normalizedInt.Uint64()) case *uint8: @@ -207,9 +207,9 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(big0) < 0 { - return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) } else if normalizedInt.Cmp(bigMaxUint8) > 0 { - return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) } *v = uint8(normalizedInt.Uint64()) case *uint16: @@ -218,9 +218,9 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(big0) < 0 { - return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) } else if normalizedInt.Cmp(bigMaxUint16) > 0 { - return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) } *v = uint16(normalizedInt.Uint64()) case *uint32: @@ -229,9 +229,9 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(big0) < 0 { - return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) } else if normalizedInt.Cmp(bigMaxUint32) > 0 { - return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) } *v = uint32(normalizedInt.Uint64()) case *uint64: @@ -240,9 +240,9 @@ func (src *Numeric) AssignTo(dst interface{}) error { return err } if normalizedInt.Cmp(big0) < 0 { - return fmt.Errorf("%d is less than zero for %T", normalizedInt, *v) + return errors.Errorf("%d is less than zero for %T", normalizedInt, *v) } else if normalizedInt.Cmp(bigMaxUint64) > 0 { - return fmt.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) + return errors.Errorf("%d is greater than maximum value for %T", normalizedInt, *v) } *v = normalizedInt.Uint64() default: @@ -276,7 +276,7 @@ func (dst *Numeric) toBigInt() (*big.Int, error) { remainder := &big.Int{} num.DivMod(num, div, remainder) if remainder.Cmp(big0) != 0 { - return nil, fmt.Errorf("cannot convert %v to integer", dst) + return nil, errors.Errorf("cannot convert %v to integer", dst) } return num, nil } @@ -328,7 +328,7 @@ func parseNumericString(str string) (n *big.Int, exp int32, err error) { accum := &big.Int{} if _, ok := accum.SetString(digits, 10); !ok { - return nil, 0, fmt.Errorf("%s is not a number", str) + return nil, 0, errors.Errorf("%s is not a number", str) } return accum, exp, nil @@ -341,7 +341,7 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) < 8 { - return fmt.Errorf("numeric incomplete %v", src) + return errors.Errorf("numeric incomplete %v", src) } rp := 0 @@ -361,7 +361,7 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { rp += 2 if len(src[rp:]) < int(ndigits)*2 { - return fmt.Errorf("numeric incomplete %v", src) + return errors.Errorf("numeric incomplete %v", src) } accum := &big.Int{} @@ -382,7 +382,7 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { case 4: mul = bigNBaseX4 default: - return fmt.Errorf("invalid digitsRead: %d (this can't happen)", digitsRead) + return errors.Errorf("invalid digitsRead: %d (this can't happen)", digitsRead) } accum.Mul(accum, mul) } @@ -575,7 +575,7 @@ func (dst *Numeric) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/numeric_array.go b/pgtype/numeric_array.go index f193a2a5..6dfbe5e3 100644 --- a/pgtype/numeric_array.go +++ b/pgtype/numeric_array.go @@ -3,9 +3,9 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type NumericArray struct { @@ -59,7 +59,7 @@ func (dst *NumericArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Numeric", value) + return errors.Errorf("cannot convert %v to Numeric", value) } return nil @@ -108,7 +108,7 @@ func (src *NumericArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *NumericArray) DecodeText(ci *ConnInfo, src []byte) error { @@ -261,7 +261,7 @@ func (src *NumericArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) if dt, ok := ci.DataTypeForName("numeric"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "numeric") + return nil, errors.Errorf("unable to find oid for type name %v", "numeric") } for i := range src.Elements { @@ -305,7 +305,7 @@ func (dst *NumericArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/numrange.go b/pgtype/numrange.go index 00133296..aaed62ce 100644 --- a/pgtype/numrange.go +++ b/pgtype/numrange.go @@ -2,9 +2,9 @@ package pgtype import ( "database/sql/driver" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Numrange struct { @@ -16,7 +16,7 @@ type Numrange struct { } func (dst *Numrange) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Numrange", src) + return errors.Errorf("cannot convert %v to Numrange", src) } func (dst *Numrange) Get() interface{} { @@ -31,7 +31,7 @@ func (dst *Numrange) Get() interface{} { } func (src *Numrange) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *Numrange) DecodeText(ci *ConnInfo, src []byte) error { @@ -120,7 +120,7 @@ func (src Numrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, "empty"...), nil default: - return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) + return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) } var err error @@ -130,7 +130,7 @@ func (src Numrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") } } @@ -141,7 +141,7 @@ func (src Numrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") } } @@ -151,7 +151,7 @@ func (src Numrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Inclusive: buf = append(buf, ']') default: - return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) + return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) } return buf, nil @@ -175,7 +175,7 @@ func (src Numrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, emptyMask), nil default: - return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) + return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -185,7 +185,7 @@ func (src Numrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) + return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) } buf = append(buf, rangeType) @@ -201,7 +201,7 @@ func (src Numrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -216,7 +216,7 @@ func (src Numrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -241,7 +241,7 @@ func (dst *Numrange) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/oid.go b/pgtype/oid.go index d37f4e57..59370d66 100644 --- a/pgtype/oid.go +++ b/pgtype/oid.go @@ -3,10 +3,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "strconv" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) // OID (Object Identifier Type) is, according to @@ -20,7 +20,7 @@ type OID uint32 func (dst *OID) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - return fmt.Errorf("cannot decode nil into OID") + return errors.Errorf("cannot decode nil into OID") } n, err := strconv.ParseUint(string(src), 10, 32) @@ -34,11 +34,11 @@ func (dst *OID) DecodeText(ci *ConnInfo, src []byte) error { func (dst *OID) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - return fmt.Errorf("cannot decode nil into OID") + return errors.Errorf("cannot decode nil into OID") } if len(src) != 4 { - return fmt.Errorf("invalid length: %v", len(src)) + return errors.Errorf("invalid length: %v", len(src)) } n := binary.BigEndian.Uint32(src) @@ -57,7 +57,7 @@ func (src OID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *OID) Scan(src interface{}) error { if src == nil { - return fmt.Errorf("cannot scan NULL into %T", src) + return errors.Errorf("cannot scan NULL into %T", src) } switch src := src.(type) { @@ -72,7 +72,7 @@ func (dst *OID) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/path.go b/pgtype/path.go index 3575342d..aa0cee8e 100644 --- a/pgtype/path.go +++ b/pgtype/path.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Path struct { @@ -18,7 +19,7 @@ type Path struct { } func (dst *Path) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Path", src) + return errors.Errorf("cannot convert %v to Path", src) } func (dst *Path) Get() interface{} { @@ -33,7 +34,7 @@ func (dst *Path) Get() interface{} { } func (src *Path) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *Path) DecodeText(ci *ConnInfo, src []byte) error { @@ -43,7 +44,7 @@ func (dst *Path) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) < 7 { - return fmt.Errorf("invalid length for Path: %v", len(src)) + return errors.Errorf("invalid length for Path: %v", len(src)) } closed := src[0] == '(' @@ -86,7 +87,7 @@ func (dst *Path) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) < 5 { - return fmt.Errorf("invalid length for Path: %v", len(src)) + return errors.Errorf("invalid length for Path: %v", len(src)) } closed := src[0] == 1 @@ -95,7 +96,7 @@ func (dst *Path) DecodeBinary(ci *ConnInfo, src []byte) error { rp := 5 if 5+pointCount*16 != len(src) { - return fmt.Errorf("invalid length for Path with %d points: %v", pointCount, len(src)) + return errors.Errorf("invalid length for Path with %d points: %v", pointCount, len(src)) } points := make([]Vec2, pointCount) @@ -183,7 +184,7 @@ func (dst *Path) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 4302a5fe..6f8e7986 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -1,8 +1,9 @@ package pgtype import ( - "errors" "reflect" + + "github.com/pkg/errors" ) // PostgreSQL oids for common types diff --git a/pgtype/pguint32.go b/pgtype/pguint32.go index 15b0f38d..e441a690 100644 --- a/pgtype/pguint32.go +++ b/pgtype/pguint32.go @@ -3,11 +3,11 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "math" "strconv" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) // pguint32 is the core type that is used to implement PostgreSQL types such as @@ -24,16 +24,16 @@ func (dst *pguint32) Set(src interface{}) error { switch value := src.(type) { case int64: if value < 0 { - return fmt.Errorf("%d is less than minimum value for pguint32", value) + return errors.Errorf("%d is less than minimum value for pguint32", value) } if value > math.MaxUint32 { - return fmt.Errorf("%d is greater than maximum value for pguint32", value) + return errors.Errorf("%d is greater than maximum value for pguint32", value) } *dst = pguint32{Uint: uint32(value), Status: Present} case uint32: *dst = pguint32{Uint: value, Status: Present} default: - return fmt.Errorf("cannot convert %v to pguint32", value) + return errors.Errorf("cannot convert %v to pguint32", value) } return nil @@ -58,7 +58,7 @@ func (src *pguint32) AssignTo(dst interface{}) error { if src.Status == Present { *v = src.Uint } else { - return fmt.Errorf("cannot assign %v into %T", src, dst) + return errors.Errorf("cannot assign %v into %T", src, dst) } case **uint32: if src.Status == Present { @@ -94,7 +94,7 @@ func (dst *pguint32) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 4 { - return fmt.Errorf("invalid length: %v", len(src)) + return errors.Errorf("invalid length: %v", len(src)) } n := binary.BigEndian.Uint32(src) @@ -146,7 +146,7 @@ func (dst *pguint32) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/point.go b/pgtype/point.go index 3d5d4e1a..3132a939 100644 --- a/pgtype/point.go +++ b/pgtype/point.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Vec2 struct { @@ -22,7 +23,7 @@ type Point struct { } func (dst *Point) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Point", src) + return errors.Errorf("cannot convert %v to Point", src) } func (dst *Point) Get() interface{} { @@ -37,7 +38,7 @@ func (dst *Point) Get() interface{} { } func (src *Point) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *Point) DecodeText(ci *ConnInfo, src []byte) error { @@ -47,12 +48,12 @@ func (dst *Point) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) < 5 { - return fmt.Errorf("invalid length for point: %v", len(src)) + return errors.Errorf("invalid length for point: %v", len(src)) } parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) if len(parts) < 2 { - return fmt.Errorf("invalid format for point") + return errors.Errorf("invalid format for point") } x, err := strconv.ParseFloat(parts[0], 64) @@ -76,7 +77,7 @@ func (dst *Point) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 16 { - return fmt.Errorf("invalid length for point: %v", len(src)) + return errors.Errorf("invalid length for point: %v", len(src)) } x := binary.BigEndian.Uint64(src) @@ -129,7 +130,7 @@ func (dst *Point) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/polygon.go b/pgtype/polygon.go index d0b50061..3f3d9f53 100644 --- a/pgtype/polygon.go +++ b/pgtype/polygon.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Polygon struct { @@ -17,7 +18,7 @@ type Polygon struct { } func (dst *Polygon) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Polygon", src) + return errors.Errorf("cannot convert %v to Polygon", src) } func (dst *Polygon) Get() interface{} { @@ -32,7 +33,7 @@ func (dst *Polygon) Get() interface{} { } func (src *Polygon) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *Polygon) DecodeText(ci *ConnInfo, src []byte) error { @@ -42,7 +43,7 @@ func (dst *Polygon) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) < 7 { - return fmt.Errorf("invalid length for Polygon: %v", len(src)) + return errors.Errorf("invalid length for Polygon: %v", len(src)) } points := make([]Vec2, 0) @@ -84,14 +85,14 @@ func (dst *Polygon) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) < 5 { - return fmt.Errorf("invalid length for Polygon: %v", len(src)) + return errors.Errorf("invalid length for Polygon: %v", len(src)) } pointCount := int(binary.BigEndian.Uint32(src)) rp := 4 if 4+pointCount*16 != len(src) { - return fmt.Errorf("invalid length for Polygon with %d points: %v", pointCount, len(src)) + return errors.Errorf("invalid length for Polygon with %d points: %v", pointCount, len(src)) } points := make([]Vec2, pointCount) @@ -164,7 +165,7 @@ func (dst *Polygon) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/qchar.go b/pgtype/qchar.go index 9c40ce18..064dab1e 100644 --- a/pgtype/qchar.go +++ b/pgtype/qchar.go @@ -1,9 +1,10 @@ package pgtype import ( - "fmt" "math" "strconv" + + "github.com/pkg/errors" ) // QChar is for PostgreSQL's special 8-bit-only "char" type more akin to the C @@ -33,59 +34,59 @@ func (dst *QChar) Set(src interface{}) error { *dst = QChar{Int: value, Status: Present} case uint8: if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) + return errors.Errorf("%d is greater than maximum value for QChar", value) } *dst = QChar{Int: int8(value), Status: Present} case int16: if value < math.MinInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) + return errors.Errorf("%d is greater than maximum value for QChar", value) } if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) + return errors.Errorf("%d is greater than maximum value for QChar", value) } *dst = QChar{Int: int8(value), Status: Present} case uint16: if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) + return errors.Errorf("%d is greater than maximum value for QChar", value) } *dst = QChar{Int: int8(value), Status: Present} case int32: if value < math.MinInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) + return errors.Errorf("%d is greater than maximum value for QChar", value) } if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) + return errors.Errorf("%d is greater than maximum value for QChar", value) } *dst = QChar{Int: int8(value), Status: Present} case uint32: if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) + return errors.Errorf("%d is greater than maximum value for QChar", value) } *dst = QChar{Int: int8(value), Status: Present} case int64: if value < math.MinInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) + return errors.Errorf("%d is greater than maximum value for QChar", value) } if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) + return errors.Errorf("%d is greater than maximum value for QChar", value) } *dst = QChar{Int: int8(value), Status: Present} case uint64: if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) + return errors.Errorf("%d is greater than maximum value for QChar", value) } *dst = QChar{Int: int8(value), Status: Present} case int: if value < math.MinInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) + return errors.Errorf("%d is greater than maximum value for QChar", value) } if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) + return errors.Errorf("%d is greater than maximum value for QChar", value) } *dst = QChar{Int: int8(value), Status: Present} case uint: if value > math.MaxInt8 { - return fmt.Errorf("%d is greater than maximum value for QChar", value) + return errors.Errorf("%d is greater than maximum value for QChar", value) } *dst = QChar{Int: int8(value), Status: Present} case string: @@ -98,7 +99,7 @@ func (dst *QChar) Set(src interface{}) error { if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to QChar", value) + return errors.Errorf("cannot convert %v to QChar", value) } return nil @@ -126,7 +127,7 @@ func (dst *QChar) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 1 { - return fmt.Errorf(`invalid length for "char": %v`, len(src)) + return errors.Errorf(`invalid length for "char": %v`, len(src)) } *dst = QChar{Int: int8(src[0]), Status: Present} diff --git a/pgtype/range.go b/pgtype/range.go index 76daf8cc..d870834f 100644 --- a/pgtype/range.go +++ b/pgtype/range.go @@ -3,7 +3,8 @@ package pgtype import ( "bytes" "encoding/binary" - "fmt" + + "github.com/pkg/errors" ) type BoundType byte @@ -36,7 +37,7 @@ func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { r, _, err := buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid lower bound: %v", err) + return nil, errors.Errorf("invalid lower bound: %v", err) } switch r { case '(': @@ -44,12 +45,12 @@ func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { case '[': utr.LowerType = Inclusive default: - return nil, fmt.Errorf("missing lower bound, instead got: %v", string(r)) + return nil, errors.Errorf("missing lower bound, instead got: %v", string(r)) } r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid lower value: %v", err) + return nil, errors.Errorf("invalid lower value: %v", err) } buf.UnreadRune() @@ -58,21 +59,21 @@ func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { } else { utr.Lower, err = rangeParseValue(buf) if err != nil { - return nil, fmt.Errorf("invalid lower value: %v", err) + return nil, errors.Errorf("invalid lower value: %v", err) } } r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("missing range separator: %v", err) + return nil, errors.Errorf("missing range separator: %v", err) } if r != ',' { - return nil, fmt.Errorf("missing range separator: %v", r) + return nil, errors.Errorf("missing range separator: %v", r) } r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("invalid upper value: %v", err) + return nil, errors.Errorf("invalid upper value: %v", err) } buf.UnreadRune() @@ -81,13 +82,13 @@ func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { } else { utr.Upper, err = rangeParseValue(buf) if err != nil { - return nil, fmt.Errorf("invalid upper value: %v", err) + return nil, errors.Errorf("invalid upper value: %v", err) } } r, _, err = buf.ReadRune() if err != nil { - return nil, fmt.Errorf("missing upper bound: %v", err) + return nil, errors.Errorf("missing upper bound: %v", err) } switch r { case ')': @@ -95,13 +96,13 @@ func ParseUntypedTextRange(src string) (*UntypedTextRange, error) { case ']': utr.UpperType = Inclusive default: - return nil, fmt.Errorf("missing upper bound, instead got: %v", string(r)) + return nil, errors.Errorf("missing upper bound, instead got: %v", string(r)) } skipWhitespace(buf) if buf.Len() > 0 { - return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) + return nil, errors.Errorf("unexpected trailing data: %v", buf.String()) } return utr, nil @@ -197,7 +198,7 @@ func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { ubr := &UntypedBinaryRange{} if len(src) == 0 { - return nil, fmt.Errorf("range too short: %v", len(src)) + return nil, errors.Errorf("range too short: %v", len(src)) } rangeType := src[0] @@ -205,7 +206,7 @@ func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { if rangeType&emptyMask > 0 { if len(src[rp:]) > 0 { - return nil, fmt.Errorf("unexpected trailing bytes parsing empty range: %v", len(src[rp:])) + return nil, errors.Errorf("unexpected trailing bytes parsing empty range: %v", len(src[rp:])) } ubr.LowerType = Empty ubr.UpperType = Empty @@ -230,13 +231,13 @@ func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { if ubr.LowerType == Unbounded && ubr.UpperType == Unbounded { if len(src[rp:]) > 0 { - return nil, fmt.Errorf("unexpected trailing bytes parsing unbounded range: %v", len(src[rp:])) + return nil, errors.Errorf("unexpected trailing bytes parsing unbounded range: %v", len(src[rp:])) } return ubr, nil } if len(src[rp:]) < 4 { - return nil, fmt.Errorf("too few bytes for size: %v", src[rp:]) + return nil, errors.Errorf("too few bytes for size: %v", src[rp:]) } valueLen := int(binary.BigEndian.Uint32(src[rp:])) rp += 4 @@ -249,14 +250,14 @@ func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { } else { ubr.Upper = val if len(src[rp:]) > 0 { - return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) + return nil, errors.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) } return ubr, nil } if ubr.UpperType != Unbounded { if len(src[rp:]) < 4 { - return nil, fmt.Errorf("too few bytes for size: %v", src[rp:]) + return nil, errors.Errorf("too few bytes for size: %v", src[rp:]) } valueLen := int(binary.BigEndian.Uint32(src[rp:])) rp += 4 @@ -265,7 +266,7 @@ func ParseUntypedBinaryRange(src []byte) (*UntypedBinaryRange, error) { } if len(src[rp:]) > 0 { - return nil, fmt.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) + return nil, errors.Errorf("unexpected trailing bytes parsing range: %v", len(src[rp:])) } return ubr, nil diff --git a/pgtype/record.go b/pgtype/record.go index 7c8736df..14b415c3 100644 --- a/pgtype/record.go +++ b/pgtype/record.go @@ -2,7 +2,8 @@ package pgtype import ( "encoding/binary" - "fmt" + + "github.com/pkg/errors" ) // Record is the generic PostgreSQL record type such as is created with the @@ -25,7 +26,7 @@ func (dst *Record) Set(src interface{}) error { case []Value: *dst = Record{Fields: value, Status: Present} default: - return fmt.Errorf("cannot convert %v to Record", src) + return errors.Errorf("cannot convert %v to Record", src) } return nil @@ -65,7 +66,7 @@ func (src *Record) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { @@ -77,7 +78,7 @@ func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { rp := 0 if len(src[rp:]) < 4 { - return fmt.Errorf("Record incomplete %v", src) + return errors.Errorf("Record incomplete %v", src) } fieldCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 @@ -86,7 +87,7 @@ func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { for i := 0; i < fieldCount; i++ { if len(src[rp:]) < 8 { - return fmt.Errorf("Record incomplete %v", src) + return errors.Errorf("Record incomplete %v", src) } fieldOID := OID(binary.BigEndian.Uint32(src[rp:])) rp += 4 @@ -97,14 +98,14 @@ func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { var binaryDecoder BinaryDecoder if dt, ok := ci.DataTypeForOID(fieldOID); ok { if binaryDecoder, ok = dt.Value.(BinaryDecoder); !ok { - return fmt.Errorf("unknown oid while decoding record: %v", fieldOID) + return errors.Errorf("unknown oid while decoding record: %v", fieldOID) } } var fieldBytes []byte if fieldLen >= 0 { if len(src[rp:]) < fieldLen { - return fmt.Errorf("Record incomplete %v", src) + return errors.Errorf("Record incomplete %v", src) } fieldBytes = src[rp : rp+fieldLen] rp += fieldLen diff --git a/pgtype/text.go b/pgtype/text.go index 6638c354..f05e1e89 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -3,7 +3,8 @@ package pgtype import ( "database/sql/driver" "encoding/json" - "fmt" + + "github.com/pkg/errors" ) type Text struct { @@ -36,7 +37,7 @@ func (dst *Text) Set(src interface{}) error { if originalSrc, ok := underlyingStringType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Text", value) + return errors.Errorf("cannot convert %v to Text", value) } return nil @@ -73,7 +74,7 @@ func (src *Text) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *Text) DecodeText(ci *ConnInfo, src []byte) error { @@ -121,7 +122,7 @@ func (dst *Text) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/text_array.go b/pgtype/text_array.go index dab7d36e..2609a2cc 100644 --- a/pgtype/text_array.go +++ b/pgtype/text_array.go @@ -3,9 +3,9 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type TextArray struct { @@ -40,7 +40,7 @@ func (dst *TextArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Text", value) + return errors.Errorf("cannot convert %v to Text", value) } return nil @@ -80,7 +80,7 @@ func (src *TextArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *TextArray) DecodeText(ci *ConnInfo, src []byte) error { @@ -233,7 +233,7 @@ func (src *TextArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { if dt, ok := ci.DataTypeForName("text"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "text") + return nil, errors.Errorf("unable to find oid for type name %v", "text") } for i := range src.Elements { @@ -277,7 +277,7 @@ func (dst *TextArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/tid.go b/pgtype/tid.go index d44ea3a6..21852a14 100644 --- a/pgtype/tid.go +++ b/pgtype/tid.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) // TID is PostgreSQL's Tuple Identifier type. @@ -28,7 +29,7 @@ type TID struct { } func (dst *TID) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to TID", src) + return errors.Errorf("cannot convert %v to TID", src) } func (dst *TID) Get() interface{} { @@ -43,7 +44,7 @@ func (dst *TID) Get() interface{} { } func (src *TID) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *TID) DecodeText(ci *ConnInfo, src []byte) error { @@ -53,12 +54,12 @@ func (dst *TID) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) < 5 { - return fmt.Errorf("invalid length for tid: %v", len(src)) + return errors.Errorf("invalid length for tid: %v", len(src)) } parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) if len(parts) < 2 { - return fmt.Errorf("invalid format for tid") + return errors.Errorf("invalid format for tid") } blockNumber, err := strconv.ParseUint(parts[0], 10, 32) @@ -82,7 +83,7 @@ func (dst *TID) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 6 { - return fmt.Errorf("invalid length for tid: %v", len(src)) + return errors.Errorf("invalid length for tid: %v", len(src)) } *dst = TID{ @@ -134,7 +135,7 @@ func (dst *TID) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index 75c6cffa..d906f467 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -3,10 +3,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "time" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) const pgTimestampFormat = "2006-01-02 15:04:05.999999999" @@ -37,7 +37,7 @@ func (dst *Timestamp) Set(src interface{}) error { if originalSrc, ok := underlyingTimeType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Timestamp", value) + return errors.Errorf("cannot convert %v to Timestamp", value) } return nil @@ -63,7 +63,7 @@ func (src *Timestamp) AssignTo(dst interface{}) error { switch v := dst.(type) { case *time.Time: if src.InfinityModifier != None { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } *v = src.Time return nil @@ -76,7 +76,7 @@ func (src *Timestamp) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } // DecodeText decodes from src into dst. The decoded time is considered to @@ -114,7 +114,7 @@ func (dst *Timestamp) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 8 { - return fmt.Errorf("invalid length for timestamp: %v", len(src)) + return errors.Errorf("invalid length for timestamp: %v", len(src)) } microsecSinceY2K := int64(binary.BigEndian.Uint64(src)) @@ -143,7 +143,7 @@ func (src *Timestamp) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, errUndefined } if src.Time.Location() != time.UTC { - return nil, fmt.Errorf("cannot encode non-UTC time into timestamp") + return nil, errors.Errorf("cannot encode non-UTC time into timestamp") } var s string @@ -170,7 +170,7 @@ func (src *Timestamp) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, errUndefined } if src.Time.Location() != time.UTC { - return nil, fmt.Errorf("cannot encode non-UTC time into timestamp") + return nil, errors.Errorf("cannot encode non-UTC time into timestamp") } var microsecSinceY2K int64 @@ -206,7 +206,7 @@ func (dst *Timestamp) Scan(src interface{}) error { return nil } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/timestamp_array.go b/pgtype/timestamp_array.go index fca9ad93..be281f2e 100644 --- a/pgtype/timestamp_array.go +++ b/pgtype/timestamp_array.go @@ -3,10 +3,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "time" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type TimestampArray struct { @@ -41,7 +41,7 @@ func (dst *TimestampArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Timestamp", value) + return errors.Errorf("cannot convert %v to Timestamp", value) } return nil @@ -81,7 +81,7 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *TimestampArray) DecodeText(ci *ConnInfo, src []byte) error { @@ -234,7 +234,7 @@ func (src *TimestampArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error if dt, ok := ci.DataTypeForName("timestamp"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "timestamp") + return nil, errors.Errorf("unable to find oid for type name %v", "timestamp") } for i := range src.Elements { @@ -278,7 +278,7 @@ func (dst *TimestampArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index 97b0de2a..74fe4954 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -3,10 +3,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "time" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) const pgTimestamptzHourFormat = "2006-01-02 15:04:05.999999999Z07" @@ -38,7 +38,7 @@ func (dst *Timestamptz) Set(src interface{}) error { if originalSrc, ok := underlyingTimeType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Timestamptz", value) + return errors.Errorf("cannot convert %v to Timestamptz", value) } return nil @@ -64,7 +64,7 @@ func (src *Timestamptz) AssignTo(dst interface{}) error { switch v := dst.(type) { case *time.Time: if src.InfinityModifier != None { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } *v = src.Time return nil @@ -77,7 +77,7 @@ func (src *Timestamptz) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *Timestamptz) DecodeText(ci *ConnInfo, src []byte) error { @@ -120,7 +120,7 @@ func (dst *Timestamptz) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 8 { - return fmt.Errorf("invalid length for timestamptz: %v", len(src)) + return errors.Errorf("invalid length for timestamptz: %v", len(src)) } microsecSinceY2K := int64(binary.BigEndian.Uint64(src)) @@ -202,7 +202,7 @@ func (dst *Timestamptz) Scan(src interface{}) error { return nil } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/timestamptz_array.go b/pgtype/timestamptz_array.go index e0866d69..086a4ef0 100644 --- a/pgtype/timestamptz_array.go +++ b/pgtype/timestamptz_array.go @@ -3,10 +3,10 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "time" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type TimestamptzArray struct { @@ -41,7 +41,7 @@ func (dst *TimestamptzArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Timestamptz", value) + return errors.Errorf("cannot convert %v to Timestamptz", value) } return nil @@ -81,7 +81,7 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *TimestamptzArray) DecodeText(ci *ConnInfo, src []byte) error { @@ -234,7 +234,7 @@ func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, err if dt, ok := ci.DataTypeForName("timestamptz"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "timestamptz") + return nil, errors.Errorf("unable to find oid for type name %v", "timestamptz") } for i := range src.Elements { @@ -278,7 +278,7 @@ func (dst *TimestamptzArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/tsrange.go b/pgtype/tsrange.go index 783fb086..8a67d65e 100644 --- a/pgtype/tsrange.go +++ b/pgtype/tsrange.go @@ -2,9 +2,9 @@ package pgtype import ( "database/sql/driver" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Tsrange struct { @@ -16,7 +16,7 @@ type Tsrange struct { } func (dst *Tsrange) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Tsrange", src) + return errors.Errorf("cannot convert %v to Tsrange", src) } func (dst *Tsrange) Get() interface{} { @@ -31,7 +31,7 @@ func (dst *Tsrange) Get() interface{} { } func (src *Tsrange) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *Tsrange) DecodeText(ci *ConnInfo, src []byte) error { @@ -120,7 +120,7 @@ func (src Tsrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, "empty"...), nil default: - return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) + return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) } var err error @@ -130,7 +130,7 @@ func (src Tsrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") } } @@ -141,7 +141,7 @@ func (src Tsrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") } } @@ -151,7 +151,7 @@ func (src Tsrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Inclusive: buf = append(buf, ']') default: - return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) + return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) } return buf, nil @@ -175,7 +175,7 @@ func (src Tsrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, emptyMask), nil default: - return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) + return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -185,7 +185,7 @@ func (src Tsrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) + return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) } buf = append(buf, rangeType) @@ -201,7 +201,7 @@ func (src Tsrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -216,7 +216,7 @@ func (src Tsrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -241,7 +241,7 @@ func (dst *Tsrange) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/tstzrange.go b/pgtype/tstzrange.go index 8fd3fd68..b5129093 100644 --- a/pgtype/tstzrange.go +++ b/pgtype/tstzrange.go @@ -2,9 +2,9 @@ package pgtype import ( "database/sql/driver" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Tstzrange struct { @@ -16,7 +16,7 @@ type Tstzrange struct { } func (dst *Tstzrange) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Tstzrange", src) + return errors.Errorf("cannot convert %v to Tstzrange", src) } func (dst *Tstzrange) Get() interface{} { @@ -31,7 +31,7 @@ func (dst *Tstzrange) Get() interface{} { } func (src *Tstzrange) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *Tstzrange) DecodeText(ci *ConnInfo, src []byte) error { @@ -120,7 +120,7 @@ func (src Tstzrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, "empty"...), nil default: - return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) + return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) } var err error @@ -130,7 +130,7 @@ func (src Tstzrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") } } @@ -141,7 +141,7 @@ func (src Tstzrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { if err != nil { return nil, err } else if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") } } @@ -151,7 +151,7 @@ func (src Tstzrange) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Inclusive: buf = append(buf, ']') default: - return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) + return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) } return buf, nil @@ -175,7 +175,7 @@ func (src Tstzrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { case Empty: return append(buf, emptyMask), nil default: - return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) + return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -185,7 +185,7 @@ func (src Tstzrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { rangeType |= upperUnboundedMask case Exclusive: default: - return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) + return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) } buf = append(buf, rangeType) @@ -201,7 +201,7 @@ func (src Tstzrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -216,7 +216,7 @@ func (src Tstzrange) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, err } if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -241,7 +241,7 @@ func (dst *Tstzrange) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb index 01072549..7a69d0ab 100644 --- a/pgtype/typed_array.go.erb +++ b/pgtype/typed_array.go.erb @@ -40,7 +40,7 @@ func (dst *<%= pgtype_array_type %>) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to <%= pgtype_element_type %>", value) + return errors.Errorf("cannot convert %v to <%= pgtype_element_type %>", value) } return nil @@ -80,7 +80,7 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error { @@ -236,7 +236,7 @@ func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byt if dt, ok := ci.DataTypeForName("<%= element_type_name %>"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "<%= element_type_name %>") + return nil, errors.Errorf("unable to find oid for type name %v", "<%= element_type_name %>") } for i := range src.Elements { @@ -281,7 +281,7 @@ func (dst *<%= pgtype_array_type %>) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/typed_range.go.erb b/pgtype/typed_range.go.erb index 90c23991..91a5cb97 100644 --- a/pgtype/typed_range.go.erb +++ b/pgtype/typed_range.go.erb @@ -18,7 +18,7 @@ type <%= range_type %> struct { } func (dst *<%= range_type %>) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to <%= range_type %>", src) + return errors.Errorf("cannot convert %v to <%= range_type %>", src) } func (dst *<%= range_type %>) Get() interface{} { @@ -33,7 +33,7 @@ func (dst *<%= range_type %>) Get() interface{} { } func (src *<%= range_type %>) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *<%= range_type %>) DecodeText(ci *ConnInfo, src []byte) error { @@ -122,7 +122,7 @@ func (src <%= range_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error case Empty: return append(buf, "empty"...), nil default: - return nil, fmt.Errorf("unknown lower bound type %v", src.LowerType) + return nil, errors.Errorf("unknown lower bound type %v", src.LowerType) } var err error @@ -132,7 +132,7 @@ func (src <%= range_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error if err != nil { return nil, err } else if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") } } @@ -143,7 +143,7 @@ func (src <%= range_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error if err != nil { return nil, err } else if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") } } @@ -153,7 +153,7 @@ func (src <%= range_type %>) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error case Inclusive: buf = append(buf, ']') default: - return nil, fmt.Errorf("unknown upper bound type %v", src.UpperType) + return nil, errors.Errorf("unknown upper bound type %v", src.UpperType) } return buf, nil @@ -177,7 +177,7 @@ func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, err case Empty: return append(buf, emptyMask), nil default: - return nil, fmt.Errorf("unknown LowerType: %v", src.LowerType) + return nil, errors.Errorf("unknown LowerType: %v", src.LowerType) } switch src.UpperType { @@ -187,7 +187,7 @@ func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, err rangeType |= upperUnboundedMask case Exclusive: default: - return nil, fmt.Errorf("unknown UpperType: %v", src.UpperType) + return nil, errors.Errorf("unknown UpperType: %v", src.UpperType) } buf = append(buf, rangeType) @@ -203,7 +203,7 @@ func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, err return nil, err } if buf == nil { - return nil, fmt.Errorf("Lower cannot be null unless LowerType is Unbounded") + return nil, errors.Errorf("Lower cannot be null unless LowerType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -218,7 +218,7 @@ func (src <%= range_type %>) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, err return nil, err } if buf == nil { - return nil, fmt.Errorf("Upper cannot be null unless UpperType is Unbounded") + return nil, errors.Errorf("Upper cannot be null unless UpperType is Unbounded") } pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) @@ -243,7 +243,7 @@ func (dst *<%= range_type %>) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/uuid.go b/pgtype/uuid.go index d1ab1a38..33e79536 100644 --- a/pgtype/uuid.go +++ b/pgtype/uuid.go @@ -4,6 +4,8 @@ import ( "database/sql/driver" "encoding/hex" "fmt" + + "github.com/pkg/errors" ) type UUID struct { @@ -17,7 +19,7 @@ func (dst *UUID) Set(src interface{}) error { *dst = UUID{Bytes: value, Status: Present} case []byte: if len(value) != 16 { - return fmt.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) + return errors.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) } *dst = UUID{Status: Present} copy(dst.Bytes[:], value) @@ -31,7 +33,7 @@ func (dst *UUID) Set(src interface{}) error { if originalSrc, ok := underlyingPtrType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to UUID", value) + return errors.Errorf("cannot convert %v to UUID", value) } return nil @@ -71,7 +73,7 @@ func (src *UUID) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot assign %v into %T", src, dst) + return errors.Errorf("cannot assign %v into %T", src, dst) } // parseUUID converts a string UUID in standard form to a byte array. @@ -98,7 +100,7 @@ func (dst *UUID) DecodeText(ci *ConnInfo, src []byte) error { } if len(src) != 36 { - return fmt.Errorf("invalid length for UUID: %v", len(src)) + return errors.Errorf("invalid length for UUID: %v", len(src)) } buf, err := parseUUID(string(src)) @@ -117,7 +119,7 @@ func (dst *UUID) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) != 16 { - return fmt.Errorf("invalid length for UUID: %v", len(src)) + return errors.Errorf("invalid length for UUID: %v", len(src)) } *dst = UUID{Status: Present} @@ -163,7 +165,7 @@ func (dst *UUID) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/varbit.go b/pgtype/varbit.go index 9a9fe1e1..dfa194d2 100644 --- a/pgtype/varbit.go +++ b/pgtype/varbit.go @@ -3,9 +3,9 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type Varbit struct { @@ -15,7 +15,7 @@ type Varbit struct { } func (dst *Varbit) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Varbit", src) + return errors.Errorf("cannot convert %v to Varbit", src) } func (dst *Varbit) Get() interface{} { @@ -30,7 +30,7 @@ func (dst *Varbit) Get() interface{} { } func (src *Varbit) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) + return errors.Errorf("cannot assign %v to %T", src, dst) } func (dst *Varbit) DecodeText(ci *ConnInfo, src []byte) error { @@ -65,7 +65,7 @@ func (dst *Varbit) DecodeBinary(ci *ConnInfo, src []byte) error { } if len(src) < 4 { - return fmt.Errorf("invalid length for varbit: %v", len(src)) + return errors.Errorf("invalid length for varbit: %v", len(src)) } bitLen := int32(binary.BigEndian.Uint32(src)) @@ -124,7 +124,7 @@ func (dst *Varbit) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/pgtype/varchar_array.go b/pgtype/varchar_array.go index 95b5cfc1..fecbb2e5 100644 --- a/pgtype/varchar_array.go +++ b/pgtype/varchar_array.go @@ -3,9 +3,9 @@ package pgtype import ( "database/sql/driver" "encoding/binary" - "fmt" "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" ) type VarcharArray struct { @@ -40,7 +40,7 @@ func (dst *VarcharArray) Set(src interface{}) error { if originalSrc, ok := underlyingSliceType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Varchar", value) + return errors.Errorf("cannot convert %v to Varchar", value) } return nil @@ -80,7 +80,7 @@ func (src *VarcharArray) AssignTo(dst interface{}) error { return NullAssignTo(dst) } - return fmt.Errorf("cannot decode %v into %T", src, dst) + return errors.Errorf("cannot decode %v into %T", src, dst) } func (dst *VarcharArray) DecodeText(ci *ConnInfo, src []byte) error { @@ -233,7 +233,7 @@ func (src *VarcharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) if dt, ok := ci.DataTypeForName("varchar"); ok { arrayHeader.ElementOID = int32(dt.OID) } else { - return nil, fmt.Errorf("unable to find oid for type name %v", "varchar") + return nil, errors.Errorf("unable to find oid for type name %v", "varchar") } for i := range src.Elements { @@ -277,7 +277,7 @@ func (dst *VarcharArray) Scan(src interface{}) error { return dst.DecodeText(nil, srcCopy) } - return fmt.Errorf("cannot scan %T", src) + return errors.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface. diff --git a/query.go b/query.go index c12d64f0..811e95b1 100644 --- a/query.go +++ b/query.go @@ -3,10 +3,11 @@ package pgx import ( "context" "database/sql" - "errors" "fmt" "time" + "github.com/pkg/errors" + "github.com/jackc/pgx/internal/sanitize" "github.com/jackc/pgx/pgproto3" "github.com/jackc/pgx/pgtype" @@ -135,7 +136,7 @@ func (rows *Rows) Next() bool { rows.fields[i].DataTypeName = dt.Name rows.fields[i].FormatCode = TextFormatCode } else { - rows.fatal(fmt.Errorf("unknown oid: %d", rows.fields[i].DataType)) + rows.fatal(errors.Errorf("unknown oid: %d", rows.fields[i].DataType)) return false } } @@ -191,7 +192,7 @@ func (e scanArgError) Error() string { // copy the raw bytes received from PostgreSQL. nil will skip the value entirely. func (rows *Rows) Scan(dest ...interface{}) (err error) { if len(rows.fields) != len(dest) { - err = fmt.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.fields)) + err = errors.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.fields)) rows.fatal(err) return err } @@ -224,7 +225,7 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { rows.fatal(scanArgError{col: i, err: err}) } } else { - rows.fatal(scanArgError{col: i, err: fmt.Errorf("%T is not a pgtype.TextDecoder", value)}) + rows.fatal(scanArgError{col: i, err: errors.Errorf("%T is not a pgtype.TextDecoder", value)}) } case BinaryFormatCode: if binaryDecoder, ok := value.(pgtype.BinaryDecoder); ok { @@ -233,10 +234,10 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { rows.fatal(scanArgError{col: i, err: err}) } } else { - rows.fatal(scanArgError{col: i, err: fmt.Errorf("%T is not a pgtype.BinaryDecoder", value)}) + rows.fatal(scanArgError{col: i, err: errors.Errorf("%T is not a pgtype.BinaryDecoder", value)}) } default: - rows.fatal(scanArgError{col: i, err: fmt.Errorf("unknown format code: %v", fd.FormatCode)}) + rows.fatal(scanArgError{col: i, err: errors.Errorf("unknown format code: %v", fd.FormatCode)}) } if rows.Err() == nil { @@ -254,7 +255,7 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { } } } else { - rows.fatal(scanArgError{col: i, err: fmt.Errorf("unknown oid: %v", fd.DataType)}) + rows.fatal(scanArgError{col: i, err: errors.Errorf("unknown oid: %v", fd.DataType)}) } } @@ -464,11 +465,11 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, func (c *Conn) buildOneRoundTripQueryEx(buf []byte, sql string, options *QueryExOptions, arguments []interface{}) ([]byte, error) { if len(arguments) != len(options.ParameterOIDs) { - return nil, fmt.Errorf("mismatched number of arguments (%d) and options.ParameterOIDs (%d)", len(arguments), len(options.ParameterOIDs)) + return nil, errors.Errorf("mismatched number of arguments (%d) and options.ParameterOIDs (%d)", len(arguments), len(options.ParameterOIDs)) } if len(options.ParameterOIDs) > 65535 { - return nil, fmt.Errorf("Number of QueryExOptions ParameterOIDs must be between 0 and 65535, received %d", len(options.ParameterOIDs)) + return nil, errors.Errorf("Number of QueryExOptions ParameterOIDs must be between 0 and 65535, received %d", len(options.ParameterOIDs)) } buf = appendParse(buf, "", sql, options.ParameterOIDs) @@ -497,7 +498,7 @@ func (c *Conn) readUntilRowDescription() ([]FieldDescription, error) { if dt, ok := c.ConnInfo.DataTypeForOID(fieldDescriptions[i].DataType); ok { fieldDescriptions[i].DataTypeName = dt.Name } else { - return nil, fmt.Errorf("unknown oid: %d", fieldDescriptions[i].DataType) + return nil, errors.Errorf("unknown oid: %d", fieldDescriptions[i].DataType) } } return fieldDescriptions, nil diff --git a/replication.go b/replication.go index 1bf69c4e..bfa81e54 100644 --- a/replication.go +++ b/replication.go @@ -3,10 +3,11 @@ package pgx import ( "context" "encoding/binary" - "errors" "fmt" "time" + "github.com/pkg/errors" + "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgproto3" ) diff --git a/stdlib/sql.go b/stdlib/sql.go index b9cd3295..0c140343 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -68,12 +68,13 @@ import ( "database/sql" "database/sql/driver" "encoding/binary" - "errors" "fmt" "io" "strings" "sync" + "github.com/pkg/errors" + "github.com/jackc/pgx" "github.com/jackc/pgx/pgtype" ) @@ -260,7 +261,7 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e case sql.LevelSerializable: pgxOpts.IsoLevel = pgx.Serializable default: - return nil, fmt.Errorf("unsupported isolation: %v", opts.Isolation) + return nil, errors.Errorf("unsupported isolation: %v", opts.Isolation) } if opts.ReadOnly { @@ -546,7 +547,7 @@ func ReleaseConn(db *sql.DB, conn *pgx.Conn) error { driver.fakeTxMutex.Unlock() } else { driver.fakeTxMutex.Unlock() - return fmt.Errorf("can't release conn that is not acquired") + return errors.Errorf("can't release conn that is not acquired") } return tx.Rollback() diff --git a/stress_test.go b/stress_test.go index 93752c29..114bec81 100644 --- a/stress_test.go +++ b/stress_test.go @@ -2,7 +2,6 @@ package pgx_test import ( "context" - "errors" "fmt" "math/rand" "os" @@ -10,6 +9,8 @@ import ( "testing" "time" + "github.com/pkg/errors" + "github.com/jackc/fake" "github.com/jackc/pgx" ) @@ -73,7 +74,7 @@ func TestStressConnPool(t *testing.T) { action := actions[rand.Intn(len(actions))] err := action.fn(pool, n) if err != nil { - errChan <- fmt.Errorf("%s: %v", action.name, err) + errChan <- errors.Errorf("%s: %v", action.name, err) break } } @@ -235,7 +236,7 @@ func poolPrepareUseAndDeallocate(pool *pgx.ConnPool, actionNum int) error { } if s != "hello" { - return fmt.Errorf("Prepared statement did not return expected value: %v", s) + return errors.Errorf("Prepared statement did not return expected value: %v", s) } return pool.Deallocate(psName) @@ -328,7 +329,7 @@ func canceledQueryExContext(pool *pgx.ConnPool, actionNum int) error { if err == context.Canceled { return nil } else if err != nil { - return fmt.Errorf("Only allowed error is context.Canceled, got %v", err) + return errors.Errorf("Only allowed error is context.Canceled, got %v", err) } for rows.Next() { @@ -336,7 +337,7 @@ func canceledQueryExContext(pool *pgx.ConnPool, actionNum int) error { } if rows.Err() != context.Canceled { - return fmt.Errorf("Expected context.Canceled error, got %v", rows.Err()) + return errors.Errorf("Expected context.Canceled error, got %v", rows.Err()) } return nil @@ -351,7 +352,7 @@ func canceledExecExContext(pool *pgx.ConnPool, actionNum int) error { _, err := pool.ExecEx(ctx, "select pg_sleep(2)", nil) if err != context.Canceled { - return fmt.Errorf("Expected context.Canceled error, got %v", err) + return errors.Errorf("Expected context.Canceled error, got %v", err) } return nil diff --git a/tx.go b/tx.go index e144337d..f9607f70 100644 --- a/tx.go +++ b/tx.go @@ -3,9 +3,10 @@ package pgx import ( "bytes" "context" - "errors" "fmt" "time" + + "github.com/pkg/errors" ) type TxIsoLevel string diff --git a/v3.md b/v3.md index 993f9e24..d2afbf39 100644 --- a/v3.md +++ b/v3.md @@ -56,6 +56,8 @@ Use Go casing convention for OID, UUID, JSON(B), ACLItem, CID, TID, XID, and CID Add OnNotice +Use github.com/pkg/errors + ## TODO / Possible / Investigate Organize errors better diff --git a/values.go b/values.go index a6c350f6..86ae3afe 100644 --- a/values.go +++ b/values.go @@ -9,6 +9,7 @@ import ( "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgtype" + "github.com/pkg/errors" ) // PostgreSQL format codes @@ -69,12 +70,12 @@ func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, e return int64(arg), nil case uint64: if arg > math.MaxInt64 { - return nil, fmt.Errorf("arg too big for int64: %v", arg) + return nil, errors.Errorf("arg too big for int64: %v", arg) } return int64(arg), nil case uint: if arg > math.MaxInt64 { - return nil, fmt.Errorf("arg too big for int64: %v", arg) + return nil, errors.Errorf("arg too big for int64: %v", arg) } return int64(arg), nil case float32: From 784489d998af76222feeb6594a09003e7a699ef2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 5 Jun 2017 08:54:34 -0500 Subject: [PATCH 248/264] Update README.md with v3 --- README.md | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 877c2a37..37ab4558 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,14 @@ # Pgx -## Experimental Branch +## Version 3 Beta Branch -This is the experimental v3 branch. v2 is the stable branch. +This is the `v3` branch which is currently in beta. General release is planned +for July. `v2` is the current release branch. `v3` is considered to be stable in +the sense of lack of known bugs, but the API is not considered stable until +general release. No further changes are planned, but the beta process may +surface desirable changes. If possible API changes are acceptable, then `v3` is +the recommented branch for new development. Pgx is a pure Go database connection library designed specifically for PostgreSQL. Pgx is different from other drivers such as @@ -32,6 +37,10 @@ Pgx supports many additional features beyond what is available through database/ * Null mapping to Null* struct or pointer to pointer. * Supports database/sql.Scanner and database/sql/driver.Valuer interfaces for custom types * Logical replication connections, including receiving WAL and sending standby status updates +* Notice response handling (this is different than listen / notify) +* Batch queries +* Single-round trip query mode +* pgtype package includes support for approximately 60 different PostgreSQL types - these are usable in pgx native and any database/sql PostgreSQL adapter ## Performance From 29c017d75010df377cc99522223a5ebd4de4c7d3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 5 Jun 2017 09:11:19 -0500 Subject: [PATCH 249/264] Add v3 to changelog --- CHANGELOG.md | 62 ++++++++++++++++++++++++++++++++++- v3.md | 92 ---------------------------------------------------- 2 files changed, 61 insertions(+), 93 deletions(-) delete mode 100644 v3.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 126baef4..70d67f26 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,62 @@ -# Unreleased +# Unreleased V3 + +## Changes + +* Pid to PID in accordance with Go naming conventions. +* Conn.Pid changed to accessor method Conn.PID() +* Conn.SecretKey removed +* Remove Conn.TxStatus +* Logger interface reduced to single Log method. +* Replace BeginIso with BeginEx. BeginEx adds support for read/write mode and deferrable mode. +* Transaction isolation level constants are now typed strings instead of bare strings. +* Conn.WaitForNotification now takes context.Context instead of time.Duration for cancellation support. +* Conn.WaitForNotification no longer automatically pings internally every 15 seconds. +* ReplicationConn.WaitForReplicationMessage now takes context.Context instead of time.Duration for cancellation support. +* Reject scanning binary format values into a string (e.g. binary encoded timestamptz to string). See https://github.com/jackc/pgx/issues/219 and https://github.com/jackc/pgx/issues/228 +* No longer can read raw bytes of any value into a []byte. Use pgtype.GenericBinary if this functionality is needed. +* Remove CopyTo (functionality is now in CopyFrom) +* OID constants moved from pgx to pgtype package +* Replaced Scanner, Encoder, and PgxScanner interfaces with pgtype system +* Removed ValueReader +* ConnPool.Close no longer waits for all acquired connections to be released. Instead, it immediately closes all available connections, and closes acquired connections when they are released in the same manner as ConnPool.Reset. +* Removed Rows.Fatal(error) +* Removed Rows.AfterClose() +* Removed Rows.Conn() +* Removed Tx.AfterClose() +* Removed Tx.Conn() +* Use Go casing convention for OID, UUID, JSON(B), ACLItem, CID, TID, XID, and CIDR +* Replaced stdlib.OpenFromConnPool with DriverConfig system + +## Features + +* Entirely revamped pluggable type sytem that supports approximately 60 PostgreSQL types. +* Types support database/sql interfaces and therefore can be used with other drivers +* Added context methods supporting cancelation where appropriate +* Added simple query protocol support +* Added single round-trip query mode +* Added batch query operations +* Added OnNotice +* github.com/pkg/errors used where possible for errors +* Added stdlib.DriverConfig which allows directly allows full configuration of underlying pgx connections without needing to use a pgx.ConnPool +* Added AcquireConn and ReleaseConn to stdlib to allow acquiring a connection from a database/sql connection. + +# 2.11.0 (June 5, 2017) + +## Fixes + +* Fix race with concurrent execution of stdlib.OpenFromConnPool (Terin Stock) + +## Features + +* .pgpass support (j7b) +* Add missing CopyFrom delegators to Tx and ConnPool (Jack Christensen) +* Add ParseConnectionString (James Lawrence) + +## Performance + +* Optimize HStore encoding (René Kroon) + +# 2.10.0 (March 17, 2017) ## Fixes @@ -16,10 +74,12 @@ * Add named error ErrAcquireTimeout (Alexander Staubo) * Add logical replication decoding (Kris Wehner) * Add PgxScanner interface to allow types to simultaneously support database/sql and pgx (Jack Christensen) +* Add CopyFrom with schema support (Jack Christensen) ## Compatibility * jsonb now defaults to binary format. This means passing a []byte to a jsonb column will no longer work. +* CopyTo is now deprecated but will continue to work. # 2.9.0 (August 26, 2016) diff --git a/v3.md b/v3.md deleted file mode 100644 index d2afbf39..00000000 --- a/v3.md +++ /dev/null @@ -1,92 +0,0 @@ -# V3 Experimental - -## Changes - -Rename Pid to PID in accordance with Go naming conventions. - -Logger interface reduced to single Log method. - -Replace BeginIso with BeginEx. BeginEx adds support for read/write mode and deferrable mode. - -Transaction isolation level constants are now typed strings instead of bare strings. - -Conn.Pid changed to accessor method Conn.PID() - -Conn.SecretKey removed - -Remove Conn.TxStatus - -Added Context methods - -Conn.WaitForNotification now takes context.Context instead of time.Duration for cancellation support. - -Conn.WaitForNotification no longer automatically pings internally every 15 seconds. (Reconsider this later...) - -ReplicationConn.WaitForReplicationMessage now takes context.Context instead of time.Duration for cancellation support. - -Reject scanning binary format values into a string (e.g. binary encoded timestamptz to string). See https://github.com/jackc/pgx/issues/219 and https://github.com/jackc/pgx/issues/228 - -Remove CopyTo - -No longer can read raw bytes of any value into a []byte. Use pgtype.GenericBinary if this functionality is needed. - -OID constants moved from pgx to pgtype package - -Removed ValueReader - -Replaced Scanner, Encoder, and PgxScanner interfaces with pgtype system - -ConnPool.Close no longer waits for all acquired connections to be released. Instead, it immediately closes all available connections, and closes acquired connections when they are released in the same manner as ConnPool.Reset. - -Removed Rows.Fatal(error) - -Removed Rows.AfterClose() - -Removed Rows.Conn() - -Removed Tx.AfterClose() - -Removed Tx.Conn() - -Added ctx parameter to (Conn/Tx/ConnPool).PrepareEx - -Added batch operations - -Use Go casing convention for OID, UUID, JSON(B), ACLItem, CID, TID, XID, and CIDR - -Add OnNotice - -Use github.com/pkg/errors - -## TODO / Possible / Investigate - -Organize errors better - -Remove circular dependency between Conn and ConnPool such that ConnPool depends on Conn, but Conn doesn't know anything about ConnPool - -Or maybe double-down on conn/pool coupling and improve connpool - -Add auto-idle pinging to conns in pool - -Remove names from prepared statements - use database/sql style objects - -Better way of handling text/binary protocol choice than pgx.DefaultTypeFormats or manually editing a PreparedStatement. Possibly an optional part of preparing a statement is specifying the format and/or a decoder. Or maybe it is part of a QueryEx call... Could be very interesting to make encoding and decoding possible without being a method of the type. This could drastically clean up those huge type switches. - -Make easier / possible to mock Conn or ConnPool (https://github.com/jackc/pgx/pull/162) - -Every field that should not be set by user should be replaced by accessor method (only ones left are Conn.RuntimeParams and Conn.PgTypes) - -Investigate strongly typed queries. i.e. Some sort of interface where varargs of Query, Exec, and Scan wouldn't happen. Need to be some low-level interface where (probably generated) functions could (more or less) directly read and write to the connection. Clean code and type-safety / control would be the benefits. Row scanning performance is already so fast there is little to improve (go_db_bench shows under 1 microsecond per row). - -Further clean up logging interface -- still some pre-loglevel code in place -Possibly integrate internal logging support with context. Possibly add method that adds arbitrary pgx log data to context. Or add ability to configure what key(s) pgx looks at for additional log context. -Consider whether to switch to logrus style or stick with log15 style logs -Keep ability to change logging while running - -consider test to ensure that AssignTo makes copy of reference types -something like: -select array[1,2,3], array[4,5,6,7] - -Reconsider synonym types like varchar/text and numeric/decimal. - -integrate logging and context - should be able to replace logger via context OR inject params into log from context From 1c452a4a1e4585d60f7d2059a0b279fe991fd650 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 5 Jun 2017 09:19:29 -0500 Subject: [PATCH 250/264] Spelling corrections --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 70d67f26..f1ef2e49 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,9 +29,9 @@ ## Features -* Entirely revamped pluggable type sytem that supports approximately 60 PostgreSQL types. +* Entirely revamped pluggable type system that supports approximately 60 PostgreSQL types. * Types support database/sql interfaces and therefore can be used with other drivers -* Added context methods supporting cancelation where appropriate +* Added context methods supporting cancellation where appropriate * Added simple query protocol support * Added single round-trip query mode * Added batch query operations From 2509082c0e66e27ac530370f3312db69cea1a8ee Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Fri, 30 Jun 2017 12:52:24 -0400 Subject: [PATCH 251/264] Add missing `pgx.Identifier` to `CopyFrom` example --- doc.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc.go b/doc.go index a9b9e461..c909cf18 100644 --- a/doc.go +++ b/doc.go @@ -206,8 +206,8 @@ implement CopyToSource to avoid buffering the entire data set in memory. {"Jane", "Doe", int32(29)}, } - copyCount, err := conn.CopyTo( - "people", + copyCount, err := conn.CopyFrom( + pgx.Identifier{"people"}, []string{"first_name", "last_name", "age"}, pgx.CopyToRows(rows), ) From 340452945738ed9714433a995d610c6feeb77aa3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 4 Jul 2017 10:52:20 -0500 Subject: [PATCH 252/264] Fix docs CopyTo -> CopyFrom --- doc.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/doc.go b/doc.go index c909cf18..0be66a05 100644 --- a/doc.go +++ b/doc.go @@ -196,10 +196,10 @@ can create a transaction with a specified isolation level. Copy Protocol -Use CopyTo to efficiently insert multiple rows at a time using the PostgreSQL -copy protocol. CopyTo accepts a CopyToSource interface. If the data is already -in a [][]interface{} use CopyToRows to wrap it in a CopyToSource interface. Or -implement CopyToSource to avoid buffering the entire data set in memory. +Use CopyFrom to efficiently insert multiple rows at a time using the PostgreSQL +copy protocol. CopyFrom accepts a CopyFromSource interface. If the data is already +in a [][]interface{} use CopyFromRows to wrap it in a CopyFromSource interface. Or +implement CopyFromSource to avoid buffering the entire data set in memory. rows := [][]interface{}{ {"John", "Smith", int32(36)}, @@ -209,10 +209,10 @@ implement CopyToSource to avoid buffering the entire data set in memory. copyCount, err := conn.CopyFrom( pgx.Identifier{"people"}, []string{"first_name", "last_name", "age"}, - pgx.CopyToRows(rows), + pgx.CopyFromRows(rows), ) -CopyTo can be faster than an insert with as few as 5 rows. +CopyFrom can be faster than an insert with as few as 5 rows. Listen and Notify From 53b4280456841e3c4cb554fd34fd611a58bf715e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 4 Jul 2017 11:36:08 -0500 Subject: [PATCH 253/264] Automatically register enum types fixes #287 --- conn.go | 2 +- pgmock/pgmock.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/conn.go b/conn.go index da8ed655..0d51228d 100644 --- a/conn.go +++ b/conn.go @@ -388,7 +388,7 @@ func (c *Conn) initConnInfo() error { from pg_type t left join pg_type base_type on t.typelem=base_type.oid where ( - t.typtype in('b', 'p', 'r') + t.typtype in('b', 'p', 'r', 'e') and (base_type.oid is null or base_type.typtype in('b', 'p', 'r')) )`) if err != nil { diff --git a/pgmock/pgmock.go b/pgmock/pgmock.go index 5e340881..fe78b009 100644 --- a/pgmock/pgmock.go +++ b/pgmock/pgmock.go @@ -204,7 +204,7 @@ func AcceptUnauthenticatedConnRequestSteps() []Step { func PgxInitSteps() []Step { steps := []Step{ ExpectMessage(&pgproto3.Parse{ - Query: "select t.oid, t.typname\nfrom pg_type t\nleft join pg_type base_type on t.typelem=base_type.oid\nwhere (\n\t t.typtype in('b', 'p', 'r')\n\t and (base_type.oid is null or base_type.typtype in('b', 'p', 'r'))\n\t)", + Query: "select t.oid, t.typname\nfrom pg_type t\nleft join pg_type base_type on t.typelem=base_type.oid\nwhere (\n\t t.typtype in('b', 'p', 'r', 'e')\n\t and (base_type.oid is null or base_type.typtype in('b', 'p', 'r'))\n\t)", }), ExpectMessage(&pgproto3.Describe{ ObjectType: 'S', From 9a1ab885af2b64b5923e6b1fe9d5dd8768cb7e47 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 14 Jul 2017 08:57:25 -0500 Subject: [PATCH 254/264] Use insert on conflict for url shortener example fixes #290 --- examples/url_shortener/main.go | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/examples/url_shortener/main.go b/examples/url_shortener/main.go index 8380ef3f..c6576a3a 100644 --- a/examples/url_shortener/main.go +++ b/examples/url_shortener/main.go @@ -28,19 +28,9 @@ func afterConnect(conn *pgx.Conn) (err error) { return } - // There technically is a small race condition in doing an upsert with a CTE - // where one of two simultaneous requests to the shortened URL would fail - // with a unique index violation. As the point of this demo is pgx usage and - // not how to perfectly upsert in PostgreSQL it is deemed acceptable. _, err = conn.Prepare("putUrl", ` - with upsert as ( - update shortened_urls - set url=$2 - where id=$1 - returning * - ) - insert into shortened_urls(id, url) - select $1, $2 where not exists(select 1 from upsert) + insert into shortened_urls(id, url) values ($1, $2) + on conflict (id) do update set url=excluded.url `) return } From e2dae9f4acb8f94c4bcf54ca9cd12518c9aa90dd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 15 Jul 2017 08:31:27 -0500 Subject: [PATCH 255/264] Fix typo --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 37ab4558..cb206b3a 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ for July. `v2` is the current release branch. `v3` is considered to be stable in the sense of lack of known bugs, but the API is not considered stable until general release. No further changes are planned, but the beta process may surface desirable changes. If possible API changes are acceptable, then `v3` is -the recommented branch for new development. +the recommended branch for new development. Pgx is a pure Go database connection library designed specifically for PostgreSQL. Pgx is different from other drivers such as From 4d6fe2d5faf823da75c3139a38874d034782d61e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 15 Jul 2017 08:41:26 -0500 Subject: [PATCH 256/264] Doc updates --- README.md | 2 +- doc.go | 17 ++++++----------- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index cb206b3a..439f71be 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ Pgx supports many additional features beyond what is available through database/ * Full TLS connection control * Binary format support for custom types (can be much faster) * Copy protocol support for faster bulk data loads -* Logging support +* Extendable logging support including builtin support for log15 and logrus * Configurable connection pool with after connect hooks to do arbitrary connection setup * PostgreSQL array to Go slice mapping for integers, floats, and strings * Hstore support diff --git a/doc.go b/doc.go index 0be66a05..6e26b8b9 100644 --- a/doc.go +++ b/doc.go @@ -117,11 +117,11 @@ particular: Null Mapping -pgx can map nulls in two ways. The first is Null* types that have a data field -and a valid field. They work in a similar fashion to database/sql. The second -is to use a pointer to a pointer. +pgx can map nulls in two ways. The first is package pgtype provides types that +have a data field and a null indicator field. They work in a similar fashion to +database/sql. The second is to use a pointer to a pointer. - var foo pgx.NullString + var foo pgtype.Varchar var bar *string err := conn.QueryRow("select foo, bar from widgets where id=$1", 42).Scan(&a, &b) if err != nil { @@ -133,13 +133,8 @@ Array Mapping pgx maps between int16, int32, int64, float32, float64, and string Go slices and the equivalent PostgreSQL array type. Go slices of native types do not support nulls, so if a PostgreSQL array that contains a null is read into a -native Go slice an error will occur. - -Hstore Mapping - -pgx includes an Hstore type and a NullHstore type. Hstore is simply a -map[string]string and is preferred when the hstore contains no nulls. NullHstore -follows the Null* pattern and supports null values. +native Go slice an error will occur. The pgtype package includes many more +array types for PostgreSQL types that do not directly map to native Go types. JSON and JSONB Mapping From 062d97deb248942457b36f4042d9c6381beb65a9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 21 Jul 2017 14:09:38 -0500 Subject: [PATCH 257/264] Doc tweaks --- doc.go | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/doc.go b/doc.go index 6e26b8b9..c61329d9 100644 --- a/doc.go +++ b/doc.go @@ -62,17 +62,15 @@ Use Exec to execute a query that does not return a result set. Connection Pool -Connection pool usage is explicit and configurable. In pgx, a connection can -be created and managed directly, or a connection pool with a configurable -maximum connections can be used. Also, the connection pool offers an after -connect hook that allows every connection to be automatically setup before -being made available in the connection pool. This is especially useful to -ensure all connections have the same prepared statements available or to -change any other connection settings. +Connection pool usage is explicit and configurable. In pgx, a connection can be +created and managed directly, or a connection pool with a configurable maximum +connections can be used. The connection pool offers an after connect hook that +allows every connection to be automatically setup before being made available in +the connection pool. -It delegates Query, QueryRow, Exec, and Begin functions to an automatically -checked out and released connection so you can avoid manually acquiring and -releasing connections when you do not need that level of control. +It delegates methods such as QueryRow to an automatically checked out and +released connection so you can avoid manually acquiring and releasing +connections when you do not need that level of control. var name string var weight int64 @@ -118,7 +116,7 @@ particular: Null Mapping pgx can map nulls in two ways. The first is package pgtype provides types that -have a data field and a null indicator field. They work in a similar fashion to +have a data field and a status field. They work in a similar fashion to database/sql. The second is to use a pointer to a pointer. var foo pgtype.Varchar From 05509e1f6f33f5febd1fc06f1972cde042c17710 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 21 Jul 2017 14:10:33 -0500 Subject: [PATCH 258/264] Remove unused files --- conn-lock-todo.txt | 11 ----------- context-todo.txt | 12 ------------ 2 files changed, 23 deletions(-) delete mode 100644 conn-lock-todo.txt delete mode 100644 context-todo.txt diff --git a/conn-lock-todo.txt b/conn-lock-todo.txt deleted file mode 100644 index ab5eac95..00000000 --- a/conn-lock-todo.txt +++ /dev/null @@ -1,11 +0,0 @@ -Extract all locking state into a separate struct that will encapsulate locking and state change behavior. - -This struct should add or subsume at least the following: -* alive -* closingLock -* ctxInProgress (though this may be restructured because it's possible a Tx may have a ctx and a query run in that Tx could have one) -* busy -* lock/unlock -* Tx in-progress -* Rows in-progress -* ConnPool checked-out or checked-in - maybe include reference to conn pool diff --git a/context-todo.txt b/context-todo.txt deleted file mode 100644 index b5a20d0a..00000000 --- a/context-todo.txt +++ /dev/null @@ -1,12 +0,0 @@ -Add more testing -- stress test style -- pgmock - -Add documentation - -Add PrepareContext -Add context methods to ConnPool -Add context methods to Tx -Add context support database/sql - -Benchmark - possibly cache done channel on Conn From 88c7dd8da25c06a6a9ea232e64d131a4f518347b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 21 Jul 2017 14:12:19 -0500 Subject: [PATCH 259/264] Fix typo in changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f1ef2e49..cb96eaf3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,7 +37,7 @@ * Added batch query operations * Added OnNotice * github.com/pkg/errors used where possible for errors -* Added stdlib.DriverConfig which allows directly allows full configuration of underlying pgx connections without needing to use a pgx.ConnPool +* Added stdlib.DriverConfig which directly allows full configuration of underlying pgx connections without needing to use a pgx.ConnPool * Added AcquireConn and ReleaseConn to stdlib to allow acquiring a connection from a database/sql connection. # 2.11.0 (June 5, 2017) From dde965bc9d069c041807abaccb11f246d65bae86 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 21 Jul 2017 14:19:45 -0500 Subject: [PATCH 260/264] README updates --- README.md | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 439f71be..6b71ba50 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ [![](https://godoc.org/github.com/jackc/pgx?status.svg)](https://godoc.org/github.com/jackc/pgx) -# Pgx +# Pgx - PostgreSQL Driver and Toolkit ## Version 3 Beta Branch @@ -22,25 +22,26 @@ performance and more features. Pgx supports many additional features beyond what is available through database/sql. -* Listen / notify -* Transaction isolation level control +* pgtype package includes support for approximately 60 different PostgreSQL types - these are usable in pgx native and any database/sql PostgreSQL adapter +* Batch queries +* Single-round trip query mode * Full TLS connection control * Binary format support for custom types (can be much faster) * Copy protocol support for faster bulk data loads * Extendable logging support including builtin support for log15 and logrus * Configurable connection pool with after connect hooks to do arbitrary connection setup +* Listen / notify +* Transaction isolation level control * PostgreSQL array to Go slice mapping for integers, floats, and strings * Hstore support * JSON and JSONB support * Maps inet and cidr PostgreSQL types to net.IPNet and net.IP * Large object support -* Null mapping to Null* struct or pointer to pointer. +* NULL mapping to Null* struct or pointer to pointer. * Supports database/sql.Scanner and database/sql/driver.Valuer interfaces for custom types * Logical replication connections, including receiving WAL and sending standby status updates * Notice response handling (this is different than listen / notify) -* Batch queries -* Single-round trip query mode -* pgtype package includes support for approximately 60 different PostgreSQL types - these are usable in pgx native and any database/sql PostgreSQL adapter +* pgproto3 package can encode and decode the PostgreSQL version 3 wire protocol ## Performance From 79517aaa0e5d9cccf3bf8c7398eb0fa6b03c0c8a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 21 Jul 2017 15:22:32 -0500 Subject: [PATCH 261/264] Fix batch query with query syntax error --- batch.go | 15 +++++++++------ batch_test.go | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/batch.go b/batch.go index 3c16fd13..fc6f0d03 100644 --- a/batch.go +++ b/batch.go @@ -167,25 +167,28 @@ func (b *Batch) ExecResults() (CommandTag, error) { // QueryResults reads the results from the next query in the batch as if the // query has been sent with Query. func (b *Batch) QueryResults() (*Rows, error) { + rows := b.conn.getRows("batch query", nil) + if b.err != nil { - return nil, b.err + rows.fatal(b.err) + return rows, b.err } select { case <-b.ctx.Done(): b.die(b.ctx.Err()) - return nil, b.ctx.Err() + rows.fatal(b.err) + return rows, b.ctx.Err() default: } b.resultsRead++ - rows := b.conn.getRows("batch query", nil) - fieldDescriptions, err := b.conn.readUntilRowDescription() if err != nil { - b.die(b.ctx.Err()) - return nil, err + b.die(err) + rows.fatal(b.err) + return rows, err } rows.batch = b diff --git a/batch_test.go b/batch_test.go index ffd3cc50..e12e4f32 100644 --- a/batch_test.go +++ b/batch_test.go @@ -442,3 +442,37 @@ func TestConnBeginBatchQueryError(t *testing.T) { t.Error("conn should be dead, but was alive") } } + +func TestConnBeginBatchQuerySyntaxError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + batch := conn.BeginBatch() + batch.Queue("select 1 1", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + + err := batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + var n int32 + err = batch.QueryRowResults().Scan(&n) + if pgErr, ok := err.(pgx.PgError); !(ok && pgErr.Code == "42601") { + t.Errorf("rows.Err() => %v, want error code %v", err, 42601) + } + + err = batch.Close() + if err == nil { + t.Error("Expected error") + } + + if conn.IsAlive() { + t.Error("conn should be dead, but was alive") + } +} From 2e5f5e0c9de9ccf06c7cbf4f80d52451d23ddf07 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 21 Jul 2017 16:35:54 -0500 Subject: [PATCH 262/264] More README tweaks --- README.md | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 6b71ba50..42436913 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ [![](https://godoc.org/github.com/jackc/pgx?status.svg)](https://godoc.org/github.com/jackc/pgx) -# Pgx - PostgreSQL Driver and Toolkit +# pgx - PostgreSQL Driver and Toolkit ## Version 3 Beta Branch @@ -11,16 +11,15 @@ general release. No further changes are planned, but the beta process may surface desirable changes. If possible API changes are acceptable, then `v3` is the recommended branch for new development. -Pgx is a pure Go database connection library designed specifically for -PostgreSQL. Pgx is different from other drivers such as -[pq](http://godoc.org/github.com/lib/pq) because, while it can operate as a -database/sql compatible driver, pgx is primarily intended to be used directly. -It offers a native interface similar to database/sql that offers better -performance and more features. +pgx is a pure Go driver and toolkit for PostgreSQL. pgx is different from other +drivers such as [pq](http://godoc.org/github.com/lib/pq) because, while it can +operate as a database/sql compatible driver, pgx is primarily intended to be +used directly. It offers a native interface similar to database/sql that offers +better performance and more features. ## Features -Pgx supports many additional features beyond what is available through database/sql. +pgx supports many additional features beyond what is available through database/sql. * pgtype package includes support for approximately 60 different PostgreSQL types - these are usable in pgx native and any database/sql PostgreSQL adapter * Batch queries @@ -45,13 +44,20 @@ Pgx supports many additional features beyond what is available through database/ ## Performance -Pgx performs roughly equivalent to [pq](http://godoc.org/github.com/lib/pq) and -[go-pg](https://github.com/go-pg/pg) for selecting a single column from a single -row, but it is substantially faster when selecting multiple entire rows (6893 -queries/sec for pgx vs. 3968 queries/sec for pq -- 73% faster). +pgx performs roughly equivalent to [go-pg](https://github.com/go-pg/pg) and is +almost always faster than [pq](http://godoc.org/github.com/lib/pq). When parsing +large result sets the percentage difference can be significant (16483 +queries/sec for pgx vs. 10106 queries/sec for pq -- 63% faster). -See this [gist](https://gist.github.com/jackc/d282f39e088b495fba3e) for the -underlying benchmark results or checkout +In many use cases a significant cause of latency is network round trips between +the application and the server. pgx supports query batching to bundle multiple +queries into a single round trip. Even in the case of the fastest possible +connection, a local Unix domain socket, batching as few as three queries +together can yield an improvement of 57%. With a typical network connection the +results can be even more substantial. + +See this [gist](https://gist.github.com/jackc/4996e8648a0c59839bff644f49d6e434) +for the underlying benchmark results or checkout [go_db_bench](https://github.com/jackc/go_db_bench) to run tests for yourself. ## database/sql From a147e0f3b8a574776d8bb9f4ce8b6fe4504cbfc2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 22 Jul 2017 08:41:13 -0500 Subject: [PATCH 263/264] Fix test on Travis --- stdlib/sql_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index bf99a8bb..65f80ac4 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -1211,6 +1211,7 @@ func TestStmtQueryContextCancel(t *testing.T) { pgmock.ExpectMessage(&pgproto3.Sync{}), pgmock.SendMessage(&pgproto3.BindComplete{}), + pgmock.WaitForClose(), ) server, err := pgmock.NewServer(script) From 534ea4a9cbf312e9026894ac9197a9744d7e7cc9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 24 Jul 2017 08:04:01 -0500 Subject: [PATCH 264/264] Update README for v3 release --- README.md | 62 +++++++++++++++++++++---------------------------------- 1 file changed, 24 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index 42436913..570afce3 100644 --- a/README.md +++ b/README.md @@ -2,35 +2,21 @@ # pgx - PostgreSQL Driver and Toolkit -## Version 3 Beta Branch - -This is the `v3` branch which is currently in beta. General release is planned -for July. `v2` is the current release branch. `v3` is considered to be stable in -the sense of lack of known bugs, but the API is not considered stable until -general release. No further changes are planned, but the beta process may -surface desirable changes. If possible API changes are acceptable, then `v3` is -the recommended branch for new development. - -pgx is a pure Go driver and toolkit for PostgreSQL. pgx is different from other -drivers such as [pq](http://godoc.org/github.com/lib/pq) because, while it can -operate as a database/sql compatible driver, pgx is primarily intended to be -used directly. It offers a native interface similar to database/sql that offers -better performance and more features. +pgx is a pure Go driver and toolkit for PostgreSQL. pgx is different from other drivers such as [pq](http://godoc.org/github.com/lib/pq) because, while it can operate as a database/sql compatible driver, pgx is also usable directly. It offers a native interface similar to database/sql that offers better performance and more features. ## Features pgx supports many additional features beyond what is available through database/sql. -* pgtype package includes support for approximately 60 different PostgreSQL types - these are usable in pgx native and any database/sql PostgreSQL adapter +* Support for approximately 60 different PostgreSQL types * Batch queries * Single-round trip query mode * Full TLS connection control * Binary format support for custom types (can be much faster) * Copy protocol support for faster bulk data loads -* Extendable logging support including builtin support for log15 and logrus -* Configurable connection pool with after connect hooks to do arbitrary connection setup +* Extendable logging support including built-in support for log15 and logrus +* Connection pool with after connect hook to do arbitrary connection setup * Listen / notify -* Transaction isolation level control * PostgreSQL array to Go slice mapping for integers, floats, and strings * Hstore support * JSON and JSONB support @@ -40,32 +26,32 @@ pgx supports many additional features beyond what is available through database/ * Supports database/sql.Scanner and database/sql/driver.Valuer interfaces for custom types * Logical replication connections, including receiving WAL and sending standby status updates * Notice response handling (this is different than listen / notify) -* pgproto3 package can encode and decode the PostgreSQL version 3 wire protocol ## Performance -pgx performs roughly equivalent to [go-pg](https://github.com/go-pg/pg) and is -almost always faster than [pq](http://godoc.org/github.com/lib/pq). When parsing -large result sets the percentage difference can be significant (16483 -queries/sec for pgx vs. 10106 queries/sec for pq -- 63% faster). +pgx performs roughly equivalent to [go-pg](https://github.com/go-pg/pg) and is almost always faster than [pq](http://godoc.org/github.com/lib/pq). When parsing large result sets the percentage difference can be significant (16483 queries/sec for pgx vs. 10106 queries/sec for pq -- 63% faster). -In many use cases a significant cause of latency is network round trips between -the application and the server. pgx supports query batching to bundle multiple -queries into a single round trip. Even in the case of the fastest possible -connection, a local Unix domain socket, batching as few as three queries -together can yield an improvement of 57%. With a typical network connection the -results can be even more substantial. +In many use cases a significant cause of latency is network round trips between the application and the server. pgx supports query batching to bundle multiple queries into a single round trip. Even in the case of a connection with the lowest possible latency, a local Unix domain socket, batching as few as three queries together can yield an improvement of 57%. With a typical network connection the results can be even more substantial. -See this [gist](https://gist.github.com/jackc/4996e8648a0c59839bff644f49d6e434) -for the underlying benchmark results or checkout -[go_db_bench](https://github.com/jackc/go_db_bench) to run tests for yourself. +See this [gist](https://gist.github.com/jackc/4996e8648a0c59839bff644f49d6e434) for the underlying benchmark results or checkout [go_db_bench](https://github.com/jackc/go_db_bench) to run tests for yourself. -## database/sql +In addition to the native driver, pgx also includes a number of packages that provide additional functionality. -Import the ```github.com/jackc/pgx/stdlib``` package to use pgx as a driver for -database/sql. It is possible to retrieve a pgx connection from database/sql on -demand. This allows using the database/sql interface in most places, but using -pgx directly when more performance or PostgreSQL specific features are needed. +## github.com/jackc/pgxstdlib + +database/sql compatibility layer for pgx. pgx can be used as a normal database/sql driver, but at any time the native interface may be acquired for more performance or PostgreSQL specific functionality. + +## github.com/jackc/pgx/pgtype + +Approximately 60 PostgreSQL types are supported including uuid, hstore, json, bytea, numeric, interval, inet, and arrays. These types support database/sql interfaces and are usable even outside of pgx. They are fully tested in pgx and pq. They also support a higher performance interface when used with the pgx driver. + +## github.com/jackc/pgx/pgproto3 + +pgproto3 provides standalone encoding and decoding of the PostgreSQL v3 wire protocol. This is useful for implementing very low level PostgreSQL tooling. + +## github.com/jackc/pgx/pgmock + +pgmock offers the ability to create a server that mocks the PostgreSQL wire protocol. This is used internally to test pgx by purposely inducing unusual errors. pgproto3 and pgmock together provide most of the foundational tooling required to implement a PostgreSQL proxy or MitM (such as for a custom connection pooler). ## Documentation @@ -144,4 +130,4 @@ Set `replicationConnConfig` appropriately in `conn_config_test.go`. ## Version Policy -pgx follows semantic versioning for the documented public API on stable releases. Branch `v2` is the latest stable release. `master` can contain new features or behavior that will change or be removed before being merged to the stable `v2` branch (in practice, this occurs very rarely). +pgx follows semantic versioning for the documented public API on stable releases. Branch `v3` is the latest stable release. `master` can contain new features or behavior that will change or be removed before being merged to the stable `v3` branch (in practice, this occurs very rarely). `v2` is the previous stable release.