Restructure connect process
- Moved lots of connection logic to pgconn from pgx - Extracted pgpassfile package
This commit is contained in:
@@ -0,0 +1,421 @@
|
||||
package pgconn
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/pgpassfile"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Config is the settings used to establish a connection to a PostgreSQL server.
|
||||
type Config struct {
|
||||
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
|
||||
Port uint16
|
||||
Database string
|
||||
User string
|
||||
Password string
|
||||
TLSConfig *tls.Config // nil disables TLS
|
||||
DialFunc DialFunc // e.g. net.Dialer.DialContext
|
||||
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
||||
|
||||
Fallbacks []*FallbackConfig
|
||||
}
|
||||
|
||||
// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a
|
||||
// network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections.
|
||||
type FallbackConfig struct {
|
||||
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
|
||||
Port uint16
|
||||
TLSConfig *tls.Config // nil disables TLS
|
||||
}
|
||||
|
||||
// NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with
|
||||
// net.Dial.
|
||||
func NetworkAddress(host string, port uint16) (network, address string) {
|
||||
if strings.HasPrefix(host, "/") {
|
||||
network = "unix"
|
||||
address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10)
|
||||
} else {
|
||||
network = "tcp"
|
||||
address = fmt.Sprintf("%s:%d", host, port)
|
||||
}
|
||||
return network, address
|
||||
}
|
||||
|
||||
// ParseConfig builds a []*Config with similar behavior to the PostgreSQL standard C library libpq.
|
||||
// It uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment
|
||||
// variables. connString may be a URL or a DSN. It also may be empty to only read from the
|
||||
// environment. If a password is not supplied it will attempt to read the .pgpass file.
|
||||
//
|
||||
// Example DSN: "user=jack password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-ca"
|
||||
//
|
||||
// Example URL: "postgres://jack:secret@1.2.3.4:5432/mydb?sslmode=verify-ca"
|
||||
//
|
||||
// Multiple configs may be returned due to sslmode settings with fallback options (e.g.
|
||||
// sslmode=prefer). Future implementations may also support multiple hosts
|
||||
// (https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS).
|
||||
//
|
||||
// ParseConfig currently recognizes the following environment variable and their parameter key word
|
||||
// equivalents passed via database URL or DSN:
|
||||
//
|
||||
// PGHOST
|
||||
// PGPORT
|
||||
// PGDATABASE
|
||||
// PGUSER
|
||||
// PGPASSWORD
|
||||
// PGPASSFILE
|
||||
// PGSSLMODE
|
||||
// PGSSLCERT
|
||||
// PGSSLKEY
|
||||
// PGSSLROOTCERT
|
||||
// PGAPPNAME
|
||||
// PGCONNECT_TIMEOUT
|
||||
//
|
||||
// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of
|
||||
// environment variables.
|
||||
//
|
||||
// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key
|
||||
// word names. They are usually but not always the environment variable name downcased and without
|
||||
// the "PG" prefix.
|
||||
//
|
||||
// Important TLS Security Notes:
|
||||
//
|
||||
// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to
|
||||
// "prefer" behavior if not set.
|
||||
//
|
||||
// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on
|
||||
// what level of security each sslmode provides.
|
||||
//
|
||||
// "verify-ca" mode currently is treated as "verify-full". e.g. It has stronger
|
||||
// security guarantees than it would with libpq. Do not rely on this behavior as it
|
||||
// may be possible to match libpq in the future. If you need full security use
|
||||
// "verify-full".
|
||||
func ParseConfig(connString string) (*Config, error) {
|
||||
settings := defaultSettings()
|
||||
addEnvSettings(settings)
|
||||
|
||||
if connString != "" {
|
||||
// connString may be a database URL or a DSN
|
||||
if strings.HasPrefix(connString, "postgres://") {
|
||||
url, err := url.Parse(connString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = addURLSettings(settings, url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
err := addDSNSettings(settings, connString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
Host: settings["host"],
|
||||
Database: settings["database"],
|
||||
User: settings["user"],
|
||||
Password: settings["password"],
|
||||
RuntimeParams: make(map[string]string),
|
||||
}
|
||||
|
||||
if port, err := parsePort(settings["port"]); err == nil {
|
||||
config.Port = port
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid port: %v", settings["port"])
|
||||
}
|
||||
|
||||
if connectTimeout, present := settings["connect_timeout"]; present {
|
||||
dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.DialFunc = dialFunc
|
||||
} else {
|
||||
defaultDialer := makeDefaultDialer()
|
||||
config.DialFunc = defaultDialer.DialContext
|
||||
}
|
||||
|
||||
notRuntimeParams := map[string]struct{}{
|
||||
"host": struct{}{},
|
||||
"port": struct{}{},
|
||||
"database": struct{}{},
|
||||
"user": struct{}{},
|
||||
"password": struct{}{},
|
||||
"passfile": struct{}{},
|
||||
"connect_timeout": struct{}{},
|
||||
"sslmode": struct{}{},
|
||||
"sslkey": struct{}{},
|
||||
"sslcert": struct{}{},
|
||||
"sslrootcert": struct{}{},
|
||||
}
|
||||
|
||||
for k, v := range settings {
|
||||
if _, present := notRuntimeParams[k]; present {
|
||||
continue
|
||||
}
|
||||
config.RuntimeParams[k] = v
|
||||
}
|
||||
|
||||
var tlsConfigs []*tls.Config
|
||||
|
||||
// Ignore TLS settings if Unix domain socket like libpq
|
||||
if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" {
|
||||
tlsConfigs = append(tlsConfigs, nil)
|
||||
} else {
|
||||
var err error
|
||||
tlsConfigs, err = configTLS(settings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
config.TLSConfig = tlsConfigs[0]
|
||||
|
||||
for _, tlsConfig := range tlsConfigs[1:] {
|
||||
config.Fallbacks = append(config.Fallbacks, &FallbackConfig{
|
||||
Host: config.Host,
|
||||
Port: config.Port,
|
||||
TLSConfig: tlsConfig,
|
||||
})
|
||||
}
|
||||
|
||||
passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
|
||||
if err == nil {
|
||||
if config.Password == "" {
|
||||
host := config.Host
|
||||
if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" {
|
||||
host = "localhost"
|
||||
}
|
||||
|
||||
config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User)
|
||||
}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func defaultSettings() map[string]string {
|
||||
settings := make(map[string]string)
|
||||
|
||||
settings["host"] = defaultHost()
|
||||
settings["port"] = "5432"
|
||||
|
||||
// Default to the OS user name. Purposely ignoring err getting user name from
|
||||
// OS. The client application will simply have to specify the user in that
|
||||
// case (which they typically will be doing anyway).
|
||||
user, err := user.Current()
|
||||
if err == nil {
|
||||
settings["user"] = user.Username
|
||||
settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass")
|
||||
}
|
||||
|
||||
return settings
|
||||
}
|
||||
|
||||
// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost
|
||||
// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it
|
||||
// checks the existence of common locations.
|
||||
func defaultHost() string {
|
||||
candidatePaths := []string{
|
||||
"/var/run/postgresql", // Debian
|
||||
"/private/tmp", // OSX - homebrew
|
||||
"/tmp", // standard PostgreSQL
|
||||
}
|
||||
|
||||
for _, path := range candidatePaths {
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return path
|
||||
}
|
||||
}
|
||||
|
||||
return "localhost"
|
||||
}
|
||||
|
||||
func addEnvSettings(settings map[string]string) {
|
||||
nameMap := map[string]string{
|
||||
"PGHOST": "host",
|
||||
"PGPORT": "port",
|
||||
"PGDATABASE": "database",
|
||||
"PGUSER": "user",
|
||||
"PGPASSWORD": "password",
|
||||
"PGPASSFILE": "passfile",
|
||||
"PGAPPNAME": "application_name",
|
||||
"PGCONNECT_TIMEOUT": "connect_timeout",
|
||||
"PGSSLMODE": "sslmode",
|
||||
"PGSSLKEY": "sslkey",
|
||||
"PGSSLCERT": "sslcert",
|
||||
"PGSSLROOTCERT": "sslrootcert",
|
||||
}
|
||||
|
||||
for envname, realname := range nameMap {
|
||||
value := os.Getenv(envname)
|
||||
if value != "" {
|
||||
settings[realname] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func addURLSettings(settings map[string]string, url *url.URL) error {
|
||||
if url.User != nil {
|
||||
settings["user"] = url.User.Username()
|
||||
if password, present := url.User.Password(); present {
|
||||
settings["password"] = password
|
||||
}
|
||||
}
|
||||
|
||||
parts := strings.SplitN(url.Host, ":", 2)
|
||||
if parts[0] != "" {
|
||||
settings["host"] = parts[0]
|
||||
}
|
||||
if len(parts) == 2 {
|
||||
settings["port"] = parts[1]
|
||||
}
|
||||
|
||||
database := strings.TrimLeft(url.Path, "/")
|
||||
if database != "" {
|
||||
settings["database"] = database
|
||||
}
|
||||
|
||||
for k, v := range url.Query() {
|
||||
settings[k] = v[0]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`)
|
||||
|
||||
func addDSNSettings(settings map[string]string, s string) error {
|
||||
m := dsnRegexp.FindAllStringSubmatch(s, -1)
|
||||
|
||||
for _, b := range m {
|
||||
settings[b[1]] = b[2]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type pgTLSArgs struct {
|
||||
sslMode string
|
||||
sslRootCert string
|
||||
sslCert string
|
||||
sslKey string
|
||||
}
|
||||
|
||||
// configTLS uses libpq's TLS parameters to construct []*tls.Config. It is
|
||||
// necessary to allow returning multiple TLS configs as sslmode "allow" and
|
||||
// "prefer" allow fallback.
|
||||
func configTLS(settings map[string]string) ([]*tls.Config, error) {
|
||||
host := settings["host"]
|
||||
sslmode := settings["sslmode"]
|
||||
sslrootcert := settings["sslrootcert"]
|
||||
sslcert := settings["sslcert"]
|
||||
sslkey := settings["sslkey"]
|
||||
|
||||
// Match libpq default behavior
|
||||
if sslmode == "" {
|
||||
sslmode = "prefer"
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{}
|
||||
|
||||
switch sslmode {
|
||||
case "disable":
|
||||
return []*tls.Config{nil}, nil
|
||||
case "allow", "prefer":
|
||||
tlsConfig.InsecureSkipVerify = true
|
||||
case "require":
|
||||
tlsConfig.InsecureSkipVerify = sslrootcert == ""
|
||||
case "verify-ca", "verify-full":
|
||||
tlsConfig.ServerName = host
|
||||
default:
|
||||
return nil, errors.New("sslmode is invalid")
|
||||
}
|
||||
|
||||
if sslrootcert != "" {
|
||||
caCertPool := x509.NewCertPool()
|
||||
|
||||
caPath := sslrootcert
|
||||
caCert, err := ioutil.ReadFile(caPath)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "unable to read CA file %q", caPath)
|
||||
}
|
||||
|
||||
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||
return nil, errors.Wrap(err, "unable to add CA to cert pool")
|
||||
}
|
||||
|
||||
tlsConfig.RootCAs = caCertPool
|
||||
tlsConfig.ClientCAs = caCertPool
|
||||
}
|
||||
|
||||
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
|
||||
return nil, fmt.Errorf(`both "sslcert" and "sslkey" are required`)
|
||||
}
|
||||
|
||||
if sslcert != "" && sslkey != "" {
|
||||
cert, err := tls.LoadX509KeyPair(sslcert, sslkey)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "unable to read cert")
|
||||
}
|
||||
|
||||
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||
}
|
||||
|
||||
switch sslmode {
|
||||
case "allow":
|
||||
return []*tls.Config{nil, tlsConfig}, nil
|
||||
case "prefer":
|
||||
return []*tls.Config{tlsConfig, nil}, nil
|
||||
case "require", "verify-ca", "verify-full":
|
||||
return []*tls.Config{tlsConfig}, nil
|
||||
default:
|
||||
panic("BUG: bad sslmode should already have been caught")
|
||||
}
|
||||
}
|
||||
|
||||
func parsePort(s string) (uint16, error) {
|
||||
port, err := strconv.ParseUint(s, 10, 16)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if port < 1 || port > math.MaxUint16 {
|
||||
return 0, errors.New("outside range")
|
||||
}
|
||||
return uint16(port), nil
|
||||
}
|
||||
|
||||
func makeDefaultDialer() *net.Dialer {
|
||||
return &net.Dialer{KeepAlive: 5 * time.Minute}
|
||||
}
|
||||
|
||||
func makeConnectTimeoutDialFunc(s string) (DialFunc, error) {
|
||||
timeout, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if timeout < 0 {
|
||||
return nil, errors.New("negative timeout")
|
||||
}
|
||||
|
||||
d := makeDefaultDialer()
|
||||
d.Timeout = time.Duration(timeout) * time.Second
|
||||
return d.DialContext, nil
|
||||
}
|
||||
+392
@@ -0,0 +1,392 @@
|
||||
package pgconn_test
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/user"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var osUserName string
|
||||
osUser, err := user.Current()
|
||||
if err == nil {
|
||||
osUserName = osUser.Username
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
connString string
|
||||
config *pgconn.Config
|
||||
}{
|
||||
// Test all sslmodes
|
||||
{
|
||||
name: "sslmode not set (prefer)",
|
||||
connString: "postgres://jack:secret@localhost:5432/mydb",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
RuntimeParams: map[string]string{},
|
||||
Fallbacks: []*pgconn.FallbackConfig{
|
||||
&pgconn.FallbackConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
TLSConfig: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sslmode disable",
|
||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sslmode allow",
|
||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=allow",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
Fallbacks: []*pgconn.FallbackConfig{
|
||||
&pgconn.FallbackConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sslmode prefer",
|
||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer",
|
||||
config: &pgconn.Config{
|
||||
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
RuntimeParams: map[string]string{},
|
||||
Fallbacks: []*pgconn.FallbackConfig{
|
||||
&pgconn.FallbackConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
TLSConfig: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sslmode require",
|
||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=require",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sslmode verify-ca",
|
||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-ca",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{ServerName: "localhost"},
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sslmode verify-full",
|
||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-full",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{ServerName: "localhost"},
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "database url everything",
|
||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{
|
||||
"application_name": "pgxtest",
|
||||
"search_path": "myschema",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "database url missing password",
|
||||
connString: "postgres://jack@localhost:5432/mydb?sslmode=disable",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "database url missing user and password",
|
||||
connString: "postgres://localhost:5432/mydb?sslmode=disable",
|
||||
config: &pgconn.Config{
|
||||
User: osUserName,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "database url missing port",
|
||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "database url unix domain socket host",
|
||||
connString: "postgres:///foo?host=/tmp",
|
||||
config: &pgconn.Config{
|
||||
User: osUserName,
|
||||
Host: "/tmp",
|
||||
Port: 5432,
|
||||
Database: "foo",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "DSN everything",
|
||||
connString: "user=jack password=secret host=localhost port=5432 database=mydb sslmode=disable application_name=pgxtest search_path=myschema",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{
|
||||
"application_name": "pgxtest",
|
||||
"search_path": "myschema",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
config, err := pgconn.ParseConfig(tt.connString)
|
||||
if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) {
|
||||
continue
|
||||
}
|
||||
|
||||
assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name))
|
||||
}
|
||||
}
|
||||
|
||||
func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) {
|
||||
assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName)
|
||||
assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName)
|
||||
assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName)
|
||||
assert.Equalf(t, expected.User, actual.User, "%s - User", testName)
|
||||
assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName)
|
||||
assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
|
||||
|
||||
if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) {
|
||||
if expected.TLSConfig != nil {
|
||||
assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName)
|
||||
assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName)
|
||||
}
|
||||
}
|
||||
|
||||
if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks %v", testName) {
|
||||
for i := range expected.Fallbacks {
|
||||
assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i)
|
||||
assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i)
|
||||
|
||||
if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName) {
|
||||
if expected.Fallbacks[i].TLSConfig != nil {
|
||||
assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName)
|
||||
assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigEnvLibpq(t *testing.T) {
|
||||
var osUserName string
|
||||
osUser, err := user.Current()
|
||||
if err == nil {
|
||||
osUserName = osUser.Username
|
||||
}
|
||||
|
||||
pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT"}
|
||||
|
||||
savedEnv := make(map[string]string)
|
||||
for _, n := range pgEnvvars {
|
||||
savedEnv[n] = os.Getenv(n)
|
||||
}
|
||||
defer func() {
|
||||
for k, v := range savedEnv {
|
||||
err := os.Setenv(k, v)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to restore environment: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
envvars map[string]string
|
||||
config *pgconn.Config
|
||||
}{
|
||||
{
|
||||
// not testing no environment at all as that would use default host and that can vary.
|
||||
name: "PGHOST only",
|
||||
envvars: map[string]string{"PGHOST": "123.123.123.123"},
|
||||
config: &pgconn.Config{
|
||||
User: osUserName,
|
||||
Host: "123.123.123.123",
|
||||
Port: 5432,
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
RuntimeParams: map[string]string{},
|
||||
Fallbacks: []*pgconn.FallbackConfig{
|
||||
&pgconn.FallbackConfig{
|
||||
Host: "123.123.123.123",
|
||||
Port: 5432,
|
||||
TLSConfig: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "All non-TLS environment",
|
||||
envvars: map[string]string{
|
||||
"PGHOST": "123.123.123.123",
|
||||
"PGPORT": "7777",
|
||||
"PGDATABASE": "foo",
|
||||
"PGUSER": "bar",
|
||||
"PGPASSWORD": "baz",
|
||||
"PGCONNECT_TIMEOUT": "10",
|
||||
"PGSSLMODE": "disable",
|
||||
"PGAPPNAME": "pgxtest",
|
||||
},
|
||||
config: &pgconn.Config{
|
||||
Host: "123.123.123.123",
|
||||
Port: 7777,
|
||||
Database: "foo",
|
||||
User: "bar",
|
||||
Password: "baz",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{"application_name": "pgxtest"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
for _, n := range pgEnvvars {
|
||||
err := os.Unsetenv(n)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
for k, v := range tt.envvars {
|
||||
err := os.Setenv(k, v)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
config, err := pgconn.ParseConfig("")
|
||||
if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) {
|
||||
continue
|
||||
}
|
||||
|
||||
assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigReadsPgPassfile(t *testing.T) {
|
||||
tf, err := ioutil.TempFile("", "")
|
||||
require.Nil(t, err)
|
||||
|
||||
defer tf.Close()
|
||||
defer os.Remove(tf.Name())
|
||||
|
||||
_, err = tf.Write([]byte("test1:5432:curlydb:curly:nyuknyuknyuk"))
|
||||
require.Nil(t, err)
|
||||
|
||||
connString := fmt.Sprintf("postgres://curly@test1:5432/curlydb?sslmode=disable&passfile=%s", tf.Name())
|
||||
expected := &pgconn.Config{
|
||||
User: "curly",
|
||||
Password: "nyuknyuknyuk",
|
||||
Host: "test1",
|
||||
Port: 5432,
|
||||
Database: "curlydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
}
|
||||
|
||||
actual, err := pgconn.ParseConfig(connString)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assertConfigsEqual(t, expected, actual, "passfile")
|
||||
}
|
||||
@@ -1,20 +1,16 @@
|
||||
package pgconn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
@@ -23,7 +19,7 @@ import (
|
||||
const batchBufferSize = 4096
|
||||
|
||||
// PgError represents an error reported by the PostgreSQL server. See
|
||||
// http://www.postgresql.org/docs/9.3/static/protocol-error-fields.html for
|
||||
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
|
||||
// detailed field description.
|
||||
type PgError struct {
|
||||
Severity string
|
||||
@@ -50,60 +46,12 @@ func (pe PgError) Error() string {
|
||||
}
|
||||
|
||||
// DialFunc is a function that can be used to connect to a PostgreSQL server
|
||||
type DialFunc func(network, addr string) (net.Conn, error)
|
||||
type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
|
||||
// ErrTLSRefused occurs when the connection attempt requires TLS and the
|
||||
// PostgreSQL server refuses to use TLS
|
||||
var ErrTLSRefused = errors.New("server refused TLS connection")
|
||||
|
||||
type ConnConfig struct {
|
||||
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
|
||||
Port uint16 // default: 5432
|
||||
Database string
|
||||
User string // default: OS user name
|
||||
Password string
|
||||
TLSConfig *tls.Config // config for TLS connection -- nil disables TLS
|
||||
Dial DialFunc
|
||||
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
||||
}
|
||||
|
||||
func (cc *ConnConfig) NetworkAddress() (network, address string) {
|
||||
// If host is a valid path, then address is unix socket
|
||||
if _, err := os.Stat(cc.Host); err == nil {
|
||||
network = "unix"
|
||||
address = cc.Host
|
||||
if !strings.Contains(address, "/.s.PGSQL.") {
|
||||
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10)
|
||||
}
|
||||
} else {
|
||||
network = "tcp"
|
||||
address = fmt.Sprintf("%s:%d", cc.Host, cc.Port)
|
||||
}
|
||||
|
||||
return network, address
|
||||
}
|
||||
|
||||
func (cc *ConnConfig) assignDefaults() error {
|
||||
if cc.User == "" {
|
||||
user, err := user.Current()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.User = user.Username
|
||||
}
|
||||
|
||||
if cc.Port == 0 {
|
||||
cc.Port = 5432
|
||||
}
|
||||
|
||||
if cc.Dial == nil {
|
||||
defaultDialer := &net.Dialer{KeepAlive: 5 * time.Minute}
|
||||
cc.Dial = defaultDialer.Dial
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
|
||||
type PgConn struct {
|
||||
NetConn net.Conn // the underlying TCP or unix domain socket connection
|
||||
@@ -113,7 +61,7 @@ type PgConn struct {
|
||||
TxStatus byte
|
||||
Frontend *pgproto3.Frontend
|
||||
|
||||
Config ConnConfig
|
||||
Config *Config
|
||||
|
||||
batchBuf []byte
|
||||
batchCount int32
|
||||
@@ -123,24 +71,72 @@ type PgConn struct {
|
||||
closed bool
|
||||
}
|
||||
|
||||
func Connect(cc ConnConfig) (*PgConn, error) {
|
||||
err := cc.assignDefaults()
|
||||
// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format)
|
||||
// to provide configuration. See documention for ParseConfig for details. ctx can be used to cancel a connect attempt.
|
||||
func Connect(ctx context.Context, connString string) (*PgConn, error) {
|
||||
config, err := ParseConfig(connString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pgConn := new(PgConn)
|
||||
pgConn.Config = cc
|
||||
return ConnectConfig(ctx, config)
|
||||
}
|
||||
|
||||
pgConn.NetConn, err = cc.Dial(cc.NetworkAddress())
|
||||
// Connect establishes a connection to a PostgreSQL server using config. ctx can be used to cancel a connect attempt.
|
||||
//
|
||||
// If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An
|
||||
// authentication error will terminate the chain of attempts (like libpq:
|
||||
// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise,
|
||||
// if all attempts fail the last error is returned.
|
||||
func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err error) {
|
||||
// For convenience set a few defaults if not already set. This makes it simpler to directly construct a config.
|
||||
if config.Port == 0 {
|
||||
config.Port = 5432
|
||||
}
|
||||
if config.DialFunc == nil {
|
||||
config.DialFunc = makeDefaultDialer().DialContext
|
||||
}
|
||||
if config.RuntimeParams == nil {
|
||||
config.RuntimeParams = make(map[string]string)
|
||||
}
|
||||
|
||||
// Simplify usage by treating primary config and fallbacks the same.
|
||||
fallbackConfigs := []*FallbackConfig{
|
||||
{
|
||||
Host: config.Host,
|
||||
Port: config.Port,
|
||||
TLSConfig: config.TLSConfig,
|
||||
},
|
||||
}
|
||||
fallbackConfigs = append(fallbackConfigs, config.Fallbacks...)
|
||||
|
||||
for _, fc := range fallbackConfigs {
|
||||
pgConn, err = connect(ctx, config, fc)
|
||||
if err == nil {
|
||||
return pgConn, nil
|
||||
} else if err, ok := err.(PgError); ok {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) {
|
||||
pgConn := new(PgConn)
|
||||
pgConn.Config = config
|
||||
|
||||
var err error
|
||||
network, address := NetworkAddress(config.Host, config.Port)
|
||||
pgConn.NetConn, err = config.DialFunc(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pgConn.parameterStatuses = make(map[string]string)
|
||||
|
||||
if cc.TLSConfig != nil {
|
||||
if err := pgConn.startTLS(cc.TLSConfig); err != nil {
|
||||
if config.TLSConfig != nil {
|
||||
if err := pgConn.startTLS(config.TLSConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -156,13 +152,13 @@ func Connect(cc ConnConfig) (*PgConn, error) {
|
||||
}
|
||||
|
||||
// Copy default run-time params
|
||||
for k, v := range cc.RuntimeParams {
|
||||
for k, v := range config.RuntimeParams {
|
||||
startupMsg.Parameters[k] = v
|
||||
}
|
||||
|
||||
startupMsg.Parameters["user"] = cc.User
|
||||
if cc.Database != "" {
|
||||
startupMsg.Parameters["database"] = cc.Database
|
||||
startupMsg.Parameters["user"] = config.User
|
||||
if config.Database != "" {
|
||||
startupMsg.Parameters["database"] = config.Database
|
||||
}
|
||||
|
||||
if _, err := pgConn.NetConn.Write(startupMsg.Encode(nil)); err != nil {
|
||||
|
||||
+5
-3
@@ -1,16 +1,18 @@
|
||||
package pgconn_test
|
||||
|
||||
import (
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSimple(t *testing.T) {
|
||||
pgConn, err := pgconn.Connect(pgconn.ConnConfig{Host: "/var/run/postgresql", User: "jack", Database: "pgx_test"})
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.Nil(t, err)
|
||||
|
||||
pgConn.SendExec("select current_database()")
|
||||
|
||||
Reference in New Issue
Block a user