Merge branch 'context' into v3-experimental
This commit is contained in:
@@ -51,6 +51,8 @@ install:
|
|||||||
- go get -u github.com/shopspring/decimal
|
- go get -u github.com/shopspring/decimal
|
||||||
- go get -u gopkg.in/inconshreveable/log15.v2
|
- go get -u gopkg.in/inconshreveable/log15.v2
|
||||||
- go get -u github.com/jackc/fake
|
- go get -u github.com/jackc/fake
|
||||||
|
- go get -u golang.org/x/net/context
|
||||||
|
- go get -u github.com/jackc/pgmock/pgmsg
|
||||||
|
|
||||||
script:
|
script:
|
||||||
- go test -v -race -short ./...
|
- go test -v -race -short ./...
|
||||||
|
|||||||
@@ -0,0 +1,11 @@
|
|||||||
|
Extract all locking state into a separate struct that will encapsulate locking and state change behavior.
|
||||||
|
|
||||||
|
This struct should add or subsume at least the following:
|
||||||
|
* alive
|
||||||
|
* closingLock
|
||||||
|
* ctxInProgress (though this may be restructured because it's possible a Tx may have a ctx and a query run in that Tx could have one)
|
||||||
|
* busy
|
||||||
|
* lock/unlock
|
||||||
|
* Tx in-progress
|
||||||
|
* Rows in-progress
|
||||||
|
* ConnPool checked-out or checked-in - maybe include reference to conn pool
|
||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"golang.org/x/net/context"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
"net/url"
|
||||||
@@ -17,9 +18,17 @@ import (
|
|||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
connStatusUninitialized = iota
|
||||||
|
connStatusClosed
|
||||||
|
connStatusIdle
|
||||||
|
connStatusBusy
|
||||||
|
)
|
||||||
|
|
||||||
// DialFunc is a function that can be used to connect to a PostgreSQL server
|
// DialFunc is a function that can be used to connect to a PostgreSQL server
|
||||||
type DialFunc func(network, addr string) (net.Conn, error)
|
type DialFunc func(network, addr string) (net.Conn, error)
|
||||||
|
|
||||||
@@ -39,13 +48,28 @@ type ConnConfig struct {
|
|||||||
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cc *ConnConfig) networkAddress() (network, address string) {
|
||||||
|
network = "tcp"
|
||||||
|
address = fmt.Sprintf("%s:%d", cc.Host, cc.Port)
|
||||||
|
// See if host is a valid path, if yes connect with a socket
|
||||||
|
if _, err := os.Stat(cc.Host); err == nil {
|
||||||
|
// For backward compatibility accept socket file paths -- but directories are now preferred
|
||||||
|
network = "unix"
|
||||||
|
address = cc.Host
|
||||||
|
if !strings.Contains(address, "/.s.PGSQL.") {
|
||||||
|
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return network, address
|
||||||
|
}
|
||||||
|
|
||||||
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
|
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
|
||||||
// Use ConnPool to manage access to multiple database connections from multiple
|
// Use ConnPool to manage access to multiple database connections from multiple
|
||||||
// goroutines.
|
// goroutines.
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
conn net.Conn // the underlying TCP or unix domain socket connection
|
conn net.Conn // the underlying TCP or unix domain socket connection
|
||||||
lastActivityTime time.Time // the last time the connection was used
|
lastActivityTime time.Time // the last time the connection was used
|
||||||
reader *bufio.Reader // buffered reader to improve read performance
|
|
||||||
wbuf [1024]byte
|
wbuf [1024]byte
|
||||||
writeBuf WriteBuf
|
writeBuf WriteBuf
|
||||||
pid int32 // backend pid
|
pid int32 // backend pid
|
||||||
@@ -57,17 +81,26 @@ type Conn struct {
|
|||||||
preparedStatements map[string]*PreparedStatement
|
preparedStatements map[string]*PreparedStatement
|
||||||
channels map[string]struct{}
|
channels map[string]struct{}
|
||||||
notifications []*Notification
|
notifications []*Notification
|
||||||
alive bool
|
|
||||||
causeOfDeath error
|
|
||||||
logger Logger
|
logger Logger
|
||||||
logLevel int
|
logLevel int
|
||||||
mr msgReader
|
mr msgReader
|
||||||
fp *fastpath
|
fp *fastpath
|
||||||
pgsqlAfInet *byte
|
pgsqlAfInet *byte
|
||||||
pgsqlAfInet6 *byte
|
pgsqlAfInet6 *byte
|
||||||
busy bool
|
|
||||||
poolResetCount int
|
poolResetCount int
|
||||||
preallocatedRows []Rows
|
preallocatedRows []Rows
|
||||||
|
|
||||||
|
status int32 // One of connStatus* constants
|
||||||
|
causeOfDeath error
|
||||||
|
|
||||||
|
readyForQuery bool // connection has received ReadyForQuery message since last query was sent
|
||||||
|
cancelQueryInProgress int32
|
||||||
|
cancelQueryCompleted chan struct{}
|
||||||
|
|
||||||
|
// context support
|
||||||
|
ctxInProgress bool
|
||||||
|
doneChan chan struct{}
|
||||||
|
closedChan chan error
|
||||||
}
|
}
|
||||||
|
|
||||||
// PreparedStatement is a description of a prepared statement
|
// PreparedStatement is a description of a prepared statement
|
||||||
@@ -194,17 +227,7 @@ func connect(config ConnConfig, pgTypes map[OID]PgType, pgsqlAfInet *byte, pgsql
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
network := "tcp"
|
network, address := c.config.networkAddress()
|
||||||
address := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port)
|
|
||||||
// See if host is a valid path, if yes connect with a socket
|
|
||||||
if _, err := os.Stat(c.config.Host); err == nil {
|
|
||||||
// For backward compatibility accept socket file paths -- but directories are now preferred
|
|
||||||
network = "unix"
|
|
||||||
address = c.config.Host
|
|
||||||
if !strings.Contains(address, "/.s.PGSQL.") {
|
|
||||||
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(c.config.Port), 10)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if c.config.Dial == nil {
|
if c.config.Dial == nil {
|
||||||
c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial
|
c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial
|
||||||
}
|
}
|
||||||
@@ -238,15 +261,18 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
|||||||
defer func() {
|
defer func() {
|
||||||
if c != nil && err != nil {
|
if c != nil && err != nil {
|
||||||
c.conn.Close()
|
c.conn.Close()
|
||||||
c.alive = false
|
atomic.StoreInt32(&c.status, connStatusClosed)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
c.RuntimeParams = make(map[string]string)
|
c.RuntimeParams = make(map[string]string)
|
||||||
c.preparedStatements = make(map[string]*PreparedStatement)
|
c.preparedStatements = make(map[string]*PreparedStatement)
|
||||||
c.channels = make(map[string]struct{})
|
c.channels = make(map[string]struct{})
|
||||||
c.alive = true
|
atomic.StoreInt32(&c.status, connStatusIdle)
|
||||||
c.lastActivityTime = time.Now()
|
c.lastActivityTime = time.Now()
|
||||||
|
c.cancelQueryCompleted = make(chan struct{}, 1)
|
||||||
|
c.doneChan = make(chan struct{})
|
||||||
|
c.closedChan = make(chan error)
|
||||||
|
|
||||||
if tlsConfig != nil {
|
if tlsConfig != nil {
|
||||||
if c.shouldLog(LogLevelDebug) {
|
if c.shouldLog(LogLevelDebug) {
|
||||||
@@ -257,8 +283,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.reader = bufio.NewReader(c.conn)
|
c.mr.reader = bufio.NewReader(c.conn)
|
||||||
c.mr.reader = c.reader
|
|
||||||
|
|
||||||
msg := newStartupMessage()
|
msg := newStartupMessage()
|
||||||
|
|
||||||
@@ -389,14 +414,17 @@ func (c *Conn) PID() int32 {
|
|||||||
// Close closes a connection. It is safe to call Close on a already closed
|
// Close closes a connection. It is safe to call Close on a already closed
|
||||||
// connection.
|
// connection.
|
||||||
func (c *Conn) Close() (err error) {
|
func (c *Conn) Close() (err error) {
|
||||||
if !c.IsAlive() {
|
for {
|
||||||
return nil
|
status := atomic.LoadInt32(&c.status)
|
||||||
|
if status < connStatusIdle {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if atomic.CompareAndSwapInt32(&c.status, status, connStatusClosed) {
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
wbuf := newWriteBuf(c, 'X')
|
_, err = c.conn.Write([]byte{'X', 0, 0, 0, 4})
|
||||||
wbuf.closeMsg()
|
|
||||||
|
|
||||||
_, err = c.conn.Write(wbuf.buf)
|
|
||||||
|
|
||||||
c.die(errors.New("Closed"))
|
c.die(errors.New("Closed"))
|
||||||
if c.shouldLog(LogLevelInfo) {
|
if c.shouldLog(LogLevelInfo) {
|
||||||
@@ -614,12 +642,36 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
|
|||||||
// name and sql arguments. This allows a code path to PrepareEx and Query/Exec without
|
// name and sql arguments. This allows a code path to PrepareEx and Query/Exec without
|
||||||
// concern for if the statement has already been prepared.
|
// concern for if the statement has already been prepared.
|
||||||
func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
|
func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
|
||||||
|
return c.PrepareExContext(context.Background(), name, sql, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
|
||||||
|
err = c.waitForPreviousCancelQuery(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.initContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ps, err = c.prepareEx(name, sql, opts)
|
||||||
|
err = c.termContext(err)
|
||||||
|
return ps, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
|
||||||
if name != "" {
|
if name != "" {
|
||||||
if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql {
|
if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql {
|
||||||
return ps, nil
|
return ps, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := c.ensureConnectionReadyForQuery(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
if c.shouldLog(LogLevelError) {
|
if c.shouldLog(LogLevelError) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -659,6 +711,7 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
|
|||||||
c.die(err)
|
c.die(err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
c.readyForQuery = false
|
||||||
|
|
||||||
ps = &PreparedStatement{Name: name, SQL: sql}
|
ps = &PreparedStatement{Name: name, SQL: sql}
|
||||||
|
|
||||||
@@ -673,7 +726,6 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch t {
|
switch t {
|
||||||
case parseComplete:
|
|
||||||
case parameterDescription:
|
case parameterDescription:
|
||||||
ps.ParameterOIDs = c.rxParameterDescription(r)
|
ps.ParameterOIDs = c.rxParameterDescription(r)
|
||||||
|
|
||||||
@@ -687,7 +739,6 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
|
|||||||
ps.FieldDescriptions[i].DataTypeName = t.Name
|
ps.FieldDescriptions[i].DataTypeName = t.Name
|
||||||
ps.FieldDescriptions[i].FormatCode = t.DefaultFormat
|
ps.FieldDescriptions[i].FormatCode = t.DefaultFormat
|
||||||
}
|
}
|
||||||
case noData:
|
|
||||||
case readyForQuery:
|
case readyForQuery:
|
||||||
c.rxReadyForQuery(r)
|
c.rxReadyForQuery(r)
|
||||||
|
|
||||||
@@ -705,7 +756,29 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Deallocate released a prepared statement
|
// Deallocate released a prepared statement
|
||||||
func (c *Conn) Deallocate(name string) (err error) {
|
func (c *Conn) Deallocate(name string) error {
|
||||||
|
return c.deallocateContext(context.Background(), name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO - consider making this public
|
||||||
|
func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) {
|
||||||
|
err = c.waitForPreviousCancelQuery(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.initContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err = c.termContext(err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := c.ensureConnectionReadyForQuery(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
delete(c.preparedStatements, name)
|
delete(c.preparedStatements, name)
|
||||||
|
|
||||||
// close
|
// close
|
||||||
@@ -776,6 +849,17 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error)
|
|||||||
return notification, nil
|
return notification, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx, cancelFn := context.WithTimeout(context.Background(), timeout)
|
||||||
|
if err := c.waitForPreviousCancelQuery(ctx); err != nil {
|
||||||
|
cancelFn()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cancelFn()
|
||||||
|
|
||||||
|
if err := c.ensureConnectionReadyForQuery(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
stopTime := time.Now().Add(timeout)
|
stopTime := time.Now().Add(timeout)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -835,7 +919,7 @@ func (c *Conn) waitForNotification(deadline time.Time) (*Notification, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Wait until there is a byte available before continuing onto the normal msg reading path
|
// Wait until there is a byte available before continuing onto the normal msg reading path
|
||||||
_, err = c.reader.Peek(1)
|
_, err = c.mr.reader.Peek(1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline
|
c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline
|
||||||
if err, ok := err.(*net.OpError); ok && err.Timeout() {
|
if err, ok := err.(*net.OpError); ok && err.Timeout() {
|
||||||
@@ -868,7 +952,7 @@ func (c *Conn) waitForNotification(deadline time.Time) (*Notification, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) IsAlive() bool {
|
func (c *Conn) IsAlive() bool {
|
||||||
return c.alive
|
return atomic.LoadInt32(&c.status) >= connStatusIdle
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) CauseOfDeath() error {
|
func (c *Conn) CauseOfDeath() error {
|
||||||
@@ -883,6 +967,9 @@ func (c *Conn) sendQuery(sql string, arguments ...interface{}) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error {
|
func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error {
|
||||||
|
if err := c.ensureConnectionReadyForQuery(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if len(args) == 0 {
|
if len(args) == 0 {
|
||||||
wbuf := newWriteBuf(c, 'Q')
|
wbuf := newWriteBuf(c, 'Q')
|
||||||
@@ -894,6 +981,7 @@ func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error {
|
|||||||
c.die(err)
|
c.die(err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
c.readyForQuery = false
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -911,6 +999,10 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||||||
return fmt.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOIDs), len(arguments))
|
return fmt.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOIDs), len(arguments))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := c.ensureConnectionReadyForQuery(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// bind
|
// bind
|
||||||
wbuf := newWriteBuf(c, 'B')
|
wbuf := newWriteBuf(c, 'B')
|
||||||
wbuf.WriteByte(0)
|
wbuf.WriteByte(0)
|
||||||
@@ -958,6 +1050,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
c.die(err)
|
c.die(err)
|
||||||
}
|
}
|
||||||
|
c.readyForQuery = false
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -965,91 +1058,52 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||||||
// Exec executes sql. sql can be either a prepared statement name or an SQL string.
|
// Exec executes sql. sql can be either a prepared statement name or an SQL string.
|
||||||
// arguments should be referenced positionally from the sql string as $1, $2, etc.
|
// arguments should be referenced positionally from the sql string as $1, $2, etc.
|
||||||
func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
|
func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||||
if err = c.lock(); err != nil {
|
return c.ExecContext(context.Background(), sql, arguments...)
|
||||||
return commandTag, err
|
|
||||||
}
|
|
||||||
|
|
||||||
startTime := time.Now()
|
|
||||||
c.lastActivityTime = startTime
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if err == nil {
|
|
||||||
if c.shouldLog(LogLevelInfo) {
|
|
||||||
endTime := time.Now()
|
|
||||||
c.log(LogLevelInfo, "Exec", "sql", sql, "args", logQueryArgs(arguments), "time", endTime.Sub(startTime), "commandTag", commandTag)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if c.shouldLog(LogLevelError) {
|
|
||||||
c.log(LogLevelError, "Exec", "sql", sql, "args", logQueryArgs(arguments), "error", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if unlockErr := c.unlock(); unlockErr != nil && err == nil {
|
|
||||||
err = unlockErr
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if err = c.sendQuery(sql, arguments...); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var softErr error
|
|
||||||
|
|
||||||
for {
|
|
||||||
var t byte
|
|
||||||
var r *msgReader
|
|
||||||
t, r, err = c.rxMsg()
|
|
||||||
if err != nil {
|
|
||||||
return commandTag, err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch t {
|
|
||||||
case readyForQuery:
|
|
||||||
c.rxReadyForQuery(r)
|
|
||||||
return commandTag, softErr
|
|
||||||
case rowDescription:
|
|
||||||
case dataRow:
|
|
||||||
case bindComplete:
|
|
||||||
case commandComplete:
|
|
||||||
commandTag = CommandTag(r.readCString())
|
|
||||||
default:
|
|
||||||
if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil {
|
|
||||||
softErr = e
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Processes messages that are not exclusive to one context such as
|
// Processes messages that are not exclusive to one context such as
|
||||||
// authentication or query response. The response to these messages
|
// authentication or query response. The response to these messages is the same
|
||||||
// is the same regardless of when they occur.
|
// regardless of when they occur. It also ignores messages that are only
|
||||||
|
// meaningful in a given context. These messages can occur due to a context
|
||||||
|
// deadline interrupting message processing. For example, an interrupted query
|
||||||
|
// may have left DataRow messages on the wire.
|
||||||
func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) {
|
func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) {
|
||||||
switch t {
|
switch t {
|
||||||
case 'S':
|
case bindComplete:
|
||||||
c.rxParameterStatus(r)
|
case commandComplete:
|
||||||
return nil
|
case dataRow:
|
||||||
|
case emptyQueryResponse:
|
||||||
case errorResponse:
|
case errorResponse:
|
||||||
return c.rxErrorResponse(r)
|
return c.rxErrorResponse(r)
|
||||||
|
case noData:
|
||||||
case noticeResponse:
|
case noticeResponse:
|
||||||
return nil
|
|
||||||
case emptyQueryResponse:
|
|
||||||
return nil
|
|
||||||
case notificationResponse:
|
case notificationResponse:
|
||||||
c.rxNotificationResponse(r)
|
c.rxNotificationResponse(r)
|
||||||
return nil
|
case parameterDescription:
|
||||||
|
case parseComplete:
|
||||||
|
case readyForQuery:
|
||||||
|
c.rxReadyForQuery(r)
|
||||||
|
case rowDescription:
|
||||||
|
case 'S':
|
||||||
|
c.rxParameterStatus(r)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("Received unknown message type: %c", t)
|
return fmt.Errorf("Received unknown message type: %c", t)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) rxMsg() (t byte, r *msgReader, err error) {
|
func (c *Conn) rxMsg() (t byte, r *msgReader, err error) {
|
||||||
if !c.alive {
|
if atomic.LoadInt32(&c.status) < connStatusIdle {
|
||||||
return 0, nil, ErrDeadConn
|
return 0, nil, ErrDeadConn
|
||||||
}
|
}
|
||||||
|
|
||||||
t, err = c.mr.rxMsg()
|
t, err = c.mr.rxMsg()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.die(err)
|
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
|
||||||
|
c.die(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.lastActivityTime = time.Now()
|
c.lastActivityTime = time.Now()
|
||||||
@@ -1150,6 +1204,7 @@ func (c *Conn) rxBackendKeyData(r *msgReader) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) rxReadyForQuery(r *msgReader) {
|
func (c *Conn) rxReadyForQuery(r *msgReader) {
|
||||||
|
c.readyForQuery = true
|
||||||
c.txStatus = r.readByte()
|
c.txStatus = r.readByte()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1230,25 +1285,23 @@ func (c *Conn) txPasswordMessage(password string) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) die(err error) {
|
func (c *Conn) die(err error) {
|
||||||
c.alive = false
|
atomic.StoreInt32(&c.status, connStatusClosed)
|
||||||
c.causeOfDeath = err
|
c.causeOfDeath = err
|
||||||
c.conn.Close()
|
c.conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) lock() error {
|
func (c *Conn) lock() error {
|
||||||
if c.busy {
|
if atomic.CompareAndSwapInt32(&c.status, connStatusIdle, connStatusBusy) {
|
||||||
return ErrConnBusy
|
return nil
|
||||||
}
|
}
|
||||||
c.busy = true
|
return ErrConnBusy
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) unlock() error {
|
func (c *Conn) unlock() error {
|
||||||
if !c.busy {
|
if atomic.CompareAndSwapInt32(&c.status, connStatusBusy, connStatusIdle) {
|
||||||
return errors.New("unlock conn that is not busy")
|
return nil
|
||||||
}
|
}
|
||||||
c.busy = false
|
return errors.New("unlock conn that is not busy")
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) shouldLog(lvl int) bool {
|
func (c *Conn) shouldLog(lvl int) bool {
|
||||||
@@ -1286,3 +1339,229 @@ func (c *Conn) SetLogLevel(lvl int) (int, error) {
|
|||||||
func quoteIdentifier(s string) string {
|
func quoteIdentifier(s string) string {
|
||||||
return `"` + strings.Replace(s, `"`, `""`, -1) + `"`
|
return `"` + strings.Replace(s, `"`, `""`, -1) + `"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// cancelQuery sends a cancel request to the PostgreSQL server. It returns an
|
||||||
|
// error if unable to deliver the cancel request, but lack of an error does not
|
||||||
|
// ensure that the query was canceled. As specified in the documentation, there
|
||||||
|
// is no way to be sure a query was canceled. See
|
||||||
|
// https://www.postgresql.org/docs/current/static/protocol-flow.html#AEN112861
|
||||||
|
func (c *Conn) cancelQuery() {
|
||||||
|
if !atomic.CompareAndSwapInt32(&c.cancelQueryInProgress, 0, 1) {
|
||||||
|
panic("cancelQuery when cancelQueryInProgress")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.conn.SetDeadline(time.Now()); err != nil {
|
||||||
|
c.Close() // Close connection if unable to set deadline
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
doCancel := func() error {
|
||||||
|
network, address := c.config.networkAddress()
|
||||||
|
cancelConn, err := c.config.Dial(network, address)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer cancelConn.Close()
|
||||||
|
|
||||||
|
// If server doesn't process cancellation request in bounded time then abort.
|
||||||
|
err = cancelConn.SetDeadline(time.Now().Add(15 * time.Second))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, 16)
|
||||||
|
binary.BigEndian.PutUint32(buf[0:4], 16)
|
||||||
|
binary.BigEndian.PutUint32(buf[4:8], 80877102)
|
||||||
|
binary.BigEndian.PutUint32(buf[8:12], uint32(c.pid))
|
||||||
|
binary.BigEndian.PutUint32(buf[12:16], uint32(c.SecretKey))
|
||||||
|
_, err = cancelConn.Write(buf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = cancelConn.Read(buf)
|
||||||
|
if err != io.EOF {
|
||||||
|
return fmt.Errorf("Server failed to close connection after cancel query request: %v %v", err, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
err := doCancel()
|
||||||
|
if err != nil {
|
||||||
|
c.Close() // Something is very wrong. Terminate the connection.
|
||||||
|
}
|
||||||
|
c.cancelQueryCompleted <- struct{}{}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Ping() error {
|
||||||
|
return c.PingContext(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) PingContext(ctx context.Context) error {
|
||||||
|
_, err := c.ExecContext(ctx, ";")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||||
|
err = c.waitForPreviousCancelQuery(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.initContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err = c.termContext(err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err = c.lock(); err != nil {
|
||||||
|
return commandTag, err
|
||||||
|
}
|
||||||
|
|
||||||
|
startTime := time.Now()
|
||||||
|
c.lastActivityTime = startTime
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if err == nil {
|
||||||
|
if c.shouldLog(LogLevelInfo) {
|
||||||
|
endTime := time.Now()
|
||||||
|
c.log(LogLevelInfo, "Exec", "sql", sql, "args", logQueryArgs(arguments), "time", endTime.Sub(startTime), "commandTag", commandTag)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if c.shouldLog(LogLevelError) {
|
||||||
|
c.log(LogLevelError, "Exec", "sql", sql, "args", logQueryArgs(arguments), "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if unlockErr := c.unlock(); unlockErr != nil && err == nil {
|
||||||
|
err = unlockErr
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err = c.sendQuery(sql, arguments...); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var softErr error
|
||||||
|
|
||||||
|
for {
|
||||||
|
var t byte
|
||||||
|
var r *msgReader
|
||||||
|
t, r, err = c.rxMsg()
|
||||||
|
if err != nil {
|
||||||
|
return commandTag, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch t {
|
||||||
|
case readyForQuery:
|
||||||
|
c.rxReadyForQuery(r)
|
||||||
|
return commandTag, softErr
|
||||||
|
case commandComplete:
|
||||||
|
commandTag = CommandTag(r.readCString())
|
||||||
|
default:
|
||||||
|
if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil {
|
||||||
|
softErr = e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return commandTag, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) initContext(ctx context.Context) error {
|
||||||
|
if c.ctxInProgress {
|
||||||
|
return errors.New("ctx already in progress")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.Done() == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
c.ctxInProgress = true
|
||||||
|
|
||||||
|
go c.contextHandler(ctx)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) termContext(opErr error) error {
|
||||||
|
if !c.ctxInProgress {
|
||||||
|
return opErr
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err = <-c.closedChan:
|
||||||
|
if opErr == nil {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
case c.doneChan <- struct{}{}:
|
||||||
|
err = opErr
|
||||||
|
}
|
||||||
|
|
||||||
|
c.ctxInProgress = false
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) contextHandler(ctx context.Context) {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
c.cancelQuery()
|
||||||
|
c.closedChan <- ctx.Err()
|
||||||
|
case <-c.doneChan:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error {
|
||||||
|
if atomic.LoadInt32(&c.cancelQueryInProgress) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-c.cancelQueryCompleted:
|
||||||
|
atomic.StoreInt32(&c.cancelQueryInProgress, 0)
|
||||||
|
if err := c.conn.SetDeadline(time.Time{}); err != nil {
|
||||||
|
c.Close() // Close connection if unable to disable deadline
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) ensureConnectionReadyForQuery() error {
|
||||||
|
for !c.readyForQuery {
|
||||||
|
t, r, err := c.rxMsg()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch t {
|
||||||
|
case errorResponse:
|
||||||
|
pgErr := c.rxErrorResponse(r)
|
||||||
|
if pgErr.Severity == "FATAL" {
|
||||||
|
return pgErr
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
err = c.processContextFreeMsg(t, r)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package pgx
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"golang.org/x/net/context"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -181,6 +182,10 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) {
|
|||||||
|
|
||||||
// Release gives up use of a connection.
|
// Release gives up use of a connection.
|
||||||
func (p *ConnPool) Release(conn *Conn) {
|
func (p *ConnPool) Release(conn *Conn) {
|
||||||
|
if conn.ctxInProgress {
|
||||||
|
panic("should never release when context is in progress")
|
||||||
|
}
|
||||||
|
|
||||||
if conn.txStatus != 'I' {
|
if conn.txStatus != 'I' {
|
||||||
conn.Exec("rollback")
|
conn.Exec("rollback")
|
||||||
}
|
}
|
||||||
@@ -357,6 +362,16 @@ func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag Comman
|
|||||||
return c.Exec(sql, arguments...)
|
return c.Exec(sql, arguments...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *ConnPool) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||||
|
var c *Conn
|
||||||
|
if c, err = p.Acquire(); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer p.Release(c)
|
||||||
|
|
||||||
|
return c.ExecContext(ctx, sql, arguments...)
|
||||||
|
}
|
||||||
|
|
||||||
// Query acquires a connection and delegates the call to that connection. When
|
// Query acquires a connection and delegates the call to that connection. When
|
||||||
// *Rows are closed, the connection is released automatically.
|
// *Rows are closed, the connection is released automatically.
|
||||||
func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) {
|
func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) {
|
||||||
@@ -377,6 +392,24 @@ func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) {
|
|||||||
return rows, nil
|
return rows, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *ConnPool) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) {
|
||||||
|
c, err := p.Acquire()
|
||||||
|
if err != nil {
|
||||||
|
// Because checking for errors can be deferred to the *Rows, build one with the error
|
||||||
|
return &Rows{closed: true, err: err}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := c.QueryContext(ctx, sql, args...)
|
||||||
|
if err != nil {
|
||||||
|
p.Release(c)
|
||||||
|
return rows, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rows.AfterClose(p.rowsAfterClose)
|
||||||
|
|
||||||
|
return rows, nil
|
||||||
|
}
|
||||||
|
|
||||||
// QueryRow acquires a connection and delegates the call to that connection. The
|
// QueryRow acquires a connection and delegates the call to that connection. The
|
||||||
// connection is released automatically after Scan is called on the returned
|
// connection is released automatically after Scan is called on the returned
|
||||||
// *Row.
|
// *Row.
|
||||||
@@ -385,6 +418,11 @@ func (p *ConnPool) QueryRow(sql string, args ...interface{}) *Row {
|
|||||||
return (*Row)(rows)
|
return (*Row)(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *ConnPool) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *Row {
|
||||||
|
rows, _ := p.QueryContext(ctx, sql, args...)
|
||||||
|
return (*Row)(rows)
|
||||||
|
}
|
||||||
|
|
||||||
// Begin acquires a connection and begins a transaction on it. When the
|
// Begin acquires a connection and begins a transaction on it. When the
|
||||||
// transaction is closed the connection will be automatically released.
|
// transaction is closed the connection will be automatically released.
|
||||||
func (p *ConnPool) Begin() (*Tx, error) {
|
func (p *ConnPool) Begin() (*Tx, error) {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package pgx_test
|
|||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"golang.org/x/net/context"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
@@ -816,6 +817,64 @@ func TestExecFailure(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExecContextWithoutCancelation(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
|
defer cancelFunc()
|
||||||
|
|
||||||
|
commandTag, err := conn.ExecContext(ctx, "create temporary table foo(id integer primary key);")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if commandTag != "CREATE TABLE" {
|
||||||
|
t.Fatalf("Unexpected results from ExecContext: %v", commandTag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecContextFailureWithoutCancelation(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
|
defer cancelFunc()
|
||||||
|
|
||||||
|
if _, err := conn.ExecContext(ctx, "selct;"); err == nil {
|
||||||
|
t.Fatal("Expected SQL syntax error")
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, _ := conn.Query("select 1")
|
||||||
|
rows.Close()
|
||||||
|
if rows.Err() != nil {
|
||||||
|
t.Fatalf("ExecContext failure appears to have broken connection: %v", rows.Err())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecContextCancelationCancelsQuery(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
|
go func() {
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
cancelFunc()
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := conn.ExecContext(ctx, "select pg_sleep(60)")
|
||||||
|
if err != context.Canceled {
|
||||||
|
t.Fatal("Expected context.Canceled err, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
func TestPrepare(t *testing.T) {
|
func TestPrepare(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,12 @@
|
|||||||
|
Add more testing
|
||||||
|
- stress test style
|
||||||
|
- pgmock
|
||||||
|
|
||||||
|
Add documentation
|
||||||
|
|
||||||
|
Add PrepareContext
|
||||||
|
Add context methods to ConnPool
|
||||||
|
Add context methods to Tx
|
||||||
|
Add context support database/sql
|
||||||
|
|
||||||
|
Benchmark - possibly cache done channel on Conn
|
||||||
@@ -66,7 +66,6 @@ func (ct *copyTo) readUntilReadyForQuery() {
|
|||||||
ct.conn.rxReadyForQuery(r)
|
ct.conn.rxReadyForQuery(r)
|
||||||
close(ct.readerErrChan)
|
close(ct.readerErrChan)
|
||||||
return
|
return
|
||||||
case commandComplete:
|
|
||||||
case errorResponse:
|
case errorResponse:
|
||||||
ct.readerErrChan <- ct.conn.rxErrorResponse(r)
|
ct.readerErrChan <- ct.conn.rxErrorResponse(r)
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -48,6 +48,10 @@ func fpInt64Arg(n int64) fpArg {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *fastpath) Call(oid OID, args []fpArg) (res []byte, err error) {
|
func (f *fastpath) Call(oid OID, args []fpArg) (res []byte, err error) {
|
||||||
|
if err := f.cn.ensureConnectionReadyForQuery(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
wbuf := newWriteBuf(f.cn, 'F') // function call
|
wbuf := newWriteBuf(f.cn, 'F') // function call
|
||||||
wbuf.WriteInt32(int32(oid)) // function object id
|
wbuf.WriteInt32(int32(oid)) // function object id
|
||||||
wbuf.WriteInt16(1) // # of argument format codes
|
wbuf.WriteInt16(1) // # of argument format codes
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ func mustReplicationConnect(t testing.TB, config pgx.ConnConfig) *pgx.Replicatio
|
|||||||
return conn
|
return conn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
func closeConn(t testing.TB, conn *pgx.Conn) {
|
func closeConn(t testing.TB, conn *pgx.Conn) {
|
||||||
err := conn.Close()
|
err := conn.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
+23
-8
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// msgReader is a helper that reads values from a PostgreSQL message.
|
// msgReader is a helper that reads values from a PostgreSQL message.
|
||||||
@@ -16,11 +17,6 @@ type msgReader struct {
|
|||||||
shouldLog func(lvl int) bool
|
shouldLog func(lvl int) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Err returns any error that the msgReader has experienced
|
|
||||||
func (r *msgReader) Err() error {
|
|
||||||
return r.err
|
|
||||||
}
|
|
||||||
|
|
||||||
// fatal tells rc that a Fatal error has occurred
|
// fatal tells rc that a Fatal error has occurred
|
||||||
func (r *msgReader) fatal(err error) {
|
func (r *msgReader) fatal(err error) {
|
||||||
if r.shouldLog(LogLevelTrace) {
|
if r.shouldLog(LogLevelTrace) {
|
||||||
@@ -40,20 +36,39 @@ func (r *msgReader) rxMsg() (byte, error) {
|
|||||||
r.log(LogLevelTrace, "msgReader.rxMsg discarding unread previous message", "msgBytesRemaining", r.msgBytesRemaining)
|
r.log(LogLevelTrace, "msgReader.rxMsg discarding unread previous message", "msgBytesRemaining", r.msgBytesRemaining)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := r.reader.Discard(int(r.msgBytesRemaining))
|
n, err := r.reader.Discard(int(r.msgBytesRemaining))
|
||||||
|
r.msgBytesRemaining -= int32(n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
|
||||||
|
r.fatal(err)
|
||||||
|
}
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
b, err := r.reader.Peek(5)
|
b, err := r.reader.Peek(5)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.fatal(err)
|
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
|
||||||
|
r.fatal(err)
|
||||||
|
}
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
msgType := b[0]
|
msgType := b[0]
|
||||||
r.msgBytesRemaining = int32(binary.BigEndian.Uint32(b[1:])) - 4
|
payloadSize := int32(binary.BigEndian.Uint32(b[1:])) - 4
|
||||||
|
|
||||||
|
// Try to preload bufio.Reader with entire message
|
||||||
|
b, err = r.reader.Peek(5 + int(payloadSize))
|
||||||
|
if err != nil && err != bufio.ErrBufferFull {
|
||||||
|
if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
|
||||||
|
r.fatal(err)
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
r.msgBytesRemaining = payloadSize
|
||||||
r.reader.Discard(5)
|
r.reader.Discard(5)
|
||||||
|
|
||||||
return msgType, nil
|
return msgType, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,189 @@
|
|||||||
|
package pgx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgmock/pgmsg"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMsgReaderPrebuffersWhenPossible(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
msgType byte
|
||||||
|
payloadSize int32
|
||||||
|
buffered bool
|
||||||
|
}{
|
||||||
|
{1, 50, true},
|
||||||
|
{2, 0, true},
|
||||||
|
{3, 500, true},
|
||||||
|
{4, 1050, true},
|
||||||
|
{5, 1500, true},
|
||||||
|
{6, 1500, true},
|
||||||
|
{7, 4000, true},
|
||||||
|
{8, 24000, false},
|
||||||
|
{9, 4000, true},
|
||||||
|
{1, 1500, true},
|
||||||
|
{2, 0, true},
|
||||||
|
{3, 500, true},
|
||||||
|
{4, 1050, true},
|
||||||
|
{5, 1500, true},
|
||||||
|
{6, 1500, true},
|
||||||
|
{7, 4000, true},
|
||||||
|
{8, 14000, false},
|
||||||
|
{9, 0, true},
|
||||||
|
{1, 500, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
var bigEndian pgmsg.BigEndianBuf
|
||||||
|
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
_, err = conn.Write([]byte{tt.msgType})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = conn.Write(bigEndian.Int32(tt.payloadSize + 4))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := make([]byte, int(tt.payloadSize))
|
||||||
|
_, err = conn.Write(payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
mr := &msgReader{
|
||||||
|
reader: bufio.NewReader(conn),
|
||||||
|
shouldLog: func(int) bool { return false },
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
msgType, err := mr.rxMsg()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("%d. Unexpected error: %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if msgType != tt.msgType {
|
||||||
|
t.Fatalf("%d. Expected %v, got %v", 1, i, tt.msgType, msgType)
|
||||||
|
}
|
||||||
|
|
||||||
|
if mr.reader.Buffered() < int(tt.payloadSize) && tt.buffered {
|
||||||
|
t.Fatalf("%d. Expected message to be buffered with at least %d bytes, but only %v bytes buffered", i, tt.payloadSize, mr.reader.Buffered())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMsgReaderDeadlineNeverInterruptsNormalSizedMessages(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
testCount := 10000
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
var bigEndian pgmsg.BigEndianBuf
|
||||||
|
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
for i := 0; i < testCount; i++ {
|
||||||
|
msgType := byte(i)
|
||||||
|
|
||||||
|
_, err = conn.Write([]byte{msgType})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msgSize := i % 4000
|
||||||
|
|
||||||
|
_, err = conn.Write(bigEndian.Int32(int32(msgSize + 4)))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := make([]byte, msgSize)
|
||||||
|
_, err = conn.Write(payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
mr := &msgReader{
|
||||||
|
reader: bufio.NewReader(conn),
|
||||||
|
shouldLog: func(int) bool { return false },
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.SetReadDeadline(time.Now().Add(time.Millisecond))
|
||||||
|
|
||||||
|
i := 0
|
||||||
|
for {
|
||||||
|
msgType, err := mr.rxMsg()
|
||||||
|
if err != nil {
|
||||||
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||||
|
conn.SetReadDeadline(time.Now().Add(time.Millisecond))
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
t.Fatalf("%d. Unexpected error: %v", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedMsgType := byte(i)
|
||||||
|
if msgType != expectedMsgType {
|
||||||
|
t.Fatalf("%d. Expected %v, got %v", i, expectedMsgType, msgType)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedMsgSize := i % 4000
|
||||||
|
payload := mr.readBytes(mr.msgBytesRemaining)
|
||||||
|
if mr.err != nil {
|
||||||
|
t.Fatalf("%d. readBytes killed msgReader: %v", i, mr.err)
|
||||||
|
}
|
||||||
|
if len(payload) != expectedMsgSize {
|
||||||
|
t.Fatalf("%d. Expected %v, got %v", i, expectedMsgSize, len(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
i++
|
||||||
|
if i == testCount {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"golang.org/x/net/context"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -55,7 +56,9 @@ func (rows *Rows) FieldDescriptions() []FieldDescription {
|
|||||||
return rows.fields
|
return rows.fields
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rows *Rows) close() {
|
// Close closes the rows, making the connection ready for use again. It is safe
|
||||||
|
// to call Close after rows is already closed.
|
||||||
|
func (rows *Rows) Close() {
|
||||||
if rows.closed {
|
if rows.closed {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -67,6 +70,8 @@ func (rows *Rows) close() {
|
|||||||
|
|
||||||
rows.closed = true
|
rows.closed = true
|
||||||
|
|
||||||
|
rows.err = rows.conn.termContext(rows.err)
|
||||||
|
|
||||||
if rows.err == nil {
|
if rows.err == nil {
|
||||||
if rows.conn.shouldLog(LogLevelInfo) {
|
if rows.conn.shouldLog(LogLevelInfo) {
|
||||||
endTime := time.Now()
|
endTime := time.Now()
|
||||||
@@ -81,63 +86,10 @@ func (rows *Rows) close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rows *Rows) readUntilReadyForQuery() {
|
|
||||||
for {
|
|
||||||
t, r, err := rows.conn.rxMsg()
|
|
||||||
if err != nil {
|
|
||||||
rows.close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
switch t {
|
|
||||||
case readyForQuery:
|
|
||||||
rows.conn.rxReadyForQuery(r)
|
|
||||||
rows.close()
|
|
||||||
return
|
|
||||||
case rowDescription:
|
|
||||||
case dataRow:
|
|
||||||
case commandComplete:
|
|
||||||
case bindComplete:
|
|
||||||
case errorResponse:
|
|
||||||
err = rows.conn.rxErrorResponse(r)
|
|
||||||
if rows.err == nil {
|
|
||||||
rows.err = err
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
err = rows.conn.processContextFreeMsg(t, r)
|
|
||||||
if err != nil {
|
|
||||||
rows.close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the rows, making the connection ready for use again. It is safe
|
|
||||||
// to call Close after rows is already closed.
|
|
||||||
func (rows *Rows) Close() {
|
|
||||||
if rows.closed {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
rows.readUntilReadyForQuery()
|
|
||||||
rows.close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rows *Rows) Err() error {
|
func (rows *Rows) Err() error {
|
||||||
return rows.err
|
return rows.err
|
||||||
}
|
}
|
||||||
|
|
||||||
// abort signals that the query was not successfully sent to the server.
|
|
||||||
// This differs from Fatal in that it is not necessary to readUntilReadyForQuery
|
|
||||||
func (rows *Rows) abort(err error) {
|
|
||||||
if rows.err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rows.err = err
|
|
||||||
rows.close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fatal signals an error occurred after the query was sent to the server. It
|
// Fatal signals an error occurred after the query was sent to the server. It
|
||||||
// closes the rows automatically.
|
// closes the rows automatically.
|
||||||
func (rows *Rows) Fatal(err error) {
|
func (rows *Rows) Fatal(err error) {
|
||||||
@@ -169,10 +121,6 @@ func (rows *Rows) Next() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch t {
|
switch t {
|
||||||
case readyForQuery:
|
|
||||||
rows.conn.rxReadyForQuery(r)
|
|
||||||
rows.close()
|
|
||||||
return false
|
|
||||||
case dataRow:
|
case dataRow:
|
||||||
fieldCount := r.readInt16()
|
fieldCount := r.readInt16()
|
||||||
if int(fieldCount) != len(rows.fields) {
|
if int(fieldCount) != len(rows.fields) {
|
||||||
@@ -183,7 +131,9 @@ func (rows *Rows) Next() bool {
|
|||||||
rows.mr = r
|
rows.mr = r
|
||||||
return true
|
return true
|
||||||
case commandComplete:
|
case commandComplete:
|
||||||
case bindComplete:
|
rows.Close()
|
||||||
|
return false
|
||||||
|
|
||||||
default:
|
default:
|
||||||
err = rows.conn.processContextFreeMsg(t, r)
|
err = rows.conn.processContextFreeMsg(t, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -441,32 +391,7 @@ func (rows *Rows) AfterClose(f func(*Rows)) {
|
|||||||
// be returned in an error state. So it is allowed to ignore the error returned
|
// be returned in an error state. So it is allowed to ignore the error returned
|
||||||
// from Query and handle it in *Rows.
|
// from Query and handle it in *Rows.
|
||||||
func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) {
|
func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) {
|
||||||
c.lastActivityTime = time.Now()
|
return c.QueryContext(context.Background(), sql, args...)
|
||||||
|
|
||||||
rows := c.getRows(sql, args)
|
|
||||||
|
|
||||||
if err := c.lock(); err != nil {
|
|
||||||
rows.abort(err)
|
|
||||||
return rows, err
|
|
||||||
}
|
|
||||||
rows.unlockConn = true
|
|
||||||
|
|
||||||
ps, ok := c.preparedStatements[sql]
|
|
||||||
if !ok {
|
|
||||||
var err error
|
|
||||||
ps, err = c.Prepare("", sql)
|
|
||||||
if err != nil {
|
|
||||||
rows.abort(err)
|
|
||||||
return rows, rows.err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rows.sql = ps.SQL
|
|
||||||
rows.fields = ps.FieldDescriptions
|
|
||||||
err := c.sendPreparedQuery(ps, args...)
|
|
||||||
if err != nil {
|
|
||||||
rows.abort(err)
|
|
||||||
}
|
|
||||||
return rows, rows.err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) getRows(sql string, args []interface{}) *Rows {
|
func (c *Conn) getRows(sql string, args []interface{}) *Rows {
|
||||||
@@ -492,3 +417,51 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row {
|
|||||||
rows, _ := c.Query(sql, args...)
|
rows, _ := c.Query(sql, args...)
|
||||||
return (*Row)(rows)
|
return (*Row)(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (rows *Rows, err error) {
|
||||||
|
err = c.waitForPreviousCancelQuery(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.lastActivityTime = time.Now()
|
||||||
|
|
||||||
|
rows = c.getRows(sql, args)
|
||||||
|
|
||||||
|
if err := c.lock(); err != nil {
|
||||||
|
rows.Fatal(err)
|
||||||
|
return rows, err
|
||||||
|
}
|
||||||
|
rows.unlockConn = true
|
||||||
|
|
||||||
|
ps, ok := c.preparedStatements[sql]
|
||||||
|
if !ok {
|
||||||
|
var err error
|
||||||
|
ps, err = c.PrepareExContext(ctx, "", sql, nil)
|
||||||
|
if err != nil {
|
||||||
|
rows.Fatal(err)
|
||||||
|
return rows, rows.err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rows.sql = ps.SQL
|
||||||
|
rows.fields = ps.FieldDescriptions
|
||||||
|
|
||||||
|
err = c.initContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
rows.Fatal(err)
|
||||||
|
return rows, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.sendPreparedQuery(ps, args...)
|
||||||
|
if err != nil {
|
||||||
|
rows.Fatal(err)
|
||||||
|
err = c.termContext(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rows, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *Row {
|
||||||
|
rows, _ := c.QueryContext(ctx, sql, args...)
|
||||||
|
return (*Row)(rows)
|
||||||
|
}
|
||||||
|
|||||||
+163
@@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"golang.org/x/net/context"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -1412,3 +1413,165 @@ func TestConnQueryDatabaseSQLNullX(t *testing.T) {
|
|||||||
|
|
||||||
ensureConnValid(t, conn)
|
ensureConnValid(t, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQueryContextSuccess(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
|
defer cancelFunc()
|
||||||
|
|
||||||
|
rows, err := conn.QueryContext(ctx, "select 42::integer")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result, rowCount int
|
||||||
|
for rows.Next() {
|
||||||
|
err = rows.Scan(&result)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
rowCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows.Err() != nil {
|
||||||
|
t.Fatal(rows.Err())
|
||||||
|
}
|
||||||
|
|
||||||
|
if rowCount != 1 {
|
||||||
|
t.Fatalf("Expected 1 row, got %d", rowCount)
|
||||||
|
}
|
||||||
|
if result != 42 {
|
||||||
|
t.Fatalf("Expected result 42, got %d", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryContextErrorWhileReceivingRows(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
|
defer cancelFunc()
|
||||||
|
|
||||||
|
rows, err := conn.QueryContext(ctx, "select 10/(10-n) from generate_series(1, 100) n")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result, rowCount int
|
||||||
|
for rows.Next() {
|
||||||
|
err = rows.Scan(&result)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
rowCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows.Err() == nil || rows.Err().Error() != "ERROR: division by zero (SQLSTATE 22012)" {
|
||||||
|
t.Fatalf("Expected division by zero error, but got %v", rows.Err())
|
||||||
|
}
|
||||||
|
|
||||||
|
if rowCount != 9 {
|
||||||
|
t.Fatalf("Expected 9 rows, got %d", rowCount)
|
||||||
|
}
|
||||||
|
if result != 10 {
|
||||||
|
t.Fatalf("Expected result 10, got %d", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryContextCancelationCancelsQuery(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
|
go func() {
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
cancelFunc()
|
||||||
|
}()
|
||||||
|
|
||||||
|
rows, err := conn.QueryContext(ctx, "select pg_sleep(5)")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
t.Fatal("No rows should ever be ready -- context cancel apparently did not happen")
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows.Err() != context.Canceled {
|
||||||
|
t.Fatal("Expected context.Canceled error, got %v", rows.Err())
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryRowContextSuccess(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
|
defer cancelFunc()
|
||||||
|
|
||||||
|
var result int
|
||||||
|
err := conn.QueryRowContext(ctx, "select 42::integer").Scan(&result)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if result != 42 {
|
||||||
|
t.Fatalf("Expected result 42, got %d", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryRowContextErrorWhileReceivingRow(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
|
defer cancelFunc()
|
||||||
|
|
||||||
|
var result int
|
||||||
|
err := conn.QueryRowContext(ctx, "select 10/0").Scan(&result)
|
||||||
|
if err == nil || err.Error() != "ERROR: division by zero (SQLSTATE 22012)" {
|
||||||
|
t.Fatalf("Expected division by zero error, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryRowContextCancelationCancelsQuery(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
|
go func() {
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
cancelFunc()
|
||||||
|
}()
|
||||||
|
|
||||||
|
var result []byte
|
||||||
|
err := conn.QueryRowContext(ctx, "select pg_sleep(5)").Scan(&result)
|
||||||
|
if err != context.Canceled {
|
||||||
|
t.Fatal("Expected context.Canceled error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|||||||
+4
-4
@@ -289,7 +289,7 @@ func (rc *ReplicationConn) WaitForReplicationMessage(timeout time.Duration) (r *
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Wait until there is a byte available before continuing onto the normal msg reading path
|
// Wait until there is a byte available before continuing onto the normal msg reading path
|
||||||
_, err = rc.c.reader.Peek(1)
|
_, err = rc.c.mr.reader.Peek(1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rc.c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline
|
rc.c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline
|
||||||
if err, ok := err.(*net.OpError); ok && err.Timeout() {
|
if err, ok := err.(*net.OpError); ok && err.Timeout() {
|
||||||
@@ -312,14 +312,14 @@ func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) {
|
|||||||
rows := rc.c.getRows(sql, nil)
|
rows := rc.c.getRows(sql, nil)
|
||||||
|
|
||||||
if err := rc.c.lock(); err != nil {
|
if err := rc.c.lock(); err != nil {
|
||||||
rows.abort(err)
|
rows.Fatal(err)
|
||||||
return rows, err
|
return rows, err
|
||||||
}
|
}
|
||||||
rows.unlockConn = true
|
rows.unlockConn = true
|
||||||
|
|
||||||
err := rc.c.sendSimpleQuery(sql)
|
err := rc.c.sendSimpleQuery(sql)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rows.abort(err)
|
rows.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var t byte
|
var t byte
|
||||||
@@ -337,7 +337,7 @@ func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) {
|
|||||||
// only Oids. Not much we can do about this.
|
// only Oids. Not much we can do about this.
|
||||||
default:
|
default:
|
||||||
if e := rc.c.processContextFreeMsg(t, r); e != nil {
|
if e := rc.c.processContextFreeMsg(t, r); e != nil {
|
||||||
rows.abort(e)
|
rows.Fatal(e)
|
||||||
return rows, e
|
return rows, e
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,6 +44,7 @@
|
|||||||
package stdlib
|
package stdlib
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"errors"
|
"errors"
|
||||||
@@ -211,6 +212,21 @@ func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) {
|
|||||||
return c.queryPrepared("", argsV)
|
return c.queryPrepared("", argsV)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) {
|
||||||
|
if !c.conn.IsAlive() {
|
||||||
|
return nil, driver.ErrBadConn
|
||||||
|
}
|
||||||
|
|
||||||
|
ps, err := c.conn.PrepareExContext(ctx, "", query, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
restrictBinaryToDatabaseSqlTypes(ps)
|
||||||
|
|
||||||
|
return c.queryPreparedContext(ctx, "", argsV)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) {
|
func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) {
|
||||||
if !c.conn.IsAlive() {
|
if !c.conn.IsAlive() {
|
||||||
return nil, driver.ErrBadConn
|
return nil, driver.ErrBadConn
|
||||||
@@ -226,6 +242,22 @@ func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, er
|
|||||||
return &Rows{rows: rows}, nil
|
return &Rows{rows: rows}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []driver.NamedValue) (driver.Rows, error) {
|
||||||
|
if !c.conn.IsAlive() {
|
||||||
|
return nil, driver.ErrBadConn
|
||||||
|
}
|
||||||
|
|
||||||
|
args := namedValueToInterface(argsV)
|
||||||
|
|
||||||
|
rows, err := c.conn.QueryContext(ctx, name, args...)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Rows{rows: rows}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Anything that isn't a database/sql compatible type needs to be forced to
|
// Anything that isn't a database/sql compatible type needs to be forced to
|
||||||
// text format so that pgx.Rows.Values doesn't decode it into a native type
|
// text format so that pgx.Rows.Values doesn't decode it into a native type
|
||||||
// (e.g. []int32)
|
// (e.g. []int32)
|
||||||
@@ -318,6 +350,18 @@ func valueToInterface(argsV []driver.Value) []interface{} {
|
|||||||
return args
|
return args
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func namedValueToInterface(argsV []driver.NamedValue) []interface{} {
|
||||||
|
args := make([]interface{}, 0, len(argsV))
|
||||||
|
for _, v := range argsV {
|
||||||
|
if v.Value != nil {
|
||||||
|
args = append(args, v.Value.(interface{}))
|
||||||
|
} else {
|
||||||
|
args = append(args, nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
type Tx struct {
|
type Tx struct {
|
||||||
conn *pgx.Conn
|
conn *pgx.Conn
|
||||||
}
|
}
|
||||||
|
|||||||
+44
-1
@@ -3,6 +3,7 @@ package pgx_test
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"golang.org/x/net/context"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -44,6 +45,8 @@ func TestStressConnPool(t *testing.T) {
|
|||||||
{"listenAndPoolUnlistens", listenAndPoolUnlistens},
|
{"listenAndPoolUnlistens", listenAndPoolUnlistens},
|
||||||
{"reset", func(p *pgx.ConnPool, n int) error { p.Reset(); return nil }},
|
{"reset", func(p *pgx.ConnPool, n int) error { p.Reset(); return nil }},
|
||||||
{"poolPrepareUseAndDeallocate", poolPrepareUseAndDeallocate},
|
{"poolPrepareUseAndDeallocate", poolPrepareUseAndDeallocate},
|
||||||
|
{"canceledQueryContext", canceledQueryContext},
|
||||||
|
{"canceledExecContext", canceledExecContext},
|
||||||
}
|
}
|
||||||
|
|
||||||
var timer *time.Timer
|
var timer *time.Timer
|
||||||
@@ -63,7 +66,7 @@ func TestStressConnPool(t *testing.T) {
|
|||||||
action := actions[rand.Intn(len(actions))]
|
action := actions[rand.Intn(len(actions))]
|
||||||
err := action.fn(pool, n)
|
err := action.fn(pool, n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- err
|
errChan <- fmt.Errorf("%s: %v", action.name, err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -344,3 +347,43 @@ func txMultipleQueries(pool *pgx.ConnPool, actionNum int) error {
|
|||||||
|
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func canceledQueryContext(pool *pgx.ConnPool, actionNum int) error {
|
||||||
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
|
go func() {
|
||||||
|
time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond)
|
||||||
|
cancelFunc()
|
||||||
|
}()
|
||||||
|
|
||||||
|
rows, err := pool.QueryContext(ctx, "select pg_sleep(2)")
|
||||||
|
if err == context.Canceled {
|
||||||
|
return nil
|
||||||
|
} else if err != nil {
|
||||||
|
return fmt.Errorf("Only allowed error is context.Canceled, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
return errors.New("should never receive row")
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows.Err() != context.Canceled {
|
||||||
|
return fmt.Errorf("Expected context.Canceled error, got %v", rows.Err())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func canceledExecContext(pool *pgx.ConnPool, actionNum int) error {
|
||||||
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
|
go func() {
|
||||||
|
time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond)
|
||||||
|
cancelFunc()
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := pool.ExecContext(ctx, "select pg_sleep(2)")
|
||||||
|
if err != context.Canceled {
|
||||||
|
return fmt.Errorf("Expected context.Canceled error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user