diff --git a/batch_test.go b/batch_test.go index 9fcbdf0d..1cc27c4b 100644 --- a/batch_test.go +++ b/batch_test.go @@ -535,7 +535,7 @@ func TestTxBeginBatch(t *testing.T) { } batch.Close() - tx.Commit() + tx.Commit(context.Background()) var count int conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id).Scan(&count) @@ -581,7 +581,7 @@ func TestTxBeginBatchRollback(t *testing.T) { t.Error(err) } batch.Close() - tx.Rollback() + tx.Rollback(context.Background()) row := conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id) var count int diff --git a/bench_test.go b/bench_test.go index 5f64ba7b..1eb492a9 100644 --- a/bench_test.go +++ b/bench_test.go @@ -361,7 +361,7 @@ func benchmarkWriteNRowsViaInsert(b *testing.B, n int) { } } - err = tx.Commit() + err = tx.Commit(context.Background()) if err != nil { b.Fatal(err) } @@ -392,7 +392,7 @@ func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc if err != nil { return 0, err } - defer tx.Rollback() + defer tx.Rollback(context.Background()) for rowSrc.Next() { if rowsThisInsert > 0 { @@ -437,7 +437,7 @@ func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc rowCount += rowsThisInsert } - if err := tx.Commit(); err != nil { + if err := tx.Commit(context.Background()); err != nil { return 0, nil } diff --git a/doc.go b/doc.go index 7da1cd88..c45759fb 100644 --- a/doc.go +++ b/doc.go @@ -184,12 +184,12 @@ can create a transaction with a specified isolation level. // the tx commits successfully, this is a no-op defer tx.Rollback() - _, err = tx.Exec("insert into foo(id) values (1)") + _, err = tx.Exec(context.Background(), "insert into foo(id) values (1)") if err != nil { return err } - err = tx.Commit() + err = tx.Commit(context.Background()) if err != nil { return err } diff --git a/large_objects_test.go b/large_objects_test.go index 2c8651bb..97f68fa6 100644 --- a/large_objects_test.go +++ b/large_objects_test.go @@ -160,7 +160,7 @@ func TestLargeObjectsMultipleTransactions(t *testing.T) { } // Commit the first transaction - err = tx.Commit() + err = tx.Commit(context.Background()) if err != nil { t.Fatal(err) } diff --git a/pool/tx.go b/pool/tx.go index 59527080..2898a21d 100644 --- a/pool/tx.go +++ b/pool/tx.go @@ -12,8 +12,8 @@ type Tx struct { c *Conn } -func (tx *Tx) Commit() error { - err := tx.t.Commit() +func (tx *Tx) Commit(ctx context.Context) error { + err := tx.t.Commit(ctx) if tx.c != nil { tx.c.Release() tx.c = nil @@ -21,8 +21,8 @@ func (tx *Tx) Commit() error { return err } -func (tx *Tx) Rollback() error { - err := tx.t.Rollback() +func (tx *Tx) Rollback(ctx context.Context) error { + err := tx.t.Rollback(ctx) if tx.c != nil { tx.c.Release() tx.c = nil diff --git a/pool/tx_test.go b/pool/tx_test.go index 20a7ec55..14ea739a 100644 --- a/pool/tx_test.go +++ b/pool/tx_test.go @@ -16,7 +16,7 @@ func TestTxExec(t *testing.T) { tx, err := pool.Begin() require.NoError(t, err) - defer tx.Rollback() + defer tx.Rollback(context.Background()) testExec(t, tx) } @@ -28,7 +28,7 @@ func TestTxQuery(t *testing.T) { tx, err := pool.Begin() require.NoError(t, err) - defer tx.Rollback() + defer tx.Rollback(context.Background()) testQuery(t, tx) } @@ -40,7 +40,7 @@ func TestTxQueryRow(t *testing.T) { tx, err := pool.Begin() require.NoError(t, err) - defer tx.Rollback() + defer tx.Rollback(context.Background()) testQueryRow(t, tx) } diff --git a/stress_test.go b/stress_test.go index 182b725d..e0292d3b 100644 --- a/stress_test.go +++ b/stress_test.go @@ -278,7 +278,7 @@ package pgx_test // return err // } -// return tx.Commit() +// return tx.Commit(context.Background()) // } // func txMultipleQueries(pool *pgx.ConnPool, actionNum int) error { @@ -317,7 +317,7 @@ package pgx_test // } // } -// return tx.Commit() +// return tx.Commit(context.Background()) // } // func canceledQueryExContext(pool *pgx.ConnPool, actionNum int) error { diff --git a/tx.go b/tx.go index 14d356b1..def6dbad 100644 --- a/tx.go +++ b/tx.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "fmt" - "time" "github.com/jackc/pgconn" "github.com/pkg/errors" @@ -110,13 +109,8 @@ type Tx struct { status int8 } -// Commit commits the transaction -func (tx *Tx) Commit() error { - return tx.CommitEx(context.Background()) -} - -// CommitEx commits the transaction with a context. -func (tx *Tx) CommitEx(ctx context.Context) error { +// Commit commits the transaction. +func (tx *Tx) Commit(ctx context.Context) error { if tx.status != TxStatusInProgress { return ErrTxClosed } @@ -141,14 +135,7 @@ func (tx *Tx) CommitEx(ctx context.Context) error { // Tx is already closed, but is otherwise safe to call multiple times. Hence, a // defer tx.Rollback() is safe even if tx.Commit() will be called first in a // non-error condition. -func (tx *Tx) Rollback() error { - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) - defer cancel() - return tx.RollbackEx(ctx) -} - -// RollbackEx is the context version of Rollback -func (tx *Tx) RollbackEx(ctx context.Context) error { +func (tx *Tx) Rollback(ctx context.Context) error { if tx.status != TxStatusInProgress { return ErrTxClosed } diff --git a/tx_test.go b/tx_test.go index 8244649e..e33a9c27 100644 --- a/tx_test.go +++ b/tx_test.go @@ -36,7 +36,7 @@ func TestTransactionSuccessfulCommit(t *testing.T) { t.Fatalf("tx.Exec failed: %v", err) } - err = tx.Commit() + err = tx.Commit(context.Background()) if err != nil { t.Fatalf("tx.Commit failed: %v", err) } @@ -82,7 +82,7 @@ func TestTxCommitWhenTxBroken(t *testing.T) { t.Fatal("Unexpected success") } - err = tx.Commit() + err = tx.Commit(context.Background()) if err != pgx.ErrTxCommitRollback { t.Fatalf("Expected error %v, got %v", pgx.ErrTxCommitRollback, err) } @@ -117,13 +117,13 @@ func TestTxCommitSerializationFailure(t *testing.T) { if err != nil { t.Fatalf("BeginEx failed: %v", err) } - defer tx1.Rollback() + defer tx1.Rollback(context.Background()) tx2, err := c2.BeginEx(context.Background(), &pgx.TxOptions{IsoLevel: pgx.Serializable}) if err != nil { t.Fatalf("BeginEx failed: %v", err) } - defer tx2.Rollback() + defer tx2.Rollback(context.Background()) _, err = tx1.Exec(context.Background(), `insert into tx_serializable_sums(num) select sum(num) from tx_serializable_sums`) if err != nil { @@ -135,12 +135,12 @@ func TestTxCommitSerializationFailure(t *testing.T) { t.Fatalf("Exec failed: %v", err) } - err = tx1.Commit() + err = tx1.Commit(context.Background()) if err != nil { t.Fatalf("Commit failed: %v", err) } - err = tx2.Commit() + err = tx2.Commit(context.Background()) if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "40001" { t.Fatalf("Expected serialization error 40001, got %#v", err) } @@ -173,7 +173,7 @@ func TestTransactionSuccessfulRollback(t *testing.T) { t.Fatalf("tx.Exec failed: %v", err) } - err = tx.Rollback() + err = tx.Rollback(context.Background()) if err != nil { t.Fatalf("tx.Rollback failed: %v", err) } @@ -207,7 +207,7 @@ func TestBeginExIsoLevels(t *testing.T) { t.Errorf("Expected to be in isolation level %v but was %v", iso, level) } - err = tx.Rollback() + err = tx.Rollback(context.Background()) if err != nil { t.Fatalf("tx.Rollback failed: %v", err) } @@ -224,7 +224,7 @@ func TestBeginExReadOnly(t *testing.T) { if err != nil { t.Fatalf("conn.BeginEx failed: %v", err) } - defer tx.Rollback() + defer tx.Rollback(context.Background()) _, err = conn.Exec(context.Background(), "create table foo(id serial primary key)") if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "25006" { @@ -247,7 +247,7 @@ func TestTxStatus(t *testing.T) { t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInProgress, status) } - if err := tx.Rollback(); err != nil { + if err := tx.Rollback(context.Background()); err != nil { t.Fatal(err) } @@ -294,7 +294,7 @@ func TestTxStatusErrorInTransactions(t *testing.T) { t.Fatalf("Expected status to be %v, but it was %v", pgx.TxStatusInProgress, status) } - if err := tx.Rollback(); err != nil { + if err := tx.Rollback(context.Background()); err != nil { t.Fatal(err) } @@ -319,7 +319,7 @@ func TestTxErr(t *testing.T) { t.Fatal("Unexpected success") } - if err := tx.Commit(); err != pgx.ErrTxCommitRollback { + if err := tx.Commit(context.Background()); err != pgx.ErrTxCommitRollback { t.Fatalf("Expected error %v, got %v", pgx.ErrTxCommitRollback, err) }