Inital pass at converting stdlib
Multiple tests still failing
This commit is contained in:
+36
-141
@@ -72,13 +72,13 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
@@ -99,9 +99,7 @@ var ctxKeyFakeTx ctxKey = 0
|
||||
var ErrNotPgx = errors.New("not pgx *sql.DB")
|
||||
|
||||
func init() {
|
||||
pgxDriver = &Driver{
|
||||
configs: make(map[int64]*DriverConfig),
|
||||
}
|
||||
pgxDriver = &Driver{}
|
||||
fakeTxConns = make(map[*pgx.Conn]*sql.Tx)
|
||||
sql.Register("pgx", pgxDriver)
|
||||
|
||||
@@ -126,97 +124,25 @@ var (
|
||||
fakeTxConns map[*pgx.Conn]*sql.Tx
|
||||
)
|
||||
|
||||
type Driver struct {
|
||||
configMutex sync.Mutex
|
||||
configCount int64
|
||||
configs map[int64]*DriverConfig
|
||||
}
|
||||
type Driver struct{}
|
||||
|
||||
func (d *Driver) Open(name string) (driver.Conn, error) {
|
||||
var (
|
||||
connConfig pgx.ConnConfig
|
||||
afterConnect func(*pgx.Conn) error
|
||||
)
|
||||
|
||||
if len(name) >= 9 && name[0] == 0 {
|
||||
idBuf := []byte(name)[1:9]
|
||||
id := int64(binary.BigEndian.Uint64(idBuf))
|
||||
d.configMutex.Lock()
|
||||
connConfig = d.configs[id].ConnConfig
|
||||
afterConnect = d.configs[id].AfterConnect
|
||||
d.configMutex.Unlock()
|
||||
name = name[9:]
|
||||
}
|
||||
|
||||
parsedConfig, err := pgx.ParseConnectionString(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
connConfig = connConfig.Merge(parsedConfig)
|
||||
|
||||
conn, err := pgx.Connect(connConfig)
|
||||
connConfig, err := pgx.ParseConfig(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if afterConnect != nil {
|
||||
err = afterConnect(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // Ensure eventual timeout
|
||||
defer cancel()
|
||||
conn, err := pgx.ConnectConfig(ctx, connConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c := &Conn{conn: conn, driver: d, connConfig: connConfig}
|
||||
c := &Conn{conn: conn, driver: d, connConfig: *connConfig}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
type DriverConfig struct {
|
||||
pgx.ConnConfig
|
||||
AfterConnect func(*pgx.Conn) error // function to call on every new connection
|
||||
driver *Driver
|
||||
id int64
|
||||
}
|
||||
|
||||
// ConnectionString encodes the DriverConfig into the original connection
|
||||
// string. DriverConfig must be registered before calling ConnectionString.
|
||||
func (c *DriverConfig) ConnectionString(original string) string {
|
||||
if c.driver == nil {
|
||||
panic("DriverConfig must be registered before calling ConnectionString")
|
||||
}
|
||||
|
||||
buf := make([]byte, 9)
|
||||
binary.BigEndian.PutUint64(buf[1:], uint64(c.id))
|
||||
buf = append(buf, original...)
|
||||
return string(buf)
|
||||
}
|
||||
|
||||
func (d *Driver) registerDriverConfig(c *DriverConfig) {
|
||||
d.configMutex.Lock()
|
||||
|
||||
c.driver = d
|
||||
c.id = d.configCount
|
||||
d.configs[d.configCount] = c
|
||||
d.configCount++
|
||||
|
||||
d.configMutex.Unlock()
|
||||
}
|
||||
|
||||
func (d *Driver) unregisterDriverConfig(c *DriverConfig) {
|
||||
d.configMutex.Lock()
|
||||
delete(d.configs, c.id)
|
||||
d.configMutex.Unlock()
|
||||
}
|
||||
|
||||
// RegisterDriverConfig registers a DriverConfig for use with Open.
|
||||
func RegisterDriverConfig(c *DriverConfig) {
|
||||
pgxDriver.registerDriverConfig(c)
|
||||
}
|
||||
|
||||
// UnregisterDriverConfig removes a DriverConfig registration.
|
||||
func UnregisterDriverConfig(c *DriverConfig) {
|
||||
pgxDriver.unregisterDriverConfig(c)
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
conn *pgx.Conn
|
||||
psCount int64 // Counter used for creating unique prepared statement names
|
||||
@@ -247,7 +173,7 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
return c.conn.Close()
|
||||
return c.conn.Close(context.Background())
|
||||
}
|
||||
|
||||
func (c *Conn) Begin() (driver.Tx, error) {
|
||||
@@ -283,23 +209,12 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
|
||||
pgxOpts.AccessMode = pgx.ReadOnly
|
||||
}
|
||||
|
||||
return c.conn.BeginEx(ctx, &pgxOpts)
|
||||
}
|
||||
|
||||
func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) {
|
||||
if !c.conn.IsAlive() {
|
||||
return nil, driver.ErrBadConn
|
||||
tx, err := c.conn.BeginEx(ctx, &pgxOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
args := valueToInterface(argsV)
|
||||
commandTag, err := c.conn.Exec(query, args...)
|
||||
// if we got a network error before we had a chance to send the query, retry
|
||||
if err != nil && !c.conn.LastStmtSent() {
|
||||
if _, is := err.(net.Error); is {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
}
|
||||
return driver.RowsAffected(commandTag.RowsAffected()), err
|
||||
return wrapTx{tx: tx}, nil
|
||||
}
|
||||
|
||||
func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Result, error) {
|
||||
@@ -309,7 +224,7 @@ func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.Nam
|
||||
|
||||
args := namedValueToInterface(argsV)
|
||||
|
||||
commandTag, err := c.conn.ExecEx(ctx, query, nil, args...)
|
||||
commandTag, err := c.conn.Exec(ctx, query, args...)
|
||||
// if we got a network error before we had a chance to send the query, retry
|
||||
if err != nil && !c.conn.LastStmtSent() {
|
||||
if _, is := err.(net.Error); is {
|
||||
@@ -319,44 +234,16 @@ func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.Nam
|
||||
return driver.RowsAffected(commandTag.RowsAffected()), err
|
||||
}
|
||||
|
||||
func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) {
|
||||
if !c.conn.IsAlive() {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
if !c.connConfig.PreferSimpleProtocol {
|
||||
ps, err := c.conn.Prepare("", query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
restrictBinaryToDatabaseSqlTypes(ps)
|
||||
return c.queryPrepared("", argsV)
|
||||
}
|
||||
|
||||
rows, err := c.conn.Query(query, valueToInterface(argsV)...)
|
||||
if err != nil {
|
||||
// if we got a network error before we had a chance to send the query, retry
|
||||
if !c.conn.LastStmtSent() {
|
||||
if _, is := err.(net.Error); is {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Preload first row because otherwise we won't know what columns are available when database/sql asks.
|
||||
more := rows.Next()
|
||||
return &Rows{rows: rows, skipNext: true, skipNextMore: more}, nil
|
||||
}
|
||||
|
||||
func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) {
|
||||
if !c.conn.IsAlive() {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
var rows pgx.Rows
|
||||
|
||||
if !c.connConfig.PreferSimpleProtocol {
|
||||
ps, err := c.conn.PrepareEx(ctx, "", query, nil)
|
||||
c.conn.Deallocate("stdlibtemp")
|
||||
ps, err := c.conn.PrepareEx(ctx, "stdlibtemp", query, nil)
|
||||
if err != nil {
|
||||
// since PrepareEx failed, we didn't actually get to send the values, so
|
||||
// we can safely retry
|
||||
@@ -367,10 +254,10 @@ func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.Na
|
||||
}
|
||||
|
||||
restrictBinaryToDatabaseSqlTypes(ps)
|
||||
return c.queryPreparedContext(ctx, "", argsV)
|
||||
return c.queryPreparedContext(ctx, "stdlibtemp", argsV)
|
||||
}
|
||||
|
||||
rows, err := c.conn.QueryEx(ctx, query, nil, namedValueToInterface(argsV)...)
|
||||
rows, err := c.conn.Query(ctx, query, namedValueToInterface(argsV)...)
|
||||
if err != nil {
|
||||
// if we got a network error before we had a chance to send the query, retry
|
||||
if !c.conn.LastStmtSent() {
|
||||
@@ -393,7 +280,7 @@ func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, er
|
||||
|
||||
args := valueToInterface(argsV)
|
||||
|
||||
rows, err := c.conn.Query(name, args...)
|
||||
rows, err := c.conn.Query(context.Background(), name, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -408,12 +295,14 @@ func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []dr
|
||||
|
||||
args := namedValueToInterface(argsV)
|
||||
|
||||
rows, err := c.conn.QueryEx(ctx, name, nil, args...)
|
||||
rows, err := c.conn.Query(ctx, name, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Rows{rows: rows}, nil
|
||||
// Preload first row because otherwise we won't know what columns are available when database/sql asks.
|
||||
more := rows.Next()
|
||||
return &Rows{rows: rows, skipNext: true, skipNextMore: more}, nil
|
||||
}
|
||||
|
||||
func (c *Conn) Ping(ctx context.Context) error {
|
||||
@@ -450,7 +339,7 @@ func (s *Stmt) NumInput() int {
|
||||
}
|
||||
|
||||
func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) {
|
||||
return s.conn.Exec(s.ps.Name, argsV)
|
||||
return nil, errors.New("Stmt.Exec deprecated and not implemented")
|
||||
}
|
||||
|
||||
func (s *Stmt) ExecContext(ctx context.Context, argsV []driver.NamedValue) (driver.Result, error) {
|
||||
@@ -458,7 +347,7 @@ func (s *Stmt) ExecContext(ctx context.Context, argsV []driver.NamedValue) (driv
|
||||
}
|
||||
|
||||
func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) {
|
||||
return s.conn.queryPrepared(s.ps.Name, argsV)
|
||||
return nil, errors.New("Stmt.Query deprecated and not implemented")
|
||||
}
|
||||
|
||||
func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (driver.Rows, error) {
|
||||
@@ -466,7 +355,7 @@ func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (dri
|
||||
}
|
||||
|
||||
type Rows struct {
|
||||
rows *pgx.Rows
|
||||
rows pgx.Rows
|
||||
values []interface{}
|
||||
skipNext bool
|
||||
skipNextMore bool
|
||||
@@ -605,6 +494,12 @@ func namedValueToInterface(argsV []driver.NamedValue) []interface{} {
|
||||
return args
|
||||
}
|
||||
|
||||
type wrapTx struct{ tx *pgx.Tx }
|
||||
|
||||
func (wtx wrapTx) Commit() error { return wtx.tx.Commit(context.Background()) }
|
||||
|
||||
func (wtx wrapTx) Rollback() error { return wtx.tx.Rollback(context.Background()) }
|
||||
|
||||
type fakeTx struct{}
|
||||
|
||||
func (fakeTx) Commit() error { return nil }
|
||||
|
||||
Reference in New Issue
Block a user