@@ -78,6 +78,14 @@ func (c *Conn) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, er
|
||||
return c.Conn().BeginTx(ctx, txOptions)
|
||||
}
|
||||
|
||||
func (c *Conn) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error {
|
||||
return c.Conn().BeginFunc(ctx, f)
|
||||
}
|
||||
|
||||
func (c *Conn) BeginTxFunc(ctx context.Context, txOptions pgx.TxOptions, f func(pgx.Tx) error) error {
|
||||
return c.Conn().BeginTxFunc(ctx, txOptions, f)
|
||||
}
|
||||
|
||||
func (c *Conn) Ping(ctx context.Context) error {
|
||||
return c.Conn().Ping(ctx)
|
||||
}
|
||||
|
||||
@@ -496,6 +496,20 @@ func (p *Pool) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, er
|
||||
return &Tx{t: t, c: c}, err
|
||||
}
|
||||
|
||||
func (p *Pool) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error {
|
||||
return p.BeginTxFunc(ctx, pgx.TxOptions{}, f)
|
||||
}
|
||||
|
||||
func (p *Pool) BeginTxFunc(ctx context.Context, txOptions pgx.TxOptions, f func(pgx.Tx) error) error {
|
||||
c, err := p.Acquire(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer c.Release()
|
||||
|
||||
return c.BeginTxFunc(ctx, txOptions, f)
|
||||
}
|
||||
|
||||
func (p *Pool) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
|
||||
c, err := p.Acquire(ctx)
|
||||
if err != nil {
|
||||
|
||||
@@ -2,6 +2,7 @@ package pgxpool_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
@@ -668,3 +669,93 @@ func TestConnReleaseWhenBeginFail(t *testing.T) {
|
||||
|
||||
assert.EqualValues(t, 0, db.Stat().TotalConns())
|
||||
}
|
||||
|
||||
func TestTxBeginFuncNestedTransactionCommit(t *testing.T) {
|
||||
db, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
createSql := `
|
||||
drop table if exists pgxpooltx;
|
||||
create temporary table pgxpooltx(
|
||||
id integer,
|
||||
unique (id) initially deferred
|
||||
);
|
||||
`
|
||||
|
||||
_, err = db.Exec(context.Background(), createSql)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
db.Exec(context.Background(), "drop table pgxpooltx")
|
||||
}()
|
||||
|
||||
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
|
||||
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (1)")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
|
||||
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (2)")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
|
||||
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (3)")
|
||||
require.NoError(t, err)
|
||||
return nil
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var n int64
|
||||
err = db.QueryRow(context.Background(), "select count(*) from pgxpooltx").Scan(&n)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 3, n)
|
||||
}
|
||||
|
||||
func TestTxBeginFuncNestedTransactionRollback(t *testing.T) {
|
||||
db, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
createSql := `
|
||||
drop table if exists pgxpooltx;
|
||||
create temporary table pgxpooltx(
|
||||
id integer,
|
||||
unique (id) initially deferred
|
||||
);
|
||||
`
|
||||
|
||||
_, err = db.Exec(context.Background(), createSql)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
db.Exec(context.Background(), "drop table pgxpooltx")
|
||||
}()
|
||||
|
||||
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
|
||||
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (1)")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
|
||||
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (2)")
|
||||
require.NoError(t, err)
|
||||
return errors.New("do a rollback")
|
||||
})
|
||||
require.EqualError(t, err, "do a rollback")
|
||||
|
||||
_, err = db.Exec(context.Background(), "insert into pgxpooltx(id) values (3)")
|
||||
require.NoError(t, err)
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
var n int64
|
||||
err = db.QueryRow(context.Background(), "select count(*) from pgxpooltx").Scan(&n)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 2, n)
|
||||
}
|
||||
|
||||
@@ -16,6 +16,10 @@ func (tx *Tx) Begin(ctx context.Context) (pgx.Tx, error) {
|
||||
return tx.t.Begin(ctx)
|
||||
}
|
||||
|
||||
func (tx *Tx) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error {
|
||||
return tx.t.BeginFunc(ctx, f)
|
||||
}
|
||||
|
||||
func (tx *Tx) Commit(ctx context.Context) error {
|
||||
err := tx.t.Commit(ctx)
|
||||
if tx.c != nil {
|
||||
|
||||
Reference in New Issue
Block a user