diff --git a/conn.go b/conn.go index 975cf337..fd134461 100644 --- a/conn.go +++ b/conn.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "io/ioutil" + "math" "net" "net/url" "os" @@ -947,16 +948,26 @@ func ParseURI(uri string) (ConnConfig, error) { cp.Password, _ = url.User.Password() } - parts := strings.SplitN(url.Host, ":", 2) - cp.Host = parts[0] - if len(parts) == 2 { - p, err := strconv.ParseUint(parts[1], 10, 16) - if err != nil { - return cp, err + hasMuliHosts := strings.IndexByte(url.Host, ',') != -1 + if !hasMuliHosts { + parts := strings.SplitN(url.Host, ":", 2) + cp.Host = parts[0] + if len(parts) == 2 { + p, err := strconv.ParseUint(parts[1], 10, 16) + if err != nil { + return cp, err + } + cp.Port = uint16(p) } - cp.Port = uint16(p) + } else { + cp.Host = url.Host } + cp.Database = strings.TrimLeft(url.Path, "/") + cp.TargetSessionAttrs = TargetSessionType(url.Query().Get("target_session_attrs")) + if err := cp.TargetSessionAttrs.isValid(); err != nil { + return cp, err + } if pgtimeout := url.Query().Get("connect_timeout"); pgtimeout != "" { timeout, err := strconv.ParseInt(pgtimeout, 10, 64) @@ -980,11 +991,12 @@ func ParseURI(uri string) (ConnConfig, error) { } ignoreKeys := map[string]struct{}{ - "connect_timeout": {}, - "sslcert": {}, - "sslkey": {}, - "sslmode": {}, - "sslrootcert": {}, + "connect_timeout": {}, + "sslcert": {}, + "sslkey": {}, + "sslmode": {}, + "sslrootcert": {}, + "target_session_attrs": {}, } cp.RuntimeParams = make(map[string]string) @@ -1029,6 +1041,7 @@ func ParseDSN(s string) (ConnConfig, error) { cp.RuntimeParams = make(map[string]string) + var hostval, portval string for _, b := range m { switch b[1] { case "user": @@ -1036,13 +1049,9 @@ func ParseDSN(s string) (ConnConfig, error) { case "password": cp.Password = b[2] case "host": - cp.Host = b[2] + hostval = b[2] case "port": - p, err := strconv.ParseUint(b[2], 10, 16) - if err != nil { - return cp, err - } - cp.Port = uint16(p) + portval = b[2] case "dbname": cp.Database = b[2] case "sslmode": @@ -1061,26 +1070,94 @@ func ParseDSN(s string) (ConnConfig, error) { d := defaultDialer() d.Timeout = time.Duration(timeout) * time.Second cp.Dial = d.Dial + case "target_session_attrs": + cp.TargetSessionAttrs = TargetSessionType(b[2]) + if err := cp.TargetSessionAttrs.isValid(); err != nil { + return cp, err + } default: cp.RuntimeParams[b[1]] = b[2] } } - err := configTLS(tlsArgs, &cp) + host, port, err := parseHostPortDSN(hostval, portval) if err != nil { return cp, err } + + cp.Host, cp.Port = host, port + + err = configTLS(tlsArgs, &cp) + if err != nil { + return cp, err + } + if cp.Password == "" { pgpass(&cp) } + return cp, nil } -// ParseConnectionString parses either a URI or a DSN connection string. -// see ParseURI and ParseDSN for details. +func parseHostPortDSN(hostval, portval string) (host string, port uint16, err error) { + if portval == "" { + return hostval, 0, nil + } + + hosts := strings.Split(hostval, ",") + ports := strings.Split(portval, ",") + + if len(ports) == 1 { + port, err := parsePort(portval) + if err != nil { + return "", 0, errors.Errorf("invalid port: %v", err) + } + + return hostval, port, nil + } + + if len(hosts) != len(ports) { + return "", 0, errors.New("the number of hosts and ports must be the same") + } + + hostports := make([]string, len(hosts)) + for i, host := range hosts { + hostports[i] = host + ":" + ports[i] + } + + return strings.Join(hostports, ","), 0, nil +} + +func parsePort(s string) (uint16, error) { + port, err := strconv.ParseUint(s, 10, 16) + if err != nil { + return 0, err + } + if port < 1 || port > math.MaxUint16 { + return 0, errors.New("outside range") + } + return uint16(port), nil +} + +// ParseConnectionString parses either a URI or a DSN connection string and builds ConnConfig. +// +// # Example DSN +// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca +// +// # Example URL +// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca +// +// ParseConnectionString supports specifying multiple hosts in similar manner to libpq. +// Host and port may include comma separated values that will be tried in order. +// This can be used as part of a high availability system. +// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. +// +// # Example URL +// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb +// +// # Example DSN +// user=jack password=secret host=host1,host2,host3 port=5432,5433,5434 dbname=mydb sslmode=verify-ca func ParseConnectionString(s string) (ConnConfig, error) { - // TODO: Multiple Hosts support - // @see: https://www.postgresql.org/docs/10/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS if u, err := url.Parse(s); err == nil && u.Scheme != "" { return ParseURI(s) } diff --git a/conn_test.go b/conn_test.go index 28bfe48b..7719bec7 100644 --- a/conn_test.go +++ b/conn_test.go @@ -622,6 +622,38 @@ func TestParseURI(t *testing.T) { RuntimeParams: map[string]string{}, }, }, + { + url: "postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "foo.example.com:5432,bar.example.com:5432", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "postgres://jack@localhost,10.10.20.30/mydb?application_name=pgxtest&target_session_attrs=read-write", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost,10.10.20.30", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{ + "application_name": "pgxtest", + }, + TargetSessionAttrs: pgx.ReadWriteTargetSession, + }, + }, } for i, tt := range tests { @@ -748,6 +780,50 @@ func TestParseDSN(t *testing.T) { RuntimeParams: map[string]string{}, }, }, + { + url: "user=jack host=localhost1,localhost2 dbname=mydb connect_timeout=10", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost1,localhost2", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user=jack host=100.200.220.50,localhost43 port=5432,5433 dbname=mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "100.200.220.50:5432,localhost43:5433", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user=jack host=localhost dbname=mydb target_session_attrs=read-write", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + TargetSessionAttrs: pgx.ReadWriteTargetSession, + }, + }, } for i, tt := range tests {