2
0
Files
pgx/stdlib/sql.go
T
Jack Christensen 6972a57421 pgtype.OID type should only be used for scanning and encoding values
It was a mistake to use it in other contexts. This made interop
difficult between pacakges that depended on pgtype such as pgx and
packages that did not like pgconn and pgproto3. In particular this was
awkward for prepared statements.

This is preparation for removing pgx.PreparedStatement in favor of
pgconn.PreparedStatement.
2019-08-24 13:55:57 -05:00

534 lines
13 KiB
Go

// Package stdlib is the compatibility layer from pgx to database/sql.
//
// A database/sql connection can be established through sql.Open.
//
// db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test?sslmode=disable")
// if err != nil {
// return err
// }
//
// Or from a DSN string.
//
// db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5432 database=pgx_test sslmode=disable")
// if err != nil {
// return err
// }
//
// A DriverConfig can be used to further configure the connection process. This
// allows configuring TLS configuration, setting a custom dialer, logging, and
// setting an AfterConnect hook.
//
// driverConfig := stdlib.DriverConfig{
// ConnConfig: pgx.ConnConfig{
// Logger: logger,
// },
// AfterConnect: func(c *pgx.Conn) error {
// // Ensure all connections have this temp table available
// _, err := c.Exec("create temporary table foo(...)")
// return err
// },
// }
//
// stdlib.RegisterDriverConfig(&driverConfig)
//
// db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test"))
// if err != nil {
// return err
// }
//
// pgx uses standard PostgreSQL positional parameters in queries. e.g. $1, $2.
// It does not support named parameters.
//
// db.QueryRow("select * from users where id=$1", userID)
//
// AcquireConn and ReleaseConn acquire and release a *pgx.Conn from the standard
// database/sql.DB connection pool. This allows operations that must be
// performed on a single connection without running in a transaction, and it
// supports operations that use pgx specific functionality.
//
// conn, err := stdlib.AcquireConn(db)
// if err != nil {
// return err
// }
// defer stdlib.ReleaseConn(db, conn)
//
// // do stuff with pgx.Conn
//
// It also can be used to enable a fast path for pgx while preserving
// compatibility with other drivers and database.
//
// conn, err := stdlib.AcquireConn(db)
// if err == nil {
// // fast path with pgx
// // ...
// // release conn when done
// stdlib.ReleaseConn(db, conn)
// } else {
// // normal path for other drivers and databases
// }
package stdlib
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"io"
"math"
"net"
"reflect"
"strings"
"sync"
"time"
errors "golang.org/x/xerrors"
"github.com/jackc/pgconn"
"github.com/jackc/pgtype"
"github.com/jackc/pgx/v4"
)
// Only intrinsic types should be binary format with database/sql.
var databaseSQLResultFormats pgx.QueryResultFormatsByOID
var pgxDriver *Driver
type ctxKey int
var ctxKeyFakeTx ctxKey = 0
var ErrNotPgx = errors.New("not pgx *sql.DB")
func init() {
pgxDriver = &Driver{}
fakeTxConns = make(map[*pgx.Conn]*sql.Tx)
sql.Register("pgx", pgxDriver)
databaseSQLResultFormats = pgx.QueryResultFormatsByOID{
pgtype.BoolOID: 1,
pgtype.ByteaOID: 1,
pgtype.CIDOID: 1,
pgtype.DateOID: 1,
pgtype.Float4OID: 1,
pgtype.Float8OID: 1,
pgtype.Int2OID: 1,
pgtype.Int4OID: 1,
pgtype.Int8OID: 1,
pgtype.OIDOID: 1,
pgtype.TimestampOID: 1,
pgtype.TimestamptzOID: 1,
pgtype.XIDOID: 1,
}
}
var (
fakeTxMutex sync.Mutex
fakeTxConns map[*pgx.Conn]*sql.Tx
)
type Driver struct{}
func (d *Driver) Open(name string) (driver.Conn, error) {
connConfig, err := pgx.ParseConfig(name)
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // Ensure eventual timeout
defer cancel()
conn, err := pgx.ConnectConfig(ctx, connConfig)
if err != nil {
return nil, err
}
c := &Conn{conn: conn, driver: d, connConfig: *connConfig}
return c, nil
}
type Conn struct {
conn *pgx.Conn
psCount int64 // Counter used for creating unique prepared statement names
driver *Driver
connConfig pgx.ConnConfig
}
func (c *Conn) Prepare(query string) (driver.Stmt, error) {
return c.PrepareContext(context.Background(), query)
}
func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if !c.conn.IsAlive() {
return nil, driver.ErrBadConn
}
name := fmt.Sprintf("pgx_%d", c.psCount)
c.psCount++
ps, err := c.conn.Prepare(ctx, name, query)
if err != nil {
return nil, err
}
return &Stmt{ps: ps, conn: c}, nil
}
func (c *Conn) Close() error {
return c.conn.Close(context.Background())
}
func (c *Conn) Begin() (driver.Tx, error) {
return c.BeginTx(context.Background(), driver.TxOptions{})
}
func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if !c.conn.IsAlive() {
return nil, driver.ErrBadConn
}
if pconn, ok := ctx.Value(ctxKeyFakeTx).(**pgx.Conn); ok {
*pconn = c.conn
return fakeTx{}, nil
}
var pgxOpts pgx.TxOptions
switch sql.IsolationLevel(opts.Isolation) {
case sql.LevelDefault:
case sql.LevelReadUncommitted:
pgxOpts.IsoLevel = pgx.ReadUncommitted
case sql.LevelReadCommitted:
pgxOpts.IsoLevel = pgx.ReadCommitted
case sql.LevelRepeatableRead, sql.LevelSnapshot:
pgxOpts.IsoLevel = pgx.RepeatableRead
case sql.LevelSerializable:
pgxOpts.IsoLevel = pgx.Serializable
default:
return nil, errors.Errorf("unsupported isolation: %v", opts.Isolation)
}
if opts.ReadOnly {
pgxOpts.AccessMode = pgx.ReadOnly
}
tx, err := c.conn.BeginEx(ctx, pgxOpts)
if err != nil {
return nil, err
}
return wrapTx{tx: tx}, nil
}
func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Result, error) {
if !c.conn.IsAlive() {
return nil, driver.ErrBadConn
}
args := namedValueToInterface(argsV)
commandTag, err := c.conn.Exec(ctx, query, args...)
// if we got a network error before we had a chance to send the query, retry
if err != nil {
var netErr net.Error
if is := errors.As(err, &netErr); is && errors.Is(err, pgconn.ErrNoBytesSent) {
return nil, driver.ErrBadConn
}
}
return driver.RowsAffected(commandTag.RowsAffected()), err
}
func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) {
if !c.conn.IsAlive() {
return nil, driver.ErrBadConn
}
args := []interface{}{databaseSQLResultFormats}
args = append(args, namedValueToInterface(argsV)...)
rows, err := c.conn.Query(ctx, query, args...)
if err != nil {
if errors.Is(err, pgconn.ErrNoBytesSent) {
return nil, driver.ErrBadConn
}
return nil, err
}
// Preload first row because otherwise we won't know what columns are available when database/sql asks.
more := rows.Next()
return &Rows{conn: c, rows: rows, skipNext: true, skipNextMore: more}, nil
}
func (c *Conn) Ping(ctx context.Context) error {
if !c.conn.IsAlive() {
return driver.ErrBadConn
}
return c.conn.Ping(ctx)
}
type Stmt struct {
ps *pgx.PreparedStatement
conn *Conn
}
func (s *Stmt) Close() error {
return s.conn.conn.Deallocate(context.Background(), s.ps.Name)
}
func (s *Stmt) NumInput() int {
return len(s.ps.ParameterOIDs)
}
func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) {
return nil, errors.New("Stmt.Exec deprecated and not implemented")
}
func (s *Stmt) ExecContext(ctx context.Context, argsV []driver.NamedValue) (driver.Result, error) {
return s.conn.ExecContext(ctx, s.ps.Name, argsV)
}
func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) {
return nil, errors.New("Stmt.Query deprecated and not implemented")
}
func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (driver.Rows, error) {
return s.conn.QueryContext(ctx, s.ps.Name, argsV)
}
type Rows struct {
conn *Conn
rows pgx.Rows
values []interface{}
skipNext bool
skipNextMore bool
}
func (r *Rows) Columns() []string {
fieldDescriptions := r.rows.FieldDescriptions()
names := make([]string, 0, len(fieldDescriptions))
for _, fd := range fieldDescriptions {
names = append(names, string(fd.Name))
}
return names
}
// ColumnTypeDatabaseTypeName return the database system type name.
func (r *Rows) ColumnTypeDatabaseTypeName(index int) string {
if dt, ok := r.conn.conn.ConnInfo.DataTypeForOID(uint32(r.rows.FieldDescriptions()[index].DataTypeOID)); ok {
return strings.ToUpper(dt.Name)
}
return ""
}
const varHeaderSize = 4
// ColumnTypeLength returns the length of the column type if the column is a
// variable length type. If the column is not a variable length type ok
// should return false.
func (r *Rows) ColumnTypeLength(index int) (int64, bool) {
fd := r.rows.FieldDescriptions()[index]
switch fd.DataTypeOID {
case pgtype.TextOID, pgtype.ByteaOID:
return math.MaxInt64, true
case pgtype.VarcharOID, pgtype.BPCharArrayOID:
return int64(fd.TypeModifier - varHeaderSize), true
default:
return 0, false
}
}
// ColumnTypePrecisionScale should return the precision and scale for decimal
// types. If not applicable, ok should be false.
func (r *Rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
fd := r.rows.FieldDescriptions()[index]
switch fd.DataTypeOID {
case pgtype.NumericOID:
mod := fd.TypeModifier - varHeaderSize
precision = int64((mod >> 16) & 0xffff)
scale = int64(mod & 0xffff)
return precision, scale, true
default:
return 0, 0, false
}
}
// ColumnTypeScanType returns the value type that can be used to scan types into.
func (r *Rows) ColumnTypeScanType(index int) reflect.Type {
fd := r.rows.FieldDescriptions()[index]
switch fd.DataTypeOID {
case pgtype.Float8OID:
return reflect.TypeOf(float64(0))
case pgtype.Float4OID:
return reflect.TypeOf(float32(0))
case pgtype.Int8OID:
return reflect.TypeOf(int64(0))
case pgtype.Int4OID:
return reflect.TypeOf(int32(0))
case pgtype.Int2OID:
return reflect.TypeOf(int16(0))
case pgtype.VarcharOID, pgtype.BPCharArrayOID, pgtype.TextOID:
return reflect.TypeOf("")
case pgtype.BoolOID:
return reflect.TypeOf(false)
case pgtype.NumericOID:
return reflect.TypeOf(float64(0))
case pgtype.DateOID, pgtype.TimestampOID, pgtype.TimestamptzOID:
return reflect.TypeOf(time.Time{})
case pgtype.ByteaOID:
return reflect.TypeOf([]byte(nil))
default:
return reflect.TypeOf(new(interface{})).Elem()
}
}
func (r *Rows) Close() error {
r.rows.Close()
return nil
}
func (r *Rows) Next(dest []driver.Value) error {
if r.values == nil {
r.values = make([]interface{}, len(r.rows.FieldDescriptions()))
for i, fd := range r.rows.FieldDescriptions() {
switch fd.DataTypeOID {
case pgtype.BoolOID:
r.values[i] = &pgtype.Bool{}
case pgtype.ByteaOID:
r.values[i] = &pgtype.Bytea{}
case pgtype.CIDOID:
r.values[i] = &pgtype.CID{}
case pgtype.DateOID:
r.values[i] = &pgtype.Date{}
case pgtype.Float4OID:
r.values[i] = &pgtype.Float4{}
case pgtype.Float8OID:
r.values[i] = &pgtype.Float8{}
case pgtype.Int2OID:
r.values[i] = &pgtype.Int2{}
case pgtype.Int4OID:
r.values[i] = &pgtype.Int4{}
case pgtype.Int8OID:
r.values[i] = &pgtype.Int8{}
case pgtype.JSONOID:
r.values[i] = &pgtype.JSON{}
case pgtype.JSONBOID:
r.values[i] = &pgtype.JSONB{}
case pgtype.OIDOID:
r.values[i] = &pgtype.OIDValue{}
case pgtype.TimestampOID:
r.values[i] = &pgtype.Timestamp{}
case pgtype.TimestamptzOID:
r.values[i] = &pgtype.Timestamptz{}
case pgtype.XIDOID:
r.values[i] = &pgtype.XID{}
default:
r.values[i] = &pgtype.GenericText{}
}
}
}
var more bool
if r.skipNext {
more = r.skipNextMore
r.skipNext = false
} else {
more = r.rows.Next()
}
if !more {
if r.rows.Err() == nil {
return io.EOF
} else {
return r.rows.Err()
}
}
err := r.rows.Scan(r.values...)
if err != nil {
return err
}
for i, v := range r.values {
dest[i], err = v.(driver.Valuer).Value()
if err != nil {
return err
}
}
return nil
}
func valueToInterface(argsV []driver.Value) []interface{} {
args := make([]interface{}, 0, len(argsV))
for _, v := range argsV {
if v != nil {
args = append(args, v.(interface{}))
} else {
args = append(args, nil)
}
}
return args
}
func namedValueToInterface(argsV []driver.NamedValue) []interface{} {
args := make([]interface{}, 0, len(argsV))
for _, v := range argsV {
if v.Value != nil {
args = append(args, v.Value.(interface{}))
} else {
args = append(args, nil)
}
}
return args
}
type wrapTx struct{ tx pgx.Tx }
func (wtx wrapTx) Commit() error { return wtx.tx.Commit(context.Background()) }
func (wtx wrapTx) Rollback() error { return wtx.tx.Rollback(context.Background()) }
type fakeTx struct{}
func (fakeTx) Commit() error { return nil }
func (fakeTx) Rollback() error { return nil }
func AcquireConn(db *sql.DB) (*pgx.Conn, error) {
var conn *pgx.Conn
ctx := context.WithValue(context.Background(), ctxKeyFakeTx, &conn)
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
if conn == nil {
tx.Rollback()
return nil, ErrNotPgx
}
fakeTxMutex.Lock()
fakeTxConns[conn] = tx
fakeTxMutex.Unlock()
return conn, nil
}
func ReleaseConn(db *sql.DB, conn *pgx.Conn) error {
var tx *sql.Tx
var ok bool
fakeTxMutex.Lock()
tx, ok = fakeTxConns[conn]
if ok {
delete(fakeTxConns, conn)
fakeTxMutex.Unlock()
} else {
fakeTxMutex.Unlock()
return errors.Errorf("can't release conn that is not acquired")
}
return tx.Rollback()
}