2
0

Support decoding inet/cidr to net.IP

fixes #137
This commit is contained in:
Jack Christensen
2016-04-22 16:00:11 -05:00
parent 5d6d01c41b
commit d62da82ab1
3 changed files with 166 additions and 9 deletions
+1 -1
View File
@@ -20,7 +20,7 @@ Pgx supports many additional features beyond what is available through database/
* PostgreSQL array to Go slice mapping for integers, floats, and strings * PostgreSQL array to Go slice mapping for integers, floats, and strings
* Hstore support * Hstore support
* JSON and JSONB support * JSON and JSONB support
* Maps inet and cidr PostgreSQL types to net.IPNet * Maps inet and cidr PostgreSQL types to net.IPNet and net.IP
* Large object support * Large object support
* Null mapping to Null* struct or pointer to pointer. * Null mapping to Null* struct or pointer to pointer.
* Supports database/sql.Scanner and database/sql/driver.Valuer interfaces for custom types * Supports database/sql.Scanner and database/sql/driver.Valuer interfaces for custom types
+20 -5
View File
@@ -764,6 +764,22 @@ func Decode(vr *ValueReader, d interface{}) error {
default: default:
return fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType) return fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType)
} }
case *net.IP:
ipnet := decodeInet(vr)
if oneCount, bitCount := ipnet.Mask.Size(); oneCount != bitCount {
return fmt.Errorf("Cannot decode netmask into *net.IP")
}
*v = ipnet.IP
case *[]net.IP:
ipnets := decodeInetArray(vr)
ips := make([]net.IP, len(ipnets))
for i, ipnet := range ipnets {
if oneCount, bitCount := ipnet.Mask.Size(); oneCount != bitCount {
return fmt.Errorf("Cannot decode netmask into *net.IP")
}
ips[i] = ipnet.IP
}
*v = ips
case *net.IPNet: case *net.IPNet:
*v = decodeInet(vr) *v = decodeInet(vr)
case *[]net.IPNet: case *[]net.IPNet:
@@ -1436,15 +1452,14 @@ func decodeInet(vr *ValueReader) net.IPNet {
} }
pgType := vr.Type() pgType := vr.Type()
if vr.Len() != 8 && vr.Len() != 20 {
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for a %s: %d", pgType.Name, vr.Len())))
return zero
}
if pgType.DataType != InetOid && pgType.DataType != CidrOid { if pgType.DataType != InetOid && pgType.DataType != CidrOid {
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into %s", pgType.DataType, pgType.Name))) vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into %s", pgType.DataType, pgType.Name)))
return zero return zero
} }
if vr.Len() != 8 && vr.Len() != 20 {
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for a %s: %d", pgType.Name, vr.Len())))
return zero
}
vr.ReadByte() // ignore family vr.ReadByte() // ignore family
bits := vr.ReadByte() bits := vr.ReadByte()
+145 -3
View File
@@ -1,12 +1,13 @@
package pgx_test package pgx_test
import ( import (
"github.com/jackc/pgx"
"net" "net"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/jackc/pgx"
) )
func TestDateTranscode(t *testing.T) { func TestDateTranscode(t *testing.T) {
@@ -258,7 +259,7 @@ func TestStringToNotTextTypeTranscode(t *testing.T) {
} }
} }
func TestInetCidrTranscode(t *testing.T) { func TestInetCidrTranscodeIPNet(t *testing.T) {
t.Parallel() t.Parallel()
conn := mustConnect(t, *defaultConnConfig) conn := mustConnect(t, *defaultConnConfig)
@@ -307,7 +308,67 @@ func TestInetCidrTranscode(t *testing.T) {
} }
} }
func TestInetCidrArrayTranscode(t *testing.T) { func TestInetCidrTranscodeIP(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
tests := []struct {
sql string
value net.IP
}{
{"select $1::inet", net.ParseIP("0.0.0.0")},
{"select $1::inet", net.ParseIP("127.0.0.1")},
{"select $1::inet", net.ParseIP("12.34.56.0")},
{"select $1::inet", net.ParseIP("255.255.255.255")},
{"select $1::inet", net.ParseIP("::1")},
{"select $1::inet", net.ParseIP("2607:f8b0:4009:80b::200e")},
{"select $1::cidr", net.ParseIP("0.0.0.0")},
{"select $1::cidr", net.ParseIP("127.0.0.1")},
{"select $1::cidr", net.ParseIP("12.34.56.0")},
{"select $1::cidr", net.ParseIP("255.255.255.255")},
{"select $1::cidr", net.ParseIP("::1")},
{"select $1::cidr", net.ParseIP("2607:f8b0:4009:80b::200e")},
}
for i, tt := range tests {
var actual net.IP
err := conn.QueryRow(tt.sql, tt.value).Scan(&actual)
if err != nil {
t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
continue
}
if !actual.Equal(tt.value) {
t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql)
}
ensureConnValid(t, conn)
}
failTests := []struct {
sql string
value net.IPNet
}{
{"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")},
{"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")},
}
for i, tt := range failTests {
var actual net.IP
err := conn.QueryRow(tt.sql, tt.value).Scan(&actual)
if !strings.Contains(err.Error(), "Cannot decode netmask") {
t.Errorf("%d. Expected failure cannot decode netmask, but got: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
continue
}
ensureConnValid(t, conn)
}
}
func TestInetCidrArrayTranscodeIPNet(t *testing.T) {
t.Parallel() t.Parallel()
conn := mustConnect(t, *defaultConnConfig) conn := mustConnect(t, *defaultConnConfig)
@@ -366,6 +427,87 @@ func TestInetCidrArrayTranscode(t *testing.T) {
} }
} }
func TestInetCidrArrayTranscodeIP(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
tests := []struct {
sql string
value []net.IP
}{
{
"select $1::inet[]",
[]net.IP{
net.ParseIP("0.0.0.0"),
net.ParseIP("127.0.0.1"),
net.ParseIP("12.34.56.0"),
net.ParseIP("255.255.255.255"),
net.ParseIP("2607:f8b0:4009:80b::200e"),
},
},
{
"select $1::cidr[]",
[]net.IP{
net.ParseIP("0.0.0.0"),
net.ParseIP("127.0.0.1"),
net.ParseIP("12.34.56.0"),
net.ParseIP("255.255.255.255"),
net.ParseIP("2607:f8b0:4009:80b::200e"),
},
},
}
for i, tt := range tests {
var actual []net.IP
err := conn.QueryRow(tt.sql, tt.value).Scan(&actual)
if err != nil {
t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
continue
}
if !reflect.DeepEqual(actual, tt.value) {
t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql)
}
ensureConnValid(t, conn)
}
failTests := []struct {
sql string
value []net.IPNet
}{
{
"select $1::inet[]",
[]net.IPNet{
mustParseCIDR(t, "12.34.56.0/32"),
mustParseCIDR(t, "192.168.1.0/24"),
},
},
{
"select $1::cidr[]",
[]net.IPNet{
mustParseCIDR(t, "12.34.56.0/32"),
mustParseCIDR(t, "192.168.1.0/24"),
},
},
}
for i, tt := range failTests {
var actual []net.IP
err := conn.QueryRow(tt.sql, tt.value).Scan(&actual)
if err == nil || !strings.Contains(err.Error(), "Cannot decode netmask") {
t.Errorf("%d. Expected failure cannot decode netmask, but got: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
continue
}
ensureConnValid(t, conn)
}
}
func TestInetCidrTranscodeWithJustIP(t *testing.T) { func TestInetCidrTranscodeWithJustIP(t *testing.T) {
t.Parallel() t.Parallel()