2
0

Unify locked and closed into status

No longer panic on locking busy conn
This commit is contained in:
Jack Christensen
2019-04-19 15:52:12 -05:00
parent 16412e56e2
commit 7bb6c2f3e9
2 changed files with 63 additions and 24 deletions
+59 -22
View File
@@ -20,6 +20,13 @@ import (
"github.com/jackc/pgproto3/v2" "github.com/jackc/pgproto3/v2"
) )
const (
connStatusUninitialized = iota
connStatusClosed
connStatusIdle
connStatusBusy
)
// PgError represents an error reported by the PostgreSQL server. See // PgError represents an error reported by the PostgreSQL server. See
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for // http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
// detailed field description. // detailed field description.
@@ -88,8 +95,7 @@ type PgConn struct {
Config *Config Config *Config
locked bool status byte // One of connStatus* constants
closed bool
bufferingReceive bool bufferingReceive bool
bufferingReceiveMux sync.Mutex bufferingReceiveMux sync.Mutex
@@ -217,6 +223,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
return nil, err return nil, err
} }
case *pgproto3.ReadyForQuery: case *pgproto3.ReadyForQuery:
pgConn.status = connStatusIdle
if config.AfterConnectFunc != nil { if config.AfterConnectFunc != nil {
err := config.AfterConnectFunc(ctx, pgConn) err := config.AfterConnectFunc(ctx, pgConn)
if err != nil { if err != nil {
@@ -373,10 +380,10 @@ func (pgConn *PgConn) SecretKey() uint32 {
// sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The // sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The
// underlying net.Conn.Close() will always be called regardless of any other errors. // underlying net.Conn.Close() will always be called regardless of any other errors.
func (pgConn *PgConn) Close(ctx context.Context) error { func (pgConn *PgConn) Close(ctx context.Context) error {
if pgConn.closed { if pgConn.status == connStatusClosed {
return nil return nil
} }
pgConn.closed = true pgConn.status = connStatusClosed
defer pgConn.conn.Close() defer pgConn.conn.Close()
@@ -398,34 +405,41 @@ func (pgConn *PgConn) Close(ctx context.Context) error {
// hardClose closes the underlying connection without sending the exit message. // hardClose closes the underlying connection without sending the exit message.
func (pgConn *PgConn) hardClose() error { func (pgConn *PgConn) hardClose() error {
if pgConn.closed { if pgConn.status == connStatusClosed {
return nil return nil
} }
pgConn.closed = true pgConn.status = connStatusClosed
return pgConn.conn.Close() return pgConn.conn.Close()
} }
// TODO - rethink how to report status. At the moment this is just a temporary measure so pgx.Conn can detect deatch of // TODO - rethink how to report status. At the moment this is just a temporary measure so pgx.Conn can detect deatch of
// underlying connection. // underlying connection.
func (pgConn *PgConn) IsAlive() bool { func (pgConn *PgConn) IsAlive() bool {
return !pgConn.closed return pgConn.status >= connStatusIdle
} }
// lock locks the connection. It panics if the connection is already locked or is closed. // lock locks the connection. It panics if the connection is already locked or is closed.
func (pgConn *PgConn) lock() { func (pgConn *PgConn) lock() error {
if pgConn.locked { switch pgConn.status {
panic("connection busy") // This only should be possible in case of an application bug. case connStatusBusy:
return errors.New("connection busy") // This only should be possible in case of an application bug.
case connStatusClosed:
return errors.New("conn closed")
case connStatusUninitialized:
return errors.New("conn uninitialized")
} }
pgConn.status = connStatusBusy
pgConn.locked = true return nil
} }
func (pgConn *PgConn) unlock() { func (pgConn *PgConn) unlock() {
if !pgConn.locked { switch pgConn.status {
case connStatusBusy:
pgConn.status = connStatusIdle
case connStatusClosed:
default:
panic("BUG: cannot unlock unlocked connection") // This should only be possible if there is a bug in this package. panic("BUG: cannot unlock unlocked connection") // This should only be possible if there is a bug in this package.
} }
pgConn.locked = false
} }
// ParameterStatus returns the value of a parameter reported by the server (e.g. // ParameterStatus returns the value of a parameter reported by the server (e.g.
@@ -470,7 +484,9 @@ type PreparedStatementDescription struct {
// Prepare creates a prepared statement. // Prepare creates a prepared statement.
func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) { func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) {
pgConn.lock() if err := pgConn.lock(); err != nil {
return nil, err
}
defer pgConn.unlock() defer pgConn.unlock()
select { select {
@@ -590,7 +606,9 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
// WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not // WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not
// received. // received.
func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
pgConn.lock() if err := pgConn.lock(); err != nil {
return err
}
defer pgConn.unlock() defer pgConn.unlock()
select { select {
@@ -621,7 +639,12 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
// //
// Prefer ExecParams unless executing arbitrary SQL that may contain multiple queries. // Prefer ExecParams unless executing arbitrary SQL that may contain multiple queries.
func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
pgConn.lock() if err := pgConn.lock(); err != nil {
return &MultiResultReader{
closed: true,
err: err,
}
}
pgConn.multiResultReader = MultiResultReader{ pgConn.multiResultReader = MultiResultReader{
pgConn: pgConn, pgConn: pgConn,
@@ -716,7 +739,12 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
} }
func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader { func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader {
pgConn.lock() if err := pgConn.lock(); err != nil {
return &ResultReader{
closed: true,
err: err,
}
}
pgConn.resultReader = ResultReader{ pgConn.resultReader = ResultReader{
pgConn: pgConn, pgConn: pgConn,
@@ -761,7 +789,9 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) {
// CopyTo executes the copy command sql and copies the results to w. // CopyTo executes the copy command sql and copies the results to w.
func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) {
pgConn.lock() if err := pgConn.lock(); err != nil {
return nil, err
}
select { select {
case <-ctx.Done(): case <-ctx.Done():
@@ -818,7 +848,9 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
// Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r // Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r
// could still block. // could still block.
func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) {
pgConn.lock() if err := pgConn.lock(); err != nil {
return nil, err
}
defer pgConn.unlock() defer pgConn.unlock()
select { select {
@@ -1197,7 +1229,12 @@ func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFor
// ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a // ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a
// transaction is already in progress or SQL contains transaction control statements. // transaction is already in progress or SQL contains transaction control statements.
func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader { func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader {
pgConn.lock() if err := pgConn.lock(); err != nil {
return &MultiResultReader{
closed: true,
err: err,
}
}
pgConn.multiResultReader = MultiResultReader{ pgConn.multiResultReader = MultiResultReader{
pgConn: pgConn, pgConn: pgConn,
+4 -2
View File
@@ -690,9 +690,11 @@ func TestConnLocking(t *testing.T) {
defer closeConn(t, pgConn) defer closeConn(t, pgConn)
mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") mrr := pgConn.Exec(context.Background(), "select 'Hello, world'")
require.Panics(t, func() { pgConn.Exec(context.Background(), "select 'Hello, world'") }) results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll()
assert.Error(t, err)
assert.Equal(t, "connection busy", err.Error())
results, err := mrr.ReadAll() results, err = mrr.ReadAll()
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, results, 1) assert.Len(t, results, 1)
assert.Nil(t, results[0].Err) assert.Nil(t, results[0].Err)