diff --git a/ext/shopspring-numeric/decimal.go b/ext/shopspring-numeric/decimal.go index e8694111..ef3ce201 100644 --- a/ext/shopspring-numeric/decimal.go +++ b/ext/shopspring-numeric/decimal.go @@ -34,6 +34,12 @@ func (dst *Numeric) Set(src interface{}) error { switch value := src.(type) { case decimal.Decimal: *dst = Numeric{Decimal: value, Status: pgtype.Present} + case decimal.NullDecimal: + if value.Valid { + *dst = Numeric{Decimal: value.Decimal, Status: pgtype.Present} + } else { + *dst = Numeric{Status: pgtype.Null} + } case float32: *dst = Numeric{Decimal: decimal.NewFromFloat(float64(value)), Status: pgtype.Present} case float64: @@ -113,6 +119,9 @@ func (src *Numeric) AssignTo(dst interface{}) error { switch v := dst.(type) { case *decimal.Decimal: *v = src.Decimal + case *decimal.NullDecimal: + (*v).Valid = true + (*v).Decimal = src.Decimal case *float32: f, _ := src.Decimal.Float64() *v = float32(f) @@ -216,7 +225,11 @@ func (src *Numeric) AssignTo(dst interface{}) error { return fmt.Errorf("unable to assign to %T", dst) } case pgtype.Null: - return pgtype.NullAssignTo(dst) + if v, ok := dst.(*decimal.NullDecimal); ok { + (*v).Valid = false + } else { + return pgtype.NullAssignTo(dst) + } } return nil diff --git a/ext/shopspring-numeric/decimal_test.go b/ext/shopspring-numeric/decimal_test.go index bf34e0dd..e635da41 100644 --- a/ext/shopspring-numeric/decimal_test.go +++ b/ext/shopspring-numeric/decimal_test.go @@ -153,6 +153,9 @@ func TestNumericSet(t *testing.T) { source interface{} result *shopspring.Numeric }{ + {source: decimal.New(1, 0), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: decimal.NullDecimal{Valid: true, Decimal: decimal.New(1, 0)}, result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, + {source: decimal.NullDecimal{Valid: false}, result: &shopspring.Numeric{Status: pgtype.Null}}, {source: float32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, {source: float64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, {source: int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}}, @@ -208,6 +211,8 @@ func TestNumericAssignTo(t *testing.T) { var f64 float64 var pf32 *float32 var pf64 *float64 + var d decimal.Decimal + var nd decimal.NullDecimal simpleTests := []struct { src *shopspring.Numeric @@ -231,16 +236,42 @@ func TestNumericAssignTo(t *testing.T) { {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &_i8, expected: _int8(42)}, {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))}, {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &d, expected: decimal.New(42, 0)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42000"), Status: pgtype.Present}, dst: &d, expected: decimal.New(42, 3)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.042"), Status: pgtype.Present}, dst: &d, expected: decimal.New(42, -3)}, + {src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &nd, expected: decimal.NullDecimal{Valid: true, Decimal: decimal.New(42, 0)}}, + {src: &shopspring.Numeric{Status: pgtype.Null}, dst: &nd, expected: decimal.NullDecimal{Valid: false}}, } for i, tt := range simpleTests { + // Zero out the destination variable + reflect.ValueOf(tt.dst).Elem().Set(reflect.Zero(reflect.TypeOf(tt.dst).Elem())) + err := tt.src.AssignTo(tt.dst) if err != nil { t.Errorf("%d: %v", i, err) } - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + // Need to specially handle Decimal or NullDecimal methods so we can use their Equal method. Without this + // we end up checking reference equality on the *big.Int they contain. + switch dst := tt.dst.(type) { + case *decimal.Decimal: + if !dst.Equal(tt.expected.(decimal.Decimal)) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, d) + } + case *decimal.NullDecimal: + expected := tt.expected.(decimal.NullDecimal) + + if dst.Valid != expected.Valid { + t.Errorf("%d: expected %v to assign NullDecimal.Valid = %v, but result was NullDecimal.Valid = %v", i, tt.src, expected.Valid, dst.Valid) + } + if !dst.Decimal.Equal(expected.Decimal) { + t.Errorf("%d: expected %v to assign NullDecimal.Decimal = %v, but result was NullDecimal.Decimal = %v", i, tt.src, expected.Decimal, dst.Decimal) + } + default: + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } } }