3ab8941921
Previously, stdlib.RegisterConnConfig would sometimes reuse the same connection string for different ConnConfig options (specifically, it happened when a connection was open and then closed, and then a new, different connection was opened). This behavior interferes with callers that expect that two connections with the same data source name are connecting to the same backend database in the same way. This fix updates stdlib.RegisterConnConfig to use an incrementing sequence counter to uniquify all returned connection strings. Fixes #947
795 lines
20 KiB
Go
795 lines
20 KiB
Go
// Package stdlib is the compatibility layer from pgx to database/sql.
|
|
//
|
|
// A database/sql connection can be established through sql.Open.
|
|
//
|
|
// db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test?sslmode=disable")
|
|
// if err != nil {
|
|
// return err
|
|
// }
|
|
//
|
|
// Or from a DSN string.
|
|
//
|
|
// db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5432 database=pgx_test sslmode=disable")
|
|
// if err != nil {
|
|
// return err
|
|
// }
|
|
//
|
|
// Or a pgx.ConnConfig can be used to set configuration not accessible via connection string. In this case the
|
|
// pgx.ConnConfig must first be registered with the driver. This registration returns a connection string which is used
|
|
// with sql.Open.
|
|
//
|
|
// connConfig, _ := pgx.ParseConfig(os.Getenv("DATABASE_URL"))
|
|
// connConfig.Logger = myLogger
|
|
// connStr := stdlib.RegisterConnConfig(connConfig)
|
|
// db, _ := sql.Open("pgx", connStr)
|
|
//
|
|
// pgx uses standard PostgreSQL positional parameters in queries. e.g. $1, $2.
|
|
// It does not support named parameters.
|
|
//
|
|
// db.QueryRow("select * from users where id=$1", userID)
|
|
//
|
|
// In Go 1.13 and above (*sql.Conn) Raw() can be used to get a *pgx.Conn from the standard
|
|
// database/sql.DB connection pool. This allows operations that use pgx specific functionality.
|
|
//
|
|
// // Given db is a *sql.DB
|
|
// conn, err := db.Conn(context.Background())
|
|
// if err != nil {
|
|
// // handle error from acquiring connection from DB pool
|
|
// }
|
|
//
|
|
// err = conn.Raw(func(driverConn interface{}) error {
|
|
// conn := driverConn.(*stdlib.Conn).Conn() // conn is a *pgx.Conn
|
|
// // Do pgx specific stuff with conn
|
|
// conn.CopyFrom(...)
|
|
// return nil
|
|
// })
|
|
// if err != nil {
|
|
// // handle error that occurred while using *pgx.Conn
|
|
// }
|
|
package stdlib
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"math"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/jackc/pgconn"
|
|
"github.com/jackc/pgtype"
|
|
"github.com/jackc/pgx/v4"
|
|
)
|
|
|
|
// Only intrinsic types should be binary format with database/sql.
|
|
var databaseSQLResultFormats pgx.QueryResultFormatsByOID
|
|
|
|
var pgxDriver *Driver
|
|
|
|
type ctxKey int
|
|
|
|
var ctxKeyFakeTx ctxKey = 0
|
|
|
|
var ErrNotPgx = errors.New("not pgx *sql.DB")
|
|
|
|
func init() {
|
|
pgxDriver = &Driver{
|
|
configs: make(map[string]*pgx.ConnConfig),
|
|
}
|
|
fakeTxConns = make(map[*pgx.Conn]*sql.Tx)
|
|
sql.Register("pgx", pgxDriver)
|
|
|
|
databaseSQLResultFormats = pgx.QueryResultFormatsByOID{
|
|
pgtype.BoolOID: 1,
|
|
pgtype.ByteaOID: 1,
|
|
pgtype.CIDOID: 1,
|
|
pgtype.DateOID: 1,
|
|
pgtype.Float4OID: 1,
|
|
pgtype.Float8OID: 1,
|
|
pgtype.Int2OID: 1,
|
|
pgtype.Int4OID: 1,
|
|
pgtype.Int8OID: 1,
|
|
pgtype.OIDOID: 1,
|
|
pgtype.TimestampOID: 1,
|
|
pgtype.TimestamptzOID: 1,
|
|
pgtype.XIDOID: 1,
|
|
}
|
|
}
|
|
|
|
var (
|
|
fakeTxMutex sync.Mutex
|
|
fakeTxConns map[*pgx.Conn]*sql.Tx
|
|
)
|
|
|
|
// OptionOpenDB options for configuring the driver when opening a new db pool.
|
|
type OptionOpenDB func(*connector)
|
|
|
|
// OptionAfterConnect provide a callback for after connect.
|
|
func OptionAfterConnect(ac func(context.Context, *pgx.Conn) error) OptionOpenDB {
|
|
return func(dc *connector) {
|
|
dc.AfterConnect = ac
|
|
}
|
|
}
|
|
|
|
func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB {
|
|
c := connector{
|
|
ConnConfig: config,
|
|
AfterConnect: func(context.Context, *pgx.Conn) error { return nil }, // noop after connect by default
|
|
driver: pgxDriver,
|
|
}
|
|
|
|
for _, opt := range opts {
|
|
opt(&c)
|
|
}
|
|
|
|
return sql.OpenDB(c)
|
|
}
|
|
|
|
type connector struct {
|
|
pgx.ConnConfig
|
|
AfterConnect func(context.Context, *pgx.Conn) error // function to call on every new connection
|
|
driver *Driver
|
|
}
|
|
|
|
// Connect implement driver.Connector interface
|
|
func (c connector) Connect(ctx context.Context) (driver.Conn, error) {
|
|
var (
|
|
err error
|
|
conn *pgx.Conn
|
|
)
|
|
|
|
if conn, err = pgx.ConnectConfig(ctx, &c.ConnConfig); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err = c.AfterConnect(ctx, conn); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &Conn{conn: conn, driver: c.driver, connConfig: c.ConnConfig}, nil
|
|
}
|
|
|
|
// Driver implement driver.Connector interface
|
|
func (c connector) Driver() driver.Driver {
|
|
return c.driver
|
|
}
|
|
|
|
// GetDefaultDriver returns the driver initialized in the init function
|
|
// and used when the pgx driver is registered.
|
|
func GetDefaultDriver() driver.Driver {
|
|
return pgxDriver
|
|
}
|
|
|
|
type Driver struct {
|
|
configMutex sync.Mutex
|
|
configs map[string]*pgx.ConnConfig
|
|
sequence int
|
|
}
|
|
|
|
func (d *Driver) Open(name string) (driver.Conn, error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // Ensure eventual timeout
|
|
defer cancel()
|
|
|
|
connector, err := d.OpenConnector(name)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return connector.Connect(ctx)
|
|
}
|
|
|
|
func (d *Driver) OpenConnector(name string) (driver.Connector, error) {
|
|
return &driverConnector{driver: d, name: name}, nil
|
|
}
|
|
|
|
func (d *Driver) registerConnConfig(c *pgx.ConnConfig) string {
|
|
d.configMutex.Lock()
|
|
connStr := fmt.Sprintf("registeredConnConfig%d", d.sequence)
|
|
d.sequence++
|
|
d.configs[connStr] = c
|
|
d.configMutex.Unlock()
|
|
return connStr
|
|
}
|
|
|
|
func (d *Driver) unregisterConnConfig(connStr string) {
|
|
d.configMutex.Lock()
|
|
delete(d.configs, connStr)
|
|
d.configMutex.Unlock()
|
|
}
|
|
|
|
type driverConnector struct {
|
|
driver *Driver
|
|
name string
|
|
}
|
|
|
|
func (dc *driverConnector) Connect(ctx context.Context) (driver.Conn, error) {
|
|
var connConfig *pgx.ConnConfig
|
|
|
|
dc.driver.configMutex.Lock()
|
|
connConfig = dc.driver.configs[dc.name]
|
|
dc.driver.configMutex.Unlock()
|
|
|
|
if connConfig == nil {
|
|
var err error
|
|
connConfig, err = pgx.ParseConfig(dc.name)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
conn, err := pgx.ConnectConfig(ctx, connConfig)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
c := &Conn{conn: conn, driver: dc.driver, connConfig: *connConfig}
|
|
return c, nil
|
|
}
|
|
|
|
func (dc *driverConnector) Driver() driver.Driver {
|
|
return dc.driver
|
|
}
|
|
|
|
// RegisterConnConfig registers a ConnConfig and returns the connection string to use with Open.
|
|
func RegisterConnConfig(c *pgx.ConnConfig) string {
|
|
return pgxDriver.registerConnConfig(c)
|
|
}
|
|
|
|
// UnregisterConnConfig removes the ConnConfig registration for connStr.
|
|
func UnregisterConnConfig(connStr string) {
|
|
pgxDriver.unregisterConnConfig(connStr)
|
|
}
|
|
|
|
type Conn struct {
|
|
conn *pgx.Conn
|
|
psCount int64 // Counter used for creating unique prepared statement names
|
|
driver *Driver
|
|
connConfig pgx.ConnConfig
|
|
}
|
|
|
|
// Conn returns the underlying *pgx.Conn
|
|
func (c *Conn) Conn() *pgx.Conn {
|
|
return c.conn
|
|
}
|
|
|
|
func (c *Conn) Prepare(query string) (driver.Stmt, error) {
|
|
return c.PrepareContext(context.Background(), query)
|
|
}
|
|
|
|
func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
|
|
if c.conn.IsClosed() {
|
|
return nil, driver.ErrBadConn
|
|
}
|
|
|
|
name := fmt.Sprintf("pgx_%d", c.psCount)
|
|
c.psCount++
|
|
|
|
sd, err := c.conn.Prepare(ctx, name, query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &Stmt{sd: sd, conn: c}, nil
|
|
}
|
|
|
|
func (c *Conn) Close() error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
|
defer cancel()
|
|
return c.conn.Close(ctx)
|
|
}
|
|
|
|
func (c *Conn) Begin() (driver.Tx, error) {
|
|
return c.BeginTx(context.Background(), driver.TxOptions{})
|
|
}
|
|
|
|
func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
|
if c.conn.IsClosed() {
|
|
return nil, driver.ErrBadConn
|
|
}
|
|
|
|
if pconn, ok := ctx.Value(ctxKeyFakeTx).(**pgx.Conn); ok {
|
|
*pconn = c.conn
|
|
return fakeTx{}, nil
|
|
}
|
|
|
|
var pgxOpts pgx.TxOptions
|
|
switch sql.IsolationLevel(opts.Isolation) {
|
|
case sql.LevelDefault:
|
|
case sql.LevelReadUncommitted:
|
|
pgxOpts.IsoLevel = pgx.ReadUncommitted
|
|
case sql.LevelReadCommitted:
|
|
pgxOpts.IsoLevel = pgx.ReadCommitted
|
|
case sql.LevelRepeatableRead, sql.LevelSnapshot:
|
|
pgxOpts.IsoLevel = pgx.RepeatableRead
|
|
case sql.LevelSerializable:
|
|
pgxOpts.IsoLevel = pgx.Serializable
|
|
default:
|
|
return nil, fmt.Errorf("unsupported isolation: %v", opts.Isolation)
|
|
}
|
|
|
|
if opts.ReadOnly {
|
|
pgxOpts.AccessMode = pgx.ReadOnly
|
|
}
|
|
|
|
tx, err := c.conn.BeginTx(ctx, pgxOpts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return wrapTx{ctx: ctx, tx: tx}, nil
|
|
}
|
|
|
|
func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Result, error) {
|
|
if c.conn.IsClosed() {
|
|
return nil, driver.ErrBadConn
|
|
}
|
|
|
|
args := namedValueToInterface(argsV)
|
|
|
|
commandTag, err := c.conn.Exec(ctx, query, args...)
|
|
// if we got a network error before we had a chance to send the query, retry
|
|
if err != nil {
|
|
if pgconn.SafeToRetry(err) {
|
|
return nil, driver.ErrBadConn
|
|
}
|
|
}
|
|
return driver.RowsAffected(commandTag.RowsAffected()), err
|
|
}
|
|
|
|
func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) {
|
|
if c.conn.IsClosed() {
|
|
return nil, driver.ErrBadConn
|
|
}
|
|
|
|
args := []interface{}{databaseSQLResultFormats}
|
|
args = append(args, namedValueToInterface(argsV)...)
|
|
|
|
rows, err := c.conn.Query(ctx, query, args...)
|
|
if err != nil {
|
|
if pgconn.SafeToRetry(err) {
|
|
return nil, driver.ErrBadConn
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
// Preload first row because otherwise we won't know what columns are available when database/sql asks.
|
|
more := rows.Next()
|
|
if err = rows.Err(); err != nil {
|
|
rows.Close()
|
|
return nil, err
|
|
}
|
|
return &Rows{conn: c, rows: rows, skipNext: true, skipNextMore: more}, nil
|
|
}
|
|
|
|
func (c *Conn) Ping(ctx context.Context) error {
|
|
if c.conn.IsClosed() {
|
|
return driver.ErrBadConn
|
|
}
|
|
|
|
err := c.conn.Ping(ctx)
|
|
if err != nil {
|
|
// A Ping failure implies some sort of fatal state. The connection is almost certainly already closed by the
|
|
// failure, but manually close it just to be sure.
|
|
c.Close()
|
|
return driver.ErrBadConn
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Conn) CheckNamedValue(*driver.NamedValue) error {
|
|
// Underlying pgx supports sql.Scanner and driver.Valuer interfaces natively. So everything can be passed through directly.
|
|
return nil
|
|
}
|
|
|
|
func (c *Conn) ResetSession(ctx context.Context) error {
|
|
if c.conn.IsClosed() {
|
|
return driver.ErrBadConn
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type Stmt struct {
|
|
sd *pgconn.StatementDescription
|
|
conn *Conn
|
|
}
|
|
|
|
func (s *Stmt) Close() error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
|
defer cancel()
|
|
return s.conn.conn.Deallocate(ctx, s.sd.Name)
|
|
}
|
|
|
|
func (s *Stmt) NumInput() int {
|
|
return len(s.sd.ParamOIDs)
|
|
}
|
|
|
|
func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) {
|
|
return nil, errors.New("Stmt.Exec deprecated and not implemented")
|
|
}
|
|
|
|
func (s *Stmt) ExecContext(ctx context.Context, argsV []driver.NamedValue) (driver.Result, error) {
|
|
return s.conn.ExecContext(ctx, s.sd.Name, argsV)
|
|
}
|
|
|
|
func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) {
|
|
return nil, errors.New("Stmt.Query deprecated and not implemented")
|
|
}
|
|
|
|
func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (driver.Rows, error) {
|
|
return s.conn.QueryContext(ctx, s.sd.Name, argsV)
|
|
}
|
|
|
|
type rowValueFunc func(src []byte) (driver.Value, error)
|
|
|
|
type Rows struct {
|
|
conn *Conn
|
|
rows pgx.Rows
|
|
valueFuncs []rowValueFunc
|
|
skipNext bool
|
|
skipNextMore bool
|
|
|
|
columnNames []string
|
|
}
|
|
|
|
func (r *Rows) Columns() []string {
|
|
if r.columnNames == nil {
|
|
fields := r.rows.FieldDescriptions()
|
|
r.columnNames = make([]string, len(fields))
|
|
for i, fd := range fields {
|
|
r.columnNames[i] = string(fd.Name)
|
|
}
|
|
}
|
|
|
|
return r.columnNames
|
|
}
|
|
|
|
// ColumnTypeDatabaseTypeName returns the database system type name. If the name is unknown the OID is returned.
|
|
func (r *Rows) ColumnTypeDatabaseTypeName(index int) string {
|
|
if dt, ok := r.conn.conn.ConnInfo().DataTypeForOID(r.rows.FieldDescriptions()[index].DataTypeOID); ok {
|
|
return strings.ToUpper(dt.Name)
|
|
}
|
|
|
|
return strconv.FormatInt(int64(r.rows.FieldDescriptions()[index].DataTypeOID), 10)
|
|
}
|
|
|
|
const varHeaderSize = 4
|
|
|
|
// ColumnTypeLength returns the length of the column type if the column is a
|
|
// variable length type. If the column is not a variable length type ok
|
|
// should return false.
|
|
func (r *Rows) ColumnTypeLength(index int) (int64, bool) {
|
|
fd := r.rows.FieldDescriptions()[index]
|
|
|
|
switch fd.DataTypeOID {
|
|
case pgtype.TextOID, pgtype.ByteaOID:
|
|
return math.MaxInt64, true
|
|
case pgtype.VarcharOID, pgtype.BPCharArrayOID:
|
|
return int64(fd.TypeModifier - varHeaderSize), true
|
|
default:
|
|
return 0, false
|
|
}
|
|
}
|
|
|
|
// ColumnTypePrecisionScale should return the precision and scale for decimal
|
|
// types. If not applicable, ok should be false.
|
|
func (r *Rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
|
|
fd := r.rows.FieldDescriptions()[index]
|
|
|
|
switch fd.DataTypeOID {
|
|
case pgtype.NumericOID:
|
|
mod := fd.TypeModifier - varHeaderSize
|
|
precision = int64((mod >> 16) & 0xffff)
|
|
scale = int64(mod & 0xffff)
|
|
return precision, scale, true
|
|
default:
|
|
return 0, 0, false
|
|
}
|
|
}
|
|
|
|
// ColumnTypeScanType returns the value type that can be used to scan types into.
|
|
func (r *Rows) ColumnTypeScanType(index int) reflect.Type {
|
|
fd := r.rows.FieldDescriptions()[index]
|
|
|
|
switch fd.DataTypeOID {
|
|
case pgtype.Float8OID:
|
|
return reflect.TypeOf(float64(0))
|
|
case pgtype.Float4OID:
|
|
return reflect.TypeOf(float32(0))
|
|
case pgtype.Int8OID:
|
|
return reflect.TypeOf(int64(0))
|
|
case pgtype.Int4OID:
|
|
return reflect.TypeOf(int32(0))
|
|
case pgtype.Int2OID:
|
|
return reflect.TypeOf(int16(0))
|
|
case pgtype.BoolOID:
|
|
return reflect.TypeOf(false)
|
|
case pgtype.NumericOID:
|
|
return reflect.TypeOf(float64(0))
|
|
case pgtype.DateOID, pgtype.TimestampOID, pgtype.TimestamptzOID:
|
|
return reflect.TypeOf(time.Time{})
|
|
case pgtype.ByteaOID:
|
|
return reflect.TypeOf([]byte(nil))
|
|
default:
|
|
return reflect.TypeOf("")
|
|
}
|
|
}
|
|
|
|
func (r *Rows) Close() error {
|
|
r.rows.Close()
|
|
return r.rows.Err()
|
|
}
|
|
|
|
func (r *Rows) Next(dest []driver.Value) error {
|
|
ci := r.conn.conn.ConnInfo()
|
|
fieldDescriptions := r.rows.FieldDescriptions()
|
|
|
|
if r.valueFuncs == nil {
|
|
r.valueFuncs = make([]rowValueFunc, len(fieldDescriptions))
|
|
|
|
for i, fd := range fieldDescriptions {
|
|
dataTypeOID := fd.DataTypeOID
|
|
format := fd.Format
|
|
|
|
switch fd.DataTypeOID {
|
|
case pgtype.BoolOID:
|
|
var d bool
|
|
scanPlan := ci.PlanScan(dataTypeOID, format, &d)
|
|
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
|
err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
|
|
return d, err
|
|
}
|
|
case pgtype.ByteaOID:
|
|
var d []byte
|
|
scanPlan := ci.PlanScan(dataTypeOID, format, &d)
|
|
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
|
err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
|
|
return d, err
|
|
}
|
|
case pgtype.CIDOID:
|
|
var d pgtype.CID
|
|
scanPlan := ci.PlanScan(dataTypeOID, format, &d)
|
|
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
|
err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return d.Value()
|
|
}
|
|
case pgtype.DateOID:
|
|
var d pgtype.Date
|
|
scanPlan := ci.PlanScan(dataTypeOID, format, &d)
|
|
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
|
err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return d.Value()
|
|
}
|
|
case pgtype.Float4OID:
|
|
var d float32
|
|
scanPlan := ci.PlanScan(dataTypeOID, format, &d)
|
|
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
|
err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
|
|
return float64(d), err
|
|
}
|
|
case pgtype.Float8OID:
|
|
var d float64
|
|
scanPlan := ci.PlanScan(dataTypeOID, format, &d)
|
|
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
|
err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
|
|
return d, err
|
|
}
|
|
case pgtype.Int2OID:
|
|
var d int16
|
|
scanPlan := ci.PlanScan(dataTypeOID, format, &d)
|
|
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
|
err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
|
|
return int64(d), err
|
|
}
|
|
case pgtype.Int4OID:
|
|
var d int32
|
|
scanPlan := ci.PlanScan(dataTypeOID, format, &d)
|
|
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
|
err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
|
|
return int64(d), err
|
|
}
|
|
case pgtype.Int8OID:
|
|
var d int64
|
|
scanPlan := ci.PlanScan(dataTypeOID, format, &d)
|
|
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
|
err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
|
|
return d, err
|
|
}
|
|
case pgtype.JSONOID:
|
|
var d pgtype.JSON
|
|
scanPlan := ci.PlanScan(dataTypeOID, format, &d)
|
|
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
|
err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return d.Value()
|
|
}
|
|
case pgtype.JSONBOID:
|
|
var d pgtype.JSONB
|
|
scanPlan := ci.PlanScan(dataTypeOID, format, &d)
|
|
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
|
err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return d.Value()
|
|
}
|
|
case pgtype.OIDOID:
|
|
var d pgtype.OIDValue
|
|
scanPlan := ci.PlanScan(dataTypeOID, format, &d)
|
|
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
|
err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return d.Value()
|
|
}
|
|
case pgtype.TimestampOID:
|
|
var d pgtype.Timestamp
|
|
scanPlan := ci.PlanScan(dataTypeOID, format, &d)
|
|
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
|
err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return d.Value()
|
|
}
|
|
case pgtype.TimestamptzOID:
|
|
var d pgtype.Timestamptz
|
|
scanPlan := ci.PlanScan(dataTypeOID, format, &d)
|
|
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
|
err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return d.Value()
|
|
}
|
|
case pgtype.XIDOID:
|
|
var d pgtype.XID
|
|
scanPlan := ci.PlanScan(dataTypeOID, format, &d)
|
|
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
|
err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return d.Value()
|
|
}
|
|
default:
|
|
var d string
|
|
scanPlan := ci.PlanScan(dataTypeOID, format, &d)
|
|
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
|
err := scanPlan.Scan(ci, dataTypeOID, format, src, &d)
|
|
return d, err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
var more bool
|
|
if r.skipNext {
|
|
more = r.skipNextMore
|
|
r.skipNext = false
|
|
} else {
|
|
more = r.rows.Next()
|
|
}
|
|
|
|
if !more {
|
|
if r.rows.Err() == nil {
|
|
return io.EOF
|
|
} else {
|
|
return r.rows.Err()
|
|
}
|
|
}
|
|
|
|
for i, rv := range r.rows.RawValues() {
|
|
if rv != nil {
|
|
var err error
|
|
dest[i], err = r.valueFuncs[i](rv)
|
|
if err != nil {
|
|
return fmt.Errorf("convert field %d failed: %v", i, err)
|
|
}
|
|
} else {
|
|
dest[i] = nil
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func valueToInterface(argsV []driver.Value) []interface{} {
|
|
args := make([]interface{}, 0, len(argsV))
|
|
for _, v := range argsV {
|
|
if v != nil {
|
|
args = append(args, v.(interface{}))
|
|
} else {
|
|
args = append(args, nil)
|
|
}
|
|
}
|
|
return args
|
|
}
|
|
|
|
func namedValueToInterface(argsV []driver.NamedValue) []interface{} {
|
|
args := make([]interface{}, 0, len(argsV))
|
|
for _, v := range argsV {
|
|
if v.Value != nil {
|
|
args = append(args, v.Value.(interface{}))
|
|
} else {
|
|
args = append(args, nil)
|
|
}
|
|
}
|
|
return args
|
|
}
|
|
|
|
type wrapTx struct {
|
|
ctx context.Context
|
|
tx pgx.Tx
|
|
}
|
|
|
|
func (wtx wrapTx) Commit() error { return wtx.tx.Commit(wtx.ctx) }
|
|
|
|
func (wtx wrapTx) Rollback() error { return wtx.tx.Rollback(wtx.ctx) }
|
|
|
|
type fakeTx struct{}
|
|
|
|
func (fakeTx) Commit() error { return nil }
|
|
|
|
func (fakeTx) Rollback() error { return nil }
|
|
|
|
// AcquireConn acquires a *pgx.Conn from database/sql connection pool. It must be released with ReleaseConn.
|
|
//
|
|
// In Go 1.13 this functionality has been incorporated into the standard library in the db.Conn.Raw() method.
|
|
func AcquireConn(db *sql.DB) (*pgx.Conn, error) {
|
|
var conn *pgx.Conn
|
|
ctx := context.WithValue(context.Background(), ctxKeyFakeTx, &conn)
|
|
tx, err := db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if conn == nil {
|
|
tx.Rollback()
|
|
return nil, ErrNotPgx
|
|
}
|
|
|
|
fakeTxMutex.Lock()
|
|
fakeTxConns[conn] = tx
|
|
fakeTxMutex.Unlock()
|
|
|
|
return conn, nil
|
|
}
|
|
|
|
// ReleaseConn releases a *pgx.Conn acquired with AcquireConn.
|
|
func ReleaseConn(db *sql.DB, conn *pgx.Conn) error {
|
|
var tx *sql.Tx
|
|
var ok bool
|
|
|
|
if conn.PgConn().IsBusy() || conn.PgConn().TxStatus() != 'I' {
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
conn.Close(ctx)
|
|
}
|
|
|
|
fakeTxMutex.Lock()
|
|
tx, ok = fakeTxConns[conn]
|
|
if ok {
|
|
delete(fakeTxConns, conn)
|
|
fakeTxMutex.Unlock()
|
|
} else {
|
|
fakeTxMutex.Unlock()
|
|
return fmt.Errorf("can't release conn that is not acquired")
|
|
}
|
|
|
|
return tx.Rollback()
|
|
}
|