2
0

Replace chan based conn locking with bool

This is conceptually simpler and will lead to error messages instead of
deadlocks.
This commit is contained in:
Jack Christensen
2019-03-30 17:09:39 -05:00
parent 444bd6deaf
commit 3d9e42d74c
2 changed files with 106 additions and 29 deletions
+83 -29
View File
@@ -89,8 +89,7 @@ type PgConn struct {
Config *Config
controller chan interface{}
locked bool
closed bool
bufferingReceive bool
@@ -153,7 +152,6 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) {
pgConn := new(PgConn)
pgConn.Config = config
pgConn.controller = make(chan interface{}, 1)
var err error
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
@@ -405,6 +403,29 @@ func (pgConn *PgConn) IsAlive() bool {
return !pgConn.closed
}
// lock locks the connection. It returns an error if the connection is already locked or is closed.
func (pgConn *PgConn) lock() error {
if pgConn.locked {
return errors.New("connection busy")
}
if pgConn.closed {
return errors.New("connection closed")
}
pgConn.locked = true
return nil
}
func (pgConn *PgConn) unlock() {
if !pgConn.locked {
panic("BUG: cannot unlock unlocked connection")
}
pgConn.locked = false
}
// ParameterStatus returns the value of a parameter reported by the server (e.g.
// server_version). Returns an empty string for unknown parameters.
func (pgConn *PgConn) ParameterStatus(key string) string {
@@ -476,10 +497,14 @@ type PreparedStatementDescription struct {
// Prepare creates a prepared statement.
func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) {
if err := pgConn.lock(); err != nil {
return nil, err
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case pgConn.controller <- pgConn:
default:
}
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn)
defer cleanupContextDeadline()
@@ -521,7 +546,7 @@ readloop:
}
}
<-pgConn.controller
pgConn.unlock()
if parseErr != nil {
return nil, parseErr
@@ -594,14 +619,18 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
// WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not
// received.
func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
if err := pgConn.lock(); err != nil {
return err
}
select {
case <-ctx.Done():
return ctx.Err()
case pgConn.controller <- pgConn:
default:
}
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn)
defer cleanupContextDeadline()
defer func() { <-pgConn.controller }()
defer pgConn.unlock()
for {
msg, err := pgConn.ReceiveMessage()
@@ -628,12 +657,18 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
cleanupContextDeadline: func() {},
}
if err := pgConn.lock(); err != nil {
multiResult.closed = true
multiResult.err = err
return multiResult
}
select {
case <-ctx.Done():
multiResult.closed = true
multiResult.err = ctx.Err()
return multiResult
case pgConn.controller <- multiResult:
default:
}
multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn)
@@ -646,7 +681,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
multiResult.cleanupContextDeadline()
multiResult.closed = true
multiResult.err = preferContextOverNetTimeoutError(ctx, err)
<-pgConn.controller
pgConn.unlock()
return multiResult
}
@@ -679,12 +714,18 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
cleanupContextDeadline: func() {},
}
if err := pgConn.lock(); err != nil {
result.concludeCommand("", err)
result.closed = true
return result
}
select {
case <-ctx.Done():
result.concludeCommand("", ctx.Err())
result.closed = true
return result
case pgConn.controller <- result:
default:
}
result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn)
@@ -704,7 +745,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
result.concludeCommand("", err)
result.cleanupContextDeadline()
result.closed = true
<-pgConn.controller
pgConn.unlock()
}
return result
@@ -729,12 +770,18 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
cleanupContextDeadline: func() {},
}
if err := pgConn.lock(); err != nil {
result.concludeCommand("", err)
result.closed = true
return result
}
select {
case <-ctx.Done():
result.concludeCommand("", ctx.Err())
result.closed = true
return result
case pgConn.controller <- result:
default:
}
result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn)
@@ -750,7 +797,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
result.concludeCommand("", err)
result.cleanupContextDeadline()
result.closed = true
<-pgConn.controller
pgConn.unlock()
}
return result
@@ -758,10 +805,14 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
// CopyTo executes the copy command sql and copies the results to w.
func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) {
if err := pgConn.lock(); err != nil {
return "", err
}
select {
case <-ctx.Done():
return "", ctx.Err()
case pgConn.controller <- pgConn:
default:
}
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn)
defer cleanupContextDeadline()
@@ -773,7 +824,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
_, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
<-pgConn.controller
pgConn.unlock()
return "", preferContextOverNetTimeoutError(ctx, err)
}
@@ -797,7 +848,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
return "", err
}
case *pgproto3.ReadyForQuery:
<-pgConn.controller
pgConn.unlock()
return commandTag, pgErr
case *pgproto3.CommandComplete:
commandTag = CommandTag(msg.CommandTag)
@@ -812,10 +863,15 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
// Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r
// could still block.
func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) {
if err := pgConn.lock(); err != nil {
return "", err
}
defer pgConn.unlock()
select {
case <-ctx.Done():
return "", ctx.Err()
case pgConn.controller <- pgConn:
default:
}
cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn)
defer cleanupContextDeadline()
@@ -827,8 +883,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
_, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
<-pgConn.controller
return "", preferContextOverNetTimeoutError(ctx, err)
}
@@ -849,7 +903,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
case *pgproto3.ErrorResponse:
pgErr = errorResponseToPgError(msg)
case *pgproto3.ReadyForQuery:
<-pgConn.controller
return commandTag, pgErr
}
}
@@ -871,8 +924,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
_, err = pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
<-pgConn.controller
return "", preferContextOverNetTimeoutError(ctx, err)
}
}
@@ -904,8 +955,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
_, err = pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
<-pgConn.controller
return "", preferContextOverNetTimeoutError(ctx, err)
}
@@ -919,7 +968,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
<-pgConn.controller
return commandTag, pgErr
case *pgproto3.CommandComplete:
commandTag = CommandTag(msg.CommandTag)
@@ -968,7 +1016,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
case *pgproto3.ReadyForQuery:
mrr.cleanupContextDeadline()
mrr.closed = true
<-mrr.pgConn.controller
mrr.pgConn.unlock()
case *pgproto3.ErrorResponse:
mrr.err = errorResponseToPgError(msg)
}
@@ -1125,7 +1173,7 @@ func (rr *ResultReader) Close() (CommandTag, error) {
switch msg.(type) {
case *pgproto3.ReadyForQuery:
rr.cleanupContextDeadline()
<-rr.pgConn.controller
rr.pgConn.unlock()
return rr.commandTag, rr.err
}
}
@@ -1203,12 +1251,18 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
cleanupContextDeadline: func() {},
}
if err := pgConn.lock(); err != nil {
multiResult.closed = true
multiResult.err = ctx.Err()
return multiResult
}
select {
case <-ctx.Done():
multiResult.closed = true
multiResult.err = ctx.Err()
return multiResult
case pgConn.controller <- multiResult:
default:
}
multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn)
@@ -1219,7 +1273,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
multiResult.cleanupContextDeadline()
multiResult.closed = true
multiResult.err = preferContextOverNetTimeoutError(ctx, err)
<-pgConn.controller
pgConn.unlock()
return multiResult
}
+23
View File
@@ -484,6 +484,29 @@ func TestConnExecBatch(t *testing.T) {
assert.Equal(t, "SELECT 1", string(results[2].CommandTag))
}
func TestConnLocking(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
mrr := pgConn.Exec(context.Background(), "select 'Hello, world'")
results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll()
assert.Error(t, err)
assert.Equal(t, "connection busy", err.Error())
results, err = mrr.ReadAll()
assert.NoError(t, err)
assert.Len(t, results, 1)
assert.Nil(t, results[0].Err)
assert.Equal(t, "SELECT 1", string(results[0].CommandTag))
assert.Len(t, results[0].Rows, 1)
assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
ensureConnValid(t, pgConn)
}
func TestCommandTag(t *testing.T) {
t.Parallel()