2
0

Context cancellation is fatal during query

This commit is contained in:
Jack Christensen
2019-03-30 16:44:20 -05:00
parent b2fc69d32f
commit 444bd6deaf
5 changed files with 60 additions and 311 deletions
-11
View File
@@ -41,17 +41,6 @@ type Config struct {
// allows implementing high availability behavior such as libpq does with target_session_attrs.
AfterConnectFunc AfterConnectFunc
// OnContextCancel is a callback function used to override cancellation behavior. It is called when a context.Context
// is canceled. Default cancellation behavior is to establish another connection to the PostgreSQL server and send a
// query cancel request. Some non-PostgreSQL servers (e.g. CockroachDB) that speak a subset of the PostgreSQL wire
// protocol do not support this cancellation method.
//
// It is called from a background goroutine. When the cancellation process has finished ContextCancel.Finish must be
// called whether it was successful or not. If an error occurs the connection should be closed. The connection must be
// in a ready for query state or be closed when ContextCancel.Finish is called. Use PgConn.ReceiveMessage() to read
// the connection until a ready for query message is received.
OnContextCancel func(*ContextCancel)
// OnNotice is a callback function called when a notice response is received.
OnNotice NoticeHandler
+5 -5
View File
@@ -20,10 +20,10 @@ result. The ReadAll method reads all query results into memory.
Context Support
All potentially blocking operations take a context.Context. If a context is canceled while a query is in progress the
method immediately returns. In the background a cancel request will be sent to the PostgreSQL server. If the
cancellation fails or hangs for more than a short time (approximately 15 seconds) the connection will be closed. It is
safe to use the connection while this background cancellation is in progress. Any calls will block until the
cancellation and resynchronization is complete (and those calls can be aborted by a context cancellation).
All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the
method immediately returns. In most circumstances, this will close the underlying connection.
The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the
client to abort.
*/
package pgconn
+13 -176
View File
@@ -199,6 +199,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
for {
msg, err := pgConn.ReceiveMessage()
if err != nil {
pgConn.conn.Close()
return nil, err
}
@@ -502,7 +503,7 @@ readloop:
for {
msg, err := pgConn.ReceiveMessage()
if err != nil {
go pgConn.recoverFromTimeout()
pgConn.hardClose()
return nil, preferContextOverNetTimeoutError(ctx, err)
}
@@ -555,10 +556,10 @@ func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice {
return (*Notice)(pgerr)
}
// cancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel
// CancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel
// request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there
// is no way to be sure a query was canceled. See https://www.postgresql.org/docs/11/protocol-flow.html#id-1.10.5.7.9
func (pgConn *PgConn) cancelRequest(ctx context.Context) error {
func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
// Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing
// the connection config. This is important in high availability configurations where fallback connections may be
// specified or DNS may be used to load balance.
@@ -590,21 +591,6 @@ func (pgConn *PgConn) cancelRequest(ctx context.Context) error {
return nil
}
// WaitUntilReady waits until a previous context cancellation has been completed and the connection is ready for use.
// This is done automatically by all methods that need the connection to be ready for use. The only expected use for
// this method is for a connection pool to wait for a returned connection to be usable again before making it available.
func (pgConn *PgConn) WaitUntilReady(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case pgConn.controller <- pgConn:
// The connection must be ready since it was locked. Immediately unlock it.
<-pgConn.controller
}
return nil
}
// WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not
// received.
func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
@@ -778,6 +764,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
case pgConn.controller <- pgConn:
}
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn)
defer cleanupContextDeadline()
// Send copy to command
var buf []byte
@@ -786,7 +773,6 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
_, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
cleanupContextDeadline()
<-pgConn.controller
return "", preferContextOverNetTimeoutError(ctx, err)
@@ -798,13 +784,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
for {
msg, err := pgConn.ReceiveMessage()
if err != nil {
cleanupContextDeadline()
if err, ok := err.(net.Error); ok && err.Timeout() {
go pgConn.recoverFromTimeout()
} else {
<-pgConn.controller
}
pgConn.hardClose()
return "", preferContextOverNetTimeoutError(ctx, err)
}
@@ -813,9 +793,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
case *pgproto3.CopyData:
_, err := w.Write(msg.Data)
if err != nil {
// This isn't actually a timeout, but we want the same behavior. Abort the request and cleanup.
cleanupContextDeadline()
go pgConn.recoverFromTimeout()
pgConn.hardClose()
return "", err
}
case *pgproto3.ReadyForQuery:
@@ -840,6 +818,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
case pgConn.controller <- pgConn:
}
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn)
defer cleanupContextDeadline()
// Send copy to command
var buf []byte
@@ -848,7 +827,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
_, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
cleanupContextDeadline()
<-pgConn.controller
return "", preferContextOverNetTimeoutError(ctx, err)
@@ -861,13 +839,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
for pendingCopyInResponse {
msg, err := pgConn.ReceiveMessage()
if err != nil {
cleanupContextDeadline()
if err, ok := err.(net.Error); ok && err.Timeout() {
go pgConn.recoverFromTimeoutDuringCopyFrom()
} else {
<-pgConn.controller
}
pgConn.hardClose()
return "", preferContextOverNetTimeoutError(ctx, err)
}
@@ -899,7 +871,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
_, err = pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
cleanupContextDeadline()
<-pgConn.controller
return "", preferContextOverNetTimeoutError(ctx, err)
@@ -910,13 +881,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
case <-signalMessageChan:
msg, err := pgConn.ReceiveMessage()
if err != nil {
cleanupContextDeadline()
if err, ok := err.(net.Error); ok && err.Timeout() {
go pgConn.recoverFromTimeoutDuringCopyFrom()
} else {
<-pgConn.controller
}
pgConn.hardClose()
return "", preferContextOverNetTimeoutError(ctx, err)
}
@@ -939,8 +904,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
_, err = pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
cleanupContextDeadline()
<-pgConn.controller
return "", preferContextOverNetTimeoutError(ctx, err)
@@ -950,13 +913,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
for {
msg, err := pgConn.ReceiveMessage()
if err != nil {
cleanupContextDeadline()
if err, ok := err.(net.Error); ok && err.Timeout() {
go pgConn.recoverFromTimeout()
} else {
<-pgConn.controller
}
pgConn.hardClose()
return "", preferContextOverNetTimeoutError(ctx, err)
}
@@ -972,47 +929,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
}
}
func (pgConn *PgConn) recoverFromTimeoutDuringCopyFrom() {
// Regardless of recovery outcome the lock on the pgConn must be released.
defer func() { <-pgConn.controller }()
// Limit time to wait for entire cancellation process.
err := pgConn.conn.SetDeadline(time.Now().Add(15 * time.Second))
if err != nil {
pgConn.hardClose()
return
}
copyFail := &pgproto3.CopyFail{Error: "client cancel"}
buf := copyFail.Encode(nil)
_, err = pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
return
}
pendingReadyForQuery := true
for pendingReadyForQuery {
msg, err := pgConn.ReceiveMessage()
if err != nil {
pgConn.hardClose()
return
}
switch msg.(type) {
case *pgproto3.ReadyForQuery:
pendingReadyForQuery = false
}
}
err = pgConn.conn.SetDeadline(time.Time{})
if err != nil {
pgConn.hardClose()
}
}
// MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch.
type MultiResultReader struct {
pgConn *PgConn
@@ -1044,13 +960,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
mrr.cleanupContextDeadline()
mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err)
mrr.closed = true
if err, ok := err.(net.Error); ok && err.Timeout() {
go mrr.pgConn.recoverFromTimeout()
} else {
<-mrr.pgConn.controller
}
mrr.pgConn.hardClose()
return nil, mrr.err
}
@@ -1236,11 +1146,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error
rr.cleanupContextDeadline()
rr.closed = true
if rr.multiResultReader == nil {
if err, ok := err.(net.Error); ok && err.Timeout() {
go rr.pgConn.recoverFromTimeout()
} else {
<-rr.pgConn.controller
}
rr.pgConn.hardClose()
}
return nil, rr.err
@@ -1270,75 +1176,6 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) {
rr.commandConcluded = true
}
func (pgConn *PgConn) defaultCancel() {
// Regardless of recovery outcome the lock on the pgConn must be released.
defer func() { <-pgConn.controller }()
// Send a cancellation request to the PostgreSQL server. If it is not successful in a reasonable amount of time do not
// try further to recover the connection.
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
err := pgConn.cancelRequest(ctx)
cancel()
if err != nil {
pgConn.hardClose()
return
}
// Limit time to wait for ReadyForQuery message.
err = pgConn.conn.SetDeadline(time.Now().Add(15 * time.Second))
if err != nil {
pgConn.hardClose()
return
}
// A cancel query request will always return a "57014" error response, even if no query was in progress. This error
// may be returned before or after the ReadyForQuery message. Must ensure both messages are read.
needError57014 := true
needReadyForQuery := true
for needError57014 || needReadyForQuery {
msg, err := pgConn.ReceiveMessage()
if err != nil {
pgConn.hardClose()
return
}
switch msg := msg.(type) {
case *pgproto3.ErrorResponse:
if msg.Code == "57014" {
needError57014 = false
}
case *pgproto3.ReadyForQuery:
needReadyForQuery = false
}
}
err = pgConn.conn.SetDeadline(time.Time{})
if err != nil {
pgConn.hardClose()
}
}
type ContextCancel struct {
PgConn *PgConn
}
// Finish must be called when the cancellation request has finished processing. The connection must be in a ready for
// query state or the connection must be closed. This must be called regardless of the success of the cancellation and
// whether the connection is still valid or not. It releases an internal busy lock on the connection.
func (cc *ContextCancel) Finish() {
<-cc.PgConn.controller
}
func (pgConn *PgConn) recoverFromTimeout() {
if pgConn.Config.OnContextCancel == nil {
pgConn.defaultCancel()
} else {
cc := &ContextCancel{PgConn: pgConn}
pgConn.Config.OnContextCancel(cc)
}
}
// Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip.
type Batch struct {
buf []byte
+16 -46
View File
@@ -4,9 +4,9 @@ import (
"context"
"math/rand"
"os"
"runtime"
"strconv"
"testing"
"time"
"github.com/jackc/pgconn"
@@ -14,13 +14,11 @@ import (
)
func TestConnStress(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
actionCount := 100
actionCount := 10000
if s := os.Getenv("PGX_TEST_STRESS_FACTOR"); s != "" {
stressFactor, err := strconv.ParseInt(s, 10, 64)
require.Nil(t, err, "Failed to parse PGX_TEST_STRESS_FACTOR")
@@ -36,9 +34,6 @@ func TestConnStress(t *testing.T) {
{"Exec Select", stressExecSelect},
{"ExecParams Select", stressExecParamsSelect},
{"Batch", stressBatch},
{"ExecCanceled", stressExecSelectCanceled},
{"ExecParamsCanceled", stressExecParamsSelectCanceled},
{"BatchCanceled", stressBatchCanceled},
}
for i := 0; i < actionCount; i++ {
@@ -46,6 +41,10 @@ func TestConnStress(t *testing.T) {
err := action.fn(pgConn)
require.Nilf(t, err, "%d: %s", i, action.name)
}
// Each call with a context starts a goroutine. Ensure they are cleaned up when context is not canceled.
numGoroutine := runtime.NumGoroutine()
require.Truef(t, numGoroutine < 1000, "goroutines appear to be orphaned: %d in process", numGoroutine)
}
func setupStressDB(t *testing.T, pgConn *pgconn.PgConn) {
@@ -65,56 +64,27 @@ func setupStressDB(t *testing.T, pgConn *pgconn.PgConn) {
}
func stressExecSelect(pgConn *pgconn.PgConn) error {
_, err := pgConn.Exec(context.Background(), "select * from widgets").ReadAll()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := pgConn.Exec(ctx, "select * from widgets").ReadAll()
return err
}
func stressExecParamsSelect(pgConn *pgconn.PgConn) error {
result := pgConn.ExecParams(context.Background(), "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
result := pgConn.ExecParams(ctx, "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read()
return result.Err
}
func stressBatch(pgConn *pgconn.PgConn) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
batch := &pgconn.Batch{}
batch.ExecParams("select * from widgets", nil, nil, nil, nil)
batch.ExecParams("select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil)
_, err := pgConn.ExecBatch(context.Background(), batch).ReadAll()
_, err := pgConn.ExecBatch(ctx, batch).ReadAll()
return err
}
func stressExecSelectCanceled(pgConn *pgconn.PgConn) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
_, err := pgConn.Exec(ctx, "select *, pg_sleep(1) from widgets").ReadAll()
cancel()
if err != context.DeadlineExceeded {
return err
}
return nil
}
func stressExecParamsSelectCanceled(pgConn *pgconn.PgConn) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
result := pgConn.ExecParams(ctx, "select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read()
cancel()
if result.Err != context.DeadlineExceeded {
return result.Err
}
return nil
}
func stressBatchCanceled(pgConn *pgconn.PgConn) error {
batch := &pgconn.Batch{}
batch.ExecParams("select * from widgets", nil, nil, nil, nil)
batch.ExecParams("select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
_, err := pgConn.ExecBatch(ctx, batch).ReadAll()
cancel()
if err != context.DeadlineExceeded {
return err
}
return nil
}
+26 -73
View File
@@ -16,7 +16,6 @@ import (
"time"
"github.com/jackc/pgconn"
"github.com/jackc/pgproto3"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
@@ -356,8 +355,7 @@ func TestConnExecContextCanceled(t *testing.T) {
}
err = multiResult.Close()
assert.Equal(t, context.DeadlineExceeded, err)
ensureConnValid(t, pgConn)
assert.False(t, pgConn.IsAlive())
}
func TestConnExecParams(t *testing.T) {
@@ -400,7 +398,7 @@ func TestConnExecParamsCanceled(t *testing.T) {
assert.Equal(t, pgconn.CommandTag(""), commandTag)
assert.Equal(t, context.DeadlineExceeded, err)
ensureConnValid(t, pgConn)
assert.False(t, pgConn.IsAlive())
}
func TestConnExecPrepared(t *testing.T) {
@@ -451,8 +449,7 @@ func TestConnExecPreparedCanceled(t *testing.T) {
commandTag, err := result.Close()
assert.Equal(t, pgconn.CommandTag(""), commandTag)
assert.Equal(t, context.DeadlineExceeded, err)
ensureConnValid(t, pgConn)
assert.False(t, pgConn.IsAlive())
}
func TestConnExecBatch(t *testing.T) {
@@ -510,72 +507,6 @@ func TestCommandTag(t *testing.T) {
}
}
func TestConnContextCancelWithOnContextCancel(t *testing.T) {
t.Parallel()
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
calledChan := make(chan struct{})
config.OnContextCancel = func(cc *pgconn.ContextCancel) {
defer cc.Finish()
close(calledChan)
for {
msg, err := cc.PgConn.ReceiveMessage()
if err != nil {
cc.PgConn.Close(context.Background())
return
}
switch msg.(type) {
case *pgproto3.ReadyForQuery:
return
}
}
}
pgConn, err := pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
defer closeConn(t, pgConn)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
result := pgConn.ExecParams(ctx, "select 'Hello, world', pg_sleep(0.25)", nil, nil, nil, nil)
_, err = result.Close()
assert.Equal(t, context.DeadlineExceeded, err)
called := false
select {
case <-calledChan:
called = true
case <-time.NewTimer(time.Second).C:
}
assert.True(t, called)
ensureConnValid(t, pgConn)
}
func TestConnWaitUntilReady(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
result := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil).Read()
assert.Equal(t, context.DeadlineExceeded, result.Err)
err = pgConn.WaitUntilReady(context.Background())
require.NoError(t, err)
ensureConnValid(t, pgConn)
}
func TestConnOnNotice(t *testing.T) {
t.Parallel()
@@ -792,7 +723,7 @@ func TestConnCopyToCanceled(t *testing.T) {
assert.Equal(t, context.DeadlineExceeded, err)
assert.Equal(t, pgconn.CommandTag(""), res)
ensureConnValid(t, pgConn)
assert.False(t, pgConn.IsAlive())
}
func TestConnCopyFrom(t *testing.T) {
@@ -991,6 +922,28 @@ func TestConnEscapeString(t *testing.T) {
ensureConnValid(t, pgConn)
}
func TestConnCancelRequest(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
multiResult := pgConn.Exec(context.Background(), "select 'Hello, world', pg_sleep(5)")
err = pgConn.CancelRequest(context.Background())
require.NoError(t, err)
for multiResult.NextResult() {
}
err = multiResult.Close()
require.IsType(t, &pgconn.PgError{}, err)
require.Equal(t, "57014", err.(*pgconn.PgError).Code)
ensureConnValid(t, pgConn)
}
func Example() {
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
if err != nil {