diff --git a/inet.go b/inet.go index f35f88ba..25e56170 100644 --- a/inet.go +++ b/inet.go @@ -47,17 +47,26 @@ func (dst *Inet) Set(src interface{}) error { case string: ip, ipnet, err := net.ParseCIDR(value) if err != nil { - ip = net.ParseIP(value) + ip := net.ParseIP(value) if ip == nil { return fmt.Errorf("unable to parse inet address: %s", value) } - ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)} + if ipv4 := ip.To4(); ipv4 != nil { - ip = ipv4 - ipnet.Mask = net.CIDRMask(32, 32) + ipnet = &net.IPNet{IP: ipv4, Mask: net.CIDRMask(32, 32)} + } else { + ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)} + } + } else { + ipnet.IP = ip + if ipv4 := ipnet.IP.To4(); ipv4 != nil { + ipnet.IP = ipv4 + if len(ipnet.Mask) == 16 { + ipnet.Mask = ipnet.Mask[12:] // Needed if input is IPv4-mapped IPv6. + } } } - ipnet.IP = ip + *dst = Inet{IPNet: ipnet, Status: Present} case *net.IPNet: if value == nil { diff --git a/inet_test.go b/inet_test.go index 8d70c0d0..badbf82e 100644 --- a/inet_test.go +++ b/inet_test.go @@ -52,10 +52,12 @@ func TestInetSet(t *testing.T) { {source: mustParseCIDR(t, "127.0.0.1/32"), result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, {source: mustParseCIDR(t, "127.0.0.1/32").IP, result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, {source: "127.0.0.1/32", result: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}}, - {source: "1.2.3.4/24", result: pgtype.Inet{IPNet: &net.IPNet{IP: net.ParseIP("1.2.3.4"), Mask: net.CIDRMask(24, 32)}, Status: pgtype.Present}}, + {source: "1.2.3.4/24", result: pgtype.Inet{IPNet: &net.IPNet{IP: net.ParseIP("1.2.3.4").To4(), Mask: net.CIDRMask(24, 32)}, Status: pgtype.Present}}, {source: "10.0.0.1", result: pgtype.Inet{IPNet: mustParseInet(t, "10.0.0.1"), Status: pgtype.Present}}, {source: "2607:f8b0:4009:80b::200e", result: pgtype.Inet{IPNet: mustParseInet(t, "2607:f8b0:4009:80b::200e"), Status: pgtype.Present}}, {source: net.ParseIP(""), result: pgtype.Inet{Status: pgtype.Null}}, + {source: "0.0.0.0/8", result: pgtype.Inet{IPNet: mustParseInet(t, "0.0.0.0/8"), Status: pgtype.Present}}, + {source: "::ffff:0.0.0.0/104", result: pgtype.Inet{IPNet: &net.IPNet{IP: net.ParseIP("0.0.0.0").To4(), Mask: net.CIDRMask(8, 32)}, Status: pgtype.Present}}, } for i, tt := range successfulTests { @@ -70,6 +72,7 @@ func TestInetSet(t *testing.T) { if tt.result.Status == pgtype.Present { assert.Equalf(t, tt.result.IPNet.Mask, r.IPNet.Mask, "%d: Mask", i) assert.Truef(t, tt.result.IPNet.IP.Equal(r.IPNet.IP), "%d: IP", i) + assert.Equalf(t, len(tt.result.IPNet.IP), len(r.IPNet.IP), "%d: IP length", i) } } }