From 19c668975218ca857f07e0506cdbcaa83f68fb24 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 18 Mar 2017 12:01:16 -0500 Subject: [PATCH] Add pgtype.Record and prerequisite restructuring Because reading a record type requires the decoder to be able to look up oid to type mapping and types such as hstore have types that are not fixed between different PostgreSQL servers it was necessary to restructure the pgtype system so all encoders and decodes take a *ConnInfo that includes oid/name/type information. --- conn.go | 148 ++++-------- conn_pool.go | 8 +- copy_from_test.go | 5 +- example_custom_type_test.go | 4 +- pgtype/aclitem.go | 4 +- pgtype/aclitem_array.go | 8 +- pgtype/array.go | 4 +- pgtype/bool.go | 8 +- pgtype/bool_array.go | 24 +- pgtype/bytea.go | 8 +- pgtype/bytea_array.go | 24 +- pgtype/cid.go | 16 +- pgtype/cidr.go | 35 +++ pgtype/cidr_array.go | 317 +++++++++++++++++++++++- pgtype/cidr_array_test.go | 164 +++++++++++++ pgtype/database_sql.go | 66 +++++ pgtype/date.go | 11 +- pgtype/date_array.go | 24 +- pgtype/float4.go | 8 +- pgtype/float4_array.go | 24 +- pgtype/float8.go | 8 +- pgtype/float8_array.go | 24 +- pgtype/generic_binary.go | 8 +- pgtype/generic_text.go | 8 +- pgtype/hstore.go | 14 +- pgtype/inet.go | 8 +- pgtype/inet_array.go | 24 +- pgtype/int2.go | 8 +- pgtype/int2_array.go | 24 +- pgtype/int4.go | 8 +- pgtype/int4_array.go | 24 +- pgtype/int8.go | 8 +- pgtype/int8_array.go | 24 +- pgtype/json.go | 12 +- pgtype/jsonb.go | 12 +- pgtype/name.go | 16 +- pgtype/oid.go | 8 +- pgtype/oid_value.go | 16 +- pgtype/pgtype.go | 129 +++++++++- pgtype/pgtype_test.go | 10 +- pgtype/pguint32.go | 8 +- pgtype/qchar.go | 4 +- pgtype/record.go | 123 ++++++++++ pgtype/record_test.go | 150 ++++++++++++ pgtype/text.go | 12 +- pgtype/text_array.go | 24 +- pgtype/tid.go | 8 +- pgtype/timestamp.go | 8 +- pgtype/timestamp_array.go | 24 +- pgtype/timestamptz.go | 8 +- pgtype/timestamptz_array.go | 24 +- pgtype/typed_array.go.erb | 24 +- pgtype/typed_array_gen.sh | 2 + pgtype/unknown.go | 32 +++ pgtype/varchar.go | 40 ++++ pgtype/varchar_array.go | 285 +++++++++++++++++++++- pgtype/varchar_array_test.go | 151 ++++++++++++ pgtype/xid.go | 16 +- query.go | 208 ++++++++-------- query_test.go | 2 +- values.go | 451 +---------------------------------- values_test.go | 265 ++++++++++---------- 62 files changed, 2067 insertions(+), 1105 deletions(-) create mode 100644 pgtype/cidr.go create mode 100644 pgtype/cidr_array_test.go create mode 100644 pgtype/database_sql.go create mode 100644 pgtype/record.go create mode 100644 pgtype/record_test.go create mode 100644 pgtype/unknown.go create mode 100644 pgtype/varchar.go create mode 100644 pgtype/varchar_array_test.go diff --git a/conn.go b/conn.go index 0c86d169..3414d7cf 100644 --- a/conn.go +++ b/conn.go @@ -31,6 +31,20 @@ const ( connStatusBusy ) +// minimalConnInfo has just enough static type information to establish the +// connection and retrieve the type data. +var minimalConnInfo *pgtype.ConnInfo + +func init() { + minimalConnInfo = pgtype.NewConnInfo() + minimalConnInfo.InitializeDataTypes(map[string]pgtype.Oid{ + "int4": Int4Oid, + "name": NameOid, + "oid": OidOid, + "text": TextOid, + }) +} + // DialFunc is a function that can be used to connect to a PostgreSQL server type DialFunc func(network, addr string) (net.Conn, error) @@ -74,11 +88,10 @@ type Conn struct { lastActivityTime time.Time // the last time the connection was used wbuf [1024]byte writeBuf WriteBuf - pid int32 // backend pid - secretKey int32 // key to use to send a cancel query message to the server - RuntimeParams map[string]string // parameters that have been reported by the server - PgTypes map[pgtype.Oid]PgType // oids to PgTypes - config ConnConfig // config used when establishing this connection + pid int32 // backend pid + secretKey int32 // key to use to send a cancel query message to the server + RuntimeParams map[string]string // parameters that have been reported by the server + config ConnConfig // config used when establishing this connection txStatus byte preparedStatements map[string]*PreparedStatement channels map[string]struct{} @@ -102,7 +115,7 @@ type Conn struct { doneChan chan struct{} closedChan chan error - oidPgtypeValues map[pgtype.Oid]pgtype.Value + ConnInfo *pgtype.ConnInfo } // PreparedStatement is a description of a prepared statement @@ -125,12 +138,6 @@ type Notification struct { Payload string } -// PgType is information about PostgreSQL type and how to encode and decode it -type PgType struct { - Name string // name of type e.g. int4, text, date - DefaultFormat int16 // default format (text or binary) this type will be requested in -} - // CommandTag is the result of an Exec function type CommandTag string @@ -190,20 +197,14 @@ func (e ProtocolError) Error() string { // config.Host must be specified. config.User will default to the OS user name. // Other config fields are optional. func Connect(config ConnConfig) (c *Conn, err error) { - return connect(config, nil) + return connect(config, minimalConnInfo) } -func connect(config ConnConfig, pgTypes map[pgtype.Oid]PgType) (c *Conn, err error) { +func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) { c = new(Conn) c.config = config - - if pgTypes != nil { - c.PgTypes = make(map[pgtype.Oid]PgType, len(pgTypes)) - for k, v := range pgTypes { - c.PgTypes[k] = v - } - } + c.ConnInfo = connInfo if c.config.LogLevel != 0 { c.logLevel = c.config.LogLevel @@ -289,8 +290,6 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl } } - c.loadStaticOidPgtypeValues() - c.mr.cr = chunkreader.NewChunkReader(c.conn) msg := newStartupMessage() @@ -344,13 +343,12 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl return nil } - if c.PgTypes == nil { - err = c.loadPgTypes() + if c.ConnInfo == minimalConnInfo { + err = c.initConnInfo() if err != nil { return err } } - c.loadDynamicOidPgtypeValues() return nil default: @@ -361,88 +359,37 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl } } -func (c *Conn) loadPgTypes() error { +func (c *Conn) initConnInfo() error { + nameOids := make(map[string]pgtype.Oid, 256) + rows, err := c.Query(`select t.oid, t.typname from pg_type t left join pg_type base_type on t.typelem=base_type.oid where ( - t.typtype='b' - and (base_type.oid is null or base_type.typtype='b') - ) - or t.typname in('record');`) + t.typtype in('b', 'p') + and (base_type.oid is null or base_type.typtype in('b', 'p')) + )`) if err != nil { return err } - c.PgTypes = make(map[pgtype.Oid]PgType, 128) - for rows.Next() { - var oid uint32 - var t PgType + var oid pgtype.Oid + var name pgtype.Text + if err := rows.Scan(&oid, &name); err != nil { + return err + } - rows.Scan(&oid, &t.Name) - - // The zero value is text format so we ignore any types without a default type format - t.DefaultFormat, _ = DefaultTypeFormats[t.Name] - - c.PgTypes[pgtype.Oid(oid)] = t + nameOids[name.String] = oid } - return rows.Err() -} - -func (c *Conn) loadStaticOidPgtypeValues() { - c.oidPgtypeValues = map[pgtype.Oid]pgtype.Value{ - AclitemArrayOid: &pgtype.AclitemArray{}, - AclitemOid: &pgtype.Aclitem{}, - BoolArrayOid: &pgtype.BoolArray{}, - BoolOid: &pgtype.Bool{}, - ByteaArrayOid: &pgtype.ByteaArray{}, - ByteaOid: &pgtype.Bytea{}, - CharOid: &pgtype.QChar{}, - CidOid: &pgtype.Cid{}, - CidrArrayOid: &pgtype.CidrArray{}, - CidrOid: &pgtype.Inet{}, - DateArrayOid: &pgtype.DateArray{}, - DateOid: &pgtype.Date{}, - Float4ArrayOid: &pgtype.Float4Array{}, - Float4Oid: &pgtype.Float4{}, - Float8ArrayOid: &pgtype.Float8Array{}, - Float8Oid: &pgtype.Float8{}, - InetArrayOid: &pgtype.InetArray{}, - InetOid: &pgtype.Inet{}, - Int2ArrayOid: &pgtype.Int2Array{}, - Int2Oid: &pgtype.Int2{}, - Int4ArrayOid: &pgtype.Int4Array{}, - Int4Oid: &pgtype.Int4{}, - Int8ArrayOid: &pgtype.Int8Array{}, - Int8Oid: &pgtype.Int8{}, - JsonbOid: &pgtype.Jsonb{}, - JsonOid: &pgtype.Json{}, - NameOid: &pgtype.Name{}, - OidOid: &pgtype.OidValue{}, - TextArrayOid: &pgtype.TextArray{}, - TextOid: &pgtype.Text{}, - TidOid: &pgtype.Tid{}, - TimestampArrayOid: &pgtype.TimestampArray{}, - TimestampOid: &pgtype.Timestamp{}, - TimestampTzArrayOid: &pgtype.TimestamptzArray{}, - TimestampTzOid: &pgtype.Timestamptz{}, - VarcharArrayOid: &pgtype.VarcharArray{}, - VarcharOid: &pgtype.Text{}, - XidOid: &pgtype.Xid{}, - } -} - -func (c *Conn) loadDynamicOidPgtypeValues() { - nameOids := make(map[string]pgtype.Oid, len(c.PgTypes)) - for k, v := range c.PgTypes { - nameOids[v.Name] = k + if rows.Err() != nil { + return rows.Err() } - if oid, ok := nameOids["hstore"]; ok { - c.oidPgtypeValues[oid] = &pgtype.Hstore{} - } + c.ConnInfo = pgtype.NewConnInfo() + c.ConnInfo.InitializeDataTypes(nameOids) + return nil } // PID returns the backend PID for this connection. @@ -805,9 +752,16 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared case rowDescription: ps.FieldDescriptions = c.rxRowDescription(r) for i := range ps.FieldDescriptions { - t, _ := c.PgTypes[ps.FieldDescriptions[i].DataType] - ps.FieldDescriptions[i].DataTypeName = t.Name - ps.FieldDescriptions[i].FormatCode = t.DefaultFormat + if dt, ok := c.ConnInfo.DataTypeForOid(ps.FieldDescriptions[i].DataType); ok { + ps.FieldDescriptions[i].DataTypeName = dt.Name + if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { + ps.FieldDescriptions[i].FormatCode = BinaryFormatCode + } else { + ps.FieldDescriptions[i].FormatCode = TextFormatCode + } + } else { + return nil, fmt.Errorf("unknown oid: %d", ps.FieldDescriptions[i].DataType) + } } case readyForQuery: c.rxReadyForQuery(r) diff --git a/conn_pool.go b/conn_pool.go index 653ed0ba..44559ea8 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -30,7 +30,7 @@ type ConnPool struct { closed bool preparedStatements map[string]*PreparedStatement acquireTimeout time.Duration - pgTypes map[pgtype.Oid]PgType + connInfo *pgtype.ConnInfo txAfterClose func(tx *Tx) rowsAfterClose func(rows *Rows) } @@ -49,6 +49,7 @@ var ErrAcquireTimeout = errors.New("timeout acquiring connection from pool") func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) { p = new(ConnPool) p.config = config.ConnConfig + p.connInfo = minimalConnInfo p.maxConnections = config.MaxConnections if p.maxConnections == 0 { p.maxConnections = 5 @@ -95,6 +96,7 @@ func NewConnPool(config ConnPoolConfig) (p *ConnPool, err error) { } p.allConnections = append(p.allConnections, c) p.availableConnections = append(p.availableConnections, c) + p.connInfo = c.ConnInfo.DeepCopy() return } @@ -294,7 +296,7 @@ func (p *ConnPool) Stat() (s ConnPoolStat) { } func (p *ConnPool) createConnection() (*Conn, error) { - c, err := connect(p.config, p.pgTypes) + c, err := connect(p.config, p.connInfo) if err != nil { return nil, err } @@ -329,8 +331,6 @@ func (p *ConnPool) createConnectionUnlocked() (*Conn, error) { // afterConnectionCreated executes (if it is) afterConnect() callback and prepares // all the known statements for the new connection. func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) { - p.pgTypes = c.PgTypes - if p.afterConnect != nil { err := p.afterConnect(c) if err != nil { diff --git a/copy_from_test.go b/copy_from_test.go index e17575de..6df4ebb1 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -7,7 +7,6 @@ import ( "time" "github.com/jackc/pgx" - "github.com/jackc/pgx/pgtype" ) func TestConnCopyFromSmall(t *testing.T) { @@ -126,8 +125,8 @@ func TestConnCopyFromJSON(t *testing.T) { conn := mustConnect(t, *defaultConnConfig) defer closeConn(t, conn) - for _, oid := range []pgtype.Oid{pgx.JsonOid, pgx.JsonbOid} { - if _, ok := conn.PgTypes[oid]; !ok { + for _, typeName := range []string{"json", "jsonb"} { + if _, ok := conn.ConnInfo.DataTypeForName(typeName); !ok { return // No JSON/JSONB type -- must be running against old PostgreSQL } } diff --git a/example_custom_type_test.go b/example_custom_type_test.go index 71110f85..1c21c7e6 100644 --- a/example_custom_type_test.go +++ b/example_custom_type_test.go @@ -18,7 +18,7 @@ type Point struct { Status pgtype.Status } -func (dst *Point) DecodeText(src []byte) error { +func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error { if src == nil { *dst = Point{Status: pgtype.Null} return nil @@ -44,7 +44,7 @@ func (dst *Point) DecodeText(src []byte) error { return nil } -func (src Point) EncodeText(w io.Writer) (bool, error) { +func (src Point) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { switch src.Status { case pgtype.Null: return true, nil diff --git a/pgtype/aclitem.go b/pgtype/aclitem.go index b8a1549e..f9faab20 100644 --- a/pgtype/aclitem.go +++ b/pgtype/aclitem.go @@ -90,7 +90,7 @@ func (src *Aclitem) AssignTo(dst interface{}) error { return nil } -func (dst *Aclitem) DecodeText(src []byte) error { +func (dst *Aclitem) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Aclitem{Status: Null} return nil @@ -100,7 +100,7 @@ func (dst *Aclitem) DecodeText(src []byte) error { return nil } -func (src Aclitem) EncodeText(w io.Writer) (bool, error) { +func (src Aclitem) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/aclitem_array.go b/pgtype/aclitem_array.go index 5e3647b7..f02d339e 100644 --- a/pgtype/aclitem_array.go +++ b/pgtype/aclitem_array.go @@ -82,7 +82,7 @@ func (src *AclitemArray) AssignTo(dst interface{}) error { return nil } -func (dst *AclitemArray) DecodeText(src []byte) error { +func (dst *AclitemArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = AclitemArray{Status: Null} return nil @@ -104,7 +104,7 @@ func (dst *AclitemArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -118,7 +118,7 @@ func (dst *AclitemArray) DecodeText(src []byte) error { return nil } -func (src *AclitemArray) EncodeText(w io.Writer) (bool, error) { +func (src *AclitemArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -165,7 +165,7 @@ func (src *AclitemArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/array.go b/pgtype/array.go index dff0fe81..9561afe5 100644 --- a/pgtype/array.go +++ b/pgtype/array.go @@ -27,7 +27,7 @@ type ArrayDimension struct { LowerBound int32 } -func (dst *ArrayHeader) DecodeBinary(src []byte) (int, error) { +func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) { if len(src) < 12 { return 0, fmt.Errorf("array header too short: %d", len(src)) } @@ -60,7 +60,7 @@ func (dst *ArrayHeader) DecodeBinary(src []byte) (int, error) { return rp, nil } -func (src *ArrayHeader) EncodeBinary(w io.Writer) error { +func (src *ArrayHeader) EncodeBinary(ci *ConnInfo, w io.Writer) error { _, err := pgio.WriteInt32(w, int32(len(src.Dimensions))) if err != nil { return err diff --git a/pgtype/bool.go b/pgtype/bool.go index a8e9b8e1..87316381 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -79,7 +79,7 @@ func (src *Bool) AssignTo(dst interface{}) error { return nil } -func (dst *Bool) DecodeText(src []byte) error { +func (dst *Bool) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Bool{Status: Null} return nil @@ -93,7 +93,7 @@ func (dst *Bool) DecodeText(src []byte) error { return nil } -func (dst *Bool) DecodeBinary(src []byte) error { +func (dst *Bool) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Bool{Status: Null} return nil @@ -107,7 +107,7 @@ func (dst *Bool) DecodeBinary(src []byte) error { return nil } -func (src Bool) EncodeText(w io.Writer) (bool, error) { +func (src Bool) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -126,7 +126,7 @@ func (src Bool) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Bool) EncodeBinary(w io.Writer) (bool, error) { +func (src Bool) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/bool_array.go b/pgtype/bool_array.go index 4c5fc563..1cb46cf6 100644 --- a/pgtype/bool_array.go +++ b/pgtype/bool_array.go @@ -83,7 +83,7 @@ func (src *BoolArray) AssignTo(dst interface{}) error { return nil } -func (dst *BoolArray) DecodeText(src []byte) error { +func (dst *BoolArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = BoolArray{Status: Null} return nil @@ -105,7 +105,7 @@ func (dst *BoolArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -119,14 +119,14 @@ func (dst *BoolArray) DecodeText(src []byte) error { return nil } -func (dst *BoolArray) DecodeBinary(src []byte) error { +func (dst *BoolArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = BoolArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -151,7 +151,7 @@ func (dst *BoolArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -161,7 +161,7 @@ func (dst *BoolArray) DecodeBinary(src []byte) error { return nil } -func (src *BoolArray) EncodeText(w io.Writer) (bool, error) { +func (src *BoolArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -208,7 +208,7 @@ func (src *BoolArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -237,11 +237,11 @@ func (src *BoolArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *BoolArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, BoolOid) +func (src *BoolArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, BoolOid) } -func (src *BoolArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *BoolArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -261,7 +261,7 @@ func (src *BoolArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -271,7 +271,7 @@ func (src *BoolArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/bytea.go b/pgtype/bytea.go index 5df05360..dc1e9c07 100644 --- a/pgtype/bytea.go +++ b/pgtype/bytea.go @@ -78,7 +78,7 @@ func (src *Bytea) AssignTo(dst interface{}) error { // DecodeText only supports the hex format. This has been the default since // PostgreSQL 9.0. -func (dst *Bytea) DecodeText(src []byte) error { +func (dst *Bytea) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Bytea{Status: Null} return nil @@ -98,7 +98,7 @@ func (dst *Bytea) DecodeText(src []byte) error { return nil } -func (dst *Bytea) DecodeBinary(src []byte) error { +func (dst *Bytea) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Bytea{Status: Null} return nil @@ -111,7 +111,7 @@ func (dst *Bytea) DecodeBinary(src []byte) error { return nil } -func (src Bytea) EncodeText(w io.Writer) (bool, error) { +func (src Bytea) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -128,7 +128,7 @@ func (src Bytea) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Bytea) EncodeBinary(w io.Writer) (bool, error) { +func (src Bytea) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/bytea_array.go b/pgtype/bytea_array.go index c6f676a4..30405509 100644 --- a/pgtype/bytea_array.go +++ b/pgtype/bytea_array.go @@ -83,7 +83,7 @@ func (src *ByteaArray) AssignTo(dst interface{}) error { return nil } -func (dst *ByteaArray) DecodeText(src []byte) error { +func (dst *ByteaArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = ByteaArray{Status: Null} return nil @@ -105,7 +105,7 @@ func (dst *ByteaArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -119,14 +119,14 @@ func (dst *ByteaArray) DecodeText(src []byte) error { return nil } -func (dst *ByteaArray) DecodeBinary(src []byte) error { +func (dst *ByteaArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = ByteaArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -151,7 +151,7 @@ func (dst *ByteaArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -161,7 +161,7 @@ func (dst *ByteaArray) DecodeBinary(src []byte) error { return nil } -func (src *ByteaArray) EncodeText(w io.Writer) (bool, error) { +func (src *ByteaArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -208,7 +208,7 @@ func (src *ByteaArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -237,11 +237,11 @@ func (src *ByteaArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *ByteaArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, ByteaOid) +func (src *ByteaArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, ByteaOid) } -func (src *ByteaArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *ByteaArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -261,7 +261,7 @@ func (src *ByteaArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -271,7 +271,7 @@ func (src *ByteaArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/cid.go b/pgtype/cid.go index 20957f36..d86e8063 100644 --- a/pgtype/cid.go +++ b/pgtype/cid.go @@ -34,18 +34,18 @@ func (src *Cid) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *Cid) DecodeText(src []byte) error { - return (*pguint32)(dst).DecodeText(src) +func (dst *Cid) DecodeText(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeText(ci, src) } -func (dst *Cid) DecodeBinary(src []byte) error { - return (*pguint32)(dst).DecodeBinary(src) +func (dst *Cid) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src Cid) EncodeText(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeText(w) +func (src Cid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeText(ci, w) } -func (src Cid) EncodeBinary(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeBinary(w) +func (src Cid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeBinary(ci, w) } diff --git a/pgtype/cidr.go b/pgtype/cidr.go new file mode 100644 index 00000000..463b279d --- /dev/null +++ b/pgtype/cidr.go @@ -0,0 +1,35 @@ +package pgtype + +import ( + "io" +) + +type Cidr Inet + +func (dst *Cidr) Set(src interface{}) error { + return (*Inet)(dst).Set(src) +} + +func (dst *Cidr) Get() interface{} { + return (*Inet)(dst).Get() +} + +func (src *Cidr) AssignTo(dst interface{}) error { + return (*Inet)(src).AssignTo(dst) +} + +func (dst *Cidr) DecodeText(ci *ConnInfo, src []byte) error { + return (*Inet)(dst).DecodeText(ci, src) +} + +func (dst *Cidr) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Inet)(dst).DecodeBinary(ci, src) +} + +func (src Cidr) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (Inet)(src).EncodeText(ci, w) +} + +func (src Cidr) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (Inet)(src).EncodeBinary(ci, w) +} diff --git a/pgtype/cidr_array.go b/pgtype/cidr_array.go index c30c53d3..32d2e7bf 100644 --- a/pgtype/cidr_array.go +++ b/pgtype/cidr_array.go @@ -1,35 +1,328 @@ package pgtype import ( + "bytes" + "encoding/binary" + "fmt" "io" + "net" + + "github.com/jackc/pgx/pgio" ) -type CidrArray InetArray +type CidrArray struct { + Elements []Cidr + Dimensions []ArrayDimension + Status Status +} func (dst *CidrArray) Set(src interface{}) error { - return (*InetArray)(dst).Set(src) + switch value := src.(type) { + + case []*net.IPNet: + if value == nil { + *dst = CidrArray{Status: Null} + } else if len(value) == 0 { + *dst = CidrArray{Status: Present} + } else { + elements := make([]Cidr, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = CidrArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []net.IP: + if value == nil { + *dst = CidrArray{Status: Null} + } else if len(value) == 0 { + *dst = CidrArray{Status: Present} + } else { + elements := make([]Cidr, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = CidrArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Cidr", value) + } + + return nil } func (dst *CidrArray) Get() interface{} { - return (*InetArray)(dst).Get() + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } } func (src *CidrArray) AssignTo(dst interface{}) error { - return (*InetArray)(src).AssignTo(dst) + switch v := dst.(type) { + + case *[]*net.IPNet: + if src.Status == Present { + *v = make([]*net.IPNet, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + case *[]net.IP: + if src.Status == Present { + *v = make([]net.IP, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil } -func (dst *CidrArray) DecodeText(src []byte) error { - return (*InetArray)(dst).DecodeText(src) +func (dst *CidrArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = CidrArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Cidr + + if len(uta.Elements) > 0 { + elements = make([]Cidr, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Cidr + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = CidrArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil } -func (dst *CidrArray) DecodeBinary(src []byte) error { - return (*InetArray)(dst).DecodeBinary(src) +func (dst *CidrArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = CidrArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = CidrArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Cidr, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = CidrArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil } -func (src *CidrArray) EncodeText(w io.Writer) (bool, error) { - return (*InetArray)(src).EncodeText(w) +func (src *CidrArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if len(src.Dimensions) == 0 { + _, err := io.WriteString(w, "{}") + return false, err + } + + err := EncodeTextArrayDimensions(w, src.Dimensions) + if err != nil { + return false, err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(w, ',') + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(w, '{') + if err != nil { + return false, err + } + } + } + + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else { + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(w, '}') + if err != nil { + return false, err + } + } + } + } + + return false, nil } -func (src *CidrArray) EncodeBinary(w io.Writer) (bool, error) { - return (*InetArray)(src).encodeBinary(w, CidrOid) +func (src *CidrArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, CidrOid) +} + +func (src *CidrArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + arrayHeader := ArrayHeader{ + ElementOid: elementOid, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(ci, w) + if err != nil { + return false, err + } + + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } + } + } + + return false, err } diff --git a/pgtype/cidr_array_test.go b/pgtype/cidr_array_test.go new file mode 100644 index 00000000..ec105914 --- /dev/null +++ b/pgtype/cidr_array_test.go @@ -0,0 +1,164 @@ +package pgtype_test + +import ( + "net" + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestCidrArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "cidr[]", []interface{}{ + &pgtype.CidrArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.CidrArray{ + Elements: []pgtype.Cidr{ + pgtype.Cidr{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Cidr{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.CidrArray{Status: pgtype.Null}, + &pgtype.CidrArray{ + Elements: []pgtype.Cidr{ + pgtype.Cidr{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + pgtype.Cidr{Status: pgtype.Null}, + pgtype.Cidr{IPNet: mustParseCidr(t, "255.0.0.0/8"), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.CidrArray{ + Elements: []pgtype.Cidr{ + pgtype.Cidr{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "12.34.56.0/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "192.168.0.1/32"), Status: pgtype.Present}, + pgtype.Cidr{IPNet: mustParseCidr(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestCidrArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.CidrArray + }{ + { + source: []*net.IPNet{mustParseCidr(t, "127.0.0.1/32")}, + result: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]*net.IPNet)(nil)), + result: pgtype.CidrArray{Status: pgtype.Null}, + }, + { + source: []net.IP{mustParseCidr(t, "127.0.0.1/32").IP}, + result: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]net.IP)(nil)), + result: pgtype.CidrArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.CidrArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestCidrArrayAssignTo(t *testing.T) { + var ipnetSlice []*net.IPNet + var ipSlice []net.IP + + simpleTests := []struct { + src pgtype.CidrArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipnetSlice, + expected: []*net.IPNet{mustParseCidr(t, "127.0.0.1/32")}, + }, + { + src: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipnetSlice, + expected: []*net.IPNet{nil}, + }, + { + src: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{IPNet: mustParseCidr(t, "127.0.0.1/32"), Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipSlice, + expected: []net.IP{mustParseCidr(t, "127.0.0.1/32").IP}, + }, + { + src: pgtype.CidrArray{ + Elements: []pgtype.Cidr{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &ipSlice, + expected: []net.IP{nil}, + }, + { + src: pgtype.CidrArray{Status: pgtype.Null}, + dst: &ipnetSlice, + expected: (([]*net.IPNet)(nil)), + }, + { + src: pgtype.CidrArray{Status: pgtype.Null}, + dst: &ipSlice, + expected: (([]net.IP)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/database_sql.go b/pgtype/database_sql.go new file mode 100644 index 00000000..969d6542 --- /dev/null +++ b/pgtype/database_sql.go @@ -0,0 +1,66 @@ +package pgtype + +import ( + "bytes" + "errors" +) + +func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) { + switch src := src.(type) { + case *Bool: + return src.Bool, nil + case *Bytea: + return src.Bytes, nil + case *Date: + if src.InfinityModifier == None { + return src.Time, nil + } + case *Float4: + return float64(src.Float), nil + case *Float8: + return src.Float, nil + case *GenericBinary: + return src.Bytes, nil + case *GenericText: + return src.String, nil + case *Int2: + return int64(src.Int), nil + case *Int4: + return int64(src.Int), nil + case *Int8: + return int64(src.Int), nil + case *Text: + return src.String, nil + case *Timestamp: + if src.InfinityModifier == None { + return src.Time, nil + } + case *Timestamptz: + if src.InfinityModifier == None { + return src.Time, nil + } + case *Unknown: + return src.String, nil + case *Varchar: + return src.String, nil + } + + buf := &bytes.Buffer{} + if textEncoder, ok := src.(TextEncoder); ok { + _, err := textEncoder.EncodeText(ci, buf) + if err != nil { + return nil, err + } + return buf.String(), nil + } + + if binaryEncoder, ok := src.(BinaryEncoder); ok { + _, err := binaryEncoder.EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + return buf.Bytes(), nil + } + + return nil, errors.New("cannot convert to database/sql compatible value") +} diff --git a/pgtype/date.go b/pgtype/date.go index d0481637..b6cc8329 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -38,6 +38,9 @@ func (dst *Date) Set(src interface{}) error { func (dst *Date) Get() interface{} { switch dst.Status { case Present: + if dst.InfinityModifier != None { + return dst.InfinityModifier + } return dst.Time case Null: return nil @@ -76,7 +79,7 @@ func (src *Date) AssignTo(dst interface{}) error { return nil } -func (dst *Date) DecodeText(src []byte) error { +func (dst *Date) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Date{Status: Null} return nil @@ -100,7 +103,7 @@ func (dst *Date) DecodeText(src []byte) error { return nil } -func (dst *Date) DecodeBinary(src []byte) error { +func (dst *Date) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Date{Status: Null} return nil @@ -125,7 +128,7 @@ func (dst *Date) DecodeBinary(src []byte) error { return nil } -func (src Date) EncodeText(w io.Writer) (bool, error) { +func (src Date) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -148,7 +151,7 @@ func (src Date) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Date) EncodeBinary(w io.Writer) (bool, error) { +func (src Date) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/date_array.go b/pgtype/date_array.go index 7f602d83..ba68d561 100644 --- a/pgtype/date_array.go +++ b/pgtype/date_array.go @@ -84,7 +84,7 @@ func (src *DateArray) AssignTo(dst interface{}) error { return nil } -func (dst *DateArray) DecodeText(src []byte) error { +func (dst *DateArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = DateArray{Status: Null} return nil @@ -106,7 +106,7 @@ func (dst *DateArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -120,14 +120,14 @@ func (dst *DateArray) DecodeText(src []byte) error { return nil } -func (dst *DateArray) DecodeBinary(src []byte) error { +func (dst *DateArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = DateArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -152,7 +152,7 @@ func (dst *DateArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -162,7 +162,7 @@ func (dst *DateArray) DecodeBinary(src []byte) error { return nil } -func (src *DateArray) EncodeText(w io.Writer) (bool, error) { +func (src *DateArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -209,7 +209,7 @@ func (src *DateArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -238,11 +238,11 @@ func (src *DateArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *DateArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, DateOid) +func (src *DateArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, DateOid) } -func (src *DateArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *DateArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -262,7 +262,7 @@ func (src *DateArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -272,7 +272,7 @@ func (src *DateArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/float4.go b/pgtype/float4.go index 053af44b..94b7b7a1 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -102,7 +102,7 @@ func (src *Float4) AssignTo(dst interface{}) error { return float64AssignTo(float64(src.Float), src.Status, dst) } -func (dst *Float4) DecodeText(src []byte) error { +func (dst *Float4) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float4{Status: Null} return nil @@ -117,7 +117,7 @@ func (dst *Float4) DecodeText(src []byte) error { return nil } -func (dst *Float4) DecodeBinary(src []byte) error { +func (dst *Float4) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float4{Status: Null} return nil @@ -133,7 +133,7 @@ func (dst *Float4) DecodeBinary(src []byte) error { return nil } -func (src Float4) EncodeText(w io.Writer) (bool, error) { +func (src Float4) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -145,7 +145,7 @@ func (src Float4) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Float4) EncodeBinary(w io.Writer) (bool, error) { +func (src Float4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/float4_array.go b/pgtype/float4_array.go index 0e815e0b..40152bcf 100644 --- a/pgtype/float4_array.go +++ b/pgtype/float4_array.go @@ -83,7 +83,7 @@ func (src *Float4Array) AssignTo(dst interface{}) error { return nil } -func (dst *Float4Array) DecodeText(src []byte) error { +func (dst *Float4Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float4Array{Status: Null} return nil @@ -105,7 +105,7 @@ func (dst *Float4Array) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -119,14 +119,14 @@ func (dst *Float4Array) DecodeText(src []byte) error { return nil } -func (dst *Float4Array) DecodeBinary(src []byte) error { +func (dst *Float4Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float4Array{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -151,7 +151,7 @@ func (dst *Float4Array) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -161,7 +161,7 @@ func (dst *Float4Array) DecodeBinary(src []byte) error { return nil } -func (src *Float4Array) EncodeText(w io.Writer) (bool, error) { +func (src *Float4Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -208,7 +208,7 @@ func (src *Float4Array) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -237,11 +237,11 @@ func (src *Float4Array) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *Float4Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Float4Oid) +func (src *Float4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, Float4Oid) } -func (src *Float4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *Float4Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -261,7 +261,7 @@ func (src *Float4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -271,7 +271,7 @@ func (src *Float4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/float8.go b/pgtype/float8.go index 635b7a09..dd2d592d 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -92,7 +92,7 @@ func (src *Float8) AssignTo(dst interface{}) error { return float64AssignTo(src.Float, src.Status, dst) } -func (dst *Float8) DecodeText(src []byte) error { +func (dst *Float8) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float8{Status: Null} return nil @@ -107,7 +107,7 @@ func (dst *Float8) DecodeText(src []byte) error { return nil } -func (dst *Float8) DecodeBinary(src []byte) error { +func (dst *Float8) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float8{Status: Null} return nil @@ -123,7 +123,7 @@ func (dst *Float8) DecodeBinary(src []byte) error { return nil } -func (src Float8) EncodeText(w io.Writer) (bool, error) { +func (src Float8) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -135,7 +135,7 @@ func (src Float8) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Float8) EncodeBinary(w io.Writer) (bool, error) { +func (src Float8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/float8_array.go b/pgtype/float8_array.go index 811c5a1f..d0ee0d70 100644 --- a/pgtype/float8_array.go +++ b/pgtype/float8_array.go @@ -83,7 +83,7 @@ func (src *Float8Array) AssignTo(dst interface{}) error { return nil } -func (dst *Float8Array) DecodeText(src []byte) error { +func (dst *Float8Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float8Array{Status: Null} return nil @@ -105,7 +105,7 @@ func (dst *Float8Array) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -119,14 +119,14 @@ func (dst *Float8Array) DecodeText(src []byte) error { return nil } -func (dst *Float8Array) DecodeBinary(src []byte) error { +func (dst *Float8Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Float8Array{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -151,7 +151,7 @@ func (dst *Float8Array) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -161,7 +161,7 @@ func (dst *Float8Array) DecodeBinary(src []byte) error { return nil } -func (src *Float8Array) EncodeText(w io.Writer) (bool, error) { +func (src *Float8Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -208,7 +208,7 @@ func (src *Float8Array) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -237,11 +237,11 @@ func (src *Float8Array) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *Float8Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Float8Oid) +func (src *Float8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, Float8Oid) } -func (src *Float8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *Float8Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -261,7 +261,7 @@ func (src *Float8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -271,7 +271,7 @@ func (src *Float8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/generic_binary.go b/pgtype/generic_binary.go index ac35ea60..aa28bb62 100644 --- a/pgtype/generic_binary.go +++ b/pgtype/generic_binary.go @@ -20,10 +20,10 @@ func (src *GenericBinary) AssignTo(dst interface{}) error { return (*Bytea)(src).AssignTo(dst) } -func (dst *GenericBinary) DecodeBinary(src []byte) error { - return (*Bytea)(dst).DecodeBinary(src) +func (dst *GenericBinary) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Bytea)(dst).DecodeBinary(ci, src) } -func (src GenericBinary) EncodeBinary(w io.Writer) (bool, error) { - return (Bytea)(src).EncodeBinary(w) +func (src GenericBinary) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (Bytea)(src).EncodeBinary(ci, w) } diff --git a/pgtype/generic_text.go b/pgtype/generic_text.go index 19f41059..bd75e0d0 100644 --- a/pgtype/generic_text.go +++ b/pgtype/generic_text.go @@ -20,10 +20,10 @@ func (src *GenericText) AssignTo(dst interface{}) error { return (*Text)(src).AssignTo(dst) } -func (dst *GenericText) DecodeText(src []byte) error { - return (*Text)(dst).DecodeText(src) +func (dst *GenericText) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) } -func (src GenericText) EncodeText(w io.Writer) (bool, error) { - return (Text)(src).EncodeText(w) +func (src GenericText) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (Text)(src).EncodeText(ci, w) } diff --git a/pgtype/hstore.go b/pgtype/hstore.go index c48ae6da..d771d6e6 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -70,7 +70,7 @@ func (src *Hstore) AssignTo(dst interface{}) error { return nil } -func (dst *Hstore) DecodeText(src []byte) error { +func (dst *Hstore) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Hstore{Status: Null} return nil @@ -90,7 +90,7 @@ func (dst *Hstore) DecodeText(src []byte) error { return nil } -func (dst *Hstore) DecodeBinary(src []byte) error { +func (dst *Hstore) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Hstore{Status: Null} return nil @@ -132,7 +132,7 @@ func (dst *Hstore) DecodeBinary(src []byte) error { rp += valueLen var value Text - err := value.DecodeBinary(valueBuf) + err := value.DecodeBinary(ci, valueBuf) if err != nil { return err } @@ -144,7 +144,7 @@ func (dst *Hstore) DecodeBinary(src []byte) error { return nil } -func (src Hstore) EncodeText(w io.Writer) (bool, error) { +func (src Hstore) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -175,7 +175,7 @@ func (src Hstore) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := v.EncodeText(elemBuf) + null, err := v.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -196,7 +196,7 @@ func (src Hstore) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src Hstore) EncodeBinary(w io.Writer) (bool, error) { +func (src Hstore) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -220,7 +220,7 @@ func (src Hstore) EncodeBinary(w io.Writer) (bool, error) { return false, err } - null, err := v.EncodeText(elemBuf) + null, err := v.EncodeText(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/inet.go b/pgtype/inet.go index 87d675f9..b83bd1c9 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -100,7 +100,7 @@ func (src *Inet) AssignTo(dst interface{}) error { return nil } -func (dst *Inet) DecodeText(src []byte) error { +func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Inet{Status: Null} return nil @@ -128,7 +128,7 @@ func (dst *Inet) DecodeText(src []byte) error { return nil } -func (dst *Inet) DecodeBinary(src []byte) error { +func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Inet{Status: Null} return nil @@ -153,7 +153,7 @@ func (dst *Inet) DecodeBinary(src []byte) error { return nil } -func (src Inet) EncodeText(w io.Writer) (bool, error) { +func (src Inet) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src Inet) EncodeText(w io.Writer) (bool, error) { } // EncodeBinary encodes src into w. -func (src Inet) EncodeBinary(w io.Writer) (bool, error) { +func (src Inet) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/inet_array.go b/pgtype/inet_array.go index 1d1cf3fd..6cad82e7 100644 --- a/pgtype/inet_array.go +++ b/pgtype/inet_array.go @@ -115,7 +115,7 @@ func (src *InetArray) AssignTo(dst interface{}) error { return nil } -func (dst *InetArray) DecodeText(src []byte) error { +func (dst *InetArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = InetArray{Status: Null} return nil @@ -137,7 +137,7 @@ func (dst *InetArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -151,14 +151,14 @@ func (dst *InetArray) DecodeText(src []byte) error { return nil } -func (dst *InetArray) DecodeBinary(src []byte) error { +func (dst *InetArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = InetArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -183,7 +183,7 @@ func (dst *InetArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -193,7 +193,7 @@ func (dst *InetArray) DecodeBinary(src []byte) error { return nil } -func (src *InetArray) EncodeText(w io.Writer) (bool, error) { +func (src *InetArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -240,7 +240,7 @@ func (src *InetArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -269,11 +269,11 @@ func (src *InetArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *InetArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, InetOid) +func (src *InetArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, InetOid) } -func (src *InetArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *InetArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -293,7 +293,7 @@ func (src *InetArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -303,7 +303,7 @@ func (src *InetArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/int2.go b/pgtype/int2.go index 62e1bc69..6996cd4f 100644 --- a/pgtype/int2.go +++ b/pgtype/int2.go @@ -98,7 +98,7 @@ func (src *Int2) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *Int2) DecodeText(src []byte) error { +func (dst *Int2) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int2{Status: Null} return nil @@ -113,7 +113,7 @@ func (dst *Int2) DecodeText(src []byte) error { return nil } -func (dst *Int2) DecodeBinary(src []byte) error { +func (dst *Int2) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int2{Status: Null} return nil @@ -128,7 +128,7 @@ func (dst *Int2) DecodeBinary(src []byte) error { return nil } -func (src Int2) EncodeText(w io.Writer) (bool, error) { +func (src Int2) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -140,7 +140,7 @@ func (src Int2) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Int2) EncodeBinary(w io.Writer) (bool, error) { +func (src Int2) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/int2_array.go b/pgtype/int2_array.go index 3d06c018..2bf1c237 100644 --- a/pgtype/int2_array.go +++ b/pgtype/int2_array.go @@ -114,7 +114,7 @@ func (src *Int2Array) AssignTo(dst interface{}) error { return nil } -func (dst *Int2Array) DecodeText(src []byte) error { +func (dst *Int2Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int2Array{Status: Null} return nil @@ -136,7 +136,7 @@ func (dst *Int2Array) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -150,14 +150,14 @@ func (dst *Int2Array) DecodeText(src []byte) error { return nil } -func (dst *Int2Array) DecodeBinary(src []byte) error { +func (dst *Int2Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int2Array{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -182,7 +182,7 @@ func (dst *Int2Array) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -192,7 +192,7 @@ func (dst *Int2Array) DecodeBinary(src []byte) error { return nil } -func (src *Int2Array) EncodeText(w io.Writer) (bool, error) { +func (src *Int2Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -239,7 +239,7 @@ func (src *Int2Array) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -268,11 +268,11 @@ func (src *Int2Array) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *Int2Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Int2Oid) +func (src *Int2Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, Int2Oid) } -func (src *Int2Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *Int2Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -292,7 +292,7 @@ func (src *Int2Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -302,7 +302,7 @@ func (src *Int2Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/int4.go b/pgtype/int4.go index 8eaf5094..62ee366f 100644 --- a/pgtype/int4.go +++ b/pgtype/int4.go @@ -89,7 +89,7 @@ func (src *Int4) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *Int4) DecodeText(src []byte) error { +func (dst *Int4) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int4{Status: Null} return nil @@ -104,7 +104,7 @@ func (dst *Int4) DecodeText(src []byte) error { return nil } -func (dst *Int4) DecodeBinary(src []byte) error { +func (dst *Int4) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int4{Status: Null} return nil @@ -119,7 +119,7 @@ func (dst *Int4) DecodeBinary(src []byte) error { return nil } -func (src Int4) EncodeText(w io.Writer) (bool, error) { +func (src Int4) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -131,7 +131,7 @@ func (src Int4) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Int4) EncodeBinary(w io.Writer) (bool, error) { +func (src Int4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/int4_array.go b/pgtype/int4_array.go index 5cd91c04..dda88eaf 100644 --- a/pgtype/int4_array.go +++ b/pgtype/int4_array.go @@ -114,7 +114,7 @@ func (src *Int4Array) AssignTo(dst interface{}) error { return nil } -func (dst *Int4Array) DecodeText(src []byte) error { +func (dst *Int4Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int4Array{Status: Null} return nil @@ -136,7 +136,7 @@ func (dst *Int4Array) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -150,14 +150,14 @@ func (dst *Int4Array) DecodeText(src []byte) error { return nil } -func (dst *Int4Array) DecodeBinary(src []byte) error { +func (dst *Int4Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int4Array{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -182,7 +182,7 @@ func (dst *Int4Array) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -192,7 +192,7 @@ func (dst *Int4Array) DecodeBinary(src []byte) error { return nil } -func (src *Int4Array) EncodeText(w io.Writer) (bool, error) { +func (src *Int4Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -239,7 +239,7 @@ func (src *Int4Array) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -268,11 +268,11 @@ func (src *Int4Array) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *Int4Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Int4Oid) +func (src *Int4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, Int4Oid) } -func (src *Int4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *Int4Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -292,7 +292,7 @@ func (src *Int4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -302,7 +302,7 @@ func (src *Int4Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/int8.go b/pgtype/int8.go index 2416500d..7ed54f8e 100644 --- a/pgtype/int8.go +++ b/pgtype/int8.go @@ -80,7 +80,7 @@ func (src *Int8) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *Int8) DecodeText(src []byte) error { +func (dst *Int8) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int8{Status: Null} return nil @@ -95,7 +95,7 @@ func (dst *Int8) DecodeText(src []byte) error { return nil } -func (dst *Int8) DecodeBinary(src []byte) error { +func (dst *Int8) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int8{Status: Null} return nil @@ -111,7 +111,7 @@ func (dst *Int8) DecodeBinary(src []byte) error { return nil } -func (src Int8) EncodeText(w io.Writer) (bool, error) { +func (src Int8) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -123,7 +123,7 @@ func (src Int8) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Int8) EncodeBinary(w io.Writer) (bool, error) { +func (src Int8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/int8_array.go b/pgtype/int8_array.go index 5efc0f45..468c126b 100644 --- a/pgtype/int8_array.go +++ b/pgtype/int8_array.go @@ -114,7 +114,7 @@ func (src *Int8Array) AssignTo(dst interface{}) error { return nil } -func (dst *Int8Array) DecodeText(src []byte) error { +func (dst *Int8Array) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int8Array{Status: Null} return nil @@ -136,7 +136,7 @@ func (dst *Int8Array) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -150,14 +150,14 @@ func (dst *Int8Array) DecodeText(src []byte) error { return nil } -func (dst *Int8Array) DecodeBinary(src []byte) error { +func (dst *Int8Array) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Int8Array{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -182,7 +182,7 @@ func (dst *Int8Array) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -192,7 +192,7 @@ func (dst *Int8Array) DecodeBinary(src []byte) error { return nil } -func (src *Int8Array) EncodeText(w io.Writer) (bool, error) { +func (src *Int8Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -239,7 +239,7 @@ func (src *Int8Array) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -268,11 +268,11 @@ func (src *Int8Array) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *Int8Array) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, Int8Oid) +func (src *Int8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, Int8Oid) } -func (src *Int8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *Int8Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -292,7 +292,7 @@ func (src *Int8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -302,7 +302,7 @@ func (src *Int8Array) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/json.go b/pgtype/json.go index ecdb3dab..bfffae14 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -84,7 +84,7 @@ func (src *Json) AssignTo(dst interface{}) error { return nil } -func (dst *Json) DecodeText(src []byte) error { +func (dst *Json) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Json{Status: Null} return nil @@ -97,11 +97,11 @@ func (dst *Json) DecodeText(src []byte) error { return nil } -func (dst *Json) DecodeBinary(src []byte) error { - return dst.DecodeText(src) +func (dst *Json) DecodeBinary(ci *ConnInfo, src []byte) error { + return dst.DecodeText(ci, src) } -func (src Json) EncodeText(w io.Writer) (bool, error) { +func (src Json) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -113,6 +113,6 @@ func (src Json) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Json) EncodeBinary(w io.Writer) (bool, error) { - return src.EncodeText(w) +func (src Json) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.EncodeText(ci, w) } diff --git a/pgtype/jsonb.go b/pgtype/jsonb.go index 13062e8e..e44f3c41 100644 --- a/pgtype/jsonb.go +++ b/pgtype/jsonb.go @@ -19,11 +19,11 @@ func (src *Jsonb) AssignTo(dst interface{}) error { return (*Json)(src).AssignTo(dst) } -func (dst *Jsonb) DecodeText(src []byte) error { - return (*Json)(dst).DecodeText(src) +func (dst *Jsonb) DecodeText(ci *ConnInfo, src []byte) error { + return (*Json)(dst).DecodeText(ci, src) } -func (dst *Jsonb) DecodeBinary(src []byte) error { +func (dst *Jsonb) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Jsonb{Status: Null} return nil @@ -46,11 +46,11 @@ func (dst *Jsonb) DecodeBinary(src []byte) error { } -func (src Jsonb) EncodeText(w io.Writer) (bool, error) { - return (Json)(src).EncodeText(w) +func (src Jsonb) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (Json)(src).EncodeText(ci, w) } -func (src Jsonb) EncodeBinary(w io.Writer) (bool, error) { +func (src Jsonb) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/name.go b/pgtype/name.go index 9eb12ece..9ebf63d3 100644 --- a/pgtype/name.go +++ b/pgtype/name.go @@ -31,18 +31,18 @@ func (src *Name) AssignTo(dst interface{}) error { return (*Text)(src).AssignTo(dst) } -func (dst *Name) DecodeText(src []byte) error { - return (*Text)(dst).DecodeText(src) +func (dst *Name) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) } -func (dst *Name) DecodeBinary(src []byte) error { - return (*Text)(dst).DecodeBinary(src) +func (dst *Name) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeBinary(ci, src) } -func (src Name) EncodeText(w io.Writer) (bool, error) { - return (Text)(src).EncodeText(w) +func (src Name) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (Text)(src).EncodeText(ci, w) } -func (src Name) EncodeBinary(w io.Writer) (bool, error) { - return (Text)(src).EncodeBinary(w) +func (src Name) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (Text)(src).EncodeBinary(ci, w) } diff --git a/pgtype/oid.go b/pgtype/oid.go index eab1fbcb..3edd7f3c 100644 --- a/pgtype/oid.go +++ b/pgtype/oid.go @@ -18,7 +18,7 @@ import ( // allow for NULL Oids use OidValue. type Oid uint32 -func (dst *Oid) DecodeText(src []byte) error { +func (dst *Oid) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { return fmt.Errorf("cannot decode nil into Oid") } @@ -32,7 +32,7 @@ func (dst *Oid) DecodeText(src []byte) error { return nil } -func (dst *Oid) DecodeBinary(src []byte) error { +func (dst *Oid) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { return fmt.Errorf("cannot decode nil into Oid") } @@ -46,12 +46,12 @@ func (dst *Oid) DecodeBinary(src []byte) error { return nil } -func (src Oid) EncodeText(w io.Writer) (bool, error) { +func (src Oid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { _, err := io.WriteString(w, strconv.FormatUint(uint64(src), 10)) return false, err } -func (src Oid) EncodeBinary(w io.Writer) (bool, error) { +func (src Oid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteUint32(w, uint32(src)) return false, err } diff --git a/pgtype/oid_value.go b/pgtype/oid_value.go index a2b2dcbe..1bce6e11 100644 --- a/pgtype/oid_value.go +++ b/pgtype/oid_value.go @@ -28,18 +28,18 @@ func (src *OidValue) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *OidValue) DecodeText(src []byte) error { - return (*pguint32)(dst).DecodeText(src) +func (dst *OidValue) DecodeText(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeText(ci, src) } -func (dst *OidValue) DecodeBinary(src []byte) error { - return (*pguint32)(dst).DecodeBinary(src) +func (dst *OidValue) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src OidValue) EncodeText(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeText(w) +func (src OidValue) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeText(ci, w) } -func (src OidValue) EncodeBinary(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeBinary(w) +func (src OidValue) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeBinary(ci, w) } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 7b1470b7..674c0db7 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -3,6 +3,7 @@ package pgtype import ( "errors" "io" + "reflect" ) // PostgreSQL oids for common types @@ -83,14 +84,14 @@ type BinaryDecoder interface { // DecodeBinary decodes src into BinaryDecoder. If src is nil then the // original SQL value is NULL. BinaryDecoder MUST not retain a reference to // src. It MUST make a copy if it needs to retain the raw bytes. - DecodeBinary(src []byte) error + DecodeBinary(ci *ConnInfo, src []byte) error } type TextDecoder interface { // DecodeText decodes src into TextDecoder. If src is nil then the original // SQL value is NULL. TextDecoder MUST not retain a reference to src. It MUST // make a copy if it needs to retain the raw bytes. - DecodeText(src []byte) error + DecodeText(ci *ConnInfo, src []byte) error } // BinaryEncoder is implemented by types that can encode themselves into the @@ -100,7 +101,7 @@ type BinaryEncoder interface { // SQL value NULL then write nothing and return (true, nil). The caller of // EncodeBinary is responsible for writing the correct NULL value or the // length of the data written. - EncodeBinary(w io.Writer) (null bool, err error) + EncodeBinary(ci *ConnInfo, w io.Writer) (null bool, err error) } // TextEncoder is implemented by types that can encode themselves into the @@ -110,7 +111,127 @@ type TextEncoder interface { // value NULL then write nothing and return (true, nil). The caller of // EncodeText is responsible for writing the correct NULL value or the length // of the data written. - EncodeText(w io.Writer) (null bool, err error) + EncodeText(ci *ConnInfo, w io.Writer) (null bool, err error) } var errUndefined = errors.New("cannot encode status undefined") + +type DataType struct { + Value Value + Name string + Oid Oid +} + +type ConnInfo struct { + oidToDataType map[Oid]*DataType + nameToDataType map[string]*DataType + reflectTypeToDataType map[reflect.Type]*DataType +} + +func NewConnInfo() *ConnInfo { + return &ConnInfo{ + oidToDataType: make(map[Oid]*DataType, 256), + nameToDataType: make(map[string]*DataType, 256), + reflectTypeToDataType: make(map[reflect.Type]*DataType, 256), + } +} + +func (ci *ConnInfo) InitializeDataTypes(nameOids map[string]Oid) { + for name, oid := range nameOids { + var value Value + if t, ok := nameValues[name]; ok { + value = reflect.New(reflect.ValueOf(t).Elem().Type()).Interface().(Value) + } else { + value = &GenericText{} + } + ci.RegisterDataType(DataType{Value: value, Name: name, Oid: oid}) + } +} + +func (ci *ConnInfo) RegisterDataType(t DataType) { + ci.oidToDataType[t.Oid] = &t + ci.nameToDataType[t.Name] = &t + ci.reflectTypeToDataType[reflect.ValueOf(t.Value).Type()] = &t +} + +func (ci *ConnInfo) DataTypeForOid(oid Oid) (*DataType, bool) { + dt, ok := ci.oidToDataType[oid] + return dt, ok +} + +func (ci *ConnInfo) DataTypeForName(name string) (*DataType, bool) { + dt, ok := ci.nameToDataType[name] + return dt, ok +} + +func (ci *ConnInfo) DataTypeForValue(v Value) (*DataType, bool) { + dt, ok := ci.reflectTypeToDataType[reflect.ValueOf(v).Type()] + return dt, ok +} + +// DeepCopy makes a deep copy of the ConnInfo. +func (ci *ConnInfo) DeepCopy() *ConnInfo { + ci2 := &ConnInfo{ + oidToDataType: make(map[Oid]*DataType, len(ci.oidToDataType)), + nameToDataType: make(map[string]*DataType, len(ci.nameToDataType)), + reflectTypeToDataType: make(map[reflect.Type]*DataType, len(ci.reflectTypeToDataType)), + } + + for _, dt := range ci.oidToDataType { + ci2.RegisterDataType(DataType{ + Value: reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(Value), + Name: dt.Name, + Oid: dt.Oid, + }) + } + + return ci2 +} + +var nameValues map[string]Value + +func init() { + nameValues = map[string]Value{ + "_aclitem": &AclitemArray{}, + "_bool": &BoolArray{}, + "_bytea": &ByteaArray{}, + "_cidr": &CidrArray{}, + "_date": &DateArray{}, + "_float4": &Float4Array{}, + "_float8": &Float8Array{}, + "_inet": &InetArray{}, + "_int2": &Int2Array{}, + "_int4": &Int4Array{}, + "_int8": &Int8Array{}, + "_text": &TextArray{}, + "_timestamp": &TimestampArray{}, + "_timestamptz": &TimestamptzArray{}, + "_varchar": &VarcharArray{}, + "aclitem": &Aclitem{}, + "bool": &Bool{}, + "bytea": &Bytea{}, + "char": &QChar{}, + "cid": &Cid{}, + "cidr": &Cidr{}, + "date": &Date{}, + "float4": &Float4{}, + "float8": &Float8{}, + "hstore": &Hstore{}, + "inet": &Inet{}, + "int2": &Int2{}, + "int4": &Int4{}, + "int8": &Int8{}, + "json": &Json{}, + "jsonb": &Jsonb{}, + "name": &Name{}, + "oid": &OidValue{}, + "record": &Record{}, + "text": &Text{}, + "tid": &Tid{}, + "timestamp": &Timestamp{}, + "timestamptz": &Timestamptz{}, + "unknown": &Unknown{}, + "varchar": &Varchar{}, + "xid": &Xid{}, + } +} diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index f9b6f56d..391fed57 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -60,16 +60,16 @@ type forceTextEncoder struct { e pgtype.TextEncoder } -func (f forceTextEncoder) EncodeText(w io.Writer) (bool, error) { - return f.e.EncodeText(w) +func (f forceTextEncoder) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { + return f.e.EncodeText(ci, w) } type forceBinaryEncoder struct { e pgtype.BinaryEncoder } -func (f forceBinaryEncoder) EncodeBinary(w io.Writer) (bool, error) { - return f.e.EncodeBinary(w) +func (f forceBinaryEncoder) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { + return f.e.EncodeBinary(ci, w) } func forceEncoder(e interface{}, formatCode int16) interface{} { @@ -114,7 +114,7 @@ func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []int ps.FieldDescriptions[0].FormatCode = fc.formatCode vEncoder := forceEncoder(v, fc.formatCode) if vEncoder == nil { - t.Logf("%v does not implement %v", fc.name) + t.Logf("%#v does not implement %v", v, fc.name) continue } // Derefence value if it is a pointer diff --git a/pgtype/pguint32.go b/pgtype/pguint32.go index 05c79c0e..3f9e7bf7 100644 --- a/pgtype/pguint32.go +++ b/pgtype/pguint32.go @@ -63,7 +63,7 @@ func (src *pguint32) AssignTo(dst interface{}) error { return nil } -func (dst *pguint32) DecodeText(src []byte) error { +func (dst *pguint32) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = pguint32{Status: Null} return nil @@ -78,7 +78,7 @@ func (dst *pguint32) DecodeText(src []byte) error { return nil } -func (dst *pguint32) DecodeBinary(src []byte) error { +func (dst *pguint32) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = pguint32{Status: Null} return nil @@ -93,7 +93,7 @@ func (dst *pguint32) DecodeBinary(src []byte) error { return nil } -func (src pguint32) EncodeText(w io.Writer) (bool, error) { +func (src pguint32) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -105,7 +105,7 @@ func (src pguint32) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src pguint32) EncodeBinary(w io.Writer) (bool, error) { +func (src pguint32) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/qchar.go b/pgtype/qchar.go index d46e716d..4b32ee4a 100644 --- a/pgtype/qchar.go +++ b/pgtype/qchar.go @@ -115,7 +115,7 @@ func (src *QChar) AssignTo(dst interface{}) error { return int64AssignTo(int64(src.Int), src.Status, dst) } -func (dst *QChar) DecodeBinary(src []byte) error { +func (dst *QChar) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = QChar{Status: Null} return nil @@ -129,7 +129,7 @@ func (dst *QChar) DecodeBinary(src []byte) error { return nil } -func (src QChar) EncodeBinary(w io.Writer) (bool, error) { +func (src QChar) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/record.go b/pgtype/record.go new file mode 100644 index 00000000..1bfd05b9 --- /dev/null +++ b/pgtype/record.go @@ -0,0 +1,123 @@ +package pgtype + +import ( + "encoding/binary" + "fmt" +) + +// Record is the generic PostgreSQL record type such as is created with the +// "row" function. Record only implements BinaryEncoder and Value. The text +// format output format from PostgreSQL does not include type information and is +// therefore impossible to decode. No encoders are implemented because +// PostgreSQL does not support input of generic records. +type Record struct { + Fields []Value + Status Status +} + +func (dst *Record) Set(src interface{}) error { + switch value := src.(type) { + case []Value: + *dst = Record{Fields: value, Status: Present} + default: + return fmt.Errorf("cannot convert %v to Record", src) + } + + return nil +} + +func (dst *Record) Get() interface{} { + switch dst.Status { + case Present: + return dst.Fields + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Record) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *[]Value: + switch src.Status { + case Present: + *v = make([]Value, len(src.Fields)) + copy(*v, src.Fields) + case Null: + *v = nil + default: + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + case *[]interface{}: + switch src.Status { + case Present: + *v = make([]interface{}, len(src.Fields)) + for i := range *v { + (*v)[i] = src.Fields[i].Get() + } + case Null: + *v = nil + default: + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + default: + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +func (dst *Record) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Record{Status: Null} + return nil + } + + rp := 0 + + if len(src[rp:]) < 4 { + return fmt.Errorf("Record incomplete %v", src) + } + fieldCount := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + fields := make([]Value, fieldCount) + + for i := 0; i < fieldCount; i++ { + if len(src[rp:]) < 8 { + return fmt.Errorf("Record incomplete %v", src) + } + fieldOid := Oid(binary.BigEndian.Uint32(src[rp:])) + rp += 4 + + fieldLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + + var binaryDecoder BinaryDecoder + if dt, ok := ci.DataTypeForOid(fieldOid); ok { + if binaryDecoder, ok = dt.Value.(BinaryDecoder); !ok { + return fmt.Errorf("unknown oid while decoding record: %v", fieldOid) + } + } + + var fieldBytes []byte + if fieldLen >= 0 { + if len(src[rp:]) < fieldLen { + return fmt.Errorf("Record incomplete %v", src) + } + fieldBytes = src[rp : rp+fieldLen] + rp += fieldLen + } + + if err := binaryDecoder.DecodeBinary(ci, fieldBytes); err != nil { + return err + } + + fields[i] = binaryDecoder.(Value) + } + + *dst = Record{Fields: fields, Status: Present} + + return nil +} diff --git a/pgtype/record_test.go b/pgtype/record_test.go new file mode 100644 index 00000000..bc6e5893 --- /dev/null +++ b/pgtype/record_test.go @@ -0,0 +1,150 @@ +package pgtype_test + +import ( + "fmt" + "reflect" + "testing" + + "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" +) + +func TestRecordTranscode(t *testing.T) { + conn := mustConnectPgx(t) + defer mustClose(t, conn) + + tests := []struct { + sql string + expected pgtype.Record + }{ + { + sql: `select row()`, + expected: pgtype.Record{ + Fields: []pgtype.Value{}, + Status: pgtype.Present, + }, + }, + { + sql: `select row('foo'::text, 42::int4)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select row('foo'::text, array[1, 2, null, 4]::int4[], 42::int4)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4Array{ + Elements: []pgtype.Int4{ + pgtype.Int4{Int: 1, Status: pgtype.Present}, + pgtype.Int4{Int: 2, Status: pgtype.Present}, + pgtype.Int4{Status: pgtype.Null}, + pgtype.Int4{Int: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 4, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select row(null)`, + expected: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Unknown{Status: pgtype.Null}, + }, + Status: pgtype.Present, + }, + }, + { + sql: `select null::record`, + expected: pgtype.Record{ + Status: pgtype.Null, + }, + }, + } + + for i, tt := range tests { + psName := fmt.Sprintf("test%d", i) + ps, err := conn.Prepare(psName, tt.sql) + if err != nil { + t.Fatal(err) + } + ps.FieldDescriptions[0].FormatCode = pgx.BinaryFormatCode + + var result pgtype.Record + if err := conn.QueryRow(psName).Scan(&result); err != nil { + t.Errorf("%d: %v", i, err) + continue + } + + if !reflect.DeepEqual(tt.expected, result) { + t.Errorf("%d: expected %v, got %v", i, tt.expected, result) + } + } +} + +func TestRecordAssignTo(t *testing.T) { + var valueSlice []pgtype.Value + var interfaceSlice []interface{} + + simpleTests := []struct { + src pgtype.Record + dst interface{} + expected interface{} + }{ + { + src: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + dst: &valueSlice, + expected: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + }, + { + src: pgtype.Record{ + Fields: []pgtype.Value{ + &pgtype.Text{String: "foo", Status: pgtype.Present}, + &pgtype.Int4{Int: 42, Status: pgtype.Present}, + }, + Status: pgtype.Present, + }, + dst: &interfaceSlice, + expected: []interface{}{"foo", int32(42)}, + }, + { + src: pgtype.Record{Status: pgtype.Null}, + dst: &valueSlice, + expected: (([]pgtype.Value)(nil)), + }, + { + src: pgtype.Record{Status: pgtype.Null}, + dst: &interfaceSlice, + expected: (([]interface{})(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/text.go b/pgtype/text.go index 3dd082c9..f1a76b6e 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -78,7 +78,7 @@ func (src *Text) AssignTo(dst interface{}) error { return nil } -func (dst *Text) DecodeText(src []byte) error { +func (dst *Text) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Text{Status: Null} return nil @@ -88,11 +88,11 @@ func (dst *Text) DecodeText(src []byte) error { return nil } -func (dst *Text) DecodeBinary(src []byte) error { - return dst.DecodeText(src) +func (dst *Text) DecodeBinary(ci *ConnInfo, src []byte) error { + return dst.DecodeText(ci, src) } -func (src Text) EncodeText(w io.Writer) (bool, error) { +func (src Text) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -104,6 +104,6 @@ func (src Text) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Text) EncodeBinary(w io.Writer) (bool, error) { - return src.EncodeText(w) +func (src Text) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.EncodeText(ci, w) } diff --git a/pgtype/text_array.go b/pgtype/text_array.go index 1e6677a9..6e89708f 100644 --- a/pgtype/text_array.go +++ b/pgtype/text_array.go @@ -83,7 +83,7 @@ func (src *TextArray) AssignTo(dst interface{}) error { return nil } -func (dst *TextArray) DecodeText(src []byte) error { +func (dst *TextArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = TextArray{Status: Null} return nil @@ -105,7 +105,7 @@ func (dst *TextArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -119,14 +119,14 @@ func (dst *TextArray) DecodeText(src []byte) error { return nil } -func (dst *TextArray) DecodeBinary(src []byte) error { +func (dst *TextArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = TextArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -151,7 +151,7 @@ func (dst *TextArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -161,7 +161,7 @@ func (dst *TextArray) DecodeBinary(src []byte) error { return nil } -func (src *TextArray) EncodeText(w io.Writer) (bool, error) { +func (src *TextArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -208,7 +208,7 @@ func (src *TextArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -237,11 +237,11 @@ func (src *TextArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *TextArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, TextOid) +func (src *TextArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, TextOid) } -func (src *TextArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *TextArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -261,7 +261,7 @@ func (src *TextArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -271,7 +271,7 @@ func (src *TextArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/tid.go b/pgtype/tid.go index 20d962df..b91711d3 100644 --- a/pgtype/tid.go +++ b/pgtype/tid.go @@ -46,7 +46,7 @@ func (src *Tid) AssignTo(dst interface{}) error { return fmt.Errorf("cannot assign %v to %T", src, dst) } -func (dst *Tid) DecodeText(src []byte) error { +func (dst *Tid) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Tid{Status: Null} return nil @@ -75,7 +75,7 @@ func (dst *Tid) DecodeText(src []byte) error { return nil } -func (dst *Tid) DecodeBinary(src []byte) error { +func (dst *Tid) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Tid{Status: Null} return nil @@ -93,7 +93,7 @@ func (dst *Tid) DecodeBinary(src []byte) error { return nil } -func (src Tid) EncodeText(w io.Writer) (bool, error) { +func (src Tid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -105,7 +105,7 @@ func (src Tid) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Tid) EncodeBinary(w io.Writer) (bool, error) { +func (src Tid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index 3bb8f080..9a9e74ea 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -85,7 +85,7 @@ func (src *Timestamp) AssignTo(dst interface{}) error { // DecodeText decodes from src into dst. The decoded time is considered to // be in UTC. -func (dst *Timestamp) DecodeText(src []byte) error { +func (dst *Timestamp) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Timestamp{Status: Null} return nil @@ -111,7 +111,7 @@ func (dst *Timestamp) DecodeText(src []byte) error { // DecodeBinary decodes from src into dst. The decoded time is considered to // be in UTC. -func (dst *Timestamp) DecodeBinary(src []byte) error { +func (dst *Timestamp) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Timestamp{Status: Null} return nil @@ -139,7 +139,7 @@ func (dst *Timestamp) DecodeBinary(src []byte) error { // EncodeText writes the text encoding of src into w. If src.Time is not in // the UTC time zone it returns an error. -func (src Timestamp) EncodeText(w io.Writer) (bool, error) { +func (src Timestamp) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -167,7 +167,7 @@ func (src Timestamp) EncodeText(w io.Writer) (bool, error) { // EncodeBinary writes the binary encoding of src into w. If src.Time is not in // the UTC time zone it returns an error. -func (src Timestamp) EncodeBinary(w io.Writer) (bool, error) { +func (src Timestamp) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/timestamp_array.go b/pgtype/timestamp_array.go index c955dc42..064ad483 100644 --- a/pgtype/timestamp_array.go +++ b/pgtype/timestamp_array.go @@ -84,7 +84,7 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { return nil } -func (dst *TimestampArray) DecodeText(src []byte) error { +func (dst *TimestampArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = TimestampArray{Status: Null} return nil @@ -106,7 +106,7 @@ func (dst *TimestampArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -120,14 +120,14 @@ func (dst *TimestampArray) DecodeText(src []byte) error { return nil } -func (dst *TimestampArray) DecodeBinary(src []byte) error { +func (dst *TimestampArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = TimestampArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -152,7 +152,7 @@ func (dst *TimestampArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -162,7 +162,7 @@ func (dst *TimestampArray) DecodeBinary(src []byte) error { return nil } -func (src *TimestampArray) EncodeText(w io.Writer) (bool, error) { +func (src *TimestampArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -209,7 +209,7 @@ func (src *TimestampArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -238,11 +238,11 @@ func (src *TimestampArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *TimestampArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, TimestampOid) +func (src *TimestampArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, TimestampOid) } -func (src *TimestampArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *TimestampArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -262,7 +262,7 @@ func (src *TimestampArray) encodeBinary(w io.Writer, elementOid int32) (bool, er } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -272,7 +272,7 @@ func (src *TimestampArray) encodeBinary(w io.Writer, elementOid int32) (bool, er for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index 5b9f5038..7f57f4b7 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -84,7 +84,7 @@ func (src *Timestamptz) AssignTo(dst interface{}) error { return nil } -func (dst *Timestamptz) DecodeText(src []byte) error { +func (dst *Timestamptz) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = Timestamptz{Status: Null} return nil @@ -117,7 +117,7 @@ func (dst *Timestamptz) DecodeText(src []byte) error { return nil } -func (dst *Timestamptz) DecodeBinary(src []byte) error { +func (dst *Timestamptz) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = Timestamptz{Status: Null} return nil @@ -143,7 +143,7 @@ func (dst *Timestamptz) DecodeBinary(src []byte) error { return nil } -func (src Timestamptz) EncodeText(w io.Writer) (bool, error) { +func (src Timestamptz) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -166,7 +166,7 @@ func (src Timestamptz) EncodeText(w io.Writer) (bool, error) { return false, err } -func (src Timestamptz) EncodeBinary(w io.Writer) (bool, error) { +func (src Timestamptz) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil diff --git a/pgtype/timestamptz_array.go b/pgtype/timestamptz_array.go index cd63e02e..4af1460b 100644 --- a/pgtype/timestamptz_array.go +++ b/pgtype/timestamptz_array.go @@ -84,7 +84,7 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { return nil } -func (dst *TimestamptzArray) DecodeText(src []byte) error { +func (dst *TimestamptzArray) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = TimestamptzArray{Status: Null} return nil @@ -106,7 +106,7 @@ func (dst *TimestamptzArray) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -120,14 +120,14 @@ func (dst *TimestamptzArray) DecodeText(src []byte) error { return nil } -func (dst *TimestamptzArray) DecodeBinary(src []byte) error { +func (dst *TimestamptzArray) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = TimestamptzArray{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -152,7 +152,7 @@ func (dst *TimestamptzArray) DecodeBinary(src []byte) error { elemSrc = src[rp : rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -162,7 +162,7 @@ func (dst *TimestamptzArray) DecodeBinary(src []byte) error { return nil } -func (src *TimestamptzArray) EncodeText(w io.Writer) (bool, error) { +func (src *TimestamptzArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -209,7 +209,7 @@ func (src *TimestamptzArray) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -238,11 +238,11 @@ func (src *TimestamptzArray) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *TimestamptzArray) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, TimestamptzOid) +func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, TimestamptzOid) } -func (src *TimestamptzArray) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *TimestamptzArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -262,7 +262,7 @@ func (src *TimestamptzArray) encodeBinary(w io.Writer, elementOid int32) (bool, } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -272,7 +272,7 @@ func (src *TimestamptzArray) encodeBinary(w io.Writer, elementOid int32) (bool, for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb index a56097c0..2a46a658 100644 --- a/pgtype/typed_array.go.erb +++ b/pgtype/typed_array.go.erb @@ -82,7 +82,7 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { return nil } -func (dst *<%= pgtype_array_type %>) DecodeText(src []byte) error { +func (dst *<%= pgtype_array_type %>) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { *dst = <%= pgtype_array_type %>{Status: Null} return nil @@ -104,7 +104,7 @@ func (dst *<%= pgtype_array_type %>) DecodeText(src []byte) error { if s != "NULL" { elemSrc = []byte(s) } - err = elem.DecodeText(elemSrc) + err = elem.DecodeText(ci, elemSrc) if err != nil { return err } @@ -118,14 +118,14 @@ func (dst *<%= pgtype_array_type %>) DecodeText(src []byte) error { return nil } -func (dst *<%= pgtype_array_type %>) DecodeBinary(src []byte) error { +func (dst *<%= pgtype_array_type %>) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { *dst = <%= pgtype_array_type %>{Status: Null} return nil } var arrayHeader ArrayHeader - rp, err := arrayHeader.DecodeBinary(src) + rp, err := arrayHeader.DecodeBinary(ci, src) if err != nil { return err } @@ -150,7 +150,7 @@ func (dst *<%= pgtype_array_type %>) DecodeBinary(src []byte) error { elemSrc = src[rp:rp+elemLen] rp += elemLen } - err = elements[i].DecodeBinary(elemSrc) + err = elements[i].DecodeBinary(ci, elemSrc) if err != nil { return err } @@ -160,7 +160,7 @@ func (dst *<%= pgtype_array_type %>) DecodeBinary(src []byte) error { return nil } -func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) (bool, error) { +func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { switch src.Status { case Null: return true, nil @@ -207,7 +207,7 @@ func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) (bool, error) { } elemBuf := &bytes.Buffer{} - null, err := elem.EncodeText(elemBuf) + null, err := elem.EncodeText(ci, elemBuf) if err != nil { return false, err } @@ -236,11 +236,11 @@ func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) (bool, error) { return false, nil } -func (src *<%= pgtype_array_type %>) EncodeBinary(w io.Writer) (bool, error) { - return src.encodeBinary(w, <%= element_oid %>) +func (src *<%= pgtype_array_type %>) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, <%= element_oid %>) } -func (src *<%= pgtype_array_type %>) encodeBinary(w io.Writer, elementOid int32) (bool, error) { +func (src *<%= pgtype_array_type %>) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -260,7 +260,7 @@ func (src *<%= pgtype_array_type %>) encodeBinary(w io.Writer, elementOid int32) } } - err := arrayHeader.EncodeBinary(w) + err := arrayHeader.EncodeBinary(ci, w) if err != nil { return false, err } @@ -270,7 +270,7 @@ func (src *<%= pgtype_array_type %>) encodeBinary(w io.Writer, elementOid int32) for i := range src.Elements { elemBuf.Reset() - null, err := src.Elements[i].EncodeBinary(elemBuf) + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) if err != nil { return false, err } diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh index 41c1313f..5fde32aa 100644 --- a/pgtype/typed_array_gen.sh +++ b/pgtype/typed_array_gen.sh @@ -8,6 +8,8 @@ erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_type erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_oid=Float4Oid text_null=NULL typed_array.go.erb > float4_array.go erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8Oid text_null=NULL typed_array.go.erb > float8_array.go erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOid text_null=NULL typed_array.go.erb > inet_array.go +erb pgtype_array_type=CidrArray pgtype_element_type=Cidr go_array_types=[]*net.IPNet,[]net.IP element_oid=CidrOid text_null=NULL typed_array.go.erb > cidr_array.go erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_oid=TextOid text_null='"NULL"' typed_array.go.erb > text_array.go +erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string element_oid=VarcharOid text_null='"NULL"' typed_array.go.erb > varchar_array.go erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_oid=ByteaOid text_null=NULL typed_array.go.erb > bytea_array.go erb pgtype_array_type=AclitemArray pgtype_element_type=Aclitem go_array_types=[]string element_oid=AclitemOid text_null=NULL typed_array.go.erb > aclitem_array.go diff --git a/pgtype/unknown.go b/pgtype/unknown.go new file mode 100644 index 00000000..b951ad99 --- /dev/null +++ b/pgtype/unknown.go @@ -0,0 +1,32 @@ +package pgtype + +// Unknown represents the PostgreSQL unknown type. It is either a string literal +// or NULL. It is used when PostgreSQL does not know the type of a value. In +// general, this will only be used in pgx when selecting a null value without +// type information. e.g. SELECT NULL; +type Unknown struct { + String string + Status Status +} + +func (dst *Unknown) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +func (dst *Unknown) Get() interface{} { + return (*Text)(dst).Get() +} + +// AssignTo assigns from src to dst. Note that as Unknown is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *Unknown) AssignTo(dst interface{}) error { + return (*Text)(src).AssignTo(dst) +} + +func (dst *Unknown) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) +} + +func (dst *Unknown) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeBinary(ci, src) +} diff --git a/pgtype/varchar.go b/pgtype/varchar.go new file mode 100644 index 00000000..adda6c49 --- /dev/null +++ b/pgtype/varchar.go @@ -0,0 +1,40 @@ +package pgtype + +import ( + "io" +) + +type Varchar Text + +// Set converts from src to dst. Note that as Varchar is not a general +// number type Set does not do automatic type conversion as other number +// types do. +func (dst *Varchar) Set(src interface{}) error { + return (*Text)(dst).Set(src) +} + +func (dst *Varchar) Get() interface{} { + return (*Text)(dst).Get() +} + +// AssignTo assigns from src to dst. Note that as Varchar is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *Varchar) AssignTo(dst interface{}) error { + return (*Text)(src).AssignTo(dst) +} + +func (dst *Varchar) DecodeText(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeText(ci, src) +} + +func (dst *Varchar) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*Text)(dst).DecodeBinary(ci, src) +} + +func (src Varchar) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (Text)(src).EncodeText(ci, w) +} + +func (src Varchar) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (Text)(src).EncodeBinary(ci, w) +} diff --git a/pgtype/varchar_array.go b/pgtype/varchar_array.go index 693b9a61..21e9ccff 100644 --- a/pgtype/varchar_array.go +++ b/pgtype/varchar_array.go @@ -1,35 +1,296 @@ package pgtype import ( + "bytes" + "encoding/binary" + "fmt" "io" + + "github.com/jackc/pgx/pgio" ) -type VarcharArray TextArray +type VarcharArray struct { + Elements []Varchar + Dimensions []ArrayDimension + Status Status +} func (dst *VarcharArray) Set(src interface{}) error { - return (*TextArray)(dst).Set(src) + switch value := src.(type) { + + case []string: + if value == nil { + *dst = VarcharArray{Status: Null} + } else if len(value) == 0 { + *dst = VarcharArray{Status: Present} + } else { + elements := make([]Varchar, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = VarcharArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Varchar", value) + } + + return nil } func (dst *VarcharArray) Get() interface{} { - return (*TextArray)(dst).Get() + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } } func (src *VarcharArray) AssignTo(dst interface{}) error { - return (*TextArray)(src).AssignTo(dst) + switch v := dst.(type) { + + case *[]string: + if src.Status == Present { + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil } -func (dst *VarcharArray) DecodeText(src []byte) error { - return (*TextArray)(dst).DecodeText(src) +func (dst *VarcharArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = VarcharArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Varchar + + if len(uta.Elements) > 0 { + elements = make([]Varchar, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Varchar + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = VarcharArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil } -func (dst *VarcharArray) DecodeBinary(src []byte) error { - return (*TextArray)(dst).DecodeBinary(src) +func (dst *VarcharArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = VarcharArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = VarcharArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Varchar, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = VarcharArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil } -func (src *VarcharArray) EncodeText(w io.Writer) (bool, error) { - return (*TextArray)(src).EncodeText(w) +func (src *VarcharArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if len(src.Dimensions) == 0 { + _, err := io.WriteString(w, "{}") + return false, err + } + + err := EncodeTextArrayDimensions(w, src.Dimensions) + if err != nil { + return false, err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(w, ',') + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(w, '{') + if err != nil { + return false, err + } + } + } + + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = io.WriteString(w, `"NULL"`) + if err != nil { + return false, err + } + } else { + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(w, '}') + if err != nil { + return false, err + } + } + } + } + + return false, nil } -func (src *VarcharArray) EncodeBinary(w io.Writer) (bool, error) { - return (*TextArray)(src).encodeBinary(w, VarcharOid) +func (src *VarcharArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return src.encodeBinary(ci, w, VarcharOid) +} + +func (src *VarcharArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + arrayHeader := ArrayHeader{ + ElementOid: elementOid, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(ci, w) + if err != nil { + return false, err + } + + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } + } + } + + return false, err } diff --git a/pgtype/varchar_array_test.go b/pgtype/varchar_array_test.go new file mode 100644 index 00000000..4a8b09b8 --- /dev/null +++ b/pgtype/varchar_array_test.go @@ -0,0 +1,151 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestVarcharArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "varchar[]", []interface{}{ + &pgtype.VarcharArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + pgtype.Varchar{String: "foo", Status: pgtype.Present}, + pgtype.Varchar{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.VarcharArray{Status: pgtype.Null}, + &pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + pgtype.Varchar{String: "bar ", Status: pgtype.Present}, + pgtype.Varchar{String: "NuLL", Status: pgtype.Present}, + pgtype.Varchar{String: `wow"quz\`, Status: pgtype.Present}, + pgtype.Varchar{String: "", Status: pgtype.Present}, + pgtype.Varchar{Status: pgtype.Null}, + pgtype.Varchar{String: "null", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.VarcharArray{ + Elements: []pgtype.Varchar{ + pgtype.Varchar{String: "bar", Status: pgtype.Present}, + pgtype.Varchar{String: "baz", Status: pgtype.Present}, + pgtype.Varchar{String: "quz", Status: pgtype.Present}, + pgtype.Varchar{String: "foo", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestVarcharArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.VarcharArray + }{ + { + source: []string{"foo"}, + result: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]string)(nil)), + result: pgtype.VarcharArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.VarcharArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestVarcharArrayAssignTo(t *testing.T) { + var stringSlice []string + type _stringSlice []string + var namedStringSlice _stringSlice + + simpleTests := []struct { + src pgtype.VarcharArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + expected: []string{"foo"}, + }, + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedStringSlice, + expected: _stringSlice{"bar"}, + }, + { + src: pgtype.VarcharArray{Status: pgtype.Null}, + dst: &stringSlice, + expected: (([]string)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.VarcharArray + dst interface{} + }{ + { + src: pgtype.VarcharArray{ + Elements: []pgtype.Varchar{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/xid.go b/pgtype/xid.go index a53120de..c76548a4 100644 --- a/pgtype/xid.go +++ b/pgtype/xid.go @@ -37,18 +37,18 @@ func (src *Xid) AssignTo(dst interface{}) error { return (*pguint32)(src).AssignTo(dst) } -func (dst *Xid) DecodeText(src []byte) error { - return (*pguint32)(dst).DecodeText(src) +func (dst *Xid) DecodeText(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeText(ci, src) } -func (dst *Xid) DecodeBinary(src []byte) error { - return (*pguint32)(dst).DecodeBinary(src) +func (dst *Xid) DecodeBinary(ci *ConnInfo, src []byte) error { + return (*pguint32)(dst).DecodeBinary(ci, src) } -func (src Xid) EncodeText(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeText(w) +func (src Xid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeText(ci, w) } -func (src Xid) EncodeBinary(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeBinary(w) +func (src Xid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + return (pguint32)(src).EncodeBinary(ci, w) } diff --git a/query.go b/query.go index 63ce91ed..48a657f9 100644 --- a/query.go +++ b/query.go @@ -212,74 +212,86 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { } } } else if s, ok := d.(pgtype.BinaryDecoder); ok && vr.Type().FormatCode == BinaryFormatCode { - err = s.DecodeBinary(vr.bytes()) + err = s.DecodeBinary(rows.conn.ConnInfo, vr.bytes()) if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } } else if s, ok := d.(pgtype.TextDecoder); ok && vr.Type().FormatCode == TextFormatCode { - err = s.DecodeText(vr.bytes()) + err = s.DecodeText(rows.conn.ConnInfo, vr.bytes()) if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } } else if s, ok := d.(sql.Scanner); ok { - var val interface{} + var sqlSrc interface{} if 0 <= vr.Len() { - switch vr.Type().DataType { - case BoolOid: - val = decodeBool(vr) - case Int8Oid: - val = int64(decodeInt8(vr)) - case Int2Oid: - val = int64(decodeInt2(vr)) - case Int4Oid: - val = int64(decodeInt4(vr)) - case TextOid, VarcharOid: - val = decodeText(vr) - case Float4Oid: - val = float64(decodeFloat4(vr)) - case Float8Oid: - val = decodeFloat8(vr) - case DateOid: - val = decodeDate(vr) - case TimestampOid: - val = decodeTimestamp(vr) - case TimestampTzOid: - val = decodeTimestampTz(vr) - default: - val = vr.ReadBytes(vr.Len()) + if dt, ok := rows.conn.ConnInfo.DataTypeForOid(vr.Type().DataType); ok { + value := dt.Value + + switch vr.Type().FormatCode { + case TextFormatCode: + decoder := value.(pgtype.TextDecoder) + if decoder == nil { + decoder = &pgtype.GenericText{} + } + err := decoder.DecodeText(rows.conn.ConnInfo, vr.bytes()) + if err != nil { + rows.Fatal(err) + } + case BinaryFormatCode: + decoder := value.(pgtype.BinaryDecoder) + if decoder == nil { + decoder = &pgtype.GenericBinary{} + } + err := decoder.DecodeBinary(rows.conn.ConnInfo, vr.bytes()) + if err != nil { + rows.Fatal(err) + } + default: + rows.Fatal(errors.New("Unknown format code")) + } + + sqlSrc, err = pgtype.DatabaseSQLValue(rows.conn.ConnInfo, value) + if err != nil { + rows.Fatal(err) + } + } else { + rows.Fatal(errors.New("Unknown type")) } } - err = s.Scan(val) + err = s.Scan(sqlSrc) if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } } else { - if pgVal, present := rows.conn.oidPgtypeValues[vr.Type().DataType]; present { + if dt, ok := rows.conn.ConnInfo.DataTypeForOid(vr.Type().DataType); ok { + value := dt.Value switch vr.Type().FormatCode { case TextFormatCode: - if textDecoder, ok := pgVal.(pgtype.TextDecoder); ok { - err = textDecoder.DecodeText(vr.bytes()) + if textDecoder, ok := value.(pgtype.TextDecoder); ok { + err = textDecoder.DecodeText(rows.conn.ConnInfo, vr.bytes()) if err != nil { vr.Fatal(err) } } else { - vr.Fatal(fmt.Errorf("%T is not a pgtype.TextDecoder", pgVal)) + vr.Fatal(fmt.Errorf("%T is not a pgtype.TextDecoder", value)) } case BinaryFormatCode: - if binaryDecoder, ok := pgVal.(pgtype.BinaryDecoder); ok { - err = binaryDecoder.DecodeBinary(vr.bytes()) + if binaryDecoder, ok := value.(pgtype.BinaryDecoder); ok { + err = binaryDecoder.DecodeBinary(rows.conn.ConnInfo, vr.bytes()) if err != nil { vr.Fatal(err) } } else { - vr.Fatal(fmt.Errorf("%T is not a pgtype.BinaryDecoder", pgVal)) + vr.Fatal(fmt.Errorf("%T is not a pgtype.BinaryDecoder", value)) } default: vr.Fatal(fmt.Errorf("unknown format code: %v", vr.Type().FormatCode)) } - if err := pgVal.AssignTo(d); err != nil { - vr.Fatal(err) + if vr.Err() == nil { + if err := value.AssignTo(d); err != nil { + vr.Fatal(err) + } } } else { if err := Decode(vr, d); err != nil { @@ -315,29 +327,35 @@ func (rows *Rows) Values() ([]interface{}, error) { continue } - switch vr.Type().FormatCode { - case TextFormatCode: - decoder := rows.conn.oidPgtypeValues[vr.Type().DataType].(pgtype.TextDecoder) - if decoder == nil { - decoder = &pgtype.GenericText{} + if dt, ok := rows.conn.ConnInfo.DataTypeForOid(vr.Type().DataType); ok { + value := dt.Value + + switch vr.Type().FormatCode { + case TextFormatCode: + decoder := value.(pgtype.TextDecoder) + if decoder == nil { + decoder = &pgtype.GenericText{} + } + err := decoder.DecodeText(rows.conn.ConnInfo, vr.bytes()) + if err != nil { + rows.Fatal(err) + } + values = append(values, decoder.(pgtype.Value).Get()) + case BinaryFormatCode: + decoder := value.(pgtype.BinaryDecoder) + if decoder == nil { + decoder = &pgtype.GenericBinary{} + } + err := decoder.DecodeBinary(rows.conn.ConnInfo, vr.bytes()) + if err != nil { + rows.Fatal(err) + } + values = append(values, value.Get()) + default: + rows.Fatal(errors.New("Unknown format code")) } - err := decoder.DecodeText(vr.bytes()) - if err != nil { - rows.Fatal(err) - } - values = append(values, decoder.(pgtype.Value).Get()) - case BinaryFormatCode: - decoder := rows.conn.oidPgtypeValues[vr.Type().DataType].(pgtype.BinaryDecoder) - if decoder == nil { - decoder = &pgtype.GenericBinary{} - } - err := decoder.DecodeBinary(vr.bytes()) - if err != nil { - rows.Fatal(err) - } - values = append(values, decoder.(pgtype.Value).Get()) - default: - rows.Fatal(errors.New("Unknown format code")) + } else { + rows.Fatal(errors.New("Unknown type")) } if vr.Err() != nil { @@ -368,49 +386,41 @@ func (rows *Rows) ValuesForStdlib() ([]interface{}, error) { values = append(values, nil) continue } - // TODO - consider what are the implications of returning complex types since database/sql uses this method - switch vr.Type().FormatCode { - // All intrinsic types (except string) are encoded with binary - // encoding so anything else should be treated as a string - case TextFormatCode: - values = append(values, vr.ReadString(vr.Len())) - case BinaryFormatCode: - switch vr.Type().DataType { - case TextOid, VarcharOid: - values = append(values, decodeText(vr)) - case BoolOid: - values = append(values, decodeBool(vr)) - case ByteaOid: - values = append(values, decodeBytea(vr)) - case Int8Oid: - values = append(values, decodeInt8(vr)) - case Int2Oid: - values = append(values, decodeInt2(vr)) - case Int4Oid: - values = append(values, decodeInt4(vr)) - case Float4Oid: - values = append(values, decodeFloat4(vr)) - case Float8Oid: - values = append(values, decodeFloat8(vr)) - case DateOid: - values = append(values, decodeDate(vr)) - case TimestampTzOid: - values = append(values, decodeTimestampTz(vr)) - case TimestampOid: - values = append(values, decodeTimestamp(vr)) - case JsonOid: - var d interface{} - decodeJSON(vr, &d) - values = append(values, d) - case JsonbOid: - var d interface{} - decodeJSONB(vr, &d) - values = append(values, d) + + if dt, ok := rows.conn.ConnInfo.DataTypeForOid(vr.Type().DataType); ok { + value := dt.Value + + switch vr.Type().FormatCode { + case TextFormatCode: + decoder := value.(pgtype.TextDecoder) + if decoder == nil { + decoder = &pgtype.GenericText{} + } + err := decoder.DecodeText(rows.conn.ConnInfo, vr.bytes()) + if err != nil { + rows.Fatal(err) + } + case BinaryFormatCode: + decoder := value.(pgtype.BinaryDecoder) + if decoder == nil { + decoder = &pgtype.GenericBinary{} + } + err := decoder.DecodeBinary(rows.conn.ConnInfo, vr.bytes()) + if err != nil { + rows.Fatal(err) + } default: - rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types")) + rows.Fatal(errors.New("Unknown format code")) } - default: - rows.Fatal(errors.New("Unknown format code")) + + sqlSrc, err := pgtype.DatabaseSQLValue(rows.conn.ConnInfo, value) + if err != nil { + rows.Fatal(err) + } + + values = append(values, sqlSrc) + } else { + rows.Fatal(errors.New("Unknown type")) } if vr.Err() != nil { diff --git a/query_test.go b/query_test.go index 01889444..480959e8 100644 --- a/query_test.go +++ b/query_test.go @@ -776,7 +776,7 @@ func TestQueryRowErrors(t *testing.T) { {"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`}, {"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"}, {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "cannot decode"}, - {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "Cannot encode int into oid 600"}, + {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "cannot convert 705 to Text"}, } for i, tt := range tests { diff --git a/values.go b/values.go index d90c363b..4eb24eef 100644 --- a/values.go +++ b/values.go @@ -5,9 +5,7 @@ import ( "database/sql/driver" "encoding/json" "fmt" - "math" "reflect" - "time" "github.com/jackc/pgx/pgtype" ) @@ -167,7 +165,7 @@ func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error { switch arg := arg.(type) { case pgtype.BinaryEncoder: buf := &bytes.Buffer{} - null, err := arg.EncodeBinary(buf) + null, err := arg.EncodeBinary(wbuf.conn.ConnInfo, buf) if err != nil { return err } @@ -180,7 +178,7 @@ func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error { return nil case pgtype.TextEncoder: buf := &bytes.Buffer{} - null, err := arg.EncodeText(buf) + null, err := arg.EncodeText(wbuf.conn.ConnInfo, buf) if err != nil { return err } @@ -214,14 +212,15 @@ func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error { return Encode(wbuf, oid, arg) } - if value, ok := wbuf.conn.oidPgtypeValues[oid]; ok { + if dt, ok := wbuf.conn.ConnInfo.DataTypeForOid(oid); ok { + value := dt.Value err := value.Set(arg) if err != nil { return err } buf := &bytes.Buffer{} - null, err := value.(pgtype.BinaryEncoder).EncodeBinary(buf) + null, err := value.(pgtype.BinaryEncoder).EncodeBinary(wbuf.conn.ConnInfo, buf) if err != nil { return err } @@ -287,8 +286,6 @@ func Decode(vr *ValueReader, d interface{}) error { switch v := d.(type) { case *string: *v = decodeText(vr) - case *[]interface{}: - *v = decodeRecord(vr) default: if v := reflect.ValueOf(d); v.Kind() == reflect.Ptr { el := v.Elem() @@ -320,232 +317,6 @@ func Decode(vr *ValueReader, d interface{}) error { return nil } -func decodeBool(vr *ValueReader) bool { - if vr.Type().DataType != BoolOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into bool", vr.Type().DataType))) - return false - } - - var b pgtype.Bool - var err error - switch vr.Type().FormatCode { - case TextFormatCode: - err = b.DecodeText(vr.bytes()) - case BinaryFormatCode: - err = b.DecodeBinary(vr.bytes()) - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return false - } - - if err != nil { - vr.Fatal(err) - return false - } - - if b.Status != pgtype.Present { - vr.Fatal(fmt.Errorf("Cannot decode null into bool")) - return false - } - - return b.Bool -} - -func decodeInt(vr *ValueReader) int64 { - switch vr.Type().DataType { - case Int2Oid: - return int64(decodeInt2(vr)) - case Int4Oid: - return int64(decodeInt4(vr)) - case Int8Oid: - return int64(decodeInt8(vr)) - } - - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into any integer type", vr.Type().DataType))) - return 0 -} - -func decodeInt8(vr *ValueReader) int64 { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into int64")) - return 0 - } - - if vr.Type().DataType != Int8Oid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int8", vr.Type().DataType))) - return 0 - } - - var n pgtype.Int8 - var err error - switch vr.Type().FormatCode { - case TextFormatCode: - err = n.DecodeText(vr.bytes()) - case BinaryFormatCode: - err = n.DecodeBinary(vr.bytes()) - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return 0 - } - - if err != nil { - vr.Fatal(err) - return 0 - } - - if n.Status == pgtype.Null { - vr.Fatal(ProtocolError("Cannot decode null into int16")) - return 0 - } - - return n.Int -} - -func decodeInt2(vr *ValueReader) int16 { - - if vr.Type().DataType != Int2Oid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int16", vr.Type().DataType))) - return 0 - } - - var n pgtype.Int2 - var err error - switch vr.Type().FormatCode { - case TextFormatCode: - err = n.DecodeText(vr.bytes()) - case BinaryFormatCode: - err = n.DecodeBinary(vr.bytes()) - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return 0 - } - - if err != nil { - vr.Fatal(err) - return 0 - } - - if n.Status == pgtype.Null { - vr.Fatal(ProtocolError("Cannot decode null into int16")) - return 0 - } - - return n.Int -} - -func decodeInt4(vr *ValueReader) int32 { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into int32")) - return 0 - } - - if vr.Type().DataType != Int4Oid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into int32", vr.Type().DataType))) - return 0 - } - - var n pgtype.Int4 - var err error - switch vr.Type().FormatCode { - case TextFormatCode: - err = n.DecodeText(vr.bytes()) - case BinaryFormatCode: - err = n.DecodeBinary(vr.bytes()) - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return 0 - } - - if err != nil { - vr.Fatal(err) - return 0 - } - - if n.Status == pgtype.Null { - vr.Fatal(ProtocolError("Cannot decode null into int16")) - return 0 - } - - return n.Int -} - -func decodeFloat4(vr *ValueReader) float32 { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into float32")) - return 0 - } - - if vr.Type().DataType != Float4Oid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into float32", vr.Type().DataType))) - return 0 - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return 0 - } - - if vr.Len() != 4 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", vr.Len()))) - return 0 - } - - i := vr.ReadInt32() - return math.Float32frombits(uint32(i)) -} - -func encodeFloat32(w *WriteBuf, oid pgtype.Oid, value float32) error { - switch oid { - case Float4Oid: - w.WriteInt32(4) - w.WriteInt32(int32(math.Float32bits(value))) - case Float8Oid: - w.WriteInt32(8) - w.WriteInt64(int64(math.Float64bits(float64(value)))) - default: - return fmt.Errorf("cannot encode %s into oid %v", "float32", oid) - } - - return nil -} - -func decodeFloat8(vr *ValueReader) float64 { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into float64")) - return 0 - } - - if vr.Type().DataType != Float8Oid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into float64", vr.Type().DataType))) - return 0 - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return 0 - } - - if vr.Len() != 8 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float8: %d", vr.Len()))) - return 0 - } - - i := vr.ReadInt64() - return math.Float64frombits(uint64(i)) -} - -func encodeFloat64(w *WriteBuf, oid pgtype.Oid, value float64) error { - switch oid { - case Float8Oid: - w.WriteInt32(8) - w.WriteInt64(int64(math.Float64bits(value))) - default: - return fmt.Errorf("cannot encode %s into oid %v", "float64", oid) - } - - return nil -} - func decodeText(vr *ValueReader) string { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into string")) @@ -677,215 +448,3 @@ func encodeJSONB(w *WriteBuf, oid pgtype.Oid, value interface{}) error { return nil } - -func decodeDate(vr *ValueReader) time.Time { - if vr.Type().DataType != DateOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) - return time.Time{} - } - - var d pgtype.Date - var err error - switch vr.Type().FormatCode { - case TextFormatCode: - err = d.DecodeText(vr.bytes()) - case BinaryFormatCode: - err = d.DecodeBinary(vr.bytes()) - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return time.Time{} - } - - if err != nil { - vr.Fatal(err) - return time.Time{} - } - - if d.Status == pgtype.Null { - vr.Fatal(ProtocolError("Cannot decode null into int16")) - return time.Time{} - } - - return d.Time -} - -func encodeTime(w *WriteBuf, oid pgtype.Oid, value time.Time) error { - switch oid { - case DateOid: - var d pgtype.Date - err := d.Set(value) - if err != nil { - return err - } - - buf := &bytes.Buffer{} - null, err := d.EncodeBinary(buf) - if err != nil { - return err - } - if null { - w.WriteInt32(-1) - } else { - w.WriteInt32(int32(buf.Len())) - w.WriteBytes(buf.Bytes()) - } - return nil - - case TimestampTzOid, TimestampOid: - var t pgtype.Timestamptz - err := t.Set(value) - if err != nil { - return err - } - - buf := &bytes.Buffer{} - null, err := t.EncodeBinary(buf) - if err != nil { - return err - } - if null { - w.WriteInt32(-1) - } else { - w.WriteInt32(int32(buf.Len())) - w.WriteBytes(buf.Bytes()) - } - return nil - default: - return fmt.Errorf("cannot encode %s into oid %v", "time.Time", oid) - } -} - -const microsecFromUnixEpochToY2K = 946684800 * 1000000 - -func decodeTimestampTz(vr *ValueReader) time.Time { - var zeroTime time.Time - - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into time.Time")) - return zeroTime - } - - if vr.Type().DataType != TimestampTzOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) - return zeroTime - } - - var t pgtype.Timestamptz - var err error - switch vr.Type().FormatCode { - case TextFormatCode: - err = t.DecodeText(vr.bytes()) - case BinaryFormatCode: - err = t.DecodeBinary(vr.bytes()) - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return time.Time{} - } - - if err != nil { - vr.Fatal(err) - return time.Time{} - } - - if t.Status == pgtype.Null { - vr.Fatal(ProtocolError("Cannot decode null into time.Time")) - return time.Time{} - } - - return t.Time -} - -func decodeTimestamp(vr *ValueReader) time.Time { - var zeroTime time.Time - - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into timestamp")) - return zeroTime - } - - if vr.Type().DataType != TimestampOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) - return zeroTime - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return zeroTime - } - - if vr.Len() != 8 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an timestamp: %d", vr.Len()))) - return zeroTime - } - - microsecSinceY2K := vr.ReadInt64() - microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K - return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) -} - -func decodeRecord(vr *ValueReader) []interface{} { - if vr.Len() == -1 { - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - if vr.Type().DataType != RecordOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []interface{}", vr.Type().DataType))) - return nil - } - - valueCount := vr.ReadInt32() - record := make([]interface{}, 0, int(valueCount)) - - for i := int32(0); i < valueCount; i++ { - fd := FieldDescription{FormatCode: BinaryFormatCode} - fieldVR := ValueReader{mr: vr.mr, fd: &fd} - fd.DataType = vr.ReadOid() - fieldVR.valueBytesRemaining = vr.ReadInt32() - vr.valueBytesRemaining -= fieldVR.valueBytesRemaining - - switch fd.DataType { - case BoolOid: - record = append(record, decodeBool(&fieldVR)) - case ByteaOid: - record = append(record, decodeBytea(&fieldVR)) - case Int8Oid: - record = append(record, decodeInt8(&fieldVR)) - case Int2Oid: - record = append(record, decodeInt2(&fieldVR)) - case Int4Oid: - record = append(record, decodeInt4(&fieldVR)) - case Float4Oid: - record = append(record, decodeFloat4(&fieldVR)) - case Float8Oid: - record = append(record, decodeFloat8(&fieldVR)) - case DateOid: - record = append(record, decodeDate(&fieldVR)) - case TimestampTzOid: - record = append(record, decodeTimestampTz(&fieldVR)) - case TimestampOid: - record = append(record, decodeTimestamp(&fieldVR)) - case TextOid, VarcharOid, UnknownOid: - record = append(record, decodeTextAllowBinary(&fieldVR)) - default: - vr.Fatal(fmt.Errorf("decodeRecord cannot decode oid %d", fd.DataType)) - return nil - } - - // Consume any remaining data - if fieldVR.Len() > 0 { - fieldVR.ReadBytes(fieldVR.Len()) - } - - if fieldVR.Err() != nil { - vr.Fatal(fieldVR.Err()) - return nil - } - } - - return record -} diff --git a/values_test.go b/values_test.go index e7ae7e1d..1d09eb18 100644 --- a/values_test.go +++ b/values_test.go @@ -6,9 +6,6 @@ import ( "reflect" "testing" "time" - - "github.com/jackc/pgx" - "github.com/jackc/pgx/pgtype" ) func TestDateTranscode(t *testing.T) { @@ -78,159 +75,161 @@ func TestTimestampTzTranscode(t *testing.T) { } } -func TestJSONAndJSONBTranscode(t *testing.T) { - t.Parallel() +// TODO - move these tests to pgtype - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) +// func TestJSONAndJSONBTranscode(t *testing.T) { +// t.Parallel() - for _, oid := range []pgtype.Oid{pgx.JsonOid, pgx.JsonbOid} { - if _, ok := conn.PgTypes[oid]; !ok { - return // No JSON/JSONB type -- must be running against old PostgreSQL - } +// conn := mustConnect(t, *defaultConnConfig) +// defer closeConn(t, conn) - for _, format := range []int16{pgx.TextFormatCode, pgx.BinaryFormatCode} { - pgtype := conn.PgTypes[oid] - pgtype.DefaultFormat = format - conn.PgTypes[oid] = pgtype +// for _, oid := range []pgtype.Oid{pgx.JsonOid, pgx.JsonbOid} { +// if _, ok := conn.ConnInfo.DataTypeForOid(oid); !ok { +// return // No JSON/JSONB type -- must be running against old PostgreSQL +// } - typename := conn.PgTypes[oid].Name +// for _, format := range []int16{pgx.TextFormatCode, pgx.BinaryFormatCode} { +// pgtype := conn.PgTypes[oid] +// pgtype.DefaultFormat = format +// conn.PgTypes[oid] = pgtype - testJSONString(t, conn, typename, format) - testJSONStringPointer(t, conn, typename, format) - testJSONSingleLevelStringMap(t, conn, typename, format) - testJSONNestedMap(t, conn, typename, format) - testJSONStringArray(t, conn, typename, format) - testJSONInt64Array(t, conn, typename, format) - testJSONInt16ArrayFailureDueToOverflow(t, conn, typename, format) - testJSONStruct(t, conn, typename, format) - } - } -} +// typename := conn.PgTypes[oid].Name -func testJSONString(t *testing.T, conn *pgx.Conn, typename string, format int16) { - input := `{"key": "value"}` - expectedOutput := map[string]string{"key": "value"} - var output map[string]string - err := conn.QueryRow("select $1::"+typename, input).Scan(&output) - if err != nil { - t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) - return - } +// testJSONString(t, conn, typename, format) +// testJSONStringPointer(t, conn, typename, format) +// testJSONSingleLevelStringMap(t, conn, typename, format) +// testJSONNestedMap(t, conn, typename, format) +// testJSONStringArray(t, conn, typename, format) +// testJSONInt64Array(t, conn, typename, format) +// testJSONInt16ArrayFailureDueToOverflow(t, conn, typename, format) +// testJSONStruct(t, conn, typename, format) +// } +// } +// } - if !reflect.DeepEqual(expectedOutput, output) { - t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, expectedOutput, output) - return - } -} +// func testJSONString(t *testing.T, conn *pgx.Conn, typename string, format int16) { +// input := `{"key": "value"}` +// expectedOutput := map[string]string{"key": "value"} +// var output map[string]string +// err := conn.QueryRow("select $1::"+typename, input).Scan(&output) +// if err != nil { +// t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) +// return +// } -func testJSONStringPointer(t *testing.T, conn *pgx.Conn, typename string, format int16) { - input := `{"key": "value"}` - expectedOutput := map[string]string{"key": "value"} - var output map[string]string - err := conn.QueryRow("select $1::"+typename, &input).Scan(&output) - if err != nil { - t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) - return - } +// if !reflect.DeepEqual(expectedOutput, output) { +// t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, expectedOutput, output) +// return +// } +// } - if !reflect.DeepEqual(expectedOutput, output) { - t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, expectedOutput, output) - return - } -} +// func testJSONStringPointer(t *testing.T, conn *pgx.Conn, typename string, format int16) { +// input := `{"key": "value"}` +// expectedOutput := map[string]string{"key": "value"} +// var output map[string]string +// err := conn.QueryRow("select $1::"+typename, &input).Scan(&output) +// if err != nil { +// t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) +// return +// } -func testJSONSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string, format int16) { - input := map[string]string{"key": "value"} - var output map[string]string - err := conn.QueryRow("select $1::"+typename, input).Scan(&output) - if err != nil { - t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) - return - } +// if !reflect.DeepEqual(expectedOutput, output) { +// t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, expectedOutput, output) +// return +// } +// } - if !reflect.DeepEqual(input, output) { - t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, input, output) - return - } -} +// func testJSONSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string, format int16) { +// input := map[string]string{"key": "value"} +// var output map[string]string +// err := conn.QueryRow("select $1::"+typename, input).Scan(&output) +// if err != nil { +// t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) +// return +// } -func testJSONNestedMap(t *testing.T, conn *pgx.Conn, typename string, format int16) { - input := map[string]interface{}{ - "name": "Uncanny", - "stats": map[string]interface{}{"hp": float64(107), "maxhp": float64(150)}, - "inventory": []interface{}{"phone", "key"}, - } - var output map[string]interface{} - err := conn.QueryRow("select $1::"+typename, input).Scan(&output) - if err != nil { - t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) - return - } +// if !reflect.DeepEqual(input, output) { +// t.Errorf("%s %d: Did not transcode map[string]string successfully: %v is not %v", typename, format, input, output) +// return +// } +// } - if !reflect.DeepEqual(input, output) { - t.Errorf("%s %d: Did not transcode map[string]interface{} successfully: %v is not %v", typename, format, input, output) - return - } -} +// func testJSONNestedMap(t *testing.T, conn *pgx.Conn, typename string, format int16) { +// input := map[string]interface{}{ +// "name": "Uncanny", +// "stats": map[string]interface{}{"hp": float64(107), "maxhp": float64(150)}, +// "inventory": []interface{}{"phone", "key"}, +// } +// var output map[string]interface{} +// err := conn.QueryRow("select $1::"+typename, input).Scan(&output) +// if err != nil { +// t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) +// return +// } -func testJSONStringArray(t *testing.T, conn *pgx.Conn, typename string, format int16) { - input := []string{"foo", "bar", "baz"} - var output []string - err := conn.QueryRow("select $1::"+typename, input).Scan(&output) - if err != nil { - t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) - } +// if !reflect.DeepEqual(input, output) { +// t.Errorf("%s %d: Did not transcode map[string]interface{} successfully: %v is not %v", typename, format, input, output) +// return +// } +// } - if !reflect.DeepEqual(input, output) { - t.Errorf("%s %d: Did not transcode []string successfully: %v is not %v", typename, format, input, output) - } -} +// func testJSONStringArray(t *testing.T, conn *pgx.Conn, typename string, format int16) { +// input := []string{"foo", "bar", "baz"} +// var output []string +// err := conn.QueryRow("select $1::"+typename, input).Scan(&output) +// if err != nil { +// t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) +// } -func testJSONInt64Array(t *testing.T, conn *pgx.Conn, typename string, format int16) { - input := []int64{1, 2, 234432} - var output []int64 - err := conn.QueryRow("select $1::"+typename, input).Scan(&output) - if err != nil { - t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) - } +// if !reflect.DeepEqual(input, output) { +// t.Errorf("%s %d: Did not transcode []string successfully: %v is not %v", typename, format, input, output) +// } +// } - if !reflect.DeepEqual(input, output) { - t.Errorf("%s %d: Did not transcode []int64 successfully: %v is not %v", typename, format, input, output) - } -} +// func testJSONInt64Array(t *testing.T, conn *pgx.Conn, typename string, format int16) { +// input := []int64{1, 2, 234432} +// var output []int64 +// err := conn.QueryRow("select $1::"+typename, input).Scan(&output) +// if err != nil { +// t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) +// } -func testJSONInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string, format int16) { - input := []int{1, 2, 234432} - var output []int16 - err := conn.QueryRow("select $1::"+typename, input).Scan(&output) - if err == nil || err.Error() != "can't scan into dest[0]: json: cannot unmarshal number 234432 into Go value of type int16" { - t.Errorf("%s %d: Expected *json.UnmarkalTypeError, but got %v", typename, format, err) - } -} +// if !reflect.DeepEqual(input, output) { +// t.Errorf("%s %d: Did not transcode []int64 successfully: %v is not %v", typename, format, input, output) +// } +// } -func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string, format int16) { - type person struct { - Name string `json:"name"` - Age int `json:"age"` - } +// func testJSONInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string, format int16) { +// input := []int{1, 2, 234432} +// var output []int16 +// err := conn.QueryRow("select $1::"+typename, input).Scan(&output) +// if err == nil || err.Error() != "can't scan into dest[0]: json: cannot unmarshal number 234432 into Go value of type int16" { +// t.Errorf("%s %d: Expected *json.UnmarkalTypeError, but got %v", typename, format, err) +// } +// } - input := person{ - Name: "John", - Age: 42, - } +// func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string, format int16) { +// type person struct { +// Name string `json:"name"` +// Age int `json:"age"` +// } - var output person +// input := person{ +// Name: "John", +// Age: 42, +// } - err := conn.QueryRow("select $1::"+typename, input).Scan(&output) - if err != nil { - t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) - } +// var output person - if !reflect.DeepEqual(input, output) { - t.Errorf("%s %d: Did not transcode struct successfully: %v is not %v", typename, format, input, output) - } -} +// err := conn.QueryRow("select $1::"+typename, input).Scan(&output) +// if err != nil { +// t.Errorf("%s %d: QueryRow Scan failed: %v", typename, format, err) +// } + +// if !reflect.DeepEqual(input, output) { +// t.Errorf("%s %d: Did not transcode struct successfully: %v is not %v", typename, format, input, output) +// } +// } func mustParseCidr(t *testing.T, s string) *net.IPNet { _, ipnet, err := net.ParseCIDR(s)