2
0

Add NullDecimal to shopspring-numeric

The shopspring/decimal package provides a NullDecimal struct intended
for use with nullable SQL NUMERICs and numbers. It has Scanner and
Valuer implementations already, but adding it to this package allows
it to be used with the binary encoding as well.

The implementation is very straightforward, but the tests have been made
slightly more complicated. The previous version wasn't testing the
decimal.Decimal cases, and this change adds those as well as new
NullDecimal cases. I've added some logic to the test harness to catch
these as you need to use the Equals method to properly compare Decimals.
This commit is contained in:
Eli Treuherz
2021-08-02 09:26:24 +01:00
committed by Jack Christensen
parent 6bda09691d
commit db84905b7f
2 changed files with 47 additions and 3 deletions
+14 -1
View File
@@ -34,6 +34,12 @@ func (dst *Numeric) Set(src interface{}) error {
switch value := src.(type) { switch value := src.(type) {
case decimal.Decimal: case decimal.Decimal:
*dst = Numeric{Decimal: value, Status: pgtype.Present} *dst = Numeric{Decimal: value, Status: pgtype.Present}
case decimal.NullDecimal:
if value.Valid {
*dst = Numeric{Decimal: value.Decimal, Status: pgtype.Present}
} else {
*dst = Numeric{Status: pgtype.Null}
}
case float32: case float32:
*dst = Numeric{Decimal: decimal.NewFromFloat(float64(value)), Status: pgtype.Present} *dst = Numeric{Decimal: decimal.NewFromFloat(float64(value)), Status: pgtype.Present}
case float64: case float64:
@@ -113,6 +119,9 @@ func (src *Numeric) AssignTo(dst interface{}) error {
switch v := dst.(type) { switch v := dst.(type) {
case *decimal.Decimal: case *decimal.Decimal:
*v = src.Decimal *v = src.Decimal
case *decimal.NullDecimal:
(*v).Valid = true
(*v).Decimal = src.Decimal
case *float32: case *float32:
f, _ := src.Decimal.Float64() f, _ := src.Decimal.Float64()
*v = float32(f) *v = float32(f)
@@ -216,7 +225,11 @@ func (src *Numeric) AssignTo(dst interface{}) error {
return fmt.Errorf("unable to assign to %T", dst) return fmt.Errorf("unable to assign to %T", dst)
} }
case pgtype.Null: case pgtype.Null:
return pgtype.NullAssignTo(dst) if v, ok := dst.(*decimal.NullDecimal); ok {
(*v).Valid = false
} else {
return pgtype.NullAssignTo(dst)
}
} }
return nil return nil
+33 -2
View File
@@ -153,6 +153,9 @@ func TestNumericSet(t *testing.T) {
source interface{} source interface{}
result *shopspring.Numeric result *shopspring.Numeric
}{ }{
{source: decimal.New(1, 0), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
{source: decimal.NullDecimal{Valid: true, Decimal: decimal.New(1, 0)}, result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
{source: decimal.NullDecimal{Valid: false}, result: &shopspring.Numeric{Status: pgtype.Null}},
{source: float32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, {source: float32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
{source: float64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, {source: float64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
{source: int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, {source: int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
@@ -208,6 +211,8 @@ func TestNumericAssignTo(t *testing.T) {
var f64 float64 var f64 float64
var pf32 *float32 var pf32 *float32
var pf64 *float64 var pf64 *float64
var d decimal.Decimal
var nd decimal.NullDecimal
simpleTests := []struct { simpleTests := []struct {
src *shopspring.Numeric src *shopspring.Numeric
@@ -231,16 +236,42 @@ func TestNumericAssignTo(t *testing.T) {
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &_i8, expected: _int8(42)},
{src: &shopspring.Numeric{Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))},
{src: &shopspring.Numeric{Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))},
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &d, expected: decimal.New(42, 0)},
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42000"), Status: pgtype.Present}, dst: &d, expected: decimal.New(42, 3)},
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.042"), Status: pgtype.Present}, dst: &d, expected: decimal.New(42, -3)},
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &nd, expected: decimal.NullDecimal{Valid: true, Decimal: decimal.New(42, 0)}},
{src: &shopspring.Numeric{Status: pgtype.Null}, dst: &nd, expected: decimal.NullDecimal{Valid: false}},
} }
for i, tt := range simpleTests { for i, tt := range simpleTests {
// Zero out the destination variable
reflect.ValueOf(tt.dst).Elem().Set(reflect.Zero(reflect.TypeOf(tt.dst).Elem()))
err := tt.src.AssignTo(tt.dst) err := tt.src.AssignTo(tt.dst)
if err != nil { if err != nil {
t.Errorf("%d: %v", i, err) t.Errorf("%d: %v", i, err)
} }
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { // Need to specially handle Decimal or NullDecimal methods so we can use their Equal method. Without this
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) // we end up checking reference equality on the *big.Int they contain.
switch dst := tt.dst.(type) {
case *decimal.Decimal:
if !dst.Equal(tt.expected.(decimal.Decimal)) {
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, d)
}
case *decimal.NullDecimal:
expected := tt.expected.(decimal.NullDecimal)
if dst.Valid != expected.Valid {
t.Errorf("%d: expected %v to assign NullDecimal.Valid = %v, but result was NullDecimal.Valid = %v", i, tt.src, expected.Valid, dst.Valid)
}
if !dst.Decimal.Equal(expected.Decimal) {
t.Errorf("%d: expected %v to assign NullDecimal.Decimal = %v, but result was NullDecimal.Decimal = %v", i, tt.src, expected.Decimal, dst.Decimal)
}
default:
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected {
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
}
} }
} }