Fix context query cancellation
Previous commits had a race condition due to not waiting for the PostgreSQL server to close the cancel query connection. This made it possible for the cancel request to impact a subsequent query on the same connection. This commit sets a flag that a cancel request was made and blocks until the PostgreSQL server closes the cancel connection.
This commit is contained in:
@@ -93,7 +93,9 @@ type Conn struct {
|
|||||||
status int32 // One of connStatus* constants
|
status int32 // One of connStatus* constants
|
||||||
causeOfDeath error
|
causeOfDeath error
|
||||||
|
|
||||||
readyForQuery bool // can the connection be used to send a query
|
readyForQuery bool // connection has received ReadyForQuery message since last query was sent
|
||||||
|
cancelQueryInProgress int32
|
||||||
|
cancelQueryCompleted chan struct{}
|
||||||
|
|
||||||
// context support
|
// context support
|
||||||
ctxInProgress bool
|
ctxInProgress bool
|
||||||
@@ -268,6 +270,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
|||||||
c.channels = make(map[string]struct{})
|
c.channels = make(map[string]struct{})
|
||||||
atomic.StoreInt32(&c.status, connStatusIdle)
|
atomic.StoreInt32(&c.status, connStatusIdle)
|
||||||
c.lastActivityTime = time.Now()
|
c.lastActivityTime = time.Now()
|
||||||
|
c.cancelQueryCompleted = make(chan struct{}, 1)
|
||||||
c.doneChan = make(chan struct{})
|
c.doneChan = make(chan struct{})
|
||||||
c.closedChan = make(chan error)
|
c.closedChan = make(chan error)
|
||||||
|
|
||||||
@@ -634,10 +637,15 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
|
|||||||
// name and sql arguments. This allows a code path to PrepareEx and Query/Exec without
|
// name and sql arguments. This allows a code path to PrepareEx and Query/Exec without
|
||||||
// concern for if the statement has already been prepared.
|
// concern for if the statement has already been prepared.
|
||||||
func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
|
func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
|
||||||
return c.prepareEx(name, sql, opts)
|
return c.PrepareExContext(context.Background(), name, sql, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
|
func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
|
||||||
|
err = c.waitForPreviousCancelQuery(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
err = c.initContext(ctx)
|
err = c.initContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -743,7 +751,25 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Deallocate released a prepared statement
|
// Deallocate released a prepared statement
|
||||||
func (c *Conn) Deallocate(name string) (err error) {
|
func (c *Conn) Deallocate(name string) error {
|
||||||
|
return c.deallocateContext(context.Background(), name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO - consider making this public
|
||||||
|
func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) {
|
||||||
|
err = c.waitForPreviousCancelQuery(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.initContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err = c.termContext(err)
|
||||||
|
}()
|
||||||
|
|
||||||
if err := c.ensureConnectionReadyForQuery(); err != nil {
|
if err := c.ensureConnectionReadyForQuery(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -818,6 +844,13 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error)
|
|||||||
return notification, nil
|
return notification, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx, cancelFn := context.WithTimeout(context.Background(), timeout)
|
||||||
|
if err := c.waitForPreviousCancelQuery(ctx); err != nil {
|
||||||
|
cancelFn()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cancelFn()
|
||||||
|
|
||||||
if err := c.ensureConnectionReadyForQuery(); err != nil {
|
if err := c.ensureConnectionReadyForQuery(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1318,21 +1351,55 @@ func quoteIdentifier(s string) string {
|
|||||||
// ensure that the query was canceled. As specified in the documentation, there
|
// ensure that the query was canceled. As specified in the documentation, there
|
||||||
// is no way to be sure a query was canceled. See
|
// is no way to be sure a query was canceled. See
|
||||||
// https://www.postgresql.org/docs/current/static/protocol-flow.html#AEN112861
|
// https://www.postgresql.org/docs/current/static/protocol-flow.html#AEN112861
|
||||||
func (c *Conn) cancelQuery() error {
|
func (c *Conn) cancelQuery() {
|
||||||
network, address := c.config.networkAddress()
|
if !atomic.CompareAndSwapInt32(&c.cancelQueryInProgress, 0, 1) {
|
||||||
cancelConn, err := c.config.Dial(network, address)
|
panic("cancelQuery when cancelQueryInProgress")
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
defer cancelConn.Close()
|
|
||||||
|
|
||||||
buf := make([]byte, 16)
|
if err := c.conn.SetDeadline(time.Now()); err != nil {
|
||||||
binary.BigEndian.PutUint32(buf[0:4], 16)
|
c.Close() // Close connection if unable to set deadline
|
||||||
binary.BigEndian.PutUint32(buf[4:8], 80877102)
|
return
|
||||||
binary.BigEndian.PutUint32(buf[8:12], uint32(c.Pid))
|
}
|
||||||
binary.BigEndian.PutUint32(buf[12:16], uint32(c.SecretKey))
|
|
||||||
_, err = cancelConn.Write(buf)
|
doCancel := func() error {
|
||||||
return err
|
network, address := c.config.networkAddress()
|
||||||
|
cancelConn, err := c.config.Dial(network, address)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer cancelConn.Close()
|
||||||
|
|
||||||
|
// If server doesn't process cancellation request in bounded time then abort.
|
||||||
|
err = cancelConn.SetDeadline(time.Now().Add(15 * time.Second))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, 16)
|
||||||
|
binary.BigEndian.PutUint32(buf[0:4], 16)
|
||||||
|
binary.BigEndian.PutUint32(buf[4:8], 80877102)
|
||||||
|
binary.BigEndian.PutUint32(buf[8:12], uint32(c.Pid))
|
||||||
|
binary.BigEndian.PutUint32(buf[12:16], uint32(c.SecretKey))
|
||||||
|
_, err = cancelConn.Write(buf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = cancelConn.Read(buf)
|
||||||
|
if err != io.EOF {
|
||||||
|
return fmt.Errorf("Server failed to close connection after cancel query request: %v %v", err, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
err := doCancel()
|
||||||
|
if err != nil {
|
||||||
|
c.Close() // Something is very wrong. Terminate the connection.
|
||||||
|
}
|
||||||
|
c.cancelQueryCompleted <- struct{}{}
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Ping() error {
|
func (c *Conn) Ping() error {
|
||||||
@@ -1345,6 +1412,11 @@ func (c *Conn) PingContext(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
|
func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||||
|
err = c.waitForPreviousCancelQuery(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
err = c.initContext(ctx)
|
err = c.initContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -1438,9 +1510,6 @@ func (c *Conn) termContext(opErr error) error {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case err = <-c.closedChan:
|
case err = <-c.closedChan:
|
||||||
if dlErr := c.conn.SetDeadline(time.Time{}); dlErr != nil {
|
|
||||||
c.Close() // Close connection if unable to disable deadline
|
|
||||||
}
|
|
||||||
if opErr == nil {
|
if opErr == nil {
|
||||||
err = nil
|
err = nil
|
||||||
}
|
}
|
||||||
@@ -1456,14 +1525,29 @@ func (c *Conn) contextHandler(ctx context.Context) {
|
|||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
c.cancelQuery()
|
c.cancelQuery()
|
||||||
if err := c.conn.SetDeadline(time.Now()); err != nil {
|
|
||||||
c.Close() // Close connection if unable to set deadline
|
|
||||||
}
|
|
||||||
c.closedChan <- ctx.Err()
|
c.closedChan <- ctx.Err()
|
||||||
case <-c.doneChan:
|
case <-c.doneChan:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error {
|
||||||
|
if atomic.LoadInt32(&c.cancelQueryInProgress) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-c.cancelQueryCompleted:
|
||||||
|
atomic.StoreInt32(&c.cancelQueryInProgress, 0)
|
||||||
|
if err := c.conn.SetDeadline(time.Time{}); err != nil {
|
||||||
|
c.Close() // Close connection if unable to disable deadline
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Conn) ensureConnectionReadyForQuery() error {
|
func (c *Conn) ensureConnectionReadyForQuery() error {
|
||||||
for !c.readyForQuery {
|
for !c.readyForQuery {
|
||||||
t, r, err := c.rxMsg()
|
t, r, err := c.rxMsg()
|
||||||
|
|||||||
@@ -419,6 +419,11 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (rows *Rows, err error) {
|
func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (rows *Rows, err error) {
|
||||||
|
err = c.waitForPreviousCancelQuery(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
c.lastActivityTime = time.Now()
|
c.lastActivityTime = time.Now()
|
||||||
|
|
||||||
rows = c.getRows(sql, args)
|
rows = c.getRows(sql, args)
|
||||||
|
|||||||
+7
-7
@@ -66,7 +66,7 @@ func TestStressConnPool(t *testing.T) {
|
|||||||
action := actions[rand.Intn(len(actions))]
|
action := actions[rand.Intn(len(actions))]
|
||||||
err := action.fn(pool, n)
|
err := action.fn(pool, n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- err
|
errChan <- fmt.Errorf("%s: %v", action.name, err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -355,19 +355,19 @@ func canceledQueryContext(pool *pgx.ConnPool, actionNum int) error {
|
|||||||
cancelFunc()
|
cancelFunc()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
rows, err := pool.QueryContext(ctx, "select pg_sleep(5)")
|
rows, err := pool.QueryContext(ctx, "select pg_sleep(2)")
|
||||||
if err == context.Canceled {
|
if err == context.Canceled {
|
||||||
return nil
|
return nil
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return fmt.Errorf("canceledQueryContext: Only allowed error is context.Canceled, got %v", err)
|
return fmt.Errorf("Only allowed error is context.Canceled, got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
return errors.New("canceledQueryContext: should never receive row")
|
return errors.New("should never receive row")
|
||||||
}
|
}
|
||||||
|
|
||||||
if rows.Err() != context.Canceled {
|
if rows.Err() != context.Canceled {
|
||||||
return fmt.Errorf("canceledQueryContext: Expected context.Canceled error, got %v", rows.Err())
|
return fmt.Errorf("Expected context.Canceled error, got %v", rows.Err())
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -380,9 +380,9 @@ func canceledExecContext(pool *pgx.ConnPool, actionNum int) error {
|
|||||||
cancelFunc()
|
cancelFunc()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
_, err := pool.ExecContext(ctx, "select pg_sleep(5)")
|
_, err := pool.ExecContext(ctx, "select pg_sleep(2)")
|
||||||
if err != context.Canceled {
|
if err != context.Canceled {
|
||||||
return fmt.Errorf("canceledExecContext: Expected context.Canceled error, got %v", err)
|
return fmt.Errorf("Expected context.Canceled error, got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
Reference in New Issue
Block a user