diff --git a/pgconn.go b/pgconn.go index 8b0ddcb4..e246bcdd 100644 --- a/pgconn.go +++ b/pgconn.go @@ -89,8 +89,7 @@ type PgConn struct { Config *Config - controller chan interface{} - + locked bool closed bool bufferingReceive bool @@ -153,7 +152,6 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { pgConn := new(PgConn) pgConn.Config = config - pgConn.controller = make(chan interface{}, 1) var err error network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) @@ -405,6 +403,29 @@ func (pgConn *PgConn) IsAlive() bool { return !pgConn.closed } +// lock locks the connection. It returns an error if the connection is already locked or is closed. +func (pgConn *PgConn) lock() error { + if pgConn.locked { + return errors.New("connection busy") + } + + if pgConn.closed { + return errors.New("connection closed") + } + + pgConn.locked = true + + return nil +} + +func (pgConn *PgConn) unlock() { + if !pgConn.locked { + panic("BUG: cannot unlock unlocked connection") + } + + pgConn.locked = false +} + // ParameterStatus returns the value of a parameter reported by the server (e.g. // server_version). Returns an empty string for unknown parameters. func (pgConn *PgConn) ParameterStatus(key string) string { @@ -476,10 +497,14 @@ type PreparedStatementDescription struct { // Prepare creates a prepared statement. func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) { + if err := pgConn.lock(); err != nil { + return nil, err + } + select { case <-ctx.Done(): return nil, ctx.Err() - case pgConn.controller <- pgConn: + default: } cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanupContextDeadline() @@ -521,7 +546,7 @@ readloop: } } - <-pgConn.controller + pgConn.unlock() if parseErr != nil { return nil, parseErr @@ -594,14 +619,18 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not // received. func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { + if err := pgConn.lock(); err != nil { + return err + } + select { case <-ctx.Done(): return ctx.Err() - case pgConn.controller <- pgConn: + default: } cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanupContextDeadline() - defer func() { <-pgConn.controller }() + defer pgConn.unlock() for { msg, err := pgConn.ReceiveMessage() @@ -628,12 +657,18 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { cleanupContextDeadline: func() {}, } + if err := pgConn.lock(); err != nil { + multiResult.closed = true + multiResult.err = err + return multiResult + } + select { case <-ctx.Done(): multiResult.closed = true multiResult.err = ctx.Err() return multiResult - case pgConn.controller <- multiResult: + default: } multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) @@ -646,7 +681,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { multiResult.cleanupContextDeadline() multiResult.closed = true multiResult.err = preferContextOverNetTimeoutError(ctx, err) - <-pgConn.controller + pgConn.unlock() return multiResult } @@ -679,12 +714,18 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] cleanupContextDeadline: func() {}, } + if err := pgConn.lock(); err != nil { + result.concludeCommand("", err) + result.closed = true + return result + } + select { case <-ctx.Done(): result.concludeCommand("", ctx.Err()) result.closed = true return result - case pgConn.controller <- result: + default: } result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) @@ -704,7 +745,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] result.concludeCommand("", err) result.cleanupContextDeadline() result.closed = true - <-pgConn.controller + pgConn.unlock() } return result @@ -729,12 +770,18 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa cleanupContextDeadline: func() {}, } + if err := pgConn.lock(); err != nil { + result.concludeCommand("", err) + result.closed = true + return result + } + select { case <-ctx.Done(): result.concludeCommand("", ctx.Err()) result.closed = true return result - case pgConn.controller <- result: + default: } result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) @@ -750,7 +797,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa result.concludeCommand("", err) result.cleanupContextDeadline() result.closed = true - <-pgConn.controller + pgConn.unlock() } return result @@ -758,10 +805,14 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa // CopyTo executes the copy command sql and copies the results to w. func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { + if err := pgConn.lock(); err != nil { + return "", err + } + select { case <-ctx.Done(): return "", ctx.Err() - case pgConn.controller <- pgConn: + default: } cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanupContextDeadline() @@ -773,7 +824,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm _, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - <-pgConn.controller + pgConn.unlock() return "", preferContextOverNetTimeoutError(ctx, err) } @@ -797,7 +848,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm return "", err } case *pgproto3.ReadyForQuery: - <-pgConn.controller + pgConn.unlock() return commandTag, pgErr case *pgproto3.CommandComplete: commandTag = CommandTag(msg.CommandTag) @@ -812,10 +863,15 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm // Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r // could still block. func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { + if err := pgConn.lock(); err != nil { + return "", err + } + defer pgConn.unlock() + select { case <-ctx.Done(): return "", ctx.Err() - case pgConn.controller <- pgConn: + default: } cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanupContextDeadline() @@ -827,8 +883,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - <-pgConn.controller - return "", preferContextOverNetTimeoutError(ctx, err) } @@ -849,7 +903,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co case *pgproto3.ErrorResponse: pgErr = errorResponseToPgError(msg) case *pgproto3.ReadyForQuery: - <-pgConn.controller return commandTag, pgErr } } @@ -871,8 +924,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - <-pgConn.controller - return "", preferContextOverNetTimeoutError(ctx, err) } } @@ -904,8 +955,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - <-pgConn.controller - return "", preferContextOverNetTimeoutError(ctx, err) } @@ -919,7 +968,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co switch msg := msg.(type) { case *pgproto3.ReadyForQuery: - <-pgConn.controller return commandTag, pgErr case *pgproto3.CommandComplete: commandTag = CommandTag(msg.CommandTag) @@ -968,7 +1016,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) case *pgproto3.ReadyForQuery: mrr.cleanupContextDeadline() mrr.closed = true - <-mrr.pgConn.controller + mrr.pgConn.unlock() case *pgproto3.ErrorResponse: mrr.err = errorResponseToPgError(msg) } @@ -1125,7 +1173,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { switch msg.(type) { case *pgproto3.ReadyForQuery: rr.cleanupContextDeadline() - <-rr.pgConn.controller + rr.pgConn.unlock() return rr.commandTag, rr.err } } @@ -1203,12 +1251,18 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR cleanupContextDeadline: func() {}, } + if err := pgConn.lock(); err != nil { + multiResult.closed = true + multiResult.err = ctx.Err() + return multiResult + } + select { case <-ctx.Done(): multiResult.closed = true multiResult.err = ctx.Err() return multiResult - case pgConn.controller <- multiResult: + default: } multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) @@ -1219,7 +1273,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR multiResult.cleanupContextDeadline() multiResult.closed = true multiResult.err = preferContextOverNetTimeoutError(ctx, err) - <-pgConn.controller + pgConn.unlock() return multiResult } diff --git a/pgconn_test.go b/pgconn_test.go index 88c6f7c4..53e3b9d8 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -484,6 +484,29 @@ func TestConnExecBatch(t *testing.T) { assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) } +func TestConnLocking(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") + results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + assert.Error(t, err) + assert.Equal(t, "connection busy", err.Error()) + + results, err = mrr.ReadAll() + assert.NoError(t, err) + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + + ensureConnValid(t, pgConn) +} + func TestCommandTag(t *testing.T) { t.Parallel()