@@ -20,7 +20,7 @@ Pgx supports many additional features beyond what is available through database/
|
|||||||
* PostgreSQL array to Go slice mapping for integers, floats, and strings
|
* PostgreSQL array to Go slice mapping for integers, floats, and strings
|
||||||
* Hstore support
|
* Hstore support
|
||||||
* JSON and JSONB support
|
* JSON and JSONB support
|
||||||
* Maps inet and cidr PostgreSQL types to net.IPNet
|
* Maps inet and cidr PostgreSQL types to net.IPNet and net.IP
|
||||||
* Large object support
|
* Large object support
|
||||||
* Null mapping to Null* struct or pointer to pointer.
|
* Null mapping to Null* struct or pointer to pointer.
|
||||||
* Supports database/sql.Scanner and database/sql/driver.Valuer interfaces for custom types
|
* Supports database/sql.Scanner and database/sql/driver.Valuer interfaces for custom types
|
||||||
|
|||||||
@@ -764,6 +764,22 @@ func Decode(vr *ValueReader, d interface{}) error {
|
|||||||
default:
|
default:
|
||||||
return fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType)
|
return fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType)
|
||||||
}
|
}
|
||||||
|
case *net.IP:
|
||||||
|
ipnet := decodeInet(vr)
|
||||||
|
if oneCount, bitCount := ipnet.Mask.Size(); oneCount != bitCount {
|
||||||
|
return fmt.Errorf("Cannot decode netmask into *net.IP")
|
||||||
|
}
|
||||||
|
*v = ipnet.IP
|
||||||
|
case *[]net.IP:
|
||||||
|
ipnets := decodeInetArray(vr)
|
||||||
|
ips := make([]net.IP, len(ipnets))
|
||||||
|
for i, ipnet := range ipnets {
|
||||||
|
if oneCount, bitCount := ipnet.Mask.Size(); oneCount != bitCount {
|
||||||
|
return fmt.Errorf("Cannot decode netmask into *net.IP")
|
||||||
|
}
|
||||||
|
ips[i] = ipnet.IP
|
||||||
|
}
|
||||||
|
*v = ips
|
||||||
case *net.IPNet:
|
case *net.IPNet:
|
||||||
*v = decodeInet(vr)
|
*v = decodeInet(vr)
|
||||||
case *[]net.IPNet:
|
case *[]net.IPNet:
|
||||||
@@ -1436,15 +1452,14 @@ func decodeInet(vr *ValueReader) net.IPNet {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pgType := vr.Type()
|
pgType := vr.Type()
|
||||||
if vr.Len() != 8 && vr.Len() != 20 {
|
|
||||||
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for a %s: %d", pgType.Name, vr.Len())))
|
|
||||||
return zero
|
|
||||||
}
|
|
||||||
|
|
||||||
if pgType.DataType != InetOid && pgType.DataType != CidrOid {
|
if pgType.DataType != InetOid && pgType.DataType != CidrOid {
|
||||||
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into %s", pgType.DataType, pgType.Name)))
|
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into %s", pgType.DataType, pgType.Name)))
|
||||||
return zero
|
return zero
|
||||||
}
|
}
|
||||||
|
if vr.Len() != 8 && vr.Len() != 20 {
|
||||||
|
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for a %s: %d", pgType.Name, vr.Len())))
|
||||||
|
return zero
|
||||||
|
}
|
||||||
|
|
||||||
vr.ReadByte() // ignore family
|
vr.ReadByte() // ignore family
|
||||||
bits := vr.ReadByte()
|
bits := vr.ReadByte()
|
||||||
|
|||||||
+145
-3
@@ -1,12 +1,13 @@
|
|||||||
package pgx_test
|
package pgx_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/jackc/pgx"
|
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDateTranscode(t *testing.T) {
|
func TestDateTranscode(t *testing.T) {
|
||||||
@@ -258,7 +259,7 @@ func TestStringToNotTextTypeTranscode(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInetCidrTranscode(t *testing.T) {
|
func TestInetCidrTranscodeIPNet(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
conn := mustConnect(t, *defaultConnConfig)
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
@@ -307,7 +308,67 @@ func TestInetCidrTranscode(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInetCidrArrayTranscode(t *testing.T) {
|
func TestInetCidrTranscodeIP(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
sql string
|
||||||
|
value net.IP
|
||||||
|
}{
|
||||||
|
{"select $1::inet", net.ParseIP("0.0.0.0")},
|
||||||
|
{"select $1::inet", net.ParseIP("127.0.0.1")},
|
||||||
|
{"select $1::inet", net.ParseIP("12.34.56.0")},
|
||||||
|
{"select $1::inet", net.ParseIP("255.255.255.255")},
|
||||||
|
{"select $1::inet", net.ParseIP("::1")},
|
||||||
|
{"select $1::inet", net.ParseIP("2607:f8b0:4009:80b::200e")},
|
||||||
|
{"select $1::cidr", net.ParseIP("0.0.0.0")},
|
||||||
|
{"select $1::cidr", net.ParseIP("127.0.0.1")},
|
||||||
|
{"select $1::cidr", net.ParseIP("12.34.56.0")},
|
||||||
|
{"select $1::cidr", net.ParseIP("255.255.255.255")},
|
||||||
|
{"select $1::cidr", net.ParseIP("::1")},
|
||||||
|
{"select $1::cidr", net.ParseIP("2607:f8b0:4009:80b::200e")},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
var actual net.IP
|
||||||
|
|
||||||
|
err := conn.QueryRow(tt.sql, tt.value).Scan(&actual)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !actual.Equal(tt.value) {
|
||||||
|
t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
failTests := []struct {
|
||||||
|
sql string
|
||||||
|
value net.IPNet
|
||||||
|
}{
|
||||||
|
{"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")},
|
||||||
|
{"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")},
|
||||||
|
}
|
||||||
|
for i, tt := range failTests {
|
||||||
|
var actual net.IP
|
||||||
|
|
||||||
|
err := conn.QueryRow(tt.sql, tt.value).Scan(&actual)
|
||||||
|
if !strings.Contains(err.Error(), "Cannot decode netmask") {
|
||||||
|
t.Errorf("%d. Expected failure cannot decode netmask, but got: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInetCidrArrayTranscodeIPNet(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
conn := mustConnect(t, *defaultConnConfig)
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
@@ -366,6 +427,87 @@ func TestInetCidrArrayTranscode(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestInetCidrArrayTranscodeIP(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
sql string
|
||||||
|
value []net.IP
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"select $1::inet[]",
|
||||||
|
[]net.IP{
|
||||||
|
net.ParseIP("0.0.0.0"),
|
||||||
|
net.ParseIP("127.0.0.1"),
|
||||||
|
net.ParseIP("12.34.56.0"),
|
||||||
|
net.ParseIP("255.255.255.255"),
|
||||||
|
net.ParseIP("2607:f8b0:4009:80b::200e"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"select $1::cidr[]",
|
||||||
|
[]net.IP{
|
||||||
|
net.ParseIP("0.0.0.0"),
|
||||||
|
net.ParseIP("127.0.0.1"),
|
||||||
|
net.ParseIP("12.34.56.0"),
|
||||||
|
net.ParseIP("255.255.255.255"),
|
||||||
|
net.ParseIP("2607:f8b0:4009:80b::200e"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
var actual []net.IP
|
||||||
|
|
||||||
|
err := conn.QueryRow(tt.sql, tt.value).Scan(&actual)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(actual, tt.value) {
|
||||||
|
t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
failTests := []struct {
|
||||||
|
sql string
|
||||||
|
value []net.IPNet
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"select $1::inet[]",
|
||||||
|
[]net.IPNet{
|
||||||
|
mustParseCIDR(t, "12.34.56.0/32"),
|
||||||
|
mustParseCIDR(t, "192.168.1.0/24"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"select $1::cidr[]",
|
||||||
|
[]net.IPNet{
|
||||||
|
mustParseCIDR(t, "12.34.56.0/32"),
|
||||||
|
mustParseCIDR(t, "192.168.1.0/24"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range failTests {
|
||||||
|
var actual []net.IP
|
||||||
|
|
||||||
|
err := conn.QueryRow(tt.sql, tt.value).Scan(&actual)
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "Cannot decode netmask") {
|
||||||
|
t.Errorf("%d. Expected failure cannot decode netmask, but got: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestInetCidrTranscodeWithJustIP(t *testing.T) {
|
func TestInetCidrTranscodeWithJustIP(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user