New marshalers have been added
This commit is contained in:
@@ -0,0 +1 @@
|
||||
.idea
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -22,8 +23,62 @@ type Point struct {
|
||||
Status Status
|
||||
}
|
||||
|
||||
var nullRE = regexp.MustCompile("^null$")
|
||||
|
||||
func (dst *Point) Set(src interface{}) error {
|
||||
return errors.Errorf("cannot convert %v to Point", src)
|
||||
if src == nil {
|
||||
dst.Status = Null
|
||||
return nil
|
||||
}
|
||||
err := errors.Errorf("cannot convert %v to Point", src)
|
||||
var p *Point
|
||||
switch value := src.(type) {
|
||||
case string:
|
||||
p, err = parsePoint([]byte(value))
|
||||
case []byte:
|
||||
if nullRE.Match(value) {
|
||||
dst.Status = Null
|
||||
return nil
|
||||
}
|
||||
p, err = parsePoint(value)
|
||||
default:
|
||||
return err
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*dst = *p
|
||||
return nil
|
||||
}
|
||||
|
||||
var pointRE = regexp.MustCompile("^\\(\\d+\\.\\d+,\\s?\\d+\\.\\d+\\)$")
|
||||
var chunkRE = regexp.MustCompile("\\d+\\.\\d+")
|
||||
|
||||
func parsePoint(p []byte) (*Point, error) {
|
||||
err := errors.Errorf("cannot parse %s", p)
|
||||
if pointRE.Match(p) {
|
||||
chunks := chunkRE.FindAll(p, 2)
|
||||
if len(chunks) != 2 {
|
||||
return nil, err
|
||||
}
|
||||
x, xErr := strconv.ParseFloat(string(chunks[0]), 64)
|
||||
y, yErr := strconv.ParseFloat(string(chunks[1]), 64)
|
||||
if xErr != nil || yErr != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Point{
|
||||
P: Vec2{
|
||||
X: x,
|
||||
Y: y,
|
||||
},
|
||||
Status: Present,
|
||||
}, nil
|
||||
} else if nullRE.Match(p) {
|
||||
return &Point{
|
||||
Status: Null,
|
||||
}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (dst Point) Get() interface{} {
|
||||
@@ -140,3 +195,24 @@ func (dst *Point) Scan(src interface{}) error {
|
||||
func (src Point) Value() (driver.Value, error) {
|
||||
return EncodeValueText(src)
|
||||
}
|
||||
|
||||
func (src Point) MarshalJSON() ([]byte, error) {
|
||||
switch src.Status {
|
||||
case Present:
|
||||
return []byte(fmt.Sprintf("(%g, %g)", src.P.X, src.P.Y)), nil
|
||||
case Null:
|
||||
return []byte("null"), nil
|
||||
case Undefined:
|
||||
return nil, errUndefined
|
||||
}
|
||||
return nil, errBadStatus
|
||||
}
|
||||
|
||||
func (dst *Point) UnmarshalJSON(point []byte) error {
|
||||
p, err := parsePoint(point)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*dst = *p
|
||||
return nil
|
||||
}
|
||||
|
||||
+134
@@ -1,6 +1,7 @@
|
||||
package pgtype_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgtype"
|
||||
@@ -14,3 +15,136 @@ func TestPointTranscode(t *testing.T) {
|
||||
&pgtype.Point{Status: pgtype.Null},
|
||||
})
|
||||
}
|
||||
|
||||
func TestPoint_Set(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
arg interface{}
|
||||
status pgtype.Status
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "first",
|
||||
arg: "(12312.123123, 123123.123123)",
|
||||
status: pgtype.Present,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "second",
|
||||
arg: "(1231s2.123123, 123123.123123)",
|
||||
status: pgtype.Undefined,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "third",
|
||||
arg: []byte("(122.123123,123.123123)"),
|
||||
status: pgtype.Present,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "third",
|
||||
arg: nil,
|
||||
status: pgtype.Null,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dst := &pgtype.Point{}
|
||||
if err := dst.Set(tt.arg); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Set() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if dst.Status != tt.status {
|
||||
t.Errorf("Expected status: %v; got: %v", tt.status, dst.Status)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoint_MarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
point pgtype.Point
|
||||
want []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "first",
|
||||
point: pgtype.Point{
|
||||
P: pgtype.Vec2{},
|
||||
Status: 0,
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "second",
|
||||
point: pgtype.Point{
|
||||
P: pgtype.Vec2{X: 12.245, Y: 432.12},
|
||||
Status: pgtype.Present,
|
||||
},
|
||||
want: []byte("(12.245, 432.12)"),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "third",
|
||||
point: pgtype.Point{
|
||||
P: pgtype.Vec2{},
|
||||
Status: pgtype.Null,
|
||||
},
|
||||
want: []byte("null"),
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.point.MarshalJSON()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("MarshalJSON() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoint_UnmarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status pgtype.Status
|
||||
arg []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "first",
|
||||
status: pgtype.Present,
|
||||
arg: []byte("(123.123, 54.12)"),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "second",
|
||||
status: pgtype.Undefined,
|
||||
arg: []byte("(123.123, 54.1sad2)"),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "third",
|
||||
status: pgtype.Null,
|
||||
arg: []byte("null"),
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dst := &pgtype.Point{}
|
||||
if err := dst.UnmarshalJSON(tt.arg); (err != nil) != tt.wantErr {
|
||||
t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if dst.Status != tt.status {
|
||||
t.Errorf("Status mismatch: %v != %v", dst.Status, tt.status)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -203,3 +203,19 @@ func (dst *UUID) Scan(src interface{}) error {
|
||||
func (src UUID) Value() (driver.Value, error) {
|
||||
return EncodeValueText(src)
|
||||
}
|
||||
|
||||
func (src UUID) MarshalJSON() ([]byte, error) {
|
||||
switch src.Status {
|
||||
case Present:
|
||||
return []byte(encodeUUID(src.Bytes)), nil
|
||||
case Null:
|
||||
return []byte("null"), nil
|
||||
case Undefined:
|
||||
return nil, errUndefined
|
||||
}
|
||||
return nil, errBadStatus
|
||||
}
|
||||
|
||||
func (dst *UUID) UnmarshalJSON(bytes []byte) error {
|
||||
return dst.Set(bytes)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user