From f5eead90fca09203d8af956fea01861884ed9a8a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 19 Sep 2019 21:04:14 -0500 Subject: [PATCH] Fix statement cache reuse bug --- stmtcache/lru.go | 4 +++- stmtcache/lru_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/stmtcache/lru.go b/stmtcache/lru.go index fff4d0b7..d82ced19 100644 --- a/stmtcache/lru.go +++ b/stmtcache/lru.go @@ -104,8 +104,10 @@ func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.StatementDescrip func (c *LRU) removeOldest(ctx context.Context) error { oldest := c.l.Back() c.l.Remove(oldest) + psd := oldest.Value.(*pgconn.StatementDescription) + delete(c.m, psd.SQL) if c.mode == ModePrepare { - return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", oldest.Value.(*pgconn.StatementDescription).Name)).Close() + return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close() } return nil } diff --git a/stmtcache/lru_test.go b/stmtcache/lru_test.go index b518364e..d2902dbb 100644 --- a/stmtcache/lru_test.go +++ b/stmtcache/lru_test.go @@ -2,6 +2,8 @@ package stmtcache_test import ( "context" + "fmt" + "math/rand" "os" "testing" "time" @@ -57,6 +59,30 @@ func TestLRUModePrepare(t *testing.T) { require.Empty(t, fetchServerStatements(t, ctx, conn)) } +func TestLRUModePrepareStress(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer conn.Close(ctx) + + cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 8) + require.EqualValues(t, 0, cache.Len()) + require.EqualValues(t, 8, cache.Cap()) + require.EqualValues(t, stmtcache.ModePrepare, cache.Mode()) + + for i := 0; i < 1000; i++ { + psd, err := cache.Get(ctx, fmt.Sprintf("select %d", rand.Intn(50))) + require.NoError(t, err) + require.NotNil(t, psd) + result := conn.ExecPrepared(ctx, psd.Name, nil, nil, nil).Read() + require.NoError(t, result.Err) + } +} + func TestLRUModeDescribe(t *testing.T) { t.Parallel()