2
0

Use binary transcoding for inet/cidr

fixes #87
This commit is contained in:
Jack Christensen
2015-09-03 11:39:32 -05:00
parent 9af068add0
commit fd39261551
6 changed files with 85 additions and 50 deletions
+42 -33
View File
@@ -54,21 +54,23 @@ var DefaultTypeFormats map[string]int16
func init() {
DefaultTypeFormats = make(map[string]int16)
DefaultTypeFormats["_bool"] = BinaryFormatCode
DefaultTypeFormats["_float4"] = BinaryFormatCode
DefaultTypeFormats["_float8"] = BinaryFormatCode
DefaultTypeFormats["_bool"] = BinaryFormatCode
DefaultTypeFormats["_int2"] = BinaryFormatCode
DefaultTypeFormats["_int4"] = BinaryFormatCode
DefaultTypeFormats["_int8"] = BinaryFormatCode
DefaultTypeFormats["_text"] = BinaryFormatCode
DefaultTypeFormats["_varchar"] = BinaryFormatCode
DefaultTypeFormats["_timestamp"] = BinaryFormatCode
DefaultTypeFormats["_timestamptz"] = BinaryFormatCode
DefaultTypeFormats["_varchar"] = BinaryFormatCode
DefaultTypeFormats["bool"] = BinaryFormatCode
DefaultTypeFormats["bytea"] = BinaryFormatCode
DefaultTypeFormats["cidr"] = BinaryFormatCode
DefaultTypeFormats["date"] = BinaryFormatCode
DefaultTypeFormats["float4"] = BinaryFormatCode
DefaultTypeFormats["float8"] = BinaryFormatCode
DefaultTypeFormats["inet"] = BinaryFormatCode
DefaultTypeFormats["int2"] = BinaryFormatCode
DefaultTypeFormats["int4"] = BinaryFormatCode
DefaultTypeFormats["int8"] = BinaryFormatCode
@@ -1112,41 +1114,32 @@ func decodeInet(vr *ValueReader) net.IPNet {
return zero
}
if vr.Type().FormatCode != BinaryFormatCode {
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
return zero
}
pgType := vr.Type()
if vr.Len() != 8 && vr.Len() != 20 {
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for a %s: %d", pgType.Name, vr.Len())))
return zero
}
if pgType.DataType != InetOid && pgType.DataType != CidrOid {
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into %s", pgType.DataType, vr.Type().Name)))
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into %s", pgType.DataType, pgType.Name)))
return zero
}
s := vr.ReadString(vr.Len())
hasNetmask := strings.ContainsRune(s, '/')
if !hasNetmask {
isIpv6 := strings.ContainsRune(s, ':')
if isIpv6 {
s += "/128"
} else {
s += "/32"
}
}
vr.ReadByte() // ignore family
bits := vr.ReadByte()
vr.ReadByte() // ignore is_cidr
addressLength := vr.ReadByte()
_, ipnet, err := net.ParseCIDR(s)
if err != nil {
vr.Fatal(err)
return zero
}
// if vr.Type().FormatCode != BinaryFormatCode {
// vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
// return zero
// }
// if vr.Len() != 4 {
// vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", vr.Len())))
// return zero
// }
return *ipnet
var ipnet net.IPNet
ipnet.IP = vr.ReadBytes(int32(addressLength))
ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8)
return ipnet
}
func encodeInet(w *WriteBuf, value interface{}) error {
@@ -1159,10 +1152,26 @@ func encodeInet(w *WriteBuf, value interface{}) error {
return fmt.Errorf("Expected net.IPNet, received %T %v", value, value)
}
s := ipnet.String()
var size int32
var family byte
switch len(ipnet.IP) {
case net.IPv4len:
size = 8
family = w.conn.pgsql_af_inet
case net.IPv6len:
size = 20
family = w.conn.pgsql_af_inet6
default:
return fmt.Errorf("Unexpected IP length: %v", len(ipnet.IP))
}
w.WriteInt32(int32(len(s)))
w.WriteBytes([]byte(s))
w.WriteInt32(size)
w.WriteByte(family)
ones, _ := ipnet.Mask.Size()
w.WriteByte(byte(ones))
w.WriteByte(0) // is_cidr is ignored on server
w.WriteByte(byte(len(ipnet.IP)))
w.WriteBytes(ipnet.IP)
return nil
}