Replace chan based conn locking with bool
This is conceptually simpler and will lead to error messages instead of deadlocks.
This commit is contained in:
@@ -89,8 +89,7 @@ type PgConn struct {
|
||||
|
||||
Config *Config
|
||||
|
||||
controller chan interface{}
|
||||
|
||||
locked bool
|
||||
closed bool
|
||||
|
||||
bufferingReceive bool
|
||||
@@ -153,7 +152,6 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
|
||||
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) {
|
||||
pgConn := new(PgConn)
|
||||
pgConn.Config = config
|
||||
pgConn.controller = make(chan interface{}, 1)
|
||||
|
||||
var err error
|
||||
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
|
||||
@@ -405,6 +403,29 @@ func (pgConn *PgConn) IsAlive() bool {
|
||||
return !pgConn.closed
|
||||
}
|
||||
|
||||
// lock locks the connection. It returns an error if the connection is already locked or is closed.
|
||||
func (pgConn *PgConn) lock() error {
|
||||
if pgConn.locked {
|
||||
return errors.New("connection busy")
|
||||
}
|
||||
|
||||
if pgConn.closed {
|
||||
return errors.New("connection closed")
|
||||
}
|
||||
|
||||
pgConn.locked = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pgConn *PgConn) unlock() {
|
||||
if !pgConn.locked {
|
||||
panic("BUG: cannot unlock unlocked connection")
|
||||
}
|
||||
|
||||
pgConn.locked = false
|
||||
}
|
||||
|
||||
// ParameterStatus returns the value of a parameter reported by the server (e.g.
|
||||
// server_version). Returns an empty string for unknown parameters.
|
||||
func (pgConn *PgConn) ParameterStatus(key string) string {
|
||||
@@ -476,10 +497,14 @@ type PreparedStatementDescription struct {
|
||||
|
||||
// Prepare creates a prepared statement.
|
||||
func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) {
|
||||
if err := pgConn.lock(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case pgConn.controller <- pgConn:
|
||||
default:
|
||||
}
|
||||
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||
defer cleanupContextDeadline()
|
||||
@@ -521,7 +546,7 @@ readloop:
|
||||
}
|
||||
}
|
||||
|
||||
<-pgConn.controller
|
||||
pgConn.unlock()
|
||||
|
||||
if parseErr != nil {
|
||||
return nil, parseErr
|
||||
@@ -594,14 +619,18 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
|
||||
// WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not
|
||||
// received.
|
||||
func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
|
||||
if err := pgConn.lock(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case pgConn.controller <- pgConn:
|
||||
default:
|
||||
}
|
||||
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||
defer cleanupContextDeadline()
|
||||
defer func() { <-pgConn.controller }()
|
||||
defer pgConn.unlock()
|
||||
|
||||
for {
|
||||
msg, err := pgConn.ReceiveMessage()
|
||||
@@ -628,12 +657,18 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
|
||||
cleanupContextDeadline: func() {},
|
||||
}
|
||||
|
||||
if err := pgConn.lock(); err != nil {
|
||||
multiResult.closed = true
|
||||
multiResult.err = err
|
||||
return multiResult
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
multiResult.closed = true
|
||||
multiResult.err = ctx.Err()
|
||||
return multiResult
|
||||
case pgConn.controller <- multiResult:
|
||||
default:
|
||||
}
|
||||
multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||
|
||||
@@ -646,7 +681,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
|
||||
multiResult.cleanupContextDeadline()
|
||||
multiResult.closed = true
|
||||
multiResult.err = preferContextOverNetTimeoutError(ctx, err)
|
||||
<-pgConn.controller
|
||||
pgConn.unlock()
|
||||
return multiResult
|
||||
}
|
||||
|
||||
@@ -679,12 +714,18 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
|
||||
cleanupContextDeadline: func() {},
|
||||
}
|
||||
|
||||
if err := pgConn.lock(); err != nil {
|
||||
result.concludeCommand("", err)
|
||||
result.closed = true
|
||||
return result
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
result.concludeCommand("", ctx.Err())
|
||||
result.closed = true
|
||||
return result
|
||||
case pgConn.controller <- result:
|
||||
default:
|
||||
}
|
||||
result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||
|
||||
@@ -704,7 +745,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
|
||||
result.concludeCommand("", err)
|
||||
result.cleanupContextDeadline()
|
||||
result.closed = true
|
||||
<-pgConn.controller
|
||||
pgConn.unlock()
|
||||
}
|
||||
|
||||
return result
|
||||
@@ -729,12 +770,18 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
|
||||
cleanupContextDeadline: func() {},
|
||||
}
|
||||
|
||||
if err := pgConn.lock(); err != nil {
|
||||
result.concludeCommand("", err)
|
||||
result.closed = true
|
||||
return result
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
result.concludeCommand("", ctx.Err())
|
||||
result.closed = true
|
||||
return result
|
||||
case pgConn.controller <- result:
|
||||
default:
|
||||
}
|
||||
result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||
|
||||
@@ -750,7 +797,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
|
||||
result.concludeCommand("", err)
|
||||
result.cleanupContextDeadline()
|
||||
result.closed = true
|
||||
<-pgConn.controller
|
||||
pgConn.unlock()
|
||||
}
|
||||
|
||||
return result
|
||||
@@ -758,10 +805,14 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
|
||||
|
||||
// CopyTo executes the copy command sql and copies the results to w.
|
||||
func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) {
|
||||
if err := pgConn.lock(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
case pgConn.controller <- pgConn:
|
||||
default:
|
||||
}
|
||||
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||
defer cleanupContextDeadline()
|
||||
@@ -773,7 +824,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
|
||||
_, err := pgConn.conn.Write(buf)
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
<-pgConn.controller
|
||||
pgConn.unlock()
|
||||
|
||||
return "", preferContextOverNetTimeoutError(ctx, err)
|
||||
}
|
||||
@@ -797,7 +848,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
|
||||
return "", err
|
||||
}
|
||||
case *pgproto3.ReadyForQuery:
|
||||
<-pgConn.controller
|
||||
pgConn.unlock()
|
||||
return commandTag, pgErr
|
||||
case *pgproto3.CommandComplete:
|
||||
commandTag = CommandTag(msg.CommandTag)
|
||||
@@ -812,10 +863,15 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
|
||||
// Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r
|
||||
// could still block.
|
||||
func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) {
|
||||
if err := pgConn.lock(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer pgConn.unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
case pgConn.controller <- pgConn:
|
||||
default:
|
||||
}
|
||||
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||
defer cleanupContextDeadline()
|
||||
@@ -827,8 +883,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
||||
_, err := pgConn.conn.Write(buf)
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
<-pgConn.controller
|
||||
|
||||
return "", preferContextOverNetTimeoutError(ctx, err)
|
||||
}
|
||||
|
||||
@@ -849,7 +903,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
||||
case *pgproto3.ErrorResponse:
|
||||
pgErr = errorResponseToPgError(msg)
|
||||
case *pgproto3.ReadyForQuery:
|
||||
<-pgConn.controller
|
||||
return commandTag, pgErr
|
||||
}
|
||||
}
|
||||
@@ -871,8 +924,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
||||
_, err = pgConn.conn.Write(buf)
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
<-pgConn.controller
|
||||
|
||||
return "", preferContextOverNetTimeoutError(ctx, err)
|
||||
}
|
||||
}
|
||||
@@ -904,8 +955,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
||||
_, err = pgConn.conn.Write(buf)
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
<-pgConn.controller
|
||||
|
||||
return "", preferContextOverNetTimeoutError(ctx, err)
|
||||
}
|
||||
|
||||
@@ -919,7 +968,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.ReadyForQuery:
|
||||
<-pgConn.controller
|
||||
return commandTag, pgErr
|
||||
case *pgproto3.CommandComplete:
|
||||
commandTag = CommandTag(msg.CommandTag)
|
||||
@@ -968,7 +1016,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
|
||||
case *pgproto3.ReadyForQuery:
|
||||
mrr.cleanupContextDeadline()
|
||||
mrr.closed = true
|
||||
<-mrr.pgConn.controller
|
||||
mrr.pgConn.unlock()
|
||||
case *pgproto3.ErrorResponse:
|
||||
mrr.err = errorResponseToPgError(msg)
|
||||
}
|
||||
@@ -1125,7 +1173,7 @@ func (rr *ResultReader) Close() (CommandTag, error) {
|
||||
switch msg.(type) {
|
||||
case *pgproto3.ReadyForQuery:
|
||||
rr.cleanupContextDeadline()
|
||||
<-rr.pgConn.controller
|
||||
rr.pgConn.unlock()
|
||||
return rr.commandTag, rr.err
|
||||
}
|
||||
}
|
||||
@@ -1203,12 +1251,18 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
|
||||
cleanupContextDeadline: func() {},
|
||||
}
|
||||
|
||||
if err := pgConn.lock(); err != nil {
|
||||
multiResult.closed = true
|
||||
multiResult.err = ctx.Err()
|
||||
return multiResult
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
multiResult.closed = true
|
||||
multiResult.err = ctx.Err()
|
||||
return multiResult
|
||||
case pgConn.controller <- multiResult:
|
||||
default:
|
||||
}
|
||||
multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||
|
||||
@@ -1219,7 +1273,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
|
||||
multiResult.cleanupContextDeadline()
|
||||
multiResult.closed = true
|
||||
multiResult.err = preferContextOverNetTimeoutError(ctx, err)
|
||||
<-pgConn.controller
|
||||
pgConn.unlock()
|
||||
return multiResult
|
||||
}
|
||||
|
||||
|
||||
@@ -484,6 +484,29 @@ func TestConnExecBatch(t *testing.T) {
|
||||
assert.Equal(t, "SELECT 1", string(results[2].CommandTag))
|
||||
}
|
||||
|
||||
func TestConnLocking(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
mrr := pgConn.Exec(context.Background(), "select 'Hello, world'")
|
||||
results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll()
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "connection busy", err.Error())
|
||||
|
||||
results, err = mrr.ReadAll()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, results, 1)
|
||||
assert.Nil(t, results[0].Err)
|
||||
assert.Equal(t, "SELECT 1", string(results[0].CommandTag))
|
||||
assert.Len(t, results[0].Rows, 1)
|
||||
assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestCommandTag(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user