diff --git a/inet.go b/inet.go index 1645334e..f35f88ba 100644 --- a/inet.go +++ b/inet.go @@ -47,7 +47,15 @@ func (dst *Inet) Set(src interface{}) error { case string: ip, ipnet, err := net.ParseCIDR(value) if err != nil { - return err + ip = net.ParseIP(value) + if ip == nil { + return fmt.Errorf("unable to parse inet address: %s", value) + } + ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)} + if ipv4 := ip.To4(); ipv4 != nil { + ip = ipv4 + ipnet.Mask = net.CIDRMask(32, 32) + } } ipnet.IP = ip *dst = Inet{IPNet: ipnet, Status: Present} diff --git a/inet_test.go b/inet_test.go index 66fe777f..09c6b21f 100644 --- a/inet_test.go +++ b/inet_test.go @@ -18,6 +18,8 @@ func TestInetTranscode(t *testing.T) { &pgtype.Inet{IPNet: mustParseInet(t, "192.168.1.16/24"), Status: pgtype.Present}, &pgtype.Inet{IPNet: mustParseInet(t, "255.0.0.0/8"), Status: pgtype.Present}, &pgtype.Inet{IPNet: mustParseInet(t, "255.255.255.255/32"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "10.0.0.1"), Status: pgtype.Present}, + &pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e"), Status: pgtype.Present}, &pgtype.Inet{IPNet: mustParseInet(t, "::1/64"), Status: pgtype.Present}, &pgtype.Inet{IPNet: mustParseInet(t, "::/0"), Status: pgtype.Present}, &pgtype.Inet{IPNet: mustParseInet(t, "::1/128"), Status: pgtype.Present}, @@ -51,6 +53,8 @@ func TestInetSet(t *testing.T) { {source: mustParseCIDR(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, {source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, {source: "1.2.3.4/24", result: pgtype.Inet{IPNet: &net.IPNet{IP: net.ParseIP("1.2.3.4"), Mask: net.CIDRMask(24, 32)}, Status: pgtype.Present}}, + {source: "10.0.0.1", result: pgtype.Inet{IPNet: mustParseInet(t, "10.0.0.1"), Status: pgtype.Present}}, + {source: "2607:f8b0:4009:80b::200e", result: pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e"), Status: pgtype.Present}}, {source: net.ParseIP(""), result: pgtype.Inet{Status: pgtype.Null}}, } @@ -59,6 +63,7 @@ func TestInetSet(t *testing.T) { err := r.Set(tt.source) if err != nil { t.Errorf("%d: %v", i, err) + continue } assert.Equalf(t, tt.result.Status, r.Status, "%d: Status", i) diff --git a/pgtype_test.go b/pgtype_test.go index 75e1909f..2506e0a3 100644 --- a/pgtype_test.go +++ b/pgtype_test.go @@ -37,15 +37,24 @@ func mustParseCIDR(t testing.TB, s string) *net.IPNet { func mustParseInet(t testing.TB, s string) *net.IPNet { ip, ipnet, err := net.ParseCIDR(s) - if err != nil { - t.Fatal(err) + if err == nil { + if ipv4 := ip.To4(); ipv4 != nil { + ipnet.IP = ipv4 + } + return ipnet } + + // May be bare IP address. + // + ip = net.ParseIP(s) + if ip == nil { + t.Fatal(errors.New("unable to parse inet address")) + } + ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)} if ipv4 := ip.To4(); ipv4 != nil { - ip = ipv4 + ipnet.IP = ipv4 + ipnet.Mask = net.CIDRMask(32, 32) } - - ipnet.IP = ip - return ipnet }