2
0

Handle timeout parameters

This commit is contained in:
Timothée Peignier
2017-12-16 19:10:22 -08:00
parent cbb3fa5ecc
commit 1bec450326
2 changed files with 68 additions and 11 deletions
+27 -2
View File
@@ -72,6 +72,7 @@ type ConnConfig struct {
Logger Logger Logger Logger
LogLevel int LogLevel int
Dial DialFunc Dial DialFunc
Timeout time.Duration
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
OnNotice NoticeHandler // Callback function called when a notice response is received. OnNotice NoticeHandler // Callback function called when a notice response is received.
} }
@@ -247,7 +248,7 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error)
network, address := c.config.networkAddress() network, address := c.config.networkAddress()
if c.config.Dial == nil { if c.config.Dial == nil {
c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial c.config.Dial = (&net.Dialer{Timeout: c.config.Timeout, KeepAlive: 5 * time.Minute}).Dial
} }
if c.shouldLog(LogLevelInfo) { if c.shouldLog(LogLevelInfo) {
@@ -655,13 +656,22 @@ func ParseURI(uri string) (ConnConfig, error) {
} }
cp.Database = strings.TrimLeft(url.Path, "/") cp.Database = strings.TrimLeft(url.Path, "/")
if pgtimeout := url.Query().Get("connect_timeout"); pgtimeout != "" {
timeout, err := strconv.ParseInt(pgtimeout, 10, 64)
if err != nil {
return cp, err
}
cp.Timeout = time.Duration(timeout) * time.Second
}
err = configSSL(url.Query().Get("sslmode"), &cp) err = configSSL(url.Query().Get("sslmode"), &cp)
if err != nil { if err != nil {
return cp, err return cp, err
} }
ignoreKeys := map[string]struct{}{ ignoreKeys := map[string]struct{}{
"sslmode": {}, "sslmode": {},
"connect_timeout": {},
} }
cp.RuntimeParams = make(map[string]string) cp.RuntimeParams = make(map[string]string)
@@ -719,6 +729,12 @@ func ParseDSN(s string) (ConnConfig, error) {
cp.Database = b[2] cp.Database = b[2]
case "sslmode": case "sslmode":
sslmode = b[2] sslmode = b[2]
case "connect_timeout":
t, err := strconv.ParseInt(b[2], 10, 64)
if err != nil {
return cp, err
}
cp.Timeout = time.Duration(t) * time.Second
default: default:
cp.RuntimeParams[b[1]] = b[2] cp.RuntimeParams[b[1]] = b[2]
} }
@@ -756,6 +772,7 @@ func ParseConnectionString(s string) (ConnConfig, error) {
// PGPASSWORD // PGPASSWORD
// PGSSLMODE // PGSSLMODE
// PGAPPNAME // PGAPPNAME
// PGCONNECT_TIMEOUT
// //
// Important TLS Security Notes: // Important TLS Security Notes:
// ParseEnvLibpq tries to match libpq behavior with regard to PGSSLMODE. This // ParseEnvLibpq tries to match libpq behavior with regard to PGSSLMODE. This
@@ -791,6 +808,14 @@ func ParseEnvLibpq() (ConnConfig, error) {
cc.User = os.Getenv("PGUSER") cc.User = os.Getenv("PGUSER")
cc.Password = os.Getenv("PGPASSWORD") cc.Password = os.Getenv("PGPASSWORD")
if pgtimeout := os.Getenv("PGCONNECT_TIMEOUT"); pgtimeout != "" {
if timeout, err := strconv.ParseInt(pgtimeout, 10, 64); err == nil {
cc.Timeout = time.Duration(timeout) * time.Second
} else {
return cc, err
}
}
sslmode := os.Getenv("PGSSLMODE") sslmode := os.Getenv("PGSSLMODE")
err := configSSL(sslmode, &cc) err := configSSL(sslmode, &cc)
+41 -9
View File
@@ -542,6 +542,21 @@ func TestParseDSN(t *testing.T) {
}, },
}, },
}, },
{
url: "user=jack host=localhost dbname=mydb connect_timeout=10",
connParams: pgx.ConnConfig{
User: "jack",
Host: "localhost",
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
Timeout: 10 * time.Second,
UseFallbackTLS: true,
FallbackTLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
} }
for i, tt := range tests { for i, tt := range tests {
@@ -672,6 +687,21 @@ func TestParseConnectionString(t *testing.T) {
}, },
}, },
}, },
{
url: "postgres://jack@localhost/mydb?connect_timeout=10",
connParams: pgx.ConnConfig{
User: "jack",
Host: "localhost",
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
Timeout: 10 * time.Second,
UseFallbackTLS: true,
FallbackTLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{ {
url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable", url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable",
connParams: pgx.ConnConfig{ connParams: pgx.ConnConfig{
@@ -777,7 +807,7 @@ func TestParseConnectionString(t *testing.T) {
} }
func TestParseEnvLibpq(t *testing.T) { func TestParseEnvLibpq(t *testing.T) {
pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE"} pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT"}
savedEnv := make(map[string]string) savedEnv := make(map[string]string)
for _, n := range pgEnvvars { for _, n := range pgEnvvars {
@@ -810,11 +840,12 @@ func TestParseEnvLibpq(t *testing.T) {
{ {
name: "Normal PG vars", name: "Normal PG vars",
envvars: map[string]string{ envvars: map[string]string{
"PGHOST": "123.123.123.123", "PGHOST": "123.123.123.123",
"PGPORT": "7777", "PGPORT": "7777",
"PGDATABASE": "foo", "PGDATABASE": "foo",
"PGUSER": "bar", "PGUSER": "bar",
"PGPASSWORD": "baz", "PGPASSWORD": "baz",
"PGCONNECT_TIMEOUT": "10",
}, },
config: pgx.ConnConfig{ config: pgx.ConnConfig{
Host: "123.123.123.123", Host: "123.123.123.123",
@@ -825,6 +856,7 @@ func TestParseEnvLibpq(t *testing.T) {
TLSConfig: &tls.Config{InsecureSkipVerify: true}, TLSConfig: &tls.Config{InsecureSkipVerify: true},
UseFallbackTLS: true, UseFallbackTLS: true,
FallbackTLSConfig: nil, FallbackTLSConfig: nil,
Timeout: 10 * time.Second,
RuntimeParams: map[string]string{}, RuntimeParams: map[string]string{},
}, },
}, },
@@ -1988,9 +2020,9 @@ func TestConnInitConnInfo(t *testing.T) {
// spot check that the standard postgres type names aren't qualified // spot check that the standard postgres type names aren't qualified
nameOIDs := map[string]pgtype.OID{ nameOIDs := map[string]pgtype.OID{
"_int8": pgtype.Int8ArrayOID, "_int8": pgtype.Int8ArrayOID,
"int8": pgtype.Int8OID, "int8": pgtype.Int8OID,
"json": pgtype.JSONOID, "json": pgtype.JSONOID,
"text": pgtype.TextOID, "text": pgtype.TextOID,
} }
for name, oid := range nameOIDs { for name, oid := range nameOIDs {
dtByName, ok := conn.ConnInfo.DataTypeForName(name) dtByName, ok := conn.ConnInfo.DataTypeForName(name)