diff --git a/ext/shopspring-numeric/decimal.go b/ext/shopspring-numeric/decimal.go index 70906806..148589a4 100644 --- a/ext/shopspring-numeric/decimal.go +++ b/ext/shopspring-numeric/decimal.go @@ -250,17 +250,7 @@ func (dst *Numeric) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { return err } - buf, err := num.EncodeText(ci, nil) - if err != nil { - return err - } - - dec, err := decimal.NewFromString(string(buf)) - if err != nil { - return err - } - - *dst = Numeric{Decimal: dec, Status: pgtype.Present} + *dst = Numeric{Decimal: decimal.NewFromBigInt(num.Int, num.Exp), Status: pgtype.Present} return nil } diff --git a/ext/shopspring-numeric/decimal_test.go b/ext/shopspring-numeric/decimal_test.go index 0b256b37..bf34e0dd 100644 --- a/ext/shopspring-numeric/decimal_test.go +++ b/ext/shopspring-numeric/decimal_test.go @@ -11,6 +11,7 @@ import ( shopspring "github.com/jackc/pgtype/ext/shopspring-numeric" "github.com/jackc/pgtype/testutil" "github.com/shopspring/decimal" + "github.com/stretchr/testify/require" ) func mustParseDecimal(t *testing.T, src string) decimal.Decimal { @@ -284,3 +285,46 @@ func TestNumericAssignTo(t *testing.T) { } } } + +func BenchmarkDecode(b *testing.B) { + benchmarks := []struct { + name string + numberStr string + }{ + {"Zero", "0"}, + {"Small", "12345"}, + {"Medium", "12345.12345"}, + {"Large", "123457890.1234567890"}, + {"Huge", "123457890123457890123457890.1234567890123457890123457890"}, + } + + for _, bm := range benchmarks { + src := &shopspring.Numeric{} + err := src.Set(bm.numberStr) + require.NoError(b, err) + textFormat, err := src.EncodeText(nil, nil) + require.NoError(b, err) + binaryFormat, err := src.EncodeBinary(nil, nil) + require.NoError(b, err) + + b.Run(fmt.Sprintf("%s-Text", bm.name), func(b *testing.B) { + dst := &shopspring.Numeric{} + for i := 0; i < b.N; i++ { + err := dst.DecodeText(nil, textFormat) + if err != nil { + b.Fatal(err) + } + } + }) + + b.Run(fmt.Sprintf("%s-Binary", bm.name), func(b *testing.B) { + dst := &shopspring.Numeric{} + for i := 0; i < b.N; i++ { + err := dst.DecodeBinary(nil, binaryFormat) + if err != nil { + b.Fatal(err) + } + } + }) + } +}