Merge branch 'numeric-nan-support' of git://github.com/leighhopcroft/pgtype into leighhopcroft-numeric-nan-support
This commit is contained in:
+47
-6
@@ -15,6 +15,11 @@ import (
|
|||||||
// PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000
|
// PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000
|
||||||
const nbase = 10000
|
const nbase = 10000
|
||||||
|
|
||||||
|
const (
|
||||||
|
pgNumericNaN = 0x000000000c000000
|
||||||
|
pgNumericNaNSign = 0x0c00
|
||||||
|
)
|
||||||
|
|
||||||
var big0 *big.Int = big.NewInt(0)
|
var big0 *big.Int = big.NewInt(0)
|
||||||
var big1 *big.Int = big.NewInt(1)
|
var big1 *big.Int = big.NewInt(1)
|
||||||
var big10 *big.Int = big.NewInt(10)
|
var big10 *big.Int = big.NewInt(10)
|
||||||
@@ -47,6 +52,7 @@ type Numeric struct {
|
|||||||
Int *big.Int
|
Int *big.Int
|
||||||
Exp int32
|
Exp int32
|
||||||
Status Status
|
Status Status
|
||||||
|
IsNaN bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dst *Numeric) Set(src interface{}) error {
|
func (dst *Numeric) Set(src interface{}) error {
|
||||||
@@ -64,12 +70,20 @@ func (dst *Numeric) Set(src interface{}) error {
|
|||||||
|
|
||||||
switch value := src.(type) {
|
switch value := src.(type) {
|
||||||
case float32:
|
case float32:
|
||||||
|
if math.IsNaN(float64(value)) {
|
||||||
|
*dst = Numeric{Status: Present, IsNaN: true}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
num, exp, err := parseNumericString(strconv.FormatFloat(float64(value), 'f', -1, 64))
|
num, exp, err := parseNumericString(strconv.FormatFloat(float64(value), 'f', -1, 64))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
*dst = Numeric{Int: num, Exp: exp, Status: Present}
|
*dst = Numeric{Int: num, Exp: exp, Status: Present}
|
||||||
case float64:
|
case float64:
|
||||||
|
if math.IsNaN(value) {
|
||||||
|
*dst = Numeric{Status: Present, IsNaN: true}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
num, exp, err := parseNumericString(strconv.FormatFloat(value, 'f', -1, 64))
|
num, exp, err := parseNumericString(strconv.FormatFloat(value, 'f', -1, 64))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -291,6 +305,10 @@ func (dst *Numeric) toBigInt() (*big.Int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (src *Numeric) toFloat64() (float64, error) {
|
func (src *Numeric) toFloat64() (float64, error) {
|
||||||
|
if src.IsNaN {
|
||||||
|
return math.NaN(), nil
|
||||||
|
}
|
||||||
|
|
||||||
buf := make([]byte, 0, 32)
|
buf := make([]byte, 0, 32)
|
||||||
|
|
||||||
buf = append(buf, src.Int.String()...)
|
buf = append(buf, src.Int.String()...)
|
||||||
@@ -310,6 +328,11 @@ func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if string(src) == "'NaN'" { // includes single quotes, see EncodeText for details.
|
||||||
|
*dst = Numeric{Status: Present, IsNaN: true}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
num, exp, err := parseNumericString(string(src))
|
num, exp, err := parseNumericString(string(src))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -353,12 +376,6 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error {
|
|||||||
rp := 0
|
rp := 0
|
||||||
ndigits := int16(binary.BigEndian.Uint16(src[rp:]))
|
ndigits := int16(binary.BigEndian.Uint16(src[rp:]))
|
||||||
rp += 2
|
rp += 2
|
||||||
|
|
||||||
if ndigits == 0 {
|
|
||||||
*dst = Numeric{Int: big.NewInt(0), Status: Present}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
weight := int16(binary.BigEndian.Uint16(src[rp:]))
|
weight := int16(binary.BigEndian.Uint16(src[rp:]))
|
||||||
rp += 2
|
rp += 2
|
||||||
sign := int16(binary.BigEndian.Uint16(src[rp:]))
|
sign := int16(binary.BigEndian.Uint16(src[rp:]))
|
||||||
@@ -366,6 +383,16 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error {
|
|||||||
dscale := int16(binary.BigEndian.Uint16(src[rp:]))
|
dscale := int16(binary.BigEndian.Uint16(src[rp:]))
|
||||||
rp += 2
|
rp += 2
|
||||||
|
|
||||||
|
if sign == pgNumericNaNSign {
|
||||||
|
*dst = Numeric{Status: Present, IsNaN: true}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if ndigits == 0 {
|
||||||
|
*dst = Numeric{Int: big.NewInt(0), Status: Present}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if len(src[rp:]) < int(ndigits)*2 {
|
if len(src[rp:]) < int(ndigits)*2 {
|
||||||
return errors.Errorf("numeric incomplete %v", src)
|
return errors.Errorf("numeric incomplete %v", src)
|
||||||
}
|
}
|
||||||
@@ -467,6 +494,15 @@ func (src Numeric) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
|
|||||||
return nil, errUndefined
|
return nil, errUndefined
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if src.IsNaN {
|
||||||
|
// encode as 'NaN' including single quotes,
|
||||||
|
// "When writing this value [NaN] as a constant in an SQL command,
|
||||||
|
// you must put quotes around it, for example UPDATE table SET x = 'NaN'"
|
||||||
|
// https://www.postgresql.org/docs/9.3/datatype-numeric.html
|
||||||
|
buf = append(buf, "'NaN'"...)
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
buf = append(buf, src.Int.String()...)
|
buf = append(buf, src.Int.String()...)
|
||||||
buf = append(buf, 'e')
|
buf = append(buf, 'e')
|
||||||
buf = append(buf, strconv.FormatInt(int64(src.Exp), 10)...)
|
buf = append(buf, strconv.FormatInt(int64(src.Exp), 10)...)
|
||||||
@@ -481,6 +517,11 @@ func (src Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
|
|||||||
return nil, errUndefined
|
return nil, errUndefined
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if src.IsNaN {
|
||||||
|
buf = pgio.AppendUint64(buf, pgNumericNaN)
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
var sign int16
|
var sign int16
|
||||||
if src.Int.Cmp(big0) < 0 {
|
if src.Int.Cmp(big0) < 0 {
|
||||||
sign = 16384
|
sign = 16384
|
||||||
|
|||||||
+32
-5
@@ -210,6 +210,8 @@ func TestNumericSet(t *testing.T) {
|
|||||||
{source: float64(1234), result: &pgtype.Numeric{Int: big.NewInt(1234), Exp: 0, Status: pgtype.Present}},
|
{source: float64(1234), result: &pgtype.Numeric{Int: big.NewInt(1234), Exp: 0, Status: pgtype.Present}},
|
||||||
{source: float64(12345678900), result: &pgtype.Numeric{Int: big.NewInt(123456789), Exp: 2, Status: pgtype.Present}},
|
{source: float64(12345678900), result: &pgtype.Numeric{Int: big.NewInt(123456789), Exp: 2, Status: pgtype.Present}},
|
||||||
{source: float64(12345.678901), result: &pgtype.Numeric{Int: big.NewInt(12345678901), Exp: -6, Status: pgtype.Present}},
|
{source: float64(12345.678901), result: &pgtype.Numeric{Int: big.NewInt(12345678901), Exp: -6, Status: pgtype.Present}},
|
||||||
|
{source: math.NaN(), result: &pgtype.Numeric{Int: nil, Exp: 0, Status: pgtype.Present, IsNaN: true}},
|
||||||
|
{source: float32(math.NaN()), result: &pgtype.Numeric{Int: nil, Exp: 0, Status: pgtype.Present, IsNaN: true}},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range successfulTests {
|
for i, tt := range successfulTests {
|
||||||
@@ -267,6 +269,8 @@ func TestNumericAssignTo(t *testing.T) {
|
|||||||
{src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))},
|
{src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))},
|
||||||
{src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))},
|
{src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))},
|
||||||
{src: &pgtype.Numeric{Int: big.NewInt(1006), Exp: -2, Status: pgtype.Present}, dst: &f64, expected: float64(10.06)}, // https://github.com/jackc/pgtype/issues/27
|
{src: &pgtype.Numeric{Int: big.NewInt(1006), Exp: -2, Status: pgtype.Present}, dst: &f64, expected: float64(10.06)}, // https://github.com/jackc/pgtype/issues/27
|
||||||
|
{src: &pgtype.Numeric{Status: pgtype.Present, IsNaN: true}, dst: &f64, expected: math.NaN()},
|
||||||
|
{src: &pgtype.Numeric{Status: pgtype.Present, IsNaN: true}, dst: &f32, expected: float32(math.NaN())},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range simpleTests {
|
for i, tt := range simpleTests {
|
||||||
@@ -275,8 +279,26 @@ func TestNumericAssignTo(t *testing.T) {
|
|||||||
t.Errorf("%d: %v", i, err)
|
t.Errorf("%d: %v", i, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected {
|
dst := reflect.ValueOf(tt.dst).Elem().Interface()
|
||||||
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
|
switch dstTyped := dst.(type) {
|
||||||
|
case float32:
|
||||||
|
nanExpected := math.IsNaN(float64(tt.expected.(float32)))
|
||||||
|
if nanExpected && !math.IsNaN(float64(dstTyped)) {
|
||||||
|
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
|
||||||
|
} else if !nanExpected && dst != tt.expected {
|
||||||
|
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
|
||||||
|
}
|
||||||
|
case float64:
|
||||||
|
nanExpected := math.IsNaN(tt.expected.(float64))
|
||||||
|
if nanExpected && !math.IsNaN(dstTyped) {
|
||||||
|
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
|
||||||
|
} else if !nanExpected && dst != tt.expected {
|
||||||
|
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if dst != tt.expected {
|
||||||
|
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -328,6 +350,8 @@ func TestNumericEncodeDecodeBinary(t *testing.T) {
|
|||||||
123,
|
123,
|
||||||
0.000012345,
|
0.000012345,
|
||||||
1.00002345,
|
1.00002345,
|
||||||
|
math.NaN(),
|
||||||
|
float32(math.NaN()),
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
@@ -335,7 +359,7 @@ func TestNumericEncodeDecodeBinary(t *testing.T) {
|
|||||||
ci := pgtype.NewConnInfo()
|
ci := pgtype.NewConnInfo()
|
||||||
text, err := n.EncodeText(ci, nil)
|
text, err := n.EncodeText(ci, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("%d: %v", i, err)
|
t.Errorf("%d (EncodeText): %v", i, err)
|
||||||
}
|
}
|
||||||
return string(text)
|
return string(text)
|
||||||
}
|
}
|
||||||
@@ -344,10 +368,13 @@ func TestNumericEncodeDecodeBinary(t *testing.T) {
|
|||||||
|
|
||||||
encoded, err := numeric.EncodeBinary(ci, nil)
|
encoded, err := numeric.EncodeBinary(ci, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("%d: %v", i, err)
|
t.Errorf("%d (EncodeBinary): %v", i, err)
|
||||||
}
|
}
|
||||||
decoded := &pgtype.Numeric{}
|
decoded := &pgtype.Numeric{}
|
||||||
decoded.DecodeBinary(ci, encoded)
|
err = decoded.DecodeBinary(ci, encoded)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%d (DecodeBinary): %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
text0 := toString(numeric)
|
text0 := toString(numeric)
|
||||||
text1 := toString(decoded)
|
text1 := toString(decoded)
|
||||||
|
|||||||
Reference in New Issue
Block a user