SendBatch now uses pipeline mode to prepare and describe statements
Previously, a batch with 10 unique parameterized statements executed 100 times would entail 11 network round trips. 1 for each prepare / describe and 1 for executing them all. Now pipeline mode is used to prepare / describe all statements in a single network round trip. So it would only take 2 round trips.
This commit is contained in:
@@ -139,6 +139,12 @@ The `RowScanner` interface allows a single argument to Rows.Scan to scan the ent
|
|||||||
|
|
||||||
`QueryFunc` has been replaced by using `ForEachScannedRow`.
|
`QueryFunc` has been replaced by using `ForEachScannedRow`.
|
||||||
|
|
||||||
|
## SendBatch Uses Pipeline Mode When Appropriate
|
||||||
|
|
||||||
|
Previously, a batch with 10 unique parameterized statements executed 100 times would entail 11 network round trips. 1
|
||||||
|
for each prepare / describe and 1 for executing them all. Now pipeline mode is used to prepare / describe all statements
|
||||||
|
in a single network round trip. So it would only take 2 round trips.
|
||||||
|
|
||||||
## 3rd Party Logger Integration
|
## 3rd Party Logger Integration
|
||||||
|
|
||||||
All integrations with 3rd party loggers have been extracted to separate repositories. This trims the pgx dependency
|
All integrations with 3rd party loggers have been extracted to separate repositories. This trims the pgx dependency
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
type batchItem struct {
|
type batchItem struct {
|
||||||
query string
|
query string
|
||||||
arguments []any
|
arguments []any
|
||||||
|
sd *pgconn.StatementDescription
|
||||||
}
|
}
|
||||||
|
|
||||||
// Batch queries are a way of bundling multiple queries together to avoid
|
// Batch queries are a way of bundling multiple queries together to avoid
|
||||||
@@ -192,3 +193,165 @@ func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type pipelineBatchResults struct {
|
||||||
|
ctx context.Context
|
||||||
|
conn *Conn
|
||||||
|
pipeline *pgconn.Pipeline
|
||||||
|
lastRows *baseRows
|
||||||
|
err error
|
||||||
|
b *Batch
|
||||||
|
ix int
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec reads the results from the next query in the batch as if the query has been sent with Exec.
|
||||||
|
func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) {
|
||||||
|
if br.err != nil {
|
||||||
|
return pgconn.CommandTag{}, br.err
|
||||||
|
}
|
||||||
|
if br.closed {
|
||||||
|
return pgconn.CommandTag{}, fmt.Errorf("batch already closed")
|
||||||
|
}
|
||||||
|
if br.lastRows != nil && br.lastRows.err != nil {
|
||||||
|
return pgconn.CommandTag{}, br.err
|
||||||
|
}
|
||||||
|
|
||||||
|
query, arguments, _ := br.nextQueryAndArgs()
|
||||||
|
|
||||||
|
results, err := br.pipeline.GetResults()
|
||||||
|
if err != nil {
|
||||||
|
br.err = err
|
||||||
|
return pgconn.CommandTag{}, err
|
||||||
|
}
|
||||||
|
var commandTag pgconn.CommandTag
|
||||||
|
switch results := results.(type) {
|
||||||
|
case *pgconn.ResultReader:
|
||||||
|
commandTag, err = results.Close()
|
||||||
|
default:
|
||||||
|
return pgconn.CommandTag{}, fmt.Errorf("unexpected pipeline result: %T", results)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
br.err = err
|
||||||
|
if br.conn.shouldLog(LogLevelError) {
|
||||||
|
br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]any{
|
||||||
|
"sql": query,
|
||||||
|
"args": logQueryArgs(arguments),
|
||||||
|
"err": err,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else if br.conn.shouldLog(LogLevelInfo) {
|
||||||
|
br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Exec", map[string]any{
|
||||||
|
"sql": query,
|
||||||
|
"args": logQueryArgs(arguments),
|
||||||
|
"commandTag": commandTag,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return commandTag, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query reads the results from the next query in the batch as if the query has been sent with Query.
|
||||||
|
func (br *pipelineBatchResults) Query() (Rows, error) {
|
||||||
|
if br.err != nil {
|
||||||
|
return &baseRows{err: br.err, closed: true}, br.err
|
||||||
|
}
|
||||||
|
|
||||||
|
if br.closed {
|
||||||
|
alreadyClosedErr := fmt.Errorf("batch already closed")
|
||||||
|
return &baseRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr
|
||||||
|
}
|
||||||
|
|
||||||
|
if br.lastRows != nil && br.lastRows.err != nil {
|
||||||
|
br.err = br.lastRows.err
|
||||||
|
return &baseRows{err: br.err, closed: true}, br.err
|
||||||
|
}
|
||||||
|
|
||||||
|
query, arguments, ok := br.nextQueryAndArgs()
|
||||||
|
if !ok {
|
||||||
|
query = "batch query"
|
||||||
|
}
|
||||||
|
|
||||||
|
rows := br.conn.getRows(br.ctx, query, arguments)
|
||||||
|
br.lastRows = rows
|
||||||
|
|
||||||
|
results, err := br.pipeline.GetResults()
|
||||||
|
if err != nil {
|
||||||
|
br.err = err
|
||||||
|
rows.err = err
|
||||||
|
rows.closed = true
|
||||||
|
if br.conn.shouldLog(LogLevelError) {
|
||||||
|
br.conn.log(br.ctx, LogLevelError, "BatchResult.Query", map[string]any{
|
||||||
|
"sql": query,
|
||||||
|
"args": logQueryArgs(arguments),
|
||||||
|
"err": rows.err,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
switch results := results.(type) {
|
||||||
|
case *pgconn.ResultReader:
|
||||||
|
rows.resultReader = results
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("unexpected pipeline result: %T", results)
|
||||||
|
br.err = err
|
||||||
|
rows.err = err
|
||||||
|
rows.closed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return rows, rows.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow.
|
||||||
|
func (br *pipelineBatchResults) QueryRow() Row {
|
||||||
|
rows, _ := br.Query()
|
||||||
|
return (*connRow)(rows.(*baseRows))
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to
|
||||||
|
// resyncronize the connection with the server. In this case the underlying connection will have been closed.
|
||||||
|
func (br *pipelineBatchResults) Close() error {
|
||||||
|
if br.err != nil {
|
||||||
|
return br.err
|
||||||
|
}
|
||||||
|
|
||||||
|
if br.lastRows != nil && br.lastRows.err != nil {
|
||||||
|
br.err = br.lastRows.err
|
||||||
|
return br.err
|
||||||
|
}
|
||||||
|
|
||||||
|
if br.closed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
br.closed = true
|
||||||
|
|
||||||
|
// log any queries that haven't yet been logged by Exec or Query
|
||||||
|
for {
|
||||||
|
query, args, ok := br.nextQueryAndArgs()
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if br.conn.shouldLog(LogLevelInfo) {
|
||||||
|
br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Close", map[string]any{
|
||||||
|
"sql": query,
|
||||||
|
"args": logQueryArgs(args),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return br.pipeline.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
|
||||||
|
if br.b != nil && br.ix < len(br.b.items) {
|
||||||
|
bi := br.b.items[br.ix]
|
||||||
|
query = bi.query
|
||||||
|
args = bi.arguments
|
||||||
|
ok = true
|
||||||
|
br.ix++
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|||||||
+1
-1
@@ -420,7 +420,7 @@ func TestConnSendBatchQueryError(t *testing.T) {
|
|||||||
|
|
||||||
err = br.Close()
|
err = br.Close()
|
||||||
if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") {
|
if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") {
|
||||||
t.Errorf("rows.Err() => %v, want error code %v", err, 22012)
|
t.Errorf("br.Close() => %v, want error code %v", err, 22012)
|
||||||
}
|
}
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -236,11 +236,11 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
|
|||||||
c.wbuf = make([]byte, 0, 1024)
|
c.wbuf = make([]byte, 0, 1024)
|
||||||
|
|
||||||
if c.config.StatementCacheCapacity > 0 {
|
if c.config.StatementCacheCapacity > 0 {
|
||||||
c.statementCache = stmtcache.New(c.pgConn, stmtcache.ModePrepare, c.config.StatementCacheCapacity)
|
c.statementCache = stmtcache.NewLRUCache(c.config.StatementCacheCapacity)
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.config.DescriptionCacheCapacity > 0 {
|
if c.config.DescriptionCacheCapacity > 0 {
|
||||||
c.descriptionCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, c.config.DescriptionCacheCapacity)
|
c.descriptionCache = stmtcache.NewLRUCache(c.config.DescriptionCacheCapacity)
|
||||||
}
|
}
|
||||||
|
|
||||||
return c, nil
|
return c, nil
|
||||||
@@ -382,6 +382,10 @@ func (c *Conn) Config() *ConnConfig { return c.config.Copy() }
|
|||||||
// Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced
|
// Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced
|
||||||
// positionally from the sql string as $1, $2, etc.
|
// positionally from the sql string as $1, $2, etc.
|
||||||
func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) {
|
func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) {
|
||||||
|
if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
|
||||||
|
return pgconn.CommandTag{}, err
|
||||||
|
}
|
||||||
|
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
|
|
||||||
commandTag, err := c.exec(ctx, sql, arguments...)
|
commandTag, err := c.exec(ctx, sql, arguments...)
|
||||||
@@ -437,9 +441,13 @@ optionLoop:
|
|||||||
if c.statementCache == nil {
|
if c.statementCache == nil {
|
||||||
return pgconn.CommandTag{}, errDisabledStatementCache
|
return pgconn.CommandTag{}, errDisabledStatementCache
|
||||||
}
|
}
|
||||||
sd, err := c.statementCache.Get(ctx, sql)
|
sd := c.statementCache.Get(sql)
|
||||||
if err != nil {
|
if sd == nil {
|
||||||
return pgconn.CommandTag{}, err
|
sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql)
|
||||||
|
if err != nil {
|
||||||
|
return pgconn.CommandTag{}, err
|
||||||
|
}
|
||||||
|
c.statementCache.Put(sd)
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.execPrepared(ctx, sd, arguments)
|
return c.execPrepared(ctx, sd, arguments)
|
||||||
@@ -447,9 +455,12 @@ optionLoop:
|
|||||||
if c.descriptionCache == nil {
|
if c.descriptionCache == nil {
|
||||||
return pgconn.CommandTag{}, errDisabledDescriptionCache
|
return pgconn.CommandTag{}, errDisabledDescriptionCache
|
||||||
}
|
}
|
||||||
sd, err := c.descriptionCache.Get(ctx, sql)
|
sd := c.descriptionCache.Get(sql)
|
||||||
if err != nil {
|
if sd == nil {
|
||||||
return pgconn.CommandTag{}, err
|
sd, err = c.Prepare(ctx, "", sql)
|
||||||
|
if err != nil {
|
||||||
|
return pgconn.CommandTag{}, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.execParams(ctx, sd, arguments)
|
return c.execParams(ctx, sd, arguments)
|
||||||
@@ -620,6 +631,10 @@ type QueryRewriter interface {
|
|||||||
// QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely
|
// QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely
|
||||||
// needed. See the documentation for those types for details.
|
// needed. See the documentation for those types for details.
|
||||||
func (c *Conn) Query(ctx context.Context, sql string, args ...any) (Rows, error) {
|
func (c *Conn) Query(ctx context.Context, sql string, args ...any) (Rows, error) {
|
||||||
|
if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
|
||||||
|
return &baseRows{err: err, closed: true}, err
|
||||||
|
}
|
||||||
|
|
||||||
var resultFormats QueryResultFormats
|
var resultFormats QueryResultFormats
|
||||||
var resultFormatsByOID QueryResultFormatsByOID
|
var resultFormatsByOID QueryResultFormatsByOID
|
||||||
mode := c.config.DefaultQueryExecMode
|
mode := c.config.DefaultQueryExecMode
|
||||||
@@ -649,6 +664,11 @@ optionLoop:
|
|||||||
sql, args = queryRewriter.RewriteQuery(ctx, c, sql, args)
|
sql, args = queryRewriter.RewriteQuery(ctx, c, sql, args)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Bypass any statement caching.
|
||||||
|
if sql == "" {
|
||||||
|
mode = QueryExecModeSimpleProtocol
|
||||||
|
}
|
||||||
|
|
||||||
c.eqb.reset()
|
c.eqb.reset()
|
||||||
anynil.NormalizeSlice(args)
|
anynil.NormalizeSlice(args)
|
||||||
rows := c.getRows(ctx, sql, args)
|
rows := c.getRows(ctx, sql, args)
|
||||||
@@ -664,10 +684,14 @@ optionLoop:
|
|||||||
rows.fatal(err)
|
rows.fatal(err)
|
||||||
return rows, err
|
return rows, err
|
||||||
}
|
}
|
||||||
sd, err = c.statementCache.Get(ctx, sql)
|
sd = c.statementCache.Get(sql)
|
||||||
if err != nil {
|
if sd == nil {
|
||||||
rows.fatal(err)
|
sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql)
|
||||||
return rows, err
|
if err != nil {
|
||||||
|
rows.fatal(err)
|
||||||
|
return rows, err
|
||||||
|
}
|
||||||
|
c.statementCache.Put(sd)
|
||||||
}
|
}
|
||||||
case QueryExecModeCacheDescribe:
|
case QueryExecModeCacheDescribe:
|
||||||
if c.descriptionCache == nil {
|
if c.descriptionCache == nil {
|
||||||
@@ -675,10 +699,14 @@ optionLoop:
|
|||||||
rows.fatal(err)
|
rows.fatal(err)
|
||||||
return rows, err
|
return rows, err
|
||||||
}
|
}
|
||||||
sd, err = c.descriptionCache.Get(ctx, sql)
|
sd = c.descriptionCache.Get(sql)
|
||||||
if err != nil {
|
if sd == nil {
|
||||||
rows.fatal(err)
|
sd, err = c.Prepare(ctx, "", sql)
|
||||||
return rows, err
|
if err != nil {
|
||||||
|
rows.fatal(err)
|
||||||
|
return rows, err
|
||||||
|
}
|
||||||
|
c.descriptionCache.Put(sd)
|
||||||
}
|
}
|
||||||
case QueryExecModeDescribeExec:
|
case QueryExecModeDescribeExec:
|
||||||
sd, err = c.Prepare(ctx, "", sql)
|
sd, err = c.Prepare(ctx, "", sql)
|
||||||
@@ -767,6 +795,10 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row {
|
|||||||
// explicit transaction control statements are executed. The returned BatchResults must be closed before the connection
|
// explicit transaction control statements are executed. The returned BatchResults must be closed before the connection
|
||||||
// is used again.
|
// is used again.
|
||||||
func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
||||||
|
if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
|
||||||
|
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||||
|
}
|
||||||
|
|
||||||
mode := c.config.DefaultQueryExecMode
|
mode := c.config.DefaultQueryExecMode
|
||||||
|
|
||||||
for _, bi := range b.items {
|
for _, bi := range b.items {
|
||||||
@@ -794,105 +826,70 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if mode == QueryExecModeSimpleProtocol {
|
if mode == QueryExecModeSimpleProtocol {
|
||||||
var sb strings.Builder
|
return c.sendBatchQueryExecModeSimpleProtocol(ctx, b)
|
||||||
for i, bi := range b.items {
|
}
|
||||||
if i > 0 {
|
|
||||||
sb.WriteByte(';')
|
// All other modes use extended protocol and thus can use prepared statements.
|
||||||
}
|
for _, bi := range b.items {
|
||||||
sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...)
|
if sd, ok := c.preparedStatements[bi.query]; ok {
|
||||||
if err != nil {
|
bi.sd = sd
|
||||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
|
||||||
}
|
|
||||||
sb.WriteString(sql)
|
|
||||||
}
|
|
||||||
mrr := c.pgConn.Exec(ctx, sb.String())
|
|
||||||
return &batchResults{
|
|
||||||
ctx: ctx,
|
|
||||||
conn: c,
|
|
||||||
mrr: mrr,
|
|
||||||
b: b,
|
|
||||||
ix: 0,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
switch mode {
|
||||||
|
case QueryExecModeExec:
|
||||||
|
return c.sendBatchQueryExecModeExec(ctx, b)
|
||||||
|
case QueryExecModeCacheStatement:
|
||||||
|
return c.sendBatchQueryExecModeCacheStatement(ctx, b)
|
||||||
|
case QueryExecModeCacheDescribe:
|
||||||
|
return c.sendBatchQueryExecModeCacheDescribe(ctx, b)
|
||||||
|
case QueryExecModeDescribeExec:
|
||||||
|
return c.sendBatchQueryExecModeDescribeExec(ctx, b)
|
||||||
|
default:
|
||||||
|
panic("unknown QueryExecMode")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults {
|
||||||
|
var sb strings.Builder
|
||||||
|
for i, bi := range b.items {
|
||||||
|
if i > 0 {
|
||||||
|
sb.WriteByte(';')
|
||||||
|
}
|
||||||
|
sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...)
|
||||||
|
if err != nil {
|
||||||
|
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||||
|
}
|
||||||
|
sb.WriteString(sql)
|
||||||
|
}
|
||||||
|
mrr := c.pgConn.Exec(ctx, sb.String())
|
||||||
|
return &batchResults{
|
||||||
|
ctx: ctx,
|
||||||
|
conn: c,
|
||||||
|
mrr: mrr,
|
||||||
|
b: b,
|
||||||
|
ix: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults {
|
||||||
batch := &pgconn.Batch{}
|
batch := &pgconn.Batch{}
|
||||||
|
|
||||||
if mode == QueryExecModeExec {
|
for _, bi := range b.items {
|
||||||
for _, bi := range b.items {
|
sd := bi.sd
|
||||||
c.eqb.reset()
|
if sd != nil {
|
||||||
anynil.NormalizeSlice(bi.arguments)
|
|
||||||
|
|
||||||
sd := c.preparedStatements[bi.query]
|
|
||||||
if sd != nil {
|
|
||||||
err := c.eqb.Build(c.typeMap, sd, bi.arguments)
|
|
||||||
if err != nil {
|
|
||||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
|
||||||
}
|
|
||||||
|
|
||||||
batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
|
||||||
} else {
|
|
||||||
err := c.eqb.Build(c.typeMap, nil, bi.arguments)
|
|
||||||
if err != nil {
|
|
||||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
|
||||||
}
|
|
||||||
batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
|
|
||||||
distinctUnpreparedQueries := map[string]struct{}{}
|
|
||||||
|
|
||||||
for _, bi := range b.items {
|
|
||||||
if _, ok := c.preparedStatements[bi.query]; ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
distinctUnpreparedQueries[bi.query] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
var stmtCache stmtcache.Cache
|
|
||||||
if len(distinctUnpreparedQueries) > 0 {
|
|
||||||
if mode == QueryExecModeCacheStatement && c.statementCache != nil && c.statementCache.Cap() >= len(distinctUnpreparedQueries) {
|
|
||||||
stmtCache = c.statementCache
|
|
||||||
} else if mode == QueryExecModeCacheStatement && c.descriptionCache != nil && c.descriptionCache.Cap() >= len(distinctUnpreparedQueries) {
|
|
||||||
stmtCache = c.descriptionCache
|
|
||||||
} else {
|
|
||||||
stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries))
|
|
||||||
}
|
|
||||||
|
|
||||||
for sql, _ := range distinctUnpreparedQueries {
|
|
||||||
_, err := stmtCache.Get(ctx, sql)
|
|
||||||
if err != nil {
|
|
||||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, bi := range b.items {
|
|
||||||
c.eqb.reset()
|
|
||||||
|
|
||||||
sd := c.preparedStatements[bi.query]
|
|
||||||
if sd == nil {
|
|
||||||
var err error
|
|
||||||
sd, err = stmtCache.Get(ctx, bi.query)
|
|
||||||
if err != nil {
|
|
||||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(sd.ParamOIDs) != len(bi.arguments) {
|
|
||||||
return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")}
|
|
||||||
}
|
|
||||||
|
|
||||||
err := c.eqb.Build(c.typeMap, sd, bi.arguments)
|
err := c.eqb.Build(c.typeMap, sd, bi.arguments)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
if sd.Name == "" {
|
batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
||||||
batch.ExecParams(bi.query, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
} else {
|
||||||
} else {
|
err := c.eqb.Build(c.typeMap, nil, bi.arguments)
|
||||||
batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
if err != nil {
|
||||||
|
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||||
}
|
}
|
||||||
|
batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -909,6 +906,171 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
|
||||||
|
if c.statementCache == nil {
|
||||||
|
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache}
|
||||||
|
}
|
||||||
|
|
||||||
|
distinctNewQueries := []*pgconn.StatementDescription{}
|
||||||
|
distinctNewQueriesIdxMap := make(map[string]int)
|
||||||
|
|
||||||
|
for _, bi := range b.items {
|
||||||
|
if bi.sd == nil {
|
||||||
|
sd := c.statementCache.Get(bi.query)
|
||||||
|
if sd != nil {
|
||||||
|
bi.sd = sd
|
||||||
|
} else {
|
||||||
|
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
|
||||||
|
bi.sd = distinctNewQueries[idx]
|
||||||
|
} else {
|
||||||
|
sd = &pgconn.StatementDescription{
|
||||||
|
Name: stmtcache.NextStatementName(),
|
||||||
|
SQL: bi.query,
|
||||||
|
}
|
||||||
|
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
|
||||||
|
distinctNewQueries = append(distinctNewQueries, sd)
|
||||||
|
bi.sd = sd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, c.statementCache)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
|
||||||
|
if c.descriptionCache == nil {
|
||||||
|
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache}
|
||||||
|
}
|
||||||
|
|
||||||
|
distinctNewQueries := []*pgconn.StatementDescription{}
|
||||||
|
distinctNewQueriesIdxMap := make(map[string]int)
|
||||||
|
|
||||||
|
for _, bi := range b.items {
|
||||||
|
if bi.sd == nil {
|
||||||
|
sd := c.descriptionCache.Get(bi.query)
|
||||||
|
if sd != nil {
|
||||||
|
bi.sd = sd
|
||||||
|
} else {
|
||||||
|
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
|
||||||
|
bi.sd = distinctNewQueries[idx]
|
||||||
|
} else {
|
||||||
|
sd = &pgconn.StatementDescription{
|
||||||
|
SQL: bi.query,
|
||||||
|
}
|
||||||
|
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
|
||||||
|
distinctNewQueries = append(distinctNewQueries, sd)
|
||||||
|
bi.sd = sd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, c.descriptionCache)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
|
||||||
|
distinctNewQueries := []*pgconn.StatementDescription{}
|
||||||
|
distinctNewQueriesIdxMap := make(map[string]int)
|
||||||
|
|
||||||
|
for _, bi := range b.items {
|
||||||
|
if bi.sd == nil {
|
||||||
|
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
|
||||||
|
bi.sd = distinctNewQueries[idx]
|
||||||
|
} else {
|
||||||
|
sd := &pgconn.StatementDescription{
|
||||||
|
SQL: bi.query,
|
||||||
|
}
|
||||||
|
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
|
||||||
|
distinctNewQueries = append(distinctNewQueries, sd)
|
||||||
|
bi.sd = sd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, distinctNewQueries []*pgconn.StatementDescription, sdCache stmtcache.Cache) (pbr *pipelineBatchResults) {
|
||||||
|
pipeline := c.pgConn.StartPipeline(context.Background())
|
||||||
|
defer func() {
|
||||||
|
if pbr.err != nil {
|
||||||
|
pipeline.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Prepare any needed queries
|
||||||
|
if len(distinctNewQueries) > 0 {
|
||||||
|
for _, sd := range distinctNewQueries {
|
||||||
|
pipeline.SendPrepare(sd.Name, sd.SQL, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := pipeline.Sync()
|
||||||
|
if err != nil {
|
||||||
|
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, sd := range distinctNewQueries {
|
||||||
|
results, err := pipeline.GetResults()
|
||||||
|
if err != nil {
|
||||||
|
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
resultSD, ok := results.(*pgconn.StatementDescription)
|
||||||
|
if !ok {
|
||||||
|
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fill in the previously empty / pending statement descriptions.
|
||||||
|
sd.ParamOIDs = resultSD.ParamOIDs
|
||||||
|
sd.Fields = resultSD.Fields
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := pipeline.GetResults()
|
||||||
|
if err != nil {
|
||||||
|
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ok := results.(*pgconn.PipelineSync)
|
||||||
|
if !ok {
|
||||||
|
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put all statements into the cache. It's fine if it overflows because HandleInvalidated will clean them up later.
|
||||||
|
if sdCache != nil {
|
||||||
|
for _, sd := range distinctNewQueries {
|
||||||
|
c.statementCache.Put(sd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Queue the queries.
|
||||||
|
for _, bi := range b.items {
|
||||||
|
err := c.eqb.Build(c.typeMap, bi.sd, bi.arguments)
|
||||||
|
if err != nil {
|
||||||
|
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
if bi.sd.Name == "" {
|
||||||
|
pipeline.SendQueryParams(bi.sd.SQL, c.eqb.ParamValues, bi.sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
||||||
|
} else {
|
||||||
|
pipeline.SendQueryPrepared(bi.sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err := pipeline.Sync()
|
||||||
|
if err != nil {
|
||||||
|
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &pipelineBatchResults{
|
||||||
|
ctx: ctx,
|
||||||
|
conn: c,
|
||||||
|
pipeline: pipeline,
|
||||||
|
b: b,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Conn) sanitizeForSimpleQuery(sql string, args ...any) (string, error) {
|
func (c *Conn) sanitizeForSimpleQuery(sql string, args ...any) (string, error) {
|
||||||
if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" {
|
if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" {
|
||||||
return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on")
|
return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on")
|
||||||
@@ -1015,3 +1177,37 @@ order by attnum`,
|
|||||||
|
|
||||||
return fields, nil
|
return fields, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error {
|
||||||
|
if c.descriptionCache != nil {
|
||||||
|
c.descriptionCache.HandleInvalidated()
|
||||||
|
}
|
||||||
|
|
||||||
|
var invalidatedStatements []*pgconn.StatementDescription
|
||||||
|
if c.statementCache != nil {
|
||||||
|
invalidatedStatements = c.statementCache.HandleInvalidated()
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(invalidatedStatements) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pipeline := c.pgConn.StartPipeline(ctx)
|
||||||
|
defer pipeline.Close()
|
||||||
|
|
||||||
|
for _, sd := range invalidatedStatements {
|
||||||
|
pipeline.SendDeallocate(sd.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := pipeline.Sync()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to deallocate cached statement(s): %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = pipeline.Close()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to deallocate cached statement(s): %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
+9
-8
@@ -931,6 +931,7 @@ func TestStmtCacheInvalidationConn(t *testing.T) {
|
|||||||
rows, err := conn.Query(ctx, getSQL, 1)
|
rows, err := conn.Query(ctx, getSQL, 1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
rows.Close()
|
rows.Close()
|
||||||
|
require.NoError(t, rows.Err())
|
||||||
|
|
||||||
// Now, change the schema of the table out from under the statement, making it invalid.
|
// Now, change the schema of the table out from under the statement, making it invalid.
|
||||||
_, err = conn.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1")
|
_, err = conn.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1")
|
||||||
@@ -948,10 +949,10 @@ func TestStmtCacheInvalidationConn(t *testing.T) {
|
|||||||
rows.Close()
|
rows.Close()
|
||||||
for _, err := range []error{nextErr, rows.Err()} {
|
for _, err := range []error{nextErr, rows.Err()} {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected InvalidCachedStatementPlanError: no error")
|
t.Fatal(`expected "cached plan must not change result type": no error`)
|
||||||
}
|
}
|
||||||
if !strings.Contains(err.Error(), "cached plan must not change result type") {
|
if !strings.Contains(err.Error(), "cached plan must not change result type") {
|
||||||
t.Fatalf("expected InvalidCachedStatementPlanError, got: %s", err.Error())
|
t.Fatalf(`expected "cached plan must not change result type", got: "%s"`, err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -995,6 +996,7 @@ func TestStmtCacheInvalidationTx(t *testing.T) {
|
|||||||
rows, err := tx.Query(ctx, getSQL, 1)
|
rows, err := tx.Query(ctx, getSQL, 1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
rows.Close()
|
rows.Close()
|
||||||
|
require.NoError(t, rows.Err())
|
||||||
|
|
||||||
// Now, change the schema of the table out from under the statement, making it invalid.
|
// Now, change the schema of the table out from under the statement, making it invalid.
|
||||||
_, err = tx.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1")
|
_, err = tx.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1")
|
||||||
@@ -1012,18 +1014,17 @@ func TestStmtCacheInvalidationTx(t *testing.T) {
|
|||||||
rows.Close()
|
rows.Close()
|
||||||
for _, err := range []error{nextErr, rows.Err()} {
|
for _, err := range []error{nextErr, rows.Err()} {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected InvalidCachedStatementPlanError: no error")
|
t.Fatal(`expected "cached plan must not change result type": no error`)
|
||||||
}
|
}
|
||||||
if !strings.Contains(err.Error(), "cached plan must not change result type") {
|
if !strings.Contains(err.Error(), "cached plan must not change result type") {
|
||||||
t.Fatalf("expected InvalidCachedStatementPlanError, got: %s", err.Error())
|
t.Fatalf(`expected "cached plan must not change result type", got: "%s"`, err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err = tx.Query(ctx, getSQL, 1)
|
rows, _ = tx.Query(ctx, getSQL, 1)
|
||||||
require.NoError(t, err) // error does not pop up immediately
|
rows.Close()
|
||||||
rows.Next()
|
|
||||||
err = rows.Err()
|
err = rows.Err()
|
||||||
// Retries within the same transaction are errors (really anything except a rollbakc
|
// Retries within the same transaction are errors (really anything except a rollback
|
||||||
// will be an error in this transaction).
|
// will be an error in this transaction).
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
rows.Close()
|
rows.Close()
|
||||||
|
|||||||
@@ -1,169 +0,0 @@
|
|||||||
package stmtcache
|
|
||||||
|
|
||||||
import (
|
|
||||||
"container/list"
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"sync/atomic"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5/pgconn"
|
|
||||||
)
|
|
||||||
|
|
||||||
var lruCount uint64
|
|
||||||
|
|
||||||
// LRU implements Cache with a Least Recently Used (LRU) cache.
|
|
||||||
type LRU struct {
|
|
||||||
conn *pgconn.PgConn
|
|
||||||
mode int
|
|
||||||
cap int
|
|
||||||
prepareCount int
|
|
||||||
m map[string]*list.Element
|
|
||||||
l *list.List
|
|
||||||
psNamePrefix string
|
|
||||||
stmtsToClear []string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewLRU creates a new LRU. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache.
|
|
||||||
func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU {
|
|
||||||
mustBeValidMode(mode)
|
|
||||||
mustBeValidCap(cap)
|
|
||||||
|
|
||||||
n := atomic.AddUint64(&lruCount, 1)
|
|
||||||
|
|
||||||
return &LRU{
|
|
||||||
conn: conn,
|
|
||||||
mode: mode,
|
|
||||||
cap: cap,
|
|
||||||
m: make(map[string]*list.Element),
|
|
||||||
l: list.New(),
|
|
||||||
psNamePrefix: fmt.Sprintf("lrupsc_%d", n),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed.
|
|
||||||
func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
|
|
||||||
if ctx != context.Background() {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, ctx.Err()
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// flush an outstanding bad statements
|
|
||||||
txStatus := c.conn.TxStatus()
|
|
||||||
if (txStatus == 'I' || txStatus == 'T') && len(c.stmtsToClear) > 0 {
|
|
||||||
for _, stmt := range c.stmtsToClear {
|
|
||||||
err := c.clearStmt(ctx, stmt)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if el, ok := c.m[sql]; ok {
|
|
||||||
c.l.MoveToFront(el)
|
|
||||||
return el.Value.(*pgconn.StatementDescription), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.l.Len() == c.cap {
|
|
||||||
err := c.removeOldest(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
psd, err := c.prepare(ctx, sql)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
el := c.l.PushFront(psd)
|
|
||||||
c.m[sql] = el
|
|
||||||
|
|
||||||
return psd, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session.
|
|
||||||
func (c *LRU) Clear(ctx context.Context) error {
|
|
||||||
for c.l.Len() > 0 {
|
|
||||||
err := c.removeOldest(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *LRU) StatementErrored(sql string, err error) {
|
|
||||||
pgErr, ok := err.(*pgconn.PgError)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/jackc/pgx/issues/1162
|
|
||||||
//
|
|
||||||
// We used to look for the message "cached plan must not change result type". However, that message can be localized.
|
|
||||||
// Unfortunately, error code "0A000" - "FEATURE NOT SUPPORTED" is used for many different errors and the only way to
|
|
||||||
// tell the difference is by the message. But all that happens is we clear a statement that we otherwise wouldn't
|
|
||||||
// have so it should be safe.
|
|
||||||
possibleInvalidCachedPlanError := pgErr.Code == "0A000"
|
|
||||||
if possibleInvalidCachedPlanError {
|
|
||||||
c.stmtsToClear = append(c.stmtsToClear, sql)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *LRU) clearStmt(ctx context.Context, sql string) error {
|
|
||||||
elem, inMap := c.m[sql]
|
|
||||||
if !inMap {
|
|
||||||
// The statement probably fell off the back of the list. In that case, we've
|
|
||||||
// ensured that it isn't in the cache, so we can declare victory.
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
c.l.Remove(elem)
|
|
||||||
|
|
||||||
psd := elem.Value.(*pgconn.StatementDescription)
|
|
||||||
delete(c.m, psd.SQL)
|
|
||||||
if c.mode == ModePrepare {
|
|
||||||
return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Len returns the number of cached prepared statement descriptions.
|
|
||||||
func (c *LRU) Len() int {
|
|
||||||
return c.l.Len()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cap returns the maximum number of cached prepared statement descriptions.
|
|
||||||
func (c *LRU) Cap() int {
|
|
||||||
return c.cap
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mode returns the mode of the cache (ModePrepare or ModeDescribe)
|
|
||||||
func (c *LRU) Mode() int {
|
|
||||||
return c.mode
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
|
|
||||||
var name string
|
|
||||||
if c.mode == ModePrepare {
|
|
||||||
name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount)
|
|
||||||
c.prepareCount += 1
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.conn.Prepare(ctx, name, sql, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *LRU) removeOldest(ctx context.Context) error {
|
|
||||||
oldest := c.l.Back()
|
|
||||||
c.l.Remove(oldest)
|
|
||||||
psd := oldest.Value.(*pgconn.StatementDescription)
|
|
||||||
delete(c.m, psd.SQL)
|
|
||||||
if c.mode == ModePrepare {
|
|
||||||
return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,98 @@
|
|||||||
|
package stmtcache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"container/list"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LRUCache implements Cache with a Least Recently Used (LRU) cache.
|
||||||
|
type LRUCache struct {
|
||||||
|
cap int
|
||||||
|
m map[string]*list.Element
|
||||||
|
l *list.List
|
||||||
|
invalidStmts []*pgconn.StatementDescription
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLRUCache creates a new LRUCache. cap is the maximum size of the cache.
|
||||||
|
func NewLRUCache(cap int) *LRUCache {
|
||||||
|
return &LRUCache{
|
||||||
|
cap: cap,
|
||||||
|
m: make(map[string]*list.Element),
|
||||||
|
l: list.New(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the statement description for sql. Returns nil if not found.
|
||||||
|
func (c *LRUCache) Get(key string) *pgconn.StatementDescription {
|
||||||
|
if el, ok := c.m[key]; ok {
|
||||||
|
c.l.MoveToFront(el)
|
||||||
|
return el.Value.(*pgconn.StatementDescription)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
|
||||||
|
func (c *LRUCache) Put(sd *pgconn.StatementDescription) {
|
||||||
|
if sd.SQL == "" {
|
||||||
|
panic("cannot store statement description with empty SQL")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, present := c.m[sd.SQL]; present {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.l.Len() == c.cap {
|
||||||
|
c.invalidateOldest()
|
||||||
|
}
|
||||||
|
|
||||||
|
el := c.l.PushFront(sd)
|
||||||
|
c.m[sd.SQL] = el
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalidate invalidates statement description identified by sql. Does nothing if not found.
|
||||||
|
func (c *LRUCache) Invalidate(sql string) {
|
||||||
|
if el, ok := c.m[sql]; ok {
|
||||||
|
delete(c.m, sql)
|
||||||
|
c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription))
|
||||||
|
c.l.Remove(el)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateAll invalidates all statement descriptions.
|
||||||
|
func (c *LRUCache) InvalidateAll() {
|
||||||
|
el := c.l.Front()
|
||||||
|
for el != nil {
|
||||||
|
c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription))
|
||||||
|
el = el.Next()
|
||||||
|
}
|
||||||
|
|
||||||
|
c.m = make(map[string]*list.Element)
|
||||||
|
c.l = list.New()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LRUCache) HandleInvalidated() []*pgconn.StatementDescription {
|
||||||
|
invalidStmts := c.invalidStmts
|
||||||
|
c.invalidStmts = nil
|
||||||
|
return invalidStmts
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the number of cached prepared statement descriptions.
|
||||||
|
func (c *LRUCache) Len() int {
|
||||||
|
return c.l.Len()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cap returns the maximum number of cached prepared statement descriptions.
|
||||||
|
func (c *LRUCache) Cap() int {
|
||||||
|
return c.cap
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LRUCache) invalidateOldest() {
|
||||||
|
oldest := c.l.Back()
|
||||||
|
sd := oldest.Value.(*pgconn.StatementDescription)
|
||||||
|
c.invalidStmts = append(c.invalidStmts, sd)
|
||||||
|
delete(c.m, sd.SQL)
|
||||||
|
c.l.Remove(oldest)
|
||||||
|
}
|
||||||
@@ -1,292 +0,0 @@
|
|||||||
package stmtcache_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"math/rand"
|
|
||||||
"os"
|
|
||||||
"regexp"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5/internal/stmtcache"
|
|
||||||
"github.com/jackc/pgx/v5/pgconn"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestLRUModePrepare(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer conn.Close(ctx)
|
|
||||||
|
|
||||||
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2)
|
|
||||||
require.EqualValues(t, 0, cache.Len())
|
|
||||||
require.EqualValues(t, 2, cache.Cap())
|
|
||||||
require.EqualValues(t, stmtcache.ModePrepare, cache.Mode())
|
|
||||||
|
|
||||||
psd, err := cache.Get(ctx, "select 1")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, psd)
|
|
||||||
require.EqualValues(t, 1, cache.Len())
|
|
||||||
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
|
|
||||||
|
|
||||||
psd, err = cache.Get(ctx, "select 1")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, psd)
|
|
||||||
require.EqualValues(t, 1, cache.Len())
|
|
||||||
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
|
|
||||||
|
|
||||||
psd, err = cache.Get(ctx, "select 2")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, psd)
|
|
||||||
require.EqualValues(t, 2, cache.Len())
|
|
||||||
require.ElementsMatch(t, []string{"select 1", "select 2"}, fetchServerStatements(t, ctx, conn))
|
|
||||||
|
|
||||||
psd, err = cache.Get(ctx, "select 3")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, psd)
|
|
||||||
require.EqualValues(t, 2, cache.Len())
|
|
||||||
require.ElementsMatch(t, []string{"select 2", "select 3"}, fetchServerStatements(t, ctx, conn))
|
|
||||||
|
|
||||||
err = cache.Clear(ctx)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.EqualValues(t, 0, cache.Len())
|
|
||||||
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLRUStmtInvalidation(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer conn.Close(ctx)
|
|
||||||
|
|
||||||
// we construct a fake error because its not super straightforward to actually call
|
|
||||||
// a prepared statement from the LRU cache without the helper routines which live
|
|
||||||
// in pgx proper.
|
|
||||||
fakeInvalidCachePlanError := &pgconn.PgError{
|
|
||||||
Severity: "ERROR",
|
|
||||||
Code: "0A000",
|
|
||||||
Message: "cached plan must not change result type",
|
|
||||||
}
|
|
||||||
|
|
||||||
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2)
|
|
||||||
|
|
||||||
//
|
|
||||||
// outside of a transaction, we eagerly flush the statement
|
|
||||||
//
|
|
||||||
|
|
||||||
_, err = cache.Get(ctx, "select 1")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.EqualValues(t, 1, cache.Len())
|
|
||||||
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
|
|
||||||
|
|
||||||
cache.StatementErrored("select 1", fakeInvalidCachePlanError)
|
|
||||||
_, err = cache.Get(ctx, "select 2")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.EqualValues(t, 1, cache.Len())
|
|
||||||
require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn))
|
|
||||||
|
|
||||||
err = cache.Clear(ctx)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
//
|
|
||||||
// within an errored transaction, we defer the flush to after the first get
|
|
||||||
// that happens after the transaction is rolled back
|
|
||||||
//
|
|
||||||
|
|
||||||
_, err = cache.Get(ctx, "select 1")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.EqualValues(t, 1, cache.Len())
|
|
||||||
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
|
|
||||||
|
|
||||||
res := conn.Exec(ctx, "begin")
|
|
||||||
require.NoError(t, res.Close())
|
|
||||||
require.Equal(t, byte('T'), conn.TxStatus())
|
|
||||||
|
|
||||||
res = conn.Exec(ctx, "selec")
|
|
||||||
require.Error(t, res.Close())
|
|
||||||
require.Equal(t, byte('E'), conn.TxStatus())
|
|
||||||
|
|
||||||
cache.StatementErrored("select 1", fakeInvalidCachePlanError)
|
|
||||||
require.EqualValues(t, 1, cache.Len())
|
|
||||||
|
|
||||||
res = conn.Exec(ctx, "rollback")
|
|
||||||
require.NoError(t, res.Close())
|
|
||||||
|
|
||||||
_, err = cache.Get(ctx, "select 2")
|
|
||||||
require.EqualValues(t, 1, cache.Len())
|
|
||||||
require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLRUStmtInvalidationIntegration(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer conn.Close(ctx)
|
|
||||||
|
|
||||||
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2)
|
|
||||||
|
|
||||||
result := conn.ExecParams(ctx, "create temporary table stmtcache_table (a text)", nil, nil, nil, nil).Read()
|
|
||||||
require.NoError(t, result.Err)
|
|
||||||
|
|
||||||
sql := "select * from stmtcache_table"
|
|
||||||
sd1, err := cache.Get(ctx, sql)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read()
|
|
||||||
require.NoError(t, result.Err)
|
|
||||||
|
|
||||||
result = conn.ExecParams(ctx, "alter table stmtcache_table add column b text", nil, nil, nil, nil).Read()
|
|
||||||
require.NoError(t, result.Err)
|
|
||||||
|
|
||||||
result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read()
|
|
||||||
require.EqualError(t, result.Err, "ERROR: cached plan must not change result type (SQLSTATE 0A000)")
|
|
||||||
|
|
||||||
cache.StatementErrored(sql, result.Err)
|
|
||||||
|
|
||||||
sd2, err := cache.Get(ctx, sql)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotEqual(t, sd1.Name, sd2.Name)
|
|
||||||
|
|
||||||
result = conn.ExecPrepared(ctx, sd2.Name, nil, nil, nil).Read()
|
|
||||||
require.NoError(t, result.Err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLRUModePrepareStress(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer conn.Close(ctx)
|
|
||||||
|
|
||||||
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 8)
|
|
||||||
require.EqualValues(t, 0, cache.Len())
|
|
||||||
require.EqualValues(t, 8, cache.Cap())
|
|
||||||
require.EqualValues(t, stmtcache.ModePrepare, cache.Mode())
|
|
||||||
|
|
||||||
for i := 0; i < 1000; i++ {
|
|
||||||
psd, err := cache.Get(ctx, fmt.Sprintf("select %d", rand.Intn(50)))
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, psd)
|
|
||||||
result := conn.ExecPrepared(ctx, psd.Name, nil, nil, nil).Read()
|
|
||||||
require.NoError(t, result.Err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLRUModeDescribe(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer conn.Close(ctx)
|
|
||||||
|
|
||||||
cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2)
|
|
||||||
require.EqualValues(t, 0, cache.Len())
|
|
||||||
require.EqualValues(t, 2, cache.Cap())
|
|
||||||
require.EqualValues(t, stmtcache.ModeDescribe, cache.Mode())
|
|
||||||
|
|
||||||
psd, err := cache.Get(ctx, "select 1")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, psd)
|
|
||||||
require.EqualValues(t, 1, cache.Len())
|
|
||||||
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
|
||||||
|
|
||||||
psd, err = cache.Get(ctx, "select 1")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, psd)
|
|
||||||
require.EqualValues(t, 1, cache.Len())
|
|
||||||
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
|
||||||
|
|
||||||
psd, err = cache.Get(ctx, "select 2")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, psd)
|
|
||||||
require.EqualValues(t, 2, cache.Len())
|
|
||||||
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
|
||||||
|
|
||||||
psd, err = cache.Get(ctx, "select 3")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, psd)
|
|
||||||
require.EqualValues(t, 2, cache.Len())
|
|
||||||
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
|
||||||
|
|
||||||
err = cache.Clear(ctx)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.EqualValues(t, 0, cache.Len())
|
|
||||||
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLRUContext(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer conn.Close(ctx)
|
|
||||||
|
|
||||||
cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2)
|
|
||||||
|
|
||||||
// test 1 : getting a value for the first time with a cancelled context returns an error
|
|
||||||
ctx1, cancel1 := context.WithCancel(ctx)
|
|
||||||
cancel1()
|
|
||||||
|
|
||||||
desc, err := cache.Get(ctx1, "SELECT 1")
|
|
||||||
require.Error(t, err)
|
|
||||||
require.Nil(t, desc)
|
|
||||||
|
|
||||||
// test 2 : when querying for the 2nd time a cached value, if the context is canceled return an error
|
|
||||||
ctx2, cancel2 := context.WithCancel(ctx)
|
|
||||||
|
|
||||||
desc, err = cache.Get(ctx2, "SELECT 2")
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, desc)
|
|
||||||
|
|
||||||
cancel2()
|
|
||||||
|
|
||||||
desc, err = cache.Get(ctx2, "SELECT 2")
|
|
||||||
require.Error(t, err)
|
|
||||||
require.Nil(t, desc)
|
|
||||||
}
|
|
||||||
|
|
||||||
func fetchServerStatements(t testing.TB, ctx context.Context, conn *pgconn.PgConn) []string {
|
|
||||||
result := conn.ExecParams(ctx, `select statement from pg_prepared_statements`, nil, nil, nil, nil).Read()
|
|
||||||
require.NoError(t, result.Err)
|
|
||||||
var statements []string
|
|
||||||
for _, r := range result.Rows {
|
|
||||||
statement := string(r[0])
|
|
||||||
if conn.ParameterStatus("crdb_version") != "" {
|
|
||||||
if statement == "PREPARE AS select statement from pg_prepared_statements" {
|
|
||||||
// CockroachDB includes the currently running unnamed prepared statement while PostgreSQL does not. Ignore it.
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// CockroachDB includes the "PREPARE ... AS" text in the statement even if it was prepared through the extended
|
|
||||||
// protocol will PostgreSQL does not. Normalize the statement.
|
|
||||||
re := regexp.MustCompile(`^PREPARE lrupsc[0-9_]+ AS `)
|
|
||||||
statement = re.ReplaceAllString(statement, "")
|
|
||||||
}
|
|
||||||
statements = append(statements, statement)
|
|
||||||
}
|
|
||||||
return statements
|
|
||||||
}
|
|
||||||
@@ -2,57 +2,56 @@
|
|||||||
package stmtcache
|
package stmtcache
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"strconv"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5/pgconn"
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
var stmtCounter int64
|
||||||
ModePrepare = iota // Cache should prepare named statements.
|
|
||||||
ModeDescribe // Cache should prepare the anonymous prepared statement to only fetch the description of the statement.
|
|
||||||
)
|
|
||||||
|
|
||||||
// Cache prepares and caches prepared statement descriptions.
|
// NextStatementName returns a statement name that will be unique for the lifetime of the program.
|
||||||
|
func NextStatementName() string {
|
||||||
|
n := atomic.AddInt64(&stmtCounter, 1)
|
||||||
|
return "stmtcache_" + strconv.FormatInt(n, 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cache caches statement descriptions.
|
||||||
type Cache interface {
|
type Cache interface {
|
||||||
// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed.
|
// Get returns the statement description for sql. Returns nil if not found.
|
||||||
Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error)
|
Get(sql string) *pgconn.StatementDescription
|
||||||
|
|
||||||
// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session.
|
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
|
||||||
Clear(ctx context.Context) error
|
Put(sd *pgconn.StatementDescription)
|
||||||
|
|
||||||
// StatementErrored informs the cache that the given statement resulted in an error when it
|
// Invalidate invalidates statement description identified by sql. Does nothing if not found.
|
||||||
// was last used against the database. In some cases, this will cause the cache to maer that
|
Invalidate(sql string)
|
||||||
// statement as bad. The bad statement will instead be flushed during the next call to Get
|
|
||||||
// that occurs outside of a failed transaction.
|
// InvalidateAll invalidates all statement descriptions.
|
||||||
StatementErrored(sql string, err error)
|
InvalidateAll()
|
||||||
|
|
||||||
|
// HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated.
|
||||||
|
HandleInvalidated() []*pgconn.StatementDescription
|
||||||
|
|
||||||
// Len returns the number of cached prepared statement descriptions.
|
// Len returns the number of cached prepared statement descriptions.
|
||||||
Len() int
|
Len() int
|
||||||
|
|
||||||
// Cap returns the maximum number of cached prepared statement descriptions.
|
// Cap returns the maximum number of cached prepared statement descriptions.
|
||||||
Cap() int
|
Cap() int
|
||||||
|
|
||||||
// Mode returns the mode of the cache (ModePrepare or ModeDescribe)
|
|
||||||
Mode() int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// New returns the preferred cache implementation for mode and cap. mode is either ModePrepare or ModeDescribe. cap is
|
func IsStatementInvalid(err error) bool {
|
||||||
// the maximum size of the cache.
|
pgErr, ok := err.(*pgconn.PgError)
|
||||||
func New(conn *pgconn.PgConn, mode int, cap int) Cache {
|
if !ok {
|
||||||
mustBeValidMode(mode)
|
return false
|
||||||
mustBeValidCap(cap)
|
|
||||||
|
|
||||||
return NewLRU(conn, mode, cap)
|
|
||||||
}
|
|
||||||
|
|
||||||
func mustBeValidMode(mode int) {
|
|
||||||
if mode != ModePrepare && mode != ModeDescribe {
|
|
||||||
panic("mode must be ModePrepare or ModeDescribe")
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func mustBeValidCap(cap int) {
|
// https://github.com/jackc/pgx/issues/1162
|
||||||
if cap < 1 {
|
//
|
||||||
panic("cache must have cap of >= 1")
|
// We used to look for the message "cached plan must not change result type". However, that message can be localized.
|
||||||
}
|
// Unfortunately, error code "0A000" - "FEATURE NOT SUPPORTED" is used for many different errors and the only way to
|
||||||
|
// tell the difference is by the message. But all that happens is we clear a statement that we otherwise wouldn't
|
||||||
|
// have so it should be safe.
|
||||||
|
possibleInvalidCachedPlanError := pgErr.Code == "0A000"
|
||||||
|
return possibleInvalidCachedPlanError
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,71 @@
|
|||||||
|
package stmtcache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UnlimitedCache implements Cache with no capacity limit.
|
||||||
|
type UnlimitedCache struct {
|
||||||
|
m map[string]*pgconn.StatementDescription
|
||||||
|
invalidStmts []*pgconn.StatementDescription
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUnlimitedCache creates a new UnlimitedCache.
|
||||||
|
func NewUnlimitedCache() *UnlimitedCache {
|
||||||
|
return &UnlimitedCache{
|
||||||
|
m: make(map[string]*pgconn.StatementDescription),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the statement description for sql. Returns nil if not found.
|
||||||
|
func (c *UnlimitedCache) Get(sql string) *pgconn.StatementDescription {
|
||||||
|
return c.m[sql]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
|
||||||
|
func (c *UnlimitedCache) Put(sd *pgconn.StatementDescription) {
|
||||||
|
if sd.SQL == "" {
|
||||||
|
panic("cannot store statement description with empty SQL")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, present := c.m[sd.SQL]; present {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.m[sd.SQL] = sd
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalidate invalidates statement description identified by sql. Does nothing if not found.
|
||||||
|
func (c *UnlimitedCache) Invalidate(sql string) {
|
||||||
|
if sd, ok := c.m[sql]; ok {
|
||||||
|
delete(c.m, sql)
|
||||||
|
c.invalidStmts = append(c.invalidStmts, sd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateAll invalidates all statement descriptions.
|
||||||
|
func (c *UnlimitedCache) InvalidateAll() {
|
||||||
|
for _, sd := range c.m {
|
||||||
|
c.invalidStmts = append(c.invalidStmts, sd)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.m = make(map[string]*pgconn.StatementDescription)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UnlimitedCache) HandleInvalidated() []*pgconn.StatementDescription {
|
||||||
|
invalidStmts := c.invalidStmts
|
||||||
|
c.invalidStmts = nil
|
||||||
|
return invalidStmts
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the number of cached prepared statement descriptions.
|
||||||
|
func (c *UnlimitedCache) Len() int {
|
||||||
|
return len(c.m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cap returns the maximum number of cached prepared statement descriptions.
|
||||||
|
func (c *UnlimitedCache) Cap() int {
|
||||||
|
return math.MaxInt
|
||||||
|
}
|
||||||
+1
-1
@@ -1909,10 +1909,10 @@ func (p *Pipeline) Close() error {
|
|||||||
for p.expectedReadyForQueryCount > 0 {
|
for p.expectedReadyForQueryCount > 0 {
|
||||||
_, err := p.GetResults()
|
_, err := p.GetResults()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
p.err = err
|
||||||
var pgErr *PgError
|
var pgErr *PgError
|
||||||
if !errors.As(err, &pgErr) {
|
if !errors.As(err, &pgErr) {
|
||||||
p.conn.asyncClose()
|
p.conn.asyncClose()
|
||||||
p.err = err
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/stmtcache"
|
||||||
"github.com/jackc/pgx/v5/pgconn"
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
"github.com/jackc/pgx/v5/pgproto3"
|
"github.com/jackc/pgx/v5/pgproto3"
|
||||||
"github.com/jackc/pgx/v5/pgtype"
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
@@ -173,8 +174,16 @@ func (rows *baseRows) Close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if rows.err != nil && rows.conn != nil && rows.conn.statementCache != nil {
|
if rows.err != nil && rows.conn != nil && rows.sql != "" {
|
||||||
rows.conn.statementCache.StatementErrored(rows.sql, rows.err)
|
if stmtcache.IsStatementInvalid(rows.err) {
|
||||||
|
if sc := rows.conn.statementCache; sc != nil {
|
||||||
|
sc.Invalidate(rows.sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sc := rows.conn.descriptionCache; sc != nil {
|
||||||
|
sc.Invalidate(rows.sql)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user