Add callback functions to queued queries
Improve batch query ergonomics by allowing the code to handle the results of a query to be right next to the query.
This commit is contained in:
@@ -153,6 +153,14 @@ The `RowScanner` interface allows a single argument to Rows.Scan to scan the ent
|
||||
Rather than every type that implemented `Begin` or `BeginTx` methods also needing to implement `BeginFunc` and
|
||||
`BeginTxFunc` these methods have been converted to functions that take a db that implements `Begin` or `BeginTx`.
|
||||
|
||||
## Improved Batch Query Ergonomics
|
||||
|
||||
Previously, the code for building a batch went in one place before the call to `SendBatch`, and the code for reading the
|
||||
results went in one place after the call to `SendBatch`. This could make it difficult to match up the query and the code
|
||||
to handle the results. Now `Queue` returns a `QueuedQuery` which has methods `Query`, `QueryRow`, and `Exec` which can
|
||||
be used to register a callback function that will handle the result. Callback functions are called automatically when
|
||||
`BatchResults.Close` is called.
|
||||
|
||||
## SendBatch Uses Pipeline Mode When Appropriate
|
||||
|
||||
Previously, a batch with 10 unique parameterized statements executed 100 times would entail 11 network round trips. 1
|
||||
|
||||
@@ -8,44 +8,99 @@ import (
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
type batchItem struct {
|
||||
// QueuedQuery is a query that has been queued for execution via a Batch.
|
||||
type QueuedQuery struct {
|
||||
query string
|
||||
arguments []any
|
||||
fn batchItemFunc
|
||||
sd *pgconn.StatementDescription
|
||||
}
|
||||
|
||||
type batchItemFunc func(br BatchResults) error
|
||||
|
||||
// Query sets fn to be called when the response to qq is received.
|
||||
func (qq *QueuedQuery) Query(fn func(rows Rows) error) {
|
||||
qq.fn = func(br BatchResults) error {
|
||||
rows, err := br.Query()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
err = fn(rows)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
return rows.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Query sets fn to be called when the response to qq is received.
|
||||
func (qq *QueuedQuery) QueryRow(fn func(row Row) error) {
|
||||
qq.fn = func(br BatchResults) error {
|
||||
row := br.QueryRow()
|
||||
return fn(row)
|
||||
}
|
||||
}
|
||||
|
||||
// Exec sets fn to be called when the response to qq is received.
|
||||
func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) {
|
||||
qq.fn = func(br BatchResults) error {
|
||||
ct, err := br.Exec()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return fn(ct)
|
||||
}
|
||||
}
|
||||
|
||||
// Batch queries are a way of bundling multiple queries together to avoid
|
||||
// unnecessary network round trips. A Batch must only be sent once.
|
||||
type Batch struct {
|
||||
items []*batchItem
|
||||
queuedQueries []*QueuedQuery
|
||||
}
|
||||
|
||||
// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement.
|
||||
func (b *Batch) Queue(query string, arguments ...any) {
|
||||
b.items = append(b.items, &batchItem{
|
||||
func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery {
|
||||
qq := &QueuedQuery{
|
||||
query: query,
|
||||
arguments: arguments,
|
||||
})
|
||||
}
|
||||
b.queuedQueries = append(b.queuedQueries, qq)
|
||||
return qq
|
||||
}
|
||||
|
||||
// Len returns number of queries that have been queued so far.
|
||||
func (b *Batch) Len() int {
|
||||
return len(b.items)
|
||||
return len(b.queuedQueries)
|
||||
}
|
||||
|
||||
type BatchResults interface {
|
||||
// Exec reads the results from the next query in the batch as if the query has been sent with Conn.Exec.
|
||||
// Exec reads the results from the next query in the batch as if the query has been sent with Conn.Exec. Prefer
|
||||
// calling Exec on the QueuedQuery.
|
||||
Exec() (pgconn.CommandTag, error)
|
||||
|
||||
// Query reads the results from the next query in the batch as if the query has been sent with Conn.Query.
|
||||
// Query reads the results from the next query in the batch as if the query has been sent with Conn.Query. Prefer
|
||||
// calling Query on the QueuedQuery.
|
||||
Query() (Rows, error)
|
||||
|
||||
// QueryRow reads the results from the next query in the batch as if the query has been sent with Conn.QueryRow.
|
||||
// Prefer calling QueryRow on the QueuedQuery.
|
||||
QueryRow() Row
|
||||
|
||||
// Close closes the batch operation. This must be called before the underlying connection can be used again. Any error
|
||||
// that occurred during a batch operation may have made it impossible to resyncronize the connection with the server.
|
||||
// In this case the underlying connection will have been closed. Close is safe to call multiple times.
|
||||
// Close closes the batch operation. All unread results are read and any callback functions registered with
|
||||
// QueuedQuery.Query, QueuedQuery.QueryRow, or QueuedQuery.Exec will be called. If a callback function returns an
|
||||
// error or the batch encounters an error subsequent callback functions will not be called.
|
||||
//
|
||||
// Close must be called before the underlying connection can be used again. Any error that occurred during a batch
|
||||
// operation may have made it impossible to resyncronize the connection with the server. In this case the underlying
|
||||
// connection will have been closed.
|
||||
//
|
||||
// Close is safe to call multiple times. If it returns an error subsequent calls will return the same error. Callback
|
||||
// functions will not be rerun.
|
||||
Close() error
|
||||
}
|
||||
|
||||
@@ -55,7 +110,7 @@ type batchResults struct {
|
||||
mrr *pgconn.MultiResultReader
|
||||
err error
|
||||
b *Batch
|
||||
ix int
|
||||
qqIdx int
|
||||
closed bool
|
||||
endTraced bool
|
||||
}
|
||||
@@ -169,9 +224,14 @@ func (br *batchResults) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// consume and log any queries that haven't yet been logged by Exec or Query
|
||||
if br.conn.batchTracer != nil {
|
||||
for br.err == nil && !br.closed && br.b != nil && br.ix < len(br.b.items) {
|
||||
// Read and run fn for all remaining items
|
||||
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
|
||||
if br.b.queuedQueries[br.qqIdx].fn != nil {
|
||||
err := br.b.queuedQueries[br.qqIdx].fn(br)
|
||||
if err != nil && br.err == nil {
|
||||
br.err = err
|
||||
}
|
||||
} else {
|
||||
br.Exec()
|
||||
}
|
||||
}
|
||||
@@ -191,12 +251,12 @@ func (br *batchResults) earlyError() error {
|
||||
}
|
||||
|
||||
func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
|
||||
if br.b != nil && br.ix < len(br.b.items) {
|
||||
bi := br.b.items[br.ix]
|
||||
if br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
|
||||
bi := br.b.queuedQueries[br.qqIdx]
|
||||
query = bi.query
|
||||
args = bi.arguments
|
||||
ok = true
|
||||
br.ix++
|
||||
br.qqIdx++
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -208,7 +268,7 @@ type pipelineBatchResults struct {
|
||||
lastRows *baseRows
|
||||
err error
|
||||
b *Batch
|
||||
ix int
|
||||
qqIdx int
|
||||
closed bool
|
||||
endTraced bool
|
||||
}
|
||||
@@ -337,12 +397,18 @@ func (br *pipelineBatchResults) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// consume and log any queries that haven't yet been logged by Exec or Query
|
||||
if br.conn.batchTracer != nil {
|
||||
for br.err == nil && !br.closed && br.b != nil && br.ix < len(br.b.items) {
|
||||
// Read and run fn for all remaining items
|
||||
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
|
||||
if br.b.queuedQueries[br.qqIdx].fn != nil {
|
||||
err := br.b.queuedQueries[br.qqIdx].fn(br)
|
||||
if err != nil && br.err == nil {
|
||||
br.err = err
|
||||
}
|
||||
} else {
|
||||
br.Exec()
|
||||
}
|
||||
}
|
||||
|
||||
br.closed = true
|
||||
|
||||
err := br.pipeline.Close()
|
||||
@@ -358,12 +424,12 @@ func (br *pipelineBatchResults) earlyError() error {
|
||||
}
|
||||
|
||||
func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
|
||||
if br.b != nil && br.ix < len(br.b.items) {
|
||||
bi := br.b.items[br.ix]
|
||||
if br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
|
||||
bi := br.b.queuedQueries[br.qqIdx]
|
||||
query = bi.query
|
||||
args = bi.arguments
|
||||
ok = true
|
||||
br.ix++
|
||||
br.qqIdx++
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
+150
@@ -3,6 +3,7 @@ package pgx_test
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
@@ -148,6 +149,99 @@ func TestConnSendBatch(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnSendBatchQueuedQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test")
|
||||
|
||||
sql := `create temporary table ledger(
|
||||
id serial primary key,
|
||||
description varchar not null,
|
||||
amount int not null
|
||||
);`
|
||||
mustExec(t, conn, sql)
|
||||
|
||||
batch := &pgx.Batch{}
|
||||
|
||||
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1).Exec(func(ct pgconn.CommandTag) error {
|
||||
assert.EqualValues(t, 1, ct.RowsAffected())
|
||||
return nil
|
||||
})
|
||||
|
||||
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2).Exec(func(ct pgconn.CommandTag) error {
|
||||
assert.EqualValues(t, 1, ct.RowsAffected())
|
||||
return nil
|
||||
})
|
||||
|
||||
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3).Exec(func(ct pgconn.CommandTag) error {
|
||||
assert.EqualValues(t, 1, ct.RowsAffected())
|
||||
return nil
|
||||
})
|
||||
|
||||
selectFromLedgerExpectedRows := []struct {
|
||||
id int32
|
||||
description string
|
||||
amount int32
|
||||
}{
|
||||
{1, "q1", 1},
|
||||
{2, "q2", 2},
|
||||
{3, "q3", 3},
|
||||
}
|
||||
|
||||
batch.Queue("select id, description, amount from ledger order by id").Query(func(rows pgx.Rows) error {
|
||||
rowCount := 0
|
||||
var id int32
|
||||
var description string
|
||||
var amount int32
|
||||
_, err := pgx.ForEachRow(rows, []any{&id, &description, &amount}, func() error {
|
||||
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].id, id)
|
||||
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].description, description)
|
||||
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].amount, amount)
|
||||
rowCount++
|
||||
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
return nil
|
||||
})
|
||||
|
||||
batch.Queue("select id, description, amount from ledger order by id").Query(func(rows pgx.Rows) error {
|
||||
rowCount := 0
|
||||
var id int32
|
||||
var description string
|
||||
var amount int32
|
||||
_, err := pgx.ForEachRow(rows, []any{&id, &description, &amount}, func() error {
|
||||
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].id, id)
|
||||
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].description, description)
|
||||
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].amount, amount)
|
||||
rowCount++
|
||||
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
return nil
|
||||
})
|
||||
|
||||
batch.Queue("select * from ledger where false").QueryRow(func(row pgx.Row) error {
|
||||
err := row.Scan(nil, nil, nil)
|
||||
assert.ErrorIs(t, err, pgx.ErrNoRows)
|
||||
return nil
|
||||
})
|
||||
|
||||
batch.Queue("select sum(amount) from ledger").QueryRow(func(row pgx.Row) error {
|
||||
var sumAmount int32
|
||||
err := row.Scan(&sumAmount)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 6, sumAmount)
|
||||
return nil
|
||||
})
|
||||
|
||||
err := conn.SendBatch(context.Background(), batch).Close()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnSendBatchMany(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -773,3 +867,59 @@ func TestSendBatchSimpleProtocol(t *testing.T) {
|
||||
assert.EqualValues(t, 3, values[0])
|
||||
assert.False(t, rows.Next())
|
||||
}
|
||||
|
||||
func ExampleConn_SendBatch() {
|
||||
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
if err != nil {
|
||||
fmt.Printf("Unable to establish connection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
batch := &pgx.Batch{}
|
||||
batch.Queue("select 1 + 1").QueryRow(func(row pgx.Row) error {
|
||||
var n int32
|
||||
err := row.Scan(&n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Println(n)
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
batch.Queue("select 1 + 2").QueryRow(func(row pgx.Row) error {
|
||||
var n int32
|
||||
err := row.Scan(&n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Println(n)
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
batch.Queue("select 2 + 3").QueryRow(func(row pgx.Row) error {
|
||||
var n int32
|
||||
err := row.Scan(&n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Println(n)
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
err = conn.SendBatch(context.Background(), batch).Close()
|
||||
if err != nil {
|
||||
fmt.Printf("SendBatch error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Output:
|
||||
// 2
|
||||
// 3
|
||||
// 5
|
||||
}
|
||||
|
||||
@@ -801,7 +801,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
|
||||
|
||||
mode := c.config.DefaultQueryExecMode
|
||||
|
||||
for _, bi := range b.items {
|
||||
for _, bi := range b.queuedQueries {
|
||||
var queryRewriter QueryRewriter
|
||||
sql := bi.query
|
||||
arguments := bi.arguments
|
||||
@@ -830,7 +830,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
|
||||
}
|
||||
|
||||
// All other modes use extended protocol and thus can use prepared statements.
|
||||
for _, bi := range b.items {
|
||||
for _, bi := range b.queuedQueries {
|
||||
if sd, ok := c.preparedStatements[bi.query]; ok {
|
||||
bi.sd = sd
|
||||
}
|
||||
@@ -852,7 +852,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
|
||||
|
||||
func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults {
|
||||
var sb strings.Builder
|
||||
for i, bi := range b.items {
|
||||
for i, bi := range b.queuedQueries {
|
||||
if i > 0 {
|
||||
sb.WriteByte(';')
|
||||
}
|
||||
@@ -864,18 +864,18 @@ func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batc
|
||||
}
|
||||
mrr := c.pgConn.Exec(ctx, sb.String())
|
||||
return &batchResults{
|
||||
ctx: ctx,
|
||||
conn: c,
|
||||
mrr: mrr,
|
||||
b: b,
|
||||
ix: 0,
|
||||
ctx: ctx,
|
||||
conn: c,
|
||||
mrr: mrr,
|
||||
b: b,
|
||||
qqIdx: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults {
|
||||
batch := &pgconn.Batch{}
|
||||
|
||||
for _, bi := range b.items {
|
||||
for _, bi := range b.queuedQueries {
|
||||
sd := bi.sd
|
||||
if sd != nil {
|
||||
err := c.eqb.Build(c.typeMap, sd, bi.arguments)
|
||||
@@ -898,11 +898,11 @@ func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchR
|
||||
mrr := c.pgConn.ExecBatch(ctx, batch)
|
||||
|
||||
return &batchResults{
|
||||
ctx: ctx,
|
||||
conn: c,
|
||||
mrr: mrr,
|
||||
b: b,
|
||||
ix: 0,
|
||||
ctx: ctx,
|
||||
conn: c,
|
||||
mrr: mrr,
|
||||
b: b,
|
||||
qqIdx: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -914,7 +914,7 @@ func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batc
|
||||
distinctNewQueries := []*pgconn.StatementDescription{}
|
||||
distinctNewQueriesIdxMap := make(map[string]int)
|
||||
|
||||
for _, bi := range b.items {
|
||||
for _, bi := range b.queuedQueries {
|
||||
if bi.sd == nil {
|
||||
sd := c.statementCache.Get(bi.query)
|
||||
if sd != nil {
|
||||
@@ -946,7 +946,7 @@ func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch
|
||||
distinctNewQueries := []*pgconn.StatementDescription{}
|
||||
distinctNewQueriesIdxMap := make(map[string]int)
|
||||
|
||||
for _, bi := range b.items {
|
||||
for _, bi := range b.queuedQueries {
|
||||
if bi.sd == nil {
|
||||
sd := c.descriptionCache.Get(bi.query)
|
||||
if sd != nil {
|
||||
@@ -973,7 +973,7 @@ func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch)
|
||||
distinctNewQueries := []*pgconn.StatementDescription{}
|
||||
distinctNewQueriesIdxMap := make(map[string]int)
|
||||
|
||||
for _, bi := range b.items {
|
||||
for _, bi := range b.queuedQueries {
|
||||
if bi.sd == nil {
|
||||
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
|
||||
bi.sd = distinctNewQueries[idx]
|
||||
@@ -1045,7 +1045,7 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
|
||||
}
|
||||
|
||||
// Queue the queries.
|
||||
for _, bi := range b.items {
|
||||
for _, bi := range b.queuedQueries {
|
||||
err := c.eqb.Build(c.typeMap, bi.sd, bi.arguments)
|
||||
if err != nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
||||
|
||||
Reference in New Issue
Block a user