ParseConnectionString supports Multi-Hosts
Signed-off-by: Artemiy Ryabinkov <getlag@ya.ru>
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -947,16 +948,26 @@ func ParseURI(uri string) (ConnConfig, error) {
|
||||
cp.Password, _ = url.User.Password()
|
||||
}
|
||||
|
||||
parts := strings.SplitN(url.Host, ":", 2)
|
||||
cp.Host = parts[0]
|
||||
if len(parts) == 2 {
|
||||
p, err := strconv.ParseUint(parts[1], 10, 16)
|
||||
if err != nil {
|
||||
return cp, err
|
||||
hasMuliHosts := strings.IndexByte(url.Host, ',') != -1
|
||||
if !hasMuliHosts {
|
||||
parts := strings.SplitN(url.Host, ":", 2)
|
||||
cp.Host = parts[0]
|
||||
if len(parts) == 2 {
|
||||
p, err := strconv.ParseUint(parts[1], 10, 16)
|
||||
if err != nil {
|
||||
return cp, err
|
||||
}
|
||||
cp.Port = uint16(p)
|
||||
}
|
||||
cp.Port = uint16(p)
|
||||
} else {
|
||||
cp.Host = url.Host
|
||||
}
|
||||
|
||||
cp.Database = strings.TrimLeft(url.Path, "/")
|
||||
cp.TargetSessionAttrs = TargetSessionType(url.Query().Get("target_session_attrs"))
|
||||
if err := cp.TargetSessionAttrs.isValid(); err != nil {
|
||||
return cp, err
|
||||
}
|
||||
|
||||
if pgtimeout := url.Query().Get("connect_timeout"); pgtimeout != "" {
|
||||
timeout, err := strconv.ParseInt(pgtimeout, 10, 64)
|
||||
@@ -980,11 +991,12 @@ func ParseURI(uri string) (ConnConfig, error) {
|
||||
}
|
||||
|
||||
ignoreKeys := map[string]struct{}{
|
||||
"connect_timeout": {},
|
||||
"sslcert": {},
|
||||
"sslkey": {},
|
||||
"sslmode": {},
|
||||
"sslrootcert": {},
|
||||
"connect_timeout": {},
|
||||
"sslcert": {},
|
||||
"sslkey": {},
|
||||
"sslmode": {},
|
||||
"sslrootcert": {},
|
||||
"target_session_attrs": {},
|
||||
}
|
||||
|
||||
cp.RuntimeParams = make(map[string]string)
|
||||
@@ -1029,6 +1041,7 @@ func ParseDSN(s string) (ConnConfig, error) {
|
||||
|
||||
cp.RuntimeParams = make(map[string]string)
|
||||
|
||||
var hostval, portval string
|
||||
for _, b := range m {
|
||||
switch b[1] {
|
||||
case "user":
|
||||
@@ -1036,13 +1049,9 @@ func ParseDSN(s string) (ConnConfig, error) {
|
||||
case "password":
|
||||
cp.Password = b[2]
|
||||
case "host":
|
||||
cp.Host = b[2]
|
||||
hostval = b[2]
|
||||
case "port":
|
||||
p, err := strconv.ParseUint(b[2], 10, 16)
|
||||
if err != nil {
|
||||
return cp, err
|
||||
}
|
||||
cp.Port = uint16(p)
|
||||
portval = b[2]
|
||||
case "dbname":
|
||||
cp.Database = b[2]
|
||||
case "sslmode":
|
||||
@@ -1061,26 +1070,94 @@ func ParseDSN(s string) (ConnConfig, error) {
|
||||
d := defaultDialer()
|
||||
d.Timeout = time.Duration(timeout) * time.Second
|
||||
cp.Dial = d.Dial
|
||||
case "target_session_attrs":
|
||||
cp.TargetSessionAttrs = TargetSessionType(b[2])
|
||||
if err := cp.TargetSessionAttrs.isValid(); err != nil {
|
||||
return cp, err
|
||||
}
|
||||
default:
|
||||
cp.RuntimeParams[b[1]] = b[2]
|
||||
}
|
||||
}
|
||||
|
||||
err := configTLS(tlsArgs, &cp)
|
||||
host, port, err := parseHostPortDSN(hostval, portval)
|
||||
if err != nil {
|
||||
return cp, err
|
||||
}
|
||||
|
||||
cp.Host, cp.Port = host, port
|
||||
|
||||
err = configTLS(tlsArgs, &cp)
|
||||
if err != nil {
|
||||
return cp, err
|
||||
}
|
||||
|
||||
if cp.Password == "" {
|
||||
pgpass(&cp)
|
||||
}
|
||||
|
||||
return cp, nil
|
||||
}
|
||||
|
||||
// ParseConnectionString parses either a URI or a DSN connection string.
|
||||
// see ParseURI and ParseDSN for details.
|
||||
func parseHostPortDSN(hostval, portval string) (host string, port uint16, err error) {
|
||||
if portval == "" {
|
||||
return hostval, 0, nil
|
||||
}
|
||||
|
||||
hosts := strings.Split(hostval, ",")
|
||||
ports := strings.Split(portval, ",")
|
||||
|
||||
if len(ports) == 1 {
|
||||
port, err := parsePort(portval)
|
||||
if err != nil {
|
||||
return "", 0, errors.Errorf("invalid port: %v", err)
|
||||
}
|
||||
|
||||
return hostval, port, nil
|
||||
}
|
||||
|
||||
if len(hosts) != len(ports) {
|
||||
return "", 0, errors.New("the number of hosts and ports must be the same")
|
||||
}
|
||||
|
||||
hostports := make([]string, len(hosts))
|
||||
for i, host := range hosts {
|
||||
hostports[i] = host + ":" + ports[i]
|
||||
}
|
||||
|
||||
return strings.Join(hostports, ","), 0, nil
|
||||
}
|
||||
|
||||
func parsePort(s string) (uint16, error) {
|
||||
port, err := strconv.ParseUint(s, 10, 16)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if port < 1 || port > math.MaxUint16 {
|
||||
return 0, errors.New("outside range")
|
||||
}
|
||||
return uint16(port), nil
|
||||
}
|
||||
|
||||
// ParseConnectionString parses either a URI or a DSN connection string and builds ConnConfig.
|
||||
//
|
||||
// # Example DSN
|
||||
// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca
|
||||
//
|
||||
// # Example URL
|
||||
// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca
|
||||
//
|
||||
// ParseConnectionString supports specifying multiple hosts in similar manner to libpq.
|
||||
// Host and port may include comma separated values that will be tried in order.
|
||||
// This can be used as part of a high availability system.
|
||||
// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information.
|
||||
//
|
||||
// # Example URL
|
||||
// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb
|
||||
//
|
||||
// # Example DSN
|
||||
// user=jack password=secret host=host1,host2,host3 port=5432,5433,5434 dbname=mydb sslmode=verify-ca
|
||||
func ParseConnectionString(s string) (ConnConfig, error) {
|
||||
// TODO: Multiple Hosts support
|
||||
// @see: https://www.postgresql.org/docs/10/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS
|
||||
if u, err := url.Parse(s); err == nil && u.Scheme != "" {
|
||||
return ParseURI(s)
|
||||
}
|
||||
|
||||
@@ -622,6 +622,38 @@ func TestParseURI(t *testing.T) {
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
url: "postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb",
|
||||
connParams: pgx.ConnConfig{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "foo.example.com:5432,bar.example.com:5432",
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
UseFallbackTLS: true,
|
||||
FallbackTLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
url: "postgres://jack@localhost,10.10.20.30/mydb?application_name=pgxtest&target_session_attrs=read-write",
|
||||
connParams: pgx.ConnConfig{
|
||||
User: "jack",
|
||||
Host: "localhost,10.10.20.30",
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
UseFallbackTLS: true,
|
||||
FallbackTLSConfig: nil,
|
||||
RuntimeParams: map[string]string{
|
||||
"application_name": "pgxtest",
|
||||
},
|
||||
TargetSessionAttrs: pgx.ReadWriteTargetSession,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
@@ -748,6 +780,50 @@ func TestParseDSN(t *testing.T) {
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
url: "user=jack host=localhost1,localhost2 dbname=mydb connect_timeout=10",
|
||||
connParams: pgx.ConnConfig{
|
||||
User: "jack",
|
||||
Host: "localhost1,localhost2",
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial,
|
||||
UseFallbackTLS: true,
|
||||
FallbackTLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
url: "user=jack host=100.200.220.50,localhost43 port=5432,5433 dbname=mydb",
|
||||
connParams: pgx.ConnConfig{
|
||||
User: "jack",
|
||||
Host: "100.200.220.50:5432,localhost43:5433",
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
UseFallbackTLS: true,
|
||||
FallbackTLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
url: "user=jack host=localhost dbname=mydb target_session_attrs=read-write",
|
||||
connParams: pgx.ConnConfig{
|
||||
User: "jack",
|
||||
Host: "localhost",
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
UseFallbackTLS: true,
|
||||
FallbackTLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
TargetSessionAttrs: pgx.ReadWriteTargetSession,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
|
||||
Reference in New Issue
Block a user