From f2a2797a88765112814ec47c2b02bec97451278a Mon Sep 17 00:00:00 2001 From: leighhopcroft Date: Tue, 2 Jun 2020 20:14:51 +0100 Subject: [PATCH] support NaN in Numeric encode and decode methods --- numeric.go | 32 ++++++++++++++++++++++++-------- numeric_test.go | 11 ++++++++--- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/numeric.go b/numeric.go index 644ee23f..7ee517be 100644 --- a/numeric.go +++ b/numeric.go @@ -15,6 +15,11 @@ import ( // PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000 const nbase = 10000 +const ( + pgNumericNaN = 0x000000000c000000 + pgNumericNaNSign = 0x0c00 +) + var big0 *big.Int = big.NewInt(0) var big1 *big.Int = big.NewInt(1) var big10 *big.Int = big.NewInt(10) @@ -323,6 +328,11 @@ func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error { return nil } + if string(src) == "NaN" { + *dst = Numeric{} + return nil + } + num, exp, err := parseNumericString(string(src)) if err != nil { return err @@ -366,12 +376,6 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { rp := 0 ndigits := int16(binary.BigEndian.Uint16(src[rp:])) rp += 2 - - if ndigits == 0 { - *dst = Numeric{Int: big.NewInt(0), Status: Present} - return nil - } - weight := int16(binary.BigEndian.Uint16(src[rp:])) rp += 2 sign := int16(binary.BigEndian.Uint16(src[rp:])) @@ -379,6 +383,16 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { dscale := int16(binary.BigEndian.Uint16(src[rp:])) rp += 2 + if sign == pgNumericNaNSign { + *dst = Numeric{} + return nil + } + + if ndigits == 0 { + *dst = Numeric{Int: big.NewInt(0), Status: Present} + return nil + } + if len(src[rp:]) < int(ndigits)*2 { return errors.Errorf("numeric incomplete %v", src) } @@ -477,7 +491,8 @@ func (src Numeric) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { case Null: return nil, nil case Undefined: - return nil, errUndefined + buf = append(buf, []byte("NaN")...) + return buf, nil } buf = append(buf, src.Int.String()...) @@ -491,7 +506,8 @@ func (src Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { case Null: return nil, nil case Undefined: - return nil, errUndefined + buf = pgio.AppendUint64(buf, pgNumericNaN) + return buf, nil } var sign int16 diff --git a/numeric_test.go b/numeric_test.go index ee72ff5e..259f397e 100644 --- a/numeric_test.go +++ b/numeric_test.go @@ -344,6 +344,8 @@ func TestNumericEncodeDecodeBinary(t *testing.T) { 123, 0.000012345, 1.00002345, + math.NaN(), + float32(math.NaN()), } for i, tt := range tests { @@ -351,7 +353,7 @@ func TestNumericEncodeDecodeBinary(t *testing.T) { ci := pgtype.NewConnInfo() text, err := n.EncodeText(ci, nil) if err != nil { - t.Errorf("%d: %v", i, err) + t.Errorf("%d (EncodeText): %v", i, err) } return string(text) } @@ -360,10 +362,13 @@ func TestNumericEncodeDecodeBinary(t *testing.T) { encoded, err := numeric.EncodeBinary(ci, nil) if err != nil { - t.Errorf("%d: %v", i, err) + t.Errorf("%d (EncodeBinary): %v", i, err) } decoded := &pgtype.Numeric{} - decoded.DecodeBinary(ci, encoded) + err = decoded.DecodeBinary(ci, encoded) + if err != nil { + t.Errorf("%d (DecodeBinary): %v", i, err) + } text0 := toString(numeric) text1 := toString(decoded)