2
0

Add BeginFunc and BeginTxFunc

fixes #821
This commit is contained in:
Jack Christensen
2021-02-20 18:30:18 -06:00
parent 373bb84e9d
commit ac2918b9a3
7 changed files with 340 additions and 0 deletions
+91
View File
@@ -2,6 +2,7 @@ package pgxpool_test
import (
"context"
"errors"
"fmt"
"os"
"testing"
@@ -668,3 +669,93 @@ func TestConnReleaseWhenBeginFail(t *testing.T) {
assert.EqualValues(t, 0, db.Stat().TotalConns())
}
func TestTxBeginFuncNestedTransactionCommit(t *testing.T) {
db, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer db.Close()
createSql := `
drop table if exists pgxpooltx;
create temporary table pgxpooltx(
id integer,
unique (id) initially deferred
);
`
_, err = db.Exec(context.Background(), createSql)
require.NoError(t, err)
defer func() {
db.Exec(context.Background(), "drop table pgxpooltx")
}()
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (1)")
require.NoError(t, err)
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (2)")
require.NoError(t, err)
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (3)")
require.NoError(t, err)
return nil
})
return nil
})
require.NoError(t, err)
return nil
})
require.NoError(t, err)
var n int64
err = db.QueryRow(context.Background(), "select count(*) from pgxpooltx").Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 3, n)
}
func TestTxBeginFuncNestedTransactionRollback(t *testing.T) {
db, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer db.Close()
createSql := `
drop table if exists pgxpooltx;
create temporary table pgxpooltx(
id integer,
unique (id) initially deferred
);
`
_, err = db.Exec(context.Background(), createSql)
require.NoError(t, err)
defer func() {
db.Exec(context.Background(), "drop table pgxpooltx")
}()
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (1)")
require.NoError(t, err)
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
_, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (2)")
require.NoError(t, err)
return errors.New("do a rollback")
})
require.EqualError(t, err, "do a rollback")
_, err = db.Exec(context.Background(), "insert into pgxpooltx(id) values (3)")
require.NoError(t, err)
return nil
})
var n int64
err = db.QueryRow(context.Background(), "select count(*) from pgxpooltx").Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 2, n)
}