Add NullDecimal to shopspring-numeric
The shopspring/decimal package provides a NullDecimal struct intended for use with nullable SQL NUMERICs and numbers. It has Scanner and Valuer implementations already, but adding it to this package allows it to be used with the binary encoding as well. The implementation is very straightforward, but the tests have been made slightly more complicated. The previous version wasn't testing the decimal.Decimal cases, and this change adds those as well as new NullDecimal cases. I've added some logic to the test harness to catch these as you need to use the Equals method to properly compare Decimals.
This commit is contained in:
committed by
Jack Christensen
parent
6bda09691d
commit
db84905b7f
@@ -34,6 +34,12 @@ func (dst *Numeric) Set(src interface{}) error {
|
|||||||
switch value := src.(type) {
|
switch value := src.(type) {
|
||||||
case decimal.Decimal:
|
case decimal.Decimal:
|
||||||
*dst = Numeric{Decimal: value, Status: pgtype.Present}
|
*dst = Numeric{Decimal: value, Status: pgtype.Present}
|
||||||
|
case decimal.NullDecimal:
|
||||||
|
if value.Valid {
|
||||||
|
*dst = Numeric{Decimal: value.Decimal, Status: pgtype.Present}
|
||||||
|
} else {
|
||||||
|
*dst = Numeric{Status: pgtype.Null}
|
||||||
|
}
|
||||||
case float32:
|
case float32:
|
||||||
*dst = Numeric{Decimal: decimal.NewFromFloat(float64(value)), Status: pgtype.Present}
|
*dst = Numeric{Decimal: decimal.NewFromFloat(float64(value)), Status: pgtype.Present}
|
||||||
case float64:
|
case float64:
|
||||||
@@ -113,6 +119,9 @@ func (src *Numeric) AssignTo(dst interface{}) error {
|
|||||||
switch v := dst.(type) {
|
switch v := dst.(type) {
|
||||||
case *decimal.Decimal:
|
case *decimal.Decimal:
|
||||||
*v = src.Decimal
|
*v = src.Decimal
|
||||||
|
case *decimal.NullDecimal:
|
||||||
|
(*v).Valid = true
|
||||||
|
(*v).Decimal = src.Decimal
|
||||||
case *float32:
|
case *float32:
|
||||||
f, _ := src.Decimal.Float64()
|
f, _ := src.Decimal.Float64()
|
||||||
*v = float32(f)
|
*v = float32(f)
|
||||||
@@ -216,7 +225,11 @@ func (src *Numeric) AssignTo(dst interface{}) error {
|
|||||||
return fmt.Errorf("unable to assign to %T", dst)
|
return fmt.Errorf("unable to assign to %T", dst)
|
||||||
}
|
}
|
||||||
case pgtype.Null:
|
case pgtype.Null:
|
||||||
return pgtype.NullAssignTo(dst)
|
if v, ok := dst.(*decimal.NullDecimal); ok {
|
||||||
|
(*v).Valid = false
|
||||||
|
} else {
|
||||||
|
return pgtype.NullAssignTo(dst)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -153,6 +153,9 @@ func TestNumericSet(t *testing.T) {
|
|||||||
source interface{}
|
source interface{}
|
||||||
result *shopspring.Numeric
|
result *shopspring.Numeric
|
||||||
}{
|
}{
|
||||||
|
{source: decimal.New(1, 0), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
|
||||||
|
{source: decimal.NullDecimal{Valid: true, Decimal: decimal.New(1, 0)}, result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
|
||||||
|
{source: decimal.NullDecimal{Valid: false}, result: &shopspring.Numeric{Status: pgtype.Null}},
|
||||||
{source: float32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
|
{source: float32(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
|
||||||
{source: float64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
|
{source: float64(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
|
||||||
{source: int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
|
{source: int8(1), result: &shopspring.Numeric{Decimal: mustParseDecimal(t, "1"), Status: pgtype.Present}},
|
||||||
@@ -208,6 +211,8 @@ func TestNumericAssignTo(t *testing.T) {
|
|||||||
var f64 float64
|
var f64 float64
|
||||||
var pf32 *float32
|
var pf32 *float32
|
||||||
var pf64 *float64
|
var pf64 *float64
|
||||||
|
var d decimal.Decimal
|
||||||
|
var nd decimal.NullDecimal
|
||||||
|
|
||||||
simpleTests := []struct {
|
simpleTests := []struct {
|
||||||
src *shopspring.Numeric
|
src *shopspring.Numeric
|
||||||
@@ -231,16 +236,42 @@ func TestNumericAssignTo(t *testing.T) {
|
|||||||
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &_i8, expected: _int8(42)},
|
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &_i8, expected: _int8(42)},
|
||||||
{src: &shopspring.Numeric{Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))},
|
{src: &shopspring.Numeric{Status: pgtype.Null}, dst: &pi8, expected: ((*int8)(nil))},
|
||||||
{src: &shopspring.Numeric{Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))},
|
{src: &shopspring.Numeric{Status: pgtype.Null}, dst: &_pi8, expected: ((*_int8)(nil))},
|
||||||
|
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &d, expected: decimal.New(42, 0)},
|
||||||
|
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42000"), Status: pgtype.Present}, dst: &d, expected: decimal.New(42, 3)},
|
||||||
|
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "0.042"), Status: pgtype.Present}, dst: &d, expected: decimal.New(42, -3)},
|
||||||
|
{src: &shopspring.Numeric{Decimal: mustParseDecimal(t, "42"), Status: pgtype.Present}, dst: &nd, expected: decimal.NullDecimal{Valid: true, Decimal: decimal.New(42, 0)}},
|
||||||
|
{src: &shopspring.Numeric{Status: pgtype.Null}, dst: &nd, expected: decimal.NullDecimal{Valid: false}},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range simpleTests {
|
for i, tt := range simpleTests {
|
||||||
|
// Zero out the destination variable
|
||||||
|
reflect.ValueOf(tt.dst).Elem().Set(reflect.Zero(reflect.TypeOf(tt.dst).Elem()))
|
||||||
|
|
||||||
err := tt.src.AssignTo(tt.dst)
|
err := tt.src.AssignTo(tt.dst)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("%d: %v", i, err)
|
t.Errorf("%d: %v", i, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected {
|
// Need to specially handle Decimal or NullDecimal methods so we can use their Equal method. Without this
|
||||||
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
|
// we end up checking reference equality on the *big.Int they contain.
|
||||||
|
switch dst := tt.dst.(type) {
|
||||||
|
case *decimal.Decimal:
|
||||||
|
if !dst.Equal(tt.expected.(decimal.Decimal)) {
|
||||||
|
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, d)
|
||||||
|
}
|
||||||
|
case *decimal.NullDecimal:
|
||||||
|
expected := tt.expected.(decimal.NullDecimal)
|
||||||
|
|
||||||
|
if dst.Valid != expected.Valid {
|
||||||
|
t.Errorf("%d: expected %v to assign NullDecimal.Valid = %v, but result was NullDecimal.Valid = %v", i, tt.src, expected.Valid, dst.Valid)
|
||||||
|
}
|
||||||
|
if !dst.Decimal.Equal(expected.Decimal) {
|
||||||
|
t.Errorf("%d: expected %v to assign NullDecimal.Decimal = %v, but result was NullDecimal.Decimal = %v", i, tt.src, expected.Decimal, dst.Decimal)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected {
|
||||||
|
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user