Merge pull request #410 from jbowens/acquireconn
stdlib: allow nested database/sql/driver.Drivers
This commit is contained in:
+19
-19
@@ -99,9 +99,9 @@ var ErrNotPgx = errors.New("not pgx *sql.DB")
|
|||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
pgxDriver = &Driver{
|
pgxDriver = &Driver{
|
||||||
configs: make(map[int64]*DriverConfig),
|
configs: make(map[int64]*DriverConfig),
|
||||||
fakeTxConns: make(map[*pgx.Conn]*sql.Tx),
|
|
||||||
}
|
}
|
||||||
|
fakeTxConns = make(map[*pgx.Conn]*sql.Tx)
|
||||||
sql.Register("pgx", pgxDriver)
|
sql.Register("pgx", pgxDriver)
|
||||||
|
|
||||||
databaseSqlOIDs = make(map[pgtype.OID]bool)
|
databaseSqlOIDs = make(map[pgtype.OID]bool)
|
||||||
@@ -120,13 +120,15 @@ func init() {
|
|||||||
databaseSqlOIDs[pgtype.XIDOID] = true
|
databaseSqlOIDs[pgtype.XIDOID] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
fakeTxMutex sync.Mutex
|
||||||
|
fakeTxConns map[*pgx.Conn]*sql.Tx
|
||||||
|
)
|
||||||
|
|
||||||
type Driver struct {
|
type Driver struct {
|
||||||
configMutex sync.Mutex
|
configMutex sync.Mutex
|
||||||
configCount int64
|
configCount int64
|
||||||
configs map[int64]*DriverConfig
|
configs map[int64]*DriverConfig
|
||||||
|
|
||||||
fakeTxMutex sync.Mutex
|
|
||||||
fakeTxConns map[*pgx.Conn]*sql.Tx
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Driver) Open(name string) (driver.Conn, error) {
|
func (d *Driver) Open(name string) (driver.Conn, error) {
|
||||||
@@ -575,21 +577,20 @@ func (fakeTx) Commit() error { return nil }
|
|||||||
func (fakeTx) Rollback() error { return nil }
|
func (fakeTx) Rollback() error { return nil }
|
||||||
|
|
||||||
func AcquireConn(db *sql.DB) (*pgx.Conn, error) {
|
func AcquireConn(db *sql.DB) (*pgx.Conn, error) {
|
||||||
driver, ok := db.Driver().(*Driver)
|
|
||||||
if !ok {
|
|
||||||
return nil, ErrNotPgx
|
|
||||||
}
|
|
||||||
|
|
||||||
var conn *pgx.Conn
|
var conn *pgx.Conn
|
||||||
ctx := context.WithValue(context.Background(), ctxKeyFakeTx, &conn)
|
ctx := context.WithValue(context.Background(), ctxKeyFakeTx, &conn)
|
||||||
tx, err := db.BeginTx(ctx, nil)
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if conn == nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return nil, ErrNotPgx
|
||||||
|
}
|
||||||
|
|
||||||
driver.fakeTxMutex.Lock()
|
fakeTxMutex.Lock()
|
||||||
driver.fakeTxConns[conn] = tx
|
fakeTxConns[conn] = tx
|
||||||
driver.fakeTxMutex.Unlock()
|
fakeTxMutex.Unlock()
|
||||||
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
@@ -598,14 +599,13 @@ func ReleaseConn(db *sql.DB, conn *pgx.Conn) error {
|
|||||||
var tx *sql.Tx
|
var tx *sql.Tx
|
||||||
var ok bool
|
var ok bool
|
||||||
|
|
||||||
driver := db.Driver().(*Driver)
|
fakeTxMutex.Lock()
|
||||||
driver.fakeTxMutex.Lock()
|
tx, ok = fakeTxConns[conn]
|
||||||
tx, ok = driver.fakeTxConns[conn]
|
|
||||||
if ok {
|
if ok {
|
||||||
delete(driver.fakeTxConns, conn)
|
delete(fakeTxConns, conn)
|
||||||
driver.fakeTxMutex.Unlock()
|
fakeTxMutex.Unlock()
|
||||||
} else {
|
} else {
|
||||||
driver.fakeTxMutex.Unlock()
|
fakeTxMutex.Unlock()
|
||||||
return errors.Errorf("can't release conn that is not acquired")
|
return errors.Errorf("can't release conn that is not acquired")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user