diff --git a/pgconn.go b/pgconn.go index 6ff0d39f..7a9a42e4 100644 --- a/pgconn.go +++ b/pgconn.go @@ -20,6 +20,13 @@ import ( "github.com/jackc/pgproto3/v2" ) +const ( + connStatusUninitialized = iota + connStatusClosed + connStatusIdle + connStatusBusy +) + // PgError represents an error reported by the PostgreSQL server. See // http://www.postgresql.org/docs/11/static/protocol-error-fields.html for // detailed field description. @@ -88,8 +95,7 @@ type PgConn struct { Config *Config - locked bool - closed bool + status byte // One of connStatus* constants bufferingReceive bool bufferingReceiveMux sync.Mutex @@ -217,6 +223,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig return nil, err } case *pgproto3.ReadyForQuery: + pgConn.status = connStatusIdle if config.AfterConnectFunc != nil { err := config.AfterConnectFunc(ctx, pgConn) if err != nil { @@ -373,10 +380,10 @@ func (pgConn *PgConn) SecretKey() uint32 { // sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The // underlying net.Conn.Close() will always be called regardless of any other errors. func (pgConn *PgConn) Close(ctx context.Context) error { - if pgConn.closed { + if pgConn.status == connStatusClosed { return nil } - pgConn.closed = true + pgConn.status = connStatusClosed defer pgConn.conn.Close() @@ -398,34 +405,41 @@ func (pgConn *PgConn) Close(ctx context.Context) error { // hardClose closes the underlying connection without sending the exit message. func (pgConn *PgConn) hardClose() error { - if pgConn.closed { + if pgConn.status == connStatusClosed { return nil } - pgConn.closed = true + pgConn.status = connStatusClosed return pgConn.conn.Close() } // TODO - rethink how to report status. At the moment this is just a temporary measure so pgx.Conn can detect deatch of // underlying connection. func (pgConn *PgConn) IsAlive() bool { - return !pgConn.closed + return pgConn.status >= connStatusIdle } // lock locks the connection. It panics if the connection is already locked or is closed. -func (pgConn *PgConn) lock() { - if pgConn.locked { - panic("connection busy") // This only should be possible in case of an application bug. +func (pgConn *PgConn) lock() error { + switch pgConn.status { + case connStatusBusy: + return errors.New("connection busy") // This only should be possible in case of an application bug. + case connStatusClosed: + return errors.New("conn closed") + case connStatusUninitialized: + return errors.New("conn uninitialized") } - - pgConn.locked = true + pgConn.status = connStatusBusy + return nil } func (pgConn *PgConn) unlock() { - if !pgConn.locked { + switch pgConn.status { + case connStatusBusy: + pgConn.status = connStatusIdle + case connStatusClosed: + default: panic("BUG: cannot unlock unlocked connection") // This should only be possible if there is a bug in this package. } - - pgConn.locked = false } // ParameterStatus returns the value of a parameter reported by the server (e.g. @@ -470,7 +484,9 @@ type PreparedStatementDescription struct { // Prepare creates a prepared statement. func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) { - pgConn.lock() + if err := pgConn.lock(); err != nil { + return nil, err + } defer pgConn.unlock() select { @@ -590,7 +606,9 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not // received. func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { - pgConn.lock() + if err := pgConn.lock(); err != nil { + return err + } defer pgConn.unlock() select { @@ -621,7 +639,12 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { // // Prefer ExecParams unless executing arbitrary SQL that may contain multiple queries. func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { - pgConn.lock() + if err := pgConn.lock(); err != nil { + return &MultiResultReader{ + closed: true, + err: err, + } + } pgConn.multiResultReader = MultiResultReader{ pgConn: pgConn, @@ -716,7 +739,12 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa } func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader { - pgConn.lock() + if err := pgConn.lock(); err != nil { + return &ResultReader{ + closed: true, + err: err, + } + } pgConn.resultReader = ResultReader{ pgConn: pgConn, @@ -761,7 +789,9 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { // CopyTo executes the copy command sql and copies the results to w. func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { - pgConn.lock() + if err := pgConn.lock(); err != nil { + return nil, err + } select { case <-ctx.Done(): @@ -818,7 +848,9 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm // Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r // could still block. func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { - pgConn.lock() + if err := pgConn.lock(); err != nil { + return nil, err + } defer pgConn.unlock() select { @@ -1197,7 +1229,12 @@ func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFor // ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a // transaction is already in progress or SQL contains transaction control statements. func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader { - pgConn.lock() + if err := pgConn.lock(); err != nil { + return &MultiResultReader{ + closed: true, + err: err, + } + } pgConn.multiResultReader = MultiResultReader{ pgConn: pgConn, diff --git a/pgconn_test.go b/pgconn_test.go index 2b1e68a3..2ad02830 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -690,9 +690,11 @@ func TestConnLocking(t *testing.T) { defer closeConn(t, pgConn) mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") - require.Panics(t, func() { pgConn.Exec(context.Background(), "select 'Hello, world'") }) + results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + assert.Error(t, err) + assert.Equal(t, "connection busy", err.Error()) - results, err := mrr.ReadAll() + results, err = mrr.ReadAll() assert.NoError(t, err) assert.Len(t, results, 1) assert.Nil(t, results[0].Err)