diff --git a/pgconn.go b/pgconn.go index 7e5d585b..6fde4e50 100644 --- a/pgconn.go +++ b/pgconn.go @@ -230,7 +230,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig var err error network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) - pgConn.conn, err = config.DialFunc(ctx, network, address) + netConn, err := config.DialFunc(ctx, network, address) if err != nil { var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { @@ -239,26 +239,28 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig return nil, &connectError{config: config, msg: "dial error", err: err} } - pgConn.parameterStatuses = make(map[string]string) - - pgConn.status = connStatusConnecting - pgConn.contextWatcher = ctxwatch.NewContextWatcher( - func() { pgConn.conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, - func() { pgConn.conn.SetDeadline(time.Time{}) }, - ) - + pgConn.contextWatcher = contextWatcher(netConn) pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() + pgConn.status = connStatusConnecting + pgConn.conn = netConn + if fallbackConfig.TLSConfig != nil { - tlsConn, err := startTLS(pgConn.conn, fallbackConfig.TLSConfig) + tlsConn, err := startTLS(netConn, fallbackConfig.TLSConfig) if err != nil { - pgConn.conn.Close() + netConn.Close() return nil, &connectError{config: config, msg: "tls error", err: err} } + + pgConn.contextWatcher.Unwatch() + pgConn.contextWatcher = contextWatcher(tlsConn) + pgConn.contextWatcher.Watch(ctx) + pgConn.conn = tlsConn } + pgConn.parameterStatuses = make(map[string]string) pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn) startupMsg := pgproto3.StartupMessage{ @@ -346,6 +348,13 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } } +func contextWatcher(conn net.Conn) *ctxwatch.ContextWatcher { + return ctxwatch.NewContextWatcher( + func() { conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, + func() { conn.SetDeadline(time.Time{}) }, + ) +} + func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103}) if err != nil { @@ -1709,10 +1718,7 @@ func Construct(hc *HijackedConn) (*PgConn, error) { cleanupDone: make(chan struct{}), } - pgConn.contextWatcher = ctxwatch.NewContextWatcher( - func() { pgConn.conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, - func() { pgConn.conn.SetDeadline(time.Time{}) }, - ) + pgConn.contextWatcher = contextWatcher(pgConn.conn) return pgConn, nil }