diff --git a/pgtype/circle.go b/pgtype/circle.go index 7524d7b9..ec136438 100644 --- a/pgtype/circle.go +++ b/pgtype/circle.go @@ -11,33 +11,184 @@ import ( "github.com/jackc/pgio" ) +type CircleScanner interface { + ScanCircle(v Circle) error +} + +type CircleValuer interface { + CircleValue() (Circle, error) +} + type Circle struct { P Vec2 R float64 Valid bool } -func (dst *Circle) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Circle", src) +func (c *Circle) ScanCircle(v Circle) error { + *c = v + return nil } -func (dst Circle) Get() interface{} { - if !dst.Valid { - return nil - } - return dst +func (c Circle) CircleValue() (Circle, error) { + return c, nil } -func (src *Circle) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) -} - -func (dst *Circle) DecodeText(ci *ConnInfo, src []byte) error { +// Scan implements the database/sql Scanner interface. +func (dst *Circle) Scan(src interface{}) error { if src == nil { *dst = Circle{} return nil } + switch src := src.(type) { + case string: + return scanPlanTextAnyToCircleScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), dst) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src Circle) Value() (driver.Value, error) { + buf, err := CircleCodec{}.Encode(nil, 0, TextFormatCode, src, nil) + if err != nil { + return nil, err + } + return string(buf), err +} + +type CircleCodec struct{} + +func (CircleCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (CircleCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (CircleCodec) Encode(ci *ConnInfo, oid uint32, format int16, value interface{}, buf []byte) (newBuf []byte, err error) { + if value == nil { + return nil, nil + } + + var circle Circle + if v, ok := value.(CircleValuer); ok { + c, err := v.CircleValue() + if err != nil { + return nil, err + } + circle = c + } else { + return nil, fmt.Errorf("cannot convert %v to circle: %v", value, err) + } + + if !circle.Valid { + return nil, nil + } + + switch format { + case BinaryFormatCode: + buf = pgio.AppendUint64(buf, math.Float64bits(circle.P.X)) + buf = pgio.AppendUint64(buf, math.Float64bits(circle.P.Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(circle.R)) + return buf, nil + case TextFormatCode: + buf = append(buf, fmt.Sprintf(`<(%s,%s),%s>`, + strconv.FormatFloat(circle.P.X, 'f', -1, 64), + strconv.FormatFloat(circle.P.Y, 'f', -1, 64), + strconv.FormatFloat(circle.R, 'f', -1, 64), + )...) + return buf, nil + default: + return nil, fmt.Errorf("unknown format code: %v", format) + } +} + +func (CircleCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case CircleScanner: + return scanPlanBinaryCircleToCircleScanner{} + } + case TextFormatCode: + switch target.(type) { + case CircleScanner: + return scanPlanTextAnyToCircleScanner{} + } + } + + return nil +} + +func (c CircleCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + if format == TextFormatCode { + return string(src), nil + } else { + circle, err := c.DecodeValue(ci, oid, format, src) + if err != nil { + return nil, err + } + buf, err := c.Encode(ci, oid, TextFormatCode, circle, nil) + if err != nil { + return nil, err + } + return string(buf), nil + } +} + +func (c CircleCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { + if src == nil { + return nil, nil + } + + var circle Circle + scanPlan := c.PlanScan(ci, oid, format, &circle, true) + if scanPlan == nil { + return nil, fmt.Errorf("PlanScan did not find a plan") + } + err := scanPlan.Scan(ci, oid, format, src, &circle) + if err != nil { + return nil, err + } + return circle, nil +} + +type scanPlanBinaryCircleToCircleScanner struct{} + +func (scanPlanBinaryCircleToCircleScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(CircleScanner) + + if src == nil { + return scanner.ScanCircle(Circle{}) + } + + if len(src) != 24 { + return fmt.Errorf("invalid length for Circle: %v", len(src)) + } + + x := binary.BigEndian.Uint64(src) + y := binary.BigEndian.Uint64(src[8:]) + r := binary.BigEndian.Uint64(src[16:]) + + return scanner.ScanCircle(Circle{ + P: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, + R: math.Float64frombits(r), + Valid: true, + }) +} + +type scanPlanTextAnyToCircleScanner struct{} + +func (scanPlanTextAnyToCircleScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(CircleScanner) + + if src == nil { + return scanner.ScanCircle(Circle{}) + } + if len(src) < 9 { return fmt.Errorf("invalid length for Circle: %v", len(src)) } @@ -64,77 +215,5 @@ func (dst *Circle) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Circle{P: Vec2{x, y}, R: r, Valid: true} - return nil -} - -func (dst *Circle) DecodeBinary(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Circle{} - return nil - } - - if len(src) != 24 { - return fmt.Errorf("invalid length for Circle: %v", len(src)) - } - - x := binary.BigEndian.Uint64(src) - y := binary.BigEndian.Uint64(src[8:]) - r := binary.BigEndian.Uint64(src[16:]) - - *dst = Circle{ - P: Vec2{math.Float64frombits(x), math.Float64frombits(y)}, - R: math.Float64frombits(r), - Valid: true, - } - return nil -} - -func (src Circle) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - buf = append(buf, fmt.Sprintf(`<(%s,%s),%s>`, - strconv.FormatFloat(src.P.X, 'f', -1, 64), - strconv.FormatFloat(src.P.Y, 'f', -1, 64), - strconv.FormatFloat(src.R, 'f', -1, 64), - )...) - - return buf, nil -} - -func (src Circle) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil - } - - buf = pgio.AppendUint64(buf, math.Float64bits(src.P.X)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.P.Y)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.R)) - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Circle) Scan(src interface{}) error { - if src == nil { - *dst = Circle{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Circle) Value() (driver.Value, error) { - return EncodeValueText(src) + return scanner.ScanCircle(Circle{P: Vec2{x, y}, R: r, Valid: true}) } diff --git a/pgtype/circle_test.go b/pgtype/circle_test.go index 8f39644b..742ac688 100644 --- a/pgtype/circle_test.go +++ b/pgtype/circle_test.go @@ -4,13 +4,20 @@ import ( "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestCircleTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "circle", []interface{}{ - &pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Valid: true}, - &pgtype.Circle{P: pgtype.Vec2{-1.234, -5.6789}, R: 12.9, Valid: true}, - &pgtype.Circle{}, + testPgxCodec(t, "circle", []PgxTranscodeTestCase{ + { + pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Valid: true}, + new(pgtype.Circle), + isExpectedEq(pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Valid: true}), + }, + { + pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Valid: true}, + new(pgtype.Circle), + isExpectedEq(pgtype.Circle{P: pgtype.Vec2{1.234, 5.67890123}, R: 3.5, Valid: true}), + }, + {nil, new(pgtype.Circle), isExpectedEq(pgtype.Circle{})}, }) } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index aedc6dd5..11f7ce0b 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -320,7 +320,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &QChar{}, Name: "char", OID: QCharOID}) ci.RegisterDataType(DataType{Value: &CID{}, Name: "cid", OID: CIDOID}) ci.RegisterDataType(DataType{Value: &CIDR{}, Name: "cidr", OID: CIDROID}) - ci.RegisterDataType(DataType{Value: &Circle{}, Name: "circle", OID: CircleOID}) + ci.RegisterDataType(DataType{Name: "circle", OID: CircleOID, Codec: CircleCodec{}}) ci.RegisterDataType(DataType{Value: &Date{}, Name: "date", OID: DateOID}) // ci.RegisterDataType(DataType{Value: &Daterange{}, Name: "daterange", OID: DaterangeOID}) ci.RegisterDataType(DataType{Value: &Float4{}, Name: "float4", OID: Float4OID}) diff --git a/pgtype/zzz.circle.go b/pgtype/zzz.circle.go deleted file mode 100644 index b111c06d..00000000 --- a/pgtype/zzz.circle.go +++ /dev/null @@ -1,35 +0,0 @@ -package pgtype - -import "fmt" - -func (Circle) BinaryFormatSupported() bool { - return true -} - -func (Circle) TextFormatSupported() bool { - return true -} - -func (Circle) PreferredFormat() int16 { - return BinaryFormatCode -} - -func (dst *Circle) DecodeResult(ci *ConnInfo, oid uint32, format int16, src []byte) error { - switch format { - case BinaryFormatCode: - return dst.DecodeBinary(ci, src) - case TextFormatCode: - return dst.DecodeText(ci, src) - } - return fmt.Errorf("unknown format code %d", format) -} - -func (src Circle) EncodeParam(ci *ConnInfo, oid uint32, format int16, buf []byte) (newBuf []byte, err error) { - switch format { - case BinaryFormatCode: - return src.EncodeBinary(ci, buf) - case TextFormatCode: - return src.EncodeText(ci, buf) - } - return nil, fmt.Errorf("unknown format code %d", format) -}