2
0

Inital pass at converting stdlib

Multiple tests still failing
This commit is contained in:
Jack Christensen
2019-04-12 16:57:42 -05:00
parent 3901f3ef88
commit b77f901168
4 changed files with 108 additions and 592 deletions
+36 -141
View File
@@ -72,13 +72,13 @@ import (
"context"
"database/sql"
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
"net"
"reflect"
"strings"
"sync"
"time"
"github.com/pkg/errors"
@@ -99,9 +99,7 @@ var ctxKeyFakeTx ctxKey = 0
var ErrNotPgx = errors.New("not pgx *sql.DB")
func init() {
pgxDriver = &Driver{
configs: make(map[int64]*DriverConfig),
}
pgxDriver = &Driver{}
fakeTxConns = make(map[*pgx.Conn]*sql.Tx)
sql.Register("pgx", pgxDriver)
@@ -126,97 +124,25 @@ var (
fakeTxConns map[*pgx.Conn]*sql.Tx
)
type Driver struct {
configMutex sync.Mutex
configCount int64
configs map[int64]*DriverConfig
}
type Driver struct{}
func (d *Driver) Open(name string) (driver.Conn, error) {
var (
connConfig pgx.ConnConfig
afterConnect func(*pgx.Conn) error
)
if len(name) >= 9 && name[0] == 0 {
idBuf := []byte(name)[1:9]
id := int64(binary.BigEndian.Uint64(idBuf))
d.configMutex.Lock()
connConfig = d.configs[id].ConnConfig
afterConnect = d.configs[id].AfterConnect
d.configMutex.Unlock()
name = name[9:]
}
parsedConfig, err := pgx.ParseConnectionString(name)
if err != nil {
return nil, err
}
connConfig = connConfig.Merge(parsedConfig)
conn, err := pgx.Connect(connConfig)
connConfig, err := pgx.ParseConfig(name)
if err != nil {
return nil, err
}
if afterConnect != nil {
err = afterConnect(conn)
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // Ensure eventual timeout
defer cancel()
conn, err := pgx.ConnectConfig(ctx, connConfig)
if err != nil {
return nil, err
}
c := &Conn{conn: conn, driver: d, connConfig: connConfig}
c := &Conn{conn: conn, driver: d, connConfig: *connConfig}
return c, nil
}
type DriverConfig struct {
pgx.ConnConfig
AfterConnect func(*pgx.Conn) error // function to call on every new connection
driver *Driver
id int64
}
// ConnectionString encodes the DriverConfig into the original connection
// string. DriverConfig must be registered before calling ConnectionString.
func (c *DriverConfig) ConnectionString(original string) string {
if c.driver == nil {
panic("DriverConfig must be registered before calling ConnectionString")
}
buf := make([]byte, 9)
binary.BigEndian.PutUint64(buf[1:], uint64(c.id))
buf = append(buf, original...)
return string(buf)
}
func (d *Driver) registerDriverConfig(c *DriverConfig) {
d.configMutex.Lock()
c.driver = d
c.id = d.configCount
d.configs[d.configCount] = c
d.configCount++
d.configMutex.Unlock()
}
func (d *Driver) unregisterDriverConfig(c *DriverConfig) {
d.configMutex.Lock()
delete(d.configs, c.id)
d.configMutex.Unlock()
}
// RegisterDriverConfig registers a DriverConfig for use with Open.
func RegisterDriverConfig(c *DriverConfig) {
pgxDriver.registerDriverConfig(c)
}
// UnregisterDriverConfig removes a DriverConfig registration.
func UnregisterDriverConfig(c *DriverConfig) {
pgxDriver.unregisterDriverConfig(c)
}
type Conn struct {
conn *pgx.Conn
psCount int64 // Counter used for creating unique prepared statement names
@@ -247,7 +173,7 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
}
func (c *Conn) Close() error {
return c.conn.Close()
return c.conn.Close(context.Background())
}
func (c *Conn) Begin() (driver.Tx, error) {
@@ -283,23 +209,12 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
pgxOpts.AccessMode = pgx.ReadOnly
}
return c.conn.BeginEx(ctx, &pgxOpts)
}
func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) {
if !c.conn.IsAlive() {
return nil, driver.ErrBadConn
tx, err := c.conn.BeginEx(ctx, &pgxOpts)
if err != nil {
return nil, err
}
args := valueToInterface(argsV)
commandTag, err := c.conn.Exec(query, args...)
// if we got a network error before we had a chance to send the query, retry
if err != nil && !c.conn.LastStmtSent() {
if _, is := err.(net.Error); is {
return nil, driver.ErrBadConn
}
}
return driver.RowsAffected(commandTag.RowsAffected()), err
return wrapTx{tx: tx}, nil
}
func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Result, error) {
@@ -309,7 +224,7 @@ func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.Nam
args := namedValueToInterface(argsV)
commandTag, err := c.conn.ExecEx(ctx, query, nil, args...)
commandTag, err := c.conn.Exec(ctx, query, args...)
// if we got a network error before we had a chance to send the query, retry
if err != nil && !c.conn.LastStmtSent() {
if _, is := err.(net.Error); is {
@@ -319,44 +234,16 @@ func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.Nam
return driver.RowsAffected(commandTag.RowsAffected()), err
}
func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) {
if !c.conn.IsAlive() {
return nil, driver.ErrBadConn
}
if !c.connConfig.PreferSimpleProtocol {
ps, err := c.conn.Prepare("", query)
if err != nil {
return nil, err
}
restrictBinaryToDatabaseSqlTypes(ps)
return c.queryPrepared("", argsV)
}
rows, err := c.conn.Query(query, valueToInterface(argsV)...)
if err != nil {
// if we got a network error before we had a chance to send the query, retry
if !c.conn.LastStmtSent() {
if _, is := err.(net.Error); is {
return nil, driver.ErrBadConn
}
}
return nil, err
}
// Preload first row because otherwise we won't know what columns are available when database/sql asks.
more := rows.Next()
return &Rows{rows: rows, skipNext: true, skipNextMore: more}, nil
}
func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) {
if !c.conn.IsAlive() {
return nil, driver.ErrBadConn
}
var rows pgx.Rows
if !c.connConfig.PreferSimpleProtocol {
ps, err := c.conn.PrepareEx(ctx, "", query, nil)
c.conn.Deallocate("stdlibtemp")
ps, err := c.conn.PrepareEx(ctx, "stdlibtemp", query, nil)
if err != nil {
// since PrepareEx failed, we didn't actually get to send the values, so
// we can safely retry
@@ -367,10 +254,10 @@ func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.Na
}
restrictBinaryToDatabaseSqlTypes(ps)
return c.queryPreparedContext(ctx, "", argsV)
return c.queryPreparedContext(ctx, "stdlibtemp", argsV)
}
rows, err := c.conn.QueryEx(ctx, query, nil, namedValueToInterface(argsV)...)
rows, err := c.conn.Query(ctx, query, namedValueToInterface(argsV)...)
if err != nil {
// if we got a network error before we had a chance to send the query, retry
if !c.conn.LastStmtSent() {
@@ -393,7 +280,7 @@ func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, er
args := valueToInterface(argsV)
rows, err := c.conn.Query(name, args...)
rows, err := c.conn.Query(context.Background(), name, args...)
if err != nil {
return nil, err
}
@@ -408,12 +295,14 @@ func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []dr
args := namedValueToInterface(argsV)
rows, err := c.conn.QueryEx(ctx, name, nil, args...)
rows, err := c.conn.Query(ctx, name, args...)
if err != nil {
return nil, err
}
return &Rows{rows: rows}, nil
// Preload first row because otherwise we won't know what columns are available when database/sql asks.
more := rows.Next()
return &Rows{rows: rows, skipNext: true, skipNextMore: more}, nil
}
func (c *Conn) Ping(ctx context.Context) error {
@@ -450,7 +339,7 @@ func (s *Stmt) NumInput() int {
}
func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) {
return s.conn.Exec(s.ps.Name, argsV)
return nil, errors.New("Stmt.Exec deprecated and not implemented")
}
func (s *Stmt) ExecContext(ctx context.Context, argsV []driver.NamedValue) (driver.Result, error) {
@@ -458,7 +347,7 @@ func (s *Stmt) ExecContext(ctx context.Context, argsV []driver.NamedValue) (driv
}
func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) {
return s.conn.queryPrepared(s.ps.Name, argsV)
return nil, errors.New("Stmt.Query deprecated and not implemented")
}
func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (driver.Rows, error) {
@@ -466,7 +355,7 @@ func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (dri
}
type Rows struct {
rows *pgx.Rows
rows pgx.Rows
values []interface{}
skipNext bool
skipNextMore bool
@@ -605,6 +494,12 @@ func namedValueToInterface(argsV []driver.NamedValue) []interface{} {
return args
}
type wrapTx struct{ tx *pgx.Tx }
func (wtx wrapTx) Commit() error { return wtx.tx.Commit(context.Background()) }
func (wtx wrapTx) Rollback() error { return wtx.tx.Rollback(context.Background()) }
type fakeTx struct{}
func (fakeTx) Commit() error { return nil }