2
0

Encode / decode named types with compatible underlying type

Handle string, int, int8, int16, int32, int64, uint, uint8, uint16,
uint32, uint64.
This commit is contained in:
Jack Christensen
2016-07-05 18:01:44 -05:00
parent 30cb421551
commit 71d8b5b438
3 changed files with 166 additions and 5 deletions
+61 -5
View File
@@ -615,12 +615,14 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error {
return encodeByteSliceSlice(wbuf, oid, arg)
}
if v := reflect.ValueOf(arg); v.Kind() == reflect.Ptr {
if v.IsNil() {
refVal := reflect.ValueOf(arg)
if refVal.Kind() == reflect.Ptr {
if refVal.IsNil() {
wbuf.WriteInt32(-1)
return nil
} else {
arg = v.Elem().Interface()
arg = refVal.Elem().Interface()
return Encode(wbuf, oid, arg)
}
}
@@ -691,10 +693,42 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error {
case Oid:
return encodeOid(wbuf, oid, arg)
default:
if strippedArg, ok := stripNamedType(&refVal); ok {
return Encode(wbuf, oid, strippedArg)
}
return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
}
}
func stripNamedType(val *reflect.Value) (interface{}, bool) {
switch val.Kind() {
case reflect.Int:
return int(val.Int()), true
case reflect.Int8:
return int8(val.Int()), true
case reflect.Int16:
return int16(val.Int()), true
case reflect.Int32:
return int32(val.Int()), true
case reflect.Int64:
return int64(val.Int()), true
case reflect.Uint:
return uint(val.Uint()), true
case reflect.Uint8:
return uint8(val.Uint()), true
case reflect.Uint16:
return uint16(val.Uint()), true
case reflect.Uint32:
return uint32(val.Uint()), true
case reflect.Uint64:
return uint64(val.Uint()), true
case reflect.String:
return val.String(), true
}
return nil, false
}
// Decode decodes from vr into d. d must be a pointer. This allows
// implementations of the Decoder interface to delegate the actual work of
// decoding to the built-in functionality.
@@ -846,9 +880,11 @@ func Decode(vr *ValueReader, d interface{}) error {
case *[]net.IPNet:
*v = decodeInetArray(vr)
default:
// if d is a pointer to pointer, strip the pointer and try again
if v := reflect.ValueOf(d); v.Kind() == reflect.Ptr {
if el := v.Elem(); el.Kind() == reflect.Ptr {
el := v.Elem()
switch el.Kind() {
// if d is a pointer to pointer, strip the pointer and try again
case reflect.Ptr:
// -1 is a null value
if vr.Len() == -1 {
if !el.IsNil() {
@@ -864,6 +900,26 @@ func Decode(vr *ValueReader, d interface{}) error {
d = el.Interface()
return Decode(vr, d)
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
n := decodeInt(vr)
if el.OverflowInt(n) {
return fmt.Errorf("Scan cannot decode %d into %T", n, d)
}
el.SetInt(n)
return nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
n := decodeInt(vr)
if n < 0 {
return fmt.Errorf("%d is less than zero for %T", n, d)
}
if el.OverflowUint(uint64(n)) {
return fmt.Errorf("Scan cannot decode %d into %T", n, d)
}
el.SetUint(uint64(n))
return nil
case reflect.String:
el.SetString(decodeText(vr))
return nil
}
}
return fmt.Errorf("Scan cannot decode into %T", d)