diff --git a/ext/shopspring-numeric/decimal.go b/ext/shopspring-numeric/decimal.go index c035b15b..259fa54d 100644 --- a/ext/shopspring-numeric/decimal.go +++ b/ext/shopspring-numeric/decimal.go @@ -11,6 +11,7 @@ import ( ) var errUndefined = errors.New("cannot encode status undefined") +var errBadStatus = errors.New("invalid status") type Numeric struct { Decimal decimal.Decimal @@ -316,3 +317,32 @@ func (src Numeric) Value() (driver.Value, error) { return nil, errUndefined } } + +func (src Numeric) MarshalJSON() ([]byte, error) { + switch src.Status { + case pgtype.Present: + return src.Decimal.MarshalJSON() + case pgtype.Null: + return []byte("null"), nil + case pgtype.Undefined: + return nil, errUndefined + } + + return nil, errBadStatus +} + +func (dst *Numeric) UnmarshalJSON(b []byte) error { + d := decimal.NullDecimal{} + err := d.UnmarshalJSON(b) + if err != nil { + return err + } + + status := pgtype.Null + if d.Valid { + status = pgtype.Present + } + *dst = Numeric{Decimal: d.Decimal, Status: status} + + return nil +}