Initial privatization of stmtcache
ConnConfig.BuildStatementCache is pending removal once connections always have separate caches for prepared and described statements.
This commit is contained in:
@@ -0,0 +1,165 @@
|
||||
package stmtcache
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
var lruCount uint64
|
||||
|
||||
// LRU implements Cache with a Least Recently Used (LRU) cache.
|
||||
type LRU struct {
|
||||
conn *pgconn.PgConn
|
||||
mode int
|
||||
cap int
|
||||
prepareCount int
|
||||
m map[string]*list.Element
|
||||
l *list.List
|
||||
psNamePrefix string
|
||||
stmtsToClear []string
|
||||
}
|
||||
|
||||
// NewLRU creates a new LRU. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache.
|
||||
func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU {
|
||||
mustBeValidMode(mode)
|
||||
mustBeValidCap(cap)
|
||||
|
||||
n := atomic.AddUint64(&lruCount, 1)
|
||||
|
||||
return &LRU{
|
||||
conn: conn,
|
||||
mode: mode,
|
||||
cap: cap,
|
||||
m: make(map[string]*list.Element),
|
||||
l: list.New(),
|
||||
psNamePrefix: fmt.Sprintf("lrupsc_%d", n),
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed.
|
||||
func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
|
||||
if ctx != context.Background() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// flush an outstanding bad statements
|
||||
txStatus := c.conn.TxStatus()
|
||||
if (txStatus == 'I' || txStatus == 'T') && len(c.stmtsToClear) > 0 {
|
||||
for _, stmt := range c.stmtsToClear {
|
||||
err := c.clearStmt(ctx, stmt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if el, ok := c.m[sql]; ok {
|
||||
c.l.MoveToFront(el)
|
||||
return el.Value.(*pgconn.StatementDescription), nil
|
||||
}
|
||||
|
||||
if c.l.Len() == c.cap {
|
||||
err := c.removeOldest(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
psd, err := c.prepare(ctx, sql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
el := c.l.PushFront(psd)
|
||||
c.m[sql] = el
|
||||
|
||||
return psd, nil
|
||||
}
|
||||
|
||||
// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session.
|
||||
func (c *LRU) Clear(ctx context.Context) error {
|
||||
for c.l.Len() > 0 {
|
||||
err := c.removeOldest(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *LRU) StatementErrored(sql string, err error) {
|
||||
pgErr, ok := err.(*pgconn.PgError)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
isInvalidCachedPlanError := pgErr.Severity == "ERROR" &&
|
||||
pgErr.Code == "0A000" &&
|
||||
pgErr.Message == "cached plan must not change result type"
|
||||
if isInvalidCachedPlanError {
|
||||
c.stmtsToClear = append(c.stmtsToClear, sql)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *LRU) clearStmt(ctx context.Context, sql string) error {
|
||||
elem, inMap := c.m[sql]
|
||||
if !inMap {
|
||||
// The statement probably fell off the back of the list. In that case, we've
|
||||
// ensured that it isn't in the cache, so we can declare victory.
|
||||
return nil
|
||||
}
|
||||
|
||||
c.l.Remove(elem)
|
||||
|
||||
psd := elem.Value.(*pgconn.StatementDescription)
|
||||
delete(c.m, psd.SQL)
|
||||
if c.mode == ModePrepare {
|
||||
return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Len returns the number of cached prepared statement descriptions.
|
||||
func (c *LRU) Len() int {
|
||||
return c.l.Len()
|
||||
}
|
||||
|
||||
// Cap returns the maximum number of cached prepared statement descriptions.
|
||||
func (c *LRU) Cap() int {
|
||||
return c.cap
|
||||
}
|
||||
|
||||
// Mode returns the mode of the cache (ModePrepare or ModeDescribe)
|
||||
func (c *LRU) Mode() int {
|
||||
return c.mode
|
||||
}
|
||||
|
||||
func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
|
||||
var name string
|
||||
if c.mode == ModePrepare {
|
||||
name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount)
|
||||
c.prepareCount += 1
|
||||
}
|
||||
|
||||
return c.conn.Prepare(ctx, name, sql, nil)
|
||||
}
|
||||
|
||||
func (c *LRU) removeOldest(ctx context.Context) error {
|
||||
oldest := c.l.Back()
|
||||
c.l.Remove(oldest)
|
||||
psd := oldest.Value.(*pgconn.StatementDescription)
|
||||
delete(c.m, psd.SQL)
|
||||
if c.mode == ModePrepare {
|
||||
return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,292 @@
|
||||
package stmtcache_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/stmtcache"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLRUModePrepare(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(ctx)
|
||||
|
||||
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2)
|
||||
require.EqualValues(t, 0, cache.Len())
|
||||
require.EqualValues(t, 2, cache.Cap())
|
||||
require.EqualValues(t, stmtcache.ModePrepare, cache.Mode())
|
||||
|
||||
psd, err := cache.Get(ctx, "select 1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, psd)
|
||||
require.EqualValues(t, 1, cache.Len())
|
||||
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
psd, err = cache.Get(ctx, "select 1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, psd)
|
||||
require.EqualValues(t, 1, cache.Len())
|
||||
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
psd, err = cache.Get(ctx, "select 2")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, psd)
|
||||
require.EqualValues(t, 2, cache.Len())
|
||||
require.ElementsMatch(t, []string{"select 1", "select 2"}, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
psd, err = cache.Get(ctx, "select 3")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, psd)
|
||||
require.EqualValues(t, 2, cache.Len())
|
||||
require.ElementsMatch(t, []string{"select 2", "select 3"}, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
err = cache.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 0, cache.Len())
|
||||
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
||||
}
|
||||
|
||||
func TestLRUStmtInvalidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(ctx)
|
||||
|
||||
// we construct a fake error because its not super straightforward to actually call
|
||||
// a prepared statement from the LRU cache without the helper routines which live
|
||||
// in pgx proper.
|
||||
fakeInvalidCachePlanError := &pgconn.PgError{
|
||||
Severity: "ERROR",
|
||||
Code: "0A000",
|
||||
Message: "cached plan must not change result type",
|
||||
}
|
||||
|
||||
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2)
|
||||
|
||||
//
|
||||
// outside of a transaction, we eagerly flush the statement
|
||||
//
|
||||
|
||||
_, err = cache.Get(ctx, "select 1")
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, cache.Len())
|
||||
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
cache.StatementErrored("select 1", fakeInvalidCachePlanError)
|
||||
_, err = cache.Get(ctx, "select 2")
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, cache.Len())
|
||||
require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
err = cache.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
//
|
||||
// within an errored transaction, we defer the flush to after the first get
|
||||
// that happens after the transaction is rolled back
|
||||
//
|
||||
|
||||
_, err = cache.Get(ctx, "select 1")
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, cache.Len())
|
||||
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
res := conn.Exec(ctx, "begin")
|
||||
require.NoError(t, res.Close())
|
||||
require.Equal(t, byte('T'), conn.TxStatus())
|
||||
|
||||
res = conn.Exec(ctx, "selec")
|
||||
require.Error(t, res.Close())
|
||||
require.Equal(t, byte('E'), conn.TxStatus())
|
||||
|
||||
cache.StatementErrored("select 1", fakeInvalidCachePlanError)
|
||||
require.EqualValues(t, 1, cache.Len())
|
||||
|
||||
res = conn.Exec(ctx, "rollback")
|
||||
require.NoError(t, res.Close())
|
||||
|
||||
_, err = cache.Get(ctx, "select 2")
|
||||
require.EqualValues(t, 1, cache.Len())
|
||||
require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn))
|
||||
}
|
||||
|
||||
func TestLRUStmtInvalidationIntegration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(ctx)
|
||||
|
||||
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2)
|
||||
|
||||
result := conn.ExecParams(ctx, "create temporary table stmtcache_table (a text)", nil, nil, nil, nil).Read()
|
||||
require.NoError(t, result.Err)
|
||||
|
||||
sql := "select * from stmtcache_table"
|
||||
sd1, err := cache.Get(ctx, sql)
|
||||
require.NoError(t, err)
|
||||
|
||||
result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read()
|
||||
require.NoError(t, result.Err)
|
||||
|
||||
result = conn.ExecParams(ctx, "alter table stmtcache_table add column b text", nil, nil, nil, nil).Read()
|
||||
require.NoError(t, result.Err)
|
||||
|
||||
result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read()
|
||||
require.EqualError(t, result.Err, "ERROR: cached plan must not change result type (SQLSTATE 0A000)")
|
||||
|
||||
cache.StatementErrored(sql, result.Err)
|
||||
|
||||
sd2, err := cache.Get(ctx, sql)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, sd1.Name, sd2.Name)
|
||||
|
||||
result = conn.ExecPrepared(ctx, sd2.Name, nil, nil, nil).Read()
|
||||
require.NoError(t, result.Err)
|
||||
}
|
||||
|
||||
func TestLRUModePrepareStress(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(ctx)
|
||||
|
||||
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 8)
|
||||
require.EqualValues(t, 0, cache.Len())
|
||||
require.EqualValues(t, 8, cache.Cap())
|
||||
require.EqualValues(t, stmtcache.ModePrepare, cache.Mode())
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
psd, err := cache.Get(ctx, fmt.Sprintf("select %d", rand.Intn(50)))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, psd)
|
||||
result := conn.ExecPrepared(ctx, psd.Name, nil, nil, nil).Read()
|
||||
require.NoError(t, result.Err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLRUModeDescribe(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(ctx)
|
||||
|
||||
cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2)
|
||||
require.EqualValues(t, 0, cache.Len())
|
||||
require.EqualValues(t, 2, cache.Cap())
|
||||
require.EqualValues(t, stmtcache.ModeDescribe, cache.Mode())
|
||||
|
||||
psd, err := cache.Get(ctx, "select 1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, psd)
|
||||
require.EqualValues(t, 1, cache.Len())
|
||||
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
psd, err = cache.Get(ctx, "select 1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, psd)
|
||||
require.EqualValues(t, 1, cache.Len())
|
||||
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
psd, err = cache.Get(ctx, "select 2")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, psd)
|
||||
require.EqualValues(t, 2, cache.Len())
|
||||
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
psd, err = cache.Get(ctx, "select 3")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, psd)
|
||||
require.EqualValues(t, 2, cache.Len())
|
||||
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
err = cache.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 0, cache.Len())
|
||||
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
||||
}
|
||||
|
||||
func TestLRUContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(ctx)
|
||||
|
||||
cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2)
|
||||
|
||||
// test 1 : getting a value for the first time with a cancelled context returns an error
|
||||
ctx1, cancel1 := context.WithCancel(ctx)
|
||||
cancel1()
|
||||
|
||||
desc, err := cache.Get(ctx1, "SELECT 1")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, desc)
|
||||
|
||||
// test 2 : when querying for the 2nd time a cached value, if the context is canceled return an error
|
||||
ctx2, cancel2 := context.WithCancel(ctx)
|
||||
|
||||
desc, err = cache.Get(ctx2, "SELECT 2")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, desc)
|
||||
|
||||
cancel2()
|
||||
|
||||
desc, err = cache.Get(ctx2, "SELECT 2")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, desc)
|
||||
}
|
||||
|
||||
func fetchServerStatements(t testing.TB, ctx context.Context, conn *pgconn.PgConn) []string {
|
||||
result := conn.ExecParams(ctx, `select statement from pg_prepared_statements`, nil, nil, nil, nil).Read()
|
||||
require.NoError(t, result.Err)
|
||||
var statements []string
|
||||
for _, r := range result.Rows {
|
||||
statement := string(r[0])
|
||||
if conn.ParameterStatus("crdb_version") != "" {
|
||||
if statement == "PREPARE AS select statement from pg_prepared_statements" {
|
||||
// CockroachDB includes the currently running unnamed prepared statement while PostgreSQL does not. Ignore it.
|
||||
continue
|
||||
}
|
||||
|
||||
// CockroachDB includes the "PREPARE ... AS" text in the statement even if it was prepared through the extended
|
||||
// protocol will PostgreSQL does not. Normalize the statement.
|
||||
re := regexp.MustCompile(`^PREPARE lrupsc[0-9_]+ AS `)
|
||||
statement = re.ReplaceAllString(statement, "")
|
||||
}
|
||||
statements = append(statements, statement)
|
||||
}
|
||||
return statements
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
// Package stmtcache is a cache that can be used to implement lazy prepared statements.
|
||||
package stmtcache
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
const (
|
||||
ModePrepare = iota // Cache should prepare named statements.
|
||||
ModeDescribe // Cache should prepare the anonymous prepared statement to only fetch the description of the statement.
|
||||
)
|
||||
|
||||
// Cache prepares and caches prepared statement descriptions.
|
||||
type Cache interface {
|
||||
// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed.
|
||||
Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error)
|
||||
|
||||
// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session.
|
||||
Clear(ctx context.Context) error
|
||||
|
||||
// StatementErrored informs the cache that the given statement resulted in an error when it
|
||||
// was last used against the database. In some cases, this will cause the cache to maer that
|
||||
// statement as bad. The bad statement will instead be flushed during the next call to Get
|
||||
// that occurs outside of a failed transaction.
|
||||
StatementErrored(sql string, err error)
|
||||
|
||||
// Len returns the number of cached prepared statement descriptions.
|
||||
Len() int
|
||||
|
||||
// Cap returns the maximum number of cached prepared statement descriptions.
|
||||
Cap() int
|
||||
|
||||
// Mode returns the mode of the cache (ModePrepare or ModeDescribe)
|
||||
Mode() int
|
||||
}
|
||||
|
||||
// New returns the preferred cache implementation for mode and cap. mode is either ModePrepare or ModeDescribe. cap is
|
||||
// the maximum size of the cache.
|
||||
func New(conn *pgconn.PgConn, mode int, cap int) Cache {
|
||||
mustBeValidMode(mode)
|
||||
mustBeValidCap(cap)
|
||||
|
||||
return NewLRU(conn, mode, cap)
|
||||
}
|
||||
|
||||
func mustBeValidMode(mode int) {
|
||||
if mode != ModePrepare && mode != ModeDescribe {
|
||||
panic("mode must be ModePrepare or ModeDescribe")
|
||||
}
|
||||
}
|
||||
|
||||
func mustBeValidCap(cap int) {
|
||||
if cap < 1 {
|
||||
panic("cache must have cap of >= 1")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user