+11
-2
@@ -4,6 +4,9 @@ go:
|
||||
- 1.x
|
||||
- tip
|
||||
|
||||
git:
|
||||
depth: 1
|
||||
|
||||
# Derived from https://github.com/lib/pq/blob/master/.travis.yml
|
||||
before_install:
|
||||
- ./travis/before_install.bash
|
||||
@@ -11,6 +14,8 @@ before_install:
|
||||
env:
|
||||
global:
|
||||
- GO111MODULE=on
|
||||
- GOPROXY=https://proxy.golang.org
|
||||
- GOFLAGS=-mod=readonly
|
||||
- PGX_TEST_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
- PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql database=pgx_test"
|
||||
- PGX_TEST_TCP_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test
|
||||
@@ -25,11 +30,15 @@ env:
|
||||
- PGVERSION=9.4
|
||||
- PGVERSION=9.3
|
||||
|
||||
cache:
|
||||
directories:
|
||||
- $HOME/.cache/go-build
|
||||
- $HOME/gopath/pkg/mod
|
||||
|
||||
before_script:
|
||||
- ./travis/before_script.bash
|
||||
|
||||
install:
|
||||
- ./travis/install.bash
|
||||
install: go mod download
|
||||
|
||||
script:
|
||||
- ./travis/script.bash
|
||||
|
||||
+24
-14
@@ -31,7 +31,7 @@ const clientNonceLen = 18
|
||||
|
||||
// Perform SCRAM authentication.
|
||||
func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
|
||||
sc, err := newScramClient(serverAuthMechanisms, c.Config.Password)
|
||||
sc, err := newScramClient(serverAuthMechanisms, c.config.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -47,11 +47,11 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
|
||||
}
|
||||
|
||||
// Receive server-first-message payload in a AuthenticationSASLContinue.
|
||||
authMsg, err := c.rxAuthMsg(pgproto3.AuthTypeSASLContinue)
|
||||
saslContinue, err := c.rxSASLContinue()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = sc.recvServerFirstMessage(authMsg.SASLData)
|
||||
err = sc.recvServerFirstMessage(saslContinue.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -66,27 +66,37 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
|
||||
}
|
||||
|
||||
// Receive server-final-message payload in a AuthenticationSASLFinal.
|
||||
authMsg, err = c.rxAuthMsg(pgproto3.AuthTypeSASLFinal)
|
||||
saslFinal, err := c.rxSASLFinal()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sc.recvServerFinalMessage(authMsg.SASLData)
|
||||
return sc.recvServerFinalMessage(saslFinal.Data)
|
||||
}
|
||||
|
||||
func (c *PgConn) rxAuthMsg(typ uint32) (*pgproto3.Authentication, error) {
|
||||
msg, err := c.ReceiveMessage()
|
||||
func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) {
|
||||
msg, err := c.receiveMessage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authMsg, ok := msg.(*pgproto3.Authentication)
|
||||
if !ok {
|
||||
return nil, errors.New("unexpected message type")
|
||||
}
|
||||
if authMsg.Type != typ {
|
||||
return nil, errors.New("unexpected auth type")
|
||||
saslContinue, ok := msg.(*pgproto3.AuthenticationSASLContinue)
|
||||
if ok {
|
||||
return saslContinue, nil
|
||||
}
|
||||
|
||||
return authMsg, nil
|
||||
return nil, errors.New("expected AuthenticationSASLContinue message but received unexpected message")
|
||||
}
|
||||
|
||||
func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
|
||||
msg, err := c.receiveMessage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
saslFinal, ok := msg.(*pgproto3.AuthenticationSASLFinal)
|
||||
if ok {
|
||||
return saslFinal, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("expected AuthenticationSASLFinal message but received unexpected message")
|
||||
}
|
||||
|
||||
type scramClient struct {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math"
|
||||
"net"
|
||||
@@ -17,22 +18,26 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/chunkreader/v2"
|
||||
"github.com/jackc/pgpassfile"
|
||||
"github.com/jackc/pgproto3/v2"
|
||||
errors "golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
||||
type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
||||
|
||||
// Config is the settings used to establish a connection to a PostgreSQL server.
|
||||
// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig and
|
||||
// then it can be modified. A manually initialized Config will cause ConnectConfig to panic.
|
||||
type Config struct {
|
||||
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
|
||||
Port uint16
|
||||
Database string
|
||||
User string
|
||||
Password string
|
||||
TLSConfig *tls.Config // nil disables TLS
|
||||
DialFunc DialFunc // e.g. net.Dialer.DialContext
|
||||
TLSConfig *tls.Config // nil disables TLS
|
||||
DialFunc DialFunc // e.g. net.Dialer.DialContext
|
||||
BuildFrontend BuildFrontendFunc
|
||||
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
||||
|
||||
Fallbacks []*FallbackConfig
|
||||
@@ -52,6 +57,8 @@ type Config struct {
|
||||
|
||||
// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
|
||||
OnNotification NotificationHandler
|
||||
|
||||
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
|
||||
}
|
||||
|
||||
// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a
|
||||
@@ -134,36 +141,48 @@ func NetworkAddress(host string, port uint16) (network, address string) {
|
||||
//
|
||||
// When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn
|
||||
// does not.
|
||||
//
|
||||
// In addition, ParseConfig accepts the following options:
|
||||
//
|
||||
// min_read_buffer_size
|
||||
// The minimum size of the internal read buffer. Default 8192.
|
||||
func ParseConfig(connString string) (*Config, error) {
|
||||
settings := defaultSettings()
|
||||
addEnvSettings(settings)
|
||||
|
||||
if connString != "" {
|
||||
// connString may be a database URL or a DSN
|
||||
if strings.HasPrefix(connString, "postgres://") {
|
||||
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
|
||||
err := addURLSettings(settings, connString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err}
|
||||
}
|
||||
} else {
|
||||
err := addDSNSettings(settings, connString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32)
|
||||
if err != nil {
|
||||
return nil, &parseConfigError{connString: connString, msg: "cannot parse min_read_buffer_size", err: err}
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
Database: settings["database"],
|
||||
User: settings["user"],
|
||||
Password: settings["password"],
|
||||
RuntimeParams: make(map[string]string),
|
||||
createdByParseConfig: true,
|
||||
Database: settings["database"],
|
||||
User: settings["user"],
|
||||
Password: settings["password"],
|
||||
RuntimeParams: make(map[string]string),
|
||||
BuildFrontend: makeDefaultBuildFrontendFunc(int(minReadBufferSize)),
|
||||
}
|
||||
|
||||
if connectTimeout, present := settings["connect_timeout"]; present {
|
||||
dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err}
|
||||
}
|
||||
config.DialFunc = dialFunc
|
||||
} else {
|
||||
@@ -184,6 +203,7 @@ func ParseConfig(connString string) (*Config, error) {
|
||||
"sslcert": struct{}{},
|
||||
"sslrootcert": struct{}{},
|
||||
"target_session_attrs": struct{}{},
|
||||
"min_read_buffer_size": struct{}{},
|
||||
}
|
||||
|
||||
for k, v := range settings {
|
||||
@@ -208,7 +228,7 @@ func ParseConfig(connString string) (*Config, error) {
|
||||
|
||||
port, err := parsePort(portStr)
|
||||
if err != nil {
|
||||
return nil, errors.Errorf("invalid port: %w", err)
|
||||
return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err}
|
||||
}
|
||||
|
||||
var tlsConfigs []*tls.Config
|
||||
@@ -220,7 +240,7 @@ func ParseConfig(connString string) (*Config, error) {
|
||||
var err error
|
||||
tlsConfigs, err = configTLS(settings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -253,7 +273,7 @@ func ParseConfig(connString string) (*Config, error) {
|
||||
if settings["target_session_attrs"] == "read-write" {
|
||||
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite
|
||||
} else if settings["target_session_attrs"] != "any" {
|
||||
return nil, errors.Errorf("unknown target_session_attrs value: %v", settings["target_session_attrs"])
|
||||
return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", settings["target_session_attrs"])}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
@@ -276,6 +296,8 @@ func defaultSettings() map[string]string {
|
||||
|
||||
settings["target_session_attrs"] = "any"
|
||||
|
||||
settings["min_read_buffer_size"] = "8192"
|
||||
|
||||
return settings
|
||||
}
|
||||
|
||||
@@ -473,6 +495,18 @@ func makeDefaultDialer() *net.Dialer {
|
||||
return &net.Dialer{KeepAlive: 5 * time.Minute}
|
||||
}
|
||||
|
||||
func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc {
|
||||
return func(r io.Reader, w io.Writer) Frontend {
|
||||
cr, err := chunkreader.NewConfig(r, chunkreader.Config{MinBufLen: minBufferLen})
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("BUG: chunkreader.NewConfig failed: %v", err))
|
||||
}
|
||||
frontend := pgproto3.NewFrontend(cr, w)
|
||||
|
||||
return frontend
|
||||
}
|
||||
}
|
||||
|
||||
func makeConnectTimeoutDialFunc(s string) (DialFunc, error) {
|
||||
timeout, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
|
||||
@@ -214,6 +214,18 @@ func TestParseConfig(t *testing.T) {
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "database url postgresql protocol",
|
||||
connString: "postgresql://jack@localhost:5432/mydb?sslmode=disable",
|
||||
config: &pgconn.Config{
|
||||
User: "jack",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
TLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "DSN everything",
|
||||
connString: "user=jack password=secret host=localhost port=5432 database=mydb sslmode=disable application_name=pgxtest search_path=myschema",
|
||||
@@ -561,3 +573,15 @@ func TestParseConfigReadsPgPassfile(t *testing.T) {
|
||||
|
||||
assertConfigsEqual(t, expected, actual, "passfile")
|
||||
}
|
||||
|
||||
func TestParseConfigExtractsMinReadBufferSize(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config, err := pgconn.ParseConfig("min_read_buffer_size=0")
|
||||
require.NoError(t, err)
|
||||
_, present := config.RuntimeParams["min_read_buffer_size"]
|
||||
require.False(t, present)
|
||||
|
||||
// The buffer size is internal so there isn't much that can be done to test it other than see that the runtime param
|
||||
// was removed.
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ reads all rows into memory.
|
||||
|
||||
Executing Multiple Queries in a Single Round Trip
|
||||
|
||||
Exec and ExecBatch can execute multiple queries in a single round trip. The return readers that iterate over each query
|
||||
Exec and ExecBatch can execute multiple queries in a single round trip. They return readers that iterate over each query
|
||||
result. The ReadAll method reads all query results into memory.
|
||||
|
||||
Context Support
|
||||
|
||||
@@ -2,22 +2,31 @@ package pgconn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
errors "golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// ErrTLSRefused occurs when the connection attempt requires TLS and the
|
||||
// PostgreSQL server refuses to use TLS
|
||||
var ErrTLSRefused = errors.New("server refused TLS connection")
|
||||
// SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server.
|
||||
func SafeToRetry(err error) bool {
|
||||
if e, ok := err.(interface{ SafeToRetry() bool }); ok {
|
||||
return e.SafeToRetry()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ErrConnBusy occurs when the connection is busy (for example, in the middle of reading query results) and another
|
||||
// action is attempted.
|
||||
var ErrConnBusy = errors.New("conn is busy")
|
||||
// Timeout checks if err was was caused by a timeout. To be specific, it is true if err is or was caused by a
|
||||
// context.Canceled, context.Canceled or an implementer of net.Error where Timeout() is true.
|
||||
func Timeout(err error) bool {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return true
|
||||
}
|
||||
|
||||
// ErrNoBytesSent is used to annotate an error that occurred without sending any bytes to the server. This can be used
|
||||
// to implement safe retry logic. ErrNoBytesSent will never occur alone. It will always be wrapped by another error.
|
||||
var ErrNoBytesSent = errors.New("no bytes sent to server")
|
||||
var netErr net.Error
|
||||
return errors.As(err, &netErr) && netErr.Timeout()
|
||||
}
|
||||
|
||||
// PgError represents an error reported by the PostgreSQL server. See
|
||||
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
|
||||
@@ -46,44 +55,107 @@ func (pe *PgError) Error() string {
|
||||
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
|
||||
}
|
||||
|
||||
// linkedError connects two errors as if err wrapped next.
|
||||
type linkedError struct {
|
||||
err error
|
||||
next error
|
||||
type connectError struct {
|
||||
config *Config
|
||||
msg string
|
||||
err error
|
||||
}
|
||||
|
||||
func (le *linkedError) Error() string {
|
||||
return le.err.Error()
|
||||
}
|
||||
|
||||
func (le *linkedError) Is(target error) bool {
|
||||
return errors.Is(le.err, target)
|
||||
}
|
||||
|
||||
func (le *linkedError) As(target interface{}) bool {
|
||||
return errors.As(le.err, target)
|
||||
}
|
||||
|
||||
func (le *linkedError) Unwrap() error {
|
||||
return le.next
|
||||
}
|
||||
|
||||
// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() ==
|
||||
// true. Otherwise returns err.
|
||||
func preferContextOverNetTimeoutError(ctx context.Context, err error) error {
|
||||
if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
func (e *connectError) Error() string {
|
||||
sb := &strings.Builder{}
|
||||
fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg)
|
||||
if e.err != nil {
|
||||
fmt.Fprintf(sb, " (%s)", e.err.Error())
|
||||
}
|
||||
return err
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// linkErrors connects outer and inner as if the the fully unwrapped outer wrapped inner. If either outer or inner is nil then the other is returned.
|
||||
func linkErrors(outer, inner error) error {
|
||||
if outer == nil {
|
||||
return inner
|
||||
func (e *connectError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
type connLockError struct {
|
||||
status string
|
||||
}
|
||||
|
||||
func (e *connLockError) SafeToRetry() bool {
|
||||
return true // a lock failure by definition happens before the connection is used.
|
||||
}
|
||||
|
||||
func (e *connLockError) Error() string {
|
||||
return e.status
|
||||
}
|
||||
|
||||
type parseConfigError struct {
|
||||
connString string
|
||||
msg string
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *parseConfigError) Error() string {
|
||||
if e.err == nil {
|
||||
return fmt.Sprintf("cannot parse `%s`: %s", e.connString, e.msg)
|
||||
}
|
||||
if inner == nil {
|
||||
return outer
|
||||
return fmt.Sprintf("cannot parse `%s`: %s (%s)", e.connString, e.msg, e.err.Error())
|
||||
}
|
||||
|
||||
func (e *parseConfigError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
type pgconnError struct {
|
||||
msg string
|
||||
err error
|
||||
safeToRetry bool
|
||||
}
|
||||
|
||||
func (e *pgconnError) Error() string {
|
||||
if e.msg == "" {
|
||||
return e.err.Error()
|
||||
}
|
||||
return &linkedError{err: outer, next: inner}
|
||||
if e.err == nil {
|
||||
return e.msg
|
||||
}
|
||||
return fmt.Sprintf("%s: %s", e.msg, e.err.Error())
|
||||
}
|
||||
|
||||
func (e *pgconnError) SafeToRetry() bool {
|
||||
return e.safeToRetry
|
||||
}
|
||||
|
||||
func (e *pgconnError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
type contextAlreadyDoneError struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *contextAlreadyDoneError) Error() string {
|
||||
return fmt.Sprintf("context already done: %s", e.err.Error())
|
||||
}
|
||||
|
||||
func (e *contextAlreadyDoneError) SafeToRetry() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *contextAlreadyDoneError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
type writeError struct {
|
||||
err error
|
||||
safeToRetry bool
|
||||
}
|
||||
|
||||
func (e *writeError) Error() string {
|
||||
return fmt.Sprintf("write failed: %s", e.err.Error())
|
||||
}
|
||||
|
||||
func (e *writeError) SafeToRetry() bool {
|
||||
return e.safeToRetry
|
||||
}
|
||||
|
||||
func (e *writeError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
@@ -3,11 +3,13 @@ module github.com/jackc/pgconn
|
||||
go 1.12
|
||||
|
||||
require (
|
||||
github.com/jackc/chunkreader/v2 v2.0.0
|
||||
github.com/jackc/pgio v1.0.0
|
||||
github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2
|
||||
github.com/jackc/pgpassfile v1.0.0
|
||||
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711
|
||||
github.com/stretchr/testify v1.3.0
|
||||
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a
|
||||
golang.org/x/text v0.3.0
|
||||
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522
|
||||
github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29
|
||||
github.com/stretchr/testify v1.4.0
|
||||
golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586
|
||||
golang.org/x/text v0.3.2
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7
|
||||
)
|
||||
|
||||
@@ -1,31 +1,105 @@
|
||||
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
||||
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
|
||||
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
|
||||
github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
|
||||
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
||||
github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0=
|
||||
github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo=
|
||||
github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs=
|
||||
github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
|
||||
github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA=
|
||||
github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE=
|
||||
github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s=
|
||||
github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE=
|
||||
github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8=
|
||||
github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye4717ITLaNwV9mWbJx0dLCpcRzdA=
|
||||
github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db h1:UpaKn/gYxzH6/zWyRQH1S260zvKqwJJ4h8+Kf09ooh0=
|
||||
github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A=
|
||||
github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78=
|
||||
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA=
|
||||
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 h1:vZp4bYotXUkFx7JUSm7U8KV/7Q0AOdrQxxBBj0ZmZsg=
|
||||
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg=
|
||||
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
||||
github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM=
|
||||
github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 h1:f2HwOeI1NIJyNFVVeh1gUISyt57iw/fmI/IXJfH3ATE=
|
||||
github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM=
|
||||
github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg=
|
||||
github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc=
|
||||
github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw=
|
||||
github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y=
|
||||
github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM=
|
||||
github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc=
|
||||
github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
|
||||
github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw=
|
||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||
github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||
github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ=
|
||||
github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
|
||||
github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
|
||||
github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
|
||||
github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc=
|
||||
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
|
||||
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4=
|
||||
github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
|
||||
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
|
||||
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcNpq8q3BCACtVgNfoJxOV7g=
|
||||
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
|
||||
go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
|
||||
go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
|
||||
go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0=
|
||||
go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
|
||||
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE=
|
||||
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e h1:nFYrTHrdrAOpShe27kaFHjsqYSEQ0KWqdWLu3xuZJts=
|
||||
golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 h1:7KByu05hhLed2MO29w7p1XfZvZ13m8mub3shuVftRs0=
|
||||
golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 h1:PPwnA7z1Pjf7XYaBP9GL1VAMZmcIWyFz7QCMSIIa3Bg=
|
||||
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
||||
golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 h1:bhOzK9QyoD0ogCnFro1m2mz41+Ib0oOhfJnBp5MR4K4=
|
||||
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s=
|
||||
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
|
||||
@@ -40,9 +40,12 @@ type Notification struct {
|
||||
Payload string
|
||||
}
|
||||
|
||||
// DialFunc is a function that can be used to connect to a PostgreSQL server
|
||||
// DialFunc is a function that can be used to connect to a PostgreSQL server.
|
||||
type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
|
||||
// BuildFrontendFunc is a function that can be used to create Frontend implementation for connection.
|
||||
type BuildFrontendFunc func(r io.Reader, w io.Writer) Frontend
|
||||
|
||||
// NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at
|
||||
// any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin
|
||||
// of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY
|
||||
@@ -55,16 +58,21 @@ type NoticeHandler func(*PgConn, *Notice)
|
||||
// notice event.
|
||||
type NotificationHandler func(*PgConn, *Notification)
|
||||
|
||||
// Frontend used to receive messages from backend.
|
||||
type Frontend interface {
|
||||
Receive() (pgproto3.BackendMessage, error)
|
||||
}
|
||||
|
||||
// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
|
||||
type PgConn struct {
|
||||
conn net.Conn // the underlying TCP or unix domain socket connection
|
||||
pid uint32 // backend pid
|
||||
secretKey uint32 // key to use to send a cancel query message to the server
|
||||
parameterStatuses map[string]string // parameters that have been reported by the server
|
||||
TxStatus byte
|
||||
Frontend *pgproto3.Frontend
|
||||
txStatus byte
|
||||
frontend Frontend
|
||||
|
||||
Config *Config
|
||||
config *Config
|
||||
|
||||
status byte // One of connStatus* constants
|
||||
|
||||
@@ -91,22 +99,18 @@ func Connect(ctx context.Context, connString string) (*PgConn, error) {
|
||||
return ConnectConfig(ctx, config)
|
||||
}
|
||||
|
||||
// Connect establishes a connection to a PostgreSQL server using config. ctx can be used to cancel a connect attempt.
|
||||
// Connect establishes a connection to a PostgreSQL server using config. config must have been constructed with
|
||||
// ParseConfig. ctx can be used to cancel a connect attempt.
|
||||
//
|
||||
// If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An
|
||||
// authentication error will terminate the chain of attempts (like libpq:
|
||||
// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise,
|
||||
// if all attempts fail the last error is returned.
|
||||
func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err error) {
|
||||
// For convenience set a few defaults if not already set. This makes it simpler to directly construct a config.
|
||||
if config.Port == 0 {
|
||||
config.Port = 5432
|
||||
}
|
||||
if config.DialFunc == nil {
|
||||
config.DialFunc = makeDefaultDialer().DialContext
|
||||
}
|
||||
if config.RuntimeParams == nil {
|
||||
config.RuntimeParams = make(map[string]string)
|
||||
// Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from
|
||||
// zero values.
|
||||
if !config.createdByParseConfig {
|
||||
panic("config must be created by ParseConfig")
|
||||
}
|
||||
|
||||
// Simplify usage by treating primary config and fallbacks the same.
|
||||
@@ -124,19 +128,19 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
|
||||
if err == nil {
|
||||
break
|
||||
} else if err, ok := err.(*PgError); ok {
|
||||
return nil, err
|
||||
return nil, &connectError{config: config, msg: "server error", err: err}
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, err // no need to wrap in connectError because it will already be wrapped in all cases except PgError
|
||||
}
|
||||
|
||||
if config.AfterConnect != nil {
|
||||
err := config.AfterConnect(ctx, pgConn)
|
||||
if err != nil {
|
||||
pgConn.conn.Close()
|
||||
return nil, errors.Errorf("AfterConnect: %v", err)
|
||||
return nil, &connectError{config: config, msg: "AfterConnect error", err: err}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -145,14 +149,14 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
|
||||
|
||||
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) {
|
||||
pgConn := new(PgConn)
|
||||
pgConn.Config = config
|
||||
pgConn.config = config
|
||||
pgConn.wbuf = make([]byte, 0, 1024)
|
||||
|
||||
var err error
|
||||
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
|
||||
pgConn.conn, err = config.DialFunc(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, &connectError{config: config, msg: "dial error", err: err}
|
||||
}
|
||||
|
||||
pgConn.parameterStatuses = make(map[string]string)
|
||||
@@ -160,7 +164,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
||||
if fallbackConfig.TLSConfig != nil {
|
||||
if err := pgConn.startTLS(fallbackConfig.TLSConfig); err != nil {
|
||||
pgConn.conn.Close()
|
||||
return nil, err
|
||||
return nil, &connectError{config: config, msg: "tls error", err: err}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,10 +174,10 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
||||
func() { pgConn.conn.SetDeadline(time.Time{}) },
|
||||
)
|
||||
|
||||
pgConn.Frontend, err = pgproto3.NewFrontend(pgproto3.NewChunkReader(pgConn.conn), pgConn.conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pgConn.contextWatcher.Watch(ctx)
|
||||
defer pgConn.contextWatcher.Unwatch()
|
||||
|
||||
pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn)
|
||||
|
||||
startupMsg := pgproto3.StartupMessage{
|
||||
ProtocolVersion: pgproto3.ProtocolVersionNumber,
|
||||
@@ -192,32 +196,52 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
||||
|
||||
if _, err := pgConn.conn.Write(startupMsg.Encode(pgConn.wbuf)); err != nil {
|
||||
pgConn.conn.Close()
|
||||
return nil, err
|
||||
return nil, &connectError{config: config, msg: "failed to write startup message", err: err}
|
||||
}
|
||||
|
||||
for {
|
||||
msg, err := pgConn.ReceiveMessage()
|
||||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
pgConn.conn.Close()
|
||||
return nil, err
|
||||
if err, ok := err.(*PgError); ok {
|
||||
return nil, err
|
||||
}
|
||||
return nil, &connectError{config: config, msg: "failed to receive message", err: err}
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.BackendKeyData:
|
||||
pgConn.pid = msg.ProcessID
|
||||
pgConn.secretKey = msg.SecretKey
|
||||
case *pgproto3.Authentication:
|
||||
if err = pgConn.rxAuthenticationX(msg); err != nil {
|
||||
|
||||
case *pgproto3.AuthenticationOk:
|
||||
case *pgproto3.AuthenticationCleartextPassword:
|
||||
err = pgConn.txPasswordMessage(pgConn.config.Password)
|
||||
if err != nil {
|
||||
pgConn.conn.Close()
|
||||
return nil, err
|
||||
return nil, &connectError{config: config, msg: "failed to write password message", err: err}
|
||||
}
|
||||
case *pgproto3.AuthenticationMD5Password:
|
||||
digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:]))
|
||||
err = pgConn.txPasswordMessage(digestedPassword)
|
||||
if err != nil {
|
||||
pgConn.conn.Close()
|
||||
return nil, &connectError{config: config, msg: "failed to write password message", err: err}
|
||||
}
|
||||
case *pgproto3.AuthenticationSASL:
|
||||
err = pgConn.scramAuth(msg.AuthMechanisms)
|
||||
if err != nil {
|
||||
pgConn.conn.Close()
|
||||
return nil, &connectError{config: config, msg: "failed SASL auth", err: err}
|
||||
}
|
||||
|
||||
case *pgproto3.ReadyForQuery:
|
||||
pgConn.status = connStatusIdle
|
||||
if config.ValidateConnect != nil {
|
||||
err := config.ValidateConnect(ctx, pgConn)
|
||||
if err != nil {
|
||||
pgConn.conn.Close()
|
||||
return nil, errors.Errorf("ValidateConnect: %v", err)
|
||||
return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err}
|
||||
}
|
||||
}
|
||||
return pgConn, nil
|
||||
@@ -225,10 +249,10 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
||||
// handled by ReceiveMessage
|
||||
case *pgproto3.ErrorResponse:
|
||||
pgConn.conn.Close()
|
||||
return nil, errorResponseToPgError(msg)
|
||||
return nil, ErrorResponseToPgError(msg)
|
||||
default:
|
||||
pgConn.conn.Close()
|
||||
return nil, errors.New("unexpected message")
|
||||
return nil, &connectError{config: config, msg: "received unexpected message", err: err}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -245,7 +269,7 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) {
|
||||
}
|
||||
|
||||
if response[0] != 'S' {
|
||||
return ErrTLSRefused
|
||||
return errors.New("server refused TLS connection")
|
||||
}
|
||||
|
||||
pgConn.conn = tls.Client(pgConn.conn, tlsConfig)
|
||||
@@ -253,23 +277,6 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pgConn *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) {
|
||||
switch msg.Type {
|
||||
case pgproto3.AuthTypeOk:
|
||||
case pgproto3.AuthTypeCleartextPassword:
|
||||
err = pgConn.txPasswordMessage(pgConn.Config.Password)
|
||||
case pgproto3.AuthTypeMD5Password:
|
||||
digestedPassword := "md5" + hexMD5(hexMD5(pgConn.Config.Password+pgConn.Config.User)+string(msg.Salt[:]))
|
||||
err = pgConn.txPasswordMessage(digestedPassword)
|
||||
case pgproto3.AuthTypeSASL:
|
||||
err = pgConn.scramAuth(msg.SASLAuthMechanisms)
|
||||
default:
|
||||
err = errors.New("Received unknown authentication message")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (pgConn *PgConn) txPasswordMessage(password string) (err error) {
|
||||
msg := &pgproto3.PasswordMessage{Password: password}
|
||||
_, err = pgConn.conn.Write(msg.Encode(pgConn.wbuf))
|
||||
@@ -292,7 +299,7 @@ func (pgConn *PgConn) signalMessage() chan struct{} {
|
||||
|
||||
ch := make(chan struct{})
|
||||
go func() {
|
||||
pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.Frontend.Receive()
|
||||
pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive()
|
||||
pgConn.bufferingReceiveMux.Unlock()
|
||||
close(ch)
|
||||
}()
|
||||
@@ -300,7 +307,64 @@ func (pgConn *PgConn) signalMessage() chan struct{} {
|
||||
return ch
|
||||
}
|
||||
|
||||
func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) {
|
||||
// SendBytes sends buf to the PostgreSQL server. It must only be used when the connection is not busy. e.g. It is as
|
||||
// error to call SendBytes while reading the result of a query.
|
||||
//
|
||||
// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly.
|
||||
// See https://www.postgresql.org/docs/current/protocol.html.
|
||||
func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error {
|
||||
if err := pgConn.lock(); err != nil {
|
||||
return err
|
||||
}
|
||||
defer pgConn.unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return &contextAlreadyDoneError{err: ctx.Err()}
|
||||
default:
|
||||
}
|
||||
pgConn.contextWatcher.Watch(ctx)
|
||||
defer pgConn.contextWatcher.Unwatch()
|
||||
|
||||
n, err := pgConn.conn.Write(buf)
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
return &writeError{err: err, safeToRetry: n == 0}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the
|
||||
// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages
|
||||
// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger
|
||||
// the OnNotification callback.
|
||||
//
|
||||
// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly.
|
||||
// See https://www.postgresql.org/docs/current/protocol.html.
|
||||
func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessage, error) {
|
||||
if err := pgConn.lock(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer pgConn.unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, &contextAlreadyDoneError{err: ctx.Err()}
|
||||
default:
|
||||
}
|
||||
pgConn.contextWatcher.Watch(ctx)
|
||||
defer pgConn.contextWatcher.Unwatch()
|
||||
|
||||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
err = &pgconnError{msg: "receive message failed", err: err, safeToRetry: true}
|
||||
}
|
||||
return msg, err
|
||||
}
|
||||
|
||||
// receiveMessage receives a message without setting up context cancellation
|
||||
func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
|
||||
var msg pgproto3.BackendMessage
|
||||
var err error
|
||||
if pgConn.bufferingReceive {
|
||||
@@ -312,10 +376,10 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) {
|
||||
|
||||
// If a timeout error happened in the background try the read again.
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
msg, err = pgConn.Frontend.Receive()
|
||||
msg, err = pgConn.frontend.Receive()
|
||||
}
|
||||
} else {
|
||||
msg, err = pgConn.Frontend.Receive()
|
||||
msg, err = pgConn.frontend.Receive()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -329,21 +393,21 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) {
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.ReadyForQuery:
|
||||
pgConn.TxStatus = msg.TxStatus
|
||||
pgConn.txStatus = msg.TxStatus
|
||||
case *pgproto3.ParameterStatus:
|
||||
pgConn.parameterStatuses[msg.Name] = msg.Value
|
||||
case *pgproto3.ErrorResponse:
|
||||
if msg.Severity == "FATAL" {
|
||||
pgConn.hardClose()
|
||||
return nil, errorResponseToPgError(msg)
|
||||
return nil, ErrorResponseToPgError(msg)
|
||||
}
|
||||
case *pgproto3.NoticeResponse:
|
||||
if pgConn.Config.OnNotice != nil {
|
||||
pgConn.Config.OnNotice(pgConn, noticeResponseToNotice(msg))
|
||||
if pgConn.config.OnNotice != nil {
|
||||
pgConn.config.OnNotice(pgConn, noticeResponseToNotice(msg))
|
||||
}
|
||||
case *pgproto3.NotificationResponse:
|
||||
if pgConn.Config.OnNotification != nil {
|
||||
pgConn.Config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload})
|
||||
if pgConn.config.OnNotification != nil {
|
||||
pgConn.config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -360,6 +424,11 @@ func (pgConn *PgConn) PID() uint32 {
|
||||
return pgConn.pid
|
||||
}
|
||||
|
||||
// TxStatus returns the current TxStatus as reported by the server.
|
||||
func (pgConn *PgConn) TxStatus() byte {
|
||||
return pgConn.txStatus
|
||||
}
|
||||
|
||||
// SecretKey returns the backend secret key used to send a cancel query message to the server.
|
||||
func (pgConn *PgConn) SecretKey() uint32 {
|
||||
return pgConn.secretKey
|
||||
@@ -381,12 +450,12 @@ func (pgConn *PgConn) Close(ctx context.Context) error {
|
||||
|
||||
_, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4})
|
||||
if err != nil {
|
||||
return linkErrors(ctx.Err(), err)
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = pgConn.conn.Read(make([]byte, 1))
|
||||
if err != io.EOF {
|
||||
return linkErrors(ctx.Err(), err)
|
||||
return err
|
||||
}
|
||||
|
||||
return pgConn.conn.Close()
|
||||
@@ -402,21 +471,20 @@ func (pgConn *PgConn) hardClose() error {
|
||||
return pgConn.conn.Close()
|
||||
}
|
||||
|
||||
// TODO - rethink how to report status. At the moment this is just a temporary measure so pgx.Conn can detect death of
|
||||
// underlying connection.
|
||||
func (pgConn *PgConn) IsAlive() bool {
|
||||
return pgConn.status >= connStatusIdle
|
||||
// IsClosed reports if the connection has been closed.
|
||||
func (pgConn *PgConn) IsClosed() bool {
|
||||
return pgConn.status < connStatusIdle
|
||||
}
|
||||
|
||||
// lock locks the connection. It panics if the connection is already locked or is closed.
|
||||
// lock locks the connection.
|
||||
func (pgConn *PgConn) lock() error {
|
||||
switch pgConn.status {
|
||||
case connStatusBusy:
|
||||
return ErrConnBusy // This only should be possible in case of an application bug.
|
||||
return &connLockError{status: "conn busy"} // This only should be possible in case of an application bug.
|
||||
case connStatusClosed:
|
||||
return errors.New("conn closed")
|
||||
return &connLockError{status: "conn closed"}
|
||||
case connStatusUninitialized:
|
||||
return errors.New("conn uninitialized")
|
||||
return &connLockError{status: "conn uninitialized"}
|
||||
}
|
||||
pgConn.status = connStatusBusy
|
||||
return nil
|
||||
@@ -456,23 +524,24 @@ func (ct CommandTag) String() string {
|
||||
return string(ct)
|
||||
}
|
||||
|
||||
type PreparedStatementDescription struct {
|
||||
type StatementDescription struct {
|
||||
Name string
|
||||
SQL string
|
||||
ParamOIDs []uint32
|
||||
Fields []pgproto3.FieldDescription
|
||||
}
|
||||
|
||||
// Prepare creates a prepared statement.
|
||||
func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) {
|
||||
// Prepare creates a prepared statement. If the name is empty, the anonymous prepared statement will be used. This
|
||||
// allows Prepare to also to describe statements without creating a server-side prepared statement.
|
||||
func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) {
|
||||
if err := pgConn.lock(); err != nil {
|
||||
return nil, linkErrors(err, ErrNoBytesSent)
|
||||
return nil, err
|
||||
}
|
||||
defer pgConn.unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, linkErrors(ctx.Err(), ErrNoBytesSent)
|
||||
return nil, &contextAlreadyDoneError{err: ctx.Err()}
|
||||
default:
|
||||
}
|
||||
pgConn.contextWatcher.Watch(ctx)
|
||||
@@ -486,22 +555,19 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
|
||||
n, err := pgConn.conn.Write(buf)
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
if n == 0 {
|
||||
err = linkErrors(err, ErrNoBytesSent)
|
||||
}
|
||||
return nil, linkErrors(ctx.Err(), err)
|
||||
return nil, &pgconnError{msg: "write failed", err: err, safeToRetry: n == 0}
|
||||
}
|
||||
|
||||
psd := &PreparedStatementDescription{Name: name, SQL: sql}
|
||||
psd := &StatementDescription{Name: name, SQL: sql}
|
||||
|
||||
var parseErr error
|
||||
|
||||
readloop:
|
||||
for {
|
||||
msg, err := pgConn.ReceiveMessage()
|
||||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
return nil, linkErrors(ctx.Err(), err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
@@ -512,7 +578,7 @@ readloop:
|
||||
psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields))
|
||||
copy(psd.Fields, msg.Fields)
|
||||
case *pgproto3.ErrorResponse:
|
||||
parseErr = errorResponseToPgError(msg)
|
||||
parseErr = ErrorResponseToPgError(msg)
|
||||
case *pgproto3.ReadyForQuery:
|
||||
break readloop
|
||||
}
|
||||
@@ -524,7 +590,8 @@ readloop:
|
||||
return psd, nil
|
||||
}
|
||||
|
||||
func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {
|
||||
// ErrorResponseToPgError converts a wire protocol error message to a *PgError.
|
||||
func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {
|
||||
return &PgError{
|
||||
Severity: msg.Severity,
|
||||
Code: string(msg.Code),
|
||||
@@ -547,7 +614,7 @@ func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {
|
||||
}
|
||||
|
||||
func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice {
|
||||
pgerr := errorResponseToPgError((*pgproto3.ErrorResponse)(msg))
|
||||
pgerr := ErrorResponseToPgError((*pgproto3.ErrorResponse)(msg))
|
||||
return (*Notice)(pgerr)
|
||||
}
|
||||
|
||||
@@ -559,7 +626,7 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
|
||||
// the connection config. This is important in high availability configurations where fallback connections may be
|
||||
// specified or DNS may be used to load balance.
|
||||
serverAddr := pgConn.conn.RemoteAddr()
|
||||
cancelConn, err := pgConn.Config.DialFunc(ctx, serverAddr.Network(), serverAddr.String())
|
||||
cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -579,12 +646,12 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
|
||||
binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey))
|
||||
_, err = cancelConn.Write(buf)
|
||||
if err != nil {
|
||||
return linkErrors(ctx.Err(), err)
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = cancelConn.Read(buf)
|
||||
if err != io.EOF {
|
||||
return errors.Errorf("Server failed to close connection after cancel query request: %w", linkErrors(ctx.Err(), err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -608,9 +675,9 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
|
||||
defer pgConn.contextWatcher.Unwatch()
|
||||
|
||||
for {
|
||||
msg, err := pgConn.ReceiveMessage()
|
||||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
return linkErrors(ctx.Err(), err)
|
||||
return err
|
||||
}
|
||||
|
||||
switch msg.(type) {
|
||||
@@ -629,7 +696,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
|
||||
if err := pgConn.lock(); err != nil {
|
||||
return &MultiResultReader{
|
||||
closed: true,
|
||||
err: linkErrors(err, ErrNoBytesSent),
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -642,7 +709,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
multiResult.closed = true
|
||||
multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent)
|
||||
multiResult.err = &contextAlreadyDoneError{err: ctx.Err()}
|
||||
pgConn.unlock()
|
||||
return multiResult
|
||||
default:
|
||||
@@ -657,10 +724,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
|
||||
pgConn.hardClose()
|
||||
pgConn.contextWatcher.Unwatch()
|
||||
multiResult.closed = true
|
||||
if n == 0 {
|
||||
err = linkErrors(err, ErrNoBytesSent)
|
||||
}
|
||||
multiResult.err = linkErrors(ctx.Err(), err)
|
||||
multiResult.err = &writeError{err: err, safeToRetry: n == 0}
|
||||
pgConn.unlock()
|
||||
return multiResult
|
||||
}
|
||||
@@ -729,19 +793,18 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
|
||||
}
|
||||
|
||||
func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader {
|
||||
if err := pgConn.lock(); err != nil {
|
||||
return &ResultReader{
|
||||
closed: true,
|
||||
err: linkErrors(err, ErrNoBytesSent),
|
||||
}
|
||||
}
|
||||
|
||||
pgConn.resultReader = ResultReader{
|
||||
pgConn: pgConn,
|
||||
ctx: ctx,
|
||||
}
|
||||
result := &pgConn.resultReader
|
||||
|
||||
if err := pgConn.lock(); err != nil {
|
||||
result.concludeCommand(nil, err)
|
||||
result.closed = true
|
||||
return result
|
||||
}
|
||||
|
||||
if len(paramValues) > math.MaxUint16 {
|
||||
result.concludeCommand(nil, errors.Errorf("extended protocol limited to %v parameters", math.MaxUint16))
|
||||
result.closed = true
|
||||
@@ -751,7 +814,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
result.concludeCommand(nil, linkErrors(ctx.Err(), ErrNoBytesSent))
|
||||
result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()})
|
||||
result.closed = true
|
||||
pgConn.unlock()
|
||||
return result
|
||||
@@ -770,10 +833,7 @@ func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result
|
||||
n, err := pgConn.conn.Write(buf)
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
if n == 0 {
|
||||
err = linkErrors(err, ErrNoBytesSent)
|
||||
}
|
||||
result.concludeCommand(nil, linkErrors(ctx.Err(), err))
|
||||
result.concludeCommand(nil, &writeError{err: err, safeToRetry: n == 0})
|
||||
pgConn.contextWatcher.Unwatch()
|
||||
result.closed = true
|
||||
pgConn.unlock()
|
||||
@@ -783,13 +843,13 @@ func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result
|
||||
// CopyTo executes the copy command sql and copies the results to w.
|
||||
func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) {
|
||||
if err := pgConn.lock(); err != nil {
|
||||
return nil, linkErrors(err, ErrNoBytesSent)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
pgConn.unlock()
|
||||
return nil, linkErrors(ctx.Err(), ErrNoBytesSent)
|
||||
return nil, &contextAlreadyDoneError{err: ctx.Err()}
|
||||
default:
|
||||
}
|
||||
pgConn.contextWatcher.Watch(ctx)
|
||||
@@ -803,20 +863,17 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
pgConn.unlock()
|
||||
if n == 0 {
|
||||
err = linkErrors(err, ErrNoBytesSent)
|
||||
}
|
||||
return nil, linkErrors(ctx.Err(), err)
|
||||
return nil, &writeError{err: err, safeToRetry: n == 0}
|
||||
}
|
||||
|
||||
// Read results
|
||||
var commandTag CommandTag
|
||||
var pgErr error
|
||||
for {
|
||||
msg, err := pgConn.ReceiveMessage()
|
||||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
return nil, linkErrors(ctx.Err(), err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
@@ -833,7 +890,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
|
||||
case *pgproto3.CommandComplete:
|
||||
commandTag = CommandTag(msg.CommandTag)
|
||||
case *pgproto3.ErrorResponse:
|
||||
pgErr = errorResponseToPgError(msg)
|
||||
pgErr = ErrorResponseToPgError(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -844,13 +901,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
|
||||
// could still block.
|
||||
func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) {
|
||||
if err := pgConn.lock(); err != nil {
|
||||
return nil, linkErrors(err, ErrNoBytesSent)
|
||||
return nil, err
|
||||
}
|
||||
defer pgConn.unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, linkErrors(ctx.Err(), ErrNoBytesSent)
|
||||
return nil, &contextAlreadyDoneError{err: ctx.Err()}
|
||||
default:
|
||||
}
|
||||
pgConn.contextWatcher.Watch(ctx)
|
||||
@@ -863,10 +920,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
||||
n, err := pgConn.conn.Write(buf)
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
if n == 0 {
|
||||
err = linkErrors(err, ErrNoBytesSent)
|
||||
}
|
||||
return nil, linkErrors(ctx.Err(), err)
|
||||
return nil, &writeError{err: err, safeToRetry: n == 0}
|
||||
}
|
||||
|
||||
// Read until copy in response or error.
|
||||
@@ -874,17 +928,17 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
||||
var pgErr error
|
||||
pendingCopyInResponse := true
|
||||
for pendingCopyInResponse {
|
||||
msg, err := pgConn.ReceiveMessage()
|
||||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
return nil, linkErrors(ctx.Err(), err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.CopyInResponse:
|
||||
pendingCopyInResponse = false
|
||||
case *pgproto3.ErrorResponse:
|
||||
pgErr = errorResponseToPgError(msg)
|
||||
pgErr = ErrorResponseToPgError(msg)
|
||||
case *pgproto3.ReadyForQuery:
|
||||
return commandTag, pgErr
|
||||
}
|
||||
@@ -906,21 +960,21 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
||||
_, err = pgConn.conn.Write(buf)
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
return nil, linkErrors(ctx.Err(), err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-signalMessageChan:
|
||||
msg, err := pgConn.ReceiveMessage()
|
||||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
return nil, linkErrors(ctx.Err(), err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.ErrorResponse:
|
||||
pgErr = errorResponseToPgError(msg)
|
||||
pgErr = ErrorResponseToPgError(msg)
|
||||
}
|
||||
default:
|
||||
}
|
||||
@@ -937,15 +991,15 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
||||
_, err = pgConn.conn.Write(buf)
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
return nil, linkErrors(ctx.Err(), err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read results
|
||||
for {
|
||||
msg, err := pgConn.ReceiveMessage()
|
||||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
pgConn.hardClose()
|
||||
return nil, linkErrors(ctx.Err(), err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
@@ -954,7 +1008,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
||||
case *pgproto3.CommandComplete:
|
||||
commandTag = CommandTag(msg.CommandTag)
|
||||
case *pgproto3.ErrorResponse:
|
||||
pgErr = errorResponseToPgError(msg)
|
||||
pgErr = ErrorResponseToPgError(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -983,11 +1037,11 @@ func (mrr *MultiResultReader) ReadAll() ([]*Result, error) {
|
||||
}
|
||||
|
||||
func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) {
|
||||
msg, err := mrr.pgConn.ReceiveMessage()
|
||||
msg, err := mrr.pgConn.receiveMessage()
|
||||
|
||||
if err != nil {
|
||||
mrr.pgConn.contextWatcher.Unwatch()
|
||||
mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err)
|
||||
mrr.err = err
|
||||
mrr.closed = true
|
||||
mrr.pgConn.hardClose()
|
||||
return nil, mrr.err
|
||||
@@ -999,7 +1053,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
|
||||
mrr.closed = true
|
||||
mrr.pgConn.unlock()
|
||||
case *pgproto3.ErrorResponse:
|
||||
mrr.err = errorResponseToPgError(msg)
|
||||
mrr.err = ErrorResponseToPgError(msg)
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
@@ -1151,7 +1205,10 @@ func (rr *ResultReader) Close() (CommandTag, error) {
|
||||
return nil, rr.err
|
||||
}
|
||||
|
||||
switch msg.(type) {
|
||||
switch msg := msg.(type) {
|
||||
// Detect a deferred constraint violation where the ErrorResponse is sent after CommandComplete.
|
||||
case *pgproto3.ErrorResponse:
|
||||
rr.err = ErrorResponseToPgError(msg)
|
||||
case *pgproto3.ReadyForQuery:
|
||||
rr.pgConn.contextWatcher.Unwatch()
|
||||
rr.pgConn.unlock()
|
||||
@@ -1165,7 +1222,7 @@ func (rr *ResultReader) Close() (CommandTag, error) {
|
||||
|
||||
func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) {
|
||||
if rr.multiResultReader == nil {
|
||||
msg, err = rr.pgConn.ReceiveMessage()
|
||||
msg, err = rr.pgConn.receiveMessage()
|
||||
} else {
|
||||
msg, err = rr.multiResultReader.receiveMessage()
|
||||
}
|
||||
@@ -1187,7 +1244,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error
|
||||
case *pgproto3.CommandComplete:
|
||||
rr.concludeCommand(CommandTag(msg.CommandTag), nil)
|
||||
case *pgproto3.ErrorResponse:
|
||||
rr.concludeCommand(nil, errorResponseToPgError(msg))
|
||||
rr.concludeCommand(nil, ErrorResponseToPgError(msg))
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
@@ -1199,7 +1256,7 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) {
|
||||
}
|
||||
|
||||
rr.commandTag = commandTag
|
||||
rr.err = preferContextOverNetTimeoutError(rr.ctx, err)
|
||||
rr.err = err
|
||||
rr.fieldDescriptions = nil
|
||||
rr.rowValues = nil
|
||||
rr.commandConcluded = true
|
||||
@@ -1229,7 +1286,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
|
||||
if err := pgConn.lock(); err != nil {
|
||||
return &MultiResultReader{
|
||||
closed: true,
|
||||
err: linkErrors(err, ErrNoBytesSent),
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1242,7 +1299,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
multiResult.closed = true
|
||||
multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent)
|
||||
multiResult.err = &contextAlreadyDoneError{err: ctx.Err()}
|
||||
pgConn.unlock()
|
||||
return multiResult
|
||||
default:
|
||||
|
||||
+219
-27
@@ -18,6 +18,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgmock"
|
||||
"github.com/jackc/pgproto3/v2"
|
||||
errors "golang.org/x/xerrors"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -72,6 +74,67 @@ func TestConnectTLS(t *testing.T) {
|
||||
closeConn(t, conn)
|
||||
}
|
||||
|
||||
type pgmockWaitStep time.Duration
|
||||
|
||||
func (s pgmockWaitStep) Step(*pgproto3.Backend) error {
|
||||
time.Sleep(time.Duration(s))
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestConnectWithContextThatTimesOut(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
script := &pgmock.Script{
|
||||
Steps: []pgmock.Step{
|
||||
pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
|
||||
pgmock.SendMessage(&pgproto3.AuthenticationOk{}),
|
||||
pgmockWaitStep(time.Millisecond * 500),
|
||||
pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
|
||||
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
|
||||
},
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:")
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
serverErrChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(serverErrChan)
|
||||
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
err = conn.SetDeadline(time.Now().Add(time.Millisecond * 450))
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn))
|
||||
if err != nil {
|
||||
serverErrChan <- err
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
parts := strings.Split(ln.Addr().String(), ":")
|
||||
host := parts[0]
|
||||
port := parts[1]
|
||||
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
|
||||
tooLate := time.Now().Add(time.Millisecond * 500)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50)
|
||||
defer cancel()
|
||||
_, err = pgconn.Connect(ctx, connStr)
|
||||
require.True(t, pgconn.Timeout(err), err)
|
||||
require.True(t, time.Now().Before(tooLate))
|
||||
}
|
||||
|
||||
func TestConnectInvalidUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -85,14 +148,11 @@ func TestConnectInvalidUser(t *testing.T) {
|
||||
|
||||
config.User = "pgxinvalidusertest"
|
||||
|
||||
conn, err := pgconn.ConnectConfig(context.Background(), config)
|
||||
if err == nil {
|
||||
conn.Close(context.Background())
|
||||
t.Fatal("expected err but got none")
|
||||
}
|
||||
pgErr, ok := err.(*pgconn.PgError)
|
||||
_, err = pgconn.ConnectConfig(context.Background(), config)
|
||||
require.Error(t, err)
|
||||
pgErr, ok := errors.Unwrap(err).(*pgconn.PgError)
|
||||
if !ok {
|
||||
t.Fatalf("Expected to receive a PgError, instead received: %v", err)
|
||||
t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err)
|
||||
}
|
||||
if pgErr.Code != "28000" && pgErr.Code != "28P01" {
|
||||
t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr)
|
||||
@@ -262,6 +322,14 @@ func TestConnectWithAfterConnect(t *testing.T) {
|
||||
assert.Equal(t, []byte("foobar"), results[0].Rows[0][0])
|
||||
}
|
||||
|
||||
func TestConnectConfigRequiresConfigFromParseConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &pgconn.Config{}
|
||||
|
||||
require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgconn.ConnectConfig(context.Background(), config) })
|
||||
}
|
||||
|
||||
func TestConnPrepareSyntaxError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -289,7 +357,7 @@ func TestConnPrepareContextPrecanceled(t *testing.T) {
|
||||
assert.Nil(t, psd)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, context.Canceled))
|
||||
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
|
||||
assert.True(t, pgconn.SafeToRetry(err))
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
@@ -381,6 +449,34 @@ func TestConnExecMultipleQueriesError(t *testing.T) {
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnExecDeferredError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
setupSQL := `create temporary table t (
|
||||
id text primary key,
|
||||
n int not null,
|
||||
unique (n) deferrable initially deferred
|
||||
);
|
||||
|
||||
insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`
|
||||
|
||||
_, err = pgConn.Exec(context.Background(), setupSQL).ReadAll()
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = pgConn.Exec(context.Background(), `update t set n=n+1 where id='b' returning *`).ReadAll()
|
||||
require.NotNil(t, err)
|
||||
|
||||
var pgErr *pgconn.PgError
|
||||
require.True(t, errors.As(err, &pgErr))
|
||||
require.Equal(t, "23505", pgErr.Code)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnExecContextCanceled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -395,8 +491,8 @@ func TestConnExecContextCanceled(t *testing.T) {
|
||||
for multiResult.NextResult() {
|
||||
}
|
||||
err = multiResult.Close()
|
||||
assert.Equal(t, context.DeadlineExceeded, err)
|
||||
assert.False(t, pgConn.IsAlive())
|
||||
assert.True(t, pgconn.Timeout(err))
|
||||
assert.True(t, pgConn.IsClosed())
|
||||
}
|
||||
|
||||
func TestConnExecContextPrecanceled(t *testing.T) {
|
||||
@@ -411,7 +507,7 @@ func TestConnExecContextPrecanceled(t *testing.T) {
|
||||
_, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll()
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, context.Canceled))
|
||||
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
|
||||
assert.True(t, pgconn.SafeToRetry(err))
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
@@ -437,6 +533,33 @@ func TestConnExecParams(t *testing.T) {
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnExecParamsDeferredError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
setupSQL := `create temporary table t (
|
||||
id text primary key,
|
||||
n int not null,
|
||||
unique (n) deferrable initially deferred
|
||||
);
|
||||
|
||||
insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`
|
||||
|
||||
_, err = pgConn.Exec(context.Background(), setupSQL).ReadAll()
|
||||
assert.NoError(t, err)
|
||||
|
||||
result := pgConn.ExecParams(context.Background(), `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read()
|
||||
require.NotNil(t, result.Err)
|
||||
var pgErr *pgconn.PgError
|
||||
require.True(t, errors.As(result.Err, &pgErr))
|
||||
require.Equal(t, "23505", pgErr.Code)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnExecParamsMaxNumberOfParams(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -500,9 +623,9 @@ func TestConnExecParamsCanceled(t *testing.T) {
|
||||
assert.Equal(t, 0, rowCount)
|
||||
commandTag, err := result.Close()
|
||||
assert.Equal(t, pgconn.CommandTag(nil), commandTag)
|
||||
assert.Equal(t, context.DeadlineExceeded, err)
|
||||
assert.True(t, pgconn.Timeout(err))
|
||||
|
||||
assert.False(t, pgConn.IsAlive())
|
||||
assert.True(t, pgConn.IsClosed())
|
||||
}
|
||||
|
||||
func TestConnExecParamsPrecanceled(t *testing.T) {
|
||||
@@ -517,7 +640,7 @@ func TestConnExecParamsPrecanceled(t *testing.T) {
|
||||
result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read()
|
||||
require.Error(t, result.Err)
|
||||
assert.True(t, errors.Is(result.Err, context.Canceled))
|
||||
assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent))
|
||||
assert.True(t, pgconn.SafeToRetry(result.Err))
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
@@ -627,8 +750,8 @@ func TestConnExecPreparedCanceled(t *testing.T) {
|
||||
assert.Equal(t, 0, rowCount)
|
||||
commandTag, err := result.Close()
|
||||
assert.Equal(t, pgconn.CommandTag(nil), commandTag)
|
||||
assert.Equal(t, context.DeadlineExceeded, err)
|
||||
assert.False(t, pgConn.IsAlive())
|
||||
assert.True(t, pgconn.Timeout(err))
|
||||
assert.True(t, pgConn.IsClosed())
|
||||
}
|
||||
|
||||
func TestConnExecPreparedPrecanceled(t *testing.T) {
|
||||
@@ -646,7 +769,7 @@ func TestConnExecPreparedPrecanceled(t *testing.T) {
|
||||
result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read()
|
||||
require.Error(t, result.Err)
|
||||
assert.True(t, errors.Is(result.Err, context.Canceled))
|
||||
assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent))
|
||||
assert.True(t, pgconn.SafeToRetry(result.Err))
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
@@ -683,6 +806,36 @@ func TestConnExecBatch(t *testing.T) {
|
||||
assert.Equal(t, "SELECT 1", string(results[2].CommandTag))
|
||||
}
|
||||
|
||||
func TestConnExecBatchDeferredError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
setupSQL := `create temporary table t (
|
||||
id text primary key,
|
||||
n int not null,
|
||||
unique (n) deferrable initially deferred
|
||||
);
|
||||
|
||||
insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`
|
||||
|
||||
_, err = pgConn.Exec(context.Background(), setupSQL).ReadAll()
|
||||
assert.NoError(t, err)
|
||||
|
||||
batch := &pgconn.Batch{}
|
||||
|
||||
batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil)
|
||||
_, err = pgConn.ExecBatch(context.Background(), batch).ReadAll()
|
||||
require.NotNil(t, err)
|
||||
var pgErr *pgconn.PgError
|
||||
require.True(t, errors.As(err, &pgErr))
|
||||
require.Equal(t, "23505", pgErr.Code)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnExecBatchPrecanceled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -704,7 +857,7 @@ func TestConnExecBatchPrecanceled(t *testing.T) {
|
||||
_, err = pgConn.ExecBatch(ctx, batch).ReadAll()
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, context.Canceled))
|
||||
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
|
||||
assert.True(t, pgconn.SafeToRetry(err))
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
@@ -777,8 +930,8 @@ func TestConnLocking(t *testing.T) {
|
||||
mrr := pgConn.Exec(context.Background(), "select 'Hello, world'")
|
||||
_, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll()
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, pgconn.ErrConnBusy))
|
||||
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
|
||||
assert.Equal(t, "conn busy", err.Error())
|
||||
assert.True(t, pgconn.SafeToRetry(err))
|
||||
|
||||
results, err := mrr.ReadAll()
|
||||
assert.NoError(t, err)
|
||||
@@ -935,7 +1088,7 @@ func TestConnWaitForNotificationTimeout(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
||||
err = pgConn.WaitForNotification(ctx)
|
||||
cancel()
|
||||
assert.True(t, errors.Is(err, context.DeadlineExceeded))
|
||||
assert.True(t, pgconn.Timeout(err))
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
@@ -1045,10 +1198,10 @@ func TestConnCopyToCanceled(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout")
|
||||
assert.True(t, errors.Is(err, context.DeadlineExceeded))
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, pgconn.CommandTag(nil), res)
|
||||
|
||||
assert.False(t, pgConn.IsAlive())
|
||||
assert.True(t, pgConn.IsClosed())
|
||||
}
|
||||
|
||||
func TestConnCopyToPrecanceled(t *testing.T) {
|
||||
@@ -1065,7 +1218,7 @@ func TestConnCopyToPrecanceled(t *testing.T) {
|
||||
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout")
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, context.Canceled))
|
||||
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
|
||||
assert.True(t, pgconn.SafeToRetry(err))
|
||||
assert.Equal(t, pgconn.CommandTag(nil), res)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
@@ -1137,9 +1290,9 @@ func TestConnCopyFromCanceled(t *testing.T) {
|
||||
ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
|
||||
cancel()
|
||||
assert.Equal(t, int64(0), ct.RowsAffected())
|
||||
assert.True(t, errors.Is(err, context.DeadlineExceeded))
|
||||
assert.Error(t, err)
|
||||
|
||||
assert.False(t, pgConn.IsAlive())
|
||||
assert.True(t, pgConn.IsClosed())
|
||||
}
|
||||
|
||||
func TestConnCopyFromPrecanceled(t *testing.T) {
|
||||
@@ -1173,7 +1326,7 @@ func TestConnCopyFromPrecanceled(t *testing.T) {
|
||||
ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, context.Canceled))
|
||||
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
|
||||
assert.True(t, pgconn.SafeToRetry(err))
|
||||
assert.Equal(t, pgconn.CommandTag(nil), ct)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
@@ -1331,6 +1484,45 @@ func TestConnCancelRequest(t *testing.T) {
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnSendBytesAndReceiveMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
queryMsg := pgproto3.Query{String: "select 42"}
|
||||
buf := queryMsg.Encode(nil)
|
||||
|
||||
err = pgConn.SendBytes(ctx, buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
msg, err := pgConn.ReceiveMessage(ctx)
|
||||
require.NoError(t, err)
|
||||
_, ok := msg.(*pgproto3.RowDescription)
|
||||
require.True(t, ok)
|
||||
|
||||
msg, err = pgConn.ReceiveMessage(ctx)
|
||||
require.NoError(t, err)
|
||||
_, ok = msg.(*pgproto3.DataRow)
|
||||
require.True(t, ok)
|
||||
|
||||
msg, err = pgConn.ReceiveMessage(ctx)
|
||||
require.NoError(t, err)
|
||||
_, ok = msg.(*pgproto3.CommandComplete)
|
||||
require.True(t, ok)
|
||||
|
||||
msg, err = pgConn.ReceiveMessage(ctx)
|
||||
require.NoError(t, err)
|
||||
_, ok = msg.(*pgproto3.ReadyForQuery)
|
||||
require.True(t, ok)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func Example() {
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
if err != nil {
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
package stmtcache
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
)
|
||||
|
||||
var lruCount uint64
|
||||
|
||||
// LRU implements Cache with a Least Recently Used (LRU) cache.
|
||||
type LRU struct {
|
||||
conn *pgconn.PgConn
|
||||
mode int
|
||||
cap int
|
||||
prepareCount int
|
||||
m map[string]*list.Element
|
||||
l *list.List
|
||||
psNamePrefix string
|
||||
}
|
||||
|
||||
// NewLRU creates a new LRU. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache.
|
||||
func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU {
|
||||
mustBeValidMode(mode)
|
||||
mustBeValidCap(cap)
|
||||
|
||||
n := atomic.AddUint64(&lruCount, 1)
|
||||
|
||||
return &LRU{
|
||||
conn: conn,
|
||||
mode: mode,
|
||||
cap: cap,
|
||||
m: make(map[string]*list.Element),
|
||||
l: list.New(),
|
||||
psNamePrefix: fmt.Sprintf("lrupsc_%d", n),
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed.
|
||||
func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
|
||||
if el, ok := c.m[sql]; ok {
|
||||
c.l.MoveToFront(el)
|
||||
return el.Value.(*pgconn.StatementDescription), nil
|
||||
}
|
||||
|
||||
if c.l.Len() == c.cap {
|
||||
err := c.removeOldest(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
psd, err := c.prepare(ctx, sql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
el := c.l.PushFront(psd)
|
||||
c.m[sql] = el
|
||||
|
||||
return psd, nil
|
||||
}
|
||||
|
||||
// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session.
|
||||
func (c *LRU) Clear(ctx context.Context) error {
|
||||
for c.l.Len() > 0 {
|
||||
err := c.removeOldest(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Len returns the number of cached prepared statement descriptions.
|
||||
func (c *LRU) Len() int {
|
||||
return c.l.Len()
|
||||
}
|
||||
|
||||
// Cap returns the maximum number of cached prepared statement descriptions.
|
||||
func (c *LRU) Cap() int {
|
||||
return c.cap
|
||||
}
|
||||
|
||||
// Mode returns the mode of the cache (ModePrepare or ModeDescribe)
|
||||
func (c *LRU) Mode() int {
|
||||
return c.mode
|
||||
}
|
||||
|
||||
func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
|
||||
var name string
|
||||
if c.mode == ModePrepare {
|
||||
name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount)
|
||||
c.prepareCount += 1
|
||||
}
|
||||
|
||||
return c.conn.Prepare(ctx, name, sql, nil)
|
||||
}
|
||||
|
||||
func (c *LRU) removeOldest(ctx context.Context) error {
|
||||
oldest := c.l.Back()
|
||||
c.l.Remove(oldest)
|
||||
if c.mode == ModePrepare {
|
||||
return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", oldest.Value.(*pgconn.StatementDescription).Name)).Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,113 @@
|
||||
package stmtcache_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgconn/stmtcache"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLRUModePrepare(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(ctx)
|
||||
|
||||
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2)
|
||||
require.EqualValues(t, 0, cache.Len())
|
||||
require.EqualValues(t, 2, cache.Cap())
|
||||
require.EqualValues(t, stmtcache.ModePrepare, cache.Mode())
|
||||
|
||||
psd, err := cache.Get(ctx, "select 1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, psd)
|
||||
require.EqualValues(t, 1, cache.Len())
|
||||
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
psd, err = cache.Get(ctx, "select 1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, psd)
|
||||
require.EqualValues(t, 1, cache.Len())
|
||||
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
psd, err = cache.Get(ctx, "select 2")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, psd)
|
||||
require.EqualValues(t, 2, cache.Len())
|
||||
require.ElementsMatch(t, []string{"select 1", "select 2"}, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
psd, err = cache.Get(ctx, "select 3")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, psd)
|
||||
require.EqualValues(t, 2, cache.Len())
|
||||
require.ElementsMatch(t, []string{"select 2", "select 3"}, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
err = cache.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 0, cache.Len())
|
||||
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
||||
}
|
||||
|
||||
func TestLRUModeDescribe(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(ctx)
|
||||
|
||||
cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2)
|
||||
require.EqualValues(t, 0, cache.Len())
|
||||
require.EqualValues(t, 2, cache.Cap())
|
||||
require.EqualValues(t, stmtcache.ModeDescribe, cache.Mode())
|
||||
|
||||
psd, err := cache.Get(ctx, "select 1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, psd)
|
||||
require.EqualValues(t, 1, cache.Len())
|
||||
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
psd, err = cache.Get(ctx, "select 1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, psd)
|
||||
require.EqualValues(t, 1, cache.Len())
|
||||
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
psd, err = cache.Get(ctx, "select 2")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, psd)
|
||||
require.EqualValues(t, 2, cache.Len())
|
||||
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
psd, err = cache.Get(ctx, "select 3")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, psd)
|
||||
require.EqualValues(t, 2, cache.Len())
|
||||
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
||||
|
||||
err = cache.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 0, cache.Len())
|
||||
require.Empty(t, fetchServerStatements(t, ctx, conn))
|
||||
}
|
||||
|
||||
func fetchServerStatements(t testing.TB, ctx context.Context, conn *pgconn.PgConn) []string {
|
||||
result := conn.ExecParams(ctx, `select statement from pg_prepared_statements`, nil, nil, nil, nil).Read()
|
||||
require.NoError(t, result.Err)
|
||||
var statements []string
|
||||
for _, r := range result.Rows {
|
||||
statements = append(statements, string(r[0]))
|
||||
}
|
||||
return statements
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
// Package stmtcache is a cache that can be used to implement lazy prepared statements.
|
||||
package stmtcache
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
)
|
||||
|
||||
const (
|
||||
ModePrepare = iota // Cache should prepare named statements.
|
||||
ModeDescribe // Cache should prepare the anonymous prepared statement to only fetch the description of the statement.
|
||||
)
|
||||
|
||||
// Cache prepares and caches prepared statement descriptions.
|
||||
type Cache interface {
|
||||
// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed.
|
||||
Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error)
|
||||
|
||||
// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session.
|
||||
Clear(ctx context.Context) error
|
||||
|
||||
// Len returns the number of cached prepared statement descriptions.
|
||||
Len() int
|
||||
|
||||
// Cap returns the maximum number of cached prepared statement descriptions.
|
||||
Cap() int
|
||||
|
||||
// Mode returns the mode of the cache (ModePrepare or ModeDescribe)
|
||||
Mode() int
|
||||
}
|
||||
|
||||
// New returns the preferred cache implementation for mode and cap. mode is either ModePrepare or ModeDescribe. cap is
|
||||
// the maximum size of the cache.
|
||||
func New(conn *pgconn.PgConn, mode int, cap int) Cache {
|
||||
mustBeValidMode(mode)
|
||||
mustBeValidCap(cap)
|
||||
|
||||
return NewLRU(conn, mode, cap)
|
||||
}
|
||||
|
||||
func mustBeValidMode(mode int) {
|
||||
if mode != ModePrepare && mode != ModeDescribe {
|
||||
panic("mode must be ModePrepare or ModeDescribe")
|
||||
}
|
||||
}
|
||||
|
||||
func mustBeValidCap(cap int) {
|
||||
if cap < 1 {
|
||||
panic("cache must have cap of >= 1")
|
||||
}
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set -eux
|
||||
|
||||
go get -u github.com/cockroachdb/apd
|
||||
go get -u github.com/shopspring/decimal
|
||||
go get -u gopkg.in/inconshreveable/log15.v2
|
||||
go get -u github.com/jackc/fake
|
||||
go get -u github.com/lib/pq
|
||||
go get -u github.com/hashicorp/go-version
|
||||
go get -u github.com/satori/go.uuid
|
||||
go get -u github.com/sirupsen/logrus
|
||||
go get -u github.com/pkg/errors
|
||||
go get -u go.uber.org/zap
|
||||
go get -u github.com/rs/zerolog
|
||||
Reference in New Issue
Block a user