2
0

Merge branch 'improve-connect-timeout' of git://github.com/georgysavva/pgconn into georgysavva-improve-connect-timeout

This commit is contained in:
Jack Christensen
2020-05-13 07:43:15 -05:00
4 changed files with 131 additions and 87 deletions
+13 -8
View File
@@ -36,6 +36,7 @@ type Config struct {
User string User string
Password string Password string
TLSConfig *tls.Config // nil disables TLS TLSConfig *tls.Config // nil disables TLS
ConnectTimeout time.Duration
DialFunc DialFunc // e.g. net.Dialer.DialContext DialFunc DialFunc // e.g. net.Dialer.DialContext
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
BuildFrontend BuildFrontendFunc BuildFrontend BuildFrontendFunc
@@ -191,12 +192,13 @@ func ParseConfig(connString string) (*Config, error) {
BuildFrontend: makeDefaultBuildFrontendFunc(int(minReadBufferSize)), BuildFrontend: makeDefaultBuildFrontendFunc(int(minReadBufferSize)),
} }
if connectTimeout, present := settings["connect_timeout"]; present { if connectTimeoutSetting, present := settings["connect_timeout"]; present {
dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout) connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
if err != nil { if err != nil {
return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err} return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err}
} }
config.DialFunc = dialFunc config.ConnectTimeout = connectTimeout
config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
} else { } else {
defaultDialer := makeDefaultDialer() defaultDialer := makeDefaultDialer()
config.DialFunc = defaultDialer.DialContext config.DialFunc = defaultDialer.DialContext
@@ -672,18 +674,21 @@ func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc {
} }
} }
func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { func parseConnectTimeoutSetting(s string) (time.Duration, error) {
timeout, err := strconv.ParseInt(s, 10, 64) timeout, err := strconv.ParseInt(s, 10, 64)
if err != nil { if err != nil {
return nil, err return 0, err
} }
if timeout < 0 { if timeout < 0 {
return nil, errors.New("negative timeout") return 0, errors.New("negative timeout")
} }
return time.Duration(timeout) * time.Second, nil
}
func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc {
d := makeDefaultDialer() d := makeDefaultDialer()
d.Timeout = time.Duration(timeout) * time.Second d.Timeout = timeout
return d.DialContext, nil return d.DialContext
} }
// ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible // ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible
+7 -2
View File
@@ -7,6 +7,7 @@ import (
"os" "os"
"os/user" "os/user"
"testing" "testing"
"time"
"github.com/jackc/pgconn" "github.com/jackc/pgconn"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -153,7 +154,7 @@ func TestParseConfig(t *testing.T) {
}, },
{ {
name: "database url everything", name: "database url everything",
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema&connect_timeout=5",
config: &pgconn.Config{ config: &pgconn.Config{
User: "jack", User: "jack",
Password: "secret", Password: "secret",
@@ -161,6 +162,7 @@ func TestParseConfig(t *testing.T) {
Port: 5432, Port: 5432,
Database: "mydb", Database: "mydb",
TLSConfig: nil, TLSConfig: nil,
ConnectTimeout: 5 * time.Second,
RuntimeParams: map[string]string{ RuntimeParams: map[string]string{
"application_name": "pgxtest", "application_name": "pgxtest",
"search_path": "myschema", "search_path": "myschema",
@@ -230,7 +232,7 @@ func TestParseConfig(t *testing.T) {
}, },
{ {
name: "DSN everything", name: "DSN everything",
connString: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable application_name=pgxtest search_path=myschema", connString: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable application_name=pgxtest search_path=myschema connect_timeout=5",
config: &pgconn.Config{ config: &pgconn.Config{
User: "jack", User: "jack",
Password: "secret", Password: "secret",
@@ -238,6 +240,7 @@ func TestParseConfig(t *testing.T) {
Port: 5432, Port: 5432,
Database: "mydb", Database: "mydb",
TLSConfig: nil, TLSConfig: nil,
ConnectTimeout: 5 * time.Second,
RuntimeParams: map[string]string{ RuntimeParams: map[string]string{
"application_name": "pgxtest", "application_name": "pgxtest",
"search_path": "myschema", "search_path": "myschema",
@@ -501,6 +504,7 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName
assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName) assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName)
assert.Equalf(t, expected.User, actual.User, "%s - User", testName) assert.Equalf(t, expected.User, actual.User, "%s - User", testName)
assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName) assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName)
assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName)
assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
// Can't test function equality, so just test that they are set or not. // Can't test function equality, so just test that they are set or not.
@@ -595,6 +599,7 @@ func TestParseConfigEnvLibpq(t *testing.T) {
Database: "foo", Database: "foo",
User: "bar", User: "bar",
Password: "baz", Password: "baz",
ConnectTimeout: 10 * time.Second,
TLSConfig: nil, TLSConfig: nil,
RuntimeParams: map[string]string{"application_name": "pgxtest"}, RuntimeParams: map[string]string{"application_name": "pgxtest"},
}, },
+6
View File
@@ -116,6 +116,12 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
panic("config must be created by ParseConfig") panic("config must be created by ParseConfig")
} }
// ConnectTimeout restricts the whole connection process.
if config.ConnectTimeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, config.ConnectTimeout)
defer cancel()
}
// Simplify usage by treating primary config and fallbacks the same. // Simplify usage by treating primary config and fallbacks the same.
fallbackConfigs := []*FallbackConfig{ fallbackConfigs := []*FallbackConfig{
{ {
+34 -6
View File
@@ -17,8 +17,9 @@ import (
"testing" "testing"
"time" "time"
"github.com/jackc/pgconn"
"github.com/jackc/pgmock" "github.com/jackc/pgmock"
"github.com/jackc/pgconn"
"github.com/jackc/pgproto3/v2" "github.com/jackc/pgproto3/v2"
errors "golang.org/x/xerrors" errors "golang.org/x/xerrors"
@@ -81,9 +82,36 @@ func (s pgmockWaitStep) Step(*pgproto3.Backend) error {
return nil return nil
} }
func TestConnectWithContextThatTimesOut(t *testing.T) { func TestConnectTimeout(t *testing.T) {
t.Parallel()
tests := []struct {
name string
connect func(connStr string) error
}{
{
name: "via context that times out",
connect: func(connStr string) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50)
defer cancel()
_, err := pgconn.Connect(ctx, connStr)
return err
},
},
{
name: "via config ConnectTimeout",
connect: func(connStr string) error {
conf, err := pgconn.ParseConfig(connStr)
require.NoError(t, err)
conf.ConnectTimeout = time.Microsecond * 50
_, err = pgconn.ConnectConfig(context.Background(), conf)
return err
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
script := &pgmock.Script{ script := &pgmock.Script{
Steps: []pgmock.Step{ Steps: []pgmock.Step{
pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}), pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
@@ -128,11 +156,11 @@ func TestConnectWithContextThatTimesOut(t *testing.T) {
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
tooLate := time.Now().Add(time.Millisecond * 500) tooLate := time.Now().Add(time.Millisecond * 500)
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50) err = tt.connect(connStr)
defer cancel()
_, err = pgconn.Connect(ctx, connStr)
require.True(t, pgconn.Timeout(err), err) require.True(t, pgconn.Timeout(err), err)
require.True(t, time.Now().Before(tooLate)) require.True(t, time.Now().Before(tooLate))
})
}
} }
func TestConnectInvalidUser(t *testing.T) { func TestConnectInvalidUser(t *testing.T) {