diff --git a/tx.go b/tx.go index aee9a2d9..91ffab9a 100644 --- a/tx.go +++ b/tx.go @@ -261,7 +261,7 @@ func (sp *dbSavepoint) Begin(ctx context.Context) (Tx, error) { return nil, ErrTxClosed } - return sp.Begin(ctx) + return sp.tx.Begin(ctx) } // Commit releases the savepoint essentially committing the pseudo nested transaction. diff --git a/tx_test.go b/tx_test.go index d40fce7a..47880bb0 100644 --- a/tx_test.go +++ b/tx_test.go @@ -269,6 +269,21 @@ func TestTxNestedTransactionCommit(t *testing.T) { t.Fatalf("nestedTx.Exec failed: %v", err) } + doubleNestedTx, err := nestedTx.Begin(context.Background()) + if err != nil { + t.Fatal(err) + } + + _, err = doubleNestedTx.Exec(context.Background(), "insert into foo(id) values (3)") + if err != nil { + t.Fatalf("doubleNestedTx.Exec failed: %v", err) + } + + err = doubleNestedTx.Commit(context.Background()) + if err != nil { + t.Fatalf("doubleNestedTx.Commit failed: %v", err) + } + err = nestedTx.Commit(context.Background()) if err != nil { t.Fatalf("nestedTx.Commit failed: %v", err) @@ -284,7 +299,7 @@ func TestTxNestedTransactionCommit(t *testing.T) { if err != nil { t.Fatalf("QueryRow Scan failed: %v", err) } - if n != 2 { + if n != 3 { t.Fatalf("Did not receive correct number of rows: %v", n) } }