diff --git a/numeric.go b/numeric.go index 4cfbb657..b24f433c 100644 --- a/numeric.go +++ b/numeric.go @@ -1,6 +1,7 @@ package pgtype import ( + "bytes" "database/sql/driver" "encoding/binary" "fmt" @@ -807,3 +808,41 @@ func (src Numeric) Value() (driver.Value, error) { return string(buf), nil } + +func (src Numeric) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + + if src.NaN { + return []byte(`"NaN"`), nil + } + + intStr := src.Int.String() + buf := &bytes.Buffer{} + exp := int(src.Exp) + if exp > 0 { + buf.WriteString(intStr) + for i := 0; i < exp; i++ { + buf.WriteByte('0') + } + } else if exp < 0 { + if len(intStr) <= -exp { + buf.WriteString("0.") + leadingZeros := -exp - len(intStr) + for i := 0; i < leadingZeros; i++ { + buf.WriteByte('0') + } + buf.WriteString(intStr) + } else if len(intStr) > -exp { + dpPos := len(intStr) + exp + buf.WriteString(intStr[:dpPos]) + buf.WriteByte('.') + buf.WriteString(intStr[dpPos:]) + } + } else { + buf.WriteString(intStr) + } + + return buf.Bytes(), nil +} diff --git a/numeric_test.go b/numeric_test.go index 58ce5c0f..7f0734d0 100644 --- a/numeric_test.go +++ b/numeric_test.go @@ -1,6 +1,8 @@ package pgtype_test import ( + "context" + "encoding/json" "math" "math/big" "math/rand" @@ -9,6 +11,7 @@ import ( "github.com/jackc/pgtype" "github.com/jackc/pgtype/testutil" + "github.com/stretchr/testify/require" ) // For test purposes only. Note that it does not normalize values. e.g. (Int: 1, Exp: 3) will not equal (Int: 1000, Exp: 0) @@ -410,3 +413,35 @@ func TestNumericEncodeDecodeBinary(t *testing.T) { } } } + +func TestNumericMarshalJSON(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + for i, tt := range []struct { + decString string + }{ + {"NaN"}, + {"0"}, + {"1"}, + {"-1"}, + {"1000000000000000000"}, + {"1234.56789"}, + {"1.56789"}, + {"0.00000000000056789"}, + {"0.00123000"}, + {"123e-3"}, + {"243723409723490243842378942378901237502734019231380123e23790"}, + {"3409823409243892349028349023482934092340892390101e-14021"}, + } { + var num pgtype.Numeric + var pgJSON string + err := conn.QueryRow(context.Background(), `select $1::numeric, to_json($1::numeric)`, tt.decString).Scan(&num, &pgJSON) + require.NoErrorf(t, err, "%d", i) + + goJSON, err := json.Marshal(num) + require.NoErrorf(t, err, "%d", i) + + require.Equal(t, pgJSON, string(goJSON)) + } +}