2
0

ParseConnectionString supports Multi-Hosts

Signed-off-by: Artemiy Ryabinkov <getlag@ya.ru>
This commit is contained in:
Artemiy Ryabinkov
2019-07-11 20:28:04 +03:00
parent 75b4ba635c
commit 18189fafd5
2 changed files with 176 additions and 23 deletions
+100 -23
View File
@@ -10,6 +10,7 @@ import (
"fmt"
"io"
"io/ioutil"
"math"
"net"
"net/url"
"os"
@@ -947,16 +948,26 @@ func ParseURI(uri string) (ConnConfig, error) {
cp.Password, _ = url.User.Password()
}
parts := strings.SplitN(url.Host, ":", 2)
cp.Host = parts[0]
if len(parts) == 2 {
p, err := strconv.ParseUint(parts[1], 10, 16)
if err != nil {
return cp, err
hasMuliHosts := strings.IndexByte(url.Host, ',') != -1
if !hasMuliHosts {
parts := strings.SplitN(url.Host, ":", 2)
cp.Host = parts[0]
if len(parts) == 2 {
p, err := strconv.ParseUint(parts[1], 10, 16)
if err != nil {
return cp, err
}
cp.Port = uint16(p)
}
cp.Port = uint16(p)
} else {
cp.Host = url.Host
}
cp.Database = strings.TrimLeft(url.Path, "/")
cp.TargetSessionAttrs = TargetSessionType(url.Query().Get("target_session_attrs"))
if err := cp.TargetSessionAttrs.isValid(); err != nil {
return cp, err
}
if pgtimeout := url.Query().Get("connect_timeout"); pgtimeout != "" {
timeout, err := strconv.ParseInt(pgtimeout, 10, 64)
@@ -980,11 +991,12 @@ func ParseURI(uri string) (ConnConfig, error) {
}
ignoreKeys := map[string]struct{}{
"connect_timeout": {},
"sslcert": {},
"sslkey": {},
"sslmode": {},
"sslrootcert": {},
"connect_timeout": {},
"sslcert": {},
"sslkey": {},
"sslmode": {},
"sslrootcert": {},
"target_session_attrs": {},
}
cp.RuntimeParams = make(map[string]string)
@@ -1029,6 +1041,7 @@ func ParseDSN(s string) (ConnConfig, error) {
cp.RuntimeParams = make(map[string]string)
var hostval, portval string
for _, b := range m {
switch b[1] {
case "user":
@@ -1036,13 +1049,9 @@ func ParseDSN(s string) (ConnConfig, error) {
case "password":
cp.Password = b[2]
case "host":
cp.Host = b[2]
hostval = b[2]
case "port":
p, err := strconv.ParseUint(b[2], 10, 16)
if err != nil {
return cp, err
}
cp.Port = uint16(p)
portval = b[2]
case "dbname":
cp.Database = b[2]
case "sslmode":
@@ -1061,26 +1070,94 @@ func ParseDSN(s string) (ConnConfig, error) {
d := defaultDialer()
d.Timeout = time.Duration(timeout) * time.Second
cp.Dial = d.Dial
case "target_session_attrs":
cp.TargetSessionAttrs = TargetSessionType(b[2])
if err := cp.TargetSessionAttrs.isValid(); err != nil {
return cp, err
}
default:
cp.RuntimeParams[b[1]] = b[2]
}
}
err := configTLS(tlsArgs, &cp)
host, port, err := parseHostPortDSN(hostval, portval)
if err != nil {
return cp, err
}
cp.Host, cp.Port = host, port
err = configTLS(tlsArgs, &cp)
if err != nil {
return cp, err
}
if cp.Password == "" {
pgpass(&cp)
}
return cp, nil
}
// ParseConnectionString parses either a URI or a DSN connection string.
// see ParseURI and ParseDSN for details.
func parseHostPortDSN(hostval, portval string) (host string, port uint16, err error) {
if portval == "" {
return hostval, 0, nil
}
hosts := strings.Split(hostval, ",")
ports := strings.Split(portval, ",")
if len(ports) == 1 {
port, err := parsePort(portval)
if err != nil {
return "", 0, errors.Errorf("invalid port: %v", err)
}
return hostval, port, nil
}
if len(hosts) != len(ports) {
return "", 0, errors.New("the number of hosts and ports must be the same")
}
hostports := make([]string, len(hosts))
for i, host := range hosts {
hostports[i] = host + ":" + ports[i]
}
return strings.Join(hostports, ","), 0, nil
}
func parsePort(s string) (uint16, error) {
port, err := strconv.ParseUint(s, 10, 16)
if err != nil {
return 0, err
}
if port < 1 || port > math.MaxUint16 {
return 0, errors.New("outside range")
}
return uint16(port), nil
}
// ParseConnectionString parses either a URI or a DSN connection string and builds ConnConfig.
//
// # Example DSN
// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca
//
// # Example URL
// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca
//
// ParseConnectionString supports specifying multiple hosts in similar manner to libpq.
// Host and port may include comma separated values that will be tried in order.
// This can be used as part of a high availability system.
// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information.
//
// # Example URL
// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb
//
// # Example DSN
// user=jack password=secret host=host1,host2,host3 port=5432,5433,5434 dbname=mydb sslmode=verify-ca
func ParseConnectionString(s string) (ConnConfig, error) {
// TODO: Multiple Hosts support
// @see: https://www.postgresql.org/docs/10/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS
if u, err := url.Parse(s); err == nil && u.Scheme != "" {
return ParseURI(s)
}
+76
View File
@@ -622,6 +622,38 @@ func TestParseURI(t *testing.T) {
RuntimeParams: map[string]string{},
},
},
{
url: "postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb",
connParams: pgx.ConnConfig{
User: "jack",
Password: "secret",
Host: "foo.example.com:5432,bar.example.com:5432",
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
UseFallbackTLS: true,
FallbackTLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
url: "postgres://jack@localhost,10.10.20.30/mydb?application_name=pgxtest&target_session_attrs=read-write",
connParams: pgx.ConnConfig{
User: "jack",
Host: "localhost,10.10.20.30",
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
UseFallbackTLS: true,
FallbackTLSConfig: nil,
RuntimeParams: map[string]string{
"application_name": "pgxtest",
},
TargetSessionAttrs: pgx.ReadWriteTargetSession,
},
},
}
for i, tt := range tests {
@@ -748,6 +780,50 @@ func TestParseDSN(t *testing.T) {
RuntimeParams: map[string]string{},
},
},
{
url: "user=jack host=localhost1,localhost2 dbname=mydb connect_timeout=10",
connParams: pgx.ConnConfig{
User: "jack",
Host: "localhost1,localhost2",
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial,
UseFallbackTLS: true,
FallbackTLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
url: "user=jack host=100.200.220.50,localhost43 port=5432,5433 dbname=mydb",
connParams: pgx.ConnConfig{
User: "jack",
Host: "100.200.220.50:5432,localhost43:5433",
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
UseFallbackTLS: true,
FallbackTLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
url: "user=jack host=localhost dbname=mydb target_session_attrs=read-write",
connParams: pgx.ConnConfig{
User: "jack",
Host: "localhost",
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
UseFallbackTLS: true,
FallbackTLSConfig: nil,
RuntimeParams: map[string]string{},
TargetSessionAttrs: pgx.ReadWriteTargetSession,
},
},
}
for i, tt := range tests {