2
0

Add callback functions to queued queries

Improve batch query ergonomics by allowing the code to handle the
results of a query to be right next to the query.
This commit is contained in:
Jack Christensen
2022-07-16 17:46:47 -05:00
parent 78875bb95a
commit 29254180ca
4 changed files with 267 additions and 43 deletions
+8
View File
@@ -153,6 +153,14 @@ The `RowScanner` interface allows a single argument to Rows.Scan to scan the ent
Rather than every type that implemented `Begin` or `BeginTx` methods also needing to implement `BeginFunc` and
`BeginTxFunc` these methods have been converted to functions that take a db that implements `Begin` or `BeginTx`.
## Improved Batch Query Ergonomics
Previously, the code for building a batch went in one place before the call to `SendBatch`, and the code for reading the
results went in one place after the call to `SendBatch`. This could make it difficult to match up the query and the code
to handle the results. Now `Queue` returns a `QueuedQuery` which has methods `Query`, `QueryRow`, and `Exec` which can
be used to register a callback function that will handle the result. Callback functions are called automatically when
`BatchResults.Close` is called.
## SendBatch Uses Pipeline Mode When Appropriate
Previously, a batch with 10 unique parameterized statements executed 100 times would entail 11 network round trips. 1
+91 -25
View File
@@ -8,44 +8,99 @@ import (
"github.com/jackc/pgx/v5/pgconn"
)
type batchItem struct {
// QueuedQuery is a query that has been queued for execution via a Batch.
type QueuedQuery struct {
query string
arguments []any
fn batchItemFunc
sd *pgconn.StatementDescription
}
type batchItemFunc func(br BatchResults) error
// Query sets fn to be called when the response to qq is received.
func (qq *QueuedQuery) Query(fn func(rows Rows) error) {
qq.fn = func(br BatchResults) error {
rows, err := br.Query()
if err != nil {
return err
}
defer rows.Close()
err = fn(rows)
if err != nil {
return err
}
rows.Close()
return rows.Err()
}
}
// Query sets fn to be called when the response to qq is received.
func (qq *QueuedQuery) QueryRow(fn func(row Row) error) {
qq.fn = func(br BatchResults) error {
row := br.QueryRow()
return fn(row)
}
}
// Exec sets fn to be called when the response to qq is received.
func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) {
qq.fn = func(br BatchResults) error {
ct, err := br.Exec()
if err != nil {
return err
}
return fn(ct)
}
}
// Batch queries are a way of bundling multiple queries together to avoid
// unnecessary network round trips. A Batch must only be sent once.
type Batch struct {
items []*batchItem
queuedQueries []*QueuedQuery
}
// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement.
func (b *Batch) Queue(query string, arguments ...any) {
b.items = append(b.items, &batchItem{
func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery {
qq := &QueuedQuery{
query: query,
arguments: arguments,
})
}
b.queuedQueries = append(b.queuedQueries, qq)
return qq
}
// Len returns number of queries that have been queued so far.
func (b *Batch) Len() int {
return len(b.items)
return len(b.queuedQueries)
}
type BatchResults interface {
// Exec reads the results from the next query in the batch as if the query has been sent with Conn.Exec.
// Exec reads the results from the next query in the batch as if the query has been sent with Conn.Exec. Prefer
// calling Exec on the QueuedQuery.
Exec() (pgconn.CommandTag, error)
// Query reads the results from the next query in the batch as if the query has been sent with Conn.Query.
// Query reads the results from the next query in the batch as if the query has been sent with Conn.Query. Prefer
// calling Query on the QueuedQuery.
Query() (Rows, error)
// QueryRow reads the results from the next query in the batch as if the query has been sent with Conn.QueryRow.
// Prefer calling QueryRow on the QueuedQuery.
QueryRow() Row
// Close closes the batch operation. This must be called before the underlying connection can be used again. Any error
// that occurred during a batch operation may have made it impossible to resyncronize the connection with the server.
// In this case the underlying connection will have been closed. Close is safe to call multiple times.
// Close closes the batch operation. All unread results are read and any callback functions registered with
// QueuedQuery.Query, QueuedQuery.QueryRow, or QueuedQuery.Exec will be called. If a callback function returns an
// error or the batch encounters an error subsequent callback functions will not be called.
//
// Close must be called before the underlying connection can be used again. Any error that occurred during a batch
// operation may have made it impossible to resyncronize the connection with the server. In this case the underlying
// connection will have been closed.
//
// Close is safe to call multiple times. If it returns an error subsequent calls will return the same error. Callback
// functions will not be rerun.
Close() error
}
@@ -55,7 +110,7 @@ type batchResults struct {
mrr *pgconn.MultiResultReader
err error
b *Batch
ix int
qqIdx int
closed bool
endTraced bool
}
@@ -169,9 +224,14 @@ func (br *batchResults) Close() error {
return nil
}
// consume and log any queries that haven't yet been logged by Exec or Query
if br.conn.batchTracer != nil {
for br.err == nil && !br.closed && br.b != nil && br.ix < len(br.b.items) {
// Read and run fn for all remaining items
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
if br.b.queuedQueries[br.qqIdx].fn != nil {
err := br.b.queuedQueries[br.qqIdx].fn(br)
if err != nil && br.err == nil {
br.err = err
}
} else {
br.Exec()
}
}
@@ -191,12 +251,12 @@ func (br *batchResults) earlyError() error {
}
func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
if br.b != nil && br.ix < len(br.b.items) {
bi := br.b.items[br.ix]
if br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
bi := br.b.queuedQueries[br.qqIdx]
query = bi.query
args = bi.arguments
ok = true
br.ix++
br.qqIdx++
}
return
}
@@ -208,7 +268,7 @@ type pipelineBatchResults struct {
lastRows *baseRows
err error
b *Batch
ix int
qqIdx int
closed bool
endTraced bool
}
@@ -337,12 +397,18 @@ func (br *pipelineBatchResults) Close() error {
return nil
}
// consume and log any queries that haven't yet been logged by Exec or Query
if br.conn.batchTracer != nil {
for br.err == nil && !br.closed && br.b != nil && br.ix < len(br.b.items) {
// Read and run fn for all remaining items
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
if br.b.queuedQueries[br.qqIdx].fn != nil {
err := br.b.queuedQueries[br.qqIdx].fn(br)
if err != nil && br.err == nil {
br.err = err
}
} else {
br.Exec()
}
}
br.closed = true
err := br.pipeline.Close()
@@ -358,12 +424,12 @@ func (br *pipelineBatchResults) earlyError() error {
}
func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
if br.b != nil && br.ix < len(br.b.items) {
bi := br.b.items[br.ix]
if br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
bi := br.b.queuedQueries[br.qqIdx]
query = bi.query
args = bi.arguments
ok = true
br.ix++
br.qqIdx++
}
return
}
+150
View File
@@ -3,6 +3,7 @@ package pgx_test
import (
"context"
"errors"
"fmt"
"os"
"testing"
@@ -148,6 +149,99 @@ func TestConnSendBatch(t *testing.T) {
})
}
func TestConnSendBatchQueuedQuery(t *testing.T) {
t.Parallel()
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test")
sql := `create temporary table ledger(
id serial primary key,
description varchar not null,
amount int not null
);`
mustExec(t, conn, sql)
batch := &pgx.Batch{}
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1).Exec(func(ct pgconn.CommandTag) error {
assert.EqualValues(t, 1, ct.RowsAffected())
return nil
})
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2).Exec(func(ct pgconn.CommandTag) error {
assert.EqualValues(t, 1, ct.RowsAffected())
return nil
})
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3).Exec(func(ct pgconn.CommandTag) error {
assert.EqualValues(t, 1, ct.RowsAffected())
return nil
})
selectFromLedgerExpectedRows := []struct {
id int32
description string
amount int32
}{
{1, "q1", 1},
{2, "q2", 2},
{3, "q3", 3},
}
batch.Queue("select id, description, amount from ledger order by id").Query(func(rows pgx.Rows) error {
rowCount := 0
var id int32
var description string
var amount int32
_, err := pgx.ForEachRow(rows, []any{&id, &description, &amount}, func() error {
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].id, id)
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].description, description)
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].amount, amount)
rowCount++
return nil
})
assert.NoError(t, err)
return nil
})
batch.Queue("select id, description, amount from ledger order by id").Query(func(rows pgx.Rows) error {
rowCount := 0
var id int32
var description string
var amount int32
_, err := pgx.ForEachRow(rows, []any{&id, &description, &amount}, func() error {
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].id, id)
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].description, description)
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].amount, amount)
rowCount++
return nil
})
assert.NoError(t, err)
return nil
})
batch.Queue("select * from ledger where false").QueryRow(func(row pgx.Row) error {
err := row.Scan(nil, nil, nil)
assert.ErrorIs(t, err, pgx.ErrNoRows)
return nil
})
batch.Queue("select sum(amount) from ledger").QueryRow(func(row pgx.Row) error {
var sumAmount int32
err := row.Scan(&sumAmount)
assert.NoError(t, err)
assert.EqualValues(t, 6, sumAmount)
return nil
})
err := conn.SendBatch(context.Background(), batch).Close()
assert.NoError(t, err)
})
}
func TestConnSendBatchMany(t *testing.T) {
t.Parallel()
@@ -773,3 +867,59 @@ func TestSendBatchSimpleProtocol(t *testing.T) {
assert.EqualValues(t, 3, values[0])
assert.False(t, rows.Next())
}
func ExampleConn_SendBatch() {
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
if err != nil {
fmt.Printf("Unable to establish connection: %v", err)
return
}
batch := &pgx.Batch{}
batch.Queue("select 1 + 1").QueryRow(func(row pgx.Row) error {
var n int32
err := row.Scan(&n)
if err != nil {
return err
}
fmt.Println(n)
return err
})
batch.Queue("select 1 + 2").QueryRow(func(row pgx.Row) error {
var n int32
err := row.Scan(&n)
if err != nil {
return err
}
fmt.Println(n)
return err
})
batch.Queue("select 2 + 3").QueryRow(func(row pgx.Row) error {
var n int32
err := row.Scan(&n)
if err != nil {
return err
}
fmt.Println(n)
return err
})
err = conn.SendBatch(context.Background(), batch).Close()
if err != nil {
fmt.Printf("SendBatch error: %v", err)
return
}
// Output:
// 2
// 3
// 5
}
+18 -18
View File
@@ -801,7 +801,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
mode := c.config.DefaultQueryExecMode
for _, bi := range b.items {
for _, bi := range b.queuedQueries {
var queryRewriter QueryRewriter
sql := bi.query
arguments := bi.arguments
@@ -830,7 +830,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
}
// All other modes use extended protocol and thus can use prepared statements.
for _, bi := range b.items {
for _, bi := range b.queuedQueries {
if sd, ok := c.preparedStatements[bi.query]; ok {
bi.sd = sd
}
@@ -852,7 +852,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults {
var sb strings.Builder
for i, bi := range b.items {
for i, bi := range b.queuedQueries {
if i > 0 {
sb.WriteByte(';')
}
@@ -864,18 +864,18 @@ func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batc
}
mrr := c.pgConn.Exec(ctx, sb.String())
return &batchResults{
ctx: ctx,
conn: c,
mrr: mrr,
b: b,
ix: 0,
ctx: ctx,
conn: c,
mrr: mrr,
b: b,
qqIdx: 0,
}
}
func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults {
batch := &pgconn.Batch{}
for _, bi := range b.items {
for _, bi := range b.queuedQueries {
sd := bi.sd
if sd != nil {
err := c.eqb.Build(c.typeMap, sd, bi.arguments)
@@ -898,11 +898,11 @@ func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchR
mrr := c.pgConn.ExecBatch(ctx, batch)
return &batchResults{
ctx: ctx,
conn: c,
mrr: mrr,
b: b,
ix: 0,
ctx: ctx,
conn: c,
mrr: mrr,
b: b,
qqIdx: 0,
}
}
@@ -914,7 +914,7 @@ func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batc
distinctNewQueries := []*pgconn.StatementDescription{}
distinctNewQueriesIdxMap := make(map[string]int)
for _, bi := range b.items {
for _, bi := range b.queuedQueries {
if bi.sd == nil {
sd := c.statementCache.Get(bi.query)
if sd != nil {
@@ -946,7 +946,7 @@ func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch
distinctNewQueries := []*pgconn.StatementDescription{}
distinctNewQueriesIdxMap := make(map[string]int)
for _, bi := range b.items {
for _, bi := range b.queuedQueries {
if bi.sd == nil {
sd := c.descriptionCache.Get(bi.query)
if sd != nil {
@@ -973,7 +973,7 @@ func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch)
distinctNewQueries := []*pgconn.StatementDescription{}
distinctNewQueriesIdxMap := make(map[string]int)
for _, bi := range b.items {
for _, bi := range b.queuedQueries {
if bi.sd == nil {
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
bi.sd = distinctNewQueries[idx]
@@ -1045,7 +1045,7 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
}
// Queue the queries.
for _, bi := range b.items {
for _, bi := range b.queuedQueries {
err := c.eqb.Build(c.typeMap, bi.sd, bi.arguments)
if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}