diff --git a/stdlib/sql.go b/stdlib/sql.go index eca1a863..fa81e73d 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -126,6 +126,15 @@ func OptionAfterConnect(ac func(context.Context, *pgx.Conn) error) OptionOpenDB } } +// OptionResetSession provides a callback that can be used to add custom logic prior to executing a query on the +// connection if the connection has been used before. +// If ResetSessionFunc returns ErrBadConn error the connection will be discarded. +func OptionResetSession(rs func(context.Context, *pgx.Conn) error) OptionOpenDB { + return func(dc *connector) { + dc.ResetSession = rs + } +} + // RandomizeHostOrderFunc is a BeforeConnect hook that randomizes the host order in the provided connConfig, so that a // new host becomes primary each time. This is useful to distribute connections for multi-master databases like // CockroachDB. If you use this you likely should set https://golang.org/pkg/database/sql/#DB.SetConnMaxLifetime as well @@ -159,6 +168,7 @@ func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB { ConnConfig: config, BeforeConnect: func(context.Context, *pgx.ConnConfig) error { return nil }, // noop before connect by default AfterConnect: func(context.Context, *pgx.Conn) error { return nil }, // noop after connect by default + ResetSession: func(context.Context, *pgx.Conn) error { return nil }, // noop reset session by default driver: pgxDriver, } @@ -173,6 +183,7 @@ type connector struct { pgx.ConnConfig BeforeConnect func(context.Context, *pgx.ConnConfig) error // function to call before creation of every new connection AfterConnect func(context.Context, *pgx.Conn) error // function to call after creation of every new connection + ResetSession func(context.Context, *pgx.Conn) error // function is called before a connection is reused driver *Driver } @@ -197,7 +208,7 @@ func (c connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } - return &Conn{conn: conn, driver: c.driver, connConfig: connConfig}, nil + return &Conn{conn: conn, driver: c.driver, connConfig: connConfig, resetSessionFunc: c.ResetSession}, nil } // Driver implement driver.Connector interface @@ -272,7 +283,13 @@ func (dc *driverConnector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } - c := &Conn{conn: conn, driver: dc.driver, connConfig: *connConfig} + c := &Conn{ + conn: conn, + driver: dc.driver, + connConfig: *connConfig, + resetSessionFunc: func(context.Context, *pgx.Conn) error { return nil }, + } + return c, nil } @@ -291,10 +308,11 @@ func UnregisterConnConfig(connStr string) { } type Conn struct { - conn *pgx.Conn - psCount int64 // Counter used for creating unique prepared statement names - driver *Driver - connConfig pgx.ConnConfig + conn *pgx.Conn + psCount int64 // Counter used for creating unique prepared statement names + driver *Driver + connConfig pgx.ConnConfig + resetSessionFunc func(context.Context, *pgx.Conn) error // Function is called before a connection is reused } // Conn returns the underlying *pgx.Conn @@ -436,7 +454,8 @@ func (c *Conn) ResetSession(ctx context.Context) error { if c.conn.IsClosed() { return driver.ErrBadConn } - return nil + + return c.resetSessionFunc(ctx, c.conn) } type Stmt struct { diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index eb08be7d..099320c0 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -1202,3 +1202,26 @@ func TestRandomizeHostOrderFunc(t *testing.T) { require.Fail(t, "did not get all hosts as primaries after many randomizations") } + +func TestResetSessionHookCalled(t *testing.T) { + var mockCalled bool + + connConfig, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + db := stdlib.OpenDB(*connConfig, stdlib.OptionResetSession(func(ctx context.Context, conn *pgx.Conn) error { + mockCalled = true + + return nil + })) + + defer closeDB(t, db) + + err = db.Ping() + require.NoError(t, err) + + err = db.Ping() + require.NoError(t, err) + + require.True(t, mockCalled) +}