diff --git a/batch_test.go b/batch_test.go index e12e4f32..998e8764 100644 --- a/batch_test.go +++ b/batch_test.go @@ -476,3 +476,52 @@ func TestConnBeginBatchQuerySyntaxError(t *testing.T) { t.Error("conn should be dead, but was alive") } } + +func TestConnBeginBatchSelectInsert(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + sql := `create temporary table ledger( + id serial primary key, + description varchar not null, + amount int not null +);` + mustExec(t, conn, sql) + + batch := conn.BeginBatch() + batch.Queue("select 1", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", + []interface{}{"q1", 1}, + []pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID}, + nil, + ) + + err := batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + var value int + err = batch.QueryRowResults().Scan(&value) + if err != nil { + t.Error(err) + } + + ct, err := batch.ExecResults() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 2 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected, 2) + } + + batch.Close() + + ensureConnValid(t, conn) +} diff --git a/query.go b/query.go index e37e6120..07b42e59 100644 --- a/query.go +++ b/query.go @@ -34,6 +34,8 @@ func (r *Row) Scan(dest ...interface{}) (err error) { } rows.Scan(dest...) + for rows.Next() { + } rows.Close() return rows.Err() }