@@ -54,21 +54,23 @@ var DefaultTypeFormats map[string]int16
|
||||
|
||||
func init() {
|
||||
DefaultTypeFormats = make(map[string]int16)
|
||||
DefaultTypeFormats["_bool"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_float4"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_float8"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_bool"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_int2"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_int4"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_int8"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_text"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_varchar"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_timestamp"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_timestamptz"] = BinaryFormatCode
|
||||
DefaultTypeFormats["_varchar"] = BinaryFormatCode
|
||||
DefaultTypeFormats["bool"] = BinaryFormatCode
|
||||
DefaultTypeFormats["bytea"] = BinaryFormatCode
|
||||
DefaultTypeFormats["cidr"] = BinaryFormatCode
|
||||
DefaultTypeFormats["date"] = BinaryFormatCode
|
||||
DefaultTypeFormats["float4"] = BinaryFormatCode
|
||||
DefaultTypeFormats["float8"] = BinaryFormatCode
|
||||
DefaultTypeFormats["inet"] = BinaryFormatCode
|
||||
DefaultTypeFormats["int2"] = BinaryFormatCode
|
||||
DefaultTypeFormats["int4"] = BinaryFormatCode
|
||||
DefaultTypeFormats["int8"] = BinaryFormatCode
|
||||
@@ -1112,41 +1114,32 @@ func decodeInet(vr *ValueReader) net.IPNet {
|
||||
return zero
|
||||
}
|
||||
|
||||
if vr.Type().FormatCode != BinaryFormatCode {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
return zero
|
||||
}
|
||||
|
||||
pgType := vr.Type()
|
||||
if vr.Len() != 8 && vr.Len() != 20 {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for a %s: %d", pgType.Name, vr.Len())))
|
||||
return zero
|
||||
}
|
||||
|
||||
if pgType.DataType != InetOid && pgType.DataType != CidrOid {
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into %s", pgType.DataType, vr.Type().Name)))
|
||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into %s", pgType.DataType, pgType.Name)))
|
||||
return zero
|
||||
}
|
||||
|
||||
s := vr.ReadString(vr.Len())
|
||||
hasNetmask := strings.ContainsRune(s, '/')
|
||||
if !hasNetmask {
|
||||
isIpv6 := strings.ContainsRune(s, ':')
|
||||
if isIpv6 {
|
||||
s += "/128"
|
||||
} else {
|
||||
s += "/32"
|
||||
}
|
||||
}
|
||||
vr.ReadByte() // ignore family
|
||||
bits := vr.ReadByte()
|
||||
vr.ReadByte() // ignore is_cidr
|
||||
addressLength := vr.ReadByte()
|
||||
|
||||
_, ipnet, err := net.ParseCIDR(s)
|
||||
if err != nil {
|
||||
vr.Fatal(err)
|
||||
return zero
|
||||
}
|
||||
|
||||
// if vr.Type().FormatCode != BinaryFormatCode {
|
||||
// vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||
// return zero
|
||||
// }
|
||||
|
||||
// if vr.Len() != 4 {
|
||||
// vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", vr.Len())))
|
||||
// return zero
|
||||
// }
|
||||
|
||||
return *ipnet
|
||||
var ipnet net.IPNet
|
||||
ipnet.IP = vr.ReadBytes(int32(addressLength))
|
||||
ipnet.Mask = net.CIDRMask(int(bits), int(addressLength)*8)
|
||||
|
||||
return ipnet
|
||||
}
|
||||
|
||||
func encodeInet(w *WriteBuf, value interface{}) error {
|
||||
@@ -1159,10 +1152,26 @@ func encodeInet(w *WriteBuf, value interface{}) error {
|
||||
return fmt.Errorf("Expected net.IPNet, received %T %v", value, value)
|
||||
}
|
||||
|
||||
s := ipnet.String()
|
||||
var size int32
|
||||
var family byte
|
||||
switch len(ipnet.IP) {
|
||||
case net.IPv4len:
|
||||
size = 8
|
||||
family = w.conn.pgsql_af_inet
|
||||
case net.IPv6len:
|
||||
size = 20
|
||||
family = w.conn.pgsql_af_inet6
|
||||
default:
|
||||
return fmt.Errorf("Unexpected IP length: %v", len(ipnet.IP))
|
||||
}
|
||||
|
||||
w.WriteInt32(int32(len(s)))
|
||||
w.WriteBytes([]byte(s))
|
||||
w.WriteInt32(size)
|
||||
w.WriteByte(family)
|
||||
ones, _ := ipnet.Mask.Size()
|
||||
w.WriteByte(byte(ones))
|
||||
w.WriteByte(0) // is_cidr is ignored on server
|
||||
w.WriteByte(byte(len(ipnet.IP)))
|
||||
w.WriteBytes(ipnet.IP)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user