From 869213a315c957e4210efb09b640ae68b5d671ee Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 18 Jan 2022 11:38:35 -0600 Subject: [PATCH] Convert lseg to Codec --- pgtype/lseg.go | 247 +++++++++++++++++++++++++++++--------------- pgtype/lseg_test.go | 32 ++++-- pgtype/pgtype.go | 2 +- 3 files changed, 189 insertions(+), 92 deletions(-) diff --git a/pgtype/lseg.go b/pgtype/lseg.go index 649863ca..26730e85 100644 --- a/pgtype/lseg.go +++ b/pgtype/lseg.go @@ -11,34 +11,177 @@ import ( "github.com/jackc/pgio" ) +type LsegScanner interface { + ScanLseg(v Lseg) error +} + +type LsegValuer interface { + LsegValue() (Lseg, error) +} + type Lseg struct { P [2]Vec2 Valid bool } -func (dst *Lseg) Set(src interface{}) error { - return fmt.Errorf("cannot convert %v to Lseg", src) +func (lseg *Lseg) ScanLseg(v Lseg) error { + *lseg = v + return nil } -func (dst Lseg) Get() interface{} { - if !dst.Valid { +func (lseg Lseg) LsegValue() (Lseg, error) { + return lseg, nil +} + +// Scan implements the database/sql Scanner interface. +func (lseg *Lseg) Scan(src interface{}) error { + if src == nil { + *lseg = Lseg{} return nil } - return dst + + switch src := src.(type) { + case string: + return scanPlanTextAnyToLsegScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), lseg) + } + + return fmt.Errorf("cannot scan %T", src) } -func (src *Lseg) AssignTo(dst interface{}) error { - return fmt.Errorf("cannot assign %v to %T", src, dst) +// Value implements the database/sql/driver Valuer interface. +func (lseg Lseg) Value() (driver.Value, error) { + if !lseg.Valid { + return nil, nil + } + + buf, err := LsegCodec{}.PlanEncode(nil, 0, TextFormatCode, lseg).Encode(lseg, nil) + if err != nil { + return nil, err + } + return string(buf), err } -func (dst *Lseg) DecodeText(ci *ConnInfo, src []byte) error { - if src == nil { - *dst = Lseg{} +type LsegCodec struct{} + +func (LsegCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (LsegCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (LsegCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + if _, ok := value.(LsegValuer); !ok { return nil } + switch format { + case BinaryFormatCode: + return encodePlanLsegCodecBinary{} + case TextFormatCode: + return encodePlanLsegCodecText{} + } + + return nil +} + +type encodePlanLsegCodecBinary struct{} + +func (encodePlanLsegCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + lseg, err := value.(LsegValuer).LsegValue() + if err != nil { + return nil, err + } + + if !lseg.Valid { + return nil, nil + } + + buf = pgio.AppendUint64(buf, math.Float64bits(lseg.P[0].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(lseg.P[0].Y)) + buf = pgio.AppendUint64(buf, math.Float64bits(lseg.P[1].X)) + buf = pgio.AppendUint64(buf, math.Float64bits(lseg.P[1].Y)) + return buf, nil +} + +type encodePlanLsegCodecText struct{} + +func (encodePlanLsegCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + lseg, err := value.(LsegValuer).LsegValue() + if err != nil { + return nil, err + } + + if !lseg.Valid { + return nil, nil + } + + buf = append(buf, fmt.Sprintf(`(%s,%s),(%s,%s)`, + strconv.FormatFloat(lseg.P[0].X, 'f', -1, 64), + strconv.FormatFloat(lseg.P[0].Y, 'f', -1, 64), + strconv.FormatFloat(lseg.P[1].X, 'f', -1, 64), + strconv.FormatFloat(lseg.P[1].Y, 'f', -1, 64), + )...) + return buf, nil +} + +func (LsegCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case LsegScanner: + return scanPlanBinaryLsegToLsegScanner{} + } + case TextFormatCode: + switch target.(type) { + case LsegScanner: + return scanPlanTextAnyToLsegScanner{} + } + } + + return nil +} + +type scanPlanBinaryLsegToLsegScanner struct{} + +func (scanPlanBinaryLsegToLsegScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(LsegScanner) + + if src == nil { + return scanner.ScanLseg(Lseg{}) + } + + if len(src) != 32 { + return fmt.Errorf("invalid length for lseg: %v", len(src)) + } + + x1 := binary.BigEndian.Uint64(src) + y1 := binary.BigEndian.Uint64(src[8:]) + x2 := binary.BigEndian.Uint64(src[16:]) + y2 := binary.BigEndian.Uint64(src[24:]) + + return scanner.ScanLseg(Lseg{ + P: [2]Vec2{ + {math.Float64frombits(x1), math.Float64frombits(y1)}, + {math.Float64frombits(x2), math.Float64frombits(y2)}, + }, + Valid: true, + }) +} + +type scanPlanTextAnyToLsegScanner struct{} + +func (scanPlanTextAnyToLsegScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(LsegScanner) + + if src == nil { + return scanner.ScanLseg(Lseg{}) + } + if len(src) < 11 { - return fmt.Errorf("invalid length for Lseg: %v", len(src)) + return fmt.Errorf("invalid length for lseg: %v", len(src)) } str := string(src[2:]) @@ -74,82 +217,22 @@ func (dst *Lseg) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Lseg{P: [2]Vec2{{x1, y1}, {x2, y2}}, Valid: true} - return nil + return scanner.ScanLseg(Lseg{P: [2]Vec2{{x1, y1}, {x2, y2}}, Valid: true}) } -func (dst *Lseg) DecodeBinary(ci *ConnInfo, src []byte) error { +func (c LsegCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, ci, oid, format, src) +} + +func (c LsegCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { - *dst = Lseg{} - return nil - } - - if len(src) != 32 { - return fmt.Errorf("invalid length for Lseg: %v", len(src)) - } - - x1 := binary.BigEndian.Uint64(src) - y1 := binary.BigEndian.Uint64(src[8:]) - x2 := binary.BigEndian.Uint64(src[16:]) - y2 := binary.BigEndian.Uint64(src[24:]) - - *dst = Lseg{ - P: [2]Vec2{ - {math.Float64frombits(x1), math.Float64frombits(y1)}, - {math.Float64frombits(x2), math.Float64frombits(y2)}, - }, - Valid: true, - } - return nil -} - -func (src Lseg) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { return nil, nil } - buf = append(buf, fmt.Sprintf(`(%s,%s),(%s,%s)`, - strconv.FormatFloat(src.P[0].X, 'f', -1, 64), - strconv.FormatFloat(src.P[0].Y, 'f', -1, 64), - strconv.FormatFloat(src.P[1].X, 'f', -1, 64), - strconv.FormatFloat(src.P[1].Y, 'f', -1, 64), - )...) - - return buf, nil -} - -func (src Lseg) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil + var lseg Lseg + err := codecScan(c, ci, oid, format, src, &lseg) + if err != nil { + return nil, err } - - buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].X)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].Y)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].X)) - buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].Y)) - return buf, nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Lseg) Scan(src interface{}) error { - if src == nil { - *dst = Lseg{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Lseg) Value() (driver.Value, error) { - return EncodeValueText(src) + return lseg, nil } diff --git a/pgtype/lseg_test.go b/pgtype/lseg_test.go index ce128784..1866439f 100644 --- a/pgtype/lseg_test.go +++ b/pgtype/lseg_test.go @@ -4,19 +4,33 @@ import ( "testing" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" ) func TestLsegTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "lseg", []interface{}{ - &pgtype.Lseg{ - P: [2]pgtype.Vec2{{3.14, 1.678}, {7.1, 5.2345678901}}, - Valid: true, + testPgxCodec(t, "lseg", []PgxTranscodeTestCase{ + { + pgtype.Lseg{ + P: [2]pgtype.Vec2{{3.14, 1.678}, {7.1, 5.2345678901}}, + Valid: true, + }, + new(pgtype.Lseg), + isExpectedEq(pgtype.Lseg{ + P: [2]pgtype.Vec2{{3.14, 1.678}, {7.1, 5.2345678901}}, + Valid: true, + }), }, - &pgtype.Lseg{ - P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, - Valid: true, + { + pgtype.Lseg{ + P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Valid: true, + }, + new(pgtype.Lseg), + isExpectedEq(pgtype.Lseg{ + P: [2]pgtype.Vec2{{7.1, 1.678}, {-13.14, -5.234}}, + Valid: true, + }), }, - &pgtype.Lseg{}, + {pgtype.Lseg{}, new(pgtype.Lseg), isExpectedEq(pgtype.Lseg{})}, + {nil, new(pgtype.Lseg), isExpectedEq(pgtype.Lseg{})}, }) } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 7c94c809..9948c87a 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -314,7 +314,7 @@ func NewConnInfo() *ConnInfo { ci.RegisterDataType(DataType{Value: &JSONB{}, Name: "jsonb", OID: JSONBOID}) ci.RegisterDataType(DataType{Value: &JSONBArray{}, Name: "_jsonb", OID: JSONBArrayOID}) ci.RegisterDataType(DataType{Name: "line", OID: LineOID, Codec: LineCodec{}}) - ci.RegisterDataType(DataType{Value: &Lseg{}, Name: "lseg", OID: LsegOID}) + ci.RegisterDataType(DataType{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}}) ci.RegisterDataType(DataType{Value: &Macaddr{}, Name: "macaddr", OID: MacaddrOID}) ci.RegisterDataType(DataType{Name: "name", OID: NameOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Value: &Numeric{}, Name: "numeric", OID: NumericOID})