2
0

Merge pull request #2 from jackc/master

Sync
This commit is contained in:
Artemiy Ryabinkov
2019-09-13 17:20:09 +03:00
committed by GitHub
14 changed files with 1009 additions and 273 deletions
+11 -2
View File
@@ -4,6 +4,9 @@ go:
- 1.x
- tip
git:
depth: 1
# Derived from https://github.com/lib/pq/blob/master/.travis.yml
before_install:
- ./travis/before_install.bash
@@ -11,6 +14,8 @@ before_install:
env:
global:
- GO111MODULE=on
- GOPROXY=https://proxy.golang.org
- GOFLAGS=-mod=readonly
- PGX_TEST_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test
- PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql database=pgx_test"
- PGX_TEST_TCP_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test
@@ -25,11 +30,15 @@ env:
- PGVERSION=9.4
- PGVERSION=9.3
cache:
directories:
- $HOME/.cache/go-build
- $HOME/gopath/pkg/mod
before_script:
- ./travis/before_script.bash
install:
- ./travis/install.bash
install: go mod download
script:
- ./travis/script.bash
+24 -14
View File
@@ -31,7 +31,7 @@ const clientNonceLen = 18
// Perform SCRAM authentication.
func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
sc, err := newScramClient(serverAuthMechanisms, c.Config.Password)
sc, err := newScramClient(serverAuthMechanisms, c.config.Password)
if err != nil {
return err
}
@@ -47,11 +47,11 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
}
// Receive server-first-message payload in a AuthenticationSASLContinue.
authMsg, err := c.rxAuthMsg(pgproto3.AuthTypeSASLContinue)
saslContinue, err := c.rxSASLContinue()
if err != nil {
return err
}
err = sc.recvServerFirstMessage(authMsg.SASLData)
err = sc.recvServerFirstMessage(saslContinue.Data)
if err != nil {
return err
}
@@ -66,27 +66,37 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
}
// Receive server-final-message payload in a AuthenticationSASLFinal.
authMsg, err = c.rxAuthMsg(pgproto3.AuthTypeSASLFinal)
saslFinal, err := c.rxSASLFinal()
if err != nil {
return err
}
return sc.recvServerFinalMessage(authMsg.SASLData)
return sc.recvServerFinalMessage(saslFinal.Data)
}
func (c *PgConn) rxAuthMsg(typ uint32) (*pgproto3.Authentication, error) {
msg, err := c.ReceiveMessage()
func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) {
msg, err := c.receiveMessage()
if err != nil {
return nil, err
}
authMsg, ok := msg.(*pgproto3.Authentication)
if !ok {
return nil, errors.New("unexpected message type")
}
if authMsg.Type != typ {
return nil, errors.New("unexpected auth type")
saslContinue, ok := msg.(*pgproto3.AuthenticationSASLContinue)
if ok {
return saslContinue, nil
}
return authMsg, nil
return nil, errors.New("expected AuthenticationSASLContinue message but received unexpected message")
}
func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
msg, err := c.receiveMessage()
if err != nil {
return nil, err
}
saslFinal, ok := msg.(*pgproto3.AuthenticationSASLFinal)
if ok {
return saslFinal, nil
}
return nil, errors.New("expected AuthenticationSASLFinal message but received unexpected message")
}
type scramClient struct {
+48 -14
View File
@@ -5,6 +5,7 @@ import (
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"io/ioutil"
"math"
"net"
@@ -17,22 +18,26 @@ import (
"strings"
"time"
"github.com/jackc/chunkreader/v2"
"github.com/jackc/pgpassfile"
"github.com/jackc/pgproto3/v2"
errors "golang.org/x/xerrors"
)
type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
// Config is the settings used to establish a connection to a PostgreSQL server.
// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig and
// then it can be modified. A manually initialized Config will cause ConnectConfig to panic.
type Config struct {
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
Port uint16
Database string
User string
Password string
TLSConfig *tls.Config // nil disables TLS
DialFunc DialFunc // e.g. net.Dialer.DialContext
TLSConfig *tls.Config // nil disables TLS
DialFunc DialFunc // e.g. net.Dialer.DialContext
BuildFrontend BuildFrontendFunc
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
Fallbacks []*FallbackConfig
@@ -52,6 +57,8 @@ type Config struct {
// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
OnNotification NotificationHandler
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
}
// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a
@@ -134,36 +141,48 @@ func NetworkAddress(host string, port uint16) (network, address string) {
//
// When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn
// does not.
//
// In addition, ParseConfig accepts the following options:
//
// min_read_buffer_size
// The minimum size of the internal read buffer. Default 8192.
func ParseConfig(connString string) (*Config, error) {
settings := defaultSettings()
addEnvSettings(settings)
if connString != "" {
// connString may be a database URL or a DSN
if strings.HasPrefix(connString, "postgres://") {
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
err := addURLSettings(settings, connString)
if err != nil {
return nil, err
return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err}
}
} else {
err := addDSNSettings(settings, connString)
if err != nil {
return nil, err
return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err}
}
}
}
minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "cannot parse min_read_buffer_size", err: err}
}
config := &Config{
Database: settings["database"],
User: settings["user"],
Password: settings["password"],
RuntimeParams: make(map[string]string),
createdByParseConfig: true,
Database: settings["database"],
User: settings["user"],
Password: settings["password"],
RuntimeParams: make(map[string]string),
BuildFrontend: makeDefaultBuildFrontendFunc(int(minReadBufferSize)),
}
if connectTimeout, present := settings["connect_timeout"]; present {
dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout)
if err != nil {
return nil, err
return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err}
}
config.DialFunc = dialFunc
} else {
@@ -184,6 +203,7 @@ func ParseConfig(connString string) (*Config, error) {
"sslcert": struct{}{},
"sslrootcert": struct{}{},
"target_session_attrs": struct{}{},
"min_read_buffer_size": struct{}{},
}
for k, v := range settings {
@@ -208,7 +228,7 @@ func ParseConfig(connString string) (*Config, error) {
port, err := parsePort(portStr)
if err != nil {
return nil, errors.Errorf("invalid port: %w", err)
return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err}
}
var tlsConfigs []*tls.Config
@@ -220,7 +240,7 @@ func ParseConfig(connString string) (*Config, error) {
var err error
tlsConfigs, err = configTLS(settings)
if err != nil {
return nil, err
return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err}
}
}
@@ -253,7 +273,7 @@ func ParseConfig(connString string) (*Config, error) {
if settings["target_session_attrs"] == "read-write" {
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite
} else if settings["target_session_attrs"] != "any" {
return nil, errors.Errorf("unknown target_session_attrs value: %v", settings["target_session_attrs"])
return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", settings["target_session_attrs"])}
}
return config, nil
@@ -276,6 +296,8 @@ func defaultSettings() map[string]string {
settings["target_session_attrs"] = "any"
settings["min_read_buffer_size"] = "8192"
return settings
}
@@ -473,6 +495,18 @@ func makeDefaultDialer() *net.Dialer {
return &net.Dialer{KeepAlive: 5 * time.Minute}
}
func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc {
return func(r io.Reader, w io.Writer) Frontend {
cr, err := chunkreader.NewConfig(r, chunkreader.Config{MinBufLen: minBufferLen})
if err != nil {
panic(fmt.Sprintf("BUG: chunkreader.NewConfig failed: %v", err))
}
frontend := pgproto3.NewFrontend(cr, w)
return frontend
}
}
func makeConnectTimeoutDialFunc(s string) (DialFunc, error) {
timeout, err := strconv.ParseInt(s, 10, 64)
if err != nil {
+24
View File
@@ -214,6 +214,18 @@ func TestParseConfig(t *testing.T) {
RuntimeParams: map[string]string{},
},
},
{
name: "database url postgresql protocol",
connString: "postgresql://jack@localhost:5432/mydb?sslmode=disable",
config: &pgconn.Config{
User: "jack",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
name: "DSN everything",
connString: "user=jack password=secret host=localhost port=5432 database=mydb sslmode=disable application_name=pgxtest search_path=myschema",
@@ -561,3 +573,15 @@ func TestParseConfigReadsPgPassfile(t *testing.T) {
assertConfigsEqual(t, expected, actual, "passfile")
}
func TestParseConfigExtractsMinReadBufferSize(t *testing.T) {
t.Parallel()
config, err := pgconn.ParseConfig("min_read_buffer_size=0")
require.NoError(t, err)
_, present := config.RuntimeParams["min_read_buffer_size"]
require.False(t, present)
// The buffer size is internal so there isn't much that can be done to test it other than see that the runtime param
// was removed.
}
+1 -1
View File
@@ -15,7 +15,7 @@ reads all rows into memory.
Executing Multiple Queries in a Single Round Trip
Exec and ExecBatch can execute multiple queries in a single round trip. The return readers that iterate over each query
Exec and ExecBatch can execute multiple queries in a single round trip. They return readers that iterate over each query
result. The ReadAll method reads all query results into memory.
Context Support
+114 -42
View File
@@ -2,22 +2,31 @@ package pgconn
import (
"context"
"fmt"
"net"
"strings"
errors "golang.org/x/xerrors"
)
// ErrTLSRefused occurs when the connection attempt requires TLS and the
// PostgreSQL server refuses to use TLS
var ErrTLSRefused = errors.New("server refused TLS connection")
// SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server.
func SafeToRetry(err error) bool {
if e, ok := err.(interface{ SafeToRetry() bool }); ok {
return e.SafeToRetry()
}
return false
}
// ErrConnBusy occurs when the connection is busy (for example, in the middle of reading query results) and another
// action is attempted.
var ErrConnBusy = errors.New("conn is busy")
// Timeout checks if err was was caused by a timeout. To be specific, it is true if err is or was caused by a
// context.Canceled, context.Canceled or an implementer of net.Error where Timeout() is true.
func Timeout(err error) bool {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return true
}
// ErrNoBytesSent is used to annotate an error that occurred without sending any bytes to the server. This can be used
// to implement safe retry logic. ErrNoBytesSent will never occur alone. It will always be wrapped by another error.
var ErrNoBytesSent = errors.New("no bytes sent to server")
var netErr net.Error
return errors.As(err, &netErr) && netErr.Timeout()
}
// PgError represents an error reported by the PostgreSQL server. See
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
@@ -46,44 +55,107 @@ func (pe *PgError) Error() string {
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
}
// linkedError connects two errors as if err wrapped next.
type linkedError struct {
err error
next error
type connectError struct {
config *Config
msg string
err error
}
func (le *linkedError) Error() string {
return le.err.Error()
}
func (le *linkedError) Is(target error) bool {
return errors.Is(le.err, target)
}
func (le *linkedError) As(target interface{}) bool {
return errors.As(le.err, target)
}
func (le *linkedError) Unwrap() error {
return le.next
}
// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() ==
// true. Otherwise returns err.
func preferContextOverNetTimeoutError(ctx context.Context, err error) error {
if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil {
return ctx.Err()
func (e *connectError) Error() string {
sb := &strings.Builder{}
fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg)
if e.err != nil {
fmt.Fprintf(sb, " (%s)", e.err.Error())
}
return err
return sb.String()
}
// linkErrors connects outer and inner as if the the fully unwrapped outer wrapped inner. If either outer or inner is nil then the other is returned.
func linkErrors(outer, inner error) error {
if outer == nil {
return inner
func (e *connectError) Unwrap() error {
return e.err
}
type connLockError struct {
status string
}
func (e *connLockError) SafeToRetry() bool {
return true // a lock failure by definition happens before the connection is used.
}
func (e *connLockError) Error() string {
return e.status
}
type parseConfigError struct {
connString string
msg string
err error
}
func (e *parseConfigError) Error() string {
if e.err == nil {
return fmt.Sprintf("cannot parse `%s`: %s", e.connString, e.msg)
}
if inner == nil {
return outer
return fmt.Sprintf("cannot parse `%s`: %s (%s)", e.connString, e.msg, e.err.Error())
}
func (e *parseConfigError) Unwrap() error {
return e.err
}
type pgconnError struct {
msg string
err error
safeToRetry bool
}
func (e *pgconnError) Error() string {
if e.msg == "" {
return e.err.Error()
}
return &linkedError{err: outer, next: inner}
if e.err == nil {
return e.msg
}
return fmt.Sprintf("%s: %s", e.msg, e.err.Error())
}
func (e *pgconnError) SafeToRetry() bool {
return e.safeToRetry
}
func (e *pgconnError) Unwrap() error {
return e.err
}
type contextAlreadyDoneError struct {
err error
}
func (e *contextAlreadyDoneError) Error() string {
return fmt.Sprintf("context already done: %s", e.err.Error())
}
func (e *contextAlreadyDoneError) SafeToRetry() bool {
return true
}
func (e *contextAlreadyDoneError) Unwrap() error {
return e.err
}
type writeError struct {
err error
safeToRetry bool
}
func (e *writeError) Error() string {
return fmt.Sprintf("write failed: %s", e.err.Error())
}
func (e *writeError) SafeToRetry() bool {
return e.safeToRetry
}
func (e *writeError) Unwrap() error {
return e.err
}
+7 -5
View File
@@ -3,11 +3,13 @@ module github.com/jackc/pgconn
go 1.12
require (
github.com/jackc/chunkreader/v2 v2.0.0
github.com/jackc/pgio v1.0.0
github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2
github.com/jackc/pgpassfile v1.0.0
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711
github.com/stretchr/testify v1.3.0
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a
golang.org/x/text v0.3.0
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522
github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29
github.com/stretchr/testify v1.4.0
golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586
golang.org/x/text v0.3.2
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7
)
+84 -10
View File
@@ -1,31 +1,105 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0=
github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo=
github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs=
github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA=
github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE=
github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s=
github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE=
github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8=
github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye4717ITLaNwV9mWbJx0dLCpcRzdA=
github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db h1:UpaKn/gYxzH6/zWyRQH1S260zvKqwJJ4h8+Kf09ooh0=
github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A=
github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 h1:vZp4bYotXUkFx7JUSm7U8KV/7Q0AOdrQxxBBj0ZmZsg=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM=
github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 h1:f2HwOeI1NIJyNFVVeh1gUISyt57iw/fmI/IXJfH3ATE=
github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM=
github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg=
github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc=
github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw=
github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y=
github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM=
github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc=
github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ=
github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc=
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4=
github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcNpq8q3BCACtVgNfoJxOV7g=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0=
go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE=
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e h1:nFYrTHrdrAOpShe27kaFHjsqYSEQ0KWqdWLu3xuZJts=
golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 h1:7KByu05hhLed2MO29w7p1XfZvZ13m8mub3shuVftRs0=
golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 h1:PPwnA7z1Pjf7XYaBP9GL1VAMZmcIWyFz7QCMSIIa3Bg=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 h1:bhOzK9QyoD0ogCnFro1m2mz41+Ib0oOhfJnBp5MR4K4=
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+201 -144
View File
@@ -40,9 +40,12 @@ type Notification struct {
Payload string
}
// DialFunc is a function that can be used to connect to a PostgreSQL server
// DialFunc is a function that can be used to connect to a PostgreSQL server.
type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
// BuildFrontendFunc is a function that can be used to create Frontend implementation for connection.
type BuildFrontendFunc func(r io.Reader, w io.Writer) Frontend
// NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at
// any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin
// of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY
@@ -55,16 +58,21 @@ type NoticeHandler func(*PgConn, *Notice)
// notice event.
type NotificationHandler func(*PgConn, *Notification)
// Frontend used to receive messages from backend.
type Frontend interface {
Receive() (pgproto3.BackendMessage, error)
}
// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
type PgConn struct {
conn net.Conn // the underlying TCP or unix domain socket connection
pid uint32 // backend pid
secretKey uint32 // key to use to send a cancel query message to the server
parameterStatuses map[string]string // parameters that have been reported by the server
TxStatus byte
Frontend *pgproto3.Frontend
txStatus byte
frontend Frontend
Config *Config
config *Config
status byte // One of connStatus* constants
@@ -91,22 +99,18 @@ func Connect(ctx context.Context, connString string) (*PgConn, error) {
return ConnectConfig(ctx, config)
}
// Connect establishes a connection to a PostgreSQL server using config. ctx can be used to cancel a connect attempt.
// Connect establishes a connection to a PostgreSQL server using config. config must have been constructed with
// ParseConfig. ctx can be used to cancel a connect attempt.
//
// If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An
// authentication error will terminate the chain of attempts (like libpq:
// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise,
// if all attempts fail the last error is returned.
func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err error) {
// For convenience set a few defaults if not already set. This makes it simpler to directly construct a config.
if config.Port == 0 {
config.Port = 5432
}
if config.DialFunc == nil {
config.DialFunc = makeDefaultDialer().DialContext
}
if config.RuntimeParams == nil {
config.RuntimeParams = make(map[string]string)
// Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from
// zero values.
if !config.createdByParseConfig {
panic("config must be created by ParseConfig")
}
// Simplify usage by treating primary config and fallbacks the same.
@@ -124,19 +128,19 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
if err == nil {
break
} else if err, ok := err.(*PgError); ok {
return nil, err
return nil, &connectError{config: config, msg: "server error", err: err}
}
}
if err != nil {
return nil, err
return nil, err // no need to wrap in connectError because it will already be wrapped in all cases except PgError
}
if config.AfterConnect != nil {
err := config.AfterConnect(ctx, pgConn)
if err != nil {
pgConn.conn.Close()
return nil, errors.Errorf("AfterConnect: %v", err)
return nil, &connectError{config: config, msg: "AfterConnect error", err: err}
}
}
@@ -145,14 +149,14 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) {
pgConn := new(PgConn)
pgConn.Config = config
pgConn.config = config
pgConn.wbuf = make([]byte, 0, 1024)
var err error
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
pgConn.conn, err = config.DialFunc(ctx, network, address)
if err != nil {
return nil, err
return nil, &connectError{config: config, msg: "dial error", err: err}
}
pgConn.parameterStatuses = make(map[string]string)
@@ -160,7 +164,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
if fallbackConfig.TLSConfig != nil {
if err := pgConn.startTLS(fallbackConfig.TLSConfig); err != nil {
pgConn.conn.Close()
return nil, err
return nil, &connectError{config: config, msg: "tls error", err: err}
}
}
@@ -170,10 +174,10 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
func() { pgConn.conn.SetDeadline(time.Time{}) },
)
pgConn.Frontend, err = pgproto3.NewFrontend(pgproto3.NewChunkReader(pgConn.conn), pgConn.conn)
if err != nil {
return nil, err
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn)
startupMsg := pgproto3.StartupMessage{
ProtocolVersion: pgproto3.ProtocolVersionNumber,
@@ -192,32 +196,52 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
if _, err := pgConn.conn.Write(startupMsg.Encode(pgConn.wbuf)); err != nil {
pgConn.conn.Close()
return nil, err
return nil, &connectError{config: config, msg: "failed to write startup message", err: err}
}
for {
msg, err := pgConn.ReceiveMessage()
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.conn.Close()
return nil, err
if err, ok := err.(*PgError); ok {
return nil, err
}
return nil, &connectError{config: config, msg: "failed to receive message", err: err}
}
switch msg := msg.(type) {
case *pgproto3.BackendKeyData:
pgConn.pid = msg.ProcessID
pgConn.secretKey = msg.SecretKey
case *pgproto3.Authentication:
if err = pgConn.rxAuthenticationX(msg); err != nil {
case *pgproto3.AuthenticationOk:
case *pgproto3.AuthenticationCleartextPassword:
err = pgConn.txPasswordMessage(pgConn.config.Password)
if err != nil {
pgConn.conn.Close()
return nil, err
return nil, &connectError{config: config, msg: "failed to write password message", err: err}
}
case *pgproto3.AuthenticationMD5Password:
digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:]))
err = pgConn.txPasswordMessage(digestedPassword)
if err != nil {
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed to write password message", err: err}
}
case *pgproto3.AuthenticationSASL:
err = pgConn.scramAuth(msg.AuthMechanisms)
if err != nil {
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed SASL auth", err: err}
}
case *pgproto3.ReadyForQuery:
pgConn.status = connStatusIdle
if config.ValidateConnect != nil {
err := config.ValidateConnect(ctx, pgConn)
if err != nil {
pgConn.conn.Close()
return nil, errors.Errorf("ValidateConnect: %v", err)
return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err}
}
}
return pgConn, nil
@@ -225,10 +249,10 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
// handled by ReceiveMessage
case *pgproto3.ErrorResponse:
pgConn.conn.Close()
return nil, errorResponseToPgError(msg)
return nil, ErrorResponseToPgError(msg)
default:
pgConn.conn.Close()
return nil, errors.New("unexpected message")
return nil, &connectError{config: config, msg: "received unexpected message", err: err}
}
}
}
@@ -245,7 +269,7 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) {
}
if response[0] != 'S' {
return ErrTLSRefused
return errors.New("server refused TLS connection")
}
pgConn.conn = tls.Client(pgConn.conn, tlsConfig)
@@ -253,23 +277,6 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) {
return nil
}
func (pgConn *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) {
switch msg.Type {
case pgproto3.AuthTypeOk:
case pgproto3.AuthTypeCleartextPassword:
err = pgConn.txPasswordMessage(pgConn.Config.Password)
case pgproto3.AuthTypeMD5Password:
digestedPassword := "md5" + hexMD5(hexMD5(pgConn.Config.Password+pgConn.Config.User)+string(msg.Salt[:]))
err = pgConn.txPasswordMessage(digestedPassword)
case pgproto3.AuthTypeSASL:
err = pgConn.scramAuth(msg.SASLAuthMechanisms)
default:
err = errors.New("Received unknown authentication message")
}
return
}
func (pgConn *PgConn) txPasswordMessage(password string) (err error) {
msg := &pgproto3.PasswordMessage{Password: password}
_, err = pgConn.conn.Write(msg.Encode(pgConn.wbuf))
@@ -292,7 +299,7 @@ func (pgConn *PgConn) signalMessage() chan struct{} {
ch := make(chan struct{})
go func() {
pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.Frontend.Receive()
pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive()
pgConn.bufferingReceiveMux.Unlock()
close(ch)
}()
@@ -300,7 +307,64 @@ func (pgConn *PgConn) signalMessage() chan struct{} {
return ch
}
func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) {
// SendBytes sends buf to the PostgreSQL server. It must only be used when the connection is not busy. e.g. It is as
// error to call SendBytes while reading the result of a query.
//
// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly.
// See https://www.postgresql.org/docs/current/protocol.html.
func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error {
if err := pgConn.lock(); err != nil {
return err
}
defer pgConn.unlock()
select {
case <-ctx.Done():
return &contextAlreadyDoneError{err: ctx.Err()}
default:
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
n, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
return &writeError{err: err, safeToRetry: n == 0}
}
return nil
}
// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the
// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages
// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger
// the OnNotification callback.
//
// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly.
// See https://www.postgresql.org/docs/current/protocol.html.
func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessage, error) {
if err := pgConn.lock(); err != nil {
return nil, err
}
defer pgConn.unlock()
select {
case <-ctx.Done():
return nil, &contextAlreadyDoneError{err: ctx.Err()}
default:
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
msg, err := pgConn.receiveMessage()
if err != nil {
err = &pgconnError{msg: "receive message failed", err: err, safeToRetry: true}
}
return msg, err
}
// receiveMessage receives a message without setting up context cancellation
func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
var msg pgproto3.BackendMessage
var err error
if pgConn.bufferingReceive {
@@ -312,10 +376,10 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) {
// If a timeout error happened in the background try the read again.
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
msg, err = pgConn.Frontend.Receive()
msg, err = pgConn.frontend.Receive()
}
} else {
msg, err = pgConn.Frontend.Receive()
msg, err = pgConn.frontend.Receive()
}
if err != nil {
@@ -329,21 +393,21 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) {
switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
pgConn.TxStatus = msg.TxStatus
pgConn.txStatus = msg.TxStatus
case *pgproto3.ParameterStatus:
pgConn.parameterStatuses[msg.Name] = msg.Value
case *pgproto3.ErrorResponse:
if msg.Severity == "FATAL" {
pgConn.hardClose()
return nil, errorResponseToPgError(msg)
return nil, ErrorResponseToPgError(msg)
}
case *pgproto3.NoticeResponse:
if pgConn.Config.OnNotice != nil {
pgConn.Config.OnNotice(pgConn, noticeResponseToNotice(msg))
if pgConn.config.OnNotice != nil {
pgConn.config.OnNotice(pgConn, noticeResponseToNotice(msg))
}
case *pgproto3.NotificationResponse:
if pgConn.Config.OnNotification != nil {
pgConn.Config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload})
if pgConn.config.OnNotification != nil {
pgConn.config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload})
}
}
@@ -360,6 +424,11 @@ func (pgConn *PgConn) PID() uint32 {
return pgConn.pid
}
// TxStatus returns the current TxStatus as reported by the server.
func (pgConn *PgConn) TxStatus() byte {
return pgConn.txStatus
}
// SecretKey returns the backend secret key used to send a cancel query message to the server.
func (pgConn *PgConn) SecretKey() uint32 {
return pgConn.secretKey
@@ -381,12 +450,12 @@ func (pgConn *PgConn) Close(ctx context.Context) error {
_, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4})
if err != nil {
return linkErrors(ctx.Err(), err)
return err
}
_, err = pgConn.conn.Read(make([]byte, 1))
if err != io.EOF {
return linkErrors(ctx.Err(), err)
return err
}
return pgConn.conn.Close()
@@ -402,21 +471,20 @@ func (pgConn *PgConn) hardClose() error {
return pgConn.conn.Close()
}
// TODO - rethink how to report status. At the moment this is just a temporary measure so pgx.Conn can detect death of
// underlying connection.
func (pgConn *PgConn) IsAlive() bool {
return pgConn.status >= connStatusIdle
// IsClosed reports if the connection has been closed.
func (pgConn *PgConn) IsClosed() bool {
return pgConn.status < connStatusIdle
}
// lock locks the connection. It panics if the connection is already locked or is closed.
// lock locks the connection.
func (pgConn *PgConn) lock() error {
switch pgConn.status {
case connStatusBusy:
return ErrConnBusy // This only should be possible in case of an application bug.
return &connLockError{status: "conn busy"} // This only should be possible in case of an application bug.
case connStatusClosed:
return errors.New("conn closed")
return &connLockError{status: "conn closed"}
case connStatusUninitialized:
return errors.New("conn uninitialized")
return &connLockError{status: "conn uninitialized"}
}
pgConn.status = connStatusBusy
return nil
@@ -456,23 +524,24 @@ func (ct CommandTag) String() string {
return string(ct)
}
type PreparedStatementDescription struct {
type StatementDescription struct {
Name string
SQL string
ParamOIDs []uint32
Fields []pgproto3.FieldDescription
}
// Prepare creates a prepared statement.
func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) {
// Prepare creates a prepared statement. If the name is empty, the anonymous prepared statement will be used. This
// allows Prepare to also to describe statements without creating a server-side prepared statement.
func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) {
if err := pgConn.lock(); err != nil {
return nil, linkErrors(err, ErrNoBytesSent)
return nil, err
}
defer pgConn.unlock()
select {
case <-ctx.Done():
return nil, linkErrors(ctx.Err(), ErrNoBytesSent)
return nil, &contextAlreadyDoneError{err: ctx.Err()}
default:
}
pgConn.contextWatcher.Watch(ctx)
@@ -486,22 +555,19 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
n, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
if n == 0 {
err = linkErrors(err, ErrNoBytesSent)
}
return nil, linkErrors(ctx.Err(), err)
return nil, &pgconnError{msg: "write failed", err: err, safeToRetry: n == 0}
}
psd := &PreparedStatementDescription{Name: name, SQL: sql}
psd := &StatementDescription{Name: name, SQL: sql}
var parseErr error
readloop:
for {
msg, err := pgConn.ReceiveMessage()
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err)
return nil, err
}
switch msg := msg.(type) {
@@ -512,7 +578,7 @@ readloop:
psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields))
copy(psd.Fields, msg.Fields)
case *pgproto3.ErrorResponse:
parseErr = errorResponseToPgError(msg)
parseErr = ErrorResponseToPgError(msg)
case *pgproto3.ReadyForQuery:
break readloop
}
@@ -524,7 +590,8 @@ readloop:
return psd, nil
}
func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {
// ErrorResponseToPgError converts a wire protocol error message to a *PgError.
func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {
return &PgError{
Severity: msg.Severity,
Code: string(msg.Code),
@@ -547,7 +614,7 @@ func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {
}
func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice {
pgerr := errorResponseToPgError((*pgproto3.ErrorResponse)(msg))
pgerr := ErrorResponseToPgError((*pgproto3.ErrorResponse)(msg))
return (*Notice)(pgerr)
}
@@ -559,7 +626,7 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
// the connection config. This is important in high availability configurations where fallback connections may be
// specified or DNS may be used to load balance.
serverAddr := pgConn.conn.RemoteAddr()
cancelConn, err := pgConn.Config.DialFunc(ctx, serverAddr.Network(), serverAddr.String())
cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String())
if err != nil {
return err
}
@@ -579,12 +646,12 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey))
_, err = cancelConn.Write(buf)
if err != nil {
return linkErrors(ctx.Err(), err)
return err
}
_, err = cancelConn.Read(buf)
if err != io.EOF {
return errors.Errorf("Server failed to close connection after cancel query request: %w", linkErrors(ctx.Err(), err))
return err
}
return nil
@@ -608,9 +675,9 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
defer pgConn.contextWatcher.Unwatch()
for {
msg, err := pgConn.ReceiveMessage()
msg, err := pgConn.receiveMessage()
if err != nil {
return linkErrors(ctx.Err(), err)
return err
}
switch msg.(type) {
@@ -629,7 +696,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
if err := pgConn.lock(); err != nil {
return &MultiResultReader{
closed: true,
err: linkErrors(err, ErrNoBytesSent),
err: err,
}
}
@@ -642,7 +709,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
select {
case <-ctx.Done():
multiResult.closed = true
multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent)
multiResult.err = &contextAlreadyDoneError{err: ctx.Err()}
pgConn.unlock()
return multiResult
default:
@@ -657,10 +724,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
pgConn.hardClose()
pgConn.contextWatcher.Unwatch()
multiResult.closed = true
if n == 0 {
err = linkErrors(err, ErrNoBytesSent)
}
multiResult.err = linkErrors(ctx.Err(), err)
multiResult.err = &writeError{err: err, safeToRetry: n == 0}
pgConn.unlock()
return multiResult
}
@@ -729,19 +793,18 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
}
func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader {
if err := pgConn.lock(); err != nil {
return &ResultReader{
closed: true,
err: linkErrors(err, ErrNoBytesSent),
}
}
pgConn.resultReader = ResultReader{
pgConn: pgConn,
ctx: ctx,
}
result := &pgConn.resultReader
if err := pgConn.lock(); err != nil {
result.concludeCommand(nil, err)
result.closed = true
return result
}
if len(paramValues) > math.MaxUint16 {
result.concludeCommand(nil, errors.Errorf("extended protocol limited to %v parameters", math.MaxUint16))
result.closed = true
@@ -751,7 +814,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
select {
case <-ctx.Done():
result.concludeCommand(nil, linkErrors(ctx.Err(), ErrNoBytesSent))
result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()})
result.closed = true
pgConn.unlock()
return result
@@ -770,10 +833,7 @@ func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result
n, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
if n == 0 {
err = linkErrors(err, ErrNoBytesSent)
}
result.concludeCommand(nil, linkErrors(ctx.Err(), err))
result.concludeCommand(nil, &writeError{err: err, safeToRetry: n == 0})
pgConn.contextWatcher.Unwatch()
result.closed = true
pgConn.unlock()
@@ -783,13 +843,13 @@ func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result
// CopyTo executes the copy command sql and copies the results to w.
func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) {
if err := pgConn.lock(); err != nil {
return nil, linkErrors(err, ErrNoBytesSent)
return nil, err
}
select {
case <-ctx.Done():
pgConn.unlock()
return nil, linkErrors(ctx.Err(), ErrNoBytesSent)
return nil, &contextAlreadyDoneError{err: ctx.Err()}
default:
}
pgConn.contextWatcher.Watch(ctx)
@@ -803,20 +863,17 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
if err != nil {
pgConn.hardClose()
pgConn.unlock()
if n == 0 {
err = linkErrors(err, ErrNoBytesSent)
}
return nil, linkErrors(ctx.Err(), err)
return nil, &writeError{err: err, safeToRetry: n == 0}
}
// Read results
var commandTag CommandTag
var pgErr error
for {
msg, err := pgConn.ReceiveMessage()
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err)
return nil, err
}
switch msg := msg.(type) {
@@ -833,7 +890,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
case *pgproto3.CommandComplete:
commandTag = CommandTag(msg.CommandTag)
case *pgproto3.ErrorResponse:
pgErr = errorResponseToPgError(msg)
pgErr = ErrorResponseToPgError(msg)
}
}
}
@@ -844,13 +901,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
// could still block.
func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) {
if err := pgConn.lock(); err != nil {
return nil, linkErrors(err, ErrNoBytesSent)
return nil, err
}
defer pgConn.unlock()
select {
case <-ctx.Done():
return nil, linkErrors(ctx.Err(), ErrNoBytesSent)
return nil, &contextAlreadyDoneError{err: ctx.Err()}
default:
}
pgConn.contextWatcher.Watch(ctx)
@@ -863,10 +920,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
n, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
if n == 0 {
err = linkErrors(err, ErrNoBytesSent)
}
return nil, linkErrors(ctx.Err(), err)
return nil, &writeError{err: err, safeToRetry: n == 0}
}
// Read until copy in response or error.
@@ -874,17 +928,17 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
var pgErr error
pendingCopyInResponse := true
for pendingCopyInResponse {
msg, err := pgConn.ReceiveMessage()
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err)
return nil, err
}
switch msg := msg.(type) {
case *pgproto3.CopyInResponse:
pendingCopyInResponse = false
case *pgproto3.ErrorResponse:
pgErr = errorResponseToPgError(msg)
pgErr = ErrorResponseToPgError(msg)
case *pgproto3.ReadyForQuery:
return commandTag, pgErr
}
@@ -906,21 +960,21 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
_, err = pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err)
return nil, err
}
}
select {
case <-signalMessageChan:
msg, err := pgConn.ReceiveMessage()
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err)
return nil, err
}
switch msg := msg.(type) {
case *pgproto3.ErrorResponse:
pgErr = errorResponseToPgError(msg)
pgErr = ErrorResponseToPgError(msg)
}
default:
}
@@ -937,15 +991,15 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
_, err = pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err)
return nil, err
}
// Read results
for {
msg, err := pgConn.ReceiveMessage()
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.hardClose()
return nil, linkErrors(ctx.Err(), err)
return nil, err
}
switch msg := msg.(type) {
@@ -954,7 +1008,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
case *pgproto3.CommandComplete:
commandTag = CommandTag(msg.CommandTag)
case *pgproto3.ErrorResponse:
pgErr = errorResponseToPgError(msg)
pgErr = ErrorResponseToPgError(msg)
}
}
}
@@ -983,11 +1037,11 @@ func (mrr *MultiResultReader) ReadAll() ([]*Result, error) {
}
func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) {
msg, err := mrr.pgConn.ReceiveMessage()
msg, err := mrr.pgConn.receiveMessage()
if err != nil {
mrr.pgConn.contextWatcher.Unwatch()
mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err)
mrr.err = err
mrr.closed = true
mrr.pgConn.hardClose()
return nil, mrr.err
@@ -999,7 +1053,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
mrr.closed = true
mrr.pgConn.unlock()
case *pgproto3.ErrorResponse:
mrr.err = errorResponseToPgError(msg)
mrr.err = ErrorResponseToPgError(msg)
}
return msg, nil
@@ -1151,7 +1205,10 @@ func (rr *ResultReader) Close() (CommandTag, error) {
return nil, rr.err
}
switch msg.(type) {
switch msg := msg.(type) {
// Detect a deferred constraint violation where the ErrorResponse is sent after CommandComplete.
case *pgproto3.ErrorResponse:
rr.err = ErrorResponseToPgError(msg)
case *pgproto3.ReadyForQuery:
rr.pgConn.contextWatcher.Unwatch()
rr.pgConn.unlock()
@@ -1165,7 +1222,7 @@ func (rr *ResultReader) Close() (CommandTag, error) {
func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) {
if rr.multiResultReader == nil {
msg, err = rr.pgConn.ReceiveMessage()
msg, err = rr.pgConn.receiveMessage()
} else {
msg, err = rr.multiResultReader.receiveMessage()
}
@@ -1187,7 +1244,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error
case *pgproto3.CommandComplete:
rr.concludeCommand(CommandTag(msg.CommandTag), nil)
case *pgproto3.ErrorResponse:
rr.concludeCommand(nil, errorResponseToPgError(msg))
rr.concludeCommand(nil, ErrorResponseToPgError(msg))
}
return msg, nil
@@ -1199,7 +1256,7 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) {
}
rr.commandTag = commandTag
rr.err = preferContextOverNetTimeoutError(rr.ctx, err)
rr.err = err
rr.fieldDescriptions = nil
rr.rowValues = nil
rr.commandConcluded = true
@@ -1229,7 +1286,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
if err := pgConn.lock(); err != nil {
return &MultiResultReader{
closed: true,
err: linkErrors(err, ErrNoBytesSent),
err: err,
}
}
@@ -1242,7 +1299,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
select {
case <-ctx.Done():
multiResult.closed = true
multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent)
multiResult.err = &contextAlreadyDoneError{err: ctx.Err()}
pgConn.unlock()
return multiResult
default:
+219 -27
View File
@@ -18,6 +18,8 @@ import (
"time"
"github.com/jackc/pgconn"
"github.com/jackc/pgmock"
"github.com/jackc/pgproto3/v2"
errors "golang.org/x/xerrors"
"github.com/stretchr/testify/assert"
@@ -72,6 +74,67 @@ func TestConnectTLS(t *testing.T) {
closeConn(t, conn)
}
type pgmockWaitStep time.Duration
func (s pgmockWaitStep) Step(*pgproto3.Backend) error {
time.Sleep(time.Duration(s))
return nil
}
func TestConnectWithContextThatTimesOut(t *testing.T) {
t.Parallel()
script := &pgmock.Script{
Steps: []pgmock.Step{
pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
pgmock.SendMessage(&pgproto3.AuthenticationOk{}),
pgmockWaitStep(time.Millisecond * 500),
pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
},
}
ln, err := net.Listen("tcp", "127.0.0.1:")
require.NoError(t, err)
defer ln.Close()
serverErrChan := make(chan error, 1)
go func() {
defer close(serverErrChan)
conn, err := ln.Accept()
if err != nil {
serverErrChan <- err
return
}
defer conn.Close()
err = conn.SetDeadline(time.Now().Add(time.Millisecond * 450))
if err != nil {
serverErrChan <- err
return
}
err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn))
if err != nil {
serverErrChan <- err
return
}
}()
parts := strings.Split(ln.Addr().String(), ":")
host := parts[0]
port := parts[1]
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
tooLate := time.Now().Add(time.Millisecond * 500)
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50)
defer cancel()
_, err = pgconn.Connect(ctx, connStr)
require.True(t, pgconn.Timeout(err), err)
require.True(t, time.Now().Before(tooLate))
}
func TestConnectInvalidUser(t *testing.T) {
t.Parallel()
@@ -85,14 +148,11 @@ func TestConnectInvalidUser(t *testing.T) {
config.User = "pgxinvalidusertest"
conn, err := pgconn.ConnectConfig(context.Background(), config)
if err == nil {
conn.Close(context.Background())
t.Fatal("expected err but got none")
}
pgErr, ok := err.(*pgconn.PgError)
_, err = pgconn.ConnectConfig(context.Background(), config)
require.Error(t, err)
pgErr, ok := errors.Unwrap(err).(*pgconn.PgError)
if !ok {
t.Fatalf("Expected to receive a PgError, instead received: %v", err)
t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err)
}
if pgErr.Code != "28000" && pgErr.Code != "28P01" {
t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr)
@@ -262,6 +322,14 @@ func TestConnectWithAfterConnect(t *testing.T) {
assert.Equal(t, []byte("foobar"), results[0].Rows[0][0])
}
func TestConnectConfigRequiresConfigFromParseConfig(t *testing.T) {
t.Parallel()
config := &pgconn.Config{}
require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgconn.ConnectConfig(context.Background(), config) })
}
func TestConnPrepareSyntaxError(t *testing.T) {
t.Parallel()
@@ -289,7 +357,7 @@ func TestConnPrepareContextPrecanceled(t *testing.T) {
assert.Nil(t, psd)
assert.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
assert.True(t, pgconn.SafeToRetry(err))
ensureConnValid(t, pgConn)
}
@@ -381,6 +449,34 @@ func TestConnExecMultipleQueriesError(t *testing.T) {
ensureConnValid(t, pgConn)
}
func TestConnExecDeferredError(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer closeConn(t, pgConn)
setupSQL := `create temporary table t (
id text primary key,
n int not null,
unique (n) deferrable initially deferred
);
insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`
_, err = pgConn.Exec(context.Background(), setupSQL).ReadAll()
assert.NoError(t, err)
_, err = pgConn.Exec(context.Background(), `update t set n=n+1 where id='b' returning *`).ReadAll()
require.NotNil(t, err)
var pgErr *pgconn.PgError
require.True(t, errors.As(err, &pgErr))
require.Equal(t, "23505", pgErr.Code)
ensureConnValid(t, pgConn)
}
func TestConnExecContextCanceled(t *testing.T) {
t.Parallel()
@@ -395,8 +491,8 @@ func TestConnExecContextCanceled(t *testing.T) {
for multiResult.NextResult() {
}
err = multiResult.Close()
assert.Equal(t, context.DeadlineExceeded, err)
assert.False(t, pgConn.IsAlive())
assert.True(t, pgconn.Timeout(err))
assert.True(t, pgConn.IsClosed())
}
func TestConnExecContextPrecanceled(t *testing.T) {
@@ -411,7 +507,7 @@ func TestConnExecContextPrecanceled(t *testing.T) {
_, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll()
assert.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
assert.True(t, pgconn.SafeToRetry(err))
ensureConnValid(t, pgConn)
}
@@ -437,6 +533,33 @@ func TestConnExecParams(t *testing.T) {
ensureConnValid(t, pgConn)
}
func TestConnExecParamsDeferredError(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer closeConn(t, pgConn)
setupSQL := `create temporary table t (
id text primary key,
n int not null,
unique (n) deferrable initially deferred
);
insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`
_, err = pgConn.Exec(context.Background(), setupSQL).ReadAll()
assert.NoError(t, err)
result := pgConn.ExecParams(context.Background(), `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read()
require.NotNil(t, result.Err)
var pgErr *pgconn.PgError
require.True(t, errors.As(result.Err, &pgErr))
require.Equal(t, "23505", pgErr.Code)
ensureConnValid(t, pgConn)
}
func TestConnExecParamsMaxNumberOfParams(t *testing.T) {
t.Parallel()
@@ -500,9 +623,9 @@ func TestConnExecParamsCanceled(t *testing.T) {
assert.Equal(t, 0, rowCount)
commandTag, err := result.Close()
assert.Equal(t, pgconn.CommandTag(nil), commandTag)
assert.Equal(t, context.DeadlineExceeded, err)
assert.True(t, pgconn.Timeout(err))
assert.False(t, pgConn.IsAlive())
assert.True(t, pgConn.IsClosed())
}
func TestConnExecParamsPrecanceled(t *testing.T) {
@@ -517,7 +640,7 @@ func TestConnExecParamsPrecanceled(t *testing.T) {
result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read()
require.Error(t, result.Err)
assert.True(t, errors.Is(result.Err, context.Canceled))
assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent))
assert.True(t, pgconn.SafeToRetry(result.Err))
ensureConnValid(t, pgConn)
}
@@ -627,8 +750,8 @@ func TestConnExecPreparedCanceled(t *testing.T) {
assert.Equal(t, 0, rowCount)
commandTag, err := result.Close()
assert.Equal(t, pgconn.CommandTag(nil), commandTag)
assert.Equal(t, context.DeadlineExceeded, err)
assert.False(t, pgConn.IsAlive())
assert.True(t, pgconn.Timeout(err))
assert.True(t, pgConn.IsClosed())
}
func TestConnExecPreparedPrecanceled(t *testing.T) {
@@ -646,7 +769,7 @@ func TestConnExecPreparedPrecanceled(t *testing.T) {
result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read()
require.Error(t, result.Err)
assert.True(t, errors.Is(result.Err, context.Canceled))
assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent))
assert.True(t, pgconn.SafeToRetry(result.Err))
ensureConnValid(t, pgConn)
}
@@ -683,6 +806,36 @@ func TestConnExecBatch(t *testing.T) {
assert.Equal(t, "SELECT 1", string(results[2].CommandTag))
}
func TestConnExecBatchDeferredError(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer closeConn(t, pgConn)
setupSQL := `create temporary table t (
id text primary key,
n int not null,
unique (n) deferrable initially deferred
);
insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`
_, err = pgConn.Exec(context.Background(), setupSQL).ReadAll()
assert.NoError(t, err)
batch := &pgconn.Batch{}
batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil)
_, err = pgConn.ExecBatch(context.Background(), batch).ReadAll()
require.NotNil(t, err)
var pgErr *pgconn.PgError
require.True(t, errors.As(err, &pgErr))
require.Equal(t, "23505", pgErr.Code)
ensureConnValid(t, pgConn)
}
func TestConnExecBatchPrecanceled(t *testing.T) {
t.Parallel()
@@ -704,7 +857,7 @@ func TestConnExecBatchPrecanceled(t *testing.T) {
_, err = pgConn.ExecBatch(ctx, batch).ReadAll()
require.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
assert.True(t, pgconn.SafeToRetry(err))
ensureConnValid(t, pgConn)
}
@@ -777,8 +930,8 @@ func TestConnLocking(t *testing.T) {
mrr := pgConn.Exec(context.Background(), "select 'Hello, world'")
_, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll()
assert.Error(t, err)
assert.True(t, errors.Is(err, pgconn.ErrConnBusy))
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
assert.Equal(t, "conn busy", err.Error())
assert.True(t, pgconn.SafeToRetry(err))
results, err := mrr.ReadAll()
assert.NoError(t, err)
@@ -935,7 +1088,7 @@ func TestConnWaitForNotificationTimeout(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
err = pgConn.WaitForNotification(ctx)
cancel()
assert.True(t, errors.Is(err, context.DeadlineExceeded))
assert.True(t, pgconn.Timeout(err))
ensureConnValid(t, pgConn)
}
@@ -1045,10 +1198,10 @@ func TestConnCopyToCanceled(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout")
assert.True(t, errors.Is(err, context.DeadlineExceeded))
assert.Error(t, err)
assert.Equal(t, pgconn.CommandTag(nil), res)
assert.False(t, pgConn.IsAlive())
assert.True(t, pgConn.IsClosed())
}
func TestConnCopyToPrecanceled(t *testing.T) {
@@ -1065,7 +1218,7 @@ func TestConnCopyToPrecanceled(t *testing.T) {
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout")
require.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
assert.True(t, pgconn.SafeToRetry(err))
assert.Equal(t, pgconn.CommandTag(nil), res)
ensureConnValid(t, pgConn)
@@ -1137,9 +1290,9 @@ func TestConnCopyFromCanceled(t *testing.T) {
ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
cancel()
assert.Equal(t, int64(0), ct.RowsAffected())
assert.True(t, errors.Is(err, context.DeadlineExceeded))
assert.Error(t, err)
assert.False(t, pgConn.IsAlive())
assert.True(t, pgConn.IsClosed())
}
func TestConnCopyFromPrecanceled(t *testing.T) {
@@ -1173,7 +1326,7 @@ func TestConnCopyFromPrecanceled(t *testing.T) {
ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
require.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent))
assert.True(t, pgconn.SafeToRetry(err))
assert.Equal(t, pgconn.CommandTag(nil), ct)
ensureConnValid(t, pgConn)
@@ -1331,6 +1484,45 @@ func TestConnCancelRequest(t *testing.T) {
ensureConnValid(t, pgConn)
}
func TestConnSendBytesAndReceiveMessage(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer closeConn(t, pgConn)
queryMsg := pgproto3.Query{String: "select 42"}
buf := queryMsg.Encode(nil)
err = pgConn.SendBytes(ctx, buf)
require.NoError(t, err)
msg, err := pgConn.ReceiveMessage(ctx)
require.NoError(t, err)
_, ok := msg.(*pgproto3.RowDescription)
require.True(t, ok)
msg, err = pgConn.ReceiveMessage(ctx)
require.NoError(t, err)
_, ok = msg.(*pgproto3.DataRow)
require.True(t, ok)
msg, err = pgConn.ReceiveMessage(ctx)
require.NoError(t, err)
_, ok = msg.(*pgproto3.CommandComplete)
require.True(t, ok)
msg, err = pgConn.ReceiveMessage(ctx)
require.NoError(t, err)
_, ok = msg.(*pgproto3.ReadyForQuery)
require.True(t, ok)
ensureConnValid(t, pgConn)
}
func Example() {
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING"))
if err != nil {
+111
View File
@@ -0,0 +1,111 @@
package stmtcache
import (
"container/list"
"context"
"fmt"
"sync/atomic"
"github.com/jackc/pgconn"
)
var lruCount uint64
// LRU implements Cache with a Least Recently Used (LRU) cache.
type LRU struct {
conn *pgconn.PgConn
mode int
cap int
prepareCount int
m map[string]*list.Element
l *list.List
psNamePrefix string
}
// NewLRU creates a new LRU. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache.
func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU {
mustBeValidMode(mode)
mustBeValidCap(cap)
n := atomic.AddUint64(&lruCount, 1)
return &LRU{
conn: conn,
mode: mode,
cap: cap,
m: make(map[string]*list.Element),
l: list.New(),
psNamePrefix: fmt.Sprintf("lrupsc_%d", n),
}
}
// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed.
func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
if el, ok := c.m[sql]; ok {
c.l.MoveToFront(el)
return el.Value.(*pgconn.StatementDescription), nil
}
if c.l.Len() == c.cap {
err := c.removeOldest(ctx)
if err != nil {
return nil, err
}
}
psd, err := c.prepare(ctx, sql)
if err != nil {
return nil, err
}
el := c.l.PushFront(psd)
c.m[sql] = el
return psd, nil
}
// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session.
func (c *LRU) Clear(ctx context.Context) error {
for c.l.Len() > 0 {
err := c.removeOldest(ctx)
if err != nil {
return err
}
}
return nil
}
// Len returns the number of cached prepared statement descriptions.
func (c *LRU) Len() int {
return c.l.Len()
}
// Cap returns the maximum number of cached prepared statement descriptions.
func (c *LRU) Cap() int {
return c.cap
}
// Mode returns the mode of the cache (ModePrepare or ModeDescribe)
func (c *LRU) Mode() int {
return c.mode
}
func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
var name string
if c.mode == ModePrepare {
name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount)
c.prepareCount += 1
}
return c.conn.Prepare(ctx, name, sql, nil)
}
func (c *LRU) removeOldest(ctx context.Context) error {
oldest := c.l.Back()
c.l.Remove(oldest)
if c.mode == ModePrepare {
return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", oldest.Value.(*pgconn.StatementDescription).Name)).Close()
}
return nil
}
+113
View File
@@ -0,0 +1,113 @@
package stmtcache_test
import (
"context"
"os"
"testing"
"time"
"github.com/jackc/pgconn"
"github.com/jackc/pgconn/stmtcache"
"github.com/stretchr/testify/require"
)
func TestLRUModePrepare(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer conn.Close(ctx)
cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2)
require.EqualValues(t, 0, cache.Len())
require.EqualValues(t, 2, cache.Cap())
require.EqualValues(t, stmtcache.ModePrepare, cache.Mode())
psd, err := cache.Get(ctx, "select 1")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 1, cache.Len())
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 1")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 1, cache.Len())
require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 2")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 2, cache.Len())
require.ElementsMatch(t, []string{"select 1", "select 2"}, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 3")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 2, cache.Len())
require.ElementsMatch(t, []string{"select 2", "select 3"}, fetchServerStatements(t, ctx, conn))
err = cache.Clear(ctx)
require.NoError(t, err)
require.EqualValues(t, 0, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
}
func TestLRUModeDescribe(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
defer conn.Close(ctx)
cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2)
require.EqualValues(t, 0, cache.Len())
require.EqualValues(t, 2, cache.Cap())
require.EqualValues(t, stmtcache.ModeDescribe, cache.Mode())
psd, err := cache.Get(ctx, "select 1")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 1, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 1")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 1, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 2")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 2, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
psd, err = cache.Get(ctx, "select 3")
require.NoError(t, err)
require.NotNil(t, psd)
require.EqualValues(t, 2, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
err = cache.Clear(ctx)
require.NoError(t, err)
require.EqualValues(t, 0, cache.Len())
require.Empty(t, fetchServerStatements(t, ctx, conn))
}
func fetchServerStatements(t testing.TB, ctx context.Context, conn *pgconn.PgConn) []string {
result := conn.ExecParams(ctx, `select statement from pg_prepared_statements`, nil, nil, nil, nil).Read()
require.NoError(t, result.Err)
var statements []string
for _, r := range result.Rows {
statements = append(statements, string(r[0]))
}
return statements
}
+52
View File
@@ -0,0 +1,52 @@
// Package stmtcache is a cache that can be used to implement lazy prepared statements.
package stmtcache
import (
"context"
"github.com/jackc/pgconn"
)
const (
ModePrepare = iota // Cache should prepare named statements.
ModeDescribe // Cache should prepare the anonymous prepared statement to only fetch the description of the statement.
)
// Cache prepares and caches prepared statement descriptions.
type Cache interface {
// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed.
Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error)
// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session.
Clear(ctx context.Context) error
// Len returns the number of cached prepared statement descriptions.
Len() int
// Cap returns the maximum number of cached prepared statement descriptions.
Cap() int
// Mode returns the mode of the cache (ModePrepare or ModeDescribe)
Mode() int
}
// New returns the preferred cache implementation for mode and cap. mode is either ModePrepare or ModeDescribe. cap is
// the maximum size of the cache.
func New(conn *pgconn.PgConn, mode int, cap int) Cache {
mustBeValidMode(mode)
mustBeValidCap(cap)
return NewLRU(conn, mode, cap)
}
func mustBeValidMode(mode int) {
if mode != ModePrepare && mode != ModeDescribe {
panic("mode must be ModePrepare or ModeDescribe")
}
}
func mustBeValidCap(cap int) {
if cap < 1 {
panic("cache must have cap of >= 1")
}
}
-14
View File
@@ -1,14 +0,0 @@
#!/usr/bin/env bash
set -eux
go get -u github.com/cockroachdb/apd
go get -u github.com/shopspring/decimal
go get -u gopkg.in/inconshreveable/log15.v2
go get -u github.com/jackc/fake
go get -u github.com/lib/pq
go get -u github.com/hashicorp/go-version
go get -u github.com/satori/go.uuid
go get -u github.com/sirupsen/logrus
go get -u github.com/pkg/errors
go get -u go.uber.org/zap
go get -u github.com/rs/zerolog