Replace chan based conn locking with bool
This is conceptually simpler and will lead to error messages instead of deadlocks.
This commit is contained in:
@@ -89,8 +89,7 @@ type PgConn struct {
|
|||||||
|
|
||||||
Config *Config
|
Config *Config
|
||||||
|
|
||||||
controller chan interface{}
|
locked bool
|
||||||
|
|
||||||
closed bool
|
closed bool
|
||||||
|
|
||||||
bufferingReceive bool
|
bufferingReceive bool
|
||||||
@@ -153,7 +152,6 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
|
|||||||
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) {
|
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) {
|
||||||
pgConn := new(PgConn)
|
pgConn := new(PgConn)
|
||||||
pgConn.Config = config
|
pgConn.Config = config
|
||||||
pgConn.controller = make(chan interface{}, 1)
|
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
|
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
|
||||||
@@ -405,6 +403,29 @@ func (pgConn *PgConn) IsAlive() bool {
|
|||||||
return !pgConn.closed
|
return !pgConn.closed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// lock locks the connection. It returns an error if the connection is already locked or is closed.
|
||||||
|
func (pgConn *PgConn) lock() error {
|
||||||
|
if pgConn.locked {
|
||||||
|
return errors.New("connection busy")
|
||||||
|
}
|
||||||
|
|
||||||
|
if pgConn.closed {
|
||||||
|
return errors.New("connection closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
pgConn.locked = true
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pgConn *PgConn) unlock() {
|
||||||
|
if !pgConn.locked {
|
||||||
|
panic("BUG: cannot unlock unlocked connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
pgConn.locked = false
|
||||||
|
}
|
||||||
|
|
||||||
// ParameterStatus returns the value of a parameter reported by the server (e.g.
|
// ParameterStatus returns the value of a parameter reported by the server (e.g.
|
||||||
// server_version). Returns an empty string for unknown parameters.
|
// server_version). Returns an empty string for unknown parameters.
|
||||||
func (pgConn *PgConn) ParameterStatus(key string) string {
|
func (pgConn *PgConn) ParameterStatus(key string) string {
|
||||||
@@ -476,10 +497,14 @@ type PreparedStatementDescription struct {
|
|||||||
|
|
||||||
// Prepare creates a prepared statement.
|
// Prepare creates a prepared statement.
|
||||||
func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) {
|
func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) {
|
||||||
|
if err := pgConn.lock(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil, ctx.Err()
|
return nil, ctx.Err()
|
||||||
case pgConn.controller <- pgConn:
|
default:
|
||||||
}
|
}
|
||||||
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn)
|
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||||
defer cleanupContextDeadline()
|
defer cleanupContextDeadline()
|
||||||
@@ -521,7 +546,7 @@ readloop:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
<-pgConn.controller
|
pgConn.unlock()
|
||||||
|
|
||||||
if parseErr != nil {
|
if parseErr != nil {
|
||||||
return nil, parseErr
|
return nil, parseErr
|
||||||
@@ -594,14 +619,18 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
|
|||||||
// WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not
|
// WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not
|
||||||
// received.
|
// received.
|
||||||
func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
|
func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
|
||||||
|
if err := pgConn.lock(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
case pgConn.controller <- pgConn:
|
default:
|
||||||
}
|
}
|
||||||
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn)
|
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||||
defer cleanupContextDeadline()
|
defer cleanupContextDeadline()
|
||||||
defer func() { <-pgConn.controller }()
|
defer pgConn.unlock()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
msg, err := pgConn.ReceiveMessage()
|
msg, err := pgConn.ReceiveMessage()
|
||||||
@@ -628,12 +657,18 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
|
|||||||
cleanupContextDeadline: func() {},
|
cleanupContextDeadline: func() {},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := pgConn.lock(); err != nil {
|
||||||
|
multiResult.closed = true
|
||||||
|
multiResult.err = err
|
||||||
|
return multiResult
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
multiResult.closed = true
|
multiResult.closed = true
|
||||||
multiResult.err = ctx.Err()
|
multiResult.err = ctx.Err()
|
||||||
return multiResult
|
return multiResult
|
||||||
case pgConn.controller <- multiResult:
|
default:
|
||||||
}
|
}
|
||||||
multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn)
|
multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||||
|
|
||||||
@@ -646,7 +681,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
|
|||||||
multiResult.cleanupContextDeadline()
|
multiResult.cleanupContextDeadline()
|
||||||
multiResult.closed = true
|
multiResult.closed = true
|
||||||
multiResult.err = preferContextOverNetTimeoutError(ctx, err)
|
multiResult.err = preferContextOverNetTimeoutError(ctx, err)
|
||||||
<-pgConn.controller
|
pgConn.unlock()
|
||||||
return multiResult
|
return multiResult
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -679,12 +714,18 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
|
|||||||
cleanupContextDeadline: func() {},
|
cleanupContextDeadline: func() {},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := pgConn.lock(); err != nil {
|
||||||
|
result.concludeCommand("", err)
|
||||||
|
result.closed = true
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
result.concludeCommand("", ctx.Err())
|
result.concludeCommand("", ctx.Err())
|
||||||
result.closed = true
|
result.closed = true
|
||||||
return result
|
return result
|
||||||
case pgConn.controller <- result:
|
default:
|
||||||
}
|
}
|
||||||
result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn)
|
result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||||
|
|
||||||
@@ -704,7 +745,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
|
|||||||
result.concludeCommand("", err)
|
result.concludeCommand("", err)
|
||||||
result.cleanupContextDeadline()
|
result.cleanupContextDeadline()
|
||||||
result.closed = true
|
result.closed = true
|
||||||
<-pgConn.controller
|
pgConn.unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@@ -729,12 +770,18 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
|
|||||||
cleanupContextDeadline: func() {},
|
cleanupContextDeadline: func() {},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := pgConn.lock(); err != nil {
|
||||||
|
result.concludeCommand("", err)
|
||||||
|
result.closed = true
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
result.concludeCommand("", ctx.Err())
|
result.concludeCommand("", ctx.Err())
|
||||||
result.closed = true
|
result.closed = true
|
||||||
return result
|
return result
|
||||||
case pgConn.controller <- result:
|
default:
|
||||||
}
|
}
|
||||||
result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn)
|
result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||||
|
|
||||||
@@ -750,7 +797,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
|
|||||||
result.concludeCommand("", err)
|
result.concludeCommand("", err)
|
||||||
result.cleanupContextDeadline()
|
result.cleanupContextDeadline()
|
||||||
result.closed = true
|
result.closed = true
|
||||||
<-pgConn.controller
|
pgConn.unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@@ -758,10 +805,14 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
|
|||||||
|
|
||||||
// CopyTo executes the copy command sql and copies the results to w.
|
// CopyTo executes the copy command sql and copies the results to w.
|
||||||
func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) {
|
func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) {
|
||||||
|
if err := pgConn.lock(); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return "", ctx.Err()
|
return "", ctx.Err()
|
||||||
case pgConn.controller <- pgConn:
|
default:
|
||||||
}
|
}
|
||||||
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn)
|
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||||
defer cleanupContextDeadline()
|
defer cleanupContextDeadline()
|
||||||
@@ -773,7 +824,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
|
|||||||
_, err := pgConn.conn.Write(buf)
|
_, err := pgConn.conn.Write(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.hardClose()
|
pgConn.hardClose()
|
||||||
<-pgConn.controller
|
pgConn.unlock()
|
||||||
|
|
||||||
return "", preferContextOverNetTimeoutError(ctx, err)
|
return "", preferContextOverNetTimeoutError(ctx, err)
|
||||||
}
|
}
|
||||||
@@ -797,7 +848,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
case *pgproto3.ReadyForQuery:
|
case *pgproto3.ReadyForQuery:
|
||||||
<-pgConn.controller
|
pgConn.unlock()
|
||||||
return commandTag, pgErr
|
return commandTag, pgErr
|
||||||
case *pgproto3.CommandComplete:
|
case *pgproto3.CommandComplete:
|
||||||
commandTag = CommandTag(msg.CommandTag)
|
commandTag = CommandTag(msg.CommandTag)
|
||||||
@@ -812,10 +863,15 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
|
|||||||
// Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r
|
// Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r
|
||||||
// could still block.
|
// could still block.
|
||||||
func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) {
|
func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) {
|
||||||
|
if err := pgConn.lock(); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer pgConn.unlock()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return "", ctx.Err()
|
return "", ctx.Err()
|
||||||
case pgConn.controller <- pgConn:
|
default:
|
||||||
}
|
}
|
||||||
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn)
|
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||||
defer cleanupContextDeadline()
|
defer cleanupContextDeadline()
|
||||||
@@ -827,8 +883,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||||||
_, err := pgConn.conn.Write(buf)
|
_, err := pgConn.conn.Write(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.hardClose()
|
pgConn.hardClose()
|
||||||
<-pgConn.controller
|
|
||||||
|
|
||||||
return "", preferContextOverNetTimeoutError(ctx, err)
|
return "", preferContextOverNetTimeoutError(ctx, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -849,7 +903,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||||||
case *pgproto3.ErrorResponse:
|
case *pgproto3.ErrorResponse:
|
||||||
pgErr = errorResponseToPgError(msg)
|
pgErr = errorResponseToPgError(msg)
|
||||||
case *pgproto3.ReadyForQuery:
|
case *pgproto3.ReadyForQuery:
|
||||||
<-pgConn.controller
|
|
||||||
return commandTag, pgErr
|
return commandTag, pgErr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -871,8 +924,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||||||
_, err = pgConn.conn.Write(buf)
|
_, err = pgConn.conn.Write(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.hardClose()
|
pgConn.hardClose()
|
||||||
<-pgConn.controller
|
|
||||||
|
|
||||||
return "", preferContextOverNetTimeoutError(ctx, err)
|
return "", preferContextOverNetTimeoutError(ctx, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -904,8 +955,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||||||
_, err = pgConn.conn.Write(buf)
|
_, err = pgConn.conn.Write(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.hardClose()
|
pgConn.hardClose()
|
||||||
<-pgConn.controller
|
|
||||||
|
|
||||||
return "", preferContextOverNetTimeoutError(ctx, err)
|
return "", preferContextOverNetTimeoutError(ctx, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -919,7 +968,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||||||
|
|
||||||
switch msg := msg.(type) {
|
switch msg := msg.(type) {
|
||||||
case *pgproto3.ReadyForQuery:
|
case *pgproto3.ReadyForQuery:
|
||||||
<-pgConn.controller
|
|
||||||
return commandTag, pgErr
|
return commandTag, pgErr
|
||||||
case *pgproto3.CommandComplete:
|
case *pgproto3.CommandComplete:
|
||||||
commandTag = CommandTag(msg.CommandTag)
|
commandTag = CommandTag(msg.CommandTag)
|
||||||
@@ -968,7 +1016,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
|
|||||||
case *pgproto3.ReadyForQuery:
|
case *pgproto3.ReadyForQuery:
|
||||||
mrr.cleanupContextDeadline()
|
mrr.cleanupContextDeadline()
|
||||||
mrr.closed = true
|
mrr.closed = true
|
||||||
<-mrr.pgConn.controller
|
mrr.pgConn.unlock()
|
||||||
case *pgproto3.ErrorResponse:
|
case *pgproto3.ErrorResponse:
|
||||||
mrr.err = errorResponseToPgError(msg)
|
mrr.err = errorResponseToPgError(msg)
|
||||||
}
|
}
|
||||||
@@ -1125,7 +1173,7 @@ func (rr *ResultReader) Close() (CommandTag, error) {
|
|||||||
switch msg.(type) {
|
switch msg.(type) {
|
||||||
case *pgproto3.ReadyForQuery:
|
case *pgproto3.ReadyForQuery:
|
||||||
rr.cleanupContextDeadline()
|
rr.cleanupContextDeadline()
|
||||||
<-rr.pgConn.controller
|
rr.pgConn.unlock()
|
||||||
return rr.commandTag, rr.err
|
return rr.commandTag, rr.err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1203,12 +1251,18 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
|
|||||||
cleanupContextDeadline: func() {},
|
cleanupContextDeadline: func() {},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := pgConn.lock(); err != nil {
|
||||||
|
multiResult.closed = true
|
||||||
|
multiResult.err = ctx.Err()
|
||||||
|
return multiResult
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
multiResult.closed = true
|
multiResult.closed = true
|
||||||
multiResult.err = ctx.Err()
|
multiResult.err = ctx.Err()
|
||||||
return multiResult
|
return multiResult
|
||||||
case pgConn.controller <- multiResult:
|
default:
|
||||||
}
|
}
|
||||||
multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn)
|
multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn)
|
||||||
|
|
||||||
@@ -1219,7 +1273,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
|
|||||||
multiResult.cleanupContextDeadline()
|
multiResult.cleanupContextDeadline()
|
||||||
multiResult.closed = true
|
multiResult.closed = true
|
||||||
multiResult.err = preferContextOverNetTimeoutError(ctx, err)
|
multiResult.err = preferContextOverNetTimeoutError(ctx, err)
|
||||||
<-pgConn.controller
|
pgConn.unlock()
|
||||||
return multiResult
|
return multiResult
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -484,6 +484,29 @@ func TestConnExecBatch(t *testing.T) {
|
|||||||
assert.Equal(t, "SELECT 1", string(results[2].CommandTag))
|
assert.Equal(t, "SELECT 1", string(results[2].CommandTag))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnLocking(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
mrr := pgConn.Exec(context.Background(), "select 'Hello, world'")
|
||||||
|
results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll()
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Equal(t, "connection busy", err.Error())
|
||||||
|
|
||||||
|
results, err = mrr.ReadAll()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, results, 1)
|
||||||
|
assert.Nil(t, results[0].Err)
|
||||||
|
assert.Equal(t, "SELECT 1", string(results[0].CommandTag))
|
||||||
|
assert.Len(t, results[0].Rows, 1)
|
||||||
|
assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
|
}
|
||||||
|
|
||||||
func TestCommandTag(t *testing.T) {
|
func TestCommandTag(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user