From cfbd2519e3a9dd64906a0888c38ee05d78e19889 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 7 Mar 2020 13:17:39 -0600 Subject: [PATCH] Add PGSERVICE and PGSERVICEFILE support --- config.go | 93 +++++++++++++++++++++++++++++++++++++++++------- config_test.go | 95 ++++++++++++++++++++++++++++++++++++++++++++++++++ go.mod | 3 +- go.sum | 4 +++ 4 files changed, 182 insertions(+), 13 deletions(-) diff --git a/config.go b/config.go index 9876ac94..19521a8f 100644 --- a/config.go +++ b/config.go @@ -20,6 +20,7 @@ import ( "github.com/jackc/chunkreader/v2" "github.com/jackc/pgpassfile" "github.com/jackc/pgproto3/v2" + "github.com/jackc/pgservicefile" errors "golang.org/x/xerrors" ) @@ -108,6 +109,8 @@ func NetworkAddress(host string, port uint16) (network, address string) { // PGUSER // PGPASSWORD // PGPASSFILE +// PGSERVICE +// PGSERVICEFILE // PGSSLMODE // PGSSLCERT // PGSSLKEY @@ -145,25 +148,40 @@ func NetworkAddress(host string, port uint16) (network, address string) { // // min_read_buffer_size // The minimum size of the internal read buffer. Default 8192. +// servicefile +// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a +// part of the connection string. func ParseConfig(connString string) (*Config, error) { - settings := defaultSettings() - addEnvSettings(settings) + defaultSettings := defaultSettings() + envSettings := parseEnvSettings() + connStringSettings := make(map[string]string) if connString != "" { + var err error // connString may be a database URL or a DSN if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { - err := addURLSettings(settings, connString) + connStringSettings, err = parseURLSettings(connString) if err != nil { return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err} } } else { - err := addDSNSettings(settings, connString) + connStringSettings, err = parseDSNSettings(connString) if err != nil { return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err} } } } + settings := mergeSettings(defaultSettings, envSettings, connStringSettings) + if service, present := settings["service"]; present { + serviceSettings, err := parseServiceSettings(settings["servicefile"], service) + if err != nil { + return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err} + } + + settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings) + } + minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32) if err != nil { return nil, &parseConfigError{connString: connString, msg: "cannot parse min_read_buffer_size", err: err} @@ -205,6 +223,8 @@ func ParseConfig(connString string) (*Config, error) { "sslrootcert": struct{}{}, "target_session_attrs": struct{}{}, "min_read_buffer_size": struct{}{}, + "service": struct{}{}, + "servicefile": struct{}{}, } for k, v := range settings { @@ -293,6 +313,7 @@ func defaultSettings() map[string]string { if err == nil { settings["user"] = user.Username settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass") + settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf") } settings["target_session_attrs"] = "any" @@ -321,7 +342,21 @@ func defaultHost() string { return "localhost" } -func addEnvSettings(settings map[string]string) { +func mergeSettings(settingSets ...map[string]string) map[string]string { + settings := make(map[string]string) + + for _, s2 := range settingSets { + for k, v := range s2 { + settings[k] = v + } + } + + return settings +} + +func parseEnvSettings() map[string]string { + settings := make(map[string]string) + nameMap := map[string]string{ "PGHOST": "host", "PGPORT": "port", @@ -336,6 +371,8 @@ func addEnvSettings(settings map[string]string) { "PGSSLCERT": "sslcert", "PGSSLROOTCERT": "sslrootcert", "PGTARGETSESSIONATTRS": "target_session_attrs", + "PGSERVICE": "service", + "PGSERVICEFILE": "servicefile", } for envname, realname := range nameMap { @@ -344,12 +381,16 @@ func addEnvSettings(settings map[string]string) { settings[realname] = value } } + + return settings } -func addURLSettings(settings map[string]string, connString string) error { +func parseURLSettings(connString string) (map[string]string, error) { + settings := make(map[string]string) + url, err := url.Parse(connString) if err != nil { - return err + return nil, err } if url.User != nil { @@ -387,12 +428,14 @@ func addURLSettings(settings map[string]string, connString string) error { settings[k] = v[0] } - return nil + return settings, nil } var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1} -func addDSNSettings(settings map[string]string, s string) error { +func parseDSNSettings(s string) (map[string]string, error) { + settings := make(map[string]string) + nameMap := map[string]string{ "dbname": "database", } @@ -401,7 +444,7 @@ func addDSNSettings(settings map[string]string, s string) error { var key, val string eqIdx := strings.IndexRune(s, '=') if eqIdx < 0 { - return errors.New("invalid dsn") + return nil, errors.New("invalid dsn") } key = strings.Trim(s[:eqIdx], " \t\n\r\v\f") @@ -434,7 +477,7 @@ func addDSNSettings(settings map[string]string, s string) error { } } if end == len(s) { - return errors.New("unterminated quoted string in connection info string") + return nil, errors.New("unterminated quoted string in connection info string") } val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) if end == len(s) { @@ -451,7 +494,33 @@ func addDSNSettings(settings map[string]string, s string) error { settings[key] = val } - return nil + return settings, nil +} + +func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) { + servicefile, err := pgservicefile.ReadServicefile(servicefilePath) + if err != nil { + fmt.Errorf("failed to read service file: %v", servicefile) + } + + service, err := servicefile.GetService(serviceName) + if err != nil { + fmt.Errorf("unable to find service: %v", servicefile) + } + + nameMap := map[string]string{ + "dbname": "database", + } + + settings := make(map[string]string, len(service.Settings)) + for k, v := range service.Settings { + if k2, present := nameMap[k]; present { + k = k2 + } + settings[k] = v + } + + return settings, nil } type pgTLSArgs struct { diff --git a/config_test.go b/config_test.go index 9eb5df2f..0819740f 100644 --- a/config_test.go +++ b/config_test.go @@ -648,6 +648,101 @@ func TestParseConfigReadsPgPassfile(t *testing.T) { assertConfigsEqual(t, expected, actual, "passfile") } +func TestParseConfigReadsPgServiceFile(t *testing.T) { + t.Parallel() + + tf, err := ioutil.TempFile("", "") + require.NoError(t, err) + + defer tf.Close() + defer os.Remove(tf.Name()) + + _, err = tf.Write([]byte(` +[abc] +host=abc.example.com +port=9999 +dbname=abcdb +user=abcuser + +[def] +host = def.example.com +dbname = defdb +user = defuser +application_name = spaced string +`)) + require.NoError(t, err) + + tests := []struct { + name string + connString string + config *pgconn.Config + }{ + { + name: "abc", + connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tf.Name(), "abc"), + config: &pgconn.Config{ + Host: "abc.example.com", + Database: "abcdb", + User: "abcuser", + Port: 9999, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "abc.example.com", + Port: 9999, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "def", + connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tf.Name(), "def"), + config: &pgconn.Config{ + Host: "def.example.com", + Port: 5432, + Database: "defdb", + User: "defuser", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{"application_name": "spaced string"}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "def.example.com", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "conn string has precedence", + connString: fmt.Sprintf("postgres://other.example.com:7777/?servicefile=%s&service=%s&sslmode=disable", tf.Name(), "abc"), + config: &pgconn.Config{ + Host: "other.example.com", + Database: "abcdb", + User: "abcuser", + Port: 7777, + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + } + + for i, tt := range tests { + config, err := pgconn.ParseConfig(tt.connString) + if !assert.NoErrorf(t, err, "Test %d (%s)", i, tt.name) { + continue + } + + assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name)) + } +} + func TestParseConfigExtractsMinReadBufferSize(t *testing.T) { t.Parallel() diff --git a/go.mod b/go.mod index 37590559..b306e1e4 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,8 @@ require ( github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgproto3/v2 v2.0.1 - github.com/stretchr/testify v1.4.0 + github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 + github.com/stretchr/testify v1.5.1 golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 golang.org/x/text v0.3.2 golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 diff --git a/go.sum b/go.sum index 28f094e7..13f276b2 100644 --- a/go.sum +++ b/go.sum @@ -30,6 +30,8 @@ github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 h1:f2HwOeI github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.1 h1:Rdjp4NFjwHnEslx2b66FfCI2S0LhO4itac3hXz6WX9M= github.com/jackc/pgproto3/v2 v2.0.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 h1:Q3tB+ExeflWUW7AFcAhXqk40s9mnNYLk1nOkKNZ5GnU= +github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= @@ -69,6 +71,8 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=