diff --git a/config.go b/config.go index 1cde9c57..d392924c 100644 --- a/config.go +++ b/config.go @@ -41,17 +41,6 @@ type Config struct { // allows implementing high availability behavior such as libpq does with target_session_attrs. AfterConnectFunc AfterConnectFunc - // OnContextCancel is a callback function used to override cancellation behavior. It is called when a context.Context - // is canceled. Default cancellation behavior is to establish another connection to the PostgreSQL server and send a - // query cancel request. Some non-PostgreSQL servers (e.g. CockroachDB) that speak a subset of the PostgreSQL wire - // protocol do not support this cancellation method. - // - // It is called from a background goroutine. When the cancellation process has finished ContextCancel.Finish must be - // called whether it was successful or not. If an error occurs the connection should be closed. The connection must be - // in a ready for query state or be closed when ContextCancel.Finish is called. Use PgConn.ReceiveMessage() to read - // the connection until a ready for query message is received. - OnContextCancel func(*ContextCancel) - // OnNotice is a callback function called when a notice response is received. OnNotice NoticeHandler diff --git a/doc.go b/doc.go index 89e47536..d36eb0fd 100644 --- a/doc.go +++ b/doc.go @@ -20,10 +20,10 @@ result. The ReadAll method reads all query results into memory. Context Support -All potentially blocking operations take a context.Context. If a context is canceled while a query is in progress the -method immediately returns. In the background a cancel request will be sent to the PostgreSQL server. If the -cancellation fails or hangs for more than a short time (approximately 15 seconds) the connection will be closed. It is -safe to use the connection while this background cancellation is in progress. Any calls will block until the -cancellation and resynchronization is complete (and those calls can be aborted by a context cancellation). +All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the +method immediately returns. In most circumstances, this will close the underlying connection. + +The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the +client to abort. */ package pgconn diff --git a/pgconn.go b/pgconn.go index 6490617a..8b0ddcb4 100644 --- a/pgconn.go +++ b/pgconn.go @@ -199,6 +199,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig for { msg, err := pgConn.ReceiveMessage() if err != nil { + pgConn.conn.Close() return nil, err } @@ -502,7 +503,7 @@ readloop: for { msg, err := pgConn.ReceiveMessage() if err != nil { - go pgConn.recoverFromTimeout() + pgConn.hardClose() return nil, preferContextOverNetTimeoutError(ctx, err) } @@ -555,10 +556,10 @@ func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice { return (*Notice)(pgerr) } -// cancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel +// CancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel // request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there // is no way to be sure a query was canceled. See https://www.postgresql.org/docs/11/protocol-flow.html#id-1.10.5.7.9 -func (pgConn *PgConn) cancelRequest(ctx context.Context) error { +func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. @@ -590,21 +591,6 @@ func (pgConn *PgConn) cancelRequest(ctx context.Context) error { return nil } -// WaitUntilReady waits until a previous context cancellation has been completed and the connection is ready for use. -// This is done automatically by all methods that need the connection to be ready for use. The only expected use for -// this method is for a connection pool to wait for a returned connection to be usable again before making it available. -func (pgConn *PgConn) WaitUntilReady(ctx context.Context) error { - select { - case <-ctx.Done(): - return ctx.Err() - case pgConn.controller <- pgConn: - // The connection must be ready since it was locked. Immediately unlock it. - <-pgConn.controller - } - - return nil -} - // WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not // received. func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { @@ -778,6 +764,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm case pgConn.controller <- pgConn: } cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) + defer cleanupContextDeadline() // Send copy to command var buf []byte @@ -786,7 +773,6 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm _, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - cleanupContextDeadline() <-pgConn.controller return "", preferContextOverNetTimeoutError(ctx, err) @@ -798,13 +784,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm for { msg, err := pgConn.ReceiveMessage() if err != nil { - cleanupContextDeadline() - if err, ok := err.(net.Error); ok && err.Timeout() { - go pgConn.recoverFromTimeout() - } else { - <-pgConn.controller - } - + pgConn.hardClose() return "", preferContextOverNetTimeoutError(ctx, err) } @@ -813,9 +793,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm case *pgproto3.CopyData: _, err := w.Write(msg.Data) if err != nil { - // This isn't actually a timeout, but we want the same behavior. Abort the request and cleanup. - cleanupContextDeadline() - go pgConn.recoverFromTimeout() + pgConn.hardClose() return "", err } case *pgproto3.ReadyForQuery: @@ -840,6 +818,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co case pgConn.controller <- pgConn: } cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) + defer cleanupContextDeadline() // Send copy to command var buf []byte @@ -848,7 +827,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - cleanupContextDeadline() <-pgConn.controller return "", preferContextOverNetTimeoutError(ctx, err) @@ -861,13 +839,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co for pendingCopyInResponse { msg, err := pgConn.ReceiveMessage() if err != nil { - cleanupContextDeadline() - if err, ok := err.(net.Error); ok && err.Timeout() { - go pgConn.recoverFromTimeoutDuringCopyFrom() - } else { - <-pgConn.controller - } - + pgConn.hardClose() return "", preferContextOverNetTimeoutError(ctx, err) } @@ -899,7 +871,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - cleanupContextDeadline() <-pgConn.controller return "", preferContextOverNetTimeoutError(ctx, err) @@ -910,13 +881,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co case <-signalMessageChan: msg, err := pgConn.ReceiveMessage() if err != nil { - cleanupContextDeadline() - if err, ok := err.(net.Error); ok && err.Timeout() { - go pgConn.recoverFromTimeoutDuringCopyFrom() - } else { - <-pgConn.controller - } - + pgConn.hardClose() return "", preferContextOverNetTimeoutError(ctx, err) } @@ -939,8 +904,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - - cleanupContextDeadline() <-pgConn.controller return "", preferContextOverNetTimeoutError(ctx, err) @@ -950,13 +913,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co for { msg, err := pgConn.ReceiveMessage() if err != nil { - cleanupContextDeadline() - if err, ok := err.(net.Error); ok && err.Timeout() { - go pgConn.recoverFromTimeout() - } else { - <-pgConn.controller - } - + pgConn.hardClose() return "", preferContextOverNetTimeoutError(ctx, err) } @@ -972,47 +929,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } } -func (pgConn *PgConn) recoverFromTimeoutDuringCopyFrom() { - // Regardless of recovery outcome the lock on the pgConn must be released. - defer func() { <-pgConn.controller }() - - // Limit time to wait for entire cancellation process. - err := pgConn.conn.SetDeadline(time.Now().Add(15 * time.Second)) - if err != nil { - pgConn.hardClose() - return - } - - copyFail := &pgproto3.CopyFail{Error: "client cancel"} - buf := copyFail.Encode(nil) - - _, err = pgConn.conn.Write(buf) - if err != nil { - pgConn.hardClose() - return - } - - pendingReadyForQuery := true - - for pendingReadyForQuery { - msg, err := pgConn.ReceiveMessage() - if err != nil { - pgConn.hardClose() - return - } - - switch msg.(type) { - case *pgproto3.ReadyForQuery: - pendingReadyForQuery = false - } - } - - err = pgConn.conn.SetDeadline(time.Time{}) - if err != nil { - pgConn.hardClose() - } -} - // MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. type MultiResultReader struct { pgConn *PgConn @@ -1044,13 +960,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) mrr.cleanupContextDeadline() mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) mrr.closed = true - - if err, ok := err.(net.Error); ok && err.Timeout() { - go mrr.pgConn.recoverFromTimeout() - } else { - <-mrr.pgConn.controller - } - + mrr.pgConn.hardClose() return nil, mrr.err } @@ -1236,11 +1146,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error rr.cleanupContextDeadline() rr.closed = true if rr.multiResultReader == nil { - if err, ok := err.(net.Error); ok && err.Timeout() { - go rr.pgConn.recoverFromTimeout() - } else { - <-rr.pgConn.controller - } + rr.pgConn.hardClose() } return nil, rr.err @@ -1270,75 +1176,6 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { rr.commandConcluded = true } -func (pgConn *PgConn) defaultCancel() { - // Regardless of recovery outcome the lock on the pgConn must be released. - defer func() { <-pgConn.controller }() - - // Send a cancellation request to the PostgreSQL server. If it is not successful in a reasonable amount of time do not - // try further to recover the connection. - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) - err := pgConn.cancelRequest(ctx) - cancel() - if err != nil { - pgConn.hardClose() - return - } - - // Limit time to wait for ReadyForQuery message. - err = pgConn.conn.SetDeadline(time.Now().Add(15 * time.Second)) - if err != nil { - pgConn.hardClose() - return - } - - // A cancel query request will always return a "57014" error response, even if no query was in progress. This error - // may be returned before or after the ReadyForQuery message. Must ensure both messages are read. - needError57014 := true - needReadyForQuery := true - - for needError57014 || needReadyForQuery { - msg, err := pgConn.ReceiveMessage() - if err != nil { - pgConn.hardClose() - return - } - - switch msg := msg.(type) { - case *pgproto3.ErrorResponse: - if msg.Code == "57014" { - needError57014 = false - } - case *pgproto3.ReadyForQuery: - needReadyForQuery = false - } - } - - err = pgConn.conn.SetDeadline(time.Time{}) - if err != nil { - pgConn.hardClose() - } -} - -type ContextCancel struct { - PgConn *PgConn -} - -// Finish must be called when the cancellation request has finished processing. The connection must be in a ready for -// query state or the connection must be closed. This must be called regardless of the success of the cancellation and -// whether the connection is still valid or not. It releases an internal busy lock on the connection. -func (cc *ContextCancel) Finish() { - <-cc.PgConn.controller -} - -func (pgConn *PgConn) recoverFromTimeout() { - if pgConn.Config.OnContextCancel == nil { - pgConn.defaultCancel() - } else { - cc := &ContextCancel{PgConn: pgConn} - pgConn.Config.OnContextCancel(cc) - } -} - // Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip. type Batch struct { buf []byte diff --git a/pgconn_stress_test.go b/pgconn_stress_test.go index 1ebbe04a..7288c9b4 100644 --- a/pgconn_stress_test.go +++ b/pgconn_stress_test.go @@ -4,9 +4,9 @@ import ( "context" "math/rand" "os" + "runtime" "strconv" "testing" - "time" "github.com/jackc/pgconn" @@ -14,13 +14,11 @@ import ( ) func TestConnStress(t *testing.T) { - t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer closeConn(t, pgConn) - actionCount := 100 + actionCount := 10000 if s := os.Getenv("PGX_TEST_STRESS_FACTOR"); s != "" { stressFactor, err := strconv.ParseInt(s, 10, 64) require.Nil(t, err, "Failed to parse PGX_TEST_STRESS_FACTOR") @@ -36,9 +34,6 @@ func TestConnStress(t *testing.T) { {"Exec Select", stressExecSelect}, {"ExecParams Select", stressExecParamsSelect}, {"Batch", stressBatch}, - {"ExecCanceled", stressExecSelectCanceled}, - {"ExecParamsCanceled", stressExecParamsSelectCanceled}, - {"BatchCanceled", stressBatchCanceled}, } for i := 0; i < actionCount; i++ { @@ -46,6 +41,10 @@ func TestConnStress(t *testing.T) { err := action.fn(pgConn) require.Nilf(t, err, "%d: %s", i, action.name) } + + // Each call with a context starts a goroutine. Ensure they are cleaned up when context is not canceled. + numGoroutine := runtime.NumGoroutine() + require.Truef(t, numGoroutine < 1000, "goroutines appear to be orphaned: %d in process", numGoroutine) } func setupStressDB(t *testing.T, pgConn *pgconn.PgConn) { @@ -65,56 +64,27 @@ func setupStressDB(t *testing.T, pgConn *pgconn.PgConn) { } func stressExecSelect(pgConn *pgconn.PgConn) error { - _, err := pgConn.Exec(context.Background(), "select * from widgets").ReadAll() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, err := pgConn.Exec(ctx, "select * from widgets").ReadAll() return err } func stressExecParamsSelect(pgConn *pgconn.PgConn) error { - result := pgConn.ExecParams(context.Background(), "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + result := pgConn.ExecParams(ctx, "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read() return result.Err } func stressBatch(pgConn *pgconn.PgConn) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + batch := &pgconn.Batch{} batch.ExecParams("select * from widgets", nil, nil, nil, nil) batch.ExecParams("select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) - _, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() + _, err := pgConn.ExecBatch(ctx, batch).ReadAll() return err } - -func stressExecSelectCanceled(pgConn *pgconn.PgConn) error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) - _, err := pgConn.Exec(ctx, "select *, pg_sleep(1) from widgets").ReadAll() - cancel() - if err != context.DeadlineExceeded { - return err - } - - return nil -} - -func stressExecParamsSelectCanceled(pgConn *pgconn.PgConn) error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) - result := pgConn.ExecParams(ctx, "select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read() - cancel() - if result.Err != context.DeadlineExceeded { - return result.Err - } - - return nil -} - -func stressBatchCanceled(pgConn *pgconn.PgConn) error { - batch := &pgconn.Batch{} - batch.ExecParams("select * from widgets", nil, nil, nil, nil) - batch.ExecParams("select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) - _, err := pgConn.ExecBatch(ctx, batch).ReadAll() - cancel() - if err != context.DeadlineExceeded { - return err - } - - return nil -} diff --git a/pgconn_test.go b/pgconn_test.go index 716761ad..88c6f7c4 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -16,7 +16,6 @@ import ( "time" "github.com/jackc/pgconn" - "github.com/jackc/pgproto3" "github.com/pkg/errors" "github.com/stretchr/testify/assert" @@ -356,8 +355,7 @@ func TestConnExecContextCanceled(t *testing.T) { } err = multiResult.Close() assert.Equal(t, context.DeadlineExceeded, err) - - ensureConnValid(t, pgConn) + assert.False(t, pgConn.IsAlive()) } func TestConnExecParams(t *testing.T) { @@ -400,7 +398,7 @@ func TestConnExecParamsCanceled(t *testing.T) { assert.Equal(t, pgconn.CommandTag(""), commandTag) assert.Equal(t, context.DeadlineExceeded, err) - ensureConnValid(t, pgConn) + assert.False(t, pgConn.IsAlive()) } func TestConnExecPrepared(t *testing.T) { @@ -451,8 +449,7 @@ func TestConnExecPreparedCanceled(t *testing.T) { commandTag, err := result.Close() assert.Equal(t, pgconn.CommandTag(""), commandTag) assert.Equal(t, context.DeadlineExceeded, err) - - ensureConnValid(t, pgConn) + assert.False(t, pgConn.IsAlive()) } func TestConnExecBatch(t *testing.T) { @@ -510,72 +507,6 @@ func TestCommandTag(t *testing.T) { } } -func TestConnContextCancelWithOnContextCancel(t *testing.T) { - t.Parallel() - - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - - calledChan := make(chan struct{}) - - config.OnContextCancel = func(cc *pgconn.ContextCancel) { - defer cc.Finish() - close(calledChan) - - for { - msg, err := cc.PgConn.ReceiveMessage() - if err != nil { - cc.PgConn.Close(context.Background()) - return - } - - switch msg.(type) { - case *pgproto3.ReadyForQuery: - return - } - } - } - - pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, pgConn) - - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - result := pgConn.ExecParams(ctx, "select 'Hello, world', pg_sleep(0.25)", nil, nil, nil, nil) - _, err = result.Close() - assert.Equal(t, context.DeadlineExceeded, err) - - called := false - select { - case <-calledChan: - called = true - case <-time.NewTimer(time.Second).C: - } - - assert.True(t, called) - - ensureConnValid(t, pgConn) -} - -func TestConnWaitUntilReady(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - result := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil).Read() - assert.Equal(t, context.DeadlineExceeded, result.Err) - - err = pgConn.WaitUntilReady(context.Background()) - require.NoError(t, err) - - ensureConnValid(t, pgConn) -} - func TestConnOnNotice(t *testing.T) { t.Parallel() @@ -792,7 +723,7 @@ func TestConnCopyToCanceled(t *testing.T) { assert.Equal(t, context.DeadlineExceeded, err) assert.Equal(t, pgconn.CommandTag(""), res) - ensureConnValid(t, pgConn) + assert.False(t, pgConn.IsAlive()) } func TestConnCopyFrom(t *testing.T) { @@ -991,6 +922,28 @@ func TestConnEscapeString(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnCancelRequest(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + multiResult := pgConn.Exec(context.Background(), "select 'Hello, world', pg_sleep(5)") + + err = pgConn.CancelRequest(context.Background()) + require.NoError(t, err) + + for multiResult.NextResult() { + } + err = multiResult.Close() + + require.IsType(t, &pgconn.PgError{}, err) + require.Equal(t, "57014", err.(*pgconn.PgError).Code) + + ensureConnValid(t, pgConn) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) if err != nil {