2
0

Merge branch 'improve-connect-timeout' of git://github.com/georgysavva/pgconn into georgysavva-improve-connect-timeout

This commit is contained in:
Jack Christensen
2020-05-13 07:43:15 -05:00
4 changed files with 131 additions and 87 deletions
+23 -18
View File
@@ -30,16 +30,17 @@ type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig and
// then it can be modified. A manually initialized Config will cause ConnectConfig to panic.
type Config struct {
Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp)
Port uint16
Database string
User string
Password string
TLSConfig *tls.Config // nil disables TLS
DialFunc DialFunc // e.g. net.Dialer.DialContext
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
BuildFrontend BuildFrontendFunc
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp)
Port uint16
Database string
User string
Password string
TLSConfig *tls.Config // nil disables TLS
ConnectTimeout time.Duration
DialFunc DialFunc // e.g. net.Dialer.DialContext
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
BuildFrontend BuildFrontendFunc
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
Fallbacks []*FallbackConfig
@@ -191,12 +192,13 @@ func ParseConfig(connString string) (*Config, error) {
BuildFrontend: makeDefaultBuildFrontendFunc(int(minReadBufferSize)),
}
if connectTimeout, present := settings["connect_timeout"]; present {
dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout)
if connectTimeoutSetting, present := settings["connect_timeout"]; present {
connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err}
}
config.DialFunc = dialFunc
config.ConnectTimeout = connectTimeout
config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
} else {
defaultDialer := makeDefaultDialer()
config.DialFunc = defaultDialer.DialContext
@@ -672,18 +674,21 @@ func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc {
}
}
func makeConnectTimeoutDialFunc(s string) (DialFunc, error) {
func parseConnectTimeoutSetting(s string) (time.Duration, error) {
timeout, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return nil, err
return 0, err
}
if timeout < 0 {
return nil, errors.New("negative timeout")
return 0, errors.New("negative timeout")
}
return time.Duration(timeout) * time.Second, nil
}
func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc {
d := makeDefaultDialer()
d.Timeout = time.Duration(timeout) * time.Second
return d.DialContext, nil
d.Timeout = timeout
return d.DialContext
}
// ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible
+31 -26
View File
@@ -7,6 +7,7 @@ import (
"os"
"os/user"
"testing"
"time"
"github.com/jackc/pgconn"
"github.com/stretchr/testify/assert"
@@ -127,11 +128,11 @@ func TestParseConfig(t *testing.T) {
name: "sslmode verify-ca",
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-ca",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "localhost",
Port: 5432,
Database: "mydb",
User: "jack",
Password: "secret",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
@@ -153,14 +154,15 @@ func TestParseConfig(t *testing.T) {
},
{
name: "database url everything",
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema",
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema&connect_timeout=5",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: nil,
User: "jack",
Password: "secret",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: nil,
ConnectTimeout: 5 * time.Second,
RuntimeParams: map[string]string{
"application_name": "pgxtest",
"search_path": "myschema",
@@ -230,14 +232,15 @@ func TestParseConfig(t *testing.T) {
},
{
name: "DSN everything",
connString: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable application_name=pgxtest search_path=myschema",
connString: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable application_name=pgxtest search_path=myschema connect_timeout=5",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: nil,
User: "jack",
Password: "secret",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: nil,
ConnectTimeout: 5 * time.Second,
RuntimeParams: map[string]string{
"application_name": "pgxtest",
"search_path": "myschema",
@@ -501,6 +504,7 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName
assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName)
assert.Equalf(t, expected.User, actual.User, "%s - User", testName)
assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName)
assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName)
assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
// Can't test function equality, so just test that they are set or not.
@@ -590,13 +594,14 @@ func TestParseConfigEnvLibpq(t *testing.T) {
"PGAPPNAME": "pgxtest",
},
config: &pgconn.Config{
Host: "123.123.123.123",
Port: 7777,
Database: "foo",
User: "bar",
Password: "baz",
TLSConfig: nil,
RuntimeParams: map[string]string{"application_name": "pgxtest"},
Host: "123.123.123.123",
Port: 7777,
Database: "foo",
User: "bar",
Password: "baz",
ConnectTimeout: 10 * time.Second,
TLSConfig: nil,
RuntimeParams: map[string]string{"application_name": "pgxtest"},
},
},
}
+6
View File
@@ -116,6 +116,12 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
panic("config must be created by ParseConfig")
}
// ConnectTimeout restricts the whole connection process.
if config.ConnectTimeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, config.ConnectTimeout)
defer cancel()
}
// Simplify usage by treating primary config and fallbacks the same.
fallbackConfigs := []*FallbackConfig{
{
+71 -43
View File
@@ -17,8 +17,9 @@ import (
"testing"
"time"
"github.com/jackc/pgconn"
"github.com/jackc/pgmock"
"github.com/jackc/pgconn"
"github.com/jackc/pgproto3/v2"
errors "golang.org/x/xerrors"
@@ -81,58 +82,85 @@ func (s pgmockWaitStep) Step(*pgproto3.Backend) error {
return nil
}
func TestConnectWithContextThatTimesOut(t *testing.T) {
func TestConnectTimeout(t *testing.T) {
t.Parallel()
script := &pgmock.Script{
Steps: []pgmock.Step{
pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
pgmock.SendMessage(&pgproto3.AuthenticationOk{}),
pgmockWaitStep(time.Millisecond * 500),
pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
tests := []struct {
name string
connect func(connStr string) error
}{
{
name: "via context that times out",
connect: func(connStr string) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50)
defer cancel()
_, err := pgconn.Connect(ctx, connStr)
return err
},
},
{
name: "via config ConnectTimeout",
connect: func(connStr string) error {
conf, err := pgconn.ParseConfig(connStr)
require.NoError(t, err)
conf.ConnectTimeout = time.Microsecond * 50
_, err = pgconn.ConnectConfig(context.Background(), conf)
return err
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
script := &pgmock.Script{
Steps: []pgmock.Step{
pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
pgmock.SendMessage(&pgproto3.AuthenticationOk{}),
pgmockWaitStep(time.Millisecond * 500),
pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
},
}
ln, err := net.Listen("tcp", "127.0.0.1:")
require.NoError(t, err)
defer ln.Close()
ln, err := net.Listen("tcp", "127.0.0.1:")
require.NoError(t, err)
defer ln.Close()
serverErrChan := make(chan error, 1)
go func() {
defer close(serverErrChan)
serverErrChan := make(chan error, 1)
go func() {
defer close(serverErrChan)
conn, err := ln.Accept()
if err != nil {
serverErrChan <- err
return
}
defer conn.Close()
conn, err := ln.Accept()
if err != nil {
serverErrChan <- err
return
}
defer conn.Close()
err = conn.SetDeadline(time.Now().Add(time.Millisecond * 450))
if err != nil {
serverErrChan <- err
return
}
err = conn.SetDeadline(time.Now().Add(time.Millisecond * 450))
if err != nil {
serverErrChan <- err
return
}
err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn))
if err != nil {
serverErrChan <- err
return
}
}()
err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn))
if err != nil {
serverErrChan <- err
return
}
}()
parts := strings.Split(ln.Addr().String(), ":")
host := parts[0]
port := parts[1]
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
tooLate := time.Now().Add(time.Millisecond * 500)
parts := strings.Split(ln.Addr().String(), ":")
host := parts[0]
port := parts[1]
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
tooLate := time.Now().Add(time.Millisecond * 500)
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50)
defer cancel()
_, err = pgconn.Connect(ctx, connStr)
require.True(t, pgconn.Timeout(err), err)
require.True(t, time.Now().Before(tooLate))
err = tt.connect(connStr)
require.True(t, pgconn.Timeout(err), err)
require.True(t, time.Now().Before(tooLate))
})
}
}
func TestConnectInvalidUser(t *testing.T) {