2
0

Fix context query cancellation

Previous commits had a race condition due to not waiting for the PostgreSQL
server to close the cancel query connection. This made it possible for the
cancel request to impact a subsequent query on the same connection. This
commit sets a flag that a cancel request was made and blocks until the
PostgreSQL server closes the cancel connection.
This commit is contained in:
Jack Christensen
2017-02-11 19:53:18 -06:00
parent deac6564ee
commit 048a75406f
3 changed files with 118 additions and 29 deletions
+106 -22
View File
@@ -93,7 +93,9 @@ type Conn struct {
status int32 // One of connStatus* constants status int32 // One of connStatus* constants
causeOfDeath error causeOfDeath error
readyForQuery bool // can the connection be used to send a query readyForQuery bool // connection has received ReadyForQuery message since last query was sent
cancelQueryInProgress int32
cancelQueryCompleted chan struct{}
// context support // context support
ctxInProgress bool ctxInProgress bool
@@ -268,6 +270,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
c.channels = make(map[string]struct{}) c.channels = make(map[string]struct{})
atomic.StoreInt32(&c.status, connStatusIdle) atomic.StoreInt32(&c.status, connStatusIdle)
c.lastActivityTime = time.Now() c.lastActivityTime = time.Now()
c.cancelQueryCompleted = make(chan struct{}, 1)
c.doneChan = make(chan struct{}) c.doneChan = make(chan struct{})
c.closedChan = make(chan error) c.closedChan = make(chan error)
@@ -634,10 +637,15 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
// name and sql arguments. This allows a code path to PrepareEx and Query/Exec without // name and sql arguments. This allows a code path to PrepareEx and Query/Exec without
// concern for if the statement has already been prepared. // concern for if the statement has already been prepared.
func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
return c.prepareEx(name, sql, opts) return c.PrepareExContext(context.Background(), name, sql, opts)
} }
func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
err = c.waitForPreviousCancelQuery(ctx)
if err != nil {
return nil, err
}
err = c.initContext(ctx) err = c.initContext(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -743,7 +751,25 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
} }
// Deallocate released a prepared statement // Deallocate released a prepared statement
func (c *Conn) Deallocate(name string) (err error) { func (c *Conn) Deallocate(name string) error {
return c.deallocateContext(context.Background(), name)
}
// TODO - consider making this public
func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) {
err = c.waitForPreviousCancelQuery(ctx)
if err != nil {
return err
}
err = c.initContext(ctx)
if err != nil {
return err
}
defer func() {
err = c.termContext(err)
}()
if err := c.ensureConnectionReadyForQuery(); err != nil { if err := c.ensureConnectionReadyForQuery(); err != nil {
return err return err
} }
@@ -818,6 +844,13 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error)
return notification, nil return notification, nil
} }
ctx, cancelFn := context.WithTimeout(context.Background(), timeout)
if err := c.waitForPreviousCancelQuery(ctx); err != nil {
cancelFn()
return nil, err
}
cancelFn()
if err := c.ensureConnectionReadyForQuery(); err != nil { if err := c.ensureConnectionReadyForQuery(); err != nil {
return nil, err return nil, err
} }
@@ -1318,21 +1351,55 @@ func quoteIdentifier(s string) string {
// ensure that the query was canceled. As specified in the documentation, there // ensure that the query was canceled. As specified in the documentation, there
// is no way to be sure a query was canceled. See // is no way to be sure a query was canceled. See
// https://www.postgresql.org/docs/current/static/protocol-flow.html#AEN112861 // https://www.postgresql.org/docs/current/static/protocol-flow.html#AEN112861
func (c *Conn) cancelQuery() error { func (c *Conn) cancelQuery() {
network, address := c.config.networkAddress() if !atomic.CompareAndSwapInt32(&c.cancelQueryInProgress, 0, 1) {
cancelConn, err := c.config.Dial(network, address) panic("cancelQuery when cancelQueryInProgress")
if err != nil {
return err
} }
defer cancelConn.Close()
buf := make([]byte, 16) if err := c.conn.SetDeadline(time.Now()); err != nil {
binary.BigEndian.PutUint32(buf[0:4], 16) c.Close() // Close connection if unable to set deadline
binary.BigEndian.PutUint32(buf[4:8], 80877102) return
binary.BigEndian.PutUint32(buf[8:12], uint32(c.Pid)) }
binary.BigEndian.PutUint32(buf[12:16], uint32(c.SecretKey))
_, err = cancelConn.Write(buf) doCancel := func() error {
return err network, address := c.config.networkAddress()
cancelConn, err := c.config.Dial(network, address)
if err != nil {
return err
}
defer cancelConn.Close()
// If server doesn't process cancellation request in bounded time then abort.
err = cancelConn.SetDeadline(time.Now().Add(15 * time.Second))
if err != nil {
return err
}
buf := make([]byte, 16)
binary.BigEndian.PutUint32(buf[0:4], 16)
binary.BigEndian.PutUint32(buf[4:8], 80877102)
binary.BigEndian.PutUint32(buf[8:12], uint32(c.Pid))
binary.BigEndian.PutUint32(buf[12:16], uint32(c.SecretKey))
_, err = cancelConn.Write(buf)
if err != nil {
return err
}
_, err = cancelConn.Read(buf)
if err != io.EOF {
return fmt.Errorf("Server failed to close connection after cancel query request: %v %v", err, buf)
}
return nil
}
go func() {
err := doCancel()
if err != nil {
c.Close() // Something is very wrong. Terminate the connection.
}
c.cancelQueryCompleted <- struct{}{}
}()
} }
func (c *Conn) Ping() error { func (c *Conn) Ping() error {
@@ -1345,6 +1412,11 @@ func (c *Conn) PingContext(ctx context.Context) error {
} }
func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) { func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
err = c.waitForPreviousCancelQuery(ctx)
if err != nil {
return "", err
}
err = c.initContext(ctx) err = c.initContext(ctx)
if err != nil { if err != nil {
return "", err return "", err
@@ -1438,9 +1510,6 @@ func (c *Conn) termContext(opErr error) error {
select { select {
case err = <-c.closedChan: case err = <-c.closedChan:
if dlErr := c.conn.SetDeadline(time.Time{}); dlErr != nil {
c.Close() // Close connection if unable to disable deadline
}
if opErr == nil { if opErr == nil {
err = nil err = nil
} }
@@ -1456,14 +1525,29 @@ func (c *Conn) contextHandler(ctx context.Context) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
c.cancelQuery() c.cancelQuery()
if err := c.conn.SetDeadline(time.Now()); err != nil {
c.Close() // Close connection if unable to set deadline
}
c.closedChan <- ctx.Err() c.closedChan <- ctx.Err()
case <-c.doneChan: case <-c.doneChan:
} }
} }
func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error {
if atomic.LoadInt32(&c.cancelQueryInProgress) == 0 {
return nil
}
select {
case <-c.cancelQueryCompleted:
atomic.StoreInt32(&c.cancelQueryInProgress, 0)
if err := c.conn.SetDeadline(time.Time{}); err != nil {
c.Close() // Close connection if unable to disable deadline
return err
}
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func (c *Conn) ensureConnectionReadyForQuery() error { func (c *Conn) ensureConnectionReadyForQuery() error {
for !c.readyForQuery { for !c.readyForQuery {
t, r, err := c.rxMsg() t, r, err := c.rxMsg()
+5
View File
@@ -419,6 +419,11 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row {
} }
func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (rows *Rows, err error) { func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (rows *Rows, err error) {
err = c.waitForPreviousCancelQuery(ctx)
if err != nil {
return nil, err
}
c.lastActivityTime = time.Now() c.lastActivityTime = time.Now()
rows = c.getRows(sql, args) rows = c.getRows(sql, args)
+7 -7
View File
@@ -66,7 +66,7 @@ func TestStressConnPool(t *testing.T) {
action := actions[rand.Intn(len(actions))] action := actions[rand.Intn(len(actions))]
err := action.fn(pool, n) err := action.fn(pool, n)
if err != nil { if err != nil {
errChan <- err errChan <- fmt.Errorf("%s: %v", action.name, err)
break break
} }
} }
@@ -355,19 +355,19 @@ func canceledQueryContext(pool *pgx.ConnPool, actionNum int) error {
cancelFunc() cancelFunc()
}() }()
rows, err := pool.QueryContext(ctx, "select pg_sleep(5)") rows, err := pool.QueryContext(ctx, "select pg_sleep(2)")
if err == context.Canceled { if err == context.Canceled {
return nil return nil
} else if err != nil { } else if err != nil {
return fmt.Errorf("canceledQueryContext: Only allowed error is context.Canceled, got %v", err) return fmt.Errorf("Only allowed error is context.Canceled, got %v", err)
} }
for rows.Next() { for rows.Next() {
return errors.New("canceledQueryContext: should never receive row") return errors.New("should never receive row")
} }
if rows.Err() != context.Canceled { if rows.Err() != context.Canceled {
return fmt.Errorf("canceledQueryContext: Expected context.Canceled error, got %v", rows.Err()) return fmt.Errorf("Expected context.Canceled error, got %v", rows.Err())
} }
return nil return nil
@@ -380,9 +380,9 @@ func canceledExecContext(pool *pgx.ConnPool, actionNum int) error {
cancelFunc() cancelFunc()
}() }()
_, err := pool.ExecContext(ctx, "select pg_sleep(5)") _, err := pool.ExecContext(ctx, "select pg_sleep(2)")
if err != context.Canceled { if err != context.Canceled {
return fmt.Errorf("canceledExecContext: Expected context.Canceled error, got %v", err) return fmt.Errorf("Expected context.Canceled error, got %v", err)
} }
return nil return nil