+141
@@ -2,6 +2,7 @@ package pgx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
@@ -282,6 +283,64 @@ func TestBeginIsoLevels(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBeginFunc(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
createSql := `
|
||||
create temporary table foo(
|
||||
id integer,
|
||||
unique (id) initially deferred
|
||||
);
|
||||
`
|
||||
|
||||
_, err := conn.Exec(context.Background(), createSql)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error {
|
||||
_, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
|
||||
require.NoError(t, err)
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var n int64
|
||||
err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, n)
|
||||
}
|
||||
|
||||
func TestBeginFuncRollbackOnError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
createSql := `
|
||||
create temporary table foo(
|
||||
id integer,
|
||||
unique (id) initially deferred
|
||||
);
|
||||
`
|
||||
|
||||
_, err := conn.Exec(context.Background(), createSql)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error {
|
||||
_, err := tx.Exec(context.Background(), "insert into foo(id) values (1)")
|
||||
require.NoError(t, err)
|
||||
return errors.New("some error")
|
||||
})
|
||||
require.EqualError(t, err, "some error")
|
||||
|
||||
var n int64
|
||||
err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 0, n)
|
||||
}
|
||||
|
||||
func TestBeginReadOnly(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -433,3 +492,85 @@ func TestTxNestedTransactionRollback(t *testing.T) {
|
||||
t.Fatalf("Did not receive correct number of rows: %v", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTxBeginFuncNestedTransactionCommit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, db)
|
||||
|
||||
createSql := `
|
||||
create temporary table foo(
|
||||
id integer,
|
||||
unique (id) initially deferred
|
||||
);
|
||||
`
|
||||
|
||||
_, err := db.Exec(context.Background(), createSql)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
|
||||
_, err := db.Exec(context.Background(), "insert into foo(id) values (1)")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
|
||||
_, err := db.Exec(context.Background(), "insert into foo(id) values (2)")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
|
||||
_, err := db.Exec(context.Background(), "insert into foo(id) values (3)")
|
||||
require.NoError(t, err)
|
||||
return nil
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var n int64
|
||||
err = db.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 3, n)
|
||||
}
|
||||
|
||||
func TestTxBeginFuncNestedTransactionRollback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, db)
|
||||
|
||||
createSql := `
|
||||
create temporary table foo(
|
||||
id integer,
|
||||
unique (id) initially deferred
|
||||
);
|
||||
`
|
||||
|
||||
_, err := db.Exec(context.Background(), createSql)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
|
||||
_, err := db.Exec(context.Background(), "insert into foo(id) values (1)")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.BeginFunc(context.Background(), func(db pgx.Tx) error {
|
||||
_, err := db.Exec(context.Background(), "insert into foo(id) values (2)")
|
||||
require.NoError(t, err)
|
||||
return errors.New("do a rollback")
|
||||
})
|
||||
require.EqualError(t, err, "do a rollback")
|
||||
|
||||
_, err = db.Exec(context.Background(), "insert into foo(id) values (3)")
|
||||
require.NoError(t, err)
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
var n int64
|
||||
err = db.QueryRow(context.Background(), "select count(*) from foo").Scan(&n)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 2, n)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user