ParseConnectionString supports Multi-Hosts
Signed-off-by: Artemiy Ryabinkov <getlag@ya.ru>
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
@@ -947,16 +948,26 @@ func ParseURI(uri string) (ConnConfig, error) {
|
|||||||
cp.Password, _ = url.User.Password()
|
cp.Password, _ = url.User.Password()
|
||||||
}
|
}
|
||||||
|
|
||||||
parts := strings.SplitN(url.Host, ":", 2)
|
hasMuliHosts := strings.IndexByte(url.Host, ',') != -1
|
||||||
cp.Host = parts[0]
|
if !hasMuliHosts {
|
||||||
if len(parts) == 2 {
|
parts := strings.SplitN(url.Host, ":", 2)
|
||||||
p, err := strconv.ParseUint(parts[1], 10, 16)
|
cp.Host = parts[0]
|
||||||
if err != nil {
|
if len(parts) == 2 {
|
||||||
return cp, err
|
p, err := strconv.ParseUint(parts[1], 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return cp, err
|
||||||
|
}
|
||||||
|
cp.Port = uint16(p)
|
||||||
}
|
}
|
||||||
cp.Port = uint16(p)
|
} else {
|
||||||
|
cp.Host = url.Host
|
||||||
}
|
}
|
||||||
|
|
||||||
cp.Database = strings.TrimLeft(url.Path, "/")
|
cp.Database = strings.TrimLeft(url.Path, "/")
|
||||||
|
cp.TargetSessionAttrs = TargetSessionType(url.Query().Get("target_session_attrs"))
|
||||||
|
if err := cp.TargetSessionAttrs.isValid(); err != nil {
|
||||||
|
return cp, err
|
||||||
|
}
|
||||||
|
|
||||||
if pgtimeout := url.Query().Get("connect_timeout"); pgtimeout != "" {
|
if pgtimeout := url.Query().Get("connect_timeout"); pgtimeout != "" {
|
||||||
timeout, err := strconv.ParseInt(pgtimeout, 10, 64)
|
timeout, err := strconv.ParseInt(pgtimeout, 10, 64)
|
||||||
@@ -980,11 +991,12 @@ func ParseURI(uri string) (ConnConfig, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ignoreKeys := map[string]struct{}{
|
ignoreKeys := map[string]struct{}{
|
||||||
"connect_timeout": {},
|
"connect_timeout": {},
|
||||||
"sslcert": {},
|
"sslcert": {},
|
||||||
"sslkey": {},
|
"sslkey": {},
|
||||||
"sslmode": {},
|
"sslmode": {},
|
||||||
"sslrootcert": {},
|
"sslrootcert": {},
|
||||||
|
"target_session_attrs": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
cp.RuntimeParams = make(map[string]string)
|
cp.RuntimeParams = make(map[string]string)
|
||||||
@@ -1029,6 +1041,7 @@ func ParseDSN(s string) (ConnConfig, error) {
|
|||||||
|
|
||||||
cp.RuntimeParams = make(map[string]string)
|
cp.RuntimeParams = make(map[string]string)
|
||||||
|
|
||||||
|
var hostval, portval string
|
||||||
for _, b := range m {
|
for _, b := range m {
|
||||||
switch b[1] {
|
switch b[1] {
|
||||||
case "user":
|
case "user":
|
||||||
@@ -1036,13 +1049,9 @@ func ParseDSN(s string) (ConnConfig, error) {
|
|||||||
case "password":
|
case "password":
|
||||||
cp.Password = b[2]
|
cp.Password = b[2]
|
||||||
case "host":
|
case "host":
|
||||||
cp.Host = b[2]
|
hostval = b[2]
|
||||||
case "port":
|
case "port":
|
||||||
p, err := strconv.ParseUint(b[2], 10, 16)
|
portval = b[2]
|
||||||
if err != nil {
|
|
||||||
return cp, err
|
|
||||||
}
|
|
||||||
cp.Port = uint16(p)
|
|
||||||
case "dbname":
|
case "dbname":
|
||||||
cp.Database = b[2]
|
cp.Database = b[2]
|
||||||
case "sslmode":
|
case "sslmode":
|
||||||
@@ -1061,26 +1070,94 @@ func ParseDSN(s string) (ConnConfig, error) {
|
|||||||
d := defaultDialer()
|
d := defaultDialer()
|
||||||
d.Timeout = time.Duration(timeout) * time.Second
|
d.Timeout = time.Duration(timeout) * time.Second
|
||||||
cp.Dial = d.Dial
|
cp.Dial = d.Dial
|
||||||
|
case "target_session_attrs":
|
||||||
|
cp.TargetSessionAttrs = TargetSessionType(b[2])
|
||||||
|
if err := cp.TargetSessionAttrs.isValid(); err != nil {
|
||||||
|
return cp, err
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
cp.RuntimeParams[b[1]] = b[2]
|
cp.RuntimeParams[b[1]] = b[2]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err := configTLS(tlsArgs, &cp)
|
host, port, err := parseHostPortDSN(hostval, portval)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cp, err
|
return cp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cp.Host, cp.Port = host, port
|
||||||
|
|
||||||
|
err = configTLS(tlsArgs, &cp)
|
||||||
|
if err != nil {
|
||||||
|
return cp, err
|
||||||
|
}
|
||||||
|
|
||||||
if cp.Password == "" {
|
if cp.Password == "" {
|
||||||
pgpass(&cp)
|
pgpass(&cp)
|
||||||
}
|
}
|
||||||
|
|
||||||
return cp, nil
|
return cp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseConnectionString parses either a URI or a DSN connection string.
|
func parseHostPortDSN(hostval, portval string) (host string, port uint16, err error) {
|
||||||
// see ParseURI and ParseDSN for details.
|
if portval == "" {
|
||||||
|
return hostval, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
hosts := strings.Split(hostval, ",")
|
||||||
|
ports := strings.Split(portval, ",")
|
||||||
|
|
||||||
|
if len(ports) == 1 {
|
||||||
|
port, err := parsePort(portval)
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, errors.Errorf("invalid port: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return hostval, port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(hosts) != len(ports) {
|
||||||
|
return "", 0, errors.New("the number of hosts and ports must be the same")
|
||||||
|
}
|
||||||
|
|
||||||
|
hostports := make([]string, len(hosts))
|
||||||
|
for i, host := range hosts {
|
||||||
|
hostports[i] = host + ":" + ports[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(hostports, ","), 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePort(s string) (uint16, error) {
|
||||||
|
port, err := strconv.ParseUint(s, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if port < 1 || port > math.MaxUint16 {
|
||||||
|
return 0, errors.New("outside range")
|
||||||
|
}
|
||||||
|
return uint16(port), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseConnectionString parses either a URI or a DSN connection string and builds ConnConfig.
|
||||||
|
//
|
||||||
|
// # Example DSN
|
||||||
|
// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca
|
||||||
|
//
|
||||||
|
// # Example URL
|
||||||
|
// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca
|
||||||
|
//
|
||||||
|
// ParseConnectionString supports specifying multiple hosts in similar manner to libpq.
|
||||||
|
// Host and port may include comma separated values that will be tried in order.
|
||||||
|
// This can be used as part of a high availability system.
|
||||||
|
// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information.
|
||||||
|
//
|
||||||
|
// # Example URL
|
||||||
|
// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb
|
||||||
|
//
|
||||||
|
// # Example DSN
|
||||||
|
// user=jack password=secret host=host1,host2,host3 port=5432,5433,5434 dbname=mydb sslmode=verify-ca
|
||||||
func ParseConnectionString(s string) (ConnConfig, error) {
|
func ParseConnectionString(s string) (ConnConfig, error) {
|
||||||
// TODO: Multiple Hosts support
|
|
||||||
// @see: https://www.postgresql.org/docs/10/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS
|
|
||||||
if u, err := url.Parse(s); err == nil && u.Scheme != "" {
|
if u, err := url.Parse(s); err == nil && u.Scheme != "" {
|
||||||
return ParseURI(s)
|
return ParseURI(s)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -622,6 +622,38 @@ func TestParseURI(t *testing.T) {
|
|||||||
RuntimeParams: map[string]string{},
|
RuntimeParams: map[string]string{},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
url: "postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb",
|
||||||
|
connParams: pgx.ConnConfig{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "foo.example.com:5432,bar.example.com:5432",
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
UseFallbackTLS: true,
|
||||||
|
FallbackTLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
url: "postgres://jack@localhost,10.10.20.30/mydb?application_name=pgxtest&target_session_attrs=read-write",
|
||||||
|
connParams: pgx.ConnConfig{
|
||||||
|
User: "jack",
|
||||||
|
Host: "localhost,10.10.20.30",
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
UseFallbackTLS: true,
|
||||||
|
FallbackTLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{
|
||||||
|
"application_name": "pgxtest",
|
||||||
|
},
|
||||||
|
TargetSessionAttrs: pgx.ReadWriteTargetSession,
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
@@ -748,6 +780,50 @@ func TestParseDSN(t *testing.T) {
|
|||||||
RuntimeParams: map[string]string{},
|
RuntimeParams: map[string]string{},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
url: "user=jack host=localhost1,localhost2 dbname=mydb connect_timeout=10",
|
||||||
|
connParams: pgx.ConnConfig{
|
||||||
|
User: "jack",
|
||||||
|
Host: "localhost1,localhost2",
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial,
|
||||||
|
UseFallbackTLS: true,
|
||||||
|
FallbackTLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
url: "user=jack host=100.200.220.50,localhost43 port=5432,5433 dbname=mydb",
|
||||||
|
connParams: pgx.ConnConfig{
|
||||||
|
User: "jack",
|
||||||
|
Host: "100.200.220.50:5432,localhost43:5433",
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
UseFallbackTLS: true,
|
||||||
|
FallbackTLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
url: "user=jack host=localhost dbname=mydb target_session_attrs=read-write",
|
||||||
|
connParams: pgx.ConnConfig{
|
||||||
|
User: "jack",
|
||||||
|
Host: "localhost",
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
UseFallbackTLS: true,
|
||||||
|
FallbackTLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
TargetSessionAttrs: pgx.ReadWriteTargetSession,
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
|
|||||||
Reference in New Issue
Block a user