From c0a0be876d02652626514ed97b0349f245e3bf76 Mon Sep 17 00:00:00 2001 From: Blake Embrey Date: Wed, 22 Dec 2021 08:33:10 -0800 Subject: [PATCH] Fix TLS connection timeout --- pgconn.go | 32 ++++++++++----------- pgconn_test.go | 78 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 16 deletions(-) diff --git a/pgconn.go b/pgconn.go index dad522c6..7e5d585b 100644 --- a/pgconn.go +++ b/pgconn.go @@ -241,13 +241,6 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.parameterStatuses = make(map[string]string) - if fallbackConfig.TLSConfig != nil { - if err := pgConn.startTLS(fallbackConfig.TLSConfig); err != nil { - pgConn.conn.Close() - return nil, &connectError{config: config, msg: "tls error", err: err} - } - } - pgConn.status = connStatusConnecting pgConn.contextWatcher = ctxwatch.NewContextWatcher( func() { pgConn.conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, @@ -257,6 +250,15 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() + if fallbackConfig.TLSConfig != nil { + tlsConn, err := startTLS(pgConn.conn, fallbackConfig.TLSConfig) + if err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "tls error", err: err} + } + pgConn.conn = tlsConn + } + pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn) startupMsg := pgproto3.StartupMessage{ @@ -344,24 +346,22 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } } -func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { - err = binary.Write(pgConn.conn, binary.BigEndian, []int32{8, 80877103}) +func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { + err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103}) if err != nil { - return + return nil, err } response := make([]byte, 1) - if _, err = io.ReadFull(pgConn.conn, response); err != nil { - return + if _, err = io.ReadFull(conn, response); err != nil { + return nil, err } if response[0] != 'S' { - return errors.New("server refused TLS connection") + return nil, errors.New("server refused TLS connection") } - pgConn.conn = tls.Client(pgConn.conn, tlsConfig) - - return nil + return tls.Client(conn, tlsConfig), nil } func (pgConn *PgConn) txPasswordMessage(password string) (err error) { diff --git a/pgconn_test.go b/pgconn_test.go index 43e97eef..b22792fb 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -161,6 +161,84 @@ func TestConnectTimeout(t *testing.T) { } } +func TestConnectTimeoutStuckOnTLSHandshake(t *testing.T) { + t.Parallel() + tests := []struct { + name string + connect func(connStr string) error + }{ + { + name: "via context that times out", + connect: func(connStr string) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10) + defer cancel() + _, err := pgconn.Connect(ctx, connStr) + return err + }, + }, + { + name: "via config ConnectTimeout", + connect: func(connStr string) error { + conf, err := pgconn.ParseConfig(connStr) + require.NoError(t, err) + conf.ConnectTimeout = time.Millisecond * 10 + _, err = pgconn.ConnectConfig(context.Background(), conf) + return err + }, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverErrChan := make(chan error) + defer close(serverErrChan) + go func() { + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + var buf []byte + _, err = conn.Read(buf) + if err != nil { + serverErrChan <- err + return + } + + // Sleeping to hang the TLS handshake. + time.Sleep(time.Minute) + }() + + parts := strings.Split(ln.Addr().String(), ":") + host := parts[0] + port := parts[1] + connStr := fmt.Sprintf("host=%s port=%s", host, port) + + errChan := make(chan error) + go func() { + err := tt.connect(connStr) + errChan <- err + }() + + select { + case err = <-errChan: + require.True(t, pgconn.Timeout(err), err) + case err = <-serverErrChan: + t.Fatalf("server failed with error: %s", err) + case <-time.After(time.Millisecond * 100): + t.Fatal("exceeded connection timeout without erroring out") + } + }) + } +} + func TestConnectInvalidUser(t *testing.T) { t.Parallel()