Add PGSERVICE and PGSERVICEFILE support
This commit is contained in:
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/jackc/chunkreader/v2"
|
||||
"github.com/jackc/pgpassfile"
|
||||
"github.com/jackc/pgproto3/v2"
|
||||
"github.com/jackc/pgservicefile"
|
||||
errors "golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
@@ -108,6 +109,8 @@ func NetworkAddress(host string, port uint16) (network, address string) {
|
||||
// PGUSER
|
||||
// PGPASSWORD
|
||||
// PGPASSFILE
|
||||
// PGSERVICE
|
||||
// PGSERVICEFILE
|
||||
// PGSSLMODE
|
||||
// PGSSLCERT
|
||||
// PGSSLKEY
|
||||
@@ -145,25 +148,40 @@ func NetworkAddress(host string, port uint16) (network, address string) {
|
||||
//
|
||||
// min_read_buffer_size
|
||||
// The minimum size of the internal read buffer. Default 8192.
|
||||
// servicefile
|
||||
// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a
|
||||
// part of the connection string.
|
||||
func ParseConfig(connString string) (*Config, error) {
|
||||
settings := defaultSettings()
|
||||
addEnvSettings(settings)
|
||||
defaultSettings := defaultSettings()
|
||||
envSettings := parseEnvSettings()
|
||||
|
||||
connStringSettings := make(map[string]string)
|
||||
if connString != "" {
|
||||
var err error
|
||||
// connString may be a database URL or a DSN
|
||||
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
|
||||
err := addURLSettings(settings, connString)
|
||||
connStringSettings, err = parseURLSettings(connString)
|
||||
if err != nil {
|
||||
return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err}
|
||||
}
|
||||
} else {
|
||||
err := addDSNSettings(settings, connString)
|
||||
connStringSettings, err = parseDSNSettings(connString)
|
||||
if err != nil {
|
||||
return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
settings := mergeSettings(defaultSettings, envSettings, connStringSettings)
|
||||
if service, present := settings["service"]; present {
|
||||
serviceSettings, err := parseServiceSettings(settings["servicefile"], service)
|
||||
if err != nil {
|
||||
return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err}
|
||||
}
|
||||
|
||||
settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings)
|
||||
}
|
||||
|
||||
minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32)
|
||||
if err != nil {
|
||||
return nil, &parseConfigError{connString: connString, msg: "cannot parse min_read_buffer_size", err: err}
|
||||
@@ -205,6 +223,8 @@ func ParseConfig(connString string) (*Config, error) {
|
||||
"sslrootcert": struct{}{},
|
||||
"target_session_attrs": struct{}{},
|
||||
"min_read_buffer_size": struct{}{},
|
||||
"service": struct{}{},
|
||||
"servicefile": struct{}{},
|
||||
}
|
||||
|
||||
for k, v := range settings {
|
||||
@@ -293,6 +313,7 @@ func defaultSettings() map[string]string {
|
||||
if err == nil {
|
||||
settings["user"] = user.Username
|
||||
settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass")
|
||||
settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf")
|
||||
}
|
||||
|
||||
settings["target_session_attrs"] = "any"
|
||||
@@ -321,7 +342,21 @@ func defaultHost() string {
|
||||
return "localhost"
|
||||
}
|
||||
|
||||
func addEnvSettings(settings map[string]string) {
|
||||
func mergeSettings(settingSets ...map[string]string) map[string]string {
|
||||
settings := make(map[string]string)
|
||||
|
||||
for _, s2 := range settingSets {
|
||||
for k, v := range s2 {
|
||||
settings[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return settings
|
||||
}
|
||||
|
||||
func parseEnvSettings() map[string]string {
|
||||
settings := make(map[string]string)
|
||||
|
||||
nameMap := map[string]string{
|
||||
"PGHOST": "host",
|
||||
"PGPORT": "port",
|
||||
@@ -336,6 +371,8 @@ func addEnvSettings(settings map[string]string) {
|
||||
"PGSSLCERT": "sslcert",
|
||||
"PGSSLROOTCERT": "sslrootcert",
|
||||
"PGTARGETSESSIONATTRS": "target_session_attrs",
|
||||
"PGSERVICE": "service",
|
||||
"PGSERVICEFILE": "servicefile",
|
||||
}
|
||||
|
||||
for envname, realname := range nameMap {
|
||||
@@ -344,12 +381,16 @@ func addEnvSettings(settings map[string]string) {
|
||||
settings[realname] = value
|
||||
}
|
||||
}
|
||||
|
||||
return settings
|
||||
}
|
||||
|
||||
func addURLSettings(settings map[string]string, connString string) error {
|
||||
func parseURLSettings(connString string) (map[string]string, error) {
|
||||
settings := make(map[string]string)
|
||||
|
||||
url, err := url.Parse(connString)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if url.User != nil {
|
||||
@@ -387,12 +428,14 @@ func addURLSettings(settings map[string]string, connString string) error {
|
||||
settings[k] = v[0]
|
||||
}
|
||||
|
||||
return nil
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
|
||||
|
||||
func addDSNSettings(settings map[string]string, s string) error {
|
||||
func parseDSNSettings(s string) (map[string]string, error) {
|
||||
settings := make(map[string]string)
|
||||
|
||||
nameMap := map[string]string{
|
||||
"dbname": "database",
|
||||
}
|
||||
@@ -401,7 +444,7 @@ func addDSNSettings(settings map[string]string, s string) error {
|
||||
var key, val string
|
||||
eqIdx := strings.IndexRune(s, '=')
|
||||
if eqIdx < 0 {
|
||||
return errors.New("invalid dsn")
|
||||
return nil, errors.New("invalid dsn")
|
||||
}
|
||||
|
||||
key = strings.Trim(s[:eqIdx], " \t\n\r\v\f")
|
||||
@@ -434,7 +477,7 @@ func addDSNSettings(settings map[string]string, s string) error {
|
||||
}
|
||||
}
|
||||
if end == len(s) {
|
||||
return errors.New("unterminated quoted string in connection info string")
|
||||
return nil, errors.New("unterminated quoted string in connection info string")
|
||||
}
|
||||
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
|
||||
if end == len(s) {
|
||||
@@ -451,7 +494,33 @@ func addDSNSettings(settings map[string]string, s string) error {
|
||||
settings[key] = val
|
||||
}
|
||||
|
||||
return nil
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) {
|
||||
servicefile, err := pgservicefile.ReadServicefile(servicefilePath)
|
||||
if err != nil {
|
||||
fmt.Errorf("failed to read service file: %v", servicefile)
|
||||
}
|
||||
|
||||
service, err := servicefile.GetService(serviceName)
|
||||
if err != nil {
|
||||
fmt.Errorf("unable to find service: %v", servicefile)
|
||||
}
|
||||
|
||||
nameMap := map[string]string{
|
||||
"dbname": "database",
|
||||
}
|
||||
|
||||
settings := make(map[string]string, len(service.Settings))
|
||||
for k, v := range service.Settings {
|
||||
if k2, present := nameMap[k]; present {
|
||||
k = k2
|
||||
}
|
||||
settings[k] = v
|
||||
}
|
||||
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
type pgTLSArgs struct {
|
||||
|
||||
@@ -648,6 +648,101 @@ func TestParseConfigReadsPgPassfile(t *testing.T) {
|
||||
assertConfigsEqual(t, expected, actual, "passfile")
|
||||
}
|
||||
|
||||
func TestParseConfigReadsPgServiceFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tf, err := ioutil.TempFile("", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
defer tf.Close()
|
||||
defer os.Remove(tf.Name())
|
||||
|
||||
_, err = tf.Write([]byte(`
|
||||
[abc]
|
||||
host=abc.example.com
|
||||
port=9999
|
||||
dbname=abcdb
|
||||
user=abcuser
|
||||
|
||||
[def]
|
||||
host = def.example.com
|
||||
dbname = defdb
|
||||
user = defuser
|
||||
application_name = spaced string
|
||||
`))
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
connString string
|
||||
config *pgconn.Config
|
||||
}{
|
||||
{
|
||||
name: "abc",
|
||||
connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tf.Name(), "abc"),
|
||||
config: &pgconn.Config{
|
||||
Host: "abc.example.com",
|
||||
Database: "abcdb",
|
||||
User: "abcuser",
|
||||
Port: 9999,
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
RuntimeParams: map[string]string{},
|
||||
Fallbacks: []*pgconn.FallbackConfig{
|
||||
&pgconn.FallbackConfig{
|
||||
Host: "abc.example.com",
|
||||
Port: 9999,
|
||||
TLSConfig: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "def",
|
||||
connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tf.Name(), "def"),
|
||||
config: &pgconn.Config{
|
||||
Host: "def.example.com",
|
||||
Port: 5432,
|
||||
Database: "defdb",
|
||||
User: "defuser",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
RuntimeParams: map[string]string{"application_name": "spaced string"},
|
||||
Fallbacks: []*pgconn.FallbackConfig{
|
||||
&pgconn.FallbackConfig{
|
||||
Host: "def.example.com",
|
||||
Port: 5432,
|
||||
TLSConfig: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "conn string has precedence",
|
||||
connString: fmt.Sprintf("postgres://other.example.com:7777/?servicefile=%s&service=%s&sslmode=disable", tf.Name(), "abc"),
|
||||
config: &pgconn.Config{
|
||||
Host: "other.example.com",
|
||||
Database: "abcdb",
|
||||
User: "abcuser",
|
||||
Port: 7777,
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
config, err := pgconn.ParseConfig(tt.connString)
|
||||
if !assert.NoErrorf(t, err, "Test %d (%s)", i, tt.name) {
|
||||
continue
|
||||
}
|
||||
|
||||
assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigExtractsMinReadBufferSize(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -8,7 +8,8 @@ require (
|
||||
github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2
|
||||
github.com/jackc/pgpassfile v1.0.0
|
||||
github.com/jackc/pgproto3/v2 v2.0.1
|
||||
github.com/stretchr/testify v1.4.0
|
||||
github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8
|
||||
github.com/stretchr/testify v1.5.1
|
||||
golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586
|
||||
golang.org/x/text v0.3.2
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7
|
||||
|
||||
@@ -30,6 +30,8 @@ github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 h1:f2HwOeI
|
||||
github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM=
|
||||
github.com/jackc/pgproto3/v2 v2.0.1 h1:Rdjp4NFjwHnEslx2b66FfCI2S0LhO4itac3hXz6WX9M=
|
||||
github.com/jackc/pgproto3/v2 v2.0.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA=
|
||||
github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 h1:Q3tB+ExeflWUW7AFcAhXqk40s9mnNYLk1nOkKNZ5GnU=
|
||||
github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E=
|
||||
github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg=
|
||||
github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc=
|
||||
github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw=
|
||||
@@ -69,6 +71,8 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4=
|
||||
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
|
||||
github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
|
||||
go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
|
||||
go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
|
||||
|
||||
Reference in New Issue
Block a user