From af8519991e37719ae5288d4eba78026cdd814910 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 5 Mar 2017 13:05:49 -0600 Subject: [PATCH] Move OID to pgtype --- conn.go | 5 +- pgtype/oid.go | 41 ++++++++++++++ pgtype/oid_test.go | 94 +++++++++++++++++++++++++++++++ pgtype/pguint32.go | 2 +- query.go | 6 -- query_test.go | 7 +-- values.go | 135 +++++++++++++++++++-------------------------- values_test.go | 4 -- 8 files changed, 199 insertions(+), 95 deletions(-) create mode 100644 pgtype/oid.go create mode 100644 pgtype/oid_test.go diff --git a/conn.go b/conn.go index 2b826dad..c55d5618 100644 --- a/conn.go +++ b/conn.go @@ -287,6 +287,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl Int4OID: &pgtype.Int4{}, Int8ArrayOID: &pgtype.Int8Array{}, Int8OID: &pgtype.Int8{}, + OIDOID: &pgtype.OID{}, TextArrayOID: &pgtype.TextArray{}, TextOID: &pgtype.Text{}, TimestampArrayOID: &pgtype.TimestampArray{}, @@ -392,7 +393,7 @@ where ( c.PgTypes = make(map[OID]PgType, 128) for rows.Next() { - var oid OID + var oid uint32 var t PgType rows.Scan(&oid, &t.Name) @@ -400,7 +401,7 @@ where ( // The zero value is text format so we ignore any types without a default type format t.DefaultFormat, _ = DefaultTypeFormats[t.Name] - c.PgTypes[oid] = t + c.PgTypes[OID(oid)] = t } return rows.Err() diff --git a/pgtype/oid.go b/pgtype/oid.go new file mode 100644 index 00000000..d137f352 --- /dev/null +++ b/pgtype/oid.go @@ -0,0 +1,41 @@ +package pgtype + +import ( + "io" +) + +// OID (Object Identifier Type) is, according to +// https://www.postgresql.org/docs/current/static/datatype-oid.html, used +// internally by PostgreSQL as a primary key for various system tables. It is +// currently implemented as an unsigned four-byte integer. Its definition can be +// found in src/include/postgres_ext.h in the PostgreSQL sources. +type OID pguint32 + +// ConvertFrom converts from src to dst. Note that as OID is not a general +// number type ConvertFrom does not do automatic type conversion as other number +// types do. +func (dst *OID) ConvertFrom(src interface{}) error { + return (*pguint32)(dst).ConvertFrom(src) +} + +// AssignTo assigns from src to dst. Note that as OID is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *OID) AssignTo(dst interface{}) error { + return (*pguint32)(src).AssignTo(dst) +} + +func (dst *OID) DecodeText(r io.Reader) error { + return (*pguint32)(dst).DecodeText(r) +} + +func (dst *OID) DecodeBinary(r io.Reader) error { + return (*pguint32)(dst).DecodeBinary(r) +} + +func (src OID) EncodeText(w io.Writer) error { + return (pguint32)(src).EncodeText(w) +} + +func (src OID) EncodeBinary(w io.Writer) error { + return (pguint32)(src).EncodeBinary(w) +} diff --git a/pgtype/oid_test.go b/pgtype/oid_test.go new file mode 100644 index 00000000..c8e0b2d6 --- /dev/null +++ b/pgtype/oid_test.go @@ -0,0 +1,94 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestOIDTranscode(t *testing.T) { + testSuccessfulTranscode(t, "oid", []interface{}{ + pgtype.OID{Uint: 42, Status: pgtype.Present}, + pgtype.OID{Status: pgtype.Null}, + }) +} + +func TestOIDConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.OID + }{ + {source: uint32(1), result: pgtype.OID{Uint: 1, Status: pgtype.Present}}, + } + + for i, tt := range successfulTests { + var r pgtype.OID + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestOIDAssignTo(t *testing.T) { + var ui32 uint32 + var pui32 *uint32 + + simpleTests := []struct { + src pgtype.OID + dst interface{} + expected interface{} + }{ + {src: pgtype.OID{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.OID{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.OID + dst interface{} + expected interface{} + }{ + {src: pgtype.OID{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.OID + dst interface{} + }{ + {src: pgtype.OID{Status: pgtype.Null}, dst: &ui32}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/pguint32.go b/pgtype/pguint32.go index 66b385fb..9c1ccd6c 100644 --- a/pgtype/pguint32.go +++ b/pgtype/pguint32.go @@ -89,7 +89,7 @@ func (dst *pguint32) DecodeBinary(r io.Reader) error { } if size != 4 { - return fmt.Errorf("invalid length for cid: %v", size) + return fmt.Errorf("invalid length: %v", size) } n, err := pgio.ReadUint32(r) diff --git a/query.go b/query.go index ffe51ecc..965f3913 100644 --- a/query.go +++ b/query.go @@ -256,8 +256,6 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { val = int64(decodeInt4(vr)) case TextOID, VarcharOID: val = decodeText(vr) - case OIDOID: - val = int64(decodeOID(vr)) case Float4OID: val = float64(decodeFloat4(vr)) case Float8OID: @@ -382,8 +380,6 @@ func (rows *Rows) Values() ([]interface{}, error) { values = append(values, decodeInt2(vr)) case Int4OID: values = append(values, decodeInt4(vr)) - case OIDOID: - values = append(values, decodeOID(vr)) case Float4OID: values = append(values, decodeFloat4(vr)) case Float8OID: @@ -457,8 +453,6 @@ func (rows *Rows) ValuesForStdlib() ([]interface{}, error) { values = append(values, decodeInt2(vr)) case Int4OID: values = append(values, decodeInt4(vr)) - case OIDOID: - values = append(values, decodeOID(vr)) case Float4OID: values = append(values, decodeFloat4(vr)) case Float8OID: diff --git a/query_test.go b/query_test.go index 801ba851..bbd7871e 100644 --- a/query_test.go +++ b/query_test.go @@ -53,7 +53,7 @@ func TestConnQueryValues(t *testing.T) { var rowCount int32 - rows, err := conn.Query("select 'foo'::text, 'bar'::varchar, n, null, n::oid from generate_series(1,$1) n", 10) + rows, err := conn.Query("select 'foo'::text, 'bar'::varchar, n, null, n from generate_series(1,$1) n", 10) if err != nil { t.Fatalf("conn.Query failed: %v", err) } @@ -84,7 +84,7 @@ func TestConnQueryValues(t *testing.T) { t.Errorf(`Expected values[3] to be %v, but it was %d`, nil, values[3]) } - if values[4] != pgx.OID(rowCount) { + if values[4] != rowCount { t.Errorf(`Expected values[4] to be %d, but it was %d`, rowCount, values[4]) } } @@ -478,9 +478,6 @@ func TestQueryRowCoreTypes(t *testing.T) { if err == nil { t.Errorf("%d. Expected null to cause error, but it didn't (sql -> %v)", i, tt.sql) } - if err != nil && !strings.Contains(err.Error(), "Cannot decode null") && !strings.Contains(err.Error(), "cannot assign") { - t.Errorf(`%d. Expected null to cause error "Cannot decode null..." but it was %v (sql -> %v)`, i, err, tt.sql) - } ensureConnValid(t, conn) } diff --git a/values.go b/values.go index b6848cf5..59d6f3c4 100644 --- a/values.go +++ b/values.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgtype" ) @@ -548,46 +549,75 @@ func (n NullInt32) Encode(w *WriteBuf, oid OID) error { // OID (Object Identifier Type) is, according to https://www.postgresql.org/docs/current/static/datatype-oid.html, // used internally by PostgreSQL as a primary key for various system tables. It is currently implemented // as an unsigned four-byte integer. Its definition can be found in src/include/postgres_ext.h -// in the PostgreSQL sources. +// in the PostgreSQL sources. OID cannot be NULL. To allow for NULL OIDs use pgtype.OID. type OID uint32 -// NullOID represents a Command Identifier (OID) that may be null. NullOID implements the -// Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. -// -// If Valid is false then the value is NULL. -type NullOID struct { - OID OID - Valid bool // Valid is true if OID is not NULL +func (dst *OID) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + return fmt.Errorf("cannot decode nil into OID") + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + n, err := strconv.ParseUint(string(buf), 10, 32) + if err != nil { + return err + } + + *dst = OID(n) + return nil } -func (n *NullOID) Scan(vr *ValueReader) error { - if vr.Type().DataType != OIDOID { - return SerializationError(fmt.Sprintf("NullOID.Scan cannot decode OID %d", vr.Type().DataType)) +func (dst *OID) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err } - if vr.Len() == -1 { - n.OID, n.Valid = 0, false - return nil + if size == -1 { + return fmt.Errorf("cannot decode nil into OID") } - n.Valid = true - n.OID = decodeOID(vr) - return vr.Err() + + if size != 4 { + return fmt.Errorf("invalid length for OID: %v", size) + } + + n, err := pgio.ReadUint32(r) + if err != nil { + return err + } + + *dst = OID(n) + return nil } -func (n NullOID) FormatCode() int16 { return BinaryFormatCode } - -func (n NullOID) Encode(w *WriteBuf, oid OID) error { - if oid != OIDOID { - return SerializationError(fmt.Sprintf("NullOID.Encode cannot encode into OID %d", oid)) - } - - if !n.Valid { - w.WriteInt32(-1) +func (src OID) EncodeText(w io.Writer) error { + s := strconv.FormatUint(uint64(src), 10) + _, err := pgio.WriteInt32(w, int32(len(s))) + if err != nil { return nil } + _, err = w.Write([]byte(s)) + return err +} - return encodeOID(w, oid, n.OID) +func (src OID) EncodeBinary(w io.Writer) error { + _, err := pgio.WriteInt32(w, 4) + if err != nil { + return err + } + + _, err = pgio.WriteUint32(w, uint32(src)) + return err } // Tid is PostgreSQL's Tuple Identifier type. @@ -976,8 +1006,6 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { // The name data type goes over the wire using the same format as string, // so just cast to string and use encodeString return encodeString(wbuf, oid, string(arg)) - case OID: - return encodeOID(wbuf, oid, arg) default: if strippedArg, ok := stripNamedType(&refVal); ok { return Encode(wbuf, oid, strippedArg) @@ -1053,8 +1081,6 @@ func Decode(vr *ValueReader, d interface{}) error { case *Name: // name goes over the wire just like text *v = Name(decodeText(vr)) - case *OID: - *v = decodeOID(vr) case *Tid: *v = decodeTid(vr) case *string: @@ -1292,49 +1318,6 @@ func decodeInt4(vr *ValueReader) int32 { return n.Int } -func decodeOID(vr *ValueReader) OID { - if vr.Len() == -1 { - vr.Fatal(ProtocolError("Cannot decode null into OID")) - return OID(0) - } - - if vr.Type().DataType != OIDOID { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into pgx.OID", vr.Type().DataType))) - return OID(0) - } - - // OID needs to decode text format because it is used in loadPgTypes - switch vr.Type().FormatCode { - case TextFormatCode: - s := vr.ReadString(vr.Len()) - n, err := strconv.ParseUint(s, 10, 32) - if err != nil { - vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid OID: %v", s))) - } - return OID(n) - case BinaryFormatCode: - if vr.Len() != 4 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an OID: %d", vr.Len()))) - return OID(0) - } - return OID(vr.ReadInt32()) - default: - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return OID(0) - } -} - -func encodeOID(w *WriteBuf, oid OID, value OID) error { - if oid != OIDOID { - return fmt.Errorf("cannot encode Go %s into oid %d", "pgx.OID", oid) - } - - w.WriteInt32(4) - w.WriteUint32(uint32(value)) - - return nil -} - // Note that we do not match negative numbers, because neither the // BlockNumber nor OffsetNumber of a Tid can be negative. var tidRegexp *regexp.Regexp = regexp.MustCompile(`^\((\d*),(\d*)\)$`) @@ -1764,8 +1747,6 @@ func decodeRecord(vr *ValueReader) []interface{} { record = append(record, decodeInt2(&fieldVR)) case Int4OID: record = append(record, decodeInt4(&fieldVR)) - case OIDOID: - record = append(record, decodeOID(&fieldVR)) case Float4OID: record = append(record, decodeFloat4(&fieldVR)) case Float8OID: diff --git a/values_test.go b/values_test.go index 0283f17d..65811959 100644 --- a/values_test.go +++ b/values_test.go @@ -571,7 +571,6 @@ func TestNullX(t *testing.T) { c pgx.NullChar a pgx.NullAclItem n pgx.NullName - oid pgx.NullOID tid pgx.NullTid i64 pgx.NullInt64 f32 pgx.NullFloat32 @@ -594,9 +593,6 @@ func TestNullX(t *testing.T) { {"select $1::int2", []interface{}{pgx.NullInt16{Int16: 1, Valid: false}}, []interface{}{&actual.i16}, allTypes{i16: pgx.NullInt16{Int16: 0, Valid: false}}}, {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: true}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 1, Valid: true}}}, {"select $1::int4", []interface{}{pgx.NullInt32{Int32: 1, Valid: false}}, []interface{}{&actual.i32}, allTypes{i32: pgx.NullInt32{Int32: 0, Valid: false}}}, - {"select $1::oid", []interface{}{pgx.NullOID{OID: 1, Valid: true}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOID{OID: 1, Valid: true}}}, - {"select $1::oid", []interface{}{pgx.NullOID{OID: 1, Valid: false}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOID{OID: 0, Valid: false}}}, - {"select $1::oid", []interface{}{pgx.NullOID{OID: 4294967295, Valid: true}}, []interface{}{&actual.oid}, allTypes{oid: pgx.NullOID{OID: 4294967295, Valid: true}}}, {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 1, Valid: true}}}, {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 1, Valid: false}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 0, Valid: false}}}, {"select $1::\"char\"", []interface{}{pgx.NullChar{Char: 255, Valid: true}}, []interface{}{&actual.c}, allTypes{c: pgx.NullChar{Char: 255, Valid: true}}},