Add guards against usage of busy connection
This commit is contained in:
@@ -62,6 +62,7 @@ type Conn struct {
|
|||||||
fp *fastpath
|
fp *fastpath
|
||||||
pgsql_af_inet byte
|
pgsql_af_inet byte
|
||||||
pgsql_af_inet6 byte
|
pgsql_af_inet6 byte
|
||||||
|
busy bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type PreparedStatement struct {
|
type PreparedStatement struct {
|
||||||
@@ -99,6 +100,7 @@ var ErrNoRows = errors.New("no rows in result set")
|
|||||||
var ErrNotificationTimeout = errors.New("notification timeout")
|
var ErrNotificationTimeout = errors.New("notification timeout")
|
||||||
var ErrDeadConn = errors.New("conn is dead")
|
var ErrDeadConn = errors.New("conn is dead")
|
||||||
var ErrTLSRefused = errors.New("server refused TLS connection")
|
var ErrTLSRefused = errors.New("server refused TLS connection")
|
||||||
|
var ErrConnBusy = errors.New("conn is busy")
|
||||||
|
|
||||||
type ProtocolError string
|
type ProtocolError string
|
||||||
|
|
||||||
@@ -878,19 +880,29 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||||||
// Exec executes sql. sql can be either a prepared statement name or an SQL string.
|
// Exec executes sql. sql can be either a prepared statement name or an SQL string.
|
||||||
// arguments should be referenced positionally from the sql string as $1, $2, etc.
|
// arguments should be referenced positionally from the sql string as $1, $2, etc.
|
||||||
func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
|
func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||||
|
if err = c.lock(); err != nil {
|
||||||
|
return commandTag, err
|
||||||
|
}
|
||||||
|
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
c.lastActivityTime = startTime
|
c.lastActivityTime = startTime
|
||||||
|
|
||||||
if c.logLevel >= LogLevelError {
|
defer func() {
|
||||||
defer func() {
|
if err == nil {
|
||||||
if err == nil {
|
if c.logLevel >= LogLevelInfo {
|
||||||
endTime := time.Now()
|
endTime := time.Now()
|
||||||
c.logger.Info("Exec", "sql", sql, "args", logQueryArgs(arguments), "time", endTime.Sub(startTime), "commandTag", commandTag)
|
c.logger.Info("Exec", "sql", sql, "args", logQueryArgs(arguments), "time", endTime.Sub(startTime), "commandTag", commandTag)
|
||||||
} else {
|
}
|
||||||
|
} else {
|
||||||
|
if c.logLevel >= LogLevelError {
|
||||||
c.logger.Error("Exec", "sql", sql, "args", logQueryArgs(arguments), "error", err)
|
c.logger.Error("Exec", "sql", sql, "args", logQueryArgs(arguments), "error", err)
|
||||||
}
|
}
|
||||||
}()
|
}
|
||||||
}
|
|
||||||
|
if unlockErr := c.unlock(); unlockErr != nil && err == nil {
|
||||||
|
err = unlockErr
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
if err = c.sendQuery(sql, arguments...); err != nil {
|
if err = c.sendQuery(sql, arguments...); err != nil {
|
||||||
return
|
return
|
||||||
@@ -1137,3 +1149,19 @@ func (c *Conn) die(err error) {
|
|||||||
c.causeOfDeath = err
|
c.causeOfDeath = err
|
||||||
c.conn.Close()
|
c.conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Conn) lock() error {
|
||||||
|
if c.busy {
|
||||||
|
return ErrConnBusy
|
||||||
|
}
|
||||||
|
c.busy = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) unlock() error {
|
||||||
|
if !c.busy {
|
||||||
|
return errors.New("unlock conn that is not busy")
|
||||||
|
}
|
||||||
|
c.busy = false
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -487,6 +487,11 @@ func TestConnPoolQueryConcurrentLoad(t *testing.T) {
|
|||||||
if rowCount != 1000 {
|
if rowCount != 1000 {
|
||||||
t.Error("Select called onDataRow wrong number of times")
|
t.Error("Select called onDataRow wrong number of times")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, err = pool.Exec("--;")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("pool.Exec failed: %v", err)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1030,3 +1030,39 @@ func TestInsertTimestampArray(t *testing.T) {
|
|||||||
t.Errorf("Unexpected results from Exec: %v", results)
|
t.Errorf("Unexpected results from Exec: %v", results)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCatchSimultaneousConnectionQueries(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
rows1, err := conn.Query("select generate_series(1,$1)", 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("conn.Query failed: ", err)
|
||||||
|
}
|
||||||
|
defer rows1.Close()
|
||||||
|
|
||||||
|
_, err = conn.Query("select generate_series(1,$1)", 10)
|
||||||
|
if err != pgx.ErrConnBusy {
|
||||||
|
t.Fatalf("conn.Query should have failed with pgx.ErrConnBusy, but it was %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCatchSimultaneousConnectionQueryAndExec(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
rows, err := conn.Query("select generate_series(1,$1)", 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("conn.Query failed: ", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
_, err = conn.Exec("create temporary table foo(spice timestamp[])")
|
||||||
|
if err != pgx.ErrConnBusy {
|
||||||
|
t.Fatalf("conn.Exec should have failed with pgx.ErrConnBusy, but it was %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -39,20 +39,21 @@ func (r *Row) Scan(dest ...interface{}) (err error) {
|
|||||||
// the *Conn can be used again. Rows are closed by explicitly calling Close(),
|
// the *Conn can be used again. Rows are closed by explicitly calling Close(),
|
||||||
// calling Next() until it returns false, or when a fatal error occurs.
|
// calling Next() until it returns false, or when a fatal error occurs.
|
||||||
type Rows struct {
|
type Rows struct {
|
||||||
pool *ConnPool
|
pool *ConnPool
|
||||||
conn *Conn
|
conn *Conn
|
||||||
mr *msgReader
|
mr *msgReader
|
||||||
fields []FieldDescription
|
fields []FieldDescription
|
||||||
vr ValueReader
|
vr ValueReader
|
||||||
rowCount int
|
rowCount int
|
||||||
columnIdx int
|
columnIdx int
|
||||||
err error
|
err error
|
||||||
closed bool
|
closed bool
|
||||||
startTime time.Time
|
startTime time.Time
|
||||||
sql string
|
sql string
|
||||||
args []interface{}
|
args []interface{}
|
||||||
logger Logger
|
logger Logger
|
||||||
logLevel int
|
logLevel int
|
||||||
|
unlockConn bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rows *Rows) FieldDescriptions() []FieldDescription {
|
func (rows *Rows) FieldDescriptions() []FieldDescription {
|
||||||
@@ -64,6 +65,11 @@ func (rows *Rows) close() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if rows.unlockConn {
|
||||||
|
rows.conn.unlock()
|
||||||
|
rows.unlockConn = false
|
||||||
|
}
|
||||||
|
|
||||||
if rows.pool != nil {
|
if rows.pool != nil {
|
||||||
rows.pool.Release(rows.conn)
|
rows.pool.Release(rows.conn)
|
||||||
rows.pool = nil
|
rows.pool = nil
|
||||||
@@ -421,6 +427,12 @@ func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) {
|
|||||||
c.lastActivityTime = time.Now()
|
c.lastActivityTime = time.Now()
|
||||||
rows := &Rows{conn: c, startTime: c.lastActivityTime, sql: sql, args: args, logger: c.logger, logLevel: c.logLevel}
|
rows := &Rows{conn: c, startTime: c.lastActivityTime, sql: sql, args: args, logger: c.logger, logLevel: c.logLevel}
|
||||||
|
|
||||||
|
if err := c.lock(); err != nil {
|
||||||
|
rows.abort(err)
|
||||||
|
return rows, err
|
||||||
|
}
|
||||||
|
rows.unlockConn = true
|
||||||
|
|
||||||
ps, ok := c.preparedStatements[sql]
|
ps, ok := c.preparedStatements[sql]
|
||||||
if !ok {
|
if !ok {
|
||||||
var err error
|
var err error
|
||||||
|
|||||||
Reference in New Issue
Block a user