Add inet support
This commit is contained in:
@@ -751,6 +751,8 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||||||
err = encodeTimestampTz(wbuf, arguments[i])
|
err = encodeTimestampTz(wbuf, arguments[i])
|
||||||
case TimestampOid:
|
case TimestampOid:
|
||||||
err = encodeTimestamp(wbuf, arguments[i])
|
err = encodeTimestamp(wbuf, arguments[i])
|
||||||
|
case InetOid:
|
||||||
|
err = encodeInet(wbuf, arguments[i])
|
||||||
case BoolArrayOid:
|
case BoolArrayOid:
|
||||||
err = encodeBoolArray(wbuf, arguments[i])
|
err = encodeBoolArray(wbuf, arguments[i])
|
||||||
case Int2ArrayOid:
|
case Int2ArrayOid:
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package pgx
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -275,6 +276,8 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
|
|||||||
default:
|
default:
|
||||||
rows.Fatal(fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType))
|
rows.Fatal(fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType))
|
||||||
}
|
}
|
||||||
|
case *net.IPNet:
|
||||||
|
*d = decodeInet(vr)
|
||||||
case Scanner:
|
case Scanner:
|
||||||
err = d.Scan(vr)
|
err = d.Scan(vr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -355,6 +358,8 @@ func (rows *Rows) Values() ([]interface{}, error) {
|
|||||||
values = append(values, decodeTimestampTz(vr))
|
values = append(values, decodeTimestampTz(vr))
|
||||||
case TimestampOid:
|
case TimestampOid:
|
||||||
values = append(values, decodeTimestamp(vr))
|
values = append(values, decodeTimestamp(vr))
|
||||||
|
case InetOid:
|
||||||
|
values = append(values, decodeInet(vr))
|
||||||
default:
|
default:
|
||||||
rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types"))
|
rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types"))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -20,6 +21,7 @@ const (
|
|||||||
OidOid = 26
|
OidOid = 26
|
||||||
Float4Oid = 700
|
Float4Oid = 700
|
||||||
Float8Oid = 701
|
Float8Oid = 701
|
||||||
|
InetOid = 869
|
||||||
BoolArrayOid = 1000
|
BoolArrayOid = 1000
|
||||||
Int2ArrayOid = 1005
|
Int2ArrayOid = 1005
|
||||||
Int4ArrayOid = 1007
|
Int4ArrayOid = 1007
|
||||||
@@ -1101,6 +1103,68 @@ func encodeTimestamp(w *WriteBuf, value interface{}) error {
|
|||||||
return encodeTimestampTz(w, value)
|
return encodeTimestampTz(w, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func decodeInet(vr *ValueReader) net.IPNet {
|
||||||
|
var zero net.IPNet
|
||||||
|
|
||||||
|
if vr.Len() == -1 {
|
||||||
|
vr.Fatal(ProtocolError("Cannot decode null into net.IPNet"))
|
||||||
|
return zero
|
||||||
|
}
|
||||||
|
|
||||||
|
if vr.Type().DataType != InetOid {
|
||||||
|
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into inet", vr.Type().DataType)))
|
||||||
|
return zero
|
||||||
|
}
|
||||||
|
|
||||||
|
s := vr.ReadString(vr.Len())
|
||||||
|
hasNetmask := strings.ContainsRune(s, '/')
|
||||||
|
if !hasNetmask {
|
||||||
|
isIpv6 := strings.ContainsRune(s, ':')
|
||||||
|
if isIpv6 {
|
||||||
|
s += "/128"
|
||||||
|
} else {
|
||||||
|
s += "/32"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ipnet, err := net.ParseCIDR(s)
|
||||||
|
if err != nil {
|
||||||
|
vr.Fatal(err)
|
||||||
|
return zero
|
||||||
|
}
|
||||||
|
|
||||||
|
// if vr.Type().FormatCode != BinaryFormatCode {
|
||||||
|
// vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
|
||||||
|
// return zero
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if vr.Len() != 4 {
|
||||||
|
// vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", vr.Len())))
|
||||||
|
// return zero
|
||||||
|
// }
|
||||||
|
|
||||||
|
return *ipnet
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeInet(w *WriteBuf, value interface{}) error {
|
||||||
|
var ipnet net.IPNet
|
||||||
|
|
||||||
|
switch value := value.(type) {
|
||||||
|
case net.IPNet:
|
||||||
|
ipnet = value
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("Expected net.IPNet, received %T %v", value, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := ipnet.String()
|
||||||
|
|
||||||
|
w.WriteInt32(int32(len(s)))
|
||||||
|
w.WriteBytes([]byte(s))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func decode1dArrayHeader(vr *ValueReader) (length int32, err error) {
|
func decode1dArrayHeader(vr *ValueReader) (length int32, err error) {
|
||||||
numDims := vr.ReadInt32()
|
numDims := vr.ReadInt32()
|
||||||
if numDims > 1 {
|
if numDims > 1 {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package pgx_test
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/jackc/pgx"
|
"github.com/jackc/pgx"
|
||||||
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -89,6 +90,53 @@ func TestTimestampTzTranscode(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mustParseCIDR(t *testing.T, s string) net.IPNet {
|
||||||
|
_, ipnet, err := net.ParseCIDR(s)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return *ipnet
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInetTranscode(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
sql string
|
||||||
|
value net.IPNet
|
||||||
|
}{
|
||||||
|
{"select $1::inet", mustParseCIDR(t, "0.0.0.0/32")},
|
||||||
|
{"select $1::inet", mustParseCIDR(t, "127.0.0.1/32")},
|
||||||
|
{"select $1::inet", mustParseCIDR(t, "12.34.56.0/32")},
|
||||||
|
{"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")},
|
||||||
|
{"select $1::inet", mustParseCIDR(t, "255.0.0.0/8")},
|
||||||
|
{"select $1::inet", mustParseCIDR(t, "255.255.255.255/32")},
|
||||||
|
{"select $1::inet", mustParseCIDR(t, "::/128")},
|
||||||
|
{"select $1::inet", mustParseCIDR(t, "::/0")},
|
||||||
|
{"select $1::inet", mustParseCIDR(t, "::1/128")},
|
||||||
|
{"select $1::inet", mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128")},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
var actual net.IPNet
|
||||||
|
|
||||||
|
err := conn.QueryRow(tt.sql, tt.value).Scan(&actual)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.String() != tt.value.String() {
|
||||||
|
t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNullX(t *testing.T) {
|
func TestNullX(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user