From a885de9c949c36c1359edc6de00cff0bc4b16bb1 Mon Sep 17 00:00:00 2001 From: Ethan Pailes Date: Mon, 9 Nov 2020 08:20:34 -0500 Subject: [PATCH] stmtcache: add new StatementErrored method This patch adds a new StatementErrored method to the stmtcache. This routine MUST be called by users of the cache whenever the execution of a statement results in an error. This will allow the cache to make an intelligent decision about whether or not the statement needs to be purged from the cache. --- stmtcache/lru.go | 50 ++++++++++++++++++++++++++++++ stmtcache/lru_test.go | 69 ++++++++++++++++++++++++++++++++++++++++++ stmtcache/stmtcache.go | 8 +++++ 3 files changed, 127 insertions(+) diff --git a/stmtcache/lru.go b/stmtcache/lru.go index d82ced19..2f183f90 100644 --- a/stmtcache/lru.go +++ b/stmtcache/lru.go @@ -20,6 +20,7 @@ type LRU struct { m map[string]*list.Element l *list.List psNamePrefix string + stmtsToClear []string } // NewLRU creates a new LRU. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache. @@ -41,6 +42,17 @@ func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU { // Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) { + // flush an outstanding bad statements + txStatus := c.conn.TxStatus() + if (txStatus == 'I' || txStatus == 'T') && len(c.stmtsToClear) > 0 { + for _, stmt := range c.stmtsToClear { + err := c.clearStmt(ctx, stmt) + if err != nil { + return nil, err + } + } + } + if el, ok := c.m[sql]; ok { c.l.MoveToFront(el) return el.Value.(*pgconn.StatementDescription), nil @@ -76,6 +88,44 @@ func (c *LRU) Clear(ctx context.Context) error { return nil } +func (c *LRU) StatementErrored(ctx context.Context, sql string, err error) error { + pgErr, ok := err.(*pgconn.PgError) + if !ok { + // we don't know how to handle this error + return nil + } + + isInvalidCachedPlanError := pgErr.Severity == "ERROR" && + pgErr.Code == "0A000" && + pgErr.Message == "cached plan must not change result type" + if !isInvalidCachedPlanError { + // only flush if a plan has been changed out from under us + return nil + } + + c.stmtsToClear = append(c.stmtsToClear, sql) + + return nil +} + +func (c *LRU) clearStmt(ctx context.Context, sql string) error { + elem, inMap := c.m[sql] + if !inMap { + // The statement probably fell off the back of the list. In that case, we've + // ensured that it isn't in the cache, so we can declare victory. + return nil + } + + c.l.Remove(elem) + + psd := elem.Value.(*pgconn.StatementDescription) + delete(c.m, psd.SQL) + if c.mode == ModePrepare { + return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close() + } + return nil +} + // Len returns the number of cached prepared statement descriptions. func (c *LRU) Len() int { return c.l.Len() diff --git a/stmtcache/lru_test.go b/stmtcache/lru_test.go index d2902dbb..75925509 100644 --- a/stmtcache/lru_test.go +++ b/stmtcache/lru_test.go @@ -59,6 +59,75 @@ func TestLRUModePrepare(t *testing.T) { require.Empty(t, fetchServerStatements(t, ctx, conn)) } +func TestLRUStmtInvalidation(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer conn.Close(ctx) + + // we construct a fake error because its not super straightforward to actually call + // a prepared statement from the LRU cache without the helper routines which live + // in pgx proper. + fakeInvalidCachePlanError := &pgconn.PgError{ + Severity: "ERROR", + Code: "0A000", + Message: "cached plan must not change result type", + } + + cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2) + + // + // outside of a transaction, we eagerly flush the statement + // + + _, err = cache.Get(ctx, "select 1") + require.NoError(t, err) + require.EqualValues(t, 1, cache.Len()) + require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) + + err = cache.StatementErrored(ctx, "select 1", fakeInvalidCachePlanError) + require.NoError(t, err) + _, err = cache.Get(ctx, "select 2") + require.NoError(t, err) + require.EqualValues(t, 1, cache.Len()) + require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn)) + + err = cache.Clear(ctx) + require.NoError(t, err) + + // + // within an errored transaction, we defer the flush to after the first get + // that happens after the transaction is rolled back + // + + _, err = cache.Get(ctx, "select 1") + require.NoError(t, err) + require.EqualValues(t, 1, cache.Len()) + require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) + + res := conn.Exec(ctx, "begin") + require.NoError(t, res.Close()) + require.Equal(t, byte('T'), conn.TxStatus()) + + res = conn.Exec(ctx, "selec") + require.Error(t, res.Close()) + require.Equal(t, byte('E'), conn.TxStatus()) + + err = cache.StatementErrored(ctx, "select 1", fakeInvalidCachePlanError) + require.EqualValues(t, 1, cache.Len()) + + res = conn.Exec(ctx, "rollback") + require.NoError(t, res.Close()) + + _, err = cache.Get(ctx, "select 2") + require.EqualValues(t, 1, cache.Len()) + require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn)) +} + func TestLRUModePrepareStress(t *testing.T) { t.Parallel() diff --git a/stmtcache/stmtcache.go b/stmtcache/stmtcache.go index 96215799..6e88ba54 100644 --- a/stmtcache/stmtcache.go +++ b/stmtcache/stmtcache.go @@ -20,6 +20,14 @@ type Cache interface { // Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session. Clear(ctx context.Context) error + // StatementErrored informs the cache that the given statement resulted in an error when it + // was last used against the database. In some cases, this will cause the cache to flush + // the statement from the cache. It will only do so when the underlying `*pgconn.PgConn` + // is not currently in a transaction. If the connection is in the middle of a transaction, + // the bad statement will instead be flushed during the next call to Get that occurrs outside + // of a transaction. + StatementErrored(ctx context.Context, sql string, err error) error + // Len returns the number of cached prepared statement descriptions. Len() int