diff --git a/numeric.go b/numeric.go index 83791d4e..85648dc2 100644 --- a/numeric.go +++ b/numeric.go @@ -57,14 +57,47 @@ var bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase) type Numeric struct { Int *big.Int Exp int32 - Valid bool NaN bool InfinityModifier InfinityModifier + Valid bool + + NumericDecoderWrapper func(interface{}) NumericDecoder +} + +func (n *Numeric) NewTypeValue() Value { + return &Numeric{ + NumericDecoderWrapper: n.NumericDecoderWrapper, + } +} + +func (n *Numeric) TypeName() string { + return "numeric" +} + +func (dst *Numeric) setNil() { + dst.Int = nil + dst.Exp = 0 + dst.NaN = false + dst.Valid = false +} + +func (dst *Numeric) setNaN() { + dst.Int = nil + dst.Exp = 0 + dst.NaN = true + dst.Valid = true +} + +func (dst *Numeric) setNumber(i *big.Int, exp int32) { + dst.Int = i + dst.Exp = exp + dst.NaN = false + dst.Valid = true } func (dst *Numeric) Set(src interface{}) error { if src == nil { - *dst = Numeric{} + dst.setNil() return nil } @@ -78,7 +111,7 @@ func (dst *Numeric) Set(src interface{}) error { switch value := src.(type) { case float32: if math.IsNaN(float64(value)) { - *dst = Numeric{Valid: true, NaN: true} + dst.setNaN() return nil } else if math.IsInf(float64(value), 1) { *dst = Numeric{Valid: true, InfinityModifier: Infinity} @@ -91,10 +124,10 @@ func (dst *Numeric) Set(src interface{}) error { if err != nil { return err } - *dst = Numeric{Int: num, Exp: exp, Valid: true} + dst.setNumber(num, exp) case float64: if math.IsNaN(value) { - *dst = Numeric{Valid: true, NaN: true} + dst.setNaN() return nil } else if math.IsInf(value, 1) { *dst = Numeric{Valid: true, InfinityModifier: Infinity} @@ -107,108 +140,108 @@ func (dst *Numeric) Set(src interface{}) error { if err != nil { return err } - *dst = Numeric{Int: num, Exp: exp, Valid: true} + dst.setNumber(num, exp) case int8: - *dst = Numeric{Int: big.NewInt(int64(value)), Valid: true} + dst.setNumber(big.NewInt(int64(value)), 0) case uint8: - *dst = Numeric{Int: big.NewInt(int64(value)), Valid: true} + dst.setNumber(big.NewInt(int64(value)), 0) case int16: - *dst = Numeric{Int: big.NewInt(int64(value)), Valid: true} + dst.setNumber(big.NewInt(int64(value)), 0) case uint16: - *dst = Numeric{Int: big.NewInt(int64(value)), Valid: true} + dst.setNumber(big.NewInt(int64(value)), 0) case int32: - *dst = Numeric{Int: big.NewInt(int64(value)), Valid: true} + dst.setNumber(big.NewInt(int64(value)), 0) case uint32: - *dst = Numeric{Int: big.NewInt(int64(value)), Valid: true} + dst.setNumber(big.NewInt(int64(value)), 0) case int64: - *dst = Numeric{Int: big.NewInt(value), Valid: true} + dst.setNumber(big.NewInt(value), 0) case uint64: - *dst = Numeric{Int: (&big.Int{}).SetUint64(value), Valid: true} + dst.setNumber((&big.Int{}).SetUint64(value), 0) case int: - *dst = Numeric{Int: big.NewInt(int64(value)), Valid: true} + dst.setNumber(big.NewInt(int64(value)), 0) case uint: - *dst = Numeric{Int: (&big.Int{}).SetUint64(uint64(value)), Valid: true} + dst.setNumber((&big.Int{}).SetUint64(uint64(value)), 0) case string: num, exp, err := parseNumericString(value) if err != nil { return err } - *dst = Numeric{Int: num, Exp: exp, Valid: true} + dst.setNumber(num, exp) case *float64: if value == nil { - *dst = Numeric{} + dst.setNil() } else { return dst.Set(*value) } case *float32: if value == nil { - *dst = Numeric{} + dst.setNil() } else { return dst.Set(*value) } case *int8: if value == nil { - *dst = Numeric{} + dst.setNil() } else { return dst.Set(*value) } case *uint8: if value == nil { - *dst = Numeric{} + dst.setNil() } else { return dst.Set(*value) } case *int16: if value == nil { - *dst = Numeric{} + dst.setNil() } else { return dst.Set(*value) } case *uint16: if value == nil { - *dst = Numeric{} + dst.setNil() } else { return dst.Set(*value) } case *int32: if value == nil { - *dst = Numeric{} + dst.setNil() } else { return dst.Set(*value) } case *uint32: if value == nil { - *dst = Numeric{} + dst.setNil() } else { return dst.Set(*value) } case *int64: if value == nil { - *dst = Numeric{} + dst.setNil() } else { return dst.Set(*value) } case *uint64: if value == nil { - *dst = Numeric{} + dst.setNil() } else { return dst.Set(*value) } case *int: if value == nil { - *dst = Numeric{} + dst.setNil() } else { return dst.Set(*value) } case *uint: if value == nil { - *dst = Numeric{} + dst.setNil() } else { return dst.Set(*value) } case *string: if value == nil { - *dst = Numeric{} + dst.setNil() } else { return dst.Set(*value) } @@ -235,8 +268,6 @@ func (dst Numeric) Get() interface{} { return dst } -var NumericDecoderWrapper func(interface{}) NumericDecoder - type NumericDecoder interface { DecodeNumeric(*Numeric) error } @@ -245,8 +276,8 @@ func (src *Numeric) AssignTo(dst interface{}) error { if d, ok := dst.(NumericDecoder); ok { return d.DecodeNumeric(src) } else { - if NumericDecoderWrapper != nil { - d = NumericDecoderWrapper(dst) + if src.NumericDecoderWrapper != nil { + d = src.NumericDecoderWrapper(dst) if d != nil { return d.DecodeNumeric(src) } @@ -443,12 +474,12 @@ func (src *Numeric) toFloat64() (float64, error) { func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Numeric{} + dst.setNil() return nil } if string(src) == "NaN" { - *dst = Numeric{Valid: true, NaN: true} + dst.setNaN() return nil } else if string(src) == "Infinity" { *dst = Numeric{Valid: true, InfinityModifier: Infinity} @@ -463,7 +494,7 @@ func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error { return err } - *dst = Numeric{Int: num, Exp: exp, Valid: true} + dst.setNumber(num, exp) return nil } @@ -490,7 +521,7 @@ func parseNumericString(str string) (n *big.Int, exp int32, err error) { func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { if src == nil { - *dst = Numeric{} + dst.setNil() return nil } @@ -509,7 +540,7 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { rp += 2 if sign == pgNumericNaNSign { - *dst = Numeric{Valid: true, NaN: true} + dst.setNaN() return nil } else if sign == pgNumericPosInfSign { *dst = Numeric{Valid: true, InfinityModifier: Infinity} @@ -520,7 +551,7 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { } if ndigits == 0 { - *dst = Numeric{Int: big.NewInt(0), Valid: true} + dst.setNumber(big.NewInt(0), 0) return nil } @@ -592,7 +623,7 @@ func (dst *Numeric) DecodeBinary(ci *ConnInfo, src []byte) error { accum.Neg(accum) } - *dst = Numeric{Int: accum, Exp: exp, Valid: true} + dst.setNumber(accum, exp) return nil @@ -741,7 +772,7 @@ func (src Numeric) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { // Scan implements the database/sql Scanner interface. func (dst *Numeric) Scan(src interface{}) error { if src == nil { - *dst = Numeric{} + dst.setNil() return nil }