Merge branch 'improve-connect-timeout' of git://github.com/georgysavva/pgconn into georgysavva-improve-connect-timeout
This commit is contained in:
@@ -30,16 +30,17 @@ type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
|||||||
// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig and
|
// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig and
|
||||||
// then it can be modified. A manually initialized Config will cause ConnectConfig to panic.
|
// then it can be modified. A manually initialized Config will cause ConnectConfig to panic.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp)
|
Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp)
|
||||||
Port uint16
|
Port uint16
|
||||||
Database string
|
Database string
|
||||||
User string
|
User string
|
||||||
Password string
|
Password string
|
||||||
TLSConfig *tls.Config // nil disables TLS
|
TLSConfig *tls.Config // nil disables TLS
|
||||||
DialFunc DialFunc // e.g. net.Dialer.DialContext
|
ConnectTimeout time.Duration
|
||||||
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
|
DialFunc DialFunc // e.g. net.Dialer.DialContext
|
||||||
BuildFrontend BuildFrontendFunc
|
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
|
||||||
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
BuildFrontend BuildFrontendFunc
|
||||||
|
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
||||||
|
|
||||||
Fallbacks []*FallbackConfig
|
Fallbacks []*FallbackConfig
|
||||||
|
|
||||||
@@ -191,12 +192,13 @@ func ParseConfig(connString string) (*Config, error) {
|
|||||||
BuildFrontend: makeDefaultBuildFrontendFunc(int(minReadBufferSize)),
|
BuildFrontend: makeDefaultBuildFrontendFunc(int(minReadBufferSize)),
|
||||||
}
|
}
|
||||||
|
|
||||||
if connectTimeout, present := settings["connect_timeout"]; present {
|
if connectTimeoutSetting, present := settings["connect_timeout"]; present {
|
||||||
dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout)
|
connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err}
|
return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err}
|
||||||
}
|
}
|
||||||
config.DialFunc = dialFunc
|
config.ConnectTimeout = connectTimeout
|
||||||
|
config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
|
||||||
} else {
|
} else {
|
||||||
defaultDialer := makeDefaultDialer()
|
defaultDialer := makeDefaultDialer()
|
||||||
config.DialFunc = defaultDialer.DialContext
|
config.DialFunc = defaultDialer.DialContext
|
||||||
@@ -672,18 +674,21 @@ func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeConnectTimeoutDialFunc(s string) (DialFunc, error) {
|
func parseConnectTimeoutSetting(s string) (time.Duration, error) {
|
||||||
timeout, err := strconv.ParseInt(s, 10, 64)
|
timeout, err := strconv.ParseInt(s, 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return 0, err
|
||||||
}
|
}
|
||||||
if timeout < 0 {
|
if timeout < 0 {
|
||||||
return nil, errors.New("negative timeout")
|
return 0, errors.New("negative timeout")
|
||||||
}
|
}
|
||||||
|
return time.Duration(timeout) * time.Second, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc {
|
||||||
d := makeDefaultDialer()
|
d := makeDefaultDialer()
|
||||||
d.Timeout = time.Duration(timeout) * time.Second
|
d.Timeout = timeout
|
||||||
return d.DialContext, nil
|
return d.DialContext
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible
|
// ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible
|
||||||
|
|||||||
+31
-26
@@ -7,6 +7,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"os/user"
|
"os/user"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/jackc/pgconn"
|
"github.com/jackc/pgconn"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -127,11 +128,11 @@ func TestParseConfig(t *testing.T) {
|
|||||||
name: "sslmode verify-ca",
|
name: "sslmode verify-ca",
|
||||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-ca",
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-ca",
|
||||||
config: &pgconn.Config{
|
config: &pgconn.Config{
|
||||||
User: "jack",
|
User: "jack",
|
||||||
Password: "secret",
|
Password: "secret",
|
||||||
Host: "localhost",
|
Host: "localhost",
|
||||||
Port: 5432,
|
Port: 5432,
|
||||||
Database: "mydb",
|
Database: "mydb",
|
||||||
TLSConfig: &tls.Config{
|
TLSConfig: &tls.Config{
|
||||||
InsecureSkipVerify: true,
|
InsecureSkipVerify: true,
|
||||||
},
|
},
|
||||||
@@ -153,14 +154,15 @@ func TestParseConfig(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "database url everything",
|
name: "database url everything",
|
||||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema",
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema&connect_timeout=5",
|
||||||
config: &pgconn.Config{
|
config: &pgconn.Config{
|
||||||
User: "jack",
|
User: "jack",
|
||||||
Password: "secret",
|
Password: "secret",
|
||||||
Host: "localhost",
|
Host: "localhost",
|
||||||
Port: 5432,
|
Port: 5432,
|
||||||
Database: "mydb",
|
Database: "mydb",
|
||||||
TLSConfig: nil,
|
TLSConfig: nil,
|
||||||
|
ConnectTimeout: 5 * time.Second,
|
||||||
RuntimeParams: map[string]string{
|
RuntimeParams: map[string]string{
|
||||||
"application_name": "pgxtest",
|
"application_name": "pgxtest",
|
||||||
"search_path": "myschema",
|
"search_path": "myschema",
|
||||||
@@ -230,14 +232,15 @@ func TestParseConfig(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "DSN everything",
|
name: "DSN everything",
|
||||||
connString: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable application_name=pgxtest search_path=myschema",
|
connString: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable application_name=pgxtest search_path=myschema connect_timeout=5",
|
||||||
config: &pgconn.Config{
|
config: &pgconn.Config{
|
||||||
User: "jack",
|
User: "jack",
|
||||||
Password: "secret",
|
Password: "secret",
|
||||||
Host: "localhost",
|
Host: "localhost",
|
||||||
Port: 5432,
|
Port: 5432,
|
||||||
Database: "mydb",
|
Database: "mydb",
|
||||||
TLSConfig: nil,
|
TLSConfig: nil,
|
||||||
|
ConnectTimeout: 5 * time.Second,
|
||||||
RuntimeParams: map[string]string{
|
RuntimeParams: map[string]string{
|
||||||
"application_name": "pgxtest",
|
"application_name": "pgxtest",
|
||||||
"search_path": "myschema",
|
"search_path": "myschema",
|
||||||
@@ -501,6 +504,7 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName
|
|||||||
assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName)
|
assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName)
|
||||||
assert.Equalf(t, expected.User, actual.User, "%s - User", testName)
|
assert.Equalf(t, expected.User, actual.User, "%s - User", testName)
|
||||||
assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName)
|
assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName)
|
||||||
|
assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName)
|
||||||
assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
|
assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
|
||||||
|
|
||||||
// Can't test function equality, so just test that they are set or not.
|
// Can't test function equality, so just test that they are set or not.
|
||||||
@@ -590,13 +594,14 @@ func TestParseConfigEnvLibpq(t *testing.T) {
|
|||||||
"PGAPPNAME": "pgxtest",
|
"PGAPPNAME": "pgxtest",
|
||||||
},
|
},
|
||||||
config: &pgconn.Config{
|
config: &pgconn.Config{
|
||||||
Host: "123.123.123.123",
|
Host: "123.123.123.123",
|
||||||
Port: 7777,
|
Port: 7777,
|
||||||
Database: "foo",
|
Database: "foo",
|
||||||
User: "bar",
|
User: "bar",
|
||||||
Password: "baz",
|
Password: "baz",
|
||||||
TLSConfig: nil,
|
ConnectTimeout: 10 * time.Second,
|
||||||
RuntimeParams: map[string]string{"application_name": "pgxtest"},
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{"application_name": "pgxtest"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -116,6 +116,12 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
|
|||||||
panic("config must be created by ParseConfig")
|
panic("config must be created by ParseConfig")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ConnectTimeout restricts the whole connection process.
|
||||||
|
if config.ConnectTimeout != 0 {
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
ctx, cancel = context.WithTimeout(ctx, config.ConnectTimeout)
|
||||||
|
defer cancel()
|
||||||
|
}
|
||||||
// Simplify usage by treating primary config and fallbacks the same.
|
// Simplify usage by treating primary config and fallbacks the same.
|
||||||
fallbackConfigs := []*FallbackConfig{
|
fallbackConfigs := []*FallbackConfig{
|
||||||
{
|
{
|
||||||
|
|||||||
+71
-43
@@ -17,8 +17,9 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jackc/pgconn"
|
|
||||||
"github.com/jackc/pgmock"
|
"github.com/jackc/pgmock"
|
||||||
|
|
||||||
|
"github.com/jackc/pgconn"
|
||||||
"github.com/jackc/pgproto3/v2"
|
"github.com/jackc/pgproto3/v2"
|
||||||
errors "golang.org/x/xerrors"
|
errors "golang.org/x/xerrors"
|
||||||
|
|
||||||
@@ -81,58 +82,85 @@ func (s pgmockWaitStep) Step(*pgproto3.Backend) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnectWithContextThatTimesOut(t *testing.T) {
|
func TestConnectTimeout(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
tests := []struct {
|
||||||
script := &pgmock.Script{
|
name string
|
||||||
Steps: []pgmock.Step{
|
connect func(connStr string) error
|
||||||
pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
|
}{
|
||||||
pgmock.SendMessage(&pgproto3.AuthenticationOk{}),
|
{
|
||||||
pgmockWaitStep(time.Millisecond * 500),
|
name: "via context that times out",
|
||||||
pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
|
connect: func(connStr string) error {
|
||||||
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
|
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50)
|
||||||
|
defer cancel()
|
||||||
|
_, err := pgconn.Connect(ctx, connStr)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "via config ConnectTimeout",
|
||||||
|
connect: func(connStr string) error {
|
||||||
|
conf, err := pgconn.ParseConfig(connStr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
conf.ConnectTimeout = time.Microsecond * 50
|
||||||
|
_, err = pgconn.ConnectConfig(context.Background(), conf)
|
||||||
|
return err
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
script := &pgmock.Script{
|
||||||
|
Steps: []pgmock.Step{
|
||||||
|
pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
|
||||||
|
pgmock.SendMessage(&pgproto3.AuthenticationOk{}),
|
||||||
|
pgmockWaitStep(time.Millisecond * 500),
|
||||||
|
pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
|
||||||
|
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
ln, err := net.Listen("tcp", "127.0.0.1:")
|
ln, err := net.Listen("tcp", "127.0.0.1:")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer ln.Close()
|
defer ln.Close()
|
||||||
|
|
||||||
serverErrChan := make(chan error, 1)
|
serverErrChan := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
defer close(serverErrChan)
|
defer close(serverErrChan)
|
||||||
|
|
||||||
conn, err := ln.Accept()
|
conn, err := ln.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
serverErrChan <- err
|
serverErrChan <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
err = conn.SetDeadline(time.Now().Add(time.Millisecond * 450))
|
err = conn.SetDeadline(time.Now().Add(time.Millisecond * 450))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
serverErrChan <- err
|
serverErrChan <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn))
|
err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
serverErrChan <- err
|
serverErrChan <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
parts := strings.Split(ln.Addr().String(), ":")
|
parts := strings.Split(ln.Addr().String(), ":")
|
||||||
host := parts[0]
|
host := parts[0]
|
||||||
port := parts[1]
|
port := parts[1]
|
||||||
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
|
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
|
||||||
tooLate := time.Now().Add(time.Millisecond * 500)
|
tooLate := time.Now().Add(time.Millisecond * 500)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50)
|
err = tt.connect(connStr)
|
||||||
defer cancel()
|
require.True(t, pgconn.Timeout(err), err)
|
||||||
_, err = pgconn.Connect(ctx, connStr)
|
require.True(t, time.Now().Before(tooLate))
|
||||||
require.True(t, pgconn.Timeout(err), err)
|
})
|
||||||
require.True(t, time.Now().Before(tooLate))
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnectInvalidUser(t *testing.T) {
|
func TestConnectInvalidUser(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user