2
0

refactor to use the same connection implementation

This commit is contained in:
Lev Zakharov
2023-08-19 21:42:28 +03:00
committed by Jack Christensen
parent 3d4540aa1b
commit 51ade172e5
2 changed files with 95 additions and 548 deletions
+95 -17
View File
@@ -74,6 +74,7 @@ import (
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
)
// Only intrinsic types should be binary format with database/sql.
@@ -125,7 +126,7 @@ func contains(list []string, y string) bool {
type OptionOpenDB func(*connector)
// OptionBeforeConnect provides a callback for before connect. It is passed a shallow copy of the ConnConfig that will
// be used to connect, so only its immediate members should be modified.
// be used to connect, so only its immediate members should be modified. Used only if db is opened with *pgx.ConnConfig.
func OptionBeforeConnect(bc func(context.Context, *pgx.ConnConfig) error) OptionOpenDB {
return func(dc *connector) {
dc.BeforeConnect = bc
@@ -139,6 +140,20 @@ func OptionAfterConnect(ac func(context.Context, *pgx.Conn) error) OptionOpenDB
}
}
// OptionBeforeConnect provides a callback for before acquire. Used only if db is opened with *pgxpool.Pool.
func OptionBeforeAcquire(ba func(context.Context, *pgxpool.Pool) error) OptionOpenDB {
return func(c *connector) {
c.BeforeAcquire = ba
}
}
// OptionAfterAcquire provides a callback for after acquire. Used only if db is opened with *pgxpool.Pool.
func OptionAfterAcquire(aa func(context.Context, *pgxpool.Conn) error) OptionOpenDB {
return func(c *connector) {
c.AfterAcquire = aa
}
}
// OptionResetSession provides a callback that can be used to add custom logic prior to executing a query on the
// connection if the connection has been used before.
// If ResetSessionFunc returns ErrBadConn error the connection will be discarded.
@@ -191,15 +206,41 @@ func GetConnector(config pgx.ConnConfig, opts ...OptionOpenDB) driver.Connector
return c
}
func GetPoolConnector(pool *pgxpool.Pool, opts ...OptionOpenDB) driver.Connector {
c := connector{
pool: pool,
BeforeAcquire: func(context.Context, *pgxpool.Pool) error { return nil }, // noop before acquire by default
AfterAcquire: func(context.Context, *pgxpool.Conn) error { return nil }, // noop after acquire by default
ResetSession: func(context.Context, *pgx.Conn) error { return nil }, // noop reset session by default
driver: pgxDriver,
}
for _, opt := range opts {
opt(&c)
}
return c
}
func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB {
c := GetConnector(config, opts...)
return sql.OpenDB(c)
}
func OpenDBFromPool(pool *pgxpool.Pool, opts ...OptionOpenDB) *sql.DB {
c := GetPoolConnector(pool, opts...)
db := sql.OpenDB(c)
db.SetMaxIdleConns(0)
return db
}
type connector struct {
pgx.ConnConfig
pool *pgxpool.Pool
BeforeConnect func(context.Context, *pgx.ConnConfig) error // function to call before creation of every new connection
AfterConnect func(context.Context, *pgx.Conn) error // function to call after creation of every new connection
BeforeAcquire func(context.Context, *pgxpool.Pool) error // function to call before acquiring of every new connection
AfterAcquire func(context.Context, *pgxpool.Conn) error // function to call after acquiring of every new connection
ResetSession func(context.Context, *pgx.Conn) error // function is called before a connection is reused
driver *Driver
}
@@ -207,25 +248,60 @@ type connector struct {
// Connect implement driver.Connector interface
func (c connector) Connect(ctx context.Context) (driver.Conn, error) {
var (
err error
conn *pgx.Conn
connConfig pgx.ConnConfig
conn *pgx.Conn
close func(context.Context) error
err error
)
// Create a shallow copy of the config, so that BeforeConnect can safely modify it
connConfig := c.ConnConfig
if err = c.BeforeConnect(ctx, &connConfig); err != nil {
return nil, err
if c.pool == nil {
// Create a shallow copy of the config, so that BeforeConnect can safely modify it
connConfig = c.ConnConfig
if err = c.BeforeConnect(ctx, &connConfig); err != nil {
return nil, err
}
if conn, err = pgx.ConnectConfig(ctx, &connConfig); err != nil {
return nil, err
}
if err = c.AfterConnect(ctx, conn); err != nil {
return nil, err
}
close = conn.Close
} else {
var pconn *pgxpool.Conn
if err = c.BeforeAcquire(ctx, c.pool); err != nil {
return nil, err
}
pconn, err = c.pool.Acquire(ctx)
if err != nil {
return nil, err
}
if err = c.AfterAcquire(ctx, pconn); err != nil {
return nil, err
}
conn = pconn.Conn()
close = func(_ context.Context) error {
pconn.Release()
return nil
}
}
if conn, err = pgx.ConnectConfig(ctx, &connConfig); err != nil {
return nil, err
}
if err = c.AfterConnect(ctx, conn); err != nil {
return nil, err
}
return &Conn{conn: conn, driver: c.driver, connConfig: connConfig, resetSessionFunc: c.ResetSession}, nil
return &Conn{
conn: conn,
close: close,
driver: c.driver,
connConfig: connConfig,
resetSessionFunc: c.ResetSession,
}, nil
}
// Driver implement driver.Connector interface
@@ -302,6 +378,7 @@ func (dc *driverConnector) Connect(ctx context.Context) (driver.Conn, error) {
c := &Conn{
conn: conn,
close: conn.Close,
driver: dc.driver,
connConfig: *connConfig,
resetSessionFunc: func(context.Context, *pgx.Conn) error { return nil },
@@ -326,6 +403,7 @@ func UnregisterConnConfig(connStr string) {
type Conn struct {
conn *pgx.Conn
close func(context.Context) error
psCount int64 // Counter used for creating unique prepared statement names
driver *Driver
connConfig pgx.ConnConfig
@@ -361,7 +439,7 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
func (c *Conn) Close() error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
return c.conn.Close(ctx)
return c.close(ctx)
}
func (c *Conn) Begin() (driver.Tx, error) {