diff --git a/oid.go b/oid.go index e57bb2e6..eab1fbcb 100644 --- a/oid.go +++ b/oid.go @@ -1,45 +1,57 @@ package pgtype import ( + "encoding/binary" + "fmt" "io" + "strconv" + + "github.com/jackc/pgx/pgio" ) // Oid (Object Identifier Type) is, according to // https://www.postgresql.org/docs/current/static/datatype-oid.html, used // internally by PostgreSQL as a primary key for various system tables. It is // currently implemented as an unsigned four-byte integer. Its definition can be -// found in src/include/postgres_ext.h in the PostgreSQL sources. -type Oid pguint32 - -// Set converts from src to dst. Note that as Oid is not a general -// number type Set does not do automatic type conversion as other number -// types do. -func (dst *Oid) Set(src interface{}) error { - return (*pguint32)(dst).Set(src) -} - -func (dst *Oid) Get() interface{} { - return (*pguint32)(dst).Get() -} - -// AssignTo assigns from src to dst. Note that as Oid is not a general number -// type AssignTo does not do automatic type conversion as other number types do. -func (src *Oid) AssignTo(dst interface{}) error { - return (*pguint32)(src).AssignTo(dst) -} +// found in src/include/postgres_ext.h in the PostgreSQL sources. Because it is +// so frequently required to be in a NOT NULL condition Oid cannot be NULL. To +// allow for NULL Oids use OidValue. +type Oid uint32 func (dst *Oid) DecodeText(src []byte) error { - return (*pguint32)(dst).DecodeText(src) + if src == nil { + return fmt.Errorf("cannot decode nil into Oid") + } + + n, err := strconv.ParseUint(string(src), 10, 32) + if err != nil { + return err + } + + *dst = Oid(n) + return nil } func (dst *Oid) DecodeBinary(src []byte) error { - return (*pguint32)(dst).DecodeBinary(src) + if src == nil { + return fmt.Errorf("cannot decode nil into Oid") + } + + if len(src) != 4 { + return fmt.Errorf("invalid length: %v", len(src)) + } + + n := binary.BigEndian.Uint32(src) + *dst = Oid(n) + return nil } func (src Oid) EncodeText(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeText(w) + _, err := io.WriteString(w, strconv.FormatUint(uint64(src), 10)) + return false, err } func (src Oid) EncodeBinary(w io.Writer) (bool, error) { - return (pguint32)(src).EncodeBinary(w) + _, err := pgio.WriteUint32(w, uint32(src)) + return false, err } diff --git a/oid_value.go b/oid_value.go new file mode 100644 index 00000000..a2b2dcbe --- /dev/null +++ b/oid_value.go @@ -0,0 +1,45 @@ +package pgtype + +import ( + "io" +) + +// OidValue (Object Identifier Type) is, according to +// https://www.postgresql.org/docs/current/static/datatype-OidValue.html, used +// internally by PostgreSQL as a primary key for various system tables. It is +// currently implemented as an unsigned four-byte integer. Its definition can be +// found in src/include/postgres_ext.h in the PostgreSQL sources. +type OidValue pguint32 + +// Set converts from src to dst. Note that as OidValue is not a general +// number type Set does not do automatic type conversion as other number +// types do. +func (dst *OidValue) Set(src interface{}) error { + return (*pguint32)(dst).Set(src) +} + +func (dst *OidValue) Get() interface{} { + return (*pguint32)(dst).Get() +} + +// AssignTo assigns from src to dst. Note that as OidValue is not a general number +// type AssignTo does not do automatic type conversion as other number types do. +func (src *OidValue) AssignTo(dst interface{}) error { + return (*pguint32)(src).AssignTo(dst) +} + +func (dst *OidValue) DecodeText(src []byte) error { + return (*pguint32)(dst).DecodeText(src) +} + +func (dst *OidValue) DecodeBinary(src []byte) error { + return (*pguint32)(dst).DecodeBinary(src) +} + +func (src OidValue) EncodeText(w io.Writer) (bool, error) { + return (pguint32)(src).EncodeText(w) +} + +func (src OidValue) EncodeBinary(w io.Writer) (bool, error) { + return (pguint32)(src).EncodeBinary(w) +} diff --git a/oid_test.go b/oid_value_test.go similarity index 66% rename from oid_test.go rename to oid_value_test.go index b3b96959..21dd6f9d 100644 --- a/oid_test.go +++ b/oid_value_test.go @@ -7,23 +7,23 @@ import ( "github.com/jackc/pgx/pgtype" ) -func TestOidTranscode(t *testing.T) { +func TestOidValueTranscode(t *testing.T) { testSuccessfulTranscode(t, "oid", []interface{}{ - pgtype.Oid{Uint: 42, Status: pgtype.Present}, - pgtype.Oid{Status: pgtype.Null}, + pgtype.OidValue{Uint: 42, Status: pgtype.Present}, + pgtype.OidValue{Status: pgtype.Null}, }) } -func TestOidSet(t *testing.T) { +func TestOidValueSet(t *testing.T) { successfulTests := []struct { source interface{} - result pgtype.Oid + result pgtype.OidValue }{ - {source: uint32(1), result: pgtype.Oid{Uint: 1, Status: pgtype.Present}}, + {source: uint32(1), result: pgtype.OidValue{Uint: 1, Status: pgtype.Present}}, } for i, tt := range successfulTests { - var r pgtype.Oid + var r pgtype.OidValue err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) @@ -35,17 +35,17 @@ func TestOidSet(t *testing.T) { } } -func TestOidAssignTo(t *testing.T) { +func TestOidValueAssignTo(t *testing.T) { var ui32 uint32 var pui32 *uint32 simpleTests := []struct { - src pgtype.Oid + src pgtype.OidValue dst interface{} expected interface{} }{ - {src: pgtype.Oid{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, - {src: pgtype.Oid{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, + {src: pgtype.OidValue{Uint: 42, Status: pgtype.Present}, dst: &ui32, expected: uint32(42)}, + {src: pgtype.OidValue{Status: pgtype.Null}, dst: &pui32, expected: ((*uint32)(nil))}, } for i, tt := range simpleTests { @@ -60,11 +60,11 @@ func TestOidAssignTo(t *testing.T) { } pointerAllocTests := []struct { - src pgtype.Oid + src pgtype.OidValue dst interface{} expected interface{} }{ - {src: pgtype.Oid{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, + {src: pgtype.OidValue{Uint: 42, Status: pgtype.Present}, dst: &pui32, expected: uint32(42)}, } for i, tt := range pointerAllocTests { @@ -79,10 +79,10 @@ func TestOidAssignTo(t *testing.T) { } errorTests := []struct { - src pgtype.Oid + src pgtype.OidValue dst interface{} }{ - {src: pgtype.Oid{Status: pgtype.Null}, dst: &ui32}, + {src: pgtype.OidValue{Status: pgtype.Null}, dst: &ui32}, } for i, tt := range errorTests {