2
0

Add callback functions to queued queries

Improve batch query ergonomics by allowing the code to handle the
results of a query to be right next to the query.
This commit is contained in:
Jack Christensen
2022-07-16 17:46:47 -05:00
parent 78875bb95a
commit 29254180ca
4 changed files with 267 additions and 43 deletions
+150
View File
@@ -3,6 +3,7 @@ package pgx_test
import (
"context"
"errors"
"fmt"
"os"
"testing"
@@ -148,6 +149,99 @@ func TestConnSendBatch(t *testing.T) {
})
}
func TestConnSendBatchQueuedQuery(t *testing.T) {
t.Parallel()
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test")
sql := `create temporary table ledger(
id serial primary key,
description varchar not null,
amount int not null
);`
mustExec(t, conn, sql)
batch := &pgx.Batch{}
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1).Exec(func(ct pgconn.CommandTag) error {
assert.EqualValues(t, 1, ct.RowsAffected())
return nil
})
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2).Exec(func(ct pgconn.CommandTag) error {
assert.EqualValues(t, 1, ct.RowsAffected())
return nil
})
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3).Exec(func(ct pgconn.CommandTag) error {
assert.EqualValues(t, 1, ct.RowsAffected())
return nil
})
selectFromLedgerExpectedRows := []struct {
id int32
description string
amount int32
}{
{1, "q1", 1},
{2, "q2", 2},
{3, "q3", 3},
}
batch.Queue("select id, description, amount from ledger order by id").Query(func(rows pgx.Rows) error {
rowCount := 0
var id int32
var description string
var amount int32
_, err := pgx.ForEachRow(rows, []any{&id, &description, &amount}, func() error {
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].id, id)
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].description, description)
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].amount, amount)
rowCount++
return nil
})
assert.NoError(t, err)
return nil
})
batch.Queue("select id, description, amount from ledger order by id").Query(func(rows pgx.Rows) error {
rowCount := 0
var id int32
var description string
var amount int32
_, err := pgx.ForEachRow(rows, []any{&id, &description, &amount}, func() error {
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].id, id)
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].description, description)
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].amount, amount)
rowCount++
return nil
})
assert.NoError(t, err)
return nil
})
batch.Queue("select * from ledger where false").QueryRow(func(row pgx.Row) error {
err := row.Scan(nil, nil, nil)
assert.ErrorIs(t, err, pgx.ErrNoRows)
return nil
})
batch.Queue("select sum(amount) from ledger").QueryRow(func(row pgx.Row) error {
var sumAmount int32
err := row.Scan(&sumAmount)
assert.NoError(t, err)
assert.EqualValues(t, 6, sumAmount)
return nil
})
err := conn.SendBatch(context.Background(), batch).Close()
assert.NoError(t, err)
})
}
func TestConnSendBatchMany(t *testing.T) {
t.Parallel()
@@ -773,3 +867,59 @@ func TestSendBatchSimpleProtocol(t *testing.T) {
assert.EqualValues(t, 3, values[0])
assert.False(t, rows.Next())
}
func ExampleConn_SendBatch() {
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
if err != nil {
fmt.Printf("Unable to establish connection: %v", err)
return
}
batch := &pgx.Batch{}
batch.Queue("select 1 + 1").QueryRow(func(row pgx.Row) error {
var n int32
err := row.Scan(&n)
if err != nil {
return err
}
fmt.Println(n)
return err
})
batch.Queue("select 1 + 2").QueryRow(func(row pgx.Row) error {
var n int32
err := row.Scan(&n)
if err != nil {
return err
}
fmt.Println(n)
return err
})
batch.Queue("select 2 + 3").QueryRow(func(row pgx.Row) error {
var n int32
err := row.Scan(&n)
if err != nil {
return err
}
fmt.Println(n)
return err
})
err = conn.SendBatch(context.Background(), batch).Close()
if err != nil {
fmt.Printf("SendBatch error: %v", err)
return
}
// Output:
// 2
// 3
// 5
}