2
0

SendBatch now uses pipeline mode to prepare and describe statements

Previously, a batch with 10 unique parameterized statements executed
100 times would entail 11 network round trips. 1 for each prepare /
describe and 1 for executing them all. Now pipeline mode is used to
prepare / describe all statements in a single network round trip. So it
would only take 2 round trips.
This commit is contained in:
Jack Christensen
2022-07-09 09:28:11 -05:00
parent ba58e3d5d2
commit e7aa76ccf9
12 changed files with 694 additions and 612 deletions
+6
View File
@@ -139,6 +139,12 @@ The `RowScanner` interface allows a single argument to Rows.Scan to scan the ent
`QueryFunc` has been replaced by using `ForEachScannedRow`.
## SendBatch Uses Pipeline Mode When Appropriate
Previously, a batch with 10 unique parameterized statements executed 100 times would entail 11 network round trips. 1
for each prepare / describe and 1 for executing them all. Now pipeline mode is used to prepare / describe all statements
in a single network round trip. So it would only take 2 round trips.
## 3rd Party Logger Integration
All integrations with 3rd party loggers have been extracted to separate repositories. This trims the pgx dependency
+163
View File
@@ -11,6 +11,7 @@ import (
type batchItem struct {
query string
arguments []any
sd *pgconn.StatementDescription
}
// Batch queries are a way of bundling multiple queries together to avoid
@@ -192,3 +193,165 @@ func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
}
return
}
type pipelineBatchResults struct {
ctx context.Context
conn *Conn
pipeline *pgconn.Pipeline
lastRows *baseRows
err error
b *Batch
ix int
closed bool
}
// Exec reads the results from the next query in the batch as if the query has been sent with Exec.
func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) {
if br.err != nil {
return pgconn.CommandTag{}, br.err
}
if br.closed {
return pgconn.CommandTag{}, fmt.Errorf("batch already closed")
}
if br.lastRows != nil && br.lastRows.err != nil {
return pgconn.CommandTag{}, br.err
}
query, arguments, _ := br.nextQueryAndArgs()
results, err := br.pipeline.GetResults()
if err != nil {
br.err = err
return pgconn.CommandTag{}, err
}
var commandTag pgconn.CommandTag
switch results := results.(type) {
case *pgconn.ResultReader:
commandTag, err = results.Close()
default:
return pgconn.CommandTag{}, fmt.Errorf("unexpected pipeline result: %T", results)
}
if err != nil {
br.err = err
if br.conn.shouldLog(LogLevelError) {
br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]any{
"sql": query,
"args": logQueryArgs(arguments),
"err": err,
})
}
} else if br.conn.shouldLog(LogLevelInfo) {
br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Exec", map[string]any{
"sql": query,
"args": logQueryArgs(arguments),
"commandTag": commandTag,
})
}
return commandTag, err
}
// Query reads the results from the next query in the batch as if the query has been sent with Query.
func (br *pipelineBatchResults) Query() (Rows, error) {
if br.err != nil {
return &baseRows{err: br.err, closed: true}, br.err
}
if br.closed {
alreadyClosedErr := fmt.Errorf("batch already closed")
return &baseRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr
}
if br.lastRows != nil && br.lastRows.err != nil {
br.err = br.lastRows.err
return &baseRows{err: br.err, closed: true}, br.err
}
query, arguments, ok := br.nextQueryAndArgs()
if !ok {
query = "batch query"
}
rows := br.conn.getRows(br.ctx, query, arguments)
br.lastRows = rows
results, err := br.pipeline.GetResults()
if err != nil {
br.err = err
rows.err = err
rows.closed = true
if br.conn.shouldLog(LogLevelError) {
br.conn.log(br.ctx, LogLevelError, "BatchResult.Query", map[string]any{
"sql": query,
"args": logQueryArgs(arguments),
"err": rows.err,
})
}
} else {
switch results := results.(type) {
case *pgconn.ResultReader:
rows.resultReader = results
default:
err = fmt.Errorf("unexpected pipeline result: %T", results)
br.err = err
rows.err = err
rows.closed = true
}
}
return rows, rows.err
}
// QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow.
func (br *pipelineBatchResults) QueryRow() Row {
rows, _ := br.Query()
return (*connRow)(rows.(*baseRows))
}
// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to
// resyncronize the connection with the server. In this case the underlying connection will have been closed.
func (br *pipelineBatchResults) Close() error {
if br.err != nil {
return br.err
}
if br.lastRows != nil && br.lastRows.err != nil {
br.err = br.lastRows.err
return br.err
}
if br.closed {
return nil
}
br.closed = true
// log any queries that haven't yet been logged by Exec or Query
for {
query, args, ok := br.nextQueryAndArgs()
if !ok {
break
}
if br.conn.shouldLog(LogLevelInfo) {
br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Close", map[string]any{
"sql": query,
"args": logQueryArgs(args),
})
}
}
return br.pipeline.Close()
}
func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
if br.b != nil && br.ix < len(br.b.items) {
bi := br.b.items[br.ix]
query = bi.query
args = bi.arguments
ok = true
br.ix++
}
return
}
+1 -1
View File
@@ -420,7 +420,7 @@ func TestConnSendBatchQueryError(t *testing.T) {
err = br.Close()
if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") {
t.Errorf("rows.Err() => %v, want error code %v", err, 22012)
t.Errorf("br.Close() => %v, want error code %v", err, 22012)
}
})
+300 -104
View File
@@ -236,11 +236,11 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
c.wbuf = make([]byte, 0, 1024)
if c.config.StatementCacheCapacity > 0 {
c.statementCache = stmtcache.New(c.pgConn, stmtcache.ModePrepare, c.config.StatementCacheCapacity)
c.statementCache = stmtcache.NewLRUCache(c.config.StatementCacheCapacity)
}
if c.config.DescriptionCacheCapacity > 0 {
c.descriptionCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, c.config.DescriptionCacheCapacity)
c.descriptionCache = stmtcache.NewLRUCache(c.config.DescriptionCacheCapacity)
}
return c, nil
@@ -382,6 +382,10 @@ func (c *Conn) Config() *ConnConfig { return c.config.Copy() }
// Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced
// positionally from the sql string as $1, $2, etc.
func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) {
if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
return pgconn.CommandTag{}, err
}
startTime := time.Now()
commandTag, err := c.exec(ctx, sql, arguments...)
@@ -437,9 +441,13 @@ optionLoop:
if c.statementCache == nil {
return pgconn.CommandTag{}, errDisabledStatementCache
}
sd, err := c.statementCache.Get(ctx, sql)
if err != nil {
return pgconn.CommandTag{}, err
sd := c.statementCache.Get(sql)
if sd == nil {
sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql)
if err != nil {
return pgconn.CommandTag{}, err
}
c.statementCache.Put(sd)
}
return c.execPrepared(ctx, sd, arguments)
@@ -447,9 +455,12 @@ optionLoop:
if c.descriptionCache == nil {
return pgconn.CommandTag{}, errDisabledDescriptionCache
}
sd, err := c.descriptionCache.Get(ctx, sql)
if err != nil {
return pgconn.CommandTag{}, err
sd := c.descriptionCache.Get(sql)
if sd == nil {
sd, err = c.Prepare(ctx, "", sql)
if err != nil {
return pgconn.CommandTag{}, err
}
}
return c.execParams(ctx, sd, arguments)
@@ -620,6 +631,10 @@ type QueryRewriter interface {
// QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely
// needed. See the documentation for those types for details.
func (c *Conn) Query(ctx context.Context, sql string, args ...any) (Rows, error) {
if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
return &baseRows{err: err, closed: true}, err
}
var resultFormats QueryResultFormats
var resultFormatsByOID QueryResultFormatsByOID
mode := c.config.DefaultQueryExecMode
@@ -649,6 +664,11 @@ optionLoop:
sql, args = queryRewriter.RewriteQuery(ctx, c, sql, args)
}
// Bypass any statement caching.
if sql == "" {
mode = QueryExecModeSimpleProtocol
}
c.eqb.reset()
anynil.NormalizeSlice(args)
rows := c.getRows(ctx, sql, args)
@@ -664,10 +684,14 @@ optionLoop:
rows.fatal(err)
return rows, err
}
sd, err = c.statementCache.Get(ctx, sql)
if err != nil {
rows.fatal(err)
return rows, err
sd = c.statementCache.Get(sql)
if sd == nil {
sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql)
if err != nil {
rows.fatal(err)
return rows, err
}
c.statementCache.Put(sd)
}
case QueryExecModeCacheDescribe:
if c.descriptionCache == nil {
@@ -675,10 +699,14 @@ optionLoop:
rows.fatal(err)
return rows, err
}
sd, err = c.descriptionCache.Get(ctx, sql)
if err != nil {
rows.fatal(err)
return rows, err
sd = c.descriptionCache.Get(sql)
if sd == nil {
sd, err = c.Prepare(ctx, "", sql)
if err != nil {
rows.fatal(err)
return rows, err
}
c.descriptionCache.Put(sd)
}
case QueryExecModeDescribeExec:
sd, err = c.Prepare(ctx, "", sql)
@@ -767,6 +795,10 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row {
// explicit transaction control statements are executed. The returned BatchResults must be closed before the connection
// is used again.
func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
mode := c.config.DefaultQueryExecMode
for _, bi := range b.items {
@@ -794,105 +826,70 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
}
if mode == QueryExecModeSimpleProtocol {
var sb strings.Builder
for i, bi := range b.items {
if i > 0 {
sb.WriteByte(';')
}
sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
sb.WriteString(sql)
}
mrr := c.pgConn.Exec(ctx, sb.String())
return &batchResults{
ctx: ctx,
conn: c,
mrr: mrr,
b: b,
ix: 0,
return c.sendBatchQueryExecModeSimpleProtocol(ctx, b)
}
// All other modes use extended protocol and thus can use prepared statements.
for _, bi := range b.items {
if sd, ok := c.preparedStatements[bi.query]; ok {
bi.sd = sd
}
}
switch mode {
case QueryExecModeExec:
return c.sendBatchQueryExecModeExec(ctx, b)
case QueryExecModeCacheStatement:
return c.sendBatchQueryExecModeCacheStatement(ctx, b)
case QueryExecModeCacheDescribe:
return c.sendBatchQueryExecModeCacheDescribe(ctx, b)
case QueryExecModeDescribeExec:
return c.sendBatchQueryExecModeDescribeExec(ctx, b)
default:
panic("unknown QueryExecMode")
}
}
func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults {
var sb strings.Builder
for i, bi := range b.items {
if i > 0 {
sb.WriteByte(';')
}
sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
sb.WriteString(sql)
}
mrr := c.pgConn.Exec(ctx, sb.String())
return &batchResults{
ctx: ctx,
conn: c,
mrr: mrr,
b: b,
ix: 0,
}
}
func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults {
batch := &pgconn.Batch{}
if mode == QueryExecModeExec {
for _, bi := range b.items {
c.eqb.reset()
anynil.NormalizeSlice(bi.arguments)
sd := c.preparedStatements[bi.query]
if sd != nil {
err := c.eqb.Build(c.typeMap, sd, bi.arguments)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
} else {
err := c.eqb.Build(c.typeMap, nil, bi.arguments)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
}
}
} else {
distinctUnpreparedQueries := map[string]struct{}{}
for _, bi := range b.items {
if _, ok := c.preparedStatements[bi.query]; ok {
continue
}
distinctUnpreparedQueries[bi.query] = struct{}{}
}
var stmtCache stmtcache.Cache
if len(distinctUnpreparedQueries) > 0 {
if mode == QueryExecModeCacheStatement && c.statementCache != nil && c.statementCache.Cap() >= len(distinctUnpreparedQueries) {
stmtCache = c.statementCache
} else if mode == QueryExecModeCacheStatement && c.descriptionCache != nil && c.descriptionCache.Cap() >= len(distinctUnpreparedQueries) {
stmtCache = c.descriptionCache
} else {
stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries))
}
for sql, _ := range distinctUnpreparedQueries {
_, err := stmtCache.Get(ctx, sql)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
}
}
for _, bi := range b.items {
c.eqb.reset()
sd := c.preparedStatements[bi.query]
if sd == nil {
var err error
sd, err = stmtCache.Get(ctx, bi.query)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
}
if len(sd.ParamOIDs) != len(bi.arguments) {
return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")}
}
for _, bi := range b.items {
sd := bi.sd
if sd != nil {
err := c.eqb.Build(c.typeMap, sd, bi.arguments)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
if sd.Name == "" {
batch.ExecParams(bi.query, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats)
} else {
batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
} else {
err := c.eqb.Build(c.typeMap, nil, bi.arguments)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
}
}
@@ -909,6 +906,171 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
}
}
func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
if c.statementCache == nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache}
}
distinctNewQueries := []*pgconn.StatementDescription{}
distinctNewQueriesIdxMap := make(map[string]int)
for _, bi := range b.items {
if bi.sd == nil {
sd := c.statementCache.Get(bi.query)
if sd != nil {
bi.sd = sd
} else {
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
bi.sd = distinctNewQueries[idx]
} else {
sd = &pgconn.StatementDescription{
Name: stmtcache.NextStatementName(),
SQL: bi.query,
}
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
distinctNewQueries = append(distinctNewQueries, sd)
bi.sd = sd
}
}
}
}
return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, c.statementCache)
}
func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
if c.descriptionCache == nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache}
}
distinctNewQueries := []*pgconn.StatementDescription{}
distinctNewQueriesIdxMap := make(map[string]int)
for _, bi := range b.items {
if bi.sd == nil {
sd := c.descriptionCache.Get(bi.query)
if sd != nil {
bi.sd = sd
} else {
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
bi.sd = distinctNewQueries[idx]
} else {
sd = &pgconn.StatementDescription{
SQL: bi.query,
}
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
distinctNewQueries = append(distinctNewQueries, sd)
bi.sd = sd
}
}
}
}
return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, c.descriptionCache)
}
func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
distinctNewQueries := []*pgconn.StatementDescription{}
distinctNewQueriesIdxMap := make(map[string]int)
for _, bi := range b.items {
if bi.sd == nil {
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
bi.sd = distinctNewQueries[idx]
} else {
sd := &pgconn.StatementDescription{
SQL: bi.query,
}
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
distinctNewQueries = append(distinctNewQueries, sd)
bi.sd = sd
}
}
}
return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, nil)
}
func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, distinctNewQueries []*pgconn.StatementDescription, sdCache stmtcache.Cache) (pbr *pipelineBatchResults) {
pipeline := c.pgConn.StartPipeline(context.Background())
defer func() {
if pbr.err != nil {
pipeline.Close()
}
}()
// Prepare any needed queries
if len(distinctNewQueries) > 0 {
for _, sd := range distinctNewQueries {
pipeline.SendPrepare(sd.Name, sd.SQL, nil)
}
err := pipeline.Sync()
if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
}
for _, sd := range distinctNewQueries {
results, err := pipeline.GetResults()
if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
}
resultSD, ok := results.(*pgconn.StatementDescription)
if !ok {
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results)}
}
// Fill in the previously empty / pending statement descriptions.
sd.ParamOIDs = resultSD.ParamOIDs
sd.Fields = resultSD.Fields
}
results, err := pipeline.GetResults()
if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
}
_, ok := results.(*pgconn.PipelineSync)
if !ok {
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results)}
}
}
// Put all statements into the cache. It's fine if it overflows because HandleInvalidated will clean them up later.
if sdCache != nil {
for _, sd := range distinctNewQueries {
c.statementCache.Put(sd)
}
}
// Queue the queries.
for _, bi := range b.items {
err := c.eqb.Build(c.typeMap, bi.sd, bi.arguments)
if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
}
if bi.sd.Name == "" {
pipeline.SendQueryParams(bi.sd.SQL, c.eqb.ParamValues, bi.sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats)
} else {
pipeline.SendQueryPrepared(bi.sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
}
}
err := pipeline.Sync()
if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
}
return &pipelineBatchResults{
ctx: ctx,
conn: c,
pipeline: pipeline,
b: b,
}
}
func (c *Conn) sanitizeForSimpleQuery(sql string, args ...any) (string, error) {
if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" {
return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on")
@@ -1015,3 +1177,37 @@ order by attnum`,
return fields, nil
}
func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error {
if c.descriptionCache != nil {
c.descriptionCache.HandleInvalidated()
}
var invalidatedStatements []*pgconn.StatementDescription
if c.statementCache != nil {
invalidatedStatements = c.statementCache.HandleInvalidated()
}
if len(invalidatedStatements) == 0 {
return nil
}
pipeline := c.pgConn.StartPipeline(ctx)
defer pipeline.Close()
for _, sd := range invalidatedStatements {
pipeline.SendDeallocate(sd.Name)
}
err := pipeline.Sync()
if err != nil {
return fmt.Errorf("failed to deallocate cached statement(s): %w", err)
}
err = pipeline.Close()
if err != nil {
return fmt.Errorf("failed to deallocate cached statement(s): %w", err)
}
return nil
}
+9 -8
View File
@@ -931,6 +931,7 @@ func TestStmtCacheInvalidationConn(t *testing.T) {
rows, err := conn.Query(ctx, getSQL, 1)
require.NoError(t, err)
rows.Close()
require.NoError(t, rows.Err())
// Now, change the schema of the table out from under the statement, making it invalid.
_, err = conn.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1")
@@ -948,10 +949,10 @@ func TestStmtCacheInvalidationConn(t *testing.T) {
rows.Close()
for _, err := range []error{nextErr, rows.Err()} {
if err == nil {
t.Fatal("expected InvalidCachedStatementPlanError: no error")
t.Fatal(`expected "cached plan must not change result type": no error`)
}
if !strings.Contains(err.Error(), "cached plan must not change result type") {
t.Fatalf("expected InvalidCachedStatementPlanError, got: %s", err.Error())
t.Fatalf(`expected "cached plan must not change result type", got: "%s"`, err.Error())
}
}
@@ -995,6 +996,7 @@ func TestStmtCacheInvalidationTx(t *testing.T) {
rows, err := tx.Query(ctx, getSQL, 1)
require.NoError(t, err)
rows.Close()
require.NoError(t, rows.Err())
// Now, change the schema of the table out from under the statement, making it invalid.
_, err = tx.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1")
@@ -1012,18 +1014,17 @@ func TestStmtCacheInvalidationTx(t *testing.T) {
rows.Close()
for _, err := range []error{nextErr, rows.Err()} {
if err == nil {
t.Fatal("expected InvalidCachedStatementPlanError: no error")
t.Fatal(`expected "cached plan must not change result type": no error`)
}
if !strings.Contains(err.Error(), "cached plan must not change result type") {
t.Fatalf("expected InvalidCachedStatementPlanError, got: %s", err.Error())
t.Fatalf(`expected "cached plan must not change result type", got: "%s"`, err.Error())
}
}
rows, err = tx.Query(ctx, getSQL, 1)
require.NoError(t, err) // error does not pop up immediately
rows.Next()
rows, _ = tx.Query(ctx, getSQL, 1)
rows.Close()
err = rows.Err()
// Retries within the same transaction are errors (really anything except a rollbakc
// Retries within the same transaction are errors (really anything except a rollback
// will be an error in this transaction).
require.Error(t, err)
rows.Close()
-169
View File
@@ -1,169 +0,0 @@
package stmtcache
import (
"container/list"
"context"
"fmt"
"sync/atomic"
"github.com/jackc/pgx/v5/pgconn"
)
var lruCount uint64
// LRU implements Cache with a Least Recently Used (LRU) cache.
type LRU struct {
conn *pgconn.PgConn
mode int
cap int
prepareCount int
m map[string]*list.Element
l *list.List
psNamePrefix string
stmtsToClear []string
}
// NewLRU creates a new LRU. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache.
func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU {
mustBeValidMode(mode)
mustBeValidCap(cap)
n := atomic.AddUint64(&lruCount, 1)
return &LRU{
conn: conn,
mode: mode,
cap: cap,
m: make(map[string]*list.Element),
l: list.New(),
psNamePrefix: fmt.Sprintf("lrupsc_%d", n),
}
}
// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed.
func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
if ctx != context.Background() {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
}
// flush an outstanding bad statements
txStatus := c.conn.TxStatus()
if (txStatus == 'I' || txStatus == 'T') && len(c.stmtsToClear) > 0 {
for _, stmt := range c.stmtsToClear {
err := c.clearStmt(ctx, stmt)
if err != nil {
return nil, err
}
}
}
if el, ok := c.m[sql]; ok {
c.l.MoveToFront(el)
return el.Value.(*pgconn.StatementDescription), nil
}
if c.l.Len() == c.cap {
err := c.removeOldest(ctx)
if err != nil {
return nil, err
}
}
psd, err := c.prepare(ctx, sql)
if err != nil {
return nil, err
}
el := c.l.PushFront(psd)
c.m[sql] = el
return psd, nil
}
// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session.
func (c *LRU) Clear(ctx context.Context) error {
for c.l.Len() > 0 {
err := c.removeOldest(ctx)
if err != nil {
return err
}
}
return nil
}
func (c *LRU) StatementErrored(sql string, err error) {
pgErr, ok := err.(*pgconn.PgError)
if !ok {
return
}
// https://github.com/jackc/pgx/issues/1162
//
// We used to look for the message "cached plan must not change result type". However, that message can be localized.
// Unfortunately, error code "0A000" - "FEATURE NOT SUPPORTED" is used for many different errors and the only way to
// tell the difference is by the message. But all that happens is we clear a statement that we otherwise wouldn't
// have so it should be safe.
possibleInvalidCachedPlanError := pgErr.Code == "0A000"
if possibleInvalidCachedPlanError {
c.stmtsToClear = append(c.stmtsToClear, sql)
}
}
func (c *LRU) clearStmt(ctx context.Context, sql string) error {
elem, inMap := c.m[sql]
if !inMap {
// The statement probably fell off the back of the list. In that case, we've
// ensured that it isn't in the cache, so we can declare victory.
return nil
}
c.l.Remove(elem)
psd := elem.Value.(*pgconn.StatementDescription)
delete(c.m, psd.SQL)
if c.mode == ModePrepare {
return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close()
}
return nil
}
// Len returns the number of cached prepared statement descriptions.
func (c *LRU) Len() int {
return c.l.Len()
}
// Cap returns the maximum number of cached prepared statement descriptions.
func (c *LRU) Cap() int {
return c.cap
}
// Mode returns the mode of the cache (ModePrepare or ModeDescribe)
func (c *LRU) Mode() int {
return c.mode
}
func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
var name string
if c.mode == ModePrepare {
name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount)
c.prepareCount += 1
}
return c.conn.Prepare(ctx, name, sql, nil)
}
func (c *LRU) removeOldest(ctx context.Context) error {
oldest := c.l.Back()
c.l.Remove(oldest)
psd := oldest.Value.(*pgconn.StatementDescription)
delete(c.m, psd.SQL)
if c.mode == ModePrepare {
return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close()
}
return nil
}
+98
View File
@@ -0,0 +1,98 @@
package stmtcache
import (
"container/list"
"github.com/jackc/pgx/v5/pgconn"
)
// LRUCache implements Cache with a Least Recently Used (LRU) cache.
type LRUCache struct {
cap int
m map[string]*list.Element
l *list.List
invalidStmts []*pgconn.StatementDescription
}
// NewLRUCache creates a new LRUCache. cap is the maximum size of the cache.
func NewLRUCache(cap int) *LRUCache {
return &LRUCache{
cap: cap,
m: make(map[string]*list.Element),
l: list.New(),
}
}
// Get returns the statement description for sql. Returns nil if not found.
func (c *LRUCache) Get(key string) *pgconn.StatementDescription {
if el, ok := c.m[key]; ok {
c.l.MoveToFront(el)
return el.Value.(*pgconn.StatementDescription)
}
return nil
}
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
func (c *LRUCache) Put(sd *pgconn.StatementDescription) {
if sd.SQL == "" {
panic("cannot store statement description with empty SQL")
}
if _, present := c.m[sd.SQL]; present {
return
}
if c.l.Len() == c.cap {
c.invalidateOldest()
}
el := c.l.PushFront(sd)
c.m[sd.SQL] = el
}
// Invalidate invalidates statement description identified by sql. Does nothing if not found.
func (c *LRUCache) Invalidate(sql string) {
if el, ok := c.m[sql]; ok {
delete(c.m, sql)
c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription))
c.l.Remove(el)
}
}
// InvalidateAll invalidates all statement descriptions.
func (c *LRUCache) InvalidateAll() {
el := c.l.Front()
for el != nil {
c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription))
el = el.Next()
}
c.m = make(map[string]*list.Element)
c.l = list.New()
}
func (c *LRUCache) HandleInvalidated() []*pgconn.StatementDescription {
invalidStmts := c.invalidStmts
c.invalidStmts = nil
return invalidStmts
}
// Len returns the number of cached prepared statement descriptions.
func (c *LRUCache) Len() int {
return c.l.Len()
}
// Cap returns the maximum number of cached prepared statement descriptions.
func (c *LRUCache) Cap() int {
return c.cap
}
func (c *LRUCache) invalidateOldest() {
oldest := c.l.Back()
sd := oldest.Value.(*pgconn.StatementDescription)
c.invalidStmts = append(c.invalidStmts, sd)
delete(c.m, sd.SQL)
c.l.Remove(oldest)
}
-292
View File
@@ -1,292 +0,0 @@
package stmtcache_test
import (
"context"
"fmt"
"math/rand"
"os"
"regexp"
"testing"
"time"
"github.com/jackc/pgx/v5/internal/stmtcache"
"github.com/jackc/pgx/v5/pgconn"
"github.com/stretchr/testify/require"
)
func TestLRUModePrepare(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer conn.Close(ctx)
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2)
require.EqualValues(t, 0, cache.Len())
require.EqualValues(t, 2, cache.Cap())
require.EqualValues(t, stmtcache.ModePrepare, cache.Mode())
psd, err := cache.Get(ctx, "select 1")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 1, cache.Len())
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 1")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 1, cache.Len())
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 2")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 2, cache.Len())
require.ElementsMatch(t, []string{"select 1", "select 2"}, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 3")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 2, cache.Len())
require.ElementsMatch(t, []string{"select 2", "select 3"}, fetchServerStatements(t, ctx, conn))
err = cache.Clear(ctx)
require.NoError(t, err)
require.EqualValues(t, 0, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
}
func TestLRUStmtInvalidation(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer conn.Close(ctx)
// we construct a fake error because its not super straightforward to actually call
// a prepared statement from the LRU cache without the helper routines which live
// in pgx proper.
fakeInvalidCachePlanError := &pgconn.PgError{
Severity: "ERROR",
Code: "0A000",
Message: "cached plan must not change result type",
}
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2)
//
// outside of a transaction, we eagerly flush the statement
//
_, err = cache.Get(ctx, "select 1")
require.NoError(t, err)
require.EqualValues(t, 1, cache.Len())
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
cache.StatementErrored("select 1", fakeInvalidCachePlanError)
_, err = cache.Get(ctx, "select 2")
require.NoError(t, err)
require.EqualValues(t, 1, cache.Len())
require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn))
err = cache.Clear(ctx)
require.NoError(t, err)
//
// within an errored transaction, we defer the flush to after the first get
// that happens after the transaction is rolled back
//
_, err = cache.Get(ctx, "select 1")
require.NoError(t, err)
require.EqualValues(t, 1, cache.Len())
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
res := conn.Exec(ctx, "begin")
require.NoError(t, res.Close())
require.Equal(t, byte('T'), conn.TxStatus())
res = conn.Exec(ctx, "selec")
require.Error(t, res.Close())
require.Equal(t, byte('E'), conn.TxStatus())
cache.StatementErrored("select 1", fakeInvalidCachePlanError)
require.EqualValues(t, 1, cache.Len())
res = conn.Exec(ctx, "rollback")
require.NoError(t, res.Close())
_, err = cache.Get(ctx, "select 2")
require.EqualValues(t, 1, cache.Len())
require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn))
}
func TestLRUStmtInvalidationIntegration(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer conn.Close(ctx)
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2)
result := conn.ExecParams(ctx, "create temporary table stmtcache_table (a text)", nil, nil, nil, nil).Read()
require.NoError(t, result.Err)
sql := "select * from stmtcache_table"
sd1, err := cache.Get(ctx, sql)
require.NoError(t, err)
result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read()
require.NoError(t, result.Err)
result = conn.ExecParams(ctx, "alter table stmtcache_table add column b text", nil, nil, nil, nil).Read()
require.NoError(t, result.Err)
result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read()
require.EqualError(t, result.Err, "ERROR: cached plan must not change result type (SQLSTATE 0A000)")
cache.StatementErrored(sql, result.Err)
sd2, err := cache.Get(ctx, sql)
require.NoError(t, err)
require.NotEqual(t, sd1.Name, sd2.Name)
result = conn.ExecPrepared(ctx, sd2.Name, nil, nil, nil).Read()
require.NoError(t, result.Err)
}
func TestLRUModePrepareStress(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer conn.Close(ctx)
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 8)
require.EqualValues(t, 0, cache.Len())
require.EqualValues(t, 8, cache.Cap())
require.EqualValues(t, stmtcache.ModePrepare, cache.Mode())
for i := 0; i < 1000; i++ {
psd, err := cache.Get(ctx, fmt.Sprintf("select %d", rand.Intn(50)))
require.NoError(t, err)
require.NotNil(t, psd)
result := conn.ExecPrepared(ctx, psd.Name, nil, nil, nil).Read()
require.NoError(t, result.Err)
}
}
func TestLRUModeDescribe(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer conn.Close(ctx)
cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2)
require.EqualValues(t, 0, cache.Len())
require.EqualValues(t, 2, cache.Cap())
require.EqualValues(t, stmtcache.ModeDescribe, cache.Mode())
psd, err := cache.Get(ctx, "select 1")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 1, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 1")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 1, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 2")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 2, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 3")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 2, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
err = cache.Clear(ctx)
require.NoError(t, err)
require.EqualValues(t, 0, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
}
func TestLRUContext(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer conn.Close(ctx)
cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2)
// test 1 : getting a value for the first time with a cancelled context returns an error
ctx1, cancel1 := context.WithCancel(ctx)
cancel1()
desc, err := cache.Get(ctx1, "SELECT 1")
require.Error(t, err)
require.Nil(t, desc)
// test 2 : when querying for the 2nd time a cached value, if the context is canceled return an error
ctx2, cancel2 := context.WithCancel(ctx)
desc, err = cache.Get(ctx2, "SELECT 2")
require.NoError(t, err)
require.NotNil(t, desc)
cancel2()
desc, err = cache.Get(ctx2, "SELECT 2")
require.Error(t, err)
require.Nil(t, desc)
}
func fetchServerStatements(t testing.TB, ctx context.Context, conn *pgconn.PgConn) []string {
result := conn.ExecParams(ctx, `select statement from pg_prepared_statements`, nil, nil, nil, nil).Read()
require.NoError(t, result.Err)
var statements []string
for _, r := range result.Rows {
statement := string(r[0])
if conn.ParameterStatus("crdb_version") != "" {
if statement == "PREPARE AS select statement from pg_prepared_statements" {
// CockroachDB includes the currently running unnamed prepared statement while PostgreSQL does not. Ignore it.
continue
}
// CockroachDB includes the "PREPARE ... AS" text in the statement even if it was prepared through the extended
// protocol will PostgreSQL does not. Normalize the statement.
re := regexp.MustCompile(`^PREPARE lrupsc[0-9_]+ AS `)
statement = re.ReplaceAllString(statement, "")
}
statements = append(statements, statement)
}
return statements
}
+34 -35
View File
@@ -2,57 +2,56 @@
package stmtcache
import (
"context"
"strconv"
"sync/atomic"
"github.com/jackc/pgx/v5/pgconn"
)
const (
ModePrepare = iota // Cache should prepare named statements.
ModeDescribe // Cache should prepare the anonymous prepared statement to only fetch the description of the statement.
)
var stmtCounter int64
// Cache prepares and caches prepared statement descriptions.
// NextStatementName returns a statement name that will be unique for the lifetime of the program.
func NextStatementName() string {
n := atomic.AddInt64(&stmtCounter, 1)
return "stmtcache_" + strconv.FormatInt(n, 10)
}
// Cache caches statement descriptions.
type Cache interface {
// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed.
Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error)
// Get returns the statement description for sql. Returns nil if not found.
Get(sql string) *pgconn.StatementDescription
// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session.
Clear(ctx context.Context) error
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
Put(sd *pgconn.StatementDescription)
// StatementErrored informs the cache that the given statement resulted in an error when it
// was last used against the database. In some cases, this will cause the cache to maer that
// statement as bad. The bad statement will instead be flushed during the next call to Get
// that occurs outside of a failed transaction.
StatementErrored(sql string, err error)
// Invalidate invalidates statement description identified by sql. Does nothing if not found.
Invalidate(sql string)
// InvalidateAll invalidates all statement descriptions.
InvalidateAll()
// HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated.
HandleInvalidated() []*pgconn.StatementDescription
// Len returns the number of cached prepared statement descriptions.
Len() int
// Cap returns the maximum number of cached prepared statement descriptions.
Cap() int
// Mode returns the mode of the cache (ModePrepare or ModeDescribe)
Mode() int
}
// New returns the preferred cache implementation for mode and cap. mode is either ModePrepare or ModeDescribe. cap is
// the maximum size of the cache.
func New(conn *pgconn.PgConn, mode int, cap int) Cache {
mustBeValidMode(mode)
mustBeValidCap(cap)
return NewLRU(conn, mode, cap)
}
func mustBeValidMode(mode int) {
if mode != ModePrepare && mode != ModeDescribe {
panic("mode must be ModePrepare or ModeDescribe")
func IsStatementInvalid(err error) bool {
pgErr, ok := err.(*pgconn.PgError)
if !ok {
return false
}
}
func mustBeValidCap(cap int) {
if cap < 1 {
panic("cache must have cap of >= 1")
}
// https://github.com/jackc/pgx/issues/1162
//
// We used to look for the message "cached plan must not change result type". However, that message can be localized.
// Unfortunately, error code "0A000" - "FEATURE NOT SUPPORTED" is used for many different errors and the only way to
// tell the difference is by the message. But all that happens is we clear a statement that we otherwise wouldn't
// have so it should be safe.
possibleInvalidCachedPlanError := pgErr.Code == "0A000"
return possibleInvalidCachedPlanError
}
+71
View File
@@ -0,0 +1,71 @@
package stmtcache
import (
"math"
"github.com/jackc/pgx/v5/pgconn"
)
// UnlimitedCache implements Cache with no capacity limit.
type UnlimitedCache struct {
m map[string]*pgconn.StatementDescription
invalidStmts []*pgconn.StatementDescription
}
// NewUnlimitedCache creates a new UnlimitedCache.
func NewUnlimitedCache() *UnlimitedCache {
return &UnlimitedCache{
m: make(map[string]*pgconn.StatementDescription),
}
}
// Get returns the statement description for sql. Returns nil if not found.
func (c *UnlimitedCache) Get(sql string) *pgconn.StatementDescription {
return c.m[sql]
}
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
func (c *UnlimitedCache) Put(sd *pgconn.StatementDescription) {
if sd.SQL == "" {
panic("cannot store statement description with empty SQL")
}
if _, present := c.m[sd.SQL]; present {
return
}
c.m[sd.SQL] = sd
}
// Invalidate invalidates statement description identified by sql. Does nothing if not found.
func (c *UnlimitedCache) Invalidate(sql string) {
if sd, ok := c.m[sql]; ok {
delete(c.m, sql)
c.invalidStmts = append(c.invalidStmts, sd)
}
}
// InvalidateAll invalidates all statement descriptions.
func (c *UnlimitedCache) InvalidateAll() {
for _, sd := range c.m {
c.invalidStmts = append(c.invalidStmts, sd)
}
c.m = make(map[string]*pgconn.StatementDescription)
}
func (c *UnlimitedCache) HandleInvalidated() []*pgconn.StatementDescription {
invalidStmts := c.invalidStmts
c.invalidStmts = nil
return invalidStmts
}
// Len returns the number of cached prepared statement descriptions.
func (c *UnlimitedCache) Len() int {
return len(c.m)
}
// Cap returns the maximum number of cached prepared statement descriptions.
func (c *UnlimitedCache) Cap() int {
return math.MaxInt
}
+1 -1
View File
@@ -1909,10 +1909,10 @@ func (p *Pipeline) Close() error {
for p.expectedReadyForQueryCount > 0 {
_, err := p.GetResults()
if err != nil {
p.err = err
var pgErr *PgError
if !errors.As(err, &pgErr) {
p.conn.asyncClose()
p.err = err
break
}
}
+11 -2
View File
@@ -7,6 +7,7 @@ import (
"reflect"
"time"
"github.com/jackc/pgx/v5/internal/stmtcache"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgproto3"
"github.com/jackc/pgx/v5/pgtype"
@@ -173,8 +174,16 @@ func (rows *baseRows) Close() {
}
}
if rows.err != nil && rows.conn != nil && rows.conn.statementCache != nil {
rows.conn.statementCache.StatementErrored(rows.sql, rows.err)
if rows.err != nil && rows.conn != nil && rows.sql != "" {
if stmtcache.IsStatementInvalid(rows.err) {
if sc := rows.conn.statementCache; sc != nil {
sc.Invalidate(rows.sql)
}
if sc := rows.conn.descriptionCache; sc != nil {
sc.Invalidate(rows.sql)
}
}
}
}