Add callback functions to queued queries
Improve batch query ergonomics by allowing the code to handle the results of a query to be right next to the query.
This commit is contained in:
+150
@@ -3,6 +3,7 @@ package pgx_test
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
@@ -148,6 +149,99 @@ func TestConnSendBatch(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnSendBatchQueuedQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test")
|
||||
|
||||
sql := `create temporary table ledger(
|
||||
id serial primary key,
|
||||
description varchar not null,
|
||||
amount int not null
|
||||
);`
|
||||
mustExec(t, conn, sql)
|
||||
|
||||
batch := &pgx.Batch{}
|
||||
|
||||
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1).Exec(func(ct pgconn.CommandTag) error {
|
||||
assert.EqualValues(t, 1, ct.RowsAffected())
|
||||
return nil
|
||||
})
|
||||
|
||||
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2).Exec(func(ct pgconn.CommandTag) error {
|
||||
assert.EqualValues(t, 1, ct.RowsAffected())
|
||||
return nil
|
||||
})
|
||||
|
||||
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3).Exec(func(ct pgconn.CommandTag) error {
|
||||
assert.EqualValues(t, 1, ct.RowsAffected())
|
||||
return nil
|
||||
})
|
||||
|
||||
selectFromLedgerExpectedRows := []struct {
|
||||
id int32
|
||||
description string
|
||||
amount int32
|
||||
}{
|
||||
{1, "q1", 1},
|
||||
{2, "q2", 2},
|
||||
{3, "q3", 3},
|
||||
}
|
||||
|
||||
batch.Queue("select id, description, amount from ledger order by id").Query(func(rows pgx.Rows) error {
|
||||
rowCount := 0
|
||||
var id int32
|
||||
var description string
|
||||
var amount int32
|
||||
_, err := pgx.ForEachRow(rows, []any{&id, &description, &amount}, func() error {
|
||||
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].id, id)
|
||||
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].description, description)
|
||||
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].amount, amount)
|
||||
rowCount++
|
||||
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
return nil
|
||||
})
|
||||
|
||||
batch.Queue("select id, description, amount from ledger order by id").Query(func(rows pgx.Rows) error {
|
||||
rowCount := 0
|
||||
var id int32
|
||||
var description string
|
||||
var amount int32
|
||||
_, err := pgx.ForEachRow(rows, []any{&id, &description, &amount}, func() error {
|
||||
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].id, id)
|
||||
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].description, description)
|
||||
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].amount, amount)
|
||||
rowCount++
|
||||
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
return nil
|
||||
})
|
||||
|
||||
batch.Queue("select * from ledger where false").QueryRow(func(row pgx.Row) error {
|
||||
err := row.Scan(nil, nil, nil)
|
||||
assert.ErrorIs(t, err, pgx.ErrNoRows)
|
||||
return nil
|
||||
})
|
||||
|
||||
batch.Queue("select sum(amount) from ledger").QueryRow(func(row pgx.Row) error {
|
||||
var sumAmount int32
|
||||
err := row.Scan(&sumAmount)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 6, sumAmount)
|
||||
return nil
|
||||
})
|
||||
|
||||
err := conn.SendBatch(context.Background(), batch).Close()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnSendBatchMany(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -773,3 +867,59 @@ func TestSendBatchSimpleProtocol(t *testing.T) {
|
||||
assert.EqualValues(t, 3, values[0])
|
||||
assert.False(t, rows.Next())
|
||||
}
|
||||
|
||||
func ExampleConn_SendBatch() {
|
||||
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
if err != nil {
|
||||
fmt.Printf("Unable to establish connection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
batch := &pgx.Batch{}
|
||||
batch.Queue("select 1 + 1").QueryRow(func(row pgx.Row) error {
|
||||
var n int32
|
||||
err := row.Scan(&n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Println(n)
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
batch.Queue("select 1 + 2").QueryRow(func(row pgx.Row) error {
|
||||
var n int32
|
||||
err := row.Scan(&n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Println(n)
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
batch.Queue("select 2 + 3").QueryRow(func(row pgx.Row) error {
|
||||
var n int32
|
||||
err := row.Scan(&n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Println(n)
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
err = conn.SendBatch(context.Background(), batch).Close()
|
||||
if err != nil {
|
||||
fmt.Printf("SendBatch error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Output:
|
||||
// 2
|
||||
// 3
|
||||
// 5
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user