Fix TLS connection timeout
This commit is contained in:
committed by
Jack Christensen
parent
5a5260b73d
commit
c0a0be876d
@@ -241,13 +241,6 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||||||
|
|
||||||
pgConn.parameterStatuses = make(map[string]string)
|
pgConn.parameterStatuses = make(map[string]string)
|
||||||
|
|
||||||
if fallbackConfig.TLSConfig != nil {
|
|
||||||
if err := pgConn.startTLS(fallbackConfig.TLSConfig); err != nil {
|
|
||||||
pgConn.conn.Close()
|
|
||||||
return nil, &connectError{config: config, msg: "tls error", err: err}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pgConn.status = connStatusConnecting
|
pgConn.status = connStatusConnecting
|
||||||
pgConn.contextWatcher = ctxwatch.NewContextWatcher(
|
pgConn.contextWatcher = ctxwatch.NewContextWatcher(
|
||||||
func() { pgConn.conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) },
|
func() { pgConn.conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) },
|
||||||
@@ -257,6 +250,15 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||||||
pgConn.contextWatcher.Watch(ctx)
|
pgConn.contextWatcher.Watch(ctx)
|
||||||
defer pgConn.contextWatcher.Unwatch()
|
defer pgConn.contextWatcher.Unwatch()
|
||||||
|
|
||||||
|
if fallbackConfig.TLSConfig != nil {
|
||||||
|
tlsConn, err := startTLS(pgConn.conn, fallbackConfig.TLSConfig)
|
||||||
|
if err != nil {
|
||||||
|
pgConn.conn.Close()
|
||||||
|
return nil, &connectError{config: config, msg: "tls error", err: err}
|
||||||
|
}
|
||||||
|
pgConn.conn = tlsConn
|
||||||
|
}
|
||||||
|
|
||||||
pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn)
|
pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn)
|
||||||
|
|
||||||
startupMsg := pgproto3.StartupMessage{
|
startupMsg := pgproto3.StartupMessage{
|
||||||
@@ -344,24 +346,22 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) {
|
func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) {
|
||||||
err = binary.Write(pgConn.conn, binary.BigEndian, []int32{8, 80877103})
|
err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
response := make([]byte, 1)
|
response := make([]byte, 1)
|
||||||
if _, err = io.ReadFull(pgConn.conn, response); err != nil {
|
if _, err = io.ReadFull(conn, response); err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if response[0] != 'S' {
|
if response[0] != 'S' {
|
||||||
return errors.New("server refused TLS connection")
|
return nil, errors.New("server refused TLS connection")
|
||||||
}
|
}
|
||||||
|
|
||||||
pgConn.conn = tls.Client(pgConn.conn, tlsConfig)
|
return tls.Client(conn, tlsConfig), nil
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pgConn *PgConn) txPasswordMessage(password string) (err error) {
|
func (pgConn *PgConn) txPasswordMessage(password string) (err error) {
|
||||||
|
|||||||
@@ -161,6 +161,84 @@ func TestConnectTimeout(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnectTimeoutStuckOnTLSHandshake(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
connect func(connStr string) error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "via context that times out",
|
||||||
|
connect: func(connStr string) error {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10)
|
||||||
|
defer cancel()
|
||||||
|
_, err := pgconn.Connect(ctx, connStr)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "via config ConnectTimeout",
|
||||||
|
connect: func(connStr string) error {
|
||||||
|
conf, err := pgconn.ParseConfig(connStr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
conf.ConnectTimeout = time.Millisecond * 10
|
||||||
|
_, err = pgconn.ConnectConfig(context.Background(), conf)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
serverErrChan := make(chan error)
|
||||||
|
defer close(serverErrChan)
|
||||||
|
go func() {
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
serverErrChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
var buf []byte
|
||||||
|
_, err = conn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
serverErrChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sleeping to hang the TLS handshake.
|
||||||
|
time.Sleep(time.Minute)
|
||||||
|
}()
|
||||||
|
|
||||||
|
parts := strings.Split(ln.Addr().String(), ":")
|
||||||
|
host := parts[0]
|
||||||
|
port := parts[1]
|
||||||
|
connStr := fmt.Sprintf("host=%s port=%s", host, port)
|
||||||
|
|
||||||
|
errChan := make(chan error)
|
||||||
|
go func() {
|
||||||
|
err := tt.connect(connStr)
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err = <-errChan:
|
||||||
|
require.True(t, pgconn.Timeout(err), err)
|
||||||
|
case err = <-serverErrChan:
|
||||||
|
t.Fatalf("server failed with error: %s", err)
|
||||||
|
case <-time.After(time.Millisecond * 100):
|
||||||
|
t.Fatal("exceeded connection timeout without erroring out")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnectInvalidUser(t *testing.T) {
|
func TestConnectInvalidUser(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user