2
0

Encode and decode between all integer types

fixes #138
This commit is contained in:
Jack Christensen
2016-04-28 15:28:38 -05:00
parent 623ba1eeb1
commit 88acc7e19f
2 changed files with 423 additions and 106 deletions
+159 -84
View File
@@ -53,6 +53,10 @@ const (
BinaryFormatCode = 1
)
const maxUint = ^uint(0)
const maxInt = int(maxUint >> 1)
const minInt = -maxInt - 1
// DefaultTypeFormats maps type names to their default requested format (text
// or binary). In theory the Scanner interface should be the one to determine
// the format of the returned values. However, the query has already been
@@ -623,6 +627,10 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error {
return encodeBool(wbuf, oid, arg)
case []bool:
return encodeBoolSlice(wbuf, oid, arg)
case int:
return encodeInt(wbuf, oid, arg)
case uint:
return encodeUInt(wbuf, oid, arg)
case int8:
return encodeInt8(wbuf, oid, arg)
case uint8:
@@ -651,8 +659,6 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error {
return encodeUInt64(wbuf, oid, arg)
case []uint64:
return encodeUInt64Slice(wbuf, oid, arg)
case int:
return encodeInt(wbuf, oid, arg)
case float32:
return encodeFloat32(wbuf, oid, arg)
case []float32:
@@ -687,54 +693,84 @@ func Decode(vr *ValueReader, d interface{}) error {
switch v := d.(type) {
case *bool:
*v = decodeBool(vr)
case *int64:
*v = decodeInt8(vr)
case *int:
n := decodeInt(vr)
if n < int64(minInt) {
return fmt.Errorf("%d is less than minimum value for int", n)
} else if n > int64(maxInt) {
return fmt.Errorf("%d is greater than maximum value for int", n)
}
*v = int(n)
case *int8:
n := decodeInt(vr)
if n < math.MinInt8 {
return fmt.Errorf("%d is less than minimum value for int8", n)
} else if n > math.MaxInt8 {
return fmt.Errorf("%d is greater than maximum value for int8", n)
}
*v = int8(n)
case *int16:
*v = decodeInt2(vr)
n := decodeInt(vr)
if n < math.MinInt16 {
return fmt.Errorf("%d is less than minimum value for int16", n)
} else if n > math.MaxInt16 {
return fmt.Errorf("%d is greater than maximum value for int16", n)
}
*v = int16(n)
case *int32:
*v = decodeInt4(vr)
n := decodeInt(vr)
if n < math.MinInt32 {
return fmt.Errorf("%d is less than minimum value for int32", n)
} else if n > math.MaxInt32 {
return fmt.Errorf("%d is greater than maximum value for int32", n)
}
*v = int32(n)
case *int64:
n := decodeInt(vr)
if n < math.MinInt64 {
return fmt.Errorf("%d is less than minimum value for int64", n)
} else if n > math.MaxInt64 {
return fmt.Errorf("%d is greater than maximum value for int64", n)
}
*v = int64(n)
case *uint:
n := decodeInt(vr)
if n < 0 {
return fmt.Errorf("%d is less than zero for uint8", n)
} else if maxInt == math.MaxInt32 && n > math.MaxUint32 {
return fmt.Errorf("%d is greater than maximum value for uint", n)
}
*v = uint(n)
case *uint8:
n := decodeInt(vr)
if n < 0 {
return fmt.Errorf("%d is less than zero for uint8", n)
} else if n > math.MaxUint8 {
return fmt.Errorf("%d is greater than maximum value for uint8", n)
}
*v = uint8(n)
case *uint16:
var valInt int16
switch vr.Type().DataType {
case Int2Oid:
valInt = int16(decodeInt2(vr))
default:
return fmt.Errorf("Can't convert OID %v to uint16", vr.Type().DataType)
n := decodeInt(vr)
if n < 0 {
return fmt.Errorf("%d is less than zero for uint16", n)
} else if n > math.MaxUint16 {
return fmt.Errorf("%d is greater than maximum value for uint16", n)
}
if valInt < 0 {
return fmt.Errorf("%d is less than zero for uint16", valInt)
}
*v = uint16(valInt)
*v = uint16(n)
case *uint32:
var valInt int32
switch vr.Type().DataType {
case Int2Oid:
valInt = int32(decodeInt2(vr))
case Int4Oid:
valInt = decodeInt4(vr)
default:
return fmt.Errorf("Can't convert OID %v to uint32", vr.Type().DataType)
n := decodeInt(vr)
if n < 0 {
return fmt.Errorf("%d is less than zero for uint32", n)
} else if n > math.MaxUint32 {
return fmt.Errorf("%d is greater than maximum value for uint32", n)
}
if valInt < 0 {
return fmt.Errorf("%d is less than zero for uint32", valInt)
}
*v = uint32(valInt)
*v = uint32(n)
case *uint64:
var valInt int64
switch vr.Type().DataType {
case Int2Oid:
valInt = int64(decodeInt2(vr))
case Int4Oid:
valInt = int64(decodeInt4(vr))
case Int8Oid:
valInt = decodeInt8(vr)
default:
return fmt.Errorf("Can't convert OID %v to uint64", vr.Type().DataType)
n := decodeInt(vr)
if n < 0 {
return fmt.Errorf("%d is less than zero for uint64", n)
}
if valInt < 0 {
return fmt.Errorf("%d is less than zero for uint64", valInt)
}
*v = uint64(valInt)
*v = uint64(n)
case *Oid:
*v = decodeOid(vr)
case *string:
@@ -865,6 +901,20 @@ func encodeBool(w *WriteBuf, oid Oid, value bool) error {
return nil
}
func decodeInt(vr *ValueReader) int64 {
switch vr.Type().DataType {
case Int2Oid:
return int64(decodeInt2(vr))
case Int4Oid:
return int64(decodeInt4(vr))
case Int8Oid:
return int64(decodeInt8(vr))
}
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into any integer type", vr.Type().DataType)))
return 0
}
func decodeInt8(vr *ValueReader) int64 {
if vr.Len() == -1 {
vr.Fatal(ProtocolError("Cannot decode null into int64"))
@@ -913,6 +963,61 @@ func decodeInt2(vr *ValueReader) int16 {
return vr.ReadInt16()
}
func encodeInt(w *WriteBuf, oid Oid, value int) error {
switch oid {
case Int2Oid:
if value < math.MinInt16 {
return fmt.Errorf("%d is less than min pg:int2", value)
} else if value > math.MaxInt16 {
return fmt.Errorf("%d is greater than max pg:int2", value)
}
w.WriteInt32(2)
w.WriteInt16(int16(value))
case Int4Oid:
if value < math.MinInt32 {
return fmt.Errorf("%d is less than min pg:int4", value)
} else if value > math.MaxInt32 {
return fmt.Errorf("%d is greater than max pg:int4", value)
}
w.WriteInt32(4)
w.WriteInt32(int32(value))
case Int8Oid:
w.WriteInt32(8)
w.WriteInt64(int64(value))
default:
return fmt.Errorf("cannot encode %s into oid %v", "int8", oid)
}
return nil
}
func encodeUInt(w *WriteBuf, oid Oid, value uint) error {
switch oid {
case Int2Oid:
if value > math.MaxInt16 {
return fmt.Errorf("%d is greater than max pg:int2", value)
}
w.WriteInt32(2)
w.WriteInt16(int16(value))
case Int4Oid:
if value > math.MaxInt32 {
return fmt.Errorf("%d is greater than max pg:int4", value)
}
w.WriteInt32(4)
w.WriteInt32(int32(value))
case Int8Oid:
if value > math.MaxInt64 {
return fmt.Errorf("%d is greater than max pg:int8", value)
}
w.WriteInt32(8)
w.WriteInt64(int64(value))
default:
return fmt.Errorf("cannot encode %s into oid %v", "uint8", oid)
}
return nil
}
func encodeInt8(w *WriteBuf, oid Oid, value int8) error {
switch oid {
case Int2Oid:
@@ -974,7 +1079,7 @@ func encodeUInt16(w *WriteBuf, oid Oid, value uint16) error {
w.WriteInt32(2)
w.WriteInt16(int16(value))
} else {
return fmt.Errorf("%d is larger than max int16 %d", value, math.MaxInt16)
return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16)
}
case Int4Oid:
w.WriteInt32(4)
@@ -996,7 +1101,7 @@ func encodeInt32(w *WriteBuf, oid Oid, value int32) error {
w.WriteInt32(2)
w.WriteInt16(int16(value))
} else {
return fmt.Errorf("%d is larger than max int16 %d", value, math.MaxInt16)
return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16)
}
case Int4Oid:
w.WriteInt32(4)
@@ -1018,14 +1123,14 @@ func encodeUInt32(w *WriteBuf, oid Oid, value uint32) error {
w.WriteInt32(2)
w.WriteInt16(int16(value))
} else {
return fmt.Errorf("%d is larger than max int16 %d", value, math.MaxInt16)
return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16)
}
case Int4Oid:
if value <= math.MaxInt32 {
w.WriteInt32(4)
w.WriteInt32(int32(value))
} else {
return fmt.Errorf("%d is larger than max int32 %d", value, math.MaxInt32)
return fmt.Errorf("%d is greater than max int32 %d", value, math.MaxInt32)
}
case Int8Oid:
w.WriteInt32(8)
@@ -1044,14 +1149,14 @@ func encodeInt64(w *WriteBuf, oid Oid, value int64) error {
w.WriteInt32(2)
w.WriteInt16(int16(value))
} else {
return fmt.Errorf("%d is larger than max int16 %d", value, math.MaxInt16)
return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16)
}
case Int4Oid:
if value <= math.MaxInt32 {
w.WriteInt32(4)
w.WriteInt32(int32(value))
} else {
return fmt.Errorf("%d is larger than max int32 %d", value, math.MaxInt32)
return fmt.Errorf("%d is greater than max int32 %d", value, math.MaxInt32)
}
case Int8Oid:
w.WriteInt32(8)
@@ -1070,14 +1175,14 @@ func encodeUInt64(w *WriteBuf, oid Oid, value uint64) error {
w.WriteInt32(2)
w.WriteInt16(int16(value))
} else {
return fmt.Errorf("%d is larger than max int16 %d", value, math.MaxInt16)
return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16)
}
case Int4Oid:
if value <= math.MaxInt32 {
w.WriteInt32(4)
w.WriteInt32(int32(value))
} else {
return fmt.Errorf("%d is larger than max int32 %d", value, math.MaxInt32)
return fmt.Errorf("%d is greater than max int32 %d", value, math.MaxInt32)
}
case Int8Oid:
@@ -1085,37 +1190,7 @@ func encodeUInt64(w *WriteBuf, oid Oid, value uint64) error {
w.WriteInt32(8)
w.WriteInt64(int64(value))
} else {
return fmt.Errorf("%d is larger than max int64 %d", value, int64(math.MaxInt64))
}
default:
return fmt.Errorf("cannot encode %s into oid %v", "uint64", oid)
}
return nil
}
func encodeInt(w *WriteBuf, oid Oid, value int) error {
switch oid {
case Int2Oid:
if value <= math.MaxInt16 {
w.WriteInt32(2)
w.WriteInt16(int16(value))
} else {
return fmt.Errorf("%d is larger than max int16 %d", value, math.MaxInt16)
}
case Int4Oid:
if value <= math.MaxInt32 {
w.WriteInt32(4)
w.WriteInt32(int32(value))
} else {
return fmt.Errorf("%d is larger than max int32 %d", value, math.MaxInt32)
}
case Int8Oid:
if int64(value) <= int64(math.MaxInt64) {
w.WriteInt32(8)
w.WriteInt64(int64(value))
} else {
return fmt.Errorf("%d is larger than max int64 %d", value, int64(math.MaxInt64))
return fmt.Errorf("%d is greater than max int64 %d", value, int64(math.MaxInt64))
}
default:
return fmt.Errorf("cannot encode %s into oid %v", "uint64", oid)
@@ -1716,7 +1791,7 @@ func encodeUInt16Slice(w *WriteBuf, oid Oid, slice []uint16) error {
w.WriteInt32(2)
w.WriteInt16(int16(v))
} else {
return fmt.Errorf("%d is larger than max smallint %d", v, math.MaxInt16)
return fmt.Errorf("%d is greater than max smallint %d", v, math.MaxInt16)
}
}
@@ -1831,7 +1906,7 @@ func encodeUInt32Slice(w *WriteBuf, oid Oid, slice []uint32) error {
w.WriteInt32(4)
w.WriteInt32(int32(v))
} else {
return fmt.Errorf("%d is larger than max integer %d", v, math.MaxInt32)
return fmt.Errorf("%d is greater than max integer %d", v, math.MaxInt32)
}
}
@@ -1946,7 +2021,7 @@ func encodeUInt64Slice(w *WriteBuf, oid Oid, slice []uint64) error {
w.WriteInt32(8)
w.WriteInt64(int64(v))
} else {
return fmt.Errorf("%d is larger than max bigint %d", v, int64(math.MaxInt64))
return fmt.Errorf("%d is greater than max bigint %d", v, int64(math.MaxInt64))
}
}