Parse connect_timeout into Dial func
Instead of adding Timeout field which could conflict with custom Dial func.
This commit is contained in:
@@ -72,7 +72,6 @@ type ConnConfig struct {
|
|||||||
Logger Logger
|
Logger Logger
|
||||||
LogLevel int
|
LogLevel int
|
||||||
Dial DialFunc
|
Dial DialFunc
|
||||||
Timeout time.Duration
|
|
||||||
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
||||||
OnNotice NoticeHandler // Callback function called when a notice response is received.
|
OnNotice NoticeHandler // Callback function called when a notice response is received.
|
||||||
CustomConnInfo func(*Conn) (*pgtype.ConnInfo, error) // Callback function to implement connection strategies for different backends. crate, pgbouncer, pgpool, etc.
|
CustomConnInfo func(*Conn) (*pgtype.ConnInfo, error) // Callback function to implement connection strategies for different backends. crate, pgbouncer, pgpool, etc.
|
||||||
@@ -224,6 +223,10 @@ func Connect(config ConnConfig) (c *Conn, err error) {
|
|||||||
return connect(config, minimalConnInfo)
|
return connect(config, minimalConnInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func defaultDialer() *net.Dialer {
|
||||||
|
return &net.Dialer{KeepAlive: 5 * time.Minute}
|
||||||
|
}
|
||||||
|
|
||||||
func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) {
|
func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) {
|
||||||
c = new(Conn)
|
c = new(Conn)
|
||||||
|
|
||||||
@@ -260,7 +263,8 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error)
|
|||||||
|
|
||||||
network, address := c.config.networkAddress()
|
network, address := c.config.networkAddress()
|
||||||
if c.config.Dial == nil {
|
if c.config.Dial == nil {
|
||||||
c.config.Dial = (&net.Dialer{Timeout: c.config.Timeout, KeepAlive: 5 * time.Minute}).Dial
|
d := defaultDialer()
|
||||||
|
c.config.Dial = d.Dial
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.shouldLog(LogLevelInfo) {
|
if c.shouldLog(LogLevelInfo) {
|
||||||
@@ -692,7 +696,9 @@ func ParseURI(uri string) (ConnConfig, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return cp, err
|
return cp, err
|
||||||
}
|
}
|
||||||
cp.Timeout = time.Duration(timeout) * time.Second
|
d := defaultDialer()
|
||||||
|
d.Timeout = time.Duration(timeout) * time.Second
|
||||||
|
cp.Dial = d.Dial
|
||||||
}
|
}
|
||||||
|
|
||||||
err = configSSL(url.Query().Get("sslmode"), &cp)
|
err = configSSL(url.Query().Get("sslmode"), &cp)
|
||||||
@@ -761,11 +767,13 @@ func ParseDSN(s string) (ConnConfig, error) {
|
|||||||
case "sslmode":
|
case "sslmode":
|
||||||
sslmode = b[2]
|
sslmode = b[2]
|
||||||
case "connect_timeout":
|
case "connect_timeout":
|
||||||
t, err := strconv.ParseInt(b[2], 10, 64)
|
timeout, err := strconv.ParseInt(b[2], 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cp, err
|
return cp, err
|
||||||
}
|
}
|
||||||
cp.Timeout = time.Duration(t) * time.Second
|
d := defaultDialer()
|
||||||
|
d.Timeout = time.Duration(timeout) * time.Second
|
||||||
|
cp.Dial = d.Dial
|
||||||
default:
|
default:
|
||||||
cp.RuntimeParams[b[1]] = b[2]
|
cp.RuntimeParams[b[1]] = b[2]
|
||||||
}
|
}
|
||||||
@@ -841,7 +849,9 @@ func ParseEnvLibpq() (ConnConfig, error) {
|
|||||||
|
|
||||||
if pgtimeout := os.Getenv("PGCONNECT_TIMEOUT"); pgtimeout != "" {
|
if pgtimeout := os.Getenv("PGCONNECT_TIMEOUT"); pgtimeout != "" {
|
||||||
if timeout, err := strconv.ParseInt(pgtimeout, 10, 64); err == nil {
|
if timeout, err := strconv.ParseInt(pgtimeout, 10, 64); err == nil {
|
||||||
cc.Timeout = time.Duration(timeout) * time.Second
|
d := defaultDialer()
|
||||||
|
d.Timeout = time.Duration(timeout) * time.Second
|
||||||
|
cc.Dial = d.Dial
|
||||||
} else {
|
} else {
|
||||||
return cc, err
|
return cc, err
|
||||||
}
|
}
|
||||||
|
|||||||
+74
-70
@@ -576,7 +576,7 @@ func TestParseDSN(t *testing.T) {
|
|||||||
TLSConfig: &tls.Config{
|
TLSConfig: &tls.Config{
|
||||||
InsecureSkipVerify: true,
|
InsecureSkipVerify: true,
|
||||||
},
|
},
|
||||||
Timeout: 10 * time.Second,
|
Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial,
|
||||||
UseFallbackTLS: true,
|
UseFallbackTLS: true,
|
||||||
FallbackTLSConfig: nil,
|
FallbackTLSConfig: nil,
|
||||||
RuntimeParams: map[string]string{},
|
RuntimeParams: map[string]string{},
|
||||||
@@ -585,15 +585,13 @@ func TestParseDSN(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
connParams, err := pgx.ParseDSN(tt.url)
|
actual, err := pgx.ParseDSN(tt.url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("%d. Unexpected error from pgx.ParseDSN(%q) => %v", i, tt.url, err)
|
t.Errorf("%d. Unexpected error from pgx.ParseDSN(%q) => %v", i, tt.url, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(connParams, tt.connParams) {
|
testConnConfigEquals(t, tt.connParams, actual, strconv.Itoa(i))
|
||||||
t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -721,7 +719,7 @@ func TestParseConnectionString(t *testing.T) {
|
|||||||
TLSConfig: &tls.Config{
|
TLSConfig: &tls.Config{
|
||||||
InsecureSkipVerify: true,
|
InsecureSkipVerify: true,
|
||||||
},
|
},
|
||||||
Timeout: 10 * time.Second,
|
Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial,
|
||||||
UseFallbackTLS: true,
|
UseFallbackTLS: true,
|
||||||
FallbackTLSConfig: nil,
|
FallbackTLSConfig: nil,
|
||||||
RuntimeParams: map[string]string{},
|
RuntimeParams: map[string]string{},
|
||||||
@@ -819,16 +817,80 @@ func TestParseConnectionString(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
connParams, err := pgx.ParseConnectionString(tt.url)
|
actual, err := pgx.ParseConnectionString(tt.url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("%d. Unexpected error from pgx.ParseDSN(%q) => %v", i, tt.url, err)
|
t.Errorf("%d. Unexpected error from pgx.ParseDSN(%q) => %v", i, tt.url, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(connParams, tt.connParams) {
|
testConnConfigEquals(t, tt.connParams, actual, strconv.Itoa(i))
|
||||||
t.Errorf("%d. expected %#v got %#v", i, tt.connParams, connParams)
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testConnConfigEquals(t *testing.T, expected pgx.ConnConfig, actual pgx.ConnConfig, testName string) {
|
||||||
|
if actual.Host != expected.Host {
|
||||||
|
t.Errorf("%s: expected Host to be %v got %v", testName, expected.Host, actual.Host)
|
||||||
|
}
|
||||||
|
if actual.Port != expected.Port {
|
||||||
|
t.Errorf("%s: expected Port to be %v got %v", testName, expected.Port, actual.Port)
|
||||||
|
}
|
||||||
|
if actual.Port != expected.Port {
|
||||||
|
t.Errorf("%s: expected Port to be %v got %v", testName, expected.Port, actual.Port)
|
||||||
|
}
|
||||||
|
if actual.User != expected.User {
|
||||||
|
t.Errorf("%s: expected User to be %v got %v", testName, expected.User, actual.User)
|
||||||
|
}
|
||||||
|
if actual.Password != expected.Password {
|
||||||
|
t.Errorf("%s: expected Password to be %v got %v", testName, expected.Password, actual.Password)
|
||||||
|
}
|
||||||
|
// Cannot test value of underlying Dialer stuct but can at least test if Dial func is set.
|
||||||
|
if (actual.Dial != nil) != (expected.Dial != nil) {
|
||||||
|
t.Errorf("%s: expected Dial mismatch", testName)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(actual.RuntimeParams, expected.RuntimeParams) {
|
||||||
|
t.Errorf("%s: expected RuntimeParams to be %#v got %#v", testName, expected.RuntimeParams, actual.RuntimeParams)
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsTests := []struct {
|
||||||
|
name string
|
||||||
|
expected *tls.Config
|
||||||
|
actual *tls.Config
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "TLSConfig",
|
||||||
|
expected: expected.TLSConfig,
|
||||||
|
actual: actual.TLSConfig,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "FallbackTLSConfig",
|
||||||
|
expected: expected.FallbackTLSConfig,
|
||||||
|
actual: actual.FallbackTLSConfig,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tlsTest := range tlsTests {
|
||||||
|
name := tlsTest.name
|
||||||
|
expected := tlsTest.expected
|
||||||
|
actual := tlsTest.actual
|
||||||
|
|
||||||
|
if expected == nil && actual != nil {
|
||||||
|
t.Errorf("%s / %s: expected nil, but it was set", testName, name)
|
||||||
|
} else if expected != nil && actual == nil {
|
||||||
|
t.Errorf("%s / %s: expected to be set, but got nil", testName, name)
|
||||||
|
} else if expected != nil && actual != nil {
|
||||||
|
if actual.InsecureSkipVerify != expected.InsecureSkipVerify {
|
||||||
|
t.Errorf("%s / %s: expected InsecureSkipVerify to be %v got %v", testName, name, expected.InsecureSkipVerify, actual.InsecureSkipVerify)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.ServerName != expected.ServerName {
|
||||||
|
t.Errorf("%s / %s: expected ServerName to be %v got %v", testName, name, expected.ServerName, actual.ServerName)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if actual.UseFallbackTLS != expected.UseFallbackTLS {
|
||||||
|
t.Errorf("%s: expected UseFallbackTLS to be %v got %v", testName, expected.UseFallbackTLS, actual.UseFallbackTLS)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseEnvLibpq(t *testing.T) {
|
func TestParseEnvLibpq(t *testing.T) {
|
||||||
@@ -881,7 +943,7 @@ func TestParseEnvLibpq(t *testing.T) {
|
|||||||
TLSConfig: &tls.Config{InsecureSkipVerify: true},
|
TLSConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
UseFallbackTLS: true,
|
UseFallbackTLS: true,
|
||||||
FallbackTLSConfig: nil,
|
FallbackTLSConfig: nil,
|
||||||
Timeout: 10 * time.Second,
|
Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial,
|
||||||
RuntimeParams: map[string]string{},
|
RuntimeParams: map[string]string{},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -997,71 +1059,13 @@ func TestParseEnvLibpq(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := pgx.ParseEnvLibpq()
|
actual, err := pgx.ParseEnvLibpq()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("%s: Unexpected error from pgx.ParseLibpq() => %v", tt.name, err)
|
t.Errorf("%s: Unexpected error from pgx.ParseLibpq() => %v", tt.name, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.Host != tt.config.Host {
|
testConnConfigEquals(t, tt.config, actual, tt.name)
|
||||||
t.Errorf("%s: expected Host to be %v got %v", tt.name, tt.config.Host, config.Host)
|
|
||||||
}
|
|
||||||
if config.Port != tt.config.Port {
|
|
||||||
t.Errorf("%s: expected Port to be %v got %v", tt.name, tt.config.Port, config.Port)
|
|
||||||
}
|
|
||||||
if config.Port != tt.config.Port {
|
|
||||||
t.Errorf("%s: expected Port to be %v got %v", tt.name, tt.config.Port, config.Port)
|
|
||||||
}
|
|
||||||
if config.User != tt.config.User {
|
|
||||||
t.Errorf("%s: expected User to be %v got %v", tt.name, tt.config.User, config.User)
|
|
||||||
}
|
|
||||||
if config.Password != tt.config.Password {
|
|
||||||
t.Errorf("%s: expected Password to be %v got %v", tt.name, tt.config.Password, config.Password)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(config.RuntimeParams, tt.config.RuntimeParams) {
|
|
||||||
t.Errorf("%s: expected RuntimeParams to be %#v got %#v", tt.name, tt.config.RuntimeParams, config.RuntimeParams)
|
|
||||||
}
|
|
||||||
|
|
||||||
tlsTests := []struct {
|
|
||||||
name string
|
|
||||||
expected *tls.Config
|
|
||||||
actual *tls.Config
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "TLSConfig",
|
|
||||||
expected: tt.config.TLSConfig,
|
|
||||||
actual: config.TLSConfig,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "FallbackTLSConfig",
|
|
||||||
expected: tt.config.FallbackTLSConfig,
|
|
||||||
actual: config.FallbackTLSConfig,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tlsTest := range tlsTests {
|
|
||||||
name := tlsTest.name
|
|
||||||
expected := tlsTest.expected
|
|
||||||
actual := tlsTest.actual
|
|
||||||
|
|
||||||
if expected == nil && actual != nil {
|
|
||||||
t.Errorf("%s / %s: expected nil, but it was set", tt.name, name)
|
|
||||||
} else if expected != nil && actual == nil {
|
|
||||||
t.Errorf("%s / %s: expected to be set, but got nil", tt.name, name)
|
|
||||||
} else if expected != nil && actual != nil {
|
|
||||||
if actual.InsecureSkipVerify != expected.InsecureSkipVerify {
|
|
||||||
t.Errorf("%s / %s: expected InsecureSkipVerify to be %v got %v", tt.name, name, expected.InsecureSkipVerify, actual.InsecureSkipVerify)
|
|
||||||
}
|
|
||||||
|
|
||||||
if actual.ServerName != expected.ServerName {
|
|
||||||
t.Errorf("%s / %s: expected ServerName to be %v got %v", tt.name, name, expected.ServerName, actual.ServerName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.UseFallbackTLS != tt.config.UseFallbackTLS {
|
|
||||||
t.Errorf("%s: expected UseFallbackTLS to be %v got %v", tt.name, tt.config.UseFallbackTLS, config.UseFallbackTLS)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user