Resolve race on conn.Close/die
Use sync.Mutex instead of atomic operations for clarity.
This commit is contained in:
@@ -17,6 +17,7 @@ import (
|
|||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -102,7 +103,8 @@ type Conn struct {
|
|||||||
poolResetCount int
|
poolResetCount int
|
||||||
preallocatedRows []Rows
|
preallocatedRows []Rows
|
||||||
|
|
||||||
status int32 // One of connStatus* constants
|
mux sync.Mutex
|
||||||
|
status byte // One of connStatus* constants
|
||||||
causeOfDeath error
|
causeOfDeath error
|
||||||
|
|
||||||
readyForQuery bool // connection has received ReadyForQuery message since last query was sent
|
readyForQuery bool // connection has received ReadyForQuery message since last query was sent
|
||||||
@@ -267,20 +269,25 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
|||||||
defer func() {
|
defer func() {
|
||||||
if c != nil && err != nil {
|
if c != nil && err != nil {
|
||||||
c.conn.Close()
|
c.conn.Close()
|
||||||
atomic.StoreInt32(&c.status, connStatusClosed)
|
c.mux.Lock()
|
||||||
|
c.status = connStatusClosed
|
||||||
|
c.mux.Unlock()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
c.RuntimeParams = make(map[string]string)
|
c.RuntimeParams = make(map[string]string)
|
||||||
c.preparedStatements = make(map[string]*PreparedStatement)
|
c.preparedStatements = make(map[string]*PreparedStatement)
|
||||||
c.channels = make(map[string]struct{})
|
c.channels = make(map[string]struct{})
|
||||||
atomic.StoreInt32(&c.status, connStatusIdle)
|
|
||||||
c.lastActivityTime = time.Now()
|
c.lastActivityTime = time.Now()
|
||||||
c.cancelQueryCompleted = make(chan struct{}, 1)
|
c.cancelQueryCompleted = make(chan struct{}, 1)
|
||||||
c.doneChan = make(chan struct{})
|
c.doneChan = make(chan struct{})
|
||||||
c.closedChan = make(chan error)
|
c.closedChan = make(chan error)
|
||||||
c.wbuf = make([]byte, 0, 1024)
|
c.wbuf = make([]byte, 0, 1024)
|
||||||
|
|
||||||
|
c.mux.Lock()
|
||||||
|
c.status = connStatusIdle
|
||||||
|
c.mux.Unlock()
|
||||||
|
|
||||||
if tlsConfig != nil {
|
if tlsConfig != nil {
|
||||||
if c.shouldLog(LogLevelDebug) {
|
if c.shouldLog(LogLevelDebug) {
|
||||||
c.log(LogLevelDebug, "starting TLS handshake", nil)
|
c.log(LogLevelDebug, "starting TLS handshake", nil)
|
||||||
@@ -401,19 +408,17 @@ func (c *Conn) PID() uint32 {
|
|||||||
// Close closes a connection. It is safe to call Close on a already closed
|
// Close closes a connection. It is safe to call Close on a already closed
|
||||||
// connection.
|
// connection.
|
||||||
func (c *Conn) Close() (err error) {
|
func (c *Conn) Close() (err error) {
|
||||||
for {
|
c.mux.Lock()
|
||||||
status := atomic.LoadInt32(&c.status)
|
defer c.mux.Unlock()
|
||||||
if status < connStatusIdle {
|
|
||||||
return nil
|
if c.status < connStatusIdle {
|
||||||
}
|
return nil
|
||||||
if atomic.CompareAndSwapInt32(&c.status, status, connStatusClosed) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
c.status = connStatusClosed
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
c.conn.Close()
|
c.conn.Close()
|
||||||
c.die(errors.New("Closed"))
|
c.causeOfDeath = errors.New("Closed")
|
||||||
if c.shouldLog(LogLevelInfo) {
|
if c.shouldLog(LogLevelInfo) {
|
||||||
c.log(LogLevelInfo, "closed connection", nil)
|
c.log(LogLevelInfo, "closed connection", nil)
|
||||||
}
|
}
|
||||||
@@ -989,10 +994,14 @@ func (c *Conn) WaitForNotification(ctx context.Context) (notification *Notificat
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) IsAlive() bool {
|
func (c *Conn) IsAlive() bool {
|
||||||
return atomic.LoadInt32(&c.status) >= connStatusIdle
|
c.mux.Lock()
|
||||||
|
defer c.mux.Unlock()
|
||||||
|
return c.status >= connStatusIdle
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) CauseOfDeath() error {
|
func (c *Conn) CauseOfDeath() error {
|
||||||
|
c.mux.Lock()
|
||||||
|
defer c.mux.Unlock()
|
||||||
return c.causeOfDeath
|
return c.causeOfDeath
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1131,7 +1140,7 @@ func (c *Conn) processContextFreeMsg(msg pgproto3.BackendMessage) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) rxMsg() (pgproto3.BackendMessage, error) {
|
func (c *Conn) rxMsg() (pgproto3.BackendMessage, error) {
|
||||||
if atomic.LoadInt32(&c.status) < connStatusIdle {
|
if !c.IsAlive() {
|
||||||
return nil, ErrDeadConn
|
return nil, ErrDeadConn
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1283,23 +1292,40 @@ func (c *Conn) txPasswordMessage(password string) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) die(err error) {
|
func (c *Conn) die(err error) {
|
||||||
atomic.StoreInt32(&c.status, connStatusClosed)
|
c.mux.Lock()
|
||||||
|
defer c.mux.Unlock()
|
||||||
|
|
||||||
|
if c.status == connStatusClosed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.status = connStatusClosed
|
||||||
c.causeOfDeath = err
|
c.causeOfDeath = err
|
||||||
c.conn.Close()
|
c.conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) lock() error {
|
func (c *Conn) lock() error {
|
||||||
if atomic.CompareAndSwapInt32(&c.status, connStatusIdle, connStatusBusy) {
|
c.mux.Lock()
|
||||||
return nil
|
defer c.mux.Unlock()
|
||||||
|
|
||||||
|
if c.status != connStatusIdle {
|
||||||
|
return ErrConnBusy
|
||||||
}
|
}
|
||||||
return ErrConnBusy
|
|
||||||
|
c.status = connStatusBusy
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) unlock() error {
|
func (c *Conn) unlock() error {
|
||||||
if atomic.CompareAndSwapInt32(&c.status, connStatusBusy, connStatusIdle) {
|
c.mux.Lock()
|
||||||
return nil
|
defer c.mux.Unlock()
|
||||||
|
|
||||||
|
if c.status != connStatusBusy {
|
||||||
|
return errors.New("unlock conn that is not busy")
|
||||||
}
|
}
|
||||||
return errors.New("unlock conn that is not busy")
|
|
||||||
|
c.status = connStatusIdle
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) shouldLog(lvl int) bool {
|
func (c *Conn) shouldLog(lvl int) bool {
|
||||||
|
|||||||
Reference in New Issue
Block a user