diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 311b06a3..a7c4eea3 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -562,8 +562,14 @@ func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool { } pgConn.resetBatch() + // Clear any existing timeout pgConn.NetConn.SetDeadline(time.Time{}) + // Try to cancel any in-progress requests + for i := 0; i < int(pgConn.pendingReadyForQueryCount); i++ { + pgConn.CancelRequest(ctx) + } + cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn) defer cleanupContext() @@ -669,3 +675,38 @@ func errorResponseToPgError(msg *pgproto3.ErrorResponse) PgError { Routine: msg.Routine, } } + +// CancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel +// request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there +// is no way to be sure a query was canceled. See https://www.postgresql.org/docs/11/protocol-flow.html#id-1.10.5.7.9 +func (pgConn *PgConn) CancelRequest(ctx context.Context) error { + // Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing + // the connection config. This is important in high availability configurations where fallback connections may be + // specified or DNS may be used to load balance. + serverAddr := pgConn.NetConn.RemoteAddr() + cancelConn, err := pgConn.Config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) + if err != nil { + return err + } + defer cancelConn.Close() + + cleanupContext := contextDoneToConnDeadline(ctx, cancelConn) + defer cleanupContext() + + buf := make([]byte, 16) + binary.BigEndian.PutUint32(buf[0:4], 16) + binary.BigEndian.PutUint32(buf[4:8], 80877102) + binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.PID)) + binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.SecretKey)) + _, err = cancelConn.Write(buf) + if err != nil { + return preferContextOverNetTimeoutError(ctx, err) + } + + _, err = cancelConn.Read(buf) + if err != io.EOF { + return fmt.Errorf("Server failed to close connection after cancel query request: %v", preferContextOverNetTimeoutError(ctx, err)) + } + + return nil +} diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 98fd198e..9873013c 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -264,6 +264,8 @@ func TestConnExecContextCanceled(t *testing.T) { result, err := pgConn.Exec(ctx, "select current_database(), pg_sleep(1)") require.Nil(t, result) assert.Equal(t, context.DeadlineExceeded, err) + + assert.True(t, pgConn.RecoverFromTimeout(context.Background())) } func TestConnRecoverFromTimeout(t *testing.T) { @@ -287,3 +289,23 @@ func TestConnRecoverFromTimeout(t *testing.T) { } cancel() } + +func TestConnCancelQuery(t *testing.T) { + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + pgConn.SendExec("select current_database(), pg_sleep(5)") + err = pgConn.Flush(context.Background()) + require.Nil(t, err) + + err = pgConn.CancelRequest(context.Background()) + require.Nil(t, err) + + _, err = pgConn.GetResult(context.Background()).Close() + if err, ok := err.(pgconn.PgError); ok { + assert.Equal(t, "57014", err.Code) + } else { + t.Errorf("expected pgconn.PgError got %v", err) + } +}