From 9af068add07396b0086c24524fab30cdcd55e2d4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 3 Sep 2015 09:42:01 -0500 Subject: [PATCH] Add cidr support --- conn.go | 2 +- query.go | 2 +- values.go | 6 ++++-- values_test.go | 12 +++++++++++- 4 files changed, 17 insertions(+), 5 deletions(-) diff --git a/conn.go b/conn.go index 00cee9fb..8b8ac7a8 100644 --- a/conn.go +++ b/conn.go @@ -751,7 +751,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} err = encodeTimestampTz(wbuf, arguments[i]) case TimestampOid: err = encodeTimestamp(wbuf, arguments[i]) - case InetOid: + case InetOid, CidrOid: err = encodeInet(wbuf, arguments[i]) case BoolArrayOid: err = encodeBoolArray(wbuf, arguments[i]) diff --git a/query.go b/query.go index 3de04111..d5d0b636 100644 --- a/query.go +++ b/query.go @@ -358,7 +358,7 @@ func (rows *Rows) Values() ([]interface{}, error) { values = append(values, decodeTimestampTz(vr)) case TimestampOid: values = append(values, decodeTimestamp(vr)) - case InetOid: + case InetOid, CidrOid: values = append(values, decodeInet(vr)) default: rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types")) diff --git a/values.go b/values.go index 4718187a..22797d81 100644 --- a/values.go +++ b/values.go @@ -19,6 +19,7 @@ const ( Int4Oid = 23 TextOid = 25 OidOid = 26 + CidrOid = 650 Float4Oid = 700 Float8Oid = 701 InetOid = 869 @@ -1111,8 +1112,9 @@ func decodeInet(vr *ValueReader) net.IPNet { return zero } - if vr.Type().DataType != InetOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into inet", vr.Type().DataType))) + pgType := vr.Type() + if pgType.DataType != InetOid && pgType.DataType != CidrOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into %s", pgType.DataType, vr.Type().Name))) return zero } diff --git a/values_test.go b/values_test.go index 35e5a75f..bf866fef 100644 --- a/values_test.go +++ b/values_test.go @@ -99,7 +99,7 @@ func mustParseCIDR(t *testing.T, s string) net.IPNet { return *ipnet } -func TestInetTranscode(t *testing.T) { +func TestInetCidrTranscode(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -119,6 +119,16 @@ func TestInetTranscode(t *testing.T) { {"select $1::inet", mustParseCIDR(t, "::/0")}, {"select $1::inet", mustParseCIDR(t, "::1/128")}, {"select $1::inet", mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128")}, + {"select $1::cidr", mustParseCIDR(t, "0.0.0.0/32")}, + {"select $1::cidr", mustParseCIDR(t, "127.0.0.1/32")}, + {"select $1::cidr", mustParseCIDR(t, "12.34.56.0/32")}, + {"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")}, + {"select $1::cidr", mustParseCIDR(t, "255.0.0.0/8")}, + {"select $1::cidr", mustParseCIDR(t, "255.255.255.255/32")}, + {"select $1::cidr", mustParseCIDR(t, "::/128")}, + {"select $1::cidr", mustParseCIDR(t, "::/0")}, + {"select $1::cidr", mustParseCIDR(t, "::1/128")}, + {"select $1::cidr", mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128")}, } for i, tt := range tests {