2
0
Files
pgx/ext/shopspring-numeric/decimal.go
T
Eli Treuherz db84905b7f Add NullDecimal to shopspring-numeric
The shopspring/decimal package provides a NullDecimal struct intended
for use with nullable SQL NUMERICs and numbers. It has Scanner and
Valuer implementations already, but adding it to this package allows
it to be used with the binary encoding as well.

The implementation is very straightforward, but the tests have been made
slightly more complicated. The previous version wasn't testing the
decimal.Decimal cases, and this change adds those as well as new
NullDecimal cases. I've added some logic to the test harness to catch
these as you need to use the Equals method to properly compare Decimals.
2021-08-07 08:23:02 -05:00

359 lines
9.2 KiB
Go

package numeric
import (
"database/sql/driver"
"errors"
"fmt"
"strconv"
"github.com/jackc/pgtype"
"github.com/shopspring/decimal"
)
var errUndefined = errors.New("cannot encode status undefined")
var errBadStatus = errors.New("invalid status")
type Numeric struct {
Decimal decimal.Decimal
Status pgtype.Status
}
func (dst *Numeric) Set(src interface{}) error {
if src == nil {
*dst = Numeric{Status: pgtype.Null}
return nil
}
if value, ok := src.(interface{ Get() interface{} }); ok {
value2 := value.Get()
if value2 != value {
return dst.Set(value2)
}
}
switch value := src.(type) {
case decimal.Decimal:
*dst = Numeric{Decimal: value, Status: pgtype.Present}
case decimal.NullDecimal:
if value.Valid {
*dst = Numeric{Decimal: value.Decimal, Status: pgtype.Present}
} else {
*dst = Numeric{Status: pgtype.Null}
}
case float32:
*dst = Numeric{Decimal: decimal.NewFromFloat(float64(value)), Status: pgtype.Present}
case float64:
*dst = Numeric{Decimal: decimal.NewFromFloat(value), Status: pgtype.Present}
case int8:
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
case uint8:
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
case int16:
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
case uint16:
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
case int32:
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
case uint32:
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
case int64:
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
case uint64:
// uint64 could be greater than int64 so convert to string then to decimal
dec, err := decimal.NewFromString(strconv.FormatUint(value, 10))
if err != nil {
return err
}
*dst = Numeric{Decimal: dec, Status: pgtype.Present}
case int:
*dst = Numeric{Decimal: decimal.New(int64(value), 0), Status: pgtype.Present}
case uint:
// uint could be greater than int64 so convert to string then to decimal
dec, err := decimal.NewFromString(strconv.FormatUint(uint64(value), 10))
if err != nil {
return err
}
*dst = Numeric{Decimal: dec, Status: pgtype.Present}
case string:
dec, err := decimal.NewFromString(value)
if err != nil {
return err
}
*dst = Numeric{Decimal: dec, Status: pgtype.Present}
default:
// If all else fails see if pgtype.Numeric can handle it. If so, translate through that.
num := &pgtype.Numeric{}
if err := num.Set(value); err != nil {
return fmt.Errorf("cannot convert %v to Numeric", value)
}
buf, err := num.EncodeText(nil, nil)
if err != nil {
return fmt.Errorf("cannot convert %v to Numeric", value)
}
dec, err := decimal.NewFromString(string(buf))
if err != nil {
return fmt.Errorf("cannot convert %v to Numeric", value)
}
*dst = Numeric{Decimal: dec, Status: pgtype.Present}
}
return nil
}
func (dst Numeric) Get() interface{} {
switch dst.Status {
case pgtype.Present:
return dst.Decimal
case pgtype.Null:
return nil
default:
return dst.Status
}
}
func (src *Numeric) AssignTo(dst interface{}) error {
switch src.Status {
case pgtype.Present:
switch v := dst.(type) {
case *decimal.Decimal:
*v = src.Decimal
case *decimal.NullDecimal:
(*v).Valid = true
(*v).Decimal = src.Decimal
case *float32:
f, _ := src.Decimal.Float64()
*v = float32(f)
case *float64:
f, _ := src.Decimal.Float64()
*v = f
case *int:
if src.Decimal.Exponent() < 0 {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
n, err := strconv.ParseInt(src.Decimal.String(), 10, strconv.IntSize)
if err != nil {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
*v = int(n)
case *int8:
if src.Decimal.Exponent() < 0 {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
n, err := strconv.ParseInt(src.Decimal.String(), 10, 8)
if err != nil {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
*v = int8(n)
case *int16:
if src.Decimal.Exponent() < 0 {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
n, err := strconv.ParseInt(src.Decimal.String(), 10, 16)
if err != nil {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
*v = int16(n)
case *int32:
if src.Decimal.Exponent() < 0 {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
n, err := strconv.ParseInt(src.Decimal.String(), 10, 32)
if err != nil {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
*v = int32(n)
case *int64:
if src.Decimal.Exponent() < 0 {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
n, err := strconv.ParseInt(src.Decimal.String(), 10, 64)
if err != nil {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
*v = int64(n)
case *uint:
if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
n, err := strconv.ParseUint(src.Decimal.String(), 10, strconv.IntSize)
if err != nil {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
*v = uint(n)
case *uint8:
if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
n, err := strconv.ParseUint(src.Decimal.String(), 10, 8)
if err != nil {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
*v = uint8(n)
case *uint16:
if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
n, err := strconv.ParseUint(src.Decimal.String(), 10, 16)
if err != nil {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
*v = uint16(n)
case *uint32:
if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
n, err := strconv.ParseUint(src.Decimal.String(), 10, 32)
if err != nil {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
*v = uint32(n)
case *uint64:
if src.Decimal.Exponent() < 0 || src.Decimal.Sign() < 0 {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
n, err := strconv.ParseUint(src.Decimal.String(), 10, 64)
if err != nil {
return fmt.Errorf("cannot convert %v to %T", dst, *v)
}
*v = uint64(n)
default:
if nextDst, retry := pgtype.GetAssignToDstType(dst); retry {
return src.AssignTo(nextDst)
}
return fmt.Errorf("unable to assign to %T", dst)
}
case pgtype.Null:
if v, ok := dst.(*decimal.NullDecimal); ok {
(*v).Valid = false
} else {
return pgtype.NullAssignTo(dst)
}
}
return nil
}
func (dst *Numeric) DecodeText(ci *pgtype.ConnInfo, src []byte) error {
if src == nil {
*dst = Numeric{Status: pgtype.Null}
return nil
}
dec, err := decimal.NewFromString(string(src))
if err != nil {
return err
}
*dst = Numeric{Decimal: dec, Status: pgtype.Present}
return nil
}
func (dst *Numeric) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error {
if src == nil {
*dst = Numeric{Status: pgtype.Null}
return nil
}
// For now at least, implement this in terms of pgtype.Numeric
num := &pgtype.Numeric{}
if err := num.DecodeBinary(ci, src); err != nil {
return err
}
*dst = Numeric{Decimal: decimal.NewFromBigInt(num.Int, num.Exp), Status: pgtype.Present}
return nil
}
func (src Numeric) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case pgtype.Null:
return nil, nil
case pgtype.Undefined:
return nil, errUndefined
}
return append(buf, src.Decimal.String()...), nil
}
func (src Numeric) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case pgtype.Null:
return nil, nil
case pgtype.Undefined:
return nil, errUndefined
}
// For now at least, implement this in terms of pgtype.Numeric
num := &pgtype.Numeric{}
if err := num.DecodeText(ci, []byte(src.Decimal.String())); err != nil {
return nil, err
}
return num.EncodeBinary(ci, buf)
}
// Scan implements the database/sql Scanner interface.
func (dst *Numeric) Scan(src interface{}) error {
if src == nil {
*dst = Numeric{Status: pgtype.Null}
return nil
}
switch src := src.(type) {
case float64:
*dst = Numeric{Decimal: decimal.NewFromFloat(src), Status: pgtype.Present}
return nil
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
return dst.DecodeText(nil, src)
}
return fmt.Errorf("cannot scan %T", src)
}
// Value implements the database/sql/driver Valuer interface.
func (src Numeric) Value() (driver.Value, error) {
switch src.Status {
case pgtype.Present:
return src.Decimal.Value()
case pgtype.Null:
return nil, nil
default:
return nil, errUndefined
}
}
func (src Numeric) MarshalJSON() ([]byte, error) {
switch src.Status {
case pgtype.Present:
return src.Decimal.MarshalJSON()
case pgtype.Null:
return []byte("null"), nil
case pgtype.Undefined:
return nil, errUndefined
}
return nil, errBadStatus
}
func (dst *Numeric) UnmarshalJSON(b []byte) error {
d := decimal.NullDecimal{}
err := d.UnmarshalJSON(b)
if err != nil {
return err
}
status := pgtype.Null
if d.Valid {
status = pgtype.Present
}
*dst = Numeric{Decimal: d.Decimal, Status: status}
return nil
}