refactor to use the same connection implementation
This commit is contained in:
committed by
Jack Christensen
parent
3d4540aa1b
commit
51ade172e5
+95
-17
@@ -74,6 +74,7 @@ import (
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
// Only intrinsic types should be binary format with database/sql.
|
||||
@@ -125,7 +126,7 @@ func contains(list []string, y string) bool {
|
||||
type OptionOpenDB func(*connector)
|
||||
|
||||
// OptionBeforeConnect provides a callback for before connect. It is passed a shallow copy of the ConnConfig that will
|
||||
// be used to connect, so only its immediate members should be modified.
|
||||
// be used to connect, so only its immediate members should be modified. Used only if db is opened with *pgx.ConnConfig.
|
||||
func OptionBeforeConnect(bc func(context.Context, *pgx.ConnConfig) error) OptionOpenDB {
|
||||
return func(dc *connector) {
|
||||
dc.BeforeConnect = bc
|
||||
@@ -139,6 +140,20 @@ func OptionAfterConnect(ac func(context.Context, *pgx.Conn) error) OptionOpenDB
|
||||
}
|
||||
}
|
||||
|
||||
// OptionBeforeConnect provides a callback for before acquire. Used only if db is opened with *pgxpool.Pool.
|
||||
func OptionBeforeAcquire(ba func(context.Context, *pgxpool.Pool) error) OptionOpenDB {
|
||||
return func(c *connector) {
|
||||
c.BeforeAcquire = ba
|
||||
}
|
||||
}
|
||||
|
||||
// OptionAfterAcquire provides a callback for after acquire. Used only if db is opened with *pgxpool.Pool.
|
||||
func OptionAfterAcquire(aa func(context.Context, *pgxpool.Conn) error) OptionOpenDB {
|
||||
return func(c *connector) {
|
||||
c.AfterAcquire = aa
|
||||
}
|
||||
}
|
||||
|
||||
// OptionResetSession provides a callback that can be used to add custom logic prior to executing a query on the
|
||||
// connection if the connection has been used before.
|
||||
// If ResetSessionFunc returns ErrBadConn error the connection will be discarded.
|
||||
@@ -191,15 +206,41 @@ func GetConnector(config pgx.ConnConfig, opts ...OptionOpenDB) driver.Connector
|
||||
return c
|
||||
}
|
||||
|
||||
func GetPoolConnector(pool *pgxpool.Pool, opts ...OptionOpenDB) driver.Connector {
|
||||
c := connector{
|
||||
pool: pool,
|
||||
BeforeAcquire: func(context.Context, *pgxpool.Pool) error { return nil }, // noop before acquire by default
|
||||
AfterAcquire: func(context.Context, *pgxpool.Conn) error { return nil }, // noop after acquire by default
|
||||
ResetSession: func(context.Context, *pgx.Conn) error { return nil }, // noop reset session by default
|
||||
driver: pgxDriver,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(&c)
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB {
|
||||
c := GetConnector(config, opts...)
|
||||
return sql.OpenDB(c)
|
||||
}
|
||||
|
||||
func OpenDBFromPool(pool *pgxpool.Pool, opts ...OptionOpenDB) *sql.DB {
|
||||
c := GetPoolConnector(pool, opts...)
|
||||
db := sql.OpenDB(c)
|
||||
db.SetMaxIdleConns(0)
|
||||
return db
|
||||
}
|
||||
|
||||
type connector struct {
|
||||
pgx.ConnConfig
|
||||
pool *pgxpool.Pool
|
||||
BeforeConnect func(context.Context, *pgx.ConnConfig) error // function to call before creation of every new connection
|
||||
AfterConnect func(context.Context, *pgx.Conn) error // function to call after creation of every new connection
|
||||
BeforeAcquire func(context.Context, *pgxpool.Pool) error // function to call before acquiring of every new connection
|
||||
AfterAcquire func(context.Context, *pgxpool.Conn) error // function to call after acquiring of every new connection
|
||||
ResetSession func(context.Context, *pgx.Conn) error // function is called before a connection is reused
|
||||
driver *Driver
|
||||
}
|
||||
@@ -207,25 +248,60 @@ type connector struct {
|
||||
// Connect implement driver.Connector interface
|
||||
func (c connector) Connect(ctx context.Context) (driver.Conn, error) {
|
||||
var (
|
||||
err error
|
||||
conn *pgx.Conn
|
||||
connConfig pgx.ConnConfig
|
||||
conn *pgx.Conn
|
||||
close func(context.Context) error
|
||||
err error
|
||||
)
|
||||
|
||||
// Create a shallow copy of the config, so that BeforeConnect can safely modify it
|
||||
connConfig := c.ConnConfig
|
||||
if err = c.BeforeConnect(ctx, &connConfig); err != nil {
|
||||
return nil, err
|
||||
if c.pool == nil {
|
||||
// Create a shallow copy of the config, so that BeforeConnect can safely modify it
|
||||
connConfig = c.ConnConfig
|
||||
|
||||
if err = c.BeforeConnect(ctx, &connConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if conn, err = pgx.ConnectConfig(ctx, &connConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = c.AfterConnect(ctx, conn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
close = conn.Close
|
||||
} else {
|
||||
var pconn *pgxpool.Conn
|
||||
|
||||
if err = c.BeforeAcquire(ctx, c.pool); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pconn, err = c.pool.Acquire(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = c.AfterAcquire(ctx, pconn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn = pconn.Conn()
|
||||
|
||||
close = func(_ context.Context) error {
|
||||
pconn.Release()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if conn, err = pgx.ConnectConfig(ctx, &connConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = c.AfterConnect(ctx, conn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Conn{conn: conn, driver: c.driver, connConfig: connConfig, resetSessionFunc: c.ResetSession}, nil
|
||||
return &Conn{
|
||||
conn: conn,
|
||||
close: close,
|
||||
driver: c.driver,
|
||||
connConfig: connConfig,
|
||||
resetSessionFunc: c.ResetSession,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Driver implement driver.Connector interface
|
||||
@@ -302,6 +378,7 @@ func (dc *driverConnector) Connect(ctx context.Context) (driver.Conn, error) {
|
||||
|
||||
c := &Conn{
|
||||
conn: conn,
|
||||
close: conn.Close,
|
||||
driver: dc.driver,
|
||||
connConfig: *connConfig,
|
||||
resetSessionFunc: func(context.Context, *pgx.Conn) error { return nil },
|
||||
@@ -326,6 +403,7 @@ func UnregisterConnConfig(connStr string) {
|
||||
|
||||
type Conn struct {
|
||||
conn *pgx.Conn
|
||||
close func(context.Context) error
|
||||
psCount int64 // Counter used for creating unique prepared statement names
|
||||
driver *Driver
|
||||
connConfig pgx.ConnConfig
|
||||
@@ -361,7 +439,7 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
|
||||
func (c *Conn) Close() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
return c.conn.Close(ctx)
|
||||
return c.close(ctx)
|
||||
}
|
||||
|
||||
func (c *Conn) Begin() (driver.Tx, error) {
|
||||
|
||||
Reference in New Issue
Block a user