Parse connect timeout setting into Config. Restrict context timeout via Config.ConnectTimeout on .Connect() call.
This commit is contained in:
@@ -30,16 +30,17 @@ type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
||||
// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig and
|
||||
// then it can be modified. A manually initialized Config will cause ConnectConfig to panic.
|
||||
type Config struct {
|
||||
Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp)
|
||||
Port uint16
|
||||
Database string
|
||||
User string
|
||||
Password string
|
||||
TLSConfig *tls.Config // nil disables TLS
|
||||
DialFunc DialFunc // e.g. net.Dialer.DialContext
|
||||
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
|
||||
BuildFrontend BuildFrontendFunc
|
||||
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
||||
Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp)
|
||||
Port uint16
|
||||
Database string
|
||||
User string
|
||||
Password string
|
||||
TLSConfig *tls.Config // nil disables TLS
|
||||
ConnectTimeout time.Duration
|
||||
DialFunc DialFunc // e.g. net.Dialer.DialContext
|
||||
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
|
||||
BuildFrontend BuildFrontendFunc
|
||||
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
||||
|
||||
Fallbacks []*FallbackConfig
|
||||
|
||||
@@ -191,12 +192,13 @@ func ParseConfig(connString string) (*Config, error) {
|
||||
BuildFrontend: makeDefaultBuildFrontendFunc(int(minReadBufferSize)),
|
||||
}
|
||||
|
||||
if connectTimeout, present := settings["connect_timeout"]; present {
|
||||
dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout)
|
||||
if connectTimeoutSetting, present := settings["connect_timeout"]; present {
|
||||
connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
|
||||
if err != nil {
|
||||
return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err}
|
||||
}
|
||||
config.DialFunc = dialFunc
|
||||
config.ConnectTimeout = connectTimeout
|
||||
config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
|
||||
} else {
|
||||
defaultDialer := makeDefaultDialer()
|
||||
config.DialFunc = defaultDialer.DialContext
|
||||
@@ -672,18 +674,21 @@ func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func makeConnectTimeoutDialFunc(s string) (DialFunc, error) {
|
||||
func parseConnectTimeoutSetting(s string) (time.Duration, error) {
|
||||
timeout, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return 0, err
|
||||
}
|
||||
if timeout < 0 {
|
||||
return nil, errors.New("negative timeout")
|
||||
return 0, errors.New("negative timeout")
|
||||
}
|
||||
return time.Duration(timeout) * time.Second, nil
|
||||
}
|
||||
|
||||
func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc {
|
||||
d := makeDefaultDialer()
|
||||
d.Timeout = time.Duration(timeout) * time.Second
|
||||
return d.DialContext, nil
|
||||
d.Timeout = timeout
|
||||
return d.DialContext
|
||||
}
|
||||
|
||||
// ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible
|
||||
|
||||
+31
-26
@@ -7,6 +7,7 @@ import (
|
||||
"os"
|
||||
"os/user"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -127,11 +128,11 @@ func TestParseConfig(t *testing.T) {
|
||||
name: "sslmode verify-ca",
|
||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-ca",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
@@ -153,14 +154,15 @@ func TestParseConfig(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "database url everything",
|
||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema",
|
||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema&connect_timeout=5",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
RuntimeParams: map[string]string{
|
||||
"application_name": "pgxtest",
|
||||
"search_path": "myschema",
|
||||
@@ -230,14 +232,15 @@ func TestParseConfig(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "DSN everything",
|
||||
connString: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable application_name=pgxtest search_path=myschema",
|
||||
connString: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable application_name=pgxtest search_path=myschema connect_timeout=5",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
RuntimeParams: map[string]string{
|
||||
"application_name": "pgxtest",
|
||||
"search_path": "myschema",
|
||||
@@ -501,6 +504,7 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName
|
||||
assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName)
|
||||
assert.Equalf(t, expected.User, actual.User, "%s - User", testName)
|
||||
assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName)
|
||||
assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName)
|
||||
assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
|
||||
|
||||
// Can't test function equality, so just test that they are set or not.
|
||||
@@ -590,13 +594,14 @@ func TestParseConfigEnvLibpq(t *testing.T) {
|
||||
"PGAPPNAME": "pgxtest",
|
||||
},
|
||||
config: &pgconn.Config{
|
||||
Host: "123.123.123.123",
|
||||
Port: 7777,
|
||||
Database: "foo",
|
||||
User: "bar",
|
||||
Password: "baz",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{"application_name": "pgxtest"},
|
||||
Host: "123.123.123.123",
|
||||
Port: 7777,
|
||||
Database: "foo",
|
||||
User: "bar",
|
||||
Password: "baz",
|
||||
ConnectTimeout: 10 * time.Second,
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{"application_name": "pgxtest"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -116,6 +116,11 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
|
||||
panic("config must be created by ParseConfig")
|
||||
}
|
||||
|
||||
if config.ConnectTimeout != 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, config.ConnectTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
// Simplify usage by treating primary config and fallbacks the same.
|
||||
fallbackConfigs := []*FallbackConfig{
|
||||
{
|
||||
|
||||
+70
-43
@@ -6,6 +6,7 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"github.com/jackc/pgmock"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
@@ -18,7 +19,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgmock"
|
||||
"github.com/jackc/pgproto3/v2"
|
||||
errors "golang.org/x/xerrors"
|
||||
|
||||
@@ -81,58 +81,85 @@ func (s pgmockWaitStep) Step(*pgproto3.Backend) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestConnectWithContextThatTimesOut(t *testing.T) {
|
||||
func TestConnectTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
script := &pgmock.Script{
|
||||
Steps: []pgmock.Step{
|
||||
pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
|
||||
pgmock.SendMessage(&pgproto3.AuthenticationOk{}),
|
||||
pgmockWaitStep(time.Millisecond * 500),
|
||||
pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
|
||||
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
|
||||
tests := []struct {
|
||||
name string
|
||||
connect func(connStr string) error
|
||||
}{
|
||||
{
|
||||
name: "via context that times out",
|
||||
connect: func(connStr string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50)
|
||||
defer cancel()
|
||||
_, err := pgconn.Connect(ctx, connStr)
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "via config ConnectTimeout",
|
||||
connect: func(connStr string) error {
|
||||
conf, err := pgconn.ParseConfig(connStr)
|
||||
require.NoError(t, err)
|
||||
conf.ConnectTimeout = time.Microsecond * 50
|
||||
_, err = pgconn.ConnectConfig(context.Background(), conf)
|
||||
return err
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
script := &pgmock.Script{
|
||||
Steps: []pgmock.Step{
|
||||
pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
|
||||
pgmock.SendMessage(&pgproto3.AuthenticationOk{}),
|
||||
pgmockWaitStep(time.Millisecond * 500),
|
||||
pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
|
||||
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
|
||||
},
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:")
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:")
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
serverErrChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(serverErrChan)
|
||||
serverErrChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(serverErrChan)
|
||||
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
err = conn.SetDeadline(time.Now().Add(time.Millisecond * 450))
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
err = conn.SetDeadline(time.Now().Add(time.Millisecond * 450))
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn))
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
}()
|
||||
err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn))
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
parts := strings.Split(ln.Addr().String(), ":")
|
||||
host := parts[0]
|
||||
port := parts[1]
|
||||
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
|
||||
tooLate := time.Now().Add(time.Millisecond * 500)
|
||||
parts := strings.Split(ln.Addr().String(), ":")
|
||||
host := parts[0]
|
||||
port := parts[1]
|
||||
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
|
||||
tooLate := time.Now().Add(time.Millisecond * 500)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50)
|
||||
defer cancel()
|
||||
_, err = pgconn.Connect(ctx, connStr)
|
||||
require.True(t, pgconn.Timeout(err), err)
|
||||
require.True(t, time.Now().Before(tooLate))
|
||||
err = tt.connect(connStr)
|
||||
require.True(t, pgconn.Timeout(err), err)
|
||||
require.True(t, time.Now().Before(tooLate))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectInvalidUser(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user