diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a402c40..83c32783 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -153,6 +153,14 @@ The `RowScanner` interface allows a single argument to Rows.Scan to scan the ent Rather than every type that implemented `Begin` or `BeginTx` methods also needing to implement `BeginFunc` and `BeginTxFunc` these methods have been converted to functions that take a db that implements `Begin` or `BeginTx`. +## Improved Batch Query Ergonomics + +Previously, the code for building a batch went in one place before the call to `SendBatch`, and the code for reading the +results went in one place after the call to `SendBatch`. This could make it difficult to match up the query and the code +to handle the results. Now `Queue` returns a `QueuedQuery` which has methods `Query`, `QueryRow`, and `Exec` which can +be used to register a callback function that will handle the result. Callback functions are called automatically when +`BatchResults.Close` is called. + ## SendBatch Uses Pipeline Mode When Appropriate Previously, a batch with 10 unique parameterized statements executed 100 times would entail 11 network round trips. 1 diff --git a/batch.go b/batch.go index a6951096..af62039f 100644 --- a/batch.go +++ b/batch.go @@ -8,44 +8,99 @@ import ( "github.com/jackc/pgx/v5/pgconn" ) -type batchItem struct { +// QueuedQuery is a query that has been queued for execution via a Batch. +type QueuedQuery struct { query string arguments []any + fn batchItemFunc sd *pgconn.StatementDescription } +type batchItemFunc func(br BatchResults) error + +// Query sets fn to be called when the response to qq is received. +func (qq *QueuedQuery) Query(fn func(rows Rows) error) { + qq.fn = func(br BatchResults) error { + rows, err := br.Query() + if err != nil { + return err + } + defer rows.Close() + + err = fn(rows) + if err != nil { + return err + } + rows.Close() + + return rows.Err() + } +} + +// Query sets fn to be called when the response to qq is received. +func (qq *QueuedQuery) QueryRow(fn func(row Row) error) { + qq.fn = func(br BatchResults) error { + row := br.QueryRow() + return fn(row) + } +} + +// Exec sets fn to be called when the response to qq is received. +func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) { + qq.fn = func(br BatchResults) error { + ct, err := br.Exec() + if err != nil { + return err + } + + return fn(ct) + } +} + // Batch queries are a way of bundling multiple queries together to avoid // unnecessary network round trips. A Batch must only be sent once. type Batch struct { - items []*batchItem + queuedQueries []*QueuedQuery } // Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. -func (b *Batch) Queue(query string, arguments ...any) { - b.items = append(b.items, &batchItem{ +func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery { + qq := &QueuedQuery{ query: query, arguments: arguments, - }) + } + b.queuedQueries = append(b.queuedQueries, qq) + return qq } // Len returns number of queries that have been queued so far. func (b *Batch) Len() int { - return len(b.items) + return len(b.queuedQueries) } type BatchResults interface { - // Exec reads the results from the next query in the batch as if the query has been sent with Conn.Exec. + // Exec reads the results from the next query in the batch as if the query has been sent with Conn.Exec. Prefer + // calling Exec on the QueuedQuery. Exec() (pgconn.CommandTag, error) - // Query reads the results from the next query in the batch as if the query has been sent with Conn.Query. + // Query reads the results from the next query in the batch as if the query has been sent with Conn.Query. Prefer + // calling Query on the QueuedQuery. Query() (Rows, error) // QueryRow reads the results from the next query in the batch as if the query has been sent with Conn.QueryRow. + // Prefer calling QueryRow on the QueuedQuery. QueryRow() Row - // Close closes the batch operation. This must be called before the underlying connection can be used again. Any error - // that occurred during a batch operation may have made it impossible to resyncronize the connection with the server. - // In this case the underlying connection will have been closed. Close is safe to call multiple times. + // Close closes the batch operation. All unread results are read and any callback functions registered with + // QueuedQuery.Query, QueuedQuery.QueryRow, or QueuedQuery.Exec will be called. If a callback function returns an + // error or the batch encounters an error subsequent callback functions will not be called. + // + // Close must be called before the underlying connection can be used again. Any error that occurred during a batch + // operation may have made it impossible to resyncronize the connection with the server. In this case the underlying + // connection will have been closed. + // + // Close is safe to call multiple times. If it returns an error subsequent calls will return the same error. Callback + // functions will not be rerun. Close() error } @@ -55,7 +110,7 @@ type batchResults struct { mrr *pgconn.MultiResultReader err error b *Batch - ix int + qqIdx int closed bool endTraced bool } @@ -169,9 +224,14 @@ func (br *batchResults) Close() error { return nil } - // consume and log any queries that haven't yet been logged by Exec or Query - if br.conn.batchTracer != nil { - for br.err == nil && !br.closed && br.b != nil && br.ix < len(br.b.items) { + // Read and run fn for all remaining items + for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) { + if br.b.queuedQueries[br.qqIdx].fn != nil { + err := br.b.queuedQueries[br.qqIdx].fn(br) + if err != nil && br.err == nil { + br.err = err + } + } else { br.Exec() } } @@ -191,12 +251,12 @@ func (br *batchResults) earlyError() error { } func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) { - if br.b != nil && br.ix < len(br.b.items) { - bi := br.b.items[br.ix] + if br.b != nil && br.qqIdx < len(br.b.queuedQueries) { + bi := br.b.queuedQueries[br.qqIdx] query = bi.query args = bi.arguments ok = true - br.ix++ + br.qqIdx++ } return } @@ -208,7 +268,7 @@ type pipelineBatchResults struct { lastRows *baseRows err error b *Batch - ix int + qqIdx int closed bool endTraced bool } @@ -337,12 +397,18 @@ func (br *pipelineBatchResults) Close() error { return nil } - // consume and log any queries that haven't yet been logged by Exec or Query - if br.conn.batchTracer != nil { - for br.err == nil && !br.closed && br.b != nil && br.ix < len(br.b.items) { + // Read and run fn for all remaining items + for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) { + if br.b.queuedQueries[br.qqIdx].fn != nil { + err := br.b.queuedQueries[br.qqIdx].fn(br) + if err != nil && br.err == nil { + br.err = err + } + } else { br.Exec() } } + br.closed = true err := br.pipeline.Close() @@ -358,12 +424,12 @@ func (br *pipelineBatchResults) earlyError() error { } func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) { - if br.b != nil && br.ix < len(br.b.items) { - bi := br.b.items[br.ix] + if br.b != nil && br.qqIdx < len(br.b.queuedQueries) { + bi := br.b.queuedQueries[br.qqIdx] query = bi.query args = bi.arguments ok = true - br.ix++ + br.qqIdx++ } return } diff --git a/batch_test.go b/batch_test.go index 156e8f8f..2ade0d4a 100644 --- a/batch_test.go +++ b/batch_test.go @@ -3,6 +3,7 @@ package pgx_test import ( "context" "errors" + "fmt" "os" "testing" @@ -148,6 +149,99 @@ func TestConnSendBatch(t *testing.T) { }) } +func TestConnSendBatchQueuedQuery(t *testing.T) { + t.Parallel() + + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test") + + sql := `create temporary table ledger( + id serial primary key, + description varchar not null, + amount int not null + );` + mustExec(t, conn, sql) + + batch := &pgx.Batch{} + + batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1).Exec(func(ct pgconn.CommandTag) error { + assert.EqualValues(t, 1, ct.RowsAffected()) + return nil + }) + + batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2).Exec(func(ct pgconn.CommandTag) error { + assert.EqualValues(t, 1, ct.RowsAffected()) + return nil + }) + + batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3).Exec(func(ct pgconn.CommandTag) error { + assert.EqualValues(t, 1, ct.RowsAffected()) + return nil + }) + + selectFromLedgerExpectedRows := []struct { + id int32 + description string + amount int32 + }{ + {1, "q1", 1}, + {2, "q2", 2}, + {3, "q3", 3}, + } + + batch.Queue("select id, description, amount from ledger order by id").Query(func(rows pgx.Rows) error { + rowCount := 0 + var id int32 + var description string + var amount int32 + _, err := pgx.ForEachRow(rows, []any{&id, &description, &amount}, func() error { + assert.Equal(t, selectFromLedgerExpectedRows[rowCount].id, id) + assert.Equal(t, selectFromLedgerExpectedRows[rowCount].description, description) + assert.Equal(t, selectFromLedgerExpectedRows[rowCount].amount, amount) + rowCount++ + + return nil + }) + assert.NoError(t, err) + return nil + }) + + batch.Queue("select id, description, amount from ledger order by id").Query(func(rows pgx.Rows) error { + rowCount := 0 + var id int32 + var description string + var amount int32 + _, err := pgx.ForEachRow(rows, []any{&id, &description, &amount}, func() error { + assert.Equal(t, selectFromLedgerExpectedRows[rowCount].id, id) + assert.Equal(t, selectFromLedgerExpectedRows[rowCount].description, description) + assert.Equal(t, selectFromLedgerExpectedRows[rowCount].amount, amount) + rowCount++ + + return nil + }) + assert.NoError(t, err) + return nil + }) + + batch.Queue("select * from ledger where false").QueryRow(func(row pgx.Row) error { + err := row.Scan(nil, nil, nil) + assert.ErrorIs(t, err, pgx.ErrNoRows) + return nil + }) + + batch.Queue("select sum(amount) from ledger").QueryRow(func(row pgx.Row) error { + var sumAmount int32 + err := row.Scan(&sumAmount) + assert.NoError(t, err) + assert.EqualValues(t, 6, sumAmount) + return nil + }) + + err := conn.SendBatch(context.Background(), batch).Close() + assert.NoError(t, err) + }) +} + func TestConnSendBatchMany(t *testing.T) { t.Parallel() @@ -773,3 +867,59 @@ func TestSendBatchSimpleProtocol(t *testing.T) { assert.EqualValues(t, 3, values[0]) assert.False(t, rows.Next()) } + +func ExampleConn_SendBatch() { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + batch := &pgx.Batch{} + batch.Queue("select 1 + 1").QueryRow(func(row pgx.Row) error { + var n int32 + err := row.Scan(&n) + if err != nil { + return err + } + + fmt.Println(n) + + return err + }) + + batch.Queue("select 1 + 2").QueryRow(func(row pgx.Row) error { + var n int32 + err := row.Scan(&n) + if err != nil { + return err + } + + fmt.Println(n) + + return err + }) + + batch.Queue("select 2 + 3").QueryRow(func(row pgx.Row) error { + var n int32 + err := row.Scan(&n) + if err != nil { + return err + } + + fmt.Println(n) + + return err + }) + + err = conn.SendBatch(context.Background(), batch).Close() + if err != nil { + fmt.Printf("SendBatch error: %v", err) + return + } + + // Output: + // 2 + // 3 + // 5 +} diff --git a/conn.go b/conn.go index b8e0b232..1a43a3ca 100644 --- a/conn.go +++ b/conn.go @@ -801,7 +801,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) { mode := c.config.DefaultQueryExecMode - for _, bi := range b.items { + for _, bi := range b.queuedQueries { var queryRewriter QueryRewriter sql := bi.query arguments := bi.arguments @@ -830,7 +830,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) { } // All other modes use extended protocol and thus can use prepared statements. - for _, bi := range b.items { + for _, bi := range b.queuedQueries { if sd, ok := c.preparedStatements[bi.query]; ok { bi.sd = sd } @@ -852,7 +852,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) { func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults { var sb strings.Builder - for i, bi := range b.items { + for i, bi := range b.queuedQueries { if i > 0 { sb.WriteByte(';') } @@ -864,18 +864,18 @@ func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batc } mrr := c.pgConn.Exec(ctx, sb.String()) return &batchResults{ - ctx: ctx, - conn: c, - mrr: mrr, - b: b, - ix: 0, + ctx: ctx, + conn: c, + mrr: mrr, + b: b, + qqIdx: 0, } } func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults { batch := &pgconn.Batch{} - for _, bi := range b.items { + for _, bi := range b.queuedQueries { sd := bi.sd if sd != nil { err := c.eqb.Build(c.typeMap, sd, bi.arguments) @@ -898,11 +898,11 @@ func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchR mrr := c.pgConn.ExecBatch(ctx, batch) return &batchResults{ - ctx: ctx, - conn: c, - mrr: mrr, - b: b, - ix: 0, + ctx: ctx, + conn: c, + mrr: mrr, + b: b, + qqIdx: 0, } } @@ -914,7 +914,7 @@ func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batc distinctNewQueries := []*pgconn.StatementDescription{} distinctNewQueriesIdxMap := make(map[string]int) - for _, bi := range b.items { + for _, bi := range b.queuedQueries { if bi.sd == nil { sd := c.statementCache.Get(bi.query) if sd != nil { @@ -946,7 +946,7 @@ func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch distinctNewQueries := []*pgconn.StatementDescription{} distinctNewQueriesIdxMap := make(map[string]int) - for _, bi := range b.items { + for _, bi := range b.queuedQueries { if bi.sd == nil { sd := c.descriptionCache.Get(bi.query) if sd != nil { @@ -973,7 +973,7 @@ func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch) distinctNewQueries := []*pgconn.StatementDescription{} distinctNewQueriesIdxMap := make(map[string]int) - for _, bi := range b.items { + for _, bi := range b.queuedQueries { if bi.sd == nil { if idx, present := distinctNewQueriesIdxMap[bi.query]; present { bi.sd = distinctNewQueries[idx] @@ -1045,7 +1045,7 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d } // Queue the queries. - for _, bi := range b.items { + for _, bi := range b.queuedQueries { err := c.eqb.Build(c.typeMap, bi.sd, bi.arguments) if err != nil { return &pipelineBatchResults{ctx: ctx, conn: c, err: err}