db84905b7f
The shopspring/decimal package provides a NullDecimal struct intended for use with nullable SQL NUMERICs and numbers. It has Scanner and Valuer implementations already, but adding it to this package allows it to be used with the binary encoding as well. The implementation is very straightforward, but the tests have been made slightly more complicated. The previous version wasn't testing the decimal.Decimal cases, and this change adds those as well as new NullDecimal cases. I've added some logic to the test harness to catch these as you need to use the Equals method to properly compare Decimals.
359 lines
9.2 KiB
Go
359 lines
9.2 KiB
Go
package numeric
|
|
|
|
import (
|
|
"database/sql/driver"
|
|
"errors"
|
|
"fmt"
|
|
"strconv"
|
|
|
|
"github.com/jackc/pgtype"
|
|
"github.com/shopspring/decimal"
|
|
)
|
|
|
|
var errUndefined = errors.New("cannot encode status undefined")
|
|
var errBadStatus = errors.New("invalid status")
|
|
|
|
type Numeric struct {
|
|
Decimal decimal.Decimal
|
|
Status pgtype.Status
|
|
}
|
|
|
|
func (dst *Numeric) Set(src interface{}) error {
|
|
if src == nil {
|
|
*dst = Numeric{Status: pgtype.Null}
|
|
return nil
|
|
}
|
|
|
|
if value, ok := src.(interface{ Get() interface{} }); ok {
|
|
value2 := value.Get()
|
|
if value2 != value {
|
|
return dst.Set(value2)
|
|
}
|
|
}
|
|
|
|
switch value := src.(type) {
|
|
case decimal.Decimal:
|
|
*dst = Numeric{Decimal: value, Status: pgtype.Present}
|
|
case decimal.NullDecimal:
|
|
if value.Valid {
|
|
*dst = Numeric{Decimal: value.Decimal, Status: pgtype.Present}
|
|
} else {
|
|
*dst = Numeric{Status: pgtype.Null}
|
|
}
|
|
case float32:
|
|
*dst = Numeric{Decimal: decimal.NewFromFloat(float64(value)), Status: pgtype.Present}
|
|
case float64:
|
|
*dst = Numeric{Decimal: decimal.NewFromFloat(value), Status: pgtype.Present}
|
|
case int8:
|
|
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
|
|
case uint8:
|
|
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
|
|
case int16:
|
|
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
|
|
case uint16:
|
|
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
|
|
case int32:
|
|
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
|
|
case uint32:
|
|
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
|
|
case int64:
|
|
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
|
|
case uint64:
|
|
// uint64 could be greater than int64 so convert to string then to decimal
|
|
dec, err := decimal.NewFromString(strconv.FormatUint(value, 10))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
*dst = Numeric{Decimal: dec, Status: pgtype.Present}
|
|
case int:
|
|
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
|
|
case uint:
|
|
// uint could be greater than int64 so convert to string then to decimal
|
|
dec, err := decimal.NewFromString(strconv.FormatUint(uint64(value), 10))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
*dst = Numeric{Decimal: dec, Status: pgtype.Present}
|
|
case string:
|
|
dec, err := decimal.NewFromString(value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
*dst = Numeric{Decimal: dec, Status: pgtype.Present}
|
|
default:
|
|
// If all else fails see if pgtype.Numeric can handle it. If so, translate through that.
|
|
num := &pgtype.Numeric{}
|
|
if err := num.Set(value); err != nil {
|
|
return fmt.Errorf("cannot convert %v to Numeric", value)
|
|
}
|
|
|
|
buf, err := num.EncodeText(nil, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("cannot convert %v to Numeric", value)
|
|
}
|
|
|
|
dec, err := decimal.NewFromString(string(buf))
|
|
if err != nil {
|
|
return fmt.Errorf("cannot convert %v to Numeric", value)
|
|
}
|
|
*dst = Numeric{Decimal: dec, Status: pgtype.Present}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (dst Numeric) Get() interface{} {
|
|
switch dst.Status {
|
|
case pgtype.Present:
|
|
return dst.Decimal
|
|
case pgtype.Null:
|
|
return nil
|
|
default:
|
|
return dst.Status
|
|
}
|
|
}
|
|
|
|
func (src *Numeric) AssignTo(dst interface{}) error {
|
|
switch src.Status {
|
|
case pgtype.Present:
|
|
switch v := dst.(type) {
|
|
case *decimal.Decimal:
|
|
*v = src.Decimal
|
|
case *decimal.NullDecimal:
|
|
(*v).Valid = true
|
|
(*v).Decimal = src.Decimal
|
|
case *float32:
|
|
f, _ := src.Decimal.Float64()
|
|
*v = float32(f)
|
|
case *float64:
|
|
f, _ := src.Decimal.Float64()
|
|
*v = f
|
|
case *int:
|
|
if src.Decimal.Exponent() < 0 {
|
|
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
|
}
|
|
n, err := strconv.ParseInt(src.Decimal.String(), 10, strconv.IntSize)
|
|
if err != nil {
|
|
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
|
}
|
|
*v = int(n)
|
|
case *int8:
|
|
if src.Decimal.Exponent() < 0 {
|
|
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
|
}
|
|
n, err := strconv.ParseInt(src.Decimal.String(), 10, 8)
|
|
if err != nil {
|
|
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
|
}
|
|
*v = int8(n)
|
|
case *int16:
|
|
if src.Decimal.Exponent() < 0 {
|
|
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
|
}
|
|
n, err := strconv.ParseInt(src.Decimal.String(), 10, 16)
|
|
if err != nil {
|
|
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
|
}
|
|
*v = int16(n)
|
|
case *int32:
|
|
if src.Decimal.Exponent() < 0 {
|
|
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
|
}
|
|
n, err := strconv.ParseInt(src.Decimal.String(), 10, 32)
|
|
if err != nil {
|
|
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
|
}
|
|
*v = int32(n)
|
|
case *int64:
|
|
if src.Decimal.Exponent() < 0 {
|
|
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
|
}
|
|
n, err := strconv.ParseInt(src.Decimal.String(), 10, 64)
|
|
if err != nil {
|
|
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
|
}
|
|
*v = int64(n)
|
|
case *uint:
|
|
if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 {
|
|
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
|
}
|
|
n, err := strconv.ParseUint(src.Decimal.String(), 10, strconv.IntSize)
|
|
if err != nil {
|
|
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
|
}
|
|
*v = uint(n)
|
|
case *uint8:
|
|
if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 {
|
|
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
|
}
|
|
n, err := strconv.ParseUint(src.Decimal.String(), 10, 8)
|
|
if err != nil {
|
|
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
|
}
|
|
*v = uint8(n)
|
|
case *uint16:
|
|
if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 {
|
|
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
|
}
|
|
n, err := strconv.ParseUint(src.Decimal.String(), 10, 16)
|
|
if err != nil {
|
|
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
|
}
|
|
*v = uint16(n)
|
|
case *uint32:
|
|
if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 {
|
|
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
|
}
|
|
n, err := strconv.ParseUint(src.Decimal.String(), 10, 32)
|
|
if err != nil {
|
|
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
|
}
|
|
*v = uint32(n)
|
|
case *uint64:
|
|
if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 {
|
|
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
|
}
|
|
n, err := strconv.ParseUint(src.Decimal.String(), 10, 64)
|
|
if err != nil {
|
|
return fmt.Errorf("cannot convert %v to %T", dst, *v)
|
|
}
|
|
*v = uint64(n)
|
|
default:
|
|
if nextDst, retry := pgtype.GetAssignToDstType(dst); retry {
|
|
return src.AssignTo(nextDst)
|
|
}
|
|
return fmt.Errorf("unable to assign to %T", dst)
|
|
}
|
|
case pgtype.Null:
|
|
if v, ok := dst.(*decimal.NullDecimal); ok {
|
|
(*v).Valid = false
|
|
} else {
|
|
return pgtype.NullAssignTo(dst)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (dst *Numeric) DecodeText(ci *pgtype.ConnInfo, src []byte) error {
|
|
if src == nil {
|
|
*dst = Numeric{Status: pgtype.Null}
|
|
return nil
|
|
}
|
|
|
|
dec, err := decimal.NewFromString(string(src))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
*dst = Numeric{Decimal: dec, Status: pgtype.Present}
|
|
return nil
|
|
}
|
|
|
|
func (dst *Numeric) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error {
|
|
if src == nil {
|
|
*dst = Numeric{Status: pgtype.Null}
|
|
return nil
|
|
}
|
|
|
|
// For now at least, implement this in terms of pgtype.Numeric
|
|
|
|
num := &pgtype.Numeric{}
|
|
if err := num.DecodeBinary(ci, src); err != nil {
|
|
return err
|
|
}
|
|
|
|
*dst = Numeric{Decimal: decimal.NewFromBigInt(num.Int, num.Exp), Status: pgtype.Present}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (src Numeric) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) {
|
|
switch src.Status {
|
|
case pgtype.Null:
|
|
return nil, nil
|
|
case pgtype.Undefined:
|
|
return nil, errUndefined
|
|
}
|
|
|
|
return append(buf, src.Decimal.String()...), nil
|
|
}
|
|
|
|
func (src Numeric) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) {
|
|
switch src.Status {
|
|
case pgtype.Null:
|
|
return nil, nil
|
|
case pgtype.Undefined:
|
|
return nil, errUndefined
|
|
}
|
|
|
|
// For now at least, implement this in terms of pgtype.Numeric
|
|
num := &pgtype.Numeric{}
|
|
if err := num.DecodeText(ci, []byte(src.Decimal.String())); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return num.EncodeBinary(ci, buf)
|
|
}
|
|
|
|
// Scan implements the database/sql Scanner interface.
|
|
func (dst *Numeric) Scan(src interface{}) error {
|
|
if src == nil {
|
|
*dst = Numeric{Status: pgtype.Null}
|
|
return nil
|
|
}
|
|
|
|
switch src := src.(type) {
|
|
case float64:
|
|
*dst = Numeric{Decimal: decimal.NewFromFloat(src), Status: pgtype.Present}
|
|
return nil
|
|
case string:
|
|
return dst.DecodeText(nil, []byte(src))
|
|
case []byte:
|
|
return dst.DecodeText(nil, src)
|
|
}
|
|
|
|
return fmt.Errorf("cannot scan %T", src)
|
|
}
|
|
|
|
// Value implements the database/sql/driver Valuer interface.
|
|
func (src Numeric) Value() (driver.Value, error) {
|
|
switch src.Status {
|
|
case pgtype.Present:
|
|
return src.Decimal.Value()
|
|
case pgtype.Null:
|
|
return nil, nil
|
|
default:
|
|
return nil, errUndefined
|
|
}
|
|
}
|
|
|
|
func (src Numeric) MarshalJSON() ([]byte, error) {
|
|
switch src.Status {
|
|
case pgtype.Present:
|
|
return src.Decimal.MarshalJSON()
|
|
case pgtype.Null:
|
|
return []byte("null"), nil
|
|
case pgtype.Undefined:
|
|
return nil, errUndefined
|
|
}
|
|
|
|
return nil, errBadStatus
|
|
}
|
|
|
|
func (dst *Numeric) UnmarshalJSON(b []byte) error {
|
|
d := decimal.NullDecimal{}
|
|
err := d.UnmarshalJSON(b)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
status := pgtype.Null
|
|
if d.Valid {
|
|
status = pgtype.Present
|
|
}
|
|
*dst = Numeric{Decimal: d.Decimal, Status: status}
|
|
|
|
return nil
|
|
}
|