2
0

Add inet support

This commit is contained in:
Jack Christensen
2015-09-03 09:33:19 -05:00
parent a56e35ad0a
commit d494f83cd1
4 changed files with 119 additions and 0 deletions
+2
View File
@@ -751,6 +751,8 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
err = encodeTimestampTz(wbuf, arguments[i]) err = encodeTimestampTz(wbuf, arguments[i])
case TimestampOid: case TimestampOid:
err = encodeTimestamp(wbuf, arguments[i]) err = encodeTimestamp(wbuf, arguments[i])
case InetOid:
err = encodeInet(wbuf, arguments[i])
case BoolArrayOid: case BoolArrayOid:
err = encodeBoolArray(wbuf, arguments[i]) err = encodeBoolArray(wbuf, arguments[i])
case Int2ArrayOid: case Int2ArrayOid:
+5
View File
@@ -3,6 +3,7 @@ package pgx
import ( import (
"errors" "errors"
"fmt" "fmt"
"net"
"time" "time"
) )
@@ -275,6 +276,8 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
default: default:
rows.Fatal(fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType)) rows.Fatal(fmt.Errorf("Can't convert OID %v to time.Time", vr.Type().DataType))
} }
case *net.IPNet:
*d = decodeInet(vr)
case Scanner: case Scanner:
err = d.Scan(vr) err = d.Scan(vr)
if err != nil { if err != nil {
@@ -355,6 +358,8 @@ func (rows *Rows) Values() ([]interface{}, error) {
values = append(values, decodeTimestampTz(vr)) values = append(values, decodeTimestampTz(vr))
case TimestampOid: case TimestampOid:
values = append(values, decodeTimestamp(vr)) values = append(values, decodeTimestamp(vr))
case InetOid:
values = append(values, decodeInet(vr))
default: default:
rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types")) rows.Fatal(errors.New("Values cannot handle binary format non-intrinsic types"))
} }
+64
View File
@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"math" "math"
"net"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -20,6 +21,7 @@ const (
OidOid = 26 OidOid = 26
Float4Oid = 700 Float4Oid = 700
Float8Oid = 701 Float8Oid = 701
InetOid = 869
BoolArrayOid = 1000 BoolArrayOid = 1000
Int2ArrayOid = 1005 Int2ArrayOid = 1005
Int4ArrayOid = 1007 Int4ArrayOid = 1007
@@ -1101,6 +1103,68 @@ func encodeTimestamp(w *WriteBuf, value interface{}) error {
return encodeTimestampTz(w, value) return encodeTimestampTz(w, value)
} }
func decodeInet(vr *ValueReader) net.IPNet {
var zero net.IPNet
if vr.Len() == -1 {
vr.Fatal(ProtocolError("Cannot decode null into net.IPNet"))
return zero
}
if vr.Type().DataType != InetOid {
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into inet", vr.Type().DataType)))
return zero
}
s := vr.ReadString(vr.Len())
hasNetmask := strings.ContainsRune(s, '/')
if !hasNetmask {
isIpv6 := strings.ContainsRune(s, ':')
if isIpv6 {
s += "/128"
} else {
s += "/32"
}
}
_, ipnet, err := net.ParseCIDR(s)
if err != nil {
vr.Fatal(err)
return zero
}
// if vr.Type().FormatCode != BinaryFormatCode {
// vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
// return zero
// }
// if vr.Len() != 4 {
// vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", vr.Len())))
// return zero
// }
return *ipnet
}
func encodeInet(w *WriteBuf, value interface{}) error {
var ipnet net.IPNet
switch value := value.(type) {
case net.IPNet:
ipnet = value
default:
return fmt.Errorf("Expected net.IPNet, received %T %v", value, value)
}
s := ipnet.String()
w.WriteInt32(int32(len(s)))
w.WriteBytes([]byte(s))
return nil
}
func decode1dArrayHeader(vr *ValueReader) (length int32, err error) { func decode1dArrayHeader(vr *ValueReader) (length int32, err error) {
numDims := vr.ReadInt32() numDims := vr.ReadInt32()
if numDims > 1 { if numDims > 1 {
+48
View File
@@ -3,6 +3,7 @@ package pgx_test
import ( import (
"fmt" "fmt"
"github.com/jackc/pgx" "github.com/jackc/pgx"
"net"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
@@ -89,6 +90,53 @@ func TestTimestampTzTranscode(t *testing.T) {
} }
} }
func mustParseCIDR(t *testing.T, s string) net.IPNet {
_, ipnet, err := net.ParseCIDR(s)
if err != nil {
t.Fatal(err)
}
return *ipnet
}
func TestInetTranscode(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
tests := []struct {
sql string
value net.IPNet
}{
{"select $1::inet", mustParseCIDR(t, "0.0.0.0/32")},
{"select $1::inet", mustParseCIDR(t, "127.0.0.1/32")},
{"select $1::inet", mustParseCIDR(t, "12.34.56.0/32")},
{"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")},
{"select $1::inet", mustParseCIDR(t, "255.0.0.0/8")},
{"select $1::inet", mustParseCIDR(t, "255.255.255.255/32")},
{"select $1::inet", mustParseCIDR(t, "::/128")},
{"select $1::inet", mustParseCIDR(t, "::/0")},
{"select $1::inet", mustParseCIDR(t, "::1/128")},
{"select $1::inet", mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128")},
}
for i, tt := range tests {
var actual net.IPNet
err := conn.QueryRow(tt.sql, tt.value).Scan(&actual)
if err != nil {
t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value)
}
if actual.String() != tt.value.String() {
t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql)
}
ensureConnValid(t, conn)
}
}
func TestNullX(t *testing.T) { func TestNullX(t *testing.T) {
t.Parallel() t.Parallel()