From 210a217818576dcd21659594d289917b05ee7654 Mon Sep 17 00:00:00 2001 From: Robert Froehlich Date: Sat, 2 Jan 2021 15:08:59 -0800 Subject: [PATCH] Add BeforeConnect callback to pgxpool.Config. This allows for connection settings to be updated without having to create a new pool. The callback is passed a copy of the pgx.ConnConfig and will not impact existing live connections. --- pgxpool/pool.go | 17 ++++++++++++++++- pgxpool/pool_test.go | 21 +++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/pgxpool/pool.go b/pgxpool/pool.go index a9d1df65..503b8617 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -70,6 +70,7 @@ func (cr *connResource) getPoolRows(c *Conn, r pgx.Rows) *poolRows { type Pool struct { p *puddle.Pool config *Config + beforeConnect func(context.Context, *pgx.ConnConfig) error afterConnect func(context.Context, *pgx.Conn) error beforeAcquire func(context.Context, *pgx.Conn) bool afterRelease func(*pgx.Conn) bool @@ -85,6 +86,10 @@ type Pool struct { type Config struct { ConnConfig *pgx.ConnConfig + // BeforeConnect is called before a new connection is made. It is passed a copy of the underlying pgx.ConnConfig and + // will not impact any existing open connections. + BeforeConnect func(context.Context, *pgx.ConnConfig) error + // AfterConnect is called after a connection is established, but before it is added to the pool. AfterConnect func(context.Context, *pgx.Conn) error @@ -155,6 +160,7 @@ func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) { p := &Pool{ config: config, + beforeConnect: config.BeforeConnect, afterConnect: config.AfterConnect, beforeAcquire: config.BeforeAcquire, afterRelease: config.AfterRelease, @@ -167,7 +173,16 @@ func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) { p.p = puddle.NewPool( func(ctx context.Context) (interface{}, error) { - conn, err := pgx.ConnectConfig(ctx, config.ConnConfig) + connConfig := p.config.ConnConfig + + if p.beforeConnect != nil { + connConfig = p.config.ConnConfig.Copy() + if err := p.beforeConnect(ctx, connConfig); err != nil { + return nil, err + } + } + + conn, err := pgx.ConnectConfig(ctx, connConfig) if err != nil { return nil, err } diff --git a/pgxpool/pool_test.go b/pgxpool/pool_test.go index 4cc1e1a3..55e931cb 100644 --- a/pgxpool/pool_test.go +++ b/pgxpool/pool_test.go @@ -112,6 +112,27 @@ func TestPoolAcquireAndConnRelease(t *testing.T) { c.Release() } +func TestPoolBeforeConnect(t *testing.T) { + t.Parallel() + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + config.BeforeConnect = func(ctx context.Context, cfg *pgx.ConnConfig) error { + cfg.Config.RuntimeParams["application_name"] = "pgx" + return nil + } + + db, err := pgxpool.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer db.Close() + + var str string + err = db.QueryRow(context.Background(), "SHOW application_name").Scan(&str) + require.NoError(t, err) + assert.EqualValues(t, "pgx", str) +} + func TestPoolAfterConnect(t *testing.T) { t.Parallel()