From d540ca39be4f4f307552b89256cb6afe998221ff Mon Sep 17 00:00:00 2001 From: bakmataliev Date: Fri, 11 Sep 2020 16:24:48 +0300 Subject: [PATCH] New marshalers have been added --- .gitignore | 1 + point.go | 78 ++++++++++++++++++++++++++++- point_test.go | 134 ++++++++++++++++++++++++++++++++++++++++++++++++++ uuid.go | 16 ++++++ 4 files changed, 228 insertions(+), 1 deletion(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..723ef36f --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.idea \ No newline at end of file diff --git a/point.go b/point.go index 87993656..9961f624 100644 --- a/point.go +++ b/point.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "fmt" "math" + "regexp" "strconv" "strings" @@ -22,8 +23,62 @@ type Point struct { Status Status } +var nullRE = regexp.MustCompile("^null$") + func (dst *Point) Set(src interface{}) error { - return errors.Errorf("cannot convert %v to Point", src) + if src == nil { + dst.Status = Null + return nil + } + err := errors.Errorf("cannot convert %v to Point", src) + var p *Point + switch value := src.(type) { + case string: + p, err = parsePoint([]byte(value)) + case []byte: + if nullRE.Match(value) { + dst.Status = Null + return nil + } + p, err = parsePoint(value) + default: + return err + } + if err != nil { + return err + } + *dst = *p + return nil +} + +var pointRE = regexp.MustCompile("^\\(\\d+\\.\\d+,\\s?\\d+\\.\\d+\\)$") +var chunkRE = regexp.MustCompile("\\d+\\.\\d+") + +func parsePoint(p []byte) (*Point, error) { + err := errors.Errorf("cannot parse %s", p) + if pointRE.Match(p) { + chunks := chunkRE.FindAll(p, 2) + if len(chunks) != 2 { + return nil, err + } + x, xErr := strconv.ParseFloat(string(chunks[0]), 64) + y, yErr := strconv.ParseFloat(string(chunks[1]), 64) + if xErr != nil || yErr != nil { + return nil, err + } + return &Point{ + P: Vec2{ + X: x, + Y: y, + }, + Status: Present, + }, nil + } else if nullRE.Match(p) { + return &Point{ + Status: Null, + }, nil + } + return nil, err } func (dst Point) Get() interface{} { @@ -140,3 +195,24 @@ func (dst *Point) Scan(src interface{}) error { func (src Point) Value() (driver.Value, error) { return EncodeValueText(src) } + +func (src Point) MarshalJSON() ([]byte, error) { + switch src.Status { + case Present: + return []byte(fmt.Sprintf("(%g, %g)", src.P.X, src.P.Y)), nil + case Null: + return []byte("null"), nil + case Undefined: + return nil, errUndefined + } + return nil, errBadStatus +} + +func (dst *Point) UnmarshalJSON(point []byte) error { + p, err := parsePoint(point) + if err != nil { + return err + } + *dst = *p + return nil +} diff --git a/point_test.go b/point_test.go index 0d191b5e..9a659cbc 100644 --- a/point_test.go +++ b/point_test.go @@ -1,6 +1,7 @@ package pgtype_test import ( + "reflect" "testing" "github.com/jackc/pgtype" @@ -14,3 +15,136 @@ func TestPointTranscode(t *testing.T) { &pgtype.Point{Status: pgtype.Null}, }) } + +func TestPoint_Set(t *testing.T) { + tests := []struct { + name string + arg interface{} + status pgtype.Status + wantErr bool + }{ + { + name: "first", + arg: "(12312.123123, 123123.123123)", + status: pgtype.Present, + wantErr: false, + }, + { + name: "second", + arg: "(1231s2.123123, 123123.123123)", + status: pgtype.Undefined, + wantErr: true, + }, + { + name: "third", + arg: []byte("(122.123123,123.123123)"), + status: pgtype.Present, + wantErr: false, + }, + { + name: "third", + arg: nil, + status: pgtype.Null, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dst := &pgtype.Point{} + if err := dst.Set(tt.arg); (err != nil) != tt.wantErr { + t.Errorf("Set() error = %v, wantErr %v", err, tt.wantErr) + } + if dst.Status != tt.status { + t.Errorf("Expected status: %v; got: %v", tt.status, dst.Status) + } + }) + } +} + +func TestPoint_MarshalJSON(t *testing.T) { + tests := []struct { + name string + point pgtype.Point + want []byte + wantErr bool + }{ + { + name: "first", + point: pgtype.Point{ + P: pgtype.Vec2{}, + Status: 0, + }, + want: nil, + wantErr: true, + }, + { + name: "second", + point: pgtype.Point{ + P: pgtype.Vec2{X: 12.245, Y: 432.12}, + Status: pgtype.Present, + }, + want: []byte("(12.245, 432.12)"), + wantErr: false, + }, + { + name: "third", + point: pgtype.Point{ + P: pgtype.Vec2{}, + Status: pgtype.Null, + }, + want: []byte("null"), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.point.MarshalJSON() + if (err != nil) != tt.wantErr { + t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MarshalJSON() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestPoint_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + status pgtype.Status + arg []byte + wantErr bool + }{ + { + name: "first", + status: pgtype.Present, + arg: []byte("(123.123, 54.12)"), + wantErr: false, + }, + { + name: "second", + status: pgtype.Undefined, + arg: []byte("(123.123, 54.1sad2)"), + wantErr: true, + }, + { + name: "third", + status: pgtype.Null, + arg: []byte("null"), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dst := &pgtype.Point{} + if err := dst.UnmarshalJSON(tt.arg); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if dst.Status != tt.status { + t.Errorf("Status mismatch: %v != %v", dst.Status, tt.status) + } + }) + } +} diff --git a/uuid.go b/uuid.go index 9f9bbefd..caaef2a7 100644 --- a/uuid.go +++ b/uuid.go @@ -203,3 +203,19 @@ func (dst *UUID) Scan(src interface{}) error { func (src UUID) Value() (driver.Value, error) { return EncodeValueText(src) } + +func (src UUID) MarshalJSON() ([]byte, error) { + switch src.Status { + case Present: + return []byte(encodeUUID(src.Bytes)), nil + case Null: + return []byte("null"), nil + case Undefined: + return nil, errUndefined + } + return nil, errBadStatus +} + +func (dst *UUID) UnmarshalJSON(bytes []byte) error { + return dst.Set(bytes) +}