From 5edd6609714a32f3fafc30b6f1680ee27381586e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 27 Sep 2014 13:38:01 -0500 Subject: [PATCH] WaitForNotification detects lost connections quicker Ping server every 15 seconds while waiting if no traffic has occurred. --- conn.go | 51 +++++++++++++++++++++++++++++++++++++++++++++++++-- query.go | 3 ++- 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/conn.go b/conn.go index 9b7cef65..c0aa36a4 100644 --- a/conn.go +++ b/conn.go @@ -35,6 +35,7 @@ type ConnConfig struct { // goroutines. type Conn struct { conn net.Conn // the underlying TCP or unix domain socket connection + lastActivityTime time.Time // the last time the connection was used reader *bufio.Reader // buffered reader to improve read performance wbuf [1024]byte Pid int32 // backend pid @@ -150,6 +151,7 @@ func Connect(config ConnConfig) (c *Conn, err error) { c.RuntimeParams = make(map[string]string) c.preparedStatements = make(map[string]*PreparedStatement) c.alive = true + c.lastActivityTime = time.Now() if config.TLSConfig != nil { c.logger.Debug("Starting TLS handshake") @@ -385,15 +387,57 @@ func (c *Conn) Listen(channel string) (err error) { // WaitForNotification waits for a PostgreSQL notification for up to timeout. // If the timeout occurs it returns pgx.ErrNotificationTimeout func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error) { + // Return already received notification immediately if len(c.notifications) > 0 { notification := c.notifications[0] c.notifications = c.notifications[1:] return notification, nil } - var zeroTime time.Time stopTime := time.Now().Add(timeout) + for { + now := time.Now() + + if now.After(stopTime) { + return nil, ErrNotificationTimeout + } + + // If there has been no activity on this connection for a while send a nop message just to ensure + // the connection is alive + nextEnsureAliveTime := c.lastActivityTime.Add(15 * time.Second) + if nextEnsureAliveTime.Before(now) { + // If the server can't respond to a nop in 15 seconds, assume it's dead + err := c.conn.SetReadDeadline(now.Add(15 * time.Second)) + if err != nil { + return nil, err + } + + _, err = c.Exec("--;") + if err != nil { + return nil, err + } + + c.lastActivityTime = now + } + + var deadline time.Time + if stopTime.Before(nextEnsureAliveTime) { + deadline = stopTime + } else { + deadline = nextEnsureAliveTime + } + + notification, err := c.waitForNotification(deadline) + if err != ErrNotificationTimeout { + return notification, err + } + } +} + +func (c *Conn) waitForNotification(deadline time.Time) (*Notification, error) { + var zeroTime time.Time + for { // Use SetReadDeadline to implement the timeout. SetReadDeadline will // cause operations to fail with a *net.OpError that has a Timeout() @@ -403,7 +447,7 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error) // deadline and peek into the reader. If a timeout error occurs there // we don't break the pgx connection. If the Peek returns that data // is available then we turn off the read deadline before the rxMsg. - err := c.conn.SetReadDeadline(stopTime) + err := c.conn.SetReadDeadline(deadline) if err != nil { return nil, err } @@ -594,6 +638,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} // arguments should be referenced positionally from the sql string as $1, $2, etc. func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) { startTime := time.Now() + c.lastActivityTime = startTime defer func() { if err == nil { @@ -667,6 +712,8 @@ func (c *Conn) rxMsg() (t byte, r *msgReader, err error) { c.die(err) } + c.lastActivityTime = time.Now() + return t, &c.mr, err } diff --git a/query.go b/query.go index a6f8fc34..e4f24a32 100644 --- a/query.go +++ b/query.go @@ -357,7 +357,8 @@ func (rows *Rows) Values() ([]interface{}, error) { // be returned in an error state. So it is allowed to ignore the error returned // from Query and handle it in *Rows. func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) { - rows := &Rows{conn: c, startTime: time.Now(), sql: sql, args: args, logger: c.logger} + c.lastActivityTime = time.Now() + rows := &Rows{conn: c, startTime: c.lastActivityTime, sql: sql, args: args, logger: c.logger} ps, ok := c.preparedStatements[sql] if !ok {