diff --git a/numeric.go b/numeric.go index fc8e1789..074c2edc 100644 --- a/numeric.go +++ b/numeric.go @@ -15,6 +15,11 @@ import ( // PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000 const nbase = 10000 +const ( + pgNumericNaN = 0x000000000c000000 + pgNumericNaNSign = 0x0c00 +) + var big0 *big.Int = big.NewInt(0) var big1 *big.Int = big.NewInt(1) var big10 *big.Int = big.NewInt(10) @@ -47,6 +52,7 @@ type Numeric struct { Int *big.Int Exp int32 Status Status + IsNaN bool } func (dst *Numeric) Set(src interface{}) error { @@ -64,12 +70,20 @@ func (dst *Numeric) Set(src interface{}) error { switch value := src.(type) { case float32: + if math.IsNaN(float64(value)) { + *dst = Numeric{Status: Present, IsNaN: true} + return nil + } num, exp, err := parseNumericString(strconv.FormatFloat(float64(value), 'f', -1, 64)) if err != nil { return err } *dst = Numeric{Int: num, Exp: exp, Status: Present} case float64: + if math.IsNaN(value) { + *dst = Numeric{Status: Present, IsNaN: true} + return nil + } num, exp, err := parseNumericString(strconv.FormatFloat(value, 'f', -1, 64)) if err != nil { return err @@ -291,6 +305,10 @@ func (dst *Numeric) toBigInt() (*big.Int, error) { } func (src *Numeric) toFloat64() (float64, error) { + if src.IsNaN { + return math.NaN(), nil + } + buf := make([]byte, 0, 32) buf = append(buf, src.Int.String()...) @@ -310,6 +328,11 @@ func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error { return nil } + if string(src) == "'NaN'" { // includes single quotes, see EncodeText for details. + *dst = Numeric{Status: Present, IsNaN: true} + return nil + } + num, exp, err := parseNumericString(string(src)) if err != nil { return err @@ -353,12 +376,6 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { rp := 0 ndigits := int16(binary.BigEndian.Uint16(src[rp:])) rp += 2 - - if ndigits == 0 { - *dst = Numeric{Int: big.NewInt(0), Status: Present} - return nil - } - weight := int16(binary.BigEndian.Uint16(src[rp:])) rp += 2 sign := int16(binary.BigEndian.Uint16(src[rp:])) @@ -366,6 +383,16 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { dscale := int16(binary.BigEndian.Uint16(src[rp:])) rp += 2 + if sign == pgNumericNaNSign { + *dst = Numeric{Status: Present, IsNaN: true} + return nil + } + + if ndigits == 0 { + *dst = Numeric{Int: big.NewInt(0), Status: Present} + return nil + } + if len(src[rp:]) < int(ndigits)*2 { return errors.Errorf("numeric incomplete %v", src) } @@ -467,6 +494,15 @@ func (src Numeric) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, errUndefined } + if src.IsNaN { + // encode as 'NaN' including single quotes, + // "When writing this value [NaN] as a constant in an SQL command, + // you must put quotes around it, for example UPDATE table SET x = 'NaN'" + // https://www.postgresql.org/docs/9.3/datatype-numeric.html + buf = append(buf, "'NaN'"...) + return buf, nil + } + buf = append(buf, src.Int.String()...) buf = append(buf, 'e') buf = append(buf, strconv.FormatInt(int64(src.Exp), 10)...) @@ -481,6 +517,11 @@ func (src Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return nil, errUndefined } + if src.IsNaN { + buf = pgio.AppendUint64(buf, pgNumericNaN) + return buf, nil + } + var sign int16 if src.Int.Cmp(big0) < 0 { sign = 16384 diff --git a/numeric_test.go b/numeric_test.go index 263c78b6..4d9c5252 100644 --- a/numeric_test.go +++ b/numeric_test.go @@ -210,6 +210,8 @@ func TestNumericSet(t *testing.T) { {source: float64(1234), result: &pgtype.Numeric{Int: big.NewInt(1234), Exp: 0, Status: pgtype.Present}}, {source: float64(12345678900), result: &pgtype.Numeric{Int: big.NewInt(123456789), Exp: 2, Status: pgtype.Present}}, {source: float64(12345.678901), result: &pgtype.Numeric{Int: big.NewInt(12345678901), Exp: -6, Status: pgtype.Present}}, + {source: math.NaN(), result: &pgtype.Numeric{Int: nil, Exp: 0, Status: pgtype.Present, IsNaN: true}}, + {source: float32(math.NaN()), result: &pgtype.Numeric{Int: nil, Exp: 0, Status: pgtype.Present, IsNaN: true}}, } for i, tt := range successfulTests { @@ -267,6 +269,8 @@ func TestNumericAssignTo(t *testing.T) { {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, {src: &pgtype.Numeric{Int: big.NewInt(0), Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, {src: &pgtype.Numeric{Int: big.NewInt(1006), Exp: -2, Status: pgtype.Present}, dst: &f64, expected: float64(10.06)}, // https://github.com/jackc/pgtype/issues/27 + {src: &pgtype.Numeric{Status: pgtype.Present, IsNaN: true}, dst: &f64, expected: math.NaN()}, + {src: &pgtype.Numeric{Status: pgtype.Present, IsNaN: true}, dst: &f32, expected: float32(math.NaN())}, } for i, tt := range simpleTests { @@ -275,8 +279,26 @@ func TestNumericAssignTo(t *testing.T) { t.Errorf("%d: %v", i, err) } - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + dst := reflect.ValueOf(tt.dst).Elem().Interface() + switch dstTyped := dst.(type) { + case float32: + nanExpected := math.IsNaN(float64(tt.expected.(float32))) + if nanExpected && !math.IsNaN(float64(dstTyped)) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } else if !nanExpected && dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + case float64: + nanExpected := math.IsNaN(tt.expected.(float64)) + if nanExpected && !math.IsNaN(dstTyped) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } else if !nanExpected && dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + default: + if dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } } } @@ -328,6 +350,8 @@ func TestNumericEncodeDecodeBinary(t *testing.T) { 123, 0.000012345, 1.00002345, + math.NaN(), + float32(math.NaN()), } for i, tt := range tests { @@ -335,7 +359,7 @@ func TestNumericEncodeDecodeBinary(t *testing.T) { ci := pgtype.NewConnInfo() text, err := n.EncodeText(ci, nil) if err != nil { - t.Errorf("%d: %v", i, err) + t.Errorf("%d (EncodeText): %v", i, err) } return string(text) } @@ -344,10 +368,13 @@ func TestNumericEncodeDecodeBinary(t *testing.T) { encoded, err := numeric.EncodeBinary(ci, nil) if err != nil { - t.Errorf("%d: %v", i, err) + t.Errorf("%d (EncodeBinary): %v", i, err) } decoded := &pgtype.Numeric{} - decoded.DecodeBinary(ci, encoded) + err = decoded.DecodeBinary(ci, encoded) + if err != nil { + t.Errorf("%d (DecodeBinary): %v", i, err) + } text0 := toString(numeric) text1 := toString(decoded)