diff --git a/batch.go b/batch.go index fc6f0d03..36671cfd 100644 --- a/batch.go +++ b/batch.go @@ -17,13 +17,14 @@ type batchItem struct { // Batch queries are a way of bundling multiple queries together to avoid // unnecessary network round trips. type Batch struct { - conn *Conn - connPool *ConnPool - items []*batchItem - resultsRead int - sent bool - ctx context.Context - err error + conn *Conn + connPool *ConnPool + items []*batchItem + resultsRead int + sent bool + pendingCommandComplete bool + ctx context.Context + err error } // BeginBatch returns a *Batch query for c. @@ -145,8 +146,15 @@ func (b *Batch) ExecResults() (CommandTag, error) { default: } + if err := b.ensureCommandComplete(); err != nil { + b.die(err) + return "", err + } + b.resultsRead++ + b.pendingCommandComplete = true + for { msg, err := b.conn.rxMsg() if err != nil { @@ -155,6 +163,7 @@ func (b *Batch) ExecResults() (CommandTag, error) { switch msg := msg.(type) { case *pgproto3.CommandComplete: + b.pendingCommandComplete = false return CommandTag(msg.CommandTag), nil default: if err := b.conn.processContextFreeMsg(msg); err != nil { @@ -182,8 +191,16 @@ func (b *Batch) QueryResults() (*Rows, error) { default: } + if err := b.ensureCommandComplete(); err != nil { + b.die(err) + rows.fatal(err) + return rows, err + } + b.resultsRead++ + b.pendingCommandComplete = true + fieldDescriptions, err := b.conn.readUntilRowDescription() if err != nil { b.die(err) @@ -244,3 +261,25 @@ func (b *Batch) die(err error) { b.connPool.Release(b.conn) } } + +func (b *Batch) ensureCommandComplete() error { + for b.pendingCommandComplete { + msg, err := b.conn.rxMsg() + if err != nil { + return err + } + + switch msg := msg.(type) { + case *pgproto3.CommandComplete: + b.pendingCommandComplete = false + return nil + default: + err = b.conn.processContextFreeMsg(msg) + if err != nil { + return err + } + } + } + + return nil +} diff --git a/batch_test.go b/batch_test.go index 998e8764..3112f183 100644 --- a/batch_test.go +++ b/batch_test.go @@ -477,7 +477,7 @@ func TestConnBeginBatchQuerySyntaxError(t *testing.T) { } } -func TestConnBeginBatchSelectInsert(t *testing.T) { +func TestConnBeginBatchQueryRowInsert(t *testing.T) { t.Parallel() conn := mustConnect(t, *defaultConnConfig) @@ -525,3 +525,52 @@ func TestConnBeginBatchSelectInsert(t *testing.T) { ensureConnValid(t, conn) } + +func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + sql := `create temporary table ledger( + id serial primary key, + description varchar not null, + amount int not null +);` + mustExec(t, conn, sql) + + batch := conn.BeginBatch() + batch.Queue("select 1 union all select 2 union all select 3", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", + []interface{}{"q1", 1}, + []pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID}, + nil, + ) + + err := batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + rows, err := batch.QueryResults() + if err != nil { + t.Error(err) + } + rows.Close() + + ct, err := batch.ExecResults() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 2 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected, 2) + } + + batch.Close() + + ensureConnValid(t, conn) +} diff --git a/query.go b/query.go index 07b42e59..407a792c 100644 --- a/query.go +++ b/query.go @@ -34,8 +34,6 @@ func (r *Row) Scan(dest ...interface{}) (err error) { } rows.Scan(dest...) - for rows.Next() { - } rows.Close() return rows.Err() } @@ -151,6 +149,9 @@ func (rows *Rows) Next() bool { rows.values = msg.Values return true case *pgproto3.CommandComplete: + if rows.batch != nil { + rows.batch.pendingCommandComplete = false + } rows.Close() return false