diff --git a/batch.go b/batch.go index 36671cfd..02f93daa 100644 --- a/batch.go +++ b/batch.go @@ -21,10 +21,10 @@ type Batch struct { connPool *ConnPool items []*batchItem resultsRead int - sent bool pendingCommandComplete bool ctx context.Context err error + inTx bool } // BeginBatch returns a *Batch query for c. @@ -32,6 +32,10 @@ func (c *Conn) BeginBatch() *Batch { return &Batch{conn: c} } +func (tx *Tx) BeginBatch() *Batch { + return &Batch{conn: tx.conn, inTx: true} +} + // Conn returns the underlying connection that b will or was performed on. func (b *Batch) Conn() *Conn { return b.conn @@ -49,7 +53,8 @@ func (b *Batch) Queue(query string, arguments []interface{}, parameterOIDs []pgt }) } -// Send sends all queued queries to the server at once. All queries are wrapped +// Send sends all queued queries to the server at once. +// If the batch is created from a conn Object then All queries are wrapped // in a transaction. The transaction can optionally be configured with // txOptions. The context is in effect until the Batch is closed. func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error { @@ -68,13 +73,16 @@ func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error { return err } + buf := b.conn.wbuf + if !b.inTx { + buf = appendQuery(buf, txOptions.beginSQL()) + } + err = b.conn.initContext(ctx) if err != nil { return err } - buf := appendQuery(b.conn.wbuf, txOptions.beginSQL()) - for _, bi := range b.items { var psName string var psParameterOIDs []pgtype.OID @@ -98,7 +106,12 @@ func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error { } buf = appendSync(buf) - buf = appendQuery(buf, "commit") + b.conn.pendingReadyForQueryCount++ + + if !b.inTx { + buf = appendQuery(buf, "commit") + b.conn.pendingReadyForQueryCount++ + } n, err := b.conn.conn.Write(buf) if err != nil { @@ -108,12 +121,7 @@ func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error { return err } - // expect ReadyForQuery from sync and from commit - b.conn.pendingReadyForQueryCount = b.conn.pendingReadyForQueryCount + 2 - - b.sent = true - - for { + for !b.inTx { msg, err := b.conn.rxMsg() if err != nil { return err diff --git a/batch_test.go b/batch_test.go index 3112f183..3b51971c 100644 --- a/batch_test.go +++ b/batch_test.go @@ -574,3 +574,130 @@ func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) { ensureConnValid(t, conn) } + +func TestTxBeginBatch(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + sql := `create temporary table ledger1( + id serial primary key, + description varchar not null +);` + mustExec(t, conn, sql) + + sql = `create temporary table ledger2( + id int primary key, + amount int not null +);` + mustExec(t, conn, sql) + + tx, _ := conn.Begin() + batch := tx.BeginBatch() + batch.Queue("insert into ledger1(description) values($1) returning id", + []interface{}{"q1"}, + []pgtype.OID{pgtype.VarcharOID}, + []int16{pgx.BinaryFormatCode}, + ) + + err := batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + var id int + err = batch.QueryRowResults().Scan(&id) + if err != nil { + t.Error(err) + } + batch.Close() + + batch = tx.BeginBatch() + batch.Queue("insert into ledger2(id,amount) values($1, $2)", + []interface{}{id, 2}, + []pgtype.OID{pgtype.Int4OID, pgtype.Int4OID}, + nil, + ) + + batch.Queue("select amount from ledger2 where id = $1", + []interface{}{id}, + []pgtype.OID{pgtype.Int4OID}, + nil, + ) + + err = batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + ct, err := batch.ExecResults() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 1 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) + } + + var amout int + err = batch.QueryRowResults().Scan(&amout) + if err != nil { + t.Error(err) + } + + batch.Close() + tx.Commit() + + var count int + conn.QueryRow("select count(1) from ledger1 where id = $1", id).Scan(&count) + if count != 1 { + t.Errorf("count => %v, want %v", count, 1) + } + + err = batch.Close() + if err != nil { + t.Fatal(err) + } + + ensureConnValid(t, conn) +} + +func TestTxBeginBatchRollback(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + sql := `create temporary table ledger1( + id serial primary key, + description varchar not null +);` + mustExec(t, conn, sql) + + tx, _ := conn.Begin() + batch := tx.BeginBatch() + batch.Queue("insert into ledger1(description) values($1) returning id", + []interface{}{"q1"}, + []pgtype.OID{pgtype.VarcharOID}, + []int16{pgx.BinaryFormatCode}, + ) + + err := batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + var id int + err = batch.QueryRowResults().Scan(&id) + if err != nil { + t.Error(err) + } + batch.Close() + tx.Rollback() + + row := conn.QueryRow("select count(1) from ledger1 where id = $1", id) + var count int + row.Scan(&count) + if count != 0 { + t.Errorf("count => %v, want %v", count, 0) + } + + ensureConnValid(t, conn) +} \ No newline at end of file