diff --git a/tx.go b/tx.go index 404bf7c9..0944cc14 100644 --- a/tx.go +++ b/tx.go @@ -159,6 +159,7 @@ func (tx *dbTx) Rollback(ctx context.Context) error { } _, err := tx.conn.Exec(ctx, "rollback") + tx.closed = true if err != nil { // A rollback failure leaves the connection in an undefined state tx.conn.die(errors.Errorf("rollback failed: %w", err))