diff --git a/ext/gofrs-uuid/uuid.go b/ext/gofrs-uuid/uuid.go deleted file mode 100644 index 0e0ebed3..00000000 --- a/ext/gofrs-uuid/uuid.go +++ /dev/null @@ -1,176 +0,0 @@ -package uuid - -import ( - "database/sql/driver" - "fmt" - - "github.com/gofrs/uuid" - "github.com/jackc/pgtype" -) - -type UUID struct { - UUID uuid.UUID - Valid bool -} - -func (dst *UUID) Set(src interface{}) error { - if src == nil { - *dst = UUID{} - return nil - } - - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } - } - - switch value := src.(type) { - case uuid.UUID: - *dst = UUID{UUID: value, Valid: true} - case [16]byte: - *dst = UUID{UUID: uuid.UUID(value), Valid: true} - case []byte: - if len(value) != 16 { - return fmt.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) - } - *dst = UUID{Valid: true} - copy(dst.UUID[:], value) - case string: - uuid, err := uuid.FromString(value) - if err != nil { - return err - } - *dst = UUID{UUID: uuid, Valid: true} - default: - // If all else fails see if pgtype.UUID can handle it. If so, translate through that. - pgUUID := &pgtype.UUID{} - if err := pgUUID.Set(value); err != nil { - return fmt.Errorf("cannot convert %v to UUID", value) - } - - *dst = UUID{UUID: uuid.UUID(pgUUID.Bytes), Valid: pgUUID.Valid} - } - - return nil -} - -func (dst UUID) Get() interface{} { - if !dst.Valid { - return nil - } - return dst.UUID -} - -func (src *UUID) AssignTo(dst interface{}) error { - if !src.Valid { - return pgtype.NullAssignTo(dst) - } - - switch v := dst.(type) { - case *uuid.UUID: - *v = src.UUID - return nil - case *[16]byte: - *v = [16]byte(src.UUID) - return nil - case *[]byte: - *v = make([]byte, 16) - copy(*v, src.UUID[:]) - return nil - case *string: - *v = src.UUID.String() - return nil - default: - if nextDst, retry := pgtype.GetAssignToDstType(v); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) - } -} - -func (dst *UUID) DecodeText(ci *pgtype.ConnInfo, src []byte) error { - if src == nil { - *dst = UUID{} - return nil - } - - u, err := uuid.FromString(string(src)) - if err != nil { - return err - } - - *dst = UUID{UUID: u, Valid: true} - return nil -} - -func (dst *UUID) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { - if src == nil { - *dst = UUID{} - return nil - } - - if len(src) != 16 { - return fmt.Errorf("invalid length for UUID: %v", len(src)) - } - - *dst = UUID{Valid: true} - copy(dst.UUID[:], src) - return nil -} - -func (src UUID) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - return append(buf, src.UUID.String()...), nil -} - -func (src UUID) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - return append(buf, src.UUID[:]...), nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *UUID) Scan(src interface{}) error { - if src == nil { - *dst = UUID{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - return dst.DecodeText(nil, src) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src UUID) Value() (driver.Value, error) { - return pgtype.EncodeValueText(src) -} - -func (src UUID) MarshalJSON() ([]byte, error) { - if !src.Valid { - return []byte("null"), nil - } - return []byte(`"` + src.UUID.String() + `"`), nil -} - -func (dst *UUID) UnmarshalJSON(b []byte) error { - u := uuid.NullUUID{} - err := u.UnmarshalJSON(b) - if err != nil { - return err - } - - *dst = UUID{UUID: u.UUID, Valid: u.Valid} - - return nil -} diff --git a/ext/gofrs-uuid/uuid_test.go b/ext/gofrs-uuid/uuid_test.go deleted file mode 100644 index 3e5e4d82..00000000 --- a/ext/gofrs-uuid/uuid_test.go +++ /dev/null @@ -1,100 +0,0 @@ -package uuid_test - -import ( - "bytes" - "testing" - - gofrs "github.com/jackc/pgtype/ext/gofrs-uuid" - "github.com/jackc/pgtype/testutil" -) - -func TestUUIDTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ - &gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - &gofrs.UUID{}, - }) -} - -func TestUUIDSet(t *testing.T) { - successfulTests := []struct { - source interface{} - result gofrs.UUID - }{ - { - source: &gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - }, - { - source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - }, - { - source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - }, - { - source: "00010203-0405-0607-0809-0a0b0c0d0e0f", - result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}, - }, - } - - for i, tt := range successfulTests { - var r gofrs.UUID - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestUUIDAssignTo(t *testing.T) { - { - src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true} - var dst [16]byte - expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } - - if dst != expected { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } - } - - { - src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true} - var dst []byte - expected := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } - - if bytes.Compare(dst, expected) != 0 { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } - } - - { - src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true} - var dst string - expected := "00010203-0405-0607-0809-0a0b0c0d0e0f" - - err := src.AssignTo(&dst) - if err != nil { - t.Error(err) - } - - if dst != expected { - t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) - } - } - -} diff --git a/go.mod b/go.mod index 99c5b26e..b2f1cc10 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,8 @@ module github.com/jackc/pgtype go 1.13 require ( - github.com/gofrs/uuid v4.0.0+incompatible - github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530 + github.com/jackc/pgconn v1.10.1 github.com/jackc/pgio v1.0.0 - github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c - github.com/shopspring/decimal v1.2.0 + github.com/jackc/pgx/v4 v4.14.2-0.20211129172902-cf0de913ee8f github.com/stretchr/testify v1.7.0 ) diff --git a/go.sum b/go.sum index 8f2d760e..2a835726 100644 --- a/go.sum +++ b/go.sum @@ -25,8 +25,9 @@ github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= -github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530 h1:dUJ578zuPEsXjtzOfEF0q9zDAfljJ9oFnTHcQaNkccw= github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= +github.com/jackc/pgconn v1.10.1 h1:DzdIHIjG1AxGwoEEqS+mGsURyjt4enSmqzACXvVzOT8= +github.com/jackc/pgconn v1.10.1/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= @@ -42,22 +43,26 @@ github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.1.1 h1:7PQ/4gLoqnl87ZxL7xjO0DR5gYuviDCZxQJsUlFW1eI= github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.2.0 h1:r7JypeP2D3onoQTCxWdTpCtJ4D+qpKr0TxvoyMhZ5ns= +github.com/jackc/pgproto3/v2 v2.2.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM= +github.com/jackc/pgtype v1.9.1/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= -github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c h1:Dznn52SgVIVst9UyOT9brctYUgxs+CvVfPaC3jKrA50= github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= +github.com/jackc/pgx/v4 v4.14.2-0.20211129172902-cf0de913ee8f h1:Y3Es3mIYatTvP4CXPXfmJtHWe8eq4E8owY6Fq61hEik= +github.com/jackc/pgx/v4 v4.14.2-0.20211129172902-cf0de913ee8f/go.mod h1:RgDuE4Z34o7XE92RpLsvFiOEfrAUT0Xt2KxvX73W06M= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.2.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= diff --git a/uuid.go b/uuid.go index d46111d3..4533aa06 100644 --- a/uuid.go +++ b/uuid.go @@ -10,11 +10,58 @@ import ( type UUID struct { Bytes [16]byte Valid bool + + UUIDDecoderWrapper func(interface{}) UUIDDecoder + Getter func(UUID) interface{} +} + +func (n *UUID) NewTypeValue() Value { + return &UUID{ + UUIDDecoderWrapper: n.UUIDDecoderWrapper, + Getter: n.Getter, + } +} + +func (n *UUID) TypeName() string { + return "uuid" +} + +func (dst *UUID) setNil() { + dst.Bytes = [16]byte{} + dst.Valid = false +} + +func (dst *UUID) setByteArray(value [16]byte) { + dst.Bytes = value + dst.Valid = true +} + +func (dst *UUID) setByteSlice(value []byte) error { + if value != nil { + if len(value) != 16 { + return fmt.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) + } + copy(dst.Bytes[:], value) + dst.Valid = true + } else { + dst.setNil() + } + + return nil +} + +func (dst *UUID) setString(value string) error { + uuid, err := parseUUID(value) + if err != nil { + return err + } + dst.setByteArray(uuid) + return nil } func (dst *UUID) Set(src interface{}) error { if src == nil { - *dst = UUID{} + dst.setNil() return nil } @@ -27,28 +74,16 @@ func (dst *UUID) Set(src interface{}) error { switch value := src.(type) { case [16]byte: - *dst = UUID{Bytes: value, Valid: true} + dst.setByteArray(value) case []byte: - if value != nil { - if len(value) != 16 { - return fmt.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) - } - *dst = UUID{Valid: true} - copy(dst.Bytes[:], value) - } else { - *dst = UUID{} - } + return dst.setByteSlice(value) case string: - uuid, err := parseUUID(value) - if err != nil { - return err - } - *dst = UUID{Bytes: uuid, Valid: true} + return dst.setString(value) case *string: if value == nil { - *dst = UUID{} + dst.setNil() } else { - return dst.Set(*value) + return dst.setString(*value) } default: if originalSrc, ok := underlyingUUIDType(src); ok { @@ -61,13 +96,33 @@ func (dst *UUID) Set(src interface{}) error { } func (dst UUID) Get() interface{} { + if dst.Getter != nil { + return dst.Getter(dst) + } + if !dst.Valid { return nil } + return dst.Bytes } +type UUIDDecoder interface { + DecodeUUID(*UUID) error +} + func (src *UUID) AssignTo(dst interface{}) error { + if d, ok := dst.(UUIDDecoder); ok { + return d.DecodeUUID(src) + } else { + if src.UUIDDecoderWrapper != nil { + d = src.UUIDDecoderWrapper(dst) + if d != nil { + return d.DecodeUUID(src) + } + } + } + if !src.Valid { return NullAssignTo(dst) } @@ -120,7 +175,7 @@ func encodeUUID(src [16]byte) string { func (dst *UUID) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = UUID{} + dst.setNil() return nil } @@ -133,13 +188,13 @@ func (dst *UUID) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = UUID{Bytes: buf, Valid: true} + dst.setByteArray(buf) return nil } func (dst *UUID) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = UUID{} + dst.setNil() return nil } @@ -147,9 +202,7 @@ func (dst *UUID) DecodeBinary(ci *ConnInfo, src []byte) error { return fmt.Errorf("invalid length for UUID: %v", len(src)) } - *dst = UUID{Valid: true} - copy(dst.Bytes[:], src) - return nil + return dst.setByteSlice(src) } func (src UUID) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { @@ -171,7 +224,7 @@ func (src UUID) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *UUID) Scan(src interface{}) error { if src == nil { - *dst = UUID{} + dst.setNil() return nil } diff --git a/uuid_test.go b/uuid_test.go index 887f45dd..63797178 100644 --- a/uuid_test.go +++ b/uuid_test.go @@ -65,7 +65,7 @@ func TestUUIDSet(t *testing.T) { t.Errorf("%d: %v", i, err) } - if r != tt.result { + if r.Bytes != tt.result.Bytes || r.Valid != tt.result.Valid { t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) } }