From beeb69ff0bed06647f93f4eafae419ff43fd4da1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 30 Dec 2018 16:53:57 -0600 Subject: [PATCH] Restructure connect process - Moved lots of connection logic to pgconn from pgx - Extracted pgpassfile package --- config.go | 421 +++++++++++++++++++++++++++++++++++++++++++++++++ config_test.go | 392 +++++++++++++++++++++++++++++++++++++++++++++ pgconn.go | 130 ++++++++------- pgconn_test.go | 8 +- 4 files changed, 881 insertions(+), 70 deletions(-) create mode 100644 config.go create mode 100644 config_test.go diff --git a/config.go b/config.go new file mode 100644 index 00000000..515d6356 --- /dev/null +++ b/config.go @@ -0,0 +1,421 @@ +package pgconn + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "math" + "net" + "net/url" + "os" + "os/user" + "path/filepath" + "regexp" + "strconv" + "strings" + "time" + + "github.com/jackc/pgx/pgpassfile" + "github.com/pkg/errors" +) + +// Config is the settings used to establish a connection to a PostgreSQL server. +type Config struct { + Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) + Port uint16 + Database string + User string + Password string + TLSConfig *tls.Config // nil disables TLS + DialFunc DialFunc // e.g. net.Dialer.DialContext + RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) + + Fallbacks []*FallbackConfig +} + +// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a +// network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections. +type FallbackConfig struct { + Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) + Port uint16 + TLSConfig *tls.Config // nil disables TLS +} + +// NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with +// net.Dial. +func NetworkAddress(host string, port uint16) (network, address string) { + if strings.HasPrefix(host, "/") { + network = "unix" + address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10) + } else { + network = "tcp" + address = fmt.Sprintf("%s:%d", host, port) + } + return network, address +} + +// ParseConfig builds a []*Config with similar behavior to the PostgreSQL standard C library libpq. +// It uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment +// variables. connString may be a URL or a DSN. It also may be empty to only read from the +// environment. If a password is not supplied it will attempt to read the .pgpass file. +// +// Example DSN: "user=jack password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-ca" +// +// Example URL: "postgres://jack:secret@1.2.3.4:5432/mydb?sslmode=verify-ca" +// +// Multiple configs may be returned due to sslmode settings with fallback options (e.g. +// sslmode=prefer). Future implementations may also support multiple hosts +// (https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS). +// +// ParseConfig currently recognizes the following environment variable and their parameter key word +// equivalents passed via database URL or DSN: +// +// PGHOST +// PGPORT +// PGDATABASE +// PGUSER +// PGPASSWORD +// PGPASSFILE +// PGSSLMODE +// PGSSLCERT +// PGSSLKEY +// PGSSLROOTCERT +// PGAPPNAME +// PGCONNECT_TIMEOUT +// +// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of +// environment variables. +// +// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key +// word names. They are usually but not always the environment variable name downcased and without +// the "PG" prefix. +// +// Important TLS Security Notes: +// +// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to +// "prefer" behavior if not set. +// +// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on +// what level of security each sslmode provides. +// +// "verify-ca" mode currently is treated as "verify-full". e.g. It has stronger +// security guarantees than it would with libpq. Do not rely on this behavior as it +// may be possible to match libpq in the future. If you need full security use +// "verify-full". +func ParseConfig(connString string) (*Config, error) { + settings := defaultSettings() + addEnvSettings(settings) + + if connString != "" { + // connString may be a database URL or a DSN + if strings.HasPrefix(connString, "postgres://") { + url, err := url.Parse(connString) + if err != nil { + return nil, err + } + + err = addURLSettings(settings, url) + if err != nil { + return nil, err + } + } else { + err := addDSNSettings(settings, connString) + if err != nil { + return nil, err + } + } + } + + config := &Config{ + Host: settings["host"], + Database: settings["database"], + User: settings["user"], + Password: settings["password"], + RuntimeParams: make(map[string]string), + } + + if port, err := parsePort(settings["port"]); err == nil { + config.Port = port + } else { + return nil, fmt.Errorf("invalid port: %v", settings["port"]) + } + + if connectTimeout, present := settings["connect_timeout"]; present { + dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout) + if err != nil { + return nil, err + } + config.DialFunc = dialFunc + } else { + defaultDialer := makeDefaultDialer() + config.DialFunc = defaultDialer.DialContext + } + + notRuntimeParams := map[string]struct{}{ + "host": struct{}{}, + "port": struct{}{}, + "database": struct{}{}, + "user": struct{}{}, + "password": struct{}{}, + "passfile": struct{}{}, + "connect_timeout": struct{}{}, + "sslmode": struct{}{}, + "sslkey": struct{}{}, + "sslcert": struct{}{}, + "sslrootcert": struct{}{}, + } + + for k, v := range settings { + if _, present := notRuntimeParams[k]; present { + continue + } + config.RuntimeParams[k] = v + } + + var tlsConfigs []*tls.Config + + // Ignore TLS settings if Unix domain socket like libpq + if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" { + tlsConfigs = append(tlsConfigs, nil) + } else { + var err error + tlsConfigs, err = configTLS(settings) + if err != nil { + return nil, err + } + } + + config.TLSConfig = tlsConfigs[0] + + for _, tlsConfig := range tlsConfigs[1:] { + config.Fallbacks = append(config.Fallbacks, &FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: tlsConfig, + }) + } + + passfile, err := pgpassfile.ReadPassfile(settings["passfile"]) + if err == nil { + if config.Password == "" { + host := config.Host + if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" { + host = "localhost" + } + + config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User) + } + } + + return config, nil +} + +func defaultSettings() map[string]string { + settings := make(map[string]string) + + settings["host"] = defaultHost() + settings["port"] = "5432" + + // Default to the OS user name. Purposely ignoring err getting user name from + // OS. The client application will simply have to specify the user in that + // case (which they typically will be doing anyway). + user, err := user.Current() + if err == nil { + settings["user"] = user.Username + settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass") + } + + return settings +} + +// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost +// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it +// checks the existence of common locations. +func defaultHost() string { + candidatePaths := []string{ + "/var/run/postgresql", // Debian + "/private/tmp", // OSX - homebrew + "/tmp", // standard PostgreSQL + } + + for _, path := range candidatePaths { + if _, err := os.Stat(path); err == nil { + return path + } + } + + return "localhost" +} + +func addEnvSettings(settings map[string]string) { + nameMap := map[string]string{ + "PGHOST": "host", + "PGPORT": "port", + "PGDATABASE": "database", + "PGUSER": "user", + "PGPASSWORD": "password", + "PGPASSFILE": "passfile", + "PGAPPNAME": "application_name", + "PGCONNECT_TIMEOUT": "connect_timeout", + "PGSSLMODE": "sslmode", + "PGSSLKEY": "sslkey", + "PGSSLCERT": "sslcert", + "PGSSLROOTCERT": "sslrootcert", + } + + for envname, realname := range nameMap { + value := os.Getenv(envname) + if value != "" { + settings[realname] = value + } + } +} + +func addURLSettings(settings map[string]string, url *url.URL) error { + if url.User != nil { + settings["user"] = url.User.Username() + if password, present := url.User.Password(); present { + settings["password"] = password + } + } + + parts := strings.SplitN(url.Host, ":", 2) + if parts[0] != "" { + settings["host"] = parts[0] + } + if len(parts) == 2 { + settings["port"] = parts[1] + } + + database := strings.TrimLeft(url.Path, "/") + if database != "" { + settings["database"] = database + } + + for k, v := range url.Query() { + settings[k] = v[0] + } + + return nil +} + +var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`) + +func addDSNSettings(settings map[string]string, s string) error { + m := dsnRegexp.FindAllStringSubmatch(s, -1) + + for _, b := range m { + settings[b[1]] = b[2] + } + + return nil +} + +type pgTLSArgs struct { + sslMode string + sslRootCert string + sslCert string + sslKey string +} + +// configTLS uses libpq's TLS parameters to construct []*tls.Config. It is +// necessary to allow returning multiple TLS configs as sslmode "allow" and +// "prefer" allow fallback. +func configTLS(settings map[string]string) ([]*tls.Config, error) { + host := settings["host"] + sslmode := settings["sslmode"] + sslrootcert := settings["sslrootcert"] + sslcert := settings["sslcert"] + sslkey := settings["sslkey"] + + // Match libpq default behavior + if sslmode == "" { + sslmode = "prefer" + } + + tlsConfig := &tls.Config{} + + switch sslmode { + case "disable": + return []*tls.Config{nil}, nil + case "allow", "prefer": + tlsConfig.InsecureSkipVerify = true + case "require": + tlsConfig.InsecureSkipVerify = sslrootcert == "" + case "verify-ca", "verify-full": + tlsConfig.ServerName = host + default: + return nil, errors.New("sslmode is invalid") + } + + if sslrootcert != "" { + caCertPool := x509.NewCertPool() + + caPath := sslrootcert + caCert, err := ioutil.ReadFile(caPath) + if err != nil { + return nil, errors.Wrapf(err, "unable to read CA file %q", caPath) + } + + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, errors.Wrap(err, "unable to add CA to cert pool") + } + + tlsConfig.RootCAs = caCertPool + tlsConfig.ClientCAs = caCertPool + } + + if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") { + return nil, fmt.Errorf(`both "sslcert" and "sslkey" are required`) + } + + if sslcert != "" && sslkey != "" { + cert, err := tls.LoadX509KeyPair(sslcert, sslkey) + if err != nil { + return nil, errors.Wrap(err, "unable to read cert") + } + + tlsConfig.Certificates = []tls.Certificate{cert} + } + + switch sslmode { + case "allow": + return []*tls.Config{nil, tlsConfig}, nil + case "prefer": + return []*tls.Config{tlsConfig, nil}, nil + case "require", "verify-ca", "verify-full": + return []*tls.Config{tlsConfig}, nil + default: + panic("BUG: bad sslmode should already have been caught") + } +} + +func parsePort(s string) (uint16, error) { + port, err := strconv.ParseUint(s, 10, 16) + if err != nil { + return 0, err + } + if port < 1 || port > math.MaxUint16 { + return 0, errors.New("outside range") + } + return uint16(port), nil +} + +func makeDefaultDialer() *net.Dialer { + return &net.Dialer{KeepAlive: 5 * time.Minute} +} + +func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { + timeout, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return nil, err + } + if timeout < 0 { + return nil, errors.New("negative timeout") + } + + d := makeDefaultDialer() + d.Timeout = time.Duration(timeout) * time.Second + return d.DialContext, nil +} diff --git a/config_test.go b/config_test.go new file mode 100644 index 00000000..796876f2 --- /dev/null +++ b/config_test.go @@ -0,0 +1,392 @@ +package pgconn_test + +import ( + "crypto/tls" + "fmt" + "io/ioutil" + "os" + "os/user" + "testing" + + "github.com/jackc/pgx/pgconn" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseConfig(t *testing.T) { + t.Parallel() + + var osUserName string + osUser, err := user.Current() + if err == nil { + osUserName = osUser.Username + } + + tests := []struct { + name string + connString string + config *pgconn.Config + }{ + // Test all sslmodes + { + name: "sslmode not set (prefer)", + connString: "postgres://jack:secret@localhost:5432/mydb", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "sslmode disable", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "sslmode allow", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=allow", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 5432, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, + }, + }, + }, + { + name: "sslmode prefer", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer", + config: &pgconn.Config{ + + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "sslmode require", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=require", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "sslmode verify-ca", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-ca", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ServerName: "localhost"}, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "sslmode verify-full", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-full", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ServerName: "localhost"}, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url everything", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + }, + }, + }, + { + name: "database url missing password", + connString: "postgres://jack@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url missing user and password", + connString: "postgres://localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: osUserName, + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url missing port", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url unix domain socket host", + connString: "postgres:///foo?host=/tmp", + config: &pgconn.Config{ + User: osUserName, + Host: "/tmp", + Port: 5432, + Database: "foo", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "DSN everything", + connString: "user=jack password=secret host=localhost port=5432 database=mydb sslmode=disable application_name=pgxtest search_path=myschema", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + }, + }, + }, + } + + for i, tt := range tests { + config, err := pgconn.ParseConfig(tt.connString) + if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) { + continue + } + + assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name)) + } +} + +func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) { + assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName) + assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName) + assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName) + assert.Equalf(t, expected.User, actual.User, "%s - User", testName) + assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName) + assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) + + if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { + if expected.TLSConfig != nil { + assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName) + assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName) + } + } + + if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks %v", testName) { + for i := range expected.Fallbacks { + assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i) + assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i) + + if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName) { + if expected.Fallbacks[i].TLSConfig != nil { + assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName) + assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName) + } + } + } + } +} + +func TestParseConfigEnvLibpq(t *testing.T) { + var osUserName string + osUser, err := user.Current() + if err == nil { + osUserName = osUser.Username + } + + pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT"} + + savedEnv := make(map[string]string) + for _, n := range pgEnvvars { + savedEnv[n] = os.Getenv(n) + } + defer func() { + for k, v := range savedEnv { + err := os.Setenv(k, v) + if err != nil { + t.Fatalf("Unable to restore environment: %v", err) + } + } + }() + + tests := []struct { + name string + envvars map[string]string + config *pgconn.Config + }{ + { + // not testing no environment at all as that would use default host and that can vary. + name: "PGHOST only", + envvars: map[string]string{"PGHOST": "123.123.123.123"}, + config: &pgconn.Config{ + User: osUserName, + Host: "123.123.123.123", + Port: 5432, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "123.123.123.123", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "All non-TLS environment", + envvars: map[string]string{ + "PGHOST": "123.123.123.123", + "PGPORT": "7777", + "PGDATABASE": "foo", + "PGUSER": "bar", + "PGPASSWORD": "baz", + "PGCONNECT_TIMEOUT": "10", + "PGSSLMODE": "disable", + "PGAPPNAME": "pgxtest", + }, + config: &pgconn.Config{ + Host: "123.123.123.123", + Port: 7777, + Database: "foo", + User: "bar", + Password: "baz", + TLSConfig: nil, + RuntimeParams: map[string]string{"application_name": "pgxtest"}, + }, + }, + } + + for i, tt := range tests { + for _, n := range pgEnvvars { + err := os.Unsetenv(n) + require.Nil(t, err) + } + + for k, v := range tt.envvars { + err := os.Setenv(k, v) + require.Nil(t, err) + } + + config, err := pgconn.ParseConfig("") + if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) { + continue + } + + assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name)) + } +} + +func TestParseConfigReadsPgPassfile(t *testing.T) { + tf, err := ioutil.TempFile("", "") + require.Nil(t, err) + + defer tf.Close() + defer os.Remove(tf.Name()) + + _, err = tf.Write([]byte("test1:5432:curlydb:curly:nyuknyuknyuk")) + require.Nil(t, err) + + connString := fmt.Sprintf("postgres://curly@test1:5432/curlydb?sslmode=disable&passfile=%s", tf.Name()) + expected := &pgconn.Config{ + User: "curly", + Password: "nyuknyuknyuk", + Host: "test1", + Port: 5432, + Database: "curlydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + } + + actual, err := pgconn.ParseConfig(connString) + assert.Nil(t, err) + + assertConfigsEqual(t, expected, actual, "passfile") +} diff --git a/pgconn.go b/pgconn.go index c9caef42..37a205dc 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1,20 +1,16 @@ package pgconn import ( + "context" "crypto/md5" "crypto/tls" "encoding/binary" "encoding/hex" "errors" - "fmt" "io" "net" - "os" - "os/user" - "path/filepath" "strconv" "strings" - "time" "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgproto3" @@ -23,7 +19,7 @@ import ( const batchBufferSize = 4096 // PgError represents an error reported by the PostgreSQL server. See -// http://www.postgresql.org/docs/9.3/static/protocol-error-fields.html for +// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for // detailed field description. type PgError struct { Severity string @@ -50,60 +46,12 @@ func (pe PgError) Error() string { } // DialFunc is a function that can be used to connect to a PostgreSQL server -type DialFunc func(network, addr string) (net.Conn, error) +type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) // ErrTLSRefused occurs when the connection attempt requires TLS and the // PostgreSQL server refuses to use TLS var ErrTLSRefused = errors.New("server refused TLS connection") -type ConnConfig struct { - Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) - Port uint16 // default: 5432 - Database string - User string // default: OS user name - Password string - TLSConfig *tls.Config // config for TLS connection -- nil disables TLS - Dial DialFunc - RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) -} - -func (cc *ConnConfig) NetworkAddress() (network, address string) { - // If host is a valid path, then address is unix socket - if _, err := os.Stat(cc.Host); err == nil { - network = "unix" - address = cc.Host - if !strings.Contains(address, "/.s.PGSQL.") { - address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10) - } - } else { - network = "tcp" - address = fmt.Sprintf("%s:%d", cc.Host, cc.Port) - } - - return network, address -} - -func (cc *ConnConfig) assignDefaults() error { - if cc.User == "" { - user, err := user.Current() - if err != nil { - return err - } - cc.User = user.Username - } - - if cc.Port == 0 { - cc.Port = 5432 - } - - if cc.Dial == nil { - defaultDialer := &net.Dialer{KeepAlive: 5 * time.Minute} - cc.Dial = defaultDialer.Dial - } - - return nil -} - // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { NetConn net.Conn // the underlying TCP or unix domain socket connection @@ -113,7 +61,7 @@ type PgConn struct { TxStatus byte Frontend *pgproto3.Frontend - Config ConnConfig + Config *Config batchBuf []byte batchCount int32 @@ -123,24 +71,72 @@ type PgConn struct { closed bool } -func Connect(cc ConnConfig) (*PgConn, error) { - err := cc.assignDefaults() +// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) +// to provide configuration. See documention for ParseConfig for details. ctx can be used to cancel a connect attempt. +func Connect(ctx context.Context, connString string) (*PgConn, error) { + config, err := ParseConfig(connString) if err != nil { return nil, err } - pgConn := new(PgConn) - pgConn.Config = cc + return ConnectConfig(ctx, config) +} - pgConn.NetConn, err = cc.Dial(cc.NetworkAddress()) +// Connect establishes a connection to a PostgreSQL server using config. ctx can be used to cancel a connect attempt. +// +// If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An +// authentication error will terminate the chain of attempts (like libpq: +// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise, +// if all attempts fail the last error is returned. +func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err error) { + // For convenience set a few defaults if not already set. This makes it simpler to directly construct a config. + if config.Port == 0 { + config.Port = 5432 + } + if config.DialFunc == nil { + config.DialFunc = makeDefaultDialer().DialContext + } + if config.RuntimeParams == nil { + config.RuntimeParams = make(map[string]string) + } + + // Simplify usage by treating primary config and fallbacks the same. + fallbackConfigs := []*FallbackConfig{ + { + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }, + } + fallbackConfigs = append(fallbackConfigs, config.Fallbacks...) + + for _, fc := range fallbackConfigs { + pgConn, err = connect(ctx, config, fc) + if err == nil { + return pgConn, nil + } else if err, ok := err.(PgError); ok { + return nil, err + } + } + + return nil, err +} + +func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { + pgConn := new(PgConn) + pgConn.Config = config + + var err error + network, address := NetworkAddress(config.Host, config.Port) + pgConn.NetConn, err = config.DialFunc(ctx, network, address) if err != nil { return nil, err } pgConn.parameterStatuses = make(map[string]string) - if cc.TLSConfig != nil { - if err := pgConn.startTLS(cc.TLSConfig); err != nil { + if config.TLSConfig != nil { + if err := pgConn.startTLS(config.TLSConfig); err != nil { return nil, err } } @@ -156,13 +152,13 @@ func Connect(cc ConnConfig) (*PgConn, error) { } // Copy default run-time params - for k, v := range cc.RuntimeParams { + for k, v := range config.RuntimeParams { startupMsg.Parameters[k] = v } - startupMsg.Parameters["user"] = cc.User - if cc.Database != "" { - startupMsg.Parameters["database"] = cc.Database + startupMsg.Parameters["user"] = config.User + if config.Database != "" { + startupMsg.Parameters["database"] = config.Database } if _, err := pgConn.NetConn.Write(startupMsg.Encode(nil)); err != nil { diff --git a/pgconn_test.go b/pgconn_test.go index dbcf2704..f165786e 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1,16 +1,18 @@ package pgconn_test import ( - "github.com/jackc/pgx/pgconn" - + "context" + "os" "testing" + "github.com/jackc/pgx/pgconn" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestSimple(t *testing.T) { - pgConn, err := pgconn.Connect(pgconn.ConnConfig{Host: "/var/run/postgresql", User: "jack", Database: "pgx_test"}) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) pgConn.SendExec("select current_database()")