2
0

Unify locked and closed into status

No longer panic on locking busy conn
This commit is contained in:
Jack Christensen
2019-04-19 15:52:12 -05:00
parent 16412e56e2
commit 7bb6c2f3e9
2 changed files with 63 additions and 24 deletions
+59 -22
View File
@@ -20,6 +20,13 @@ import (
"github.com/jackc/pgproto3/v2"
)
const (
connStatusUninitialized = iota
connStatusClosed
connStatusIdle
connStatusBusy
)
// PgError represents an error reported by the PostgreSQL server. See
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
// detailed field description.
@@ -88,8 +95,7 @@ type PgConn struct {
Config *Config
locked bool
closed bool
status byte // One of connStatus* constants
bufferingReceive bool
bufferingReceiveMux sync.Mutex
@@ -217,6 +223,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
return nil, err
}
case *pgproto3.ReadyForQuery:
pgConn.status = connStatusIdle
if config.AfterConnectFunc != nil {
err := config.AfterConnectFunc(ctx, pgConn)
if err != nil {
@@ -373,10 +380,10 @@ func (pgConn *PgConn) SecretKey() uint32 {
// sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The
// underlying net.Conn.Close() will always be called regardless of any other errors.
func (pgConn *PgConn) Close(ctx context.Context) error {
if pgConn.closed {
if pgConn.status == connStatusClosed {
return nil
}
pgConn.closed = true
pgConn.status = connStatusClosed
defer pgConn.conn.Close()
@@ -398,34 +405,41 @@ func (pgConn *PgConn) Close(ctx context.Context) error {
// hardClose closes the underlying connection without sending the exit message.
func (pgConn *PgConn) hardClose() error {
if pgConn.closed {
if pgConn.status == connStatusClosed {
return nil
}
pgConn.closed = true
pgConn.status = connStatusClosed
return pgConn.conn.Close()
}
// TODO - rethink how to report status. At the moment this is just a temporary measure so pgx.Conn can detect deatch of
// underlying connection.
func (pgConn *PgConn) IsAlive() bool {
return !pgConn.closed
return pgConn.status >= connStatusIdle
}
// lock locks the connection. It panics if the connection is already locked or is closed.
func (pgConn *PgConn) lock() {
if pgConn.locked {
panic("connection busy") // This only should be possible in case of an application bug.
func (pgConn *PgConn) lock() error {
switch pgConn.status {
case connStatusBusy:
return errors.New("connection busy") // This only should be possible in case of an application bug.
case connStatusClosed:
return errors.New("conn closed")
case connStatusUninitialized:
return errors.New("conn uninitialized")
}
pgConn.locked = true
pgConn.status = connStatusBusy
return nil
}
func (pgConn *PgConn) unlock() {
if !pgConn.locked {
switch pgConn.status {
case connStatusBusy:
pgConn.status = connStatusIdle
case connStatusClosed:
default:
panic("BUG: cannot unlock unlocked connection") // This should only be possible if there is a bug in this package.
}
pgConn.locked = false
}
// ParameterStatus returns the value of a parameter reported by the server (e.g.
@@ -470,7 +484,9 @@ type PreparedStatementDescription struct {
// Prepare creates a prepared statement.
func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) {
pgConn.lock()
if err := pgConn.lock(); err != nil {
return nil, err
}
defer pgConn.unlock()
select {
@@ -590,7 +606,9 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
// WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not
// received.
func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
pgConn.lock()
if err := pgConn.lock(); err != nil {
return err
}
defer pgConn.unlock()
select {
@@ -621,7 +639,12 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
//
// Prefer ExecParams unless executing arbitrary SQL that may contain multiple queries.
func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
pgConn.lock()
if err := pgConn.lock(); err != nil {
return &MultiResultReader{
closed: true,
err: err,
}
}
pgConn.multiResultReader = MultiResultReader{
pgConn: pgConn,
@@ -716,7 +739,12 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
}
func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader {
pgConn.lock()
if err := pgConn.lock(); err != nil {
return &ResultReader{
closed: true,
err: err,
}
}
pgConn.resultReader = ResultReader{
pgConn: pgConn,
@@ -761,7 +789,9 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) {
// CopyTo executes the copy command sql and copies the results to w.
func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) {
pgConn.lock()
if err := pgConn.lock(); err != nil {
return nil, err
}
select {
case <-ctx.Done():
@@ -818,7 +848,9 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
// Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r
// could still block.
func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) {
pgConn.lock()
if err := pgConn.lock(); err != nil {
return nil, err
}
defer pgConn.unlock()
select {
@@ -1197,7 +1229,12 @@ func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFor
// ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a
// transaction is already in progress or SQL contains transaction control statements.
func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader {
pgConn.lock()
if err := pgConn.lock(); err != nil {
return &MultiResultReader{
closed: true,
err: err,
}
}
pgConn.multiResultReader = MultiResultReader{
pgConn: pgConn,
+4 -2
View File
@@ -690,9 +690,11 @@ func TestConnLocking(t *testing.T) {
defer closeConn(t, pgConn)
mrr := pgConn.Exec(context.Background(), "select 'Hello, world'")
require.Panics(t, func() { pgConn.Exec(context.Background(), "select 'Hello, world'") })
results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll()
assert.Error(t, err)
assert.Equal(t, "connection busy", err.Error())
results, err := mrr.ReadAll()
results, err = mrr.ReadAll()
assert.NoError(t, err)
assert.Len(t, results, 1)
assert.Nil(t, results[0].Err)