2
0

Replace Begin and BeginTx methods with functions

This commit is contained in:
Jack Christensen
2022-07-09 17:25:55 -05:00
parent 62f0347586
commit 31ec18cc65
8 changed files with 82 additions and 117 deletions
+63 -77
View File
@@ -94,39 +94,6 @@ func (c *Conn) BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error) {
return &dbTx{conn: c}, nil
}
// BeginFunc starts a transaction and calls f. If f does not return an error the transaction is committed. If f returns
// an error the transaction is rolled back. The context will be used when executing the transaction control statements
// (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of f.
func (c *Conn) BeginFunc(ctx context.Context, f func(Tx) error) (err error) {
return c.BeginTxFunc(ctx, TxOptions{}, f)
}
// BeginTxFunc starts a transaction with txOptions determining the transaction mode and calls f. If f does not return
// an error the transaction is committed. If f returns an error the transaction is rolled back. The context will be
// used when executing the transaction control statements (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect
// the execution of f.
func (c *Conn) BeginTxFunc(ctx context.Context, txOptions TxOptions, f func(Tx) error) (err error) {
var tx Tx
tx, err = c.BeginTx(ctx, txOptions)
if err != nil {
return err
}
defer func() {
rollbackErr := tx.Rollback(ctx)
if rollbackErr != nil && !errors.Is(rollbackErr, ErrTxClosed) {
err = rollbackErr
}
}()
fErr := f(tx)
if fErr != nil {
_ = tx.Rollback(ctx) // ignore rollback error as there is already an error to return
return fErr
}
return tx.Commit(ctx)
}
// Tx represents a database transaction.
//
// Tx is an interface instead of a struct to enable connection pools to be implemented without relying on internal pgx
@@ -138,20 +105,17 @@ type Tx interface {
// Begin starts a pseudo nested transaction.
Begin(ctx context.Context) (Tx, error)
// BeginFunc starts a pseudo nested transaction and executes f. If f does not return an err the pseudo nested
// transaction will be committed. If it does then it will be rolled back.
BeginFunc(ctx context.Context, f func(Tx) error) (err error)
// Commit commits the transaction if this is a real transaction or releases the savepoint if this is a pseudo nested
// transaction. Commit will return ErrTxClosed if the Tx is already closed, but is otherwise safe to call multiple
// times. If the commit fails with a rollback status (e.g. the transaction was already in a broken state) then
// ErrTxCommitRollback will be returned.
// transaction. Commit will return an error where errors.Is(ErrTxClosed) is true if the Tx is already closed, but is
// otherwise safe to call multiple times. If the commit fails with a rollback status (e.g. the transaction was already
// in a broken state) then an error where errors.Is(ErrTxCommitRollback) is true will be returned.
Commit(ctx context.Context) error
// Rollback rolls back the transaction if this is a real transaction or rolls back to the savepoint if this is a
// pseudo nested transaction. Rollback will return ErrTxClosed if the Tx is already closed, but is otherwise safe to
// call multiple times. Hence, a defer tx.Rollback() is safe even if tx.Commit() will be called first in a non-error
// condition. Any other failure of a real transaction will result in the connection being closed.
// pseudo nested transaction. Rollback will return an error where errors.Is(ErrTxClosed) is true if the Tx is already
// closed, but is otherwise safe to call multiple times. Hence, a defer tx.Rollback() is safe even if tx.Commit() will
// be called first in a non-error condition. Any other failure of a real transaction will result in the connection
// being closed.
Rollback(ctx context.Context) error
CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error)
@@ -194,32 +158,6 @@ func (tx *dbTx) Begin(ctx context.Context) (Tx, error) {
return &dbSimulatedNestedTx{tx: tx, savepointNum: tx.savepointNum}, nil
}
func (tx *dbTx) BeginFunc(ctx context.Context, f func(Tx) error) (err error) {
if tx.closed {
return ErrTxClosed
}
var savepoint Tx
savepoint, err = tx.Begin(ctx)
if err != nil {
return err
}
defer func() {
rollbackErr := savepoint.Rollback(ctx)
if rollbackErr != nil && !errors.Is(rollbackErr, ErrTxClosed) {
err = rollbackErr
}
}()
fErr := f(savepoint)
if fErr != nil {
_ = savepoint.Rollback(ctx) // ignore rollback error as there is already an error to return
return fErr
}
return savepoint.Commit(ctx)
}
// Commit commits the transaction.
func (tx *dbTx) Commit(ctx context.Context) error {
if tx.closed {
@@ -335,14 +273,6 @@ func (sp *dbSimulatedNestedTx) Begin(ctx context.Context) (Tx, error) {
return sp.tx.Begin(ctx)
}
func (sp *dbSimulatedNestedTx) BeginFunc(ctx context.Context, f func(Tx) error) (err error) {
if sp.closed {
return ErrTxClosed
}
return sp.tx.BeginFunc(ctx, f)
}
// Commit releases the savepoint essentially committing the pseudo nested transaction.
func (sp *dbSimulatedNestedTx) Commit(ctx context.Context) error {
if sp.closed {
@@ -427,3 +357,59 @@ func (sp *dbSimulatedNestedTx) LargeObjects() LargeObjects {
func (sp *dbSimulatedNestedTx) Conn() *Conn {
return sp.tx.Conn()
}
// BeginFunc calls Begin on db and then calls fn. If fn does not return an error then it calls Commit on db. If fn
// returns an error it calls Rollback on db. The context will be used when executing the transaction control statements
// (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of fn.
func BeginFunc(
ctx context.Context,
db interface {
Begin(ctx context.Context) (Tx, error)
},
fn func(Tx) error,
) (err error) {
var tx Tx
tx, err = db.Begin(ctx)
if err != nil {
return err
}
return beginFuncExec(ctx, tx, fn)
}
// BeginTxFunc calls BeginTx on db and then calls fn. If fn does not return an error then it calls Commit on db. If fn
// returns an error it calls Rollback on db. The context will be used when executing the transaction control statements
// (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of fn.
func BeginTxFunc(
ctx context.Context,
db interface {
BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error)
},
txOptions TxOptions,
fn func(Tx) error,
) (err error) {
var tx Tx
tx, err = db.BeginTx(ctx, txOptions)
if err != nil {
return err
}
return beginFuncExec(ctx, tx, fn)
}
func beginFuncExec(ctx context.Context, tx Tx, fn func(Tx) error) (err error) {
defer func() {
rollbackErr := tx.Rollback(ctx)
if rollbackErr != nil && !errors.Is(rollbackErr, ErrTxClosed) {
err = rollbackErr
}
}()
fErr := fn(tx)
if fErr != nil {
_ = tx.Rollback(ctx) // ignore rollback error as there is already an error to return
return fErr
}
return tx.Commit(ctx)
}