2
0

New marshalers have been added

This commit is contained in:
bakmataliev
2020-09-11 16:24:48 +03:00
parent e7d2b057a7
commit d540ca39be
4 changed files with 228 additions and 1 deletions
+1
View File
@@ -0,0 +1 @@
.idea
+77 -1
View File
@@ -5,6 +5,7 @@ import (
"encoding/binary"
"fmt"
"math"
"regexp"
"strconv"
"strings"
@@ -22,8 +23,62 @@ type Point struct {
Status Status
}
var nullRE = regexp.MustCompile("^null$")
func (dst *Point) Set(src interface{}) error {
return errors.Errorf("cannot convert %v to Point", src)
if src == nil {
dst.Status = Null
return nil
}
err := errors.Errorf("cannot convert %v to Point", src)
var p *Point
switch value := src.(type) {
case string:
p, err = parsePoint([]byte(value))
case []byte:
if nullRE.Match(value) {
dst.Status = Null
return nil
}
p, err = parsePoint(value)
default:
return err
}
if err != nil {
return err
}
*dst = *p
return nil
}
var pointRE = regexp.MustCompile("^\\(\\d+\\.\\d+,\\s?\\d+\\.\\d+\\)$")
var chunkRE = regexp.MustCompile("\\d+\\.\\d+")
func parsePoint(p []byte) (*Point, error) {
err := errors.Errorf("cannot parse %s", p)
if pointRE.Match(p) {
chunks := chunkRE.FindAll(p, 2)
if len(chunks) != 2 {
return nil, err
}
x, xErr := strconv.ParseFloat(string(chunks[0]), 64)
y, yErr := strconv.ParseFloat(string(chunks[1]), 64)
if xErr != nil || yErr != nil {
return nil, err
}
return &Point{
P: Vec2{
X: x,
Y: y,
},
Status: Present,
}, nil
} else if nullRE.Match(p) {
return &Point{
Status: Null,
}, nil
}
return nil, err
}
func (dst Point) Get() interface{} {
@@ -140,3 +195,24 @@ func (dst *Point) Scan(src interface{}) error {
func (src Point) Value() (driver.Value, error) {
return EncodeValueText(src)
}
func (src Point) MarshalJSON() ([]byte, error) {
switch src.Status {
case Present:
return []byte(fmt.Sprintf("(%g, %g)", src.P.X, src.P.Y)), nil
case Null:
return []byte("null"), nil
case Undefined:
return nil, errUndefined
}
return nil, errBadStatus
}
func (dst *Point) UnmarshalJSON(point []byte) error {
p, err := parsePoint(point)
if err != nil {
return err
}
*dst = *p
return nil
}
+134
View File
@@ -1,6 +1,7 @@
package pgtype_test
import (
"reflect"
"testing"
"github.com/jackc/pgtype"
@@ -14,3 +15,136 @@ func TestPointTranscode(t *testing.T) {
&pgtype.Point{Status: pgtype.Null},
})
}
func TestPoint_Set(t *testing.T) {
tests := []struct {
name string
arg interface{}
status pgtype.Status
wantErr bool
}{
{
name: "first",
arg: "(12312.123123, 123123.123123)",
status: pgtype.Present,
wantErr: false,
},
{
name: "second",
arg: "(1231s2.123123, 123123.123123)",
status: pgtype.Undefined,
wantErr: true,
},
{
name: "third",
arg: []byte("(122.123123,123.123123)"),
status: pgtype.Present,
wantErr: false,
},
{
name: "third",
arg: nil,
status: pgtype.Null,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dst := &pgtype.Point{}
if err := dst.Set(tt.arg); (err != nil) != tt.wantErr {
t.Errorf("Set() error = %v, wantErr %v", err, tt.wantErr)
}
if dst.Status != tt.status {
t.Errorf("Expected status: %v; got: %v", tt.status, dst.Status)
}
})
}
}
func TestPoint_MarshalJSON(t *testing.T) {
tests := []struct {
name string
point pgtype.Point
want []byte
wantErr bool
}{
{
name: "first",
point: pgtype.Point{
P: pgtype.Vec2{},
Status: 0,
},
want: nil,
wantErr: true,
},
{
name: "second",
point: pgtype.Point{
P: pgtype.Vec2{X: 12.245, Y: 432.12},
Status: pgtype.Present,
},
want: []byte("(12.245, 432.12)"),
wantErr: false,
},
{
name: "third",
point: pgtype.Point{
P: pgtype.Vec2{},
Status: pgtype.Null,
},
want: []byte("null"),
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.point.MarshalJSON()
if (err != nil) != tt.wantErr {
t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("MarshalJSON() got = %v, want %v", got, tt.want)
}
})
}
}
func TestPoint_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string
status pgtype.Status
arg []byte
wantErr bool
}{
{
name: "first",
status: pgtype.Present,
arg: []byte("(123.123, 54.12)"),
wantErr: false,
},
{
name: "second",
status: pgtype.Undefined,
arg: []byte("(123.123, 54.1sad2)"),
wantErr: true,
},
{
name: "third",
status: pgtype.Null,
arg: []byte("null"),
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dst := &pgtype.Point{}
if err := dst.UnmarshalJSON(tt.arg); (err != nil) != tt.wantErr {
t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
}
if dst.Status != tt.status {
t.Errorf("Status mismatch: %v != %v", dst.Status, tt.status)
}
})
}
}
+16
View File
@@ -203,3 +203,19 @@ func (dst *UUID) Scan(src interface{}) error {
func (src UUID) Value() (driver.Value, error) {
return EncodeValueText(src)
}
func (src UUID) MarshalJSON() ([]byte, error) {
switch src.Status {
case Present:
return []byte(encodeUUID(src.Bytes)), nil
case Null:
return []byte("null"), nil
case Undefined:
return nil, errUndefined
}
return nil, errBadStatus
}
func (dst *UUID) UnmarshalJSON(bytes []byte) error {
return dst.Set(bytes)
}