SendBatch now uses pipeline mode to prepare and describe statements
Previously, a batch with 10 unique parameterized statements executed 100 times would entail 11 network round trips. 1 for each prepare / describe and 1 for executing them all. Now pipeline mode is used to prepare / describe all statements in a single network round trip. So it would only take 2 round trips.
This commit is contained in:
@@ -139,6 +139,12 @@ The `RowScanner` interface allows a single argument to Rows.Scan to scan the ent
|
||||
|
||||
`QueryFunc` has been replaced by using `ForEachScannedRow`.
|
||||
|
||||
## SendBatch Uses Pipeline Mode When Appropriate
|
||||
|
||||
Previously, a batch with 10 unique parameterized statements executed 100 times would entail 11 network round trips. 1
|
||||
for each prepare / describe and 1 for executing them all. Now pipeline mode is used to prepare / describe all statements
|
||||
in a single network round trip. So it would only take 2 round trips.
|
||||
|
||||
## 3rd Party Logger Integration
|
||||
|
||||
All integrations with 3rd party loggers have been extracted to separate repositories. This trims the pgx dependency
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
type batchItem struct {
|
||||
query string
|
||||
arguments []any
|
||||
sd *pgconn.StatementDescription
|
||||
}
|
||||
|
||||
// Batch queries are a way of bundling multiple queries together to avoid
|
||||
@@ -192,3 +193,165 @@ func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type pipelineBatchResults struct {
|
||||
ctx context.Context
|
||||
conn *Conn
|
||||
pipeline *pgconn.Pipeline
|
||||
lastRows *baseRows
|
||||
err error
|
||||
b *Batch
|
||||
ix int
|
||||
closed bool
|
||||
}
|
||||
|
||||
// Exec reads the results from the next query in the batch as if the query has been sent with Exec.
|
||||
func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) {
|
||||
if br.err != nil {
|
||||
return pgconn.CommandTag{}, br.err
|
||||
}
|
||||
if br.closed {
|
||||
return pgconn.CommandTag{}, fmt.Errorf("batch already closed")
|
||||
}
|
||||
if br.lastRows != nil && br.lastRows.err != nil {
|
||||
return pgconn.CommandTag{}, br.err
|
||||
}
|
||||
|
||||
query, arguments, _ := br.nextQueryAndArgs()
|
||||
|
||||
results, err := br.pipeline.GetResults()
|
||||
if err != nil {
|
||||
br.err = err
|
||||
return pgconn.CommandTag{}, err
|
||||
}
|
||||
var commandTag pgconn.CommandTag
|
||||
switch results := results.(type) {
|
||||
case *pgconn.ResultReader:
|
||||
commandTag, err = results.Close()
|
||||
default:
|
||||
return pgconn.CommandTag{}, fmt.Errorf("unexpected pipeline result: %T", results)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
br.err = err
|
||||
if br.conn.shouldLog(LogLevelError) {
|
||||
br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]any{
|
||||
"sql": query,
|
||||
"args": logQueryArgs(arguments),
|
||||
"err": err,
|
||||
})
|
||||
}
|
||||
} else if br.conn.shouldLog(LogLevelInfo) {
|
||||
br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Exec", map[string]any{
|
||||
"sql": query,
|
||||
"args": logQueryArgs(arguments),
|
||||
"commandTag": commandTag,
|
||||
})
|
||||
}
|
||||
|
||||
return commandTag, err
|
||||
}
|
||||
|
||||
// Query reads the results from the next query in the batch as if the query has been sent with Query.
|
||||
func (br *pipelineBatchResults) Query() (Rows, error) {
|
||||
if br.err != nil {
|
||||
return &baseRows{err: br.err, closed: true}, br.err
|
||||
}
|
||||
|
||||
if br.closed {
|
||||
alreadyClosedErr := fmt.Errorf("batch already closed")
|
||||
return &baseRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr
|
||||
}
|
||||
|
||||
if br.lastRows != nil && br.lastRows.err != nil {
|
||||
br.err = br.lastRows.err
|
||||
return &baseRows{err: br.err, closed: true}, br.err
|
||||
}
|
||||
|
||||
query, arguments, ok := br.nextQueryAndArgs()
|
||||
if !ok {
|
||||
query = "batch query"
|
||||
}
|
||||
|
||||
rows := br.conn.getRows(br.ctx, query, arguments)
|
||||
br.lastRows = rows
|
||||
|
||||
results, err := br.pipeline.GetResults()
|
||||
if err != nil {
|
||||
br.err = err
|
||||
rows.err = err
|
||||
rows.closed = true
|
||||
if br.conn.shouldLog(LogLevelError) {
|
||||
br.conn.log(br.ctx, LogLevelError, "BatchResult.Query", map[string]any{
|
||||
"sql": query,
|
||||
"args": logQueryArgs(arguments),
|
||||
"err": rows.err,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
switch results := results.(type) {
|
||||
case *pgconn.ResultReader:
|
||||
rows.resultReader = results
|
||||
default:
|
||||
err = fmt.Errorf("unexpected pipeline result: %T", results)
|
||||
br.err = err
|
||||
rows.err = err
|
||||
rows.closed = true
|
||||
}
|
||||
}
|
||||
|
||||
return rows, rows.err
|
||||
}
|
||||
|
||||
// QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow.
|
||||
func (br *pipelineBatchResults) QueryRow() Row {
|
||||
rows, _ := br.Query()
|
||||
return (*connRow)(rows.(*baseRows))
|
||||
|
||||
}
|
||||
|
||||
// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to
|
||||
// resyncronize the connection with the server. In this case the underlying connection will have been closed.
|
||||
func (br *pipelineBatchResults) Close() error {
|
||||
if br.err != nil {
|
||||
return br.err
|
||||
}
|
||||
|
||||
if br.lastRows != nil && br.lastRows.err != nil {
|
||||
br.err = br.lastRows.err
|
||||
return br.err
|
||||
}
|
||||
|
||||
if br.closed {
|
||||
return nil
|
||||
}
|
||||
br.closed = true
|
||||
|
||||
// log any queries that haven't yet been logged by Exec or Query
|
||||
for {
|
||||
query, args, ok := br.nextQueryAndArgs()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
if br.conn.shouldLog(LogLevelInfo) {
|
||||
br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Close", map[string]any{
|
||||
"sql": query,
|
||||
"args": logQueryArgs(args),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return br.pipeline.Close()
|
||||
}
|
||||
|
||||
func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
|
||||
if br.b != nil && br.ix < len(br.b.items) {
|
||||
bi := br.b.items[br.ix]
|
||||
query = bi.query
|
||||
args = bi.arguments
|
||||
ok = true
|
||||
br.ix++
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
+1
-1
@@ -420,7 +420,7 @@ func TestConnSendBatchQueryError(t *testing.T) {
|
||||
|
||||
err = br.Close()
|
||||
if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") {
|
||||
t.Errorf("rows.Err() => %v, want error code %v", err, 22012)
|
||||
t.Errorf("br.Close() => %v, want error code %v", err, 22012)
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
@@ -236,11 +236,11 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
|
||||
c.wbuf = make([]byte, 0, 1024)
|
||||
|
||||
if c.config.StatementCacheCapacity > 0 {
|
||||
c.statementCache = stmtcache.New(c.pgConn, stmtcache.ModePrepare, c.config.StatementCacheCapacity)
|
||||
c.statementCache = stmtcache.NewLRUCache(c.config.StatementCacheCapacity)
|
||||
}
|
||||
|
||||
if c.config.DescriptionCacheCapacity > 0 {
|
||||
c.descriptionCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, c.config.DescriptionCacheCapacity)
|
||||
c.descriptionCache = stmtcache.NewLRUCache(c.config.DescriptionCacheCapacity)
|
||||
}
|
||||
|
||||
return c, nil
|
||||
@@ -382,6 +382,10 @@ func (c *Conn) Config() *ConnConfig { return c.config.Copy() }
|
||||
// Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced
|
||||
// positionally from the sql string as $1, $2, etc.
|
||||
func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) {
|
||||
if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
|
||||
return pgconn.CommandTag{}, err
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
commandTag, err := c.exec(ctx, sql, arguments...)
|
||||
@@ -437,9 +441,13 @@ optionLoop:
|
||||
if c.statementCache == nil {
|
||||
return pgconn.CommandTag{}, errDisabledStatementCache
|
||||
}
|
||||
sd, err := c.statementCache.Get(ctx, sql)
|
||||
if err != nil {
|
||||
return pgconn.CommandTag{}, err
|
||||
sd := c.statementCache.Get(sql)
|
||||
if sd == nil {
|
||||
sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql)
|
||||
if err != nil {
|
||||
return pgconn.CommandTag{}, err
|
||||
}
|
||||
c.statementCache.Put(sd)
|
||||
}
|
||||
|
||||
return c.execPrepared(ctx, sd, arguments)
|
||||
@@ -447,9 +455,12 @@ optionLoop:
|
||||
if c.descriptionCache == nil {
|
||||
return pgconn.CommandTag{}, errDisabledDescriptionCache
|
||||
}
|
||||
sd, err := c.descriptionCache.Get(ctx, sql)
|
||||
if err != nil {
|
||||
return pgconn.CommandTag{}, err
|
||||
sd := c.descriptionCache.Get(sql)
|
||||
if sd == nil {
|
||||
sd, err = c.Prepare(ctx, "", sql)
|
||||
if err != nil {
|
||||
return pgconn.CommandTag{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return c.execParams(ctx, sd, arguments)
|
||||
@@ -620,6 +631,10 @@ type QueryRewriter interface {
|
||||
// QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely
|
||||
// needed. See the documentation for those types for details.
|
||||
func (c *Conn) Query(ctx context.Context, sql string, args ...any) (Rows, error) {
|
||||
if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
|
||||
return &baseRows{err: err, closed: true}, err
|
||||
}
|
||||
|
||||
var resultFormats QueryResultFormats
|
||||
var resultFormatsByOID QueryResultFormatsByOID
|
||||
mode := c.config.DefaultQueryExecMode
|
||||
@@ -649,6 +664,11 @@ optionLoop:
|
||||
sql, args = queryRewriter.RewriteQuery(ctx, c, sql, args)
|
||||
}
|
||||
|
||||
// Bypass any statement caching.
|
||||
if sql == "" {
|
||||
mode = QueryExecModeSimpleProtocol
|
||||
}
|
||||
|
||||
c.eqb.reset()
|
||||
anynil.NormalizeSlice(args)
|
||||
rows := c.getRows(ctx, sql, args)
|
||||
@@ -664,10 +684,14 @@ optionLoop:
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
sd, err = c.statementCache.Get(ctx, sql)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
sd = c.statementCache.Get(sql)
|
||||
if sd == nil {
|
||||
sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
c.statementCache.Put(sd)
|
||||
}
|
||||
case QueryExecModeCacheDescribe:
|
||||
if c.descriptionCache == nil {
|
||||
@@ -675,10 +699,14 @@ optionLoop:
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
sd, err = c.descriptionCache.Get(ctx, sql)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
sd = c.descriptionCache.Get(sql)
|
||||
if sd == nil {
|
||||
sd, err = c.Prepare(ctx, "", sql)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
c.descriptionCache.Put(sd)
|
||||
}
|
||||
case QueryExecModeDescribeExec:
|
||||
sd, err = c.Prepare(ctx, "", sql)
|
||||
@@ -767,6 +795,10 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row {
|
||||
// explicit transaction control statements are executed. The returned BatchResults must be closed before the connection
|
||||
// is used again.
|
||||
func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
||||
if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
|
||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
|
||||
mode := c.config.DefaultQueryExecMode
|
||||
|
||||
for _, bi := range b.items {
|
||||
@@ -794,105 +826,70 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
||||
}
|
||||
|
||||
if mode == QueryExecModeSimpleProtocol {
|
||||
var sb strings.Builder
|
||||
for i, bi := range b.items {
|
||||
if i > 0 {
|
||||
sb.WriteByte(';')
|
||||
}
|
||||
sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...)
|
||||
if err != nil {
|
||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
sb.WriteString(sql)
|
||||
}
|
||||
mrr := c.pgConn.Exec(ctx, sb.String())
|
||||
return &batchResults{
|
||||
ctx: ctx,
|
||||
conn: c,
|
||||
mrr: mrr,
|
||||
b: b,
|
||||
ix: 0,
|
||||
return c.sendBatchQueryExecModeSimpleProtocol(ctx, b)
|
||||
}
|
||||
|
||||
// All other modes use extended protocol and thus can use prepared statements.
|
||||
for _, bi := range b.items {
|
||||
if sd, ok := c.preparedStatements[bi.query]; ok {
|
||||
bi.sd = sd
|
||||
}
|
||||
}
|
||||
|
||||
switch mode {
|
||||
case QueryExecModeExec:
|
||||
return c.sendBatchQueryExecModeExec(ctx, b)
|
||||
case QueryExecModeCacheStatement:
|
||||
return c.sendBatchQueryExecModeCacheStatement(ctx, b)
|
||||
case QueryExecModeCacheDescribe:
|
||||
return c.sendBatchQueryExecModeCacheDescribe(ctx, b)
|
||||
case QueryExecModeDescribeExec:
|
||||
return c.sendBatchQueryExecModeDescribeExec(ctx, b)
|
||||
default:
|
||||
panic("unknown QueryExecMode")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults {
|
||||
var sb strings.Builder
|
||||
for i, bi := range b.items {
|
||||
if i > 0 {
|
||||
sb.WriteByte(';')
|
||||
}
|
||||
sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...)
|
||||
if err != nil {
|
||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
sb.WriteString(sql)
|
||||
}
|
||||
mrr := c.pgConn.Exec(ctx, sb.String())
|
||||
return &batchResults{
|
||||
ctx: ctx,
|
||||
conn: c,
|
||||
mrr: mrr,
|
||||
b: b,
|
||||
ix: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults {
|
||||
batch := &pgconn.Batch{}
|
||||
|
||||
if mode == QueryExecModeExec {
|
||||
for _, bi := range b.items {
|
||||
c.eqb.reset()
|
||||
anynil.NormalizeSlice(bi.arguments)
|
||||
|
||||
sd := c.preparedStatements[bi.query]
|
||||
if sd != nil {
|
||||
err := c.eqb.Build(c.typeMap, sd, bi.arguments)
|
||||
if err != nil {
|
||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
|
||||
batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
||||
} else {
|
||||
err := c.eqb.Build(c.typeMap, nil, bi.arguments)
|
||||
if err != nil {
|
||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
||||
distinctUnpreparedQueries := map[string]struct{}{}
|
||||
|
||||
for _, bi := range b.items {
|
||||
if _, ok := c.preparedStatements[bi.query]; ok {
|
||||
continue
|
||||
}
|
||||
distinctUnpreparedQueries[bi.query] = struct{}{}
|
||||
}
|
||||
|
||||
var stmtCache stmtcache.Cache
|
||||
if len(distinctUnpreparedQueries) > 0 {
|
||||
if mode == QueryExecModeCacheStatement && c.statementCache != nil && c.statementCache.Cap() >= len(distinctUnpreparedQueries) {
|
||||
stmtCache = c.statementCache
|
||||
} else if mode == QueryExecModeCacheStatement && c.descriptionCache != nil && c.descriptionCache.Cap() >= len(distinctUnpreparedQueries) {
|
||||
stmtCache = c.descriptionCache
|
||||
} else {
|
||||
stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries))
|
||||
}
|
||||
|
||||
for sql, _ := range distinctUnpreparedQueries {
|
||||
_, err := stmtCache.Get(ctx, sql)
|
||||
if err != nil {
|
||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, bi := range b.items {
|
||||
c.eqb.reset()
|
||||
|
||||
sd := c.preparedStatements[bi.query]
|
||||
if sd == nil {
|
||||
var err error
|
||||
sd, err = stmtCache.Get(ctx, bi.query)
|
||||
if err != nil {
|
||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
}
|
||||
|
||||
if len(sd.ParamOIDs) != len(bi.arguments) {
|
||||
return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")}
|
||||
}
|
||||
|
||||
for _, bi := range b.items {
|
||||
sd := bi.sd
|
||||
if sd != nil {
|
||||
err := c.eqb.Build(c.typeMap, sd, bi.arguments)
|
||||
if err != nil {
|
||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
|
||||
if sd.Name == "" {
|
||||
batch.ExecParams(bi.query, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
||||
} else {
|
||||
batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
||||
batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
||||
} else {
|
||||
err := c.eqb.Build(c.typeMap, nil, bi.arguments)
|
||||
if err != nil {
|
||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -909,6 +906,171 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
|
||||
if c.statementCache == nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache}
|
||||
}
|
||||
|
||||
distinctNewQueries := []*pgconn.StatementDescription{}
|
||||
distinctNewQueriesIdxMap := make(map[string]int)
|
||||
|
||||
for _, bi := range b.items {
|
||||
if bi.sd == nil {
|
||||
sd := c.statementCache.Get(bi.query)
|
||||
if sd != nil {
|
||||
bi.sd = sd
|
||||
} else {
|
||||
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
|
||||
bi.sd = distinctNewQueries[idx]
|
||||
} else {
|
||||
sd = &pgconn.StatementDescription{
|
||||
Name: stmtcache.NextStatementName(),
|
||||
SQL: bi.query,
|
||||
}
|
||||
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
|
||||
distinctNewQueries = append(distinctNewQueries, sd)
|
||||
bi.sd = sd
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, c.statementCache)
|
||||
}
|
||||
|
||||
func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
|
||||
if c.descriptionCache == nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache}
|
||||
}
|
||||
|
||||
distinctNewQueries := []*pgconn.StatementDescription{}
|
||||
distinctNewQueriesIdxMap := make(map[string]int)
|
||||
|
||||
for _, bi := range b.items {
|
||||
if bi.sd == nil {
|
||||
sd := c.descriptionCache.Get(bi.query)
|
||||
if sd != nil {
|
||||
bi.sd = sd
|
||||
} else {
|
||||
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
|
||||
bi.sd = distinctNewQueries[idx]
|
||||
} else {
|
||||
sd = &pgconn.StatementDescription{
|
||||
SQL: bi.query,
|
||||
}
|
||||
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
|
||||
distinctNewQueries = append(distinctNewQueries, sd)
|
||||
bi.sd = sd
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, c.descriptionCache)
|
||||
}
|
||||
|
||||
func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
|
||||
distinctNewQueries := []*pgconn.StatementDescription{}
|
||||
distinctNewQueriesIdxMap := make(map[string]int)
|
||||
|
||||
for _, bi := range b.items {
|
||||
if bi.sd == nil {
|
||||
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
|
||||
bi.sd = distinctNewQueries[idx]
|
||||
} else {
|
||||
sd := &pgconn.StatementDescription{
|
||||
SQL: bi.query,
|
||||
}
|
||||
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
|
||||
distinctNewQueries = append(distinctNewQueries, sd)
|
||||
bi.sd = sd
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, nil)
|
||||
}
|
||||
|
||||
func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, distinctNewQueries []*pgconn.StatementDescription, sdCache stmtcache.Cache) (pbr *pipelineBatchResults) {
|
||||
pipeline := c.pgConn.StartPipeline(context.Background())
|
||||
defer func() {
|
||||
if pbr.err != nil {
|
||||
pipeline.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// Prepare any needed queries
|
||||
if len(distinctNewQueries) > 0 {
|
||||
for _, sd := range distinctNewQueries {
|
||||
pipeline.SendPrepare(sd.Name, sd.SQL, nil)
|
||||
}
|
||||
|
||||
err := pipeline.Sync()
|
||||
if err != nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
|
||||
for _, sd := range distinctNewQueries {
|
||||
results, err := pipeline.GetResults()
|
||||
if err != nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
|
||||
resultSD, ok := results.(*pgconn.StatementDescription)
|
||||
if !ok {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results)}
|
||||
}
|
||||
|
||||
// Fill in the previously empty / pending statement descriptions.
|
||||
sd.ParamOIDs = resultSD.ParamOIDs
|
||||
sd.Fields = resultSD.Fields
|
||||
}
|
||||
|
||||
results, err := pipeline.GetResults()
|
||||
if err != nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
|
||||
_, ok := results.(*pgconn.PipelineSync)
|
||||
if !ok {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results)}
|
||||
}
|
||||
}
|
||||
|
||||
// Put all statements into the cache. It's fine if it overflows because HandleInvalidated will clean them up later.
|
||||
if sdCache != nil {
|
||||
for _, sd := range distinctNewQueries {
|
||||
c.statementCache.Put(sd)
|
||||
}
|
||||
}
|
||||
|
||||
// Queue the queries.
|
||||
for _, bi := range b.items {
|
||||
err := c.eqb.Build(c.typeMap, bi.sd, bi.arguments)
|
||||
if err != nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
|
||||
if bi.sd.Name == "" {
|
||||
pipeline.SendQueryParams(bi.sd.SQL, c.eqb.ParamValues, bi.sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
||||
} else {
|
||||
pipeline.SendQueryPrepared(bi.sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
||||
}
|
||||
}
|
||||
|
||||
err := pipeline.Sync()
|
||||
if err != nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
|
||||
return &pipelineBatchResults{
|
||||
ctx: ctx,
|
||||
conn: c,
|
||||
pipeline: pipeline,
|
||||
b: b,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) sanitizeForSimpleQuery(sql string, args ...any) (string, error) {
|
||||
if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" {
|
||||
return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on")
|
||||
@@ -1015,3 +1177,37 @@ order by attnum`,
|
||||
|
||||
return fields, nil
|
||||
}
|
||||
|
||||
func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error {
|
||||
if c.descriptionCache != nil {
|
||||
c.descriptionCache.HandleInvalidated()
|
||||
}
|
||||
|
||||
var invalidatedStatements []*pgconn.StatementDescription
|
||||
if c.statementCache != nil {
|
||||
invalidatedStatements = c.statementCache.HandleInvalidated()
|
||||
}
|
||||
|
||||
if len(invalidatedStatements) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
pipeline := c.pgConn.StartPipeline(ctx)
|
||||
defer pipeline.Close()
|
||||
|
||||
for _, sd := range invalidatedStatements {
|
||||
pipeline.SendDeallocate(sd.Name)
|
||||
}
|
||||
|
||||
err := pipeline.Sync()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to deallocate cached statement(s): %w", err)
|
||||
}
|
||||
|
||||
err = pipeline.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to deallocate cached statement(s): %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
+9
-8
@@ -931,6 +931,7 @@ func TestStmtCacheInvalidationConn(t *testing.T) {
|
||||
rows, err := conn.Query(ctx, getSQL, 1)
|
||||
require.NoError(t, err)
|
||||
rows.Close()
|
||||
require.NoError(t, rows.Err())
|
||||
|
||||
// Now, change the schema of the table out from under the statement, making it invalid.
|
||||
_, err = conn.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1")
|
||||
@@ -948,10 +949,10 @@ func TestStmtCacheInvalidationConn(t *testing.T) {
|
||||
rows.Close()
|
||||
for _, err := range []error{nextErr, rows.Err()} {
|
||||
if err == nil {
|
||||
t.Fatal("expected InvalidCachedStatementPlanError: no error")
|
||||
t.Fatal(`expected "cached plan must not change result type": no error`)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "cached plan must not change result type") {
|
||||
t.Fatalf("expected InvalidCachedStatementPlanError, got: %s", err.Error())
|
||||
t.Fatalf(`expected "cached plan must not change result type", got: "%s"`, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -995,6 +996,7 @@ func TestStmtCacheInvalidationTx(t *testing.T) {
|
||||
rows, err := tx.Query(ctx, getSQL, 1)
|
||||
require.NoError(t, err)
|
||||
rows.Close()
|
||||
require.NoError(t, rows.Err())
|
||||
|
||||
// Now, change the schema of the table out from under the statement, making it invalid.
|
||||
_, err = tx.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1")
|
||||
@@ -1012,18 +1014,17 @@ func TestStmtCacheInvalidationTx(t *testing.T) {
|
||||
rows.Close()
|
||||
for _, err := range []error{nextErr, rows.Err()} {
|
||||
if err == nil {
|
||||
t.Fatal("expected InvalidCachedStatementPlanError: no error")
|
||||
t.Fatal(`expected "cached plan must not change result type": no error`)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "cached plan must not change result type") {
|
||||
t.Fatalf("expected InvalidCachedStatementPlanError, got: %s", err.Error())
|
||||
t.Fatalf(`expected "cached plan must not change result type", got: "%s"`, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
rows, err = tx.Query(ctx, getSQL, 1)
|
||||
require.NoError(t, err) // error does not pop up immediately
|
||||
rows.Next()
|
||||
rows, _ = tx.Query(ctx, getSQL, 1)
|
||||
rows.Close()
|
||||
err = rows.Err()
|
||||
// Retries within the same transaction are errors (really anything except a rollbakc
|
||||
// Retries within the same transaction are errors (really anything except a rollback
|
||||
// will be an error in this transaction).
|
||||
require.Error(t, err)
|
||||
rows.Close()
|
||||
|
||||
@@ -1,169 +0,0 @@
|
||||
package stmtcache
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
var lruCount uint64
|
||||
|
||||
// LRU implements Cache with a Least Recently Used (LRU) cache.
|
||||
type LRU struct {
|
||||
conn *pgconn.PgConn
|
||||
mode int
|
||||
cap int
|
||||
prepareCount int
|
||||
m map[string]*list.Element
|
||||
l *list.List
|
||||
psNamePrefix string
|
||||
stmtsToClear []string
|
||||
}
|
||||
|
||||
// NewLRU creates a new LRU. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache.
|
||||
func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU {
|
||||
mustBeValidMode(mode)
|
||||
mustBeValidCap(cap)
|
||||
|
||||
n := atomic.AddUint64(&lruCount, 1)
|
||||
|
||||
return &LRU{
|
||||
conn: conn,
|
||||
mode: mode,
|
||||
cap: cap,
|
||||
m: make(map[string]*list.Element),
|
||||
l: list.New(),
|
||||
psNamePrefix: fmt.Sprintf("lrupsc_%d", n),
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed.
|
||||
func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
|
||||
if ctx != context.Background() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// flush an outstanding bad statements
|
||||
txStatus := c.conn.TxStatus()
|
||||
if (txStatus == 'I' || txStatus == 'T') && len(c.stmtsToClear) > 0 {
|
||||
for _, stmt := range c.stmtsToClear {
|
||||
err := c.clearStmt(ctx, stmt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if el, ok := c.m[sql]; ok {
|
||||
c.l.MoveToFront(el)
|
||||
return el.Value.(*pgconn.StatementDescription), nil
|
||||
}
|
||||
|
||||
if c.l.Len() == c.cap {
|
||||
err := c.removeOldest(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
psd, err := c.prepare(ctx, sql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
el := c.l.PushFront(psd)
|
||||
c.m[sql] = el
|
||||
|
||||
return psd, nil
|
||||
}
|
||||
|
||||
// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session.
|
||||
func (c *LRU) Clear(ctx context.Context) error {
|
||||
for c.l.Len() > 0 {
|
||||
err := c.removeOldest(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *LRU) StatementErrored(sql string, err error) {
|
||||
pgErr, ok := err.(*pgconn.PgError)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/1162
|
||||
//
|
||||
// We used to look for the message "cached plan must not change result type". However, that message can be localized.
|
||||
// Unfortunately, error code "0A000" - "FEATURE NOT SUPPORTED" is used for many different errors and the only way to
|
||||
// tell the difference is by the message. But all that happens is we clear a statement that we otherwise wouldn't
|
||||
// have so it should be safe.
|
||||
possibleInvalidCachedPlanError := pgErr.Code == "0A000"
|
||||
if possibleInvalidCachedPlanError {
|
||||
c.stmtsToClear = append(c.stmtsToClear, sql)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *LRU) clearStmt(ctx context.Context, sql string) error {
|
||||
elem, inMap := c.m[sql]
|
||||
if !inMap {
|
||||
// The statement probably fell off the back of the list. In that case, we've
|
||||
// ensured that it isn't in the cache, so we can declare victory.
|
||||
return nil
|
||||
}
|
||||
|
||||
c.l.Remove(elem)
|
||||
|
||||
psd := elem.Value.(*pgconn.StatementDescription)
|
||||
delete(c.m, psd.SQL)
|
||||
if c.mode == ModePrepare {
|
||||
return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Len returns the number of cached prepared statement descriptions.
|
||||
func (c *LRU) Len() int {
|
||||
return c.l.Len()
|
||||
}
|
||||
|
||||
// Cap returns the maximum number of cached prepared statement descriptions.
|
||||
func (c *LRU) Cap() int {
|
||||
return c.cap
|
||||
}
|
||||
|
||||
// Mode returns the mode of the cache (ModePrepare or ModeDescribe)
|
||||
func (c *LRU) Mode() int {
|
||||
return c.mode
|
||||
}
|
||||
|
||||
func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
|
||||
var name string
|
||||
if c.mode == ModePrepare {
|
||||
name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount)
|
||||
c.prepareCount += 1
|
||||
}
|
||||
|
||||
return c.conn.Prepare(ctx, name, sql, nil)
|
||||
}
|
||||
|
||||
func (c *LRU) removeOldest(ctx context.Context) error {
|
||||
oldest := c.l.Back()
|
||||
c.l.Remove(oldest)
|
||||
psd := oldest.Value.(*pgconn.StatementDescription)
|
||||
delete(c.m, psd.SQL)
|
||||
if c.mode == ModePrepare {
|
||||
return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
package stmtcache
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
// LRUCache implements Cache with a Least Recently Used (LRU) cache.
|
||||
type LRUCache struct {
|
||||
cap int
|
||||
m map[string]*list.Element
|
||||
l *list.List
|
||||
invalidStmts []*pgconn.StatementDescription
|
||||
}
|
||||
|
||||
// NewLRUCache creates a new LRUCache. cap is the maximum size of the cache.
|
||||
func NewLRUCache(cap int) *LRUCache {
|
||||
return &LRUCache{
|
||||
cap: cap,
|
||||
m: make(map[string]*list.Element),
|
||||
l: list.New(),
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns the statement description for sql. Returns nil if not found.
|
||||
func (c *LRUCache) Get(key string) *pgconn.StatementDescription {
|
||||
if el, ok := c.m[key]; ok {
|
||||
c.l.MoveToFront(el)
|
||||
return el.Value.(*pgconn.StatementDescription)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
|
||||
func (c *LRUCache) Put(sd *pgconn.StatementDescription) {
|
||||
if sd.SQL == "" {
|
||||
panic("cannot store statement description with empty SQL")
|
||||
}
|
||||
|
||||
if _, present := c.m[sd.SQL]; present {
|
||||
return
|
||||
}
|
||||
|
||||
if c.l.Len() == c.cap {
|
||||
c.invalidateOldest()
|
||||
}
|
||||
|
||||
el := c.l.PushFront(sd)
|
||||
c.m[sd.SQL] = el
|
||||
}
|
||||
|
||||
// Invalidate invalidates statement description identified by sql. Does nothing if not found.
|
||||
func (c *LRUCache) Invalidate(sql string) {
|
||||
if el, ok := c.m[sql]; ok {
|
||||
delete(c.m, sql)
|
||||
c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription))
|
||||
c.l.Remove(el)
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateAll invalidates all statement descriptions.
|
||||
func (c *LRUCache) InvalidateAll() {
|
||||
el := c.l.Front()
|
||||
for el != nil {
|
||||
c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription))
|
||||
el = el.Next()
|
||||
}
|
||||
|
||||
c.m = make(map[string]*list.Element)
|
||||
c.l = list.New()
|
||||
}
|
||||
|
||||
func (c *LRUCache) HandleInvalidated() []*pgconn.StatementDescription {
|
||||
invalidStmts := c.invalidStmts
|
||||
c.invalidStmts = nil
|
||||
return invalidStmts
|
||||
}
|
||||
|
||||
// Len returns the number of cached prepared statement descriptions.
|
||||
func (c *LRUCache) Len() int {
|
||||
return c.l.Len()
|
||||
}
|
||||
|
||||
// Cap returns the maximum number of cached prepared statement descriptions.
|
||||
func (c *LRUCache) Cap() int {
|
||||
return c.cap
|
||||
}
|
||||
|
||||
func (c *LRUCache) invalidateOldest() {
|
||||
oldest := c.l.Back()
|
||||
sd := oldest.Value.(*pgconn.StatementDescription)
|
||||
c.invalidStmts = append(c.invalidStmts, sd)
|
||||
delete(c.m, sd.SQL)
|
||||
c.l.Remove(oldest)
|
||||
}
|
||||
@@ -1,292 +0,0 @@
|
||||
package stmtcache_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/stmtcache"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLRUModePrepare(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(ctx)
|
||||
|
||||
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2)
|
||||
require.EqualValues(t, 0, cache.Len())
|
||||
require.EqualValues(t, 2, cache.Cap())
|
||||
require.EqualValues(t, stmtcache.ModePrepare, cache.Mode())
|
||||
|
||||
psd, err := cache.Get(ctx, "select 1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, psd)
|
||||
require.EqualValues(t, 1, cache.Len())
|
||||
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
psd, err = cache.Get(ctx, "select 1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, psd)
|
||||
require.EqualValues(t, 1, cache.Len())
|
||||
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
psd, err = cache.Get(ctx, "select 2")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, psd)
|
||||
require.EqualValues(t, 2, cache.Len())
|
||||
require.ElementsMatch(t, []string{"select 1", "select 2"}, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
psd, err = cache.Get(ctx, "select 3")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, psd)
|
||||
require.EqualValues(t, 2, cache.Len())
|
||||
require.ElementsMatch(t, []string{"select 2", "select 3"}, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
err = cache.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 0, cache.Len())
|
||||
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
||||
}
|
||||
|
||||
func TestLRUStmtInvalidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(ctx)
|
||||
|
||||
// we construct a fake error because its not super straightforward to actually call
|
||||
// a prepared statement from the LRU cache without the helper routines which live
|
||||
// in pgx proper.
|
||||
fakeInvalidCachePlanError := &pgconn.PgError{
|
||||
Severity: "ERROR",
|
||||
Code: "0A000",
|
||||
Message: "cached plan must not change result type",
|
||||
}
|
||||
|
||||
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2)
|
||||
|
||||
//
|
||||
// outside of a transaction, we eagerly flush the statement
|
||||
//
|
||||
|
||||
_, err = cache.Get(ctx, "select 1")
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, cache.Len())
|
||||
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
cache.StatementErrored("select 1", fakeInvalidCachePlanError)
|
||||
_, err = cache.Get(ctx, "select 2")
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, cache.Len())
|
||||
require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
err = cache.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
//
|
||||
// within an errored transaction, we defer the flush to after the first get
|
||||
// that happens after the transaction is rolled back
|
||||
//
|
||||
|
||||
_, err = cache.Get(ctx, "select 1")
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, cache.Len())
|
||||
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
res := conn.Exec(ctx, "begin")
|
||||
require.NoError(t, res.Close())
|
||||
require.Equal(t, byte('T'), conn.TxStatus())
|
||||
|
||||
res = conn.Exec(ctx, "selec")
|
||||
require.Error(t, res.Close())
|
||||
require.Equal(t, byte('E'), conn.TxStatus())
|
||||
|
||||
cache.StatementErrored("select 1", fakeInvalidCachePlanError)
|
||||
require.EqualValues(t, 1, cache.Len())
|
||||
|
||||
res = conn.Exec(ctx, "rollback")
|
||||
require.NoError(t, res.Close())
|
||||
|
||||
_, err = cache.Get(ctx, "select 2")
|
||||
require.EqualValues(t, 1, cache.Len())
|
||||
require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn))
|
||||
}
|
||||
|
||||
func TestLRUStmtInvalidationIntegration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(ctx)
|
||||
|
||||
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2)
|
||||
|
||||
result := conn.ExecParams(ctx, "create temporary table stmtcache_table (a text)", nil, nil, nil, nil).Read()
|
||||
require.NoError(t, result.Err)
|
||||
|
||||
sql := "select * from stmtcache_table"
|
||||
sd1, err := cache.Get(ctx, sql)
|
||||
require.NoError(t, err)
|
||||
|
||||
result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read()
|
||||
require.NoError(t, result.Err)
|
||||
|
||||
result = conn.ExecParams(ctx, "alter table stmtcache_table add column b text", nil, nil, nil, nil).Read()
|
||||
require.NoError(t, result.Err)
|
||||
|
||||
result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read()
|
||||
require.EqualError(t, result.Err, "ERROR: cached plan must not change result type (SQLSTATE 0A000)")
|
||||
|
||||
cache.StatementErrored(sql, result.Err)
|
||||
|
||||
sd2, err := cache.Get(ctx, sql)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, sd1.Name, sd2.Name)
|
||||
|
||||
result = conn.ExecPrepared(ctx, sd2.Name, nil, nil, nil).Read()
|
||||
require.NoError(t, result.Err)
|
||||
}
|
||||
|
||||
func TestLRUModePrepareStress(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(ctx)
|
||||
|
||||
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 8)
|
||||
require.EqualValues(t, 0, cache.Len())
|
||||
require.EqualValues(t, 8, cache.Cap())
|
||||
require.EqualValues(t, stmtcache.ModePrepare, cache.Mode())
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
psd, err := cache.Get(ctx, fmt.Sprintf("select %d", rand.Intn(50)))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, psd)
|
||||
result := conn.ExecPrepared(ctx, psd.Name, nil, nil, nil).Read()
|
||||
require.NoError(t, result.Err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLRUModeDescribe(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(ctx)
|
||||
|
||||
cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2)
|
||||
require.EqualValues(t, 0, cache.Len())
|
||||
require.EqualValues(t, 2, cache.Cap())
|
||||
require.EqualValues(t, stmtcache.ModeDescribe, cache.Mode())
|
||||
|
||||
psd, err := cache.Get(ctx, "select 1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, psd)
|
||||
require.EqualValues(t, 1, cache.Len())
|
||||
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
psd, err = cache.Get(ctx, "select 1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, psd)
|
||||
require.EqualValues(t, 1, cache.Len())
|
||||
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
psd, err = cache.Get(ctx, "select 2")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, psd)
|
||||
require.EqualValues(t, 2, cache.Len())
|
||||
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
psd, err = cache.Get(ctx, "select 3")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, psd)
|
||||
require.EqualValues(t, 2, cache.Len())
|
||||
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
err = cache.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 0, cache.Len())
|
||||
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
||||
}
|
||||
|
||||
func TestLRUContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(ctx)
|
||||
|
||||
cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2)
|
||||
|
||||
// test 1 : getting a value for the first time with a cancelled context returns an error
|
||||
ctx1, cancel1 := context.WithCancel(ctx)
|
||||
cancel1()
|
||||
|
||||
desc, err := cache.Get(ctx1, "SELECT 1")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, desc)
|
||||
|
||||
// test 2 : when querying for the 2nd time a cached value, if the context is canceled return an error
|
||||
ctx2, cancel2 := context.WithCancel(ctx)
|
||||
|
||||
desc, err = cache.Get(ctx2, "SELECT 2")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, desc)
|
||||
|
||||
cancel2()
|
||||
|
||||
desc, err = cache.Get(ctx2, "SELECT 2")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, desc)
|
||||
}
|
||||
|
||||
func fetchServerStatements(t testing.TB, ctx context.Context, conn *pgconn.PgConn) []string {
|
||||
result := conn.ExecParams(ctx, `select statement from pg_prepared_statements`, nil, nil, nil, nil).Read()
|
||||
require.NoError(t, result.Err)
|
||||
var statements []string
|
||||
for _, r := range result.Rows {
|
||||
statement := string(r[0])
|
||||
if conn.ParameterStatus("crdb_version") != "" {
|
||||
if statement == "PREPARE AS select statement from pg_prepared_statements" {
|
||||
// CockroachDB includes the currently running unnamed prepared statement while PostgreSQL does not. Ignore it.
|
||||
continue
|
||||
}
|
||||
|
||||
// CockroachDB includes the "PREPARE ... AS" text in the statement even if it was prepared through the extended
|
||||
// protocol will PostgreSQL does not. Normalize the statement.
|
||||
re := regexp.MustCompile(`^PREPARE lrupsc[0-9_]+ AS `)
|
||||
statement = re.ReplaceAllString(statement, "")
|
||||
}
|
||||
statements = append(statements, statement)
|
||||
}
|
||||
return statements
|
||||
}
|
||||
@@ -2,57 +2,56 @@
|
||||
package stmtcache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
const (
|
||||
ModePrepare = iota // Cache should prepare named statements.
|
||||
ModeDescribe // Cache should prepare the anonymous prepared statement to only fetch the description of the statement.
|
||||
)
|
||||
var stmtCounter int64
|
||||
|
||||
// Cache prepares and caches prepared statement descriptions.
|
||||
// NextStatementName returns a statement name that will be unique for the lifetime of the program.
|
||||
func NextStatementName() string {
|
||||
n := atomic.AddInt64(&stmtCounter, 1)
|
||||
return "stmtcache_" + strconv.FormatInt(n, 10)
|
||||
}
|
||||
|
||||
// Cache caches statement descriptions.
|
||||
type Cache interface {
|
||||
// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed.
|
||||
Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error)
|
||||
// Get returns the statement description for sql. Returns nil if not found.
|
||||
Get(sql string) *pgconn.StatementDescription
|
||||
|
||||
// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session.
|
||||
Clear(ctx context.Context) error
|
||||
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
|
||||
Put(sd *pgconn.StatementDescription)
|
||||
|
||||
// StatementErrored informs the cache that the given statement resulted in an error when it
|
||||
// was last used against the database. In some cases, this will cause the cache to maer that
|
||||
// statement as bad. The bad statement will instead be flushed during the next call to Get
|
||||
// that occurs outside of a failed transaction.
|
||||
StatementErrored(sql string, err error)
|
||||
// Invalidate invalidates statement description identified by sql. Does nothing if not found.
|
||||
Invalidate(sql string)
|
||||
|
||||
// InvalidateAll invalidates all statement descriptions.
|
||||
InvalidateAll()
|
||||
|
||||
// HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated.
|
||||
HandleInvalidated() []*pgconn.StatementDescription
|
||||
|
||||
// Len returns the number of cached prepared statement descriptions.
|
||||
Len() int
|
||||
|
||||
// Cap returns the maximum number of cached prepared statement descriptions.
|
||||
Cap() int
|
||||
|
||||
// Mode returns the mode of the cache (ModePrepare or ModeDescribe)
|
||||
Mode() int
|
||||
}
|
||||
|
||||
// New returns the preferred cache implementation for mode and cap. mode is either ModePrepare or ModeDescribe. cap is
|
||||
// the maximum size of the cache.
|
||||
func New(conn *pgconn.PgConn, mode int, cap int) Cache {
|
||||
mustBeValidMode(mode)
|
||||
mustBeValidCap(cap)
|
||||
|
||||
return NewLRU(conn, mode, cap)
|
||||
}
|
||||
|
||||
func mustBeValidMode(mode int) {
|
||||
if mode != ModePrepare && mode != ModeDescribe {
|
||||
panic("mode must be ModePrepare or ModeDescribe")
|
||||
func IsStatementInvalid(err error) bool {
|
||||
pgErr, ok := err.(*pgconn.PgError)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func mustBeValidCap(cap int) {
|
||||
if cap < 1 {
|
||||
panic("cache must have cap of >= 1")
|
||||
}
|
||||
// https://github.com/jackc/pgx/issues/1162
|
||||
//
|
||||
// We used to look for the message "cached plan must not change result type". However, that message can be localized.
|
||||
// Unfortunately, error code "0A000" - "FEATURE NOT SUPPORTED" is used for many different errors and the only way to
|
||||
// tell the difference is by the message. But all that happens is we clear a statement that we otherwise wouldn't
|
||||
// have so it should be safe.
|
||||
possibleInvalidCachedPlanError := pgErr.Code == "0A000"
|
||||
return possibleInvalidCachedPlanError
|
||||
}
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
package stmtcache
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
// UnlimitedCache implements Cache with no capacity limit.
|
||||
type UnlimitedCache struct {
|
||||
m map[string]*pgconn.StatementDescription
|
||||
invalidStmts []*pgconn.StatementDescription
|
||||
}
|
||||
|
||||
// NewUnlimitedCache creates a new UnlimitedCache.
|
||||
func NewUnlimitedCache() *UnlimitedCache {
|
||||
return &UnlimitedCache{
|
||||
m: make(map[string]*pgconn.StatementDescription),
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns the statement description for sql. Returns nil if not found.
|
||||
func (c *UnlimitedCache) Get(sql string) *pgconn.StatementDescription {
|
||||
return c.m[sql]
|
||||
}
|
||||
|
||||
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache.
|
||||
func (c *UnlimitedCache) Put(sd *pgconn.StatementDescription) {
|
||||
if sd.SQL == "" {
|
||||
panic("cannot store statement description with empty SQL")
|
||||
}
|
||||
|
||||
if _, present := c.m[sd.SQL]; present {
|
||||
return
|
||||
}
|
||||
|
||||
c.m[sd.SQL] = sd
|
||||
}
|
||||
|
||||
// Invalidate invalidates statement description identified by sql. Does nothing if not found.
|
||||
func (c *UnlimitedCache) Invalidate(sql string) {
|
||||
if sd, ok := c.m[sql]; ok {
|
||||
delete(c.m, sql)
|
||||
c.invalidStmts = append(c.invalidStmts, sd)
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateAll invalidates all statement descriptions.
|
||||
func (c *UnlimitedCache) InvalidateAll() {
|
||||
for _, sd := range c.m {
|
||||
c.invalidStmts = append(c.invalidStmts, sd)
|
||||
}
|
||||
|
||||
c.m = make(map[string]*pgconn.StatementDescription)
|
||||
}
|
||||
|
||||
func (c *UnlimitedCache) HandleInvalidated() []*pgconn.StatementDescription {
|
||||
invalidStmts := c.invalidStmts
|
||||
c.invalidStmts = nil
|
||||
return invalidStmts
|
||||
}
|
||||
|
||||
// Len returns the number of cached prepared statement descriptions.
|
||||
func (c *UnlimitedCache) Len() int {
|
||||
return len(c.m)
|
||||
}
|
||||
|
||||
// Cap returns the maximum number of cached prepared statement descriptions.
|
||||
func (c *UnlimitedCache) Cap() int {
|
||||
return math.MaxInt
|
||||
}
|
||||
+1
-1
@@ -1909,10 +1909,10 @@ func (p *Pipeline) Close() error {
|
||||
for p.expectedReadyForQueryCount > 0 {
|
||||
_, err := p.GetResults()
|
||||
if err != nil {
|
||||
p.err = err
|
||||
var pgErr *PgError
|
||||
if !errors.As(err, &pgErr) {
|
||||
p.conn.asyncClose()
|
||||
p.err = err
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/stmtcache"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
@@ -173,8 +174,16 @@ func (rows *baseRows) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
if rows.err != nil && rows.conn != nil && rows.conn.statementCache != nil {
|
||||
rows.conn.statementCache.StatementErrored(rows.sql, rows.err)
|
||||
if rows.err != nil && rows.conn != nil && rows.sql != "" {
|
||||
if stmtcache.IsStatementInvalid(rows.err) {
|
||||
if sc := rows.conn.statementCache; sc != nil {
|
||||
sc.Invalidate(rows.sql)
|
||||
}
|
||||
|
||||
if sc := rows.conn.descriptionCache; sc != nil {
|
||||
sc.Invalidate(rows.sql)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user