diff --git a/CHANGELOG.md b/CHANGELOG.md index a0b95203..8b6c3a96 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -139,6 +139,12 @@ The `RowScanner` interface allows a single argument to Rows.Scan to scan the ent `QueryFunc` has been replaced by using `ForEachScannedRow`. +## SendBatch Uses Pipeline Mode When Appropriate + +Previously, a batch with 10 unique parameterized statements executed 100 times would entail 11 network round trips. 1 +for each prepare / describe and 1 for executing them all. Now pipeline mode is used to prepare / describe all statements +in a single network round trip. So it would only take 2 round trips. + ## 3rd Party Logger Integration All integrations with 3rd party loggers have been extracted to separate repositories. This trims the pgx dependency diff --git a/batch.go b/batch.go index 21830a1f..f2a9b4c8 100644 --- a/batch.go +++ b/batch.go @@ -11,6 +11,7 @@ import ( type batchItem struct { query string arguments []any + sd *pgconn.StatementDescription } // Batch queries are a way of bundling multiple queries together to avoid @@ -192,3 +193,165 @@ func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) { } return } + +type pipelineBatchResults struct { + ctx context.Context + conn *Conn + pipeline *pgconn.Pipeline + lastRows *baseRows + err error + b *Batch + ix int + closed bool +} + +// Exec reads the results from the next query in the batch as if the query has been sent with Exec. +func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) { + if br.err != nil { + return pgconn.CommandTag{}, br.err + } + if br.closed { + return pgconn.CommandTag{}, fmt.Errorf("batch already closed") + } + if br.lastRows != nil && br.lastRows.err != nil { + return pgconn.CommandTag{}, br.err + } + + query, arguments, _ := br.nextQueryAndArgs() + + results, err := br.pipeline.GetResults() + if err != nil { + br.err = err + return pgconn.CommandTag{}, err + } + var commandTag pgconn.CommandTag + switch results := results.(type) { + case *pgconn.ResultReader: + commandTag, err = results.Close() + default: + return pgconn.CommandTag{}, fmt.Errorf("unexpected pipeline result: %T", results) + } + + if err != nil { + br.err = err + if br.conn.shouldLog(LogLevelError) { + br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]any{ + "sql": query, + "args": logQueryArgs(arguments), + "err": err, + }) + } + } else if br.conn.shouldLog(LogLevelInfo) { + br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Exec", map[string]any{ + "sql": query, + "args": logQueryArgs(arguments), + "commandTag": commandTag, + }) + } + + return commandTag, err +} + +// Query reads the results from the next query in the batch as if the query has been sent with Query. +func (br *pipelineBatchResults) Query() (Rows, error) { + if br.err != nil { + return &baseRows{err: br.err, closed: true}, br.err + } + + if br.closed { + alreadyClosedErr := fmt.Errorf("batch already closed") + return &baseRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr + } + + if br.lastRows != nil && br.lastRows.err != nil { + br.err = br.lastRows.err + return &baseRows{err: br.err, closed: true}, br.err + } + + query, arguments, ok := br.nextQueryAndArgs() + if !ok { + query = "batch query" + } + + rows := br.conn.getRows(br.ctx, query, arguments) + br.lastRows = rows + + results, err := br.pipeline.GetResults() + if err != nil { + br.err = err + rows.err = err + rows.closed = true + if br.conn.shouldLog(LogLevelError) { + br.conn.log(br.ctx, LogLevelError, "BatchResult.Query", map[string]any{ + "sql": query, + "args": logQueryArgs(arguments), + "err": rows.err, + }) + } + } else { + switch results := results.(type) { + case *pgconn.ResultReader: + rows.resultReader = results + default: + err = fmt.Errorf("unexpected pipeline result: %T", results) + br.err = err + rows.err = err + rows.closed = true + } + } + + return rows, rows.err +} + +// QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow. +func (br *pipelineBatchResults) QueryRow() Row { + rows, _ := br.Query() + return (*connRow)(rows.(*baseRows)) + +} + +// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to +// resyncronize the connection with the server. In this case the underlying connection will have been closed. +func (br *pipelineBatchResults) Close() error { + if br.err != nil { + return br.err + } + + if br.lastRows != nil && br.lastRows.err != nil { + br.err = br.lastRows.err + return br.err + } + + if br.closed { + return nil + } + br.closed = true + + // log any queries that haven't yet been logged by Exec or Query + for { + query, args, ok := br.nextQueryAndArgs() + if !ok { + break + } + + if br.conn.shouldLog(LogLevelInfo) { + br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Close", map[string]any{ + "sql": query, + "args": logQueryArgs(args), + }) + } + } + + return br.pipeline.Close() +} + +func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) { + if br.b != nil && br.ix < len(br.b.items) { + bi := br.b.items[br.ix] + query = bi.query + args = bi.arguments + ok = true + br.ix++ + } + return +} diff --git a/batch_test.go b/batch_test.go index abe9f915..f5409d25 100644 --- a/batch_test.go +++ b/batch_test.go @@ -420,7 +420,7 @@ func TestConnSendBatchQueryError(t *testing.T) { err = br.Close() if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") { - t.Errorf("rows.Err() => %v, want error code %v", err, 22012) + t.Errorf("br.Close() => %v, want error code %v", err, 22012) } }) diff --git a/conn.go b/conn.go index d8ab21d7..997a84d3 100644 --- a/conn.go +++ b/conn.go @@ -236,11 +236,11 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { c.wbuf = make([]byte, 0, 1024) if c.config.StatementCacheCapacity > 0 { - c.statementCache = stmtcache.New(c.pgConn, stmtcache.ModePrepare, c.config.StatementCacheCapacity) + c.statementCache = stmtcache.NewLRUCache(c.config.StatementCacheCapacity) } if c.config.DescriptionCacheCapacity > 0 { - c.descriptionCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, c.config.DescriptionCacheCapacity) + c.descriptionCache = stmtcache.NewLRUCache(c.config.DescriptionCacheCapacity) } return c, nil @@ -382,6 +382,10 @@ func (c *Conn) Config() *ConnConfig { return c.config.Copy() } // Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced // positionally from the sql string as $1, $2, etc. func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { + if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil { + return pgconn.CommandTag{}, err + } + startTime := time.Now() commandTag, err := c.exec(ctx, sql, arguments...) @@ -437,9 +441,13 @@ optionLoop: if c.statementCache == nil { return pgconn.CommandTag{}, errDisabledStatementCache } - sd, err := c.statementCache.Get(ctx, sql) - if err != nil { - return pgconn.CommandTag{}, err + sd := c.statementCache.Get(sql) + if sd == nil { + sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql) + if err != nil { + return pgconn.CommandTag{}, err + } + c.statementCache.Put(sd) } return c.execPrepared(ctx, sd, arguments) @@ -447,9 +455,12 @@ optionLoop: if c.descriptionCache == nil { return pgconn.CommandTag{}, errDisabledDescriptionCache } - sd, err := c.descriptionCache.Get(ctx, sql) - if err != nil { - return pgconn.CommandTag{}, err + sd := c.descriptionCache.Get(sql) + if sd == nil { + sd, err = c.Prepare(ctx, "", sql) + if err != nil { + return pgconn.CommandTag{}, err + } } return c.execParams(ctx, sd, arguments) @@ -620,6 +631,10 @@ type QueryRewriter interface { // QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely // needed. See the documentation for those types for details. func (c *Conn) Query(ctx context.Context, sql string, args ...any) (Rows, error) { + if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil { + return &baseRows{err: err, closed: true}, err + } + var resultFormats QueryResultFormats var resultFormatsByOID QueryResultFormatsByOID mode := c.config.DefaultQueryExecMode @@ -649,6 +664,11 @@ optionLoop: sql, args = queryRewriter.RewriteQuery(ctx, c, sql, args) } + // Bypass any statement caching. + if sql == "" { + mode = QueryExecModeSimpleProtocol + } + c.eqb.reset() anynil.NormalizeSlice(args) rows := c.getRows(ctx, sql, args) @@ -664,10 +684,14 @@ optionLoop: rows.fatal(err) return rows, err } - sd, err = c.statementCache.Get(ctx, sql) - if err != nil { - rows.fatal(err) - return rows, err + sd = c.statementCache.Get(sql) + if sd == nil { + sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql) + if err != nil { + rows.fatal(err) + return rows, err + } + c.statementCache.Put(sd) } case QueryExecModeCacheDescribe: if c.descriptionCache == nil { @@ -675,10 +699,14 @@ optionLoop: rows.fatal(err) return rows, err } - sd, err = c.descriptionCache.Get(ctx, sql) - if err != nil { - rows.fatal(err) - return rows, err + sd = c.descriptionCache.Get(sql) + if sd == nil { + sd, err = c.Prepare(ctx, "", sql) + if err != nil { + rows.fatal(err) + return rows, err + } + c.descriptionCache.Put(sd) } case QueryExecModeDescribeExec: sd, err = c.Prepare(ctx, "", sql) @@ -767,6 +795,10 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row { // explicit transaction control statements are executed. The returned BatchResults must be closed before the connection // is used again. func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { + if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + mode := c.config.DefaultQueryExecMode for _, bi := range b.items { @@ -794,105 +826,70 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { } if mode == QueryExecModeSimpleProtocol { - var sb strings.Builder - for i, bi := range b.items { - if i > 0 { - sb.WriteByte(';') - } - sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...) - if err != nil { - return &batchResults{ctx: ctx, conn: c, err: err} - } - sb.WriteString(sql) - } - mrr := c.pgConn.Exec(ctx, sb.String()) - return &batchResults{ - ctx: ctx, - conn: c, - mrr: mrr, - b: b, - ix: 0, + return c.sendBatchQueryExecModeSimpleProtocol(ctx, b) + } + + // All other modes use extended protocol and thus can use prepared statements. + for _, bi := range b.items { + if sd, ok := c.preparedStatements[bi.query]; ok { + bi.sd = sd } } + switch mode { + case QueryExecModeExec: + return c.sendBatchQueryExecModeExec(ctx, b) + case QueryExecModeCacheStatement: + return c.sendBatchQueryExecModeCacheStatement(ctx, b) + case QueryExecModeCacheDescribe: + return c.sendBatchQueryExecModeCacheDescribe(ctx, b) + case QueryExecModeDescribeExec: + return c.sendBatchQueryExecModeDescribeExec(ctx, b) + default: + panic("unknown QueryExecMode") + } +} + +func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults { + var sb strings.Builder + for i, bi := range b.items { + if i > 0 { + sb.WriteByte(';') + } + sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + sb.WriteString(sql) + } + mrr := c.pgConn.Exec(ctx, sb.String()) + return &batchResults{ + ctx: ctx, + conn: c, + mrr: mrr, + b: b, + ix: 0, + } +} + +func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults { batch := &pgconn.Batch{} - if mode == QueryExecModeExec { - for _, bi := range b.items { - c.eqb.reset() - anynil.NormalizeSlice(bi.arguments) - - sd := c.preparedStatements[bi.query] - if sd != nil { - err := c.eqb.Build(c.typeMap, sd, bi.arguments) - if err != nil { - return &batchResults{ctx: ctx, conn: c, err: err} - } - - batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats) - } else { - err := c.eqb.Build(c.typeMap, nil, bi.arguments) - if err != nil { - return &batchResults{ctx: ctx, conn: c, err: err} - } - batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats) - } - } - } else { - - distinctUnpreparedQueries := map[string]struct{}{} - - for _, bi := range b.items { - if _, ok := c.preparedStatements[bi.query]; ok { - continue - } - distinctUnpreparedQueries[bi.query] = struct{}{} - } - - var stmtCache stmtcache.Cache - if len(distinctUnpreparedQueries) > 0 { - if mode == QueryExecModeCacheStatement && c.statementCache != nil && c.statementCache.Cap() >= len(distinctUnpreparedQueries) { - stmtCache = c.statementCache - } else if mode == QueryExecModeCacheStatement && c.descriptionCache != nil && c.descriptionCache.Cap() >= len(distinctUnpreparedQueries) { - stmtCache = c.descriptionCache - } else { - stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries)) - } - - for sql, _ := range distinctUnpreparedQueries { - _, err := stmtCache.Get(ctx, sql) - if err != nil { - return &batchResults{ctx: ctx, conn: c, err: err} - } - } - } - - for _, bi := range b.items { - c.eqb.reset() - - sd := c.preparedStatements[bi.query] - if sd == nil { - var err error - sd, err = stmtCache.Get(ctx, bi.query) - if err != nil { - return &batchResults{ctx: ctx, conn: c, err: err} - } - } - - if len(sd.ParamOIDs) != len(bi.arguments) { - return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")} - } - + for _, bi := range b.items { + sd := bi.sd + if sd != nil { err := c.eqb.Build(c.typeMap, sd, bi.arguments) if err != nil { return &batchResults{ctx: ctx, conn: c, err: err} } - if sd.Name == "" { - batch.ExecParams(bi.query, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats) - } else { - batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats) + batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats) + } else { + err := c.eqb.Build(c.typeMap, nil, bi.arguments) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} } + batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats) } } @@ -909,6 +906,171 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { } } +func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) { + if c.statementCache == nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache} + } + + distinctNewQueries := []*pgconn.StatementDescription{} + distinctNewQueriesIdxMap := make(map[string]int) + + for _, bi := range b.items { + if bi.sd == nil { + sd := c.statementCache.Get(bi.query) + if sd != nil { + bi.sd = sd + } else { + if idx, present := distinctNewQueriesIdxMap[bi.query]; present { + bi.sd = distinctNewQueries[idx] + } else { + sd = &pgconn.StatementDescription{ + Name: stmtcache.NextStatementName(), + SQL: bi.query, + } + distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries) + distinctNewQueries = append(distinctNewQueries, sd) + bi.sd = sd + } + } + } + } + + return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, c.statementCache) +} + +func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) { + if c.descriptionCache == nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache} + } + + distinctNewQueries := []*pgconn.StatementDescription{} + distinctNewQueriesIdxMap := make(map[string]int) + + for _, bi := range b.items { + if bi.sd == nil { + sd := c.descriptionCache.Get(bi.query) + if sd != nil { + bi.sd = sd + } else { + if idx, present := distinctNewQueriesIdxMap[bi.query]; present { + bi.sd = distinctNewQueries[idx] + } else { + sd = &pgconn.StatementDescription{ + SQL: bi.query, + } + distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries) + distinctNewQueries = append(distinctNewQueries, sd) + bi.sd = sd + } + } + } + } + + return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, c.descriptionCache) +} + +func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) { + distinctNewQueries := []*pgconn.StatementDescription{} + distinctNewQueriesIdxMap := make(map[string]int) + + for _, bi := range b.items { + if bi.sd == nil { + if idx, present := distinctNewQueriesIdxMap[bi.query]; present { + bi.sd = distinctNewQueries[idx] + } else { + sd := &pgconn.StatementDescription{ + SQL: bi.query, + } + distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries) + distinctNewQueries = append(distinctNewQueries, sd) + bi.sd = sd + } + } + } + + return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, nil) +} + +func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, distinctNewQueries []*pgconn.StatementDescription, sdCache stmtcache.Cache) (pbr *pipelineBatchResults) { + pipeline := c.pgConn.StartPipeline(context.Background()) + defer func() { + if pbr.err != nil { + pipeline.Close() + } + }() + + // Prepare any needed queries + if len(distinctNewQueries) > 0 { + for _, sd := range distinctNewQueries { + pipeline.SendPrepare(sd.Name, sd.SQL, nil) + } + + err := pipeline.Sync() + if err != nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + } + + for _, sd := range distinctNewQueries { + results, err := pipeline.GetResults() + if err != nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + } + + resultSD, ok := results.(*pgconn.StatementDescription) + if !ok { + return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results)} + } + + // Fill in the previously empty / pending statement descriptions. + sd.ParamOIDs = resultSD.ParamOIDs + sd.Fields = resultSD.Fields + } + + results, err := pipeline.GetResults() + if err != nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + } + + _, ok := results.(*pgconn.PipelineSync) + if !ok { + return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results)} + } + } + + // Put all statements into the cache. It's fine if it overflows because HandleInvalidated will clean them up later. + if sdCache != nil { + for _, sd := range distinctNewQueries { + c.statementCache.Put(sd) + } + } + + // Queue the queries. + for _, bi := range b.items { + err := c.eqb.Build(c.typeMap, bi.sd, bi.arguments) + if err != nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + } + + if bi.sd.Name == "" { + pipeline.SendQueryParams(bi.sd.SQL, c.eqb.ParamValues, bi.sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats) + } else { + pipeline.SendQueryPrepared(bi.sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats) + } + } + + err := pipeline.Sync() + if err != nil { + return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + } + + return &pipelineBatchResults{ + ctx: ctx, + conn: c, + pipeline: pipeline, + b: b, + } +} + func (c *Conn) sanitizeForSimpleQuery(sql string, args ...any) (string, error) { if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" { return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on") @@ -1015,3 +1177,37 @@ order by attnum`, return fields, nil } + +func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error { + if c.descriptionCache != nil { + c.descriptionCache.HandleInvalidated() + } + + var invalidatedStatements []*pgconn.StatementDescription + if c.statementCache != nil { + invalidatedStatements = c.statementCache.HandleInvalidated() + } + + if len(invalidatedStatements) == 0 { + return nil + } + + pipeline := c.pgConn.StartPipeline(ctx) + defer pipeline.Close() + + for _, sd := range invalidatedStatements { + pipeline.SendDeallocate(sd.Name) + } + + err := pipeline.Sync() + if err != nil { + return fmt.Errorf("failed to deallocate cached statement(s): %w", err) + } + + err = pipeline.Close() + if err != nil { + return fmt.Errorf("failed to deallocate cached statement(s): %w", err) + } + + return nil +} diff --git a/conn_test.go b/conn_test.go index 79697cbd..a023a7d7 100644 --- a/conn_test.go +++ b/conn_test.go @@ -931,6 +931,7 @@ func TestStmtCacheInvalidationConn(t *testing.T) { rows, err := conn.Query(ctx, getSQL, 1) require.NoError(t, err) rows.Close() + require.NoError(t, rows.Err()) // Now, change the schema of the table out from under the statement, making it invalid. _, err = conn.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1") @@ -948,10 +949,10 @@ func TestStmtCacheInvalidationConn(t *testing.T) { rows.Close() for _, err := range []error{nextErr, rows.Err()} { if err == nil { - t.Fatal("expected InvalidCachedStatementPlanError: no error") + t.Fatal(`expected "cached plan must not change result type": no error`) } if !strings.Contains(err.Error(), "cached plan must not change result type") { - t.Fatalf("expected InvalidCachedStatementPlanError, got: %s", err.Error()) + t.Fatalf(`expected "cached plan must not change result type", got: "%s"`, err.Error()) } } @@ -995,6 +996,7 @@ func TestStmtCacheInvalidationTx(t *testing.T) { rows, err := tx.Query(ctx, getSQL, 1) require.NoError(t, err) rows.Close() + require.NoError(t, rows.Err()) // Now, change the schema of the table out from under the statement, making it invalid. _, err = tx.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1") @@ -1012,18 +1014,17 @@ func TestStmtCacheInvalidationTx(t *testing.T) { rows.Close() for _, err := range []error{nextErr, rows.Err()} { if err == nil { - t.Fatal("expected InvalidCachedStatementPlanError: no error") + t.Fatal(`expected "cached plan must not change result type": no error`) } if !strings.Contains(err.Error(), "cached plan must not change result type") { - t.Fatalf("expected InvalidCachedStatementPlanError, got: %s", err.Error()) + t.Fatalf(`expected "cached plan must not change result type", got: "%s"`, err.Error()) } } - rows, err = tx.Query(ctx, getSQL, 1) - require.NoError(t, err) // error does not pop up immediately - rows.Next() + rows, _ = tx.Query(ctx, getSQL, 1) + rows.Close() err = rows.Err() - // Retries within the same transaction are errors (really anything except a rollbakc + // Retries within the same transaction are errors (really anything except a rollback // will be an error in this transaction). require.Error(t, err) rows.Close() diff --git a/internal/stmtcache/lru.go b/internal/stmtcache/lru.go deleted file mode 100644 index a3378c86..00000000 --- a/internal/stmtcache/lru.go +++ /dev/null @@ -1,169 +0,0 @@ -package stmtcache - -import ( - "container/list" - "context" - "fmt" - "sync/atomic" - - "github.com/jackc/pgx/v5/pgconn" -) - -var lruCount uint64 - -// LRU implements Cache with a Least Recently Used (LRU) cache. -type LRU struct { - conn *pgconn.PgConn - mode int - cap int - prepareCount int - m map[string]*list.Element - l *list.List - psNamePrefix string - stmtsToClear []string -} - -// NewLRU creates a new LRU. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache. -func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU { - mustBeValidMode(mode) - mustBeValidCap(cap) - - n := atomic.AddUint64(&lruCount, 1) - - return &LRU{ - conn: conn, - mode: mode, - cap: cap, - m: make(map[string]*list.Element), - l: list.New(), - psNamePrefix: fmt.Sprintf("lrupsc_%d", n), - } -} - -// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. -func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) { - if ctx != context.Background() { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - } - - // flush an outstanding bad statements - txStatus := c.conn.TxStatus() - if (txStatus == 'I' || txStatus == 'T') && len(c.stmtsToClear) > 0 { - for _, stmt := range c.stmtsToClear { - err := c.clearStmt(ctx, stmt) - if err != nil { - return nil, err - } - } - } - - if el, ok := c.m[sql]; ok { - c.l.MoveToFront(el) - return el.Value.(*pgconn.StatementDescription), nil - } - - if c.l.Len() == c.cap { - err := c.removeOldest(ctx) - if err != nil { - return nil, err - } - } - - psd, err := c.prepare(ctx, sql) - if err != nil { - return nil, err - } - - el := c.l.PushFront(psd) - c.m[sql] = el - - return psd, nil -} - -// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session. -func (c *LRU) Clear(ctx context.Context) error { - for c.l.Len() > 0 { - err := c.removeOldest(ctx) - if err != nil { - return err - } - } - - return nil -} - -func (c *LRU) StatementErrored(sql string, err error) { - pgErr, ok := err.(*pgconn.PgError) - if !ok { - return - } - - // https://github.com/jackc/pgx/issues/1162 - // - // We used to look for the message "cached plan must not change result type". However, that message can be localized. - // Unfortunately, error code "0A000" - "FEATURE NOT SUPPORTED" is used for many different errors and the only way to - // tell the difference is by the message. But all that happens is we clear a statement that we otherwise wouldn't - // have so it should be safe. - possibleInvalidCachedPlanError := pgErr.Code == "0A000" - if possibleInvalidCachedPlanError { - c.stmtsToClear = append(c.stmtsToClear, sql) - } -} - -func (c *LRU) clearStmt(ctx context.Context, sql string) error { - elem, inMap := c.m[sql] - if !inMap { - // The statement probably fell off the back of the list. In that case, we've - // ensured that it isn't in the cache, so we can declare victory. - return nil - } - - c.l.Remove(elem) - - psd := elem.Value.(*pgconn.StatementDescription) - delete(c.m, psd.SQL) - if c.mode == ModePrepare { - return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close() - } - return nil -} - -// Len returns the number of cached prepared statement descriptions. -func (c *LRU) Len() int { - return c.l.Len() -} - -// Cap returns the maximum number of cached prepared statement descriptions. -func (c *LRU) Cap() int { - return c.cap -} - -// Mode returns the mode of the cache (ModePrepare or ModeDescribe) -func (c *LRU) Mode() int { - return c.mode -} - -func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.StatementDescription, error) { - var name string - if c.mode == ModePrepare { - name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount) - c.prepareCount += 1 - } - - return c.conn.Prepare(ctx, name, sql, nil) -} - -func (c *LRU) removeOldest(ctx context.Context) error { - oldest := c.l.Back() - c.l.Remove(oldest) - psd := oldest.Value.(*pgconn.StatementDescription) - delete(c.m, psd.SQL) - if c.mode == ModePrepare { - return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close() - } - return nil -} diff --git a/internal/stmtcache/lru_cache.go b/internal/stmtcache/lru_cache.go new file mode 100644 index 00000000..a25cc8b1 --- /dev/null +++ b/internal/stmtcache/lru_cache.go @@ -0,0 +1,98 @@ +package stmtcache + +import ( + "container/list" + + "github.com/jackc/pgx/v5/pgconn" +) + +// LRUCache implements Cache with a Least Recently Used (LRU) cache. +type LRUCache struct { + cap int + m map[string]*list.Element + l *list.List + invalidStmts []*pgconn.StatementDescription +} + +// NewLRUCache creates a new LRUCache. cap is the maximum size of the cache. +func NewLRUCache(cap int) *LRUCache { + return &LRUCache{ + cap: cap, + m: make(map[string]*list.Element), + l: list.New(), + } +} + +// Get returns the statement description for sql. Returns nil if not found. +func (c *LRUCache) Get(key string) *pgconn.StatementDescription { + if el, ok := c.m[key]; ok { + c.l.MoveToFront(el) + return el.Value.(*pgconn.StatementDescription) + } + + return nil + +} + +// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache. +func (c *LRUCache) Put(sd *pgconn.StatementDescription) { + if sd.SQL == "" { + panic("cannot store statement description with empty SQL") + } + + if _, present := c.m[sd.SQL]; present { + return + } + + if c.l.Len() == c.cap { + c.invalidateOldest() + } + + el := c.l.PushFront(sd) + c.m[sd.SQL] = el +} + +// Invalidate invalidates statement description identified by sql. Does nothing if not found. +func (c *LRUCache) Invalidate(sql string) { + if el, ok := c.m[sql]; ok { + delete(c.m, sql) + c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription)) + c.l.Remove(el) + } +} + +// InvalidateAll invalidates all statement descriptions. +func (c *LRUCache) InvalidateAll() { + el := c.l.Front() + for el != nil { + c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription)) + el = el.Next() + } + + c.m = make(map[string]*list.Element) + c.l = list.New() +} + +func (c *LRUCache) HandleInvalidated() []*pgconn.StatementDescription { + invalidStmts := c.invalidStmts + c.invalidStmts = nil + return invalidStmts +} + +// Len returns the number of cached prepared statement descriptions. +func (c *LRUCache) Len() int { + return c.l.Len() +} + +// Cap returns the maximum number of cached prepared statement descriptions. +func (c *LRUCache) Cap() int { + return c.cap +} + +func (c *LRUCache) invalidateOldest() { + oldest := c.l.Back() + sd := oldest.Value.(*pgconn.StatementDescription) + c.invalidStmts = append(c.invalidStmts, sd) + delete(c.m, sd.SQL) + c.l.Remove(oldest) +} diff --git a/internal/stmtcache/lru_test.go b/internal/stmtcache/lru_test.go deleted file mode 100644 index 7690a2b0..00000000 --- a/internal/stmtcache/lru_test.go +++ /dev/null @@ -1,292 +0,0 @@ -package stmtcache_test - -import ( - "context" - "fmt" - "math/rand" - "os" - "regexp" - "testing" - "time" - - "github.com/jackc/pgx/v5/internal/stmtcache" - "github.com/jackc/pgx/v5/pgconn" - - "github.com/stretchr/testify/require" -) - -func TestLRUModePrepare(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer conn.Close(ctx) - - cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2) - require.EqualValues(t, 0, cache.Len()) - require.EqualValues(t, 2, cache.Cap()) - require.EqualValues(t, stmtcache.ModePrepare, cache.Mode()) - - psd, err := cache.Get(ctx, "select 1") - require.NoError(t, err) - require.NotNil(t, psd) - require.EqualValues(t, 1, cache.Len()) - require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) - - psd, err = cache.Get(ctx, "select 1") - require.NoError(t, err) - require.NotNil(t, psd) - require.EqualValues(t, 1, cache.Len()) - require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) - - psd, err = cache.Get(ctx, "select 2") - require.NoError(t, err) - require.NotNil(t, psd) - require.EqualValues(t, 2, cache.Len()) - require.ElementsMatch(t, []string{"select 1", "select 2"}, fetchServerStatements(t, ctx, conn)) - - psd, err = cache.Get(ctx, "select 3") - require.NoError(t, err) - require.NotNil(t, psd) - require.EqualValues(t, 2, cache.Len()) - require.ElementsMatch(t, []string{"select 2", "select 3"}, fetchServerStatements(t, ctx, conn)) - - err = cache.Clear(ctx) - require.NoError(t, err) - require.EqualValues(t, 0, cache.Len()) - require.Empty(t, fetchServerStatements(t, ctx, conn)) -} - -func TestLRUStmtInvalidation(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer conn.Close(ctx) - - // we construct a fake error because its not super straightforward to actually call - // a prepared statement from the LRU cache without the helper routines which live - // in pgx proper. - fakeInvalidCachePlanError := &pgconn.PgError{ - Severity: "ERROR", - Code: "0A000", - Message: "cached plan must not change result type", - } - - cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2) - - // - // outside of a transaction, we eagerly flush the statement - // - - _, err = cache.Get(ctx, "select 1") - require.NoError(t, err) - require.EqualValues(t, 1, cache.Len()) - require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) - - cache.StatementErrored("select 1", fakeInvalidCachePlanError) - _, err = cache.Get(ctx, "select 2") - require.NoError(t, err) - require.EqualValues(t, 1, cache.Len()) - require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn)) - - err = cache.Clear(ctx) - require.NoError(t, err) - - // - // within an errored transaction, we defer the flush to after the first get - // that happens after the transaction is rolled back - // - - _, err = cache.Get(ctx, "select 1") - require.NoError(t, err) - require.EqualValues(t, 1, cache.Len()) - require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) - - res := conn.Exec(ctx, "begin") - require.NoError(t, res.Close()) - require.Equal(t, byte('T'), conn.TxStatus()) - - res = conn.Exec(ctx, "selec") - require.Error(t, res.Close()) - require.Equal(t, byte('E'), conn.TxStatus()) - - cache.StatementErrored("select 1", fakeInvalidCachePlanError) - require.EqualValues(t, 1, cache.Len()) - - res = conn.Exec(ctx, "rollback") - require.NoError(t, res.Close()) - - _, err = cache.Get(ctx, "select 2") - require.EqualValues(t, 1, cache.Len()) - require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn)) -} - -func TestLRUStmtInvalidationIntegration(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer conn.Close(ctx) - - cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2) - - result := conn.ExecParams(ctx, "create temporary table stmtcache_table (a text)", nil, nil, nil, nil).Read() - require.NoError(t, result.Err) - - sql := "select * from stmtcache_table" - sd1, err := cache.Get(ctx, sql) - require.NoError(t, err) - - result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read() - require.NoError(t, result.Err) - - result = conn.ExecParams(ctx, "alter table stmtcache_table add column b text", nil, nil, nil, nil).Read() - require.NoError(t, result.Err) - - result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read() - require.EqualError(t, result.Err, "ERROR: cached plan must not change result type (SQLSTATE 0A000)") - - cache.StatementErrored(sql, result.Err) - - sd2, err := cache.Get(ctx, sql) - require.NoError(t, err) - require.NotEqual(t, sd1.Name, sd2.Name) - - result = conn.ExecPrepared(ctx, sd2.Name, nil, nil, nil).Read() - require.NoError(t, result.Err) -} - -func TestLRUModePrepareStress(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer conn.Close(ctx) - - cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 8) - require.EqualValues(t, 0, cache.Len()) - require.EqualValues(t, 8, cache.Cap()) - require.EqualValues(t, stmtcache.ModePrepare, cache.Mode()) - - for i := 0; i < 1000; i++ { - psd, err := cache.Get(ctx, fmt.Sprintf("select %d", rand.Intn(50))) - require.NoError(t, err) - require.NotNil(t, psd) - result := conn.ExecPrepared(ctx, psd.Name, nil, nil, nil).Read() - require.NoError(t, result.Err) - } -} - -func TestLRUModeDescribe(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer conn.Close(ctx) - - cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2) - require.EqualValues(t, 0, cache.Len()) - require.EqualValues(t, 2, cache.Cap()) - require.EqualValues(t, stmtcache.ModeDescribe, cache.Mode()) - - psd, err := cache.Get(ctx, "select 1") - require.NoError(t, err) - require.NotNil(t, psd) - require.EqualValues(t, 1, cache.Len()) - require.Empty(t, fetchServerStatements(t, ctx, conn)) - - psd, err = cache.Get(ctx, "select 1") - require.NoError(t, err) - require.NotNil(t, psd) - require.EqualValues(t, 1, cache.Len()) - require.Empty(t, fetchServerStatements(t, ctx, conn)) - - psd, err = cache.Get(ctx, "select 2") - require.NoError(t, err) - require.NotNil(t, psd) - require.EqualValues(t, 2, cache.Len()) - require.Empty(t, fetchServerStatements(t, ctx, conn)) - - psd, err = cache.Get(ctx, "select 3") - require.NoError(t, err) - require.NotNil(t, psd) - require.EqualValues(t, 2, cache.Len()) - require.Empty(t, fetchServerStatements(t, ctx, conn)) - - err = cache.Clear(ctx) - require.NoError(t, err) - require.EqualValues(t, 0, cache.Len()) - require.Empty(t, fetchServerStatements(t, ctx, conn)) -} - -func TestLRUContext(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer conn.Close(ctx) - - cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2) - - // test 1 : getting a value for the first time with a cancelled context returns an error - ctx1, cancel1 := context.WithCancel(ctx) - cancel1() - - desc, err := cache.Get(ctx1, "SELECT 1") - require.Error(t, err) - require.Nil(t, desc) - - // test 2 : when querying for the 2nd time a cached value, if the context is canceled return an error - ctx2, cancel2 := context.WithCancel(ctx) - - desc, err = cache.Get(ctx2, "SELECT 2") - require.NoError(t, err) - require.NotNil(t, desc) - - cancel2() - - desc, err = cache.Get(ctx2, "SELECT 2") - require.Error(t, err) - require.Nil(t, desc) -} - -func fetchServerStatements(t testing.TB, ctx context.Context, conn *pgconn.PgConn) []string { - result := conn.ExecParams(ctx, `select statement from pg_prepared_statements`, nil, nil, nil, nil).Read() - require.NoError(t, result.Err) - var statements []string - for _, r := range result.Rows { - statement := string(r[0]) - if conn.ParameterStatus("crdb_version") != "" { - if statement == "PREPARE AS select statement from pg_prepared_statements" { - // CockroachDB includes the currently running unnamed prepared statement while PostgreSQL does not. Ignore it. - continue - } - - // CockroachDB includes the "PREPARE ... AS" text in the statement even if it was prepared through the extended - // protocol will PostgreSQL does not. Normalize the statement. - re := regexp.MustCompile(`^PREPARE lrupsc[0-9_]+ AS `) - statement = re.ReplaceAllString(statement, "") - } - statements = append(statements, statement) - } - return statements -} diff --git a/internal/stmtcache/stmtcache.go b/internal/stmtcache/stmtcache.go index a2582019..f975273e 100644 --- a/internal/stmtcache/stmtcache.go +++ b/internal/stmtcache/stmtcache.go @@ -2,57 +2,56 @@ package stmtcache import ( - "context" + "strconv" + "sync/atomic" "github.com/jackc/pgx/v5/pgconn" ) -const ( - ModePrepare = iota // Cache should prepare named statements. - ModeDescribe // Cache should prepare the anonymous prepared statement to only fetch the description of the statement. -) +var stmtCounter int64 -// Cache prepares and caches prepared statement descriptions. +// NextStatementName returns a statement name that will be unique for the lifetime of the program. +func NextStatementName() string { + n := atomic.AddInt64(&stmtCounter, 1) + return "stmtcache_" + strconv.FormatInt(n, 10) +} + +// Cache caches statement descriptions. type Cache interface { - // Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. - Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) + // Get returns the statement description for sql. Returns nil if not found. + Get(sql string) *pgconn.StatementDescription - // Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session. - Clear(ctx context.Context) error + // Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache. + Put(sd *pgconn.StatementDescription) - // StatementErrored informs the cache that the given statement resulted in an error when it - // was last used against the database. In some cases, this will cause the cache to maer that - // statement as bad. The bad statement will instead be flushed during the next call to Get - // that occurs outside of a failed transaction. - StatementErrored(sql string, err error) + // Invalidate invalidates statement description identified by sql. Does nothing if not found. + Invalidate(sql string) + + // InvalidateAll invalidates all statement descriptions. + InvalidateAll() + + // HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated. + HandleInvalidated() []*pgconn.StatementDescription // Len returns the number of cached prepared statement descriptions. Len() int // Cap returns the maximum number of cached prepared statement descriptions. Cap() int - - // Mode returns the mode of the cache (ModePrepare or ModeDescribe) - Mode() int } -// New returns the preferred cache implementation for mode and cap. mode is either ModePrepare or ModeDescribe. cap is -// the maximum size of the cache. -func New(conn *pgconn.PgConn, mode int, cap int) Cache { - mustBeValidMode(mode) - mustBeValidCap(cap) - - return NewLRU(conn, mode, cap) -} - -func mustBeValidMode(mode int) { - if mode != ModePrepare && mode != ModeDescribe { - panic("mode must be ModePrepare or ModeDescribe") +func IsStatementInvalid(err error) bool { + pgErr, ok := err.(*pgconn.PgError) + if !ok { + return false } -} -func mustBeValidCap(cap int) { - if cap < 1 { - panic("cache must have cap of >= 1") - } + // https://github.com/jackc/pgx/issues/1162 + // + // We used to look for the message "cached plan must not change result type". However, that message can be localized. + // Unfortunately, error code "0A000" - "FEATURE NOT SUPPORTED" is used for many different errors and the only way to + // tell the difference is by the message. But all that happens is we clear a statement that we otherwise wouldn't + // have so it should be safe. + possibleInvalidCachedPlanError := pgErr.Code == "0A000" + return possibleInvalidCachedPlanError } diff --git a/internal/stmtcache/unlimited_cache.go b/internal/stmtcache/unlimited_cache.go new file mode 100644 index 00000000..f5f59396 --- /dev/null +++ b/internal/stmtcache/unlimited_cache.go @@ -0,0 +1,71 @@ +package stmtcache + +import ( + "math" + + "github.com/jackc/pgx/v5/pgconn" +) + +// UnlimitedCache implements Cache with no capacity limit. +type UnlimitedCache struct { + m map[string]*pgconn.StatementDescription + invalidStmts []*pgconn.StatementDescription +} + +// NewUnlimitedCache creates a new UnlimitedCache. +func NewUnlimitedCache() *UnlimitedCache { + return &UnlimitedCache{ + m: make(map[string]*pgconn.StatementDescription), + } +} + +// Get returns the statement description for sql. Returns nil if not found. +func (c *UnlimitedCache) Get(sql string) *pgconn.StatementDescription { + return c.m[sql] +} + +// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache. +func (c *UnlimitedCache) Put(sd *pgconn.StatementDescription) { + if sd.SQL == "" { + panic("cannot store statement description with empty SQL") + } + + if _, present := c.m[sd.SQL]; present { + return + } + + c.m[sd.SQL] = sd +} + +// Invalidate invalidates statement description identified by sql. Does nothing if not found. +func (c *UnlimitedCache) Invalidate(sql string) { + if sd, ok := c.m[sql]; ok { + delete(c.m, sql) + c.invalidStmts = append(c.invalidStmts, sd) + } +} + +// InvalidateAll invalidates all statement descriptions. +func (c *UnlimitedCache) InvalidateAll() { + for _, sd := range c.m { + c.invalidStmts = append(c.invalidStmts, sd) + } + + c.m = make(map[string]*pgconn.StatementDescription) +} + +func (c *UnlimitedCache) HandleInvalidated() []*pgconn.StatementDescription { + invalidStmts := c.invalidStmts + c.invalidStmts = nil + return invalidStmts +} + +// Len returns the number of cached prepared statement descriptions. +func (c *UnlimitedCache) Len() int { + return len(c.m) +} + +// Cap returns the maximum number of cached prepared statement descriptions. +func (c *UnlimitedCache) Cap() int { + return math.MaxInt +} diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index bb4d35a9..65fb015a 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -1909,10 +1909,10 @@ func (p *Pipeline) Close() error { for p.expectedReadyForQueryCount > 0 { _, err := p.GetResults() if err != nil { + p.err = err var pgErr *PgError if !errors.As(err, &pgErr) { p.conn.asyncClose() - p.err = err break } } diff --git a/rows.go b/rows.go index a1492c3e..4d4c5ec6 100644 --- a/rows.go +++ b/rows.go @@ -7,6 +7,7 @@ import ( "reflect" "time" + "github.com/jackc/pgx/v5/internal/stmtcache" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgtype" @@ -173,8 +174,16 @@ func (rows *baseRows) Close() { } } - if rows.err != nil && rows.conn != nil && rows.conn.statementCache != nil { - rows.conn.statementCache.StatementErrored(rows.sql, rows.err) + if rows.err != nil && rows.conn != nil && rows.sql != "" { + if stmtcache.IsStatementInvalid(rows.err) { + if sc := rows.conn.statementCache; sc != nil { + sc.Invalidate(rows.sql) + } + + if sc := rows.conn.descriptionCache; sc != nil { + sc.Invalidate(rows.sql) + } + } } }