From db84905b7f4388608bbc5362c4bf9b9423b4cdd6 Mon Sep 17 00:00:00 2001 From: Eli Treuherz Date: Mon, 2 Aug 2021 09:26:24 +0100 Subject: [PATCH] Add NullDecimal to shopspring-numeric The shopspring/decimal package provides a NullDecimal struct intended for use with nullable SQL NUMERICs and numbers. It has Scanner and Valuer implementations already, but adding it to this package allows it to be used with the binary encoding as well. The implementation is very straightforward, but the tests have been made slightly more complicated. The previous version wasn't testing the decimal.Decimal cases, and this change adds those as well as new NullDecimal cases. I've added some logic to the test harness to catch these as you need to use the Equals method to properly compare Decimals. --- ext/shopspring-numeric/decimal.go | 15 ++++++++++- ext/shopspring-numeric/decimal_test.go | 35 ++++++++++++++++++++++++-- 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/ext/shopspring-numeric/decimal.go b/ext/shopspring-numeric/decimal.go index e8694111..ef3ce201 100644 --- a/ext/shopspring-numeric/decimal.go +++ b/ext/shopspring-numeric/decimal.go @@ -34,6 +34,12 @@ func (dst *Numeric) Set(src interface{}) error { switch value := src.(type) { case decimal.Decimal: *dst = Numeric{Decimal: value, Status: pgtype.Present} + case decimal.NullDecimal: + if value.Valid { + *dst = Numeric{Decimal: value.Decimal, Status: pgtype.Present} + } else { + *dst = Numeric{Status: pgtype.Null} + } case float32: *dst = Numeric{Decimal: decimal.NewFromFloat(float64(value)), Status: pgtype.Present} case float64: @@ -113,6 +119,9 @@ func (src *Numeric) AssignTo(dst interface{}) error { switch v := dst.(type) { case *decimal.Decimal: *v = src.Decimal + case *decimal.NullDecimal: + (*v).Valid = true + (*v).Decimal = src.Decimal case *float32: f, _ := src.Decimal.Float64() *v = float32(f) @@ -216,7 +225,11 @@ func (src *Numeric) AssignTo(dst interface{}) error { return fmt.Errorf("unable to assign to %T", dst) } case pgtype.Null: - return pgtype.NullAssignTo(dst) + if v, ok := dst.(*decimal.NullDecimal); ok { + (*v).Valid = false + } else { + return pgtype.NullAssignTo(dst) + } } return nil diff --git a/ext/shopspring-numeric/decimal_test.go b/ext/shopspring-numeric/decimal_test.go index bf34e0dd..e635da41 100644 --- a/ext/shopspring-numeric/decimal_test.go +++ b/ext/shopspring-numeric/decimal_test.go @@ -153,6 +153,9 @@ func TestNumericSet(t *testing.T) { source interface{} result *shopspring.Numeric }{ + {source: decimal.New(1, 0), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: decimal.NullDecimal{Valid: true, Decimal: decimal.New(1, 0)}, result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: decimal.NullDecimal{Valid: false}, result: &shopspring.Numeric{Status: pgtype.Null}}, {source: float32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, {source: float64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, {source: int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, @@ -208,6 +211,8 @@ func TestNumericAssignTo(t *testing.T) { var f64 float64 var pf32 *float32 var pf64 *float64 + var d decimal.Decimal + var nd decimal.NullDecimal simpleTests := []struct { src *shopspring.Numeric @@ -231,16 +236,42 @@ func TestNumericAssignTo(t *testing.T) { {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &d, expected: decimal.New(42, 0)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42000"), Status: pgtype.Present}, dst: &d, expected: decimal.New(42, 3)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.042"), Status: pgtype.Present}, dst: &d, expected: decimal.New(42, -3)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &nd, expected: decimal.NullDecimal{Valid: true, Decimal: decimal.New(42, 0)}}, + {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &nd, expected: decimal.NullDecimal{Valid: false}}, } for i, tt := range simpleTests { + // Zero out the destination variable + reflect.ValueOf(tt.dst).Elem().Set(reflect.Zero(reflect.TypeOf(tt.dst).Elem())) + err := tt.src.AssignTo(tt.dst) if err != nil { t.Errorf("%d: %v", i, err) } - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + // Need to specially handle Decimal or NullDecimal methods so we can use their Equal method. Without this + // we end up checking reference equality on the *big.Int they contain. + switch dst := tt.dst.(type) { + case *decimal.Decimal: + if !dst.Equal(tt.expected.(decimal.Decimal)) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, d) + } + case *decimal.NullDecimal: + expected := tt.expected.(decimal.NullDecimal) + + if dst.Valid != expected.Valid { + t.Errorf("%d: expected %v to assign NullDecimal.Valid = %v, but result was NullDecimal.Valid = %v", i, tt.src, expected.Valid, dst.Valid) + } + if !dst.Decimal.Equal(expected.Decimal) { + t.Errorf("%d: expected %v to assign NullDecimal.Decimal = %v, but result was NullDecimal.Decimal = %v", i, tt.src, expected.Decimal, dst.Decimal) + } + default: + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } } }