From d494f83cd141fe45a0e1964dfbd4d2474436e3e2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 3 Sep 2015 09:33:19 -0500 Subject: [PATCH] Add inet support --- conn.go | 2 ++ query.go | 5 ++++ values.go | 64 ++++++++++++++++++++++++++++++++++++++++++++++++++ values_test.go | 48 +++++++++++++++++++++++++++++++++++++ 4 files changed, 119 insertions(+) diff --git a/conn.go b/conn.go index a8c71636..00cee9fb 100644 --- a/conn.go +++ b/conn.go @@ -751,6 +751,8 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} err = encodeTimestampTz(wbuf, arguments[i]) case TimestampOid: err = encodeTimestamp(wbuf, arguments[i]) + case InetOid: + err = encodeInet(wbuf, arguments[i]) case BoolArrayOid: err = encodeBoolArray(wbuf, arguments[i]) case Int2ArrayOid: diff --git a/query.go b/query.go index bcbd593f..3de04111 100644 --- a/query.go +++ b/query.go @@ -3,6 +3,7 @@ package pgx import ( "errors" "fmt" + "net" "time" ) @@ -275,6 +276,8 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { default: rows.Fatal(fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType)) } + case *net.IPNet: + *d = decodeInet(vr) case Scanner: err = d.Scan(vr) if err != nil { @@ -355,6 +358,8 @@ func (rows *Rows) Values() ([]interface{}, error) { values = append(values, decodeTimestampTz(vr)) case TimestampOid: values = append(values, decodeTimestamp(vr)) + case InetOid: + values = append(values, decodeInet(vr)) default: rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types")) } diff --git a/values.go b/values.go index 2d5738b5..4718187a 100644 --- a/values.go +++ b/values.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" "math" + "net" "strconv" "strings" "time" @@ -20,6 +21,7 @@ const ( OidOid = 26 Float4Oid = 700 Float8Oid = 701 + InetOid = 869 BoolArrayOid = 1000 Int2ArrayOid = 1005 Int4ArrayOid = 1007 @@ -1101,6 +1103,68 @@ func encodeTimestamp(w *WriteBuf, value interface{}) error { return encodeTimestampTz(w, value) } +func decodeInet(vr *ValueReader) net.IPNet { + var zero net.IPNet + + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into net.IPNet")) + return zero + } + + if vr.Type().DataType != InetOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into inet", vr.Type().DataType))) + return zero + } + + s := vr.ReadString(vr.Len()) + hasNetmask := strings.ContainsRune(s, '/') + if !hasNetmask { + isIpv6 := strings.ContainsRune(s, ':') + if isIpv6 { + s += "/128" + } else { + s += "/32" + } + } + + _, ipnet, err := net.ParseCIDR(s) + if err != nil { + vr.Fatal(err) + return zero + } + + // if vr.Type().FormatCode != BinaryFormatCode { + // vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + // return zero + // } + + // if vr.Len() != 4 { + // vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", vr.Len()))) + // return zero + // } + + return *ipnet + +} + +func encodeInet(w *WriteBuf, value interface{}) error { + var ipnet net.IPNet + + switch value := value.(type) { + case net.IPNet: + ipnet = value + default: + return fmt.Errorf("Expected net.IPNet, received %T %v", value, value) + } + + s := ipnet.String() + + w.WriteInt32(int32(len(s))) + w.WriteBytes([]byte(s)) + + return nil +} + func decode1dArrayHeader(vr *ValueReader) (length int32, err error) { numDims := vr.ReadInt32() if numDims > 1 { diff --git a/values_test.go b/values_test.go index 0f411157..35e5a75f 100644 --- a/values_test.go +++ b/values_test.go @@ -3,6 +3,7 @@ package pgx_test import ( "fmt" "github.com/jackc/pgx" + "net" "reflect" "strings" "testing" @@ -89,6 +90,53 @@ func TestTimestampTzTranscode(t *testing.T) { } } +func mustParseCIDR(t *testing.T, s string) net.IPNet { + _, ipnet, err := net.ParseCIDR(s) + if err != nil { + t.Fatal(err) + } + + return *ipnet +} + +func TestInetTranscode(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + tests := []struct { + sql string + value net.IPNet + }{ + {"select $1::inet", mustParseCIDR(t, "0.0.0.0/32")}, + {"select $1::inet", mustParseCIDR(t, "127.0.0.1/32")}, + {"select $1::inet", mustParseCIDR(t, "12.34.56.0/32")}, + {"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")}, + {"select $1::inet", mustParseCIDR(t, "255.0.0.0/8")}, + {"select $1::inet", mustParseCIDR(t, "255.255.255.255/32")}, + {"select $1::inet", mustParseCIDR(t, "::/128")}, + {"select $1::inet", mustParseCIDR(t, "::/0")}, + {"select $1::inet", mustParseCIDR(t, "::1/128")}, + {"select $1::inet", mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128")}, + } + + for i, tt := range tests { + var actual net.IPNet + + err := conn.QueryRow(tt.sql, tt.value).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + } + + if actual.String() != tt.value.String() { + t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) + } + + ensureConnValid(t, conn) + } +} + func TestNullX(t *testing.T) { t.Parallel()