diff --git a/.travis.yml b/.travis.yml index e5ed43a8..2c547abf 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,6 +4,9 @@ go: - 1.x - tip +git: + depth: 1 + # Derived from https://github.com/lib/pq/blob/master/.travis.yml before_install: - ./travis/before_install.bash @@ -11,6 +14,8 @@ before_install: env: global: - GO111MODULE=on + - GOPROXY=https://proxy.golang.org + - GOFLAGS=-mod=readonly - PGX_TEST_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test - PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql database=pgx_test" - PGX_TEST_TCP_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test @@ -25,11 +30,15 @@ env: - PGVERSION=9.4 - PGVERSION=9.3 +cache: + directories: + - $HOME/.cache/go-build + - $HOME/gopath/pkg/mod + before_script: - ./travis/before_script.bash -install: - - ./travis/install.bash +install: go mod download script: - ./travis/script.bash diff --git a/auth_scram.go b/auth_scram.go index bdaf3e92..665fc2c2 100644 --- a/auth_scram.go +++ b/auth_scram.go @@ -31,7 +31,7 @@ const clientNonceLen = 18 // Perform SCRAM authentication. func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { - sc, err := newScramClient(serverAuthMechanisms, c.Config.Password) + sc, err := newScramClient(serverAuthMechanisms, c.config.Password) if err != nil { return err } @@ -47,11 +47,11 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { } // Receive server-first-message payload in a AuthenticationSASLContinue. - authMsg, err := c.rxAuthMsg(pgproto3.AuthTypeSASLContinue) + saslContinue, err := c.rxSASLContinue() if err != nil { return err } - err = sc.recvServerFirstMessage(authMsg.SASLData) + err = sc.recvServerFirstMessage(saslContinue.Data) if err != nil { return err } @@ -66,27 +66,37 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { } // Receive server-final-message payload in a AuthenticationSASLFinal. - authMsg, err = c.rxAuthMsg(pgproto3.AuthTypeSASLFinal) + saslFinal, err := c.rxSASLFinal() if err != nil { return err } - return sc.recvServerFinalMessage(authMsg.SASLData) + return sc.recvServerFinalMessage(saslFinal.Data) } -func (c *PgConn) rxAuthMsg(typ uint32) (*pgproto3.Authentication, error) { - msg, err := c.ReceiveMessage() +func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) { + msg, err := c.receiveMessage() if err != nil { return nil, err } - authMsg, ok := msg.(*pgproto3.Authentication) - if !ok { - return nil, errors.New("unexpected message type") - } - if authMsg.Type != typ { - return nil, errors.New("unexpected auth type") + saslContinue, ok := msg.(*pgproto3.AuthenticationSASLContinue) + if ok { + return saslContinue, nil } - return authMsg, nil + return nil, errors.New("expected AuthenticationSASLContinue message but received unexpected message") +} + +func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) { + msg, err := c.receiveMessage() + if err != nil { + return nil, err + } + saslFinal, ok := msg.(*pgproto3.AuthenticationSASLFinal) + if ok { + return saslFinal, nil + } + + return nil, errors.New("expected AuthenticationSASLFinal message but received unexpected message") } type scramClient struct { diff --git a/config.go b/config.go index 9b74945e..2ec6ae3f 100644 --- a/config.go +++ b/config.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "io" "io/ioutil" "math" "net" @@ -17,22 +18,26 @@ import ( "strings" "time" + "github.com/jackc/chunkreader/v2" "github.com/jackc/pgpassfile" + "github.com/jackc/pgproto3/v2" errors "golang.org/x/xerrors" ) type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error -// Config is the settings used to establish a connection to a PostgreSQL server. +// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig and +// then it can be modified. A manually initialized Config will cause ConnectConfig to panic. type Config struct { Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) Port uint16 Database string User string Password string - TLSConfig *tls.Config // nil disables TLS - DialFunc DialFunc // e.g. net.Dialer.DialContext + TLSConfig *tls.Config // nil disables TLS + DialFunc DialFunc // e.g. net.Dialer.DialContext + BuildFrontend BuildFrontendFunc RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) Fallbacks []*FallbackConfig @@ -52,6 +57,8 @@ type Config struct { // OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received. OnNotification NotificationHandler + + createdByParseConfig bool // Used to enforce created by ParseConfig rule. } // FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a @@ -134,36 +141,48 @@ func NetworkAddress(host string, port uint16) (network, address string) { // // When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn // does not. +// +// In addition, ParseConfig accepts the following options: +// +// min_read_buffer_size +// The minimum size of the internal read buffer. Default 8192. func ParseConfig(connString string) (*Config, error) { settings := defaultSettings() addEnvSettings(settings) if connString != "" { // connString may be a database URL or a DSN - if strings.HasPrefix(connString, "postgres://") { + if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { err := addURLSettings(settings, connString) if err != nil { - return nil, err + return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err} } } else { err := addDSNSettings(settings, connString) if err != nil { - return nil, err + return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err} } } } + minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32) + if err != nil { + return nil, &parseConfigError{connString: connString, msg: "cannot parse min_read_buffer_size", err: err} + } + config := &Config{ - Database: settings["database"], - User: settings["user"], - Password: settings["password"], - RuntimeParams: make(map[string]string), + createdByParseConfig: true, + Database: settings["database"], + User: settings["user"], + Password: settings["password"], + RuntimeParams: make(map[string]string), + BuildFrontend: makeDefaultBuildFrontendFunc(int(minReadBufferSize)), } if connectTimeout, present := settings["connect_timeout"]; present { dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout) if err != nil { - return nil, err + return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err} } config.DialFunc = dialFunc } else { @@ -184,6 +203,7 @@ func ParseConfig(connString string) (*Config, error) { "sslcert": struct{}{}, "sslrootcert": struct{}{}, "target_session_attrs": struct{}{}, + "min_read_buffer_size": struct{}{}, } for k, v := range settings { @@ -208,7 +228,7 @@ func ParseConfig(connString string) (*Config, error) { port, err := parsePort(portStr) if err != nil { - return nil, errors.Errorf("invalid port: %w", err) + return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err} } var tlsConfigs []*tls.Config @@ -220,7 +240,7 @@ func ParseConfig(connString string) (*Config, error) { var err error tlsConfigs, err = configTLS(settings) if err != nil { - return nil, err + return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err} } } @@ -253,7 +273,7 @@ func ParseConfig(connString string) (*Config, error) { if settings["target_session_attrs"] == "read-write" { config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite } else if settings["target_session_attrs"] != "any" { - return nil, errors.Errorf("unknown target_session_attrs value: %v", settings["target_session_attrs"]) + return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", settings["target_session_attrs"])} } return config, nil @@ -276,6 +296,8 @@ func defaultSettings() map[string]string { settings["target_session_attrs"] = "any" + settings["min_read_buffer_size"] = "8192" + return settings } @@ -473,6 +495,18 @@ func makeDefaultDialer() *net.Dialer { return &net.Dialer{KeepAlive: 5 * time.Minute} } +func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc { + return func(r io.Reader, w io.Writer) Frontend { + cr, err := chunkreader.NewConfig(r, chunkreader.Config{MinBufLen: minBufferLen}) + if err != nil { + panic(fmt.Sprintf("BUG: chunkreader.NewConfig failed: %v", err)) + } + frontend := pgproto3.NewFrontend(cr, w) + + return frontend + } +} + func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { timeout, err := strconv.ParseInt(s, 10, 64) if err != nil { diff --git a/config_test.go b/config_test.go index 23d86529..090302a2 100644 --- a/config_test.go +++ b/config_test.go @@ -214,6 +214,18 @@ func TestParseConfig(t *testing.T) { RuntimeParams: map[string]string{}, }, }, + { + name: "database url postgresql protocol", + connString: "postgresql://jack@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, { name: "DSN everything", connString: "user=jack password=secret host=localhost port=5432 database=mydb sslmode=disable application_name=pgxtest search_path=myschema", @@ -561,3 +573,15 @@ func TestParseConfigReadsPgPassfile(t *testing.T) { assertConfigsEqual(t, expected, actual, "passfile") } + +func TestParseConfigExtractsMinReadBufferSize(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig("min_read_buffer_size=0") + require.NoError(t, err) + _, present := config.RuntimeParams["min_read_buffer_size"] + require.False(t, present) + + // The buffer size is internal so there isn't much that can be done to test it other than see that the runtime param + // was removed. +} diff --git a/doc.go b/doc.go index d36eb0fd..cde58cd8 100644 --- a/doc.go +++ b/doc.go @@ -15,7 +15,7 @@ reads all rows into memory. Executing Multiple Queries in a Single Round Trip -Exec and ExecBatch can execute multiple queries in a single round trip. The return readers that iterate over each query +Exec and ExecBatch can execute multiple queries in a single round trip. They return readers that iterate over each query result. The ReadAll method reads all query results into memory. Context Support diff --git a/errors.go b/errors.go index 4f8af407..a088dcdd 100644 --- a/errors.go +++ b/errors.go @@ -2,22 +2,31 @@ package pgconn import ( "context" + "fmt" "net" + "strings" errors "golang.org/x/xerrors" ) -// ErrTLSRefused occurs when the connection attempt requires TLS and the -// PostgreSQL server refuses to use TLS -var ErrTLSRefused = errors.New("server refused TLS connection") +// SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server. +func SafeToRetry(err error) bool { + if e, ok := err.(interface{ SafeToRetry() bool }); ok { + return e.SafeToRetry() + } + return false +} -// ErrConnBusy occurs when the connection is busy (for example, in the middle of reading query results) and another -// action is attempted. -var ErrConnBusy = errors.New("conn is busy") +// Timeout checks if err was was caused by a timeout. To be specific, it is true if err is or was caused by a +// context.Canceled, context.Canceled or an implementer of net.Error where Timeout() is true. +func Timeout(err error) bool { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return true + } -// ErrNoBytesSent is used to annotate an error that occurred without sending any bytes to the server. This can be used -// to implement safe retry logic. ErrNoBytesSent will never occur alone. It will always be wrapped by another error. -var ErrNoBytesSent = errors.New("no bytes sent to server") + var netErr net.Error + return errors.As(err, &netErr) && netErr.Timeout() +} // PgError represents an error reported by the PostgreSQL server. See // http://www.postgresql.org/docs/11/static/protocol-error-fields.html for @@ -46,44 +55,107 @@ func (pe *PgError) Error() string { return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" } -// linkedError connects two errors as if err wrapped next. -type linkedError struct { - err error - next error +type connectError struct { + config *Config + msg string + err error } -func (le *linkedError) Error() string { - return le.err.Error() -} - -func (le *linkedError) Is(target error) bool { - return errors.Is(le.err, target) -} - -func (le *linkedError) As(target interface{}) bool { - return errors.As(le.err, target) -} - -func (le *linkedError) Unwrap() error { - return le.next -} - -// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == -// true. Otherwise returns err. -func preferContextOverNetTimeoutError(ctx context.Context, err error) error { - if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil { - return ctx.Err() +func (e *connectError) Error() string { + sb := &strings.Builder{} + fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg) + if e.err != nil { + fmt.Fprintf(sb, " (%s)", e.err.Error()) } - return err + return sb.String() } -// linkErrors connects outer and inner as if the the fully unwrapped outer wrapped inner. If either outer or inner is nil then the other is returned. -func linkErrors(outer, inner error) error { - if outer == nil { - return inner +func (e *connectError) Unwrap() error { + return e.err +} + +type connLockError struct { + status string +} + +func (e *connLockError) SafeToRetry() bool { + return true // a lock failure by definition happens before the connection is used. +} + +func (e *connLockError) Error() string { + return e.status +} + +type parseConfigError struct { + connString string + msg string + err error +} + +func (e *parseConfigError) Error() string { + if e.err == nil { + return fmt.Sprintf("cannot parse `%s`: %s", e.connString, e.msg) } - if inner == nil { - return outer + return fmt.Sprintf("cannot parse `%s`: %s (%s)", e.connString, e.msg, e.err.Error()) +} + +func (e *parseConfigError) Unwrap() error { + return e.err +} + +type pgconnError struct { + msg string + err error + safeToRetry bool +} + +func (e *pgconnError) Error() string { + if e.msg == "" { + return e.err.Error() } - return &linkedError{err: outer, next: inner} + if e.err == nil { + return e.msg + } + return fmt.Sprintf("%s: %s", e.msg, e.err.Error()) +} + +func (e *pgconnError) SafeToRetry() bool { + return e.safeToRetry +} + +func (e *pgconnError) Unwrap() error { + return e.err +} + +type contextAlreadyDoneError struct { + err error +} + +func (e *contextAlreadyDoneError) Error() string { + return fmt.Sprintf("context already done: %s", e.err.Error()) +} + +func (e *contextAlreadyDoneError) SafeToRetry() bool { + return true +} + +func (e *contextAlreadyDoneError) Unwrap() error { + return e.err +} + +type writeError struct { + err error + safeToRetry bool +} + +func (e *writeError) Error() string { + return fmt.Sprintf("write failed: %s", e.err.Error()) +} + +func (e *writeError) SafeToRetry() bool { + return e.safeToRetry +} + +func (e *writeError) Unwrap() error { + return e.err } diff --git a/go.mod b/go.mod index b1c84049..11692c10 100644 --- a/go.mod +++ b/go.mod @@ -3,11 +3,13 @@ module github.com/jackc/pgconn go 1.12 require ( + github.com/jackc/chunkreader/v2 v2.0.0 github.com/jackc/pgio v1.0.0 + github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 - github.com/stretchr/testify v1.3.0 - golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a - golang.org/x/text v0.3.0 - golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 + github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 + github.com/stretchr/testify v1.4.0 + golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 + golang.org/x/text v0.3.2 + golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 ) diff --git a/go.sum b/go.sum index 50dfc2fd..d0a917fc 100644 --- a/go.sum +++ b/go.sum @@ -1,31 +1,105 @@ -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= +github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= +github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= +github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye4717ITLaNwV9mWbJx0dLCpcRzdA= +github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db h1:UpaKn/gYxzH6/zWyRQH1S260zvKqwJJ4h8+Kf09ooh0= +github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= +github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 h1:vZp4bYotXUkFx7JUSm7U8KV/7Q0AOdrQxxBBj0ZmZsg= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= -github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 h1:f2HwOeI1NIJyNFVVeh1gUISyt57iw/fmI/IXJfH3ATE= +github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= +github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= +github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= +github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= +github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= +github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= +github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= +github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= +github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= +github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= +github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= +github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcNpq8q3BCACtVgNfoJxOV7g= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= +go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= -golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e h1:nFYrTHrdrAOpShe27kaFHjsqYSEQ0KWqdWLu3xuZJts= +golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 h1:7KByu05hhLed2MO29w7p1XfZvZ13m8mub3shuVftRs0= +golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 h1:PPwnA7z1Pjf7XYaBP9GL1VAMZmcIWyFz7QCMSIIa3Bg= +golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 h1:bhOzK9QyoD0ogCnFro1m2mz41+Ib0oOhfJnBp5MR4K4= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/pgconn.go b/pgconn.go index 6e1fb7e3..5c01d1dc 100644 --- a/pgconn.go +++ b/pgconn.go @@ -40,9 +40,12 @@ type Notification struct { Payload string } -// DialFunc is a function that can be used to connect to a PostgreSQL server +// DialFunc is a function that can be used to connect to a PostgreSQL server. type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) +// BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. +type BuildFrontendFunc func(r io.Reader, w io.Writer) Frontend + // NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at // any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin // of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY @@ -55,16 +58,21 @@ type NoticeHandler func(*PgConn, *Notice) // notice event. type NotificationHandler func(*PgConn, *Notification) +// Frontend used to receive messages from backend. +type Frontend interface { + Receive() (pgproto3.BackendMessage, error) +} + // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { conn net.Conn // the underlying TCP or unix domain socket connection pid uint32 // backend pid secretKey uint32 // key to use to send a cancel query message to the server parameterStatuses map[string]string // parameters that have been reported by the server - TxStatus byte - Frontend *pgproto3.Frontend + txStatus byte + frontend Frontend - Config *Config + config *Config status byte // One of connStatus* constants @@ -91,22 +99,18 @@ func Connect(ctx context.Context, connString string) (*PgConn, error) { return ConnectConfig(ctx, config) } -// Connect establishes a connection to a PostgreSQL server using config. ctx can be used to cancel a connect attempt. +// Connect establishes a connection to a PostgreSQL server using config. config must have been constructed with +// ParseConfig. ctx can be used to cancel a connect attempt. // // If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An // authentication error will terminate the chain of attempts (like libpq: // https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise, // if all attempts fail the last error is returned. func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err error) { - // For convenience set a few defaults if not already set. This makes it simpler to directly construct a config. - if config.Port == 0 { - config.Port = 5432 - } - if config.DialFunc == nil { - config.DialFunc = makeDefaultDialer().DialContext - } - if config.RuntimeParams == nil { - config.RuntimeParams = make(map[string]string) + // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from + // zero values. + if !config.createdByParseConfig { + panic("config must be created by ParseConfig") } // Simplify usage by treating primary config and fallbacks the same. @@ -124,19 +128,19 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err if err == nil { break } else if err, ok := err.(*PgError); ok { - return nil, err + return nil, &connectError{config: config, msg: "server error", err: err} } } if err != nil { - return nil, err + return nil, err // no need to wrap in connectError because it will already be wrapped in all cases except PgError } if config.AfterConnect != nil { err := config.AfterConnect(ctx, pgConn) if err != nil { pgConn.conn.Close() - return nil, errors.Errorf("AfterConnect: %v", err) + return nil, &connectError{config: config, msg: "AfterConnect error", err: err} } } @@ -145,14 +149,14 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { pgConn := new(PgConn) - pgConn.Config = config + pgConn.config = config pgConn.wbuf = make([]byte, 0, 1024) var err error network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) pgConn.conn, err = config.DialFunc(ctx, network, address) if err != nil { - return nil, err + return nil, &connectError{config: config, msg: "dial error", err: err} } pgConn.parameterStatuses = make(map[string]string) @@ -160,7 +164,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if fallbackConfig.TLSConfig != nil { if err := pgConn.startTLS(fallbackConfig.TLSConfig); err != nil { pgConn.conn.Close() - return nil, err + return nil, &connectError{config: config, msg: "tls error", err: err} } } @@ -170,10 +174,10 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig func() { pgConn.conn.SetDeadline(time.Time{}) }, ) - pgConn.Frontend, err = pgproto3.NewFrontend(pgproto3.NewChunkReader(pgConn.conn), pgConn.conn) - if err != nil { - return nil, err - } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + + pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn) startupMsg := pgproto3.StartupMessage{ ProtocolVersion: pgproto3.ProtocolVersionNumber, @@ -192,32 +196,52 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if _, err := pgConn.conn.Write(startupMsg.Encode(pgConn.wbuf)); err != nil { pgConn.conn.Close() - return nil, err + return nil, &connectError{config: config, msg: "failed to write startup message", err: err} } for { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.conn.Close() - return nil, err + if err, ok := err.(*PgError); ok { + return nil, err + } + return nil, &connectError{config: config, msg: "failed to receive message", err: err} } switch msg := msg.(type) { case *pgproto3.BackendKeyData: pgConn.pid = msg.ProcessID pgConn.secretKey = msg.SecretKey - case *pgproto3.Authentication: - if err = pgConn.rxAuthenticationX(msg); err != nil { + + case *pgproto3.AuthenticationOk: + case *pgproto3.AuthenticationCleartextPassword: + err = pgConn.txPasswordMessage(pgConn.config.Password) + if err != nil { pgConn.conn.Close() - return nil, err + return nil, &connectError{config: config, msg: "failed to write password message", err: err} } + case *pgproto3.AuthenticationMD5Password: + digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:])) + err = pgConn.txPasswordMessage(digestedPassword) + if err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "failed to write password message", err: err} + } + case *pgproto3.AuthenticationSASL: + err = pgConn.scramAuth(msg.AuthMechanisms) + if err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "failed SASL auth", err: err} + } + case *pgproto3.ReadyForQuery: pgConn.status = connStatusIdle if config.ValidateConnect != nil { err := config.ValidateConnect(ctx, pgConn) if err != nil { pgConn.conn.Close() - return nil, errors.Errorf("ValidateConnect: %v", err) + return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err} } } return pgConn, nil @@ -225,10 +249,10 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig // handled by ReceiveMessage case *pgproto3.ErrorResponse: pgConn.conn.Close() - return nil, errorResponseToPgError(msg) + return nil, ErrorResponseToPgError(msg) default: pgConn.conn.Close() - return nil, errors.New("unexpected message") + return nil, &connectError{config: config, msg: "received unexpected message", err: err} } } } @@ -245,7 +269,7 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { } if response[0] != 'S' { - return ErrTLSRefused + return errors.New("server refused TLS connection") } pgConn.conn = tls.Client(pgConn.conn, tlsConfig) @@ -253,23 +277,6 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { return nil } -func (pgConn *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { - switch msg.Type { - case pgproto3.AuthTypeOk: - case pgproto3.AuthTypeCleartextPassword: - err = pgConn.txPasswordMessage(pgConn.Config.Password) - case pgproto3.AuthTypeMD5Password: - digestedPassword := "md5" + hexMD5(hexMD5(pgConn.Config.Password+pgConn.Config.User)+string(msg.Salt[:])) - err = pgConn.txPasswordMessage(digestedPassword) - case pgproto3.AuthTypeSASL: - err = pgConn.scramAuth(msg.SASLAuthMechanisms) - default: - err = errors.New("Received unknown authentication message") - } - - return -} - func (pgConn *PgConn) txPasswordMessage(password string) (err error) { msg := &pgproto3.PasswordMessage{Password: password} _, err = pgConn.conn.Write(msg.Encode(pgConn.wbuf)) @@ -292,7 +299,7 @@ func (pgConn *PgConn) signalMessage() chan struct{} { ch := make(chan struct{}) go func() { - pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.Frontend.Receive() + pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive() pgConn.bufferingReceiveMux.Unlock() close(ch) }() @@ -300,7 +307,64 @@ func (pgConn *PgConn) signalMessage() chan struct{} { return ch } -func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { +// SendBytes sends buf to the PostgreSQL server. It must only be used when the connection is not busy. e.g. It is as +// error to call SendBytes while reading the result of a query. +// +// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. +// See https://www.postgresql.org/docs/current/protocol.html. +func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { + if err := pgConn.lock(); err != nil { + return err + } + defer pgConn.unlock() + + select { + case <-ctx.Done(): + return &contextAlreadyDoneError{err: ctx.Err()} + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + + n, err := pgConn.conn.Write(buf) + if err != nil { + pgConn.hardClose() + return &writeError{err: err, safeToRetry: n == 0} + } + + return nil +} + +// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the +// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages +// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger +// the OnNotification callback. +// +// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. +// See https://www.postgresql.org/docs/current/protocol.html. +func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessage, error) { + if err := pgConn.lock(); err != nil { + return nil, err + } + defer pgConn.unlock() + + select { + case <-ctx.Done(): + return nil, &contextAlreadyDoneError{err: ctx.Err()} + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + + msg, err := pgConn.receiveMessage() + if err != nil { + err = &pgconnError{msg: "receive message failed", err: err, safeToRetry: true} + } + return msg, err +} + +// receiveMessage receives a message without setting up context cancellation +func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { var msg pgproto3.BackendMessage var err error if pgConn.bufferingReceive { @@ -312,10 +376,10 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { // If a timeout error happened in the background try the read again. if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - msg, err = pgConn.Frontend.Receive() + msg, err = pgConn.frontend.Receive() } } else { - msg, err = pgConn.Frontend.Receive() + msg, err = pgConn.frontend.Receive() } if err != nil { @@ -329,21 +393,21 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { switch msg := msg.(type) { case *pgproto3.ReadyForQuery: - pgConn.TxStatus = msg.TxStatus + pgConn.txStatus = msg.TxStatus case *pgproto3.ParameterStatus: pgConn.parameterStatuses[msg.Name] = msg.Value case *pgproto3.ErrorResponse: if msg.Severity == "FATAL" { pgConn.hardClose() - return nil, errorResponseToPgError(msg) + return nil, ErrorResponseToPgError(msg) } case *pgproto3.NoticeResponse: - if pgConn.Config.OnNotice != nil { - pgConn.Config.OnNotice(pgConn, noticeResponseToNotice(msg)) + if pgConn.config.OnNotice != nil { + pgConn.config.OnNotice(pgConn, noticeResponseToNotice(msg)) } case *pgproto3.NotificationResponse: - if pgConn.Config.OnNotification != nil { - pgConn.Config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload}) + if pgConn.config.OnNotification != nil { + pgConn.config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload}) } } @@ -360,6 +424,11 @@ func (pgConn *PgConn) PID() uint32 { return pgConn.pid } +// TxStatus returns the current TxStatus as reported by the server. +func (pgConn *PgConn) TxStatus() byte { + return pgConn.txStatus +} + // SecretKey returns the backend secret key used to send a cancel query message to the server. func (pgConn *PgConn) SecretKey() uint32 { return pgConn.secretKey @@ -381,12 +450,12 @@ func (pgConn *PgConn) Close(ctx context.Context) error { _, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) if err != nil { - return linkErrors(ctx.Err(), err) + return err } _, err = pgConn.conn.Read(make([]byte, 1)) if err != io.EOF { - return linkErrors(ctx.Err(), err) + return err } return pgConn.conn.Close() @@ -402,21 +471,20 @@ func (pgConn *PgConn) hardClose() error { return pgConn.conn.Close() } -// TODO - rethink how to report status. At the moment this is just a temporary measure so pgx.Conn can detect death of -// underlying connection. -func (pgConn *PgConn) IsAlive() bool { - return pgConn.status >= connStatusIdle +// IsClosed reports if the connection has been closed. +func (pgConn *PgConn) IsClosed() bool { + return pgConn.status < connStatusIdle } -// lock locks the connection. It panics if the connection is already locked or is closed. +// lock locks the connection. func (pgConn *PgConn) lock() error { switch pgConn.status { case connStatusBusy: - return ErrConnBusy // This only should be possible in case of an application bug. + return &connLockError{status: "conn busy"} // This only should be possible in case of an application bug. case connStatusClosed: - return errors.New("conn closed") + return &connLockError{status: "conn closed"} case connStatusUninitialized: - return errors.New("conn uninitialized") + return &connLockError{status: "conn uninitialized"} } pgConn.status = connStatusBusy return nil @@ -456,23 +524,24 @@ func (ct CommandTag) String() string { return string(ct) } -type PreparedStatementDescription struct { +type StatementDescription struct { Name string SQL string ParamOIDs []uint32 Fields []pgproto3.FieldDescription } -// Prepare creates a prepared statement. -func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) { +// Prepare creates a prepared statement. If the name is empty, the anonymous prepared statement will be used. This +// allows Prepare to also to describe statements without creating a server-side prepared statement. +func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) { if err := pgConn.lock(); err != nil { - return nil, linkErrors(err, ErrNoBytesSent) + return nil, err } defer pgConn.unlock() select { case <-ctx.Done(): - return nil, linkErrors(ctx.Err(), ErrNoBytesSent) + return nil, &contextAlreadyDoneError{err: ctx.Err()} default: } pgConn.contextWatcher.Watch(ctx) @@ -486,22 +555,19 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - if n == 0 { - err = linkErrors(err, ErrNoBytesSent) - } - return nil, linkErrors(ctx.Err(), err) + return nil, &pgconnError{msg: "write failed", err: err, safeToRetry: n == 0} } - psd := &PreparedStatementDescription{Name: name, SQL: sql} + psd := &StatementDescription{Name: name, SQL: sql} var parseErr error readloop: for { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } switch msg := msg.(type) { @@ -512,7 +578,7 @@ readloop: psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields)) copy(psd.Fields, msg.Fields) case *pgproto3.ErrorResponse: - parseErr = errorResponseToPgError(msg) + parseErr = ErrorResponseToPgError(msg) case *pgproto3.ReadyForQuery: break readloop } @@ -524,7 +590,8 @@ readloop: return psd, nil } -func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { +// ErrorResponseToPgError converts a wire protocol error message to a *PgError. +func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { return &PgError{ Severity: msg.Severity, Code: string(msg.Code), @@ -547,7 +614,7 @@ func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { } func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice { - pgerr := errorResponseToPgError((*pgproto3.ErrorResponse)(msg)) + pgerr := ErrorResponseToPgError((*pgproto3.ErrorResponse)(msg)) return (*Notice)(pgerr) } @@ -559,7 +626,7 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. serverAddr := pgConn.conn.RemoteAddr() - cancelConn, err := pgConn.Config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) + cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) if err != nil { return err } @@ -579,12 +646,12 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey)) _, err = cancelConn.Write(buf) if err != nil { - return linkErrors(ctx.Err(), err) + return err } _, err = cancelConn.Read(buf) if err != io.EOF { - return errors.Errorf("Server failed to close connection after cancel query request: %w", linkErrors(ctx.Err(), err)) + return err } return nil @@ -608,9 +675,9 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { defer pgConn.contextWatcher.Unwatch() for { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { - return linkErrors(ctx.Err(), err) + return err } switch msg.(type) { @@ -629,7 +696,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { if err := pgConn.lock(); err != nil { return &MultiResultReader{ closed: true, - err: linkErrors(err, ErrNoBytesSent), + err: err, } } @@ -642,7 +709,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { select { case <-ctx.Done(): multiResult.closed = true - multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent) + multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} pgConn.unlock() return multiResult default: @@ -657,10 +724,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { pgConn.hardClose() pgConn.contextWatcher.Unwatch() multiResult.closed = true - if n == 0 { - err = linkErrors(err, ErrNoBytesSent) - } - multiResult.err = linkErrors(ctx.Err(), err) + multiResult.err = &writeError{err: err, safeToRetry: n == 0} pgConn.unlock() return multiResult } @@ -729,19 +793,18 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa } func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader { - if err := pgConn.lock(); err != nil { - return &ResultReader{ - closed: true, - err: linkErrors(err, ErrNoBytesSent), - } - } - pgConn.resultReader = ResultReader{ pgConn: pgConn, ctx: ctx, } result := &pgConn.resultReader + if err := pgConn.lock(); err != nil { + result.concludeCommand(nil, err) + result.closed = true + return result + } + if len(paramValues) > math.MaxUint16 { result.concludeCommand(nil, errors.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) result.closed = true @@ -751,7 +814,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by select { case <-ctx.Done(): - result.concludeCommand(nil, linkErrors(ctx.Err(), ErrNoBytesSent)) + result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()}) result.closed = true pgConn.unlock() return result @@ -770,10 +833,7 @@ func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - if n == 0 { - err = linkErrors(err, ErrNoBytesSent) - } - result.concludeCommand(nil, linkErrors(ctx.Err(), err)) + result.concludeCommand(nil, &writeError{err: err, safeToRetry: n == 0}) pgConn.contextWatcher.Unwatch() result.closed = true pgConn.unlock() @@ -783,13 +843,13 @@ func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result // CopyTo executes the copy command sql and copies the results to w. func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { if err := pgConn.lock(); err != nil { - return nil, linkErrors(err, ErrNoBytesSent) + return nil, err } select { case <-ctx.Done(): pgConn.unlock() - return nil, linkErrors(ctx.Err(), ErrNoBytesSent) + return nil, &contextAlreadyDoneError{err: ctx.Err()} default: } pgConn.contextWatcher.Watch(ctx) @@ -803,20 +863,17 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm if err != nil { pgConn.hardClose() pgConn.unlock() - if n == 0 { - err = linkErrors(err, ErrNoBytesSent) - } - return nil, linkErrors(ctx.Err(), err) + return nil, &writeError{err: err, safeToRetry: n == 0} } // Read results var commandTag CommandTag var pgErr error for { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } switch msg := msg.(type) { @@ -833,7 +890,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm case *pgproto3.CommandComplete: commandTag = CommandTag(msg.CommandTag) case *pgproto3.ErrorResponse: - pgErr = errorResponseToPgError(msg) + pgErr = ErrorResponseToPgError(msg) } } } @@ -844,13 +901,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm // could still block. func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { if err := pgConn.lock(); err != nil { - return nil, linkErrors(err, ErrNoBytesSent) + return nil, err } defer pgConn.unlock() select { case <-ctx.Done(): - return nil, linkErrors(ctx.Err(), ErrNoBytesSent) + return nil, &contextAlreadyDoneError{err: ctx.Err()} default: } pgConn.contextWatcher.Watch(ctx) @@ -863,10 +920,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - if n == 0 { - err = linkErrors(err, ErrNoBytesSent) - } - return nil, linkErrors(ctx.Err(), err) + return nil, &writeError{err: err, safeToRetry: n == 0} } // Read until copy in response or error. @@ -874,17 +928,17 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co var pgErr error pendingCopyInResponse := true for pendingCopyInResponse { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } switch msg := msg.(type) { case *pgproto3.CopyInResponse: pendingCopyInResponse = false case *pgproto3.ErrorResponse: - pgErr = errorResponseToPgError(msg) + pgErr = ErrorResponseToPgError(msg) case *pgproto3.ReadyForQuery: return commandTag, pgErr } @@ -906,21 +960,21 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } } select { case <-signalMessageChan: - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } switch msg := msg.(type) { case *pgproto3.ErrorResponse: - pgErr = errorResponseToPgError(msg) + pgErr = ErrorResponseToPgError(msg) } default: } @@ -937,15 +991,15 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } // Read results for { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } switch msg := msg.(type) { @@ -954,7 +1008,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co case *pgproto3.CommandComplete: commandTag = CommandTag(msg.CommandTag) case *pgproto3.ErrorResponse: - pgErr = errorResponseToPgError(msg) + pgErr = ErrorResponseToPgError(msg) } } } @@ -983,11 +1037,11 @@ func (mrr *MultiResultReader) ReadAll() ([]*Result, error) { } func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) { - msg, err := mrr.pgConn.ReceiveMessage() + msg, err := mrr.pgConn.receiveMessage() if err != nil { mrr.pgConn.contextWatcher.Unwatch() - mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) + mrr.err = err mrr.closed = true mrr.pgConn.hardClose() return nil, mrr.err @@ -999,7 +1053,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) mrr.closed = true mrr.pgConn.unlock() case *pgproto3.ErrorResponse: - mrr.err = errorResponseToPgError(msg) + mrr.err = ErrorResponseToPgError(msg) } return msg, nil @@ -1151,7 +1205,10 @@ func (rr *ResultReader) Close() (CommandTag, error) { return nil, rr.err } - switch msg.(type) { + switch msg := msg.(type) { + // Detect a deferred constraint violation where the ErrorResponse is sent after CommandComplete. + case *pgproto3.ErrorResponse: + rr.err = ErrorResponseToPgError(msg) case *pgproto3.ReadyForQuery: rr.pgConn.contextWatcher.Unwatch() rr.pgConn.unlock() @@ -1165,7 +1222,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) { if rr.multiResultReader == nil { - msg, err = rr.pgConn.ReceiveMessage() + msg, err = rr.pgConn.receiveMessage() } else { msg, err = rr.multiResultReader.receiveMessage() } @@ -1187,7 +1244,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error case *pgproto3.CommandComplete: rr.concludeCommand(CommandTag(msg.CommandTag), nil) case *pgproto3.ErrorResponse: - rr.concludeCommand(nil, errorResponseToPgError(msg)) + rr.concludeCommand(nil, ErrorResponseToPgError(msg)) } return msg, nil @@ -1199,7 +1256,7 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { } rr.commandTag = commandTag - rr.err = preferContextOverNetTimeoutError(rr.ctx, err) + rr.err = err rr.fieldDescriptions = nil rr.rowValues = nil rr.commandConcluded = true @@ -1229,7 +1286,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR if err := pgConn.lock(); err != nil { return &MultiResultReader{ closed: true, - err: linkErrors(err, ErrNoBytesSent), + err: err, } } @@ -1242,7 +1299,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR select { case <-ctx.Done(): multiResult.closed = true - multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent) + multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} pgConn.unlock() return multiResult default: diff --git a/pgconn_test.go b/pgconn_test.go index feb78641..4a67a2e0 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -18,6 +18,8 @@ import ( "time" "github.com/jackc/pgconn" + "github.com/jackc/pgmock" + "github.com/jackc/pgproto3/v2" errors "golang.org/x/xerrors" "github.com/stretchr/testify/assert" @@ -72,6 +74,67 @@ func TestConnectTLS(t *testing.T) { closeConn(t, conn) } +type pgmockWaitStep time.Duration + +func (s pgmockWaitStep) Step(*pgproto3.Backend) error { + time.Sleep(time.Duration(s)) + return nil +} + +func TestConnectWithContextThatTimesOut(t *testing.T) { + t.Parallel() + + script := &pgmock.Script{ + Steps: []pgmock.Step{ + pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}), + pgmock.SendMessage(&pgproto3.AuthenticationOk{}), + pgmockWaitStep(time.Millisecond * 500), + pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}), + pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + }, + } + + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverErrChan := make(chan error, 1) + go func() { + defer close(serverErrChan) + + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + err = conn.SetDeadline(time.Now().Add(time.Millisecond * 450)) + if err != nil { + serverErrChan <- err + return + } + + err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)) + if err != nil { + serverErrChan <- err + return + } + }() + + parts := strings.Split(ln.Addr().String(), ":") + host := parts[0] + port := parts[1] + connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) + tooLate := time.Now().Add(time.Millisecond * 500) + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50) + defer cancel() + _, err = pgconn.Connect(ctx, connStr) + require.True(t, pgconn.Timeout(err), err) + require.True(t, time.Now().Before(tooLate)) +} + func TestConnectInvalidUser(t *testing.T) { t.Parallel() @@ -85,14 +148,11 @@ func TestConnectInvalidUser(t *testing.T) { config.User = "pgxinvalidusertest" - conn, err := pgconn.ConnectConfig(context.Background(), config) - if err == nil { - conn.Close(context.Background()) - t.Fatal("expected err but got none") - } - pgErr, ok := err.(*pgconn.PgError) + _, err = pgconn.ConnectConfig(context.Background(), config) + require.Error(t, err) + pgErr, ok := errors.Unwrap(err).(*pgconn.PgError) if !ok { - t.Fatalf("Expected to receive a PgError, instead received: %v", err) + t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err) } if pgErr.Code != "28000" && pgErr.Code != "28P01" { t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) @@ -262,6 +322,14 @@ func TestConnectWithAfterConnect(t *testing.T) { assert.Equal(t, []byte("foobar"), results[0].Rows[0][0]) } +func TestConnectConfigRequiresConfigFromParseConfig(t *testing.T) { + t.Parallel() + + config := &pgconn.Config{} + + require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgconn.ConnectConfig(context.Background(), config) }) +} + func TestConnPrepareSyntaxError(t *testing.T) { t.Parallel() @@ -289,7 +357,7 @@ func TestConnPrepareContextPrecanceled(t *testing.T) { assert.Nil(t, psd) assert.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(err)) ensureConnValid(t, pgConn) } @@ -381,6 +449,34 @@ func TestConnExecMultipleQueriesError(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnExecDeferredError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); + + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) + + _, err = pgConn.Exec(context.Background(), `update t set n=n+1 where id='b' returning *`).ReadAll() + require.NotNil(t, err) + + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) + + ensureConnValid(t, pgConn) +} + func TestConnExecContextCanceled(t *testing.T) { t.Parallel() @@ -395,8 +491,8 @@ func TestConnExecContextCanceled(t *testing.T) { for multiResult.NextResult() { } err = multiResult.Close() - assert.Equal(t, context.DeadlineExceeded, err) - assert.False(t, pgConn.IsAlive()) + assert.True(t, pgconn.Timeout(err)) + assert.True(t, pgConn.IsClosed()) } func TestConnExecContextPrecanceled(t *testing.T) { @@ -411,7 +507,7 @@ func TestConnExecContextPrecanceled(t *testing.T) { _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() assert.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(err)) ensureConnValid(t, pgConn) } @@ -437,6 +533,33 @@ func TestConnExecParams(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnExecParamsDeferredError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); + + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) + + result := pgConn.ExecParams(context.Background(), `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read() + require.NotNil(t, result.Err) + var pgErr *pgconn.PgError + require.True(t, errors.As(result.Err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) + + ensureConnValid(t, pgConn) +} + func TestConnExecParamsMaxNumberOfParams(t *testing.T) { t.Parallel() @@ -500,9 +623,9 @@ func TestConnExecParamsCanceled(t *testing.T) { assert.Equal(t, 0, rowCount) commandTag, err := result.Close() assert.Equal(t, pgconn.CommandTag(nil), commandTag) - assert.Equal(t, context.DeadlineExceeded, err) + assert.True(t, pgconn.Timeout(err)) - assert.False(t, pgConn.IsAlive()) + assert.True(t, pgConn.IsClosed()) } func TestConnExecParamsPrecanceled(t *testing.T) { @@ -517,7 +640,7 @@ func TestConnExecParamsPrecanceled(t *testing.T) { result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read() require.Error(t, result.Err) assert.True(t, errors.Is(result.Err, context.Canceled)) - assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(result.Err)) ensureConnValid(t, pgConn) } @@ -627,8 +750,8 @@ func TestConnExecPreparedCanceled(t *testing.T) { assert.Equal(t, 0, rowCount) commandTag, err := result.Close() assert.Equal(t, pgconn.CommandTag(nil), commandTag) - assert.Equal(t, context.DeadlineExceeded, err) - assert.False(t, pgConn.IsAlive()) + assert.True(t, pgconn.Timeout(err)) + assert.True(t, pgConn.IsClosed()) } func TestConnExecPreparedPrecanceled(t *testing.T) { @@ -646,7 +769,7 @@ func TestConnExecPreparedPrecanceled(t *testing.T) { result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() require.Error(t, result.Err) assert.True(t, errors.Is(result.Err, context.Canceled)) - assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(result.Err)) ensureConnValid(t, pgConn) } @@ -683,6 +806,36 @@ func TestConnExecBatch(t *testing.T) { assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) } +func TestConnExecBatchDeferredError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); + + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) + + batch := &pgconn.Batch{} + + batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil) + _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.NotNil(t, err) + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) + + ensureConnValid(t, pgConn) +} + func TestConnExecBatchPrecanceled(t *testing.T) { t.Parallel() @@ -704,7 +857,7 @@ func TestConnExecBatchPrecanceled(t *testing.T) { _, err = pgConn.ExecBatch(ctx, batch).ReadAll() require.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(err)) ensureConnValid(t, pgConn) } @@ -777,8 +930,8 @@ func TestConnLocking(t *testing.T) { mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") _, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() assert.Error(t, err) - assert.True(t, errors.Is(err, pgconn.ErrConnBusy)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.Equal(t, "conn busy", err.Error()) + assert.True(t, pgconn.SafeToRetry(err)) results, err := mrr.ReadAll() assert.NoError(t, err) @@ -935,7 +1088,7 @@ func TestConnWaitForNotificationTimeout(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) err = pgConn.WaitForNotification(ctx) cancel() - assert.True(t, errors.Is(err, context.DeadlineExceeded)) + assert.True(t, pgconn.Timeout(err)) ensureConnValid(t, pgConn) } @@ -1045,10 +1198,10 @@ func TestConnCopyToCanceled(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout") - assert.True(t, errors.Is(err, context.DeadlineExceeded)) + assert.Error(t, err) assert.Equal(t, pgconn.CommandTag(nil), res) - assert.False(t, pgConn.IsAlive()) + assert.True(t, pgConn.IsClosed()) } func TestConnCopyToPrecanceled(t *testing.T) { @@ -1065,7 +1218,7 @@ func TestConnCopyToPrecanceled(t *testing.T) { res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout") require.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(err)) assert.Equal(t, pgconn.CommandTag(nil), res) ensureConnValid(t, pgConn) @@ -1137,9 +1290,9 @@ func TestConnCopyFromCanceled(t *testing.T) { ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") cancel() assert.Equal(t, int64(0), ct.RowsAffected()) - assert.True(t, errors.Is(err, context.DeadlineExceeded)) + assert.Error(t, err) - assert.False(t, pgConn.IsAlive()) + assert.True(t, pgConn.IsClosed()) } func TestConnCopyFromPrecanceled(t *testing.T) { @@ -1173,7 +1326,7 @@ func TestConnCopyFromPrecanceled(t *testing.T) { ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") require.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(err)) assert.Equal(t, pgconn.CommandTag(nil), ct) ensureConnValid(t, pgConn) @@ -1331,6 +1484,45 @@ func TestConnCancelRequest(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnSendBytesAndReceiveMessage(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + queryMsg := pgproto3.Query{String: "select 42"} + buf := queryMsg.Encode(nil) + + err = pgConn.SendBytes(ctx, buf) + require.NoError(t, err) + + msg, err := pgConn.ReceiveMessage(ctx) + require.NoError(t, err) + _, ok := msg.(*pgproto3.RowDescription) + require.True(t, ok) + + msg, err = pgConn.ReceiveMessage(ctx) + require.NoError(t, err) + _, ok = msg.(*pgproto3.DataRow) + require.True(t, ok) + + msg, err = pgConn.ReceiveMessage(ctx) + require.NoError(t, err) + _, ok = msg.(*pgproto3.CommandComplete) + require.True(t, ok) + + msg, err = pgConn.ReceiveMessage(ctx) + require.NoError(t, err) + _, ok = msg.(*pgproto3.ReadyForQuery) + require.True(t, ok) + + ensureConnValid(t, pgConn) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) if err != nil { diff --git a/stmtcache/lru.go b/stmtcache/lru.go new file mode 100644 index 00000000..fff4d0b7 --- /dev/null +++ b/stmtcache/lru.go @@ -0,0 +1,111 @@ +package stmtcache + +import ( + "container/list" + "context" + "fmt" + "sync/atomic" + + "github.com/jackc/pgconn" +) + +var lruCount uint64 + +// LRU implements Cache with a Least Recently Used (LRU) cache. +type LRU struct { + conn *pgconn.PgConn + mode int + cap int + prepareCount int + m map[string]*list.Element + l *list.List + psNamePrefix string +} + +// NewLRU creates a new LRU. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache. +func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU { + mustBeValidMode(mode) + mustBeValidCap(cap) + + n := atomic.AddUint64(&lruCount, 1) + + return &LRU{ + conn: conn, + mode: mode, + cap: cap, + m: make(map[string]*list.Element), + l: list.New(), + psNamePrefix: fmt.Sprintf("lrupsc_%d", n), + } +} + +// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. +func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) { + if el, ok := c.m[sql]; ok { + c.l.MoveToFront(el) + return el.Value.(*pgconn.StatementDescription), nil + } + + if c.l.Len() == c.cap { + err := c.removeOldest(ctx) + if err != nil { + return nil, err + } + } + + psd, err := c.prepare(ctx, sql) + if err != nil { + return nil, err + } + + el := c.l.PushFront(psd) + c.m[sql] = el + + return psd, nil +} + +// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session. +func (c *LRU) Clear(ctx context.Context) error { + for c.l.Len() > 0 { + err := c.removeOldest(ctx) + if err != nil { + return err + } + } + + return nil +} + +// Len returns the number of cached prepared statement descriptions. +func (c *LRU) Len() int { + return c.l.Len() +} + +// Cap returns the maximum number of cached prepared statement descriptions. +func (c *LRU) Cap() int { + return c.cap +} + +// Mode returns the mode of the cache (ModePrepare or ModeDescribe) +func (c *LRU) Mode() int { + return c.mode +} + +func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.StatementDescription, error) { + var name string + if c.mode == ModePrepare { + name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount) + c.prepareCount += 1 + } + + return c.conn.Prepare(ctx, name, sql, nil) +} + +func (c *LRU) removeOldest(ctx context.Context) error { + oldest := c.l.Back() + c.l.Remove(oldest) + if c.mode == ModePrepare { + return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", oldest.Value.(*pgconn.StatementDescription).Name)).Close() + } + return nil +} diff --git a/stmtcache/lru_test.go b/stmtcache/lru_test.go new file mode 100644 index 00000000..b518364e --- /dev/null +++ b/stmtcache/lru_test.go @@ -0,0 +1,113 @@ +package stmtcache_test + +import ( + "context" + "os" + "testing" + "time" + + "github.com/jackc/pgconn" + "github.com/jackc/pgconn/stmtcache" + + "github.com/stretchr/testify/require" +) + +func TestLRUModePrepare(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer conn.Close(ctx) + + cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2) + require.EqualValues(t, 0, cache.Len()) + require.EqualValues(t, 2, cache.Cap()) + require.EqualValues(t, stmtcache.ModePrepare, cache.Mode()) + + psd, err := cache.Get(ctx, "select 1") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 1, cache.Len()) + require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 1") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 1, cache.Len()) + require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 2") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 2, cache.Len()) + require.ElementsMatch(t, []string{"select 1", "select 2"}, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 3") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 2, cache.Len()) + require.ElementsMatch(t, []string{"select 2", "select 3"}, fetchServerStatements(t, ctx, conn)) + + err = cache.Clear(ctx) + require.NoError(t, err) + require.EqualValues(t, 0, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) +} + +func TestLRUModeDescribe(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer conn.Close(ctx) + + cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2) + require.EqualValues(t, 0, cache.Len()) + require.EqualValues(t, 2, cache.Cap()) + require.EqualValues(t, stmtcache.ModeDescribe, cache.Mode()) + + psd, err := cache.Get(ctx, "select 1") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 1, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 1") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 1, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 2") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 2, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 3") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 2, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) + + err = cache.Clear(ctx) + require.NoError(t, err) + require.EqualValues(t, 0, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) +} + +func fetchServerStatements(t testing.TB, ctx context.Context, conn *pgconn.PgConn) []string { + result := conn.ExecParams(ctx, `select statement from pg_prepared_statements`, nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + var statements []string + for _, r := range result.Rows { + statements = append(statements, string(r[0])) + } + return statements +} diff --git a/stmtcache/stmtcache.go b/stmtcache/stmtcache.go new file mode 100644 index 00000000..96215799 --- /dev/null +++ b/stmtcache/stmtcache.go @@ -0,0 +1,52 @@ +// Package stmtcache is a cache that can be used to implement lazy prepared statements. +package stmtcache + +import ( + "context" + + "github.com/jackc/pgconn" +) + +const ( + ModePrepare = iota // Cache should prepare named statements. + ModeDescribe // Cache should prepare the anonymous prepared statement to only fetch the description of the statement. +) + +// Cache prepares and caches prepared statement descriptions. +type Cache interface { + // Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. + Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) + + // Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session. + Clear(ctx context.Context) error + + // Len returns the number of cached prepared statement descriptions. + Len() int + + // Cap returns the maximum number of cached prepared statement descriptions. + Cap() int + + // Mode returns the mode of the cache (ModePrepare or ModeDescribe) + Mode() int +} + +// New returns the preferred cache implementation for mode and cap. mode is either ModePrepare or ModeDescribe. cap is +// the maximum size of the cache. +func New(conn *pgconn.PgConn, mode int, cap int) Cache { + mustBeValidMode(mode) + mustBeValidCap(cap) + + return NewLRU(conn, mode, cap) +} + +func mustBeValidMode(mode int) { + if mode != ModePrepare && mode != ModeDescribe { + panic("mode must be ModePrepare or ModeDescribe") + } +} + +func mustBeValidCap(cap int) { + if cap < 1 { + panic("cache must have cap of >= 1") + } +} diff --git a/travis/install.bash b/travis/install.bash deleted file mode 100755 index 63ba875d..00000000 --- a/travis/install.bash +++ /dev/null @@ -1,14 +0,0 @@ -#!/usr/bin/env bash -set -eux - -go get -u github.com/cockroachdb/apd -go get -u github.com/shopspring/decimal -go get -u gopkg.in/inconshreveable/log15.v2 -go get -u github.com/jackc/fake -go get -u github.com/lib/pq -go get -u github.com/hashicorp/go-version -go get -u github.com/satori/go.uuid -go get -u github.com/sirupsen/logrus -go get -u github.com/pkg/errors -go get -u go.uber.org/zap -go get -u github.com/rs/zerolog