2
0

Merge branch 'improve-connect-timeout' of git://github.com/georgysavva/pgconn into georgysavva-improve-connect-timeout

This commit is contained in:
Jack Christensen
2020-05-13 07:43:15 -05:00
4 changed files with 131 additions and 87 deletions
+23 -18
View File
@@ -30,16 +30,17 @@ type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig and // Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig and
// then it can be modified. A manually initialized Config will cause ConnectConfig to panic. // then it can be modified. A manually initialized Config will cause ConnectConfig to panic.
type Config struct { type Config struct {
Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp) Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp)
Port uint16 Port uint16
Database string Database string
User string User string
Password string Password string
TLSConfig *tls.Config // nil disables TLS TLSConfig *tls.Config // nil disables TLS
DialFunc DialFunc // e.g. net.Dialer.DialContext ConnectTimeout time.Duration
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost DialFunc DialFunc // e.g. net.Dialer.DialContext
BuildFrontend BuildFrontendFunc LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) BuildFrontend BuildFrontendFunc
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
Fallbacks []*FallbackConfig Fallbacks []*FallbackConfig
@@ -191,12 +192,13 @@ func ParseConfig(connString string) (*Config, error) {
BuildFrontend: makeDefaultBuildFrontendFunc(int(minReadBufferSize)), BuildFrontend: makeDefaultBuildFrontendFunc(int(minReadBufferSize)),
} }
if connectTimeout, present := settings["connect_timeout"]; present { if connectTimeoutSetting, present := settings["connect_timeout"]; present {
dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout) connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
if err != nil { if err != nil {
return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err} return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err}
} }
config.DialFunc = dialFunc config.ConnectTimeout = connectTimeout
config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
} else { } else {
defaultDialer := makeDefaultDialer() defaultDialer := makeDefaultDialer()
config.DialFunc = defaultDialer.DialContext config.DialFunc = defaultDialer.DialContext
@@ -672,18 +674,21 @@ func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc {
} }
} }
func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { func parseConnectTimeoutSetting(s string) (time.Duration, error) {
timeout, err := strconv.ParseInt(s, 10, 64) timeout, err := strconv.ParseInt(s, 10, 64)
if err != nil { if err != nil {
return nil, err return 0, err
} }
if timeout < 0 { if timeout < 0 {
return nil, errors.New("negative timeout") return 0, errors.New("negative timeout")
} }
return time.Duration(timeout) * time.Second, nil
}
func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc {
d := makeDefaultDialer() d := makeDefaultDialer()
d.Timeout = time.Duration(timeout) * time.Second d.Timeout = timeout
return d.DialContext, nil return d.DialContext
} }
// ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible // ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible
+31 -26
View File
@@ -7,6 +7,7 @@ import (
"os" "os"
"os/user" "os/user"
"testing" "testing"
"time"
"github.com/jackc/pgconn" "github.com/jackc/pgconn"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -127,11 +128,11 @@ func TestParseConfig(t *testing.T) {
name: "sslmode verify-ca", name: "sslmode verify-ca",
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-ca", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-ca",
config: &pgconn.Config{ config: &pgconn.Config{
User: "jack", User: "jack",
Password: "secret", Password: "secret",
Host: "localhost", Host: "localhost",
Port: 5432, Port: 5432,
Database: "mydb", Database: "mydb",
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
}, },
@@ -153,14 +154,15 @@ func TestParseConfig(t *testing.T) {
}, },
{ {
name: "database url everything", name: "database url everything",
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema&connect_timeout=5",
config: &pgconn.Config{ config: &pgconn.Config{
User: "jack", User: "jack",
Password: "secret", Password: "secret",
Host: "localhost", Host: "localhost",
Port: 5432, Port: 5432,
Database: "mydb", Database: "mydb",
TLSConfig: nil, TLSConfig: nil,
ConnectTimeout: 5 * time.Second,
RuntimeParams: map[string]string{ RuntimeParams: map[string]string{
"application_name": "pgxtest", "application_name": "pgxtest",
"search_path": "myschema", "search_path": "myschema",
@@ -230,14 +232,15 @@ func TestParseConfig(t *testing.T) {
}, },
{ {
name: "DSN everything", name: "DSN everything",
connString: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable application_name=pgxtest search_path=myschema", connString: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable application_name=pgxtest search_path=myschema connect_timeout=5",
config: &pgconn.Config{ config: &pgconn.Config{
User: "jack", User: "jack",
Password: "secret", Password: "secret",
Host: "localhost", Host: "localhost",
Port: 5432, Port: 5432,
Database: "mydb", Database: "mydb",
TLSConfig: nil, TLSConfig: nil,
ConnectTimeout: 5 * time.Second,
RuntimeParams: map[string]string{ RuntimeParams: map[string]string{
"application_name": "pgxtest", "application_name": "pgxtest",
"search_path": "myschema", "search_path": "myschema",
@@ -501,6 +504,7 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName
assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName) assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName)
assert.Equalf(t, expected.User, actual.User, "%s - User", testName) assert.Equalf(t, expected.User, actual.User, "%s - User", testName)
assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName) assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName)
assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName)
assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
// Can't test function equality, so just test that they are set or not. // Can't test function equality, so just test that they are set or not.
@@ -590,13 +594,14 @@ func TestParseConfigEnvLibpq(t *testing.T) {
"PGAPPNAME": "pgxtest", "PGAPPNAME": "pgxtest",
}, },
config: &pgconn.Config{ config: &pgconn.Config{
Host: "123.123.123.123", Host: "123.123.123.123",
Port: 7777, Port: 7777,
Database: "foo", Database: "foo",
User: "bar", User: "bar",
Password: "baz", Password: "baz",
TLSConfig: nil, ConnectTimeout: 10 * time.Second,
RuntimeParams: map[string]string{"application_name": "pgxtest"}, TLSConfig: nil,
RuntimeParams: map[string]string{"application_name": "pgxtest"},
}, },
}, },
} }
+6
View File
@@ -116,6 +116,12 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
panic("config must be created by ParseConfig") panic("config must be created by ParseConfig")
} }
// ConnectTimeout restricts the whole connection process.
if config.ConnectTimeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, config.ConnectTimeout)
defer cancel()
}
// Simplify usage by treating primary config and fallbacks the same. // Simplify usage by treating primary config and fallbacks the same.
fallbackConfigs := []*FallbackConfig{ fallbackConfigs := []*FallbackConfig{
{ {
+71 -43
View File
@@ -17,8 +17,9 @@ import (
"testing" "testing"
"time" "time"
"github.com/jackc/pgconn"
"github.com/jackc/pgmock" "github.com/jackc/pgmock"
"github.com/jackc/pgconn"
"github.com/jackc/pgproto3/v2" "github.com/jackc/pgproto3/v2"
errors "golang.org/x/xerrors" errors "golang.org/x/xerrors"
@@ -81,58 +82,85 @@ func (s pgmockWaitStep) Step(*pgproto3.Backend) error {
return nil return nil
} }
func TestConnectWithContextThatTimesOut(t *testing.T) { func TestConnectTimeout(t *testing.T) {
t.Parallel() t.Parallel()
tests := []struct {
script := &pgmock.Script{ name string
Steps: []pgmock.Step{ connect func(connStr string) error
pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}), }{
pgmock.SendMessage(&pgproto3.AuthenticationOk{}), {
pgmockWaitStep(time.Millisecond * 500), name: "via context that times out",
pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}), connect: func(connStr string) error {
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50)
defer cancel()
_, err := pgconn.Connect(ctx, connStr)
return err
},
},
{
name: "via config ConnectTimeout",
connect: func(connStr string) error {
conf, err := pgconn.ParseConfig(connStr)
require.NoError(t, err)
conf.ConnectTimeout = time.Microsecond * 50
_, err = pgconn.ConnectConfig(context.Background(), conf)
return err
},
}, },
} }
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
script := &pgmock.Script{
Steps: []pgmock.Step{
pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
pgmock.SendMessage(&pgproto3.AuthenticationOk{}),
pgmockWaitStep(time.Millisecond * 500),
pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
},
}
ln, err := net.Listen("tcp", "127.0.0.1:") ln, err := net.Listen("tcp", "127.0.0.1:")
require.NoError(t, err) require.NoError(t, err)
defer ln.Close() defer ln.Close()
serverErrChan := make(chan error, 1) serverErrChan := make(chan error, 1)
go func() { go func() {
defer close(serverErrChan) defer close(serverErrChan)
conn, err := ln.Accept() conn, err := ln.Accept()
if err != nil { if err != nil {
serverErrChan <- err serverErrChan <- err
return return
} }
defer conn.Close() defer conn.Close()
err = conn.SetDeadline(time.Now().Add(time.Millisecond * 450)) err = conn.SetDeadline(time.Now().Add(time.Millisecond * 450))
if err != nil { if err != nil {
serverErrChan <- err serverErrChan <- err
return return
} }
err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)) err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn))
if err != nil { if err != nil {
serverErrChan <- err serverErrChan <- err
return return
} }
}() }()
parts := strings.Split(ln.Addr().String(), ":") parts := strings.Split(ln.Addr().String(), ":")
host := parts[0] host := parts[0]
port := parts[1] port := parts[1]
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
tooLate := time.Now().Add(time.Millisecond * 500) tooLate := time.Now().Add(time.Millisecond * 500)
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50) err = tt.connect(connStr)
defer cancel() require.True(t, pgconn.Timeout(err), err)
_, err = pgconn.Connect(ctx, connStr) require.True(t, time.Now().Before(tooLate))
require.True(t, pgconn.Timeout(err), err) })
require.True(t, time.Now().Before(tooLate)) }
} }
func TestConnectInvalidUser(t *testing.T) { func TestConnectInvalidUser(t *testing.T) {