From 749fdfe7d5e0ccca384376a583030b0a2fa74b9d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 21 May 2017 19:35:37 -0500 Subject: [PATCH] Resolve race on conn.Close/die Use sync.Mutex instead of atomic operations for clarity. --- conn.go | 68 +++++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 47 insertions(+), 21 deletions(-) diff --git a/conn.go b/conn.go index 04299de7..c4c054dd 100644 --- a/conn.go +++ b/conn.go @@ -17,6 +17,7 @@ import ( "regexp" "strconv" "strings" + "sync" "sync/atomic" "time" @@ -102,7 +103,8 @@ type Conn struct { poolResetCount int preallocatedRows []Rows - status int32 // One of connStatus* constants + mux sync.Mutex + status byte // One of connStatus* constants causeOfDeath error readyForQuery bool // connection has received ReadyForQuery message since last query was sent @@ -267,20 +269,25 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl defer func() { if c != nil && err != nil { c.conn.Close() - atomic.StoreInt32(&c.status, connStatusClosed) + c.mux.Lock() + c.status = connStatusClosed + c.mux.Unlock() } }() c.RuntimeParams = make(map[string]string) c.preparedStatements = make(map[string]*PreparedStatement) c.channels = make(map[string]struct{}) - atomic.StoreInt32(&c.status, connStatusIdle) c.lastActivityTime = time.Now() c.cancelQueryCompleted = make(chan struct{}, 1) c.doneChan = make(chan struct{}) c.closedChan = make(chan error) c.wbuf = make([]byte, 0, 1024) + c.mux.Lock() + c.status = connStatusIdle + c.mux.Unlock() + if tlsConfig != nil { if c.shouldLog(LogLevelDebug) { c.log(LogLevelDebug, "starting TLS handshake", nil) @@ -401,19 +408,17 @@ func (c *Conn) PID() uint32 { // Close closes a connection. It is safe to call Close on a already closed // connection. func (c *Conn) Close() (err error) { - for { - status := atomic.LoadInt32(&c.status) - if status < connStatusIdle { - return nil - } - if atomic.CompareAndSwapInt32(&c.status, status, connStatusClosed) { - break - } + c.mux.Lock() + defer c.mux.Unlock() + + if c.status < connStatusIdle { + return nil } + c.status = connStatusClosed defer func() { c.conn.Close() - c.die(errors.New("Closed")) + c.causeOfDeath = errors.New("Closed") if c.shouldLog(LogLevelInfo) { c.log(LogLevelInfo, "closed connection", nil) } @@ -989,10 +994,14 @@ func (c *Conn) WaitForNotification(ctx context.Context) (notification *Notificat } func (c *Conn) IsAlive() bool { - return atomic.LoadInt32(&c.status) >= connStatusIdle + c.mux.Lock() + defer c.mux.Unlock() + return c.status >= connStatusIdle } func (c *Conn) CauseOfDeath() error { + c.mux.Lock() + defer c.mux.Unlock() return c.causeOfDeath } @@ -1131,7 +1140,7 @@ func (c *Conn) processContextFreeMsg(msg pgproto3.BackendMessage) (err error) { } func (c *Conn) rxMsg() (pgproto3.BackendMessage, error) { - if atomic.LoadInt32(&c.status) < connStatusIdle { + if !c.IsAlive() { return nil, ErrDeadConn } @@ -1283,23 +1292,40 @@ func (c *Conn) txPasswordMessage(password string) (err error) { } func (c *Conn) die(err error) { - atomic.StoreInt32(&c.status, connStatusClosed) + c.mux.Lock() + defer c.mux.Unlock() + + if c.status == connStatusClosed { + return + } + + c.status = connStatusClosed c.causeOfDeath = err c.conn.Close() } func (c *Conn) lock() error { - if atomic.CompareAndSwapInt32(&c.status, connStatusIdle, connStatusBusy) { - return nil + c.mux.Lock() + defer c.mux.Unlock() + + if c.status != connStatusIdle { + return ErrConnBusy } - return ErrConnBusy + + c.status = connStatusBusy + return nil } func (c *Conn) unlock() error { - if atomic.CompareAndSwapInt32(&c.status, connStatusBusy, connStatusIdle) { - return nil + c.mux.Lock() + defer c.mux.Unlock() + + if c.status != connStatusBusy { + return errors.New("unlock conn that is not busy") } - return errors.New("unlock conn that is not busy") + + c.status = connStatusIdle + return nil } func (c *Conn) shouldLog(lvl int) bool {