2
0

SendBatch now uses pipeline mode to prepare and describe statements

Previously, a batch with 10 unique parameterized statements executed
100 times would entail 11 network round trips. 1 for each prepare /
describe and 1 for executing them all. Now pipeline mode is used to
prepare / describe all statements in a single network round trip. So it
would only take 2 round trips.
This commit is contained in:
Jack Christensen
2022-07-09 09:28:11 -05:00
parent ba58e3d5d2
commit e7aa76ccf9
12 changed files with 694 additions and 612 deletions
+6
View File
@@ -139,6 +139,12 @@ The `RowScanner` interface allows a single argument to Rows.Scan to scan the ent
`QueryFunc` has been replaced by using `ForEachScannedRow`. `QueryFunc` has been replaced by using `ForEachScannedRow`.
## SendBatch Uses Pipeline Mode When Appropriate
Previously, a batch with 10 unique parameterized statements executed 100 times would entail 11 network round trips. 1
for each prepare / describe and 1 for executing them all. Now pipeline mode is used to prepare / describe all statements
in a single network round trip. So it would only take 2 round trips.
## 3rd Party Logger Integration ## 3rd Party Logger Integration
All integrations with 3rd party loggers have been extracted to separate repositories. This trims the pgx dependency All integrations with 3rd party loggers have been extracted to separate repositories. This trims the pgx dependency
+163
View File
@@ -11,6 +11,7 @@ import (
type batchItem struct { type batchItem struct {
query string query string
arguments []any arguments []any
sd *pgconn.StatementDescription
} }
// Batch queries are a way of bundling multiple queries together to avoid // Batch queries are a way of bundling multiple queries together to avoid
@@ -192,3 +193,165 @@ func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
} }
return return
} }
type pipelineBatchResults struct {
ctx context.Context
conn *Conn
pipeline *pgconn.Pipeline
lastRows *baseRows
err error
b *Batch
ix int
closed bool
}
// Exec reads the results from the next query in the batch as if the query has been sent with Exec.
func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) {
if br.err != nil {
return pgconn.CommandTag{}, br.err
}
if br.closed {
return pgconn.CommandTag{}, fmt.Errorf("batch already closed")
}
if br.lastRows != nil && br.lastRows.err != nil {
return pgconn.CommandTag{}, br.err
}
query, arguments, _ := br.nextQueryAndArgs()
results, err := br.pipeline.GetResults()
if err != nil {
br.err = err
return pgconn.CommandTag{}, err
}
var commandTag pgconn.CommandTag
switch results := results.(type) {
case *pgconn.ResultReader:
commandTag, err = results.Close()
default:
return pgconn.CommandTag{}, fmt.Errorf("unexpected pipeline result: %T", results)
}
if err != nil {
br.err = err
if br.conn.shouldLog(LogLevelError) {
br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]any{
"sql": query,
"args": logQueryArgs(arguments),
"err": err,
})
}
} else if br.conn.shouldLog(LogLevelInfo) {
br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Exec", map[string]any{
"sql": query,
"args": logQueryArgs(arguments),
"commandTag": commandTag,
})
}
return commandTag, err
}
// Query reads the results from the next query in the batch as if the query has been sent with Query.
func (br *pipelineBatchResults) Query() (Rows, error) {
if br.err != nil {
return &baseRows{err: br.err, closed: true}, br.err
}
if br.closed {
alreadyClosedErr := fmt.Errorf("batch already closed")
return &baseRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr
}
if br.lastRows != nil && br.lastRows.err != nil {
br.err = br.lastRows.err
return &baseRows{err: br.err, closed: true}, br.err
}
query, arguments, ok := br.nextQueryAndArgs()
if !ok {
query = "batch query"
}
rows := br.conn.getRows(br.ctx, query, arguments)
br.lastRows = rows
results, err := br.pipeline.GetResults()
if err != nil {
br.err = err
rows.err = err
rows.closed = true
if br.conn.shouldLog(LogLevelError) {
br.conn.log(br.ctx, LogLevelError, "BatchResult.Query", map[string]any{
"sql": query,
"args": logQueryArgs(arguments),
"err": rows.err,
})
}
} else {
switch results := results.(type) {
case *pgconn.ResultReader:
rows.resultReader = results
default:
err = fmt.Errorf("unexpected pipeline result: %T", results)
br.err = err
rows.err = err
rows.closed = true
}
}
return rows, rows.err
}
// QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow.
func (br *pipelineBatchResults) QueryRow() Row {
rows, _ := br.Query()
return (*connRow)(rows.(*baseRows))
}
// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to
// resyncronize the connection with the server. In this case the underlying connection will have been closed.
func (br *pipelineBatchResults) Close() error {
if br.err != nil {
return br.err
}
if br.lastRows != nil && br.lastRows.err != nil {
br.err = br.lastRows.err
return br.err
}
if br.closed {
return nil
}
br.closed = true
// log any queries that haven't yet been logged by Exec or Query
for {
query, args, ok := br.nextQueryAndArgs()
if !ok {
break
}
if br.conn.shouldLog(LogLevelInfo) {
br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Close", map[string]any{
"sql": query,
"args": logQueryArgs(args),
})
}
}
return br.pipeline.Close()
}
func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
if br.b != nil && br.ix < len(br.b.items) {
bi := br.b.items[br.ix]
query = bi.query
args = bi.arguments
ok = true
br.ix++
}
return
}
+1 -1
View File
@@ -420,7 +420,7 @@ func TestConnSendBatchQueryError(t *testing.T) {
err = br.Close() err = br.Close()
if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") { if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") {
t.Errorf("rows.Err() => %v, want error code %v", err, 22012) t.Errorf("br.Close() => %v, want error code %v", err, 22012)
} }
}) })
+300 -104
View File
@@ -236,11 +236,11 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
c.wbuf = make([]byte, 0, 1024) c.wbuf = make([]byte, 0, 1024)
if c.config.StatementCacheCapacity > 0 { if c.config.StatementCacheCapacity > 0 {
c.statementCache = stmtcache.New(c.pgConn, stmtcache.ModePrepare, c.config.StatementCacheCapacity) c.statementCache = stmtcache.NewLRUCache(c.config.StatementCacheCapacity)
} }
if c.config.DescriptionCacheCapacity > 0 { if c.config.DescriptionCacheCapacity > 0 {
c.descriptionCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, c.config.DescriptionCacheCapacity) c.descriptionCache = stmtcache.NewLRUCache(c.config.DescriptionCacheCapacity)
} }
return c, nil return c, nil
@@ -382,6 +382,10 @@ func (c *Conn) Config() *ConnConfig { return c.config.Copy() }
// Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced // Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced
// positionally from the sql string as $1, $2, etc. // positionally from the sql string as $1, $2, etc.
func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) {
if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
return pgconn.CommandTag{}, err
}
startTime := time.Now() startTime := time.Now()
commandTag, err := c.exec(ctx, sql, arguments...) commandTag, err := c.exec(ctx, sql, arguments...)
@@ -437,9 +441,13 @@ optionLoop:
if c.statementCache == nil { if c.statementCache == nil {
return pgconn.CommandTag{}, errDisabledStatementCache return pgconn.CommandTag{}, errDisabledStatementCache
} }
sd, err := c.statementCache.Get(ctx, sql) sd := c.statementCache.Get(sql)
if err != nil { if sd == nil {
return pgconn.CommandTag{}, err sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql)
if err != nil {
return pgconn.CommandTag{}, err
}
c.statementCache.Put(sd)
} }
return c.execPrepared(ctx, sd, arguments) return c.execPrepared(ctx, sd, arguments)
@@ -447,9 +455,12 @@ optionLoop:
if c.descriptionCache == nil { if c.descriptionCache == nil {
return pgconn.CommandTag{}, errDisabledDescriptionCache return pgconn.CommandTag{}, errDisabledDescriptionCache
} }
sd, err := c.descriptionCache.Get(ctx, sql) sd := c.descriptionCache.Get(sql)
if err != nil { if sd == nil {
return pgconn.CommandTag{}, err sd, err = c.Prepare(ctx, "", sql)
if err != nil {
return pgconn.CommandTag{}, err
}
} }
return c.execParams(ctx, sd, arguments) return c.execParams(ctx, sd, arguments)
@@ -620,6 +631,10 @@ type QueryRewriter interface {
// QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely // QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely
// needed. See the documentation for those types for details. // needed. See the documentation for those types for details.
func (c *Conn) Query(ctx context.Context, sql string, args ...any) (Rows, error) { func (c *Conn) Query(ctx context.Context, sql string, args ...any) (Rows, error) {
if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
return &baseRows{err: err, closed: true}, err
}
var resultFormats QueryResultFormats var resultFormats QueryResultFormats
var resultFormatsByOID QueryResultFormatsByOID var resultFormatsByOID QueryResultFormatsByOID
mode := c.config.DefaultQueryExecMode mode := c.config.DefaultQueryExecMode
@@ -649,6 +664,11 @@ optionLoop:
sql, args = queryRewriter.RewriteQuery(ctx, c, sql, args) sql, args = queryRewriter.RewriteQuery(ctx, c, sql, args)
} }
// Bypass any statement caching.
if sql == "" {
mode = QueryExecModeSimpleProtocol
}
c.eqb.reset() c.eqb.reset()
anynil.NormalizeSlice(args) anynil.NormalizeSlice(args)
rows := c.getRows(ctx, sql, args) rows := c.getRows(ctx, sql, args)
@@ -664,10 +684,14 @@ optionLoop:
rows.fatal(err) rows.fatal(err)
return rows, err return rows, err
} }
sd, err = c.statementCache.Get(ctx, sql) sd = c.statementCache.Get(sql)
if err != nil { if sd == nil {
rows.fatal(err) sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql)
return rows, err if err != nil {
rows.fatal(err)
return rows, err
}
c.statementCache.Put(sd)
} }
case QueryExecModeCacheDescribe: case QueryExecModeCacheDescribe:
if c.descriptionCache == nil { if c.descriptionCache == nil {
@@ -675,10 +699,14 @@ optionLoop:
rows.fatal(err) rows.fatal(err)
return rows, err return rows, err
} }
sd, err = c.descriptionCache.Get(ctx, sql) sd = c.descriptionCache.Get(sql)
if err != nil { if sd == nil {
rows.fatal(err) sd, err = c.Prepare(ctx, "", sql)
return rows, err if err != nil {
rows.fatal(err)
return rows, err
}
c.descriptionCache.Put(sd)
} }
case QueryExecModeDescribeExec: case QueryExecModeDescribeExec:
sd, err = c.Prepare(ctx, "", sql) sd, err = c.Prepare(ctx, "", sql)
@@ -767,6 +795,10 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row {
// explicit transaction control statements are executed. The returned BatchResults must be closed before the connection // explicit transaction control statements are executed. The returned BatchResults must be closed before the connection
// is used again. // is used again.
func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
mode := c.config.DefaultQueryExecMode mode := c.config.DefaultQueryExecMode
for _, bi := range b.items { for _, bi := range b.items {
@@ -794,105 +826,70 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
} }
if mode == QueryExecModeSimpleProtocol { if mode == QueryExecModeSimpleProtocol {
var sb strings.Builder return c.sendBatchQueryExecModeSimpleProtocol(ctx, b)
for i, bi := range b.items { }
if i > 0 {
sb.WriteByte(';') // All other modes use extended protocol and thus can use prepared statements.
} for _, bi := range b.items {
sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...) if sd, ok := c.preparedStatements[bi.query]; ok {
if err != nil { bi.sd = sd
return &batchResults{ctx: ctx, conn: c, err: err}
}
sb.WriteString(sql)
}
mrr := c.pgConn.Exec(ctx, sb.String())
return &batchResults{
ctx: ctx,
conn: c,
mrr: mrr,
b: b,
ix: 0,
} }
} }
switch mode {
case QueryExecModeExec:
return c.sendBatchQueryExecModeExec(ctx, b)
case QueryExecModeCacheStatement:
return c.sendBatchQueryExecModeCacheStatement(ctx, b)
case QueryExecModeCacheDescribe:
return c.sendBatchQueryExecModeCacheDescribe(ctx, b)
case QueryExecModeDescribeExec:
return c.sendBatchQueryExecModeDescribeExec(ctx, b)
default:
panic("unknown QueryExecMode")
}
}
func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults {
var sb strings.Builder
for i, bi := range b.items {
if i > 0 {
sb.WriteByte(';')
}
sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
sb.WriteString(sql)
}
mrr := c.pgConn.Exec(ctx, sb.String())
return &batchResults{
ctx: ctx,
conn: c,
mrr: mrr,
b: b,
ix: 0,
}
}
func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults {
batch := &pgconn.Batch{} batch := &pgconn.Batch{}
if mode == QueryExecModeExec { for _, bi := range b.items {
for _, bi := range b.items { sd := bi.sd
c.eqb.reset() if sd != nil {
anynil.NormalizeSlice(bi.arguments)
sd := c.preparedStatements[bi.query]
if sd != nil {
err := c.eqb.Build(c.typeMap, sd, bi.arguments)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
} else {
err := c.eqb.Build(c.typeMap, nil, bi.arguments)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
}
}
} else {
distinctUnpreparedQueries := map[string]struct{}{}
for _, bi := range b.items {
if _, ok := c.preparedStatements[bi.query]; ok {
continue
}
distinctUnpreparedQueries[bi.query] = struct{}{}
}
var stmtCache stmtcache.Cache
if len(distinctUnpreparedQueries) > 0 {
if mode == QueryExecModeCacheStatement && c.statementCache != nil && c.statementCache.Cap() >= len(distinctUnpreparedQueries) {
stmtCache = c.statementCache
} else if mode == QueryExecModeCacheStatement && c.descriptionCache != nil && c.descriptionCache.Cap() >= len(distinctUnpreparedQueries) {
stmtCache = c.descriptionCache
} else {
stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries))
}
for sql, _ := range distinctUnpreparedQueries {
_, err := stmtCache.Get(ctx, sql)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
}
}
for _, bi := range b.items {
c.eqb.reset()
sd := c.preparedStatements[bi.query]
if sd == nil {
var err error
sd, err = stmtCache.Get(ctx, bi.query)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
}
if len(sd.ParamOIDs) != len(bi.arguments) {
return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")}
}
err := c.eqb.Build(c.typeMap, sd, bi.arguments) err := c.eqb.Build(c.typeMap, sd, bi.arguments)
if err != nil { if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err} return &batchResults{ctx: ctx, conn: c, err: err}
} }
if sd.Name == "" { batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
batch.ExecParams(bi.query, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats) } else {
} else { err := c.eqb.Build(c.typeMap, nil, bi.arguments)
batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats) if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
} }
batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
} }
} }
@@ -909,6 +906,171 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
} }
} }
func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
if c.statementCache == nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache}
}
distinctNewQueries := []*pgconn.StatementDescription{}
distinctNewQueriesIdxMap := make(map[string]int)
for _, bi := range b.items {
if bi.sd == nil {
sd := c.statementCache.Get(bi.query)
if sd != nil {
bi.sd = sd
} else {
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
bi.sd = distinctNewQueries[idx]
} else {
sd = &pgconn.StatementDescription{
Name: stmtcache.NextStatementName(),
SQL: bi.query,
}
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
distinctNewQueries = append(distinctNewQueries, sd)
bi.sd = sd
}
}
}
}
return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, c.statementCache)
}
func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
if c.descriptionCache == nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache}
}
distinctNewQueries := []*pgconn.StatementDescription{}
distinctNewQueriesIdxMap := make(map[string]int)
for _, bi := range b.items {
if bi.sd == nil {
sd := c.descriptionCache.Get(bi.query)
if sd != nil {
bi.sd = sd
} else {
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
bi.sd = distinctNewQueries[idx]
} else {
sd = &pgconn.StatementDescription{
SQL: bi.query,
}
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
distinctNewQueries = append(distinctNewQueries, sd)
bi.sd = sd
}
}
}
}
return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, c.descriptionCache)
}
func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
distinctNewQueries := []*pgconn.StatementDescription{}
distinctNewQueriesIdxMap := make(map[string]int)
for _, bi := range b.items {
if bi.sd == nil {
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
bi.sd = distinctNewQueries[idx]
} else {
sd := &pgconn.StatementDescription{
SQL: bi.query,
}
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
distinctNewQueries = append(distinctNewQueries, sd)
bi.sd = sd
}
}
}
return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, nil)
}
func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, distinctNewQueries []*pgconn.StatementDescription, sdCache stmtcache.Cache) (pbr *pipelineBatchResults) {
pipeline := c.pgConn.StartPipeline(context.Background())
defer func() {
if pbr.err != nil {
pipeline.Close()
}
}()
// Prepare any needed queries
if len(distinctNewQueries) > 0 {
for _, sd := range distinctNewQueries {
pipeline.SendPrepare(sd.Name, sd.SQL, nil)
}
err := pipeline.Sync()
if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
}
for _, sd := range distinctNewQueries {
results, err := pipeline.GetResults()
if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
}
resultSD, ok := results.(*pgconn.StatementDescription)
if !ok {
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results)}
}
// Fill in the previously empty / pending statement descriptions.
sd.ParamOIDs = resultSD.ParamOIDs
sd.Fields = resultSD.Fields
}
results, err := pipeline.GetResults()
if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
}
_, ok := results.(*pgconn.PipelineSync)
if !ok {
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results)}
}
}
// Put all statements into the cache. It's fine if it overflows because HandleInvalidated will clean them up later.
if sdCache != nil {
for _, sd := range distinctNewQueries {
c.statementCache.Put(sd)
}
}
// Queue the queries.
for _, bi := range b.items {
err := c.eqb.Build(c.typeMap, bi.sd, bi.arguments)
if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
}
if bi.sd.Name == "" {
pipeline.SendQueryParams(bi.sd.SQL, c.eqb.ParamValues, bi.sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats)
} else {
pipeline.SendQueryPrepared(bi.sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
}
}
err := pipeline.Sync()
if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
}
return &pipelineBatchResults{
ctx: ctx,
conn: c,
pipeline: pipeline,
b: b,
}
}
func (c *Conn) sanitizeForSimpleQuery(sql string, args ...any) (string, error) { func (c *Conn) sanitizeForSimpleQuery(sql string, args ...any) (string, error) {
if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" { if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" {
return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on") return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on")
@@ -1015,3 +1177,37 @@ order by attnum`,
return fields, nil return fields, nil
} }
func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error {
if c.descriptionCache != nil {
c.descriptionCache.HandleInvalidated()
}
var invalidatedStatements []*pgconn.StatementDescription
if c.statementCache != nil {
invalidatedStatements = c.statementCache.HandleInvalidated()
}
if len(invalidatedStatements) == 0 {
return nil
}
pipeline := c.pgConn.StartPipeline(ctx)
defer pipeline.Close()
for _, sd := range invalidatedStatements {
pipeline.SendDeallocate(sd.Name)
}
err := pipeline.Sync()
if err != nil {
return fmt.Errorf("failed to deallocate cached statement(s): %w", err)
}
err = pipeline.Close()
if err != nil {
return fmt.Errorf("failed to deallocate cached statement(s): %w", err)
}
return nil
}
+9 -8
View File
@@ -931,6 +931,7 @@ func TestStmtCacheInvalidationConn(t *testing.T) {
rows, err := conn.Query(ctx, getSQL, 1) rows, err := conn.Query(ctx, getSQL, 1)
require.NoError(t, err) require.NoError(t, err)
rows.Close() rows.Close()
require.NoError(t, rows.Err())
// Now, change the schema of the table out from under the statement, making it invalid. // Now, change the schema of the table out from under the statement, making it invalid.
_, err = conn.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1") _, err = conn.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1")
@@ -948,10 +949,10 @@ func TestStmtCacheInvalidationConn(t *testing.T) {
rows.Close() rows.Close()
for _, err := range []error{nextErr, rows.Err()} { for _, err := range []error{nextErr, rows.Err()} {
if err == nil { if err == nil {
t.Fatal("expected InvalidCachedStatementPlanError: no error") t.Fatal(`expected "cached plan must not change result type": no error`)
} }
if !strings.Contains(err.Error(), "cached plan must not change result type") { if !strings.Contains(err.Error(), "cached plan must not change result type") {
t.Fatalf("expected InvalidCachedStatementPlanError, got: %s", err.Error()) t.Fatalf(`expected "cached plan must not change result type", got: "%s"`, err.Error())
} }
} }
@@ -995,6 +996,7 @@ func TestStmtCacheInvalidationTx(t *testing.T) {
rows, err := tx.Query(ctx, getSQL, 1) rows, err := tx.Query(ctx, getSQL, 1)
require.NoError(t, err) require.NoError(t, err)
rows.Close() rows.Close()
require.NoError(t, rows.Err())
// Now, change the schema of the table out from under the statement, making it invalid. // Now, change the schema of the table out from under the statement, making it invalid.
_, err = tx.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1") _, err = tx.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1")
@@ -1012,18 +1014,17 @@ func TestStmtCacheInvalidationTx(t *testing.T) {
rows.Close() rows.Close()
for _, err := range []error{nextErr, rows.Err()} { for _, err := range []error{nextErr, rows.Err()} {
if err == nil { if err == nil {
t.Fatal("expected InvalidCachedStatementPlanError: no error") t.Fatal(`expected "cached plan must not change result type": no error`)
} }
if !strings.Contains(err.Error(), "cached plan must not change result type") { if !strings.Contains(err.Error(), "cached plan must not change result type") {
t.Fatalf("expected InvalidCachedStatementPlanError, got: %s", err.Error()) t.Fatalf(`expected "cached plan must not change result type", got: "%s"`, err.Error())
} }
} }
rows, err = tx.Query(ctx, getSQL, 1) rows, _ = tx.Query(ctx, getSQL, 1)
require.NoError(t, err) // error does not pop up immediately rows.Close()
rows.Next()
err = rows.Err() err = rows.Err()
// Retries within the same transaction are errors (really anything except a rollbakc // Retries within the same transaction are errors (really anything except a rollback
// will be an error in this transaction). // will be an error in this transaction).
require.Error(t, err) require.Error(t, err)
rows.Close() rows.Close()
-169
View File
@@ -1,169 +0,0 @@
package stmtcache
import (
"container/list"
"context"
"fmt"
"sync/atomic"
"github.com/jackc/pgx/v5/pgconn"
)
var lruCount uint64
// LRU implements Cache with a Least Recently Used (LRU) cache.
type LRU struct {
conn *pgconn.PgConn
mode int
cap int
prepareCount int
m map[string]*list.Element
l *list.List
psNamePrefix string
stmtsToClear []string
}
// NewLRU creates a new LRU. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache.
func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU {
mustBeValidMode(mode)
mustBeValidCap(cap)
n := atomic.AddUint64(&lruCount, 1)
return &LRU{
conn: conn,
mode: mode,
cap: cap,
m: make(map[string]*list.Element),
l: list.New(),
psNamePrefix: fmt.Sprintf("lrupsc_%d", n),
}
}
// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed.
func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
if ctx != context.Background() {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
}
// flush an outstanding bad statements
txStatus := c.conn.TxStatus()
if (txStatus == 'I' || txStatus == 'T') && len(c.stmtsToClear) > 0 {
for _, stmt := range c.stmtsToClear {
err := c.clearStmt(ctx, stmt)
if err != nil {
return nil, err
}
}
}
if el, ok := c.m[sql]; ok {
c.l.MoveToFront(el)
return el.Value.(*pgconn.StatementDescription), nil
}
if c.l.Len() == c.cap {
err := c.removeOldest(ctx)
if err != nil {
return nil, err
}
}
psd, err := c.prepare(ctx, sql)
if err != nil {
return nil, err
}
el := c.l.PushFront(psd)
c.m[sql] = el
return psd, nil
}
// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session.
func (c *LRU) Clear(ctx context.Context) error {
for c.l.Len() > 0 {
err := c.removeOldest(ctx)
if err != nil {
return err
}
}
return nil
}
func (c *LRU) StatementErrored(sql string, err error) {
pgErr, ok := err.(*pgconn.PgError)
if !ok {
return
}
// https://github.com/jackc/pgx/issues/1162
//
// We used to look for the message "cached plan must not change result type". However, that message can be localized.
// Unfortunately, error code "0A000" - "FEATURE NOT SUPPORTED" is used for many different errors and the only way to
// tell the difference is by the message. But all that happens is we clear a statement that we otherwise wouldn't
// have so it should be safe.
possibleInvalidCachedPlanError := pgErr.Code == "0A000"
if possibleInvalidCachedPlanError {
c.stmtsToClear = append(c.stmtsToClear, sql)
}
}
func (c *LRU) clearStmt(ctx context.Context, sql string) error {
elem, inMap := c.m[sql]
if !inMap {
// The statement probably fell off the back of the list. In that case, we've
// ensured that it isn't in the cache, so we can declare victory.
return nil
}
c.l.Remove(elem)
psd := elem.Value.(*pgconn.StatementDescription)
delete(c.m, psd.SQL)
if c.mode == ModePrepare {
return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close()
}
return nil
}
// Len returns the number of cached prepared statement descriptions.
func (c *LRU) Len() int {
return c.l.Len()
}
// Cap returns the maximum number of cached prepared statement descriptions.
func (c *LRU) Cap() int {
return c.cap
}
// Mode returns the mode of the cache (ModePrepare or ModeDescribe)
func (c *LRU) Mode() int {
return c.mode
}
func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
var name string
if c.mode == ModePrepare {
name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount)
c.prepareCount += 1
}
return c.conn.Prepare(ctx, name, sql, nil)
}
func (c *LRU) removeOldest(ctx context.Context) error {
oldest := c.l.Back()
c.l.Remove(oldest)
psd := oldest.Value.(*pgconn.StatementDescription)
delete(c.m, psd.SQL)
if c.mode == ModePrepare {
return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close()
}
return nil
}
+98
View File
@@ -0,0 +1,98 @@
package stmtcache
import (
"container/list"
"github.com/jackc/pgx/v5/pgconn"
)
// LRUCache implements Cache with a Least Recently Used (LRU) cache.
type LRUCache struct {
cap int
m map[string]*list.Element
l *list.List
invalidStmts []*pgconn.StatementDescription
}
// NewLRUCache creates a new LRUCache. cap is the maximum size of the cache.
func NewLRUCache(cap int) *LRUCache {
return &LRUCache{
cap: cap,
m: make(map[string]*list.Element),
l: list.New(),
}
}
// Get returns the statement description for sql. Returns nil if not found.
func (c *LRUCache) Get(key string) *pgconn.StatementDescription {
if el, ok := c.m[key]; ok {
c.l.MoveToFront(el)
return el.Value.(*pgconn.StatementDescription)
}
return nil
}
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
func (c *LRUCache) Put(sd *pgconn.StatementDescription) {
if sd.SQL == "" {
panic("cannot store statement description with empty SQL")
}
if _, present := c.m[sd.SQL]; present {
return
}
if c.l.Len() == c.cap {
c.invalidateOldest()
}
el := c.l.PushFront(sd)
c.m[sd.SQL] = el
}
// Invalidate invalidates statement description identified by sql. Does nothing if not found.
func (c *LRUCache) Invalidate(sql string) {
if el, ok := c.m[sql]; ok {
delete(c.m, sql)
c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription))
c.l.Remove(el)
}
}
// InvalidateAll invalidates all statement descriptions.
func (c *LRUCache) InvalidateAll() {
el := c.l.Front()
for el != nil {
c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription))
el = el.Next()
}
c.m = make(map[string]*list.Element)
c.l = list.New()
}
func (c *LRUCache) HandleInvalidated() []*pgconn.StatementDescription {
invalidStmts := c.invalidStmts
c.invalidStmts = nil
return invalidStmts
}
// Len returns the number of cached prepared statement descriptions.
func (c *LRUCache) Len() int {
return c.l.Len()
}
// Cap returns the maximum number of cached prepared statement descriptions.
func (c *LRUCache) Cap() int {
return c.cap
}
func (c *LRUCache) invalidateOldest() {
oldest := c.l.Back()
sd := oldest.Value.(*pgconn.StatementDescription)
c.invalidStmts = append(c.invalidStmts, sd)
delete(c.m, sd.SQL)
c.l.Remove(oldest)
}
-292
View File
@@ -1,292 +0,0 @@
package stmtcache_test
import (
"context"
"fmt"
"math/rand"
"os"
"regexp"
"testing"
"time"
"github.com/jackc/pgx/v5/internal/stmtcache"
"github.com/jackc/pgx/v5/pgconn"
"github.com/stretchr/testify/require"
)
func TestLRUModePrepare(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer conn.Close(ctx)
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2)
require.EqualValues(t, 0, cache.Len())
require.EqualValues(t, 2, cache.Cap())
require.EqualValues(t, stmtcache.ModePrepare, cache.Mode())
psd, err := cache.Get(ctx, "select 1")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 1, cache.Len())
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 1")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 1, cache.Len())
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 2")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 2, cache.Len())
require.ElementsMatch(t, []string{"select 1", "select 2"}, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 3")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 2, cache.Len())
require.ElementsMatch(t, []string{"select 2", "select 3"}, fetchServerStatements(t, ctx, conn))
err = cache.Clear(ctx)
require.NoError(t, err)
require.EqualValues(t, 0, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
}
func TestLRUStmtInvalidation(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer conn.Close(ctx)
// we construct a fake error because its not super straightforward to actually call
// a prepared statement from the LRU cache without the helper routines which live
// in pgx proper.
fakeInvalidCachePlanError := &pgconn.PgError{
Severity: "ERROR",
Code: "0A000",
Message: "cached plan must not change result type",
}
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2)
//
// outside of a transaction, we eagerly flush the statement
//
_, err = cache.Get(ctx, "select 1")
require.NoError(t, err)
require.EqualValues(t, 1, cache.Len())
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
cache.StatementErrored("select 1", fakeInvalidCachePlanError)
_, err = cache.Get(ctx, "select 2")
require.NoError(t, err)
require.EqualValues(t, 1, cache.Len())
require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn))
err = cache.Clear(ctx)
require.NoError(t, err)
//
// within an errored transaction, we defer the flush to after the first get
// that happens after the transaction is rolled back
//
_, err = cache.Get(ctx, "select 1")
require.NoError(t, err)
require.EqualValues(t, 1, cache.Len())
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
res := conn.Exec(ctx, "begin")
require.NoError(t, res.Close())
require.Equal(t, byte('T'), conn.TxStatus())
res = conn.Exec(ctx, "selec")
require.Error(t, res.Close())
require.Equal(t, byte('E'), conn.TxStatus())
cache.StatementErrored("select 1", fakeInvalidCachePlanError)
require.EqualValues(t, 1, cache.Len())
res = conn.Exec(ctx, "rollback")
require.NoError(t, res.Close())
_, err = cache.Get(ctx, "select 2")
require.EqualValues(t, 1, cache.Len())
require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn))
}
func TestLRUStmtInvalidationIntegration(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer conn.Close(ctx)
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2)
result := conn.ExecParams(ctx, "create temporary table stmtcache_table (a text)", nil, nil, nil, nil).Read()
require.NoError(t, result.Err)
sql := "select * from stmtcache_table"
sd1, err := cache.Get(ctx, sql)
require.NoError(t, err)
result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read()
require.NoError(t, result.Err)
result = conn.ExecParams(ctx, "alter table stmtcache_table add column b text", nil, nil, nil, nil).Read()
require.NoError(t, result.Err)
result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read()
require.EqualError(t, result.Err, "ERROR: cached plan must not change result type (SQLSTATE 0A000)")
cache.StatementErrored(sql, result.Err)
sd2, err := cache.Get(ctx, sql)
require.NoError(t, err)
require.NotEqual(t, sd1.Name, sd2.Name)
result = conn.ExecPrepared(ctx, sd2.Name, nil, nil, nil).Read()
require.NoError(t, result.Err)
}
func TestLRUModePrepareStress(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer conn.Close(ctx)
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 8)
require.EqualValues(t, 0, cache.Len())
require.EqualValues(t, 8, cache.Cap())
require.EqualValues(t, stmtcache.ModePrepare, cache.Mode())
for i := 0; i < 1000; i++ {
psd, err := cache.Get(ctx, fmt.Sprintf("select %d", rand.Intn(50)))
require.NoError(t, err)
require.NotNil(t, psd)
result := conn.ExecPrepared(ctx, psd.Name, nil, nil, nil).Read()
require.NoError(t, result.Err)
}
}
func TestLRUModeDescribe(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer conn.Close(ctx)
cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2)
require.EqualValues(t, 0, cache.Len())
require.EqualValues(t, 2, cache.Cap())
require.EqualValues(t, stmtcache.ModeDescribe, cache.Mode())
psd, err := cache.Get(ctx, "select 1")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 1, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 1")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 1, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 2")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 2, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 3")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 2, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
err = cache.Clear(ctx)
require.NoError(t, err)
require.EqualValues(t, 0, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
}
func TestLRUContext(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer conn.Close(ctx)
cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2)
// test 1 : getting a value for the first time with a cancelled context returns an error
ctx1, cancel1 := context.WithCancel(ctx)
cancel1()
desc, err := cache.Get(ctx1, "SELECT 1")
require.Error(t, err)
require.Nil(t, desc)
// test 2 : when querying for the 2nd time a cached value, if the context is canceled return an error
ctx2, cancel2 := context.WithCancel(ctx)
desc, err = cache.Get(ctx2, "SELECT 2")
require.NoError(t, err)
require.NotNil(t, desc)
cancel2()
desc, err = cache.Get(ctx2, "SELECT 2")
require.Error(t, err)
require.Nil(t, desc)
}
func fetchServerStatements(t testing.TB, ctx context.Context, conn *pgconn.PgConn) []string {
result := conn.ExecParams(ctx, `select statement from pg_prepared_statements`, nil, nil, nil, nil).Read()
require.NoError(t, result.Err)
var statements []string
for _, r := range result.Rows {
statement := string(r[0])
if conn.ParameterStatus("crdb_version") != "" {
if statement == "PREPARE AS select statement from pg_prepared_statements" {
// CockroachDB includes the currently running unnamed prepared statement while PostgreSQL does not. Ignore it.
continue
}
// CockroachDB includes the "PREPARE ... AS" text in the statement even if it was prepared through the extended
// protocol will PostgreSQL does not. Normalize the statement.
re := regexp.MustCompile(`^PREPARE lrupsc[0-9_]+ AS `)
statement = re.ReplaceAllString(statement, "")
}
statements = append(statements, statement)
}
return statements
}
+34 -35
View File
@@ -2,57 +2,56 @@
package stmtcache package stmtcache
import ( import (
"context" "strconv"
"sync/atomic"
"github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgconn"
) )
const ( var stmtCounter int64
ModePrepare = iota // Cache should prepare named statements.
ModeDescribe // Cache should prepare the anonymous prepared statement to only fetch the description of the statement.
)
// Cache prepares and caches prepared statement descriptions. // NextStatementName returns a statement name that will be unique for the lifetime of the program.
func NextStatementName() string {
n := atomic.AddInt64(&stmtCounter, 1)
return "stmtcache_" + strconv.FormatInt(n, 10)
}
// Cache caches statement descriptions.
type Cache interface { type Cache interface {
// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. // Get returns the statement description for sql. Returns nil if not found.
Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) Get(sql string) *pgconn.StatementDescription
// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session. // Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
Clear(ctx context.Context) error Put(sd *pgconn.StatementDescription)
// StatementErrored informs the cache that the given statement resulted in an error when it // Invalidate invalidates statement description identified by sql. Does nothing if not found.
// was last used against the database. In some cases, this will cause the cache to maer that Invalidate(sql string)
// statement as bad. The bad statement will instead be flushed during the next call to Get
// that occurs outside of a failed transaction. // InvalidateAll invalidates all statement descriptions.
StatementErrored(sql string, err error) InvalidateAll()
// HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated.
HandleInvalidated() []*pgconn.StatementDescription
// Len returns the number of cached prepared statement descriptions. // Len returns the number of cached prepared statement descriptions.
Len() int Len() int
// Cap returns the maximum number of cached prepared statement descriptions. // Cap returns the maximum number of cached prepared statement descriptions.
Cap() int Cap() int
// Mode returns the mode of the cache (ModePrepare or ModeDescribe)
Mode() int
} }
// New returns the preferred cache implementation for mode and cap. mode is either ModePrepare or ModeDescribe. cap is func IsStatementInvalid(err error) bool {
// the maximum size of the cache. pgErr, ok := err.(*pgconn.PgError)
func New(conn *pgconn.PgConn, mode int, cap int) Cache { if !ok {
mustBeValidMode(mode) return false
mustBeValidCap(cap)
return NewLRU(conn, mode, cap)
}
func mustBeValidMode(mode int) {
if mode != ModePrepare && mode != ModeDescribe {
panic("mode must be ModePrepare or ModeDescribe")
} }
}
func mustBeValidCap(cap int) { // https://github.com/jackc/pgx/issues/1162
if cap < 1 { //
panic("cache must have cap of >= 1") // We used to look for the message "cached plan must not change result type". However, that message can be localized.
} // Unfortunately, error code "0A000" - "FEATURE NOT SUPPORTED" is used for many different errors and the only way to
// tell the difference is by the message. But all that happens is we clear a statement that we otherwise wouldn't
// have so it should be safe.
possibleInvalidCachedPlanError := pgErr.Code == "0A000"
return possibleInvalidCachedPlanError
} }
+71
View File
@@ -0,0 +1,71 @@
package stmtcache
import (
"math"
"github.com/jackc/pgx/v5/pgconn"
)
// UnlimitedCache implements Cache with no capacity limit.
type UnlimitedCache struct {
m map[string]*pgconn.StatementDescription
invalidStmts []*pgconn.StatementDescription
}
// NewUnlimitedCache creates a new UnlimitedCache.
func NewUnlimitedCache() *UnlimitedCache {
return &UnlimitedCache{
m: make(map[string]*pgconn.StatementDescription),
}
}
// Get returns the statement description for sql. Returns nil if not found.
func (c *UnlimitedCache) Get(sql string) *pgconn.StatementDescription {
return c.m[sql]
}
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
func (c *UnlimitedCache) Put(sd *pgconn.StatementDescription) {
if sd.SQL == "" {
panic("cannot store statement description with empty SQL")
}
if _, present := c.m[sd.SQL]; present {
return
}
c.m[sd.SQL] = sd
}
// Invalidate invalidates statement description identified by sql. Does nothing if not found.
func (c *UnlimitedCache) Invalidate(sql string) {
if sd, ok := c.m[sql]; ok {
delete(c.m, sql)
c.invalidStmts = append(c.invalidStmts, sd)
}
}
// InvalidateAll invalidates all statement descriptions.
func (c *UnlimitedCache) InvalidateAll() {
for _, sd := range c.m {
c.invalidStmts = append(c.invalidStmts, sd)
}
c.m = make(map[string]*pgconn.StatementDescription)
}
func (c *UnlimitedCache) HandleInvalidated() []*pgconn.StatementDescription {
invalidStmts := c.invalidStmts
c.invalidStmts = nil
return invalidStmts
}
// Len returns the number of cached prepared statement descriptions.
func (c *UnlimitedCache) Len() int {
return len(c.m)
}
// Cap returns the maximum number of cached prepared statement descriptions.
func (c *UnlimitedCache) Cap() int {
return math.MaxInt
}
+1 -1
View File
@@ -1909,10 +1909,10 @@ func (p *Pipeline) Close() error {
for p.expectedReadyForQueryCount > 0 { for p.expectedReadyForQueryCount > 0 {
_, err := p.GetResults() _, err := p.GetResults()
if err != nil { if err != nil {
p.err = err
var pgErr *PgError var pgErr *PgError
if !errors.As(err, &pgErr) { if !errors.As(err, &pgErr) {
p.conn.asyncClose() p.conn.asyncClose()
p.err = err
break break
} }
} }
+11 -2
View File
@@ -7,6 +7,7 @@ import (
"reflect" "reflect"
"time" "time"
"github.com/jackc/pgx/v5/internal/stmtcache"
"github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgproto3"
"github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgtype"
@@ -173,8 +174,16 @@ func (rows *baseRows) Close() {
} }
} }
if rows.err != nil && rows.conn != nil && rows.conn.statementCache != nil { if rows.err != nil && rows.conn != nil && rows.sql != "" {
rows.conn.statementCache.StatementErrored(rows.sql, rows.err) if stmtcache.IsStatementInvalid(rows.err) {
if sc := rows.conn.statementCache; sc != nil {
sc.Invalidate(rows.sql)
}
if sc := rows.conn.descriptionCache; sc != nil {
sc.Invalidate(rows.sql)
}
}
} }
} }