2
0

Merge pull request #2 from jackc/master

Sync
This commit is contained in:
Artemiy Ryabinkov
2019-09-13 17:20:09 +03:00
committed by GitHub
14 changed files with 1009 additions and 273 deletions
+11 -2
View File
@@ -4,6 +4,9 @@ go:
- 1.x - 1.x
- tip - tip
git:
depth: 1
# Derived from https://github.com/lib/pq/blob/master/.travis.yml # Derived from https://github.com/lib/pq/blob/master/.travis.yml
before_install: before_install:
- ./travis/before_install.bash - ./travis/before_install.bash
@@ -11,6 +14,8 @@ before_install:
env: env:
global: global:
- GO111MODULE=on - GO111MODULE=on
- GOPROXY=https://proxy.golang.org
- GOFLAGS=-mod=readonly
- PGX_TEST_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test - PGX_TEST_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test
- PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql database=pgx_test" - PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql database=pgx_test"
- PGX_TEST_TCP_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test - PGX_TEST_TCP_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test
@@ -25,11 +30,15 @@ env:
- PGVERSION=9.4 - PGVERSION=9.4
- PGVERSION=9.3 - PGVERSION=9.3
cache:
directories:
- $HOME/.cache/go-build
- $HOME/gopath/pkg/mod
before_script: before_script:
- ./travis/before_script.bash - ./travis/before_script.bash
install: install: go mod download
- ./travis/install.bash
script: script:
- ./travis/script.bash - ./travis/script.bash
+24 -14
View File
@@ -31,7 +31,7 @@ const clientNonceLen = 18
// Perform SCRAM authentication. // Perform SCRAM authentication.
func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
sc, err := newScramClient(serverAuthMechanisms, c.Config.Password) sc, err := newScramClient(serverAuthMechanisms, c.config.Password)
if err != nil { if err != nil {
return err return err
} }
@@ -47,11 +47,11 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
} }
// Receive server-first-message payload in a AuthenticationSASLContinue. // Receive server-first-message payload in a AuthenticationSASLContinue.
authMsg, err := c.rxAuthMsg(pgproto3.AuthTypeSASLContinue) saslContinue, err := c.rxSASLContinue()
if err != nil { if err != nil {
return err return err
} }
err = sc.recvServerFirstMessage(authMsg.SASLData) err = sc.recvServerFirstMessage(saslContinue.Data)
if err != nil { if err != nil {
return err return err
} }
@@ -66,27 +66,37 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
} }
// Receive server-final-message payload in a AuthenticationSASLFinal. // Receive server-final-message payload in a AuthenticationSASLFinal.
authMsg, err = c.rxAuthMsg(pgproto3.AuthTypeSASLFinal) saslFinal, err := c.rxSASLFinal()
if err != nil { if err != nil {
return err return err
} }
return sc.recvServerFinalMessage(authMsg.SASLData) return sc.recvServerFinalMessage(saslFinal.Data)
} }
func (c *PgConn) rxAuthMsg(typ uint32) (*pgproto3.Authentication, error) { func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) {
msg, err := c.ReceiveMessage() msg, err := c.receiveMessage()
if err != nil { if err != nil {
return nil, err return nil, err
} }
authMsg, ok := msg.(*pgproto3.Authentication) saslContinue, ok := msg.(*pgproto3.AuthenticationSASLContinue)
if !ok { if ok {
return nil, errors.New("unexpected message type") return saslContinue, nil
}
if authMsg.Type != typ {
return nil, errors.New("unexpected auth type")
} }
return authMsg, nil return nil, errors.New("expected AuthenticationSASLContinue message but received unexpected message")
}
func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
msg, err := c.receiveMessage()
if err != nil {
return nil, err
}
saslFinal, ok := msg.(*pgproto3.AuthenticationSASLFinal)
if ok {
return saslFinal, nil
}
return nil, errors.New("expected AuthenticationSASLFinal message but received unexpected message")
} }
type scramClient struct { type scramClient struct {
+48 -14
View File
@@ -5,6 +5,7 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"io"
"io/ioutil" "io/ioutil"
"math" "math"
"net" "net"
@@ -17,22 +18,26 @@ import (
"strings" "strings"
"time" "time"
"github.com/jackc/chunkreader/v2"
"github.com/jackc/pgpassfile" "github.com/jackc/pgpassfile"
"github.com/jackc/pgproto3/v2"
errors "golang.org/x/xerrors" errors "golang.org/x/xerrors"
) )
type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
// Config is the settings used to establish a connection to a PostgreSQL server. // Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig and
// then it can be modified. A manually initialized Config will cause ConnectConfig to panic.
type Config struct { type Config struct {
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
Port uint16 Port uint16
Database string Database string
User string User string
Password string Password string
TLSConfig *tls.Config // nil disables TLS TLSConfig *tls.Config // nil disables TLS
DialFunc DialFunc // e.g. net.Dialer.DialContext DialFunc DialFunc // e.g. net.Dialer.DialContext
BuildFrontend BuildFrontendFunc
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
Fallbacks []*FallbackConfig Fallbacks []*FallbackConfig
@@ -52,6 +57,8 @@ type Config struct {
// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received. // OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
OnNotification NotificationHandler OnNotification NotificationHandler
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
} }
// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a // FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a
@@ -134,36 +141,48 @@ func NetworkAddress(host string, port uint16) (network, address string) {
// //
// When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn // When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn
// does not. // does not.
//
// In addition, ParseConfig accepts the following options:
//
// min_read_buffer_size
// The minimum size of the internal read buffer. Default 8192.
func ParseConfig(connString string) (*Config, error) { func ParseConfig(connString string) (*Config, error) {
settings := defaultSettings() settings := defaultSettings()
addEnvSettings(settings) addEnvSettings(settings)
if connString != "" { if connString != "" {
// connString may be a database URL or a DSN // connString may be a database URL or a DSN
if strings.HasPrefix(connString, "postgres://") { if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
err := addURLSettings(settings, connString) err := addURLSettings(settings, connString)
if err != nil { if err != nil {
return nil, err return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err}
} }
} else { } else {
err := addDSNSettings(settings, connString) err := addDSNSettings(settings, connString)
if err != nil { if err != nil {
return nil, err return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err}
} }
} }
} }
minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "cannot parse min_read_buffer_size", err: err}
}
config := &Config{ config := &Config{
Database: settings["database"], createdByParseConfig: true,
User: settings["user"], Database: settings["database"],
Password: settings["password"], User: settings["user"],
RuntimeParams: make(map[string]string), Password: settings["password"],
RuntimeParams: make(map[string]string),
BuildFrontend: makeDefaultBuildFrontendFunc(int(minReadBufferSize)),
} }
if connectTimeout, present := settings["connect_timeout"]; present { if connectTimeout, present := settings["connect_timeout"]; present {
dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout) dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout)
if err != nil { if err != nil {
return nil, err return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err}
} }
config.DialFunc = dialFunc config.DialFunc = dialFunc
} else { } else {
@@ -184,6 +203,7 @@ func ParseConfig(connString string) (*Config, error) {
"sslcert": struct{}{}, "sslcert": struct{}{},
"sslrootcert": struct{}{}, "sslrootcert": struct{}{},
"target_session_attrs": struct{}{}, "target_session_attrs": struct{}{},
"min_read_buffer_size": struct{}{},
} }
for k, v := range settings { for k, v := range settings {
@@ -208,7 +228,7 @@ func ParseConfig(connString string) (*Config, error) {
port, err := parsePort(portStr) port, err := parsePort(portStr)
if err != nil { if err != nil {
return nil, errors.Errorf("invalid port: %w", err) return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err}
} }
var tlsConfigs []*tls.Config var tlsConfigs []*tls.Config
@@ -220,7 +240,7 @@ func ParseConfig(connString string) (*Config, error) {
var err error var err error
tlsConfigs, err = configTLS(settings) tlsConfigs, err = configTLS(settings)
if err != nil { if err != nil {
return nil, err return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err}
} }
} }
@@ -253,7 +273,7 @@ func ParseConfig(connString string) (*Config, error) {
if settings["target_session_attrs"] == "read-write" { if settings["target_session_attrs"] == "read-write" {
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite
} else if settings["target_session_attrs"] != "any" { } else if settings["target_session_attrs"] != "any" {
return nil, errors.Errorf("unknown target_session_attrs value: %v", settings["target_session_attrs"]) return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", settings["target_session_attrs"])}
} }
return config, nil return config, nil
@@ -276,6 +296,8 @@ func defaultSettings() map[string]string {
settings["target_session_attrs"] = "any" settings["target_session_attrs"] = "any"
settings["min_read_buffer_size"] = "8192"
return settings return settings
} }
@@ -473,6 +495,18 @@ func makeDefaultDialer() *net.Dialer {
return &net.Dialer{KeepAlive: 5 * time.Minute} return &net.Dialer{KeepAlive: 5 * time.Minute}
} }
func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc {
return func(r io.Reader, w io.Writer) Frontend {
cr, err := chunkreader.NewConfig(r, chunkreader.Config{MinBufLen: minBufferLen})
if err != nil {
panic(fmt.Sprintf("BUG: chunkreader.NewConfig failed: %v", err))
}
frontend := pgproto3.NewFrontend(cr, w)
return frontend
}
}
func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { func makeConnectTimeoutDialFunc(s string) (DialFunc, error) {
timeout, err := strconv.ParseInt(s, 10, 64) timeout, err := strconv.ParseInt(s, 10, 64)
if err != nil { if err != nil {
+24
View File
@@ -214,6 +214,18 @@ func TestParseConfig(t *testing.T) {
RuntimeParams: map[string]string{}, RuntimeParams: map[string]string{},
}, },
}, },
{
name: "database url postgresql protocol",
connString: "postgresql://jack@localhost:5432/mydb?sslmode=disable",
config: &pgconn.Config{
User: "jack",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{ {
name: "DSN everything", name: "DSN everything",
connString: "user=jack password=secret host=localhost port=5432 database=mydb sslmode=disable application_name=pgxtest search_path=myschema", connString: "user=jack password=secret host=localhost port=5432 database=mydb sslmode=disable application_name=pgxtest search_path=myschema",
@@ -561,3 +573,15 @@ func TestParseConfigReadsPgPassfile(t *testing.T) {
assertConfigsEqual(t, expected, actual, "passfile") assertConfigsEqual(t, expected, actual, "passfile")
} }
func TestParseConfigExtractsMinReadBufferSize(t *testing.T) {
t.Parallel()
config, err := pgconn.ParseConfig("min_read_buffer_size=0")
require.NoError(t, err)
_, present := config.RuntimeParams["min_read_buffer_size"]
require.False(t, present)
// The buffer size is internal so there isn't much that can be done to test it other than see that the runtime param
// was removed.
}
+1 -1
View File
@@ -15,7 +15,7 @@ reads all rows into memory.
Executing Multiple Queries in a Single Round Trip Executing Multiple Queries in a Single Round Trip
Exec and ExecBatch can execute multiple queries in a single round trip. The return readers that iterate over each query Exec and ExecBatch can execute multiple queries in a single round trip. They return readers that iterate over each query
result. The ReadAll method reads all query results into memory. result. The ReadAll method reads all query results into memory.
Context Support Context Support
+114 -42
View File
@@ -2,22 +2,31 @@ package pgconn
import ( import (
"context" "context"
"fmt"
"net" "net"
"strings"
errors "golang.org/x/xerrors" errors "golang.org/x/xerrors"
) )
// ErrTLSRefused occurs when the connection attempt requires TLS and the // SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server.
// PostgreSQL server refuses to use TLS func SafeToRetry(err error) bool {
var ErrTLSRefused = errors.New("server refused TLS connection") if e, ok := err.(interface{ SafeToRetry() bool }); ok {
return e.SafeToRetry()
}
return false
}
// ErrConnBusy occurs when the connection is busy (for example, in the middle of reading query results) and another // Timeout checks if err was was caused by a timeout. To be specific, it is true if err is or was caused by a
// action is attempted. // context.Canceled, context.Canceled or an implementer of net.Error where Timeout() is true.
var ErrConnBusy = errors.New("conn is busy") func Timeout(err error) bool {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return true
}
// ErrNoBytesSent is used to annotate an error that occurred without sending any bytes to the server. This can be used var netErr net.Error
// to implement safe retry logic. ErrNoBytesSent will never occur alone. It will always be wrapped by another error. return errors.As(err, &netErr) && netErr.Timeout()
var ErrNoBytesSent = errors.New("no bytes sent to server") }
// PgError represents an error reported by the PostgreSQL server. See // PgError represents an error reported by the PostgreSQL server. See
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for // http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
@@ -46,44 +55,107 @@ func (pe *PgError) Error() string {
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
} }
// linkedError connects two errors as if err wrapped next. type connectError struct {
type linkedError struct { config *Config
err error msg string
next error err error
} }
func (le *linkedError) Error() string { func (e *connectError) Error() string {
return le.err.Error() sb := &strings.Builder{}
} fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg)
if e.err != nil {
func (le *linkedError) Is(target error) bool { fmt.Fprintf(sb, " (%s)", e.err.Error())
return errors.Is(le.err, target)
}
func (le *linkedError) As(target interface{}) bool {
return errors.As(le.err, target)
}
func (le *linkedError) Unwrap() error {
return le.next
}
// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() ==
// true. Otherwise returns err.
func preferContextOverNetTimeoutError(ctx context.Context, err error) error {
if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil {
return ctx.Err()
} }
return err return sb.String()
} }
// linkErrors connects outer and inner as if the the fully unwrapped outer wrapped inner. If either outer or inner is nil then the other is returned. func (e *connectError) Unwrap() error {
func linkErrors(outer, inner error) error { return e.err
if outer == nil { }
return inner
type connLockError struct {
status string
}
func (e *connLockError) SafeToRetry() bool {
return true // a lock failure by definition happens before the connection is used.
}
func (e *connLockError) Error() string {
return e.status
}
type parseConfigError struct {
connString string
msg string
err error
}
func (e *parseConfigError) Error() string {
if e.err == nil {
return fmt.Sprintf("cannot parse `%s`: %s", e.connString, e.msg)
} }
if inner == nil { return fmt.Sprintf("cannot parse `%s`: %s (%s)", e.connString, e.msg, e.err.Error())
return outer }
func (e *parseConfigError) Unwrap() error {
return e.err
}
type pgconnError struct {
msg string
err error
safeToRetry bool
}
func (e *pgconnError) Error() string {
if e.msg == "" {
return e.err.Error()
} }
return &linkedError{err: outer, next: inner} if e.err == nil {
return e.msg
}
return fmt.Sprintf("%s: %s", e.msg, e.err.Error())
}
func (e *pgconnError) SafeToRetry() bool {
return e.safeToRetry
}
func (e *pgconnError) Unwrap() error {
return e.err
}
type contextAlreadyDoneError struct {
err error
}
func (e *contextAlreadyDoneError) Error() string {
return fmt.Sprintf("context already done: %s", e.err.Error())
}
func (e *contextAlreadyDoneError) SafeToRetry() bool {
return true
}
func (e *contextAlreadyDoneError) Unwrap() error {
return e.err
}
type writeError struct {
err error
safeToRetry bool
}
func (e *writeError) Error() string {
return fmt.Sprintf("write failed: %s", e.err.Error())
}
func (e *writeError) SafeToRetry() bool {
return e.safeToRetry
}
func (e *writeError) Unwrap() error {
return e.err
} }
+7 -5
View File
@@ -3,11 +3,13 @@ module github.com/jackc/pgconn
go 1.12 go 1.12
require ( require (
github.com/jackc/chunkreader/v2 v2.0.0
github.com/jackc/pgio v1.0.0 github.com/jackc/pgio v1.0.0
github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2
github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgpassfile v1.0.0
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29
github.com/stretchr/testify v1.3.0 github.com/stretchr/testify v1.4.0
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586
golang.org/x/text v0.3.0 golang.org/x/text v0.3.2
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7
) )
+84 -10
View File
@@ -1,31 +1,105 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0=
github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo=
github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs=
github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA=
github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE=
github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s=
github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE=
github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8=
github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye4717ITLaNwV9mWbJx0dLCpcRzdA=
github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db h1:UpaKn/gYxzH6/zWyRQH1S260zvKqwJJ4h8+Kf09ooh0= github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A=
github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 h1:vZp4bYotXUkFx7JUSm7U8KV/7Q0AOdrQxxBBj0ZmZsg=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM=
github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 h1:f2HwOeI1NIJyNFVVeh1gUISyt57iw/fmI/IXJfH3ATE=
github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM=
github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg=
github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc=
github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw=
github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y=
github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM=
github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc=
github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ=
github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc=
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4=
github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcNpq8q3BCACtVgNfoJxOV7g= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0=
go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE=
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e h1:nFYrTHrdrAOpShe27kaFHjsqYSEQ0KWqdWLu3xuZJts= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 h1:7KByu05hhLed2MO29w7p1XfZvZ13m8mub3shuVftRs0=
golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 h1:PPwnA7z1Pjf7XYaBP9GL1VAMZmcIWyFz7QCMSIIa3Bg= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 h1:bhOzK9QyoD0ogCnFro1m2mz41+Ib0oOhfJnBp5MR4K4=
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+201 -144
View File
@@ -40,9 +40,12 @@ type Notification struct {
Payload string Payload string
} }
// DialFunc is a function that can be used to connect to a PostgreSQL server // DialFunc is a function that can be used to connect to a PostgreSQL server.
type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
// BuildFrontendFunc is a function that can be used to create Frontend implementation for connection.
type BuildFrontendFunc func(r io.Reader, w io.Writer) Frontend
// NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at // NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at
// any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin // any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin
// of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY // of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY
@@ -55,16 +58,21 @@ type NoticeHandler func(*PgConn, *Notice)
// notice event. // notice event.
type NotificationHandler func(*PgConn, *Notification) type NotificationHandler func(*PgConn, *Notification)
// Frontend used to receive messages from backend.
type Frontend interface {
Receive() (pgproto3.BackendMessage, error)
}
// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
type PgConn struct { type PgConn struct {
conn net.Conn // the underlying TCP or unix domain socket connection conn net.Conn // the underlying TCP or unix domain socket connection
pid uint32 // backend pid pid uint32 // backend pid
secretKey uint32 // key to use to send a cancel query message to the server secretKey uint32 // key to use to send a cancel query message to the server
parameterStatuses map[string]string // parameters that have been reported by the server parameterStatuses map[string]string // parameters that have been reported by the server
TxStatus byte txStatus byte
Frontend *pgproto3.Frontend frontend Frontend
Config *Config config *Config
status byte // One of connStatus* constants status byte // One of connStatus* constants
@@ -91,22 +99,18 @@ func Connect(ctx context.Context, connString string) (*PgConn, error) {
return ConnectConfig(ctx, config) return ConnectConfig(ctx, config)
} }
// Connect establishes a connection to a PostgreSQL server using config. ctx can be used to cancel a connect attempt. // Connect establishes a connection to a PostgreSQL server using config. config must have been constructed with
// ParseConfig. ctx can be used to cancel a connect attempt.
// //
// If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An // If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An
// authentication error will terminate the chain of attempts (like libpq: // authentication error will terminate the chain of attempts (like libpq:
// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise, // https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise,
// if all attempts fail the last error is returned. // if all attempts fail the last error is returned.
func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err error) { func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err error) {
// For convenience set a few defaults if not already set. This makes it simpler to directly construct a config. // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from
if config.Port == 0 { // zero values.
config.Port = 5432 if !config.createdByParseConfig {
} panic("config must be created by ParseConfig")
if config.DialFunc == nil {
config.DialFunc = makeDefaultDialer().DialContext
}
if config.RuntimeParams == nil {
config.RuntimeParams = make(map[string]string)
} }
// Simplify usage by treating primary config and fallbacks the same. // Simplify usage by treating primary config and fallbacks the same.
@@ -124,19 +128,19 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
if err == nil { if err == nil {
break break
} else if err, ok := err.(*PgError); ok { } else if err, ok := err.(*PgError); ok {
return nil, err return nil, &connectError{config: config, msg: "server error", err: err}
} }
} }
if err != nil { if err != nil {
return nil, err return nil, err // no need to wrap in connectError because it will already be wrapped in all cases except PgError
} }
if config.AfterConnect != nil { if config.AfterConnect != nil {
err := config.AfterConnect(ctx, pgConn) err := config.AfterConnect(ctx, pgConn)
if err != nil { if err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, errors.Errorf("AfterConnect: %v", err) return nil, &connectError{config: config, msg: "AfterConnect error", err: err}
} }
} }
@@ -145,14 +149,14 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) {
pgConn := new(PgConn) pgConn := new(PgConn)
pgConn.Config = config pgConn.config = config
pgConn.wbuf = make([]byte, 0, 1024) pgConn.wbuf = make([]byte, 0, 1024)
var err error var err error
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
pgConn.conn, err = config.DialFunc(ctx, network, address) pgConn.conn, err = config.DialFunc(ctx, network, address)
if err != nil { if err != nil {
return nil, err return nil, &connectError{config: config, msg: "dial error", err: err}
} }
pgConn.parameterStatuses = make(map[string]string) pgConn.parameterStatuses = make(map[string]string)
@@ -160,7 +164,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
if fallbackConfig.TLSConfig != nil { if fallbackConfig.TLSConfig != nil {
if err := pgConn.startTLS(fallbackConfig.TLSConfig); err != nil { if err := pgConn.startTLS(fallbackConfig.TLSConfig); err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, err return nil, &connectError{config: config, msg: "tls error", err: err}
} }
} }
@@ -170,10 +174,10 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
func() { pgConn.conn.SetDeadline(time.Time{}) }, func() { pgConn.conn.SetDeadline(time.Time{}) },
) )
pgConn.Frontend, err = pgproto3.NewFrontend(pgproto3.NewChunkReader(pgConn.conn), pgConn.conn) pgConn.contextWatcher.Watch(ctx)
if err != nil { defer pgConn.contextWatcher.Unwatch()
return nil, err
} pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn)
startupMsg := pgproto3.StartupMessage{ startupMsg := pgproto3.StartupMessage{
ProtocolVersion: pgproto3.ProtocolVersionNumber, ProtocolVersion: pgproto3.ProtocolVersionNumber,
@@ -192,32 +196,52 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
if _, err := pgConn.conn.Write(startupMsg.Encode(pgConn.wbuf)); err != nil { if _, err := pgConn.conn.Write(startupMsg.Encode(pgConn.wbuf)); err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, err return nil, &connectError{config: config, msg: "failed to write startup message", err: err}
} }
for { for {
msg, err := pgConn.ReceiveMessage() msg, err := pgConn.receiveMessage()
if err != nil { if err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, err if err, ok := err.(*PgError); ok {
return nil, err
}
return nil, &connectError{config: config, msg: "failed to receive message", err: err}
} }
switch msg := msg.(type) { switch msg := msg.(type) {
case *pgproto3.BackendKeyData: case *pgproto3.BackendKeyData:
pgConn.pid = msg.ProcessID pgConn.pid = msg.ProcessID
pgConn.secretKey = msg.SecretKey pgConn.secretKey = msg.SecretKey
case *pgproto3.Authentication:
if err = pgConn.rxAuthenticationX(msg); err != nil { case *pgproto3.AuthenticationOk:
case *pgproto3.AuthenticationCleartextPassword:
err = pgConn.txPasswordMessage(pgConn.config.Password)
if err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, err return nil, &connectError{config: config, msg: "failed to write password message", err: err}
} }
case *pgproto3.AuthenticationMD5Password:
digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:]))
err = pgConn.txPasswordMessage(digestedPassword)
if err != nil {
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed to write password message", err: err}
}
case *pgproto3.AuthenticationSASL:
err = pgConn.scramAuth(msg.AuthMechanisms)
if err != nil {
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed SASL auth", err: err}
}
case *pgproto3.ReadyForQuery: case *pgproto3.ReadyForQuery:
pgConn.status = connStatusIdle pgConn.status = connStatusIdle
if config.ValidateConnect != nil { if config.ValidateConnect != nil {
err := config.ValidateConnect(ctx, pgConn) err := config.ValidateConnect(ctx, pgConn)
if err != nil { if err != nil {
pgConn.conn.Close() pgConn.conn.Close()
return nil, errors.Errorf("ValidateConnect: %v", err) return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err}
} }
} }
return pgConn, nil return pgConn, nil
@@ -225,10 +249,10 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
// handled by ReceiveMessage // handled by ReceiveMessage
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
pgConn.conn.Close() pgConn.conn.Close()
return nil, errorResponseToPgError(msg) return nil, ErrorResponseToPgError(msg)
default: default:
pgConn.conn.Close() pgConn.conn.Close()
return nil, errors.New("unexpected message") return nil, &connectError{config: config, msg: "received unexpected message", err: err}
} }
} }
} }
@@ -245,7 +269,7 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) {
} }
if response[0] != 'S' { if response[0] != 'S' {
return ErrTLSRefused return errors.New("server refused TLS connection")
} }
pgConn.conn = tls.Client(pgConn.conn, tlsConfig) pgConn.conn = tls.Client(pgConn.conn, tlsConfig)
@@ -253,23 +277,6 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) {
return nil return nil
} }
func (pgConn *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) {
switch msg.Type {
case pgproto3.AuthTypeOk:
case pgproto3.AuthTypeCleartextPassword:
err = pgConn.txPasswordMessage(pgConn.Config.Password)
case pgproto3.AuthTypeMD5Password:
digestedPassword := "md5" + hexMD5(hexMD5(pgConn.Config.Password+pgConn.Config.User)+string(msg.Salt[:]))
err = pgConn.txPasswordMessage(digestedPassword)
case pgproto3.AuthTypeSASL:
err = pgConn.scramAuth(msg.SASLAuthMechanisms)
default:
err = errors.New("Received unknown authentication message")
}
return
}
func (pgConn *PgConn) txPasswordMessage(password string) (err error) { func (pgConn *PgConn) txPasswordMessage(password string) (err error) {
msg := &pgproto3.PasswordMessage{Password: password} msg := &pgproto3.PasswordMessage{Password: password}
_, err = pgConn.conn.Write(msg.Encode(pgConn.wbuf)) _, err = pgConn.conn.Write(msg.Encode(pgConn.wbuf))
@@ -292,7 +299,7 @@ func (pgConn *PgConn) signalMessage() chan struct{} {
ch := make(chan struct{}) ch := make(chan struct{})
go func() { go func() {
pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.Frontend.Receive() pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive()
pgConn.bufferingReceiveMux.Unlock() pgConn.bufferingReceiveMux.Unlock()
close(ch) close(ch)
}() }()
@@ -300,7 +307,64 @@ func (pgConn *PgConn) signalMessage() chan struct{} {
return ch return ch
} }
func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { // SendBytes sends buf to the PostgreSQL server. It must only be used when the connection is not busy. e.g. It is as
// error to call SendBytes while reading the result of a query.
//
// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly.
// See https://www.postgresql.org/docs/current/protocol.html.
func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error {
if err := pgConn.lock(); err != nil {
return err
}
defer pgConn.unlock()
select {
case <-ctx.Done():
return &contextAlreadyDoneError{err: ctx.Err()}
default:
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
n, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
return &writeError{err: err, safeToRetry: n == 0}
}
return nil
}
// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the
// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages
// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger
// the OnNotification callback.
//
// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly.
// See https://www.postgresql.org/docs/current/protocol.html.
func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessage, error) {
if err := pgConn.lock(); err != nil {
return nil, err
}
defer pgConn.unlock()
select {
case <-ctx.Done():
return nil, &contextAlreadyDoneError{err: ctx.Err()}
default:
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
msg, err := pgConn.receiveMessage()
if err != nil {
err = &pgconnError{msg: "receive message failed", err: err, safeToRetry: true}
}
return msg, err
}
// receiveMessage receives a message without setting up context cancellation
func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
var msg pgproto3.BackendMessage var msg pgproto3.BackendMessage
var err error var err error
if pgConn.bufferingReceive { if pgConn.bufferingReceive {
@@ -312,10 +376,10 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) {
// If a timeout error happened in the background try the read again. // If a timeout error happened in the background try the read again.
if netErr, ok := err.(net.Error); ok && netErr.Timeout() { if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
msg, err = pgConn.Frontend.Receive() msg, err = pgConn.frontend.Receive()
} }
} else { } else {
msg, err = pgConn.Frontend.Receive() msg, err = pgConn.frontend.Receive()
} }
if err != nil { if err != nil {
@@ -329,21 +393,21 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) {
switch msg := msg.(type) { switch msg := msg.(type) {
case *pgproto3.ReadyForQuery: case *pgproto3.ReadyForQuery:
pgConn.TxStatus = msg.TxStatus pgConn.txStatus = msg.TxStatus
case *pgproto3.ParameterStatus: case *pgproto3.ParameterStatus:
pgConn.parameterStatuses[msg.Name] = msg.Value pgConn.parameterStatuses[msg.Name] = msg.Value
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
if msg.Severity == "FATAL" { if msg.Severity == "FATAL" {
pgConn.hardClose() pgConn.hardClose()
return nil, errorResponseToPgError(msg) return nil, ErrorResponseToPgError(msg)
} }
case *pgproto3.NoticeResponse: case *pgproto3.NoticeResponse:
if pgConn.Config.OnNotice != nil { if pgConn.config.OnNotice != nil {
pgConn.Config.OnNotice(pgConn, noticeResponseToNotice(msg)) pgConn.config.OnNotice(pgConn, noticeResponseToNotice(msg))
} }
case *pgproto3.NotificationResponse: case *pgproto3.NotificationResponse:
if pgConn.Config.OnNotification != nil { if pgConn.config.OnNotification != nil {
pgConn.Config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload}) pgConn.config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload})
} }
} }
@@ -360,6 +424,11 @@ func (pgConn *PgConn) PID() uint32 {
return pgConn.pid return pgConn.pid
} }
// TxStatus returns the current TxStatus as reported by the server.
func (pgConn *PgConn) TxStatus() byte {
return pgConn.txStatus
}
// SecretKey returns the backend secret key used to send a cancel query message to the server. // SecretKey returns the backend secret key used to send a cancel query message to the server.
func (pgConn *PgConn) SecretKey() uint32 { func (pgConn *PgConn) SecretKey() uint32 {
return pgConn.secretKey return pgConn.secretKey
@@ -381,12 +450,12 @@ func (pgConn *PgConn) Close(ctx context.Context) error {
_, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) _, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4})
if err != nil { if err != nil {
return linkErrors(ctx.Err(), err) return err
} }
_, err = pgConn.conn.Read(make([]byte, 1)) _, err = pgConn.conn.Read(make([]byte, 1))
if err != io.EOF { if err != io.EOF {
return linkErrors(ctx.Err(), err) return err
} }
return pgConn.conn.Close() return pgConn.conn.Close()
@@ -402,21 +471,20 @@ func (pgConn *PgConn) hardClose() error {
return pgConn.conn.Close() return pgConn.conn.Close()
} }
// TODO - rethink how to report status. At the moment this is just a temporary measure so pgx.Conn can detect death of // IsClosed reports if the connection has been closed.
// underlying connection. func (pgConn *PgConn) IsClosed() bool {
func (pgConn *PgConn) IsAlive() bool { return pgConn.status < connStatusIdle
return pgConn.status >= connStatusIdle
} }
// lock locks the connection. It panics if the connection is already locked or is closed. // lock locks the connection.
func (pgConn *PgConn) lock() error { func (pgConn *PgConn) lock() error {
switch pgConn.status { switch pgConn.status {
case connStatusBusy: case connStatusBusy:
return ErrConnBusy // This only should be possible in case of an application bug. return &connLockError{status: "conn busy"} // This only should be possible in case of an application bug.
case connStatusClosed: case connStatusClosed:
return errors.New("conn closed") return &connLockError{status: "conn closed"}
case connStatusUninitialized: case connStatusUninitialized:
return errors.New("conn uninitialized") return &connLockError{status: "conn uninitialized"}
} }
pgConn.status = connStatusBusy pgConn.status = connStatusBusy
return nil return nil
@@ -456,23 +524,24 @@ func (ct CommandTag) String() string {
return string(ct) return string(ct)
} }
type PreparedStatementDescription struct { type StatementDescription struct {
Name string Name string
SQL string SQL string
ParamOIDs []uint32 ParamOIDs []uint32
Fields []pgproto3.FieldDescription Fields []pgproto3.FieldDescription
} }
// Prepare creates a prepared statement. // Prepare creates a prepared statement. If the name is empty, the anonymous prepared statement will be used. This
func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) { // allows Prepare to also to describe statements without creating a server-side prepared statement.
func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) {
if err := pgConn.lock(); err != nil { if err := pgConn.lock(); err != nil {
return nil, linkErrors(err, ErrNoBytesSent) return nil, err
} }
defer pgConn.unlock() defer pgConn.unlock()
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, linkErrors(ctx.Err(), ErrNoBytesSent) return nil, &contextAlreadyDoneError{err: ctx.Err()}
default: default:
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
@@ -486,22 +555,19 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
n, err := pgConn.conn.Write(buf) n, err := pgConn.conn.Write(buf)
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
if n == 0 { return nil, &pgconnError{msg: "write failed", err: err, safeToRetry: n == 0}
err = linkErrors(err, ErrNoBytesSent)
}
return nil, linkErrors(ctx.Err(), err)
} }
psd := &PreparedStatementDescription{Name: name, SQL: sql} psd := &StatementDescription{Name: name, SQL: sql}
var parseErr error var parseErr error
readloop: readloop:
for { for {
msg, err := pgConn.ReceiveMessage() msg, err := pgConn.receiveMessage()
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err) return nil, err
} }
switch msg := msg.(type) { switch msg := msg.(type) {
@@ -512,7 +578,7 @@ readloop:
psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields)) psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields))
copy(psd.Fields, msg.Fields) copy(psd.Fields, msg.Fields)
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
parseErr = errorResponseToPgError(msg) parseErr = ErrorResponseToPgError(msg)
case *pgproto3.ReadyForQuery: case *pgproto3.ReadyForQuery:
break readloop break readloop
} }
@@ -524,7 +590,8 @@ readloop:
return psd, nil return psd, nil
} }
func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { // ErrorResponseToPgError converts a wire protocol error message to a *PgError.
func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {
return &PgError{ return &PgError{
Severity: msg.Severity, Severity: msg.Severity,
Code: string(msg.Code), Code: string(msg.Code),
@@ -547,7 +614,7 @@ func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {
} }
func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice { func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice {
pgerr := errorResponseToPgError((*pgproto3.ErrorResponse)(msg)) pgerr := ErrorResponseToPgError((*pgproto3.ErrorResponse)(msg))
return (*Notice)(pgerr) return (*Notice)(pgerr)
} }
@@ -559,7 +626,7 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
// the connection config. This is important in high availability configurations where fallback connections may be // the connection config. This is important in high availability configurations where fallback connections may be
// specified or DNS may be used to load balance. // specified or DNS may be used to load balance.
serverAddr := pgConn.conn.RemoteAddr() serverAddr := pgConn.conn.RemoteAddr()
cancelConn, err := pgConn.Config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String())
if err != nil { if err != nil {
return err return err
} }
@@ -579,12 +646,12 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey)) binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey))
_, err = cancelConn.Write(buf) _, err = cancelConn.Write(buf)
if err != nil { if err != nil {
return linkErrors(ctx.Err(), err) return err
} }
_, err = cancelConn.Read(buf) _, err = cancelConn.Read(buf)
if err != io.EOF { if err != io.EOF {
return errors.Errorf("Server failed to close connection after cancel query request: %w", linkErrors(ctx.Err(), err)) return err
} }
return nil return nil
@@ -608,9 +675,9 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
defer pgConn.contextWatcher.Unwatch() defer pgConn.contextWatcher.Unwatch()
for { for {
msg, err := pgConn.ReceiveMessage() msg, err := pgConn.receiveMessage()
if err != nil { if err != nil {
return linkErrors(ctx.Err(), err) return err
} }
switch msg.(type) { switch msg.(type) {
@@ -629,7 +696,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
if err := pgConn.lock(); err != nil { if err := pgConn.lock(); err != nil {
return &MultiResultReader{ return &MultiResultReader{
closed: true, closed: true,
err: linkErrors(err, ErrNoBytesSent), err: err,
} }
} }
@@ -642,7 +709,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
select { select {
case <-ctx.Done(): case <-ctx.Done():
multiResult.closed = true multiResult.closed = true
multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent) multiResult.err = &contextAlreadyDoneError{err: ctx.Err()}
pgConn.unlock() pgConn.unlock()
return multiResult return multiResult
default: default:
@@ -657,10 +724,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
pgConn.hardClose() pgConn.hardClose()
pgConn.contextWatcher.Unwatch() pgConn.contextWatcher.Unwatch()
multiResult.closed = true multiResult.closed = true
if n == 0 { multiResult.err = &writeError{err: err, safeToRetry: n == 0}
err = linkErrors(err, ErrNoBytesSent)
}
multiResult.err = linkErrors(ctx.Err(), err)
pgConn.unlock() pgConn.unlock()
return multiResult return multiResult
} }
@@ -729,19 +793,18 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
} }
func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader { func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader {
if err := pgConn.lock(); err != nil {
return &ResultReader{
closed: true,
err: linkErrors(err, ErrNoBytesSent),
}
}
pgConn.resultReader = ResultReader{ pgConn.resultReader = ResultReader{
pgConn: pgConn, pgConn: pgConn,
ctx: ctx, ctx: ctx,
} }
result := &pgConn.resultReader result := &pgConn.resultReader
if err := pgConn.lock(); err != nil {
result.concludeCommand(nil, err)
result.closed = true
return result
}
if len(paramValues) > math.MaxUint16 { if len(paramValues) > math.MaxUint16 {
result.concludeCommand(nil, errors.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) result.concludeCommand(nil, errors.Errorf("extended protocol limited to %v parameters", math.MaxUint16))
result.closed = true result.closed = true
@@ -751,7 +814,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
select { select {
case <-ctx.Done(): case <-ctx.Done():
result.concludeCommand(nil, linkErrors(ctx.Err(), ErrNoBytesSent)) result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()})
result.closed = true result.closed = true
pgConn.unlock() pgConn.unlock()
return result return result
@@ -770,10 +833,7 @@ func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result
n, err := pgConn.conn.Write(buf) n, err := pgConn.conn.Write(buf)
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
if n == 0 { result.concludeCommand(nil, &writeError{err: err, safeToRetry: n == 0})
err = linkErrors(err, ErrNoBytesSent)
}
result.concludeCommand(nil, linkErrors(ctx.Err(), err))
pgConn.contextWatcher.Unwatch() pgConn.contextWatcher.Unwatch()
result.closed = true result.closed = true
pgConn.unlock() pgConn.unlock()
@@ -783,13 +843,13 @@ func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result
// CopyTo executes the copy command sql and copies the results to w. // CopyTo executes the copy command sql and copies the results to w.
func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) {
if err := pgConn.lock(); err != nil { if err := pgConn.lock(); err != nil {
return nil, linkErrors(err, ErrNoBytesSent) return nil, err
} }
select { select {
case <-ctx.Done(): case <-ctx.Done():
pgConn.unlock() pgConn.unlock()
return nil, linkErrors(ctx.Err(), ErrNoBytesSent) return nil, &contextAlreadyDoneError{err: ctx.Err()}
default: default:
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
@@ -803,20 +863,17 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
pgConn.unlock() pgConn.unlock()
if n == 0 { return nil, &writeError{err: err, safeToRetry: n == 0}
err = linkErrors(err, ErrNoBytesSent)
}
return nil, linkErrors(ctx.Err(), err)
} }
// Read results // Read results
var commandTag CommandTag var commandTag CommandTag
var pgErr error var pgErr error
for { for {
msg, err := pgConn.ReceiveMessage() msg, err := pgConn.receiveMessage()
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err) return nil, err
} }
switch msg := msg.(type) { switch msg := msg.(type) {
@@ -833,7 +890,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
case *pgproto3.CommandComplete: case *pgproto3.CommandComplete:
commandTag = CommandTag(msg.CommandTag) commandTag = CommandTag(msg.CommandTag)
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
pgErr = errorResponseToPgError(msg) pgErr = ErrorResponseToPgError(msg)
} }
} }
} }
@@ -844,13 +901,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
// could still block. // could still block.
func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) {
if err := pgConn.lock(); err != nil { if err := pgConn.lock(); err != nil {
return nil, linkErrors(err, ErrNoBytesSent) return nil, err
} }
defer pgConn.unlock() defer pgConn.unlock()
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, linkErrors(ctx.Err(), ErrNoBytesSent) return nil, &contextAlreadyDoneError{err: ctx.Err()}
default: default:
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
@@ -863,10 +920,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
n, err := pgConn.conn.Write(buf) n, err := pgConn.conn.Write(buf)
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
if n == 0 { return nil, &writeError{err: err, safeToRetry: n == 0}
err = linkErrors(err, ErrNoBytesSent)
}
return nil, linkErrors(ctx.Err(), err)
} }
// Read until copy in response or error. // Read until copy in response or error.
@@ -874,17 +928,17 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
var pgErr error var pgErr error
pendingCopyInResponse := true pendingCopyInResponse := true
for pendingCopyInResponse { for pendingCopyInResponse {
msg, err := pgConn.ReceiveMessage() msg, err := pgConn.receiveMessage()
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err) return nil, err
} }
switch msg := msg.(type) { switch msg := msg.(type) {
case *pgproto3.CopyInResponse: case *pgproto3.CopyInResponse:
pendingCopyInResponse = false pendingCopyInResponse = false
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
pgErr = errorResponseToPgError(msg) pgErr = ErrorResponseToPgError(msg)
case *pgproto3.ReadyForQuery: case *pgproto3.ReadyForQuery:
return commandTag, pgErr return commandTag, pgErr
} }
@@ -906,21 +960,21 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
_, err = pgConn.conn.Write(buf) _, err = pgConn.conn.Write(buf)
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err) return nil, err
} }
} }
select { select {
case <-signalMessageChan: case <-signalMessageChan:
msg, err := pgConn.ReceiveMessage() msg, err := pgConn.receiveMessage()
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err) return nil, err
} }
switch msg := msg.(type) { switch msg := msg.(type) {
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
pgErr = errorResponseToPgError(msg) pgErr = ErrorResponseToPgError(msg)
} }
default: default:
} }
@@ -937,15 +991,15 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
_, err = pgConn.conn.Write(buf) _, err = pgConn.conn.Write(buf)
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err) return nil, err
} }
// Read results // Read results
for { for {
msg, err := pgConn.ReceiveMessage() msg, err := pgConn.receiveMessage()
if err != nil { if err != nil {
pgConn.hardClose() pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err) return nil, err
} }
switch msg := msg.(type) { switch msg := msg.(type) {
@@ -954,7 +1008,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
case *pgproto3.CommandComplete: case *pgproto3.CommandComplete:
commandTag = CommandTag(msg.CommandTag) commandTag = CommandTag(msg.CommandTag)
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
pgErr = errorResponseToPgError(msg) pgErr = ErrorResponseToPgError(msg)
} }
} }
} }
@@ -983,11 +1037,11 @@ func (mrr *MultiResultReader) ReadAll() ([]*Result, error) {
} }
func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) { func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) {
msg, err := mrr.pgConn.ReceiveMessage() msg, err := mrr.pgConn.receiveMessage()
if err != nil { if err != nil {
mrr.pgConn.contextWatcher.Unwatch() mrr.pgConn.contextWatcher.Unwatch()
mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) mrr.err = err
mrr.closed = true mrr.closed = true
mrr.pgConn.hardClose() mrr.pgConn.hardClose()
return nil, mrr.err return nil, mrr.err
@@ -999,7 +1053,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
mrr.closed = true mrr.closed = true
mrr.pgConn.unlock() mrr.pgConn.unlock()
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
mrr.err = errorResponseToPgError(msg) mrr.err = ErrorResponseToPgError(msg)
} }
return msg, nil return msg, nil
@@ -1151,7 +1205,10 @@ func (rr *ResultReader) Close() (CommandTag, error) {
return nil, rr.err return nil, rr.err
} }
switch msg.(type) { switch msg := msg.(type) {
// Detect a deferred constraint violation where the ErrorResponse is sent after CommandComplete.
case *pgproto3.ErrorResponse:
rr.err = ErrorResponseToPgError(msg)
case *pgproto3.ReadyForQuery: case *pgproto3.ReadyForQuery:
rr.pgConn.contextWatcher.Unwatch() rr.pgConn.contextWatcher.Unwatch()
rr.pgConn.unlock() rr.pgConn.unlock()
@@ -1165,7 +1222,7 @@ func (rr *ResultReader) Close() (CommandTag, error) {
func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) { func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) {
if rr.multiResultReader == nil { if rr.multiResultReader == nil {
msg, err = rr.pgConn.ReceiveMessage() msg, err = rr.pgConn.receiveMessage()
} else { } else {
msg, err = rr.multiResultReader.receiveMessage() msg, err = rr.multiResultReader.receiveMessage()
} }
@@ -1187,7 +1244,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error
case *pgproto3.CommandComplete: case *pgproto3.CommandComplete:
rr.concludeCommand(CommandTag(msg.CommandTag), nil) rr.concludeCommand(CommandTag(msg.CommandTag), nil)
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
rr.concludeCommand(nil, errorResponseToPgError(msg)) rr.concludeCommand(nil, ErrorResponseToPgError(msg))
} }
return msg, nil return msg, nil
@@ -1199,7 +1256,7 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) {
} }
rr.commandTag = commandTag rr.commandTag = commandTag
rr.err = preferContextOverNetTimeoutError(rr.ctx, err) rr.err = err
rr.fieldDescriptions = nil rr.fieldDescriptions = nil
rr.rowValues = nil rr.rowValues = nil
rr.commandConcluded = true rr.commandConcluded = true
@@ -1229,7 +1286,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
if err := pgConn.lock(); err != nil { if err := pgConn.lock(); err != nil {
return &MultiResultReader{ return &MultiResultReader{
closed: true, closed: true,
err: linkErrors(err, ErrNoBytesSent), err: err,
} }
} }
@@ -1242,7 +1299,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
select { select {
case <-ctx.Done(): case <-ctx.Done():
multiResult.closed = true multiResult.closed = true
multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent) multiResult.err = &contextAlreadyDoneError{err: ctx.Err()}
pgConn.unlock() pgConn.unlock()
return multiResult return multiResult
default: default:
+219 -27
View File
@@ -18,6 +18,8 @@ import (
"time" "time"
"github.com/jackc/pgconn" "github.com/jackc/pgconn"
"github.com/jackc/pgmock"
"github.com/jackc/pgproto3/v2"
errors "golang.org/x/xerrors" errors "golang.org/x/xerrors"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -72,6 +74,67 @@ func TestConnectTLS(t *testing.T) {
closeConn(t, conn) closeConn(t, conn)
} }
type pgmockWaitStep time.Duration
func (s pgmockWaitStep) Step(*pgproto3.Backend) error {
time.Sleep(time.Duration(s))
return nil
}
func TestConnectWithContextThatTimesOut(t *testing.T) {
t.Parallel()
script := &pgmock.Script{
Steps: []pgmock.Step{
pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
pgmock.SendMessage(&pgproto3.AuthenticationOk{}),
pgmockWaitStep(time.Millisecond * 500),
pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
},
}
ln, err := net.Listen("tcp", "127.0.0.1:")
require.NoError(t, err)
defer ln.Close()
serverErrChan := make(chan error, 1)
go func() {
defer close(serverErrChan)
conn, err := ln.Accept()
if err != nil {
serverErrChan <- err
return
}
defer conn.Close()
err = conn.SetDeadline(time.Now().Add(time.Millisecond * 450))
if err != nil {
serverErrChan <- err
return
}
err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn))
if err != nil {
serverErrChan <- err
return
}
}()
parts := strings.Split(ln.Addr().String(), ":")
host := parts[0]
port := parts[1]
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
tooLate := time.Now().Add(time.Millisecond * 500)
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50)
defer cancel()
_, err = pgconn.Connect(ctx, connStr)
require.True(t, pgconn.Timeout(err), err)
require.True(t, time.Now().Before(tooLate))
}
func TestConnectInvalidUser(t *testing.T) { func TestConnectInvalidUser(t *testing.T) {
t.Parallel() t.Parallel()
@@ -85,14 +148,11 @@ func TestConnectInvalidUser(t *testing.T) {
config.User = "pgxinvalidusertest" config.User = "pgxinvalidusertest"
conn, err := pgconn.ConnectConfig(context.Background(), config) _, err = pgconn.ConnectConfig(context.Background(), config)
if err == nil { require.Error(t, err)
conn.Close(context.Background()) pgErr, ok := errors.Unwrap(err).(*pgconn.PgError)
t.Fatal("expected err but got none")
}
pgErr, ok := err.(*pgconn.PgError)
if !ok { if !ok {
t.Fatalf("Expected to receive a PgError, instead received: %v", err) t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err)
} }
if pgErr.Code != "28000" && pgErr.Code != "28P01" { if pgErr.Code != "28000" && pgErr.Code != "28P01" {
t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr)
@@ -262,6 +322,14 @@ func TestConnectWithAfterConnect(t *testing.T) {
assert.Equal(t, []byte("foobar"), results[0].Rows[0][0]) assert.Equal(t, []byte("foobar"), results[0].Rows[0][0])
} }
func TestConnectConfigRequiresConfigFromParseConfig(t *testing.T) {
t.Parallel()
config := &pgconn.Config{}
require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgconn.ConnectConfig(context.Background(), config) })
}
func TestConnPrepareSyntaxError(t *testing.T) { func TestConnPrepareSyntaxError(t *testing.T) {
t.Parallel() t.Parallel()
@@ -289,7 +357,7 @@ func TestConnPrepareContextPrecanceled(t *testing.T) {
assert.Nil(t, psd) assert.Nil(t, psd)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled)) assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) assert.True(t, pgconn.SafeToRetry(err))
ensureConnValid(t, pgConn) ensureConnValid(t, pgConn)
} }
@@ -381,6 +449,34 @@ func TestConnExecMultipleQueriesError(t *testing.T) {
ensureConnValid(t, pgConn) ensureConnValid(t, pgConn)
} }
func TestConnExecDeferredError(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer closeConn(t, pgConn)
setupSQL := `create temporary table t (
id text primary key,
n int not null,
unique (n) deferrable initially deferred
);
insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`
_, err = pgConn.Exec(context.Background(), setupSQL).ReadAll()
assert.NoError(t, err)
_, err = pgConn.Exec(context.Background(), `update t set n=n+1 where id='b' returning *`).ReadAll()
require.NotNil(t, err)
var pgErr *pgconn.PgError
require.True(t, errors.As(err, &pgErr))
require.Equal(t, "23505", pgErr.Code)
ensureConnValid(t, pgConn)
}
func TestConnExecContextCanceled(t *testing.T) { func TestConnExecContextCanceled(t *testing.T) {
t.Parallel() t.Parallel()
@@ -395,8 +491,8 @@ func TestConnExecContextCanceled(t *testing.T) {
for multiResult.NextResult() { for multiResult.NextResult() {
} }
err = multiResult.Close() err = multiResult.Close()
assert.Equal(t, context.DeadlineExceeded, err) assert.True(t, pgconn.Timeout(err))
assert.False(t, pgConn.IsAlive()) assert.True(t, pgConn.IsClosed())
} }
func TestConnExecContextPrecanceled(t *testing.T) { func TestConnExecContextPrecanceled(t *testing.T) {
@@ -411,7 +507,7 @@ func TestConnExecContextPrecanceled(t *testing.T) {
_, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll()
assert.Error(t, err) assert.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled)) assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) assert.True(t, pgconn.SafeToRetry(err))
ensureConnValid(t, pgConn) ensureConnValid(t, pgConn)
} }
@@ -437,6 +533,33 @@ func TestConnExecParams(t *testing.T) {
ensureConnValid(t, pgConn) ensureConnValid(t, pgConn)
} }
func TestConnExecParamsDeferredError(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer closeConn(t, pgConn)
setupSQL := `create temporary table t (
id text primary key,
n int not null,
unique (n) deferrable initially deferred
);
insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`
_, err = pgConn.Exec(context.Background(), setupSQL).ReadAll()
assert.NoError(t, err)
result := pgConn.ExecParams(context.Background(), `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read()
require.NotNil(t, result.Err)
var pgErr *pgconn.PgError
require.True(t, errors.As(result.Err, &pgErr))
require.Equal(t, "23505", pgErr.Code)
ensureConnValid(t, pgConn)
}
func TestConnExecParamsMaxNumberOfParams(t *testing.T) { func TestConnExecParamsMaxNumberOfParams(t *testing.T) {
t.Parallel() t.Parallel()
@@ -500,9 +623,9 @@ func TestConnExecParamsCanceled(t *testing.T) {
assert.Equal(t, 0, rowCount) assert.Equal(t, 0, rowCount)
commandTag, err := result.Close() commandTag, err := result.Close()
assert.Equal(t, pgconn.CommandTag(nil), commandTag) assert.Equal(t, pgconn.CommandTag(nil), commandTag)
assert.Equal(t, context.DeadlineExceeded, err) assert.True(t, pgconn.Timeout(err))
assert.False(t, pgConn.IsAlive()) assert.True(t, pgConn.IsClosed())
} }
func TestConnExecParamsPrecanceled(t *testing.T) { func TestConnExecParamsPrecanceled(t *testing.T) {
@@ -517,7 +640,7 @@ func TestConnExecParamsPrecanceled(t *testing.T) {
result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read() result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read()
require.Error(t, result.Err) require.Error(t, result.Err)
assert.True(t, errors.Is(result.Err, context.Canceled)) assert.True(t, errors.Is(result.Err, context.Canceled))
assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent)) assert.True(t, pgconn.SafeToRetry(result.Err))
ensureConnValid(t, pgConn) ensureConnValid(t, pgConn)
} }
@@ -627,8 +750,8 @@ func TestConnExecPreparedCanceled(t *testing.T) {
assert.Equal(t, 0, rowCount) assert.Equal(t, 0, rowCount)
commandTag, err := result.Close() commandTag, err := result.Close()
assert.Equal(t, pgconn.CommandTag(nil), commandTag) assert.Equal(t, pgconn.CommandTag(nil), commandTag)
assert.Equal(t, context.DeadlineExceeded, err) assert.True(t, pgconn.Timeout(err))
assert.False(t, pgConn.IsAlive()) assert.True(t, pgConn.IsClosed())
} }
func TestConnExecPreparedPrecanceled(t *testing.T) { func TestConnExecPreparedPrecanceled(t *testing.T) {
@@ -646,7 +769,7 @@ func TestConnExecPreparedPrecanceled(t *testing.T) {
result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read()
require.Error(t, result.Err) require.Error(t, result.Err)
assert.True(t, errors.Is(result.Err, context.Canceled)) assert.True(t, errors.Is(result.Err, context.Canceled))
assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent)) assert.True(t, pgconn.SafeToRetry(result.Err))
ensureConnValid(t, pgConn) ensureConnValid(t, pgConn)
} }
@@ -683,6 +806,36 @@ func TestConnExecBatch(t *testing.T) {
assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) assert.Equal(t, "SELECT 1", string(results[2].CommandTag))
} }
func TestConnExecBatchDeferredError(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer closeConn(t, pgConn)
setupSQL := `create temporary table t (
id text primary key,
n int not null,
unique (n) deferrable initially deferred
);
insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`
_, err = pgConn.Exec(context.Background(), setupSQL).ReadAll()
assert.NoError(t, err)
batch := &pgconn.Batch{}
batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil)
_, err = pgConn.ExecBatch(context.Background(), batch).ReadAll()
require.NotNil(t, err)
var pgErr *pgconn.PgError
require.True(t, errors.As(err, &pgErr))
require.Equal(t, "23505", pgErr.Code)
ensureConnValid(t, pgConn)
}
func TestConnExecBatchPrecanceled(t *testing.T) { func TestConnExecBatchPrecanceled(t *testing.T) {
t.Parallel() t.Parallel()
@@ -704,7 +857,7 @@ func TestConnExecBatchPrecanceled(t *testing.T) {
_, err = pgConn.ExecBatch(ctx, batch).ReadAll() _, err = pgConn.ExecBatch(ctx, batch).ReadAll()
require.Error(t, err) require.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled)) assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) assert.True(t, pgconn.SafeToRetry(err))
ensureConnValid(t, pgConn) ensureConnValid(t, pgConn)
} }
@@ -777,8 +930,8 @@ func TestConnLocking(t *testing.T) {
mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") mrr := pgConn.Exec(context.Background(), "select 'Hello, world'")
_, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() _, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll()
assert.Error(t, err) assert.Error(t, err)
assert.True(t, errors.Is(err, pgconn.ErrConnBusy)) assert.Equal(t, "conn busy", err.Error())
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) assert.True(t, pgconn.SafeToRetry(err))
results, err := mrr.ReadAll() results, err := mrr.ReadAll()
assert.NoError(t, err) assert.NoError(t, err)
@@ -935,7 +1088,7 @@ func TestConnWaitForNotificationTimeout(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
err = pgConn.WaitForNotification(ctx) err = pgConn.WaitForNotification(ctx)
cancel() cancel()
assert.True(t, errors.Is(err, context.DeadlineExceeded)) assert.True(t, pgconn.Timeout(err))
ensureConnValid(t, pgConn) ensureConnValid(t, pgConn)
} }
@@ -1045,10 +1198,10 @@ func TestConnCopyToCanceled(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel() defer cancel()
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout") res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout")
assert.True(t, errors.Is(err, context.DeadlineExceeded)) assert.Error(t, err)
assert.Equal(t, pgconn.CommandTag(nil), res) assert.Equal(t, pgconn.CommandTag(nil), res)
assert.False(t, pgConn.IsAlive()) assert.True(t, pgConn.IsClosed())
} }
func TestConnCopyToPrecanceled(t *testing.T) { func TestConnCopyToPrecanceled(t *testing.T) {
@@ -1065,7 +1218,7 @@ func TestConnCopyToPrecanceled(t *testing.T) {
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout") res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout")
require.Error(t, err) require.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled)) assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) assert.True(t, pgconn.SafeToRetry(err))
assert.Equal(t, pgconn.CommandTag(nil), res) assert.Equal(t, pgconn.CommandTag(nil), res)
ensureConnValid(t, pgConn) ensureConnValid(t, pgConn)
@@ -1137,9 +1290,9 @@ func TestConnCopyFromCanceled(t *testing.T) {
ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
cancel() cancel()
assert.Equal(t, int64(0), ct.RowsAffected()) assert.Equal(t, int64(0), ct.RowsAffected())
assert.True(t, errors.Is(err, context.DeadlineExceeded)) assert.Error(t, err)
assert.False(t, pgConn.IsAlive()) assert.True(t, pgConn.IsClosed())
} }
func TestConnCopyFromPrecanceled(t *testing.T) { func TestConnCopyFromPrecanceled(t *testing.T) {
@@ -1173,7 +1326,7 @@ func TestConnCopyFromPrecanceled(t *testing.T) {
ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
require.Error(t, err) require.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled)) assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) assert.True(t, pgconn.SafeToRetry(err))
assert.Equal(t, pgconn.CommandTag(nil), ct) assert.Equal(t, pgconn.CommandTag(nil), ct)
ensureConnValid(t, pgConn) ensureConnValid(t, pgConn)
@@ -1331,6 +1484,45 @@ func TestConnCancelRequest(t *testing.T) {
ensureConnValid(t, pgConn) ensureConnValid(t, pgConn)
} }
func TestConnSendBytesAndReceiveMessage(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer closeConn(t, pgConn)
queryMsg := pgproto3.Query{String: "select 42"}
buf := queryMsg.Encode(nil)
err = pgConn.SendBytes(ctx, buf)
require.NoError(t, err)
msg, err := pgConn.ReceiveMessage(ctx)
require.NoError(t, err)
_, ok := msg.(*pgproto3.RowDescription)
require.True(t, ok)
msg, err = pgConn.ReceiveMessage(ctx)
require.NoError(t, err)
_, ok = msg.(*pgproto3.DataRow)
require.True(t, ok)
msg, err = pgConn.ReceiveMessage(ctx)
require.NoError(t, err)
_, ok = msg.(*pgproto3.CommandComplete)
require.True(t, ok)
msg, err = pgConn.ReceiveMessage(ctx)
require.NoError(t, err)
_, ok = msg.(*pgproto3.ReadyForQuery)
require.True(t, ok)
ensureConnValid(t, pgConn)
}
func Example() { func Example() {
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
if err != nil { if err != nil {
+111
View File
@@ -0,0 +1,111 @@
package stmtcache
import (
"container/list"
"context"
"fmt"
"sync/atomic"
"github.com/jackc/pgconn"
)
var lruCount uint64
// LRU implements Cache with a Least Recently Used (LRU) cache.
type LRU struct {
conn *pgconn.PgConn
mode int
cap int
prepareCount int
m map[string]*list.Element
l *list.List
psNamePrefix string
}
// NewLRU creates a new LRU. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache.
func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU {
mustBeValidMode(mode)
mustBeValidCap(cap)
n := atomic.AddUint64(&lruCount, 1)
return &LRU{
conn: conn,
mode: mode,
cap: cap,
m: make(map[string]*list.Element),
l: list.New(),
psNamePrefix: fmt.Sprintf("lrupsc_%d", n),
}
}
// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed.
func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
if el, ok := c.m[sql]; ok {
c.l.MoveToFront(el)
return el.Value.(*pgconn.StatementDescription), nil
}
if c.l.Len() == c.cap {
err := c.removeOldest(ctx)
if err != nil {
return nil, err
}
}
psd, err := c.prepare(ctx, sql)
if err != nil {
return nil, err
}
el := c.l.PushFront(psd)
c.m[sql] = el
return psd, nil
}
// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session.
func (c *LRU) Clear(ctx context.Context) error {
for c.l.Len() > 0 {
err := c.removeOldest(ctx)
if err != nil {
return err
}
}
return nil
}
// Len returns the number of cached prepared statement descriptions.
func (c *LRU) Len() int {
return c.l.Len()
}
// Cap returns the maximum number of cached prepared statement descriptions.
func (c *LRU) Cap() int {
return c.cap
}
// Mode returns the mode of the cache (ModePrepare or ModeDescribe)
func (c *LRU) Mode() int {
return c.mode
}
func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
var name string
if c.mode == ModePrepare {
name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount)
c.prepareCount += 1
}
return c.conn.Prepare(ctx, name, sql, nil)
}
func (c *LRU) removeOldest(ctx context.Context) error {
oldest := c.l.Back()
c.l.Remove(oldest)
if c.mode == ModePrepare {
return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", oldest.Value.(*pgconn.StatementDescription).Name)).Close()
}
return nil
}
+113
View File
@@ -0,0 +1,113 @@
package stmtcache_test
import (
"context"
"os"
"testing"
"time"
"github.com/jackc/pgconn"
"github.com/jackc/pgconn/stmtcache"
"github.com/stretchr/testify/require"
)
func TestLRUModePrepare(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer conn.Close(ctx)
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2)
require.EqualValues(t, 0, cache.Len())
require.EqualValues(t, 2, cache.Cap())
require.EqualValues(t, stmtcache.ModePrepare, cache.Mode())
psd, err := cache.Get(ctx, "select 1")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 1, cache.Len())
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 1")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 1, cache.Len())
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 2")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 2, cache.Len())
require.ElementsMatch(t, []string{"select 1", "select 2"}, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 3")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 2, cache.Len())
require.ElementsMatch(t, []string{"select 2", "select 3"}, fetchServerStatements(t, ctx, conn))
err = cache.Clear(ctx)
require.NoError(t, err)
require.EqualValues(t, 0, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
}
func TestLRUModeDescribe(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer conn.Close(ctx)
cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2)
require.EqualValues(t, 0, cache.Len())
require.EqualValues(t, 2, cache.Cap())
require.EqualValues(t, stmtcache.ModeDescribe, cache.Mode())
psd, err := cache.Get(ctx, "select 1")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 1, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 1")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 1, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 2")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 2, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 3")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 2, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
err = cache.Clear(ctx)
require.NoError(t, err)
require.EqualValues(t, 0, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
}
func fetchServerStatements(t testing.TB, ctx context.Context, conn *pgconn.PgConn) []string {
result := conn.ExecParams(ctx, `select statement from pg_prepared_statements`, nil, nil, nil, nil).Read()
require.NoError(t, result.Err)
var statements []string
for _, r := range result.Rows {
statements = append(statements, string(r[0]))
}
return statements
}
+52
View File
@@ -0,0 +1,52 @@
// Package stmtcache is a cache that can be used to implement lazy prepared statements.
package stmtcache
import (
"context"
"github.com/jackc/pgconn"
)
const (
ModePrepare = iota // Cache should prepare named statements.
ModeDescribe // Cache should prepare the anonymous prepared statement to only fetch the description of the statement.
)
// Cache prepares and caches prepared statement descriptions.
type Cache interface {
// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed.
Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error)
// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session.
Clear(ctx context.Context) error
// Len returns the number of cached prepared statement descriptions.
Len() int
// Cap returns the maximum number of cached prepared statement descriptions.
Cap() int
// Mode returns the mode of the cache (ModePrepare or ModeDescribe)
Mode() int
}
// New returns the preferred cache implementation for mode and cap. mode is either ModePrepare or ModeDescribe. cap is
// the maximum size of the cache.
func New(conn *pgconn.PgConn, mode int, cap int) Cache {
mustBeValidMode(mode)
mustBeValidCap(cap)
return NewLRU(conn, mode, cap)
}
func mustBeValidMode(mode int) {
if mode != ModePrepare && mode != ModeDescribe {
panic("mode must be ModePrepare or ModeDescribe")
}
}
func mustBeValidCap(cap int) {
if cap < 1 {
panic("cache must have cap of >= 1")
}
}
-14
View File
@@ -1,14 +0,0 @@
#!/usr/bin/env bash
set -eux
go get -u github.com/cockroachdb/apd
go get -u github.com/shopspring/decimal
go get -u gopkg.in/inconshreveable/log15.v2
go get -u github.com/jackc/fake
go get -u github.com/lib/pq
go get -u github.com/hashicorp/go-version
go get -u github.com/satori/go.uuid
go get -u github.com/sirupsen/logrus
go get -u github.com/pkg/errors
go get -u go.uber.org/zap
go get -u github.com/rs/zerolog