diff --git a/pgtype/aclitem.go b/pgtype/aclitem.go index e8386ae7..77e385e6 100644 --- a/pgtype/aclitem.go +++ b/pgtype/aclitem.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "fmt" "io" ) @@ -93,3 +94,32 @@ func (src Aclitem) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { _, err := io.WriteString(w, src.String) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Aclitem) Scan(src interface{}) error { + if src == nil { + *dst = Aclitem{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Aclitem) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.String, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/aclitem_array.go b/pgtype/aclitem_array.go index 1c97e74f..20a7636a 100644 --- a/pgtype/aclitem_array.go +++ b/pgtype/aclitem_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "fmt" "io" @@ -194,3 +195,33 @@ func (src *AclitemArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return false, nil } + +// Scan implements the database/sql Scanner interface. +func (dst *AclitemArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *AclitemArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/bool.go b/pgtype/bool.go index 608a6f95..736d19cf 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "fmt" "io" "strconv" @@ -126,3 +127,35 @@ func (src Bool) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := w.Write(buf) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Bool) Scan(src interface{}) error { + if src == nil { + *dst = Bool{Status: Null} + return nil + } + + switch src := src.(type) { + case bool: + *dst = Bool{Bool: src, Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Bool) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.Bool, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/bool_array.go b/pgtype/bool_array.go index cdfe9685..4705d734 100644 --- a/pgtype/bool_array.go +++ b/pgtype/bool_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -296,3 +297,33 @@ func (src *BoolArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *BoolArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *BoolArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/bytea.go b/pgtype/bytea.go index 00bed8e8..9f0266e7 100644 --- a/pgtype/bytea.go +++ b/pgtype/bytea.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/hex" "fmt" "io" @@ -12,6 +13,11 @@ type Bytea struct { } func (dst *Bytea) Set(src interface{}) error { + if src == nil { + *dst = Bytea{Status: Null} + return nil + } + switch value := src.(type) { case []byte: if value != nil { @@ -124,3 +130,35 @@ func (src Bytea) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := w.Write(src.Bytes) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Bytea) Scan(src interface{}) error { + if src == nil { + *dst = Bytea{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + buf := make([]byte, len(src)) + copy(buf, src) + *dst = Bytea{Bytes: buf, Status: Present} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Bytea) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.Bytes, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/bytea_array.go b/pgtype/bytea_array.go index 175ca2f6..268364c1 100644 --- a/pgtype/bytea_array.go +++ b/pgtype/bytea_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -296,3 +297,33 @@ func (src *ByteaArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *ByteaArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *ByteaArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/cid.go b/pgtype/cid.go index d86e8063..63ba6a2f 100644 --- a/pgtype/cid.go +++ b/pgtype/cid.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "io" ) @@ -49,3 +50,13 @@ func (src Cid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { func (src Cid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return (pguint32)(src).EncodeBinary(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *Cid) Scan(src interface{}) error { + return (*pguint32)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Cid) Value() (driver.Value, error) { + return (pguint32)(src).Value() +} diff --git a/pgtype/cidr_array.go b/pgtype/cidr_array.go index 49a2728b..6643bb47 100644 --- a/pgtype/cidr_array.go +++ b/pgtype/cidr_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -325,3 +326,33 @@ func (src *CidrArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *CidrArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *CidrArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/database_sql.go b/pgtype/database_sql.go index 969d6542..2ddd842d 100644 --- a/pgtype/database_sql.go +++ b/pgtype/database_sql.go @@ -2,47 +2,13 @@ package pgtype import ( "bytes" + "database/sql/driver" "errors" ) func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) { - switch src := src.(type) { - case *Bool: - return src.Bool, nil - case *Bytea: - return src.Bytes, nil - case *Date: - if src.InfinityModifier == None { - return src.Time, nil - } - case *Float4: - return float64(src.Float), nil - case *Float8: - return src.Float, nil - case *GenericBinary: - return src.Bytes, nil - case *GenericText: - return src.String, nil - case *Int2: - return int64(src.Int), nil - case *Int4: - return int64(src.Int), nil - case *Int8: - return int64(src.Int), nil - case *Text: - return src.String, nil - case *Timestamp: - if src.InfinityModifier == None { - return src.Time, nil - } - case *Timestamptz: - if src.InfinityModifier == None { - return src.Time, nil - } - case *Unknown: - return src.String, nil - case *Varchar: - return src.String, nil + if valuer, ok := src.(driver.Valuer); ok { + return valuer.Value() } buf := &bytes.Buffer{} @@ -64,3 +30,15 @@ func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) { return nil, errors.New("cannot convert to database/sql compatible value") } + +func encodeValueText(src TextEncoder) (interface{}, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + return buf.String(), err +} diff --git a/pgtype/date.go b/pgtype/date.go index ab854eb2..7dd2c4f0 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -10,9 +11,9 @@ import ( ) type Date struct { - Time time.Time - Status Status - InfinityModifier + Time time.Time + Status Status + InfinityModifier InfinityModifier } const ( @@ -21,6 +22,11 @@ const ( ) func (dst *Date) Set(src interface{}) error { + if src == nil { + *dst = Date{Status: Null} + return nil + } + switch value := src.(type) { case time.Time: *dst = Date{Time: value, Status: Present} @@ -167,3 +173,38 @@ func (src Date) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt32(w, daysSinceDateEpoch) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Date) Scan(src interface{}) error { + if src == nil { + *dst = Date{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + case time.Time: + *dst = Date{Time: src, Status: Present} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Date) Value() (driver.Value, error) { + switch src.Status { + case Present: + if src.InfinityModifier != None { + return src.InfinityModifier.String(), nil + } + return src.Time, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/date_array.go b/pgtype/date_array.go index bf791677..f58de011 100644 --- a/pgtype/date_array.go +++ b/pgtype/date_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -297,3 +298,33 @@ func (src *DateArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *DateArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *DateArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/date_test.go b/pgtype/date_test.go index cfc3dd70..1832b5b4 100644 --- a/pgtype/date_test.go +++ b/pgtype/date_test.go @@ -9,7 +9,7 @@ import ( ) func TestDateTranscode(t *testing.T) { - testSuccessfulTranscode(t, "date", []interface{}{ + testSuccessfulTranscodeEqFunc(t, "date", []interface{}{ pgtype.Date{Time: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, pgtype.Date{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, pgtype.Date{Time: time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, @@ -19,6 +19,11 @@ func TestDateTranscode(t *testing.T) { pgtype.Date{Status: pgtype.Null}, pgtype.Date{Status: pgtype.Present, InfinityModifier: pgtype.Infinity}, pgtype.Date{Status: pgtype.Present, InfinityModifier: -pgtype.Infinity}, + }, func(a, b interface{}) bool { + at := a.(pgtype.Date) + bt := b.(pgtype.Date) + + return at.Time.Equal(bt.Time) && at.Status == bt.Status && at.InfinityModifier == bt.InfinityModifier }) } diff --git a/pgtype/float4.go b/pgtype/float4.go index 94b7b7a1..e92149a6 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -16,6 +17,11 @@ type Float4 struct { } func (dst *Float4) Set(src interface{}) error { + if src == nil { + *dst = Float4{Status: Null} + return nil + } + switch value := src.(type) { case float32: *dst = Float4{Float: value, Status: Present} @@ -156,3 +162,35 @@ func (src Float4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt32(w, int32(math.Float32bits(src.Float))) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Float4) Scan(src interface{}) error { + if src == nil { + *dst = Float4{Status: Null} + return nil + } + + switch src := src.(type) { + case float64: + *dst = Float4{Float: float32(src), Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Float4) Value() (driver.Value, error) { + switch src.Status { + case Present: + return float64(src.Float), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/float4_array.go b/pgtype/float4_array.go index b4d05c55..b9ee4b9e 100644 --- a/pgtype/float4_array.go +++ b/pgtype/float4_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -296,3 +297,33 @@ func (src *Float4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Float4Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Float4Array) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/float8.go b/pgtype/float8.go index dd2d592d..4d094757 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -16,6 +17,11 @@ type Float8 struct { } func (dst *Float8) Set(src interface{}) error { + if src == nil { + *dst = Float8{Status: Null} + return nil + } + switch value := src.(type) { case float32: *dst = Float8{Float: float64(value), Status: Present} @@ -146,3 +152,35 @@ func (src Float8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt64(w, int64(math.Float64bits(src.Float))) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Float8) Scan(src interface{}) error { + if src == nil { + *dst = Float8{Status: Null} + return nil + } + + switch src := src.(type) { + case float64: + *dst = Float8{Float: src, Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Float8) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.Float, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/float8_array.go b/pgtype/float8_array.go index e000807e..d49f18a7 100644 --- a/pgtype/float8_array.go +++ b/pgtype/float8_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -296,3 +297,33 @@ func (src *Float8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Float8Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Float8Array) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/generic_binary.go b/pgtype/generic_binary.go index aa28bb62..f834bfb2 100644 --- a/pgtype/generic_binary.go +++ b/pgtype/generic_binary.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "io" ) @@ -27,3 +28,13 @@ func (dst *GenericBinary) DecodeBinary(ci *ConnInfo, src []byte) error { func (src GenericBinary) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return (Bytea)(src).EncodeBinary(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *GenericBinary) Scan(src interface{}) error { + return (*Bytea)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src GenericBinary) Value() (driver.Value, error) { + return (Bytea)(src).Value() +} diff --git a/pgtype/generic_text.go b/pgtype/generic_text.go index bd75e0d0..053ec504 100644 --- a/pgtype/generic_text.go +++ b/pgtype/generic_text.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "io" ) @@ -27,3 +28,13 @@ func (dst *GenericText) DecodeText(ci *ConnInfo, src []byte) error { func (src GenericText) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { return (Text)(src).EncodeText(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *GenericText) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src GenericText) Value() (driver.Value, error) { + return (Text)(src).Value() +} diff --git a/pgtype/hstore.go b/pgtype/hstore.go index 8dc5b4d8..b8b0c6f3 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "errors" "fmt" @@ -21,6 +22,11 @@ type Hstore struct { } func (dst *Hstore) Set(src interface{}) error { + if src == nil { + *dst = Hstore{Status: Null} + return nil + } + switch value := src.(type) { case map[string]string: m := make(map[string]Text, len(value)) @@ -437,3 +443,25 @@ func parseHstore(s string) (k []string, v []Text, err error) { v = values return } + +// Scan implements the database/sql Scanner interface. +func (dst *Hstore) Scan(src interface{}) error { + if src == nil { + *dst = Hstore{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Hstore) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/pgtype/hstore_array.go b/pgtype/hstore_array.go index 9bd0ed3b..097fec7b 100644 --- a/pgtype/hstore_array.go +++ b/pgtype/hstore_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -296,3 +297,33 @@ func (src *HstoreArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *HstoreArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *HstoreArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/inet.go b/pgtype/inet.go index 13764814..0ca3ee7a 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "fmt" "io" "net" @@ -23,6 +24,11 @@ type Inet struct { } func (dst *Inet) Set(src interface{}) error { + if src == nil { + *dst = Inet{Status: Null} + return nil + } + switch value := src.(type) { case net.IPNet: *dst = Inet{IPNet: &value, Status: Present} @@ -189,3 +195,25 @@ func (src Inet) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := w.Write(src.IPNet.IP) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Inet) Scan(src interface{}) error { + if src == nil { + *dst = Inet{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Inet) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/pgtype/inet_array.go b/pgtype/inet_array.go index 1988a145..a108d75b 100644 --- a/pgtype/inet_array.go +++ b/pgtype/inet_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -325,3 +326,33 @@ func (src *InetArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *InetArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *InetArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/int2.go b/pgtype/int2.go index 6996cd4f..3bcac63c 100644 --- a/pgtype/int2.go +++ b/pgtype/int2.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -16,6 +17,11 @@ type Int2 struct { } func (dst *Int2) Set(src interface{}) error { + if src == nil { + *dst = Int2{Status: Null} + return nil + } + switch value := src.(type) { case int8: *dst = Int2{Int: int16(value), Status: Present} @@ -151,3 +157,41 @@ func (src Int2) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt16(w, src.Int) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Int2) Scan(src interface{}) error { + if src == nil { + *dst = Int2{Status: Null} + return nil + } + + switch src := src.(type) { + case int64: + if src < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", src) + } + if src > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", src) + } + *dst = Int2{Int: int16(src), Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int2) Value() (driver.Value, error) { + switch src.Status { + case Present: + return int64(src.Int), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/int2_array.go b/pgtype/int2_array.go index 531e7dd6..bddb5ac2 100644 --- a/pgtype/int2_array.go +++ b/pgtype/int2_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -324,3 +325,33 @@ func (src *Int2Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Int2Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Int2Array) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/int4.go b/pgtype/int4.go index 62ee366f..5069dab4 100644 --- a/pgtype/int4.go +++ b/pgtype/int4.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -16,6 +17,11 @@ type Int4 struct { } func (dst *Int4) Set(src interface{}) error { + if src == nil { + *dst = Int4{Status: Null} + return nil + } + switch value := src.(type) { case int8: *dst = Int4{Int: int32(value), Status: Present} @@ -68,7 +74,7 @@ func (dst *Int4) Set(src interface{}) error { if originalSrc, ok := underlyingNumberType(src); ok { return dst.Set(originalSrc) } - return fmt.Errorf("cannot convert %v to Int8", value) + return fmt.Errorf("cannot convert %v to Int4", value) } return nil @@ -142,3 +148,41 @@ func (src Int4) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt32(w, src.Int) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Int4) Scan(src interface{}) error { + if src == nil { + *dst = Int4{Status: Null} + return nil + } + + switch src := src.(type) { + case int64: + if src < math.MinInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", src) + } + if src > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", src) + } + *dst = Int4{Int: int32(src), Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int4) Value() (driver.Value, error) { + switch src.Status { + case Present: + return int64(src.Int), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/int4_array.go b/pgtype/int4_array.go index 3617050f..d5c8f911 100644 --- a/pgtype/int4_array.go +++ b/pgtype/int4_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -324,3 +325,33 @@ func (src *Int4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Int4Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Int4Array) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/int8.go b/pgtype/int8.go index 7ed54f8e..cf701dc6 100644 --- a/pgtype/int8.go +++ b/pgtype/int8.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -16,6 +17,11 @@ type Int8 struct { } func (dst *Int8) Set(src interface{}) error { + if src == nil { + *dst = Int8{Status: Null} + return nil + } + switch value := src.(type) { case int8: *dst = Int8{Int: int64(value), Status: Present} @@ -134,3 +140,35 @@ func (src Int8) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt64(w, src.Int) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Int8) Scan(src interface{}) error { + if src == nil { + *dst = Int8{Status: Null} + return nil + } + + switch src := src.(type) { + case int64: + *dst = Int8{Int: src, Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Int8) Value() (driver.Value, error) { + switch src.Status { + case Present: + return int64(src.Int), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/int8_array.go b/pgtype/int8_array.go index 4f04b660..ae2521fa 100644 --- a/pgtype/int8_array.go +++ b/pgtype/int8_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -324,3 +325,33 @@ func (src *Int8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Int8Array) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Int8Array) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/json.go b/pgtype/json.go index bfffae14..05d965ca 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -1,7 +1,9 @@ package pgtype import ( + "database/sql/driver" "encoding/json" + "fmt" "io" ) @@ -11,6 +13,11 @@ type Json struct { } func (dst *Json) Set(src interface{}) error { + if src == nil { + *dst = Json{Status: Null} + return nil + } + switch value := src.(type) { case string: *dst = Json{Bytes: []byte(value), Status: Present} @@ -116,3 +123,32 @@ func (src Json) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { func (src Json) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return src.EncodeText(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *Json) Scan(src interface{}) error { + if src == nil { + *dst = Json{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Json) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.Bytes, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/jsonb.go b/pgtype/jsonb.go index e44f3c41..f47476d6 100644 --- a/pgtype/jsonb.go +++ b/pgtype/jsonb.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "fmt" "io" ) @@ -66,3 +67,13 @@ func (src Jsonb) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err = w.Write(src.Bytes) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Jsonb) Scan(src interface{}) error { + return (*Json)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Jsonb) Value() (driver.Value, error) { + return (Json)(src).Value() +} diff --git a/pgtype/name.go b/pgtype/name.go index 9ebf63d3..cc4ae23b 100644 --- a/pgtype/name.go +++ b/pgtype/name.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "io" ) @@ -46,3 +47,13 @@ func (src Name) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { func (src Name) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return (Text)(src).EncodeBinary(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *Name) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Name) Value() (driver.Value, error) { + return (Text)(src).Value() +} diff --git a/pgtype/oid.go b/pgtype/oid.go index 3edd7f3c..339dee0f 100644 --- a/pgtype/oid.go +++ b/pgtype/oid.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -55,3 +56,27 @@ func (src Oid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteUint32(w, uint32(src)) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Oid) Scan(src interface{}) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", src) + } + + switch src := src.(type) { + case int64: + *dst = Oid(src) + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Oid) Value() (driver.Value, error) { + return int64(src), nil +} diff --git a/pgtype/oid_value.go b/pgtype/oid_value.go index 1bce6e11..cb03802e 100644 --- a/pgtype/oid_value.go +++ b/pgtype/oid_value.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "io" ) @@ -43,3 +44,13 @@ func (src OidValue) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { func (src OidValue) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return (pguint32)(src).EncodeBinary(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *OidValue) Scan(src interface{}) error { + return (*pguint32)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src OidValue) Value() (driver.Value, error) { + return (pguint32)(src).Value() +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 674c0db7..7e6633d9 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -67,6 +67,19 @@ const ( NegativeInfinity InfinityModifier = -Infinity ) +func (im InfinityModifier) String() string { + switch im { + case None: + return "none" + case Infinity: + return "infinity" + case NegativeInfinity: + return "-infinity" + default: + return "invalid" + } +} + type Value interface { // Set converts and assigns src to itself. Set(src interface{}) error diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 391fed57..16cabfd1 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "database/sql" "fmt" "io" "net" @@ -10,6 +11,8 @@ import ( "github.com/jackc/pgx" "github.com/jackc/pgx/pgtype" + _ "github.com/jackc/pgx/stdlib" + _ "github.com/lib/pq" ) // Test for renamed types @@ -24,6 +27,25 @@ type _float32Slice []float32 type _float64Slice []float64 type _byteSlice []byte +func mustConnectDatabaseSQL(t testing.TB, driverName string) *sql.DB { + var sqlDriverName string + switch driverName { + case "github.com/lib/pq": + sqlDriverName = "postgres" + case "github.com/jackc/pgx/stdlib": + sqlDriverName = "pgx" + default: + t.Fatalf("Unknown driver %v", driverName) + } + + db, err := sql.Open(sqlDriverName, os.Getenv("DATABASE_URL")) + if err != nil { + t.Fatal(err) + } + + return db +} + func mustConnectPgx(t testing.TB) *pgx.Conn { config, err := pgx.ParseURI(os.Getenv("DATABASE_URL")) if err != nil { @@ -93,6 +115,13 @@ func testSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface } func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + testPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) + for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { + testDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) + } +} + +func testPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { conn := mustConnectPgx(t) defer mustClose(t, conn) @@ -114,7 +143,7 @@ func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []int ps.FieldDescriptions[0].FormatCode = fc.formatCode vEncoder := forceEncoder(v, fc.formatCode) if vEncoder == nil { - t.Logf("%#v does not implement %v", v, fc.name) + t.Logf("Skipping: %#v does not implement %v", v, fc.name) continue } // Derefence value if it is a pointer @@ -136,3 +165,33 @@ func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []int } } } + +func testDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { + conn := mustConnectDatabaseSQL(t, driverName) + defer mustClose(t, conn) + + ps, err := conn.Prepare(fmt.Sprintf("select $1::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + for i, v := range values { + // Derefence value if it is a pointer + derefV := v + refVal := reflect.ValueOf(v) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } + + result := reflect.New(reflect.TypeOf(derefV)) + err := ps.QueryRow(v).Scan(result.Interface()) + if err != nil { + t.Errorf("%v %d: %v", driverName, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("%v %d: expected %v, got %v", driverName, i, derefV, result.Elem().Interface()) + } + } + +} diff --git a/pgtype/pguint32.go b/pgtype/pguint32.go index 3f9e7bf7..7138a409 100644 --- a/pgtype/pguint32.go +++ b/pgtype/pguint32.go @@ -1,9 +1,11 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" + "math" "strconv" "github.com/jackc/pgx/pgio" @@ -21,6 +23,14 @@ type pguint32 struct { // types do. func (dst *pguint32) Set(src interface{}) error { switch value := src.(type) { + case int64: + if value < 0 { + return fmt.Errorf("%d is less than minimum value for pguint32", value) + } + if value > math.MaxUint32 { + return fmt.Errorf("%d is greater than maximum value for pguint32", value) + } + *dst = pguint32{Uint: uint32(value), Status: Present} case uint32: *dst = pguint32{Uint: value, Status: Present} default: @@ -116,3 +126,38 @@ func (src pguint32) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteUint32(w, src.Uint) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *pguint32) Scan(src interface{}) error { + if src == nil { + *dst = pguint32{Status: Null} + return nil + } + + switch src := src.(type) { + case uint32: + *dst = pguint32{Uint: src, Status: Present} + return nil + case int64: + *dst = pguint32{Uint: uint32(src), Status: Present} + return nil + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src pguint32) Value() (driver.Value, error) { + switch src.Status { + case Present: + return int64(src.Uint), nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/qchar.go b/pgtype/qchar.go index 4b32ee4a..49475bd3 100644 --- a/pgtype/qchar.go +++ b/pgtype/qchar.go @@ -17,13 +17,20 @@ import ( // standard type char. // // Not all possible values of QChar are representable in the text format. -// Therefore, QChar does not implement TextEncoder and TextDecoder. +// Therefore, QChar does not implement TextEncoder and TextDecoder. In +// addition, database/sql Scanner and database/sql/driver Value are not +// implemented. type QChar struct { Int int8 Status Status } func (dst *QChar) Set(src interface{}) error { + if src == nil { + *dst = QChar{Status: Null} + return nil + } + switch value := src.(type) { case int8: *dst = QChar{Int: value, Status: Present} diff --git a/pgtype/qchar_test.go b/pgtype/qchar_test.go index a1b6d22e..afac5016 100644 --- a/pgtype/qchar_test.go +++ b/pgtype/qchar_test.go @@ -9,13 +9,15 @@ import ( ) func TestQCharTranscode(t *testing.T) { - testSuccessfulTranscode(t, `"char"`, []interface{}{ + testPgxSuccessfulTranscodeEqFunc(t, `"char"`, []interface{}{ pgtype.QChar{Int: math.MinInt8, Status: pgtype.Present}, pgtype.QChar{Int: -1, Status: pgtype.Present}, pgtype.QChar{Int: 0, Status: pgtype.Present}, pgtype.QChar{Int: 1, Status: pgtype.Present}, pgtype.QChar{Int: math.MaxInt8, Status: pgtype.Present}, pgtype.QChar{Int: 0, Status: pgtype.Null}, + }, func(a, b interface{}) bool { + return reflect.DeepEqual(a, b) }) } diff --git a/pgtype/record.go b/pgtype/record.go index 89e081ca..9c42c907 100644 --- a/pgtype/record.go +++ b/pgtype/record.go @@ -16,6 +16,11 @@ type Record struct { } func (dst *Record) Set(src interface{}) error { + if src == nil { + *dst = Record{Status: Null} + return nil + } + switch value := src.(type) { case []Value: *dst = Record{Fields: value, Status: Present} diff --git a/pgtype/text.go b/pgtype/text.go index dbc9362b..482c9023 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "fmt" "io" ) @@ -11,6 +12,11 @@ type Text struct { } func (dst *Text) Set(src interface{}) error { + if src == nil { + *dst = Text{Status: Null} + return nil + } + switch value := src.(type) { case string: *dst = Text{String: value, Status: Present} @@ -20,6 +26,12 @@ func (dst *Text) Set(src interface{}) error { } else { *dst = Text{String: *value, Status: Present} } + case []byte: + if value == nil { + *dst = Text{Status: Null} + } else { + *dst = Text{String: string(value), Status: Present} + } default: if originalSrc, ok := underlyingStringType(src); ok { return dst.Set(originalSrc) @@ -93,3 +105,32 @@ func (src Text) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { func (src Text) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return src.EncodeText(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *Text) Scan(src interface{}) error { + if src == nil { + *dst = Text{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Text) Value() (driver.Value, error) { + switch src.Status { + case Present: + return src.String, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/text_array.go b/pgtype/text_array.go index 6e8ead26..64728048 100644 --- a/pgtype/text_array.go +++ b/pgtype/text_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -296,3 +297,33 @@ func (src *TextArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *TextArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *TextArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/tid.go b/pgtype/tid.go index b91711d3..b363c1f9 100644 --- a/pgtype/tid.go +++ b/pgtype/tid.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -121,3 +122,25 @@ func (src Tid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err = pgio.WriteUint16(w, src.OffsetNumber) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Tid) Scan(src interface{}) error { + if src == nil { + *dst = Tid{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Tid) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index 4b42f3cf..78c6355e 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -17,14 +18,19 @@ const pgTimestampFormat = "2006-01-02 15:04:05.999999999" // recommended to use timestamptz whenever possible. Timestamp methods either // convert to UTC or return an error on non-UTC times. type Timestamp struct { - Time time.Time // Time must always be in UTC. - Status Status - InfinityModifier + Time time.Time // Time must always be in UTC. + Status Status + InfinityModifier InfinityModifier } // Set converts src into a Timestamp and stores in dst. If src is a // time.Time in a non-UTC time zone, the time zone is discarded. func (dst *Timestamp) Set(src interface{}) error { + if src == nil { + *dst = Timestamp{Status: Null} + return nil + } + switch value := src.(type) { case time.Time: *dst = Timestamp{Time: time.Date(value.Year(), value.Month(), value.Day(), value.Hour(), value.Minute(), value.Second(), value.Nanosecond(), time.UTC), Status: Present} @@ -183,3 +189,38 @@ func (src Timestamp) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt64(w, microsecSinceY2K) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Timestamp) Scan(src interface{}) error { + if src == nil { + *dst = Timestamp{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + case time.Time: + *dst = Timestamp{Time: src, Status: Present} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Timestamp) Value() (driver.Value, error) { + switch src.Status { + case Present: + if src.InfinityModifier != None { + return src.InfinityModifier.String(), nil + } + return src.Time, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/timestamp_array.go b/pgtype/timestamp_array.go index 6a6950c7..5d08f9cc 100644 --- a/pgtype/timestamp_array.go +++ b/pgtype/timestamp_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -297,3 +298,33 @@ func (src *TimestampArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *TimestampArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *TimestampArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index ba849ac8..50370335 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -20,12 +21,17 @@ const ( ) type Timestamptz struct { - Time time.Time - Status Status - InfinityModifier + Time time.Time + Status Status + InfinityModifier InfinityModifier } func (dst *Timestamptz) Set(src interface{}) error { + if src == nil { + *dst = Timestamptz{Status: Null} + return nil + } + switch value := src.(type) { case time.Time: *dst = Timestamptz{Time: value, Status: Present} @@ -179,3 +185,38 @@ func (src Timestamptz) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { _, err := pgio.WriteInt64(w, microsecSinceY2K) return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *Timestamptz) Scan(src interface{}) error { + if src == nil { + *dst = Timestamptz{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + case time.Time: + *dst = Timestamptz{Time: src, Status: Present} + return nil + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Timestamptz) Value() (driver.Value, error) { + switch src.Status { + case Present: + if src.InfinityModifier != None { + return src.InfinityModifier.String(), nil + } + return src.Time, nil + case Null: + return nil, nil + default: + return nil, errUndefined + } +} diff --git a/pgtype/timestamptz_array.go b/pgtype/timestamptz_array.go index 347d0b8b..107be06a 100644 --- a/pgtype/timestamptz_array.go +++ b/pgtype/timestamptz_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -297,3 +298,33 @@ func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, erro return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *TimestamptzArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *TimestamptzArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb index 0e5725ce..4b8f1a28 100644 --- a/pgtype/typed_array.go.erb +++ b/pgtype/typed_array.go.erb @@ -299,3 +299,33 @@ func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool return false, err } <% end %> + +// Scan implements the database/sql Scanner interface. +func (dst *<%= pgtype_array_type %>) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *<%= pgtype_array_type %>) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/unknown.go b/pgtype/unknown.go index b951ad99..2dca0f87 100644 --- a/pgtype/unknown.go +++ b/pgtype/unknown.go @@ -1,5 +1,7 @@ package pgtype +import "database/sql/driver" + // Unknown represents the PostgreSQL unknown type. It is either a string literal // or NULL. It is used when PostgreSQL does not know the type of a value. In // general, this will only be used in pgx when selecting a null value without @@ -30,3 +32,13 @@ func (dst *Unknown) DecodeText(ci *ConnInfo, src []byte) error { func (dst *Unknown) DecodeBinary(ci *ConnInfo, src []byte) error { return (*Text)(dst).DecodeBinary(ci, src) } + +// Scan implements the database/sql Scanner interface. +func (dst *Unknown) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Unknown) Value() (driver.Value, error) { + return (Text)(src).Value() +} diff --git a/pgtype/varchar.go b/pgtype/varchar.go index adda6c49..f25ada5d 100644 --- a/pgtype/varchar.go +++ b/pgtype/varchar.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "io" ) @@ -38,3 +39,13 @@ func (src Varchar) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { func (src Varchar) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return (Text)(src).EncodeBinary(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *Varchar) Scan(src interface{}) error { + return (*Text)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Varchar) Value() (driver.Value, error) { + return (Text)(src).Value() +} diff --git a/pgtype/varchar_array.go b/pgtype/varchar_array.go index e1dd3910..2712b4d2 100644 --- a/pgtype/varchar_array.go +++ b/pgtype/varchar_array.go @@ -2,6 +2,7 @@ package pgtype import ( "bytes" + "database/sql/driver" "encoding/binary" "fmt" "io" @@ -296,3 +297,33 @@ func (src *VarcharArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return false, err } + +// Scan implements the database/sql Scanner interface. +func (dst *VarcharArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *VarcharArray) Value() (driver.Value, error) { + buf := &bytes.Buffer{} + null, err := src.EncodeText(nil, buf) + if err != nil { + return nil, err + } + if null { + return nil, nil + } + + return buf.String(), nil +} diff --git a/pgtype/xid.go b/pgtype/xid.go index c76548a4..0a7fc7d9 100644 --- a/pgtype/xid.go +++ b/pgtype/xid.go @@ -1,6 +1,7 @@ package pgtype import ( + "database/sql/driver" "io" ) @@ -52,3 +53,13 @@ func (src Xid) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { func (src Xid) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { return (pguint32)(src).EncodeBinary(ci, w) } + +// Scan implements the database/sql Scanner interface. +func (dst *Xid) Scan(src interface{}) error { + return (*pguint32)(dst).Scan(src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Xid) Value() (driver.Value, error) { + return (pguint32)(src).Value() +} diff --git a/query.go b/query.go index 0b5cc911..e820fabc 100644 --- a/query.go +++ b/query.go @@ -208,47 +208,6 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { if err != nil { rows.Fatal(scanArgError{col: i, err: err}) } - } else if s, ok := d.(sql.Scanner); ok { - var sqlSrc interface{} - if 0 <= vr.Len() { - if dt, ok := rows.conn.ConnInfo.DataTypeForOid(vr.Type().DataType); ok { - value := dt.Value - - switch vr.Type().FormatCode { - case TextFormatCode: - decoder := value.(pgtype.TextDecoder) - if decoder == nil { - decoder = &pgtype.GenericText{} - } - err := decoder.DecodeText(rows.conn.ConnInfo, vr.bytes()) - if err != nil { - rows.Fatal(err) - } - case BinaryFormatCode: - decoder := value.(pgtype.BinaryDecoder) - if decoder == nil { - decoder = &pgtype.GenericBinary{} - } - err := decoder.DecodeBinary(rows.conn.ConnInfo, vr.bytes()) - if err != nil { - rows.Fatal(err) - } - default: - rows.Fatal(errors.New("Unknown format code")) - } - - sqlSrc, err = pgtype.DatabaseSQLValue(rows.conn.ConnInfo, value) - if err != nil { - rows.Fatal(err) - } - } else { - rows.Fatal(errors.New("Unknown type")) - } - } - err = s.Scan(sqlSrc) - if err != nil { - rows.Fatal(scanArgError{col: i, err: err}) - } } else { if dt, ok := rows.conn.ConnInfo.DataTypeForOid(vr.Type().DataType); ok { value := dt.Value @@ -276,7 +235,16 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { } if vr.Err() == nil { - if err := value.AssignTo(d); err != nil { + if scanner, ok := d.(sql.Scanner); ok { + sqlSrc, err := pgtype.DatabaseSQLValue(rows.conn.ConnInfo, value) + if err != nil { + rows.Fatal(err) + } + err = scanner.Scan(sqlSrc) + if err != nil { + rows.Fatal(scanArgError{col: i, err: err}) + } + } else if err := value.AssignTo(d); err != nil { vr.Fatal(err) } } @@ -355,71 +323,6 @@ func (rows *Rows) Values() ([]interface{}, error) { return values, rows.Err() } -// ValuesForStdlib is a temporary function to keep all systems operational -// while refactoring. Do not use. -func (rows *Rows) ValuesForStdlib() ([]interface{}, error) { - if rows.closed { - return nil, errors.New("rows is closed") - } - - values := make([]interface{}, 0, len(rows.fields)) - - for range rows.fields { - vr, _ := rows.nextColumn() - - if vr.Len() == -1 { - values = append(values, nil) - continue - } - - if dt, ok := rows.conn.ConnInfo.DataTypeForOid(vr.Type().DataType); ok { - value := dt.Value - - switch vr.Type().FormatCode { - case TextFormatCode: - decoder := value.(pgtype.TextDecoder) - if decoder == nil { - decoder = &pgtype.GenericText{} - } - err := decoder.DecodeText(rows.conn.ConnInfo, vr.bytes()) - if err != nil { - rows.Fatal(err) - } - case BinaryFormatCode: - decoder := value.(pgtype.BinaryDecoder) - if decoder == nil { - decoder = &pgtype.GenericBinary{} - } - err := decoder.DecodeBinary(rows.conn.ConnInfo, vr.bytes()) - if err != nil { - rows.Fatal(err) - } - default: - rows.Fatal(errors.New("Unknown format code")) - } - - sqlSrc, err := pgtype.DatabaseSQLValue(rows.conn.ConnInfo, value) - if err != nil { - rows.Fatal(err) - } - - values = append(values, sqlSrc) - } else { - rows.Fatal(errors.New("Unknown type")) - } - - if vr.Err() != nil { - rows.Fatal(vr.Err()) - } - - if rows.Err() != nil { - return nil, rows.Err() - } - } - - return values, rows.Err() -} - // AfterClose adds f to a LILO queue of functions that will be called when // rows is closed. func (rows *Rows) AfterClose(f func(*Rows)) { diff --git a/query_test.go b/query_test.go index b053e26d..25347ec5 100644 --- a/query_test.go +++ b/query_test.go @@ -704,30 +704,6 @@ func TestQueryRowCoreByteSlice(t *testing.T) { } } -func TestQueryRowByteSliceArgument(t *testing.T) { - t.Parallel() - - conn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, conn) - - sql := "select $1::int4" - queryArg := []byte{14, 63, 53, 49} - expected := int32(239023409) - - var actual int32 - - err := conn.QueryRow(sql, queryArg).Scan(&actual) - if err != nil { - t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql) - } - - if expected != actual { - t.Errorf("Expected %v, got %v (sql -> %v)", expected, actual, sql) - } - - ensureConnValid(t, conn) -} - func TestQueryRowUnknownType(t *testing.T) { t.Parallel() diff --git a/stdlib/sql.go b/stdlib/sql.go index 6889a2b6..affa93b6 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -68,14 +68,17 @@ func init() { databaseSqlOids = make(map[pgtype.Oid]bool) databaseSqlOids[pgtype.BoolOid] = true databaseSqlOids[pgtype.ByteaOid] = true + databaseSqlOids[pgtype.CidOid] = true + databaseSqlOids[pgtype.DateOid] = true + databaseSqlOids[pgtype.Float4Oid] = true + databaseSqlOids[pgtype.Float8Oid] = true databaseSqlOids[pgtype.Int2Oid] = true databaseSqlOids[pgtype.Int4Oid] = true databaseSqlOids[pgtype.Int8Oid] = true - databaseSqlOids[pgtype.Float4Oid] = true - databaseSqlOids[pgtype.Float8Oid] = true - databaseSqlOids[pgtype.DateOid] = true - databaseSqlOids[pgtype.TimestamptzOid] = true + databaseSqlOids[pgtype.OidOid] = true databaseSqlOids[pgtype.TimestampOid] = true + databaseSqlOids[pgtype.TimestamptzOid] = true + databaseSqlOids[pgtype.XidOid] = true } type Driver struct { @@ -292,9 +295,9 @@ func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) { return s.conn.queryPrepared(s.ps.Name, argsV) } -// TODO - rename to avoid alloc type Rows struct { - rows *pgx.Rows + rows *pgx.Rows + values []interface{} } func (r *Rows) Columns() []string { @@ -312,6 +315,42 @@ func (r *Rows) Close() error { } func (r *Rows) Next(dest []driver.Value) error { + if r.values == nil { + r.values = make([]interface{}, len(r.rows.FieldDescriptions())) + for i, fd := range r.rows.FieldDescriptions() { + switch fd.DataType { + case pgtype.BoolOid: + r.values[i] = &pgtype.Bool{} + case pgtype.ByteaOid: + r.values[i] = &pgtype.Bytea{} + case pgtype.CidOid: + r.values[i] = &pgtype.Cid{} + case pgtype.DateOid: + r.values[i] = &pgtype.Date{} + case pgtype.Float4Oid: + r.values[i] = &pgtype.Float4{} + case pgtype.Float8Oid: + r.values[i] = &pgtype.Float8{} + case pgtype.Int2Oid: + r.values[i] = &pgtype.Int2{} + case pgtype.Int4Oid: + r.values[i] = &pgtype.Int4{} + case pgtype.Int8Oid: + r.values[i] = &pgtype.Int8{} + case pgtype.OidOid: + r.values[i] = &pgtype.OidValue{} + case pgtype.TimestampOid: + r.values[i] = &pgtype.Timestamp{} + case pgtype.TimestamptzOid: + r.values[i] = &pgtype.Timestamptz{} + case pgtype.XidOid: + r.values[i] = &pgtype.Xid{} + default: + r.values[i] = &pgtype.GenericText{} + } + } + } + more := r.rows.Next() if !more { if r.rows.Err() == nil { @@ -321,19 +360,16 @@ func (r *Rows) Next(dest []driver.Value) error { } } - values, err := r.rows.ValuesForStdlib() + err := r.rows.Scan(r.values...) if err != nil { return err } - if len(dest) < len(values) { - fmt.Printf("%d: %#v\n", len(dest), dest) - fmt.Printf("%d: %#v\n", len(values), values) - return errors.New("expected more values than were received") - } - - for i, v := range values { - dest[i] = driver.Value(v) + for i, v := range r.values { + dest[i], err = v.(driver.Valuer).Value() + if err != nil { + return err + } } return nil diff --git a/values.go b/values.go index c399b42c..5370bf47 100644 --- a/values.go +++ b/values.go @@ -65,10 +65,6 @@ func encodePreparedStatementArgument(wbuf *WriteBuf, oid pgtype.Oid, arg interfa wbuf.WriteInt32(int32(len(arg))) wbuf.WriteBytes([]byte(arg)) return nil - case []byte: - wbuf.WriteInt32(int32(len(arg))) - wbuf.WriteBytes(arg) - return nil } refVal := reflect.ValueOf(arg)