diff --git a/pgtype.go b/pgtype.go index 208b1f00..911ab70e 100644 --- a/pgtype.go +++ b/pgtype.go @@ -245,6 +245,7 @@ func init() { "numeric": &Numeric{}, "numrange": &Numrange{}, "oid": &OidValue{}, + "point": &Point{}, "record": &Record{}, "text": &Text{}, "tid": &Tid{}, diff --git a/point.go b/point.go new file mode 100644 index 00000000..1b40bc44 --- /dev/null +++ b/point.go @@ -0,0 +1,139 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "io" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/pgio" +) + +type Point struct { + X float64 + Y float64 + Status Status +} + +func (dst *Point) Set(src interface{}) error { + return fmt.Errorf("cannot convert %v to Point", src) +} + +func (dst *Point) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *Point) AssignTo(dst interface{}) error { + return fmt.Errorf("cannot assign %v to %T", src, dst) +} + +func (dst *Point) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Point{Status: Null} + return nil + } + + if len(src) < 5 { + return fmt.Errorf("invalid length for point: %v", len(src)) + } + + parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) + if len(parts) < 2 { + return fmt.Errorf("invalid format for point") + } + + x, err := strconv.ParseFloat(parts[0], 64) + if err != nil { + return err + } + + y, err := strconv.ParseFloat(parts[1], 64) + if err != nil { + return err + } + + *dst = Point{X: x, Y: y, Status: Present} + return nil +} + +func (dst *Point) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = Point{Status: Null} + return nil + } + + if len(src) != 16 { + return fmt.Errorf("invalid length for point: %v", len(src)) + } + + x := binary.BigEndian.Uint64(src) + y := binary.BigEndian.Uint64(src[8:]) + + *dst = Point{ + X: math.Float64frombits(x), + Y: math.Float64frombits(y), + Status: Present, + } + return nil +} + +func (src *Point) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := io.WriteString(w, fmt.Sprintf(`(%f,%f)`, src.X, src.Y)) + return false, err +} + +func (src *Point) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + _, err := pgio.WriteUint64(w, math.Float64bits(src.X)) + if err != nil { + return false, err + } + + _, err = pgio.WriteUint64(w, math.Float64bits(src.Y)) + return false, err +} + +// Scan implements the database/sql Scanner interface. +func (dst *Point) Scan(src interface{}) error { + if src == nil { + *dst = Point{Status: Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Point) Value() (driver.Value, error) { + return encodeValueText(src) +} diff --git a/point_test.go b/point_test.go new file mode 100644 index 00000000..4ddb8009 --- /dev/null +++ b/point_test.go @@ -0,0 +1,15 @@ +package pgtype_test + +import ( + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestPointTranscode(t *testing.T) { + testSuccessfulTranscode(t, "point", []interface{}{ + &pgtype.Point{X: 1.234, Y: 5.6789, Status: pgtype.Present}, + &pgtype.Point{X: -1.234, Y: -5.6789, Status: pgtype.Present}, + &pgtype.Point{Status: pgtype.Null}, + }) +}