From 5d17ec41567f376fb2a78d995a3e8e262dce5c9b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 28 Dec 2018 17:09:56 -0600 Subject: [PATCH 001/290] Rename base package to pgconn --- pgconn.go | 478 +++++++++++++++++++++++++++++++++++++++++++++++++ pgconn_test.go | 34 ++++ 2 files changed, 512 insertions(+) create mode 100644 pgconn.go create mode 100644 pgconn_test.go diff --git a/pgconn.go b/pgconn.go new file mode 100644 index 00000000..c9caef42 --- /dev/null +++ b/pgconn.go @@ -0,0 +1,478 @@ +package pgconn + +import ( + "crypto/md5" + "crypto/tls" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "io" + "net" + "os" + "os/user" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/jackc/pgx/pgio" + "github.com/jackc/pgx/pgproto3" +) + +const batchBufferSize = 4096 + +// PgError represents an error reported by the PostgreSQL server. See +// http://www.postgresql.org/docs/9.3/static/protocol-error-fields.html for +// detailed field description. +type PgError struct { + Severity string + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string +} + +func (pe PgError) Error() string { + return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" +} + +// DialFunc is a function that can be used to connect to a PostgreSQL server +type DialFunc func(network, addr string) (net.Conn, error) + +// ErrTLSRefused occurs when the connection attempt requires TLS and the +// PostgreSQL server refuses to use TLS +var ErrTLSRefused = errors.New("server refused TLS connection") + +type ConnConfig struct { + Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) + Port uint16 // default: 5432 + Database string + User string // default: OS user name + Password string + TLSConfig *tls.Config // config for TLS connection -- nil disables TLS + Dial DialFunc + RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) +} + +func (cc *ConnConfig) NetworkAddress() (network, address string) { + // If host is a valid path, then address is unix socket + if _, err := os.Stat(cc.Host); err == nil { + network = "unix" + address = cc.Host + if !strings.Contains(address, "/.s.PGSQL.") { + address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10) + } + } else { + network = "tcp" + address = fmt.Sprintf("%s:%d", cc.Host, cc.Port) + } + + return network, address +} + +func (cc *ConnConfig) assignDefaults() error { + if cc.User == "" { + user, err := user.Current() + if err != nil { + return err + } + cc.User = user.Username + } + + if cc.Port == 0 { + cc.Port = 5432 + } + + if cc.Dial == nil { + defaultDialer := &net.Dialer{KeepAlive: 5 * time.Minute} + cc.Dial = defaultDialer.Dial + } + + return nil +} + +// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. +type PgConn struct { + NetConn net.Conn // the underlying TCP or unix domain socket connection + PID uint32 // backend pid + SecretKey uint32 // key to use to send a cancel query message to the server + parameterStatuses map[string]string // parameters that have been reported by the server + TxStatus byte + Frontend *pgproto3.Frontend + + Config ConnConfig + + batchBuf []byte + batchCount int32 + + pendingReadyForQueryCount int32 + + closed bool +} + +func Connect(cc ConnConfig) (*PgConn, error) { + err := cc.assignDefaults() + if err != nil { + return nil, err + } + + pgConn := new(PgConn) + pgConn.Config = cc + + pgConn.NetConn, err = cc.Dial(cc.NetworkAddress()) + if err != nil { + return nil, err + } + + pgConn.parameterStatuses = make(map[string]string) + + if cc.TLSConfig != nil { + if err := pgConn.startTLS(cc.TLSConfig); err != nil { + return nil, err + } + } + + pgConn.Frontend, err = pgproto3.NewFrontend(pgConn.NetConn, pgConn.NetConn) + if err != nil { + return nil, err + } + + startupMsg := pgproto3.StartupMessage{ + ProtocolVersion: pgproto3.ProtocolVersionNumber, + Parameters: make(map[string]string), + } + + // Copy default run-time params + for k, v := range cc.RuntimeParams { + startupMsg.Parameters[k] = v + } + + startupMsg.Parameters["user"] = cc.User + if cc.Database != "" { + startupMsg.Parameters["database"] = cc.Database + } + + if _, err := pgConn.NetConn.Write(startupMsg.Encode(nil)); err != nil { + return nil, err + } + + for { + msg, err := pgConn.ReceiveMessage() + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *pgproto3.BackendKeyData: + pgConn.PID = msg.ProcessID + pgConn.SecretKey = msg.SecretKey + case *pgproto3.Authentication: + if err = pgConn.rxAuthenticationX(msg); err != nil { + return nil, err + } + case *pgproto3.ReadyForQuery: + return pgConn, nil + case *pgproto3.ParameterStatus: + // handled by ReceiveMessage + case *pgproto3.ErrorResponse: + return nil, PgError{ + Severity: msg.Severity, + Code: msg.Code, + Message: msg.Message, + Detail: msg.Detail, + Hint: msg.Hint, + Position: msg.Position, + InternalPosition: msg.InternalPosition, + InternalQuery: msg.InternalQuery, + Where: msg.Where, + SchemaName: msg.SchemaName, + TableName: msg.TableName, + ColumnName: msg.ColumnName, + DataTypeName: msg.DataTypeName, + ConstraintName: msg.ConstraintName, + File: msg.File, + Line: msg.Line, + Routine: msg.Routine, + } + default: + return nil, errors.New("unexpected message") + } + } +} + +func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { + err = binary.Write(pgConn.NetConn, binary.BigEndian, []int32{8, 80877103}) + if err != nil { + return + } + + response := make([]byte, 1) + if _, err = io.ReadFull(pgConn.NetConn, response); err != nil { + return + } + + if response[0] != 'S' { + return ErrTLSRefused + } + + pgConn.NetConn = tls.Client(pgConn.NetConn, tlsConfig) + + return nil +} + +func (c *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { + switch msg.Type { + case pgproto3.AuthTypeOk: + case pgproto3.AuthTypeCleartextPassword: + err = c.txPasswordMessage(c.Config.Password) + case pgproto3.AuthTypeMD5Password: + digestedPassword := "md5" + hexMD5(hexMD5(c.Config.Password+c.Config.User)+string(msg.Salt[:])) + err = c.txPasswordMessage(digestedPassword) + default: + err = errors.New("Received unknown authentication message") + } + + return +} + +func (pgConn *PgConn) txPasswordMessage(password string) (err error) { + msg := &pgproto3.PasswordMessage{Password: password} + _, err = pgConn.NetConn.Write(msg.Encode(nil)) + return err +} + +func hexMD5(s string) string { + hash := md5.New() + io.WriteString(hash, s) + return hex.EncodeToString(hash.Sum(nil)) +} + +func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { + msg, err := pgConn.Frontend.Receive() + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + // Under normal circumstances pendingReadyForQueryCount will be > 0 when a + // ReadyForQuery is received. However, this is not the case on initial + // connection. + if pgConn.pendingReadyForQueryCount > 0 { + pgConn.pendingReadyForQueryCount -= 1 + } + pgConn.TxStatus = msg.TxStatus + case *pgproto3.ParameterStatus: + pgConn.parameterStatuses[msg.Name] = msg.Value + case *pgproto3.ErrorResponse: + if msg.Severity == "FATAL" { + // TODO - close pgConn + return nil, errorResponseToPgError(msg) + } + } + + return msg, nil +} + +// Close closes a connection. It is safe to call Close on a already closed +// connection. +func (pgConn *PgConn) Close() error { + if pgConn.closed { + return nil + } + pgConn.closed = true + + _, err := pgConn.NetConn.Write([]byte{'X', 0, 0, 0, 4}) + if err != nil { + pgConn.NetConn.Close() + return err + } + + _, err = pgConn.NetConn.Read(make([]byte, 1)) + if err != io.EOF { + pgConn.NetConn.Close() + return err + } + + return pgConn.NetConn.Close() +} + +// ParameterStatus returns the value of a parameter reported by the server (e.g. +// server_version). Returns an empty string for unknown parameters. +func (pgConn *PgConn) ParameterStatus(key string) string { + return pgConn.parameterStatuses[key] +} + +// CommandTag is the result of an Exec function +type CommandTag string + +// RowsAffected returns the number of rows affected. If the CommandTag was not +// for a row affecting command (e.g. "CREATE TABLE") then it returns 0. +func (ct CommandTag) RowsAffected() int64 { + s := string(ct) + index := strings.LastIndex(s, " ") + if index == -1 { + return 0 + } + n, _ := strconv.ParseInt(s[index+1:], 10, 64) + return n +} + +// SendExec enqueues the execution of sql via the PostgreSQL simple query +// protocol. sql may contain multipe queries. Multiple queries will be processed +// within a single transation. It is only sent to the PostgreSQL server when +// Flush is called. +func (pgConn *PgConn) SendExec(sql string) { + pgConn.batchBuf = appendQuery(pgConn.batchBuf, sql) + pgConn.batchCount += 1 +} + +// appendQuery appends a PostgreSQL wire protocol query message to buf and returns it. +func appendQuery(buf []byte, query string) []byte { + buf = append(buf, 'Q') + buf = pgio.AppendInt32(buf, int32(len(query)+5)) + buf = append(buf, query...) + buf = append(buf, 0) + return buf +} + +type PgResultReader struct { + pgConn *PgConn + fieldDescriptions []pgproto3.FieldDescription + rowValues [][]byte + commandTag CommandTag + err error + complete bool +} + +// GetResult returns a PgResultReader for the next result. If all results are +// consumed it returns nil. If an error occurs it will be reported on the +// returned PgResultReader. +func (pgConn *PgConn) GetResult() *PgResultReader { + if pgConn.pendingReadyForQueryCount == 0 { + return nil + } + + return &PgResultReader{pgConn: pgConn} +} + +func (rr *PgResultReader) NextRow() (present bool) { + if rr.complete { + return false + } + + for { + msg, err := rr.pgConn.ReceiveMessage() + if err != nil { + return false + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: + rr.fieldDescriptions = msg.Fields + case *pgproto3.DataRow: + rr.rowValues = msg.Values + return true + case *pgproto3.CommandComplete: + rr.commandTag = CommandTag(msg.CommandTag) + rr.complete = true + return false + case *pgproto3.ErrorResponse: + rr.err = errorResponseToPgError(msg) + rr.complete = true + return false + } + } +} + +func (rr *PgResultReader) Value(c int) []byte { + return rr.rowValues[c] +} + +// Close consumes any remaining result data and returns the command tag or +// error. +func (rr *PgResultReader) Close() (CommandTag, error) { + if rr.complete { + return rr.commandTag, rr.err + } + + for { + msg, err := rr.pgConn.ReceiveMessage() + if err != nil { + rr.err = err + rr.complete = true + return rr.commandTag, rr.err + } + + switch msg := msg.(type) { + case *pgproto3.CommandComplete: + rr.commandTag = CommandTag(msg.CommandTag) + rr.complete = true + return rr.commandTag, rr.err + case *pgproto3.ErrorResponse: + rr.err = errorResponseToPgError(msg) + rr.complete = true + return rr.commandTag, rr.err + } + } +} + +// Flush sends the enqueued execs to the server. +func (pgConn *PgConn) Flush() error { + defer pgConn.resetBatch() + + n, err := pgConn.NetConn.Write(pgConn.batchBuf) + if err != nil { + if n > 0 { + // TODO - kill connection - we sent a partial message + } + return err + } + + pgConn.pendingReadyForQueryCount += pgConn.batchCount + return nil +} + +func (pgConn *PgConn) resetBatch() { + pgConn.batchCount = 0 + if len(pgConn.batchBuf) > batchBufferSize { + pgConn.batchBuf = make([]byte, 0, batchBufferSize) + } else { + pgConn.batchBuf = pgConn.batchBuf[0:0] + } +} + +func errorResponseToPgError(msg *pgproto3.ErrorResponse) PgError { + return PgError{ + Severity: msg.Severity, + Code: msg.Code, + Message: msg.Message, + Detail: msg.Detail, + Hint: msg.Hint, + Position: msg.Position, + InternalPosition: msg.InternalPosition, + InternalQuery: msg.InternalQuery, + Where: msg.Where, + SchemaName: msg.SchemaName, + TableName: msg.TableName, + ColumnName: msg.ColumnName, + DataTypeName: msg.DataTypeName, + ConstraintName: msg.ConstraintName, + File: msg.File, + Line: msg.Line, + Routine: msg.Routine, + } +} diff --git a/pgconn_test.go b/pgconn_test.go new file mode 100644 index 00000000..dbcf2704 --- /dev/null +++ b/pgconn_test.go @@ -0,0 +1,34 @@ +package pgconn_test + +import ( + "github.com/jackc/pgx/pgconn" + + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSimple(t *testing.T) { + pgConn, err := pgconn.Connect(pgconn.ConnConfig{Host: "/var/run/postgresql", User: "jack", Database: "pgx_test"}) + require.Nil(t, err) + + pgConn.SendExec("select current_database()") + err = pgConn.Flush() + require.Nil(t, err) + + result := pgConn.GetResult() + require.NotNil(t, result) + + rowFound := result.NextRow() + assert.True(t, rowFound) + if rowFound { + assert.Equal(t, "pgx_test", string(result.Value(0))) + } + + _, err = result.Close() + assert.Nil(t, err) + + err = pgConn.Close() + assert.Nil(t, err) +} From beeb69ff0bed06647f93f4eafae419ff43fd4da1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 30 Dec 2018 16:53:57 -0600 Subject: [PATCH 002/290] Restructure connect process - Moved lots of connection logic to pgconn from pgx - Extracted pgpassfile package --- config.go | 421 +++++++++++++++++++++++++++++++++++++++++++++++++ config_test.go | 392 +++++++++++++++++++++++++++++++++++++++++++++ pgconn.go | 130 ++++++++------- pgconn_test.go | 8 +- 4 files changed, 881 insertions(+), 70 deletions(-) create mode 100644 config.go create mode 100644 config_test.go diff --git a/config.go b/config.go new file mode 100644 index 00000000..515d6356 --- /dev/null +++ b/config.go @@ -0,0 +1,421 @@ +package pgconn + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "math" + "net" + "net/url" + "os" + "os/user" + "path/filepath" + "regexp" + "strconv" + "strings" + "time" + + "github.com/jackc/pgx/pgpassfile" + "github.com/pkg/errors" +) + +// Config is the settings used to establish a connection to a PostgreSQL server. +type Config struct { + Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) + Port uint16 + Database string + User string + Password string + TLSConfig *tls.Config // nil disables TLS + DialFunc DialFunc // e.g. net.Dialer.DialContext + RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) + + Fallbacks []*FallbackConfig +} + +// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a +// network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections. +type FallbackConfig struct { + Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) + Port uint16 + TLSConfig *tls.Config // nil disables TLS +} + +// NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with +// net.Dial. +func NetworkAddress(host string, port uint16) (network, address string) { + if strings.HasPrefix(host, "/") { + network = "unix" + address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10) + } else { + network = "tcp" + address = fmt.Sprintf("%s:%d", host, port) + } + return network, address +} + +// ParseConfig builds a []*Config with similar behavior to the PostgreSQL standard C library libpq. +// It uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment +// variables. connString may be a URL or a DSN. It also may be empty to only read from the +// environment. If a password is not supplied it will attempt to read the .pgpass file. +// +// Example DSN: "user=jack password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-ca" +// +// Example URL: "postgres://jack:secret@1.2.3.4:5432/mydb?sslmode=verify-ca" +// +// Multiple configs may be returned due to sslmode settings with fallback options (e.g. +// sslmode=prefer). Future implementations may also support multiple hosts +// (https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS). +// +// ParseConfig currently recognizes the following environment variable and their parameter key word +// equivalents passed via database URL or DSN: +// +// PGHOST +// PGPORT +// PGDATABASE +// PGUSER +// PGPASSWORD +// PGPASSFILE +// PGSSLMODE +// PGSSLCERT +// PGSSLKEY +// PGSSLROOTCERT +// PGAPPNAME +// PGCONNECT_TIMEOUT +// +// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of +// environment variables. +// +// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key +// word names. They are usually but not always the environment variable name downcased and without +// the "PG" prefix. +// +// Important TLS Security Notes: +// +// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to +// "prefer" behavior if not set. +// +// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on +// what level of security each sslmode provides. +// +// "verify-ca" mode currently is treated as "verify-full". e.g. It has stronger +// security guarantees than it would with libpq. Do not rely on this behavior as it +// may be possible to match libpq in the future. If you need full security use +// "verify-full". +func ParseConfig(connString string) (*Config, error) { + settings := defaultSettings() + addEnvSettings(settings) + + if connString != "" { + // connString may be a database URL or a DSN + if strings.HasPrefix(connString, "postgres://") { + url, err := url.Parse(connString) + if err != nil { + return nil, err + } + + err = addURLSettings(settings, url) + if err != nil { + return nil, err + } + } else { + err := addDSNSettings(settings, connString) + if err != nil { + return nil, err + } + } + } + + config := &Config{ + Host: settings["host"], + Database: settings["database"], + User: settings["user"], + Password: settings["password"], + RuntimeParams: make(map[string]string), + } + + if port, err := parsePort(settings["port"]); err == nil { + config.Port = port + } else { + return nil, fmt.Errorf("invalid port: %v", settings["port"]) + } + + if connectTimeout, present := settings["connect_timeout"]; present { + dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout) + if err != nil { + return nil, err + } + config.DialFunc = dialFunc + } else { + defaultDialer := makeDefaultDialer() + config.DialFunc = defaultDialer.DialContext + } + + notRuntimeParams := map[string]struct{}{ + "host": struct{}{}, + "port": struct{}{}, + "database": struct{}{}, + "user": struct{}{}, + "password": struct{}{}, + "passfile": struct{}{}, + "connect_timeout": struct{}{}, + "sslmode": struct{}{}, + "sslkey": struct{}{}, + "sslcert": struct{}{}, + "sslrootcert": struct{}{}, + } + + for k, v := range settings { + if _, present := notRuntimeParams[k]; present { + continue + } + config.RuntimeParams[k] = v + } + + var tlsConfigs []*tls.Config + + // Ignore TLS settings if Unix domain socket like libpq + if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" { + tlsConfigs = append(tlsConfigs, nil) + } else { + var err error + tlsConfigs, err = configTLS(settings) + if err != nil { + return nil, err + } + } + + config.TLSConfig = tlsConfigs[0] + + for _, tlsConfig := range tlsConfigs[1:] { + config.Fallbacks = append(config.Fallbacks, &FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: tlsConfig, + }) + } + + passfile, err := pgpassfile.ReadPassfile(settings["passfile"]) + if err == nil { + if config.Password == "" { + host := config.Host + if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" { + host = "localhost" + } + + config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User) + } + } + + return config, nil +} + +func defaultSettings() map[string]string { + settings := make(map[string]string) + + settings["host"] = defaultHost() + settings["port"] = "5432" + + // Default to the OS user name. Purposely ignoring err getting user name from + // OS. The client application will simply have to specify the user in that + // case (which they typically will be doing anyway). + user, err := user.Current() + if err == nil { + settings["user"] = user.Username + settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass") + } + + return settings +} + +// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost +// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it +// checks the existence of common locations. +func defaultHost() string { + candidatePaths := []string{ + "/var/run/postgresql", // Debian + "/private/tmp", // OSX - homebrew + "/tmp", // standard PostgreSQL + } + + for _, path := range candidatePaths { + if _, err := os.Stat(path); err == nil { + return path + } + } + + return "localhost" +} + +func addEnvSettings(settings map[string]string) { + nameMap := map[string]string{ + "PGHOST": "host", + "PGPORT": "port", + "PGDATABASE": "database", + "PGUSER": "user", + "PGPASSWORD": "password", + "PGPASSFILE": "passfile", + "PGAPPNAME": "application_name", + "PGCONNECT_TIMEOUT": "connect_timeout", + "PGSSLMODE": "sslmode", + "PGSSLKEY": "sslkey", + "PGSSLCERT": "sslcert", + "PGSSLROOTCERT": "sslrootcert", + } + + for envname, realname := range nameMap { + value := os.Getenv(envname) + if value != "" { + settings[realname] = value + } + } +} + +func addURLSettings(settings map[string]string, url *url.URL) error { + if url.User != nil { + settings["user"] = url.User.Username() + if password, present := url.User.Password(); present { + settings["password"] = password + } + } + + parts := strings.SplitN(url.Host, ":", 2) + if parts[0] != "" { + settings["host"] = parts[0] + } + if len(parts) == 2 { + settings["port"] = parts[1] + } + + database := strings.TrimLeft(url.Path, "/") + if database != "" { + settings["database"] = database + } + + for k, v := range url.Query() { + settings[k] = v[0] + } + + return nil +} + +var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`) + +func addDSNSettings(settings map[string]string, s string) error { + m := dsnRegexp.FindAllStringSubmatch(s, -1) + + for _, b := range m { + settings[b[1]] = b[2] + } + + return nil +} + +type pgTLSArgs struct { + sslMode string + sslRootCert string + sslCert string + sslKey string +} + +// configTLS uses libpq's TLS parameters to construct []*tls.Config. It is +// necessary to allow returning multiple TLS configs as sslmode "allow" and +// "prefer" allow fallback. +func configTLS(settings map[string]string) ([]*tls.Config, error) { + host := settings["host"] + sslmode := settings["sslmode"] + sslrootcert := settings["sslrootcert"] + sslcert := settings["sslcert"] + sslkey := settings["sslkey"] + + // Match libpq default behavior + if sslmode == "" { + sslmode = "prefer" + } + + tlsConfig := &tls.Config{} + + switch sslmode { + case "disable": + return []*tls.Config{nil}, nil + case "allow", "prefer": + tlsConfig.InsecureSkipVerify = true + case "require": + tlsConfig.InsecureSkipVerify = sslrootcert == "" + case "verify-ca", "verify-full": + tlsConfig.ServerName = host + default: + return nil, errors.New("sslmode is invalid") + } + + if sslrootcert != "" { + caCertPool := x509.NewCertPool() + + caPath := sslrootcert + caCert, err := ioutil.ReadFile(caPath) + if err != nil { + return nil, errors.Wrapf(err, "unable to read CA file %q", caPath) + } + + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, errors.Wrap(err, "unable to add CA to cert pool") + } + + tlsConfig.RootCAs = caCertPool + tlsConfig.ClientCAs = caCertPool + } + + if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") { + return nil, fmt.Errorf(`both "sslcert" and "sslkey" are required`) + } + + if sslcert != "" && sslkey != "" { + cert, err := tls.LoadX509KeyPair(sslcert, sslkey) + if err != nil { + return nil, errors.Wrap(err, "unable to read cert") + } + + tlsConfig.Certificates = []tls.Certificate{cert} + } + + switch sslmode { + case "allow": + return []*tls.Config{nil, tlsConfig}, nil + case "prefer": + return []*tls.Config{tlsConfig, nil}, nil + case "require", "verify-ca", "verify-full": + return []*tls.Config{tlsConfig}, nil + default: + panic("BUG: bad sslmode should already have been caught") + } +} + +func parsePort(s string) (uint16, error) { + port, err := strconv.ParseUint(s, 10, 16) + if err != nil { + return 0, err + } + if port < 1 || port > math.MaxUint16 { + return 0, errors.New("outside range") + } + return uint16(port), nil +} + +func makeDefaultDialer() *net.Dialer { + return &net.Dialer{KeepAlive: 5 * time.Minute} +} + +func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { + timeout, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return nil, err + } + if timeout < 0 { + return nil, errors.New("negative timeout") + } + + d := makeDefaultDialer() + d.Timeout = time.Duration(timeout) * time.Second + return d.DialContext, nil +} diff --git a/config_test.go b/config_test.go new file mode 100644 index 00000000..796876f2 --- /dev/null +++ b/config_test.go @@ -0,0 +1,392 @@ +package pgconn_test + +import ( + "crypto/tls" + "fmt" + "io/ioutil" + "os" + "os/user" + "testing" + + "github.com/jackc/pgx/pgconn" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseConfig(t *testing.T) { + t.Parallel() + + var osUserName string + osUser, err := user.Current() + if err == nil { + osUserName = osUser.Username + } + + tests := []struct { + name string + connString string + config *pgconn.Config + }{ + // Test all sslmodes + { + name: "sslmode not set (prefer)", + connString: "postgres://jack:secret@localhost:5432/mydb", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "sslmode disable", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "sslmode allow", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=allow", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 5432, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, + }, + }, + }, + { + name: "sslmode prefer", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer", + config: &pgconn.Config{ + + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "sslmode require", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=require", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "sslmode verify-ca", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-ca", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ServerName: "localhost"}, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "sslmode verify-full", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-full", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ServerName: "localhost"}, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url everything", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + }, + }, + }, + { + name: "database url missing password", + connString: "postgres://jack@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url missing user and password", + connString: "postgres://localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: osUserName, + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url missing port", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url unix domain socket host", + connString: "postgres:///foo?host=/tmp", + config: &pgconn.Config{ + User: osUserName, + Host: "/tmp", + Port: 5432, + Database: "foo", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "DSN everything", + connString: "user=jack password=secret host=localhost port=5432 database=mydb sslmode=disable application_name=pgxtest search_path=myschema", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + }, + }, + }, + } + + for i, tt := range tests { + config, err := pgconn.ParseConfig(tt.connString) + if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) { + continue + } + + assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name)) + } +} + +func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) { + assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName) + assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName) + assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName) + assert.Equalf(t, expected.User, actual.User, "%s - User", testName) + assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName) + assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) + + if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { + if expected.TLSConfig != nil { + assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName) + assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName) + } + } + + if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks %v", testName) { + for i := range expected.Fallbacks { + assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i) + assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i) + + if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName) { + if expected.Fallbacks[i].TLSConfig != nil { + assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName) + assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName) + } + } + } + } +} + +func TestParseConfigEnvLibpq(t *testing.T) { + var osUserName string + osUser, err := user.Current() + if err == nil { + osUserName = osUser.Username + } + + pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT"} + + savedEnv := make(map[string]string) + for _, n := range pgEnvvars { + savedEnv[n] = os.Getenv(n) + } + defer func() { + for k, v := range savedEnv { + err := os.Setenv(k, v) + if err != nil { + t.Fatalf("Unable to restore environment: %v", err) + } + } + }() + + tests := []struct { + name string + envvars map[string]string + config *pgconn.Config + }{ + { + // not testing no environment at all as that would use default host and that can vary. + name: "PGHOST only", + envvars: map[string]string{"PGHOST": "123.123.123.123"}, + config: &pgconn.Config{ + User: osUserName, + Host: "123.123.123.123", + Port: 5432, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "123.123.123.123", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "All non-TLS environment", + envvars: map[string]string{ + "PGHOST": "123.123.123.123", + "PGPORT": "7777", + "PGDATABASE": "foo", + "PGUSER": "bar", + "PGPASSWORD": "baz", + "PGCONNECT_TIMEOUT": "10", + "PGSSLMODE": "disable", + "PGAPPNAME": "pgxtest", + }, + config: &pgconn.Config{ + Host: "123.123.123.123", + Port: 7777, + Database: "foo", + User: "bar", + Password: "baz", + TLSConfig: nil, + RuntimeParams: map[string]string{"application_name": "pgxtest"}, + }, + }, + } + + for i, tt := range tests { + for _, n := range pgEnvvars { + err := os.Unsetenv(n) + require.Nil(t, err) + } + + for k, v := range tt.envvars { + err := os.Setenv(k, v) + require.Nil(t, err) + } + + config, err := pgconn.ParseConfig("") + if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) { + continue + } + + assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name)) + } +} + +func TestParseConfigReadsPgPassfile(t *testing.T) { + tf, err := ioutil.TempFile("", "") + require.Nil(t, err) + + defer tf.Close() + defer os.Remove(tf.Name()) + + _, err = tf.Write([]byte("test1:5432:curlydb:curly:nyuknyuknyuk")) + require.Nil(t, err) + + connString := fmt.Sprintf("postgres://curly@test1:5432/curlydb?sslmode=disable&passfile=%s", tf.Name()) + expected := &pgconn.Config{ + User: "curly", + Password: "nyuknyuknyuk", + Host: "test1", + Port: 5432, + Database: "curlydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + } + + actual, err := pgconn.ParseConfig(connString) + assert.Nil(t, err) + + assertConfigsEqual(t, expected, actual, "passfile") +} diff --git a/pgconn.go b/pgconn.go index c9caef42..37a205dc 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1,20 +1,16 @@ package pgconn import ( + "context" "crypto/md5" "crypto/tls" "encoding/binary" "encoding/hex" "errors" - "fmt" "io" "net" - "os" - "os/user" - "path/filepath" "strconv" "strings" - "time" "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgproto3" @@ -23,7 +19,7 @@ import ( const batchBufferSize = 4096 // PgError represents an error reported by the PostgreSQL server. See -// http://www.postgresql.org/docs/9.3/static/protocol-error-fields.html for +// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for // detailed field description. type PgError struct { Severity string @@ -50,60 +46,12 @@ func (pe PgError) Error() string { } // DialFunc is a function that can be used to connect to a PostgreSQL server -type DialFunc func(network, addr string) (net.Conn, error) +type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) // ErrTLSRefused occurs when the connection attempt requires TLS and the // PostgreSQL server refuses to use TLS var ErrTLSRefused = errors.New("server refused TLS connection") -type ConnConfig struct { - Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) - Port uint16 // default: 5432 - Database string - User string // default: OS user name - Password string - TLSConfig *tls.Config // config for TLS connection -- nil disables TLS - Dial DialFunc - RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) -} - -func (cc *ConnConfig) NetworkAddress() (network, address string) { - // If host is a valid path, then address is unix socket - if _, err := os.Stat(cc.Host); err == nil { - network = "unix" - address = cc.Host - if !strings.Contains(address, "/.s.PGSQL.") { - address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10) - } - } else { - network = "tcp" - address = fmt.Sprintf("%s:%d", cc.Host, cc.Port) - } - - return network, address -} - -func (cc *ConnConfig) assignDefaults() error { - if cc.User == "" { - user, err := user.Current() - if err != nil { - return err - } - cc.User = user.Username - } - - if cc.Port == 0 { - cc.Port = 5432 - } - - if cc.Dial == nil { - defaultDialer := &net.Dialer{KeepAlive: 5 * time.Minute} - cc.Dial = defaultDialer.Dial - } - - return nil -} - // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { NetConn net.Conn // the underlying TCP or unix domain socket connection @@ -113,7 +61,7 @@ type PgConn struct { TxStatus byte Frontend *pgproto3.Frontend - Config ConnConfig + Config *Config batchBuf []byte batchCount int32 @@ -123,24 +71,72 @@ type PgConn struct { closed bool } -func Connect(cc ConnConfig) (*PgConn, error) { - err := cc.assignDefaults() +// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) +// to provide configuration. See documention for ParseConfig for details. ctx can be used to cancel a connect attempt. +func Connect(ctx context.Context, connString string) (*PgConn, error) { + config, err := ParseConfig(connString) if err != nil { return nil, err } - pgConn := new(PgConn) - pgConn.Config = cc + return ConnectConfig(ctx, config) +} - pgConn.NetConn, err = cc.Dial(cc.NetworkAddress()) +// Connect establishes a connection to a PostgreSQL server using config. ctx can be used to cancel a connect attempt. +// +// If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An +// authentication error will terminate the chain of attempts (like libpq: +// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise, +// if all attempts fail the last error is returned. +func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err error) { + // For convenience set a few defaults if not already set. This makes it simpler to directly construct a config. + if config.Port == 0 { + config.Port = 5432 + } + if config.DialFunc == nil { + config.DialFunc = makeDefaultDialer().DialContext + } + if config.RuntimeParams == nil { + config.RuntimeParams = make(map[string]string) + } + + // Simplify usage by treating primary config and fallbacks the same. + fallbackConfigs := []*FallbackConfig{ + { + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }, + } + fallbackConfigs = append(fallbackConfigs, config.Fallbacks...) + + for _, fc := range fallbackConfigs { + pgConn, err = connect(ctx, config, fc) + if err == nil { + return pgConn, nil + } else if err, ok := err.(PgError); ok { + return nil, err + } + } + + return nil, err +} + +func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { + pgConn := new(PgConn) + pgConn.Config = config + + var err error + network, address := NetworkAddress(config.Host, config.Port) + pgConn.NetConn, err = config.DialFunc(ctx, network, address) if err != nil { return nil, err } pgConn.parameterStatuses = make(map[string]string) - if cc.TLSConfig != nil { - if err := pgConn.startTLS(cc.TLSConfig); err != nil { + if config.TLSConfig != nil { + if err := pgConn.startTLS(config.TLSConfig); err != nil { return nil, err } } @@ -156,13 +152,13 @@ func Connect(cc ConnConfig) (*PgConn, error) { } // Copy default run-time params - for k, v := range cc.RuntimeParams { + for k, v := range config.RuntimeParams { startupMsg.Parameters[k] = v } - startupMsg.Parameters["user"] = cc.User - if cc.Database != "" { - startupMsg.Parameters["database"] = cc.Database + startupMsg.Parameters["user"] = config.User + if config.Database != "" { + startupMsg.Parameters["database"] = config.Database } if _, err := pgConn.NetConn.Write(startupMsg.Encode(nil)); err != nil { diff --git a/pgconn_test.go b/pgconn_test.go index dbcf2704..f165786e 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1,16 +1,18 @@ package pgconn_test import ( - "github.com/jackc/pgx/pgconn" - + "context" + "os" "testing" + "github.com/jackc/pgx/pgconn" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestSimple(t *testing.T) { - pgConn, err := pgconn.Connect(pgconn.ConnConfig{Host: "/var/run/postgresql", User: "jack", Database: "pgx_test"}) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) pgConn.SendExec("select current_database()") From c4080cce35dcf4f76c7807f4d0e5fd98593a9521 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 30 Dec 2018 21:10:06 -0600 Subject: [PATCH 003/290] Move connection tests to pgconn --- helper_test.go | 13 +++++ pgconn_test.go | 146 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 159 insertions(+) create mode 100644 helper_test.go diff --git a/helper_test.go b/helper_test.go new file mode 100644 index 00000000..e6a7c73b --- /dev/null +++ b/helper_test.go @@ -0,0 +1,13 @@ +package pgconn_test + +import ( + "testing" + + "github.com/jackc/pgx/pgconn" + + "github.com/stretchr/testify/require" +) + +func closeConn(t testing.TB, conn *pgconn.PgConn) { + require.Nil(t, conn.Close()) +} diff --git a/pgconn_test.go b/pgconn_test.go index f165786e..9e16e925 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -2,15 +2,161 @@ package pgconn_test import ( "context" + "crypto/tls" + "net" "os" "testing" + "github.com/jackc/pgx" "github.com/jackc/pgx/pgconn" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func TestConnect(t *testing.T) { + tests := []struct { + name string + env string + }{ + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, + {"TCP", "PGX_TEST_TCP_CONN_STRING"}, + {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, + {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + connString := os.Getenv(tt.env) + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", tt.env) + } + + conn, err := pgconn.Connect(context.Background(), connString) + require.Nil(t, err) + + err = conn.Close() + require.Nil(t, err) + }) + } +} + +// TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure +// connection. +func TestConnectTLS(t *testing.T) { + connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") + } + + conn, err := pgconn.Connect(context.Background(), connString) + require.Nil(t, err) + + if _, ok := conn.NetConn.(*tls.Conn); !ok { + t.Error("not a TLS connection") + } + + err = conn.Close() + require.Nil(t, err) +} + +func TestConnectInvalidUser(t *testing.T) { + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } + + config, err := pgconn.ParseConfig(connString) + require.Nil(t, err) + + config.User = "pgxinvalidusertest" + + conn, err := pgconn.ConnectConfig(context.Background(), config) + if err == nil { + conn.Close() + t.Fatal("expected err but got none") + } + pgErr, ok := err.(pgx.PgError) + if !ok { + t.Fatalf("Expected to receive a PgError, instead received: %v", err) + } + if pgErr.Code != "28000" && pgErr.Code != "28P01" { + t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) + } +} + +func TestConnectWithConnectionRefused(t *testing.T) { + t.Parallel() + + // Presumably nothing is listening on 127.0.0.1:1 + conn, err := pgconn.Connect(context.Background(), "host=127.0.0.1 port=1") + if err == nil { + conn.Close() + t.Fatal("Expected error establishing connection to bad port") + } +} + +func TestConnectCustomDialer(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + + dialed := false + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + dialed = true + return net.Dial(network, address) + } + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + require.True(t, dialed) + conn.Close() +} + +func TestConnectWithRuntimeParams(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + + config.RuntimeParams = map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + } + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + defer closeConn(t, conn) + + // TODO - refactor these selects once there are higher level query functions + + conn.SendExec("show application_name") + conn.SendExec("show search_path") + err = conn.Flush() + require.Nil(t, err) + + result := conn.GetResult() + require.NotNil(t, result) + + rowFound := result.NextRow() + assert.True(t, rowFound) + if rowFound { + assert.Equal(t, "pgxtest", string(result.Value(0))) + } + + _, err = result.Close() + assert.Nil(t, err) + + result = conn.GetResult() + require.NotNil(t, result) + + rowFound = result.NextRow() + assert.True(t, rowFound) + if rowFound { + assert.Equal(t, "myschema", string(result.Value(0))) + } + + _, err = result.Close() + assert.Nil(t, err) +} + func TestSimple(t *testing.T) { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) From 1836f7be464fb7ce1b69c7cec17dba86d2437634 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 11:14:13 -0600 Subject: [PATCH 004/290] Support comma separated hosts and ports like libpq Also add test and fix the fallback config implementation. --- config.go | 138 +++++++++++++++++++++++++------------------ config_test.go | 155 ++++++++++++++++++++++++++++++++++++++++++++++++- pgconn.go | 2 +- pgconn_test.go | 31 ++++++++++ 4 files changed, 267 insertions(+), 59 deletions(-) diff --git a/config.go b/config.go index 515d6356..d2001dc5 100644 --- a/config.go +++ b/config.go @@ -55,21 +55,23 @@ func NetworkAddress(host string, port uint16) (network, address string) { return network, address } -// ParseConfig builds a []*Config with similar behavior to the PostgreSQL standard C library libpq. -// It uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment -// variables. connString may be a URL or a DSN. It also may be empty to only read from the -// environment. If a password is not supplied it will attempt to read the .pgpass file. +// ParseConfig builds a []*Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same +// defaults as libpq (e.g. port=5432) and understands most PG* environment variables. connString may be a URL or a DSN. +// It also may be empty to only read from the environment. If a password is not supplied it will attempt to read the +// .pgpass file. // -// Example DSN: "user=jack password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-ca" +// Example DSN: "user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca" // -// Example URL: "postgres://jack:secret@1.2.3.4:5432/mydb?sslmode=verify-ca" +// Example URL: "postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca" // -// Multiple configs may be returned due to sslmode settings with fallback options (e.g. -// sslmode=prefer). Future implementations may also support multiple hosts -// (https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS). +// ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated +// values that will be tried in order. This can be used as part of a high availability system. See +// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. // -// ParseConfig currently recognizes the following environment variable and their parameter key word -// equivalents passed via database URL or DSN: +// Example URL: "postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb" +// +// ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed +// via database URL or DSN: // // PGHOST // PGPORT @@ -84,20 +86,18 @@ func NetworkAddress(host string, port uint16) (network, address string) { // PGAPPNAME // PGCONNECT_TIMEOUT // -// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of -// environment variables. +// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables. // -// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key -// word names. They are usually but not always the environment variable name downcased and without -// the "PG" prefix. +// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are +// usually but not always the environment variable name downcased and without the "PG" prefix. // // Important TLS Security Notes: // -// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to -// "prefer" behavior if not set. +// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if +// not set. // -// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on -// what level of security each sslmode provides. +// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of +// security each sslmode provides. // // "verify-ca" mode currently is treated as "verify-full". e.g. It has stronger // security guarantees than it would with libpq. Do not rely on this behavior as it @@ -110,12 +110,7 @@ func ParseConfig(connString string) (*Config, error) { if connString != "" { // connString may be a database URL or a DSN if strings.HasPrefix(connString, "postgres://") { - url, err := url.Parse(connString) - if err != nil { - return nil, err - } - - err = addURLSettings(settings, url) + err := addURLSettings(settings, connString) if err != nil { return nil, err } @@ -128,19 +123,12 @@ func ParseConfig(connString string) (*Config, error) { } config := &Config{ - Host: settings["host"], Database: settings["database"], User: settings["user"], Password: settings["password"], RuntimeParams: make(map[string]string), } - if port, err := parsePort(settings["port"]); err == nil { - config.Port = port - } else { - return nil, fmt.Errorf("invalid port: %v", settings["port"]) - } - if connectTimeout, present := settings["connect_timeout"]; present { dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout) if err != nil { @@ -173,28 +161,50 @@ func ParseConfig(connString string) (*Config, error) { config.RuntimeParams[k] = v } - var tlsConfigs []*tls.Config + fallbacks := []*FallbackConfig{} - // Ignore TLS settings if Unix domain socket like libpq - if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" { - tlsConfigs = append(tlsConfigs, nil) - } else { - var err error - tlsConfigs, err = configTLS(settings) + hosts := strings.Split(settings["host"], ",") + ports := strings.Split(settings["port"], ",") + + for i, host := range hosts { + var portStr string + if i < len(ports) { + portStr = ports[i] + } else { + portStr = ports[0] + } + + port, err := parsePort(portStr) if err != nil { - return nil, err + return nil, fmt.Errorf("invalid port: %v", settings["port"]) + } + + var tlsConfigs []*tls.Config + + // Ignore TLS settings if Unix domain socket like libpq + if network, _ := NetworkAddress(host, port); network == "unix" { + tlsConfigs = append(tlsConfigs, nil) + } else { + var err error + tlsConfigs, err = configTLS(settings) + if err != nil { + return nil, err + } + } + + for _, tlsConfig := range tlsConfigs { + fallbacks = append(fallbacks, &FallbackConfig{ + Host: host, + Port: port, + TLSConfig: tlsConfig, + }) } } - config.TLSConfig = tlsConfigs[0] - - for _, tlsConfig := range tlsConfigs[1:] { - config.Fallbacks = append(config.Fallbacks, &FallbackConfig{ - Host: config.Host, - Port: config.Port, - TLSConfig: tlsConfig, - }) - } + config.Host = fallbacks[0].Host + config.Port = fallbacks[0].Port + config.TLSConfig = fallbacks[0].TLSConfig + config.Fallbacks = fallbacks[1:] passfile, err := pgpassfile.ReadPassfile(settings["passfile"]) if err == nil { @@ -272,7 +282,12 @@ func addEnvSettings(settings map[string]string) { } } -func addURLSettings(settings map[string]string, url *url.URL) error { +func addURLSettings(settings map[string]string, connString string) error { + url, err := url.Parse(connString) + if err != nil { + return err + } + if url.User != nil { settings["user"] = url.User.Username() if password, present := url.User.Password(); present { @@ -280,12 +295,23 @@ func addURLSettings(settings map[string]string, url *url.URL) error { } } - parts := strings.SplitN(url.Host, ":", 2) - if parts[0] != "" { - settings["host"] = parts[0] + // Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port. + var hosts []string + var ports []string + for _, host := range strings.Split(url.Host, ",") { + parts := strings.SplitN(host, ":", 2) + if parts[0] != "" { + hosts = append(hosts, parts[0]) + } + if len(parts) == 2 { + ports = append(ports, parts[1]) + } } - if len(parts) == 2 { - settings["port"] = parts[1] + if len(hosts) > 0 { + settings["host"] = strings.Join(hosts, ",") + } + if len(ports) > 0 { + settings["port"] = strings.Join(ports, ",") } database := strings.TrimLeft(url.Path, "/") diff --git a/config_test.go b/config_test.go index 796876f2..566a44f0 100644 --- a/config_test.go +++ b/config_test.go @@ -230,6 +230,150 @@ func TestParseConfig(t *testing.T) { }, }, }, + { + name: "URL multiple hosts", + connString: "postgres://jack:secret@foo,bar,baz/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "bar", + Port: 5432, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "URL multiple hosts and ports", + connString: "postgres://jack:secret@foo:1,bar:2,baz:3/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: 1, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "bar", + Port: 2, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 3, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "DSN multiple hosts one port", + connString: "user=jack password=secret host=foo,bar,baz port=5432 database=mydb sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "bar", + Port: 5432, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "DSN multiple hosts multiple ports", + connString: "user=jack password=secret host=foo,bar,baz port=1,2,3 database=mydb sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: 1, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "bar", + Port: 2, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 3, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "multiple hosts and fallback tsl", + connString: "user=jack password=secret host=foo,bar,baz database=mydb sslmode=prefer", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "foo", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "foo", + Port: 5432, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "bar", + Port: 5432, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }}, + &pgconn.FallbackConfig{ + Host: "bar", + Port: 5432, + TLSConfig: nil, + }, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 5432, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }}, + &pgconn.FallbackConfig{ + Host: "baz", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, } for i, tt := range tests { @@ -243,6 +387,13 @@ func TestParseConfig(t *testing.T) { } func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) { + if !assert.NotNil(t, expected) { + return + } + if !assert.NotNil(t, actual) { + return + } + assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName) assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName) assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName) @@ -257,12 +408,12 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName } } - if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks %v", testName) { + if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) { for i := range expected.Fallbacks { assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i) assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i) - if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName) { + if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) { if expected.Fallbacks[i].TLSConfig != nil { assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName) assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName) diff --git a/pgconn.go b/pgconn.go index 37a205dc..09860eb2 100644 --- a/pgconn.go +++ b/pgconn.go @@ -127,7 +127,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.Config = config var err error - network, address := NetworkAddress(config.Host, config.Port) + network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) pgConn.NetConn, err = config.DialFunc(ctx, network, address) if err != nil { return nil, err diff --git a/pgconn_test.go b/pgconn_test.go index 9e16e925..d53bbc09 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -157,6 +157,37 @@ func TestConnectWithRuntimeParams(t *testing.T) { assert.Nil(t, err) } +func TestConnectWithFallback(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + + // Prepend current primary config to fallbacks + config.Fallbacks = append([]*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }, + }, config.Fallbacks...) + + // Make primary config bad + config.Host = "localhost" + config.Port = 1 // presumably nothing listening here + + // Prepend bad first fallback + config.Fallbacks = append([]*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 1, + TLSConfig: config.TLSConfig, + }, + }, config.Fallbacks...) + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + closeConn(t, conn) +} + func TestSimple(t *testing.T) { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) From 5ae6310b058d73bc8fe19e6e71a857a4d3796eff Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 11:39:22 -0600 Subject: [PATCH 005/290] Add AcceptConnFunc for filtering HA connections --- config.go | 7 +++++++ pgconn.go | 11 ++++++++++- pgconn_test.go | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 51 insertions(+), 1 deletion(-) diff --git a/config.go b/config.go index d2001dc5..a07fa533 100644 --- a/config.go +++ b/config.go @@ -20,6 +20,8 @@ import ( "github.com/pkg/errors" ) +type AcceptConnFunc func(pgconn *PgConn) bool + // Config is the settings used to establish a connection to a PostgreSQL server. type Config struct { Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) @@ -32,6 +34,11 @@ type Config struct { RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) Fallbacks []*FallbackConfig + + // AcceptConnFunc is called after successful connection allow custom logic for determining if the connection is + // acceptable. If AcceptConnFunc returns false the connection is closed and the next fallback config is tried. This + // allows implementing high availability behavior such as libpq does with target_session_attrs. + AcceptConnFunc AcceptConnFunc } // FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a diff --git a/pgconn.go b/pgconn.go index 09860eb2..ac48f870 100644 --- a/pgconn.go +++ b/pgconn.go @@ -137,6 +137,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if config.TLSConfig != nil { if err := pgConn.startTLS(config.TLSConfig); err != nil { + pgConn.NetConn.Close() return nil, err } } @@ -162,6 +163,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } if _, err := pgConn.NetConn.Write(startupMsg.Encode(nil)); err != nil { + pgConn.NetConn.Close() return nil, err } @@ -177,13 +179,19 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.SecretKey = msg.SecretKey case *pgproto3.Authentication: if err = pgConn.rxAuthenticationX(msg); err != nil { + pgConn.NetConn.Close() return nil, err } case *pgproto3.ReadyForQuery: - return pgConn, nil + if config.AcceptConnFunc == nil || config.AcceptConnFunc(pgConn) { + return pgConn, nil + } + pgConn.NetConn.Close() + return nil, errors.New("AcceptConnFunc rejected connection") case *pgproto3.ParameterStatus: // handled by ReceiveMessage case *pgproto3.ErrorResponse: + pgConn.NetConn.Close() return nil, PgError{ Severity: msg.Severity, Code: msg.Code, @@ -204,6 +212,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig Routine: msg.Routine, } default: + pgConn.NetConn.Close() return nil, errors.New("unexpected message") } } diff --git a/pgconn_test.go b/pgconn_test.go index d53bbc09..ad06ae7b 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -188,6 +188,40 @@ func TestConnectWithFallback(t *testing.T) { closeConn(t, conn) } +func TestConnectWithAcceptConnFunc(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + + dialCount := 0 + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + dialCount += 1 + return net.Dial(network, address) + } + + acceptConnCount := 0 + config.AcceptConnFunc = func(conn *pgconn.PgConn) bool { + acceptConnCount += 1 + return acceptConnCount > 1 + } + + // Append current primary config to fallbacks + config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }) + + // Repeat fallbacks + config.Fallbacks = append(config.Fallbacks, config.Fallbacks...) + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + closeConn(t, conn) + + assert.True(t, dialCount > 1) + assert.True(t, acceptConnCount > 1) +} + func TestSimple(t *testing.T) { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) From 8c574c39f830d10c0b5a1c4ad46cc1e010646071 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 12:14:41 -0600 Subject: [PATCH 006/290] Add support for libpq target_session_attrs Generalize AcceptConnFunc into AfterConnectFunc. --- config.go | 93 +++++++++++++++++++++++++++++++++++--------------- config_test.go | 17 +++++++++ pgconn.go | 12 ++++--- pgconn_test.go | 23 +++++++++++-- 4 files changed, 111 insertions(+), 34 deletions(-) diff --git a/config.go b/config.go index a07fa533..38144be7 100644 --- a/config.go +++ b/config.go @@ -20,7 +20,7 @@ import ( "github.com/pkg/errors" ) -type AcceptConnFunc func(pgconn *PgConn) bool +type AfterConnectFunc func(pgconn *PgConn) error // Config is the settings used to establish a connection to a PostgreSQL server. type Config struct { @@ -35,10 +35,10 @@ type Config struct { Fallbacks []*FallbackConfig - // AcceptConnFunc is called after successful connection allow custom logic for determining if the connection is - // acceptable. If AcceptConnFunc returns false the connection is closed and the next fallback config is tried. This + // AfterConnectFunc is called after successful connection. It can be used to set up the connection or to validate that + // server is acceptable. If this returns an error the connection is closed and the next fallback config is tried. This // allows implementing high availability behavior such as libpq does with target_session_attrs. - AcceptConnFunc AcceptConnFunc + AfterConnectFunc AfterConnectFunc } // FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a @@ -92,6 +92,7 @@ func NetworkAddress(host string, port uint16) (network, address string) { // PGSSLROOTCERT // PGAPPNAME // PGCONNECT_TIMEOUT +// PGTARGETSESSIONATTRS // // See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables. // @@ -148,17 +149,18 @@ func ParseConfig(connString string) (*Config, error) { } notRuntimeParams := map[string]struct{}{ - "host": struct{}{}, - "port": struct{}{}, - "database": struct{}{}, - "user": struct{}{}, - "password": struct{}{}, - "passfile": struct{}{}, - "connect_timeout": struct{}{}, - "sslmode": struct{}{}, - "sslkey": struct{}{}, - "sslcert": struct{}{}, - "sslrootcert": struct{}{}, + "host": struct{}{}, + "port": struct{}{}, + "database": struct{}{}, + "user": struct{}{}, + "password": struct{}{}, + "passfile": struct{}{}, + "connect_timeout": struct{}{}, + "sslmode": struct{}{}, + "sslkey": struct{}{}, + "sslcert": struct{}{}, + "sslrootcert": struct{}{}, + "target_session_attrs": struct{}{}, } for k, v := range settings { @@ -225,6 +227,12 @@ func ParseConfig(connString string) (*Config, error) { } } + if settings["target_session_attrs"] == "read-write" { + config.AfterConnectFunc = AfterConnectTargetSessionAttrsReadWrite + } else if settings["target_session_attrs"] != "any" { + return nil, fmt.Errorf("unknown target_session_attrs value %v", settings["target_session_attrs"]) + } + return config, nil } @@ -243,6 +251,8 @@ func defaultSettings() map[string]string { settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass") } + settings["target_session_attrs"] = "any" + return settings } @@ -267,18 +277,19 @@ func defaultHost() string { func addEnvSettings(settings map[string]string) { nameMap := map[string]string{ - "PGHOST": "host", - "PGPORT": "port", - "PGDATABASE": "database", - "PGUSER": "user", - "PGPASSWORD": "password", - "PGPASSFILE": "passfile", - "PGAPPNAME": "application_name", - "PGCONNECT_TIMEOUT": "connect_timeout", - "PGSSLMODE": "sslmode", - "PGSSLKEY": "sslkey", - "PGSSLCERT": "sslcert", - "PGSSLROOTCERT": "sslrootcert", + "PGHOST": "host", + "PGPORT": "port", + "PGDATABASE": "database", + "PGUSER": "user", + "PGPASSWORD": "password", + "PGPASSFILE": "passfile", + "PGAPPNAME": "application_name", + "PGCONNECT_TIMEOUT": "connect_timeout", + "PGSSLMODE": "sslmode", + "PGSSLKEY": "sslkey", + "PGSSLCERT": "sslcert", + "PGSSLROOTCERT": "sslrootcert", + "PGTARGETSESSIONATTRS": "target_session_attrs", } for envname, realname := range nameMap { @@ -452,3 +463,31 @@ func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { d.Timeout = time.Duration(timeout) * time.Second return d.DialContext, nil } + +// AfterConnectTargetSessionAttrsReadWrite is an AfterConnectFunc that implements libpq compatible +// target_session_attrs=read-write. +func AfterConnectTargetSessionAttrsReadWrite(pgConn *PgConn) error { + pgConn.SendExec("show transaction_read_only") + err := pgConn.Flush() + if err != nil { + return err + } + + result := pgConn.GetResult() + if err != nil { + return err + } + + rowFound := result.NextRow() + if !rowFound { + return errors.New("show transaction_read_only failed") + } + + if string(result.Value(0)) == "on" { + return errors.New("read only connection") + } + + _, err = result.Close() + + return err +} diff --git a/config_test.go b/config_test.go index 566a44f0..36f3fee2 100644 --- a/config_test.go +++ b/config_test.go @@ -374,6 +374,20 @@ func TestParseConfig(t *testing.T) { }, }, }, + { + name: "target_session_attrs", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-write", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + AfterConnectFunc: pgconn.AfterConnectTargetSessionAttrsReadWrite, + }, + }, } for i, tt := range tests { @@ -401,6 +415,9 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName) assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) + // Can't test function equality, so just test that they are set or not. + assert.Equalf(t, expected.AfterConnectFunc == nil, actual.AfterConnectFunc == nil, "%s - AfterConnectFunc", testName) + if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { if expected.TLSConfig != nil { assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName) diff --git a/pgconn.go b/pgconn.go index ac48f870..94397759 100644 --- a/pgconn.go +++ b/pgconn.go @@ -7,6 +7,7 @@ import ( "encoding/binary" "encoding/hex" "errors" + "fmt" "io" "net" "strconv" @@ -183,11 +184,14 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig return nil, err } case *pgproto3.ReadyForQuery: - if config.AcceptConnFunc == nil || config.AcceptConnFunc(pgConn) { - return pgConn, nil + if config.AfterConnectFunc != nil { + err := config.AfterConnectFunc(pgConn) + if err != nil { + pgConn.NetConn.Close() + return nil, fmt.Errorf("AfterConnectFunc: %v", err) + } } - pgConn.NetConn.Close() - return nil, errors.New("AcceptConnFunc rejected connection") + return pgConn, nil case *pgproto3.ParameterStatus: // handled by ReceiveMessage case *pgproto3.ErrorResponse: diff --git a/pgconn_test.go b/pgconn_test.go index ad06ae7b..0dccc99f 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -9,6 +9,7 @@ import ( "github.com/jackc/pgx" "github.com/jackc/pgx/pgconn" + "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -188,7 +189,7 @@ func TestConnectWithFallback(t *testing.T) { closeConn(t, conn) } -func TestConnectWithAcceptConnFunc(t *testing.T) { +func TestConnectWithAfterConnectFunc(t *testing.T) { config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) @@ -199,9 +200,12 @@ func TestConnectWithAcceptConnFunc(t *testing.T) { } acceptConnCount := 0 - config.AcceptConnFunc = func(conn *pgconn.PgConn) bool { + config.AfterConnectFunc = func(conn *pgconn.PgConn) error { acceptConnCount += 1 - return acceptConnCount > 1 + if acceptConnCount < 2 { + return errors.New("reject first conn") + } + return nil } // Append current primary config to fallbacks @@ -222,6 +226,19 @@ func TestConnectWithAcceptConnFunc(t *testing.T) { assert.True(t, acceptConnCount > 1) } +func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + + config.AfterConnectFunc = pgconn.AfterConnectTargetSessionAttrsReadWrite + config.RuntimeParams["default_transaction_read_only"] = "on" + + conn, err := pgconn.ConnectConfig(context.Background(), config) + if !assert.NotNil(t, err) { + conn.Close() + } +} + func TestSimple(t *testing.T) { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) From b419493e5ca130ab9c4eae74eb64c23467d73843 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 13:32:26 -0600 Subject: [PATCH 007/290] Add pgconn.Exec --- config.go | 2 +- pgconn.go | 109 +++++++++++++++++++++++++++++++++++++++++-------- pgconn_test.go | 84 +++++++++++++++++-------------------- 3 files changed, 132 insertions(+), 63 deletions(-) diff --git a/config.go b/config.go index 38144be7..a446a67e 100644 --- a/config.go +++ b/config.go @@ -483,7 +483,7 @@ func AfterConnectTargetSessionAttrsReadWrite(pgConn *PgConn) error { return errors.New("show transaction_read_only failed") } - if string(result.Value(0)) == "on" { + if string(result.Values()[0]) == "on" { return errors.New("read only connection") } diff --git a/pgconn.go b/pgconn.go index 94397759..c243d2f6 100644 --- a/pgconn.go +++ b/pgconn.go @@ -340,10 +340,9 @@ func (ct CommandTag) RowsAffected() int64 { return n } -// SendExec enqueues the execution of sql via the PostgreSQL simple query -// protocol. sql may contain multipe queries. Multiple queries will be processed -// within a single transation. It is only sent to the PostgreSQL server when -// Flush is called. +// SendExec enqueues the execution of sql via the PostgreSQL simple query protocol. sql may contain multiple queries. +// Execution is implicitly wrapped in a transactions unless a transaction is already in progress or sql contains +// transaction control statements. It is only sent to the PostgreSQL server when Flush is called. func (pgConn *PgConn) SendExec(sql string) { pgConn.batchBuf = appendQuery(pgConn.batchBuf, sql) pgConn.batchCount += 1 @@ -359,30 +358,51 @@ func appendQuery(buf []byte, query string) []byte { } type PgResultReader struct { - pgConn *PgConn - fieldDescriptions []pgproto3.FieldDescription - rowValues [][]byte - commandTag CommandTag - err error - complete bool + pgConn *PgConn + fieldDescriptions []pgproto3.FieldDescription + rowValues [][]byte + commandTag CommandTag + err error + complete bool + preloadedRowValues bool } // GetResult returns a PgResultReader for the next result. If all results are // consumed it returns nil. If an error occurs it will be reported on the // returned PgResultReader. func (pgConn *PgConn) GetResult() *PgResultReader { - if pgConn.pendingReadyForQueryCount == 0 { - return nil + for pgConn.pendingReadyForQueryCount > 0 { + msg, err := pgConn.ReceiveMessage() + if err != nil { + return &PgResultReader{pgConn: pgConn, err: err, complete: true} + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: + return &PgResultReader{pgConn: pgConn, fieldDescriptions: msg.Fields} + case *pgproto3.DataRow: + return &PgResultReader{pgConn: pgConn, rowValues: msg.Values, preloadedRowValues: true} + case *pgproto3.CommandComplete: + return &PgResultReader{pgConn: pgConn, commandTag: CommandTag(msg.CommandTag), complete: true} + case *pgproto3.ErrorResponse: + return &PgResultReader{pgConn: pgConn, err: errorResponseToPgError(msg), complete: true} + } } - return &PgResultReader{pgConn: pgConn} + return nil } -func (rr *PgResultReader) NextRow() (present bool) { +// NextRow returns advances the PgResultReader to the next row and returns true if a row is available. +func (rr *PgResultReader) NextRow() bool { if rr.complete { return false } + if rr.preloadedRowValues { + rr.preloadedRowValues = false + return true + } + for { msg, err := rr.pgConn.ReceiveMessage() if err != nil { @@ -396,6 +416,7 @@ func (rr *PgResultReader) NextRow() (present bool) { rr.rowValues = msg.Values return true case *pgproto3.CommandComplete: + rr.rowValues = nil rr.commandTag = CommandTag(msg.CommandTag) rr.complete = true return false @@ -407,8 +428,11 @@ func (rr *PgResultReader) NextRow() (present bool) { } } -func (rr *PgResultReader) Value(c int) []byte { - return rr.rowValues[c] +// Values returns the current row data. NextRow must have been previously been called. The returned [][]byte is only +// valid until the next NextRow call or the PgResultReader is closed. However, the underlying byte data is safe to +// retain a reference to and mutate. +func (rr *PgResultReader) Values() [][]byte { + return rr.rowValues } // Close consumes any remaining result data and returns the command tag or @@ -418,6 +442,8 @@ func (rr *PgResultReader) Close() (CommandTag, error) { return rr.commandTag, rr.err } + rr.rowValues = nil + for { msg, err := rr.pgConn.ReceiveMessage() if err != nil { @@ -464,6 +490,57 @@ func (pgConn *PgConn) resetBatch() { } } +type PgResult struct { + Rows [][][]byte + CommandTag CommandTag +} + +// Exec executes sql via the PostgreSQL simple query protocol, buffers the entire result, and returns it. sql may +// contain multiple queries, but only the last results will be returned. Execution is implicitly wrapped in a +// transactions unless a transaction is already in progress or sql contains transaction control statements. +// +// Exec must not be called when there are pending results from previous Send* methods (e.g. SendExec). +func (pgConn *PgConn) Exec(sql string) (*PgResult, error) { + if pgConn.batchCount != 0 { + return nil, errors.New("unflushed previous sends") + } + if pgConn.pendingReadyForQueryCount != 0 { + return nil, errors.New("unread previous results") + } + + pgConn.SendExec(sql) + err := pgConn.Flush() + if err != nil { + return nil, err + } + + var result *PgResult + + for resultReader := pgConn.GetResult(); resultReader != nil; resultReader = pgConn.GetResult() { + rows := [][][]byte{} + for resultReader.NextRow() { + row := make([][]byte, len(resultReader.Values())) + copy(row, resultReader.Values()) + rows = append(rows, row) + } + + commandTag, err := resultReader.Close() + if err != nil { + return nil, err + } + + result = &PgResult{ + Rows: rows, + CommandTag: commandTag, + } + } + if result == nil { + return nil, errors.New("unexpected missing result") + } + + return result, nil +} + func errorResponseToPgError(msg *pgproto3.ErrorResponse) PgError { return PgError{ Severity: msg.Severity, diff --git a/pgconn_test.go b/pgconn_test.go index 0dccc99f..f3f22d42 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -126,36 +126,15 @@ func TestConnectWithRuntimeParams(t *testing.T) { require.Nil(t, err) defer closeConn(t, conn) - // TODO - refactor these selects once there are higher level query functions - - conn.SendExec("show application_name") - conn.SendExec("show search_path") - err = conn.Flush() + result, err := conn.Exec("show application_name") require.Nil(t, err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "pgxtest", string(result.Rows[0][0])) - result := conn.GetResult() - require.NotNil(t, result) - - rowFound := result.NextRow() - assert.True(t, rowFound) - if rowFound { - assert.Equal(t, "pgxtest", string(result.Value(0))) - } - - _, err = result.Close() - assert.Nil(t, err) - - result = conn.GetResult() - require.NotNil(t, result) - - rowFound = result.NextRow() - assert.True(t, rowFound) - if rowFound { - assert.Equal(t, "myschema", string(result.Value(0))) - } - - _, err = result.Close() - assert.Nil(t, err) + result, err = conn.Exec("show search_path") + require.Nil(t, err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "myschema", string(result.Rows[0][0])) } func TestConnectWithFallback(t *testing.T) { @@ -239,26 +218,39 @@ func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { } } -func TestSimple(t *testing.T) { +func TestExec(t *testing.T) { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) + defer closeConn(t, pgConn) - pgConn.SendExec("select current_database()") - err = pgConn.Flush() + result, err := pgConn.Exec("select current_database()") require.Nil(t, err) - - result := pgConn.GetResult() - require.NotNil(t, result) - - rowFound := result.NextRow() - assert.True(t, rowFound) - if rowFound { - assert.Equal(t, "pgx_test", string(result.Value(0))) - } - - _, err = result.Close() - assert.Nil(t, err) - - err = pgConn.Close() - assert.Nil(t, err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, pgConn.Config.Database, string(result.Rows[0][0])) +} + +func TestExecMultipleQueries(t *testing.T) { + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + result, err := pgConn.Exec("select current_database(); select 1") + require.Nil(t, err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "1", string(result.Rows[0][0])) +} + +func TestExecMultipleQueriesError(t *testing.T) { + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + result, err := pgConn.Exec("select 1; select 1/0; select 1") + require.NotNil(t, err) + require.Nil(t, result) + if pgErr, ok := err.(pgconn.PgError); ok { + assert.Equal(t, "22012", pgErr.Code) + } else { + t.Errorf("unexpected error: %v", err) + } } From 4e12c08b04a441cfbedf8c0f7dcaac9414ca9f26 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 14:14:40 -0600 Subject: [PATCH 008/290] Use buffered exec --- config.go | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/config.go b/config.go index a446a67e..4d8bee4c 100644 --- a/config.go +++ b/config.go @@ -467,27 +467,14 @@ func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { // AfterConnectTargetSessionAttrsReadWrite is an AfterConnectFunc that implements libpq compatible // target_session_attrs=read-write. func AfterConnectTargetSessionAttrsReadWrite(pgConn *PgConn) error { - pgConn.SendExec("show transaction_read_only") - err := pgConn.Flush() + result, err := pgConn.Exec("show transaction_read_only") if err != nil { return err } - result := pgConn.GetResult() - if err != nil { - return err - } - - rowFound := result.NextRow() - if !rowFound { - return errors.New("show transaction_read_only failed") - } - - if string(result.Values()[0]) == "on" { + if string(result.Rows[0][0]) == "on" { return errors.New("read only connection") } - _, err = result.Close() - - return err + return nil } From 4ee6fef45286e5a0056b0d07a1a388be151b92cd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 17:17:11 -0600 Subject: [PATCH 009/290] Add context to potentially blocking methods --- config.go | 7 ++- helper_test.go | 6 +- pgconn.go | 165 ++++++++++++++++++++++++++++++++++++++++--------- pgconn_test.go | 67 +++++++++++++++----- 4 files changed, 195 insertions(+), 50 deletions(-) diff --git a/config.go b/config.go index 4d8bee4c..d8872f66 100644 --- a/config.go +++ b/config.go @@ -1,6 +1,7 @@ package pgconn import ( + "context" "crypto/tls" "crypto/x509" "fmt" @@ -20,7 +21,7 @@ import ( "github.com/pkg/errors" ) -type AfterConnectFunc func(pgconn *PgConn) error +type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error // Config is the settings used to establish a connection to a PostgreSQL server. type Config struct { @@ -466,8 +467,8 @@ func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { // AfterConnectTargetSessionAttrsReadWrite is an AfterConnectFunc that implements libpq compatible // target_session_attrs=read-write. -func AfterConnectTargetSessionAttrsReadWrite(pgConn *PgConn) error { - result, err := pgConn.Exec("show transaction_read_only") +func AfterConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { + result, err := pgConn.Exec(ctx, "show transaction_read_only") if err != nil { return err } diff --git a/helper_test.go b/helper_test.go index e6a7c73b..8e7ca92f 100644 --- a/helper_test.go +++ b/helper_test.go @@ -1,7 +1,9 @@ package pgconn_test import ( + "context" "testing" + "time" "github.com/jackc/pgx/pgconn" @@ -9,5 +11,7 @@ import ( ) func closeConn(t testing.TB, conn *pgconn.PgConn) { - require.Nil(t, conn.Close()) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.Nil(t, conn.Close(ctx)) } diff --git a/pgconn.go b/pgconn.go index c243d2f6..311b06a3 100644 --- a/pgconn.go +++ b/pgconn.go @@ -12,6 +12,7 @@ import ( "net" "strconv" "strings" + "time" "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgproto3" @@ -19,6 +20,8 @@ import ( const batchBufferSize = 4096 +var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC) + // PgError represents an error reported by the PostgreSQL server. See // http://www.postgresql.org/docs/11/static/protocol-error-fields.html for // detailed field description. @@ -185,7 +188,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } case *pgproto3.ReadyForQuery: if config.AfterConnectFunc != nil { - err := config.AfterConnectFunc(pgConn) + err := config.AfterConnectFunc(ctx, pgConn) if err != nil { pgConn.NetConn.Close() return nil, fmt.Errorf("AfterConnectFunc: %v", err) @@ -296,24 +299,28 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { return msg, nil } -// Close closes a connection. It is safe to call Close on a already closed -// connection. -func (pgConn *PgConn) Close() error { +// Close closes a connection. It is safe to call Close on a already closed connection. Close attempts a clean close by +// sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The +// underlying net.Conn.Close() will always be called regardless of any other errors. +func (pgConn *PgConn) Close(ctx context.Context) error { if pgConn.closed { return nil } pgConn.closed = true + defer pgConn.NetConn.Close() + + cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn) + defer cleanupContext() + _, err := pgConn.NetConn.Write([]byte{'X', 0, 0, 0, 4}) if err != nil { - pgConn.NetConn.Close() - return err + return preferContextOverNetTimeoutError(ctx, err) } _, err = pgConn.NetConn.Read(make([]byte, 1)) if err != io.EOF { - pgConn.NetConn.Close() - return err + return preferContextOverNetTimeoutError(ctx, err) } return pgConn.NetConn.Close() @@ -365,30 +372,38 @@ type PgResultReader struct { err error complete bool preloadedRowValues bool + ctx context.Context + cleanupContext func() } // GetResult returns a PgResultReader for the next result. If all results are // consumed it returns nil. If an error occurs it will be reported on the // returned PgResultReader. -func (pgConn *PgConn) GetResult() *PgResultReader { +func (pgConn *PgConn) GetResult(ctx context.Context) *PgResultReader { + cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn) + for pgConn.pendingReadyForQueryCount > 0 { msg, err := pgConn.ReceiveMessage() if err != nil { - return &PgResultReader{pgConn: pgConn, err: err, complete: true} + cleanupContext() + return &PgResultReader{pgConn: pgConn, ctx: ctx, err: preferContextOverNetTimeoutError(ctx, err), complete: true} } switch msg := msg.(type) { case *pgproto3.RowDescription: - return &PgResultReader{pgConn: pgConn, fieldDescriptions: msg.Fields} + return &PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, fieldDescriptions: msg.Fields} case *pgproto3.DataRow: - return &PgResultReader{pgConn: pgConn, rowValues: msg.Values, preloadedRowValues: true} + return &PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, rowValues: msg.Values, preloadedRowValues: true} case *pgproto3.CommandComplete: - return &PgResultReader{pgConn: pgConn, commandTag: CommandTag(msg.CommandTag), complete: true} + cleanupContext() + return &PgResultReader{pgConn: pgConn, ctx: ctx, commandTag: CommandTag(msg.CommandTag), complete: true} case *pgproto3.ErrorResponse: - return &PgResultReader{pgConn: pgConn, err: errorResponseToPgError(msg), complete: true} + cleanupContext() + return &PgResultReader{pgConn: pgConn, ctx: ctx, err: errorResponseToPgError(msg), complete: true} } } + cleanupContext() return nil } @@ -406,6 +421,8 @@ func (rr *PgResultReader) NextRow() bool { for { msg, err := rr.pgConn.ReceiveMessage() if err != nil { + rr.err = preferContextOverNetTimeoutError(rr.ctx, err) + rr.close() return false } @@ -416,13 +433,12 @@ func (rr *PgResultReader) NextRow() bool { rr.rowValues = msg.Values return true case *pgproto3.CommandComplete: - rr.rowValues = nil rr.commandTag = CommandTag(msg.CommandTag) - rr.complete = true + rr.close() return false case *pgproto3.ErrorResponse: rr.err = errorResponseToPgError(msg) - rr.complete = true + rr.close() return false } } @@ -441,46 +457,137 @@ func (rr *PgResultReader) Close() (CommandTag, error) { if rr.complete { return rr.commandTag, rr.err } - - rr.rowValues = nil + defer rr.close() for { msg, err := rr.pgConn.ReceiveMessage() if err != nil { - rr.err = err - rr.complete = true + rr.err = preferContextOverNetTimeoutError(rr.ctx, err) return rr.commandTag, rr.err } switch msg := msg.(type) { case *pgproto3.CommandComplete: rr.commandTag = CommandTag(msg.CommandTag) - rr.complete = true return rr.commandTag, rr.err case *pgproto3.ErrorResponse: rr.err = errorResponseToPgError(msg) - rr.complete = true return rr.commandTag, rr.err } } } +func (rr *PgResultReader) close() { + if rr.complete { + return + } + + rr.cleanupContext() + rr.rowValues = nil + rr.complete = true +} + // Flush sends the enqueued execs to the server. -func (pgConn *PgConn) Flush() error { +func (pgConn *PgConn) Flush(ctx context.Context) error { defer pgConn.resetBatch() + cleanup := contextDoneToConnDeadline(ctx, pgConn.NetConn) + defer cleanup() + n, err := pgConn.NetConn.Write(pgConn.batchBuf) if err != nil { if n > 0 { - // TODO - kill connection - we sent a partial message + // Close connection because cannot recover from partially sent message. + pgConn.NetConn.Close() + pgConn.closed = true } - return err + return preferContextOverNetTimeoutError(ctx, err) } pgConn.pendingReadyForQueryCount += pgConn.batchCount return nil } +// contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from +// ctx.Done(). The returned cleanup function must be called to terminate this goroutine. The cleanup function is safe to +// call multiple times. +func contextDoneToConnDeadline(ctx context.Context, conn net.Conn) (cleanup func()) { + if ctx.Done() != nil { + deadlineWasSet := false + doneChan := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + conn.SetDeadline(deadlineTime) + deadlineWasSet = true + <-doneChan + // TODO + case <-doneChan: + } + }() + + finished := false + return func() { + if !finished { + doneChan <- struct{}{} + if deadlineWasSet { + conn.SetDeadline(time.Time{}) + } + finished = true + } + } + } + + return func() {} +} + +// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == +// true. Otherwise returns err. +func preferContextOverNetTimeoutError(ctx context.Context, err error) error { + if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil { + return ctx.Err() + } + return err +} + +// RecoverFromTimeout attempts to recover from a timeout error such as is caused by a canceled context. If recovery is +// successful true is returned. If recovery is not successful the connection is closed and false it returned. Recovery +// should usually be possible except in the case of a partial write. This must be called after any context cancellation. +// +// As RecoverFromTimeout may need to read and ignored data already sent from the server, it potentially can block +// indefinitely. Use ctx to guard against this. +func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool { + if pgConn.closed { + return false + } + pgConn.resetBatch() + + pgConn.NetConn.SetDeadline(time.Time{}) + + cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn) + defer cleanupContext() + + for pgConn.pendingReadyForQueryCount > 0 { + _, err := pgConn.ReceiveMessage() + if err != nil { + preferContextOverNetTimeoutError(ctx, err) + pgConn.Close(context.Background()) + return false + } + } + + result, err := pgConn.Exec( + context.Background(), // do not use ctx again because deadline goroutine already started above + "select 'RecoverFromTimeout'", + ) + if err != nil || len(result.Rows) != 1 || len(result.Rows[0]) != 1 || string(result.Rows[0][0]) != "RecoverFromTimeout" { + pgConn.Close(context.Background()) + return false + } + + return true +} + func (pgConn *PgConn) resetBatch() { pgConn.batchCount = 0 if len(pgConn.batchBuf) > batchBufferSize { @@ -500,7 +607,7 @@ type PgResult struct { // transactions unless a transaction is already in progress or sql contains transaction control statements. // // Exec must not be called when there are pending results from previous Send* methods (e.g. SendExec). -func (pgConn *PgConn) Exec(sql string) (*PgResult, error) { +func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) { if pgConn.batchCount != 0 { return nil, errors.New("unflushed previous sends") } @@ -509,14 +616,14 @@ func (pgConn *PgConn) Exec(sql string) (*PgResult, error) { } pgConn.SendExec(sql) - err := pgConn.Flush() + err := pgConn.Flush(ctx) if err != nil { return nil, err } var result *PgResult - for resultReader := pgConn.GetResult(); resultReader != nil; resultReader = pgConn.GetResult() { + for resultReader := pgConn.GetResult(ctx); resultReader != nil; resultReader = pgConn.GetResult(ctx) { rows := [][][]byte{} for resultReader.NextRow() { row := make([][]byte, len(resultReader.Values())) diff --git a/pgconn_test.go b/pgconn_test.go index f3f22d42..98fd198e 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -6,6 +6,7 @@ import ( "net" "os" "testing" + "time" "github.com/jackc/pgx" "github.com/jackc/pgx/pgconn" @@ -36,8 +37,7 @@ func TestConnect(t *testing.T) { conn, err := pgconn.Connect(context.Background(), connString) require.Nil(t, err) - err = conn.Close() - require.Nil(t, err) + closeConn(t, conn) }) } } @@ -57,8 +57,7 @@ func TestConnectTLS(t *testing.T) { t.Error("not a TLS connection") } - err = conn.Close() - require.Nil(t, err) + closeConn(t, conn) } func TestConnectInvalidUser(t *testing.T) { @@ -74,7 +73,7 @@ func TestConnectInvalidUser(t *testing.T) { conn, err := pgconn.ConnectConfig(context.Background(), config) if err == nil { - conn.Close() + conn.Close(context.Background()) t.Fatal("expected err but got none") } pgErr, ok := err.(pgx.PgError) @@ -92,7 +91,7 @@ func TestConnectWithConnectionRefused(t *testing.T) { // Presumably nothing is listening on 127.0.0.1:1 conn, err := pgconn.Connect(context.Background(), "host=127.0.0.1 port=1") if err == nil { - conn.Close() + conn.Close(context.Background()) t.Fatal("Expected error establishing connection to bad port") } } @@ -110,7 +109,7 @@ func TestConnectCustomDialer(t *testing.T) { conn, err := pgconn.ConnectConfig(context.Background(), config) require.Nil(t, err) require.True(t, dialed) - conn.Close() + closeConn(t, conn) } func TestConnectWithRuntimeParams(t *testing.T) { @@ -126,12 +125,12 @@ func TestConnectWithRuntimeParams(t *testing.T) { require.Nil(t, err) defer closeConn(t, conn) - result, err := conn.Exec("show application_name") + result, err := conn.Exec(context.Background(), "show application_name") require.Nil(t, err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "pgxtest", string(result.Rows[0][0])) - result, err = conn.Exec("show search_path") + result, err = conn.Exec(context.Background(), "show search_path") require.Nil(t, err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "myschema", string(result.Rows[0][0])) @@ -179,7 +178,7 @@ func TestConnectWithAfterConnectFunc(t *testing.T) { } acceptConnCount := 0 - config.AfterConnectFunc = func(conn *pgconn.PgConn) error { + config.AfterConnectFunc = func(ctx context.Context, conn *pgconn.PgConn) error { acceptConnCount += 1 if acceptConnCount < 2 { return errors.New("reject first conn") @@ -214,38 +213,38 @@ func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { conn, err := pgconn.ConnectConfig(context.Background(), config) if !assert.NotNil(t, err) { - conn.Close() + conn.Close(context.Background()) } } -func TestExec(t *testing.T) { +func TestConnExec(t *testing.T) { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) defer closeConn(t, pgConn) - result, err := pgConn.Exec("select current_database()") + result, err := pgConn.Exec(context.Background(), "select current_database()") require.Nil(t, err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, pgConn.Config.Database, string(result.Rows[0][0])) } -func TestExecMultipleQueries(t *testing.T) { +func TestConnExecMultipleQueries(t *testing.T) { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) defer closeConn(t, pgConn) - result, err := pgConn.Exec("select current_database(); select 1") + result, err := pgConn.Exec(context.Background(), "select current_database(); select 1") require.Nil(t, err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "1", string(result.Rows[0][0])) } -func TestExecMultipleQueriesError(t *testing.T) { +func TestConnExecMultipleQueriesError(t *testing.T) { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) defer closeConn(t, pgConn) - result, err := pgConn.Exec("select 1; select 1/0; select 1") + result, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1") require.NotNil(t, err) require.Nil(t, result) if pgErr, ok := err.(pgconn.PgError); ok { @@ -254,3 +253,37 @@ func TestExecMultipleQueriesError(t *testing.T) { t.Errorf("unexpected error: %v", err) } } + +func TestConnExecContextCanceled(t *testing.T) { + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + result, err := pgConn.Exec(ctx, "select current_database(), pg_sleep(1)") + require.Nil(t, result) + assert.Equal(t, context.DeadlineExceeded, err) +} + +func TestConnRecoverFromTimeout(t *testing.T) { + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + result, err := pgConn.Exec(ctx, "select current_database(), pg_sleep(1)") + cancel() + require.Nil(t, result) + assert.Equal(t, context.DeadlineExceeded, err) + + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + if assert.True(t, pgConn.RecoverFromTimeout(ctx)) { + result, err := pgConn.Exec(ctx, "select 1") + require.Nil(t, err) + assert.Len(t, result.Rows, 1) + assert.Len(t, result.Rows[0], 1) + assert.Equal(t, "1", string(result.Rows[0][0])) + } + cancel() +} From 53175a7badc5a1a035517900f0d42a323d03f04b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 17:32:04 -0600 Subject: [PATCH 010/290] Add cancel request to PgConn RecoverFromTimeout automatically tries to cancel in progress requests. --- pgconn.go | 41 +++++++++++++++++++++++++++++++++++++++++ pgconn_test.go | 22 ++++++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/pgconn.go b/pgconn.go index 311b06a3..a7c4eea3 100644 --- a/pgconn.go +++ b/pgconn.go @@ -562,8 +562,14 @@ func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool { } pgConn.resetBatch() + // Clear any existing timeout pgConn.NetConn.SetDeadline(time.Time{}) + // Try to cancel any in-progress requests + for i := 0; i < int(pgConn.pendingReadyForQueryCount); i++ { + pgConn.CancelRequest(ctx) + } + cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn) defer cleanupContext() @@ -669,3 +675,38 @@ func errorResponseToPgError(msg *pgproto3.ErrorResponse) PgError { Routine: msg.Routine, } } + +// CancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel +// request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there +// is no way to be sure a query was canceled. See https://www.postgresql.org/docs/11/protocol-flow.html#id-1.10.5.7.9 +func (pgConn *PgConn) CancelRequest(ctx context.Context) error { + // Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing + // the connection config. This is important in high availability configurations where fallback connections may be + // specified or DNS may be used to load balance. + serverAddr := pgConn.NetConn.RemoteAddr() + cancelConn, err := pgConn.Config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) + if err != nil { + return err + } + defer cancelConn.Close() + + cleanupContext := contextDoneToConnDeadline(ctx, cancelConn) + defer cleanupContext() + + buf := make([]byte, 16) + binary.BigEndian.PutUint32(buf[0:4], 16) + binary.BigEndian.PutUint32(buf[4:8], 80877102) + binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.PID)) + binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.SecretKey)) + _, err = cancelConn.Write(buf) + if err != nil { + return preferContextOverNetTimeoutError(ctx, err) + } + + _, err = cancelConn.Read(buf) + if err != io.EOF { + return fmt.Errorf("Server failed to close connection after cancel query request: %v", preferContextOverNetTimeoutError(ctx, err)) + } + + return nil +} diff --git a/pgconn_test.go b/pgconn_test.go index 98fd198e..9873013c 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -264,6 +264,8 @@ func TestConnExecContextCanceled(t *testing.T) { result, err := pgConn.Exec(ctx, "select current_database(), pg_sleep(1)") require.Nil(t, result) assert.Equal(t, context.DeadlineExceeded, err) + + assert.True(t, pgConn.RecoverFromTimeout(context.Background())) } func TestConnRecoverFromTimeout(t *testing.T) { @@ -287,3 +289,23 @@ func TestConnRecoverFromTimeout(t *testing.T) { } cancel() } + +func TestConnCancelQuery(t *testing.T) { + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + pgConn.SendExec("select current_database(), pg_sleep(5)") + err = pgConn.Flush(context.Background()) + require.Nil(t, err) + + err = pgConn.CancelRequest(context.Background()) + require.Nil(t, err) + + _, err = pgConn.GetResult(context.Background()).Close() + if err, ok := err.(pgconn.PgError); ok { + assert.Equal(t, "57014", err.Code) + } else { + t.Errorf("expected pgconn.PgError got %v", err) + } +} From bcc3da490cd2c06889c03f18c0a0e41eea51e45d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 17:34:44 -0600 Subject: [PATCH 011/290] Run tests in parallel --- config_test.go | 2 ++ pgconn_test.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/config_test.go b/config_test.go index 36f3fee2..e7a5bb44 100644 --- a/config_test.go +++ b/config_test.go @@ -533,6 +533,8 @@ func TestParseConfigEnvLibpq(t *testing.T) { } func TestParseConfigReadsPgPassfile(t *testing.T) { + t.Parallel() + tf, err := ioutil.TempFile("", "") require.Nil(t, err) diff --git a/pgconn_test.go b/pgconn_test.go index 9873013c..741c1b4b 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -29,6 +29,8 @@ func TestConnect(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() + connString := os.Getenv(tt.env) if connString == "" { t.Skipf("Skipping due to missing environment variable %v", tt.env) @@ -45,6 +47,8 @@ func TestConnect(t *testing.T) { // TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure // connection. func TestConnectTLS(t *testing.T) { + t.Parallel() + connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") if connString == "" { t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") @@ -61,6 +65,8 @@ func TestConnectTLS(t *testing.T) { } func TestConnectInvalidUser(t *testing.T) { + t.Parallel() + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") if connString == "" { t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") @@ -97,6 +103,8 @@ func TestConnectWithConnectionRefused(t *testing.T) { } func TestConnectCustomDialer(t *testing.T) { + t.Parallel() + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) @@ -113,6 +121,8 @@ func TestConnectCustomDialer(t *testing.T) { } func TestConnectWithRuntimeParams(t *testing.T) { + t.Parallel() + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) @@ -137,6 +147,8 @@ func TestConnectWithRuntimeParams(t *testing.T) { } func TestConnectWithFallback(t *testing.T) { + t.Parallel() + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) @@ -168,6 +180,8 @@ func TestConnectWithFallback(t *testing.T) { } func TestConnectWithAfterConnectFunc(t *testing.T) { + t.Parallel() + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) @@ -205,6 +219,8 @@ func TestConnectWithAfterConnectFunc(t *testing.T) { } func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { + t.Parallel() + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) @@ -218,6 +234,8 @@ func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { } func TestConnExec(t *testing.T) { + t.Parallel() + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) defer closeConn(t, pgConn) @@ -229,6 +247,8 @@ func TestConnExec(t *testing.T) { } func TestConnExecMultipleQueries(t *testing.T) { + t.Parallel() + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) defer closeConn(t, pgConn) @@ -240,6 +260,8 @@ func TestConnExecMultipleQueries(t *testing.T) { } func TestConnExecMultipleQueriesError(t *testing.T) { + t.Parallel() + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) defer closeConn(t, pgConn) @@ -255,6 +277,8 @@ func TestConnExecMultipleQueriesError(t *testing.T) { } func TestConnExecContextCanceled(t *testing.T) { + t.Parallel() + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) defer closeConn(t, pgConn) @@ -269,6 +293,8 @@ func TestConnExecContextCanceled(t *testing.T) { } func TestConnRecoverFromTimeout(t *testing.T) { + t.Parallel() + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) defer closeConn(t, pgConn) @@ -291,6 +317,8 @@ func TestConnRecoverFromTimeout(t *testing.T) { } func TestConnCancelQuery(t *testing.T) { + t.Parallel() + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err) defer closeConn(t, pgConn) From 49c9674102c3c151a004f2ef1d54072c9cb8244d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 17:46:56 -0600 Subject: [PATCH 012/290] PG error type is *pgconn.PgError --- pgconn.go | 10 +++++----- pgconn_test.go | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pgconn.go b/pgconn.go index a7c4eea3..fef113e0 100644 --- a/pgconn.go +++ b/pgconn.go @@ -45,7 +45,7 @@ type PgError struct { Routine string } -func (pe PgError) Error() string { +func (pe *PgError) Error() string { return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" } @@ -118,7 +118,7 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err pgConn, err = connect(ctx, config, fc) if err == nil { return pgConn, nil - } else if err, ok := err.(PgError); ok { + } else if err, ok := err.(*PgError); ok { return nil, err } } @@ -199,7 +199,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig // handled by ReceiveMessage case *pgproto3.ErrorResponse: pgConn.NetConn.Close() - return nil, PgError{ + return nil, &PgError{ Severity: msg.Severity, Code: msg.Code, Message: msg.Message, @@ -654,8 +654,8 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) { return result, nil } -func errorResponseToPgError(msg *pgproto3.ErrorResponse) PgError { - return PgError{ +func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { + return &PgError{ Severity: msg.Severity, Code: msg.Code, Message: msg.Message, diff --git a/pgconn_test.go b/pgconn_test.go index 741c1b4b..e46093b0 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -269,7 +269,7 @@ func TestConnExecMultipleQueriesError(t *testing.T) { result, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1") require.NotNil(t, err) require.Nil(t, result) - if pgErr, ok := err.(pgconn.PgError); ok { + if pgErr, ok := err.(*pgconn.PgError); ok { assert.Equal(t, "22012", pgErr.Code) } else { t.Errorf("unexpected error: %v", err) @@ -331,7 +331,7 @@ func TestConnCancelQuery(t *testing.T) { require.Nil(t, err) _, err = pgConn.GetResult(context.Background()).Close() - if err, ok := err.(pgconn.PgError); ok { + if err, ok := err.(*pgconn.PgError); ok { assert.Equal(t, "57014", err.Code) } else { t.Errorf("expected pgconn.PgError got %v", err) From f5faed65688c703f48a5712b2a41fc7db928fea9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 18:00:08 -0600 Subject: [PATCH 013/290] Access underlying net.Conn via method Also remove some dead code. --- pgconn.go | 57 ++++++++++++++++++++++++++++++------------------------- 1 file changed, 31 insertions(+), 26 deletions(-) diff --git a/pgconn.go b/pgconn.go index fef113e0..776141f9 100644 --- a/pgconn.go +++ b/pgconn.go @@ -58,7 +58,7 @@ var ErrTLSRefused = errors.New("server refused TLS connection") // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { - NetConn net.Conn // the underlying TCP or unix domain socket connection + conn net.Conn // the underlying TCP or unix domain socket connection PID uint32 // backend pid SecretKey uint32 // key to use to send a cancel query message to the server parameterStatuses map[string]string // parameters that have been reported by the server @@ -132,7 +132,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig var err error network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) - pgConn.NetConn, err = config.DialFunc(ctx, network, address) + pgConn.conn, err = config.DialFunc(ctx, network, address) if err != nil { return nil, err } @@ -141,12 +141,12 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if config.TLSConfig != nil { if err := pgConn.startTLS(config.TLSConfig); err != nil { - pgConn.NetConn.Close() + pgConn.conn.Close() return nil, err } } - pgConn.Frontend, err = pgproto3.NewFrontend(pgConn.NetConn, pgConn.NetConn) + pgConn.Frontend, err = pgproto3.NewFrontend(pgConn.conn, pgConn.conn) if err != nil { return nil, err } @@ -166,8 +166,8 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig startupMsg.Parameters["database"] = config.Database } - if _, err := pgConn.NetConn.Write(startupMsg.Encode(nil)); err != nil { - pgConn.NetConn.Close() + if _, err := pgConn.conn.Write(startupMsg.Encode(nil)); err != nil { + pgConn.conn.Close() return nil, err } @@ -183,14 +183,14 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.SecretKey = msg.SecretKey case *pgproto3.Authentication: if err = pgConn.rxAuthenticationX(msg); err != nil { - pgConn.NetConn.Close() + pgConn.conn.Close() return nil, err } case *pgproto3.ReadyForQuery: if config.AfterConnectFunc != nil { err := config.AfterConnectFunc(ctx, pgConn) if err != nil { - pgConn.NetConn.Close() + pgConn.conn.Close() return nil, fmt.Errorf("AfterConnectFunc: %v", err) } } @@ -198,7 +198,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig case *pgproto3.ParameterStatus: // handled by ReceiveMessage case *pgproto3.ErrorResponse: - pgConn.NetConn.Close() + pgConn.conn.Close() return nil, &PgError{ Severity: msg.Severity, Code: msg.Code, @@ -219,20 +219,20 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig Routine: msg.Routine, } default: - pgConn.NetConn.Close() + pgConn.conn.Close() return nil, errors.New("unexpected message") } } } func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { - err = binary.Write(pgConn.NetConn, binary.BigEndian, []int32{8, 80877103}) + err = binary.Write(pgConn.conn, binary.BigEndian, []int32{8, 80877103}) if err != nil { return } response := make([]byte, 1) - if _, err = io.ReadFull(pgConn.NetConn, response); err != nil { + if _, err = io.ReadFull(pgConn.conn, response); err != nil { return } @@ -240,7 +240,7 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { return ErrTLSRefused } - pgConn.NetConn = tls.Client(pgConn.NetConn, tlsConfig) + pgConn.conn = tls.Client(pgConn.conn, tlsConfig) return nil } @@ -262,7 +262,7 @@ func (c *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { func (pgConn *PgConn) txPasswordMessage(password string) (err error) { msg := &pgproto3.PasswordMessage{Password: password} - _, err = pgConn.NetConn.Write(msg.Encode(nil)) + _, err = pgConn.conn.Write(msg.Encode(nil)) return err } @@ -299,6 +299,11 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { return msg, nil } +// Conn returns the underlying net.Conn. +func (pgConn *PgConn) Conn() net.Conn { + return pgConn.conn +} + // Close closes a connection. It is safe to call Close on a already closed connection. Close attempts a clean close by // sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The // underlying net.Conn.Close() will always be called regardless of any other errors. @@ -308,22 +313,22 @@ func (pgConn *PgConn) Close(ctx context.Context) error { } pgConn.closed = true - defer pgConn.NetConn.Close() + defer pgConn.conn.Close() - cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn) + cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanupContext() - _, err := pgConn.NetConn.Write([]byte{'X', 0, 0, 0, 4}) + _, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) if err != nil { return preferContextOverNetTimeoutError(ctx, err) } - _, err = pgConn.NetConn.Read(make([]byte, 1)) + _, err = pgConn.conn.Read(make([]byte, 1)) if err != io.EOF { return preferContextOverNetTimeoutError(ctx, err) } - return pgConn.NetConn.Close() + return pgConn.conn.Close() } // ParameterStatus returns the value of a parameter reported by the server (e.g. @@ -380,7 +385,7 @@ type PgResultReader struct { // consumed it returns nil. If an error occurs it will be reported on the // returned PgResultReader. func (pgConn *PgConn) GetResult(ctx context.Context) *PgResultReader { - cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn) + cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) for pgConn.pendingReadyForQueryCount > 0 { msg, err := pgConn.ReceiveMessage() @@ -491,14 +496,14 @@ func (rr *PgResultReader) close() { func (pgConn *PgConn) Flush(ctx context.Context) error { defer pgConn.resetBatch() - cleanup := contextDoneToConnDeadline(ctx, pgConn.NetConn) + cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanup() - n, err := pgConn.NetConn.Write(pgConn.batchBuf) + n, err := pgConn.conn.Write(pgConn.batchBuf) if err != nil { if n > 0 { // Close connection because cannot recover from partially sent message. - pgConn.NetConn.Close() + pgConn.conn.Close() pgConn.closed = true } return preferContextOverNetTimeoutError(ctx, err) @@ -563,14 +568,14 @@ func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool { pgConn.resetBatch() // Clear any existing timeout - pgConn.NetConn.SetDeadline(time.Time{}) + pgConn.conn.SetDeadline(time.Time{}) // Try to cancel any in-progress requests for i := 0; i < int(pgConn.pendingReadyForQueryCount); i++ { pgConn.CancelRequest(ctx) } - cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn) + cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanupContext() for pgConn.pendingReadyForQueryCount > 0 { @@ -683,7 +688,7 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. - serverAddr := pgConn.NetConn.RemoteAddr() + serverAddr := pgConn.conn.RemoteAddr() cancelConn, err := pgConn.Config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) if err != nil { return err From 2f156c7add3a4026af92c7a9626ec2ea85e17f61 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 18:03:55 -0600 Subject: [PATCH 014/290] Access PID and SecretKey via method --- pgconn.go | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/pgconn.go b/pgconn.go index 776141f9..87ba0096 100644 --- a/pgconn.go +++ b/pgconn.go @@ -59,8 +59,8 @@ var ErrTLSRefused = errors.New("server refused TLS connection") // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { conn net.Conn // the underlying TCP or unix domain socket connection - PID uint32 // backend pid - SecretKey uint32 // key to use to send a cancel query message to the server + pid uint32 // backend pid + secretKey uint32 // key to use to send a cancel query message to the server parameterStatuses map[string]string // parameters that have been reported by the server TxStatus byte Frontend *pgproto3.Frontend @@ -179,8 +179,8 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig switch msg := msg.(type) { case *pgproto3.BackendKeyData: - pgConn.PID = msg.ProcessID - pgConn.SecretKey = msg.SecretKey + pgConn.pid = msg.ProcessID + pgConn.secretKey = msg.SecretKey case *pgproto3.Authentication: if err = pgConn.rxAuthenticationX(msg); err != nil { pgConn.conn.Close() @@ -304,6 +304,16 @@ func (pgConn *PgConn) Conn() net.Conn { return pgConn.conn } +// PID returns the backend PID. +func (pgConn *PgConn) PID() uint32 { + return pgConn.pid +} + +// SecretKey returns the backend secret key used to send a cancel query message to the server. +func (pgConn *PgConn) SecretKey() uint32 { + return pgConn.secretKey +} + // Close closes a connection. It is safe to call Close on a already closed connection. Close attempts a clean close by // sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The // underlying net.Conn.Close() will always be called regardless of any other errors. @@ -701,8 +711,8 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { buf := make([]byte, 16) binary.BigEndian.PutUint32(buf[0:4], 16) binary.BigEndian.PutUint32(buf[4:8], 80877102) - binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.PID)) - binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.SecretKey)) + binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.pid)) + binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey)) _, err = cancelConn.Write(buf) if err != nil { return preferContextOverNetTimeoutError(ctx, err) From 650aa7059a3ac45b9e4215ecc8cdb21e35846e8a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 18:45:51 -0600 Subject: [PATCH 015/290] Fix broken tests --- pgconn_test.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pgconn_test.go b/pgconn_test.go index e46093b0..05318dac 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -8,7 +8,6 @@ import ( "testing" "time" - "github.com/jackc/pgx" "github.com/jackc/pgx/pgconn" "github.com/pkg/errors" @@ -57,7 +56,7 @@ func TestConnectTLS(t *testing.T) { conn, err := pgconn.Connect(context.Background(), connString) require.Nil(t, err) - if _, ok := conn.NetConn.(*tls.Conn); !ok { + if _, ok := conn.Conn().(*tls.Conn); !ok { t.Error("not a TLS connection") } @@ -82,7 +81,7 @@ func TestConnectInvalidUser(t *testing.T) { conn.Close(context.Background()) t.Fatal("expected err but got none") } - pgErr, ok := err.(pgx.PgError) + pgErr, ok := err.(*pgconn.PgError) if !ok { t.Fatalf("Expected to receive a PgError, instead received: %v", err) } From 5f69253174a8f3712c90f5b791a0467a89d347a2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 19:59:32 -0600 Subject: [PATCH 016/290] Added ExecParams --- pgconn.go | 171 +++++++++++++++++++++++++++++++++++++++++++++++++ pgconn_test.go | 31 ++++++++- 2 files changed, 201 insertions(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index 87ba0096..db9c758d 100644 --- a/pgconn.go +++ b/pgconn.go @@ -20,6 +20,12 @@ import ( const batchBufferSize = 4096 +// PostgreSQL extended protocol format codes +const ( + TextFormatCode = 0 + BinaryFormatCode = 1 +) + var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC) // PgError represents an error reported by the PostgreSQL server. See @@ -379,6 +385,127 @@ func appendQuery(buf []byte, query string) []byte { return buf } +// appendParse appends a PostgreSQL wire protocol parse message to buf and returns it. +func appendParse(buf []byte, name string, query string, paramOIDs []uint32) []byte { + buf = append(buf, 'P') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, name...) + buf = append(buf, 0) + buf = append(buf, query...) + buf = append(buf, 0) + + buf = pgio.AppendInt16(buf, int16(len(paramOIDs))) + for _, oid := range paramOIDs { + buf = pgio.AppendUint32(buf, oid) + } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + return buf +} + +// appendSync appends a PostgreSQL wire protocol sync message to buf and returns it. +func appendSync(buf []byte) []byte { + buf = append(buf, 'S') + buf = pgio.AppendInt32(buf, 4) + + return buf +} + +// appendBind appends a PostgreSQL wire protocol bind message to buf and returns it. +func appendBind( + buf []byte, + destinationPortal, + preparedStatement string, + paramFormats []int16, + paramValues [][]byte, + resultFormatCodes []int16, +) []byte { + if len(paramFormats) != 0 && len(paramFormats) != len(paramValues) && len(paramFormats) != len(paramValues) { + panic(fmt.Sprintf("len(paramFormats) must be 0, 1, or len(paramValues), received %d", len(paramFormats))) + } + + buf = append(buf, 'B') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, destinationPortal...) + buf = append(buf, 0) + buf = append(buf, preparedStatement...) + buf = append(buf, 0) + + buf = pgio.AppendInt16(buf, int16(len(paramFormats))) + for _, f := range paramFormats { + buf = pgio.AppendInt16(buf, f) + } + + buf = pgio.AppendInt16(buf, int16(len(paramValues))) + for _, p := range paramValues { + if p == nil { + buf = pgio.AppendInt32(buf, -1) + continue + } + + buf = pgio.AppendInt32(buf, int32(len(p))) + buf = append(buf, p...) + } + + buf = pgio.AppendInt16(buf, int16(len(resultFormatCodes))) + for _, fc := range resultFormatCodes { + buf = pgio.AppendInt16(buf, fc) + } + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + return buf +} + +// appendExecute appends a PostgreSQL wire protocol execute message to buf and returns it. +func appendExecute(buf []byte, portal string, maxRows uint32) []byte { + buf = append(buf, 'E') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + buf = append(buf, portal...) + buf = append(buf, 0) + buf = pgio.AppendUint32(buf, maxRows) + + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + return buf +} + +// SendExecParams enqueues the execution of sql via the PostgreSQL extended query protocol. +// +// sql is a SQL command string. It may only contain one query. Parameter substitution is position using $1, $2, $3, etc. +// +// paramValues are the parameter values. It must be encoded in the format given by paramFormats. +// +// paramOIDs is a slice of data type OIDs for paramValues. If paramOIDs is nil, the server will infer the data type for +// all parameters. Any paramOID element that is 0 that will cause the server to infer the data type for that parameter. +// SendExecParams will panic if len(paramOIDs) is not 0, 1, or len(paramValues). +// +// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or +// binary format. If paramFormats is nil all results will be in text protocol. SendExecParams will panic if +// len(paramFormats) is not 0, 1, or len(paramValues). +// +// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or +// binary format. If resultFormats is nil all results will be in text protocol. +// +// Query is only sent to the PostgreSQL server when Flush is called. +func (pgConn *PgConn) SendExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { + if len(paramValues) > 65535 { + panic(fmt.Sprintf("Number of params 0 and 65535, received %d", len(paramValues))) + } + if len(paramOIDs) != 0 && len(paramOIDs) != len(paramValues) && len(paramOIDs) != len(paramValues) { + panic(fmt.Sprintf("len(paramOIDs) must be 0, 1, or len(paramValues), received %d", len(paramOIDs))) + } + + pgConn.batchBuf = appendParse(pgConn.batchBuf, "", sql, paramOIDs) + pgConn.batchBuf = appendBind(pgConn.batchBuf, "", "", paramFormats, paramValues, resultFormats) + pgConn.batchBuf = appendExecute(pgConn.batchBuf, "", 0) + pgConn.batchBuf = appendSync(pgConn.batchBuf) + pgConn.batchCount += 1 +} + type PgResultReader struct { pgConn *PgConn fieldDescriptions []pgproto3.FieldDescription @@ -669,6 +796,50 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) { return result, nil } +// ExecParams executes sql via the PostgreSQL extended query protocol, buffers the entire result, and returns it. See +// SendExecParams for parameter descriptions. +// +// ExecParams must not be called when there are pending results from previous Send* methods (e.g. SendExec). +func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) (*PgResult, error) { + if pgConn.batchCount != 0 { + return nil, errors.New("unflushed previous sends") + } + if pgConn.pendingReadyForQueryCount != 0 { + return nil, errors.New("unread previous results") + } + + pgConn.SendExecParams(sql, paramValues, paramOIDs, paramFormats, resultFormats) + err := pgConn.Flush(ctx) + if err != nil { + return nil, err + } + + resultReader := pgConn.GetResult(ctx) + if resultReader == nil { + return nil, errors.New("unexpected missing result") + } + + var result *PgResult + rows := [][][]byte{} + for resultReader.NextRow() { + row := make([][]byte, len(resultReader.Values())) + copy(row, resultReader.Values()) + rows = append(rows, row) + } + + commandTag, err := resultReader.Close() + if err != nil { + return nil, err + } + + result = &PgResult{ + Rows: rows, + CommandTag: commandTag, + } + + return result, nil +} + func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { return &PgError{ Severity: msg.Severity, diff --git a/pgconn_test.go b/pgconn_test.go index 05318dac..fa1ec5fc 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -285,7 +285,36 @@ func TestConnExecContextCanceled(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() result, err := pgConn.Exec(ctx, "select current_database(), pg_sleep(1)") - require.Nil(t, result) + assert.Nil(t, result) + assert.Equal(t, context.DeadlineExceeded, err) + + assert.True(t, pgConn.RecoverFromTimeout(context.Background())) +} + +func TestConnExecParams(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + result, err := pgConn.ExecParams(context.Background(), "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) + require.Nil(t, err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "Hello, world", string(result.Rows[0][0])) +} + +func TestConnExecParamsCanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + result, err := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil) + assert.Nil(t, result) assert.Equal(t, context.DeadlineExceeded, err) assert.True(t, pgConn.RecoverFromTimeout(context.Background())) From 13323df0dd20310714151bada713d8f168a672df Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 20:08:11 -0600 Subject: [PATCH 017/290] Add batched query test --- pgconn_test.go | 90 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/pgconn_test.go b/pgconn_test.go index fa1ec5fc..a765dc4c 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -320,6 +320,96 @@ func TestConnExecParamsCanceled(t *testing.T) { assert.True(t, pgConn.RecoverFromTimeout(context.Background())) } +func TestConnBatchedQueries(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + pgConn.SendExec("select 'SendExec 1'") + pgConn.SendExecParams("select $1::text", [][]byte{[]byte("SendExecParams 1")}, nil, nil, nil) + pgConn.SendExec("select 'SendExec 2'") + pgConn.SendExecParams("select $1::text", [][]byte{[]byte("SendExecParams 2")}, nil, nil, nil) + err = pgConn.Flush(context.Background()) + + // "select 'SendExec 1'" + resultReader := pgConn.GetResult(context.Background()) + require.NotNil(t, resultReader) + + rows := [][][]byte{} + for resultReader.NextRow() { + row := make([][]byte, len(resultReader.Values())) + copy(row, resultReader.Values()) + rows = append(rows, row) + } + require.Len(t, rows, 1) + require.Len(t, rows[0], 1) + assert.Equal(t, "SendExec 1", string(rows[0][0])) + + commandTag, err := resultReader.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.Nil(t, err) + + // "SendExecParams 1" + resultReader = pgConn.GetResult(context.Background()) + require.NotNil(t, resultReader) + + rows = [][][]byte{} + for resultReader.NextRow() { + row := make([][]byte, len(resultReader.Values())) + copy(row, resultReader.Values()) + rows = append(rows, row) + } + require.Len(t, rows, 1) + require.Len(t, rows[0], 1) + assert.Equal(t, "SendExecParams 1", string(rows[0][0])) + + commandTag, err = resultReader.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.Nil(t, err) + + // "SendExec 2" + resultReader = pgConn.GetResult(context.Background()) + require.NotNil(t, resultReader) + + rows = [][][]byte{} + for resultReader.NextRow() { + row := make([][]byte, len(resultReader.Values())) + copy(row, resultReader.Values()) + rows = append(rows, row) + } + require.Len(t, rows, 1) + require.Len(t, rows[0], 1) + assert.Equal(t, "SendExec 2", string(rows[0][0])) + + commandTag, err = resultReader.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.Nil(t, err) + + // "SendExecParams 2" + resultReader = pgConn.GetResult(context.Background()) + require.NotNil(t, resultReader) + + rows = [][][]byte{} + for resultReader.NextRow() { + row := make([][]byte, len(resultReader.Values())) + copy(row, resultReader.Values()) + rows = append(rows, row) + } + require.Len(t, rows, 1) + require.Len(t, rows[0], 1) + assert.Equal(t, "SendExecParams 2", string(rows[0][0])) + + commandTag, err = resultReader.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.Nil(t, err) + + // Done + resultReader = pgConn.GetResult(context.Background()) + assert.Nil(t, resultReader) +} + func TestConnRecoverFromTimeout(t *testing.T) { t.Parallel() From 54df8c691874b91855671462f081a4dd3ce9df42 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 11:32:56 -0600 Subject: [PATCH 018/290] Add ExecPrepared --- pgconn.go | 126 +++++++++++++++++++++++++++++++++++++++++++++++-- pgconn_test.go | 57 ++++++++++++++++++++++ 2 files changed, 180 insertions(+), 3 deletions(-) diff --git a/pgconn.go b/pgconn.go index db9c758d..f2e46539 100644 --- a/pgconn.go +++ b/pgconn.go @@ -387,6 +387,10 @@ func appendQuery(buf []byte, query string) []byte { // appendParse appends a PostgreSQL wire protocol parse message to buf and returns it. func appendParse(buf []byte, name string, query string, paramOIDs []uint32) []byte { + if len(paramOIDs) > 65535 { + panic(fmt.Sprintf("len(paramOIDs) must be between 0 and 65535, received %d", len(paramOIDs))) + } + buf = append(buf, 'P') sp := len(buf) buf = pgio.AppendInt32(buf, -1) @@ -404,6 +408,19 @@ func appendParse(buf []byte, name string, query string, paramOIDs []uint32) []by return buf } +// appendDescribe appends a PostgreSQL wire protocol describe message to buf and returns it. +func appendDescribe(buf []byte, objectType byte, name string) []byte { + buf = append(buf, 'D') + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + buf = append(buf, objectType) + buf = append(buf, name...) + buf = append(buf, 0) + pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) + + return buf +} + // appendSync appends a PostgreSQL wire protocol sync message to buf and returns it. func appendSync(buf []byte) []byte { buf = append(buf, 'S') @@ -424,6 +441,9 @@ func appendBind( if len(paramFormats) != 0 && len(paramFormats) != len(paramValues) && len(paramFormats) != len(paramValues) { panic(fmt.Sprintf("len(paramFormats) must be 0, 1, or len(paramValues), received %d", len(paramFormats))) } + if len(paramValues) > 65535 { + panic(fmt.Sprintf("len(paramValues) must be between 0 and 65535, received %d", len(paramValues))) + } buf = append(buf, 'B') sp := len(buf) @@ -492,9 +512,6 @@ func appendExecute(buf []byte, portal string, maxRows uint32) []byte { // // Query is only sent to the PostgreSQL server when Flush is called. func (pgConn *PgConn) SendExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { - if len(paramValues) > 65535 { - panic(fmt.Sprintf("Number of params 0 and 65535, received %d", len(paramValues))) - } if len(paramOIDs) != 0 && len(paramOIDs) != len(paramValues) && len(paramOIDs) != len(paramValues) { panic(fmt.Sprintf("len(paramOIDs) must be 0, 1, or len(paramValues), received %d", len(paramOIDs))) } @@ -506,6 +523,25 @@ func (pgConn *PgConn) SendExecParams(sql string, paramValues [][]byte, paramOIDs pgConn.batchCount += 1 } +// SendExecPrepared enqueues the execution of a prepared statement via the PostgreSQL extended query protocol. +// +// paramValues are the parameter values. It must be encoded in the format given by paramFormats. +// +// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or +// binary format. If paramFormats is nil all results will be in text protocol. SendExecParams will panic if +// len(paramFormats) is not 0, 1, or len(paramValues). +// +// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or +// binary format. If resultFormats is nil all results will be in text protocol. +// +// Query is only sent to the PostgreSQL server when Flush is called. +func (pgConn *PgConn) SendExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { + pgConn.batchBuf = appendBind(pgConn.batchBuf, "", stmtName, paramFormats, paramValues, resultFormats) + pgConn.batchBuf = appendExecute(pgConn.batchBuf, "", 0) + pgConn.batchBuf = appendSync(pgConn.batchBuf) + pgConn.batchCount += 1 +} + type PgResultReader struct { pgConn *PgConn fieldDescriptions []pgproto3.FieldDescription @@ -840,6 +876,90 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] return result, nil } +// ExecPrepared executes a prepared statement via the PostgreSQL extended query protocol, buffers the entire result, and +// returns it. See SendExecPrepared for parameter descriptions. +// +// ExecPrepared must not be called when there are pending results from previous Send* methods (e.g. SendExec). +func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) (*PgResult, error) { + if pgConn.batchCount != 0 { + return nil, errors.New("unflushed previous sends") + } + if pgConn.pendingReadyForQueryCount != 0 { + return nil, errors.New("unread previous results") + } + + pgConn.SendExecPrepared(stmtName, paramValues, paramFormats, resultFormats) + err := pgConn.Flush(ctx) + if err != nil { + return nil, err + } + + resultReader := pgConn.GetResult(ctx) + if resultReader == nil { + return nil, errors.New("unexpected missing result") + } + + var result *PgResult + rows := [][][]byte{} + for resultReader.NextRow() { + row := make([][]byte, len(resultReader.Values())) + copy(row, resultReader.Values()) + rows = append(rows, row) + } + + commandTag, err := resultReader.Close() + if err != nil { + return nil, err + } + + result = &PgResult{ + Rows: rows, + CommandTag: commandTag, + } + + return result, nil +} + +// Prepare creates a prepared statement. +func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) error { + if pgConn.batchCount != 0 { + return errors.New("unflushed previous sends") + } + if pgConn.pendingReadyForQueryCount != 0 { + return errors.New("unread previous results") + } + + cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) + defer cleanupContext() + + pgConn.batchBuf = appendParse(pgConn.batchBuf, name, sql, paramOIDs) + pgConn.batchBuf = appendDescribe(pgConn.batchBuf, 'S', name) + pgConn.batchBuf = appendSync(pgConn.batchBuf) + pgConn.batchCount += 1 + err := pgConn.Flush(context.Background()) + if err != nil { + return preferContextOverNetTimeoutError(ctx, err) + } + + for pgConn.pendingReadyForQueryCount > 0 { + msg, err := pgConn.ReceiveMessage() + if err != nil { + return preferContextOverNetTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.ParameterDescription: + // TODO + case *pgproto3.RowDescription: + // TODO + case *pgproto3.ErrorResponse: + return errorResponseToPgError(msg) + } + } + + return nil +} + func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { return &PgError{ Severity: msg.Severity, diff --git a/pgconn_test.go b/pgconn_test.go index a765dc4c..35f5b536 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -320,6 +320,41 @@ func TestConnExecParamsCanceled(t *testing.T) { assert.True(t, pgConn.RecoverFromTimeout(context.Background())) } +func TestConnExecPrepared(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + require.Nil(t, err) + + result, err := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) + require.Nil(t, err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "Hello, world", string(result.Rows[0][0])) +} + +func TestConnExecPreparedCanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil) + require.Nil(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + result, err := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil) + assert.Nil(t, result) + assert.Equal(t, context.DeadlineExceeded, err) + + assert.True(t, pgConn.RecoverFromTimeout(context.Background())) +} + func TestConnBatchedQueries(t *testing.T) { t.Parallel() @@ -327,8 +362,12 @@ func TestConnBatchedQueries(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) + err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + require.Nil(t, err) + pgConn.SendExec("select 'SendExec 1'") pgConn.SendExecParams("select $1::text", [][]byte{[]byte("SendExecParams 1")}, nil, nil, nil) + pgConn.SendExecPrepared("ps1", [][]byte{[]byte("SendExecPrepared 1")}, nil, nil) pgConn.SendExec("select 'SendExec 2'") pgConn.SendExecParams("select $1::text", [][]byte{[]byte("SendExecParams 2")}, nil, nil, nil) err = pgConn.Flush(context.Background()) @@ -369,6 +408,24 @@ func TestConnBatchedQueries(t *testing.T) { assert.Equal(t, "SELECT 1", string(commandTag)) assert.Nil(t, err) + // "SendExecPrepared 1" + resultReader = pgConn.GetResult(context.Background()) + require.NotNil(t, resultReader) + + rows = [][][]byte{} + for resultReader.NextRow() { + row := make([][]byte, len(resultReader.Values())) + copy(row, resultReader.Values()) + rows = append(rows, row) + } + require.Len(t, rows, 1) + require.Len(t, rows[0], 1) + assert.Equal(t, "SendExecPrepared 1", string(rows[0][0])) + + commandTag, err = resultReader.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.Nil(t, err) + // "SendExec 2" resultReader = pgConn.GetResult(context.Background()) require.NotNil(t, resultReader) From 51d654d32a9c4975ff27f499fa5cb1149d750cf0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 11:35:39 -0600 Subject: [PATCH 019/290] Format code constants already in pgproto3 --- pgconn.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pgconn.go b/pgconn.go index f2e46539..9aeba757 100644 --- a/pgconn.go +++ b/pgconn.go @@ -20,12 +20,6 @@ import ( const batchBufferSize = 4096 -// PostgreSQL extended protocol format codes -const ( - TextFormatCode = 0 - BinaryFormatCode = 1 -) - var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC) // PgError represents an error reported by the PostgreSQL server. See From b793875c1ffbdf077c20d2eb36fe3346ae6d77a4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 13:16:50 -0600 Subject: [PATCH 020/290] Extract bufferLastResult Buffered exec methods need to read until pending ready for queries is 0. Factor this common logic out. Add stress test for PgConn. --- pgconn.go | 54 ++++++------------------------------------------------ 1 file changed, 6 insertions(+), 48 deletions(-) diff --git a/pgconn.go b/pgconn.go index 9aeba757..ec5413de 100644 --- a/pgconn.go +++ b/pgconn.go @@ -799,6 +799,10 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) { return nil, err } + return pgConn.bufferLastResult(ctx) +} + +func (pgConn *PgConn) bufferLastResult(ctx context.Context) (*PgResult, error) { var result *PgResult for resultReader := pgConn.GetResult(ctx); resultReader != nil; resultReader = pgConn.GetResult(ctx) { @@ -844,30 +848,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] return nil, err } - resultReader := pgConn.GetResult(ctx) - if resultReader == nil { - return nil, errors.New("unexpected missing result") - } - - var result *PgResult - rows := [][][]byte{} - for resultReader.NextRow() { - row := make([][]byte, len(resultReader.Values())) - copy(row, resultReader.Values()) - rows = append(rows, row) - } - - commandTag, err := resultReader.Close() - if err != nil { - return nil, err - } - - result = &PgResult{ - Rows: rows, - CommandTag: commandTag, - } - - return result, nil + return pgConn.bufferLastResult(ctx) } // ExecPrepared executes a prepared statement via the PostgreSQL extended query protocol, buffers the entire result, and @@ -888,30 +869,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa return nil, err } - resultReader := pgConn.GetResult(ctx) - if resultReader == nil { - return nil, errors.New("unexpected missing result") - } - - var result *PgResult - rows := [][][]byte{} - for resultReader.NextRow() { - row := make([][]byte, len(resultReader.Values())) - copy(row, resultReader.Values()) - rows = append(rows, row) - } - - commandTag, err := resultReader.Close() - if err != nil { - return nil, err - } - - result = &PgResult{ - Rows: rows, - CommandTag: commandTag, - } - - return result, nil + return pgConn.bufferLastResult(ctx) } // Prepare creates a prepared statement. From 8df3f2010f3b448bcbf5499e889df94223c8d7fd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 13:47:37 -0600 Subject: [PATCH 021/290] Avoid allocating strings in common message types --- pgconn.go | 62 +++++++++++++++++++++++++------------------------------ 1 file changed, 28 insertions(+), 34 deletions(-) diff --git a/pgconn.go b/pgconn.go index ec5413de..df823042 100644 --- a/pgconn.go +++ b/pgconn.go @@ -199,25 +199,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig // handled by ReceiveMessage case *pgproto3.ErrorResponse: pgConn.conn.Close() - return nil, &PgError{ - Severity: msg.Severity, - Code: msg.Code, - Message: msg.Message, - Detail: msg.Detail, - Hint: msg.Hint, - Position: msg.Position, - InternalPosition: msg.InternalPosition, - InternalQuery: msg.InternalQuery, - Where: msg.Where, - SchemaName: msg.SchemaName, - TableName: msg.TableName, - ColumnName: msg.ColumnName, - DataTypeName: msg.DataTypeName, - ConstraintName: msg.ConstraintName, - File: msg.File, - Line: msg.Line, - Routine: msg.Routine, - } + return nil, errorResponseToPgError(msg) default: pgConn.conn.Close() return nil, errors.New("unexpected message") @@ -348,7 +330,7 @@ func (pgConn *PgConn) ParameterStatus(key string) string { } // CommandTag is the result of an Exec function -type CommandTag string +type CommandTag []byte // RowsAffected returns the number of rows affected. If the CommandTag was not // for a row affecting command (e.g. "CREATE TABLE") then it returns 0. @@ -362,6 +344,10 @@ func (ct CommandTag) RowsAffected() int64 { return n } +func (ct CommandTag) String() string { + return string(ct) +} + // SendExec enqueues the execution of sql via the PostgreSQL simple query protocol. sql may contain multiple queries. // Execution is implicitly wrapped in a transactions unless a transaction is already in progress or sql contains // transaction control statements. It is only sent to the PostgreSQL server when Flush is called. @@ -511,6 +497,7 @@ func (pgConn *PgConn) SendExecParams(sql string, paramValues [][]byte, paramOIDs } pgConn.batchBuf = appendParse(pgConn.batchBuf, "", sql, paramOIDs) + pgConn.batchBuf = appendDescribe(pgConn.batchBuf, 'S', "") pgConn.batchBuf = appendBind(pgConn.batchBuf, "", "", paramFormats, paramValues, resultFormats) pgConn.batchBuf = appendExecute(pgConn.batchBuf, "", 0) pgConn.batchBuf = appendSync(pgConn.batchBuf) @@ -530,6 +517,7 @@ func (pgConn *PgConn) SendExecParams(sql string, paramValues [][]byte, paramOIDs // // Query is only sent to the PostgreSQL server when Flush is called. func (pgConn *PgConn) SendExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { + pgConn.batchBuf = appendDescribe(pgConn.batchBuf, 'S', stmtName) pgConn.batchBuf = appendBind(pgConn.batchBuf, "", stmtName, paramFormats, paramValues, resultFormats) pgConn.batchBuf = appendExecute(pgConn.batchBuf, "", 0) pgConn.batchBuf = appendSync(pgConn.batchBuf) @@ -616,6 +604,12 @@ func (rr *PgResultReader) NextRow() bool { } } +// FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until +// the PgResultReader is closed. +func (rr *PgResultReader) FieldDescriptions() []pgproto3.FieldDescription { + return rr.fieldDescriptions +} + // Values returns the current row data. NextRow must have been previously been called. The returned [][]byte is only // valid until the next NextRow call or the PgResultReader is closed. However, the underlying byte data is safe to // retain a reference to and mutate. @@ -914,23 +908,23 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { return &PgError{ - Severity: msg.Severity, - Code: msg.Code, - Message: msg.Message, - Detail: msg.Detail, - Hint: msg.Hint, + Severity: string(msg.Severity), + Code: string(msg.Code), + Message: string(msg.Message), + Detail: string(msg.Detail), + Hint: string(msg.Hint), Position: msg.Position, InternalPosition: msg.InternalPosition, - InternalQuery: msg.InternalQuery, - Where: msg.Where, - SchemaName: msg.SchemaName, - TableName: msg.TableName, - ColumnName: msg.ColumnName, - DataTypeName: msg.DataTypeName, - ConstraintName: msg.ConstraintName, - File: msg.File, + InternalQuery: string(msg.InternalQuery), + Where: string(msg.Where), + SchemaName: string(msg.SchemaName), + TableName: string(msg.TableName), + ColumnName: string(msg.ColumnName), + DataTypeName: string(msg.DataTypeName), + ConstraintName: string(msg.ConstraintName), + File: string(msg.File), Line: msg.Line, - Routine: msg.Routine, + Routine: string(msg.Routine), } } From 4f00c6aebdee5322c29a5672eef1fef53d058bdb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 13:49:12 -0600 Subject: [PATCH 022/290] Add pgconn stress test --- pgconn_stress_test.go | 199 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 199 insertions(+) create mode 100644 pgconn_stress_test.go diff --git a/pgconn_stress_test.go b/pgconn_stress_test.go new file mode 100644 index 00000000..cc6acab8 --- /dev/null +++ b/pgconn_stress_test.go @@ -0,0 +1,199 @@ +package pgconn_test + +import ( + "context" + "math/rand" + "os" + "strconv" + "testing" + "time" + + "github.com/jackc/pgx/pgconn" + "github.com/pkg/errors" + + "github.com/stretchr/testify/require" +) + +func TestConnStress(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + actionCount := 100 + if s := os.Getenv("PTX_TEST_STRESS_FACTOR"); s != "" { + stressFactor, err := strconv.ParseInt(s, 10, 64) + require.Nil(t, err, "Failed to parse PTX_TEST_STRESS_FACTOR") + actionCount *= int(stressFactor) + } + + setupStressDB(t, pgConn) + + actions := []struct { + name string + fn func(*pgconn.PgConn) error + }{ + {"Exec Select", stressExecSelect}, + {"ExecParams Select", stressExecParamsSelect}, + {"Batch", stressBatch}, + {"ExecCanceled", stressExecSelectCanceled}, + {"ExecParamsCanceled", stressExecParamsSelectCanceled}, + {"BatchCanceled", stressBatchCanceled}, + } + + for i := 0; i < actionCount; i++ { + action := actions[rand.Intn(len(actions))] + err := action.fn(pgConn) + require.Nilf(t, err, "%d: %s", i, action.name) + } +} + +func setupStressDB(t *testing.T, pgConn *pgconn.PgConn) { + _, err := pgConn.Exec(context.Background(), ` + create temporary table widgets( + id serial primary key, + name varchar not null, + description text, + creation_time timestamptz default now() + ); + + insert into widgets(name, description) values + ('Foo', 'bar'), + ('baz', 'Something really long Something really long Something really long Something really long Something really long'), + ('a', 'b')`) + require.Nil(t, err) +} + +func stressExecSelect(pgConn *pgconn.PgConn) error { + _, err := pgConn.Exec(context.Background(), "select * from widgets") + return err +} + +func stressExecParamsSelect(pgConn *pgconn.PgConn) error { + _, err := pgConn.ExecParams(context.Background(), "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) + return err +} + +func stressBatch(pgConn *pgconn.PgConn) error { + pgConn.SendExec("select * from widgets") + pgConn.SendExecParams("select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) + err := pgConn.Flush(context.Background()) + if err != nil { + return err + } + + // Query 1 + resultReader := pgConn.GetResult(context.Background()) + if resultReader == nil { + return errors.New("missing resultReader") + } + + for resultReader.NextRow() { + } + _, err = resultReader.Close() + if err != nil { + return err + } + + // Query 2 + resultReader = pgConn.GetResult(context.Background()) + if resultReader == nil { + return errors.New("missing resultReader") + } + + for resultReader.NextRow() { + } + _, err = resultReader.Close() + if err != nil { + return err + } + + // No more + resultReader = pgConn.GetResult(context.Background()) + if resultReader != nil { + return errors.New("unexpected result reader") + } + + return nil +} + +func stressExecSelectCanceled(pgConn *pgconn.PgConn) error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + _, err := pgConn.Exec(ctx, "select *, pg_sleep(1) from widgets") + cancel() + if err != context.DeadlineExceeded { + return err + } + + ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond) + recovered := pgConn.RecoverFromTimeout(ctx) + cancel() + if !recovered { + return errors.New("unable to recover from timeout") + } + return nil +} + +func stressExecParamsSelectCanceled(pgConn *pgconn.PgConn) error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + _, err := pgConn.ExecParams(ctx, "select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) + cancel() + if err != context.DeadlineExceeded { + return err + } + + ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond) + recovered := pgConn.RecoverFromTimeout(ctx) + cancel() + if !recovered { + return errors.New("unable to recover from timeout") + } + return nil +} + +func stressBatchCanceled(pgConn *pgconn.PgConn) error { + + pgConn.SendExec("select * from widgets") + pgConn.SendExecParams("select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) + err := pgConn.Flush(context.Background()) + if err != nil { + return err + } + + // Query 1 + resultReader := pgConn.GetResult(context.Background()) + if resultReader == nil { + return errors.New("missing resultReader") + } + + for resultReader.NextRow() { + } + _, err = resultReader.Close() + if err != nil { + return err + } + + // Query 2 + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + resultReader = pgConn.GetResult(ctx) + cancel() + if resultReader == nil { + return errors.New("missing resultReader") + } + + for resultReader.NextRow() { + } + _, err = resultReader.Close() + if err != context.DeadlineExceeded { + return err + } + + ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond) + recovered := pgConn.RecoverFromTimeout(ctx) + cancel() + if !recovered { + return errors.New("unable to recover from timeout") + } + return nil +} From 9af9f57f1575f0ae0c35a473f9125056acb90cba Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 13:56:09 -0600 Subject: [PATCH 023/290] Remove another allocation --- pgconn.go | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/pgconn.go b/pgconn.go index df823042..d9755f6c 100644 --- a/pgconn.go +++ b/pgconn.go @@ -73,6 +73,8 @@ type PgConn struct { pendingReadyForQueryCount int32 closed bool + + resultReader PgResultReader } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) @@ -536,9 +538,9 @@ type PgResultReader struct { cleanupContext func() } -// GetResult returns a PgResultReader for the next result. If all results are -// consumed it returns nil. If an error occurs it will be reported on the -// returned PgResultReader. +// GetResult returns a PgResultReader for the next result. If all results are consumed it returns nil. If an error +// occurs it will be reported on the returned PgResultReader. Returned PgResultReader is only valid until next call of +// GetResult. func (pgConn *PgConn) GetResult(ctx context.Context) *PgResultReader { cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) @@ -546,20 +548,25 @@ func (pgConn *PgConn) GetResult(ctx context.Context) *PgResultReader { msg, err := pgConn.ReceiveMessage() if err != nil { cleanupContext() - return &PgResultReader{pgConn: pgConn, ctx: ctx, err: preferContextOverNetTimeoutError(ctx, err), complete: true} + pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, err: preferContextOverNetTimeoutError(ctx, err), complete: true} + return &pgConn.resultReader } switch msg := msg.(type) { case *pgproto3.RowDescription: - return &PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, fieldDescriptions: msg.Fields} + pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, fieldDescriptions: msg.Fields} + return &pgConn.resultReader case *pgproto3.DataRow: - return &PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, rowValues: msg.Values, preloadedRowValues: true} + pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, rowValues: msg.Values, preloadedRowValues: true} + return &pgConn.resultReader case *pgproto3.CommandComplete: cleanupContext() - return &PgResultReader{pgConn: pgConn, ctx: ctx, commandTag: CommandTag(msg.CommandTag), complete: true} + pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, commandTag: CommandTag(msg.CommandTag), complete: true} + return &pgConn.resultReader case *pgproto3.ErrorResponse: cleanupContext() - return &PgResultReader{pgConn: pgConn, ctx: ctx, err: errorResponseToPgError(msg), complete: true} + pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, err: errorResponseToPgError(msg), complete: true} + return &pgConn.resultReader } } From 914766af9b867dd946da258bb19269a632ba5296 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 14:10:16 -0600 Subject: [PATCH 024/290] Use result readers in next/get fashion --- pgconn.go | 27 ++++++++++++++++----------- pgconn_stress_test.go | 29 ++++++++++++++--------------- pgconn_test.go | 26 +++++++++++++------------- 3 files changed, 43 insertions(+), 39 deletions(-) diff --git a/pgconn.go b/pgconn.go index d9755f6c..8511d5b9 100644 --- a/pgconn.go +++ b/pgconn.go @@ -538,10 +538,9 @@ type PgResultReader struct { cleanupContext func() } -// GetResult returns a PgResultReader for the next result. If all results are consumed it returns nil. If an error -// occurs it will be reported on the returned PgResultReader. Returned PgResultReader is only valid until next call of -// GetResult. -func (pgConn *PgConn) GetResult(ctx context.Context) *PgResultReader { +// NextResult reads until a result is ready to be read or no results are pending. Returns true if a result is available. +// Use ResultReader() to acquire a reader for the result. +func (pgConn *PgConn) NextResult(ctx context.Context) bool { cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) for pgConn.pendingReadyForQueryCount > 0 { @@ -549,29 +548,34 @@ func (pgConn *PgConn) GetResult(ctx context.Context) *PgResultReader { if err != nil { cleanupContext() pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, err: preferContextOverNetTimeoutError(ctx, err), complete: true} - return &pgConn.resultReader + return true } switch msg := msg.(type) { case *pgproto3.RowDescription: pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, fieldDescriptions: msg.Fields} - return &pgConn.resultReader + return true case *pgproto3.DataRow: pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, rowValues: msg.Values, preloadedRowValues: true} - return &pgConn.resultReader + return true case *pgproto3.CommandComplete: cleanupContext() pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, commandTag: CommandTag(msg.CommandTag), complete: true} - return &pgConn.resultReader + return true case *pgproto3.ErrorResponse: cleanupContext() pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, err: errorResponseToPgError(msg), complete: true} - return &pgConn.resultReader + return true } } cleanupContext() - return nil + return false +} + +// ResultReader returns the result reader prepared by next result. It is only valid until the result is completed. +func (pgConn *PgConn) ResultReader() *PgResultReader { + return &pgConn.resultReader } // NextRow returns advances the PgResultReader to the next row and returns true if a row is available. @@ -806,7 +810,8 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) { func (pgConn *PgConn) bufferLastResult(ctx context.Context) (*PgResult, error) { var result *PgResult - for resultReader := pgConn.GetResult(ctx); resultReader != nil; resultReader = pgConn.GetResult(ctx) { + for pgConn.NextResult(ctx) { + resultReader := pgConn.ResultReader() rows := [][][]byte{} for resultReader.NextRow() { row := make([][]byte, len(resultReader.Values())) diff --git a/pgconn_stress_test.go b/pgconn_stress_test.go index cc6acab8..9aa94539 100644 --- a/pgconn_stress_test.go +++ b/pgconn_stress_test.go @@ -84,10 +84,10 @@ func stressBatch(pgConn *pgconn.PgConn) error { } // Query 1 - resultReader := pgConn.GetResult(context.Background()) - if resultReader == nil { - return errors.New("missing resultReader") + if !pgConn.NextResult(context.Background()) { + return errors.New("missing result") } + resultReader := pgConn.ResultReader() for resultReader.NextRow() { } @@ -97,10 +97,10 @@ func stressBatch(pgConn *pgconn.PgConn) error { } // Query 2 - resultReader = pgConn.GetResult(context.Background()) - if resultReader == nil { - return errors.New("missing resultReader") + if !pgConn.NextResult(context.Background()) { + return errors.New("missing result") } + resultReader = pgConn.ResultReader() for resultReader.NextRow() { } @@ -110,8 +110,7 @@ func stressBatch(pgConn *pgconn.PgConn) error { } // No more - resultReader = pgConn.GetResult(context.Background()) - if resultReader != nil { + if pgConn.NextResult(context.Background()) { return errors.New("unexpected result reader") } @@ -162,10 +161,10 @@ func stressBatchCanceled(pgConn *pgconn.PgConn) error { } // Query 1 - resultReader := pgConn.GetResult(context.Background()) - if resultReader == nil { - return errors.New("missing resultReader") + if !pgConn.NextResult(context.Background()) { + return errors.New("missing result") } + resultReader := pgConn.ResultReader() for resultReader.NextRow() { } @@ -176,11 +175,11 @@ func stressBatchCanceled(pgConn *pgconn.PgConn) error { // Query 2 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) - resultReader = pgConn.GetResult(ctx) - cancel() - if resultReader == nil { - return errors.New("missing resultReader") + if !pgConn.NextResult(ctx) { + return errors.New("missing result") } + cancel() + resultReader = pgConn.ResultReader() for resultReader.NextRow() { } diff --git a/pgconn_test.go b/pgconn_test.go index 35f5b536..8b578d42 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -373,8 +373,8 @@ func TestConnBatchedQueries(t *testing.T) { err = pgConn.Flush(context.Background()) // "select 'SendExec 1'" - resultReader := pgConn.GetResult(context.Background()) - require.NotNil(t, resultReader) + require.True(t, pgConn.NextResult(context.Background())) + resultReader := pgConn.ResultReader() rows := [][][]byte{} for resultReader.NextRow() { @@ -391,8 +391,8 @@ func TestConnBatchedQueries(t *testing.T) { assert.Nil(t, err) // "SendExecParams 1" - resultReader = pgConn.GetResult(context.Background()) - require.NotNil(t, resultReader) + require.True(t, pgConn.NextResult(context.Background())) + resultReader = pgConn.ResultReader() rows = [][][]byte{} for resultReader.NextRow() { @@ -409,8 +409,8 @@ func TestConnBatchedQueries(t *testing.T) { assert.Nil(t, err) // "SendExecPrepared 1" - resultReader = pgConn.GetResult(context.Background()) - require.NotNil(t, resultReader) + require.True(t, pgConn.NextResult(context.Background())) + resultReader = pgConn.ResultReader() rows = [][][]byte{} for resultReader.NextRow() { @@ -427,8 +427,8 @@ func TestConnBatchedQueries(t *testing.T) { assert.Nil(t, err) // "SendExec 2" - resultReader = pgConn.GetResult(context.Background()) - require.NotNil(t, resultReader) + require.True(t, pgConn.NextResult(context.Background())) + resultReader = pgConn.ResultReader() rows = [][][]byte{} for resultReader.NextRow() { @@ -445,8 +445,8 @@ func TestConnBatchedQueries(t *testing.T) { assert.Nil(t, err) // "SendExecParams 2" - resultReader = pgConn.GetResult(context.Background()) - require.NotNil(t, resultReader) + require.True(t, pgConn.NextResult(context.Background())) + resultReader = pgConn.ResultReader() rows = [][][]byte{} for resultReader.NextRow() { @@ -463,8 +463,7 @@ func TestConnBatchedQueries(t *testing.T) { assert.Nil(t, err) // Done - resultReader = pgConn.GetResult(context.Background()) - assert.Nil(t, resultReader) + require.False(t, pgConn.NextResult(context.Background())) } func TestConnRecoverFromTimeout(t *testing.T) { @@ -505,7 +504,8 @@ func TestConnCancelQuery(t *testing.T) { err = pgConn.CancelRequest(context.Background()) require.Nil(t, err) - _, err = pgConn.GetResult(context.Background()).Close() + require.True(t, pgConn.NextResult(context.Background())) + _, err = pgConn.ResultReader().Close() if err, ok := err.(*pgconn.PgError); ok { assert.Equal(t, "57014", err.Code) } else { From bd2a5d97d0f850520f40ef02a5d328f39fd94d7f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 14:10:24 -0600 Subject: [PATCH 025/290] Add benchmark to pgconn --- benchmark_test.go | 52 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 benchmark_test.go diff --git a/benchmark_test.go b/benchmark_test.go new file mode 100644 index 00000000..da5bd4fc --- /dev/null +++ b/benchmark_test.go @@ -0,0 +1,52 @@ +package pgconn_test + +import ( + "context" + "os" + "testing" + + "github.com/jackc/pgx/pgconn" + "github.com/stretchr/testify/require" +) + +func BenchmarkConnect(b *testing.B) { + benchmarks := []struct { + name string + env string + }{ + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, + {"TCP", "PGX_TEST_TCP_CONN_STRING"}, + } + + for _, bm := range benchmarks { + b.Run(bm.name, func(b *testing.B) { + connString := os.Getenv(bm.env) + if connString == "" { + b.Skipf("Skipping due to missing environment variable %v", bm.env) + } + + for i := 0; i < b.N; i++ { + conn, err := pgconn.Connect(context.Background(), connString) + require.Nil(b, err) + + err = conn.Close(context.Background()) + require.Nil(b, err) + } + }) + } +} + +func BenchmarkExecPrepared(b *testing.B) { + conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(b, err) + defer closeConn(b, conn) + + err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, err := conn.ExecPrepared(context.Background(), "ps1", nil, nil, nil) + require.Nil(b, err) + } +} From 11964a6ec38e6b899e826ffe721817e634665c81 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 14:17:17 -0600 Subject: [PATCH 026/290] Add non-buffered benchmark --- benchmark_test.go | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/benchmark_test.go b/benchmark_test.go index da5bd4fc..aff21216 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -50,3 +50,24 @@ func BenchmarkExecPrepared(b *testing.B) { require.Nil(b, err) } } + +func BenchmarkSendExecPrepared(b *testing.B) { + conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(b, err) + defer closeConn(b, conn) + + err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + conn.SendExecPrepared("ps1", nil, nil, nil) + err := conn.Flush(context.Background()) + require.Nil(b, err) + + for conn.NextResult(context.Background()) { + _, err := conn.ResultReader().Close() + require.Nil(b, err) + } + } +} From fdbf2ba728332987aa341ecef2a2b9e0232b5654 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 14:32:42 -0600 Subject: [PATCH 027/290] Use pgproto3 instead of custom message encoders --- benchmark_test.go | 13 +++++ pgconn.go | 141 ++++------------------------------------------ 2 files changed, 23 insertions(+), 131 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index aff21216..bdc550cb 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -36,6 +36,19 @@ func BenchmarkConnect(b *testing.B) { } } +func BenchmarkExec(b *testing.B) { + conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(b, err) + defer closeConn(b, conn) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, err := conn.Exec(context.Background(), "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") + require.Nil(b, err) + } +} + func BenchmarkExecPrepared(b *testing.B) { conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(b, err) diff --git a/pgconn.go b/pgconn.go index 8511d5b9..d7a99676 100644 --- a/pgconn.go +++ b/pgconn.go @@ -14,7 +14,6 @@ import ( "strings" "time" - "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgproto3" ) @@ -354,127 +353,10 @@ func (ct CommandTag) String() string { // Execution is implicitly wrapped in a transactions unless a transaction is already in progress or sql contains // transaction control statements. It is only sent to the PostgreSQL server when Flush is called. func (pgConn *PgConn) SendExec(sql string) { - pgConn.batchBuf = appendQuery(pgConn.batchBuf, sql) + pgConn.batchBuf = (&pgproto3.Query{String: sql}).Encode(pgConn.batchBuf) pgConn.batchCount += 1 } -// appendQuery appends a PostgreSQL wire protocol query message to buf and returns it. -func appendQuery(buf []byte, query string) []byte { - buf = append(buf, 'Q') - buf = pgio.AppendInt32(buf, int32(len(query)+5)) - buf = append(buf, query...) - buf = append(buf, 0) - return buf -} - -// appendParse appends a PostgreSQL wire protocol parse message to buf and returns it. -func appendParse(buf []byte, name string, query string, paramOIDs []uint32) []byte { - if len(paramOIDs) > 65535 { - panic(fmt.Sprintf("len(paramOIDs) must be between 0 and 65535, received %d", len(paramOIDs))) - } - - buf = append(buf, 'P') - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - buf = append(buf, name...) - buf = append(buf, 0) - buf = append(buf, query...) - buf = append(buf, 0) - - buf = pgio.AppendInt16(buf, int16(len(paramOIDs))) - for _, oid := range paramOIDs { - buf = pgio.AppendUint32(buf, oid) - } - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - - return buf -} - -// appendDescribe appends a PostgreSQL wire protocol describe message to buf and returns it. -func appendDescribe(buf []byte, objectType byte, name string) []byte { - buf = append(buf, 'D') - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - buf = append(buf, objectType) - buf = append(buf, name...) - buf = append(buf, 0) - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - - return buf -} - -// appendSync appends a PostgreSQL wire protocol sync message to buf and returns it. -func appendSync(buf []byte) []byte { - buf = append(buf, 'S') - buf = pgio.AppendInt32(buf, 4) - - return buf -} - -// appendBind appends a PostgreSQL wire protocol bind message to buf and returns it. -func appendBind( - buf []byte, - destinationPortal, - preparedStatement string, - paramFormats []int16, - paramValues [][]byte, - resultFormatCodes []int16, -) []byte { - if len(paramFormats) != 0 && len(paramFormats) != len(paramValues) && len(paramFormats) != len(paramValues) { - panic(fmt.Sprintf("len(paramFormats) must be 0, 1, or len(paramValues), received %d", len(paramFormats))) - } - if len(paramValues) > 65535 { - panic(fmt.Sprintf("len(paramValues) must be between 0 and 65535, received %d", len(paramValues))) - } - - buf = append(buf, 'B') - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - buf = append(buf, destinationPortal...) - buf = append(buf, 0) - buf = append(buf, preparedStatement...) - buf = append(buf, 0) - - buf = pgio.AppendInt16(buf, int16(len(paramFormats))) - for _, f := range paramFormats { - buf = pgio.AppendInt16(buf, f) - } - - buf = pgio.AppendInt16(buf, int16(len(paramValues))) - for _, p := range paramValues { - if p == nil { - buf = pgio.AppendInt32(buf, -1) - continue - } - - buf = pgio.AppendInt32(buf, int32(len(p))) - buf = append(buf, p...) - } - - buf = pgio.AppendInt16(buf, int16(len(resultFormatCodes))) - for _, fc := range resultFormatCodes { - buf = pgio.AppendInt16(buf, fc) - } - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - - return buf -} - -// appendExecute appends a PostgreSQL wire protocol execute message to buf and returns it. -func appendExecute(buf []byte, portal string, maxRows uint32) []byte { - buf = append(buf, 'E') - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - - buf = append(buf, portal...) - buf = append(buf, 0) - buf = pgio.AppendUint32(buf, maxRows) - - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - - return buf -} - // SendExecParams enqueues the execution of sql via the PostgreSQL extended query protocol. // // sql is a SQL command string. It may only contain one query. Parameter substitution is position using $1, $2, $3, etc. @@ -498,11 +380,8 @@ func (pgConn *PgConn) SendExecParams(sql string, paramValues [][]byte, paramOIDs panic(fmt.Sprintf("len(paramOIDs) must be 0, 1, or len(paramValues), received %d", len(paramOIDs))) } - pgConn.batchBuf = appendParse(pgConn.batchBuf, "", sql, paramOIDs) - pgConn.batchBuf = appendDescribe(pgConn.batchBuf, 'S', "") - pgConn.batchBuf = appendBind(pgConn.batchBuf, "", "", paramFormats, paramValues, resultFormats) - pgConn.batchBuf = appendExecute(pgConn.batchBuf, "", 0) - pgConn.batchBuf = appendSync(pgConn.batchBuf) + pgConn.batchBuf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(pgConn.batchBuf) + pgConn.SendExecPrepared("", paramValues, paramFormats, resultFormats) pgConn.batchCount += 1 } @@ -519,10 +398,10 @@ func (pgConn *PgConn) SendExecParams(sql string, paramValues [][]byte, paramOIDs // // Query is only sent to the PostgreSQL server when Flush is called. func (pgConn *PgConn) SendExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { - pgConn.batchBuf = appendDescribe(pgConn.batchBuf, 'S', stmtName) - pgConn.batchBuf = appendBind(pgConn.batchBuf, "", stmtName, paramFormats, paramValues, resultFormats) - pgConn.batchBuf = appendExecute(pgConn.batchBuf, "", 0) - pgConn.batchBuf = appendSync(pgConn.batchBuf) + pgConn.batchBuf = (&pgproto3.Describe{ObjectType: 'S', Name: stmtName}).Encode(pgConn.batchBuf) + pgConn.batchBuf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(pgConn.batchBuf) + pgConn.batchBuf = (&pgproto3.Execute{}).Encode(pgConn.batchBuf) + pgConn.batchBuf = (&pgproto3.Sync{}).Encode(pgConn.batchBuf) pgConn.batchCount += 1 } @@ -890,9 +769,9 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanupContext() - pgConn.batchBuf = appendParse(pgConn.batchBuf, name, sql, paramOIDs) - pgConn.batchBuf = appendDescribe(pgConn.batchBuf, 'S', name) - pgConn.batchBuf = appendSync(pgConn.batchBuf) + pgConn.batchBuf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(pgConn.batchBuf) + pgConn.batchBuf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(pgConn.batchBuf) + pgConn.batchBuf = (&pgproto3.Sync{}).Encode(pgConn.batchBuf) pgConn.batchCount += 1 err := pgConn.Flush(context.Background()) if err != nil { From 7986e2726d1679e78d3cce4c3df19e3f7bd3a866 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 16:55:48 -0600 Subject: [PATCH 028/290] pgx uses pgconn.CommandTag instead of own definition --- pgconn_test.go | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/pgconn_test.go b/pgconn_test.go index 8b578d42..8f976d87 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -512,3 +512,26 @@ func TestConnCancelQuery(t *testing.T) { t.Errorf("expected pgconn.PgError got %v", err) } } + +func TestCommandTag(t *testing.T) { + t.Parallel() + + var tests = []struct { + commandTag pgconn.CommandTag + rowsAffected int64 + }{ + {commandTag: pgconn.CommandTag("INSERT 0 5"), rowsAffected: 5}, + {commandTag: pgconn.CommandTag("UPDATE 0"), rowsAffected: 0}, + {commandTag: pgconn.CommandTag("UPDATE 1"), rowsAffected: 1}, + {commandTag: pgconn.CommandTag("DELETE 0"), rowsAffected: 0}, + {commandTag: pgconn.CommandTag("DELETE 1"), rowsAffected: 1}, + {commandTag: pgconn.CommandTag("CREATE TABLE"), rowsAffected: 0}, + {commandTag: pgconn.CommandTag("ALTER TABLE"), rowsAffected: 0}, + {commandTag: pgconn.CommandTag("DROP TABLE"), rowsAffected: 0}, + } + + for i, tt := range tests { + actual := tt.commandTag.RowsAffected() + assert.Equalf(t, tt.rowsAffected, actual, "%d. %v", i, tt.commandTag) + } +} From 547741ae6aac6cec4b7c7a724d746a30c2ade465 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 17:08:56 -0600 Subject: [PATCH 029/290] Fix bug with ready for query counter --- pgconn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index d7a99676..1e70a82b 100644 --- a/pgconn.go +++ b/pgconn.go @@ -382,7 +382,6 @@ func (pgConn *PgConn) SendExecParams(sql string, paramValues [][]byte, paramOIDs pgConn.batchBuf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(pgConn.batchBuf) pgConn.SendExecPrepared("", paramValues, paramFormats, resultFormats) - pgConn.batchCount += 1 } // SendExecPrepared enqueues the execution of a prepared statement via the PostgreSQL extended query protocol. @@ -708,6 +707,7 @@ func (pgConn *PgConn) bufferLastResult(ctx context.Context) (*PgResult, error) { CommandTag: commandTag, } } + if result == nil { return nil, errors.New("unexpected missing result") } From d545e0704e4a57498c357254d3cfb8b528d19697 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 1 Jan 2019 18:03:20 -0600 Subject: [PATCH 030/290] Prepare returns description --- benchmark_test.go | 4 ++-- pgconn.go | 53 +++++++++++++++++++++++++++++++++++++++-------- pgconn_test.go | 9 +++++--- 3 files changed, 52 insertions(+), 14 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index bdc550cb..269ac59b 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -54,7 +54,7 @@ func BenchmarkExecPrepared(b *testing.B) { require.Nil(b, err) defer closeConn(b, conn) - err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) + _, err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) b.ResetTimer() @@ -69,7 +69,7 @@ func BenchmarkSendExecPrepared(b *testing.B) { require.Nil(b, err) defer closeConn(b, conn) - err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) + _, err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) b.ResetTimer() diff --git a/pgconn.go b/pgconn.go index 1e70a82b..de7020b2 100644 --- a/pgconn.go +++ b/pgconn.go @@ -757,13 +757,42 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa return pgConn.bufferLastResult(ctx) } +type FieldDescription struct { + Name string + TableOID uint32 + TableAttributeNumber uint16 + DataTypeOID uint32 + DataTypeSize int16 + TypeModifier int32 + FormatCode int16 +} + +// pgproto3FieldDescriptionToPgconnFieldDescription copies and converts the data from a pgproto3.FieldDescription to a +// FieldDescription. +func pgproto3FieldDescriptionToPgconnFieldDescription(src *pgproto3.FieldDescription, dst *FieldDescription) { + dst.Name = string(src.Name) + dst.TableOID = src.TableOID + dst.TableAttributeNumber = src.TableAttributeNumber + dst.DataTypeOID = src.DataTypeOID + dst.DataTypeSize = src.DataTypeSize + dst.TypeModifier = src.TypeModifier + dst.FormatCode = src.Format +} + +type PreparedStatementDescription struct { + Name string + SQL string + ParamOIDs []uint32 + Fields []FieldDescription +} + // Prepare creates a prepared statement. -func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) error { +func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) { if pgConn.batchCount != 0 { - return errors.New("unflushed previous sends") + return nil, errors.New("unflushed previous sends") } if pgConn.pendingReadyForQueryCount != 0 { - return errors.New("unread previous results") + return nil, errors.New("unread previous results") } cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) @@ -775,26 +804,32 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ pgConn.batchCount += 1 err := pgConn.Flush(context.Background()) if err != nil { - return preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } + psd := &PreparedStatementDescription{Name: name, SQL: sql} + for pgConn.pendingReadyForQueryCount > 0 { msg, err := pgConn.ReceiveMessage() if err != nil { - return preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { case *pgproto3.ParameterDescription: - // TODO + psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) + copy(psd.ParamOIDs, msg.ParameterOIDs) case *pgproto3.RowDescription: - // TODO + psd.Fields = make([]FieldDescription, len(msg.Fields)) + for i := range msg.Fields { + pgproto3FieldDescriptionToPgconnFieldDescription(&msg.Fields[i], &psd.Fields[i]) + } case *pgproto3.ErrorResponse: - return errorResponseToPgError(msg) + return nil, errorResponseToPgError(msg) } } - return nil + return psd, nil } func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { diff --git a/pgconn_test.go b/pgconn_test.go index 8f976d87..ee573d42 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -327,8 +327,11 @@ func TestConnExecPrepared(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) - err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) require.Nil(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, 1) + assert.Len(t, psd.Fields, 1) result, err := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) require.Nil(t, err) @@ -343,7 +346,7 @@ func TestConnExecPreparedCanceled(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) - err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil) + _, err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil) require.Nil(t, err) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) @@ -362,7 +365,7 @@ func TestConnBatchedQueries(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) - err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + _, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) require.Nil(t, err) pgConn.SendExec("select 'SendExec 1'") From 6d2fa9c5cf5f09b696faf2597341c132139258fb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 2 Jan 2019 12:28:11 -0600 Subject: [PATCH 031/290] Handle empty query response --- pgconn.go | 4 ++++ pgconn_test.go | 13 +++++++++++++ 2 files changed, 17 insertions(+) diff --git a/pgconn.go b/pgconn.go index de7020b2..b3abe8e0 100644 --- a/pgconn.go +++ b/pgconn.go @@ -440,6 +440,10 @@ func (pgConn *PgConn) NextResult(ctx context.Context) bool { cleanupContext() pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, commandTag: CommandTag(msg.CommandTag), complete: true} return true + case *pgproto3.EmptyQueryResponse: + cleanupContext() + pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, complete: true} + return true case *pgproto3.ErrorResponse: cleanupContext() pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, err: errorResponseToPgError(msg), complete: true} diff --git a/pgconn_test.go b/pgconn_test.go index ee573d42..8d6b606a 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -245,6 +245,19 @@ func TestConnExec(t *testing.T) { assert.Equal(t, pgConn.Config.Database, string(result.Rows[0][0])) } +func TestConnExecEmpty(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + result, err := pgConn.Exec(context.Background(), ";") + require.Nil(t, err) + assert.Nil(t, result.CommandTag) + assert.Equal(t, 0, len(result.Rows)) +} + func TestConnExecMultipleQueries(t *testing.T) { t.Parallel() From 460946d66256f11b10bbd3b3bac99def937143f2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 2 Jan 2019 13:14:34 -0600 Subject: [PATCH 032/290] Move notice handling to pgconn --- config.go | 2 ++ pgconn.go | 19 +++++++++++++++++++ pgconn_test.go | 23 +++++++++++++++++++++++ 3 files changed, 44 insertions(+) diff --git a/config.go b/config.go index d8872f66..bd1fec9b 100644 --- a/config.go +++ b/config.go @@ -40,6 +40,8 @@ type Config struct { // server is acceptable. If this returns an error the connection is closed and the next fallback config is tried. This // allows implementing high availability behavior such as libpq does with target_session_attrs. AfterConnectFunc AfterConnectFunc + + OnNotice NoticeHandler // Callback function called when a notice response is received. } // FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a diff --git a/pgconn.go b/pgconn.go index b3abe8e0..6b6330dc 100644 --- a/pgconn.go +++ b/pgconn.go @@ -48,9 +48,19 @@ func (pe *PgError) Error() string { return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" } +// Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from +// LISTEN/NOTIFY notification. +type Notice PgError + // DialFunc is a function that can be used to connect to a PostgreSQL server type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) +// NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at +// any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin +// of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY +// notification. +type NoticeHandler func(*PgConn, *Notice) + // ErrTLSRefused occurs when the connection attempt requires TLS and the // PostgreSQL server refuses to use TLS var ErrTLSRefused = errors.New("server refused TLS connection") @@ -277,6 +287,10 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { // TODO - close pgConn return nil, errorResponseToPgError(msg) } + case *pgproto3.NoticeResponse: + if pgConn.Config.OnNotice != nil { + pgConn.Config.OnNotice(pgConn, noticeResponseToNotice(msg)) + } } return msg, nil @@ -858,6 +872,11 @@ func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { } } +func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice { + pgerr := errorResponseToPgError((*pgproto3.ErrorResponse)(msg)) + return (*Notice)(pgerr) +} + // CancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel // request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there // is no way to be sure a query was canceled. See https://www.postgresql.org/docs/11/protocol-flow.html#id-1.10.5.7.9 diff --git a/pgconn_test.go b/pgconn_test.go index 8d6b606a..98ec9664 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -551,3 +551,26 @@ func TestCommandTag(t *testing.T) { assert.Equalf(t, tt.rowsAffected, actual, "%d. %v", i, tt.commandTag) } } + +func TestConnOnNotice(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + + var msg string + config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) { + msg = notice.Message + } + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `do $$ +begin + raise notice 'hello, world'; +end$$;`) + require.Nil(t, err) + assert.Equal(t, "hello, world", msg) +} From b213299a9261bb9845b0785a9adcc3d7aebe12f6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 2 Jan 2019 13:59:00 -0600 Subject: [PATCH 033/290] Add ensureReadyForQuery to pgconn --- helper_test.go | 14 +++++++ pgconn.go | 102 ++++++++++++++++++++++++++++++++----------------- pgconn_test.go | 28 ++++++++++++++ 3 files changed, 109 insertions(+), 35 deletions(-) diff --git a/helper_test.go b/helper_test.go index 8e7ca92f..1053310b 100644 --- a/helper_test.go +++ b/helper_test.go @@ -7,6 +7,7 @@ import ( "github.com/jackc/pgx/pgconn" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -15,3 +16,16 @@ func closeConn(t testing.TB, conn *pgconn.PgConn) { defer cancel() require.Nil(t, conn.Close(ctx)) } + +// Do a simple query to ensure the connection is still usable +func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + result, err := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil) + cancel() + + require.Nil(t, err) + assert.Equal(t, 3, len(result.Rows)) + assert.Equal(t, "1", string(result.Rows[0][0])) + assert.Equal(t, "2", string(result.Rows[1][0])) + assert.Equal(t, "3", string(result.Rows[2][0])) +} diff --git a/pgconn.go b/pgconn.go index 6b6330dc..76836b9c 100644 --- a/pgconn.go +++ b/pgconn.go @@ -562,23 +562,28 @@ func (rr *PgResultReader) close() { // Flush sends the enqueued execs to the server. func (pgConn *PgConn) Flush(ctx context.Context) error { - defer pgConn.resetBatch() - cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanup() + err := pgConn.flush() + cleanup() + return preferContextOverNetTimeoutError(ctx, err) +} +// flush sends the enqueued execs to the server without handling a context. +func (pgConn *PgConn) flush() error { n, err := pgConn.conn.Write(pgConn.batchBuf) - if err != nil { - if n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } - return preferContextOverNetTimeoutError(ctx, err) + if err != nil && n > 0 { + // Close connection because cannot recover from partially sent message. + pgConn.conn.Close() + pgConn.closed = true } - pgConn.pendingReadyForQueryCount += pgConn.batchCount - return nil + if err == nil { + pgConn.pendingReadyForQueryCount += pgConn.batchCount + } + + pgConn.resetBatch() + + return err } // contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from @@ -646,13 +651,11 @@ func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool { cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanupContext() - for pgConn.pendingReadyForQueryCount > 0 { - _, err := pgConn.ReceiveMessage() - if err != nil { - preferContextOverNetTimeoutError(ctx, err) - pgConn.Close(context.Background()) - return false - } + err := pgConn.ensureReadyForQuery() + if err != nil { + preferContextOverNetTimeoutError(ctx, err) + pgConn.Close(context.Background()) + return false } result, err := pgConn.Exec( @@ -667,6 +670,18 @@ func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool { return true } +// ensureReadyForQuery reads until pendingReadyForQueryCount == 0. +func (pgConn *PgConn) ensureReadyForQuery() error { + for pgConn.pendingReadyForQueryCount > 0 { + _, err := pgConn.ReceiveMessage() + if err != nil { + return err + } + } + + return nil +} + func (pgConn *PgConn) resetBatch() { pgConn.batchCount = 0 if len(pgConn.batchBuf) > batchBufferSize { @@ -690,14 +705,19 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) { if pgConn.batchCount != 0 { return nil, errors.New("unflushed previous sends") } - if pgConn.pendingReadyForQueryCount != 0 { - return nil, errors.New("unread previous results") + + cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) + defer cleanup() + + err := pgConn.ensureReadyForQuery() + if err != nil { + return nil, preferContextOverNetTimeoutError(ctx, err) } pgConn.SendExec(sql) - err := pgConn.Flush(ctx) + err = pgConn.flush() if err != nil { - return nil, err + return nil, preferContextOverNetTimeoutError(ctx, err) } return pgConn.bufferLastResult(ctx) @@ -741,12 +761,17 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] if pgConn.batchCount != 0 { return nil, errors.New("unflushed previous sends") } - if pgConn.pendingReadyForQueryCount != 0 { - return nil, errors.New("unread previous results") + + cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) + defer cleanup() + + err := pgConn.ensureReadyForQuery() + if err != nil { + return nil, preferContextOverNetTimeoutError(ctx, err) } pgConn.SendExecParams(sql, paramValues, paramOIDs, paramFormats, resultFormats) - err := pgConn.Flush(ctx) + err = pgConn.flush() if err != nil { return nil, err } @@ -762,12 +787,17 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa if pgConn.batchCount != 0 { return nil, errors.New("unflushed previous sends") } - if pgConn.pendingReadyForQueryCount != 0 { - return nil, errors.New("unread previous results") + + cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) + defer cleanup() + + err := pgConn.ensureReadyForQuery() + if err != nil { + return nil, preferContextOverNetTimeoutError(ctx, err) } pgConn.SendExecPrepared(stmtName, paramValues, paramFormats, resultFormats) - err := pgConn.Flush(ctx) + err = pgConn.flush() if err != nil { return nil, err } @@ -809,18 +839,20 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ if pgConn.batchCount != 0 { return nil, errors.New("unflushed previous sends") } - if pgConn.pendingReadyForQueryCount != 0 { - return nil, errors.New("unread previous results") - } - cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanupContext() + cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) + defer cleanup() + + err := pgConn.ensureReadyForQuery() + if err != nil { + return nil, preferContextOverNetTimeoutError(ctx, err) + } pgConn.batchBuf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(pgConn.batchBuf) pgConn.batchBuf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(pgConn.batchBuf) pgConn.batchBuf = (&pgproto3.Sync{}).Encode(pgConn.batchBuf) pgConn.batchCount += 1 - err := pgConn.Flush(context.Background()) + err = pgConn.flush() if err != nil { return nil, preferContextOverNetTimeoutError(ctx, err) } diff --git a/pgconn_test.go b/pgconn_test.go index 98ec9664..e436d739 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -243,6 +243,8 @@ func TestConnExec(t *testing.T) { require.Nil(t, err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, pgConn.Config.Database, string(result.Rows[0][0])) + + ensureConnValid(t, pgConn) } func TestConnExecEmpty(t *testing.T) { @@ -256,6 +258,8 @@ func TestConnExecEmpty(t *testing.T) { require.Nil(t, err) assert.Nil(t, result.CommandTag) assert.Equal(t, 0, len(result.Rows)) + + ensureConnValid(t, pgConn) } func TestConnExecMultipleQueries(t *testing.T) { @@ -269,6 +273,8 @@ func TestConnExecMultipleQueries(t *testing.T) { require.Nil(t, err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "1", string(result.Rows[0][0])) + + ensureConnValid(t, pgConn) } func TestConnExecMultipleQueriesError(t *testing.T) { @@ -286,6 +292,8 @@ func TestConnExecMultipleQueriesError(t *testing.T) { } else { t.Errorf("unexpected error: %v", err) } + + ensureConnValid(t, pgConn) } func TestConnExecContextCanceled(t *testing.T) { @@ -302,6 +310,8 @@ func TestConnExecContextCanceled(t *testing.T) { assert.Equal(t, context.DeadlineExceeded, err) assert.True(t, pgConn.RecoverFromTimeout(context.Background())) + + ensureConnValid(t, pgConn) } func TestConnExecParams(t *testing.T) { @@ -315,6 +325,8 @@ func TestConnExecParams(t *testing.T) { require.Nil(t, err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "Hello, world", string(result.Rows[0][0])) + + ensureConnValid(t, pgConn) } func TestConnExecParamsCanceled(t *testing.T) { @@ -331,6 +343,8 @@ func TestConnExecParamsCanceled(t *testing.T) { assert.Equal(t, context.DeadlineExceeded, err) assert.True(t, pgConn.RecoverFromTimeout(context.Background())) + + ensureConnValid(t, pgConn) } func TestConnExecPrepared(t *testing.T) { @@ -350,6 +364,8 @@ func TestConnExecPrepared(t *testing.T) { require.Nil(t, err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "Hello, world", string(result.Rows[0][0])) + + ensureConnValid(t, pgConn) } func TestConnExecPreparedCanceled(t *testing.T) { @@ -369,6 +385,8 @@ func TestConnExecPreparedCanceled(t *testing.T) { assert.Equal(t, context.DeadlineExceeded, err) assert.True(t, pgConn.RecoverFromTimeout(context.Background())) + + ensureConnValid(t, pgConn) } func TestConnBatchedQueries(t *testing.T) { @@ -480,6 +498,8 @@ func TestConnBatchedQueries(t *testing.T) { // Done require.False(t, pgConn.NextResult(context.Background())) + + ensureConnValid(t, pgConn) } func TestConnRecoverFromTimeout(t *testing.T) { @@ -504,6 +524,8 @@ func TestConnRecoverFromTimeout(t *testing.T) { assert.Equal(t, "1", string(result.Rows[0][0])) } cancel() + + ensureConnValid(t, pgConn) } func TestConnCancelQuery(t *testing.T) { @@ -527,6 +549,10 @@ func TestConnCancelQuery(t *testing.T) { } else { t.Errorf("expected pgconn.PgError got %v", err) } + + require.False(t, pgConn.NextResult(context.Background())) + + ensureConnValid(t, pgConn) } func TestCommandTag(t *testing.T) { @@ -573,4 +599,6 @@ begin end$$;`) require.Nil(t, err) assert.Equal(t, "hello, world", msg) + + ensureConnValid(t, pgConn) } From 475720d172af1aaff5711146b60bb8839e2952f8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 2 Jan 2019 14:10:57 -0600 Subject: [PATCH 034/290] Fix typo --- pgconn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index 76836b9c..ff43f8a8 100644 --- a/pgconn.go +++ b/pgconn.go @@ -629,7 +629,7 @@ func preferContextOverNetTimeoutError(ctx context.Context, err error) error { } // RecoverFromTimeout attempts to recover from a timeout error such as is caused by a canceled context. If recovery is -// successful true is returned. If recovery is not successful the connection is closed and false it returned. Recovery +// successful true is returned. If recovery is not successful the connection is closed and false is returned. Recovery // should usually be possible except in the case of a partial write. This must be called after any context cancellation. // // As RecoverFromTimeout may need to read and ignored data already sent from the server, it potentially can block From de2b9bb301c52f92abdd4f3caf13520e6e4855a9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 2 Jan 2019 14:20:10 -0600 Subject: [PATCH 035/290] Tweak RecoverFromTimeout docs --- pgconn.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pgconn.go b/pgconn.go index ff43f8a8..9661f99e 100644 --- a/pgconn.go +++ b/pgconn.go @@ -628,12 +628,12 @@ func preferContextOverNetTimeoutError(ctx context.Context, err error) error { return err } -// RecoverFromTimeout attempts to recover from a timeout error such as is caused by a canceled context. If recovery is -// successful true is returned. If recovery is not successful the connection is closed and false is returned. Recovery -// should usually be possible except in the case of a partial write. This must be called after any context cancellation. -// -// As RecoverFromTimeout may need to read and ignored data already sent from the server, it potentially can block -// indefinitely. Use ctx to guard against this. +// RecoverFromTimeout attempts to recover from a timeout error such as is caused by a canceled context. This must be +// called after any context cancellation. This is not done automatically as RecoverFromTimeout may need to signal the +// server to abort the in-progress query and read and ignore data already sent from the server. This potentially can +// block indefinitely. Use ctx to guard against this. If recovery is successful true is returned. If recovery is not +// successful the connection is closed and false is returned. Recovery should usually be possible except in the case of +// a partial write. func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool { if pgConn.closed { return false From ec622237e97fe4258b9f33e60af1aae33f622c27 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 2 Jan 2019 14:56:24 -0600 Subject: [PATCH 036/290] Extract startOperation --- pgconn.go | 113 +++++++++++++++++++++++++++++------------------------- 1 file changed, 60 insertions(+), 53 deletions(-) diff --git a/pgconn.go b/pgconn.go index 9661f99e..e22a0de8 100644 --- a/pgconn.go +++ b/pgconn.go @@ -586,39 +586,6 @@ func (pgConn *PgConn) flush() error { return err } -// contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from -// ctx.Done(). The returned cleanup function must be called to terminate this goroutine. The cleanup function is safe to -// call multiple times. -func contextDoneToConnDeadline(ctx context.Context, conn net.Conn) (cleanup func()) { - if ctx.Done() != nil { - deadlineWasSet := false - doneChan := make(chan struct{}) - go func() { - select { - case <-ctx.Done(): - conn.SetDeadline(deadlineTime) - deadlineWasSet = true - <-doneChan - // TODO - case <-doneChan: - } - }() - - finished := false - return func() { - if !finished { - doneChan <- struct{}{} - if deadlineWasSet { - conn.SetDeadline(time.Time{}) - } - finished = true - } - } - } - - return func() {} -} - // preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == // true. Otherwise returns err. func preferContextOverNetTimeoutError(ctx context.Context, err error) error { @@ -670,6 +637,54 @@ func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool { return true } +// startOperation gets the connection ready for a new operation. It should be called at the beginning of every public +// method that communicates with the server. The returned cleanup function must be called if err == nil or a goroutine may +// be leaked. The cleanup function is safe to call multiple times. +func (pgConn *PgConn) startOperation(ctx context.Context) (cleanup func(), err error) { + cleanup = contextDoneToConnDeadline(ctx, pgConn.conn) + + err = pgConn.ensureReadyForQuery() + if err != nil { + cleanup() + return cleanup, preferContextOverNetTimeoutError(ctx, err) + } + + return cleanup, nil +} + +// contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from +// ctx.Done(). The returned cleanup function must be called to terminate this goroutine. The cleanup function is safe to +// call multiple times. +func contextDoneToConnDeadline(ctx context.Context, conn net.Conn) (cleanup func()) { + if ctx.Done() != nil { + deadlineWasSet := false + doneChan := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + conn.SetDeadline(deadlineTime) + deadlineWasSet = true + <-doneChan + // TODO + case <-doneChan: + } + }() + + finished := false + return func() { + if !finished { + doneChan <- struct{}{} + if deadlineWasSet { + conn.SetDeadline(time.Time{}) + } + finished = true + } + } + } + + return func() {} +} + // ensureReadyForQuery reads until pendingReadyForQueryCount == 0. func (pgConn *PgConn) ensureReadyForQuery() error { for pgConn.pendingReadyForQueryCount > 0 { @@ -706,13 +721,11 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) { return nil, errors.New("unflushed previous sends") } - cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanup() - - err := pgConn.ensureReadyForQuery() + cleanup, err := pgConn.startOperation(ctx) if err != nil { - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, err } + defer cleanup() pgConn.SendExec(sql) err = pgConn.flush() @@ -762,13 +775,11 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] return nil, errors.New("unflushed previous sends") } - cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanup() - - err := pgConn.ensureReadyForQuery() + cleanup, err := pgConn.startOperation(ctx) if err != nil { - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, err } + defer cleanup() pgConn.SendExecParams(sql, paramValues, paramOIDs, paramFormats, resultFormats) err = pgConn.flush() @@ -788,13 +799,11 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa return nil, errors.New("unflushed previous sends") } - cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanup() - - err := pgConn.ensureReadyForQuery() + cleanup, err := pgConn.startOperation(ctx) if err != nil { - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, err } + defer cleanup() pgConn.SendExecPrepared(stmtName, paramValues, paramFormats, resultFormats) err = pgConn.flush() @@ -840,13 +849,11 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ return nil, errors.New("unflushed previous sends") } - cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanup() - - err := pgConn.ensureReadyForQuery() + cleanup, err := pgConn.startOperation(ctx) if err != nil { - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, err } + defer cleanup() pgConn.batchBuf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(pgConn.batchBuf) pgConn.batchBuf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(pgConn.batchBuf) From fa5e1d3ec4ad6453e36af71c042ffc261b379cfa Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 2 Jan 2019 18:16:08 -0600 Subject: [PATCH 037/290] Back out of some over optimization --- pgconn.go | 30 +++--------------------------- 1 file changed, 3 insertions(+), 27 deletions(-) diff --git a/pgconn.go b/pgconn.go index e22a0de8..ee8127bf 100644 --- a/pgconn.go +++ b/pgconn.go @@ -814,33 +814,11 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa return pgConn.bufferLastResult(ctx) } -type FieldDescription struct { - Name string - TableOID uint32 - TableAttributeNumber uint16 - DataTypeOID uint32 - DataTypeSize int16 - TypeModifier int32 - FormatCode int16 -} - -// pgproto3FieldDescriptionToPgconnFieldDescription copies and converts the data from a pgproto3.FieldDescription to a -// FieldDescription. -func pgproto3FieldDescriptionToPgconnFieldDescription(src *pgproto3.FieldDescription, dst *FieldDescription) { - dst.Name = string(src.Name) - dst.TableOID = src.TableOID - dst.TableAttributeNumber = src.TableAttributeNumber - dst.DataTypeOID = src.DataTypeOID - dst.DataTypeSize = src.DataTypeSize - dst.TypeModifier = src.TypeModifier - dst.FormatCode = src.Format -} - type PreparedStatementDescription struct { Name string SQL string ParamOIDs []uint32 - Fields []FieldDescription + Fields []pgproto3.FieldDescription } // Prepare creates a prepared statement. @@ -877,10 +855,8 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) copy(psd.ParamOIDs, msg.ParameterOIDs) case *pgproto3.RowDescription: - psd.Fields = make([]FieldDescription, len(msg.Fields)) - for i := range msg.Fields { - pgproto3FieldDescriptionToPgconnFieldDescription(&msg.Fields[i], &psd.Fields[i]) - } + psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields)) + copy(psd.Fields, msg.Fields) case *pgproto3.ErrorResponse: return nil, errorResponseToPgError(msg) } From 64e80f1f723cc2edc3495db56b23755b164abf62 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 2 Jan 2019 18:16:20 -0600 Subject: [PATCH 038/290] Add benchmarks when cancellable --- benchmark_test.go | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/benchmark_test.go b/benchmark_test.go index 269ac59b..fc4b6057 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -49,6 +49,22 @@ func BenchmarkExec(b *testing.B) { } } +func BenchmarkExecPossibleToCancel(b *testing.B) { + conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(b, err) + defer closeConn(b, conn) + + b.ResetTimer() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for i := 0; i < b.N; i++ { + _, err := conn.Exec(ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") + require.Nil(b, err) + } +} + func BenchmarkExecPrepared(b *testing.B) { conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(b, err) @@ -64,6 +80,24 @@ func BenchmarkExecPrepared(b *testing.B) { } } +func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { + conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(b, err) + defer closeConn(b, conn) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, err = conn.Prepare(ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, err := conn.ExecPrepared(ctx, "ps1", nil, nil, nil) + require.Nil(b, err) + } +} + func BenchmarkSendExecPrepared(b *testing.B) { conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(b, err) From cddf01180659a163df1e7ca9cd03dd648ea8c153 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Jan 2019 17:37:28 -0600 Subject: [PATCH 039/290] Big restructure to better handle context cancel --- benchmark_test.go | 33 +- config.go | 6 +- helper_test.go | 4 +- pgconn.go | 995 +++++++++++++++++++++++------------------- pgconn_stress_test.go | 116 +---- pgconn_test.go | 289 +++++------- 6 files changed, 686 insertions(+), 757 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index fc4b6057..ffb1455c 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -44,7 +44,7 @@ func BenchmarkExec(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := conn.Exec(context.Background(), "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") + _, err := conn.Exec(context.Background(), "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date").ReadAll() require.Nil(b, err) } } @@ -60,7 +60,7 @@ func BenchmarkExecPossibleToCancel(b *testing.B) { defer cancel() for i := 0; i < b.N; i++ { - _, err := conn.Exec(ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") + _, err := conn.Exec(ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date").ReadAll() require.Nil(b, err) } } @@ -71,12 +71,13 @@ func BenchmarkExecPrepared(b *testing.B) { defer closeConn(b, conn) _, err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) + require.Nil(b, err) b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := conn.ExecPrepared(context.Background(), "ps1", nil, nil, nil) - require.Nil(b, err) + result := conn.ExecPrepared(context.Background(), "ps1", nil, nil, nil).ReadAll() + require.Nil(b, result.Err) } } @@ -89,32 +90,12 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { defer cancel() _, err = conn.Prepare(ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - _, err := conn.ExecPrepared(ctx, "ps1", nil, nil, nil) - require.Nil(b, err) - } -} - -func BenchmarkSendExecPrepared(b *testing.B) { - conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(b, err) - defer closeConn(b, conn) - - _, err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) b.ResetTimer() for i := 0; i < b.N; i++ { - conn.SendExecPrepared("ps1", nil, nil, nil) - err := conn.Flush(context.Background()) - require.Nil(b, err) - - for conn.NextResult(context.Background()) { - _, err := conn.ResultReader().Close() - require.Nil(b, err) - } + result := conn.ExecPrepared(ctx, "ps1", nil, nil, nil).ReadAll() + require.Nil(b, result.Err) } } diff --git a/config.go b/config.go index bd1fec9b..fb0719cd 100644 --- a/config.go +++ b/config.go @@ -470,9 +470,9 @@ func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { // AfterConnectTargetSessionAttrsReadWrite is an AfterConnectFunc that implements libpq compatible // target_session_attrs=read-write. func AfterConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { - result, err := pgConn.Exec(ctx, "show transaction_read_only") - if err != nil { - return err + result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).ReadAll() + if result.Err != nil { + return result.Err } if string(result.Rows[0][0]) == "on" { diff --git a/helper_test.go b/helper_test.go index 1053310b..a50f7cb1 100644 --- a/helper_test.go +++ b/helper_test.go @@ -20,10 +20,10 @@ func closeConn(t testing.TB, conn *pgconn.PgConn) { // Do a simple query to ensure the connection is still usable func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) - result, err := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil) + result := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil).ReadAll() cancel() - require.Nil(t, err) + require.Nil(t, result.Err) assert.Equal(t, 3, len(result.Rows)) assert.Equal(t, "1", string(result.Rows[0][0])) assert.Equal(t, "2", string(result.Rows[1][0])) diff --git a/pgconn.go b/pgconn.go index ee8127bf..cfacc7bb 100644 --- a/pgconn.go +++ b/pgconn.go @@ -17,8 +17,6 @@ import ( "github.com/jackc/pgx/pgproto3" ) -const batchBufferSize = 4096 - var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC) // PgError represents an error reported by the PostgreSQL server. See @@ -76,14 +74,9 @@ type PgConn struct { Config *Config - batchBuf []byte - batchCount int32 - - pendingReadyForQueryCount int32 + controller chan interface{} closed bool - - resultReader PgResultReader } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) @@ -140,6 +133,7 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { pgConn := new(PgConn) pgConn.Config = config + pgConn.controller = make(chan interface{}, 1) var err error network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) @@ -268,23 +262,22 @@ func hexMD5(s string) string { func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { msg, err := pgConn.Frontend.Receive() if err != nil { + // Close on anything other than timeout error - everything else is fatal + if err, ok := err.(net.Error); !ok && err.Timeout() { + pgConn.hardClose() + } + return nil, err } switch msg := msg.(type) { case *pgproto3.ReadyForQuery: - // Under normal circumstances pendingReadyForQueryCount will be > 0 when a - // ReadyForQuery is received. However, this is not the case on initial - // connection. - if pgConn.pendingReadyForQueryCount > 0 { - pgConn.pendingReadyForQueryCount -= 1 - } pgConn.TxStatus = msg.TxStatus case *pgproto3.ParameterStatus: pgConn.parameterStatuses[msg.Name] = msg.Value case *pgproto3.ErrorResponse: if msg.Severity == "FATAL" { - // TODO - close pgConn + pgConn.hardClose() return nil, errorResponseToPgError(msg) } case *pgproto3.NoticeResponse: @@ -338,6 +331,15 @@ func (pgConn *PgConn) Close(ctx context.Context) error { return pgConn.conn.Close() } +// hardClose closes the underlying connection without sending the exit message. +func (pgConn *PgConn) hardClose() error { + if pgConn.closed { + return nil + } + pgConn.closed = true + return pgConn.conn.Close() +} + // ParameterStatus returns the value of a parameter reported by the server (e.g. // server_version). Returns an empty string for unknown parameters. func (pgConn *PgConn) ParameterStatus(key string) string { @@ -363,229 +365,6 @@ func (ct CommandTag) String() string { return string(ct) } -// SendExec enqueues the execution of sql via the PostgreSQL simple query protocol. sql may contain multiple queries. -// Execution is implicitly wrapped in a transactions unless a transaction is already in progress or sql contains -// transaction control statements. It is only sent to the PostgreSQL server when Flush is called. -func (pgConn *PgConn) SendExec(sql string) { - pgConn.batchBuf = (&pgproto3.Query{String: sql}).Encode(pgConn.batchBuf) - pgConn.batchCount += 1 -} - -// SendExecParams enqueues the execution of sql via the PostgreSQL extended query protocol. -// -// sql is a SQL command string. It may only contain one query. Parameter substitution is position using $1, $2, $3, etc. -// -// paramValues are the parameter values. It must be encoded in the format given by paramFormats. -// -// paramOIDs is a slice of data type OIDs for paramValues. If paramOIDs is nil, the server will infer the data type for -// all parameters. Any paramOID element that is 0 that will cause the server to infer the data type for that parameter. -// SendExecParams will panic if len(paramOIDs) is not 0, 1, or len(paramValues). -// -// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or -// binary format. If paramFormats is nil all results will be in text protocol. SendExecParams will panic if -// len(paramFormats) is not 0, 1, or len(paramValues). -// -// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or -// binary format. If resultFormats is nil all results will be in text protocol. -// -// Query is only sent to the PostgreSQL server when Flush is called. -func (pgConn *PgConn) SendExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { - if len(paramOIDs) != 0 && len(paramOIDs) != len(paramValues) && len(paramOIDs) != len(paramValues) { - panic(fmt.Sprintf("len(paramOIDs) must be 0, 1, or len(paramValues), received %d", len(paramOIDs))) - } - - pgConn.batchBuf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(pgConn.batchBuf) - pgConn.SendExecPrepared("", paramValues, paramFormats, resultFormats) -} - -// SendExecPrepared enqueues the execution of a prepared statement via the PostgreSQL extended query protocol. -// -// paramValues are the parameter values. It must be encoded in the format given by paramFormats. -// -// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or -// binary format. If paramFormats is nil all results will be in text protocol. SendExecParams will panic if -// len(paramFormats) is not 0, 1, or len(paramValues). -// -// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or -// binary format. If resultFormats is nil all results will be in text protocol. -// -// Query is only sent to the PostgreSQL server when Flush is called. -func (pgConn *PgConn) SendExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { - pgConn.batchBuf = (&pgproto3.Describe{ObjectType: 'S', Name: stmtName}).Encode(pgConn.batchBuf) - pgConn.batchBuf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(pgConn.batchBuf) - pgConn.batchBuf = (&pgproto3.Execute{}).Encode(pgConn.batchBuf) - pgConn.batchBuf = (&pgproto3.Sync{}).Encode(pgConn.batchBuf) - pgConn.batchCount += 1 -} - -type PgResultReader struct { - pgConn *PgConn - fieldDescriptions []pgproto3.FieldDescription - rowValues [][]byte - commandTag CommandTag - err error - complete bool - preloadedRowValues bool - ctx context.Context - cleanupContext func() -} - -// NextResult reads until a result is ready to be read or no results are pending. Returns true if a result is available. -// Use ResultReader() to acquire a reader for the result. -func (pgConn *PgConn) NextResult(ctx context.Context) bool { - cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) - - for pgConn.pendingReadyForQueryCount > 0 { - msg, err := pgConn.ReceiveMessage() - if err != nil { - cleanupContext() - pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, err: preferContextOverNetTimeoutError(ctx, err), complete: true} - return true - } - - switch msg := msg.(type) { - case *pgproto3.RowDescription: - pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, fieldDescriptions: msg.Fields} - return true - case *pgproto3.DataRow: - pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, cleanupContext: cleanupContext, rowValues: msg.Values, preloadedRowValues: true} - return true - case *pgproto3.CommandComplete: - cleanupContext() - pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, commandTag: CommandTag(msg.CommandTag), complete: true} - return true - case *pgproto3.EmptyQueryResponse: - cleanupContext() - pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, complete: true} - return true - case *pgproto3.ErrorResponse: - cleanupContext() - pgConn.resultReader = PgResultReader{pgConn: pgConn, ctx: ctx, err: errorResponseToPgError(msg), complete: true} - return true - } - } - - cleanupContext() - return false -} - -// ResultReader returns the result reader prepared by next result. It is only valid until the result is completed. -func (pgConn *PgConn) ResultReader() *PgResultReader { - return &pgConn.resultReader -} - -// NextRow returns advances the PgResultReader to the next row and returns true if a row is available. -func (rr *PgResultReader) NextRow() bool { - if rr.complete { - return false - } - - if rr.preloadedRowValues { - rr.preloadedRowValues = false - return true - } - - for { - msg, err := rr.pgConn.ReceiveMessage() - if err != nil { - rr.err = preferContextOverNetTimeoutError(rr.ctx, err) - rr.close() - return false - } - - switch msg := msg.(type) { - case *pgproto3.RowDescription: - rr.fieldDescriptions = msg.Fields - case *pgproto3.DataRow: - rr.rowValues = msg.Values - return true - case *pgproto3.CommandComplete: - rr.commandTag = CommandTag(msg.CommandTag) - rr.close() - return false - case *pgproto3.ErrorResponse: - rr.err = errorResponseToPgError(msg) - rr.close() - return false - } - } -} - -// FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until -// the PgResultReader is closed. -func (rr *PgResultReader) FieldDescriptions() []pgproto3.FieldDescription { - return rr.fieldDescriptions -} - -// Values returns the current row data. NextRow must have been previously been called. The returned [][]byte is only -// valid until the next NextRow call or the PgResultReader is closed. However, the underlying byte data is safe to -// retain a reference to and mutate. -func (rr *PgResultReader) Values() [][]byte { - return rr.rowValues -} - -// Close consumes any remaining result data and returns the command tag or -// error. -func (rr *PgResultReader) Close() (CommandTag, error) { - if rr.complete { - return rr.commandTag, rr.err - } - defer rr.close() - - for { - msg, err := rr.pgConn.ReceiveMessage() - if err != nil { - rr.err = preferContextOverNetTimeoutError(rr.ctx, err) - return rr.commandTag, rr.err - } - - switch msg := msg.(type) { - case *pgproto3.CommandComplete: - rr.commandTag = CommandTag(msg.CommandTag) - return rr.commandTag, rr.err - case *pgproto3.ErrorResponse: - rr.err = errorResponseToPgError(msg) - return rr.commandTag, rr.err - } - } -} - -func (rr *PgResultReader) close() { - if rr.complete { - return - } - - rr.cleanupContext() - rr.rowValues = nil - rr.complete = true -} - -// Flush sends the enqueued execs to the server. -func (pgConn *PgConn) Flush(ctx context.Context) error { - cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) - err := pgConn.flush() - cleanup() - return preferContextOverNetTimeoutError(ctx, err) -} - -// flush sends the enqueued execs to the server without handling a context. -func (pgConn *PgConn) flush() error { - n, err := pgConn.conn.Write(pgConn.batchBuf) - if err != nil && n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } - - if err == nil { - pgConn.pendingReadyForQueryCount += pgConn.batchCount - } - - pgConn.resetBatch() - - return err -} - // preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == // true. Otherwise returns err. func preferContextOverNetTimeoutError(ctx context.Context, err error) error { @@ -595,63 +374,6 @@ func preferContextOverNetTimeoutError(ctx context.Context, err error) error { return err } -// RecoverFromTimeout attempts to recover from a timeout error such as is caused by a canceled context. This must be -// called after any context cancellation. This is not done automatically as RecoverFromTimeout may need to signal the -// server to abort the in-progress query and read and ignore data already sent from the server. This potentially can -// block indefinitely. Use ctx to guard against this. If recovery is successful true is returned. If recovery is not -// successful the connection is closed and false is returned. Recovery should usually be possible except in the case of -// a partial write. -func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool { - if pgConn.closed { - return false - } - pgConn.resetBatch() - - // Clear any existing timeout - pgConn.conn.SetDeadline(time.Time{}) - - // Try to cancel any in-progress requests - for i := 0; i < int(pgConn.pendingReadyForQueryCount); i++ { - pgConn.CancelRequest(ctx) - } - - cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanupContext() - - err := pgConn.ensureReadyForQuery() - if err != nil { - preferContextOverNetTimeoutError(ctx, err) - pgConn.Close(context.Background()) - return false - } - - result, err := pgConn.Exec( - context.Background(), // do not use ctx again because deadline goroutine already started above - "select 'RecoverFromTimeout'", - ) - if err != nil || len(result.Rows) != 1 || len(result.Rows[0]) != 1 || string(result.Rows[0][0]) != "RecoverFromTimeout" { - pgConn.Close(context.Background()) - return false - } - - return true -} - -// startOperation gets the connection ready for a new operation. It should be called at the beginning of every public -// method that communicates with the server. The returned cleanup function must be called if err == nil or a goroutine may -// be leaked. The cleanup function is safe to call multiple times. -func (pgConn *PgConn) startOperation(ctx context.Context) (cleanup func(), err error) { - cleanup = contextDoneToConnDeadline(ctx, pgConn.conn) - - err = pgConn.ensureReadyForQuery() - if err != nil { - cleanup() - return cleanup, preferContextOverNetTimeoutError(ctx, err) - } - - return cleanup, nil -} - // contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from // ctx.Done(). The returned cleanup function must be called to terminate this goroutine. The cleanup function is safe to // call multiple times. @@ -665,7 +387,6 @@ func contextDoneToConnDeadline(ctx context.Context, conn net.Conn) (cleanup func conn.SetDeadline(deadlineTime) deadlineWasSet = true <-doneChan - // TODO case <-doneChan: } }() @@ -685,135 +406,6 @@ func contextDoneToConnDeadline(ctx context.Context, conn net.Conn) (cleanup func return func() {} } -// ensureReadyForQuery reads until pendingReadyForQueryCount == 0. -func (pgConn *PgConn) ensureReadyForQuery() error { - for pgConn.pendingReadyForQueryCount > 0 { - _, err := pgConn.ReceiveMessage() - if err != nil { - return err - } - } - - return nil -} - -func (pgConn *PgConn) resetBatch() { - pgConn.batchCount = 0 - if len(pgConn.batchBuf) > batchBufferSize { - pgConn.batchBuf = make([]byte, 0, batchBufferSize) - } else { - pgConn.batchBuf = pgConn.batchBuf[0:0] - } -} - -type PgResult struct { - Rows [][][]byte - CommandTag CommandTag -} - -// Exec executes sql via the PostgreSQL simple query protocol, buffers the entire result, and returns it. sql may -// contain multiple queries, but only the last results will be returned. Execution is implicitly wrapped in a -// transactions unless a transaction is already in progress or sql contains transaction control statements. -// -// Exec must not be called when there are pending results from previous Send* methods (e.g. SendExec). -func (pgConn *PgConn) Exec(ctx context.Context, sql string) (*PgResult, error) { - if pgConn.batchCount != 0 { - return nil, errors.New("unflushed previous sends") - } - - cleanup, err := pgConn.startOperation(ctx) - if err != nil { - return nil, err - } - defer cleanup() - - pgConn.SendExec(sql) - err = pgConn.flush() - if err != nil { - return nil, preferContextOverNetTimeoutError(ctx, err) - } - - return pgConn.bufferLastResult(ctx) -} - -func (pgConn *PgConn) bufferLastResult(ctx context.Context) (*PgResult, error) { - var result *PgResult - - for pgConn.NextResult(ctx) { - resultReader := pgConn.ResultReader() - rows := [][][]byte{} - for resultReader.NextRow() { - row := make([][]byte, len(resultReader.Values())) - copy(row, resultReader.Values()) - rows = append(rows, row) - } - - commandTag, err := resultReader.Close() - if err != nil { - return nil, err - } - - result = &PgResult{ - Rows: rows, - CommandTag: commandTag, - } - } - - if result == nil { - return nil, errors.New("unexpected missing result") - } - - return result, nil -} - -// ExecParams executes sql via the PostgreSQL extended query protocol, buffers the entire result, and returns it. See -// SendExecParams for parameter descriptions. -// -// ExecParams must not be called when there are pending results from previous Send* methods (e.g. SendExec). -func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) (*PgResult, error) { - if pgConn.batchCount != 0 { - return nil, errors.New("unflushed previous sends") - } - - cleanup, err := pgConn.startOperation(ctx) - if err != nil { - return nil, err - } - defer cleanup() - - pgConn.SendExecParams(sql, paramValues, paramOIDs, paramFormats, resultFormats) - err = pgConn.flush() - if err != nil { - return nil, err - } - - return pgConn.bufferLastResult(ctx) -} - -// ExecPrepared executes a prepared statement via the PostgreSQL extended query protocol, buffers the entire result, and -// returns it. See SendExecPrepared for parameter descriptions. -// -// ExecPrepared must not be called when there are pending results from previous Send* methods (e.g. SendExec). -func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) (*PgResult, error) { - if pgConn.batchCount != 0 { - return nil, errors.New("unflushed previous sends") - } - - cleanup, err := pgConn.startOperation(ctx) - if err != nil { - return nil, err - } - defer cleanup() - - pgConn.SendExecPrepared(stmtName, paramValues, paramFormats, resultFormats) - err = pgConn.flush() - if err != nil { - return nil, err - } - - return pgConn.bufferLastResult(ctx) -} - type PreparedStatementDescription struct { Name string SQL string @@ -823,30 +415,38 @@ type PreparedStatementDescription struct { // Prepare creates a prepared statement. func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) { - if pgConn.batchCount != 0 { - return nil, errors.New("unflushed previous sends") + select { + case <-ctx.Done(): + return nil, ctx.Err() + case pgConn.controller <- pgConn: } + cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) + defer cleanupContextDeadline() - cleanup, err := pgConn.startOperation(ctx) - if err != nil { - return nil, err - } - defer cleanup() + var buf []byte + buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) + buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf) + buf = (&pgproto3.Sync{}).Encode(buf) - pgConn.batchBuf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(pgConn.batchBuf) - pgConn.batchBuf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(pgConn.batchBuf) - pgConn.batchBuf = (&pgproto3.Sync{}).Encode(pgConn.batchBuf) - pgConn.batchCount += 1 - err = pgConn.flush() + n, err := pgConn.conn.Write(buf) if err != nil { + // Partially sent messages are a fatal error for the connection. + if n > 0 { + // Close connection because cannot recover from partially sent message. + pgConn.conn.Close() + pgConn.closed = true + } + return nil, preferContextOverNetTimeoutError(ctx, err) } psd := &PreparedStatementDescription{Name: name, SQL: sql} - for pgConn.pendingReadyForQueryCount > 0 { +readloop: + for { msg, err := pgConn.ReceiveMessage() if err != nil { + go pgConn.recoverFromTimeout() return nil, preferContextOverNetTimeoutError(ctx, err) } @@ -858,10 +458,14 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields)) copy(psd.Fields, msg.Fields) case *pgproto3.ErrorResponse: + go pgConn.recoverFromTimeout() return nil, errorResponseToPgError(msg) + case *pgproto3.ReadyForQuery: + break readloop } } + <-pgConn.controller return psd, nil } @@ -892,10 +496,10 @@ func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice { return (*Notice)(pgerr) } -// CancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel +// cancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel // request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there // is no way to be sure a query was canceled. See https://www.postgresql.org/docs/11/protocol-flow.html#id-1.10.5.7.9 -func (pgConn *PgConn) CancelRequest(ctx context.Context) error { +func (pgConn *PgConn) cancelRequest(ctx context.Context) error { // Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. @@ -926,3 +530,514 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { return nil } + +// Exec executes SQL via the PostgreSQL simple query protocol. SQL may contain multiple queries. Execution is +// implicitly wrapped in a transaction unless a transaction is already in progress or SQL contains transaction control +// statements. +// +// Prefer ExecParams unless executing arbitrary SQL that may contain multiple queries. +func (pgConn *PgConn) Exec(ctx context.Context, sql string) *PgMultiResult { + multiResult := &PgMultiResult{ + pgConn: pgConn, + ctx: ctx, + cleanupContextDeadline: func() {}, + } + + select { + case <-ctx.Done(): + multiResult.closed = true + multiResult.err = ctx.Err() + return multiResult + case pgConn.controller <- multiResult: + } + multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) + + var buf []byte + buf = (&pgproto3.Query{String: sql}).Encode(buf) + + n, err := pgConn.conn.Write(buf) + if err != nil { + // Partially sent messages are a fatal error for the connection. + if n > 0 { + // Close connection because cannot recover from partially sent message. + pgConn.conn.Close() + pgConn.closed = true + } + + multiResult.cleanupContextDeadline() + multiResult.closed = true + multiResult.err = preferContextOverNetTimeoutError(ctx, err) + <-pgConn.controller + return multiResult + } + + return multiResult +} + +// ExecParams executes a command via the PostgreSQL extended query protocol. +// +// sql is a SQL command string. It may only contain one query. Parameter substitution is positional using $1, $2, $3, +// etc. +// +// paramValues are the parameter values. It must be encoded in the format given by paramFormats. +// +// paramOIDs is a slice of data type OIDs for paramValues. If paramOIDs is nil, the server will infer the data type for +// all parameters. Any paramOID element that is 0 that will cause the server to infer the data type for that parameter. +// ExecParams will panic if len(paramOIDs) is not 0, 1, or len(paramValues). +// +// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or +// binary format. If paramFormats is nil all results will be in text protocol. ExecParams will panic if +// len(paramFormats) is not 0, 1, or len(paramValues). +// +// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or +// binary format. If resultFormats is nil all results will be in text protocol. +// +// Result must be closed before PgConn can be used again. +func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *PgResult { + result := &PgResult{ + pgConn: pgConn, + ctx: ctx, + cleanupContextDeadline: func() {}, + } + + select { + case <-ctx.Done(): + result.concludeCommand(nil, ctx.Err()) + result.closed = true + return result + case pgConn.controller <- result: + } + result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) + + var buf []byte + + // TODO - refactor ExecParams and ExecPrepared - these lines only difference + buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) + buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) + + buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) + buf = (&pgproto3.Execute{}).Encode(buf) + buf = (&pgproto3.Sync{}).Encode(buf) + + n, err := pgConn.conn.Write(buf) + if err != nil { + // Partially sent messages are a fatal error for the connection. + if n > 0 { + // Close connection because cannot recover from partially sent message. + pgConn.conn.Close() + pgConn.closed = true + } + + result.concludeCommand(nil, err) + result.cleanupContextDeadline() + result.closed = true + <-pgConn.controller + } + + return result +} + +// ExecPrepared enqueues the execution of a prepared statement via the PostgreSQL extended query protocol. +// +// paramValues are the parameter values. It must be encoded in the format given by paramFormats. +// +// paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or +// binary format. If paramFormats is nil all results will be in text protocol. ExecPrepared will panic if +// len(paramFormats) is not 0, 1, or len(paramValues). +// +// resultFormats is a slice of format codes determining for each result column whether it is encoded in text or +// binary format. If resultFormats is nil all results will be in text protocol. +// +// Result must be closed before PgConn can be used again. +func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *PgResult { + result := &PgResult{ + pgConn: pgConn, + ctx: ctx, + cleanupContextDeadline: func() {}, + } + + select { + case <-ctx.Done(): + result.concludeCommand(nil, ctx.Err()) + result.closed = true + return result + case pgConn.controller <- result: + } + result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) + + var buf []byte + buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) + buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) + buf = (&pgproto3.Execute{}).Encode(buf) + buf = (&pgproto3.Sync{}).Encode(buf) + + n, err := pgConn.conn.Write(buf) + if err != nil { + // Partially sent messages are a fatal error for the connection. + if n > 0 { + // Close connection because cannot recover from partially sent message. + pgConn.conn.Close() + pgConn.closed = true + } + + result.concludeCommand(nil, err) + result.cleanupContextDeadline() + result.closed = true + <-pgConn.controller + } + + return result +} + +type PgMultiResult struct { + pgConn *PgConn + ctx context.Context + cleanupContextDeadline func() + + pgResult *PgResult + + closed bool + err error +} + +func (mr *PgMultiResult) ReadAll() ([]*BufferedResult, error) { + var results []*BufferedResult + + for mr.NextResult() { + results = append(results, mr.Result().ReadAll()) + } + err := mr.Close() + + return results, err +} + +func (mr *PgMultiResult) receiveMessage() (pgproto3.BackendMessage, error) { + msg, err := mr.pgConn.ReceiveMessage() + + if err != nil { + mr.cleanupContextDeadline() + mr.err = preferContextOverNetTimeoutError(mr.ctx, err) + mr.closed = true + + if err, ok := err.(net.Error); ok && err.Timeout() { + go mr.pgConn.recoverFromTimeout() + } else { + <-mr.pgConn.controller + } + + return nil, mr.err + } + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + mr.cleanupContextDeadline() + mr.closed = true + <-mr.pgConn.controller + case *pgproto3.ErrorResponse: + mr.err = errorResponseToPgError(msg) + } + + return msg, nil +} + +// NextResult returns advances the PgMultiResult to the next result and returns true if a result is available. +func (mr *PgMultiResult) NextResult() bool { + for !mr.closed && mr.err == nil { + msg, err := mr.receiveMessage() + if err != nil { + return false + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: + mr.pgResult = &PgResult{ + pgConn: mr.pgConn, + pgMultiResult: mr, + ctx: mr.ctx, + cleanupContextDeadline: func() {}, + fieldDescriptions: msg.Fields, + } + return true + case *pgproto3.CommandComplete: + mr.pgResult = &PgResult{ + commandTag: CommandTag(msg.CommandTag), + commandConcluded: true, + closed: true, + } + return true + case *pgproto3.EmptyQueryResponse: + return false + } + } + + return false +} + +func (mr *PgMultiResult) Result() *PgResult { + return mr.pgResult +} + +func (mr *PgMultiResult) Close() error { + for !mr.closed { + _, err := mr.receiveMessage() + if err != nil { + return mr.err + } + } + + return mr.err +} + +type PgResult struct { + pgConn *PgConn + pgMultiResult *PgMultiResult + ctx context.Context + cleanupContextDeadline func() + + fieldDescriptions []pgproto3.FieldDescription + rowValues [][]byte + commandTag CommandTag + commandConcluded bool + closed bool + err error +} + +type BufferedResult struct { + FieldDescriptions []pgproto3.FieldDescription + Rows [][][]byte + CommandTag CommandTag + Err error +} + +func (rr *PgResult) ReadAll() *BufferedResult { + br := &BufferedResult{} + + for rr.NextRow() { + if br.FieldDescriptions == nil { + br.FieldDescriptions = make([]pgproto3.FieldDescription, len(rr.FieldDescriptions())) + copy(br.FieldDescriptions, rr.FieldDescriptions()) + } + + row := make([][]byte, len(rr.Values())) + copy(row, rr.Values()) + br.Rows = append(br.Rows, row) + } + + br.CommandTag, br.Err = rr.Close() + + return br +} + +// NextRow advances the PgResult to the next row and returns true if a row is available. +func (rr *PgResult) NextRow() bool { + for !rr.commandConcluded { + msg, err := rr.receiveMessage() + if err != nil { + return false + } + + switch msg := msg.(type) { + case *pgproto3.DataRow: + rr.rowValues = msg.Values + return true + } + } + + return false +} + +// FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until +// the PgResult is closed. +func (rr *PgResult) FieldDescriptions() []pgproto3.FieldDescription { + return rr.fieldDescriptions +} + +// Values returns the current row data. NextRow must have been previously been called. The returned [][]byte is only +// valid until the next NextRow call or the PgResult is closed. However, the underlying byte data is safe to +// retain a reference to and mutate. +func (rr *PgResult) Values() [][]byte { + return rr.rowValues +} + +// Close consumes any remaining result data and returns the command tag or +// error. +func (rr *PgResult) Close() (CommandTag, error) { + if rr.closed { + return rr.commandTag, rr.err + } + rr.closed = true + + for !rr.commandConcluded { + _, err := rr.receiveMessage() + if err != nil { + return nil, rr.err + } + } + + if rr.pgMultiResult == nil { + for { + msg, err := rr.receiveMessage() + if err != nil { + return nil, rr.err + } + + switch msg.(type) { + case *pgproto3.ReadyForQuery: + rr.cleanupContextDeadline() + <-rr.pgConn.controller + return rr.commandTag, rr.err + } + } + } + + return rr.commandTag, rr.err +} + +func (rr *PgResult) receiveMessage() (msg pgproto3.BackendMessage, err error) { + if rr.pgMultiResult == nil { + msg, err = rr.pgConn.ReceiveMessage() + } else { + msg, err = rr.pgMultiResult.receiveMessage() + } + + if err != nil { + rr.concludeCommand(nil, err) + rr.cleanupContextDeadline() + rr.closed = true + if rr.pgMultiResult == nil { + if err, ok := err.(net.Error); ok && err.Timeout() { + go rr.pgConn.recoverFromTimeout() + } else { + <-rr.pgConn.controller + } + } + + return nil, rr.err + } + + switch msg := msg.(type) { + case *pgproto3.RowDescription: + rr.fieldDescriptions = msg.Fields + case *pgproto3.CommandComplete: + rr.concludeCommand(CommandTag(msg.CommandTag), nil) + case *pgproto3.ErrorResponse: + rr.concludeCommand(nil, errorResponseToPgError(msg)) + } + + return msg, nil +} + +func (rr *PgResult) concludeCommand(commandTag CommandTag, err error) { + if rr.commandConcluded { + return + } + + rr.commandTag = commandTag + rr.err = preferContextOverNetTimeoutError(rr.ctx, err) + rr.fieldDescriptions = nil + rr.rowValues = nil + rr.commandConcluded = true +} + +func (pgConn *PgConn) recoverFromTimeout() { + // Regardless of recovery outcome the lock on the pgConn must be released. + defer func() { <-pgConn.controller }() + + // Send a cancellation request to the PostgreSQL server. If it is not successful in a reasonable amount of time do not + // try further to recover the connection. + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + err := pgConn.cancelRequest(ctx) + cancel() + if err != nil { + pgConn.hardClose() + return + } + + // Limit time to wait for ReadyForQuery message. + err = pgConn.conn.SetDeadline(time.Now().Add(15 * time.Second)) + if err != nil { + pgConn.hardClose() + return + } + + // A cancel query request will always return a "57014" error response, even if no query was in progress. This error + // may be returned before or after the ReadyForQuery message. Must ensure both messages are read. + needError57014 := true + needReadyForQuery := true + + for needError57014 || needReadyForQuery { + msg, err := pgConn.ReceiveMessage() + if err != nil { + pgConn.hardClose() + return + } + + switch msg := msg.(type) { + case *pgproto3.ErrorResponse: + if msg.Code == "57014" { + needError57014 = false + } + case *pgproto3.ReadyForQuery: + needReadyForQuery = false + } + } + + err = pgConn.conn.SetDeadline(time.Time{}) + if err != nil { + pgConn.hardClose() + } +} + +type Batch struct { + buf []byte +} + +// ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions. +func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { + // TODO - refactor ExecParams and ExecPrepared - these lines only difference + batch.buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf) + batch.ExecPrepared("", paramValues, paramFormats, resultFormats) +} + +// ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions. +func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { + batch.buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf) + batch.buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf) + batch.buf = (&pgproto3.Execute{}).Encode(batch.buf) +} + +func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *PgMultiResult { + multiResult := &PgMultiResult{ + pgConn: pgConn, + ctx: ctx, + cleanupContextDeadline: func() {}, + } + + select { + case <-ctx.Done(): + multiResult.closed = true + multiResult.err = ctx.Err() + return multiResult + case pgConn.controller <- multiResult: + } + multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) + + batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) + n, err := pgConn.conn.Write(batch.buf) + if err != nil { + // Partially sent messages are a fatal error for the connection. + if n > 0 { + // Close connection because cannot recover from partially sent message. + pgConn.conn.Close() + pgConn.closed = true + } + + multiResult.cleanupContextDeadline() + multiResult.closed = true + multiResult.err = preferContextOverNetTimeoutError(ctx, err) + <-pgConn.controller + return multiResult + } + + return multiResult +} diff --git a/pgconn_stress_test.go b/pgconn_stress_test.go index 9aa94539..17d344b7 100644 --- a/pgconn_stress_test.go +++ b/pgconn_stress_test.go @@ -9,7 +9,6 @@ import ( "time" "github.com/jackc/pgx/pgconn" - "github.com/pkg/errors" "github.com/stretchr/testify/require" ) @@ -22,9 +21,9 @@ func TestConnStress(t *testing.T) { defer closeConn(t, pgConn) actionCount := 100 - if s := os.Getenv("PTX_TEST_STRESS_FACTOR"); s != "" { + if s := os.Getenv("PGX_TEST_STRESS_FACTOR"); s != "" { stressFactor, err := strconv.ParseInt(s, 10, 64) - require.Nil(t, err, "Failed to parse PTX_TEST_STRESS_FACTOR") + require.Nil(t, err, "Failed to parse PGX_TEST_STRESS_FACTOR") actionCount *= int(stressFactor) } @@ -61,138 +60,61 @@ func setupStressDB(t *testing.T, pgConn *pgconn.PgConn) { insert into widgets(name, description) values ('Foo', 'bar'), ('baz', 'Something really long Something really long Something really long Something really long Something really long'), - ('a', 'b')`) + ('a', 'b')`).ReadAll() require.Nil(t, err) } func stressExecSelect(pgConn *pgconn.PgConn) error { - _, err := pgConn.Exec(context.Background(), "select * from widgets") + _, err := pgConn.Exec(context.Background(), "select * from widgets").ReadAll() return err } func stressExecParamsSelect(pgConn *pgconn.PgConn) error { - _, err := pgConn.ExecParams(context.Background(), "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) - return err + result := pgConn.ExecParams(context.Background(), "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).ReadAll() + return result.Err } func stressBatch(pgConn *pgconn.PgConn) error { - pgConn.SendExec("select * from widgets") - pgConn.SendExecParams("select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) - err := pgConn.Flush(context.Background()) - if err != nil { - return err - } + batch := &pgconn.Batch{} - // Query 1 - if !pgConn.NextResult(context.Background()) { - return errors.New("missing result") - } - resultReader := pgConn.ResultReader() - - for resultReader.NextRow() { - } - _, err = resultReader.Close() - if err != nil { - return err - } - - // Query 2 - if !pgConn.NextResult(context.Background()) { - return errors.New("missing result") - } - resultReader = pgConn.ResultReader() - - for resultReader.NextRow() { - } - _, err = resultReader.Close() - if err != nil { - return err - } - - // No more - if pgConn.NextResult(context.Background()) { - return errors.New("unexpected result reader") - } - - return nil + batch.ExecParams("select * from widgets", nil, nil, nil, nil) + batch.ExecParams("select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) + _, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() + return err } func stressExecSelectCanceled(pgConn *pgconn.PgConn) error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) - _, err := pgConn.Exec(ctx, "select *, pg_sleep(1) from widgets") + _, err := pgConn.Exec(ctx, "select *, pg_sleep(1) from widgets").ReadAll() cancel() if err != context.DeadlineExceeded { return err } - ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond) - recovered := pgConn.RecoverFromTimeout(ctx) - cancel() - if !recovered { - return errors.New("unable to recover from timeout") - } return nil } func stressExecParamsSelectCanceled(pgConn *pgconn.PgConn) error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) - _, err := pgConn.ExecParams(ctx, "select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) + result := pgConn.ExecParams(ctx, "select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).ReadAll() cancel() - if err != context.DeadlineExceeded { - return err + if result.Err != context.DeadlineExceeded { + return result.Err } - ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond) - recovered := pgConn.RecoverFromTimeout(ctx) - cancel() - if !recovered { - return errors.New("unable to recover from timeout") - } return nil } func stressBatchCanceled(pgConn *pgconn.PgConn) error { - - pgConn.SendExec("select * from widgets") - pgConn.SendExecParams("select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) - err := pgConn.Flush(context.Background()) - if err != nil { - return err - } - - // Query 1 - if !pgConn.NextResult(context.Background()) { - return errors.New("missing result") - } - resultReader := pgConn.ResultReader() - - for resultReader.NextRow() { - } - _, err = resultReader.Close() - if err != nil { - return err - } - - // Query 2 + batch := &pgconn.Batch{} + batch.ExecParams("select * from widgets", nil, nil, nil, nil) + batch.ExecParams("select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) - if !pgConn.NextResult(ctx) { - return errors.New("missing result") - } + _, err := pgConn.ExecBatch(ctx, batch).ReadAll() cancel() - resultReader = pgConn.ResultReader() - - for resultReader.NextRow() { - } - _, err = resultReader.Close() if err != context.DeadlineExceeded { return err } - ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond) - recovered := pgConn.RecoverFromTimeout(ctx) - cancel() - if !recovered { - return errors.New("unable to recover from timeout") - } return nil } diff --git a/pgconn_test.go b/pgconn_test.go index e436d739..a2eb7838 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -134,13 +134,13 @@ func TestConnectWithRuntimeParams(t *testing.T) { require.Nil(t, err) defer closeConn(t, conn) - result, err := conn.Exec(context.Background(), "show application_name") - require.Nil(t, err) + result := conn.ExecParams(context.Background(), "show application_name", nil, nil, nil, nil).ReadAll() + require.Nil(t, result.Err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "pgxtest", string(result.Rows[0][0])) - result, err = conn.Exec(context.Background(), "show search_path") - require.Nil(t, err) + result = conn.ExecParams(context.Background(), "show search_path", nil, nil, nil, nil).ReadAll() + require.Nil(t, result.Err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "myschema", string(result.Rows[0][0])) } @@ -239,10 +239,14 @@ func TestConnExec(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) - result, err := pgConn.Exec(context.Background(), "select current_database()") - require.Nil(t, err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, pgConn.Config.Database, string(result.Rows[0][0])) + results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + assert.Nil(t, err) + + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) ensureConnValid(t, pgConn) } @@ -254,10 +258,16 @@ func TestConnExecEmpty(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) - result, err := pgConn.Exec(context.Background(), ";") - require.Nil(t, err) - assert.Nil(t, result.CommandTag) - assert.Equal(t, 0, len(result.Rows)) + multiResult := pgConn.Exec(context.Background(), ";") + + resultCount := 0 + for multiResult.NextResult() { + resultCount += 1 + multiResult.Result().Close() + } + assert.Equal(t, 0, resultCount) + err = multiResult.Close() + assert.Nil(t, err) ensureConnValid(t, pgConn) } @@ -269,10 +279,20 @@ func TestConnExecMultipleQueries(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) - result, err := pgConn.Exec(context.Background(), "select current_database(); select 1") - require.Nil(t, err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, "1", string(result.Rows[0][0])) + results, err := pgConn.Exec(context.Background(), "select 'Hello, world'; select 1").ReadAll() + assert.Nil(t, err) + + assert.Len(t, results, 2) + + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + + assert.Nil(t, results[1].Err) + assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) + assert.Len(t, results[1].Rows, 1) + assert.Equal(t, "1", string(results[1].Rows[0][0])) ensureConnValid(t, pgConn) } @@ -284,15 +304,18 @@ func TestConnExecMultipleQueriesError(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) - result, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1") + results, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1").ReadAll() require.NotNil(t, err) - require.Nil(t, result) if pgErr, ok := err.(*pgconn.PgError); ok { assert.Equal(t, "22012", pgErr.Code) } else { t.Errorf("unexpected error: %v", err) } + assert.Len(t, results, 1) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "1", string(results[0].Rows[0][0])) + ensureConnValid(t, pgConn) } @@ -305,11 +328,12 @@ func TestConnExecContextCanceled(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() - result, err := pgConn.Exec(ctx, "select current_database(), pg_sleep(1)") - assert.Nil(t, result) - assert.Equal(t, context.DeadlineExceeded, err) + multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(1)") - assert.True(t, pgConn.RecoverFromTimeout(context.Background())) + for multiResult.NextResult() { + } + err = multiResult.Close() + assert.Equal(t, context.DeadlineExceeded, err) ensureConnValid(t, pgConn) } @@ -321,10 +345,16 @@ func TestConnExecParams(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) - result, err := pgConn.ExecParams(context.Background(), "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) - require.Nil(t, err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, "Hello, world", string(result.Rows[0][0])) + result := pgConn.ExecParams(context.Background(), "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.Nil(t, err) ensureConnValid(t, pgConn) } @@ -338,12 +368,16 @@ func TestConnExecParamsCanceled(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() - result, err := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil) - assert.Nil(t, result) + result := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + } + assert.Equal(t, 0, rowCount) + commandTag, err := result.Close() + assert.Nil(t, commandTag) assert.Equal(t, context.DeadlineExceeded, err) - assert.True(t, pgConn.RecoverFromTimeout(context.Background())) - ensureConnValid(t, pgConn) } @@ -360,10 +394,16 @@ func TestConnExecPrepared(t *testing.T) { assert.Len(t, psd.ParamOIDs, 1) assert.Len(t, psd.Fields, 1) - result, err := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) - require.Nil(t, err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, "Hello, world", string(result.Rows[0][0])) + result := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.Nil(t, err) ensureConnValid(t, pgConn) } @@ -380,16 +420,20 @@ func TestConnExecPreparedCanceled(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() - result, err := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil) - assert.Nil(t, result) + result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + } + assert.Equal(t, 0, rowCount) + commandTag, err := result.Close() + assert.Nil(t, commandTag) assert.Equal(t, context.DeadlineExceeded, err) - assert.True(t, pgConn.RecoverFromTimeout(context.Background())) - ensureConnValid(t, pgConn) } -func TestConnBatchedQueries(t *testing.T) { +func TestConnExecBatch(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) @@ -399,160 +443,26 @@ func TestConnBatchedQueries(t *testing.T) { _, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) require.Nil(t, err) - pgConn.SendExec("select 'SendExec 1'") - pgConn.SendExecParams("select $1::text", [][]byte{[]byte("SendExecParams 1")}, nil, nil, nil) - pgConn.SendExecPrepared("ps1", [][]byte{[]byte("SendExecPrepared 1")}, nil, nil) - pgConn.SendExec("select 'SendExec 2'") - pgConn.SendExecParams("select $1::text", [][]byte{[]byte("SendExecParams 2")}, nil, nil, nil) - err = pgConn.Flush(context.Background()) + batch := &pgconn.Batch{} - // "select 'SendExec 1'" - require.True(t, pgConn.NextResult(context.Background())) - resultReader := pgConn.ResultReader() - - rows := [][][]byte{} - for resultReader.NextRow() { - row := make([][]byte, len(resultReader.Values())) - copy(row, resultReader.Values()) - rows = append(rows, row) - } - require.Len(t, rows, 1) - require.Len(t, rows[0], 1) - assert.Equal(t, "SendExec 1", string(rows[0][0])) - - commandTag, err := resultReader.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.Nil(t, err) - - // "SendExecParams 1" - require.True(t, pgConn.NextResult(context.Background())) - resultReader = pgConn.ResultReader() - - rows = [][][]byte{} - for resultReader.NextRow() { - row := make([][]byte, len(resultReader.Values())) - copy(row, resultReader.Values()) - rows = append(rows, row) - } - require.Len(t, rows, 1) - require.Len(t, rows[0], 1) - assert.Equal(t, "SendExecParams 1", string(rows[0][0])) - - commandTag, err = resultReader.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.Nil(t, err) - - // "SendExecPrepared 1" - require.True(t, pgConn.NextResult(context.Background())) - resultReader = pgConn.ResultReader() - - rows = [][][]byte{} - for resultReader.NextRow() { - row := make([][]byte, len(resultReader.Values())) - copy(row, resultReader.Values()) - rows = append(rows, row) - } - require.Len(t, rows, 1) - require.Len(t, rows[0], 1) - assert.Equal(t, "SendExecPrepared 1", string(rows[0][0])) - - commandTag, err = resultReader.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.Nil(t, err) - - // "SendExec 2" - require.True(t, pgConn.NextResult(context.Background())) - resultReader = pgConn.ResultReader() - - rows = [][][]byte{} - for resultReader.NextRow() { - row := make([][]byte, len(resultReader.Values())) - copy(row, resultReader.Values()) - rows = append(rows, row) - } - require.Len(t, rows, 1) - require.Len(t, rows[0], 1) - assert.Equal(t, "SendExec 2", string(rows[0][0])) - - commandTag, err = resultReader.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.Nil(t, err) - - // "SendExecParams 2" - require.True(t, pgConn.NextResult(context.Background())) - resultReader = pgConn.ResultReader() - - rows = [][][]byte{} - for resultReader.NextRow() { - row := make([][]byte, len(resultReader.Values())) - copy(row, resultReader.Values()) - rows = append(rows, row) - } - require.Len(t, rows, 1) - require.Len(t, rows[0], 1) - assert.Equal(t, "SendExecParams 2", string(rows[0][0])) - - commandTag, err = resultReader.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.Nil(t, err) - - // Done - require.False(t, pgConn.NextResult(context.Background())) - - ensureConnValid(t, pgConn) -} - -func TestConnRecoverFromTimeout(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) + batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) + results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() require.Nil(t, err) - defer closeConn(t, pgConn) + require.Len(t, results, 3) - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - result, err := pgConn.Exec(ctx, "select current_database(), pg_sleep(1)") - cancel() - require.Nil(t, result) - assert.Equal(t, context.DeadlineExceeded, err) + require.Len(t, results[0].Rows, 1) + require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) - if assert.True(t, pgConn.RecoverFromTimeout(ctx)) { - result, err := pgConn.Exec(ctx, "select 1") - require.Nil(t, err) - assert.Len(t, result.Rows, 1) - assert.Len(t, result.Rows[0], 1) - assert.Equal(t, "1", string(result.Rows[0][0])) - } - cancel() + require.Len(t, results[1].Rows, 1) + require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) - ensureConnValid(t, pgConn) -} - -func TestConnCancelQuery(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) - defer closeConn(t, pgConn) - - pgConn.SendExec("select current_database(), pg_sleep(5)") - err = pgConn.Flush(context.Background()) - require.Nil(t, err) - - err = pgConn.CancelRequest(context.Background()) - require.Nil(t, err) - - require.True(t, pgConn.NextResult(context.Background())) - _, err = pgConn.ResultReader().Close() - if err, ok := err.(*pgconn.PgError); ok { - assert.Equal(t, "57014", err.Code) - } else { - t.Errorf("expected pgconn.PgError got %v", err) - } - - require.False(t, pgConn.NextResult(context.Background())) - - ensureConnValid(t, pgConn) + require.Len(t, results[2].Rows, 1) + require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) } func TestCommandTag(t *testing.T) { @@ -593,10 +503,11 @@ func TestConnOnNotice(t *testing.T) { require.Nil(t, err) defer closeConn(t, pgConn) - _, err = pgConn.Exec(context.Background(), `do $$ + multiResult := pgConn.Exec(context.Background(), `do $$ begin raise notice 'hello, world'; end$$;`) + err = multiResult.Close() require.Nil(t, err) assert.Equal(t, "hello, world", msg) From 04ee3b8cbd64e2acbab4ceff2b1369677f3cb2d5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Jan 2019 17:41:43 -0600 Subject: [PATCH 040/290] Remove Pg prefix for a couple types --- pgconn.go | 60 +++++++++++++++++++++++++++---------------------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/pgconn.go b/pgconn.go index cfacc7bb..d4563086 100644 --- a/pgconn.go +++ b/pgconn.go @@ -536,8 +536,8 @@ func (pgConn *PgConn) cancelRequest(ctx context.Context) error { // statements. // // Prefer ExecParams unless executing arbitrary SQL that may contain multiple queries. -func (pgConn *PgConn) Exec(ctx context.Context, sql string) *PgMultiResult { - multiResult := &PgMultiResult{ +func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResult { + multiResult := &MultiResult{ pgConn: pgConn, ctx: ctx, cleanupContextDeadline: func() {}, @@ -593,8 +593,8 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *PgMultiResult { // binary format. If resultFormats is nil all results will be in text protocol. // // Result must be closed before PgConn can be used again. -func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *PgResult { - result := &PgResult{ +func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *Result { + result := &Result{ pgConn: pgConn, ctx: ctx, cleanupContextDeadline: func() {}, @@ -649,8 +649,8 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] // binary format. If resultFormats is nil all results will be in text protocol. // // Result must be closed before PgConn can be used again. -func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *PgResult { - result := &PgResult{ +func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *Result { + result := &Result{ pgConn: pgConn, ctx: ctx, cleanupContextDeadline: func() {}, @@ -689,18 +689,18 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa return result } -type PgMultiResult struct { +type MultiResult struct { pgConn *PgConn ctx context.Context cleanupContextDeadline func() - pgResult *PgResult + pgResult *Result closed bool err error } -func (mr *PgMultiResult) ReadAll() ([]*BufferedResult, error) { +func (mr *MultiResult) ReadAll() ([]*BufferedResult, error) { var results []*BufferedResult for mr.NextResult() { @@ -711,7 +711,7 @@ func (mr *PgMultiResult) ReadAll() ([]*BufferedResult, error) { return results, err } -func (mr *PgMultiResult) receiveMessage() (pgproto3.BackendMessage, error) { +func (mr *MultiResult) receiveMessage() (pgproto3.BackendMessage, error) { msg, err := mr.pgConn.ReceiveMessage() if err != nil { @@ -740,8 +740,8 @@ func (mr *PgMultiResult) receiveMessage() (pgproto3.BackendMessage, error) { return msg, nil } -// NextResult returns advances the PgMultiResult to the next result and returns true if a result is available. -func (mr *PgMultiResult) NextResult() bool { +// NextResult returns advances the MultiResult to the next result and returns true if a result is available. +func (mr *MultiResult) NextResult() bool { for !mr.closed && mr.err == nil { msg, err := mr.receiveMessage() if err != nil { @@ -750,7 +750,7 @@ func (mr *PgMultiResult) NextResult() bool { switch msg := msg.(type) { case *pgproto3.RowDescription: - mr.pgResult = &PgResult{ + mr.pgResult = &Result{ pgConn: mr.pgConn, pgMultiResult: mr, ctx: mr.ctx, @@ -759,7 +759,7 @@ func (mr *PgMultiResult) NextResult() bool { } return true case *pgproto3.CommandComplete: - mr.pgResult = &PgResult{ + mr.pgResult = &Result{ commandTag: CommandTag(msg.CommandTag), commandConcluded: true, closed: true, @@ -773,11 +773,11 @@ func (mr *PgMultiResult) NextResult() bool { return false } -func (mr *PgMultiResult) Result() *PgResult { +func (mr *MultiResult) Result() *Result { return mr.pgResult } -func (mr *PgMultiResult) Close() error { +func (mr *MultiResult) Close() error { for !mr.closed { _, err := mr.receiveMessage() if err != nil { @@ -788,9 +788,9 @@ func (mr *PgMultiResult) Close() error { return mr.err } -type PgResult struct { +type Result struct { pgConn *PgConn - pgMultiResult *PgMultiResult + pgMultiResult *MultiResult ctx context.Context cleanupContextDeadline func() @@ -809,7 +809,7 @@ type BufferedResult struct { Err error } -func (rr *PgResult) ReadAll() *BufferedResult { +func (rr *Result) ReadAll() *BufferedResult { br := &BufferedResult{} for rr.NextRow() { @@ -828,8 +828,8 @@ func (rr *PgResult) ReadAll() *BufferedResult { return br } -// NextRow advances the PgResult to the next row and returns true if a row is available. -func (rr *PgResult) NextRow() bool { +// NextRow advances the Result to the next row and returns true if a row is available. +func (rr *Result) NextRow() bool { for !rr.commandConcluded { msg, err := rr.receiveMessage() if err != nil { @@ -847,21 +847,21 @@ func (rr *PgResult) NextRow() bool { } // FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until -// the PgResult is closed. -func (rr *PgResult) FieldDescriptions() []pgproto3.FieldDescription { +// the Result is closed. +func (rr *Result) FieldDescriptions() []pgproto3.FieldDescription { return rr.fieldDescriptions } // Values returns the current row data. NextRow must have been previously been called. The returned [][]byte is only -// valid until the next NextRow call or the PgResult is closed. However, the underlying byte data is safe to +// valid until the next NextRow call or the Result is closed. However, the underlying byte data is safe to // retain a reference to and mutate. -func (rr *PgResult) Values() [][]byte { +func (rr *Result) Values() [][]byte { return rr.rowValues } // Close consumes any remaining result data and returns the command tag or // error. -func (rr *PgResult) Close() (CommandTag, error) { +func (rr *Result) Close() (CommandTag, error) { if rr.closed { return rr.commandTag, rr.err } @@ -893,7 +893,7 @@ func (rr *PgResult) Close() (CommandTag, error) { return rr.commandTag, rr.err } -func (rr *PgResult) receiveMessage() (msg pgproto3.BackendMessage, err error) { +func (rr *Result) receiveMessage() (msg pgproto3.BackendMessage, err error) { if rr.pgMultiResult == nil { msg, err = rr.pgConn.ReceiveMessage() } else { @@ -927,7 +927,7 @@ func (rr *PgResult) receiveMessage() (msg pgproto3.BackendMessage, err error) { return msg, nil } -func (rr *PgResult) concludeCommand(commandTag CommandTag, err error) { +func (rr *Result) concludeCommand(commandTag CommandTag, err error) { if rr.commandConcluded { return } @@ -1006,8 +1006,8 @@ func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFor batch.buf = (&pgproto3.Execute{}).Encode(batch.buf) } -func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *PgMultiResult { - multiResult := &PgMultiResult{ +func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResult { + multiResult := &MultiResult{ pgConn: pgConn, ctx: ctx, cleanupContextDeadline: func() {}, From 379be3508b5e79eba7dcd7ac4a47f80cfeba8058 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Jan 2019 17:46:47 -0600 Subject: [PATCH 041/290] Add some docs for batch --- pgconn.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pgconn.go b/pgconn.go index d4563086..09d87b31 100644 --- a/pgconn.go +++ b/pgconn.go @@ -988,6 +988,7 @@ func (pgConn *PgConn) recoverFromTimeout() { } } +// Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip. type Batch struct { buf []byte } @@ -1006,6 +1007,7 @@ func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFor batch.buf = (&pgproto3.Execute{}).Encode(batch.buf) } +// ExecBatch executes all the queries in batch in a single round-trip. func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResult { multiResult := &MultiResult{ pgConn: pgConn, From 2c8971b38263182f5644c7c6f65ec19a89e6f428 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Jan 2019 18:01:57 -0600 Subject: [PATCH 042/290] Rename some types and methods --- benchmark_test.go | 4 +- config.go | 2 +- helper_test.go | 2 +- pgconn.go | 126 +++++++++++++++++++++--------------------- pgconn_stress_test.go | 4 +- pgconn_test.go | 6 +- 6 files changed, 72 insertions(+), 72 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index ffb1455c..d2576324 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -76,7 +76,7 @@ func BenchmarkExecPrepared(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - result := conn.ExecPrepared(context.Background(), "ps1", nil, nil, nil).ReadAll() + result := conn.ExecPrepared(context.Background(), "ps1", nil, nil, nil).Read() require.Nil(b, result.Err) } } @@ -95,7 +95,7 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - result := conn.ExecPrepared(ctx, "ps1", nil, nil, nil).ReadAll() + result := conn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() require.Nil(b, result.Err) } } diff --git a/config.go b/config.go index fb0719cd..b85bcaec 100644 --- a/config.go +++ b/config.go @@ -470,7 +470,7 @@ func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { // AfterConnectTargetSessionAttrsReadWrite is an AfterConnectFunc that implements libpq compatible // target_session_attrs=read-write. func AfterConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { - result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).ReadAll() + result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() if result.Err != nil { return result.Err } diff --git a/helper_test.go b/helper_test.go index a50f7cb1..c5ac6e01 100644 --- a/helper_test.go +++ b/helper_test.go @@ -20,7 +20,7 @@ func closeConn(t testing.TB, conn *pgconn.PgConn) { // Do a simple query to ensure the connection is still usable func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) - result := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil).ReadAll() + result := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil).Read() cancel() require.Nil(t, result.Err) diff --git a/pgconn.go b/pgconn.go index 09d87b31..be7d37ae 100644 --- a/pgconn.go +++ b/pgconn.go @@ -536,8 +536,8 @@ func (pgConn *PgConn) cancelRequest(ctx context.Context) error { // statements. // // Prefer ExecParams unless executing arbitrary SQL that may contain multiple queries. -func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResult { - multiResult := &MultiResult{ +func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { + multiResult := &MultiResultReader{ pgConn: pgConn, ctx: ctx, cleanupContextDeadline: func() {}, @@ -592,9 +592,9 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResult { // resultFormats is a slice of format codes determining for each result column whether it is encoded in text or // binary format. If resultFormats is nil all results will be in text protocol. // -// Result must be closed before PgConn can be used again. -func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *Result { - result := &Result{ +// ResultReader must be closed before PgConn can be used again. +func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *ResultReader { + result := &ResultReader{ pgConn: pgConn, ctx: ctx, cleanupContextDeadline: func() {}, @@ -648,9 +648,9 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] // resultFormats is a slice of format codes determining for each result column whether it is encoded in text or // binary format. If resultFormats is nil all results will be in text protocol. // -// Result must be closed before PgConn can be used again. -func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *Result { - result := &Result{ +// ResultReader must be closed before PgConn can be used again. +func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *ResultReader { + result := &ResultReader{ pgConn: pgConn, ctx: ctx, cleanupContextDeadline: func() {}, @@ -689,77 +689,77 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa return result } -type MultiResult struct { +type MultiResultReader struct { pgConn *PgConn ctx context.Context cleanupContextDeadline func() - pgResult *Result + rr *ResultReader closed bool err error } -func (mr *MultiResult) ReadAll() ([]*BufferedResult, error) { - var results []*BufferedResult +func (mrr *MultiResultReader) ReadAll() ([]*Result, error) { + var results []*Result - for mr.NextResult() { - results = append(results, mr.Result().ReadAll()) + for mrr.NextResult() { + results = append(results, mrr.ResultReader().Read()) } - err := mr.Close() + err := mrr.Close() return results, err } -func (mr *MultiResult) receiveMessage() (pgproto3.BackendMessage, error) { - msg, err := mr.pgConn.ReceiveMessage() +func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) { + msg, err := mrr.pgConn.ReceiveMessage() if err != nil { - mr.cleanupContextDeadline() - mr.err = preferContextOverNetTimeoutError(mr.ctx, err) - mr.closed = true + mrr.cleanupContextDeadline() + mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) + mrr.closed = true if err, ok := err.(net.Error); ok && err.Timeout() { - go mr.pgConn.recoverFromTimeout() + go mrr.pgConn.recoverFromTimeout() } else { - <-mr.pgConn.controller + <-mrr.pgConn.controller } - return nil, mr.err + return nil, mrr.err } switch msg := msg.(type) { case *pgproto3.ReadyForQuery: - mr.cleanupContextDeadline() - mr.closed = true - <-mr.pgConn.controller + mrr.cleanupContextDeadline() + mrr.closed = true + <-mrr.pgConn.controller case *pgproto3.ErrorResponse: - mr.err = errorResponseToPgError(msg) + mrr.err = errorResponseToPgError(msg) } return msg, nil } -// NextResult returns advances the MultiResult to the next result and returns true if a result is available. -func (mr *MultiResult) NextResult() bool { - for !mr.closed && mr.err == nil { - msg, err := mr.receiveMessage() +// NextResult returns advances the MultiResultReader to the next result and returns true if a result is available. +func (mrr *MultiResultReader) NextResult() bool { + for !mrr.closed && mrr.err == nil { + msg, err := mrr.receiveMessage() if err != nil { return false } switch msg := msg.(type) { case *pgproto3.RowDescription: - mr.pgResult = &Result{ - pgConn: mr.pgConn, - pgMultiResult: mr, - ctx: mr.ctx, + mrr.rr = &ResultReader{ + pgConn: mrr.pgConn, + multiResultReader: mrr, + ctx: mrr.ctx, cleanupContextDeadline: func() {}, fieldDescriptions: msg.Fields, } return true case *pgproto3.CommandComplete: - mr.pgResult = &Result{ + mrr.rr = &ResultReader{ commandTag: CommandTag(msg.CommandTag), commandConcluded: true, closed: true, @@ -773,24 +773,24 @@ func (mr *MultiResult) NextResult() bool { return false } -func (mr *MultiResult) Result() *Result { - return mr.pgResult +func (mrr *MultiResultReader) ResultReader() *ResultReader { + return mrr.rr } -func (mr *MultiResult) Close() error { - for !mr.closed { - _, err := mr.receiveMessage() +func (mrr *MultiResultReader) Close() error { + for !mrr.closed { + _, err := mrr.receiveMessage() if err != nil { - return mr.err + return mrr.err } } - return mr.err + return mrr.err } -type Result struct { +type ResultReader struct { pgConn *PgConn - pgMultiResult *MultiResult + multiResultReader *MultiResultReader ctx context.Context cleanupContextDeadline func() @@ -802,15 +802,15 @@ type Result struct { err error } -type BufferedResult struct { +type Result struct { FieldDescriptions []pgproto3.FieldDescription Rows [][][]byte CommandTag CommandTag Err error } -func (rr *Result) ReadAll() *BufferedResult { - br := &BufferedResult{} +func (rr *ResultReader) Read() *Result { + br := &Result{} for rr.NextRow() { if br.FieldDescriptions == nil { @@ -828,8 +828,8 @@ func (rr *Result) ReadAll() *BufferedResult { return br } -// NextRow advances the Result to the next row and returns true if a row is available. -func (rr *Result) NextRow() bool { +// NextRow advances the ResultReader to the next row and returns true if a row is available. +func (rr *ResultReader) NextRow() bool { for !rr.commandConcluded { msg, err := rr.receiveMessage() if err != nil { @@ -847,21 +847,21 @@ func (rr *Result) NextRow() bool { } // FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until -// the Result is closed. -func (rr *Result) FieldDescriptions() []pgproto3.FieldDescription { +// the ResultReader is closed. +func (rr *ResultReader) FieldDescriptions() []pgproto3.FieldDescription { return rr.fieldDescriptions } // Values returns the current row data. NextRow must have been previously been called. The returned [][]byte is only -// valid until the next NextRow call or the Result is closed. However, the underlying byte data is safe to +// valid until the next NextRow call or the ResultReader is closed. However, the underlying byte data is safe to // retain a reference to and mutate. -func (rr *Result) Values() [][]byte { +func (rr *ResultReader) Values() [][]byte { return rr.rowValues } // Close consumes any remaining result data and returns the command tag or // error. -func (rr *Result) Close() (CommandTag, error) { +func (rr *ResultReader) Close() (CommandTag, error) { if rr.closed { return rr.commandTag, rr.err } @@ -874,7 +874,7 @@ func (rr *Result) Close() (CommandTag, error) { } } - if rr.pgMultiResult == nil { + if rr.multiResultReader == nil { for { msg, err := rr.receiveMessage() if err != nil { @@ -893,18 +893,18 @@ func (rr *Result) Close() (CommandTag, error) { return rr.commandTag, rr.err } -func (rr *Result) receiveMessage() (msg pgproto3.BackendMessage, err error) { - if rr.pgMultiResult == nil { +func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) { + if rr.multiResultReader == nil { msg, err = rr.pgConn.ReceiveMessage() } else { - msg, err = rr.pgMultiResult.receiveMessage() + msg, err = rr.multiResultReader.receiveMessage() } if err != nil { rr.concludeCommand(nil, err) rr.cleanupContextDeadline() rr.closed = true - if rr.pgMultiResult == nil { + if rr.multiResultReader == nil { if err, ok := err.(net.Error); ok && err.Timeout() { go rr.pgConn.recoverFromTimeout() } else { @@ -927,7 +927,7 @@ func (rr *Result) receiveMessage() (msg pgproto3.BackendMessage, err error) { return msg, nil } -func (rr *Result) concludeCommand(commandTag CommandTag, err error) { +func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { if rr.commandConcluded { return } @@ -1008,8 +1008,8 @@ func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFor } // ExecBatch executes all the queries in batch in a single round-trip. -func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResult { - multiResult := &MultiResult{ +func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader { + multiResult := &MultiResultReader{ pgConn: pgConn, ctx: ctx, cleanupContextDeadline: func() {}, diff --git a/pgconn_stress_test.go b/pgconn_stress_test.go index 17d344b7..6b5efd9f 100644 --- a/pgconn_stress_test.go +++ b/pgconn_stress_test.go @@ -70,7 +70,7 @@ func stressExecSelect(pgConn *pgconn.PgConn) error { } func stressExecParamsSelect(pgConn *pgconn.PgConn) error { - result := pgConn.ExecParams(context.Background(), "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).ReadAll() + result := pgConn.ExecParams(context.Background(), "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read() return result.Err } @@ -96,7 +96,7 @@ func stressExecSelectCanceled(pgConn *pgconn.PgConn) error { func stressExecParamsSelectCanceled(pgConn *pgconn.PgConn) error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) - result := pgConn.ExecParams(ctx, "select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).ReadAll() + result := pgConn.ExecParams(ctx, "select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read() cancel() if result.Err != context.DeadlineExceeded { return result.Err diff --git a/pgconn_test.go b/pgconn_test.go index a2eb7838..a524d18f 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -134,12 +134,12 @@ func TestConnectWithRuntimeParams(t *testing.T) { require.Nil(t, err) defer closeConn(t, conn) - result := conn.ExecParams(context.Background(), "show application_name", nil, nil, nil, nil).ReadAll() + result := conn.ExecParams(context.Background(), "show application_name", nil, nil, nil, nil).Read() require.Nil(t, result.Err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "pgxtest", string(result.Rows[0][0])) - result = conn.ExecParams(context.Background(), "show search_path", nil, nil, nil, nil).ReadAll() + result = conn.ExecParams(context.Background(), "show search_path", nil, nil, nil, nil).Read() require.Nil(t, result.Err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "myschema", string(result.Rows[0][0])) @@ -263,7 +263,7 @@ func TestConnExecEmpty(t *testing.T) { resultCount := 0 for multiResult.NextResult() { resultCount += 1 - multiResult.Result().Close() + multiResult.ResultReader().Close() } assert.Equal(t, 0, resultCount) err = multiResult.Close() From 2959411c419c147d5eef0d1d8ae14b611b7850ac Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Jan 2019 18:06:00 -0600 Subject: [PATCH 043/290] CommandTag is string --- pgconn.go | 22 +++++++++------------- pgconn_test.go | 4 ++-- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/pgconn.go b/pgconn.go index be7d37ae..7cf3c91d 100644 --- a/pgconn.go +++ b/pgconn.go @@ -347,7 +347,7 @@ func (pgConn *PgConn) ParameterStatus(key string) string { } // CommandTag is the result of an Exec function -type CommandTag []byte +type CommandTag string // RowsAffected returns the number of rows affected. If the CommandTag was not // for a row affecting command (e.g. "CREATE TABLE") then it returns 0. @@ -361,10 +361,6 @@ func (ct CommandTag) RowsAffected() int64 { return n } -func (ct CommandTag) String() string { - return string(ct) -} - // preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == // true. Otherwise returns err. func preferContextOverNetTimeoutError(ctx context.Context, err error) error { @@ -602,7 +598,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] select { case <-ctx.Done(): - result.concludeCommand(nil, ctx.Err()) + result.concludeCommand("", ctx.Err()) result.closed = true return result case pgConn.controller <- result: @@ -628,7 +624,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] pgConn.closed = true } - result.concludeCommand(nil, err) + result.concludeCommand("", err) result.cleanupContextDeadline() result.closed = true <-pgConn.controller @@ -658,7 +654,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa select { case <-ctx.Done(): - result.concludeCommand(nil, ctx.Err()) + result.concludeCommand("", ctx.Err()) result.closed = true return result case pgConn.controller <- result: @@ -680,7 +676,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa pgConn.closed = true } - result.concludeCommand(nil, err) + result.concludeCommand("", err) result.cleanupContextDeadline() result.closed = true <-pgConn.controller @@ -870,7 +866,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { for !rr.commandConcluded { _, err := rr.receiveMessage() if err != nil { - return nil, rr.err + return "", rr.err } } @@ -878,7 +874,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { for { msg, err := rr.receiveMessage() if err != nil { - return nil, rr.err + return "", rr.err } switch msg.(type) { @@ -901,7 +897,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error } if err != nil { - rr.concludeCommand(nil, err) + rr.concludeCommand("", err) rr.cleanupContextDeadline() rr.closed = true if rr.multiResultReader == nil { @@ -921,7 +917,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error case *pgproto3.CommandComplete: rr.concludeCommand(CommandTag(msg.CommandTag), nil) case *pgproto3.ErrorResponse: - rr.concludeCommand(nil, errorResponseToPgError(msg)) + rr.concludeCommand("", errorResponseToPgError(msg)) } return msg, nil diff --git a/pgconn_test.go b/pgconn_test.go index a524d18f..a63aee38 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -375,7 +375,7 @@ func TestConnExecParamsCanceled(t *testing.T) { } assert.Equal(t, 0, rowCount) commandTag, err := result.Close() - assert.Nil(t, commandTag) + assert.Equal(t, pgconn.CommandTag(""), commandTag) assert.Equal(t, context.DeadlineExceeded, err) ensureConnValid(t, pgConn) @@ -427,7 +427,7 @@ func TestConnExecPreparedCanceled(t *testing.T) { } assert.Equal(t, 0, rowCount) commandTag, err := result.Close() - assert.Nil(t, commandTag) + assert.Equal(t, pgconn.CommandTag(""), commandTag) assert.Equal(t, context.DeadlineExceeded, err) ensureConnValid(t, pgConn) From 406e95650a8823d98074dd3e08bdcf097e4b50cc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Jan 2019 18:40:33 -0600 Subject: [PATCH 044/290] Add more docs --- config.go | 35 +++++++++++++++++++---------------- doc.go | 29 +++++++++++++++++++++++++++++ pgconn.go | 10 +++++++++- 3 files changed, 57 insertions(+), 17 deletions(-) create mode 100644 doc.go diff --git a/config.go b/config.go index b85bcaec..13167729 100644 --- a/config.go +++ b/config.go @@ -70,32 +70,35 @@ func NetworkAddress(host string, port uint16) (network, address string) { // It also may be empty to only read from the environment. If a password is not supplied it will attempt to read the // .pgpass file. // -// Example DSN: "user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca" +// # Example DSN +// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca // -// Example URL: "postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca" +// # Example URL +// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca // // ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated // values that will be tried in order. This can be used as part of a high availability system. See // https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. // -// Example URL: "postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb" +// # Example URL +// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb // // ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed // via database URL or DSN: // -// PGHOST -// PGPORT -// PGDATABASE -// PGUSER -// PGPASSWORD -// PGPASSFILE -// PGSSLMODE -// PGSSLCERT -// PGSSLKEY -// PGSSLROOTCERT -// PGAPPNAME -// PGCONNECT_TIMEOUT -// PGTARGETSESSIONATTRS +// PGHOST +// PGPORT +// PGDATABASE +// PGUSER +// PGPASSWORD +// PGPASSFILE +// PGSSLMODE +// PGSSLCERT +// PGSSLKEY +// PGSSLROOTCERT +// PGAPPNAME +// PGCONNECT_TIMEOUT +// PGTARGETSESSIONATTRS // // See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables. // diff --git a/doc.go b/doc.go new file mode 100644 index 00000000..89e47536 --- /dev/null +++ b/doc.go @@ -0,0 +1,29 @@ +// Package pgconn is a low-level PostgreSQL database driver. +/* +pgconn provides lower level access to a PostgreSQL connection than a database/sql or pgx connection. It operates at +nearly the same level is the C library libpq. + +Establishing a Connection + +Use Connect to establish a connection. It accepts a connection string in URL or DSN and will read the environment for +libpq style environment variables. + +Executing a Query + +ExecParams and ExecPrepared execute a single query. They return readers that iterate over each row. The Read method +reads all rows into memory. + +Executing Multiple Queries in a Single Round Trip + +Exec and ExecBatch can execute multiple queries in a single round trip. The return readers that iterate over each query +result. The ReadAll method reads all query results into memory. + +Context Support + +All potentially blocking operations take a context.Context. If a context is canceled while a query is in progress the +method immediately returns. In the background a cancel request will be sent to the PostgreSQL server. If the +cancellation fails or hangs for more than a short time (approximately 15 seconds) the connection will be closed. It is +safe to use the connection while this background cancellation is in progress. Any calls will block until the +cancellation and resynchronization is complete (and those calls can be aborted by a context cancellation). +*/ +package pgconn diff --git a/pgconn.go b/pgconn.go index 7cf3c91d..bab4370a 100644 --- a/pgconn.go +++ b/pgconn.go @@ -685,6 +685,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa return result } +// MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. type MultiResultReader struct { pgConn *PgConn ctx context.Context @@ -696,6 +697,7 @@ type MultiResultReader struct { err error } +// ReadAll reads all available results. Calling ReadAll is mutually exclusive with all other MultiResultReader methods. func (mrr *MultiResultReader) ReadAll() ([]*Result, error) { var results []*Result @@ -769,10 +771,12 @@ func (mrr *MultiResultReader) NextResult() bool { return false } +// ResultReader returns the current ResultReader. func (mrr *MultiResultReader) ResultReader() *ResultReader { return mrr.rr } +// Close closes the MultiResultReader and returns the first error that occurred during the MultiResultReader's use. func (mrr *MultiResultReader) Close() error { for !mrr.closed { _, err := mrr.receiveMessage() @@ -784,6 +788,7 @@ func (mrr *MultiResultReader) Close() error { return mrr.err } +// ResultReader is a reader for the result of a single query. type ResultReader struct { pgConn *PgConn multiResultReader *MultiResultReader @@ -798,6 +803,7 @@ type ResultReader struct { err error } +// Result is the saved query response that is returned by calling Read on a ResultReader. type Result struct { FieldDescriptions []pgproto3.FieldDescription Rows [][][]byte @@ -805,6 +811,7 @@ type Result struct { Err error } +// Read saves the query response to a Result. func (rr *ResultReader) Read() *Result { br := &Result{} @@ -1003,7 +1010,8 @@ func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFor batch.buf = (&pgproto3.Execute{}).Encode(batch.buf) } -// ExecBatch executes all the queries in batch in a single round-trip. +// ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a +// transaction is already in progress or SQL contains transaction control statements. func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader { multiResult := &MultiResultReader{ pgConn: pgConn, From c6a73a469a84661171e31116a8228e54f3f52aa6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Jan 2019 18:47:50 -0600 Subject: [PATCH 045/290] Add example --- pgconn_test.go | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/pgconn_test.go b/pgconn_test.go index a63aee38..2d8cc784 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -3,6 +3,8 @@ package pgconn_test import ( "context" "crypto/tls" + "fmt" + "log" "net" "os" "testing" @@ -513,3 +515,27 @@ end$$;`) ensureConnValid(t, pgConn) } + +func Example() { + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + log.Fatalln(err) + } + defer pgConn.Close(context.Background()) + + result := pgConn.ExecParams(context.Background(), "select generate_series(1,3)", nil, nil, nil, nil).Read() + if result.Err != nil { + log.Fatalln(result.Err) + } + + for _, row := range result.Rows { + fmt.Println(string(row[0])) + } + + fmt.Println(result.CommandTag) + // Output: + // 1 + // 2 + // 3 + // SELECT 3 +} From bd777fe20c73cf2eea37d2ada1a62164f0074bd1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 12 Jan 2019 11:37:13 -0600 Subject: [PATCH 046/290] Add custom context cancellation hook --- config.go | 14 ++++++++++- pgconn.go | 38 +++++++++++++++++++++++++++- pgconn_test.go | 67 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 117 insertions(+), 2 deletions(-) diff --git a/config.go b/config.go index 13167729..40cbd0bb 100644 --- a/config.go +++ b/config.go @@ -41,7 +41,19 @@ type Config struct { // allows implementing high availability behavior such as libpq does with target_session_attrs. AfterConnectFunc AfterConnectFunc - OnNotice NoticeHandler // Callback function called when a notice response is received. + // OnContextCancel is a callback function used to override cancellation behavior. It is called when a context.Context + // is canceled. Default cancellation behavior is to establish another connection to the PostgreSQL server and send a + // query cancel request. Some non-PostgreSQL servers (e.g. CockroachDB) that speak a subset of the PostgreSQL wire + // protocol do not support this cancellation method. + // + // It is called from a background goroutine. When the cancellation process has finished ContextCancel.Finish must be + // called whether it was successful or not. If an error occurs the connection should be closed. The connection must be + // in a ready for query state or be closed when ContextCancel.Finish is called. Use PgConn.ReceiveMessage() to read + // the connection until a ready for query message is received. + OnContextCancel func(*ContextCancel) + + // OnNotice is a callback function called when a notice response is received. + OnNotice NoticeHandler } // FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a diff --git a/pgconn.go b/pgconn.go index bab4370a..08fce16e 100644 --- a/pgconn.go +++ b/pgconn.go @@ -527,6 +527,22 @@ func (pgConn *PgConn) cancelRequest(ctx context.Context) error { return nil } +// WaitUntilReady waits until a previous context cancellation has been competed processed and the connection is ready +// for use. This is done automatically by all methods that need the connection to be ready for use. The only expected +// use for this method is for a connection pool to wait for a returned connection to be usable again before making it +// available. +func (pgConn *PgConn) WaitUntilReady(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + case pgConn.controller <- pgConn: + // The connection must be ready since it was locked. Immediately unlock it. + <-pgConn.controller + } + + return nil +} + // Exec executes SQL via the PostgreSQL simple query protocol. SQL may contain multiple queries. Execution is // implicitly wrapped in a transaction unless a transaction is already in progress or SQL contains transaction control // statements. @@ -942,7 +958,7 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { rr.commandConcluded = true } -func (pgConn *PgConn) recoverFromTimeout() { +func (pgConn *PgConn) defaultCancel() { // Regardless of recovery outcome the lock on the pgConn must be released. defer func() { <-pgConn.controller }() @@ -991,6 +1007,26 @@ func (pgConn *PgConn) recoverFromTimeout() { } } +type ContextCancel struct { + PgConn *PgConn +} + +// Finish must be called when the cancellation request has finished processing. The connection must be in a ready for +// query state or the connection must be closed. This must be called regardless of the success of the cancellation and +// whether the connection is still valid or not. It releases an internal busy lock on the connection. +func (cc *ContextCancel) Finish() { + <-cc.PgConn.controller +} + +func (pgConn *PgConn) recoverFromTimeout() { + if pgConn.Config.OnContextCancel == nil { + pgConn.defaultCancel() + } else { + cc := &ContextCancel{PgConn: pgConn} + pgConn.Config.OnContextCancel(cc) + } +} + // Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip. type Batch struct { buf []byte diff --git a/pgconn_test.go b/pgconn_test.go index 2d8cc784..9452ffc0 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/jackc/pgx/pgconn" + "github.com/jackc/pgx/pgproto3" "github.com/pkg/errors" "github.com/stretchr/testify/assert" @@ -490,6 +491,72 @@ func TestCommandTag(t *testing.T) { } } +func TestConnContextCancelWithOnContextCancel(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + + calledChan := make(chan struct{}) + + config.OnContextCancel = func(cc *pgconn.ContextCancel) { + defer cc.Finish() + close(calledChan) + + for { + msg, err := cc.PgConn.ReceiveMessage() + if err != nil { + cc.PgConn.Close(context.Background()) + return + } + + switch msg.(type) { + case *pgproto3.ReadyForQuery: + return + } + } + } + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + result := pgConn.ExecParams(ctx, "select 'Hello, world', pg_sleep(0.25)", nil, nil, nil, nil) + _, err = result.Close() + assert.Equal(t, context.DeadlineExceeded, err) + + called := false + select { + case <-calledChan: + called = true + case <-time.NewTimer(time.Second).C: + } + + assert.True(t, called) + + ensureConnValid(t, pgConn) +} + +func TestConnWaitUntilReady(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + result := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil).Read() + assert.Equal(t, context.DeadlineExceeded, result.Err) + + err = pgConn.WaitUntilReady(context.Background()) + require.Nil(t, err) + + ensureConnValid(t, pgConn) +} + func TestConnOnNotice(t *testing.T) { t.Parallel() From 9c36fa1e5038662693788b43752bceae00f00417 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 12 Jan 2019 15:38:20 -0600 Subject: [PATCH 047/290] Fix prepare failure --- pgconn.go | 9 +++++++-- pgconn_test.go | 14 ++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/pgconn.go b/pgconn.go index 08fce16e..2a3c5936 100644 --- a/pgconn.go +++ b/pgconn.go @@ -438,6 +438,8 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ psd := &PreparedStatementDescription{Name: name, SQL: sql} + var parseErr error + readloop: for { msg, err := pgConn.ReceiveMessage() @@ -454,14 +456,17 @@ readloop: psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields)) copy(psd.Fields, msg.Fields) case *pgproto3.ErrorResponse: - go pgConn.recoverFromTimeout() - return nil, errorResponseToPgError(msg) + parseErr = errorResponseToPgError(msg) case *pgproto3.ReadyForQuery: break readloop } } <-pgConn.controller + + if parseErr != nil { + return nil, parseErr + } return psd, nil } diff --git a/pgconn_test.go b/pgconn_test.go index 9452ffc0..90f99325 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -235,6 +235,20 @@ func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { } } +func TestConnPrepareFailure(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + psd, err := pgConn.Prepare(context.Background(), "ps1", "SYNTAX ERROR", nil) + require.Nil(t, psd) + require.NotNil(t, err) + + ensureConnValid(t, pgConn) +} + func TestConnExec(t *testing.T) { t.Parallel() From b3cde6830f0ae451d08f0854a421cc538e2d2e6e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 12 Jan 2019 16:17:03 -0600 Subject: [PATCH 048/290] Fix die on receive message error --- pgconn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index 2a3c5936..9277d4a8 100644 --- a/pgconn.go +++ b/pgconn.go @@ -263,7 +263,7 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { msg, err := pgConn.Frontend.Receive() if err != nil { // Close on anything other than timeout error - everything else is fatal - if err, ok := err.(net.Error); !ok && err.Timeout() { + if err, ok := err.(net.Error); !(ok && err.Timeout()) { pgConn.hardClose() } From cd4b0025c3d322ac21fe481a8e55a0743e37f27b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 14 Jan 2019 20:39:10 -0600 Subject: [PATCH 049/290] Add listen/notify to pgconn --- config.go | 3 +++ pgconn.go | 17 +++++++++++++++++ pgconn_test.go | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 52 insertions(+) diff --git a/config.go b/config.go index 40cbd0bb..fec1fedf 100644 --- a/config.go +++ b/config.go @@ -54,6 +54,9 @@ type Config struct { // OnNotice is a callback function called when a notice response is received. OnNotice NoticeHandler + + // OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received. + OnNotification NotificationHandler } // FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a diff --git a/pgconn.go b/pgconn.go index 9277d4a8..b2ffe7ca 100644 --- a/pgconn.go +++ b/pgconn.go @@ -50,6 +50,13 @@ func (pe *PgError) Error() string { // LISTEN/NOTIFY notification. type Notice PgError +// Notification is a message received from the PostgreSQL LISTEN/NOTIFY system +type Notification struct { + PID uint32 // backend pid that sent the notification + Channel string // channel from which notification was received + Payload string +} + // DialFunc is a function that can be used to connect to a PostgreSQL server type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) @@ -59,6 +66,12 @@ type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) // notification. type NoticeHandler func(*PgConn, *Notice) +// NotificationHandler is a function that can handle notifications received from the PostgreSQL server. Notifications +// can be received at any time, usually during handling of a query response. The *PgConn is provided so the handler is +// aware of the origin of the notice, but it must not invoke any query method. Be aware that this is distinct from a +// notice event. +type NotificationHandler func(*PgConn, *Notification) + // ErrTLSRefused occurs when the connection attempt requires TLS and the // PostgreSQL server refuses to use TLS var ErrTLSRefused = errors.New("server refused TLS connection") @@ -284,6 +297,10 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { if pgConn.Config.OnNotice != nil { pgConn.Config.OnNotice(pgConn, noticeResponseToNotice(msg)) } + case *pgproto3.NotificationResponse: + if pgConn.Config.OnNotification != nil { + pgConn.Config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload}) + } } return msg, nil diff --git a/pgconn_test.go b/pgconn_test.go index 90f99325..ad538257 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -597,6 +597,38 @@ end$$;`) ensureConnValid(t, pgConn) } +func TestConnOnNotification(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + + var msg string + config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { + msg = n.Payload + } + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() + require.Nil(t, err) + + notifier, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + defer closeConn(t, notifier) + _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() + require.Nil(t, err) + + _, err = pgConn.Exec(context.Background(), "select 1").ReadAll() + require.Nil(t, err) + + assert.Equal(t, "bar", msg) + + ensureConnValid(t, pgConn) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) if err != nil { From edfd837ba4192c55770f6a18bd1fcfb49ed07f4f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 14 Jan 2019 20:51:53 -0600 Subject: [PATCH 050/290] Add PgConn.WaitForNotification --- pgconn.go | 25 +++++++++++++++++++++++++ pgconn_test.go | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+) diff --git a/pgconn.go b/pgconn.go index b2ffe7ca..efd7686f 100644 --- a/pgconn.go +++ b/pgconn.go @@ -565,6 +565,31 @@ func (pgConn *PgConn) WaitUntilReady(ctx context.Context) error { return nil } +// WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not +// received. +func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + case pgConn.controller <- pgConn: + } + cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) + defer cleanupContextDeadline() + defer func() { <-pgConn.controller }() + + for { + msg, err := pgConn.ReceiveMessage() + if err != nil { + return preferContextOverNetTimeoutError(ctx, err) + } + + switch msg.(type) { + case *pgproto3.NotificationResponse: + return nil + } + } +} + // Exec executes SQL via the PostgreSQL simple query protocol. SQL may contain multiple queries. Execution is // implicitly wrapped in a transaction unless a transaction is already in progress or SQL contains transaction control // statements. diff --git a/pgconn_test.go b/pgconn_test.go index ad538257..07e54c75 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -629,6 +629,56 @@ func TestConnOnNotification(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnWaitForNotification(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + + var msg string + config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { + msg = n.Payload + } + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() + require.Nil(t, err) + + notifier, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + defer closeConn(t, notifier) + _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() + require.Nil(t, err) + + err = pgConn.WaitForNotification(context.Background()) + require.Nil(t, err) + + assert.Equal(t, "bar", msg) + + ensureConnValid(t, pgConn) +} + +func TestConnWaitForNotificationTimeout(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + err = pgConn.WaitForNotification(ctx) + cancel() + require.Equal(t, context.DeadlineExceeded, err) + + ensureConnValid(t, pgConn) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) if err != nil { From e441d4828c13a21de6c8f96aa814ab0d119e639e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 19 Jan 2019 14:49:26 -0600 Subject: [PATCH 051/290] Fix doc typo --- pgconn.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pgconn.go b/pgconn.go index efd7686f..13301364 100644 --- a/pgconn.go +++ b/pgconn.go @@ -549,10 +549,9 @@ func (pgConn *PgConn) cancelRequest(ctx context.Context) error { return nil } -// WaitUntilReady waits until a previous context cancellation has been competed processed and the connection is ready -// for use. This is done automatically by all methods that need the connection to be ready for use. The only expected -// use for this method is for a connection pool to wait for a returned connection to be usable again before making it -// available. +// WaitUntilReady waits until a previous context cancellation has been completed and the connection is ready for use. +// This is done automatically by all methods that need the connection to be ready for use. The only expected use for +// this method is for a connection pool to wait for a returned connection to be usable again before making it available. func (pgConn *PgConn) WaitUntilReady(ctx context.Context) error { select { case <-ctx.Done(): From 19ef57ad9a7392ef6cd4ae96665ec2e32d1caa0c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 19 Jan 2019 14:49:39 -0600 Subject: [PATCH 052/290] Add PgConn.CopyTo --- pgconn.go | 65 ++++++++++++++++++++++++++++ pgconn_test.go | 112 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 177 insertions(+) diff --git a/pgconn.go b/pgconn.go index 13301364..476cd046 100644 --- a/pgconn.go +++ b/pgconn.go @@ -747,6 +747,71 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa return result } +// CopyTo executes the copy command sql and copies the results to w. +func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { + select { + case <-ctx.Done(): + return "", ctx.Err() + case pgConn.controller <- pgConn: + } + cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) + + // Send copy to command + var buf []byte + buf = (&pgproto3.Query{String: sql}).Encode(buf) + + n, err := pgConn.conn.Write(buf) + if err != nil { + // Partially sent messages are a fatal error for the connection. + if n > 0 { + // Close connection because cannot recover from partially sent message. + pgConn.conn.Close() + pgConn.closed = true + } + + cleanupContextDeadline() + <-pgConn.controller + + return "", preferContextOverNetTimeoutError(ctx, err) + } + + // Read results + var commandTag CommandTag + var pgErr error + for { + msg, err := pgConn.ReceiveMessage() + if err != nil { + cleanupContextDeadline() + if err, ok := err.(net.Error); ok && err.Timeout() { + go pgConn.recoverFromTimeout() + } else { + <-pgConn.controller + } + + return "", preferContextOverNetTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.CopyDone: + case *pgproto3.CopyData: + _, err := w.Write(msg.Data) + if err != nil { + // This isn't actually a timeout, but we want the same behavior. Abort the request and cleanup. + cleanupContextDeadline() + go pgConn.recoverFromTimeout() + return "", err + } + case *pgproto3.ReadyForQuery: + <-pgConn.controller + return commandTag, pgErr + case *pgproto3.CommandComplete: + commandTag = CommandTag(msg.CommandTag) + case *pgproto3.ErrorResponse: + pgErr = errorResponseToPgError(msg) + } + } +} + // MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. type MultiResultReader struct { pgConn *PgConn diff --git a/pgconn_test.go b/pgconn_test.go index 07e54c75..ab7cfa72 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1,6 +1,7 @@ package pgconn_test import ( + "bytes" "context" "crypto/tls" "fmt" @@ -679,6 +680,117 @@ func TestConnWaitForNotificationTimeout(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnCopyToSmall(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g json + )`).ReadAll() + require.Nil(t, err) + + _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll() + require.Nil(t, err) + + _, err = pgConn.Exec(context.Background(), `insert into foo values (null, null, null, null, null, null, null)`).ReadAll() + require.Nil(t, err) + + inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" + + "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n") + + outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) + + res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") + require.Nil(t, err) + + assert.Equal(t, int64(2), res.RowsAffected()) + assert.Equal(t, inputBytes, outputWriter.Bytes()) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyToLarge(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g json, + h bytea + )`).ReadAll() + require.Nil(t, err) + + inputBytes := make([]byte, 0) + + for i := 0; i < 1000; i++ { + _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll() + require.Nil(t, err) + inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...) + } + + outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) + + res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") + require.Nil(t, err) + + assert.Equal(t, int64(1000), res.RowsAffected()) + assert.Equal(t, inputBytes, outputWriter.Bytes()) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyToQueryError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + outputWriter := bytes.NewBuffer(make([]byte, 0)) + + res, err := pgConn.CopyTo(context.Background(), outputWriter, "cropy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyToCanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + outputWriter := &bytes.Buffer{} + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout") + assert.Equal(t, context.DeadlineExceeded, err) + assert.Equal(t, pgconn.CommandTag(""), res) + + ensureConnValid(t, pgConn) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) if err != nil { From c447ff4e797dc10be183fed254cbed82c61cc4f6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 19 Jan 2019 14:51:07 -0600 Subject: [PATCH 053/290] Use NoError instead of Nil for assertions --- config_test.go | 10 ++-- pgconn_stress_test.go | 4 +- pgconn_test.go | 128 +++++++++++++++++++++--------------------- 3 files changed, 71 insertions(+), 71 deletions(-) diff --git a/config_test.go b/config_test.go index e7a5bb44..c7b65861 100644 --- a/config_test.go +++ b/config_test.go @@ -515,12 +515,12 @@ func TestParseConfigEnvLibpq(t *testing.T) { for i, tt := range tests { for _, n := range pgEnvvars { err := os.Unsetenv(n) - require.Nil(t, err) + require.NoError(t, err) } for k, v := range tt.envvars { err := os.Setenv(k, v) - require.Nil(t, err) + require.NoError(t, err) } config, err := pgconn.ParseConfig("") @@ -536,13 +536,13 @@ func TestParseConfigReadsPgPassfile(t *testing.T) { t.Parallel() tf, err := ioutil.TempFile("", "") - require.Nil(t, err) + require.NoError(t, err) defer tf.Close() defer os.Remove(tf.Name()) _, err = tf.Write([]byte("test1:5432:curlydb:curly:nyuknyuknyuk")) - require.Nil(t, err) + require.NoError(t, err) connString := fmt.Sprintf("postgres://curly@test1:5432/curlydb?sslmode=disable&passfile=%s", tf.Name()) expected := &pgconn.Config{ @@ -556,7 +556,7 @@ func TestParseConfigReadsPgPassfile(t *testing.T) { } actual, err := pgconn.ParseConfig(connString) - assert.Nil(t, err) + assert.NoError(t, err) assertConfigsEqual(t, expected, actual, "passfile") } diff --git a/pgconn_stress_test.go b/pgconn_stress_test.go index 6b5efd9f..7a95fa98 100644 --- a/pgconn_stress_test.go +++ b/pgconn_stress_test.go @@ -17,7 +17,7 @@ func TestConnStress(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) actionCount := 100 @@ -61,7 +61,7 @@ func setupStressDB(t *testing.T, pgConn *pgconn.PgConn) { ('Foo', 'bar'), ('baz', 'Something really long Something really long Something really long Something really long Something really long'), ('a', 'b')`).ReadAll() - require.Nil(t, err) + require.NoError(t, err) } func stressExecSelect(pgConn *pgconn.PgConn) error { diff --git a/pgconn_test.go b/pgconn_test.go index ab7cfa72..f3ed04df 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -40,7 +40,7 @@ func TestConnect(t *testing.T) { } conn, err := pgconn.Connect(context.Background(), connString) - require.Nil(t, err) + require.NoError(t, err) closeConn(t, conn) }) @@ -58,7 +58,7 @@ func TestConnectTLS(t *testing.T) { } conn, err := pgconn.Connect(context.Background(), connString) - require.Nil(t, err) + require.NoError(t, err) if _, ok := conn.Conn().(*tls.Conn); !ok { t.Error("not a TLS connection") @@ -76,7 +76,7 @@ func TestConnectInvalidUser(t *testing.T) { } config, err := pgconn.ParseConfig(connString) - require.Nil(t, err) + require.NoError(t, err) config.User = "pgxinvalidusertest" @@ -109,7 +109,7 @@ func TestConnectCustomDialer(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) dialed := false config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { @@ -118,7 +118,7 @@ func TestConnectCustomDialer(t *testing.T) { } conn, err := pgconn.ConnectConfig(context.Background(), config) - require.Nil(t, err) + require.NoError(t, err) require.True(t, dialed) closeConn(t, conn) } @@ -127,7 +127,7 @@ func TestConnectWithRuntimeParams(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) config.RuntimeParams = map[string]string{ "application_name": "pgxtest", @@ -135,7 +135,7 @@ func TestConnectWithRuntimeParams(t *testing.T) { } conn, err := pgconn.ConnectConfig(context.Background(), config) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, conn) result := conn.ExecParams(context.Background(), "show application_name", nil, nil, nil, nil).Read() @@ -153,7 +153,7 @@ func TestConnectWithFallback(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) // Prepend current primary config to fallbacks config.Fallbacks = append([]*pgconn.FallbackConfig{ @@ -178,7 +178,7 @@ func TestConnectWithFallback(t *testing.T) { }, config.Fallbacks...) conn, err := pgconn.ConnectConfig(context.Background(), config) - require.Nil(t, err) + require.NoError(t, err) closeConn(t, conn) } @@ -186,7 +186,7 @@ func TestConnectWithAfterConnectFunc(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) dialCount := 0 config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { @@ -214,7 +214,7 @@ func TestConnectWithAfterConnectFunc(t *testing.T) { config.Fallbacks = append(config.Fallbacks, config.Fallbacks...) conn, err := pgconn.ConnectConfig(context.Background(), config) - require.Nil(t, err) + require.NoError(t, err) closeConn(t, conn) assert.True(t, dialCount > 1) @@ -225,7 +225,7 @@ func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) config.AfterConnectFunc = pgconn.AfterConnectTargetSessionAttrsReadWrite config.RuntimeParams["default_transaction_read_only"] = "on" @@ -240,7 +240,7 @@ func TestConnPrepareFailure(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) psd, err := pgConn.Prepare(context.Background(), "ps1", "SYNTAX ERROR", nil) @@ -254,11 +254,11 @@ func TestConnExec(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, results, 1) assert.Nil(t, results[0].Err) @@ -273,7 +273,7 @@ func TestConnExecEmpty(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) multiResult := pgConn.Exec(context.Background(), ";") @@ -285,7 +285,7 @@ func TestConnExecEmpty(t *testing.T) { } assert.Equal(t, 0, resultCount) err = multiResult.Close() - assert.Nil(t, err) + assert.NoError(t, err) ensureConnValid(t, pgConn) } @@ -294,11 +294,11 @@ func TestConnExecMultipleQueries(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) results, err := pgConn.Exec(context.Background(), "select 'Hello, world'; select 1").ReadAll() - assert.Nil(t, err) + assert.NoError(t, err) assert.Len(t, results, 2) @@ -319,7 +319,7 @@ func TestConnExecMultipleQueriesError(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) results, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1").ReadAll() @@ -341,7 +341,7 @@ func TestConnExecContextCanceled(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) @@ -360,7 +360,7 @@ func TestConnExecParams(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) result := pgConn.ExecParams(context.Background(), "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) @@ -372,7 +372,7 @@ func TestConnExecParams(t *testing.T) { assert.Equal(t, 1, rowCount) commandTag, err := result.Close() assert.Equal(t, "SELECT 1", string(commandTag)) - assert.Nil(t, err) + assert.NoError(t, err) ensureConnValid(t, pgConn) } @@ -381,7 +381,7 @@ func TestConnExecParamsCanceled(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) @@ -403,11 +403,11 @@ func TestConnExecPrepared(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) - require.Nil(t, err) + require.NoError(t, err) require.NotNil(t, psd) assert.Len(t, psd.ParamOIDs, 1) assert.Len(t, psd.Fields, 1) @@ -421,7 +421,7 @@ func TestConnExecPrepared(t *testing.T) { assert.Equal(t, 1, rowCount) commandTag, err := result.Close() assert.Equal(t, "SELECT 1", string(commandTag)) - assert.Nil(t, err) + assert.NoError(t, err) ensureConnValid(t, pgConn) } @@ -430,11 +430,11 @@ func TestConnExecPreparedCanceled(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) _, err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil) - require.Nil(t, err) + require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() @@ -455,11 +455,11 @@ func TestConnExecBatch(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) _, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) - require.Nil(t, err) + require.NoError(t, err) batch := &pgconn.Batch{} @@ -467,7 +467,7 @@ func TestConnExecBatch(t *testing.T) { batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() - require.Nil(t, err) + require.NoError(t, err) require.Len(t, results, 3) require.Len(t, results[0].Rows, 1) @@ -510,7 +510,7 @@ func TestConnContextCancelWithOnContextCancel(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) calledChan := make(chan struct{}) @@ -533,7 +533,7 @@ func TestConnContextCancelWithOnContextCancel(t *testing.T) { } pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) @@ -558,7 +558,7 @@ func TestConnWaitUntilReady(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) @@ -567,7 +567,7 @@ func TestConnWaitUntilReady(t *testing.T) { assert.Equal(t, context.DeadlineExceeded, result.Err) err = pgConn.WaitUntilReady(context.Background()) - require.Nil(t, err) + require.NoError(t, err) ensureConnValid(t, pgConn) } @@ -576,7 +576,7 @@ func TestConnOnNotice(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) var msg string config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) { @@ -584,7 +584,7 @@ func TestConnOnNotice(t *testing.T) { } pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) multiResult := pgConn.Exec(context.Background(), `do $$ @@ -592,7 +592,7 @@ begin raise notice 'hello, world'; end$$;`) err = multiResult.Close() - require.Nil(t, err) + require.NoError(t, err) assert.Equal(t, "hello, world", msg) ensureConnValid(t, pgConn) @@ -602,7 +602,7 @@ func TestConnOnNotification(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) var msg string config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { @@ -610,20 +610,20 @@ func TestConnOnNotification(t *testing.T) { } pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() - require.Nil(t, err) + require.NoError(t, err) notifier, err := pgconn.ConnectConfig(context.Background(), config) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, notifier) _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() - require.Nil(t, err) + require.NoError(t, err) _, err = pgConn.Exec(context.Background(), "select 1").ReadAll() - require.Nil(t, err) + require.NoError(t, err) assert.Equal(t, "bar", msg) @@ -634,7 +634,7 @@ func TestConnWaitForNotification(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) var msg string config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { @@ -642,20 +642,20 @@ func TestConnWaitForNotification(t *testing.T) { } pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() - require.Nil(t, err) + require.NoError(t, err) notifier, err := pgconn.ConnectConfig(context.Background(), config) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, notifier) _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() - require.Nil(t, err) + require.NoError(t, err) err = pgConn.WaitForNotification(context.Background()) - require.Nil(t, err) + require.NoError(t, err) assert.Equal(t, "bar", msg) @@ -666,10 +666,10 @@ func TestConnWaitForNotificationTimeout(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) @@ -684,7 +684,7 @@ func TestConnCopyToSmall(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) _, err = pgConn.Exec(context.Background(), `create temporary table foo( @@ -696,13 +696,13 @@ func TestConnCopyToSmall(t *testing.T) { f date, g json )`).ReadAll() - require.Nil(t, err) + require.NoError(t, err) _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll() - require.Nil(t, err) + require.NoError(t, err) _, err = pgConn.Exec(context.Background(), `insert into foo values (null, null, null, null, null, null, null)`).ReadAll() - require.Nil(t, err) + require.NoError(t, err) inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" + "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n") @@ -710,7 +710,7 @@ func TestConnCopyToSmall(t *testing.T) { outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") - require.Nil(t, err) + require.NoError(t, err) assert.Equal(t, int64(2), res.RowsAffected()) assert.Equal(t, inputBytes, outputWriter.Bytes()) @@ -722,7 +722,7 @@ func TestConnCopyToLarge(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) _, err = pgConn.Exec(context.Background(), `create temporary table foo( @@ -735,20 +735,20 @@ func TestConnCopyToLarge(t *testing.T) { g json, h bytea )`).ReadAll() - require.Nil(t, err) + require.NoError(t, err) inputBytes := make([]byte, 0) for i := 0; i < 1000; i++ { _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll() - require.Nil(t, err) + require.NoError(t, err) inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...) } outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") - require.Nil(t, err) + require.NoError(t, err) assert.Equal(t, int64(1000), res.RowsAffected()) assert.Equal(t, inputBytes, outputWriter.Bytes()) @@ -760,7 +760,7 @@ func TestConnCopyToQueryError(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) outputWriter := bytes.NewBuffer(make([]byte, 0)) @@ -777,7 +777,7 @@ func TestConnCopyToCanceled(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.Nil(t, err) + require.NoError(t, err) defer closeConn(t, pgConn) outputWriter := &bytes.Buffer{} From e15528c4195b2e3b8cb2e9f8b0eacf80d5a5fba3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 19 Jan 2019 15:41:42 -0600 Subject: [PATCH 054/290] Remove obsolete comment --- pgconn.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index 476cd046..aa246614 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1145,7 +1145,6 @@ type Batch struct { // ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions. func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { - // TODO - refactor ExecParams and ExecPrepared - these lines only difference batch.buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf) batch.ExecPrepared("", paramValues, paramFormats, resultFormats) } From c9f985c1e40fea85c7acc1a404063d3f4d94b001 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 19 Jan 2019 15:44:03 -0600 Subject: [PATCH 055/290] Add PgConn.EscapeString --- pgconn.go | 17 +++++++++++++++++ pgconn_test.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/pgconn.go b/pgconn.go index aa246614..49062f23 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1193,3 +1193,20 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR return multiResult } + +// EscapeString escapes a string such that it can safely be interpolated into a SQL command string. It does not include +// the surrounding single quotes. +// +// The current implementation requires that standard_conforming_strings=on and client_encoding="UTF8". If these +// conditions are not met an error will be returned. It is possible these restrictions will be lifted in the future. +func (pgConn *PgConn) EscapeString(s string) (string, error) { + if pgConn.ParameterStatus("standard_conforming_strings") != "on" { + return "", errors.New("EscapeString must be run with standard_conforming_strings=on") + } + + if pgConn.ParameterStatus("client_encoding") != "UTF8" { + return "", errors.New("EscapeString must be run with client_encoding=UTF8") + } + + return strings.Replace(s, "'", "''", -1), nil +} diff --git a/pgconn_test.go b/pgconn_test.go index f3ed04df..587acc57 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -791,6 +791,34 @@ func TestConnCopyToCanceled(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnEscapeString(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + tests := []struct { + in string + out string + }{ + {in: "", out: ""}, + {in: "42", out: "42"}, + {in: "'", out: "''"}, + {in: "hi'there", out: "hi''there"}, + {in: "'hi there'", out: "''hi there''"}, + } + + for i, tt := range tests { + value, err := pgConn.EscapeString(tt.in) + if assert.NoErrorf(t, err, "%d.", i) { + assert.Equalf(t, tt.out, value, "%d.", i) + } + } + + ensureConnValid(t, pgConn) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) if err != nil { From 3683e4a0a16d4508f14ac54a8986a8e5dc658a59 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 19 Jan 2019 17:24:48 -0600 Subject: [PATCH 056/290] Move CopyFrom to pgconn --- pgconn.go | 129 ++++++++++++++++++++++++++++++++++++++++++++++ pgconn_test.go | 136 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 265 insertions(+) diff --git a/pgconn.go b/pgconn.go index 49062f23..e8baffa2 100644 --- a/pgconn.go +++ b/pgconn.go @@ -14,6 +14,7 @@ import ( "strings" "time" + "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgproto3" ) @@ -812,6 +813,134 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm } } +// CopyFrom executes the copy command sql and copies all of r to the PostgreSQL server. +// +// Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r +// could still block. +func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { + select { + case <-ctx.Done(): + return "", ctx.Err() + case pgConn.controller <- pgConn: + } + cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) + + // Send copy to command + var buf []byte + buf = (&pgproto3.Query{String: sql}).Encode(buf) + + n, err := pgConn.conn.Write(buf) + if err != nil { + // Partially sent messages are a fatal error for the connection. + if n > 0 { + // Close connection because cannot recover from partially sent message. + pgConn.conn.Close() + pgConn.closed = true + } + + cleanupContextDeadline() + <-pgConn.controller + + return "", preferContextOverNetTimeoutError(ctx, err) + } + + // Read until copy in response or error. + var commandTag CommandTag + var pgErr error + pendingCopyInResponse := true + for pendingCopyInResponse { + msg, err := pgConn.ReceiveMessage() + if err != nil { + cleanupContextDeadline() + if err, ok := err.(net.Error); ok && err.Timeout() { + go pgConn.recoverFromTimeout() + } else { + <-pgConn.controller + } + + return "", preferContextOverNetTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.CopyInResponse: + pendingCopyInResponse = false + case *pgproto3.ErrorResponse: + pgErr = errorResponseToPgError(msg) + case *pgproto3.ReadyForQuery: + <-pgConn.controller + return commandTag, pgErr + } + } + + // Send copy data + buf = make([]byte, 0, 65536) + buf = append(buf, 'd') + sp := len(buf) + for { + n, err := r.Read(buf[5:cap(buf)]) + if err == io.EOF && n == 0 { + break + } + buf = buf[0 : n+5] + pgio.SetInt32(buf[sp:], int32(n+4)) + + _, err = pgConn.conn.Write(buf) + if err != nil { + // Partially sent messages are a fatal error for the connection. If nothing was sent it might be possible to + // recover the connection with a CopyFail, but that could be rather complicated and error prone. Simpler just to + // close the connection. + pgConn.conn.Close() + pgConn.closed = true + + cleanupContextDeadline() + <-pgConn.controller + + return "", preferContextOverNetTimeoutError(ctx, err) + } + } + + // Send copy done + buf = buf[:0] + copyDone := &pgproto3.CopyDone{} + buf = copyDone.Encode(buf) + + _, err = pgConn.conn.Write(buf) + if err != nil { + pgConn.conn.Close() + pgConn.closed = true + + cleanupContextDeadline() + <-pgConn.controller + + return "", preferContextOverNetTimeoutError(ctx, err) + } + + // Read results + for { + msg, err := pgConn.ReceiveMessage() + if err != nil { + cleanupContextDeadline() + if err, ok := err.(net.Error); ok && err.Timeout() { + go pgConn.recoverFromTimeout() + } else { + <-pgConn.controller + } + + return "", preferContextOverNetTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + <-pgConn.controller + return commandTag, pgErr + case *pgproto3.CommandComplete: + commandTag = CommandTag(msg.CommandTag) + case *pgproto3.ErrorResponse: + pgErr = errorResponseToPgError(msg) + } + } +} + // MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. type MultiResultReader struct { pgConn *PgConn diff --git a/pgconn_test.go b/pgconn_test.go index 587acc57..47b3b3fb 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -2,12 +2,15 @@ package pgconn_test import ( "bytes" + "compress/gzip" "context" "crypto/tls" "fmt" + "io/ioutil" "log" "net" "os" + "strconv" "testing" "time" @@ -791,6 +794,139 @@ func TestConnCopyToCanceled(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnCopyFrom(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + srcBuf := &bytes.Buffer{} + + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + require.NoError(t, err) + } + + ct, err := pgConn.CopyFrom(context.Background(), srcBuf, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + + result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + assert.Equal(t, inputRows, result.Rows) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyFromGzipReader(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + f, err := ioutil.TempFile("", "*") + require.NoError(t, err) + + gw := gzip.NewWriter(f) + + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + require.NoError(t, err) + } + + err = gw.Close() + require.NoError(t, err) + + _, err = f.Seek(0, 0) + require.NoError(t, err) + + gr, err := gzip.NewReader(f) + require.NoError(t, err) + + ct, err := pgConn.CopyFrom(context.Background(), gr, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + + err = gr.Close() + require.NoError(t, err) + + err = f.Close() + require.NoError(t, err) + + err = os.Remove(f.Name()) + require.NoError(t, err) + + result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + assert.Equal(t, inputRows, result.Rows) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyFromQuerySyntaxError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + srcBuf := &bytes.Buffer{} + + res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) + + ensureConnValid(t, pgConn) +} + +func TestConnCopyFromQueryNoTableError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + srcBuf := &bytes.Buffer{} + + res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) + + ensureConnValid(t, pgConn) +} + func TestConnEscapeString(t *testing.T) { t.Parallel() From 01b54c7cb6f204983e3ece13262e4560b798eab9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 26 Jan 2019 10:21:16 -0600 Subject: [PATCH 057/290] Properly abort CopyFrom on reader error --- pgconn.go | 45 ++++++++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/pgconn.go b/pgconn.go index e8baffa2..d8ec6b07 100644 --- a/pgconn.go +++ b/pgconn.go @@ -876,34 +876,37 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co buf = make([]byte, 0, 65536) buf = append(buf, 'd') sp := len(buf) - for { - n, err := r.Read(buf[5:cap(buf)]) - if err == io.EOF && n == 0 { - break - } - buf = buf[0 : n+5] - pgio.SetInt32(buf[sp:], int32(n+4)) + var readErr error + for readErr == nil { + n, readErr = r.Read(buf[5:cap(buf)]) + if n > 0 { + buf = buf[0 : n+5] + pgio.SetInt32(buf[sp:], int32(n+4)) - _, err = pgConn.conn.Write(buf) - if err != nil { - // Partially sent messages are a fatal error for the connection. If nothing was sent it might be possible to - // recover the connection with a CopyFail, but that could be rather complicated and error prone. Simpler just to - // close the connection. - pgConn.conn.Close() - pgConn.closed = true + _, err = pgConn.conn.Write(buf) + if err != nil { + // Partially sent messages are a fatal error for the connection. If nothing was sent it might be possible to + // recover the connection with a CopyFail, but that could be rather complicated and error prone. Simpler just to + // close the connection. + pgConn.conn.Close() + pgConn.closed = true - cleanupContextDeadline() - <-pgConn.controller + cleanupContextDeadline() + <-pgConn.controller - return "", preferContextOverNetTimeoutError(ctx, err) + return "", preferContextOverNetTimeoutError(ctx, err) + } } } - // Send copy done buf = buf[:0] - copyDone := &pgproto3.CopyDone{} - buf = copyDone.Encode(buf) - + if readErr == io.EOF { + copyDone := &pgproto3.CopyDone{} + buf = copyDone.Encode(buf) + } else { + copyFail := &pgproto3.CopyFail{Error: readErr.Error()} + buf = copyFail.Encode(buf) + } _, err = pgConn.conn.Write(buf) if err != nil { pgConn.conn.Close() From 96c85cf0c3981d8e35cd3c5fd34a9d1c1ddad313 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 26 Jan 2019 12:20:36 -0600 Subject: [PATCH 058/290] Recover from context cancellation during CopyFrom --- pgconn.go | 131 ++++++++++++++++++++++++++++++++++++++++++++----- pgconn_test.go | 36 ++++++++++++++ 2 files changed, 155 insertions(+), 12 deletions(-) diff --git a/pgconn.go b/pgconn.go index d8ec6b07..e34853a0 100644 --- a/pgconn.go +++ b/pgconn.go @@ -12,6 +12,7 @@ import ( "net" "strconv" "strings" + "sync" "time" "github.com/jackc/pgx/pgio" @@ -91,6 +92,11 @@ type PgConn struct { controller chan interface{} closed bool + + bufferingReceive bool + bufferingReceiveMux sync.Mutex + bufferingReceiveMsg pgproto3.BackendMessage + bufferingReceiveErr error } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) @@ -273,8 +279,42 @@ func hexMD5(s string) string { return hex.EncodeToString(hash.Sum(nil)) } +func (pgConn *PgConn) signalMessage() chan struct{} { + if pgConn.bufferingReceive { + panic("BUG: signalMessage when already in progress") + } + + pgConn.bufferingReceive = true + pgConn.bufferingReceiveMux.Lock() + + ch := make(chan struct{}) + go func() { + pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.Frontend.Receive() + pgConn.bufferingReceiveMux.Unlock() + close(ch) + }() + + return ch +} + func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { - msg, err := pgConn.Frontend.Receive() + var msg pgproto3.BackendMessage + var err error + if pgConn.bufferingReceive { + pgConn.bufferingReceiveMux.Lock() + msg = pgConn.bufferingReceiveMsg + err = pgConn.bufferingReceiveErr + pgConn.bufferingReceiveMux.Unlock() + pgConn.bufferingReceive = false + + // If a timeout error happened in the background try the read again. + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + msg, err = pgConn.Frontend.Receive() + } + } else { + msg, err = pgConn.Frontend.Receive() + } + if err != nil { // Close on anything other than timeout error - everything else is fatal if err, ok := err.(net.Error); !(ok && err.Timeout()) { @@ -853,7 +893,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co if err != nil { cleanupContextDeadline() if err, ok := err.(net.Error); ok && err.Timeout() { - go pgConn.recoverFromTimeout() + go pgConn.recoverFromTimeoutDuringCopyFrom() } else { <-pgConn.controller } @@ -877,30 +917,56 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co buf = append(buf, 'd') sp := len(buf) var readErr error - for readErr == nil { + signalMessageChan := pgConn.signalMessage() + for readErr == nil && pgErr == nil { n, readErr = r.Read(buf[5:cap(buf)]) if n > 0 { buf = buf[0 : n+5] pgio.SetInt32(buf[sp:], int32(n+4)) - _, err = pgConn.conn.Write(buf) + n, err = pgConn.conn.Write(buf) if err != nil { - // Partially sent messages are a fatal error for the connection. If nothing was sent it might be possible to - // recover the connection with a CopyFail, but that could be rather complicated and error prone. Simpler just to - // close the connection. - pgConn.conn.Close() - pgConn.closed = true - + // Partially sent messages are a fatal error for the connection. + if n > 0 { + // Close connection because cannot recover from partially sent message. + pgConn.conn.Close() + pgConn.closed = true + } cleanupContextDeadline() - <-pgConn.controller + if err, ok := err.(net.Error); ok && err.Timeout() { + go pgConn.recoverFromTimeoutDuringCopyFrom() + } else { + <-pgConn.controller + } return "", preferContextOverNetTimeoutError(ctx, err) } } + + select { + case <-signalMessageChan: + msg, err := pgConn.ReceiveMessage() + if err != nil { + cleanupContextDeadline() + if err, ok := err.(net.Error); ok && err.Timeout() { + go pgConn.recoverFromTimeoutDuringCopyFrom() + } else { + <-pgConn.controller + } + + return "", preferContextOverNetTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.ErrorResponse: + pgErr = errorResponseToPgError(msg) + } + default: + } } buf = buf[:0] - if readErr == io.EOF { + if readErr == io.EOF || pgErr != nil { copyDone := &pgproto3.CopyDone{} buf = copyDone.Encode(buf) } else { @@ -944,6 +1010,47 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } } +func (pgConn *PgConn) recoverFromTimeoutDuringCopyFrom() { + // Regardless of recovery outcome the lock on the pgConn must be released. + defer func() { <-pgConn.controller }() + + // Limit time to wait for entire cancellation process. + err := pgConn.conn.SetDeadline(time.Now().Add(15 * time.Second)) + if err != nil { + pgConn.hardClose() + return + } + + copyFail := &pgproto3.CopyFail{Error: "client cancel"} + buf := copyFail.Encode(nil) + + _, err = pgConn.conn.Write(buf) + if err != nil { + pgConn.hardClose() + return + } + + pendingReadyForQuery := true + + for pendingReadyForQuery { + msg, err := pgConn.ReceiveMessage() + if err != nil { + pgConn.hardClose() + return + } + + switch msg.(type) { + case *pgproto3.ReadyForQuery: + pendingReadyForQuery = false + } + } + + err = pgConn.conn.SetDeadline(time.Time{}) + if err != nil { + pgConn.hardClose() + } +} + // MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. type MultiResultReader struct { pgConn *PgConn diff --git a/pgconn_test.go b/pgconn_test.go index 47b3b3fb..7fb01e2c 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -6,6 +6,7 @@ import ( "context" "crypto/tls" "fmt" + "io" "io/ioutil" "log" "net" @@ -830,6 +831,41 @@ func TestConnCopyFrom(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnCopyFromCanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + r, w := io.Pipe() + go func() { + for i := 0; i < 1000000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + _, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + if err != nil { + return + } + time.Sleep(time.Microsecond) + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") + cancel() + assert.Equal(t, int64(0), ct.RowsAffected()) + require.Equal(t, context.DeadlineExceeded, err) + + ensureConnValid(t, pgConn) +} + func TestConnCopyFromGzipReader(t *testing.T) { t.Parallel() From f5aecdd4992504d8344ea0730800e38d48b32f28 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 26 Jan 2019 12:33:51 -0600 Subject: [PATCH 059/290] Extract writeAll --- pgconn.go | 87 ++++++++++++++----------------------------------------- 1 file changed, 21 insertions(+), 66 deletions(-) diff --git a/pgconn.go b/pgconn.go index e34853a0..461ff1c0 100644 --- a/pgconn.go +++ b/pgconn.go @@ -398,6 +398,15 @@ func (pgConn *PgConn) hardClose() error { return pgConn.conn.Close() } +// writeAll writes the entire buffer successfully or it hard closes the connection. +func (pgConn *PgConn) writeAll(buf []byte) error { + n, err := pgConn.conn.Write(buf) + if err != nil && n > 0 { + pgConn.hardClose() + } + return err +} + // ParameterStatus returns the value of a parameter reported by the server (e.g. // server_version). Returns an empty string for unknown parameters. func (pgConn *PgConn) ParameterStatus(key string) string { @@ -482,15 +491,8 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) - n, err := pgConn.conn.Write(buf) + err := pgConn.writeAll(buf) if err != nil { - // Partially sent messages are a fatal error for the connection. - if n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } - return nil, preferContextOverNetTimeoutError(ctx, err) } @@ -654,15 +656,8 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { var buf []byte buf = (&pgproto3.Query{String: sql}).Encode(buf) - n, err := pgConn.conn.Write(buf) + err := pgConn.writeAll(buf) if err != nil { - // Partially sent messages are a fatal error for the connection. - if n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } - multiResult.cleanupContextDeadline() multiResult.closed = true multiResult.err = preferContextOverNetTimeoutError(ctx, err) @@ -718,15 +713,8 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] buf = (&pgproto3.Execute{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) - n, err := pgConn.conn.Write(buf) + err := pgConn.writeAll(buf) if err != nil { - // Partially sent messages are a fatal error for the connection. - if n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } - result.concludeCommand("", err) result.cleanupContextDeadline() result.closed = true @@ -770,15 +758,8 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa buf = (&pgproto3.Execute{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) - n, err := pgConn.conn.Write(buf) + err := pgConn.writeAll(buf) if err != nil { - // Partially sent messages are a fatal error for the connection. - if n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } - result.concludeCommand("", err) result.cleanupContextDeadline() result.closed = true @@ -801,15 +782,8 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm var buf []byte buf = (&pgproto3.Query{String: sql}).Encode(buf) - n, err := pgConn.conn.Write(buf) + err := pgConn.writeAll(buf) if err != nil { - // Partially sent messages are a fatal error for the connection. - if n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } - cleanupContextDeadline() <-pgConn.controller @@ -869,15 +843,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co var buf []byte buf = (&pgproto3.Query{String: sql}).Encode(buf) - n, err := pgConn.conn.Write(buf) + err := pgConn.writeAll(buf) if err != nil { - // Partially sent messages are a fatal error for the connection. - if n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } - cleanupContextDeadline() <-pgConn.controller @@ -913,25 +880,21 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } // Send copy data - buf = make([]byte, 0, 65536) + buf = make([]byte, 0, 20000) + // buf = make([]byte, 0, 65536) buf = append(buf, 'd') sp := len(buf) var readErr error signalMessageChan := pgConn.signalMessage() for readErr == nil && pgErr == nil { + var n int n, readErr = r.Read(buf[5:cap(buf)]) if n > 0 { buf = buf[0 : n+5] pgio.SetInt32(buf[sp:], int32(n+4)) - n, err = pgConn.conn.Write(buf) + err = pgConn.writeAll(buf) if err != nil { - // Partially sent messages are a fatal error for the connection. - if n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } cleanupContextDeadline() if err, ok := err.(net.Error); ok && err.Timeout() { go pgConn.recoverFromTimeoutDuringCopyFrom() @@ -975,8 +938,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } _, err = pgConn.conn.Write(buf) if err != nil { - pgConn.conn.Close() - pgConn.closed = true + pgConn.hardClose() cleanupContextDeadline() <-pgConn.controller @@ -1414,15 +1376,8 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) - n, err := pgConn.conn.Write(batch.buf) + err := pgConn.writeAll(batch.buf) if err != nil { - // Partially sent messages are a fatal error for the connection. - if n > 0 { - // Close connection because cannot recover from partially sent message. - pgConn.conn.Close() - pgConn.closed = true - } - multiResult.cleanupContextDeadline() multiResult.closed = true multiResult.err = preferContextOverNetTimeoutError(ctx, err) From b59437f6ecfec5b604e0ae2063134078578e1d7e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 26 Jan 2019 16:45:06 -0600 Subject: [PATCH 060/290] writeAll dies on permanent net errors --- pgconn.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pgconn.go b/pgconn.go index 461ff1c0..06f9e833 100644 --- a/pgconn.go +++ b/pgconn.go @@ -398,11 +398,15 @@ func (pgConn *PgConn) hardClose() error { return pgConn.conn.Close() } -// writeAll writes the entire buffer successfully or it hard closes the connection. +// writeAll writes the entire buffer. The connection is hard closed on a partial write or a non-temporary error. func (pgConn *PgConn) writeAll(buf []byte) error { n, err := pgConn.conn.Write(buf) - if err != nil && n > 0 { - pgConn.hardClose() + if err != nil { + if n > 0 { + pgConn.hardClose() + } else if ne, ok := err.(net.Error); ok && !ne.Temporary() { + pgConn.hardClose() + } } return err } From 9229e03d06a317a765275e7bd82f301a623b760d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 26 Jan 2019 16:46:30 -0600 Subject: [PATCH 061/290] Partial conversion of pgx to use pgconn --- pgconn.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pgconn.go b/pgconn.go index 06f9e833..512c9a88 100644 --- a/pgconn.go +++ b/pgconn.go @@ -398,6 +398,12 @@ func (pgConn *PgConn) hardClose() error { return pgConn.conn.Close() } +// TODO - rethink how to report status. At the moment this is just a temporary measure so pgx.Conn can detect deatch of +// underlying connection. +func (pgConn *PgConn) IsAlive() bool { + return !pgConn.closed +} + // writeAll writes the entire buffer. The connection is hard closed on a partial write or a non-temporary error. func (pgConn *PgConn) writeAll(buf []byte) error { n, err := pgConn.conn.Write(buf) From 79ffab98367fef041c34bfc3307b82f833661694 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 28 Jan 2019 23:13:03 -0600 Subject: [PATCH 062/290] All writes errors are fatal --- pgconn.go | 43 +++++++++++++++++-------------------------- pgconn_test.go | 2 +- 2 files changed, 18 insertions(+), 27 deletions(-) diff --git a/pgconn.go b/pgconn.go index 512c9a88..c785f367 100644 --- a/pgconn.go +++ b/pgconn.go @@ -404,19 +404,6 @@ func (pgConn *PgConn) IsAlive() bool { return !pgConn.closed } -// writeAll writes the entire buffer. The connection is hard closed on a partial write or a non-temporary error. -func (pgConn *PgConn) writeAll(buf []byte) error { - n, err := pgConn.conn.Write(buf) - if err != nil { - if n > 0 { - pgConn.hardClose() - } else if ne, ok := err.(net.Error); ok && !ne.Temporary() { - pgConn.hardClose() - } - } - return err -} - // ParameterStatus returns the value of a parameter reported by the server (e.g. // server_version). Returns an empty string for unknown parameters. func (pgConn *PgConn) ParameterStatus(key string) string { @@ -501,8 +488,9 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) - err := pgConn.writeAll(buf) + _, err := pgConn.conn.Write(buf) if err != nil { + pgConn.hardClose() return nil, preferContextOverNetTimeoutError(ctx, err) } @@ -666,8 +654,9 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { var buf []byte buf = (&pgproto3.Query{String: sql}).Encode(buf) - err := pgConn.writeAll(buf) + _, err := pgConn.conn.Write(buf) if err != nil { + pgConn.hardClose() multiResult.cleanupContextDeadline() multiResult.closed = true multiResult.err = preferContextOverNetTimeoutError(ctx, err) @@ -723,8 +712,9 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] buf = (&pgproto3.Execute{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) - err := pgConn.writeAll(buf) + _, err := pgConn.conn.Write(buf) if err != nil { + pgConn.hardClose() result.concludeCommand("", err) result.cleanupContextDeadline() result.closed = true @@ -768,8 +758,9 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa buf = (&pgproto3.Execute{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) - err := pgConn.writeAll(buf) + _, err := pgConn.conn.Write(buf) if err != nil { + pgConn.hardClose() result.concludeCommand("", err) result.cleanupContextDeadline() result.closed = true @@ -792,8 +783,9 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm var buf []byte buf = (&pgproto3.Query{String: sql}).Encode(buf) - err := pgConn.writeAll(buf) + _, err := pgConn.conn.Write(buf) if err != nil { + pgConn.hardClose() cleanupContextDeadline() <-pgConn.controller @@ -853,8 +845,9 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co var buf []byte buf = (&pgproto3.Query{String: sql}).Encode(buf) - err := pgConn.writeAll(buf) + _, err := pgConn.conn.Write(buf) if err != nil { + pgConn.hardClose() cleanupContextDeadline() <-pgConn.controller @@ -903,14 +896,11 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co buf = buf[0 : n+5] pgio.SetInt32(buf[sp:], int32(n+4)) - err = pgConn.writeAll(buf) + _, err = pgConn.conn.Write(buf) if err != nil { + pgConn.hardClose() cleanupContextDeadline() - if err, ok := err.(net.Error); ok && err.Timeout() { - go pgConn.recoverFromTimeoutDuringCopyFrom() - } else { - <-pgConn.controller - } + <-pgConn.controller return "", preferContextOverNetTimeoutError(ctx, err) } @@ -1386,8 +1376,9 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) - err := pgConn.writeAll(batch.buf) + _, err := pgConn.conn.Write(batch.buf) if err != nil { + pgConn.hardClose() multiResult.cleanupContextDeadline() multiResult.closed = true multiResult.err = preferContextOverNetTimeoutError(ctx, err) diff --git a/pgconn_test.go b/pgconn_test.go index 7fb01e2c..dbf9b840 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -863,7 +863,7 @@ func TestConnCopyFromCanceled(t *testing.T) { assert.Equal(t, int64(0), ct.RowsAffected()) require.Equal(t, context.DeadlineExceeded, err) - ensureConnValid(t, pgConn) + assert.False(t, pgConn.IsAlive()) } func TestConnCopyFromGzipReader(t *testing.T) { From fbdfccf1f91a4c0bc042cb37f3c7c2c9e27a4877 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Mar 2019 12:55:56 -0500 Subject: [PATCH 063/290] Use Go modules --- .gitignore | 1 + benchmark_test.go | 2 +- config.go | 2 +- config_test.go | 2 +- go.mod | 11 +++++++++++ go.sum | 19 +++++++++++++++++++ helper_test.go | 2 +- pgconn.go | 6 +++--- pgconn_stress_test.go | 2 +- pgconn_test.go | 4 ++-- 10 files changed, 41 insertions(+), 10 deletions(-) create mode 100644 .gitignore create mode 100644 go.mod create mode 100644 go.sum diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..7a6353d6 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.envrc diff --git a/benchmark_test.go b/benchmark_test.go index d2576324..959e86be 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -5,7 +5,7 @@ import ( "os" "testing" - "github.com/jackc/pgx/pgconn" + "github.com/jackc/pgconn" "github.com/stretchr/testify/require" ) diff --git a/config.go b/config.go index fec1fedf..1cde9c57 100644 --- a/config.go +++ b/config.go @@ -17,7 +17,7 @@ import ( "strings" "time" - "github.com/jackc/pgx/pgpassfile" + "github.com/jackc/pgpassfile" "github.com/pkg/errors" ) diff --git a/config_test.go b/config_test.go index c7b65861..ce6f3957 100644 --- a/config_test.go +++ b/config_test.go @@ -8,7 +8,7 @@ import ( "os/user" "testing" - "github.com/jackc/pgx/pgconn" + "github.com/jackc/pgconn" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..3dc806a4 --- /dev/null +++ b/go.mod @@ -0,0 +1,11 @@ +module github.com/jackc/pgconn + +go 1.12 + +require ( + github.com/jackc/pgio v1.0.0 + github.com/jackc/pgpassfile v1.0.0 + github.com/jackc/pgproto3 v1.0.0 + github.com/pkg/errors v0.8.1 + github.com/stretchr/testify v1.3.0 +) diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..5b6f835b --- /dev/null +++ b/go.sum @@ -0,0 +1,19 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= +github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= +github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= +github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgproto3 v0.0.0-20190330174656-bb06e6b3ff87 h1:xueDi0R+HxuFmuOA1xyFbbF+2LSXqWQJZSPWmmMFB0A= +github.com/jackc/pgproto3 v0.0.0-20190330174656-bb06e6b3ff87/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= +github.com/jackc/pgproto3 v1.0.0 h1:25tUmlES7eyD96oYaUHc1dLOFbgcJtFzCdnOOoqmA1I= +github.com/jackc/pgproto3 v1.0.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= diff --git a/helper_test.go b/helper_test.go index c5ac6e01..5d44f3b8 100644 --- a/helper_test.go +++ b/helper_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "github.com/jackc/pgx/pgconn" + "github.com/jackc/pgconn" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" diff --git a/pgconn.go b/pgconn.go index c785f367..6490617a 100644 --- a/pgconn.go +++ b/pgconn.go @@ -15,8 +15,8 @@ import ( "sync" "time" - "github.com/jackc/pgx/pgio" - "github.com/jackc/pgx/pgproto3" + "github.com/jackc/pgio" + "github.com/jackc/pgproto3" ) var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC) @@ -171,7 +171,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } } - pgConn.Frontend, err = pgproto3.NewFrontend(pgConn.conn, pgConn.conn) + pgConn.Frontend, err = pgproto3.NewFrontend(pgproto3.NewChunkReader(pgConn.conn), pgConn.conn) if err != nil { return nil, err } diff --git a/pgconn_stress_test.go b/pgconn_stress_test.go index 7a95fa98..1ebbe04a 100644 --- a/pgconn_stress_test.go +++ b/pgconn_stress_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/jackc/pgx/pgconn" + "github.com/jackc/pgconn" "github.com/stretchr/testify/require" ) diff --git a/pgconn_test.go b/pgconn_test.go index dbf9b840..716761ad 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -15,8 +15,8 @@ import ( "testing" "time" - "github.com/jackc/pgx/pgconn" - "github.com/jackc/pgx/pgproto3" + "github.com/jackc/pgconn" + "github.com/jackc/pgproto3" "github.com/pkg/errors" "github.com/stretchr/testify/assert" From 08fcc7f2736a16192388fc08a0dc7863951e2c69 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Mar 2019 12:59:04 -0500 Subject: [PATCH 064/290] Add license and readme --- LICENSE | 22 ++++++++++++++++++++++ README.md | 8 ++++++++ 2 files changed, 30 insertions(+) create mode 100644 LICENSE create mode 100644 README.md diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..c1c4f50f --- /dev/null +++ b/LICENSE @@ -0,0 +1,22 @@ +Copyright (c) 2019 Jack Christensen + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 00000000..8a881009 --- /dev/null +++ b/README.md @@ -0,0 +1,8 @@ +[![](https://godoc.org/github.com/jackc/pgconn?status.svg)](https://godoc.org/github.com/jackc/pgconn) +[![Build Status](https://travis-ci.org/jackc/pgconn.svg)](https://travis-ci.org/jackc/pgconn) + +# pgconn + +Package pgconn is a low-level PostgreSQL database driver. + +It is intended to serve as the foundation for the next generation of https://github.com/jackc/pgx. From b2fc69d32f5cdf79a4119888a36c997dcecdc073 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Mar 2019 13:03:28 -0500 Subject: [PATCH 065/290] Import pgx travis config --- .travis.yml | 38 +++++++++++++++++++++++++++++++++++++ travis/before_install.bash | 39 ++++++++++++++++++++++++++++++++++++++ travis/before_script.bash | 16 ++++++++++++++++ travis/install.bash | 14 ++++++++++++++ travis/script.bash | 10 ++++++++++ 5 files changed, 117 insertions(+) create mode 100644 .travis.yml create mode 100755 travis/before_install.bash create mode 100755 travis/before_script.bash create mode 100755 travis/install.bash create mode 100755 travis/script.bash diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 00000000..950792d1 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,38 @@ +language: go + +go: + - 1.x + - tip + +# Derived from https://github.com/lib/pq/blob/master/.travis.yml +before_install: + - ./travis/before_install.bash + +env: + global: + - PGX_TEST_DATABASE=postgres://pgx_md5:secret@127.0.0.1/pgx_test + - PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql database=pgx_test" + - PGX_TEST_TCP_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test + - PGX_TEST_TLS_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + - PGX_TEST_MD5_PASSWORD_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test + - PGX_TEST_PLAIN_PASSWORD_CONN_STRING=postgres://pgx_pw:secret@127.0.0.1/pgx_test + matrix: + - CRATEVERSION=2.1 PGX_TEST_CRATEDB_CONN_STRING="host=127.0.0.1 port=6543 user=pgx database=pgx_test" + - PGVERSION=10 PGX_TEST_REPLICATION_CONN_STRING="host=127.0.0.1 port=6543 user=pgx_replication password=secret database=pgx_test" + - PGVERSION=9.6 PGX_TEST_REPLICATION_CONN_STRING="host=127.0.0.1 port=6543 user=pgx_replication password=secret database=pgx_test" + - PGVERSION=9.5 + - PGVERSION=9.4 + - PGVERSION=9.3 + +before_script: + - ./travis/before_script.bash + +install: + - ./travis/install.bash + +script: + - ./travis/script.bash + +matrix: + allow_failures: + - go: tip diff --git a/travis/before_install.bash b/travis/before_install.bash new file mode 100755 index 00000000..23c7d9cf --- /dev/null +++ b/travis/before_install.bash @@ -0,0 +1,39 @@ +#!/usr/bin/env bash +set -eux + +if [ "${PGVERSION-}" != "" ] +then + sudo apt-get remove -y --purge postgresql libpq-dev libpq5 postgresql-client-common postgresql-common + sudo rm -rf /var/lib/postgresql + wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - + sudo sh -c "echo deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVERSION >> /etc/apt/sources.list.d/postgresql.list" + sudo apt-get update -qq + sudo apt-get -y -o Dpkg::Options::=--force-confdef -o Dpkg::Options::="--force-confnew" install postgresql-$PGVERSION postgresql-server-dev-$PGVERSION postgresql-contrib-$PGVERSION + sudo chmod 777 /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "local all postgres trust" > /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "local all all trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "host all pgx_md5 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "host all pgx_pw 127.0.0.1/32 password" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "hostssl all pgx_ssl 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "host replication pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + echo "host pgx_test pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf + sudo chmod 777 /etc/postgresql/$PGVERSION/main/postgresql.conf + if $(dpkg --compare-versions $PGVERSION ge 9.6) ; then + echo "wal_level='logical'" >> /etc/postgresql/$PGVERSION/main/postgresql.conf + echo "max_wal_senders=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf + echo "max_replication_slots=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf + fi + sudo /etc/init.d/postgresql restart +fi + +if [ "${CRATEVERSION-}" != "" ] +then + docker run \ + -p "6543:5432" \ + -d \ + crate:"$CRATEVERSION" \ + crate \ + -Cnetwork.host=0.0.0.0 \ + -Ctransport.host=localhost \ + -Clicense.enterprise=false +fi diff --git a/travis/before_script.bash b/travis/before_script.bash new file mode 100755 index 00000000..bcf748a1 --- /dev/null +++ b/travis/before_script.bash @@ -0,0 +1,16 @@ +#!/usr/bin/env bash +set -eux + +if [ "${PGVERSION-}" != "" ] +then + # The tricky test user, below, has to actually exist so that it can be used in a test + # of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles. + psql -U postgres -c 'create database pgx_test' + psql -U postgres pgx_test -c 'create extension hstore' + psql -U postgres pgx_test -c 'create domain uint64 as numeric(20,0)' + psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'" + psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'" + psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'" + psql -U postgres -c "create user pgx_replication with replication password 'secret'" + psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'" +fi diff --git a/travis/install.bash b/travis/install.bash new file mode 100755 index 00000000..63ba875d --- /dev/null +++ b/travis/install.bash @@ -0,0 +1,14 @@ +#!/usr/bin/env bash +set -eux + +go get -u github.com/cockroachdb/apd +go get -u github.com/shopspring/decimal +go get -u gopkg.in/inconshreveable/log15.v2 +go get -u github.com/jackc/fake +go get -u github.com/lib/pq +go get -u github.com/hashicorp/go-version +go get -u github.com/satori/go.uuid +go get -u github.com/sirupsen/logrus +go get -u github.com/pkg/errors +go get -u go.uber.org/zap +go get -u github.com/rs/zerolog diff --git a/travis/script.bash b/travis/script.bash new file mode 100755 index 00000000..5bf1b77e --- /dev/null +++ b/travis/script.bash @@ -0,0 +1,10 @@ +#!/usr/bin/env bash +set -eux + +if [ "${PGVERSION-}" != "" ] +then + go test -v -race ./... +elif [ "${CRATEVERSION-}" != "" ] +then + go test -v -race -run 'TestCrateDBConnect' +fi From 444bd6deaf2065c0d108dadfa042df36af88ea57 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Mar 2019 16:44:20 -0500 Subject: [PATCH 066/290] Context cancellation is fatal during query --- config.go | 11 --- doc.go | 10 +-- pgconn.go | 189 +++--------------------------------------- pgconn_stress_test.go | 62 ++++---------- pgconn_test.go | 99 ++++++---------------- 5 files changed, 60 insertions(+), 311 deletions(-) diff --git a/config.go b/config.go index 1cde9c57..d392924c 100644 --- a/config.go +++ b/config.go @@ -41,17 +41,6 @@ type Config struct { // allows implementing high availability behavior such as libpq does with target_session_attrs. AfterConnectFunc AfterConnectFunc - // OnContextCancel is a callback function used to override cancellation behavior. It is called when a context.Context - // is canceled. Default cancellation behavior is to establish another connection to the PostgreSQL server and send a - // query cancel request. Some non-PostgreSQL servers (e.g. CockroachDB) that speak a subset of the PostgreSQL wire - // protocol do not support this cancellation method. - // - // It is called from a background goroutine. When the cancellation process has finished ContextCancel.Finish must be - // called whether it was successful or not. If an error occurs the connection should be closed. The connection must be - // in a ready for query state or be closed when ContextCancel.Finish is called. Use PgConn.ReceiveMessage() to read - // the connection until a ready for query message is received. - OnContextCancel func(*ContextCancel) - // OnNotice is a callback function called when a notice response is received. OnNotice NoticeHandler diff --git a/doc.go b/doc.go index 89e47536..d36eb0fd 100644 --- a/doc.go +++ b/doc.go @@ -20,10 +20,10 @@ result. The ReadAll method reads all query results into memory. Context Support -All potentially blocking operations take a context.Context. If a context is canceled while a query is in progress the -method immediately returns. In the background a cancel request will be sent to the PostgreSQL server. If the -cancellation fails or hangs for more than a short time (approximately 15 seconds) the connection will be closed. It is -safe to use the connection while this background cancellation is in progress. Any calls will block until the -cancellation and resynchronization is complete (and those calls can be aborted by a context cancellation). +All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the +method immediately returns. In most circumstances, this will close the underlying connection. + +The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the +client to abort. */ package pgconn diff --git a/pgconn.go b/pgconn.go index 6490617a..8b0ddcb4 100644 --- a/pgconn.go +++ b/pgconn.go @@ -199,6 +199,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig for { msg, err := pgConn.ReceiveMessage() if err != nil { + pgConn.conn.Close() return nil, err } @@ -502,7 +503,7 @@ readloop: for { msg, err := pgConn.ReceiveMessage() if err != nil { - go pgConn.recoverFromTimeout() + pgConn.hardClose() return nil, preferContextOverNetTimeoutError(ctx, err) } @@ -555,10 +556,10 @@ func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice { return (*Notice)(pgerr) } -// cancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel +// CancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel // request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there // is no way to be sure a query was canceled. See https://www.postgresql.org/docs/11/protocol-flow.html#id-1.10.5.7.9 -func (pgConn *PgConn) cancelRequest(ctx context.Context) error { +func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. @@ -590,21 +591,6 @@ func (pgConn *PgConn) cancelRequest(ctx context.Context) error { return nil } -// WaitUntilReady waits until a previous context cancellation has been completed and the connection is ready for use. -// This is done automatically by all methods that need the connection to be ready for use. The only expected use for -// this method is for a connection pool to wait for a returned connection to be usable again before making it available. -func (pgConn *PgConn) WaitUntilReady(ctx context.Context) error { - select { - case <-ctx.Done(): - return ctx.Err() - case pgConn.controller <- pgConn: - // The connection must be ready since it was locked. Immediately unlock it. - <-pgConn.controller - } - - return nil -} - // WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not // received. func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { @@ -778,6 +764,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm case pgConn.controller <- pgConn: } cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) + defer cleanupContextDeadline() // Send copy to command var buf []byte @@ -786,7 +773,6 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm _, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - cleanupContextDeadline() <-pgConn.controller return "", preferContextOverNetTimeoutError(ctx, err) @@ -798,13 +784,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm for { msg, err := pgConn.ReceiveMessage() if err != nil { - cleanupContextDeadline() - if err, ok := err.(net.Error); ok && err.Timeout() { - go pgConn.recoverFromTimeout() - } else { - <-pgConn.controller - } - + pgConn.hardClose() return "", preferContextOverNetTimeoutError(ctx, err) } @@ -813,9 +793,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm case *pgproto3.CopyData: _, err := w.Write(msg.Data) if err != nil { - // This isn't actually a timeout, but we want the same behavior. Abort the request and cleanup. - cleanupContextDeadline() - go pgConn.recoverFromTimeout() + pgConn.hardClose() return "", err } case *pgproto3.ReadyForQuery: @@ -840,6 +818,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co case pgConn.controller <- pgConn: } cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) + defer cleanupContextDeadline() // Send copy to command var buf []byte @@ -848,7 +827,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - cleanupContextDeadline() <-pgConn.controller return "", preferContextOverNetTimeoutError(ctx, err) @@ -861,13 +839,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co for pendingCopyInResponse { msg, err := pgConn.ReceiveMessage() if err != nil { - cleanupContextDeadline() - if err, ok := err.(net.Error); ok && err.Timeout() { - go pgConn.recoverFromTimeoutDuringCopyFrom() - } else { - <-pgConn.controller - } - + pgConn.hardClose() return "", preferContextOverNetTimeoutError(ctx, err) } @@ -899,7 +871,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - cleanupContextDeadline() <-pgConn.controller return "", preferContextOverNetTimeoutError(ctx, err) @@ -910,13 +881,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co case <-signalMessageChan: msg, err := pgConn.ReceiveMessage() if err != nil { - cleanupContextDeadline() - if err, ok := err.(net.Error); ok && err.Timeout() { - go pgConn.recoverFromTimeoutDuringCopyFrom() - } else { - <-pgConn.controller - } - + pgConn.hardClose() return "", preferContextOverNetTimeoutError(ctx, err) } @@ -939,8 +904,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - - cleanupContextDeadline() <-pgConn.controller return "", preferContextOverNetTimeoutError(ctx, err) @@ -950,13 +913,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co for { msg, err := pgConn.ReceiveMessage() if err != nil { - cleanupContextDeadline() - if err, ok := err.(net.Error); ok && err.Timeout() { - go pgConn.recoverFromTimeout() - } else { - <-pgConn.controller - } - + pgConn.hardClose() return "", preferContextOverNetTimeoutError(ctx, err) } @@ -972,47 +929,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } } -func (pgConn *PgConn) recoverFromTimeoutDuringCopyFrom() { - // Regardless of recovery outcome the lock on the pgConn must be released. - defer func() { <-pgConn.controller }() - - // Limit time to wait for entire cancellation process. - err := pgConn.conn.SetDeadline(time.Now().Add(15 * time.Second)) - if err != nil { - pgConn.hardClose() - return - } - - copyFail := &pgproto3.CopyFail{Error: "client cancel"} - buf := copyFail.Encode(nil) - - _, err = pgConn.conn.Write(buf) - if err != nil { - pgConn.hardClose() - return - } - - pendingReadyForQuery := true - - for pendingReadyForQuery { - msg, err := pgConn.ReceiveMessage() - if err != nil { - pgConn.hardClose() - return - } - - switch msg.(type) { - case *pgproto3.ReadyForQuery: - pendingReadyForQuery = false - } - } - - err = pgConn.conn.SetDeadline(time.Time{}) - if err != nil { - pgConn.hardClose() - } -} - // MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. type MultiResultReader struct { pgConn *PgConn @@ -1044,13 +960,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) mrr.cleanupContextDeadline() mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) mrr.closed = true - - if err, ok := err.(net.Error); ok && err.Timeout() { - go mrr.pgConn.recoverFromTimeout() - } else { - <-mrr.pgConn.controller - } - + mrr.pgConn.hardClose() return nil, mrr.err } @@ -1236,11 +1146,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error rr.cleanupContextDeadline() rr.closed = true if rr.multiResultReader == nil { - if err, ok := err.(net.Error); ok && err.Timeout() { - go rr.pgConn.recoverFromTimeout() - } else { - <-rr.pgConn.controller - } + rr.pgConn.hardClose() } return nil, rr.err @@ -1270,75 +1176,6 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { rr.commandConcluded = true } -func (pgConn *PgConn) defaultCancel() { - // Regardless of recovery outcome the lock on the pgConn must be released. - defer func() { <-pgConn.controller }() - - // Send a cancellation request to the PostgreSQL server. If it is not successful in a reasonable amount of time do not - // try further to recover the connection. - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) - err := pgConn.cancelRequest(ctx) - cancel() - if err != nil { - pgConn.hardClose() - return - } - - // Limit time to wait for ReadyForQuery message. - err = pgConn.conn.SetDeadline(time.Now().Add(15 * time.Second)) - if err != nil { - pgConn.hardClose() - return - } - - // A cancel query request will always return a "57014" error response, even if no query was in progress. This error - // may be returned before or after the ReadyForQuery message. Must ensure both messages are read. - needError57014 := true - needReadyForQuery := true - - for needError57014 || needReadyForQuery { - msg, err := pgConn.ReceiveMessage() - if err != nil { - pgConn.hardClose() - return - } - - switch msg := msg.(type) { - case *pgproto3.ErrorResponse: - if msg.Code == "57014" { - needError57014 = false - } - case *pgproto3.ReadyForQuery: - needReadyForQuery = false - } - } - - err = pgConn.conn.SetDeadline(time.Time{}) - if err != nil { - pgConn.hardClose() - } -} - -type ContextCancel struct { - PgConn *PgConn -} - -// Finish must be called when the cancellation request has finished processing. The connection must be in a ready for -// query state or the connection must be closed. This must be called regardless of the success of the cancellation and -// whether the connection is still valid or not. It releases an internal busy lock on the connection. -func (cc *ContextCancel) Finish() { - <-cc.PgConn.controller -} - -func (pgConn *PgConn) recoverFromTimeout() { - if pgConn.Config.OnContextCancel == nil { - pgConn.defaultCancel() - } else { - cc := &ContextCancel{PgConn: pgConn} - pgConn.Config.OnContextCancel(cc) - } -} - // Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip. type Batch struct { buf []byte diff --git a/pgconn_stress_test.go b/pgconn_stress_test.go index 1ebbe04a..7288c9b4 100644 --- a/pgconn_stress_test.go +++ b/pgconn_stress_test.go @@ -4,9 +4,9 @@ import ( "context" "math/rand" "os" + "runtime" "strconv" "testing" - "time" "github.com/jackc/pgconn" @@ -14,13 +14,11 @@ import ( ) func TestConnStress(t *testing.T) { - t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer closeConn(t, pgConn) - actionCount := 100 + actionCount := 10000 if s := os.Getenv("PGX_TEST_STRESS_FACTOR"); s != "" { stressFactor, err := strconv.ParseInt(s, 10, 64) require.Nil(t, err, "Failed to parse PGX_TEST_STRESS_FACTOR") @@ -36,9 +34,6 @@ func TestConnStress(t *testing.T) { {"Exec Select", stressExecSelect}, {"ExecParams Select", stressExecParamsSelect}, {"Batch", stressBatch}, - {"ExecCanceled", stressExecSelectCanceled}, - {"ExecParamsCanceled", stressExecParamsSelectCanceled}, - {"BatchCanceled", stressBatchCanceled}, } for i := 0; i < actionCount; i++ { @@ -46,6 +41,10 @@ func TestConnStress(t *testing.T) { err := action.fn(pgConn) require.Nilf(t, err, "%d: %s", i, action.name) } + + // Each call with a context starts a goroutine. Ensure they are cleaned up when context is not canceled. + numGoroutine := runtime.NumGoroutine() + require.Truef(t, numGoroutine < 1000, "goroutines appear to be orphaned: %d in process", numGoroutine) } func setupStressDB(t *testing.T, pgConn *pgconn.PgConn) { @@ -65,56 +64,27 @@ func setupStressDB(t *testing.T, pgConn *pgconn.PgConn) { } func stressExecSelect(pgConn *pgconn.PgConn) error { - _, err := pgConn.Exec(context.Background(), "select * from widgets").ReadAll() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, err := pgConn.Exec(ctx, "select * from widgets").ReadAll() return err } func stressExecParamsSelect(pgConn *pgconn.PgConn) error { - result := pgConn.ExecParams(context.Background(), "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + result := pgConn.ExecParams(ctx, "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read() return result.Err } func stressBatch(pgConn *pgconn.PgConn) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + batch := &pgconn.Batch{} batch.ExecParams("select * from widgets", nil, nil, nil, nil) batch.ExecParams("select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) - _, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() + _, err := pgConn.ExecBatch(ctx, batch).ReadAll() return err } - -func stressExecSelectCanceled(pgConn *pgconn.PgConn) error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) - _, err := pgConn.Exec(ctx, "select *, pg_sleep(1) from widgets").ReadAll() - cancel() - if err != context.DeadlineExceeded { - return err - } - - return nil -} - -func stressExecParamsSelectCanceled(pgConn *pgconn.PgConn) error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) - result := pgConn.ExecParams(ctx, "select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read() - cancel() - if result.Err != context.DeadlineExceeded { - return result.Err - } - - return nil -} - -func stressBatchCanceled(pgConn *pgconn.PgConn) error { - batch := &pgconn.Batch{} - batch.ExecParams("select * from widgets", nil, nil, nil, nil) - batch.ExecParams("select *, pg_sleep(1) from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) - _, err := pgConn.ExecBatch(ctx, batch).ReadAll() - cancel() - if err != context.DeadlineExceeded { - return err - } - - return nil -} diff --git a/pgconn_test.go b/pgconn_test.go index 716761ad..88c6f7c4 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -16,7 +16,6 @@ import ( "time" "github.com/jackc/pgconn" - "github.com/jackc/pgproto3" "github.com/pkg/errors" "github.com/stretchr/testify/assert" @@ -356,8 +355,7 @@ func TestConnExecContextCanceled(t *testing.T) { } err = multiResult.Close() assert.Equal(t, context.DeadlineExceeded, err) - - ensureConnValid(t, pgConn) + assert.False(t, pgConn.IsAlive()) } func TestConnExecParams(t *testing.T) { @@ -400,7 +398,7 @@ func TestConnExecParamsCanceled(t *testing.T) { assert.Equal(t, pgconn.CommandTag(""), commandTag) assert.Equal(t, context.DeadlineExceeded, err) - ensureConnValid(t, pgConn) + assert.False(t, pgConn.IsAlive()) } func TestConnExecPrepared(t *testing.T) { @@ -451,8 +449,7 @@ func TestConnExecPreparedCanceled(t *testing.T) { commandTag, err := result.Close() assert.Equal(t, pgconn.CommandTag(""), commandTag) assert.Equal(t, context.DeadlineExceeded, err) - - ensureConnValid(t, pgConn) + assert.False(t, pgConn.IsAlive()) } func TestConnExecBatch(t *testing.T) { @@ -510,72 +507,6 @@ func TestCommandTag(t *testing.T) { } } -func TestConnContextCancelWithOnContextCancel(t *testing.T) { - t.Parallel() - - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - - calledChan := make(chan struct{}) - - config.OnContextCancel = func(cc *pgconn.ContextCancel) { - defer cc.Finish() - close(calledChan) - - for { - msg, err := cc.PgConn.ReceiveMessage() - if err != nil { - cc.PgConn.Close(context.Background()) - return - } - - switch msg.(type) { - case *pgproto3.ReadyForQuery: - return - } - } - } - - pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, pgConn) - - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - result := pgConn.ExecParams(ctx, "select 'Hello, world', pg_sleep(0.25)", nil, nil, nil, nil) - _, err = result.Close() - assert.Equal(t, context.DeadlineExceeded, err) - - called := false - select { - case <-calledChan: - called = true - case <-time.NewTimer(time.Second).C: - } - - assert.True(t, called) - - ensureConnValid(t, pgConn) -} - -func TestConnWaitUntilReady(t *testing.T) { - t.Parallel() - - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - result := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil).Read() - assert.Equal(t, context.DeadlineExceeded, result.Err) - - err = pgConn.WaitUntilReady(context.Background()) - require.NoError(t, err) - - ensureConnValid(t, pgConn) -} - func TestConnOnNotice(t *testing.T) { t.Parallel() @@ -792,7 +723,7 @@ func TestConnCopyToCanceled(t *testing.T) { assert.Equal(t, context.DeadlineExceeded, err) assert.Equal(t, pgconn.CommandTag(""), res) - ensureConnValid(t, pgConn) + assert.False(t, pgConn.IsAlive()) } func TestConnCopyFrom(t *testing.T) { @@ -991,6 +922,28 @@ func TestConnEscapeString(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnCancelRequest(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + multiResult := pgConn.Exec(context.Background(), "select 'Hello, world', pg_sleep(5)") + + err = pgConn.CancelRequest(context.Background()) + require.NoError(t, err) + + for multiResult.NextResult() { + } + err = multiResult.Close() + + require.IsType(t, &pgconn.PgError{}, err) + require.Equal(t, "57014", err.(*pgconn.PgError).Code) + + ensureConnValid(t, pgConn) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) if err != nil { From 3d9e42d74c14ed6f091449fc7602727c3dc49d07 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Mar 2019 17:09:39 -0500 Subject: [PATCH 067/290] Replace chan based conn locking with bool This is conceptually simpler and will lead to error messages instead of deadlocks. --- pgconn.go | 112 ++++++++++++++++++++++++++++++++++++------------- pgconn_test.go | 23 ++++++++++ 2 files changed, 106 insertions(+), 29 deletions(-) diff --git a/pgconn.go b/pgconn.go index 8b0ddcb4..e246bcdd 100644 --- a/pgconn.go +++ b/pgconn.go @@ -89,8 +89,7 @@ type PgConn struct { Config *Config - controller chan interface{} - + locked bool closed bool bufferingReceive bool @@ -153,7 +152,6 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { pgConn := new(PgConn) pgConn.Config = config - pgConn.controller = make(chan interface{}, 1) var err error network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) @@ -405,6 +403,29 @@ func (pgConn *PgConn) IsAlive() bool { return !pgConn.closed } +// lock locks the connection. It returns an error if the connection is already locked or is closed. +func (pgConn *PgConn) lock() error { + if pgConn.locked { + return errors.New("connection busy") + } + + if pgConn.closed { + return errors.New("connection closed") + } + + pgConn.locked = true + + return nil +} + +func (pgConn *PgConn) unlock() { + if !pgConn.locked { + panic("BUG: cannot unlock unlocked connection") + } + + pgConn.locked = false +} + // ParameterStatus returns the value of a parameter reported by the server (e.g. // server_version). Returns an empty string for unknown parameters. func (pgConn *PgConn) ParameterStatus(key string) string { @@ -476,10 +497,14 @@ type PreparedStatementDescription struct { // Prepare creates a prepared statement. func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) { + if err := pgConn.lock(); err != nil { + return nil, err + } + select { case <-ctx.Done(): return nil, ctx.Err() - case pgConn.controller <- pgConn: + default: } cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanupContextDeadline() @@ -521,7 +546,7 @@ readloop: } } - <-pgConn.controller + pgConn.unlock() if parseErr != nil { return nil, parseErr @@ -594,14 +619,18 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not // received. func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { + if err := pgConn.lock(); err != nil { + return err + } + select { case <-ctx.Done(): return ctx.Err() - case pgConn.controller <- pgConn: + default: } cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanupContextDeadline() - defer func() { <-pgConn.controller }() + defer pgConn.unlock() for { msg, err := pgConn.ReceiveMessage() @@ -628,12 +657,18 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { cleanupContextDeadline: func() {}, } + if err := pgConn.lock(); err != nil { + multiResult.closed = true + multiResult.err = err + return multiResult + } + select { case <-ctx.Done(): multiResult.closed = true multiResult.err = ctx.Err() return multiResult - case pgConn.controller <- multiResult: + default: } multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) @@ -646,7 +681,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { multiResult.cleanupContextDeadline() multiResult.closed = true multiResult.err = preferContextOverNetTimeoutError(ctx, err) - <-pgConn.controller + pgConn.unlock() return multiResult } @@ -679,12 +714,18 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] cleanupContextDeadline: func() {}, } + if err := pgConn.lock(); err != nil { + result.concludeCommand("", err) + result.closed = true + return result + } + select { case <-ctx.Done(): result.concludeCommand("", ctx.Err()) result.closed = true return result - case pgConn.controller <- result: + default: } result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) @@ -704,7 +745,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] result.concludeCommand("", err) result.cleanupContextDeadline() result.closed = true - <-pgConn.controller + pgConn.unlock() } return result @@ -729,12 +770,18 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa cleanupContextDeadline: func() {}, } + if err := pgConn.lock(); err != nil { + result.concludeCommand("", err) + result.closed = true + return result + } + select { case <-ctx.Done(): result.concludeCommand("", ctx.Err()) result.closed = true return result - case pgConn.controller <- result: + default: } result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) @@ -750,7 +797,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa result.concludeCommand("", err) result.cleanupContextDeadline() result.closed = true - <-pgConn.controller + pgConn.unlock() } return result @@ -758,10 +805,14 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa // CopyTo executes the copy command sql and copies the results to w. func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { + if err := pgConn.lock(); err != nil { + return "", err + } + select { case <-ctx.Done(): return "", ctx.Err() - case pgConn.controller <- pgConn: + default: } cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanupContextDeadline() @@ -773,7 +824,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm _, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - <-pgConn.controller + pgConn.unlock() return "", preferContextOverNetTimeoutError(ctx, err) } @@ -797,7 +848,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm return "", err } case *pgproto3.ReadyForQuery: - <-pgConn.controller + pgConn.unlock() return commandTag, pgErr case *pgproto3.CommandComplete: commandTag = CommandTag(msg.CommandTag) @@ -812,10 +863,15 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm // Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r // could still block. func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { + if err := pgConn.lock(); err != nil { + return "", err + } + defer pgConn.unlock() + select { case <-ctx.Done(): return "", ctx.Err() - case pgConn.controller <- pgConn: + default: } cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanupContextDeadline() @@ -827,8 +883,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - <-pgConn.controller - return "", preferContextOverNetTimeoutError(ctx, err) } @@ -849,7 +903,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co case *pgproto3.ErrorResponse: pgErr = errorResponseToPgError(msg) case *pgproto3.ReadyForQuery: - <-pgConn.controller return commandTag, pgErr } } @@ -871,8 +924,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - <-pgConn.controller - return "", preferContextOverNetTimeoutError(ctx, err) } } @@ -904,8 +955,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - <-pgConn.controller - return "", preferContextOverNetTimeoutError(ctx, err) } @@ -919,7 +968,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co switch msg := msg.(type) { case *pgproto3.ReadyForQuery: - <-pgConn.controller return commandTag, pgErr case *pgproto3.CommandComplete: commandTag = CommandTag(msg.CommandTag) @@ -968,7 +1016,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) case *pgproto3.ReadyForQuery: mrr.cleanupContextDeadline() mrr.closed = true - <-mrr.pgConn.controller + mrr.pgConn.unlock() case *pgproto3.ErrorResponse: mrr.err = errorResponseToPgError(msg) } @@ -1125,7 +1173,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { switch msg.(type) { case *pgproto3.ReadyForQuery: rr.cleanupContextDeadline() - <-rr.pgConn.controller + rr.pgConn.unlock() return rr.commandTag, rr.err } } @@ -1203,12 +1251,18 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR cleanupContextDeadline: func() {}, } + if err := pgConn.lock(); err != nil { + multiResult.closed = true + multiResult.err = ctx.Err() + return multiResult + } + select { case <-ctx.Done(): multiResult.closed = true multiResult.err = ctx.Err() return multiResult - case pgConn.controller <- multiResult: + default: } multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) @@ -1219,7 +1273,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR multiResult.cleanupContextDeadline() multiResult.closed = true multiResult.err = preferContextOverNetTimeoutError(ctx, err) - <-pgConn.controller + pgConn.unlock() return multiResult } diff --git a/pgconn_test.go b/pgconn_test.go index 88c6f7c4..53e3b9d8 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -484,6 +484,29 @@ func TestConnExecBatch(t *testing.T) { assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) } +func TestConnLocking(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") + results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + assert.Error(t, err) + assert.Equal(t, "connection busy", err.Error()) + + results, err = mrr.ReadAll() + assert.NoError(t, err) + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + + ensureConnValid(t, pgConn) +} + func TestCommandTag(t *testing.T) { t.Parallel() From ed7d91dc987364b13d9039fe83a76aa993e9cdf9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Mar 2019 17:13:23 -0500 Subject: [PATCH 068/290] Force Go modules for Travis --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index 950792d1..50e81eb5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,6 +10,7 @@ before_install: env: global: + - GO111MODULE=on - PGX_TEST_DATABASE=postgres://pgx_md5:secret@127.0.0.1/pgx_test - PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql database=pgx_test" - PGX_TEST_TCP_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test From c745509c595970c1776bee941d3fd969f313b845 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 5 Apr 2019 11:27:04 -0500 Subject: [PATCH 069/290] Rename test --- pgconn_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgconn_test.go b/pgconn_test.go index 53e3b9d8..ab8ae173 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -239,7 +239,7 @@ func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { } } -func TestConnPrepareFailure(t *testing.T) { +func TestConnPrepareSyntaxError(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) From 408837dcb1e5fb4535ab313178a64a6ad79d9bbb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 5 Apr 2019 11:47:31 -0500 Subject: [PATCH 070/290] Handle extended protocol with too many arguments --- pgconn.go | 17 ++++++++ pgconn_test.go | 106 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 123 insertions(+) diff --git a/pgconn.go b/pgconn.go index e246bcdd..223b8e3d 100644 --- a/pgconn.go +++ b/pgconn.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "io" + "math" "net" "strconv" "strings" @@ -720,10 +721,18 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] return result } + if len(paramValues) > math.MaxUint16 { + result.concludeCommand("", fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) + result.closed = true + pgConn.unlock() + return result + } + select { case <-ctx.Done(): result.concludeCommand("", ctx.Err()) result.closed = true + pgConn.unlock() return result default: } @@ -776,10 +785,18 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa return result } + if len(paramValues) > math.MaxUint16 { + result.concludeCommand("", fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) + result.closed = true + pgConn.unlock() + return result + } + select { case <-ctx.Done(): result.concludeCommand("", ctx.Err()) result.closed = true + pgConn.unlock() return result default: } diff --git a/pgconn_test.go b/pgconn_test.go index ab8ae173..b2514e48 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -9,9 +9,11 @@ import ( "io" "io/ioutil" "log" + "math" "net" "os" "strconv" + "strings" "testing" "time" @@ -379,6 +381,52 @@ func TestConnExecParams(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnExecParamsMaxNumberOfParams(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") + + result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) + + ensureConnValid(t, pgConn) +} + +func TestConnExecParamsTooManyParams(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + paramCount := math.MaxUint16 + 1 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") + + result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) + + ensureConnValid(t, pgConn) +} + func TestConnExecParamsCanceled(t *testing.T) { t.Parallel() @@ -428,6 +476,64 @@ func TestConnExecPrepared(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnExecPreparedMaxNumberOfParams(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") + + psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) + + result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) + + ensureConnValid(t, pgConn) +} + +func TestConnExecPreparedTooManyParams(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + paramCount := math.MaxUint16 + 1 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") + + psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) + + result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) + + ensureConnValid(t, pgConn) +} + func TestConnExecPreparedCanceled(t *testing.T) { t.Parallel() From 7ad3625edd3b36e00d73c0c09009d8841074daed Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 5 Apr 2019 12:06:59 -0500 Subject: [PATCH 071/290] unlock connection when context is pre-canceled --- pgconn.go | 5 ++ pgconn_test.go | 166 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 171 insertions(+) diff --git a/pgconn.go b/pgconn.go index 223b8e3d..db741d47 100644 --- a/pgconn.go +++ b/pgconn.go @@ -504,6 +504,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ select { case <-ctx.Done(): + pgConn.unlock() return nil, ctx.Err() default: } @@ -626,6 +627,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { select { case <-ctx.Done(): + pgConn.unlock() return ctx.Err() default: } @@ -668,6 +670,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { case <-ctx.Done(): multiResult.closed = true multiResult.err = ctx.Err() + pgConn.unlock() return multiResult default: } @@ -828,6 +831,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm select { case <-ctx.Done(): + pgConn.unlock() return "", ctx.Err() default: } @@ -1278,6 +1282,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR case <-ctx.Done(): multiResult.closed = true multiResult.err = ctx.Err() + pgConn.unlock() return multiResult default: } diff --git a/pgconn_test.go b/pgconn_test.go index b2514e48..66a4337b 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -255,6 +255,23 @@ func TestConnPrepareSyntaxError(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnPrepareContextPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + psd, err := pgConn.Prepare(ctx, "ps1", "select 1", nil) + require.Nil(t, psd) + require.Error(t, err) + require.Equal(t, context.Canceled, err) + + ensureConnValid(t, pgConn) +} + func TestConnExec(t *testing.T) { t.Parallel() @@ -360,6 +377,22 @@ func TestConnExecContextCanceled(t *testing.T) { assert.False(t, pgConn.IsAlive()) } +func TestConnExecContextPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() + require.Error(t, err) + require.Equal(t, context.Canceled, err) + + ensureConnValid(t, pgConn) +} + func TestConnExecParams(t *testing.T) { t.Parallel() @@ -449,6 +482,22 @@ func TestConnExecParamsCanceled(t *testing.T) { assert.False(t, pgConn.IsAlive()) } +func TestConnExecParamsPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, context.Canceled, result.Err) + + ensureConnValid(t, pgConn) +} + func TestConnExecPrepared(t *testing.T) { t.Parallel() @@ -558,6 +607,25 @@ func TestConnExecPreparedCanceled(t *testing.T) { assert.False(t, pgConn.IsAlive()) } +func TestConnExecPreparedPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, context.Canceled, result.Err) + + ensureConnValid(t, pgConn) +} + func TestConnExecBatch(t *testing.T) { t.Parallel() @@ -590,6 +658,31 @@ func TestConnExecBatch(t *testing.T) { assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) } +func TestConnExecBatchPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + require.NoError(t, err) + + batch := &pgconn.Batch{} + + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) + batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err = pgConn.ExecBatch(ctx, batch).ReadAll() + require.Error(t, err) + require.Equal(t, context.Canceled, err) + + ensureConnValid(t, pgConn) +} + func TestConnLocking(t *testing.T) { t.Parallel() @@ -726,6 +819,24 @@ func TestConnWaitForNotification(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnWaitForNotificationPrecanceled(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err = pgConn.WaitForNotification(ctx) + require.Equal(t, context.Canceled, err) + + ensureConnValid(t, pgConn) +} + func TestConnWaitForNotificationTimeout(t *testing.T) { t.Parallel() @@ -855,6 +966,25 @@ func TestConnCopyToCanceled(t *testing.T) { assert.False(t, pgConn.IsAlive()) } +func TestConnCopyToPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + outputWriter := &bytes.Buffer{} + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout") + require.Error(t, err) + require.Equal(t, context.Canceled, err) + assert.Equal(t, pgconn.CommandTag(""), res) + + ensureConnValid(t, pgConn) +} + func TestConnCopyFrom(t *testing.T) { t.Parallel() @@ -926,6 +1056,42 @@ func TestConnCopyFromCanceled(t *testing.T) { assert.False(t, pgConn.IsAlive()) } +func TestConnCopyFromPrecanceled(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + r, w := io.Pipe() + go func() { + for i := 0; i < 1000000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + _, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + if err != nil { + return + } + time.Sleep(time.Microsecond) + } + }() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.Error(t, err) + require.Equal(t, context.Canceled, err) + assert.Equal(t, pgconn.CommandTag(""), ct) + + ensureConnValid(t, pgConn) +} + func TestConnCopyFromGzipReader(t *testing.T) { t.Parallel() From 0ebe322ac3600c12c9d1989f0ad6b3322c224c4a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 5 Apr 2019 16:10:11 -0500 Subject: [PATCH 072/290] Extract common code from ExecParams and ExecPrepared --- pgconn.go | 65 ++++++++++++++++++------------------------------------- 1 file changed, 21 insertions(+), 44 deletions(-) diff --git a/pgconn.go b/pgconn.go index db741d47..7ddc50e6 100644 --- a/pgconn.go +++ b/pgconn.go @@ -712,53 +712,16 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { // // ResultReader must be closed before PgConn can be used again. func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *ResultReader { - result := &ResultReader{ - pgConn: pgConn, - ctx: ctx, - cleanupContextDeadline: func() {}, - } - - if err := pgConn.lock(); err != nil { - result.concludeCommand("", err) - result.closed = true + result := pgConn.execExtendedPrefix(ctx, paramValues) + if result.closed { return result } - if len(paramValues) > math.MaxUint16 { - result.concludeCommand("", fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) - result.closed = true - pgConn.unlock() - return result - } - - select { - case <-ctx.Done(): - result.concludeCommand("", ctx.Err()) - result.closed = true - pgConn.unlock() - return result - default: - } - result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) - var buf []byte - - // TODO - refactor ExecParams and ExecPrepared - these lines only difference buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) - buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) - buf = (&pgproto3.Execute{}).Encode(buf) - buf = (&pgproto3.Sync{}).Encode(buf) - - _, err := pgConn.conn.Write(buf) - if err != nil { - pgConn.hardClose() - result.concludeCommand("", err) - result.cleanupContextDeadline() - result.closed = true - pgConn.unlock() - } + pgConn.execExtendedSuffix(buf, result) return result } @@ -776,6 +739,20 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] // // ResultReader must be closed before PgConn can be used again. func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *ResultReader { + result := pgConn.execExtendedPrefix(ctx, paramValues) + if result.closed { + return result + } + + var buf []byte + buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) + + pgConn.execExtendedSuffix(buf, result) + + return result +} + +func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader { result := &ResultReader{ pgConn: pgConn, ctx: ctx, @@ -805,8 +782,10 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa } result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) - var buf []byte - buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) + return result +} + +func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) buf = (&pgproto3.Execute{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) @@ -819,8 +798,6 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa result.closed = true pgConn.unlock() } - - return result } // CopyTo executes the copy command sql and copies the results to w. From 698bd4bf5a75e4c6386e38225e701d7a08da4c86 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Apr 2019 10:30:49 -0500 Subject: [PATCH 073/290] Use defer to unlock pgConn in Prepare --- pgconn.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pgconn.go b/pgconn.go index 7ddc50e6..c9891dbf 100644 --- a/pgconn.go +++ b/pgconn.go @@ -501,10 +501,10 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ if err := pgConn.lock(); err != nil { return nil, err } + defer pgConn.unlock() select { case <-ctx.Done(): - pgConn.unlock() return nil, ctx.Err() default: } @@ -548,8 +548,6 @@ readloop: } } - pgConn.unlock() - if parseErr != nil { return nil, parseErr } From 244e114435d4afb7934392284613255b639d6fb9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 16 Apr 2019 20:41:38 -0500 Subject: [PATCH 074/290] Add SCRAM authentication --- auth_scram.go | 255 +++++++++++++++++++++++++++++++++++++++++++++++++ go.mod | 4 +- go.sum | 11 ++- pgconn.go | 2 + pgconn_test.go | 3 +- 5 files changed, 268 insertions(+), 7 deletions(-) create mode 100644 auth_scram.go diff --git a/auth_scram.go b/auth_scram.go new file mode 100644 index 00000000..b78a236a --- /dev/null +++ b/auth_scram.go @@ -0,0 +1,255 @@ +// SCRAM-SHA-256 authentication +// +// Resources: +// https://tools.ietf.org/html/rfc5802 +// https://tools.ietf.org/html/rfc8265 +// https://www.postgresql.org/docs/current/sasl-authentication.html +// +// Inspiration drawn from other implementations: +// https://github.com/lib/pq/pull/608 +// https://github.com/lib/pq/pull/788 +// https://github.com/lib/pq/pull/833 +package pgconn + +import ( + "bytes" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "errors" + "fmt" + "strconv" + + "github.com/jackc/pgproto3" + "golang.org/x/crypto/pbkdf2" + "golang.org/x/text/secure/precis" +) + +const clientNonceLen = 18 + +// Perform SCRAM authentication. +func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { + sc, err := newScramClient(serverAuthMechanisms, c.Config.Password) + if err != nil { + return err + } + + // Send client-first-message in a SASLInitialResponse + saslInitialResponse := &pgproto3.SASLInitialResponse{ + AuthMechanism: "SCRAM-SHA-256", + Data: sc.clientFirstMessage(), + } + _, err = c.conn.Write(saslInitialResponse.Encode(nil)) + if err != nil { + return err + } + + // Receive server-first-message payload in a AuthenticationSASLContinue. + authMsg, err := c.rxAuthMsg(pgproto3.AuthTypeSASLContinue) + if err != nil { + return err + } + err = sc.recvServerFirstMessage(authMsg.SASLData) + if err != nil { + return err + } + + // Send client-final-message in a SASLResponse + saslResponse := &pgproto3.SASLResponse{ + Data: []byte(sc.clientFinalMessage()), + } + _, err = c.conn.Write(saslResponse.Encode(nil)) + if err != nil { + return err + } + + // Receive server-final-message payload in a AuthenticationSASLFinal. + authMsg, err = c.rxAuthMsg(pgproto3.AuthTypeSASLFinal) + if err != nil { + return err + } + return sc.recvServerFinalMessage(authMsg.SASLData) +} + +func (c *PgConn) rxAuthMsg(typ uint32) (*pgproto3.Authentication, error) { + msg, err := c.ReceiveMessage() + if err != nil { + return nil, err + } + authMsg, ok := msg.(*pgproto3.Authentication) + if !ok { + return nil, errors.New("unexpected message type") + } + if authMsg.Type != typ { + return nil, errors.New("unexpected auth type") + } + + return authMsg, nil +} + +type scramClient struct { + serverAuthMechanisms []string + password []byte + clientNonce []byte + + clientFirstMessageBare []byte + + serverFirstMessage []byte + clientAndServerNonce []byte + salt []byte + iterations int + + saltedPassword []byte + authMessage []byte +} + +func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) { + sc := &scramClient{ + serverAuthMechanisms: serverAuthMechanisms, + } + + // Ensure server supports SCRAM-SHA-256 + hasScramSHA256 := false + for _, mech := range sc.serverAuthMechanisms { + if mech == "SCRAM-SHA-256" { + hasScramSHA256 = true + break + } + } + if !hasScramSHA256 { + return nil, errors.New("server does not support SCRAM-SHA-256") + } + + // precis.OpaqueString is equivalent to SASLprep for password. + var err error + sc.password, err = precis.OpaqueString.Bytes([]byte(password)) + if err != nil { + // PostgreSQL allows passwords invalid according to SCRAM / SASLprep. + sc.password = []byte(password) + } + + buf := make([]byte, clientNonceLen) + _, err = rand.Read(buf) + if err != nil { + return nil, err + } + sc.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len(buf))) + base64.RawStdEncoding.Encode(sc.clientNonce, buf) + + return sc, nil +} + +func (sc *scramClient) clientFirstMessage() []byte { + sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce)) + return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare)) +} + +func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error { + sc.serverFirstMessage = serverFirstMessage + buf := serverFirstMessage + if !bytes.HasPrefix(buf, []byte("r=")) { + return errors.New("invalid SCRAM server-first-message received from server: did not include r=") + } + buf = buf[2:] + + idx := bytes.IndexByte(buf, ',') + if idx == -1 { + return errors.New("invalid SCRAM server-first-message received from server: did not include s=") + } + sc.clientAndServerNonce = buf[:idx] + buf = buf[idx+1:] + + if !bytes.HasPrefix(buf, []byte("s=")) { + return errors.New("invalid SCRAM server-first-message received from server: did not include s=") + } + buf = buf[2:] + + idx = bytes.IndexByte(buf, ',') + if idx == -1 { + return errors.New("invalid SCRAM server-first-message received from server: did not include i=") + } + saltStr := buf[:idx] + buf = buf[idx+1:] + + if !bytes.HasPrefix(buf, []byte("i=")) { + return errors.New("invalid SCRAM server-first-message received from server: did not include i=") + } + buf = buf[2:] + iterationsStr := buf + + var err error + sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr)) + if err != nil { + return fmt.Errorf("invalid SCRAM salt received from server: %v", err) + } + + sc.iterations, err = strconv.Atoi(string(iterationsStr)) + if err != nil || sc.iterations <= 0 { + return fmt.Errorf("invalid SCRAM iteration count received from server: %s", iterationsStr) + } + + if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) { + return errors.New("invalid SCRAM nonce: did not start with client nonce") + } + + if len(sc.clientAndServerNonce) <= len(sc.clientNonce) { + return errors.New("invalid SCRAM nonce: did not include server nonce") + } + + return nil +} + +func (sc *scramClient) clientFinalMessage() string { + clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce)) + + sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New) + sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(",")) + + clientProof := computeClientProof(sc.saltedPassword, sc.authMessage) + + return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof) +} + +func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error { + if !bytes.HasPrefix(serverFinalMessage, []byte("v=")) { + return errors.New("invalid SCRAM server-final-message received from server") + } + + serverSignature := serverFinalMessage[2:] + + if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) { + return errors.New("invalid SCRAM ServerSignature received from server") + } + + return nil +} + +func computeHMAC(key, msg []byte) []byte { + mac := hmac.New(sha256.New, key) + mac.Write(msg) + return mac.Sum(nil) +} + +func computeClientProof(saltedPassword, authMessage []byte) []byte { + clientKey := computeHMAC(saltedPassword, []byte("Client Key")) + storedKey := sha256.Sum256(clientKey) + clientSignature := computeHMAC(storedKey[:], authMessage) + + clientProof := make([]byte, len(clientSignature)) + for i := 0; i < len(clientSignature); i++ { + clientProof[i] = clientKey[i] ^ clientSignature[i] + } + + buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof))) + base64.StdEncoding.Encode(buf, clientProof) + return buf +} + +func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte { + serverKey := computeHMAC(saltedPassword, []byte("Server Key")) + serverSignature := computeHMAC(serverKey[:], authMessage) + buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature))) + base64.StdEncoding.Encode(buf, serverSignature) + return buf +} diff --git a/go.mod b/go.mod index 3dc806a4..09b4471d 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,9 @@ go 1.12 require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3 v1.0.0 + github.com/jackc/pgproto3 v1.1.0 github.com/pkg/errors v0.8.1 github.com/stretchr/testify v1.3.0 + golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a + golang.org/x/text v0.3.0 ) diff --git a/go.sum b/go.sum index 5b6f835b..8872aac1 100644 --- a/go.sum +++ b/go.sum @@ -6,10 +6,8 @@ github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgproto3 v0.0.0-20190330174656-bb06e6b3ff87 h1:xueDi0R+HxuFmuOA1xyFbbF+2LSXqWQJZSPWmmMFB0A= -github.com/jackc/pgproto3 v0.0.0-20190330174656-bb06e6b3ff87/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= -github.com/jackc/pgproto3 v1.0.0 h1:25tUmlES7eyD96oYaUHc1dLOFbgcJtFzCdnOOoqmA1I= -github.com/jackc/pgproto3 v1.0.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= +github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= +github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -17,3 +15,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcNpq8q3BCACtVgNfoJxOV7g= +golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/pgconn.go b/pgconn.go index c9891dbf..264d9e8c 100644 --- a/pgconn.go +++ b/pgconn.go @@ -260,6 +260,8 @@ func (c *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { case pgproto3.AuthTypeMD5Password: digestedPassword := "md5" + hexMD5(hexMD5(c.Config.Password+c.Config.User)+string(msg.Salt[:])) err = c.txPasswordMessage(digestedPassword) + case pgproto3.AuthTypeSASL: + err = c.scramAuth(msg.SASLAuthMechanisms) default: err = errors.New("Received unknown authentication message") } diff --git a/pgconn_test.go b/pgconn_test.go index 66a4337b..fd57face 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -33,12 +33,11 @@ func TestConnect(t *testing.T) { {"TCP", "PGX_TEST_TCP_CONN_STRING"}, {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, + {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - t.Parallel() - connString := os.Getenv(tt.env) if connString == "" { t.Skipf("Skipping due to missing environment variable %v", tt.env) From 0174907e04e75b23e393d98c31593b886e599e5f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 16 Apr 2019 20:58:10 -0500 Subject: [PATCH 075/290] Fix travis unix domain socket test --- travis/before_script.bash | 1 + 1 file changed, 1 insertion(+) diff --git a/travis/before_script.bash b/travis/before_script.bash index bcf748a1..923b7d06 100755 --- a/travis/before_script.bash +++ b/travis/before_script.bash @@ -11,6 +11,7 @@ then psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'" psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'" psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'" + psql -U postgres -c "create user travis" psql -U postgres -c "create user pgx_replication with replication password 'secret'" psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'" fi From e948dc3246b579f05e9d508be07bf2816ba96b3d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 18 Apr 2019 21:51:58 -0500 Subject: [PATCH 076/290] Reuse buffer for writing --- pgconn.go | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/pgconn.go b/pgconn.go index 264d9e8c..82a010b8 100644 --- a/pgconn.go +++ b/pgconn.go @@ -97,6 +97,8 @@ type PgConn struct { bufferingReceiveMux sync.Mutex bufferingReceiveMsg pgproto3.BackendMessage bufferingReceiveErr error + + wbuf []byte // Reusable write buffer } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) @@ -153,6 +155,7 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { pgConn := new(PgConn) pgConn.Config = config + pgConn.wbuf = make([]byte, 0, 1024) var err error network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) @@ -190,7 +193,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig startupMsg.Parameters["database"] = config.Database } - if _, err := pgConn.conn.Write(startupMsg.Encode(nil)); err != nil { + if _, err := pgConn.conn.Write(startupMsg.Encode(pgConn.wbuf)); err != nil { pgConn.conn.Close() return nil, err } @@ -271,7 +274,7 @@ func (c *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { func (pgConn *PgConn) txPasswordMessage(password string) (err error) { msg := &pgproto3.PasswordMessage{Password: password} - _, err = pgConn.conn.Write(msg.Encode(nil)) + _, err = pgConn.conn.Write(msg.Encode(pgConn.wbuf)) return err } @@ -513,7 +516,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanupContextDeadline() - var buf []byte + buf := pgConn.wbuf buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) @@ -676,7 +679,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { } multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) - var buf []byte + buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) _, err := pgConn.conn.Write(buf) @@ -717,7 +720,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] return result } - var buf []byte + buf := pgConn.wbuf buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) @@ -744,7 +747,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa return result } - var buf []byte + buf := pgConn.wbuf buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) pgConn.execExtendedSuffix(buf, result) @@ -816,7 +819,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm defer cleanupContextDeadline() // Send copy to command - var buf []byte + buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) _, err := pgConn.conn.Write(buf) @@ -875,7 +878,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co defer cleanupContextDeadline() // Send copy to command - var buf []byte + buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) _, err := pgConn.conn.Write(buf) From bc139fadb5b49cf4159b33c1312cc66ef0582c7e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 18 Apr 2019 22:01:47 -0500 Subject: [PATCH 077/290] Reuse one ResultReader per connection --- pgconn.go | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/pgconn.go b/pgconn.go index 82a010b8..4cf4d745 100644 --- a/pgconn.go +++ b/pgconn.go @@ -98,7 +98,9 @@ type PgConn struct { bufferingReceiveMsg pgproto3.BackendMessage bufferingReceiveErr error - wbuf []byte // Reusable write buffer + // Reusable / preallocated resources + wbuf []byte // write buffer + resultReader ResultReader } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) @@ -756,11 +758,12 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa } func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader { - result := &ResultReader{ + pgConn.resultReader = ResultReader{ pgConn: pgConn, ctx: ctx, cleanupContextDeadline: func() {}, } + result := &pgConn.resultReader if err := pgConn.lock(); err != nil { result.concludeCommand("", err) @@ -1035,20 +1038,22 @@ func (mrr *MultiResultReader) NextResult() bool { switch msg := msg.(type) { case *pgproto3.RowDescription: - mrr.rr = &ResultReader{ + mrr.pgConn.resultReader = ResultReader{ pgConn: mrr.pgConn, multiResultReader: mrr, ctx: mrr.ctx, cleanupContextDeadline: func() {}, fieldDescriptions: msg.Fields, } + mrr.rr = &mrr.pgConn.resultReader return true case *pgproto3.CommandComplete: - mrr.rr = &ResultReader{ + mrr.pgConn.resultReader = ResultReader{ commandTag: CommandTag(msg.CommandTag), commandConcluded: true, closed: true, } + mrr.rr = &mrr.pgConn.resultReader return true case *pgproto3.EmptyQueryResponse: return false From b6e5b74e2c82dc3305355453ec86dc002bf577b4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 18 Apr 2019 22:50:36 -0500 Subject: [PATCH 078/290] Reuse one MultiResultReader per connection Using a PgConn while locked now panics. i.e. You must Close any ResultReader or MultiResultReader. --- pgconn.go | 65 ++++++++++++++++---------------------------------- pgconn_test.go | 6 ++--- 2 files changed, 23 insertions(+), 48 deletions(-) diff --git a/pgconn.go b/pgconn.go index 4cf4d745..7e8909ea 100644 --- a/pgconn.go +++ b/pgconn.go @@ -99,8 +99,9 @@ type PgConn struct { bufferingReceiveErr error // Reusable / preallocated resources - wbuf []byte // write buffer - resultReader ResultReader + wbuf []byte // write buffer + resultReader ResultReader + multiResultReader MultiResultReader } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) @@ -411,24 +412,18 @@ func (pgConn *PgConn) IsAlive() bool { return !pgConn.closed } -// lock locks the connection. It returns an error if the connection is already locked or is closed. -func (pgConn *PgConn) lock() error { +// lock locks the connection. It panics if the connection is already locked or is closed. +func (pgConn *PgConn) lock() { if pgConn.locked { - return errors.New("connection busy") - } - - if pgConn.closed { - return errors.New("connection closed") + panic("connection busy") // This only should be possible in case of an application bug. } pgConn.locked = true - - return nil } func (pgConn *PgConn) unlock() { if !pgConn.locked { - panic("BUG: cannot unlock unlocked connection") + panic("BUG: cannot unlock unlocked connection") // This should only be possible if there is a bug in this package. } pgConn.locked = false @@ -505,9 +500,7 @@ type PreparedStatementDescription struct { // Prepare creates a prepared statement. func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) { - if err := pgConn.lock(); err != nil { - return nil, err - } + pgConn.lock() defer pgConn.unlock() select { @@ -626,9 +619,7 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not // received. func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { - if err := pgConn.lock(); err != nil { - return err - } + pgConn.lock() select { case <-ctx.Done(): @@ -659,17 +650,14 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { // // Prefer ExecParams unless executing arbitrary SQL that may contain multiple queries. func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { - multiResult := &MultiResultReader{ + pgConn.lock() + + pgConn.multiResultReader = MultiResultReader{ pgConn: pgConn, ctx: ctx, cleanupContextDeadline: func() {}, } - - if err := pgConn.lock(); err != nil { - multiResult.closed = true - multiResult.err = err - return multiResult - } + multiResult := &pgConn.multiResultReader select { case <-ctx.Done(): @@ -758,6 +746,8 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa } func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader { + pgConn.lock() + pgConn.resultReader = ResultReader{ pgConn: pgConn, ctx: ctx, @@ -765,12 +755,6 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by } result := &pgConn.resultReader - if err := pgConn.lock(); err != nil { - result.concludeCommand("", err) - result.closed = true - return result - } - if len(paramValues) > math.MaxUint16 { result.concludeCommand("", fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) result.closed = true @@ -808,9 +792,7 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { // CopyTo executes the copy command sql and copies the results to w. func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { - if err := pgConn.lock(); err != nil { - return "", err - } + pgConn.lock() select { case <-ctx.Done(): @@ -867,9 +849,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm // Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r // could still block. func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { - if err := pgConn.lock(); err != nil { - return "", err - } + pgConn.lock() defer pgConn.unlock() select { @@ -1251,17 +1231,14 @@ func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFor // ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a // transaction is already in progress or SQL contains transaction control statements. func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader { - multiResult := &MultiResultReader{ + pgConn.lock() + + pgConn.multiResultReader = MultiResultReader{ pgConn: pgConn, ctx: ctx, cleanupContextDeadline: func() {}, } - - if err := pgConn.lock(); err != nil { - multiResult.closed = true - multiResult.err = ctx.Err() - return multiResult - } + multiResult := &pgConn.multiResultReader select { case <-ctx.Done(): diff --git a/pgconn_test.go b/pgconn_test.go index fd57face..3be61be8 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -690,11 +690,9 @@ func TestConnLocking(t *testing.T) { defer closeConn(t, pgConn) mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") - results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() - assert.Error(t, err) - assert.Equal(t, "connection busy", err.Error()) + require.Panics(t, func() { pgConn.Exec(context.Background(), "select 'Hello, world'") }) - results, err = mrr.ReadAll() + results, err := mrr.ReadAll() assert.NoError(t, err) assert.Len(t, results, 1) assert.Nil(t, results[0].Err) From 9d30dad837720c6b53dbf37fcb413afbd1d94045 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 18 Apr 2019 22:52:07 -0500 Subject: [PATCH 079/290] Do not buffer results in benchmarks --- benchmark_test.go | 123 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 115 insertions(+), 8 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index 959e86be..000dfd1b 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -1,6 +1,7 @@ package pgconn_test import ( + "bytes" "context" "os" "testing" @@ -41,11 +42,42 @@ func BenchmarkExec(b *testing.B) { require.Nil(b, err) defer closeConn(b, conn) + expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} + b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := conn.Exec(context.Background(), "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date").ReadAll() - require.Nil(b, err) + mrr := conn.Exec(context.Background(), "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") + + for mrr.NextResult() { + rr := mrr.ResultReader() + + rowCount := 0 + for rr.NextRow() { + rowCount += 1 + if len(rr.Values()) != len(expectedValues) { + b.Fatalf("unexpected number of values: %d", len(rr.Values())) + } + for i := range rr.Values() { + if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + } + } + } + _, err = rr.Close() + + if err != nil { + b.Fatal(err) + } + if rowCount != 1 { + b.Fatalf("unexpected rowCount: %d", rowCount) + } + } + + err := mrr.Close() + if err != nil { + b.Fatal(err) + } } } @@ -54,14 +86,45 @@ func BenchmarkExecPossibleToCancel(b *testing.B) { require.Nil(b, err) defer closeConn(b, conn) + expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} + b.ResetTimer() ctx, cancel := context.WithCancel(context.Background()) defer cancel() for i := 0; i < b.N; i++ { - _, err := conn.Exec(ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date").ReadAll() - require.Nil(b, err) + mrr := conn.Exec(ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") + + for mrr.NextResult() { + rr := mrr.ResultReader() + + rowCount := 0 + for rr.NextRow() { + rowCount += 1 + if len(rr.Values()) != len(expectedValues) { + b.Fatalf("unexpected number of values: %d", len(rr.Values())) + } + for i := range rr.Values() { + if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + } + } + } + _, err = rr.Close() + + if err != nil { + b.Fatal(err) + } + if rowCount != 1 { + b.Fatalf("unexpected rowCount: %d", rowCount) + } + } + + err := mrr.Close() + if err != nil { + b.Fatal(err) + } } } @@ -73,11 +136,33 @@ func BenchmarkExecPrepared(b *testing.B) { _, err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) require.Nil(b, err) + expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} + b.ResetTimer() for i := 0; i < b.N; i++ { - result := conn.ExecPrepared(context.Background(), "ps1", nil, nil, nil).Read() - require.Nil(b, result.Err) + rr := conn.ExecPrepared(context.Background(), "ps1", nil, nil, nil) + + rowCount := 0 + for rr.NextRow() { + rowCount += 1 + if len(rr.Values()) != len(expectedValues) { + b.Fatalf("unexpected number of values: %d", len(rr.Values())) + } + for i := range rr.Values() { + if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + } + } + } + _, err = rr.Close() + + if err != nil { + b.Fatal(err) + } + if rowCount != 1 { + b.Fatalf("unexpected rowCount: %d", rowCount) + } } } @@ -92,10 +177,32 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { _, err = conn.Prepare(ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) require.Nil(b, err) + expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} + b.ResetTimer() for i := 0; i < b.N; i++ { - result := conn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() - require.Nil(b, result.Err) + rr := conn.ExecPrepared(ctx, "ps1", nil, nil, nil) + + rowCount := 0 + for rr.NextRow() { + rowCount += 1 + if len(rr.Values()) != len(expectedValues) { + b.Fatalf("unexpected number of values: %d", len(rr.Values())) + } + for i := range rr.Values() { + if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + } + } + } + _, err = rr.Close() + + if err != nil { + b.Fatal(err) + } + if rowCount != 1 { + b.Fatalf("unexpected rowCount: %d", rowCount) + } } } From 2383561e4d1bbf50fde6a214aa04f296764e265f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 18 Apr 2019 23:17:28 -0500 Subject: [PATCH 080/290] Use 0-alloc pgproto3/v2 --- auth_scram.go | 2 +- go.mod | 1 + go.sum | 2 ++ pgconn.go | 52 +++++++++++++++++++++++++++----------------------- pgconn_test.go | 10 +++++----- 5 files changed, 37 insertions(+), 30 deletions(-) diff --git a/auth_scram.go b/auth_scram.go index b78a236a..50fbff40 100644 --- a/auth_scram.go +++ b/auth_scram.go @@ -21,7 +21,7 @@ import ( "fmt" "strconv" - "github.com/jackc/pgproto3" + "github.com/jackc/pgproto3/v2" "golang.org/x/crypto/pbkdf2" "golang.org/x/text/secure/precis" ) diff --git a/go.mod b/go.mod index 09b4471d..232df737 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgproto3 v1.1.0 + github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190419041544-9b6a681f50bf github.com/pkg/errors v0.8.1 github.com/stretchr/testify v1.3.0 golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a diff --git a/go.sum b/go.sum index 8872aac1..8e0e2c9f 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190419041544-9b6a681f50bf h1:wI8d/uq9/RfZOe6bKOpC4Skd4VgkTIGZqxmHu6IQGb8= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190419041544-9b6a681f50bf/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/pgconn.go b/pgconn.go index 7e8909ea..7bc93435 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1,6 +1,7 @@ package pgconn import ( + "bytes" "context" "crypto/md5" "crypto/tls" @@ -17,7 +18,7 @@ import ( "time" "github.com/jackc/pgio" - "github.com/jackc/pgproto3" + "github.com/jackc/pgproto3/v2" ) var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC) @@ -436,20 +437,23 @@ func (pgConn *PgConn) ParameterStatus(key string) string { } // CommandTag is the result of an Exec function -type CommandTag string +type CommandTag []byte // RowsAffected returns the number of rows affected. If the CommandTag was not // for a row affecting command (e.g. "CREATE TABLE") then it returns 0. func (ct CommandTag) RowsAffected() int64 { - s := string(ct) - index := strings.LastIndex(s, " ") - if index == -1 { + idx := bytes.LastIndexByte([]byte(ct), ' ') + if idx == -1 { return 0 } - n, _ := strconv.ParseInt(s[index+1:], 10, 64) + n, _ := strconv.ParseInt(string([]byte(ct)[idx+1:]), 10, 64) return n } +func (ct CommandTag) String() string { + return string(ct) +} + // preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == // true. Otherwise returns err. func preferContextOverNetTimeoutError(ctx context.Context, err error) error { @@ -756,7 +760,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by result := &pgConn.resultReader if len(paramValues) > math.MaxUint16 { - result.concludeCommand("", fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) + result.concludeCommand(nil, fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) result.closed = true pgConn.unlock() return result @@ -764,7 +768,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by select { case <-ctx.Done(): - result.concludeCommand("", ctx.Err()) + result.concludeCommand(nil, ctx.Err()) result.closed = true pgConn.unlock() return result @@ -783,7 +787,7 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { _, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - result.concludeCommand("", err) + result.concludeCommand(nil, err) result.cleanupContextDeadline() result.closed = true pgConn.unlock() @@ -797,7 +801,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm select { case <-ctx.Done(): pgConn.unlock() - return "", ctx.Err() + return nil, ctx.Err() default: } cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) @@ -812,7 +816,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm pgConn.hardClose() pgConn.unlock() - return "", preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } // Read results @@ -822,7 +826,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() - return "", preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -831,7 +835,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm _, err := w.Write(msg.Data) if err != nil { pgConn.hardClose() - return "", err + return nil, err } case *pgproto3.ReadyForQuery: pgConn.unlock() @@ -854,7 +858,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co select { case <-ctx.Done(): - return "", ctx.Err() + return nil, ctx.Err() default: } cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) @@ -867,7 +871,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return "", preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } // Read until copy in response or error. @@ -878,7 +882,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() - return "", preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -908,7 +912,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return "", preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } } @@ -917,7 +921,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() - return "", preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -939,7 +943,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return "", preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } // Read results @@ -947,7 +951,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() - return "", preferContextOverNetTimeoutError(ctx, err) + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -1145,7 +1149,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { for !rr.commandConcluded { _, err := rr.receiveMessage() if err != nil { - return "", rr.err + return nil, rr.err } } @@ -1153,7 +1157,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { for { msg, err := rr.receiveMessage() if err != nil { - return "", rr.err + return nil, rr.err } switch msg.(type) { @@ -1176,7 +1180,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error } if err != nil { - rr.concludeCommand("", err) + rr.concludeCommand(nil, err) rr.cleanupContextDeadline() rr.closed = true if rr.multiResultReader == nil { @@ -1192,7 +1196,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error case *pgproto3.CommandComplete: rr.concludeCommand(CommandTag(msg.CommandTag), nil) case *pgproto3.ErrorResponse: - rr.concludeCommand("", errorResponseToPgError(msg)) + rr.concludeCommand(nil, errorResponseToPgError(msg)) } return msg, nil diff --git a/pgconn_test.go b/pgconn_test.go index 3be61be8..2b1e68a3 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -475,7 +475,7 @@ func TestConnExecParamsCanceled(t *testing.T) { } assert.Equal(t, 0, rowCount) commandTag, err := result.Close() - assert.Equal(t, pgconn.CommandTag(""), commandTag) + assert.Equal(t, pgconn.CommandTag(nil), commandTag) assert.Equal(t, context.DeadlineExceeded, err) assert.False(t, pgConn.IsAlive()) @@ -601,7 +601,7 @@ func TestConnExecPreparedCanceled(t *testing.T) { } assert.Equal(t, 0, rowCount) commandTag, err := result.Close() - assert.Equal(t, pgconn.CommandTag(""), commandTag) + assert.Equal(t, pgconn.CommandTag(nil), commandTag) assert.Equal(t, context.DeadlineExceeded, err) assert.False(t, pgConn.IsAlive()) } @@ -958,7 +958,7 @@ func TestConnCopyToCanceled(t *testing.T) { defer cancel() res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout") assert.Equal(t, context.DeadlineExceeded, err) - assert.Equal(t, pgconn.CommandTag(""), res) + assert.Equal(t, pgconn.CommandTag(nil), res) assert.False(t, pgConn.IsAlive()) } @@ -977,7 +977,7 @@ func TestConnCopyToPrecanceled(t *testing.T) { res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout") require.Error(t, err) require.Equal(t, context.Canceled, err) - assert.Equal(t, pgconn.CommandTag(""), res) + assert.Equal(t, pgconn.CommandTag(nil), res) ensureConnValid(t, pgConn) } @@ -1084,7 +1084,7 @@ func TestConnCopyFromPrecanceled(t *testing.T) { ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") require.Error(t, err) require.Equal(t, context.Canceled, err) - assert.Equal(t, pgconn.CommandTag(""), ct) + assert.Equal(t, pgconn.CommandTag(nil), ct) ensureConnValid(t, pgConn) } From 16412e56e22d0ae96c5c8bf95b512562c71cbc80 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 19 Apr 2019 14:24:51 -0500 Subject: [PATCH 081/290] 0 alloc context to deadline --- chan_to_set_deadline.go | 51 ++++++++++++++++ pgconn.go | 127 ++++++++++++++-------------------------- 2 files changed, 95 insertions(+), 83 deletions(-) create mode 100644 chan_to_set_deadline.go diff --git a/chan_to_set_deadline.go b/chan_to_set_deadline.go new file mode 100644 index 00000000..04bb8fde --- /dev/null +++ b/chan_to_set_deadline.go @@ -0,0 +1,51 @@ +package pgconn + +import ( + "time" +) + +var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC) + +type setDeadliner interface { + SetDeadline(time.Time) error +} + +type chanToSetDeadline struct { + cleanupChan chan struct{} + conn setDeadliner + deadlineWasSet bool + cleanupComplete bool +} + +func (this *chanToSetDeadline) start(doneChan <-chan struct{}, conn setDeadliner) { + if this.cleanupChan == nil { + this.cleanupChan = make(chan struct{}) + } + this.conn = conn + this.deadlineWasSet = false + this.cleanupComplete = false + + if doneChan != nil { + go func() { + select { + case <-doneChan: + conn.SetDeadline(deadlineTime) + this.deadlineWasSet = true + <-this.cleanupChan + case <-this.cleanupChan: + } + }() + } else { + this.cleanupComplete = true + } +} + +func (this *chanToSetDeadline) cleanup() { + if !this.cleanupComplete { + this.cleanupChan <- struct{}{} + if this.deadlineWasSet { + this.conn.SetDeadline(time.Time{}) + } + this.cleanupComplete = true + } +} diff --git a/pgconn.go b/pgconn.go index 7bc93435..6ff0d39f 100644 --- a/pgconn.go +++ b/pgconn.go @@ -15,14 +15,11 @@ import ( "strconv" "strings" "sync" - "time" "github.com/jackc/pgio" "github.com/jackc/pgproto3/v2" ) -var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC) - // PgError represents an error reported by the PostgreSQL server. See // http://www.postgresql.org/docs/11/static/protocol-error-fields.html for // detailed field description. @@ -100,9 +97,10 @@ type PgConn struct { bufferingReceiveErr error // Reusable / preallocated resources - wbuf []byte // write buffer - resultReader ResultReader - multiResultReader MultiResultReader + wbuf []byte // write buffer + resultReader ResultReader + multiResultReader MultiResultReader + doneChanToDeadline chanToSetDeadline } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) @@ -382,8 +380,8 @@ func (pgConn *PgConn) Close(ctx context.Context) error { defer pgConn.conn.Close() - cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanupContext() + pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) + defer pgConn.doneChanToDeadline.cleanup() _, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) if err != nil { @@ -463,38 +461,6 @@ func preferContextOverNetTimeoutError(ctx context.Context, err error) error { return err } -// contextDoneToConnDeadline starts a goroutine that will set an immediate deadline on conn after reading from -// ctx.Done(). The returned cleanup function must be called to terminate this goroutine. The cleanup function is safe to -// call multiple times. -func contextDoneToConnDeadline(ctx context.Context, conn net.Conn) (cleanup func()) { - if ctx.Done() != nil { - deadlineWasSet := false - doneChan := make(chan struct{}) - go func() { - select { - case <-ctx.Done(): - conn.SetDeadline(deadlineTime) - deadlineWasSet = true - <-doneChan - case <-doneChan: - } - }() - - finished := false - return func() { - if !finished { - doneChan <- struct{}{} - if deadlineWasSet { - conn.SetDeadline(time.Time{}) - } - finished = true - } - } - } - - return func() {} -} - type PreparedStatementDescription struct { Name string SQL string @@ -512,8 +478,8 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ return nil, ctx.Err() default: } - cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanupContextDeadline() + pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) + defer pgConn.doneChanToDeadline.cleanup() buf := pgConn.wbuf buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) @@ -599,8 +565,9 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { } defer cancelConn.Close() - cleanupContext := contextDoneToConnDeadline(ctx, cancelConn) - defer cleanupContext() + var doneChanToDeadline chanToSetDeadline + doneChanToDeadline.start(ctx.Done(), cancelConn) + defer doneChanToDeadline.cleanup() buf := make([]byte, 16) binary.BigEndian.PutUint32(buf[0:4], 16) @@ -624,16 +591,16 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // received. func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { pgConn.lock() + defer pgConn.unlock() select { case <-ctx.Done(): - pgConn.unlock() return ctx.Err() default: } - cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanupContextDeadline() - defer pgConn.unlock() + + pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) + defer pgConn.doneChanToDeadline.cleanup() for { msg, err := pgConn.ReceiveMessage() @@ -657,9 +624,8 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { pgConn.lock() pgConn.multiResultReader = MultiResultReader{ - pgConn: pgConn, - ctx: ctx, - cleanupContextDeadline: func() {}, + pgConn: pgConn, + ctx: ctx, } multiResult := &pgConn.multiResultReader @@ -671,7 +637,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { return multiResult default: } - multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) + pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) @@ -679,7 +645,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { _, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - multiResult.cleanupContextDeadline() + pgConn.doneChanToDeadline.cleanup() multiResult.closed = true multiResult.err = preferContextOverNetTimeoutError(ctx, err) pgConn.unlock() @@ -753,9 +719,8 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by pgConn.lock() pgConn.resultReader = ResultReader{ - pgConn: pgConn, - ctx: ctx, - cleanupContextDeadline: func() {}, + pgConn: pgConn, + ctx: ctx, } result := &pgConn.resultReader @@ -774,7 +739,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by return result default: } - result.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) + pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) return result } @@ -788,7 +753,7 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { if err != nil { pgConn.hardClose() result.concludeCommand(nil, err) - result.cleanupContextDeadline() + pgConn.doneChanToDeadline.cleanup() result.closed = true pgConn.unlock() } @@ -804,8 +769,8 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm return nil, ctx.Err() default: } - cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanupContextDeadline() + pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) + defer pgConn.doneChanToDeadline.cleanup() // Send copy to command buf := pgConn.wbuf @@ -861,8 +826,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co return nil, ctx.Err() default: } - cleanupContextDeadline := contextDoneToConnDeadline(ctx, pgConn.conn) - defer cleanupContextDeadline() + pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) + defer pgConn.doneChanToDeadline.cleanup() // Send copy to command buf := pgConn.wbuf @@ -967,9 +932,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co // MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. type MultiResultReader struct { - pgConn *PgConn - ctx context.Context - cleanupContextDeadline func() + pgConn *PgConn + ctx context.Context rr *ResultReader @@ -993,7 +957,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) msg, err := mrr.pgConn.ReceiveMessage() if err != nil { - mrr.cleanupContextDeadline() + mrr.pgConn.doneChanToDeadline.cleanup() mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) mrr.closed = true mrr.pgConn.hardClose() @@ -1002,7 +966,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) switch msg := msg.(type) { case *pgproto3.ReadyForQuery: - mrr.cleanupContextDeadline() + mrr.pgConn.doneChanToDeadline.cleanup() mrr.closed = true mrr.pgConn.unlock() case *pgproto3.ErrorResponse: @@ -1023,11 +987,10 @@ func (mrr *MultiResultReader) NextResult() bool { switch msg := msg.(type) { case *pgproto3.RowDescription: mrr.pgConn.resultReader = ResultReader{ - pgConn: mrr.pgConn, - multiResultReader: mrr, - ctx: mrr.ctx, - cleanupContextDeadline: func() {}, - fieldDescriptions: msg.Fields, + pgConn: mrr.pgConn, + multiResultReader: mrr, + ctx: mrr.ctx, + fieldDescriptions: msg.Fields, } mrr.rr = &mrr.pgConn.resultReader return true @@ -1066,10 +1029,9 @@ func (mrr *MultiResultReader) Close() error { // ResultReader is a reader for the result of a single query. type ResultReader struct { - pgConn *PgConn - multiResultReader *MultiResultReader - ctx context.Context - cleanupContextDeadline func() + pgConn *PgConn + multiResultReader *MultiResultReader + ctx context.Context fieldDescriptions []pgproto3.FieldDescription rowValues [][]byte @@ -1162,7 +1124,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { switch msg.(type) { case *pgproto3.ReadyForQuery: - rr.cleanupContextDeadline() + rr.pgConn.doneChanToDeadline.cleanup() rr.pgConn.unlock() return rr.commandTag, rr.err } @@ -1181,7 +1143,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error if err != nil { rr.concludeCommand(nil, err) - rr.cleanupContextDeadline() + rr.pgConn.doneChanToDeadline.cleanup() rr.closed = true if rr.multiResultReader == nil { rr.pgConn.hardClose() @@ -1238,9 +1200,8 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR pgConn.lock() pgConn.multiResultReader = MultiResultReader{ - pgConn: pgConn, - ctx: ctx, - cleanupContextDeadline: func() {}, + pgConn: pgConn, + ctx: ctx, } multiResult := &pgConn.multiResultReader @@ -1252,13 +1213,13 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR return multiResult default: } - multiResult.cleanupContextDeadline = contextDoneToConnDeadline(ctx, pgConn.conn) + pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) _, err := pgConn.conn.Write(batch.buf) if err != nil { pgConn.hardClose() - multiResult.cleanupContextDeadline() + pgConn.doneChanToDeadline.cleanup() multiResult.closed = true multiResult.err = preferContextOverNetTimeoutError(ctx, err) pgConn.unlock() From 7bb6c2f3e9826f233e799c894439f87ac93e007f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 19 Apr 2019 15:52:12 -0500 Subject: [PATCH 082/290] Unify locked and closed into status No longer panic on locking busy conn --- pgconn.go | 81 ++++++++++++++++++++++++++++++++++++-------------- pgconn_test.go | 6 ++-- 2 files changed, 63 insertions(+), 24 deletions(-) diff --git a/pgconn.go b/pgconn.go index 6ff0d39f..7a9a42e4 100644 --- a/pgconn.go +++ b/pgconn.go @@ -20,6 +20,13 @@ import ( "github.com/jackc/pgproto3/v2" ) +const ( + connStatusUninitialized = iota + connStatusClosed + connStatusIdle + connStatusBusy +) + // PgError represents an error reported by the PostgreSQL server. See // http://www.postgresql.org/docs/11/static/protocol-error-fields.html for // detailed field description. @@ -88,8 +95,7 @@ type PgConn struct { Config *Config - locked bool - closed bool + status byte // One of connStatus* constants bufferingReceive bool bufferingReceiveMux sync.Mutex @@ -217,6 +223,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig return nil, err } case *pgproto3.ReadyForQuery: + pgConn.status = connStatusIdle if config.AfterConnectFunc != nil { err := config.AfterConnectFunc(ctx, pgConn) if err != nil { @@ -373,10 +380,10 @@ func (pgConn *PgConn) SecretKey() uint32 { // sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The // underlying net.Conn.Close() will always be called regardless of any other errors. func (pgConn *PgConn) Close(ctx context.Context) error { - if pgConn.closed { + if pgConn.status == connStatusClosed { return nil } - pgConn.closed = true + pgConn.status = connStatusClosed defer pgConn.conn.Close() @@ -398,34 +405,41 @@ func (pgConn *PgConn) Close(ctx context.Context) error { // hardClose closes the underlying connection without sending the exit message. func (pgConn *PgConn) hardClose() error { - if pgConn.closed { + if pgConn.status == connStatusClosed { return nil } - pgConn.closed = true + pgConn.status = connStatusClosed return pgConn.conn.Close() } // TODO - rethink how to report status. At the moment this is just a temporary measure so pgx.Conn can detect deatch of // underlying connection. func (pgConn *PgConn) IsAlive() bool { - return !pgConn.closed + return pgConn.status >= connStatusIdle } // lock locks the connection. It panics if the connection is already locked or is closed. -func (pgConn *PgConn) lock() { - if pgConn.locked { - panic("connection busy") // This only should be possible in case of an application bug. +func (pgConn *PgConn) lock() error { + switch pgConn.status { + case connStatusBusy: + return errors.New("connection busy") // This only should be possible in case of an application bug. + case connStatusClosed: + return errors.New("conn closed") + case connStatusUninitialized: + return errors.New("conn uninitialized") } - - pgConn.locked = true + pgConn.status = connStatusBusy + return nil } func (pgConn *PgConn) unlock() { - if !pgConn.locked { + switch pgConn.status { + case connStatusBusy: + pgConn.status = connStatusIdle + case connStatusClosed: + default: panic("BUG: cannot unlock unlocked connection") // This should only be possible if there is a bug in this package. } - - pgConn.locked = false } // ParameterStatus returns the value of a parameter reported by the server (e.g. @@ -470,7 +484,9 @@ type PreparedStatementDescription struct { // Prepare creates a prepared statement. func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) { - pgConn.lock() + if err := pgConn.lock(); err != nil { + return nil, err + } defer pgConn.unlock() select { @@ -590,7 +606,9 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not // received. func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { - pgConn.lock() + if err := pgConn.lock(); err != nil { + return err + } defer pgConn.unlock() select { @@ -621,7 +639,12 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { // // Prefer ExecParams unless executing arbitrary SQL that may contain multiple queries. func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { - pgConn.lock() + if err := pgConn.lock(); err != nil { + return &MultiResultReader{ + closed: true, + err: err, + } + } pgConn.multiResultReader = MultiResultReader{ pgConn: pgConn, @@ -716,7 +739,12 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa } func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader { - pgConn.lock() + if err := pgConn.lock(); err != nil { + return &ResultReader{ + closed: true, + err: err, + } + } pgConn.resultReader = ResultReader{ pgConn: pgConn, @@ -761,7 +789,9 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { // CopyTo executes the copy command sql and copies the results to w. func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { - pgConn.lock() + if err := pgConn.lock(); err != nil { + return nil, err + } select { case <-ctx.Done(): @@ -818,7 +848,9 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm // Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r // could still block. func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { - pgConn.lock() + if err := pgConn.lock(); err != nil { + return nil, err + } defer pgConn.unlock() select { @@ -1197,7 +1229,12 @@ func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFor // ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a // transaction is already in progress or SQL contains transaction control statements. func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader { - pgConn.lock() + if err := pgConn.lock(); err != nil { + return &MultiResultReader{ + closed: true, + err: err, + } + } pgConn.multiResultReader = MultiResultReader{ pgConn: pgConn, diff --git a/pgconn_test.go b/pgconn_test.go index 2b1e68a3..2ad02830 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -690,9 +690,11 @@ func TestConnLocking(t *testing.T) { defer closeConn(t, pgConn) mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") - require.Panics(t, func() { pgConn.Exec(context.Background(), "select 'Hello, world'") }) + results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + assert.Error(t, err) + assert.Equal(t, "connection busy", err.Error()) - results, err := mrr.ReadAll() + results, err = mrr.ReadAll() assert.NoError(t, err) assert.Len(t, results, 1) assert.Nil(t, results[0].Err) From 3710e52a9a125c406f4a6f682ca9a67e695c38f6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 19 Apr 2019 16:16:55 -0500 Subject: [PATCH 083/290] Add named error for conn busy --- pgconn.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index 7a9a42e4..7d437434 100644 --- a/pgconn.go +++ b/pgconn.go @@ -84,6 +84,10 @@ type NotificationHandler func(*PgConn, *Notification) // PostgreSQL server refuses to use TLS var ErrTLSRefused = errors.New("server refused TLS connection") +// ErrConnBusy occurs when the connection is busy (for example, in the middle of reading query results) and another +// action is attempted. +var ErrConnBusy = errors.New("conn is busy") + // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { conn net.Conn // the underlying TCP or unix domain socket connection @@ -422,7 +426,7 @@ func (pgConn *PgConn) IsAlive() bool { func (pgConn *PgConn) lock() error { switch pgConn.status { case connStatusBusy: - return errors.New("connection busy") // This only should be possible in case of an application bug. + return ErrConnBusy // This only should be possible in case of an application bug. case connStatusClosed: return errors.New("conn closed") case connStatusUninitialized: From 9f774761bacc37fb32d6a8718e3aa9ccd9035de2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Apr 2019 10:59:50 -0500 Subject: [PATCH 084/290] Fix TestConnLocking --- pgconn_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgconn_test.go b/pgconn_test.go index 2ad02830..d31e8cc9 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -692,7 +692,7 @@ func TestConnLocking(t *testing.T) { mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() assert.Error(t, err) - assert.Equal(t, "connection busy", err.Error()) + assert.Equal(t, pgconn.ErrConnBusy, err) results, err = mrr.ReadAll() assert.NoError(t, err) From 39e6ff5766bde4b27a085b021b11b5b3be18a276 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Apr 2019 11:11:09 -0500 Subject: [PATCH 085/290] Prevent deadlock with huge batches --- pgconn.go | 22 +++++++++++--------- pgconn_test.go | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 9 deletions(-) diff --git a/pgconn.go b/pgconn.go index 7d437434..4f3cdd66 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1257,15 +1257,19 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) - _, err := pgConn.conn.Write(batch.buf) - if err != nil { - pgConn.hardClose() - pgConn.doneChanToDeadline.cleanup() - multiResult.closed = true - multiResult.err = preferContextOverNetTimeoutError(ctx, err) - pgConn.unlock() - return multiResult - } + + // A large batch can deadlock without concurrent reading and writing. If the Write fails the underlying net.Conn is + // closed. This is all that can be done without introducing a race condition or adding a concurrent safe communication + // channel to relay the error back. The practical effect of this is that the underlying Write error is not reported. + // The error the code reading the batch results receives will be a closed connection error. + // + // See https://github.com/jackc/pgx/issues/374. + go func() { + _, err := pgConn.conn.Write(batch.buf) + if err != nil { + pgConn.conn.Close() + } + }() return multiResult } diff --git a/pgconn_test.go b/pgconn_test.go index d31e8cc9..25cc3ee3 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -682,6 +682,60 @@ func TestConnExecBatchPrecanceled(t *testing.T) { ensureConnValid(t, pgConn) } +// Without concurrent reading and writing large batches can deadlock. +// +// See https://github.com/jackc/pgx/issues/374. +func TestConnExecBatchHuge(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + batch := &pgconn.Batch{} + + queryCount := 100000 + args := make([]string, queryCount) + + for i := range args { + args[i] = strconv.Itoa(i) + batch.ExecParams("select $1::text", [][]byte{[]byte(args[i])}, nil, nil, nil) + } + + results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.NoError(t, err) + require.Len(t, results, queryCount) + + for i := range args { + require.Len(t, results[i].Rows, 1) + require.Equal(t, args[i], string(results[i].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[i].CommandTag)) + } +} + +func TestConnExecBatchImplicitTransaction(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), "create temporary table t(id int)").ReadAll() + require.NoError(t, err) + + batch := &pgconn.Batch{} + + batch.ExecParams("insert into t(id) values(1)", nil, nil, nil, nil) + batch.ExecParams("insert into t(id) values(2)", nil, nil, nil, nil) + batch.ExecParams("insert into t(id) values(3)", nil, nil, nil, nil) + batch.ExecParams("select 1/0", nil, nil, nil, nil) + _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.Error(t, err) + + result := pgConn.ExecParams(context.Background(), "select count(*) from t", nil, nil, nil, nil).Read() + require.Equal(t, "0", string(result.Rows[0][0])) +} + func TestConnLocking(t *testing.T) { t.Parallel() From cd629965e6c1920f124691c4004507467fe2069c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Apr 2019 12:57:52 -0500 Subject: [PATCH 086/290] Use golang.org/x/xerrors --- auth_scram.go | 6 +++--- config.go | 14 +++++++------- go.mod | 3 ++- go.sum | 2 ++ pgconn.go | 9 ++++----- 5 files changed, 18 insertions(+), 16 deletions(-) diff --git a/auth_scram.go b/auth_scram.go index 50fbff40..5baa680b 100644 --- a/auth_scram.go +++ b/auth_scram.go @@ -17,13 +17,13 @@ import ( "crypto/rand" "crypto/sha256" "encoding/base64" - "errors" "fmt" "strconv" "github.com/jackc/pgproto3/v2" "golang.org/x/crypto/pbkdf2" "golang.org/x/text/secure/precis" + errors "golang.org/x/xerrors" ) const clientNonceLen = 18 @@ -181,12 +181,12 @@ func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error { var err error sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr)) if err != nil { - return fmt.Errorf("invalid SCRAM salt received from server: %v", err) + return errors.Errorf("invalid SCRAM salt received from server: %w", err) } sc.iterations, err = strconv.Atoi(string(iterationsStr)) if err != nil || sc.iterations <= 0 { - return fmt.Errorf("invalid SCRAM iteration count received from server: %s", iterationsStr) + return errors.Errorf("invalid SCRAM iteration count received from server: %w", err) } if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) { diff --git a/config.go b/config.go index d392924c..c751cc0d 100644 --- a/config.go +++ b/config.go @@ -18,7 +18,7 @@ import ( "time" "github.com/jackc/pgpassfile" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" ) type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error @@ -195,7 +195,7 @@ func ParseConfig(connString string) (*Config, error) { port, err := parsePort(portStr) if err != nil { - return nil, fmt.Errorf("invalid port: %v", settings["port"]) + return nil, errors.Errorf("invalid port: %w", err) } var tlsConfigs []*tls.Config @@ -240,7 +240,7 @@ func ParseConfig(connString string) (*Config, error) { if settings["target_session_attrs"] == "read-write" { config.AfterConnectFunc = AfterConnectTargetSessionAttrsReadWrite } else if settings["target_session_attrs"] != "any" { - return nil, fmt.Errorf("unknown target_session_attrs value %v", settings["target_session_attrs"]) + return nil, errors.Errorf("unknown target_session_attrs value: %v", settings["target_session_attrs"]) } return config, nil @@ -409,11 +409,11 @@ func configTLS(settings map[string]string) ([]*tls.Config, error) { caPath := sslrootcert caCert, err := ioutil.ReadFile(caPath) if err != nil { - return nil, errors.Wrapf(err, "unable to read CA file %q", caPath) + return nil, errors.Errorf("unable to read CA file: %w", err) } if !caCertPool.AppendCertsFromPEM(caCert) { - return nil, errors.Wrap(err, "unable to add CA to cert pool") + return nil, errors.Errorf("unable to add CA to cert pool: %w", err) } tlsConfig.RootCAs = caCertPool @@ -421,13 +421,13 @@ func configTLS(settings map[string]string) ([]*tls.Config, error) { } if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") { - return nil, fmt.Errorf(`both "sslcert" and "sslkey" are required`) + return nil, errors.New(`both "sslcert" and "sslkey" are required`) } if sslcert != "" && sslkey != "" { cert, err := tls.LoadX509KeyPair(sslcert, sslkey) if err != nil { - return nil, errors.Wrap(err, "unable to read cert") + return nil, errors.Errorf("unable to read cert: %w", err) } tlsConfig.Certificates = []tls.Certificate{cert} diff --git a/go.mod b/go.mod index 232df737..dda76fe1 100644 --- a/go.mod +++ b/go.mod @@ -5,10 +5,11 @@ go 1.12 require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3 v1.1.0 + github.com/jackc/pgproto3 v1.1.0 // indirect github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190419041544-9b6a681f50bf github.com/pkg/errors v0.8.1 github.com/stretchr/testify v1.3.0 golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a golang.org/x/text v0.3.0 + golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 ) diff --git a/go.sum b/go.sum index 8e0e2c9f..5a100ff0 100644 --- a/go.sum +++ b/go.sum @@ -22,3 +22,5 @@ golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaE golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 h1:PPwnA7z1Pjf7XYaBP9GL1VAMZmcIWyFz7QCMSIIa3Bg= +golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/pgconn.go b/pgconn.go index 4f3cdd66..14377beb 100644 --- a/pgconn.go +++ b/pgconn.go @@ -7,8 +7,6 @@ import ( "crypto/tls" "encoding/binary" "encoding/hex" - "errors" - "fmt" "io" "math" "net" @@ -18,6 +16,7 @@ import ( "github.com/jackc/pgio" "github.com/jackc/pgproto3/v2" + errors "golang.org/x/xerrors" ) const ( @@ -232,7 +231,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig err := config.AfterConnectFunc(ctx, pgConn) if err != nil { pgConn.conn.Close() - return nil, fmt.Errorf("AfterConnectFunc: %v", err) + return nil, errors.Errorf("AfterConnectFunc: %v", err) } } return pgConn, nil @@ -601,7 +600,7 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { _, err = cancelConn.Read(buf) if err != io.EOF { - return fmt.Errorf("Server failed to close connection after cancel query request: %v", preferContextOverNetTimeoutError(ctx, err)) + return errors.Errorf("Server failed to close connection after cancel query request: %w", preferContextOverNetTimeoutError(ctx, err)) } return nil @@ -757,7 +756,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by result := &pgConn.resultReader if len(paramValues) > math.MaxUint16 { - result.concludeCommand(nil, fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) + result.concludeCommand(nil, errors.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) result.closed = true pgConn.unlock() return result From 7a520059d9115a271068a920b7583a911bd3a509 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Apr 2019 13:01:59 -0500 Subject: [PATCH 087/290] Update to remove pgprotov3 ref --- go.mod | 3 +-- go.sum | 6 ++---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index dda76fe1..acbee593 100644 --- a/go.mod +++ b/go.mod @@ -5,8 +5,7 @@ go 1.12 require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3 v1.1.0 // indirect - github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190419041544-9b6a681f50bf + github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db github.com/pkg/errors v0.8.1 github.com/stretchr/testify v1.3.0 golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a diff --git a/go.sum b/go.sum index 5a100ff0..9160f187 100644 --- a/go.sum +++ b/go.sum @@ -6,10 +6,8 @@ github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= -github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190419041544-9b6a681f50bf h1:wI8d/uq9/RfZOe6bKOpC4Skd4VgkTIGZqxmHu6IQGb8= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190419041544-9b6a681f50bf/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db h1:UpaKn/gYxzH6/zWyRQH1S260zvKqwJJ4h8+Kf09ooh0= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= From f3b5f6b2753fb81b66507c5f42a55af75241d01c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Apr 2019 15:34:49 -0500 Subject: [PATCH 088/290] Allow skipping TestConnExecBatchHuge in short mode --- pgconn_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pgconn_test.go b/pgconn_test.go index 25cc3ee3..3fc15e7a 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -686,6 +686,10 @@ func TestConnExecBatchPrecanceled(t *testing.T) { // // See https://github.com/jackc/pgx/issues/374. func TestConnExecBatchHuge(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode.") + } + t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) From 0f8e1d30e2dc1a4f359761d5418126bb0e0685d5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Apr 2019 15:53:30 -0500 Subject: [PATCH 089/290] Link context errors and underlying conn errors Using golang.org/x/xerrors type errors both errors can be exposed. --- errors.go | 85 ++++++++++++++++++++++++++++++++++++++++++++++++++ pgconn.go | 76 ++++++++++---------------------------------- pgconn_test.go | 8 ++--- 3 files changed, 105 insertions(+), 64 deletions(-) create mode 100644 errors.go diff --git a/errors.go b/errors.go new file mode 100644 index 00000000..e42dae16 --- /dev/null +++ b/errors.go @@ -0,0 +1,85 @@ +package pgconn + +import ( + "context" + "net" + + errors "golang.org/x/xerrors" +) + +// ErrTLSRefused occurs when the connection attempt requires TLS and the +// PostgreSQL server refuses to use TLS +var ErrTLSRefused = errors.New("server refused TLS connection") + +// ErrConnBusy occurs when the connection is busy (for example, in the middle of reading query results) and another +// action is attempted. +var ErrConnBusy = errors.New("conn is busy") + +// PgError represents an error reported by the PostgreSQL server. See +// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for +// detailed field description. +type PgError struct { + Severity string + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string +} + +func (pe *PgError) Error() string { + return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" +} + +// linkedError connects two errors as if err wrapped next. +type linkedError struct { + err error + next error +} + +func (le *linkedError) Error() string { + return le.err.Error() +} + +func (le *linkedError) Is(target error) bool { + return errors.Is(le.err, target) +} + +func (le *linkedError) As(target interface{}) bool { + return errors.As(le.err, target) +} + +func (le *linkedError) Unwrap() error { + return le.next +} + +// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == +// true. Otherwise returns err. +func preferContextOverNetTimeoutError(ctx context.Context, err error) error { + if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil { + return ctx.Err() + } + return err +} + +// linkErrors connects outer and inner as if the the fully unwrapped outer wrapped inner. If either outer or inner is nil then the other is returned. +func linkErrors(outer, inner error) error { + if outer == nil { + return inner + } + if inner == nil { + return outer + } + return &linkedError{err: outer, next: inner} +} diff --git a/pgconn.go b/pgconn.go index 14377beb..2911211c 100644 --- a/pgconn.go +++ b/pgconn.go @@ -26,33 +26,6 @@ const ( connStatusBusy ) -// PgError represents an error reported by the PostgreSQL server. See -// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for -// detailed field description. -type PgError struct { - Severity string - Code string - Message string - Detail string - Hint string - Position int32 - InternalPosition int32 - InternalQuery string - Where string - SchemaName string - TableName string - ColumnName string - DataTypeName string - ConstraintName string - File string - Line int32 - Routine string -} - -func (pe *PgError) Error() string { - return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" -} - // Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from // LISTEN/NOTIFY notification. type Notice PgError @@ -79,14 +52,6 @@ type NoticeHandler func(*PgConn, *Notice) // notice event. type NotificationHandler func(*PgConn, *Notification) -// ErrTLSRefused occurs when the connection attempt requires TLS and the -// PostgreSQL server refuses to use TLS -var ErrTLSRefused = errors.New("server refused TLS connection") - -// ErrConnBusy occurs when the connection is busy (for example, in the middle of reading query results) and another -// action is attempted. -var ErrConnBusy = errors.New("conn is busy") - // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { conn net.Conn // the underlying TCP or unix domain socket connection @@ -395,12 +360,12 @@ func (pgConn *PgConn) Close(ctx context.Context) error { _, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) if err != nil { - return preferContextOverNetTimeoutError(ctx, err) + return linkErrors(ctx.Err(), err) } _, err = pgConn.conn.Read(make([]byte, 1)) if err != io.EOF { - return preferContextOverNetTimeoutError(ctx, err) + return linkErrors(ctx.Err(), err) } return pgConn.conn.Close() @@ -469,15 +434,6 @@ func (ct CommandTag) String() string { return string(ct) } -// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == -// true. Otherwise returns err. -func preferContextOverNetTimeoutError(ctx context.Context, err error) error { - if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil { - return ctx.Err() - } - return err -} - type PreparedStatementDescription struct { Name string SQL string @@ -508,7 +464,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ _, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } psd := &PreparedStatementDescription{Name: name, SQL: sql} @@ -520,7 +476,7 @@ readloop: msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } switch msg := msg.(type) { @@ -595,12 +551,12 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey)) _, err = cancelConn.Write(buf) if err != nil { - return preferContextOverNetTimeoutError(ctx, err) + return linkErrors(ctx.Err(), err) } _, err = cancelConn.Read(buf) if err != io.EOF { - return errors.Errorf("Server failed to close connection after cancel query request: %w", preferContextOverNetTimeoutError(ctx, err)) + return errors.Errorf("Server failed to close connection after cancel query request: %w", linkErrors(ctx.Err(), err)) } return nil @@ -626,7 +582,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { for { msg, err := pgConn.ReceiveMessage() if err != nil { - return preferContextOverNetTimeoutError(ctx, err) + return linkErrors(ctx.Err(), err) } switch msg.(type) { @@ -673,7 +629,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { pgConn.hardClose() pgConn.doneChanToDeadline.cleanup() multiResult.closed = true - multiResult.err = preferContextOverNetTimeoutError(ctx, err) + multiResult.err = linkErrors(ctx.Err(), err) pgConn.unlock() return multiResult } @@ -814,7 +770,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm pgConn.hardClose() pgConn.unlock() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } // Read results @@ -824,7 +780,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } switch msg := msg.(type) { @@ -871,7 +827,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } // Read until copy in response or error. @@ -882,7 +838,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } switch msg := msg.(type) { @@ -912,7 +868,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } } @@ -921,7 +877,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } switch msg := msg.(type) { @@ -943,7 +899,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } // Read results @@ -951,7 +907,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() - return nil, preferContextOverNetTimeoutError(ctx, err) + return nil, linkErrors(ctx.Err(), err) } switch msg := msg.(type) { diff --git a/pgconn_test.go b/pgconn_test.go index 3fc15e7a..30e6a425 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -18,7 +18,7 @@ import ( "time" "github.com/jackc/pgconn" - "github.com/pkg/errors" + errors "golang.org/x/xerrors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -907,7 +907,7 @@ func TestConnWaitForNotificationTimeout(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) err = pgConn.WaitForNotification(ctx) cancel() - require.Equal(t, context.DeadlineExceeded, err) + assert.True(t, errors.Is(err, context.DeadlineExceeded)) ensureConnValid(t, pgConn) } @@ -1017,7 +1017,7 @@ func TestConnCopyToCanceled(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout") - assert.Equal(t, context.DeadlineExceeded, err) + assert.True(t, errors.Is(err, context.DeadlineExceeded)) assert.Equal(t, pgconn.CommandTag(nil), res) assert.False(t, pgConn.IsAlive()) @@ -1108,7 +1108,7 @@ func TestConnCopyFromCanceled(t *testing.T) { ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") cancel() assert.Equal(t, int64(0), ct.RowsAffected()) - require.Equal(t, context.DeadlineExceeded, err) + assert.True(t, errors.Is(err, context.DeadlineExceeded)) assert.False(t, pgConn.IsAlive()) } From 7e0022ef6ba389ca1b8140e50e42624af1df312e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Apr 2019 16:48:24 -0500 Subject: [PATCH 090/290] Tag errors if no bytes sent to server --- errors.go | 4 ++++ pgconn.go | 58 +++++++++++++++++++++++++++++++------------------- pgconn_test.go | 30 ++++++++++++++++---------- 3 files changed, 59 insertions(+), 33 deletions(-) diff --git a/errors.go b/errors.go index e42dae16..4f8af407 100644 --- a/errors.go +++ b/errors.go @@ -15,6 +15,10 @@ var ErrTLSRefused = errors.New("server refused TLS connection") // action is attempted. var ErrConnBusy = errors.New("conn is busy") +// ErrNoBytesSent is used to annotate an error that occurred without sending any bytes to the server. This can be used +// to implement safe retry logic. ErrNoBytesSent will never occur alone. It will always be wrapped by another error. +var ErrNoBytesSent = errors.New("no bytes sent to server") + // PgError represents an error reported by the PostgreSQL server. See // http://www.postgresql.org/docs/11/static/protocol-error-fields.html for // detailed field description. diff --git a/pgconn.go b/pgconn.go index 2911211c..a4402a7d 100644 --- a/pgconn.go +++ b/pgconn.go @@ -444,13 +444,13 @@ type PreparedStatementDescription struct { // Prepare creates a prepared statement. func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) { if err := pgConn.lock(); err != nil { - return nil, err + return nil, linkErrors(err, ErrNoBytesSent) } defer pgConn.unlock() select { case <-ctx.Done(): - return nil, ctx.Err() + return nil, linkErrors(ctx.Err(), ErrNoBytesSent) default: } pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) @@ -461,9 +461,12 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) - _, err := pgConn.conn.Write(buf) + n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() + if n == 0 { + err = linkErrors(err, ErrNoBytesSent) + } return nil, linkErrors(ctx.Err(), err) } @@ -601,7 +604,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { if err := pgConn.lock(); err != nil { return &MultiResultReader{ closed: true, - err: err, + err: linkErrors(err, ErrNoBytesSent), } } @@ -614,7 +617,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { select { case <-ctx.Done(): multiResult.closed = true - multiResult.err = ctx.Err() + multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent) pgConn.unlock() return multiResult default: @@ -624,11 +627,14 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) - _, err := pgConn.conn.Write(buf) + n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() pgConn.doneChanToDeadline.cleanup() multiResult.closed = true + if n == 0 { + err = linkErrors(err, ErrNoBytesSent) + } multiResult.err = linkErrors(ctx.Err(), err) pgConn.unlock() return multiResult @@ -666,7 +672,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) - pgConn.execExtendedSuffix(buf, result) + pgConn.execExtendedSuffix(ctx, buf, result) return result } @@ -692,7 +698,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa buf := pgConn.wbuf buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) - pgConn.execExtendedSuffix(buf, result) + pgConn.execExtendedSuffix(ctx, buf, result) return result } @@ -701,7 +707,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by if err := pgConn.lock(); err != nil { return &ResultReader{ closed: true, - err: err, + err: linkErrors(err, ErrNoBytesSent), } } @@ -720,7 +726,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by select { case <-ctx.Done(): - result.concludeCommand(nil, ctx.Err()) + result.concludeCommand(nil, linkErrors(ctx.Err(), ErrNoBytesSent)) result.closed = true pgConn.unlock() return result @@ -731,15 +737,18 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by return result } -func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { +func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result *ResultReader) { buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) buf = (&pgproto3.Execute{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) - _, err := pgConn.conn.Write(buf) + n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - result.concludeCommand(nil, err) + if n == 0 { + err = linkErrors(err, ErrNoBytesSent) + } + result.concludeCommand(nil, linkErrors(ctx.Err(), err)) pgConn.doneChanToDeadline.cleanup() result.closed = true pgConn.unlock() @@ -749,13 +758,13 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { // CopyTo executes the copy command sql and copies the results to w. func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { if err := pgConn.lock(); err != nil { - return nil, err + return nil, linkErrors(err, ErrNoBytesSent) } select { case <-ctx.Done(): pgConn.unlock() - return nil, ctx.Err() + return nil, linkErrors(ctx.Err(), ErrNoBytesSent) default: } pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) @@ -765,11 +774,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) - _, err := pgConn.conn.Write(buf) + n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() pgConn.unlock() - + if n == 0 { + err = linkErrors(err, ErrNoBytesSent) + } return nil, linkErrors(ctx.Err(), err) } @@ -808,13 +819,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm // could still block. func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { if err := pgConn.lock(); err != nil { - return nil, err + return nil, linkErrors(err, ErrNoBytesSent) } defer pgConn.unlock() select { case <-ctx.Done(): - return nil, ctx.Err() + return nil, linkErrors(ctx.Err(), ErrNoBytesSent) default: } pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) @@ -824,9 +835,12 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) - _, err := pgConn.conn.Write(buf) + n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() + if n == 0 { + err = linkErrors(err, ErrNoBytesSent) + } return nil, linkErrors(ctx.Err(), err) } @@ -1191,7 +1205,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR if err := pgConn.lock(); err != nil { return &MultiResultReader{ closed: true, - err: err, + err: linkErrors(err, ErrNoBytesSent), } } @@ -1204,7 +1218,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR select { case <-ctx.Done(): multiResult.closed = true - multiResult.err = ctx.Err() + multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent) pgConn.unlock() return multiResult default: diff --git a/pgconn_test.go b/pgconn_test.go index 30e6a425..b7cb4036 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -264,9 +264,10 @@ func TestConnPrepareContextPrecanceled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() psd, err := pgConn.Prepare(ctx, "ps1", "select 1", nil) - require.Nil(t, psd) - require.Error(t, err) - require.Equal(t, context.Canceled, err) + assert.Nil(t, psd) + assert.Error(t, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) ensureConnValid(t, pgConn) } @@ -386,8 +387,9 @@ func TestConnExecContextPrecanceled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() - require.Error(t, err) - require.Equal(t, context.Canceled, err) + assert.Error(t, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) ensureConnValid(t, pgConn) } @@ -492,7 +494,8 @@ func TestConnExecParamsPrecanceled(t *testing.T) { cancel() result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read() require.Error(t, result.Err) - require.Equal(t, context.Canceled, result.Err) + assert.True(t, errors.Is(result.Err, context.Canceled)) + assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent)) ensureConnValid(t, pgConn) } @@ -620,7 +623,8 @@ func TestConnExecPreparedPrecanceled(t *testing.T) { cancel() result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() require.Error(t, result.Err) - require.Equal(t, context.Canceled, result.Err) + assert.True(t, errors.Is(result.Err, context.Canceled)) + assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent)) ensureConnValid(t, pgConn) } @@ -677,7 +681,8 @@ func TestConnExecBatchPrecanceled(t *testing.T) { cancel() _, err = pgConn.ExecBatch(ctx, batch).ReadAll() require.Error(t, err) - require.Equal(t, context.Canceled, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) ensureConnValid(t, pgConn) } @@ -750,7 +755,8 @@ func TestConnLocking(t *testing.T) { mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() assert.Error(t, err) - assert.Equal(t, pgconn.ErrConnBusy, err) + assert.True(t, errors.Is(err, pgconn.ErrConnBusy)) + assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) results, err = mrr.ReadAll() assert.NoError(t, err) @@ -1036,7 +1042,8 @@ func TestConnCopyToPrecanceled(t *testing.T) { cancel() res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout") require.Error(t, err) - require.Equal(t, context.Canceled, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) assert.Equal(t, pgconn.CommandTag(nil), res) ensureConnValid(t, pgConn) @@ -1143,7 +1150,8 @@ func TestConnCopyFromPrecanceled(t *testing.T) { cancel() ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") require.Error(t, err) - require.Equal(t, context.Canceled, err) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) assert.Equal(t, pgconn.CommandTag(nil), ct) ensureConnValid(t, pgConn) From 23a91ebc909de0d768ce3b5965a603dec0725286 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 24 Apr 2019 16:08:12 -0500 Subject: [PATCH 091/290] auth_scram.go file comment should not be part of docs --- auth_scram.go | 1 + 1 file changed, 1 insertion(+) diff --git a/auth_scram.go b/auth_scram.go index 5baa680b..d102d305 100644 --- a/auth_scram.go +++ b/auth_scram.go @@ -9,6 +9,7 @@ // https://github.com/lib/pq/pull/608 // https://github.com/lib/pq/pull/788 // https://github.com/lib/pq/pull/833 + package pgconn import ( From 1e3961bd0ea4d624dc181734894db99b9e5946f4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 24 Apr 2019 16:49:52 -0500 Subject: [PATCH 092/290] Fix flickering test --- pgconn_test.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pgconn_test.go b/pgconn_test.go index b7cb4036..dcbbfc89 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1289,7 +1289,12 @@ func TestConnCancelRequest(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) - multiResult := pgConn.Exec(context.Background(), "select 'Hello, world', pg_sleep(5)") + multiResult := pgConn.Exec(context.Background(), "select 'Hello, world', pg_sleep(2)") + + // This test flickers without the Sleep. It appears that since Exec only sends the query and returns without awaiting a + // response that the CancelRequest can race it and be received before the query is running and cancellable. So wait a + // few milliseconds. + time.Sleep(50 * time.Millisecond) err = pgConn.CancelRequest(context.Background()) require.NoError(t, err) From 1baf0ef57ec8643d0417d5b2b909ba17c214d125 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 7 May 2019 18:05:06 -0500 Subject: [PATCH 093/290] Refactor context handling into ctxwatch package --- benchmark_test.go | 16 +++ chan_to_set_deadline.go | 51 -------- go.mod | 1 - go.sum | 1 + helper_test.go | 4 +- internal/ctxwatch/context_watcher.go | 64 ++++++++++ internal/ctxwatch/context_watcher_test.go | 139 ++++++++++++++++++++++ pgconn.go | 65 ++++++---- 8 files changed, 261 insertions(+), 80 deletions(-) delete mode 100644 chan_to_set_deadline.go create mode 100644 internal/ctxwatch/context_watcher.go create mode 100644 internal/ctxwatch/context_watcher_test.go diff --git a/benchmark_test.go b/benchmark_test.go index 000dfd1b..073281aa 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -206,3 +206,19 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { } } } + +// func BenchmarkChanToSetDeadlinePossibleToCancel(b *testing.B) { +// conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) +// require.Nil(b, err) +// defer closeConn(b, conn) + +// ctx, cancel := context.WithCancel(context.Background()) +// defer cancel() + +// b.ResetTimer() + +// for i := 0; i < b.N; i++ { +// conn.ChanToSetDeadline().Watch(ctx) +// conn.ChanToSetDeadline().Ignore() +// } +// } diff --git a/chan_to_set_deadline.go b/chan_to_set_deadline.go deleted file mode 100644 index 04bb8fde..00000000 --- a/chan_to_set_deadline.go +++ /dev/null @@ -1,51 +0,0 @@ -package pgconn - -import ( - "time" -) - -var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC) - -type setDeadliner interface { - SetDeadline(time.Time) error -} - -type chanToSetDeadline struct { - cleanupChan chan struct{} - conn setDeadliner - deadlineWasSet bool - cleanupComplete bool -} - -func (this *chanToSetDeadline) start(doneChan <-chan struct{}, conn setDeadliner) { - if this.cleanupChan == nil { - this.cleanupChan = make(chan struct{}) - } - this.conn = conn - this.deadlineWasSet = false - this.cleanupComplete = false - - if doneChan != nil { - go func() { - select { - case <-doneChan: - conn.SetDeadline(deadlineTime) - this.deadlineWasSet = true - <-this.cleanupChan - case <-this.cleanupChan: - } - }() - } else { - this.cleanupComplete = true - } -} - -func (this *chanToSetDeadline) cleanup() { - if !this.cleanupComplete { - this.cleanupChan <- struct{}{} - if this.deadlineWasSet { - this.conn.SetDeadline(time.Time{}) - } - this.cleanupComplete = true - } -} diff --git a/go.mod b/go.mod index acbee593..4ad3564a 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,6 @@ require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db - github.com/pkg/errors v0.8.1 github.com/stretchr/testify v1.3.0 golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a golang.org/x/text v0.3.0 diff --git a/go.sum b/go.sum index 9160f187..9e2398cb 100644 --- a/go.sum +++ b/go.sum @@ -17,6 +17,7 @@ github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0 github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcNpq8q3BCACtVgNfoJxOV7g= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e h1:nFYrTHrdrAOpShe27kaFHjsqYSEQ0KWqdWLu3xuZJts= golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/helper_test.go b/helper_test.go index 5d44f3b8..1a3ca75e 100644 --- a/helper_test.go +++ b/helper_test.go @@ -12,9 +12,9 @@ import ( ) func closeConn(t testing.TB, conn *pgconn.PgConn) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - require.Nil(t, conn.Close(ctx)) + require.NoError(t, conn.Close(ctx)) } // Do a simple query to ensure the connection is still usable diff --git a/internal/ctxwatch/context_watcher.go b/internal/ctxwatch/context_watcher.go new file mode 100644 index 00000000..391f0b79 --- /dev/null +++ b/internal/ctxwatch/context_watcher.go @@ -0,0 +1,64 @@ +package ctxwatch + +import ( + "context" +) + +// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a +// time. +type ContextWatcher struct { + onCancel func() + onUnwatchAfterCancel func() + unwatchChan chan struct{} + watchInProgress bool + onCancelWasCalled bool +} + +// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled. +// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and +// onCancel called. +func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher { + cw := &ContextWatcher{ + onCancel: onCancel, + onUnwatchAfterCancel: onUnwatchAfterCancel, + unwatchChan: make(chan struct{}), + } + + return cw +} + +// Watch starts watching ctx. If ctx is canceled then the onCancel function passed to NewContextWatcher will be called. +func (cw *ContextWatcher) Watch(ctx context.Context) { + if cw.watchInProgress { + panic("Watch already in progress") + } + + cw.onCancelWasCalled = false + + if ctx.Done() != nil { + cw.watchInProgress = true + go func() { + select { + case <-ctx.Done(): + cw.onCancel() + cw.onCancelWasCalled = true + <-cw.unwatchChan + case <-cw.unwatchChan: + } + }() + } else { + cw.watchInProgress = false + } +} + +// Unwatch stops watching the previously watched context. If the onCancel function passed to NewContextWatcher was +// called then onUnwatchAfterCancel will also be called. +func (cw *ContextWatcher) Unwatch() { + if cw.watchInProgress { + cw.unwatchChan <- struct{}{} + if cw.onCancelWasCalled { + cw.onUnwatchAfterCancel() + } + cw.watchInProgress = false + } +} diff --git a/internal/ctxwatch/context_watcher_test.go b/internal/ctxwatch/context_watcher_test.go new file mode 100644 index 00000000..0b491bf8 --- /dev/null +++ b/internal/ctxwatch/context_watcher_test.go @@ -0,0 +1,139 @@ +package ctxwatch_test + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/jackc/pgconn/internal/ctxwatch" + "github.com/stretchr/testify/require" +) + +func TestContextWatcherContextCancelled(t *testing.T) { + canceledChan := make(chan struct{}) + cleanupCalled := false + cw := ctxwatch.NewContextWatcher(func() { + canceledChan <- struct{}{} + }, func() { + cleanupCalled = true + }) + + ctx, cancel := context.WithCancel(context.Background()) + cw.Watch(ctx) + cancel() + + select { + case <-canceledChan: + case <-time.NewTimer(time.Second).C: + t.Fatal("Timed out waiting for cancel func to be called") + } + + cw.Unwatch() + + require.True(t, cleanupCalled, "Cleanup func was not called") +} + +func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) { + cw := ctxwatch.NewContextWatcher(func() { + t.Error("cancel func should not have been called") + }, func() { + t.Error("cleanup func should not have been called") + }) + + ctx, cancel := context.WithCancel(context.Background()) + cw.Watch(ctx) + cw.Unwatch() + cancel() +} + +func TestContextWatcherMultipleWatchPanics(t *testing.T) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cw.Watch(ctx) + + ctx2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + require.Panics(t, func() { cw.Watch(ctx2) }, "Expected panic when Watch called multiple times") +} + +func TestContextWatcherStress(t *testing.T) { + var cancelFuncCalls int64 + var cleanupFuncCalls int64 + + cw := ctxwatch.NewContextWatcher(func() { + atomic.AddInt64(&cancelFuncCalls, 1) + }, func() { + atomic.AddInt64(&cleanupFuncCalls, 1) + }) + + cycleCount := 100000 + + for i := 0; i < cycleCount; i++ { + ctx, cancel := context.WithCancel(context.Background()) + cw.Watch(ctx) + if i%2 == 0 { + cancel() + } + + // Without time.Sleep, cw.Unwatch will almost always run before the cancel func which means cancel will never happen. This gives us a better mix. + if i%3 == 0 { + time.Sleep(time.Nanosecond) + } + + cw.Unwatch() + if i%2 == 1 { + cancel() + } + } + + actualCancelFuncCalls := atomic.LoadInt64(&cancelFuncCalls) + actualCleanupFuncCalls := atomic.LoadInt64(&cleanupFuncCalls) + + if actualCancelFuncCalls == 0 { + t.Fatal("actualCancelFuncCalls == 0") + } + + maxCancelFuncCalls := int64(cycleCount) / 2 + if actualCancelFuncCalls > maxCancelFuncCalls { + t.Errorf("cancel func calls should be no more than %d but was %d", actualCancelFuncCalls, maxCancelFuncCalls) + } + + if actualCancelFuncCalls != actualCleanupFuncCalls { + t.Errorf("cancel func calls (%d) should be equal to cleanup func calls (%d) but was not", actualCancelFuncCalls, actualCleanupFuncCalls) + } +} + +func BenchmarkContextWatcherUncancellable(b *testing.B) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + + for i := 0; i < b.N; i++ { + cw.Watch(context.Background()) + cw.Unwatch() + } +} + +func BenchmarkContextWatcherCancelled(b *testing.B) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + + for i := 0; i < b.N; i++ { + ctx, cancel := context.WithCancel(context.Background()) + cw.Watch(ctx) + cancel() + cw.Unwatch() + } +} + +func BenchmarkContextWatcherCancellable(b *testing.B) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for i := 0; i < b.N; i++ { + cw.Watch(ctx) + cw.Unwatch() + } +} diff --git a/pgconn.go b/pgconn.go index a4402a7d..aad5fafd 100644 --- a/pgconn.go +++ b/pgconn.go @@ -13,7 +13,9 @@ import ( "strconv" "strings" "sync" + "time" + "github.com/jackc/pgconn/internal/ctxwatch" "github.com/jackc/pgio" "github.com/jackc/pgproto3/v2" errors "golang.org/x/xerrors" @@ -21,6 +23,7 @@ import ( const ( connStatusUninitialized = iota + connStatusConnecting connStatusClosed connStatusIdle connStatusBusy @@ -71,10 +74,10 @@ type PgConn struct { bufferingReceiveErr error // Reusable / preallocated resources - wbuf []byte // write buffer - resultReader ResultReader - multiResultReader MultiResultReader - doneChanToDeadline chanToSetDeadline + wbuf []byte // write buffer + resultReader ResultReader + multiResultReader MultiResultReader + contextWatcher *ctxwatch.ContextWatcher } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) @@ -149,6 +152,12 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } } + pgConn.status = connStatusConnecting + pgConn.contextWatcher = ctxwatch.NewContextWatcher( + func() { pgConn.conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, + func() { pgConn.conn.SetDeadline(time.Time{}) }, + ) + pgConn.Frontend, err = pgproto3.NewFrontend(pgproto3.NewChunkReader(pgConn.conn), pgConn.conn) if err != nil { return nil, err @@ -355,8 +364,8 @@ func (pgConn *PgConn) Close(ctx context.Context) error { defer pgConn.conn.Close() - pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) - defer pgConn.doneChanToDeadline.cleanup() + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() _, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) if err != nil { @@ -377,6 +386,7 @@ func (pgConn *PgConn) hardClose() error { return nil } pgConn.status = connStatusClosed + return pgConn.conn.Close() } @@ -453,8 +463,8 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ return nil, linkErrors(ctx.Err(), ErrNoBytesSent) default: } - pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) - defer pgConn.doneChanToDeadline.cleanup() + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() buf := pgConn.wbuf buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) @@ -543,9 +553,12 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { } defer cancelConn.Close() - var doneChanToDeadline chanToSetDeadline - doneChanToDeadline.start(ctx.Done(), cancelConn) - defer doneChanToDeadline.cleanup() + contextWatcher := ctxwatch.NewContextWatcher( + func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, + func() { cancelConn.SetDeadline(time.Time{}) }, + ) + contextWatcher.Watch(ctx) + defer contextWatcher.Unwatch() buf := make([]byte, 16) binary.BigEndian.PutUint32(buf[0:4], 16) @@ -579,8 +592,8 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { default: } - pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) - defer pgConn.doneChanToDeadline.cleanup() + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() for { msg, err := pgConn.ReceiveMessage() @@ -622,7 +635,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { return multiResult default: } - pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) + pgConn.contextWatcher.Watch(ctx) buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) @@ -630,7 +643,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - pgConn.doneChanToDeadline.cleanup() + pgConn.contextWatcher.Unwatch() multiResult.closed = true if n == 0 { err = linkErrors(err, ErrNoBytesSent) @@ -732,7 +745,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by return result default: } - pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) + pgConn.contextWatcher.Watch(ctx) return result } @@ -749,7 +762,7 @@ func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result err = linkErrors(err, ErrNoBytesSent) } result.concludeCommand(nil, linkErrors(ctx.Err(), err)) - pgConn.doneChanToDeadline.cleanup() + pgConn.contextWatcher.Unwatch() result.closed = true pgConn.unlock() } @@ -767,8 +780,8 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm return nil, linkErrors(ctx.Err(), ErrNoBytesSent) default: } - pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) - defer pgConn.doneChanToDeadline.cleanup() + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() // Send copy to command buf := pgConn.wbuf @@ -828,8 +841,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co return nil, linkErrors(ctx.Err(), ErrNoBytesSent) default: } - pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) - defer pgConn.doneChanToDeadline.cleanup() + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() // Send copy to command buf := pgConn.wbuf @@ -962,7 +975,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) msg, err := mrr.pgConn.ReceiveMessage() if err != nil { - mrr.pgConn.doneChanToDeadline.cleanup() + mrr.pgConn.contextWatcher.Unwatch() mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) mrr.closed = true mrr.pgConn.hardClose() @@ -971,7 +984,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) switch msg := msg.(type) { case *pgproto3.ReadyForQuery: - mrr.pgConn.doneChanToDeadline.cleanup() + mrr.pgConn.contextWatcher.Unwatch() mrr.closed = true mrr.pgConn.unlock() case *pgproto3.ErrorResponse: @@ -1129,7 +1142,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { switch msg.(type) { case *pgproto3.ReadyForQuery: - rr.pgConn.doneChanToDeadline.cleanup() + rr.pgConn.contextWatcher.Unwatch() rr.pgConn.unlock() return rr.commandTag, rr.err } @@ -1148,7 +1161,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error if err != nil { rr.concludeCommand(nil, err) - rr.pgConn.doneChanToDeadline.cleanup() + rr.pgConn.contextWatcher.Unwatch() rr.closed = true if rr.multiResultReader == nil { rr.pgConn.hardClose() @@ -1223,7 +1236,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR return multiResult default: } - pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn) + pgConn.contextWatcher.Watch(ctx) batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) From d30cf1c19f3a13beb275eb8a517d7f54d5e185bf Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 9 May 2019 15:15:40 -0500 Subject: [PATCH 094/290] Adjust buffer size for CopyFrom --- pgconn.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pgconn.go b/pgconn.go index aad5fafd..bbabb0dd 100644 --- a/pgconn.go +++ b/pgconn.go @@ -879,8 +879,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } // Send copy data - buf = make([]byte, 0, 20000) - // buf = make([]byte, 0, 65536) + buf = make([]byte, 0, 65536) buf = append(buf, 'd') sp := len(buf) var readErr error From de87e8be96e1ee042303fe7116d02692155e7504 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 27 May 2019 12:50:27 -0500 Subject: [PATCH 095/290] Fix: Use fallback config TLS config --- pgconn.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pgconn.go b/pgconn.go index bbabb0dd..c51742ae 100644 --- a/pgconn.go +++ b/pgconn.go @@ -145,8 +145,8 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.parameterStatuses = make(map[string]string) - if config.TLSConfig != nil { - if err := pgConn.startTLS(config.TLSConfig); err != nil { + if fallbackConfig.TLSConfig != nil { + if err := pgConn.startTLS(fallbackConfig.TLSConfig); err != nil { pgConn.conn.Close() return nil, err } From 71ec1f78211346069f77cf843fca96d5e62ba90c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 28 May 2019 06:54:20 -0500 Subject: [PATCH 096/290] Update xerrors package --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 4ad3564a..9401dce8 100644 --- a/go.mod +++ b/go.mod @@ -9,5 +9,5 @@ require ( github.com/stretchr/testify v1.3.0 golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a golang.org/x/text v0.3.0 - golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 + golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 ) diff --git a/go.sum b/go.sum index 9e2398cb..1b6862a0 100644 --- a/go.sum +++ b/go.sum @@ -23,3 +23,5 @@ golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 h1:PPwnA7z1Pjf7XYaBP9GL1VAMZmcIWyFz7QCMSIIa3Bg= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 h1:bhOzK9QyoD0ogCnFro1m2mz41+Ib0oOhfJnBp5MR4K4= +golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= From a97dd2f9f6d06658a4d189720e2d3c8e0bf51f69 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 1 Jun 2019 09:59:04 -0500 Subject: [PATCH 097/290] Update test envvar and docs --- .travis.yml | 2 +- README.md | 28 +++++++++++++ benchmark_test.go | 10 ++--- pgconn_stress_test.go | 2 +- pgconn_test.go | 94 +++++++++++++++++++++---------------------- 5 files changed, 82 insertions(+), 54 deletions(-) diff --git a/.travis.yml b/.travis.yml index 50e81eb5..e5ed43a8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,7 +11,7 @@ before_install: env: global: - GO111MODULE=on - - PGX_TEST_DATABASE=postgres://pgx_md5:secret@127.0.0.1/pgx_test + - PGX_TEST_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test - PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql database=pgx_test" - PGX_TEST_TCP_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test - PGX_TEST_TLS_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require diff --git a/README.md b/README.md index 8a881009..05cfedf1 100644 --- a/README.md +++ b/README.md @@ -6,3 +6,31 @@ Package pgconn is a low-level PostgreSQL database driver. It is intended to serve as the foundation for the next generation of https://github.com/jackc/pgx. + +## Testing + +pgconn tests need a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_CONN_STRING` +environment variable. The `PGX_TEST_CONN_STRING` environment variable can be a URL or DSN. In addition, the standard `PG*` +environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to simplify +environment variable handling. + +### Example Test Environment + +Connect to your PostgreSQL server and run: + +``` +create database pgx_test; +``` + +Now you can run the tests: + +``` +PGX_TEST_CONN_STRING="host=/var/run/postgresql database=pgx_test" go test ./... +``` + +### Connection and Authentication Tests + +There are multiple connection types and means of authentication that pgconn supports. These tests are optional. They +will only run if the appropriate environment variable is set. Run `go test -v | grep SKIP` to see if any tests are being +skipped. Typical developers will not need to enable these tests. See travis.yml for example setup if you need change +authentication code. diff --git a/benchmark_test.go b/benchmark_test.go index 073281aa..51e11e24 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -38,7 +38,7 @@ func BenchmarkConnect(b *testing.B) { } func BenchmarkExec(b *testing.B) { - conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.Nil(b, err) defer closeConn(b, conn) @@ -82,7 +82,7 @@ func BenchmarkExec(b *testing.B) { } func BenchmarkExecPossibleToCancel(b *testing.B) { - conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.Nil(b, err) defer closeConn(b, conn) @@ -129,7 +129,7 @@ func BenchmarkExecPossibleToCancel(b *testing.B) { } func BenchmarkExecPrepared(b *testing.B) { - conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.Nil(b, err) defer closeConn(b, conn) @@ -167,7 +167,7 @@ func BenchmarkExecPrepared(b *testing.B) { } func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { - conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.Nil(b, err) defer closeConn(b, conn) @@ -208,7 +208,7 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { } // func BenchmarkChanToSetDeadlinePossibleToCancel(b *testing.B) { -// conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) +// conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) // require.Nil(b, err) // defer closeConn(b, conn) diff --git a/pgconn_stress_test.go b/pgconn_stress_test.go index 7288c9b4..356b529a 100644 --- a/pgconn_stress_test.go +++ b/pgconn_stress_test.go @@ -14,7 +14,7 @@ import ( ) func TestConnStress(t *testing.T) { - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) diff --git a/pgconn_test.go b/pgconn_test.go index dcbbfc89..310b387b 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -112,7 +112,7 @@ func TestConnectWithConnectionRefused(t *testing.T) { func TestConnectCustomDialer(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) dialed := false @@ -130,7 +130,7 @@ func TestConnectCustomDialer(t *testing.T) { func TestConnectWithRuntimeParams(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) config.RuntimeParams = map[string]string{ @@ -156,7 +156,7 @@ func TestConnectWithRuntimeParams(t *testing.T) { func TestConnectWithFallback(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) // Prepend current primary config to fallbacks @@ -189,7 +189,7 @@ func TestConnectWithFallback(t *testing.T) { func TestConnectWithAfterConnectFunc(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) dialCount := 0 @@ -228,7 +228,7 @@ func TestConnectWithAfterConnectFunc(t *testing.T) { func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) config.AfterConnectFunc = pgconn.AfterConnectTargetSessionAttrsReadWrite @@ -243,7 +243,7 @@ func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { func TestConnPrepareSyntaxError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -257,7 +257,7 @@ func TestConnPrepareSyntaxError(t *testing.T) { func TestConnPrepareContextPrecanceled(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -275,7 +275,7 @@ func TestConnPrepareContextPrecanceled(t *testing.T) { func TestConnExec(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -294,7 +294,7 @@ func TestConnExec(t *testing.T) { func TestConnExecEmpty(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -315,7 +315,7 @@ func TestConnExecEmpty(t *testing.T) { func TestConnExecMultipleQueries(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -340,7 +340,7 @@ func TestConnExecMultipleQueries(t *testing.T) { func TestConnExecMultipleQueriesError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -362,7 +362,7 @@ func TestConnExecMultipleQueriesError(t *testing.T) { func TestConnExecContextCanceled(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -380,7 +380,7 @@ func TestConnExecContextCanceled(t *testing.T) { func TestConnExecContextPrecanceled(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -397,7 +397,7 @@ func TestConnExecContextPrecanceled(t *testing.T) { func TestConnExecParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -418,7 +418,7 @@ func TestConnExecParams(t *testing.T) { func TestConnExecParamsMaxNumberOfParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -441,7 +441,7 @@ func TestConnExecParamsMaxNumberOfParams(t *testing.T) { func TestConnExecParamsTooManyParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -464,7 +464,7 @@ func TestConnExecParamsTooManyParams(t *testing.T) { func TestConnExecParamsCanceled(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -486,7 +486,7 @@ func TestConnExecParamsCanceled(t *testing.T) { func TestConnExecParamsPrecanceled(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -503,7 +503,7 @@ func TestConnExecParamsPrecanceled(t *testing.T) { func TestConnExecPrepared(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -530,7 +530,7 @@ func TestConnExecPrepared(t *testing.T) { func TestConnExecPreparedMaxNumberOfParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -559,7 +559,7 @@ func TestConnExecPreparedMaxNumberOfParams(t *testing.T) { func TestConnExecPreparedTooManyParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -588,7 +588,7 @@ func TestConnExecPreparedTooManyParams(t *testing.T) { func TestConnExecPreparedCanceled(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -612,7 +612,7 @@ func TestConnExecPreparedCanceled(t *testing.T) { func TestConnExecPreparedPrecanceled(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -632,7 +632,7 @@ func TestConnExecPreparedPrecanceled(t *testing.T) { func TestConnExecBatch(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -664,7 +664,7 @@ func TestConnExecBatch(t *testing.T) { func TestConnExecBatchPrecanceled(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -697,7 +697,7 @@ func TestConnExecBatchHuge(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -725,7 +725,7 @@ func TestConnExecBatchHuge(t *testing.T) { func TestConnExecBatchImplicitTransaction(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -748,7 +748,7 @@ func TestConnExecBatchImplicitTransaction(t *testing.T) { func TestConnLocking(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -795,7 +795,7 @@ func TestCommandTag(t *testing.T) { func TestConnOnNotice(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) var msg string @@ -821,7 +821,7 @@ end$$;`) func TestConnOnNotification(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) var msg string @@ -853,7 +853,7 @@ func TestConnOnNotification(t *testing.T) { func TestConnWaitForNotification(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) var msg string @@ -885,7 +885,7 @@ func TestConnWaitForNotification(t *testing.T) { func TestConnWaitForNotificationPrecanceled(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) pgConn, err := pgconn.ConnectConfig(context.Background(), config) @@ -903,7 +903,7 @@ func TestConnWaitForNotificationPrecanceled(t *testing.T) { func TestConnWaitForNotificationTimeout(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) pgConn, err := pgconn.ConnectConfig(context.Background(), config) @@ -921,7 +921,7 @@ func TestConnWaitForNotificationTimeout(t *testing.T) { func TestConnCopyToSmall(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -959,7 +959,7 @@ func TestConnCopyToSmall(t *testing.T) { func TestConnCopyToLarge(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -997,7 +997,7 @@ func TestConnCopyToLarge(t *testing.T) { func TestConnCopyToQueryError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -1014,7 +1014,7 @@ func TestConnCopyToQueryError(t *testing.T) { func TestConnCopyToCanceled(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -1032,7 +1032,7 @@ func TestConnCopyToCanceled(t *testing.T) { func TestConnCopyToPrecanceled(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -1052,7 +1052,7 @@ func TestConnCopyToPrecanceled(t *testing.T) { func TestConnCopyFrom(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -1088,7 +1088,7 @@ func TestConnCopyFrom(t *testing.T) { func TestConnCopyFromCanceled(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -1123,7 +1123,7 @@ func TestConnCopyFromCanceled(t *testing.T) { func TestConnCopyFromPrecanceled(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -1160,7 +1160,7 @@ func TestConnCopyFromPrecanceled(t *testing.T) { func TestConnCopyFromGzipReader(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -1217,7 +1217,7 @@ func TestConnCopyFromGzipReader(t *testing.T) { func TestConnCopyFromQuerySyntaxError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -1240,7 +1240,7 @@ func TestConnCopyFromQuerySyntaxError(t *testing.T) { func TestConnCopyFromQueryNoTableError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -1257,7 +1257,7 @@ func TestConnCopyFromQueryNoTableError(t *testing.T) { func TestConnEscapeString(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -1285,7 +1285,7 @@ func TestConnEscapeString(t *testing.T) { func TestConnCancelRequest(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) @@ -1310,7 +1310,7 @@ func TestConnCancelRequest(t *testing.T) { } func Example() { - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) if err != nil { log.Fatalln(err) } From 529805557f0334621e115ca0dabf6dcf9b5a38bb Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Sat, 22 Jun 2019 10:41:01 +0300 Subject: [PATCH 098/290] Fix linters notifications Signed-off-by: Artemiy Ryabinkov --- auth_scram.go | 2 +- benchmark_test.go | 15 ++++++++------- internal/ctxwatch/context_watcher_test.go | 3 +++ pgconn.go | 18 +++++++++--------- pgconn_test.go | 11 ++++++----- 5 files changed, 27 insertions(+), 22 deletions(-) diff --git a/auth_scram.go b/auth_scram.go index d102d305..bdaf3e92 100644 --- a/auth_scram.go +++ b/auth_scram.go @@ -249,7 +249,7 @@ func computeClientProof(saltedPassword, authMessage []byte) []byte { func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte { serverKey := computeHMAC(saltedPassword, []byte("Server Key")) - serverSignature := computeHMAC(serverKey[:], authMessage) + serverSignature := computeHMAC(serverKey, authMessage) buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature))) base64.StdEncoding.Encode(buf, serverSignature) return buf diff --git a/benchmark_test.go b/benchmark_test.go index 51e11e24..8067c985 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -20,6 +20,7 @@ func BenchmarkConnect(b *testing.B) { } for _, bm := range benchmarks { + bm := bm b.Run(bm.name, func(b *testing.B) { connString := os.Getenv(bm.env) if connString == "" { @@ -54,12 +55,12 @@ func BenchmarkExec(b *testing.B) { rowCount := 0 for rr.NextRow() { - rowCount += 1 + rowCount++ if len(rr.Values()) != len(expectedValues) { b.Fatalf("unexpected number of values: %d", len(rr.Values())) } for i := range rr.Values() { - if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) } } @@ -101,12 +102,12 @@ func BenchmarkExecPossibleToCancel(b *testing.B) { rowCount := 0 for rr.NextRow() { - rowCount += 1 + rowCount++ if len(rr.Values()) != len(expectedValues) { b.Fatalf("unexpected number of values: %d", len(rr.Values())) } for i := range rr.Values() { - if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) } } @@ -145,12 +146,12 @@ func BenchmarkExecPrepared(b *testing.B) { rowCount := 0 for rr.NextRow() { - rowCount += 1 + rowCount++ if len(rr.Values()) != len(expectedValues) { b.Fatalf("unexpected number of values: %d", len(rr.Values())) } for i := range rr.Values() { - if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) } } @@ -191,7 +192,7 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { b.Fatalf("unexpected number of values: %d", len(rr.Values())) } for i := range rr.Values() { - if bytes.Compare(rr.Values()[i], expectedValues[i]) != 0 { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) } } diff --git a/internal/ctxwatch/context_watcher_test.go b/internal/ctxwatch/context_watcher_test.go index 0b491bf8..a1b3c863 100644 --- a/internal/ctxwatch/context_watcher_test.go +++ b/internal/ctxwatch/context_watcher_test.go @@ -87,6 +87,9 @@ func TestContextWatcherStress(t *testing.T) { if i%2 == 1 { cancel() } + + // To avoid context leak + cancel() } actualCancelFuncCalls := atomic.LoadInt64(&cancelFuncCalls) diff --git a/pgconn.go b/pgconn.go index c51742ae..9e4f6253 100644 --- a/pgconn.go +++ b/pgconn.go @@ -241,16 +241,16 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { return nil } -func (c *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { +func (pgConn *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { switch msg.Type { case pgproto3.AuthTypeOk: case pgproto3.AuthTypeCleartextPassword: - err = c.txPasswordMessage(c.Config.Password) + err = pgConn.txPasswordMessage(pgConn.Config.Password) case pgproto3.AuthTypeMD5Password: - digestedPassword := "md5" + hexMD5(hexMD5(c.Config.Password+c.Config.User)+string(msg.Salt[:])) - err = c.txPasswordMessage(digestedPassword) + digestedPassword := "md5" + hexMD5(hexMD5(pgConn.Config.Password+pgConn.Config.User)+string(msg.Salt[:])) + err = pgConn.txPasswordMessage(digestedPassword) case pgproto3.AuthTypeSASL: - err = c.scramAuth(msg.SASLAuthMechanisms) + err = pgConn.scramAuth(msg.SASLAuthMechanisms) default: err = errors.New("Received unknown authentication message") } @@ -514,11 +514,11 @@ readloop: func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { return &PgError{ - Severity: string(msg.Severity), + Severity: msg.Severity, Code: string(msg.Code), Message: string(msg.Message), Detail: string(msg.Detail), - Hint: string(msg.Hint), + Hint: msg.Hint, Position: msg.Position, InternalPosition: msg.InternalPosition, InternalQuery: string(msg.InternalQuery), @@ -527,7 +527,7 @@ func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { TableName: string(msg.TableName), ColumnName: string(msg.ColumnName), DataTypeName: string(msg.DataTypeName), - ConstraintName: string(msg.ConstraintName), + ConstraintName: msg.ConstraintName, File: string(msg.File), Line: msg.Line, Routine: string(msg.Routine), @@ -919,7 +919,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co copyDone := &pgproto3.CopyDone{} buf = copyDone.Encode(buf) } else { - copyFail := &pgproto3.CopyFail{Error: readErr.Error()} + copyFail := &pgproto3.CopyFail{Message: readErr.Error()} buf = copyFail.Encode(buf) } _, err = pgConn.conn.Write(buf) diff --git a/pgconn_test.go b/pgconn_test.go index 310b387b..4389fe99 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -37,6 +37,7 @@ func TestConnect(t *testing.T) { } for _, tt := range tests { + tt := tt t.Run(tt.name, func(t *testing.T) { connString := os.Getenv(tt.env) if connString == "" { @@ -194,13 +195,13 @@ func TestConnectWithAfterConnectFunc(t *testing.T) { dialCount := 0 config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { - dialCount += 1 + dialCount++ return net.Dial(network, address) } acceptConnCount := 0 config.AfterConnectFunc = func(ctx context.Context, conn *pgconn.PgConn) error { - acceptConnCount += 1 + acceptConnCount++ if acceptConnCount < 2 { return errors.New("reject first conn") } @@ -302,7 +303,7 @@ func TestConnExecEmpty(t *testing.T) { resultCount := 0 for multiResult.NextResult() { - resultCount += 1 + resultCount++ multiResult.ResultReader().Close() } assert.Equal(t, 0, resultCount) @@ -753,12 +754,12 @@ func TestConnLocking(t *testing.T) { defer closeConn(t, pgConn) mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") - results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + _, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() assert.Error(t, err) assert.True(t, errors.Is(err, pgconn.ErrConnBusy)) assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) - results, err = mrr.ReadAll() + results, err := mrr.ReadAll() assert.NoError(t, err) assert.Len(t, results, 1) assert.Nil(t, results[0].Err) From 54ce9c6bb807f53394115ce7849f9e083aea095a Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Sat, 22 Jun 2019 14:35:17 +0300 Subject: [PATCH 099/290] Update pgproto3 dependency Signed-off-by: Artemiy Ryabinkov --- .gitignore | 1 + go.mod | 2 +- go.sum | 4 ++++ 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 7a6353d6..6eb9d442 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ .envrc +vendor/ \ No newline at end of file diff --git a/go.mod b/go.mod index 9401dce8..b1c84049 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.12 require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db + github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 github.com/stretchr/testify v1.3.0 golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a golang.org/x/text v0.3.0 diff --git a/go.sum b/go.sum index 1b6862a0..50dfc2fd 100644 --- a/go.sum +++ b/go.sum @@ -2,12 +2,16 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= +github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= +github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db h1:UpaKn/gYxzH6/zWyRQH1S260zvKqwJJ4h8+Kf09ooh0= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 h1:vZp4bYotXUkFx7JUSm7U8KV/7Q0AOdrQxxBBj0ZmZsg= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= From 07904bd774d7009cd55030606fe5d30e1329c6c0 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Sat, 22 Jun 2019 20:09:55 +0300 Subject: [PATCH 100/290] Remove unnecassary ctx cancel Signed-off-by: Artemiy Ryabinkov --- internal/ctxwatch/context_watcher_test.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/internal/ctxwatch/context_watcher_test.go b/internal/ctxwatch/context_watcher_test.go index a1b3c863..0b491bf8 100644 --- a/internal/ctxwatch/context_watcher_test.go +++ b/internal/ctxwatch/context_watcher_test.go @@ -87,9 +87,6 @@ func TestContextWatcherStress(t *testing.T) { if i%2 == 1 { cancel() } - - // To avoid context leak - cancel() } actualCancelFuncCalls := atomic.LoadInt64(&cancelFuncCalls) From d2440c7fe62ef4c392bbae34eef01e4a6865ed03 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 22 Jun 2019 16:54:10 -0500 Subject: [PATCH 101/290] Improve documentation --- README.md | 24 ++++++++++++++++++++++-- config.go | 7 +++++++ pgconn.go | 2 +- 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 05cfedf1..9e35a0f5 100644 --- a/README.md +++ b/README.md @@ -3,9 +3,29 @@ # pgconn -Package pgconn is a low-level PostgreSQL database driver. +Package pgconn is a low-level PostgreSQL database driver. It operates at nearly the same level is the C library libpq. +It is primarily intended to serve as the foundation for higher level libraries such as https://github.com/jackc/pgx. +Applications should handle normal queries with a higher level library and only use pgconn directly when required for +low-level access to PostgreSQL functionality. -It is intended to serve as the foundation for the next generation of https://github.com/jackc/pgx. +## Example Usage + +```go +pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL")) +if err != nil { + log.Fatalln("pgconn failed to connect:", err) +} +defer pgConn.Close() + +result := pgConn.ExecParams(context.Background(), "select email from users where id=$1", [][]byte{[]byte("123")}, nil, nil, nil) +for result.NextRow() { + fmt.Println("User 123 has email:", string(result.Values()[0])) +} +_, err := result.Close() +if err != nil { + log.Fatalln("failed reading result:", err) +}) +``` ## Testing diff --git a/config.go b/config.go index c751cc0d..98755b1f 100644 --- a/config.go +++ b/config.go @@ -121,6 +121,13 @@ func NetworkAddress(host string, port uint16) (network, address string) { // security guarantees than it would with libpq. Do not rely on this behavior as it // may be possible to match libpq in the future. If you need full security use // "verify-full". +// +// Other known differences with libpq: +// +// If a host name resolves into multiple addresses, libpq will try all addresses. pgconn will only try the first. +// +// When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn +// does not. func ParseConfig(connString string) (*Config, error) { settings := defaultSettings() addEnvSettings(settings) diff --git a/pgconn.go b/pgconn.go index 9e4f6253..3deb8563 100644 --- a/pgconn.go +++ b/pgconn.go @@ -390,7 +390,7 @@ func (pgConn *PgConn) hardClose() error { return pgConn.conn.Close() } -// TODO - rethink how to report status. At the moment this is just a temporary measure so pgx.Conn can detect deatch of +// TODO - rethink how to report status. At the moment this is just a temporary measure so pgx.Conn can detect death of // underlying connection. func (pgConn *PgConn) IsAlive() bool { return pgConn.status >= connStatusIdle From 59941377c8ff1467e4805f20e7aef29201e72e2c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Jul 2019 09:52:22 -0500 Subject: [PATCH 102/290] Rename Config.AfterConnectFunc to AfterConnect No need to include the type in the name. --- config.go | 6 +++--- config_test.go | 18 +++++++++--------- pgconn.go | 6 +++--- pgconn_test.go | 6 +++--- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/config.go b/config.go index 98755b1f..533791c2 100644 --- a/config.go +++ b/config.go @@ -36,10 +36,10 @@ type Config struct { Fallbacks []*FallbackConfig - // AfterConnectFunc is called after successful connection. It can be used to set up the connection or to validate that + // AfterConnect is called after successful connection. It can be used to set up the connection or to validate that // server is acceptable. If this returns an error the connection is closed and the next fallback config is tried. This // allows implementing high availability behavior such as libpq does with target_session_attrs. - AfterConnectFunc AfterConnectFunc + AfterConnect AfterConnectFunc // OnNotice is a callback function called when a notice response is received. OnNotice NoticeHandler @@ -245,7 +245,7 @@ func ParseConfig(connString string) (*Config, error) { } if settings["target_session_attrs"] == "read-write" { - config.AfterConnectFunc = AfterConnectTargetSessionAttrsReadWrite + config.AfterConnect = AfterConnectTargetSessionAttrsReadWrite } else if settings["target_session_attrs"] != "any" { return nil, errors.Errorf("unknown target_session_attrs value: %v", settings["target_session_attrs"]) } diff --git a/config_test.go b/config_test.go index ce6f3957..b222d8cc 100644 --- a/config_test.go +++ b/config_test.go @@ -378,14 +378,14 @@ func TestParseConfig(t *testing.T) { name: "target_session_attrs", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-write", config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: nil, - RuntimeParams: map[string]string{}, - AfterConnectFunc: pgconn.AfterConnectTargetSessionAttrsReadWrite, + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + AfterConnect: pgconn.AfterConnectTargetSessionAttrsReadWrite, }, }, } @@ -416,7 +416,7 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) // Can't test function equality, so just test that they are set or not. - assert.Equalf(t, expected.AfterConnectFunc == nil, actual.AfterConnectFunc == nil, "%s - AfterConnectFunc", testName) + assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName) if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { if expected.TLSConfig != nil { diff --git a/pgconn.go b/pgconn.go index 3deb8563..2db35587 100644 --- a/pgconn.go +++ b/pgconn.go @@ -201,11 +201,11 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } case *pgproto3.ReadyForQuery: pgConn.status = connStatusIdle - if config.AfterConnectFunc != nil { - err := config.AfterConnectFunc(ctx, pgConn) + if config.AfterConnect != nil { + err := config.AfterConnect(ctx, pgConn) if err != nil { pgConn.conn.Close() - return nil, errors.Errorf("AfterConnectFunc: %v", err) + return nil, errors.Errorf("AfterConnect: %v", err) } } return pgConn, nil diff --git a/pgconn_test.go b/pgconn_test.go index 4389fe99..028d5e94 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -187,7 +187,7 @@ func TestConnectWithFallback(t *testing.T) { closeConn(t, conn) } -func TestConnectWithAfterConnectFunc(t *testing.T) { +func TestConnectWithAfterConnect(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) @@ -200,7 +200,7 @@ func TestConnectWithAfterConnectFunc(t *testing.T) { } acceptConnCount := 0 - config.AfterConnectFunc = func(ctx context.Context, conn *pgconn.PgConn) error { + config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { acceptConnCount++ if acceptConnCount < 2 { return errors.New("reject first conn") @@ -232,7 +232,7 @@ func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) - config.AfterConnectFunc = pgconn.AfterConnectTargetSessionAttrsReadWrite + config.AfterConnect = pgconn.AfterConnectTargetSessionAttrsReadWrite config.RuntimeParams["default_transaction_read_only"] = "on" conn, err := pgconn.ConnectConfig(context.Background(), config) From 3dec1848118789c4430914ca04d2f6fd0542c3d9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Jul 2019 10:22:09 -0500 Subject: [PATCH 103/290] Split ValidateConnect from AfterConnect This avoids the foot-gun of ParseConfig setting AfterConnect because of target_session_attrs and the user inadvertently overriding it with an AfterConnect designed to setup the connection. Now target_session_attrs will be handled with ValidateConnect. --- config.go | 18 ++++++++++++------ config_test.go | 17 +++++++++-------- pgconn.go | 22 +++++++++++++++++----- pgconn_test.go | 29 +++++++++++++++++++++++++---- 4 files changed, 63 insertions(+), 23 deletions(-) diff --git a/config.go b/config.go index 533791c2..9b74945e 100644 --- a/config.go +++ b/config.go @@ -22,6 +22,7 @@ import ( ) type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error +type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error // Config is the settings used to establish a connection to a PostgreSQL server. type Config struct { @@ -36,9 +37,14 @@ type Config struct { Fallbacks []*FallbackConfig - // AfterConnect is called after successful connection. It can be used to set up the connection or to validate that - // server is acceptable. If this returns an error the connection is closed and the next fallback config is tried. This - // allows implementing high availability behavior such as libpq does with target_session_attrs. + // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server. + // It can be used validate that server is acceptable. If this returns an error the connection is closed and the next + // fallback config is tried. This allows implementing high availability behavior such as libpq does with + // target_session_attrs. + ValidateConnect ValidateConnectFunc + + // AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables + // or prepare statements). If this returns an error the connection attempt fails. AfterConnect AfterConnectFunc // OnNotice is a callback function called when a notice response is received. @@ -245,7 +251,7 @@ func ParseConfig(connString string) (*Config, error) { } if settings["target_session_attrs"] == "read-write" { - config.AfterConnect = AfterConnectTargetSessionAttrsReadWrite + config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite } else if settings["target_session_attrs"] != "any" { return nil, errors.Errorf("unknown target_session_attrs value: %v", settings["target_session_attrs"]) } @@ -481,9 +487,9 @@ func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { return d.DialContext, nil } -// AfterConnectTargetSessionAttrsReadWrite is an AfterConnectFunc that implements libpq compatible +// ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible // target_session_attrs=read-write. -func AfterConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { +func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() if result.Err != nil { return result.Err diff --git a/config_test.go b/config_test.go index b222d8cc..23d86529 100644 --- a/config_test.go +++ b/config_test.go @@ -378,14 +378,14 @@ func TestParseConfig(t *testing.T) { name: "target_session_attrs", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-write", config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: nil, - RuntimeParams: map[string]string{}, - AfterConnect: pgconn.AfterConnectTargetSessionAttrsReadWrite, + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsReadWrite, }, }, } @@ -416,6 +416,7 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) // Can't test function equality, so just test that they are set or not. + assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName) assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName) if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { diff --git a/pgconn.go b/pgconn.go index 2db35587..6e1fb7e3 100644 --- a/pgconn.go +++ b/pgconn.go @@ -122,13 +122,25 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err for _, fc := range fallbackConfigs { pgConn, err = connect(ctx, config, fc) if err == nil { - return pgConn, nil + break } else if err, ok := err.(*PgError); ok { return nil, err } } - return nil, err + if err != nil { + return nil, err + } + + if config.AfterConnect != nil { + err := config.AfterConnect(ctx, pgConn) + if err != nil { + pgConn.conn.Close() + return nil, errors.Errorf("AfterConnect: %v", err) + } + } + + return pgConn, nil } func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { @@ -201,11 +213,11 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } case *pgproto3.ReadyForQuery: pgConn.status = connStatusIdle - if config.AfterConnect != nil { - err := config.AfterConnect(ctx, pgConn) + if config.ValidateConnect != nil { + err := config.ValidateConnect(ctx, pgConn) if err != nil { pgConn.conn.Close() - return nil, errors.Errorf("AfterConnect: %v", err) + return nil, errors.Errorf("ValidateConnect: %v", err) } } return pgConn, nil diff --git a/pgconn_test.go b/pgconn_test.go index 028d5e94..feb78641 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -187,7 +187,7 @@ func TestConnectWithFallback(t *testing.T) { closeConn(t, conn) } -func TestConnectWithAfterConnect(t *testing.T) { +func TestConnectWithValidateConnect(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) @@ -200,7 +200,7 @@ func TestConnectWithAfterConnect(t *testing.T) { } acceptConnCount := 0 - config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error { acceptConnCount++ if acceptConnCount < 2 { return errors.New("reject first conn") @@ -226,13 +226,13 @@ func TestConnectWithAfterConnect(t *testing.T) { assert.True(t, acceptConnCount > 1) } -func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { +func TestConnectWithValidateConnectTargetSessionAttrsReadWrite(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) - config.AfterConnect = pgconn.AfterConnectTargetSessionAttrsReadWrite + config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite config.RuntimeParams["default_transaction_read_only"] = "on" conn, err := pgconn.ConnectConfig(context.Background(), config) @@ -241,6 +241,27 @@ func TestConnectWithAfterConnectTargetSessionAttrsReadWrite(t *testing.T) { } } +func TestConnectWithAfterConnect(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + _, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll() + return err + } + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + + results, err := conn.Exec(context.Background(), "show search_path;").ReadAll() + require.NoError(t, err) + defer closeConn(t, conn) + + assert.Equal(t, []byte("foobar"), results[0].Rows[0][0]) +} + func TestConnPrepareSyntaxError(t *testing.T) { t.Parallel() From fa7e06489bda50794a89e7a6e60446c4cc1c2ba5 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Fri, 26 Jul 2019 11:14:07 +0300 Subject: [PATCH 104/290] Add MinReadBufferSize option to Config Signed-off-by: Artemiy Ryabinkov --- config.go | 3 +++ pgconn.go | 8 +++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/config.go b/config.go index 9b74945e..bbd458e3 100644 --- a/config.go +++ b/config.go @@ -37,6 +37,9 @@ type Config struct { Fallbacks []*FallbackConfig + // MinReadBufferSize used to configure size of connection read buffer. + MinReadBufferSize int + // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server. // It can be used validate that server is acceptable. If this returns an error the connection is closed and the next // fallback config is tried. This allows implementing high availability behavior such as libpq does with diff --git a/pgconn.go b/pgconn.go index 6e1fb7e3..5077ccae 100644 --- a/pgconn.go +++ b/pgconn.go @@ -15,6 +15,7 @@ import ( "sync" "time" + "github.com/jackc/chunkreader/v2" "github.com/jackc/pgconn/internal/ctxwatch" "github.com/jackc/pgio" "github.com/jackc/pgproto3/v2" @@ -170,7 +171,12 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig func() { pgConn.conn.SetDeadline(time.Time{}) }, ) - pgConn.Frontend, err = pgproto3.NewFrontend(pgproto3.NewChunkReader(pgConn.conn), pgConn.conn) + cr, err := chunkreader.NewConfig(pgConn.conn, chunkreader.Config{MinBufLen: config.MinReadBufferSize}) + if err != nil { + return nil, err + } + + pgConn.Frontend, err = pgproto3.NewFrontend(cr, pgConn.conn) if err != nil { return nil, err } From f0b479097a4868d74e83c938131f5a24d25c49e8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 6 Aug 2019 17:07:11 -0500 Subject: [PATCH 105/290] Fix missing deferred constraint violations in certain conditions See https://github.com/jackc/pgx/issues/570. --- pgconn.go | 5 ++- pgconn_test.go | 85 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index 6e1fb7e3..3157f17e 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1151,7 +1151,10 @@ func (rr *ResultReader) Close() (CommandTag, error) { return nil, rr.err } - switch msg.(type) { + switch msg := msg.(type) { + // Detect a deferred constraint violation where the ErrorResponse is sent after CommandComplete. + case *pgproto3.ErrorResponse: + rr.err = errorResponseToPgError(msg) case *pgproto3.ReadyForQuery: rr.pgConn.contextWatcher.Unwatch() rr.pgConn.unlock() diff --git a/pgconn_test.go b/pgconn_test.go index feb78641..1b90b9d2 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -381,6 +381,34 @@ func TestConnExecMultipleQueriesError(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnExecDeferredError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); + + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) + + _, err = pgConn.Exec(context.Background(), `update t set n=n+1 where id='b' returning *`).ReadAll() + require.NotNil(t, err) + + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) + + ensureConnValid(t, pgConn) +} + func TestConnExecContextCanceled(t *testing.T) { t.Parallel() @@ -437,6 +465,33 @@ func TestConnExecParams(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnExecParamsDeferredError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); + + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) + + result := pgConn.ExecParams(context.Background(), `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read() + require.NotNil(t, result.Err) + var pgErr *pgconn.PgError + require.True(t, errors.As(result.Err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) + + ensureConnValid(t, pgConn) +} + func TestConnExecParamsMaxNumberOfParams(t *testing.T) { t.Parallel() @@ -683,6 +738,36 @@ func TestConnExecBatch(t *testing.T) { assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) } +func TestConnExecBatchDeferredError(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); + + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) + + batch := &pgconn.Batch{} + + batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil) + _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.NotNil(t, err) + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) + + ensureConnValid(t, pgConn) +} + func TestConnExecBatchPrecanceled(t *testing.T) { t.Parallel() From 0a99b543c007eab4dd3eb284e0206eb7d8144346 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Thu, 8 Aug 2019 11:46:25 +0300 Subject: [PATCH 106/290] Add BuildFrontendFunc in Config Signed-off-by: Artemiy Ryabinkov --- config.go | 30 +++++++++++++++++++----------- go.sum | 4 ---- pgconn.go | 32 +++++++++++++++++--------------- 3 files changed, 36 insertions(+), 30 deletions(-) diff --git a/config.go b/config.go index bbd458e3..be8bdab4 100644 --- a/config.go +++ b/config.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "io" "io/ioutil" "math" "net" @@ -18,6 +19,7 @@ import ( "time" "github.com/jackc/pgpassfile" + "github.com/jackc/pgproto3/v2" errors "golang.org/x/xerrors" ) @@ -26,20 +28,18 @@ type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error // Config is the settings used to establish a connection to a PostgreSQL server. type Config struct { - Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) - Port uint16 - Database string - User string - Password string - TLSConfig *tls.Config // nil disables TLS - DialFunc DialFunc // e.g. net.Dialer.DialContext - RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) + Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) + Port uint16 + Database string + User string + Password string + TLSConfig *tls.Config // nil disables TLS + DialFunc DialFunc // e.g. net.Dialer.DialContext + BuildFrontendFunc BuildFrontendFunc + RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) Fallbacks []*FallbackConfig - // MinReadBufferSize used to configure size of connection read buffer. - MinReadBufferSize int - // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server. // It can be used validate that server is acceptable. If this returns an error the connection is closed and the next // fallback config is tried. This allows implementing high availability behavior such as libpq does with @@ -476,6 +476,14 @@ func makeDefaultDialer() *net.Dialer { return &net.Dialer{KeepAlive: 5 * time.Minute} } +func makeDefaultBuildFrontendFunc() BuildFrontendFunc { + return func(r io.Reader) Frontend { + frontend, _ := pgproto3.NewFrontend(pgproto3.NewChunkReader(r), nil) + + return frontend + } +} + func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { timeout, err := strconv.ParseInt(s, 10, 64) if err != nil { diff --git a/go.sum b/go.sum index 50dfc2fd..0e853203 100644 --- a/go.sum +++ b/go.sum @@ -8,8 +8,6 @@ github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db h1:UpaKn/gYxzH6/zWyRQH1S260zvKqwJJ4h8+Kf09ooh0= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 h1:vZp4bYotXUkFx7JUSm7U8KV/7Q0AOdrQxxBBj0ZmZsg= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= @@ -25,7 +23,5 @@ golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e h1:nFYrTHrdrAOpShe27kaFHjsqY golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 h1:PPwnA7z1Pjf7XYaBP9GL1VAMZmcIWyFz7QCMSIIa3Bg= -golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 h1:bhOzK9QyoD0ogCnFro1m2mz41+Ib0oOhfJnBp5MR4K4= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/pgconn.go b/pgconn.go index 5077ccae..e7833c1f 100644 --- a/pgconn.go +++ b/pgconn.go @@ -15,7 +15,6 @@ import ( "sync" "time" - "github.com/jackc/chunkreader/v2" "github.com/jackc/pgconn/internal/ctxwatch" "github.com/jackc/pgio" "github.com/jackc/pgproto3/v2" @@ -41,9 +40,12 @@ type Notification struct { Payload string } -// DialFunc is a function that can be used to connect to a PostgreSQL server +// DialFunc is a function that can be used to connect to a PostgreSQL server. type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) +// BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. +type BuildFrontendFunc func(r io.Reader) Frontend + // NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at // any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin // of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY @@ -56,6 +58,11 @@ type NoticeHandler func(*PgConn, *Notice) // notice event. type NotificationHandler func(*PgConn, *Notification) +// Frontend used to receive messages from backend. +type Frontend interface { + Receive() (pgproto3.BackendMessage, error) +} + // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { conn net.Conn // the underlying TCP or unix domain socket connection @@ -63,7 +70,7 @@ type PgConn struct { secretKey uint32 // key to use to send a cancel query message to the server parameterStatuses map[string]string // parameters that have been reported by the server TxStatus byte - Frontend *pgproto3.Frontend + frontend Frontend Config *Config @@ -106,6 +113,9 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err if config.DialFunc == nil { config.DialFunc = makeDefaultDialer().DialContext } + if config.BuildFrontendFunc == nil { + config.BuildFrontendFunc = makeDefaultBuildFrontendFunc() + } if config.RuntimeParams == nil { config.RuntimeParams = make(map[string]string) } @@ -171,15 +181,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig func() { pgConn.conn.SetDeadline(time.Time{}) }, ) - cr, err := chunkreader.NewConfig(pgConn.conn, chunkreader.Config{MinBufLen: config.MinReadBufferSize}) - if err != nil { - return nil, err - } - - pgConn.Frontend, err = pgproto3.NewFrontend(cr, pgConn.conn) - if err != nil { - return nil, err - } + pgConn.frontend = config.BuildFrontendFunc(pgConn.conn) startupMsg := pgproto3.StartupMessage{ ProtocolVersion: pgproto3.ProtocolVersionNumber, @@ -298,7 +300,7 @@ func (pgConn *PgConn) signalMessage() chan struct{} { ch := make(chan struct{}) go func() { - pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.Frontend.Receive() + pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive() pgConn.bufferingReceiveMux.Unlock() close(ch) }() @@ -318,10 +320,10 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { // If a timeout error happened in the background try the read again. if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - msg, err = pgConn.Frontend.Receive() + msg, err = pgConn.frontend.Receive() } } else { - msg, err = pgConn.Frontend.Receive() + msg, err = pgConn.frontend.Receive() } if err != nil { From dbb7aa8fd51b866cf601df8daf11306a9bb7c707 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Thu, 8 Aug 2019 12:52:04 +0300 Subject: [PATCH 107/290] Add GOPROXY to travis builds to mitigate problems with github and etc Signed-off-by: Artemiy Ryabinkov --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index e5ed43a8..1687adad 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,6 +11,7 @@ before_install: env: global: - GO111MODULE=on + - GOPROXY=https://proxy.golang.org - PGX_TEST_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test - PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql database=pgx_test" - PGX_TEST_TCP_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test From c9660e30c8b4f7903eaa7814789656ea79b6d173 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Thu, 8 Aug 2019 13:12:27 +0300 Subject: [PATCH 108/290] Use go mod download to install deps on travis-ci. Add cache for travis-ci. Signed-off-by: Artemiy Ryabinkov --- .travis.yml | 12 ++++++++++-- travis/install.bash | 14 -------------- 2 files changed, 10 insertions(+), 16 deletions(-) delete mode 100755 travis/install.bash diff --git a/.travis.yml b/.travis.yml index 1687adad..2c547abf 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,6 +4,9 @@ go: - 1.x - tip +git: + depth: 1 + # Derived from https://github.com/lib/pq/blob/master/.travis.yml before_install: - ./travis/before_install.bash @@ -12,6 +15,7 @@ env: global: - GO111MODULE=on - GOPROXY=https://proxy.golang.org + - GOFLAGS=-mod=readonly - PGX_TEST_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test - PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql database=pgx_test" - PGX_TEST_TCP_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test @@ -26,11 +30,15 @@ env: - PGVERSION=9.4 - PGVERSION=9.3 +cache: + directories: + - $HOME/.cache/go-build + - $HOME/gopath/pkg/mod + before_script: - ./travis/before_script.bash -install: - - ./travis/install.bash +install: go mod download script: - ./travis/script.bash diff --git a/travis/install.bash b/travis/install.bash deleted file mode 100755 index 63ba875d..00000000 --- a/travis/install.bash +++ /dev/null @@ -1,14 +0,0 @@ -#!/usr/bin/env bash -set -eux - -go get -u github.com/cockroachdb/apd -go get -u github.com/shopspring/decimal -go get -u gopkg.in/inconshreveable/log15.v2 -go get -u github.com/jackc/fake -go get -u github.com/lib/pq -go get -u github.com/hashicorp/go-version -go get -u github.com/satori/go.uuid -go get -u github.com/sirupsen/logrus -go get -u github.com/pkg/errors -go get -u go.uber.org/zap -go get -u github.com/rs/zerolog From d364370a31359546fb19828f737073b19a56f812 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 20 Aug 2019 14:11:16 -0500 Subject: [PATCH 109/290] Add SendBytes and ReceiveMessage --- auth_scram.go | 2 +- pgconn.go | 77 +++++++++++++++++++++++++++++++++++++++++++------- pgconn_test.go | 40 ++++++++++++++++++++++++++ 3 files changed, 108 insertions(+), 11 deletions(-) diff --git a/auth_scram.go b/auth_scram.go index bdaf3e92..4409a080 100644 --- a/auth_scram.go +++ b/auth_scram.go @@ -74,7 +74,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { } func (c *PgConn) rxAuthMsg(typ uint32) (*pgproto3.Authentication, error) { - msg, err := c.ReceiveMessage() + msg, err := c.receiveMessage() if err != nil { return nil, err } diff --git a/pgconn.go b/pgconn.go index 63e19ed1..abbc2d10 100644 --- a/pgconn.go +++ b/pgconn.go @@ -204,7 +204,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } for { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.conn.Close() return nil, err @@ -308,7 +308,64 @@ func (pgConn *PgConn) signalMessage() chan struct{} { return ch } -func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { +// SendBytes sends buf to the PostgreSQL server. It must only be used when the connection is not busy. e.g. It is as +// error to call SendBytes while reading the result of a query. +// +// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. +// See https://www.postgresql.org/docs/current/protocol.html. +func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { + if err := pgConn.lock(); err != nil { + return linkErrors(err, ErrNoBytesSent) + } + defer pgConn.unlock() + + select { + case <-ctx.Done(): + return linkErrors(ctx.Err(), ErrNoBytesSent) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + + n, err := pgConn.conn.Write(buf) + if err != nil { + pgConn.hardClose() + if n == 0 { + err = linkErrors(err, ErrNoBytesSent) + } + return linkErrors(ctx.Err(), err) + } + + return nil +} + +// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the +// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages +// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger +// the OnNotification callback. +// +// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. +// See https://www.postgresql.org/docs/current/protocol.html. +func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessage, error) { + if err := pgConn.lock(); err != nil { + return nil, linkErrors(err, ErrNoBytesSent) + } + defer pgConn.unlock() + + select { + case <-ctx.Done(): + return nil, linkErrors(ctx.Err(), ErrNoBytesSent) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + + msg, err := pgConn.receiveMessage() + return msg, err +} + +// receiveMessage receives a message without setting up context cancellation +func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { var msg pgproto3.BackendMessage var err error if pgConn.bufferingReceive { @@ -506,7 +563,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ readloop: for { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() return nil, linkErrors(ctx.Err(), err) @@ -616,7 +673,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { defer pgConn.contextWatcher.Unwatch() for { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { return linkErrors(ctx.Err(), err) } @@ -821,7 +878,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm var commandTag CommandTag var pgErr error for { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() return nil, linkErrors(ctx.Err(), err) @@ -882,7 +939,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co var pgErr error pendingCopyInResponse := true for pendingCopyInResponse { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() return nil, linkErrors(ctx.Err(), err) @@ -920,7 +977,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co select { case <-signalMessageChan: - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() return nil, linkErrors(ctx.Err(), err) @@ -950,7 +1007,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co // Read results for { - msg, err := pgConn.ReceiveMessage() + msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() return nil, linkErrors(ctx.Err(), err) @@ -991,7 +1048,7 @@ func (mrr *MultiResultReader) ReadAll() ([]*Result, error) { } func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) { - msg, err := mrr.pgConn.ReceiveMessage() + msg, err := mrr.pgConn.receiveMessage() if err != nil { mrr.pgConn.contextWatcher.Unwatch() @@ -1176,7 +1233,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) { if rr.multiResultReader == nil { - msg, err = rr.pgConn.ReceiveMessage() + msg, err = rr.pgConn.receiveMessage() } else { msg, err = rr.multiResultReader.receiveMessage() } diff --git a/pgconn_test.go b/pgconn_test.go index 1b90b9d2..f385bc19 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -18,6 +18,7 @@ import ( "time" "github.com/jackc/pgconn" + "github.com/jackc/pgproto3/v2" errors "golang.org/x/xerrors" "github.com/stretchr/testify/assert" @@ -1416,6 +1417,45 @@ func TestConnCancelRequest(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnSendBytesAndReceiveMessage(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + queryMsg := pgproto3.Query{String: "select 42"} + buf := queryMsg.Encode(nil) + + err = pgConn.SendBytes(ctx, buf) + require.NoError(t, err) + + msg, err := pgConn.ReceiveMessage(ctx) + require.NoError(t, err) + _, ok := msg.(*pgproto3.RowDescription) + require.True(t, ok) + + msg, err = pgConn.ReceiveMessage(ctx) + require.NoError(t, err) + _, ok = msg.(*pgproto3.DataRow) + require.True(t, ok) + + msg, err = pgConn.ReceiveMessage(ctx) + require.NoError(t, err) + _, ok = msg.(*pgproto3.CommandComplete) + require.True(t, ok) + + msg, err = pgConn.ReceiveMessage(ctx) + require.NoError(t, err) + _, ok = msg.(*pgproto3.ReadyForQuery) + require.True(t, ok) + + ensureConnValid(t, pgConn) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) if err != nil { From 11255efe7af4e7c2ab77e863f245f42f4ca6b4c5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 20 Aug 2019 15:49:57 -0500 Subject: [PATCH 110/290] Make ErrorResponseToPgError public --- pgconn.go | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/pgconn.go b/pgconn.go index abbc2d10..e51d40e8 100644 --- a/pgconn.go +++ b/pgconn.go @@ -233,7 +233,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig // handled by ReceiveMessage case *pgproto3.ErrorResponse: pgConn.conn.Close() - return nil, errorResponseToPgError(msg) + return nil, ErrorResponseToPgError(msg) default: pgConn.conn.Close() return nil, errors.New("unexpected message") @@ -400,7 +400,7 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { case *pgproto3.ErrorResponse: if msg.Severity == "FATAL" { pgConn.hardClose() - return nil, errorResponseToPgError(msg) + return nil, ErrorResponseToPgError(msg) } case *pgproto3.NoticeResponse: if pgConn.Config.OnNotice != nil { @@ -577,7 +577,7 @@ readloop: psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields)) copy(psd.Fields, msg.Fields) case *pgproto3.ErrorResponse: - parseErr = errorResponseToPgError(msg) + parseErr = ErrorResponseToPgError(msg) case *pgproto3.ReadyForQuery: break readloop } @@ -589,7 +589,8 @@ readloop: return psd, nil } -func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { +// ErrorResponseToPgError converts a wire protocol error message to a *PgError. +func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { return &PgError{ Severity: msg.Severity, Code: string(msg.Code), @@ -612,7 +613,7 @@ func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { } func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice { - pgerr := errorResponseToPgError((*pgproto3.ErrorResponse)(msg)) + pgerr := ErrorResponseToPgError((*pgproto3.ErrorResponse)(msg)) return (*Notice)(pgerr) } @@ -898,7 +899,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm case *pgproto3.CommandComplete: commandTag = CommandTag(msg.CommandTag) case *pgproto3.ErrorResponse: - pgErr = errorResponseToPgError(msg) + pgErr = ErrorResponseToPgError(msg) } } } @@ -949,7 +950,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co case *pgproto3.CopyInResponse: pendingCopyInResponse = false case *pgproto3.ErrorResponse: - pgErr = errorResponseToPgError(msg) + pgErr = ErrorResponseToPgError(msg) case *pgproto3.ReadyForQuery: return commandTag, pgErr } @@ -985,7 +986,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co switch msg := msg.(type) { case *pgproto3.ErrorResponse: - pgErr = errorResponseToPgError(msg) + pgErr = ErrorResponseToPgError(msg) } default: } @@ -1019,7 +1020,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co case *pgproto3.CommandComplete: commandTag = CommandTag(msg.CommandTag) case *pgproto3.ErrorResponse: - pgErr = errorResponseToPgError(msg) + pgErr = ErrorResponseToPgError(msg) } } } @@ -1064,7 +1065,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) mrr.closed = true mrr.pgConn.unlock() case *pgproto3.ErrorResponse: - mrr.err = errorResponseToPgError(msg) + mrr.err = ErrorResponseToPgError(msg) } return msg, nil @@ -1219,7 +1220,7 @@ func (rr *ResultReader) Close() (CommandTag, error) { switch msg := msg.(type) { // Detect a deferred constraint violation where the ErrorResponse is sent after CommandComplete. case *pgproto3.ErrorResponse: - rr.err = errorResponseToPgError(msg) + rr.err = ErrorResponseToPgError(msg) case *pgproto3.ReadyForQuery: rr.pgConn.contextWatcher.Unwatch() rr.pgConn.unlock() @@ -1255,7 +1256,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error case *pgproto3.CommandComplete: rr.concludeCommand(CommandTag(msg.CommandTag), nil) case *pgproto3.ErrorResponse: - rr.concludeCommand(nil, errorResponseToPgError(msg)) + rr.concludeCommand(nil, ErrorResponseToPgError(msg)) } return msg, nil From 1558987979c58286747e7c90ab181adc1560f027 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 22 Aug 2019 20:11:27 -0500 Subject: [PATCH 111/290] ReceiveMessage returns context error instead of io error on cancel --- pgconn.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pgconn.go b/pgconn.go index e51d40e8..5d84871b 100644 --- a/pgconn.go +++ b/pgconn.go @@ -361,6 +361,9 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa defer pgConn.contextWatcher.Unwatch() msg, err := pgConn.receiveMessage() + if err != nil { + err = linkErrors(ctx.Err(), err) + } return msg, err } From 760dd75542eb13b37333e0e134b3463efade7cb4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 09:28:44 -0500 Subject: [PATCH 112/290] Require Config to be created by ParseConfig --- config.go | 15 ++++++++++----- pgconn.go | 19 ++++++------------- pgconn_test.go | 8 ++++++++ 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/config.go b/config.go index be8bdab4..a861ff5f 100644 --- a/config.go +++ b/config.go @@ -26,7 +26,8 @@ import ( type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error -// Config is the settings used to establish a connection to a PostgreSQL server. +// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig and +// then it can be modified. A manually initialized Config will cause ConnectConfig to panic. type Config struct { Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) Port uint16 @@ -55,6 +56,8 @@ type Config struct { // OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received. OnNotification NotificationHandler + + createdByParseConfig bool // Used to enforce created by ParseConfig rule. } // FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a @@ -157,10 +160,12 @@ func ParseConfig(connString string) (*Config, error) { } config := &Config{ - Database: settings["database"], - User: settings["user"], - Password: settings["password"], - RuntimeParams: make(map[string]string), + createdByParseConfig: true, + Database: settings["database"], + User: settings["user"], + Password: settings["password"], + RuntimeParams: make(map[string]string), + BuildFrontendFunc: makeDefaultBuildFrontendFunc(), } if connectTimeout, present := settings["connect_timeout"]; present { diff --git a/pgconn.go b/pgconn.go index 5d84871b..b0e4cfd2 100644 --- a/pgconn.go +++ b/pgconn.go @@ -99,25 +99,18 @@ func Connect(ctx context.Context, connString string) (*PgConn, error) { return ConnectConfig(ctx, config) } -// Connect establishes a connection to a PostgreSQL server using config. ctx can be used to cancel a connect attempt. +// Connect establishes a connection to a PostgreSQL server using config. config must have been constructed with +// ParseConfig. ctx can be used to cancel a connect attempt. // // If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An // authentication error will terminate the chain of attempts (like libpq: // https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise, // if all attempts fail the last error is returned. func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err error) { - // For convenience set a few defaults if not already set. This makes it simpler to directly construct a config. - if config.Port == 0 { - config.Port = 5432 - } - if config.DialFunc == nil { - config.DialFunc = makeDefaultDialer().DialContext - } - if config.BuildFrontendFunc == nil { - config.BuildFrontendFunc = makeDefaultBuildFrontendFunc() - } - if config.RuntimeParams == nil { - config.RuntimeParams = make(map[string]string) + // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from + // zero values. + if !config.createdByParseConfig { + panic("config must be created by ParseConfig") } // Simplify usage by treating primary config and fallbacks the same. diff --git a/pgconn_test.go b/pgconn_test.go index f385bc19..1cd74024 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -263,6 +263,14 @@ func TestConnectWithAfterConnect(t *testing.T) { assert.Equal(t, []byte("foobar"), results[0].Rows[0][0]) } +func TestConnectConfigRequiresConfigFromParseConfig(t *testing.T) { + t.Parallel() + + config := &pgconn.Config{} + + require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgconn.ConnectConfig(context.Background(), config) }) +} + func TestConnPrepareSyntaxError(t *testing.T) { t.Parallel() From e540a0576006af74ed45bea905dbb4d8a5e320bc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 14:16:38 -0500 Subject: [PATCH 113/290] Fix typo in docs --- doc.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc.go b/doc.go index d36eb0fd..cde58cd8 100644 --- a/doc.go +++ b/doc.go @@ -15,7 +15,7 @@ reads all rows into memory. Executing Multiple Queries in a Single Round Trip -Exec and ExecBatch can execute multiple queries in a single round trip. The return readers that iterate over each query +Exec and ExecBatch can execute multiple queries in a single round trip. They return readers that iterate over each query result. The ReadAll method reads all query results into memory. Context Support From e6bd7390678ab23b1fded5035d8364e6fa704f28 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 16:02:27 -0500 Subject: [PATCH 114/290] Add pscache package --- pscache/lrucache.go | 111 ++++++++++++++++++++++++++++++++++++++ pscache/lrucache_test.go | 113 +++++++++++++++++++++++++++++++++++++++ pscache/pscache.go | 52 ++++++++++++++++++ 3 files changed, 276 insertions(+) create mode 100644 pscache/lrucache.go create mode 100644 pscache/lrucache_test.go create mode 100644 pscache/pscache.go diff --git a/pscache/lrucache.go b/pscache/lrucache.go new file mode 100644 index 00000000..d5d6062f --- /dev/null +++ b/pscache/lrucache.go @@ -0,0 +1,111 @@ +package pscache + +import ( + "container/list" + "context" + "fmt" + "sync/atomic" + + "github.com/jackc/pgconn" +) + +var lruCacheCount uint64 + +// LRUCache implements cache with a Least Recently Used (LRU) cache. +type LRUCache struct { + conn *pgconn.PgConn + mode int + cap int + prepareCount int + m map[string]*list.Element + l *list.List + psNamePrefix string +} + +// NewLRUCache creates a new LRUCache. mode is either PrepareMode or DescribeMode. cap is the maximum size of the cache. +func NewLRUCache(conn *pgconn.PgConn, mode int, cap int) *LRUCache { + mustBeValidMode(mode) + mustBeValidCap(cap) + + n := atomic.AddUint64(&lruCacheCount, 1) + + return &LRUCache{ + conn: conn, + mode: mode, + cap: cap, + m: make(map[string]*list.Element), + l: list.New(), + psNamePrefix: fmt.Sprintf("lrupsc_%d", n), + } +} + +// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. +func (c *LRUCache) Get(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) { + if el, ok := c.m[sql]; ok { + c.l.MoveToFront(el) + return el.Value.(*pgconn.PreparedStatementDescription), nil + } + + if c.l.Len() == c.cap { + err := c.removeOldest(ctx) + if err != nil { + return nil, err + } + } + + psd, err := c.prepare(ctx, sql) + if err != nil { + return nil, err + } + + el := c.l.PushFront(psd) + c.m[sql] = el + + return psd, nil +} + +// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session. +func (c *LRUCache) Clear(ctx context.Context) error { + for c.l.Len() > 0 { + err := c.removeOldest(ctx) + if err != nil { + return err + } + } + + return nil +} + +// Len returns the number of cached prepared statement descriptions. +func (c *LRUCache) Len() int { + return c.l.Len() +} + +// Cap returns the maximum number of cached prepared statement descriptions. +func (c *LRUCache) Cap() int { + return c.cap +} + +// Mode returns the mode of the cache (PrepareMode or DescribeMode) +func (c *LRUCache) Mode() int { + return c.mode +} + +func (c *LRUCache) prepare(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) { + var name string + if c.mode == PrepareMode { + name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount) + c.prepareCount += 1 + } + + return c.conn.Prepare(ctx, name, sql, nil) +} + +func (c *LRUCache) removeOldest(ctx context.Context) error { + oldest := c.l.Back() + c.l.Remove(oldest) + if c.mode == PrepareMode { + return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", oldest.Value.(*pgconn.PreparedStatementDescription).Name)).Close() + } + return nil +} diff --git a/pscache/lrucache_test.go b/pscache/lrucache_test.go new file mode 100644 index 00000000..bf2fcbe0 --- /dev/null +++ b/pscache/lrucache_test.go @@ -0,0 +1,113 @@ +package pscache_test + +import ( + "context" + "os" + "testing" + "time" + + "github.com/jackc/pgconn" + "github.com/jackc/pgconn/pscache" + + "github.com/stretchr/testify/require" +) + +func TestLRUCachePrepareMode(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer conn.Close(ctx) + + cache := pscache.NewLRUCache(conn, pscache.PrepareMode, 2) + require.EqualValues(t, 0, cache.Len()) + require.EqualValues(t, 2, cache.Cap()) + require.EqualValues(t, pscache.PrepareMode, cache.Mode()) + + psd, err := cache.Get(ctx, "select 1") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 1, cache.Len()) + require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 1") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 1, cache.Len()) + require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 2") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 2, cache.Len()) + require.ElementsMatch(t, []string{"select 1", "select 2"}, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 3") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 2, cache.Len()) + require.ElementsMatch(t, []string{"select 2", "select 3"}, fetchServerStatements(t, ctx, conn)) + + err = cache.Clear(ctx) + require.NoError(t, err) + require.EqualValues(t, 0, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) +} + +func TestLRUCacheDescribeMode(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer conn.Close(ctx) + + cache := pscache.NewLRUCache(conn, pscache.DescribeMode, 2) + require.EqualValues(t, 0, cache.Len()) + require.EqualValues(t, 2, cache.Cap()) + require.EqualValues(t, pscache.DescribeMode, cache.Mode()) + + psd, err := cache.Get(ctx, "select 1") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 1, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 1") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 1, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 2") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 2, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) + + psd, err = cache.Get(ctx, "select 3") + require.NoError(t, err) + require.NotNil(t, psd) + require.EqualValues(t, 2, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) + + err = cache.Clear(ctx) + require.NoError(t, err) + require.EqualValues(t, 0, cache.Len()) + require.Empty(t, fetchServerStatements(t, ctx, conn)) +} + +func fetchServerStatements(t testing.TB, ctx context.Context, conn *pgconn.PgConn) []string { + result := conn.ExecParams(ctx, `select statement from pg_prepared_statements`, nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + var statements []string + for _, r := range result.Rows { + statements = append(statements, string(r[0])) + } + return statements +} diff --git a/pscache/pscache.go b/pscache/pscache.go new file mode 100644 index 00000000..bfd51e81 --- /dev/null +++ b/pscache/pscache.go @@ -0,0 +1,52 @@ +// Package pscache is a cache that can be used to implement lazy, automatic prepared statements. +package pscache + +import ( + "context" + + "github.com/jackc/pgconn" +) + +const ( + PrepareMode = iota // Cache should prepare named statements. + DescribeMode // Cache should prepare the anonymous prepared statement to only fetch the description of the statement. +) + +// Cache prepares and caches prepared statement descriptions. +type Cache interface { + // Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. + Get(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) + + // Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session. + Clear(ctx context.Context) error + + // Len returns the number of cached prepared statement descriptions. + Len() int + + // Cap returns the maximum number of cached prepared statement descriptions. + Cap() int + + // Mode returns the mode of the cache (PrepareMode or DescribeMode) + Mode() int +} + +// New returns the preferred cache implementation for mode and cap. mode is either PrepareMode or DescribeMode. cap is +// the maximum size of the cache. +func New(conn *pgconn.PgConn, mode int, cap int) Cache { + mustBeValidMode(mode) + mustBeValidCap(cap) + + return NewLRUCache(conn, mode, cap) +} + +func mustBeValidMode(mode int) { + if mode != PrepareMode && mode != DescribeMode { + panic("mode must be PrepareMode or DescribeMode") + } +} + +func mustBeValidCap(cap int) { + if cap < 1 { + panic("cache must have cap of >= 1") + } +} From 797a44bf048f27e5db5c79dcbf7e406969ca6904 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 16:18:01 -0500 Subject: [PATCH 115/290] Rename BuildFrontendFunc to BuildFrontend For consistency with other functions supplied in Config. --- config.go | 20 ++++++++++---------- pgconn.go | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/config.go b/config.go index a861ff5f..b5c119f5 100644 --- a/config.go +++ b/config.go @@ -29,15 +29,15 @@ type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error // Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig and // then it can be modified. A manually initialized Config will cause ConnectConfig to panic. type Config struct { - Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) - Port uint16 - Database string - User string - Password string - TLSConfig *tls.Config // nil disables TLS - DialFunc DialFunc // e.g. net.Dialer.DialContext - BuildFrontendFunc BuildFrontendFunc - RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) + Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) + Port uint16 + Database string + User string + Password string + TLSConfig *tls.Config // nil disables TLS + DialFunc DialFunc // e.g. net.Dialer.DialContext + BuildFrontend BuildFrontendFunc + RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) Fallbacks []*FallbackConfig @@ -165,7 +165,7 @@ func ParseConfig(connString string) (*Config, error) { User: settings["user"], Password: settings["password"], RuntimeParams: make(map[string]string), - BuildFrontendFunc: makeDefaultBuildFrontendFunc(), + BuildFrontend: makeDefaultBuildFrontendFunc(), } if connectTimeout, present := settings["connect_timeout"]; present { diff --git a/pgconn.go b/pgconn.go index b0e4cfd2..fe2f304e 100644 --- a/pgconn.go +++ b/pgconn.go @@ -174,7 +174,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig func() { pgConn.conn.SetDeadline(time.Time{}) }, ) - pgConn.frontend = config.BuildFrontendFunc(pgConn.conn) + pgConn.frontend = config.BuildFrontend(pgConn.conn) startupMsg := pgproto3.StartupMessage{ ProtocolVersion: pgproto3.ProtocolVersionNumber, From 2209d2e36aea43ee17610489a2644af2212a4bc3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 16:27:54 -0500 Subject: [PATCH 116/290] Rename mode constants --- pscache/lrucache.go | 8 ++++---- pscache/lrucache_test.go | 12 ++++++------ pscache/pscache.go | 12 ++++++------ 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/pscache/lrucache.go b/pscache/lrucache.go index d5d6062f..cdcec63c 100644 --- a/pscache/lrucache.go +++ b/pscache/lrucache.go @@ -22,7 +22,7 @@ type LRUCache struct { psNamePrefix string } -// NewLRUCache creates a new LRUCache. mode is either PrepareMode or DescribeMode. cap is the maximum size of the cache. +// NewLRUCache creates a new LRUCache. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache. func NewLRUCache(conn *pgconn.PgConn, mode int, cap int) *LRUCache { mustBeValidMode(mode) mustBeValidCap(cap) @@ -86,14 +86,14 @@ func (c *LRUCache) Cap() int { return c.cap } -// Mode returns the mode of the cache (PrepareMode or DescribeMode) +// Mode returns the mode of the cache (ModePrepare or ModeDescribe) func (c *LRUCache) Mode() int { return c.mode } func (c *LRUCache) prepare(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) { var name string - if c.mode == PrepareMode { + if c.mode == ModePrepare { name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount) c.prepareCount += 1 } @@ -104,7 +104,7 @@ func (c *LRUCache) prepare(ctx context.Context, sql string) (*pgconn.PreparedSta func (c *LRUCache) removeOldest(ctx context.Context) error { oldest := c.l.Back() c.l.Remove(oldest) - if c.mode == PrepareMode { + if c.mode == ModePrepare { return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", oldest.Value.(*pgconn.PreparedStatementDescription).Name)).Close() } return nil diff --git a/pscache/lrucache_test.go b/pscache/lrucache_test.go index bf2fcbe0..a5d413e3 100644 --- a/pscache/lrucache_test.go +++ b/pscache/lrucache_test.go @@ -12,7 +12,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestLRUCachePrepareMode(t *testing.T) { +func TestLRUCacheModePrepare(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) @@ -22,10 +22,10 @@ func TestLRUCachePrepareMode(t *testing.T) { require.NoError(t, err) defer conn.Close(ctx) - cache := pscache.NewLRUCache(conn, pscache.PrepareMode, 2) + cache := pscache.NewLRUCache(conn, pscache.ModePrepare, 2) require.EqualValues(t, 0, cache.Len()) require.EqualValues(t, 2, cache.Cap()) - require.EqualValues(t, pscache.PrepareMode, cache.Mode()) + require.EqualValues(t, pscache.ModePrepare, cache.Mode()) psd, err := cache.Get(ctx, "select 1") require.NoError(t, err) @@ -57,7 +57,7 @@ func TestLRUCachePrepareMode(t *testing.T) { require.Empty(t, fetchServerStatements(t, ctx, conn)) } -func TestLRUCacheDescribeMode(t *testing.T) { +func TestLRUCacheModeDescribe(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) @@ -67,10 +67,10 @@ func TestLRUCacheDescribeMode(t *testing.T) { require.NoError(t, err) defer conn.Close(ctx) - cache := pscache.NewLRUCache(conn, pscache.DescribeMode, 2) + cache := pscache.NewLRUCache(conn, pscache.ModeDescribe, 2) require.EqualValues(t, 0, cache.Len()) require.EqualValues(t, 2, cache.Cap()) - require.EqualValues(t, pscache.DescribeMode, cache.Mode()) + require.EqualValues(t, pscache.ModeDescribe, cache.Mode()) psd, err := cache.Get(ctx, "select 1") require.NoError(t, err) diff --git a/pscache/pscache.go b/pscache/pscache.go index bfd51e81..4f8cf723 100644 --- a/pscache/pscache.go +++ b/pscache/pscache.go @@ -8,8 +8,8 @@ import ( ) const ( - PrepareMode = iota // Cache should prepare named statements. - DescribeMode // Cache should prepare the anonymous prepared statement to only fetch the description of the statement. + ModePrepare = iota // Cache should prepare named statements. + ModeDescribe // Cache should prepare the anonymous prepared statement to only fetch the description of the statement. ) // Cache prepares and caches prepared statement descriptions. @@ -26,11 +26,11 @@ type Cache interface { // Cap returns the maximum number of cached prepared statement descriptions. Cap() int - // Mode returns the mode of the cache (PrepareMode or DescribeMode) + // Mode returns the mode of the cache (ModePrepare or ModeDescribe) Mode() int } -// New returns the preferred cache implementation for mode and cap. mode is either PrepareMode or DescribeMode. cap is +// New returns the preferred cache implementation for mode and cap. mode is either ModePrepare or ModeDescribe. cap is // the maximum size of the cache. func New(conn *pgconn.PgConn, mode int, cap int) Cache { mustBeValidMode(mode) @@ -40,8 +40,8 @@ func New(conn *pgconn.PgConn, mode int, cap int) Cache { } func mustBeValidMode(mode int) { - if mode != PrepareMode && mode != DescribeMode { - panic("mode must be PrepareMode or DescribeMode") + if mode != ModePrepare && mode != ModeDescribe { + panic("mode must be ModePrepare or ModeDescribe") } } From beba629bb5d526f8d7de6ec8754090d39b476757 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 17:18:29 -0500 Subject: [PATCH 117/290] Fix result reader returned by locked conn --- pgconn.go | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pgconn.go b/pgconn.go index fe2f304e..797080bd 100644 --- a/pgconn.go +++ b/pgconn.go @@ -791,19 +791,18 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa } func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader { - if err := pgConn.lock(); err != nil { - return &ResultReader{ - closed: true, - err: linkErrors(err, ErrNoBytesSent), - } - } - pgConn.resultReader = ResultReader{ pgConn: pgConn, ctx: ctx, } result := &pgConn.resultReader + if err := pgConn.lock(); err != nil { + result.concludeCommand(nil, linkErrors(err, ErrNoBytesSent)) + result.closed = true + return result + } + if len(paramValues) > math.MaxUint16 { result.concludeCommand(nil, errors.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) result.closed = true From bcd6b9244ab8fc80e85b75b604bf214f82345e59 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 19:46:14 -0500 Subject: [PATCH 118/290] Rename pscache to stmtcache --- {pscache => stmtcache}/lrucache.go | 2 +- {pscache => stmtcache}/lrucache_test.go | 12 ++++++------ pscache/pscache.go => stmtcache/stmtcache.go | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) rename {pscache => stmtcache}/lrucache.go (99%) rename {pscache => stmtcache}/lrucache_test.go (90%) rename pscache/pscache.go => stmtcache/stmtcache.go (92%) diff --git a/pscache/lrucache.go b/stmtcache/lrucache.go similarity index 99% rename from pscache/lrucache.go rename to stmtcache/lrucache.go index cdcec63c..9c4d046d 100644 --- a/pscache/lrucache.go +++ b/stmtcache/lrucache.go @@ -1,4 +1,4 @@ -package pscache +package stmtcache import ( "container/list" diff --git a/pscache/lrucache_test.go b/stmtcache/lrucache_test.go similarity index 90% rename from pscache/lrucache_test.go rename to stmtcache/lrucache_test.go index a5d413e3..ed8ebdc3 100644 --- a/pscache/lrucache_test.go +++ b/stmtcache/lrucache_test.go @@ -1,4 +1,4 @@ -package pscache_test +package stmtcache_test import ( "context" @@ -7,7 +7,7 @@ import ( "time" "github.com/jackc/pgconn" - "github.com/jackc/pgconn/pscache" + "github.com/jackc/pgconn/stmtcache" "github.com/stretchr/testify/require" ) @@ -22,10 +22,10 @@ func TestLRUCacheModePrepare(t *testing.T) { require.NoError(t, err) defer conn.Close(ctx) - cache := pscache.NewLRUCache(conn, pscache.ModePrepare, 2) + cache := stmtcache.NewLRUCache(conn, stmtcache.ModePrepare, 2) require.EqualValues(t, 0, cache.Len()) require.EqualValues(t, 2, cache.Cap()) - require.EqualValues(t, pscache.ModePrepare, cache.Mode()) + require.EqualValues(t, stmtcache.ModePrepare, cache.Mode()) psd, err := cache.Get(ctx, "select 1") require.NoError(t, err) @@ -67,10 +67,10 @@ func TestLRUCacheModeDescribe(t *testing.T) { require.NoError(t, err) defer conn.Close(ctx) - cache := pscache.NewLRUCache(conn, pscache.ModeDescribe, 2) + cache := stmtcache.NewLRUCache(conn, stmtcache.ModeDescribe, 2) require.EqualValues(t, 0, cache.Len()) require.EqualValues(t, 2, cache.Cap()) - require.EqualValues(t, pscache.ModeDescribe, cache.Mode()) + require.EqualValues(t, stmtcache.ModeDescribe, cache.Mode()) psd, err := cache.Get(ctx, "select 1") require.NoError(t, err) diff --git a/pscache/pscache.go b/stmtcache/stmtcache.go similarity index 92% rename from pscache/pscache.go rename to stmtcache/stmtcache.go index 4f8cf723..d70f277b 100644 --- a/pscache/pscache.go +++ b/stmtcache/stmtcache.go @@ -1,5 +1,5 @@ -// Package pscache is a cache that can be used to implement lazy, automatic prepared statements. -package pscache +// Package stmtcache is a cache that can be used to implement lazy prepared statements. +package stmtcache import ( "context" From 78abbdf1d7eef6b2aa78831c31141057876537f6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 19:48:43 -0500 Subject: [PATCH 119/290] Rename LRUCache to LRU --- stmtcache/{lrucache.go => lru.go} | 28 ++++++++++----------- stmtcache/{lrucache_test.go => lru_test.go} | 8 +++--- stmtcache/stmtcache.go | 2 +- 3 files changed, 19 insertions(+), 19 deletions(-) rename stmtcache/{lrucache.go => lru.go} (70%) rename stmtcache/{lrucache_test.go => lru_test.go} (93%) diff --git a/stmtcache/lrucache.go b/stmtcache/lru.go similarity index 70% rename from stmtcache/lrucache.go rename to stmtcache/lru.go index 9c4d046d..432a70b4 100644 --- a/stmtcache/lrucache.go +++ b/stmtcache/lru.go @@ -9,10 +9,10 @@ import ( "github.com/jackc/pgconn" ) -var lruCacheCount uint64 +var lruCount uint64 -// LRUCache implements cache with a Least Recently Used (LRU) cache. -type LRUCache struct { +// LRU implements Cache with a Least Recently Used (LRU) cache. +type LRU struct { conn *pgconn.PgConn mode int cap int @@ -22,14 +22,14 @@ type LRUCache struct { psNamePrefix string } -// NewLRUCache creates a new LRUCache. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache. -func NewLRUCache(conn *pgconn.PgConn, mode int, cap int) *LRUCache { +// NewLRU creates a new LRU. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache. +func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU { mustBeValidMode(mode) mustBeValidCap(cap) - n := atomic.AddUint64(&lruCacheCount, 1) + n := atomic.AddUint64(&lruCount, 1) - return &LRUCache{ + return &LRU{ conn: conn, mode: mode, cap: cap, @@ -40,7 +40,7 @@ func NewLRUCache(conn *pgconn.PgConn, mode int, cap int) *LRUCache { } // Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. -func (c *LRUCache) Get(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) { +func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) { if el, ok := c.m[sql]; ok { c.l.MoveToFront(el) return el.Value.(*pgconn.PreparedStatementDescription), nil @@ -65,7 +65,7 @@ func (c *LRUCache) Get(ctx context.Context, sql string) (*pgconn.PreparedStateme } // Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session. -func (c *LRUCache) Clear(ctx context.Context) error { +func (c *LRU) Clear(ctx context.Context) error { for c.l.Len() > 0 { err := c.removeOldest(ctx) if err != nil { @@ -77,21 +77,21 @@ func (c *LRUCache) Clear(ctx context.Context) error { } // Len returns the number of cached prepared statement descriptions. -func (c *LRUCache) Len() int { +func (c *LRU) Len() int { return c.l.Len() } // Cap returns the maximum number of cached prepared statement descriptions. -func (c *LRUCache) Cap() int { +func (c *LRU) Cap() int { return c.cap } // Mode returns the mode of the cache (ModePrepare or ModeDescribe) -func (c *LRUCache) Mode() int { +func (c *LRU) Mode() int { return c.mode } -func (c *LRUCache) prepare(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) { +func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) { var name string if c.mode == ModePrepare { name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount) @@ -101,7 +101,7 @@ func (c *LRUCache) prepare(ctx context.Context, sql string) (*pgconn.PreparedSta return c.conn.Prepare(ctx, name, sql, nil) } -func (c *LRUCache) removeOldest(ctx context.Context) error { +func (c *LRU) removeOldest(ctx context.Context) error { oldest := c.l.Back() c.l.Remove(oldest) if c.mode == ModePrepare { diff --git a/stmtcache/lrucache_test.go b/stmtcache/lru_test.go similarity index 93% rename from stmtcache/lrucache_test.go rename to stmtcache/lru_test.go index ed8ebdc3..b518364e 100644 --- a/stmtcache/lrucache_test.go +++ b/stmtcache/lru_test.go @@ -12,7 +12,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestLRUCacheModePrepare(t *testing.T) { +func TestLRUModePrepare(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) @@ -22,7 +22,7 @@ func TestLRUCacheModePrepare(t *testing.T) { require.NoError(t, err) defer conn.Close(ctx) - cache := stmtcache.NewLRUCache(conn, stmtcache.ModePrepare, 2) + cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2) require.EqualValues(t, 0, cache.Len()) require.EqualValues(t, 2, cache.Cap()) require.EqualValues(t, stmtcache.ModePrepare, cache.Mode()) @@ -57,7 +57,7 @@ func TestLRUCacheModePrepare(t *testing.T) { require.Empty(t, fetchServerStatements(t, ctx, conn)) } -func TestLRUCacheModeDescribe(t *testing.T) { +func TestLRUModeDescribe(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) @@ -67,7 +67,7 @@ func TestLRUCacheModeDescribe(t *testing.T) { require.NoError(t, err) defer conn.Close(ctx) - cache := stmtcache.NewLRUCache(conn, stmtcache.ModeDescribe, 2) + cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2) require.EqualValues(t, 0, cache.Len()) require.EqualValues(t, 2, cache.Cap()) require.EqualValues(t, stmtcache.ModeDescribe, cache.Mode()) diff --git a/stmtcache/stmtcache.go b/stmtcache/stmtcache.go index d70f277b..9bedf549 100644 --- a/stmtcache/stmtcache.go +++ b/stmtcache/stmtcache.go @@ -36,7 +36,7 @@ func New(conn *pgconn.PgConn, mode int, cap int) Cache { mustBeValidMode(mode) mustBeValidCap(cap) - return NewLRUCache(conn, mode, cap) + return NewLRU(conn, mode, cap) } func mustBeValidMode(mode int) { From da9fc85c4404a53f910e2f8210be5add1bc50454 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 20:39:01 -0500 Subject: [PATCH 120/290] Rename PreparedStatementDescription to StatementDescription PreparedStatementDescription was too long. It also no longer entirely represents its purpose now that it is also intended for use with described statements. --- pgconn.go | 9 +++++---- stmtcache/lru.go | 8 ++++---- stmtcache/stmtcache.go | 2 +- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/pgconn.go b/pgconn.go index 797080bd..8f3291f1 100644 --- a/pgconn.go +++ b/pgconn.go @@ -517,15 +517,16 @@ func (ct CommandTag) String() string { return string(ct) } -type PreparedStatementDescription struct { +type StatementDescription struct { Name string SQL string ParamOIDs []uint32 Fields []pgproto3.FieldDescription } -// Prepare creates a prepared statement. -func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) { +// Prepare creates a prepared statement. If the name is empty, the anonymous prepared statement will be used. This +// allows Prepare to also to describe statements without creating a server-side prepared statement. +func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) { if err := pgConn.lock(); err != nil { return nil, linkErrors(err, ErrNoBytesSent) } @@ -553,7 +554,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ return nil, linkErrors(ctx.Err(), err) } - psd := &PreparedStatementDescription{Name: name, SQL: sql} + psd := &StatementDescription{Name: name, SQL: sql} var parseErr error diff --git a/stmtcache/lru.go b/stmtcache/lru.go index 432a70b4..fff4d0b7 100644 --- a/stmtcache/lru.go +++ b/stmtcache/lru.go @@ -40,10 +40,10 @@ func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU { } // Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. -func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) { +func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) { if el, ok := c.m[sql]; ok { c.l.MoveToFront(el) - return el.Value.(*pgconn.PreparedStatementDescription), nil + return el.Value.(*pgconn.StatementDescription), nil } if c.l.Len() == c.cap { @@ -91,7 +91,7 @@ func (c *LRU) Mode() int { return c.mode } -func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) { +func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.StatementDescription, error) { var name string if c.mode == ModePrepare { name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount) @@ -105,7 +105,7 @@ func (c *LRU) removeOldest(ctx context.Context) error { oldest := c.l.Back() c.l.Remove(oldest) if c.mode == ModePrepare { - return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", oldest.Value.(*pgconn.PreparedStatementDescription).Name)).Close() + return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", oldest.Value.(*pgconn.StatementDescription).Name)).Close() } return nil } diff --git a/stmtcache/stmtcache.go b/stmtcache/stmtcache.go index 9bedf549..96215799 100644 --- a/stmtcache/stmtcache.go +++ b/stmtcache/stmtcache.go @@ -15,7 +15,7 @@ const ( // Cache prepares and caches prepared statement descriptions. type Cache interface { // Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. - Get(ctx context.Context, sql string) (*pgconn.PreparedStatementDescription, error) + Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) // Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session. Clear(ctx context.Context) error From 6feea0c1c57d8ec5ff0cd806354437ed03b415f6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 23:43:26 -0500 Subject: [PATCH 121/290] Replace IsAlive with IsClosed IsAlive is ambiguous because the connection may be dead and we do not know it. It implies the possibility of a ping. IsClosed is clearer -- it does not promise the connection is alive only that it hasn't been closed. fixes #2 --- pgconn.go | 7 +++---- pgconn_test.go | 10 +++++----- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/pgconn.go b/pgconn.go index 8f3291f1..153829ca 100644 --- a/pgconn.go +++ b/pgconn.go @@ -463,10 +463,9 @@ func (pgConn *PgConn) hardClose() error { return pgConn.conn.Close() } -// TODO - rethink how to report status. At the moment this is just a temporary measure so pgx.Conn can detect death of -// underlying connection. -func (pgConn *PgConn) IsAlive() bool { - return pgConn.status >= connStatusIdle +// IsClosed reports if the connection has been closed. +func (pgConn *PgConn) IsClosed() bool { + return pgConn.status < connStatusIdle } // lock locks the connection. It panics if the connection is already locked or is closed. diff --git a/pgconn_test.go b/pgconn_test.go index 1cd74024..64628262 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -433,7 +433,7 @@ func TestConnExecContextCanceled(t *testing.T) { } err = multiResult.Close() assert.Equal(t, context.DeadlineExceeded, err) - assert.False(t, pgConn.IsAlive()) + assert.True(t, pgConn.IsClosed()) } func TestConnExecContextPrecanceled(t *testing.T) { @@ -566,7 +566,7 @@ func TestConnExecParamsCanceled(t *testing.T) { assert.Equal(t, pgconn.CommandTag(nil), commandTag) assert.Equal(t, context.DeadlineExceeded, err) - assert.False(t, pgConn.IsAlive()) + assert.True(t, pgConn.IsClosed()) } func TestConnExecParamsPrecanceled(t *testing.T) { @@ -692,7 +692,7 @@ func TestConnExecPreparedCanceled(t *testing.T) { commandTag, err := result.Close() assert.Equal(t, pgconn.CommandTag(nil), commandTag) assert.Equal(t, context.DeadlineExceeded, err) - assert.False(t, pgConn.IsAlive()) + assert.True(t, pgConn.IsClosed()) } func TestConnExecPreparedPrecanceled(t *testing.T) { @@ -1142,7 +1142,7 @@ func TestConnCopyToCanceled(t *testing.T) { assert.True(t, errors.Is(err, context.DeadlineExceeded)) assert.Equal(t, pgconn.CommandTag(nil), res) - assert.False(t, pgConn.IsAlive()) + assert.True(t, pgConn.IsClosed()) } func TestConnCopyToPrecanceled(t *testing.T) { @@ -1233,7 +1233,7 @@ func TestConnCopyFromCanceled(t *testing.T) { assert.Equal(t, int64(0), ct.RowsAffected()) assert.True(t, errors.Is(err, context.DeadlineExceeded)) - assert.False(t, pgConn.IsAlive()) + assert.True(t, pgConn.IsClosed()) } func TestConnCopyFromPrecanceled(t *testing.T) { From 595d09d6f1bfba423db8d00f61efebf0aaa6a85a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 23:57:24 -0500 Subject: [PATCH 122/290] Build fully operational Frontend --- config.go | 4 ++-- pgconn.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/config.go b/config.go index b5c119f5..e078a061 100644 --- a/config.go +++ b/config.go @@ -482,8 +482,8 @@ func makeDefaultDialer() *net.Dialer { } func makeDefaultBuildFrontendFunc() BuildFrontendFunc { - return func(r io.Reader) Frontend { - frontend, _ := pgproto3.NewFrontend(pgproto3.NewChunkReader(r), nil) + return func(r io.Reader, w io.Writer) Frontend { + frontend, _ := pgproto3.NewFrontend(pgproto3.NewChunkReader(r), w) return frontend } diff --git a/pgconn.go b/pgconn.go index 153829ca..7d301af2 100644 --- a/pgconn.go +++ b/pgconn.go @@ -44,7 +44,7 @@ type Notification struct { type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) // BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. -type BuildFrontendFunc func(r io.Reader) Frontend +type BuildFrontendFunc func(r io.Reader, w io.Writer) Frontend // NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at // any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin @@ -174,7 +174,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig func() { pgConn.conn.SetDeadline(time.Time{}) }, ) - pgConn.frontend = config.BuildFrontend(pgConn.conn) + pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn) startupMsg := pgproto3.StartupMessage{ ProtocolVersion: pgproto3.ProtocolVersionNumber, From e6cf51b304f1d6961663ede4ba89be363fc54237 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 25 Aug 2019 00:22:32 -0500 Subject: [PATCH 123/290] Expose min_read_buffer_size config param --- config.go | 24 +++++++++++++++++++++--- config_test.go | 12 ++++++++++++ go.mod | 1 + 3 files changed, 34 insertions(+), 3 deletions(-) diff --git a/config.go b/config.go index e078a061..cb153c77 100644 --- a/config.go +++ b/config.go @@ -18,6 +18,7 @@ import ( "strings" "time" + "github.com/jackc/chunkreader/v2" "github.com/jackc/pgpassfile" "github.com/jackc/pgproto3/v2" errors "golang.org/x/xerrors" @@ -140,6 +141,11 @@ func NetworkAddress(host string, port uint16) (network, address string) { // // When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn // does not. +// +// In addition, ParseConfig accepts the following options: +// +// min_read_buffer_size +// The minimum size of the internal read buffer. Default 8192. func ParseConfig(connString string) (*Config, error) { settings := defaultSettings() addEnvSettings(settings) @@ -159,13 +165,18 @@ func ParseConfig(connString string) (*Config, error) { } } + minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32) + if err != nil { + return nil, errors.Errorf("cannot parse min_read_buffer_size: %w", err) + } + config := &Config{ createdByParseConfig: true, Database: settings["database"], User: settings["user"], Password: settings["password"], RuntimeParams: make(map[string]string), - BuildFrontend: makeDefaultBuildFrontendFunc(), + BuildFrontend: makeDefaultBuildFrontendFunc(int(minReadBufferSize)), } if connectTimeout, present := settings["connect_timeout"]; present { @@ -192,6 +203,7 @@ func ParseConfig(connString string) (*Config, error) { "sslcert": struct{}{}, "sslrootcert": struct{}{}, "target_session_attrs": struct{}{}, + "min_read_buffer_size": struct{}{}, } for k, v := range settings { @@ -284,6 +296,8 @@ func defaultSettings() map[string]string { settings["target_session_attrs"] = "any" + settings["min_read_buffer_size"] = "8192" + return settings } @@ -481,9 +495,13 @@ func makeDefaultDialer() *net.Dialer { return &net.Dialer{KeepAlive: 5 * time.Minute} } -func makeDefaultBuildFrontendFunc() BuildFrontendFunc { +func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc { return func(r io.Reader, w io.Writer) Frontend { - frontend, _ := pgproto3.NewFrontend(pgproto3.NewChunkReader(r), w) + cr, err := chunkreader.NewConfig(r, chunkreader.Config{MinBufLen: minBufferLen}) + if err != nil { + panic(fmt.Sprintf("BUG: chunkreader.NewConfig failed: %v", err)) + } + frontend, _ := pgproto3.NewFrontend(cr, w) return frontend } diff --git a/config_test.go b/config_test.go index 23d86529..af42094d 100644 --- a/config_test.go +++ b/config_test.go @@ -561,3 +561,15 @@ func TestParseConfigReadsPgPassfile(t *testing.T) { assertConfigsEqual(t, expected, actual, "passfile") } + +func TestParseConfigExtractsMinReadBufferSize(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig("min_read_buffer_size=0") + require.NoError(t, err) + _, present := config.RuntimeParams["min_read_buffer_size"] + require.False(t, present) + + // The buffer size is internal so there isn't much that can be done to test it other than see that the runtime param + // was removed. +} diff --git a/go.mod b/go.mod index b1c84049..cbeef02a 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/jackc/pgconn go 1.12 require ( + github.com/jackc/chunkreader/v2 v2.0.0 github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 From 138254da5b02b80a548f7858f01636f9a426b918 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 27 Aug 2019 18:01:59 -0500 Subject: [PATCH 124/290] Refactor errors - Use strongly typed errors internally - SafeToRetry(error) streamlines retry logic over ErrNoBytesSent - Timeout(error) removes the need to choose between returning a context and an i/o error --- config.go | 14 ++--- errors.go | 160 +++++++++++++++++++++++++++++++++++-------------- pgconn.go | 125 +++++++++++++++++--------------------- pgconn_test.go | 41 ++++++------- 4 files changed, 197 insertions(+), 143 deletions(-) diff --git a/config.go b/config.go index cb153c77..d24d0202 100644 --- a/config.go +++ b/config.go @@ -155,19 +155,19 @@ func ParseConfig(connString string) (*Config, error) { if strings.HasPrefix(connString, "postgres://") { err := addURLSettings(settings, connString) if err != nil { - return nil, err + return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err} } } else { err := addDSNSettings(settings, connString) if err != nil { - return nil, err + return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err} } } } minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32) if err != nil { - return nil, errors.Errorf("cannot parse min_read_buffer_size: %w", err) + return nil, &parseConfigError{connString: connString, msg: "cannot parse min_read_buffer_size", err: err} } config := &Config{ @@ -182,7 +182,7 @@ func ParseConfig(connString string) (*Config, error) { if connectTimeout, present := settings["connect_timeout"]; present { dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout) if err != nil { - return nil, err + return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err} } config.DialFunc = dialFunc } else { @@ -228,7 +228,7 @@ func ParseConfig(connString string) (*Config, error) { port, err := parsePort(portStr) if err != nil { - return nil, errors.Errorf("invalid port: %w", err) + return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err} } var tlsConfigs []*tls.Config @@ -240,7 +240,7 @@ func ParseConfig(connString string) (*Config, error) { var err error tlsConfigs, err = configTLS(settings) if err != nil { - return nil, err + return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err} } } @@ -273,7 +273,7 @@ func ParseConfig(connString string) (*Config, error) { if settings["target_session_attrs"] == "read-write" { config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite } else if settings["target_session_attrs"] != "any" { - return nil, errors.Errorf("unknown target_session_attrs value: %v", settings["target_session_attrs"]) + return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", settings["target_session_attrs"])} } return config, nil diff --git a/errors.go b/errors.go index 4f8af407..a088dcdd 100644 --- a/errors.go +++ b/errors.go @@ -2,22 +2,31 @@ package pgconn import ( "context" + "fmt" "net" + "strings" errors "golang.org/x/xerrors" ) -// ErrTLSRefused occurs when the connection attempt requires TLS and the -// PostgreSQL server refuses to use TLS -var ErrTLSRefused = errors.New("server refused TLS connection") +// SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server. +func SafeToRetry(err error) bool { + if e, ok := err.(interface{ SafeToRetry() bool }); ok { + return e.SafeToRetry() + } + return false +} -// ErrConnBusy occurs when the connection is busy (for example, in the middle of reading query results) and another -// action is attempted. -var ErrConnBusy = errors.New("conn is busy") +// Timeout checks if err was was caused by a timeout. To be specific, it is true if err is or was caused by a +// context.Canceled, context.Canceled or an implementer of net.Error where Timeout() is true. +func Timeout(err error) bool { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return true + } -// ErrNoBytesSent is used to annotate an error that occurred without sending any bytes to the server. This can be used -// to implement safe retry logic. ErrNoBytesSent will never occur alone. It will always be wrapped by another error. -var ErrNoBytesSent = errors.New("no bytes sent to server") + var netErr net.Error + return errors.As(err, &netErr) && netErr.Timeout() +} // PgError represents an error reported by the PostgreSQL server. See // http://www.postgresql.org/docs/11/static/protocol-error-fields.html for @@ -46,44 +55,107 @@ func (pe *PgError) Error() string { return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" } -// linkedError connects two errors as if err wrapped next. -type linkedError struct { - err error - next error +type connectError struct { + config *Config + msg string + err error } -func (le *linkedError) Error() string { - return le.err.Error() -} - -func (le *linkedError) Is(target error) bool { - return errors.Is(le.err, target) -} - -func (le *linkedError) As(target interface{}) bool { - return errors.As(le.err, target) -} - -func (le *linkedError) Unwrap() error { - return le.next -} - -// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == -// true. Otherwise returns err. -func preferContextOverNetTimeoutError(ctx context.Context, err error) error { - if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil { - return ctx.Err() +func (e *connectError) Error() string { + sb := &strings.Builder{} + fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg) + if e.err != nil { + fmt.Fprintf(sb, " (%s)", e.err.Error()) } - return err + return sb.String() } -// linkErrors connects outer and inner as if the the fully unwrapped outer wrapped inner. If either outer or inner is nil then the other is returned. -func linkErrors(outer, inner error) error { - if outer == nil { - return inner - } - if inner == nil { - return outer - } - return &linkedError{err: outer, next: inner} +func (e *connectError) Unwrap() error { + return e.err +} + +type connLockError struct { + status string +} + +func (e *connLockError) SafeToRetry() bool { + return true // a lock failure by definition happens before the connection is used. +} + +func (e *connLockError) Error() string { + return e.status +} + +type parseConfigError struct { + connString string + msg string + err error +} + +func (e *parseConfigError) Error() string { + if e.err == nil { + return fmt.Sprintf("cannot parse `%s`: %s", e.connString, e.msg) + } + return fmt.Sprintf("cannot parse `%s`: %s (%s)", e.connString, e.msg, e.err.Error()) +} + +func (e *parseConfigError) Unwrap() error { + return e.err +} + +type pgconnError struct { + msg string + err error + safeToRetry bool +} + +func (e *pgconnError) Error() string { + if e.msg == "" { + return e.err.Error() + } + if e.err == nil { + return e.msg + } + return fmt.Sprintf("%s: %s", e.msg, e.err.Error()) +} + +func (e *pgconnError) SafeToRetry() bool { + return e.safeToRetry +} + +func (e *pgconnError) Unwrap() error { + return e.err +} + +type contextAlreadyDoneError struct { + err error +} + +func (e *contextAlreadyDoneError) Error() string { + return fmt.Sprintf("context already done: %s", e.err.Error()) +} + +func (e *contextAlreadyDoneError) SafeToRetry() bool { + return true +} + +func (e *contextAlreadyDoneError) Unwrap() error { + return e.err +} + +type writeError struct { + err error + safeToRetry bool +} + +func (e *writeError) Error() string { + return fmt.Sprintf("write failed: %s", e.err.Error()) +} + +func (e *writeError) SafeToRetry() bool { + return e.safeToRetry +} + +func (e *writeError) Unwrap() error { + return e.err } diff --git a/pgconn.go b/pgconn.go index 7d301af2..347acf80 100644 --- a/pgconn.go +++ b/pgconn.go @@ -128,19 +128,19 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err if err == nil { break } else if err, ok := err.(*PgError); ok { - return nil, err + return nil, &connectError{config: config, msg: "server error", err: err} } } if err != nil { - return nil, err + return nil, err // no need to wrap in connectError because it will already be wrapped in all cases except PgError } if config.AfterConnect != nil { err := config.AfterConnect(ctx, pgConn) if err != nil { pgConn.conn.Close() - return nil, errors.Errorf("AfterConnect: %v", err) + return nil, &connectError{config: config, msg: "AfterConnect error", err: err} } } @@ -156,7 +156,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) pgConn.conn, err = config.DialFunc(ctx, network, address) if err != nil { - return nil, err + return nil, &connectError{config: config, msg: "dial error", err: err} } pgConn.parameterStatuses = make(map[string]string) @@ -164,7 +164,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if fallbackConfig.TLSConfig != nil { if err := pgConn.startTLS(fallbackConfig.TLSConfig); err != nil { pgConn.conn.Close() - return nil, err + return nil, &connectError{config: config, msg: "tls error", err: err} } } @@ -193,14 +193,17 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if _, err := pgConn.conn.Write(startupMsg.Encode(pgConn.wbuf)); err != nil { pgConn.conn.Close() - return nil, err + return nil, &connectError{config: config, msg: "failed to write startup message", err: err} } for { msg, err := pgConn.receiveMessage() if err != nil { pgConn.conn.Close() - return nil, err + if err, ok := err.(*PgError); ok { + return nil, err + } + return nil, &connectError{config: config, msg: "failed to receive message", err: err} } switch msg := msg.(type) { @@ -210,7 +213,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig case *pgproto3.Authentication: if err = pgConn.rxAuthenticationX(msg); err != nil { pgConn.conn.Close() - return nil, err + return nil, &connectError{config: config, msg: "failed handle authentication message", err: err} } case *pgproto3.ReadyForQuery: pgConn.status = connStatusIdle @@ -218,7 +221,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig err := config.ValidateConnect(ctx, pgConn) if err != nil { pgConn.conn.Close() - return nil, errors.Errorf("ValidateConnect: %v", err) + return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err} } } return pgConn, nil @@ -229,7 +232,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig return nil, ErrorResponseToPgError(msg) default: pgConn.conn.Close() - return nil, errors.New("unexpected message") + return nil, &connectError{config: config, msg: "received unexpected message", err: err} } } } @@ -246,7 +249,7 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { } if response[0] != 'S' { - return ErrTLSRefused + return errors.New("server refused TLS connection") } pgConn.conn = tls.Client(pgConn.conn, tlsConfig) @@ -308,13 +311,13 @@ func (pgConn *PgConn) signalMessage() chan struct{} { // See https://www.postgresql.org/docs/current/protocol.html. func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { if err := pgConn.lock(); err != nil { - return linkErrors(err, ErrNoBytesSent) + return err } defer pgConn.unlock() select { case <-ctx.Done(): - return linkErrors(ctx.Err(), ErrNoBytesSent) + return &contextAlreadyDoneError{err: ctx.Err()} default: } pgConn.contextWatcher.Watch(ctx) @@ -323,10 +326,7 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - if n == 0 { - err = linkErrors(err, ErrNoBytesSent) - } - return linkErrors(ctx.Err(), err) + return &writeError{err: err, safeToRetry: n == 0} } return nil @@ -341,13 +341,13 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { // See https://www.postgresql.org/docs/current/protocol.html. func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessage, error) { if err := pgConn.lock(); err != nil { - return nil, linkErrors(err, ErrNoBytesSent) + return nil, err } defer pgConn.unlock() select { case <-ctx.Done(): - return nil, linkErrors(ctx.Err(), ErrNoBytesSent) + return nil, &contextAlreadyDoneError{err: ctx.Err()} default: } pgConn.contextWatcher.Watch(ctx) @@ -355,7 +355,7 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa msg, err := pgConn.receiveMessage() if err != nil { - err = linkErrors(ctx.Err(), err) + err = &pgconnError{msg: "receive message failed", err: err, safeToRetry: true} } return msg, err } @@ -442,12 +442,12 @@ func (pgConn *PgConn) Close(ctx context.Context) error { _, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) if err != nil { - return linkErrors(ctx.Err(), err) + return err } _, err = pgConn.conn.Read(make([]byte, 1)) if err != io.EOF { - return linkErrors(ctx.Err(), err) + return err } return pgConn.conn.Close() @@ -468,15 +468,15 @@ func (pgConn *PgConn) IsClosed() bool { return pgConn.status < connStatusIdle } -// lock locks the connection. It panics if the connection is already locked or is closed. +// lock locks the connection. func (pgConn *PgConn) lock() error { switch pgConn.status { case connStatusBusy: - return ErrConnBusy // This only should be possible in case of an application bug. + return &connLockError{status: "conn busy"} // This only should be possible in case of an application bug. case connStatusClosed: - return errors.New("conn closed") + return &connLockError{status: "conn closed"} case connStatusUninitialized: - return errors.New("conn uninitialized") + return &connLockError{status: "conn uninitialized"} } pgConn.status = connStatusBusy return nil @@ -527,13 +527,13 @@ type StatementDescription struct { // allows Prepare to also to describe statements without creating a server-side prepared statement. func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) { if err := pgConn.lock(); err != nil { - return nil, linkErrors(err, ErrNoBytesSent) + return nil, err } defer pgConn.unlock() select { case <-ctx.Done(): - return nil, linkErrors(ctx.Err(), ErrNoBytesSent) + return nil, &contextAlreadyDoneError{err: ctx.Err()} default: } pgConn.contextWatcher.Watch(ctx) @@ -547,10 +547,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - if n == 0 { - err = linkErrors(err, ErrNoBytesSent) - } - return nil, linkErrors(ctx.Err(), err) + return nil, &pgconnError{msg: "write failed", err: err, safeToRetry: n == 0} } psd := &StatementDescription{Name: name, SQL: sql} @@ -562,7 +559,7 @@ readloop: msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } switch msg := msg.(type) { @@ -641,12 +638,12 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey)) _, err = cancelConn.Write(buf) if err != nil { - return linkErrors(ctx.Err(), err) + return err } _, err = cancelConn.Read(buf) if err != io.EOF { - return errors.Errorf("Server failed to close connection after cancel query request: %w", linkErrors(ctx.Err(), err)) + return err } return nil @@ -672,7 +669,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { for { msg, err := pgConn.receiveMessage() if err != nil { - return linkErrors(ctx.Err(), err) + return err } switch msg.(type) { @@ -691,7 +688,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { if err := pgConn.lock(); err != nil { return &MultiResultReader{ closed: true, - err: linkErrors(err, ErrNoBytesSent), + err: err, } } @@ -704,7 +701,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { select { case <-ctx.Done(): multiResult.closed = true - multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent) + multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} pgConn.unlock() return multiResult default: @@ -719,10 +716,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { pgConn.hardClose() pgConn.contextWatcher.Unwatch() multiResult.closed = true - if n == 0 { - err = linkErrors(err, ErrNoBytesSent) - } - multiResult.err = linkErrors(ctx.Err(), err) + multiResult.err = &writeError{err: err, safeToRetry: n == 0} pgConn.unlock() return multiResult } @@ -798,7 +792,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by result := &pgConn.resultReader if err := pgConn.lock(); err != nil { - result.concludeCommand(nil, linkErrors(err, ErrNoBytesSent)) + result.concludeCommand(nil, err) result.closed = true return result } @@ -812,7 +806,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by select { case <-ctx.Done(): - result.concludeCommand(nil, linkErrors(ctx.Err(), ErrNoBytesSent)) + result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()}) result.closed = true pgConn.unlock() return result @@ -831,10 +825,7 @@ func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - if n == 0 { - err = linkErrors(err, ErrNoBytesSent) - } - result.concludeCommand(nil, linkErrors(ctx.Err(), err)) + result.concludeCommand(nil, &writeError{err: err, safeToRetry: n == 0}) pgConn.contextWatcher.Unwatch() result.closed = true pgConn.unlock() @@ -844,13 +835,13 @@ func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result // CopyTo executes the copy command sql and copies the results to w. func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { if err := pgConn.lock(); err != nil { - return nil, linkErrors(err, ErrNoBytesSent) + return nil, err } select { case <-ctx.Done(): pgConn.unlock() - return nil, linkErrors(ctx.Err(), ErrNoBytesSent) + return nil, &contextAlreadyDoneError{err: ctx.Err()} default: } pgConn.contextWatcher.Watch(ctx) @@ -864,10 +855,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm if err != nil { pgConn.hardClose() pgConn.unlock() - if n == 0 { - err = linkErrors(err, ErrNoBytesSent) - } - return nil, linkErrors(ctx.Err(), err) + return nil, &writeError{err: err, safeToRetry: n == 0} } // Read results @@ -877,7 +865,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } switch msg := msg.(type) { @@ -905,13 +893,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm // could still block. func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { if err := pgConn.lock(); err != nil { - return nil, linkErrors(err, ErrNoBytesSent) + return nil, err } defer pgConn.unlock() select { case <-ctx.Done(): - return nil, linkErrors(ctx.Err(), ErrNoBytesSent) + return nil, &contextAlreadyDoneError{err: ctx.Err()} default: } pgConn.contextWatcher.Watch(ctx) @@ -924,10 +912,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - if n == 0 { - err = linkErrors(err, ErrNoBytesSent) - } - return nil, linkErrors(ctx.Err(), err) + return nil, &writeError{err: err, safeToRetry: n == 0} } // Read until copy in response or error. @@ -938,7 +923,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } switch msg := msg.(type) { @@ -967,7 +952,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } } @@ -976,7 +961,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } switch msg := msg.(type) { @@ -998,7 +983,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } // Read results @@ -1006,7 +991,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.receiveMessage() if err != nil { pgConn.hardClose() - return nil, linkErrors(ctx.Err(), err) + return nil, err } switch msg := msg.(type) { @@ -1048,7 +1033,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) if err != nil { mrr.pgConn.contextWatcher.Unwatch() - mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) + mrr.err = err mrr.closed = true mrr.pgConn.hardClose() return nil, mrr.err @@ -1263,7 +1248,7 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { } rr.commandTag = commandTag - rr.err = preferContextOverNetTimeoutError(rr.ctx, err) + rr.err = err rr.fieldDescriptions = nil rr.rowValues = nil rr.commandConcluded = true @@ -1293,7 +1278,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR if err := pgConn.lock(); err != nil { return &MultiResultReader{ closed: true, - err: linkErrors(err, ErrNoBytesSent), + err: err, } } @@ -1306,7 +1291,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR select { case <-ctx.Done(): multiResult.closed = true - multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent) + multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} pgConn.unlock() return multiResult default: diff --git a/pgconn_test.go b/pgconn_test.go index 64628262..3fbdf8df 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -86,14 +86,11 @@ func TestConnectInvalidUser(t *testing.T) { config.User = "pgxinvalidusertest" - conn, err := pgconn.ConnectConfig(context.Background(), config) - if err == nil { - conn.Close(context.Background()) - t.Fatal("expected err but got none") - } - pgErr, ok := err.(*pgconn.PgError) + _, err = pgconn.ConnectConfig(context.Background(), config) + require.Error(t, err) + pgErr, ok := errors.Unwrap(err).(*pgconn.PgError) if !ok { - t.Fatalf("Expected to receive a PgError, instead received: %v", err) + t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err) } if pgErr.Code != "28000" && pgErr.Code != "28P01" { t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) @@ -298,7 +295,7 @@ func TestConnPrepareContextPrecanceled(t *testing.T) { assert.Nil(t, psd) assert.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(err)) ensureConnValid(t, pgConn) } @@ -432,7 +429,7 @@ func TestConnExecContextCanceled(t *testing.T) { for multiResult.NextResult() { } err = multiResult.Close() - assert.Equal(t, context.DeadlineExceeded, err) + assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) } @@ -448,7 +445,7 @@ func TestConnExecContextPrecanceled(t *testing.T) { _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() assert.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(err)) ensureConnValid(t, pgConn) } @@ -564,7 +561,7 @@ func TestConnExecParamsCanceled(t *testing.T) { assert.Equal(t, 0, rowCount) commandTag, err := result.Close() assert.Equal(t, pgconn.CommandTag(nil), commandTag) - assert.Equal(t, context.DeadlineExceeded, err) + assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) } @@ -581,7 +578,7 @@ func TestConnExecParamsPrecanceled(t *testing.T) { result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read() require.Error(t, result.Err) assert.True(t, errors.Is(result.Err, context.Canceled)) - assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(result.Err)) ensureConnValid(t, pgConn) } @@ -691,7 +688,7 @@ func TestConnExecPreparedCanceled(t *testing.T) { assert.Equal(t, 0, rowCount) commandTag, err := result.Close() assert.Equal(t, pgconn.CommandTag(nil), commandTag) - assert.Equal(t, context.DeadlineExceeded, err) + assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) } @@ -710,7 +707,7 @@ func TestConnExecPreparedPrecanceled(t *testing.T) { result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() require.Error(t, result.Err) assert.True(t, errors.Is(result.Err, context.Canceled)) - assert.True(t, errors.Is(result.Err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(result.Err)) ensureConnValid(t, pgConn) } @@ -798,7 +795,7 @@ func TestConnExecBatchPrecanceled(t *testing.T) { _, err = pgConn.ExecBatch(ctx, batch).ReadAll() require.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(err)) ensureConnValid(t, pgConn) } @@ -871,8 +868,8 @@ func TestConnLocking(t *testing.T) { mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") _, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() assert.Error(t, err) - assert.True(t, errors.Is(err, pgconn.ErrConnBusy)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.Equal(t, "conn busy", err.Error()) + assert.True(t, pgconn.SafeToRetry(err)) results, err := mrr.ReadAll() assert.NoError(t, err) @@ -1029,7 +1026,7 @@ func TestConnWaitForNotificationTimeout(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) err = pgConn.WaitForNotification(ctx) cancel() - assert.True(t, errors.Is(err, context.DeadlineExceeded)) + assert.True(t, pgconn.Timeout(err)) ensureConnValid(t, pgConn) } @@ -1139,7 +1136,7 @@ func TestConnCopyToCanceled(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout") - assert.True(t, errors.Is(err, context.DeadlineExceeded)) + assert.Error(t, err) assert.Equal(t, pgconn.CommandTag(nil), res) assert.True(t, pgConn.IsClosed()) @@ -1159,7 +1156,7 @@ func TestConnCopyToPrecanceled(t *testing.T) { res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout") require.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(err)) assert.Equal(t, pgconn.CommandTag(nil), res) ensureConnValid(t, pgConn) @@ -1231,7 +1228,7 @@ func TestConnCopyFromCanceled(t *testing.T) { ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") cancel() assert.Equal(t, int64(0), ct.RowsAffected()) - assert.True(t, errors.Is(err, context.DeadlineExceeded)) + assert.Error(t, err) assert.True(t, pgConn.IsClosed()) } @@ -1267,7 +1264,7 @@ func TestConnCopyFromPrecanceled(t *testing.T) { ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") require.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) - assert.True(t, errors.Is(err, pgconn.ErrNoBytesSent)) + assert.True(t, pgconn.SafeToRetry(err)) assert.Equal(t, pgconn.CommandTag(nil), ct) ensureConnValid(t, pgConn) From 66aaed7c9eb0751b2936dbdbf278963dda8804fd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 27 Aug 2019 18:11:50 -0500 Subject: [PATCH 125/290] Remove public fields from PgConn - Access TxStatus via method - Make Config private fixes #7 --- auth_scram.go | 2 +- pgconn.go | 27 ++++++++++++++++----------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/auth_scram.go b/auth_scram.go index 4409a080..6d6d0651 100644 --- a/auth_scram.go +++ b/auth_scram.go @@ -31,7 +31,7 @@ const clientNonceLen = 18 // Perform SCRAM authentication. func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { - sc, err := newScramClient(serverAuthMechanisms, c.Config.Password) + sc, err := newScramClient(serverAuthMechanisms, c.config.Password) if err != nil { return err } diff --git a/pgconn.go b/pgconn.go index 347acf80..1e3f9515 100644 --- a/pgconn.go +++ b/pgconn.go @@ -69,10 +69,10 @@ type PgConn struct { pid uint32 // backend pid secretKey uint32 // key to use to send a cancel query message to the server parameterStatuses map[string]string // parameters that have been reported by the server - TxStatus byte + txStatus byte frontend Frontend - Config *Config + config *Config status byte // One of connStatus* constants @@ -149,7 +149,7 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { pgConn := new(PgConn) - pgConn.Config = config + pgConn.config = config pgConn.wbuf = make([]byte, 0, 1024) var err error @@ -261,9 +261,9 @@ func (pgConn *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error switch msg.Type { case pgproto3.AuthTypeOk: case pgproto3.AuthTypeCleartextPassword: - err = pgConn.txPasswordMessage(pgConn.Config.Password) + err = pgConn.txPasswordMessage(pgConn.config.Password) case pgproto3.AuthTypeMD5Password: - digestedPassword := "md5" + hexMD5(hexMD5(pgConn.Config.Password+pgConn.Config.User)+string(msg.Salt[:])) + digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:])) err = pgConn.txPasswordMessage(digestedPassword) case pgproto3.AuthTypeSASL: err = pgConn.scramAuth(msg.SASLAuthMechanisms) @@ -390,7 +390,7 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { switch msg := msg.(type) { case *pgproto3.ReadyForQuery: - pgConn.TxStatus = msg.TxStatus + pgConn.txStatus = msg.TxStatus case *pgproto3.ParameterStatus: pgConn.parameterStatuses[msg.Name] = msg.Value case *pgproto3.ErrorResponse: @@ -399,12 +399,12 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { return nil, ErrorResponseToPgError(msg) } case *pgproto3.NoticeResponse: - if pgConn.Config.OnNotice != nil { - pgConn.Config.OnNotice(pgConn, noticeResponseToNotice(msg)) + if pgConn.config.OnNotice != nil { + pgConn.config.OnNotice(pgConn, noticeResponseToNotice(msg)) } case *pgproto3.NotificationResponse: - if pgConn.Config.OnNotification != nil { - pgConn.Config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload}) + if pgConn.config.OnNotification != nil { + pgConn.config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload}) } } @@ -421,6 +421,11 @@ func (pgConn *PgConn) PID() uint32 { return pgConn.pid } +// TxStatus returns the current TxStatus as reported by the server. +func (pgConn *PgConn) TxStatus() byte { + return pgConn.txStatus +} + // SecretKey returns the backend secret key used to send a cancel query message to the server. func (pgConn *PgConn) SecretKey() uint32 { return pgConn.secretKey @@ -618,7 +623,7 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. serverAddr := pgConn.conn.RemoteAddr() - cancelConn, err := pgConn.Config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) + cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) if err != nil { return err } From 6bba3c4810ce93171830696896238f19911b7ca3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 31 Aug 2019 11:55:02 -0500 Subject: [PATCH 126/290] Update pgproto3 --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index cbeef02a..b54607b6 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/jackc/chunkreader/v2 v2.0.0 github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 + github.com/jackc/pgproto3/v2 v2.0.0-rc2 github.com/stretchr/testify v1.3.0 golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a golang.org/x/text v0.3.0 diff --git a/go.sum b/go.sum index 0e853203..d7a6d087 100644 --- a/go.sum +++ b/go.sum @@ -8,8 +8,8 @@ github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711 h1:vZp4bYotXUkFx7JUSm7U8KV/7Q0AOdrQxxBBj0ZmZsg= -github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= +github.com/jackc/pgproto3/v2 v2.0.0-rc2 h1:u+jUsxBxiLY2C6mhr8cZhSy71n/y8Id2STOzJ7bl2Mg= +github.com/jackc/pgproto3/v2 v2.0.0-rc2/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= From 2fabfa3c18b7bcb4f204c365f2f0d2e09d4564eb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 31 Aug 2019 15:44:54 -0500 Subject: [PATCH 127/290] Update to newest pgproto3 --- auth_scram.go | 34 ++++++++++++++++++++++------------ config.go | 2 +- go.mod | 2 +- go.sum | 8 ++------ pgconn.go | 40 ++++++++++++++++++++-------------------- 5 files changed, 46 insertions(+), 40 deletions(-) diff --git a/auth_scram.go b/auth_scram.go index 6d6d0651..665fc2c2 100644 --- a/auth_scram.go +++ b/auth_scram.go @@ -47,11 +47,11 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { } // Receive server-first-message payload in a AuthenticationSASLContinue. - authMsg, err := c.rxAuthMsg(pgproto3.AuthTypeSASLContinue) + saslContinue, err := c.rxSASLContinue() if err != nil { return err } - err = sc.recvServerFirstMessage(authMsg.SASLData) + err = sc.recvServerFirstMessage(saslContinue.Data) if err != nil { return err } @@ -66,27 +66,37 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { } // Receive server-final-message payload in a AuthenticationSASLFinal. - authMsg, err = c.rxAuthMsg(pgproto3.AuthTypeSASLFinal) + saslFinal, err := c.rxSASLFinal() if err != nil { return err } - return sc.recvServerFinalMessage(authMsg.SASLData) + return sc.recvServerFinalMessage(saslFinal.Data) } -func (c *PgConn) rxAuthMsg(typ uint32) (*pgproto3.Authentication, error) { +func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) { msg, err := c.receiveMessage() if err != nil { return nil, err } - authMsg, ok := msg.(*pgproto3.Authentication) - if !ok { - return nil, errors.New("unexpected message type") - } - if authMsg.Type != typ { - return nil, errors.New("unexpected auth type") + saslContinue, ok := msg.(*pgproto3.AuthenticationSASLContinue) + if ok { + return saslContinue, nil } - return authMsg, nil + return nil, errors.New("expected AuthenticationSASLContinue message but received unexpected message") +} + +func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) { + msg, err := c.receiveMessage() + if err != nil { + return nil, err + } + saslFinal, ok := msg.(*pgproto3.AuthenticationSASLFinal) + if ok { + return saslFinal, nil + } + + return nil, errors.New("expected AuthenticationSASLFinal message but received unexpected message") } type scramClient struct { diff --git a/config.go b/config.go index d24d0202..d1267621 100644 --- a/config.go +++ b/config.go @@ -501,7 +501,7 @@ func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc { if err != nil { panic(fmt.Sprintf("BUG: chunkreader.NewConfig failed: %v", err)) } - frontend, _ := pgproto3.NewFrontend(cr, w) + frontend := pgproto3.NewFrontend(cr, w) return frontend } diff --git a/go.mod b/go.mod index b54607b6..6e270cd6 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/jackc/chunkreader/v2 v2.0.0 github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.0-rc2 + github.com/jackc/pgproto3/v2 v2.0.0-rc3 github.com/stretchr/testify v1.3.0 golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a golang.org/x/text v0.3.0 diff --git a/go.sum b/go.sum index d7a6d087..ed8eb401 100644 --- a/go.sum +++ b/go.sum @@ -1,17 +1,13 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= -github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgproto3/v2 v2.0.0-rc2 h1:u+jUsxBxiLY2C6mhr8cZhSy71n/y8Id2STOzJ7bl2Mg= -github.com/jackc/pgproto3/v2 v2.0.0-rc2/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= -github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/jackc/pgproto3/v2 v2.0.0-rc3 h1:EHkgVE6iDyI7HZDfMPaZ2Xjdf7C29DikR6o39WVO61c= +github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/pgconn.go b/pgconn.go index 1e3f9515..d51eb76a 100644 --- a/pgconn.go +++ b/pgconn.go @@ -210,11 +210,28 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig case *pgproto3.BackendKeyData: pgConn.pid = msg.ProcessID pgConn.secretKey = msg.SecretKey - case *pgproto3.Authentication: - if err = pgConn.rxAuthenticationX(msg); err != nil { + + case *pgproto3.AuthenticationOk: + case *pgproto3.AuthenticationCleartextPassword: + err = pgConn.txPasswordMessage(pgConn.config.Password) + if err != nil { pgConn.conn.Close() - return nil, &connectError{config: config, msg: "failed handle authentication message", err: err} + return nil, &connectError{config: config, msg: "failed to write password message", err: err} } + case *pgproto3.AuthenticationMD5Password: + digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:])) + err = pgConn.txPasswordMessage(digestedPassword) + if err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "failed to write password message", err: err} + } + case *pgproto3.AuthenticationSASL: + err = pgConn.scramAuth(msg.AuthMechanisms) + if err != nil { + pgConn.conn.Close() + return nil, &connectError{config: config, msg: "failed SASL auth", err: err} + } + case *pgproto3.ReadyForQuery: pgConn.status = connStatusIdle if config.ValidateConnect != nil { @@ -257,23 +274,6 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { return nil } -func (pgConn *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { - switch msg.Type { - case pgproto3.AuthTypeOk: - case pgproto3.AuthTypeCleartextPassword: - err = pgConn.txPasswordMessage(pgConn.config.Password) - case pgproto3.AuthTypeMD5Password: - digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:])) - err = pgConn.txPasswordMessage(digestedPassword) - case pgproto3.AuthTypeSASL: - err = pgConn.scramAuth(msg.SASLAuthMechanisms) - default: - err = errors.New("Received unknown authentication message") - } - - return -} - func (pgConn *PgConn) txPasswordMessage(password string) (err error) { msg := &pgproto3.PasswordMessage{Password: password} _, err = pgConn.conn.Write(msg.Encode(pgConn.wbuf)) From 2f6b8f3f5665228c0800e66b05e797ef119f3ef2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 31 Aug 2019 17:01:54 -0500 Subject: [PATCH 128/290] Fix context timeout on connect --- go.mod | 11 ++++--- go.sum | 88 ++++++++++++++++++++++++++++++++++++++++++++++++++ pgconn.go | 3 ++ pgconn_test.go | 62 +++++++++++++++++++++++++++++++++++ 4 files changed, 159 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index 6e270cd6..11692c10 100644 --- a/go.mod +++ b/go.mod @@ -5,10 +5,11 @@ go 1.12 require ( github.com/jackc/chunkreader/v2 v2.0.0 github.com/jackc/pgio v1.0.0 + github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.0-rc3 - github.com/stretchr/testify v1.3.0 - golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a - golang.org/x/text v0.3.0 - golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 + github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 + github.com/stretchr/testify v1.4.0 + golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 + golang.org/x/text v0.3.2 + golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 ) diff --git a/go.sum b/go.sum index ed8eb401..c1b3d405 100644 --- a/go.sum +++ b/go.sum @@ -1,23 +1,111 @@ +github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= +github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= +github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= +github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye4717ITLaNwV9mWbJx0dLCpcRzdA= +github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= github.com/jackc/pgproto3/v2 v2.0.0-rc3 h1:EHkgVE6iDyI7HZDfMPaZ2Xjdf7C29DikR6o39WVO61c= github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 h1:f2HwOeI1NIJyNFVVeh1gUISyt57iw/fmI/IXJfH3ATE= +github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= +github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= +github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= +github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= +github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= +github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= +github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= +github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= +github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= +github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= +github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= +github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= +go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcNpq8q3BCACtVgNfoJxOV7g= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 h1:7KByu05hhLed2MO29w7p1XfZvZ13m8mub3shuVftRs0= +golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e h1:nFYrTHrdrAOpShe27kaFHjsqYSEQ0KWqdWLu3xuZJts= golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 h1:bhOzK9QyoD0ogCnFro1m2mz41+Ib0oOhfJnBp5MR4K4= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/pgconn.go b/pgconn.go index d51eb76a..5c01d1dc 100644 --- a/pgconn.go +++ b/pgconn.go @@ -174,6 +174,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig func() { pgConn.conn.SetDeadline(time.Time{}) }, ) + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn) startupMsg := pgproto3.StartupMessage{ diff --git a/pgconn_test.go b/pgconn_test.go index 3fbdf8df..4a67a2e0 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -18,6 +18,7 @@ import ( "time" "github.com/jackc/pgconn" + "github.com/jackc/pgmock" "github.com/jackc/pgproto3/v2" errors "golang.org/x/xerrors" @@ -73,6 +74,67 @@ func TestConnectTLS(t *testing.T) { closeConn(t, conn) } +type pgmockWaitStep time.Duration + +func (s pgmockWaitStep) Step(*pgproto3.Backend) error { + time.Sleep(time.Duration(s)) + return nil +} + +func TestConnectWithContextThatTimesOut(t *testing.T) { + t.Parallel() + + script := &pgmock.Script{ + Steps: []pgmock.Step{ + pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}), + pgmock.SendMessage(&pgproto3.AuthenticationOk{}), + pgmockWaitStep(time.Millisecond * 500), + pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}), + pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + }, + } + + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverErrChan := make(chan error, 1) + go func() { + defer close(serverErrChan) + + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + err = conn.SetDeadline(time.Now().Add(time.Millisecond * 450)) + if err != nil { + serverErrChan <- err + return + } + + err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)) + if err != nil { + serverErrChan <- err + return + } + }() + + parts := strings.Split(ln.Addr().String(), ":") + host := parts[0] + port := parts[1] + connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) + tooLate := time.Now().Add(time.Millisecond * 500) + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50) + defer cancel() + _, err = pgconn.Connect(ctx, connStr) + require.True(t, pgconn.Timeout(err), err) + require.True(t, time.Now().Before(tooLate)) +} + func TestConnectInvalidUser(t *testing.T) { t.Parallel() From a8362ef96d23eb9e53a9eb57bb12889f8cbaa1c2 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 10 Sep 2019 17:14:04 -0500 Subject: [PATCH 129/290] Parse postgresql:// protocol --- config.go | 2 +- config_test.go | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/config.go b/config.go index d1267621..2ec6ae3f 100644 --- a/config.go +++ b/config.go @@ -152,7 +152,7 @@ func ParseConfig(connString string) (*Config, error) { if connString != "" { // connString may be a database URL or a DSN - if strings.HasPrefix(connString, "postgres://") { + if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { err := addURLSettings(settings, connString) if err != nil { return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err} diff --git a/config_test.go b/config_test.go index af42094d..090302a2 100644 --- a/config_test.go +++ b/config_test.go @@ -214,6 +214,18 @@ func TestParseConfig(t *testing.T) { RuntimeParams: map[string]string{}, }, }, + { + name: "database url postgresql protocol", + connString: "postgresql://jack@localhost:5432/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, { name: "DSN everything", connString: "user=jack password=secret host=localhost port=5432 database=mydb sslmode=disable application_name=pgxtest search_path=myschema", From f8be2b60ce34bf79b747009b9cc7fb718b918734 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 10 Sep 2019 17:25:25 -0500 Subject: [PATCH 130/290] go.sum changes --- go.sum | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/go.sum b/go.sum index c1b3d405..d0a917fc 100644 --- a/go.sum +++ b/go.sum @@ -2,11 +2,11 @@ github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMe github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= @@ -19,10 +19,10 @@ github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye47 github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= -github.com/jackc/pgproto3/v2 v2.0.0-rc3 h1:EHkgVE6iDyI7HZDfMPaZ2Xjdf7C29DikR6o39WVO61c= github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 h1:f2HwOeI1NIJyNFVVeh1gUISyt57iw/fmI/IXJfH3ATE= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= @@ -48,7 +48,6 @@ github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -63,7 +62,6 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= @@ -74,7 +72,6 @@ go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/ go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcNpq8q3BCACtVgNfoJxOV7g= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 h1:7KByu05hhLed2MO29w7p1XfZvZ13m8mub3shuVftRs0= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -86,12 +83,10 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e h1:nFYrTHrdrAOpShe27kaFHjsqYSEQ0KWqdWLu3xuZJts= golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= @@ -99,7 +94,6 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 h1:bhOzK9QyoD0ogCnFro1m2mz41+Ib0oOhfJnBp5MR4K4= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= From b2ca5d8f521597a28e8dc0703b9b2a8c72d9866a Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Fri, 13 Sep 2019 17:26:09 +0300 Subject: [PATCH 131/290] validate all addresses resolved from hostname Signed-off-by: Artemiy Ryabinkov --- config.go | 9 ++++++++- pgconn.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/config.go b/config.go index 2ec6ae3f..57e65e13 100644 --- a/config.go +++ b/config.go @@ -37,6 +37,7 @@ type Config struct { Password string TLSConfig *tls.Config // nil disables TLS DialFunc DialFunc // e.g. net.Dialer.DialContext + LookupFunc LookupFunc // e.g. net.Resolver.LookupHost BuildFrontend BuildFrontendFunc RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) @@ -77,7 +78,7 @@ func NetworkAddress(host string, port uint16) (network, address string) { address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10) } else { network = "tcp" - address = fmt.Sprintf("%s:%d", host, port) + address = net.JoinHostPort(host, strconv.Itoa(int(port))) } return network, address } @@ -190,6 +191,8 @@ func ParseConfig(connString string) (*Config, error) { config.DialFunc = defaultDialer.DialContext } + config.LookupFunc = makeDefaultResolver().LookupHost + notRuntimeParams := map[string]struct{}{ "host": struct{}{}, "port": struct{}{}, @@ -495,6 +498,10 @@ func makeDefaultDialer() *net.Dialer { return &net.Dialer{KeepAlive: 5 * time.Minute} } +func makeDefaultResolver() *net.Resolver { + return net.DefaultResolver +} + func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc { return func(r io.Reader, w io.Writer) Frontend { cr, err := chunkreader.NewConfig(r, chunkreader.Config{MinBufLen: minBufferLen}) diff --git a/pgconn.go b/pgconn.go index 5c01d1dc..db2ebe73 100644 --- a/pgconn.go +++ b/pgconn.go @@ -43,6 +43,9 @@ type Notification struct { // DialFunc is a function that can be used to connect to a PostgreSQL server. type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) +// LookupFunc is a function that can be used to lookup IPs addrs from host. +type LookupFunc func(ctx context.Context, host string) (addrs []string, err error) + // BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. type BuildFrontendFunc func(r io.Reader, w io.Writer) Frontend @@ -123,6 +126,15 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err } fallbackConfigs = append(fallbackConfigs, config.Fallbacks...) + fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs) + if err != nil { + return nil, &connectError{config: config, msg: "hostname resolving error", err: err} + } + + if len(fallbackConfigs) == 0 { + return nil, &connectError{config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")} + } + for _, fc := range fallbackConfigs { pgConn, err = connect(ctx, config, fc) if err == nil { @@ -147,6 +159,27 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err return pgConn, nil } +func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*FallbackConfig) ([]*FallbackConfig, error) { + var configs []*FallbackConfig + + for _, fb := range fallbacks { + ips, err := lookupFn(ctx, fb.Host) + if err != nil { + return nil, err + } + + for _, ip := range ips { + configs = append(configs, &FallbackConfig{ + Host: ip, + Port: fb.Port, + TLSConfig: fb.TLSConfig, + }) + } + } + + return configs, nil +} + func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { pgConn := new(PgConn) pgConn.config = config From e538885fa71f92c5974c28edb49db682a0194a33 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Fri, 13 Sep 2019 17:52:01 +0300 Subject: [PATCH 132/290] skip resolve for unix sockets Signed-off-by: Artemiy Ryabinkov --- pgconn.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/pgconn.go b/pgconn.go index db2ebe73..25f4f4d5 100644 --- a/pgconn.go +++ b/pgconn.go @@ -163,6 +163,17 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba var configs []*FallbackConfig for _, fb := range fallbacks { + // skip resolve for unix sockets + if strings.HasPrefix(fb.Host, "/") { + configs = append(configs, &FallbackConfig{ + Host: fb.Host, + Port: fb.Port, + TLSConfig: fb.TLSConfig, + }) + + continue + } + ips, err := lookupFn(ctx, fb.Host) if err != nil { return nil, err From 17d3d592e980720a8baa9a98e91a3de9fec06af7 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Sat, 14 Sep 2019 19:11:26 +0300 Subject: [PATCH 133/290] add test for custom lookup func Signed-off-by: Artemiy Ryabinkov --- pgconn_test.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/pgconn_test.go b/pgconn_test.go index 4a67a2e0..36499b68 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -188,6 +188,24 @@ func TestConnectCustomDialer(t *testing.T) { closeConn(t, conn) } +func TestConnectCustomLookup(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + looked := false + config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) { + looked = true + return net.LookupHost(host) + } + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + require.True(t, looked) + closeConn(t, conn) +} + func TestConnectWithRuntimeParams(t *testing.T) { t.Parallel() From 99f22ac8e4c9c142d9541ab648274e7663357fab Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 14 Sep 2019 18:37:33 -0500 Subject: [PATCH 134/290] Port DSN parser from pgx v3 Original implementation: 2d9d8dc52ac211c6191c08e050c03588aa633038 by Joshua Barone . Also changed DSN tests to use "dbname" as key rather than "database" as that is what the PostgreSQL documentation specifies. "database" still actually works but it should not be encouraged as it is non-standard. --- .travis.yml | 8 ++--- README.md | 2 +- config.go | 61 ++++++++++++++++++++++++++++++++++--- config_test.go | 82 +++++++++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 139 insertions(+), 14 deletions(-) diff --git a/.travis.yml b/.travis.yml index 2c547abf..abff8515 100644 --- a/.travis.yml +++ b/.travis.yml @@ -17,15 +17,15 @@ env: - GOPROXY=https://proxy.golang.org - GOFLAGS=-mod=readonly - PGX_TEST_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test - - PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql database=pgx_test" + - PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql dbname=pgx_test" - PGX_TEST_TCP_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test - PGX_TEST_TLS_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require - PGX_TEST_MD5_PASSWORD_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test - PGX_TEST_PLAIN_PASSWORD_CONN_STRING=postgres://pgx_pw:secret@127.0.0.1/pgx_test matrix: - - CRATEVERSION=2.1 PGX_TEST_CRATEDB_CONN_STRING="host=127.0.0.1 port=6543 user=pgx database=pgx_test" - - PGVERSION=10 PGX_TEST_REPLICATION_CONN_STRING="host=127.0.0.1 port=6543 user=pgx_replication password=secret database=pgx_test" - - PGVERSION=9.6 PGX_TEST_REPLICATION_CONN_STRING="host=127.0.0.1 port=6543 user=pgx_replication password=secret database=pgx_test" + - CRATEVERSION=2.1 PGX_TEST_CRATEDB_CONN_STRING="host=127.0.0.1 port=6543 user=pgx dbname=pgx_test" + - PGVERSION=10 PGX_TEST_REPLICATION_CONN_STRING="host=127.0.0.1 port=6543 user=pgx_replication password=secret dbname=pgx_test" + - PGVERSION=9.6 PGX_TEST_REPLICATION_CONN_STRING="host=127.0.0.1 port=6543 user=pgx_replication password=secret dbname=pgx_test" - PGVERSION=9.5 - PGVERSION=9.4 - PGVERSION=9.3 diff --git a/README.md b/README.md index 9e35a0f5..aa980b6d 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ create database pgx_test; Now you can run the tests: ``` -PGX_TEST_CONN_STRING="host=/var/run/postgresql database=pgx_test" go test ./... +PGX_TEST_CONN_STRING="host=/var/run/postgresql dbname=pgx_test" go test ./... ``` ### Connection and Authentication Tests diff --git a/config.go b/config.go index 2ec6ae3f..6eb0065a 100644 --- a/config.go +++ b/config.go @@ -13,7 +13,6 @@ import ( "os" "os/user" "path/filepath" - "regexp" "strconv" "strings" "time" @@ -389,13 +388,65 @@ func addURLSettings(settings map[string]string, connString string) error { return nil } -var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`) +var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1} func addDSNSettings(settings map[string]string, s string) error { - m := dsnRegexp.FindAllStringSubmatch(s, -1) + nameMap := map[string]string{ + "dbname": "database", + } - for _, b := range m { - settings[b[1]] = b[2] + for len(s) > 0 { + var key, val string + eqIdx := strings.IndexRune(s, '=') + if eqIdx < 0 { + return errors.New("invalid dsn") + } + + key = strings.Trim(s[:eqIdx], " \t\n\r\v\f") + s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f") + if s[0] != '\'' { + end := 0 + for ; end < len(s); end++ { + if asciiSpace[s[end]] == 1 { + break + } + if s[end] == '\\' { + end++ + } + } + val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) + if end == len(s) { + s = "" + } else { + s = s[end+1:] + } + } else { // quoted string + s = s[1:] + end := 0 + for ; end < len(s); end++ { + if s[end] == '\'' { + break + } + if s[end] == '\\' { + end++ + } + } + if end == len(s) { + return errors.New("unterminated quoted string in connection info string") + } + val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) + if end == len(s) { + s = "" + } else { + s = s[end+1:] + } + } + + if k, ok := nameMap[key]; ok { + key = k + } + + settings[key] = val } return nil diff --git a/config_test.go b/config_test.go index 090302a2..9eb5df2f 100644 --- a/config_test.go +++ b/config_test.go @@ -228,7 +228,7 @@ func TestParseConfig(t *testing.T) { }, { name: "DSN everything", - connString: "user=jack password=secret host=localhost port=5432 database=mydb sslmode=disable application_name=pgxtest search_path=myschema", + connString: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable application_name=pgxtest search_path=myschema", config: &pgconn.Config{ User: "jack", Password: "secret", @@ -242,6 +242,80 @@ func TestParseConfig(t *testing.T) { }, }, }, + { + name: "DSN with escaped single quote", + connString: "user=jack\\'s password=secret host=localhost port=5432 dbname=mydb sslmode=disable", + config: &pgconn.Config{ + User: "jack's", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "DSN with escaped backslash", + connString: "user=jack password=sooper\\\\secret host=localhost port=5432 dbname=mydb sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "sooper\\secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "DSN with single quoted values", + connString: "user='jack' host='localhost' dbname='mydb' sslmode='disable'", + config: &pgconn.Config{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "DSN with single quoted value with escaped single quote", + connString: "user='jack\\'s' host='localhost' dbname='mydb' sslmode='disable'", + config: &pgconn.Config{ + User: "jack's", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "DSN with empty single quoted value", + connString: "user='jack' password='' host='localhost' dbname='mydb' sslmode='disable'", + config: &pgconn.Config{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "DSN with space between key and value", + connString: "user = 'jack' password = '' host = 'localhost' dbname = 'mydb' sslmode='disable'", + config: &pgconn.Config{ + User: "jack", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, { name: "URL multiple hosts", connString: "postgres://jack:secret@foo,bar,baz/mydb?sslmode=disable", @@ -294,7 +368,7 @@ func TestParseConfig(t *testing.T) { }, { name: "DSN multiple hosts one port", - connString: "user=jack password=secret host=foo,bar,baz port=5432 database=mydb sslmode=disable", + connString: "user=jack password=secret host=foo,bar,baz port=5432 dbname=mydb sslmode=disable", config: &pgconn.Config{ User: "jack", Password: "secret", @@ -319,7 +393,7 @@ func TestParseConfig(t *testing.T) { }, { name: "DSN multiple hosts multiple ports", - connString: "user=jack password=secret host=foo,bar,baz port=1,2,3 database=mydb sslmode=disable", + connString: "user=jack password=secret host=foo,bar,baz port=1,2,3 dbname=mydb sslmode=disable", config: &pgconn.Config{ User: "jack", Password: "secret", @@ -344,7 +418,7 @@ func TestParseConfig(t *testing.T) { }, { name: "multiple hosts and fallback tsl", - connString: "user=jack password=secret host=foo,bar,baz database=mydb sslmode=prefer", + connString: "user=jack password=secret host=foo,bar,baz dbname=mydb sslmode=prefer", config: &pgconn.Config{ User: "jack", Password: "secret", From bbc7f67a6f5907a413ff3106ebf6c54d1f09101a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 14 Sep 2019 20:22:50 -0500 Subject: [PATCH 135/290] Update to pgproto3 v2.0.0 --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 11692c10..4a188cce 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 + github.com/jackc/pgproto3/v2 v2.0.0 github.com/stretchr/testify v1.4.0 golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 golang.org/x/text v0.3.2 diff --git a/go.sum b/go.sum index d0a917fc..51c55d12 100644 --- a/go.sum +++ b/go.sum @@ -26,6 +26,8 @@ github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 h1:f2HwOeI1NIJyNFVVeh1gUISyt57iw/fmI/IXJfH3ATE= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.0 h1:FApgMJ/GtaXfI0s8Lvd0kaLaRwMOhs4VH92pwkwQQvU= +github.com/jackc/pgproto3/v2 v2.0.0/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= From f5eead90fca09203d8af956fea01861884ed9a8a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 19 Sep 2019 21:04:14 -0500 Subject: [PATCH 136/290] Fix statement cache reuse bug --- stmtcache/lru.go | 4 +++- stmtcache/lru_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/stmtcache/lru.go b/stmtcache/lru.go index fff4d0b7..d82ced19 100644 --- a/stmtcache/lru.go +++ b/stmtcache/lru.go @@ -104,8 +104,10 @@ func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.StatementDescrip func (c *LRU) removeOldest(ctx context.Context) error { oldest := c.l.Back() c.l.Remove(oldest) + psd := oldest.Value.(*pgconn.StatementDescription) + delete(c.m, psd.SQL) if c.mode == ModePrepare { - return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", oldest.Value.(*pgconn.StatementDescription).Name)).Close() + return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close() } return nil } diff --git a/stmtcache/lru_test.go b/stmtcache/lru_test.go index b518364e..d2902dbb 100644 --- a/stmtcache/lru_test.go +++ b/stmtcache/lru_test.go @@ -2,6 +2,8 @@ package stmtcache_test import ( "context" + "fmt" + "math/rand" "os" "testing" "time" @@ -57,6 +59,30 @@ func TestLRUModePrepare(t *testing.T) { require.Empty(t, fetchServerStatements(t, ctx, conn)) } +func TestLRUModePrepareStress(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer conn.Close(ctx) + + cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 8) + require.EqualValues(t, 0, cache.Len()) + require.EqualValues(t, 8, cache.Cap()) + require.EqualValues(t, stmtcache.ModePrepare, cache.Mode()) + + for i := 0; i < 1000; i++ { + psd, err := cache.Get(ctx, fmt.Sprintf("select %d", rand.Intn(50))) + require.NoError(t, err) + require.NotNil(t, psd) + result := conn.ExecPrepared(ctx, psd.Name, nil, nil, nil).Read() + require.NoError(t, result.Err) + } +} + func TestLRUModeDescribe(t *testing.T) { t.Parallel() From d6b0287fcda8ef85425ef39a43e0e10921877449 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 19 Sep 2019 21:41:20 -0500 Subject: [PATCH 137/290] Release v1.0.1 --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 CHANGELOG.md diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..5384b031 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,3 @@ +# 1.0.1 (September 19, 2019) + +* Fix statement cache not properly cleaning discarded statements From 6c195c17b2af217104e20bb20da66c91d2b2f8f1 Mon Sep 17 00:00:00 2001 From: Francis Chuang <2263040+F21@users.noreply.github.com> Date: Thu, 3 Oct 2019 09:49:12 +1000 Subject: [PATCH 138/290] Fix minor errors and reword some sentences for readability --- README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index aa980b6d..5d14e914 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ # pgconn -Package pgconn is a low-level PostgreSQL database driver. It operates at nearly the same level is the C library libpq. +Package pgconn is a low-level PostgreSQL database driver. It operates at nearly the same level as the C library libpq. It is primarily intended to serve as the foundation for higher level libraries such as https://github.com/jackc/pgx. Applications should handle normal queries with a higher level library and only use pgconn directly when required for low-level access to PostgreSQL functionality. @@ -17,7 +17,7 @@ if err != nil { } defer pgConn.Close() -result := pgConn.ExecParams(context.Background(), "select email from users where id=$1", [][]byte{[]byte("123")}, nil, nil, nil) +result := pgConn.ExecParams(context.Background(), "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil) for result.NextRow() { fmt.Println("User 123 has email:", string(result.Values()[0])) } @@ -29,7 +29,7 @@ if err != nil { ## Testing -pgconn tests need a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_CONN_STRING` +The pgconn tests require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_CONN_STRING` environment variable. The `PGX_TEST_CONN_STRING` environment variable can be a URL or DSN. In addition, the standard `PG*` environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to simplify environment variable handling. @@ -44,13 +44,13 @@ create database pgx_test; Now you can run the tests: -``` +```bash PGX_TEST_CONN_STRING="host=/var/run/postgresql dbname=pgx_test" go test ./... ``` ### Connection and Authentication Tests -There are multiple connection types and means of authentication that pgconn supports. These tests are optional. They +Pgconn supports multiple connection types and means of authentication. These tests are optional. They will only run if the appropriate environment variable is set. Run `go test -v | grep SKIP` to see if any tests are being -skipped. Typical developers will not need to enable these tests. See travis.yml for example setup if you need change +skipped. Most developers will not need to enable these tests. See `travis.yml` for an example set up if you need change authentication code. From fcfd7d09a9079edbce62cf83d5d184e8b2dbc33e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 12 Oct 2019 10:21:33 -0500 Subject: [PATCH 139/290] Add PgConn.IsBusy() method --- pgconn.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pgconn.go b/pgconn.go index 25f4f4d5..e3f3aaff 100644 --- a/pgconn.go +++ b/pgconn.go @@ -520,6 +520,11 @@ func (pgConn *PgConn) IsClosed() bool { return pgConn.status < connStatusIdle } +// IsBusy reports if the connection is busy. +func (pgConn *PgConn) IsBusy() bool { + return pgConn.status == connStatusBusy +} + // lock locks the connection. func (pgConn *PgConn) lock() error { switch pgConn.status { From 4df62cf3d029efb55dc1cc8d31144e9ed2d80d44 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 12 Oct 2019 11:23:48 -0500 Subject: [PATCH 140/290] Release v1.1.0 --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5384b031..92497f47 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.1.0 (October 12, 2019) + +* Add PgConn.IsBusy() method + # 1.0.1 (September 19, 2019) * Fix statement cache not properly cleaning discarded statements From 81b6ad72f6dedf2162a06cdb3543de33b28ec2ff Mon Sep 17 00:00:00 2001 From: Skip Gibson Date: Wed, 16 Oct 2019 10:01:16 +0100 Subject: [PATCH 141/290] config: fix ValidateConnect comment --- config.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/config.go b/config.go index f41c38b9..628deed8 100644 --- a/config.go +++ b/config.go @@ -43,9 +43,8 @@ type Config struct { Fallbacks []*FallbackConfig // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server. - // It can be used validate that server is acceptable. If this returns an error the connection is closed and the next - // fallback config is tried. This allows implementing high availability behavior such as libpq does with - // target_session_attrs. + // It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next + // fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs. ValidateConnect ValidateConnectFunc // AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables From eb81d2926b3b0519cd1fe945c2795cceeabe236c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 18 Nov 2019 07:24:05 -0600 Subject: [PATCH 142/290] Ignore errors sending Terminate message while closing connection This mimics the behavior of libpq PGfinish. refs #637 --- pgconn.go | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/pgconn.go b/pgconn.go index e3f3aaff..210d9979 100644 --- a/pgconn.go +++ b/pgconn.go @@ -492,15 +492,13 @@ func (pgConn *PgConn) Close(ctx context.Context) error { pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() - _, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) - if err != nil { - return err - } - - _, err = pgConn.conn.Read(make([]byte, 1)) - if err != io.EOF { - return err - } + // Ignore any errors sending Terminate message and waiting for server to close connection. + // This mimics the behavior of libpq PQfinish. It calls closePGconn which calls sendTerminateConn which purposefully + // ignores errors. + // + // See https://github.com/jackc/pgx/issues/637 + pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) + pgConn.conn.Read(make([]byte, 1)) return pgConn.conn.Close() } From 32350bd1dc3288aa22f271f87065741da0a1bdb8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 18 Nov 2019 07:28:47 -0600 Subject: [PATCH 143/290] TestConnectCustomLookup must test with TCP connection Test (correctly) fails if run on a Unix domain socket. --- pgconn_test.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pgconn_test.go b/pgconn_test.go index 36499b68..6f330efb 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -191,7 +191,12 @@ func TestConnectCustomDialer(t *testing.T) { func TestConnectCustomLookup(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } + + config, err := pgconn.ParseConfig(connString) require.NoError(t, err) looked := false From bd0ce203e9563e2b966b5f796bab2cf9f555bc2b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 21 Dec 2019 10:31:27 -0600 Subject: [PATCH 144/290] CopyFrom not table test was failing with syntax error --- pgconn_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgconn_test.go b/pgconn_test.go index 6f330efb..6b57dd09 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1444,7 +1444,7 @@ func TestConnCopyFromQueryNoTableError(t *testing.T) { srcBuf := &bytes.Buffer{} - res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo to stdout") + res, err := pgConn.CopyFrom(context.Background(), srcBuf, "copy foo to stdout") require.Error(t, err) assert.IsType(t, &pgconn.PgError{}, err) assert.Equal(t, int64(0), res.RowsAffected()) From dd53b7488d920c44204098385e460cf708626a42 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 21 Dec 2019 10:06:24 -0600 Subject: [PATCH 145/290] Restart signalMessage when receiving non-error message in CopyFrom fixes #21 --- pgconn.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pgconn.go b/pgconn.go index 210d9979..4c75d367 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1022,6 +1022,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co switch msg := msg.(type) { case *pgproto3.ErrorResponse: pgErr = ErrorResponseToPgError(msg) + default: + signalMessageChan = pgConn.signalMessage() } default: } From 18d1ed5ee5619f59c6b5e670e2ffa36f1ffe95fd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 21 Dec 2019 14:37:09 -0600 Subject: [PATCH 146/290] Remove PostgreSQL 9.3 from Travis build matrix PostgreSQL 9.3 is EOL so it doesn't make sense for pgconn to specifically support. There are no known incompatibilities but it will not longer be tested. --- .travis.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index abff8515..50f13881 100644 --- a/.travis.yml +++ b/.travis.yml @@ -28,7 +28,6 @@ env: - PGVERSION=9.6 PGX_TEST_REPLICATION_CONN_STRING="host=127.0.0.1 port=6543 user=pgx_replication password=secret dbname=pgx_test" - PGVERSION=9.5 - PGVERSION=9.4 - - PGVERSION=9.3 cache: directories: From 5fc867a833afcd0d51c7d05bbff19e16e2adb34d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 21 Dec 2019 14:40:30 -0600 Subject: [PATCH 147/290] Remove unused travis environment variable --- .travis.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 50f13881..c1688000 100644 --- a/.travis.yml +++ b/.travis.yml @@ -24,8 +24,8 @@ env: - PGX_TEST_PLAIN_PASSWORD_CONN_STRING=postgres://pgx_pw:secret@127.0.0.1/pgx_test matrix: - CRATEVERSION=2.1 PGX_TEST_CRATEDB_CONN_STRING="host=127.0.0.1 port=6543 user=pgx dbname=pgx_test" - - PGVERSION=10 PGX_TEST_REPLICATION_CONN_STRING="host=127.0.0.1 port=6543 user=pgx_replication password=secret dbname=pgx_test" - - PGVERSION=9.6 PGX_TEST_REPLICATION_CONN_STRING="host=127.0.0.1 port=6543 user=pgx_replication password=secret dbname=pgx_test" + - PGVERSION=10 + - PGVERSION=9.6 - PGVERSION=9.5 - PGVERSION=9.4 From 3e503b7b1a3beb8466a0ef126b87c209c5dc91e6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 21 Dec 2019 14:41:09 -0600 Subject: [PATCH 148/290] Add PostgreSQL 11 and 12 to the Travis build matrix --- .travis.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.travis.yml b/.travis.yml index c1688000..87a0c058 100644 --- a/.travis.yml +++ b/.travis.yml @@ -24,6 +24,8 @@ env: - PGX_TEST_PLAIN_PASSWORD_CONN_STRING=postgres://pgx_pw:secret@127.0.0.1/pgx_test matrix: - CRATEVERSION=2.1 PGX_TEST_CRATEDB_CONN_STRING="host=127.0.0.1 port=6543 user=pgx dbname=pgx_test" + - PGVERSION=12 + - PGVERSION=11 - PGVERSION=10 - PGVERSION=9.6 - PGVERSION=9.5 From 89416dd80542cc62f45af214ca0722c32e6624ca Mon Sep 17 00:00:00 2001 From: bakape Date: Wed, 1 Jan 2020 13:09:50 +0200 Subject: [PATCH 149/290] Enable passing nil context --- .gitignore | 3 +- doc.go | 3 + pgconn.go | 187 +++++++++++++++++++++++++++++++---------------------- 3 files changed, 116 insertions(+), 77 deletions(-) diff --git a/.gitignore b/.gitignore index 6eb9d442..e980f555 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .envrc -vendor/ \ No newline at end of file +vendor/ +.vscode diff --git a/doc.go b/doc.go index cde58cd8..12ed6630 100644 --- a/doc.go +++ b/doc.go @@ -23,6 +23,9 @@ Context Support All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the method immediately returns. In most circumstances, this will close the underlying connection. +A nil context can be passed for convenience. This has the same effect as passing context.Background() with an additional +slight performance increase, if you don't need the operation to be cancellable. + The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the client to abort. */ diff --git a/pgconn.go b/pgconn.go index 4c75d367..3b90b802 100644 --- a/pgconn.go +++ b/pgconn.go @@ -116,6 +116,10 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err panic("config must be created by ParseConfig") } + if ctx == nil { + ctx = context.Background() + } + // Simplify usage by treating primary config and fallbacks the same. fallbackConfigs := []*FallbackConfig{ { @@ -362,13 +366,15 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { } defer pgConn.unlock() - select { - case <-ctx.Done(): - return &contextAlreadyDoneError{err: ctx.Err()} - default: + if ctx != nil { + select { + case <-ctx.Done(): + return &contextAlreadyDoneError{err: ctx.Err()} + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() n, err := pgConn.conn.Write(buf) if err != nil { @@ -392,13 +398,15 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa } defer pgConn.unlock() - select { - case <-ctx.Done(): - return nil, &contextAlreadyDoneError{err: ctx.Err()} - default: + if ctx != nil { + select { + case <-ctx.Done(): + return nil, &contextAlreadyDoneError{err: ctx.Err()} + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() msg, err := pgConn.receiveMessage() if err != nil { @@ -489,8 +497,10 @@ func (pgConn *PgConn) Close(ctx context.Context) error { defer pgConn.conn.Close() - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() + if ctx != nil { + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } // Ignore any errors sending Terminate message and waiting for server to close connection. // This mimics the behavior of libpq PQfinish. It calls closePGconn which calls sendTerminateConn which purposefully @@ -586,13 +596,15 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ } defer pgConn.unlock() - select { - case <-ctx.Done(): - return nil, &contextAlreadyDoneError{err: ctx.Err()} - default: + if ctx != nil { + select { + case <-ctx.Done(): + return nil, &contextAlreadyDoneError{err: ctx.Err()} + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() buf := pgConn.wbuf buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) @@ -673,18 +685,24 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. serverAddr := pgConn.conn.RemoteAddr() - cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) + _ctx := ctx + if _ctx == nil { + _ctx = context.Background() + } + cancelConn, err := pgConn.config.DialFunc(_ctx, serverAddr.Network(), serverAddr.String()) if err != nil { return err } defer cancelConn.Close() - contextWatcher := ctxwatch.NewContextWatcher( - func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, - func() { cancelConn.SetDeadline(time.Time{}) }, - ) - contextWatcher.Watch(ctx) - defer contextWatcher.Unwatch() + if ctx != nil { + contextWatcher := ctxwatch.NewContextWatcher( + func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, + func() { cancelConn.SetDeadline(time.Time{}) }, + ) + contextWatcher.Watch(ctx) + defer contextWatcher.Unwatch() + } buf := make([]byte, 16) binary.BigEndian.PutUint32(buf[0:4], 16) @@ -712,14 +730,16 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { } defer pgConn.unlock() - select { - case <-ctx.Done(): - return ctx.Err() - default: - } + if ctx != nil { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } for { msg, err := pgConn.receiveMessage() @@ -752,16 +772,19 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { ctx: ctx, } multiResult := &pgConn.multiResultReader - - select { - case <-ctx.Done(): - multiResult.closed = true - multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} - pgConn.unlock() - return multiResult - default: + if ctx != nil { + select { + case <-ctx.Done(): + multiResult.closed = true + multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} + pgConn.unlock() + return multiResult + default: + } + pgConn.contextWatcher.Watch(ctx) + } else { + pgConn.multiResultReader.ctx = context.Background() } - pgConn.contextWatcher.Watch(ctx) buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) @@ -808,7 +831,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) - pgConn.execExtendedSuffix(ctx, buf, result) + pgConn.execExtendedSuffix(buf, result) return result } @@ -834,7 +857,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa buf := pgConn.wbuf buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) - pgConn.execExtendedSuffix(ctx, buf, result) + pgConn.execExtendedSuffix(buf, result) return result } @@ -845,6 +868,9 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by ctx: ctx, } result := &pgConn.resultReader + if ctx == nil { + pgConn.resultReader.ctx = context.Background() + } if err := pgConn.lock(); err != nil { result.concludeCommand(nil, err) @@ -859,20 +885,22 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by return result } - select { - case <-ctx.Done(): - result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()}) - result.closed = true - pgConn.unlock() - return result - default: + if ctx != nil { + select { + case <-ctx.Done(): + result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()}) + result.closed = true + pgConn.unlock() + return result + default: + } + pgConn.contextWatcher.Watch(ctx) } - pgConn.contextWatcher.Watch(ctx) return result } -func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result *ResultReader) { +func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) buf = (&pgproto3.Execute{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) @@ -893,14 +921,16 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm return nil, err } - select { - case <-ctx.Done(): - pgConn.unlock() - return nil, &contextAlreadyDoneError{err: ctx.Err()} - default: + if ctx != nil { + select { + case <-ctx.Done(): + pgConn.unlock() + return nil, &contextAlreadyDoneError{err: ctx.Err()} + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() // Send copy to command buf := pgConn.wbuf @@ -952,13 +982,15 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } defer pgConn.unlock() - select { - case <-ctx.Done(): - return nil, &contextAlreadyDoneError{err: ctx.Err()} - default: + if ctx != nil { + select { + case <-ctx.Done(): + return nil, &contextAlreadyDoneError{err: ctx.Err()} + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() // Send copy to command buf := pgConn.wbuf @@ -1344,16 +1376,19 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR ctx: ctx, } multiResult := &pgConn.multiResultReader - - select { - case <-ctx.Done(): - multiResult.closed = true - multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} - pgConn.unlock() - return multiResult - default: + if ctx != nil { + select { + case <-ctx.Done(): + multiResult.closed = true + multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} + pgConn.unlock() + return multiResult + default: + } + pgConn.contextWatcher.Watch(ctx) + } else { + pgConn.multiResultReader.ctx = context.Background() } - pgConn.contextWatcher.Watch(ctx) batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) From 719623452110bc4bce0e2358db9d3df658777eeb Mon Sep 17 00:00:00 2001 From: bakape Date: Wed, 1 Jan 2020 13:10:04 +0200 Subject: [PATCH 150/290] Benchmark nil context execution --- benchmark_test.go | 156 +++++++++++++++++++++++++++------------------- 1 file changed, 93 insertions(+), 63 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index 8067c985..1914e07a 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -14,9 +14,14 @@ func BenchmarkConnect(b *testing.B) { benchmarks := []struct { name string env string + ctx context.Context }{ - {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, - {"TCP", "PGX_TEST_TCP_CONN_STRING"}, + // The first benchmark in the list sometimes executes faster, no matter how + // you reorder it. Nil context is still faster on average. + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING", context.Background()}, + {"TCP", "PGX_TEST_TCP_CONN_STRING", context.Background()}, + {"Unix socket nil context", "PGX_TEST_UNIX_SOCKET_CONN_STRING", nil}, + {"TCP nil context", "PGX_TEST_TCP_CONN_STRING", nil}, } for _, bm := range benchmarks { @@ -28,10 +33,10 @@ func BenchmarkConnect(b *testing.B) { } for i := 0; i < b.N; i++ { - conn, err := pgconn.Connect(context.Background(), connString) + conn, err := pgconn.Connect(bm.ctx, connString) require.Nil(b, err) - err = conn.Close(context.Background()) + err = conn.Close(bm.ctx) require.Nil(b, err) } }) @@ -39,46 +44,58 @@ func BenchmarkConnect(b *testing.B) { } func BenchmarkExec(b *testing.B) { - conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.Nil(b, err) - defer closeConn(b, conn) - expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} + benchmarks := []struct { + name string + ctx context.Context + }{ + {"background context", context.Background()}, + {"nil context", nil}, + } - b.ResetTimer() + for _, bm := range benchmarks { + bm := bm + b.Run(bm.name, func(b *testing.B) { + conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.Nil(b, err) + defer closeConn(b, conn) - for i := 0; i < b.N; i++ { - mrr := conn.Exec(context.Background(), "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") + b.ResetTimer() - for mrr.NextResult() { - rr := mrr.ResultReader() + for i := 0; i < b.N; i++ { + mrr := conn.Exec(bm.ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") - rowCount := 0 - for rr.NextRow() { - rowCount++ - if len(rr.Values()) != len(expectedValues) { - b.Fatalf("unexpected number of values: %d", len(rr.Values())) - } - for i := range rr.Values() { - if !bytes.Equal(rr.Values()[i], expectedValues[i]) { - b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + for mrr.NextResult() { + rr := mrr.ResultReader() + + rowCount := 0 + for rr.NextRow() { + rowCount++ + if len(rr.Values()) != len(expectedValues) { + b.Fatalf("unexpected number of values: %d", len(rr.Values())) + } + for i := range rr.Values() { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + } + } + } + _, err = rr.Close() + + if err != nil { + b.Fatal(err) + } + if rowCount != 1 { + b.Fatalf("unexpected rowCount: %d", rowCount) } } - } - _, err = rr.Close() - if err != nil { - b.Fatal(err) + err := mrr.Close() + if err != nil { + b.Fatal(err) + } } - if rowCount != 1 { - b.Fatalf("unexpected rowCount: %d", rowCount) - } - } - - err := mrr.Close() - if err != nil { - b.Fatal(err) - } + }) } } @@ -130,40 +147,53 @@ func BenchmarkExecPossibleToCancel(b *testing.B) { } func BenchmarkExecPrepared(b *testing.B) { - conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.Nil(b, err) - defer closeConn(b, conn) - - _, err = conn.Prepare(context.Background(), "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) - require.Nil(b, err) - expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} - b.ResetTimer() + benchmarks := []struct { + name string + ctx context.Context + }{ + {"background context", context.Background()}, + {"nil context", nil}, + } - for i := 0; i < b.N; i++ { - rr := conn.ExecPrepared(context.Background(), "ps1", nil, nil, nil) + for _, bm := range benchmarks { + bm := bm + b.Run(bm.name, func(b *testing.B) { + conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.Nil(b, err) + defer closeConn(b, conn) - rowCount := 0 - for rr.NextRow() { - rowCount++ - if len(rr.Values()) != len(expectedValues) { - b.Fatalf("unexpected number of values: %d", len(rr.Values())) - } - for i := range rr.Values() { - if !bytes.Equal(rr.Values()[i], expectedValues[i]) { - b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + _, err = conn.Prepare(bm.ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) + require.Nil(b, err) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rr := conn.ExecPrepared(bm.ctx, "ps1", nil, nil, nil) + + rowCount := 0 + for rr.NextRow() { + rowCount++ + if len(rr.Values()) != len(expectedValues) { + b.Fatalf("unexpected number of values: %d", len(rr.Values())) + } + for i := range rr.Values() { + if !bytes.Equal(rr.Values()[i], expectedValues[i]) { + b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) + } + } + } + _, err = rr.Close() + + if err != nil { + b.Fatal(err) + } + if rowCount != 1 { + b.Fatalf("unexpected rowCount: %d", rowCount) } } - } - _, err = rr.Close() - - if err != nil { - b.Fatal(err) - } - if rowCount != 1 { - b.Fatalf("unexpected rowCount: %d", rowCount) - } + }) } } From 4d345164f1027d985717335e841868f60ca69ac2 Mon Sep 17 00:00:00 2001 From: bakape Date: Wed, 1 Jan 2020 14:36:38 +0200 Subject: [PATCH 151/290] Branch tests for nil context --- README.md | 4 +- helper_test.go | 22 + pgconn_test.go | 1500 +++++++++++++++++++++++++----------------------- 3 files changed, 818 insertions(+), 708 deletions(-) diff --git a/README.md b/README.md index 5d14e914..ddbfeaf3 100644 --- a/README.md +++ b/README.md @@ -11,13 +11,13 @@ low-level access to PostgreSQL functionality. ## Example Usage ```go -pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL")) +pgConn, err := pgconn.Connect(nil, os.Getenv("DATABASE_URL")) if err != nil { log.Fatalln("pgconn failed to connect:", err) } defer pgConn.Close() -result := pgConn.ExecParams(context.Background(), "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil) +result := pgConn.ExecParams(nil, "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil) for result.NextRow() { fmt.Println("User 123 has email:", string(result.Values()[0])) } diff --git a/helper_test.go b/helper_test.go index 1a3ca75e..1cb05fd2 100644 --- a/helper_test.go +++ b/helper_test.go @@ -29,3 +29,25 @@ func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) { assert.Equal(t, "2", string(result.Rows[1][0])) assert.Equal(t, "3", string(result.Rows[2][0])) } + +// Run subtest both with a context.Background() and nil context +func splitOnContext(t *testing.T, test func(t *testing.T, ctx context.Context)) { + t.Helper() + + cases := [...]struct { + name string + ctx context.Context + }{ + {"background context", context.Background()}, + {"nil context", nil}, + } + + for i := range cases { + c := cases[i] + t.Run(c.name, func(t *testing.T) { + t.Helper() + t.Parallel() + test(t, c.ctx) + }) + } +} diff --git a/pgconn_test.go b/pgconn_test.go index 6b57dd09..30d20229 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -27,31 +27,33 @@ import ( ) func TestConnect(t *testing.T) { - tests := []struct { - name string - env string - }{ - {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, - {"TCP", "PGX_TEST_TCP_CONN_STRING"}, - {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, - {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, - {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"}, - } + splitOnContext(t, func(t *testing.T, ctx context.Context) { + tests := []struct { + name string + env string + }{ + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, + {"TCP", "PGX_TEST_TCP_CONN_STRING"}, + {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, + {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, + {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"}, + } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - connString := os.Getenv(tt.env) - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", tt.env) - } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + connString := os.Getenv(tt.env) + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", tt.env) + } - conn, err := pgconn.Connect(context.Background(), connString) - require.NoError(t, err) + conn, err := pgconn.Connect(ctx, connString) + require.NoError(t, err) - closeConn(t, conn) - }) - } + closeConn(t, conn) + }) + } + }) } // TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure @@ -59,19 +61,21 @@ func TestConnect(t *testing.T) { func TestConnectTLS(t *testing.T) { t.Parallel() - connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") - } + splitOnContext(t, func(t *testing.T, ctx context.Context) { + connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") + } - conn, err := pgconn.Connect(context.Background(), connString) - require.NoError(t, err) + conn, err := pgconn.Connect(ctx, connString) + require.NoError(t, err) - if _, ok := conn.Conn().(*tls.Conn); !ok { - t.Error("not a TLS connection") - } + if _, ok := conn.Conn().(*tls.Conn); !ok { + t.Error("not a TLS connection") + } - closeConn(t, conn) + closeConn(t, conn) + }) } type pgmockWaitStep time.Duration @@ -138,233 +142,259 @@ func TestConnectWithContextThatTimesOut(t *testing.T) { func TestConnectInvalidUser(t *testing.T) { t.Parallel() - connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") - } + splitOnContext(t, func(t *testing.T, ctx context.Context) { + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } - config, err := pgconn.ParseConfig(connString) - require.NoError(t, err) + config, err := pgconn.ParseConfig(connString) + require.NoError(t, err) - config.User = "pgxinvalidusertest" + config.User = "pgxinvalidusertest" - _, err = pgconn.ConnectConfig(context.Background(), config) - require.Error(t, err) - pgErr, ok := errors.Unwrap(err).(*pgconn.PgError) - if !ok { - t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err) - } - if pgErr.Code != "28000" && pgErr.Code != "28P01" { - t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) - } + _, err = pgconn.ConnectConfig(ctx, config) + require.Error(t, err) + pgErr, ok := errors.Unwrap(err).(*pgconn.PgError) + if !ok { + t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err) + } + if pgErr.Code != "28000" && pgErr.Code != "28P01" { + t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) + } + }) } func TestConnectWithConnectionRefused(t *testing.T) { t.Parallel() - // Presumably nothing is listening on 127.0.0.1:1 - conn, err := pgconn.Connect(context.Background(), "host=127.0.0.1 port=1") - if err == nil { - conn.Close(context.Background()) - t.Fatal("Expected error establishing connection to bad port") - } + splitOnContext(t, func(t *testing.T, ctx context.Context) { + // Presumably nothing is listening on 127.0.0.1:1 + conn, err := pgconn.Connect(ctx, "host=127.0.0.1 port=1") + if err == nil { + conn.Close(ctx) + t.Fatal("Expected error establishing connection to bad port") + } + }) } func TestConnectCustomDialer(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - dialed := false - config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { - dialed = true - return net.Dial(network, address) - } + dialed := false + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + dialed = true + return net.Dial(network, address) + } - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - require.True(t, dialed) - closeConn(t, conn) + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + require.True(t, dialed) + closeConn(t, conn) + }) } func TestConnectCustomLookup(t *testing.T) { t.Parallel() - connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") - } + splitOnContext(t, func(t *testing.T, ctx context.Context) { + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } - config, err := pgconn.ParseConfig(connString) - require.NoError(t, err) + config, err := pgconn.ParseConfig(connString) + require.NoError(t, err) - looked := false - config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) { - looked = true - return net.LookupHost(host) - } + looked := false + config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) { + looked = true + return net.LookupHost(host) + } - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - require.True(t, looked) - closeConn(t, conn) + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + require.True(t, looked) + closeConn(t, conn) + }) } func TestConnectWithRuntimeParams(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - config.RuntimeParams = map[string]string{ - "application_name": "pgxtest", - "search_path": "myschema", - } + config.RuntimeParams = map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + } - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, conn) + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, conn) - result := conn.ExecParams(context.Background(), "show application_name", nil, nil, nil, nil).Read() - require.Nil(t, result.Err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, "pgxtest", string(result.Rows[0][0])) + result := conn.ExecParams(ctx, "show application_name", nil, nil, nil, nil).Read() + require.Nil(t, result.Err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "pgxtest", string(result.Rows[0][0])) - result = conn.ExecParams(context.Background(), "show search_path", nil, nil, nil, nil).Read() - require.Nil(t, result.Err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, "myschema", string(result.Rows[0][0])) + result = conn.ExecParams(ctx, "show search_path", nil, nil, nil, nil).Read() + require.Nil(t, result.Err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "myschema", string(result.Rows[0][0])) + }) } func TestConnectWithFallback(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - // Prepend current primary config to fallbacks - config.Fallbacks = append([]*pgconn.FallbackConfig{ - &pgconn.FallbackConfig{ - Host: config.Host, - Port: config.Port, - TLSConfig: config.TLSConfig, - }, - }, config.Fallbacks...) + // Prepend current primary config to fallbacks + config.Fallbacks = append([]*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }, + }, config.Fallbacks...) - // Make primary config bad - config.Host = "localhost" - config.Port = 1 // presumably nothing listening here + // Make primary config bad + config.Host = "localhost" + config.Port = 1 // presumably nothing listening here - // Prepend bad first fallback - config.Fallbacks = append([]*pgconn.FallbackConfig{ - &pgconn.FallbackConfig{ - Host: "localhost", - Port: 1, - TLSConfig: config.TLSConfig, - }, - }, config.Fallbacks...) + // Prepend bad first fallback + config.Fallbacks = append([]*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 1, + TLSConfig: config.TLSConfig, + }, + }, config.Fallbacks...) - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - closeConn(t, conn) + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + closeConn(t, conn) + }) } func TestConnectWithValidateConnect(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - dialCount := 0 - config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { - dialCount++ - return net.Dial(network, address) - } - - acceptConnCount := 0 - config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error { - acceptConnCount++ - if acceptConnCount < 2 { - return errors.New("reject first conn") + dialCount := 0 + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + dialCount++ + return net.Dial(network, address) } - return nil - } - // Append current primary config to fallbacks - config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{ - Host: config.Host, - Port: config.Port, - TLSConfig: config.TLSConfig, + acceptConnCount := 0 + config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + acceptConnCount++ + if acceptConnCount < 2 { + return errors.New("reject first conn") + } + return nil + } + + // Append current primary config to fallbacks + config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }) + + // Repeat fallbacks + config.Fallbacks = append(config.Fallbacks, config.Fallbacks...) + + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + closeConn(t, conn) + + assert.True(t, dialCount > 1) + assert.True(t, acceptConnCount > 1) }) - - // Repeat fallbacks - config.Fallbacks = append(config.Fallbacks, config.Fallbacks...) - - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - closeConn(t, conn) - - assert.True(t, dialCount > 1) - assert.True(t, acceptConnCount > 1) } func TestConnectWithValidateConnectTargetSessionAttrsReadWrite(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite - config.RuntimeParams["default_transaction_read_only"] = "on" + config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite + config.RuntimeParams["default_transaction_read_only"] = "on" - conn, err := pgconn.ConnectConfig(context.Background(), config) - if !assert.NotNil(t, err) { - conn.Close(context.Background()) - } + conn, err := pgconn.ConnectConfig(ctx, config) + if !assert.NotNil(t, err) { + conn.Close(ctx) + } + }) } func TestConnectWithAfterConnect(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { - _, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll() - return err - } + config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + _, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll() + return err + } - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) - results, err := conn.Exec(context.Background(), "show search_path;").ReadAll() - require.NoError(t, err) - defer closeConn(t, conn) + results, err := conn.Exec(ctx, "show search_path;").ReadAll() + require.NoError(t, err) + defer closeConn(t, conn) - assert.Equal(t, []byte("foobar"), results[0].Rows[0][0]) + assert.Equal(t, []byte("foobar"), results[0].Rows[0][0]) + }) } func TestConnectConfigRequiresConfigFromParseConfig(t *testing.T) { t.Parallel() - config := &pgconn.Config{} + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config := &pgconn.Config{} - require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgconn.ConnectConfig(context.Background(), config) }) + require.PanicsWithValue( + t, + "config must be created by ParseConfig", + func() { pgconn.ConnectConfig(ctx, config) }, + ) + }) } func TestConnPrepareSyntaxError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - psd, err := pgConn.Prepare(context.Background(), "ps1", "SYNTAX ERROR", nil) - require.Nil(t, psd) - require.NotNil(t, err) + psd, err := pgConn.Prepare(ctx, "ps1", "SYNTAX ERROR", nil) + require.Nil(t, psd) + require.NotNil(t, err) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnPrepareContextPrecanceled(t *testing.T) { @@ -388,116 +418,126 @@ func TestConnPrepareContextPrecanceled(t *testing.T) { func TestConnExec(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() - assert.NoError(t, err) + results, err := pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() + assert.NoError(t, err) - assert.Len(t, results, 1) - assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecEmpty(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - multiResult := pgConn.Exec(context.Background(), ";") + multiResult := pgConn.Exec(ctx, ";") - resultCount := 0 - for multiResult.NextResult() { - resultCount++ - multiResult.ResultReader().Close() - } - assert.Equal(t, 0, resultCount) - err = multiResult.Close() - assert.NoError(t, err) + resultCount := 0 + for multiResult.NextResult() { + resultCount++ + multiResult.ResultReader().Close() + } + assert.Equal(t, 0, resultCount) + err = multiResult.Close() + assert.NoError(t, err) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecMultipleQueries(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - results, err := pgConn.Exec(context.Background(), "select 'Hello, world'; select 1").ReadAll() - assert.NoError(t, err) + results, err := pgConn.Exec(ctx, "select 'Hello, world'; select 1").ReadAll() + assert.NoError(t, err) - assert.Len(t, results, 2) + assert.Len(t, results, 2) - assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) - assert.Nil(t, results[1].Err) - assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) - assert.Len(t, results[1].Rows, 1) - assert.Equal(t, "1", string(results[1].Rows[0][0])) + assert.Nil(t, results[1].Err) + assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) + assert.Len(t, results[1].Rows, 1) + assert.Equal(t, "1", string(results[1].Rows[0][0])) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecMultipleQueriesError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - results, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1").ReadAll() - require.NotNil(t, err) - if pgErr, ok := err.(*pgconn.PgError); ok { - assert.Equal(t, "22012", pgErr.Code) - } else { - t.Errorf("unexpected error: %v", err) - } + results, err := pgConn.Exec(ctx, "select 1; select 1/0; select 1").ReadAll() + require.NotNil(t, err) + if pgErr, ok := err.(*pgconn.PgError); ok { + assert.Equal(t, "22012", pgErr.Code) + } else { + t.Errorf("unexpected error: %v", err) + } - assert.Len(t, results, 1) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "1", string(results[0].Rows[0][0])) + assert.Len(t, results, 1) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "1", string(results[0].Rows[0][0])) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecDeferredError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - setupSQL := `create temporary table t ( - id text primary key, - n int not null, - unique (n) deferrable initially deferred - ); + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); - insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` - _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() - assert.NoError(t, err) + _, err = pgConn.Exec(ctx, setupSQL).ReadAll() + assert.NoError(t, err) - _, err = pgConn.Exec(context.Background(), `update t set n=n+1 where id='b' returning *`).ReadAll() - require.NotNil(t, err) + _, err = pgConn.Exec(ctx, `update t set n=n+1 where id='b' returning *`).ReadAll() + require.NotNil(t, err) - var pgErr *pgconn.PgError - require.True(t, errors.As(err, &pgErr)) - require.Equal(t, "23505", pgErr.Code) + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecContextCanceled(t *testing.T) { @@ -538,95 +578,103 @@ func TestConnExecContextPrecanceled(t *testing.T) { func TestConnExecParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - result := pgConn.ExecParams(context.Background(), "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) - rowCount := 0 - for result.NextRow() { - rowCount += 1 - assert.Equal(t, "Hello, world", string(result.Values()[0])) - } - assert.Equal(t, 1, rowCount) - commandTag, err := result.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.NoError(t, err) + result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.NoError(t, err) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecParamsDeferredError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - setupSQL := `create temporary table t ( - id text primary key, - n int not null, - unique (n) deferrable initially deferred - ); + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); - insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` - _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() - assert.NoError(t, err) + _, err = pgConn.Exec(ctx, setupSQL).ReadAll() + assert.NoError(t, err) - result := pgConn.ExecParams(context.Background(), `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read() - require.NotNil(t, result.Err) - var pgErr *pgconn.PgError - require.True(t, errors.As(result.Err, &pgErr)) - require.Equal(t, "23505", pgErr.Code) + result := pgConn.ExecParams(ctx, `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read() + require.NotNil(t, result.Err) + var pgErr *pgconn.PgError + require.True(t, errors.As(result.Err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecParamsMaxNumberOfParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() - require.NoError(t, result.Err) - require.Len(t, result.Rows, paramCount) + result := pgConn.ExecParams(ctx, sql, args, nil, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecParamsTooManyParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 + 1 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + 1 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() - require.Error(t, result.Err) - require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) + result := pgConn.ExecParams(ctx, sql, args, nil, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecParamsCanceled(t *testing.T) { @@ -671,86 +719,92 @@ func TestConnExecParamsPrecanceled(t *testing.T) { func TestConnExecPrepared(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) - require.NoError(t, err) - require.NotNil(t, psd) - assert.Len(t, psd.ParamOIDs, 1) - assert.Len(t, psd.Fields, 1) + psd, err := pgConn.Prepare(ctx, "ps1", "select $1::text", nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, 1) + assert.Len(t, psd.Fields, 1) - result := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) - rowCount := 0 - for result.NextRow() { - rowCount += 1 - assert.Equal(t, "Hello, world", string(result.Values()[0])) - } - assert.Equal(t, 1, rowCount) - commandTag, err := result.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.NoError(t, err) + result := pgConn.ExecPrepared(ctx, "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.NoError(t, err) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecPreparedMaxNumberOfParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) - require.NoError(t, err) - require.NotNil(t, psd) - assert.Len(t, psd.ParamOIDs, paramCount) - assert.Len(t, psd.Fields, 1) + psd, err := pgConn.Prepare(ctx, "ps1", sql, nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) - result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() - require.NoError(t, result.Err) - require.Len(t, result.Rows, paramCount) + result := pgConn.ExecPrepared(ctx, "ps1", args, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecPreparedTooManyParams(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 + 1 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + 1 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) - require.NoError(t, err) - require.NotNil(t, psd) - assert.Len(t, psd.ParamOIDs, paramCount) - assert.Len(t, psd.Fields, 1) + psd, err := pgConn.Prepare(ctx, "ps1", sql, nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) - result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() - require.Error(t, result.Err) - require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) + result := pgConn.ExecPrepared(ctx, "ps1", args, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecPreparedCanceled(t *testing.T) { @@ -800,63 +854,67 @@ func TestConnExecPreparedPrecanceled(t *testing.T) { func TestConnExecBatch(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) - require.NoError(t, err) + _, err = pgConn.Prepare(ctx, "ps1", "select $1::text", nil) + require.NoError(t, err) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) - batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) - batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) - results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() - require.NoError(t, err) - require.Len(t, results, 3) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) + batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) + results, err := pgConn.ExecBatch(ctx, batch).ReadAll() + require.NoError(t, err) + require.Len(t, results, 3) - require.Len(t, results[0].Rows, 1) - require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + require.Len(t, results[0].Rows, 1) + require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - require.Len(t, results[1].Rows, 1) - require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) + require.Len(t, results[1].Rows, 1) + require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) - require.Len(t, results[2].Rows, 1) - require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) + require.Len(t, results[2].Rows, 1) + require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) + }) } func TestConnExecBatchDeferredError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - setupSQL := `create temporary table t ( - id text primary key, - n int not null, - unique (n) deferrable initially deferred - ); + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); - insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` - _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() - assert.NoError(t, err) + _, err = pgConn.Exec(ctx, setupSQL).ReadAll() + assert.NoError(t, err) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil) - _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() - require.NotNil(t, err) - var pgErr *pgconn.PgError - require.True(t, errors.As(err, &pgErr)) - require.Equal(t, "23505", pgErr.Code) + batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil) + _, err = pgConn.ExecBatch(ctx, batch).ReadAll() + require.NotNil(t, err) + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnExecBatchPrecanceled(t *testing.T) { @@ -895,76 +953,82 @@ func TestConnExecBatchHuge(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - queryCount := 100000 - args := make([]string, queryCount) + queryCount := 100000 + args := make([]string, queryCount) - for i := range args { - args[i] = strconv.Itoa(i) - batch.ExecParams("select $1::text", [][]byte{[]byte(args[i])}, nil, nil, nil) - } + for i := range args { + args[i] = strconv.Itoa(i) + batch.ExecParams("select $1::text", [][]byte{[]byte(args[i])}, nil, nil, nil) + } - results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() - require.NoError(t, err) - require.Len(t, results, queryCount) + results, err := pgConn.ExecBatch(ctx, batch).ReadAll() + require.NoError(t, err) + require.Len(t, results, queryCount) - for i := range args { - require.Len(t, results[i].Rows, 1) - require.Equal(t, args[i], string(results[i].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[i].CommandTag)) - } + for i := range args { + require.Len(t, results[i].Rows, 1) + require.Equal(t, args[i], string(results[i].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[i].CommandTag)) + } + }) } func TestConnExecBatchImplicitTransaction(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(context.Background(), "create temporary table t(id int)").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, "create temporary table t(id int)").ReadAll() + require.NoError(t, err) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - batch.ExecParams("insert into t(id) values(1)", nil, nil, nil, nil) - batch.ExecParams("insert into t(id) values(2)", nil, nil, nil, nil) - batch.ExecParams("insert into t(id) values(3)", nil, nil, nil, nil) - batch.ExecParams("select 1/0", nil, nil, nil, nil) - _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() - require.Error(t, err) + batch.ExecParams("insert into t(id) values(1)", nil, nil, nil, nil) + batch.ExecParams("insert into t(id) values(2)", nil, nil, nil, nil) + batch.ExecParams("insert into t(id) values(3)", nil, nil, nil, nil) + batch.ExecParams("select 1/0", nil, nil, nil, nil) + _, err = pgConn.ExecBatch(ctx, batch).ReadAll() + require.Error(t, err) - result := pgConn.ExecParams(context.Background(), "select count(*) from t", nil, nil, nil, nil).Read() - require.Equal(t, "0", string(result.Rows[0][0])) + result := pgConn.ExecParams(ctx, "select count(*) from t", nil, nil, nil, nil).Read() + require.Equal(t, "0", string(result.Rows[0][0])) + }) } func TestConnLocking(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") - _, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() - assert.Error(t, err) - assert.Equal(t, "conn busy", err.Error()) - assert.True(t, pgconn.SafeToRetry(err)) + mrr := pgConn.Exec(ctx, "select 'Hello, world'") + _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() + assert.Error(t, err) + assert.Equal(t, "conn busy", err.Error()) + assert.True(t, pgconn.SafeToRetry(err)) - results, err := mrr.ReadAll() - assert.NoError(t, err) - assert.Len(t, results, 1) - assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + results, err := mrr.ReadAll() + assert.NoError(t, err) + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestCommandTag(t *testing.T) { @@ -993,91 +1057,97 @@ func TestCommandTag(t *testing.T) { func TestConnOnNotice(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - var msg string - config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) { - msg = notice.Message - } + var msg string + config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) { + msg = notice.Message + } - pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) - multiResult := pgConn.Exec(context.Background(), `do $$ -begin - raise notice 'hello, world'; -end$$;`) - err = multiResult.Close() - require.NoError(t, err) - assert.Equal(t, "hello, world", msg) + multiResult := pgConn.Exec(ctx, `do $$ + begin + raise notice 'hello, world'; + end$$;`) + err = multiResult.Close() + require.NoError(t, err) + assert.Equal(t, "hello, world", msg) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnOnNotification(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - var msg string - config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { - msg = n.Payload - } + var msg string + config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { + msg = n.Payload + } - pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, "listen foo").ReadAll() + require.NoError(t, err) - notifier, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, notifier) - _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() - require.NoError(t, err) + notifier, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, notifier) + _, err = notifier.Exec(ctx, "notify foo, 'bar'").ReadAll() + require.NoError(t, err) - _, err = pgConn.Exec(context.Background(), "select 1").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, "select 1").ReadAll() + require.NoError(t, err) - assert.Equal(t, "bar", msg) + assert.Equal(t, "bar", msg) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnWaitForNotification(t *testing.T) { t.Parallel() - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - var msg string - config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { - msg = n.Payload - } + var msg string + config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { + msg = n.Payload + } - pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, "listen foo").ReadAll() + require.NoError(t, err) - notifier, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, notifier) - _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() - require.NoError(t, err) + notifier, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, notifier) + _, err = notifier.Exec(ctx, "notify foo, 'bar'").ReadAll() + require.NoError(t, err) - err = pgConn.WaitForNotification(context.Background()) - require.NoError(t, err) + err = pgConn.WaitForNotification(ctx) + require.NoError(t, err) - assert.Equal(t, "bar", msg) + assert.Equal(t, "bar", msg) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnWaitForNotificationPrecanceled(t *testing.T) { @@ -1119,94 +1189,100 @@ func TestConnWaitForNotificationTimeout(t *testing.T) { func TestConnCopyToSmall(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(context.Background(), `create temporary table foo( - a int2, - b int4, - c int8, - d varchar, - e text, - f date, - g json - )`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g json + )`).ReadAll() + require.NoError(t, err) - _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll() + require.NoError(t, err) - _, err = pgConn.Exec(context.Background(), `insert into foo values (null, null, null, null, null, null, null)`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, `insert into foo values (null, null, null, null, null, null, null)`).ReadAll() + require.NoError(t, err) - inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" + - "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n") + inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" + + "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n") - outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) + outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) - res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") - require.NoError(t, err) + res, err := pgConn.CopyTo(ctx, outputWriter, "copy foo to stdout") + require.NoError(t, err) - assert.Equal(t, int64(2), res.RowsAffected()) - assert.Equal(t, inputBytes, outputWriter.Bytes()) + assert.Equal(t, int64(2), res.RowsAffected()) + assert.Equal(t, inputBytes, outputWriter.Bytes()) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnCopyToLarge(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - _, err = pgConn.Exec(context.Background(), `create temporary table foo( - a int2, - b int4, - c int8, - d varchar, - e text, - f date, - g json, - h bytea - )`).ReadAll() - require.NoError(t, err) - - inputBytes := make([]byte, 0) - - for i := 0; i < 1000; i++ { - _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll() + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) - inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...) - } + defer closeConn(t, pgConn) - outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g json, + h bytea + )`).ReadAll() + require.NoError(t, err) - res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") - require.NoError(t, err) + inputBytes := make([]byte, 0) - assert.Equal(t, int64(1000), res.RowsAffected()) - assert.Equal(t, inputBytes, outputWriter.Bytes()) + for i := 0; i < 1000; i++ { + _, err = pgConn.Exec(ctx, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll() + require.NoError(t, err) + inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...) + } - ensureConnValid(t, pgConn) + outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) + + res, err := pgConn.CopyTo(ctx, outputWriter, "copy foo to stdout") + require.NoError(t, err) + + assert.Equal(t, int64(1000), res.RowsAffected()) + assert.Equal(t, inputBytes, outputWriter.Bytes()) + + ensureConnValid(t, pgConn) + }) } func TestConnCopyToQueryError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - outputWriter := bytes.NewBuffer(make([]byte, 0)) + outputWriter := bytes.NewBuffer(make([]byte, 0)) - res, err := pgConn.CopyTo(context.Background(), outputWriter, "cropy foo to stdout") - require.Error(t, err) - assert.IsType(t, &pgconn.PgError{}, err) - assert.Equal(t, int64(0), res.RowsAffected()) + res, err := pgConn.CopyTo(ctx, outputWriter, "cropy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnCopyToCanceled(t *testing.T) { @@ -1250,37 +1326,39 @@ func TestConnCopyToPrecanceled(t *testing.T) { func TestConnCopyFrom(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - _, err = pgConn.Exec(context.Background(), `create temporary table foo( - a int4, - b varchar - )`).ReadAll() - require.NoError(t, err) - - srcBuf := &bytes.Buffer{} - - inputRows := [][][]byte{} - for i := 0; i < 1000; i++ { - a := strconv.Itoa(i) - b := "foo " + a + " bar" - inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) - _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) - } + defer closeConn(t, pgConn) - ct, err := pgConn.CopyFrom(context.Background(), srcBuf, "COPY foo FROM STDIN WITH (FORMAT csv)") - require.NoError(t, err) - assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) - result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() - require.NoError(t, result.Err) + srcBuf := &bytes.Buffer{} - assert.Equal(t, inputRows, result.Rows) + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + require.NoError(t, err) + } - ensureConnValid(t, pgConn) + ct, err := pgConn.CopyFrom(ctx, srcBuf, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + + result := pgConn.ExecParams(ctx, "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + assert.Equal(t, inputRows, result.Rows) + + ensureConnValid(t, pgConn) + }) } func TestConnCopyFromCanceled(t *testing.T) { @@ -1358,153 +1436,163 @@ func TestConnCopyFromPrecanceled(t *testing.T) { func TestConnCopyFromGzipReader(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) - - _, err = pgConn.Exec(context.Background(), `create temporary table foo( - a int4, - b varchar - )`).ReadAll() - require.NoError(t, err) - - f, err := ioutil.TempFile("", "*") - require.NoError(t, err) - - gw := gzip.NewWriter(f) - - inputRows := [][][]byte{} - for i := 0; i < 1000; i++ { - a := strconv.Itoa(i) - b := "foo " + a + " bar" - inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) - _, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) - } + defer closeConn(t, pgConn) - err = gw.Close() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) - _, err = f.Seek(0, 0) - require.NoError(t, err) + f, err := ioutil.TempFile("", "*") + require.NoError(t, err) - gr, err := gzip.NewReader(f) - require.NoError(t, err) + gw := gzip.NewWriter(f) - ct, err := pgConn.CopyFrom(context.Background(), gr, "COPY foo FROM STDIN WITH (FORMAT csv)") - require.NoError(t, err) - assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + require.NoError(t, err) + } - err = gr.Close() - require.NoError(t, err) + err = gw.Close() + require.NoError(t, err) - err = f.Close() - require.NoError(t, err) + _, err = f.Seek(0, 0) + require.NoError(t, err) - err = os.Remove(f.Name()) - require.NoError(t, err) + gr, err := gzip.NewReader(f) + require.NoError(t, err) - result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() - require.NoError(t, result.Err) + ct, err := pgConn.CopyFrom(ctx, gr, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) - assert.Equal(t, inputRows, result.Rows) + err = gr.Close() + require.NoError(t, err) - ensureConnValid(t, pgConn) + err = f.Close() + require.NoError(t, err) + + err = os.Remove(f.Name()) + require.NoError(t, err) + + result := pgConn.ExecParams(ctx, "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + assert.Equal(t, inputRows, result.Rows) + + ensureConnValid(t, pgConn) + }) } func TestConnCopyFromQuerySyntaxError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(context.Background(), `create temporary table foo( - a int4, - b varchar - )`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) - srcBuf := &bytes.Buffer{} + srcBuf := &bytes.Buffer{} - res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo to stdout") - require.Error(t, err) - assert.IsType(t, &pgconn.PgError{}, err) - assert.Equal(t, int64(0), res.RowsAffected()) + res, err := pgConn.CopyFrom(ctx, srcBuf, "cropy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnCopyFromQueryNoTableError(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - srcBuf := &bytes.Buffer{} + srcBuf := &bytes.Buffer{} - res, err := pgConn.CopyFrom(context.Background(), srcBuf, "copy foo to stdout") - require.Error(t, err) - assert.IsType(t, &pgconn.PgError{}, err) - assert.Equal(t, int64(0), res.RowsAffected()) + res, err := pgConn.CopyFrom(ctx, srcBuf, "copy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnEscapeString(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - tests := []struct { - in string - out string - }{ - {in: "", out: ""}, - {in: "42", out: "42"}, - {in: "'", out: "''"}, - {in: "hi'there", out: "hi''there"}, - {in: "'hi there'", out: "''hi there''"}, - } - - for i, tt := range tests { - value, err := pgConn.EscapeString(tt.in) - if assert.NoErrorf(t, err, "%d.", i) { - assert.Equalf(t, tt.out, value, "%d.", i) + tests := []struct { + in string + out string + }{ + {in: "", out: ""}, + {in: "42", out: "42"}, + {in: "'", out: "''"}, + {in: "hi'there", out: "hi''there"}, + {in: "'hi there'", out: "''hi there''"}, } - } - ensureConnValid(t, pgConn) + for i, tt := range tests { + value, err := pgConn.EscapeString(tt.in) + if assert.NoErrorf(t, err, "%d.", i) { + assert.Equalf(t, tt.out, value, "%d.", i) + } + } + + ensureConnValid(t, pgConn) + }) } func TestConnCancelRequest(t *testing.T) { t.Parallel() - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + splitOnContext(t, func(t *testing.T, ctx context.Context) { + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - multiResult := pgConn.Exec(context.Background(), "select 'Hello, world', pg_sleep(2)") + multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(2)") - // This test flickers without the Sleep. It appears that since Exec only sends the query and returns without awaiting a - // response that the CancelRequest can race it and be received before the query is running and cancellable. So wait a - // few milliseconds. - time.Sleep(50 * time.Millisecond) + // This test flickers without the Sleep. It appears that since Exec only sends the query and returns without awaiting a + // response that the CancelRequest can race it and be received before the query is running and cancellable. So wait a + // few milliseconds. + time.Sleep(50 * time.Millisecond) - err = pgConn.CancelRequest(context.Background()) - require.NoError(t, err) + err = pgConn.CancelRequest(ctx) + require.NoError(t, err) - for multiResult.NextResult() { - } - err = multiResult.Close() + for multiResult.NextResult() { + } + err = multiResult.Close() - require.IsType(t, &pgconn.PgError{}, err) - require.Equal(t, "57014", err.(*pgconn.PgError).Code) + require.IsType(t, &pgconn.PgError{}, err) + require.Equal(t, "57014", err.(*pgconn.PgError).Code) - ensureConnValid(t, pgConn) + ensureConnValid(t, pgConn) + }) } func TestConnSendBytesAndReceiveMessage(t *testing.T) { @@ -1547,13 +1635,13 @@ func TestConnSendBytesAndReceiveMessage(t *testing.T) { } func Example() { - pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + pgConn, err := pgconn.Connect(nil, os.Getenv("PGX_TEST_CONN_STRING")) if err != nil { log.Fatalln(err) } - defer pgConn.Close(context.Background()) + defer pgConn.Close(nil) - result := pgConn.ExecParams(context.Background(), "select generate_series(1,3)", nil, nil, nil, nil).Read() + result := pgConn.ExecParams(nil, "select generate_series(1,3)", nil, nil, nil, nil).Read() if result.Err != nil { log.Fatalln(result.Err) } From 93722181071cd124ad5bb67122d33b31d4ada632 Mon Sep 17 00:00:00 2001 From: bakape Date: Wed, 1 Jan 2020 19:34:56 +0200 Subject: [PATCH 152/290] Don't synchronize with context.Background() --- benchmark_test.go | 12 +++++++---- doc.go | 4 +--- pgconn.go | 52 +++++++++++++++++++++++++++++++++-------------- 3 files changed, 46 insertions(+), 22 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index 1914e07a..4cce5a97 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -18,8 +18,10 @@ func BenchmarkConnect(b *testing.B) { }{ // The first benchmark in the list sometimes executes faster, no matter how // you reorder it. Nil context is still faster on average. - {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING", context.Background()}, - {"TCP", "PGX_TEST_TCP_CONN_STRING", context.Background()}, + // + // Using and empty context other than context.Background() to compare. + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING", context.TODO()}, + {"TCP", "PGX_TEST_TCP_CONN_STRING", context.TODO()}, {"Unix socket nil context", "PGX_TEST_UNIX_SOCKET_CONN_STRING", nil}, {"TCP nil context", "PGX_TEST_TCP_CONN_STRING", nil}, } @@ -49,7 +51,8 @@ func BenchmarkExec(b *testing.B) { name string ctx context.Context }{ - {"background context", context.Background()}, + // Using and empty context other than context.Background() to compare. + {"empty context", context.TODO()}, {"nil context", nil}, } @@ -153,7 +156,8 @@ func BenchmarkExecPrepared(b *testing.B) { name string ctx context.Context }{ - {"background context", context.Background()}, + // Using and empty context other than context.Background() to compare. + {"empty context", context.TODO()}, {"nil context", nil}, } diff --git a/doc.go b/doc.go index 12ed6630..25382c68 100644 --- a/doc.go +++ b/doc.go @@ -22,9 +22,7 @@ Context Support All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the method immediately returns. In most circumstances, this will close the underlying connection. - -A nil context can be passed for convenience. This has the same effect as passing context.Background() with an additional -slight performance increase, if you don't need the operation to be cancellable. +A nil context can be passed for convenience. This has the same effect as passing context.Background(). The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the client to abort. diff --git a/pgconn.go b/pgconn.go index 3b90b802..b8ea9df7 100644 --- a/pgconn.go +++ b/pgconn.go @@ -366,7 +366,9 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { } defer pgConn.unlock() - if ctx != nil { + switch ctx { + case nil, context.Background(): + default: select { case <-ctx.Done(): return &contextAlreadyDoneError{err: ctx.Err()} @@ -398,7 +400,9 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa } defer pgConn.unlock() - if ctx != nil { + switch ctx { + case nil, context.Background(): + default: select { case <-ctx.Done(): return nil, &contextAlreadyDoneError{err: ctx.Err()} @@ -497,7 +501,9 @@ func (pgConn *PgConn) Close(ctx context.Context) error { defer pgConn.conn.Close() - if ctx != nil { + switch ctx { + case nil, context.Background(): + default: pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() } @@ -596,7 +602,9 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ } defer pgConn.unlock() - if ctx != nil { + switch ctx { + case nil, context.Background(): + default: select { case <-ctx.Done(): return nil, &contextAlreadyDoneError{err: ctx.Err()} @@ -695,7 +703,9 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { } defer cancelConn.Close() - if ctx != nil { + switch ctx { + case nil, context.Background(): + default: contextWatcher := ctxwatch.NewContextWatcher( func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, func() { cancelConn.SetDeadline(time.Time{}) }, @@ -730,7 +740,9 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { } defer pgConn.unlock() - if ctx != nil { + switch ctx { + case nil, context.Background(): + default: select { case <-ctx.Done(): return ctx.Err() @@ -772,7 +784,11 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { ctx: ctx, } multiResult := &pgConn.multiResultReader - if ctx != nil { + switch ctx { + case nil: + pgConn.multiResultReader.ctx = context.Background() + case context.Background(): + default: select { case <-ctx.Done(): multiResult.closed = true @@ -782,8 +798,6 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { default: } pgConn.contextWatcher.Watch(ctx) - } else { - pgConn.multiResultReader.ctx = context.Background() } buf := pgConn.wbuf @@ -885,7 +899,9 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by return result } - if ctx != nil { + switch ctx { + case nil, context.Background(): + default: select { case <-ctx.Done(): result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()}) @@ -921,7 +937,9 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm return nil, err } - if ctx != nil { + switch ctx { + case nil, context.Background(): + default: select { case <-ctx.Done(): pgConn.unlock() @@ -982,7 +1000,9 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } defer pgConn.unlock() - if ctx != nil { + switch ctx { + case nil, context.Background(): + default: select { case <-ctx.Done(): return nil, &contextAlreadyDoneError{err: ctx.Err()} @@ -1376,7 +1396,11 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR ctx: ctx, } multiResult := &pgConn.multiResultReader - if ctx != nil { + switch ctx { + case nil: + pgConn.multiResultReader.ctx = context.Background() + case context.Background(): + default: select { case <-ctx.Done(): multiResult.closed = true @@ -1386,8 +1410,6 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR default: } pgConn.contextWatcher.Watch(ctx) - } else { - pgConn.multiResultReader.ctx = context.Background() } batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) From 98b3c57584a2bde785c3f706afcd3d371d6faec3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 8 Jan 2020 10:02:32 -0600 Subject: [PATCH 153/290] Try to cancel any in-progress query when a conn is closed by ctx cancel See https://github.com/jackc/pgx/issues/659 --- pgconn.go | 58 +++++++++++++++++++++++++++++++------------------- pgconn_test.go | 43 +++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 22 deletions(-) diff --git a/pgconn.go b/pgconn.go index 4c75d367..70c33c4f 100644 --- a/pgconn.go +++ b/pgconn.go @@ -372,7 +372,7 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { n, err := pgConn.conn.Write(buf) if err != nil { - pgConn.hardClose() + pgConn.ayncClose() return &writeError{err: err, safeToRetry: n == 0} } @@ -429,7 +429,7 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { if err != nil { // Close on anything other than timeout error - everything else is fatal if err, ok := err.(net.Error); !(ok && err.Timeout()) { - pgConn.hardClose() + pgConn.ayncClose() } return nil, err @@ -442,7 +442,7 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { pgConn.parameterStatuses[msg.Name] = msg.Value case *pgproto3.ErrorResponse: if msg.Severity == "FATAL" { - pgConn.hardClose() + pgConn.ayncClose() return nil, ErrorResponseToPgError(msg) } case *pgproto3.NoticeResponse: @@ -503,14 +503,28 @@ func (pgConn *PgConn) Close(ctx context.Context) error { return pgConn.conn.Close() } -// hardClose closes the underlying connection without sending the exit message. -func (pgConn *PgConn) hardClose() error { +// ayncClose marks the connection as closed and asynchronously sends a cancel query message and closes the underlying +// connection. +func (pgConn *PgConn) ayncClose() { if pgConn.status == connStatusClosed { - return nil + return } pgConn.status = connStatusClosed - return pgConn.conn.Close() + go func() { + defer pgConn.conn.Close() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) + defer cancel() + + pgConn.CancelRequest(ctx) + + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + + pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) + pgConn.conn.Read(make([]byte, 1)) + }() } // IsClosed reports if the connection has been closed. @@ -601,7 +615,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ n, err := pgConn.conn.Write(buf) if err != nil { - pgConn.hardClose() + pgConn.ayncClose() return nil, &pgconnError{msg: "write failed", err: err, safeToRetry: n == 0} } @@ -613,7 +627,7 @@ readloop: for { msg, err := pgConn.receiveMessage() if err != nil { - pgConn.hardClose() + pgConn.ayncClose() return nil, err } @@ -768,7 +782,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { n, err := pgConn.conn.Write(buf) if err != nil { - pgConn.hardClose() + pgConn.ayncClose() pgConn.contextWatcher.Unwatch() multiResult.closed = true multiResult.err = &writeError{err: err, safeToRetry: n == 0} @@ -879,7 +893,7 @@ func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result n, err := pgConn.conn.Write(buf) if err != nil { - pgConn.hardClose() + pgConn.ayncClose() result.concludeCommand(nil, &writeError{err: err, safeToRetry: n == 0}) pgConn.contextWatcher.Unwatch() result.closed = true @@ -908,7 +922,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm n, err := pgConn.conn.Write(buf) if err != nil { - pgConn.hardClose() + pgConn.ayncClose() pgConn.unlock() return nil, &writeError{err: err, safeToRetry: n == 0} } @@ -919,7 +933,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm for { msg, err := pgConn.receiveMessage() if err != nil { - pgConn.hardClose() + pgConn.ayncClose() return nil, err } @@ -928,7 +942,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm case *pgproto3.CopyData: _, err := w.Write(msg.Data) if err != nil { - pgConn.hardClose() + pgConn.ayncClose() return nil, err } case *pgproto3.ReadyForQuery: @@ -966,7 +980,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co n, err := pgConn.conn.Write(buf) if err != nil { - pgConn.hardClose() + pgConn.ayncClose() return nil, &writeError{err: err, safeToRetry: n == 0} } @@ -977,7 +991,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co for pendingCopyInResponse { msg, err := pgConn.receiveMessage() if err != nil { - pgConn.hardClose() + pgConn.ayncClose() return nil, err } @@ -1006,7 +1020,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { - pgConn.hardClose() + pgConn.ayncClose() return nil, err } } @@ -1015,7 +1029,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co case <-signalMessageChan: msg, err := pgConn.receiveMessage() if err != nil { - pgConn.hardClose() + pgConn.ayncClose() return nil, err } @@ -1039,7 +1053,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } _, err = pgConn.conn.Write(buf) if err != nil { - pgConn.hardClose() + pgConn.ayncClose() return nil, err } @@ -1047,7 +1061,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co for { msg, err := pgConn.receiveMessage() if err != nil { - pgConn.hardClose() + pgConn.ayncClose() return nil, err } @@ -1092,7 +1106,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) mrr.pgConn.contextWatcher.Unwatch() mrr.err = err mrr.closed = true - mrr.pgConn.hardClose() + mrr.pgConn.ayncClose() return nil, mrr.err } @@ -1281,7 +1295,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error rr.pgConn.contextWatcher.Unwatch() rr.closed = true if rr.multiResultReader == nil { - rr.pgConn.hardClose() + rr.pgConn.ayncClose() } return nil, rr.err diff --git a/pgconn_test.go b/pgconn_test.go index 6b57dd09..7ae6fdc5 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1507,6 +1507,49 @@ func TestConnCancelRequest(t *testing.T) { ensureConnValid(t, pgConn) } +// https://github.com/jackc/pgx/issues/659 +func TestConnContextCanceledCancelsRunningQueryOnServer(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pid := pgConn.PID() + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(30)") + + for multiResult.NextResult() { + } + err = multiResult.Close() + assert.True(t, pgconn.Timeout(err)) + assert.True(t, pgConn.IsClosed()) + + otherConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, otherConn) + + ctx, cancel = context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + for { + result := otherConn.ExecParams(ctx, + `select 1 from pg_stat_activity where pid=$1`, + [][]byte{[]byte(strconv.FormatInt(int64(pid), 10))}, + nil, + nil, + nil, + ).Read() + require.NoError(t, result.Err) + + if len(result.Rows) == 0 { + break + } + } +} + func TestConnSendBytesAndReceiveMessage(t *testing.T) { t.Parallel() From 9decdbc2ec3357581cd2911141b3ead982f5026f Mon Sep 17 00:00:00 2001 From: bakape Date: Sat, 11 Jan 2020 16:53:50 +0200 Subject: [PATCH 154/290] Revert nil context support --- README.md | 4 +- benchmark_test.go | 25 +- doc.go | 1 - helper_test.go | 22 - pgconn.go | 62 +- pgconn_test.go | 1500 +++++++++++++++++++++------------------------ 6 files changed, 731 insertions(+), 883 deletions(-) diff --git a/README.md b/README.md index ddbfeaf3..5d14e914 100644 --- a/README.md +++ b/README.md @@ -11,13 +11,13 @@ low-level access to PostgreSQL functionality. ## Example Usage ```go -pgConn, err := pgconn.Connect(nil, os.Getenv("DATABASE_URL")) +pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL")) if err != nil { log.Fatalln("pgconn failed to connect:", err) } defer pgConn.Close() -result := pgConn.ExecParams(nil, "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil) +result := pgConn.ExecParams(context.Background(), "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil) for result.NextRow() { fmt.Println("User 123 has email:", string(result.Values()[0])) } diff --git a/benchmark_test.go b/benchmark_test.go index 4cce5a97..3295a90f 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -14,16 +14,9 @@ func BenchmarkConnect(b *testing.B) { benchmarks := []struct { name string env string - ctx context.Context }{ - // The first benchmark in the list sometimes executes faster, no matter how - // you reorder it. Nil context is still faster on average. - // - // Using and empty context other than context.Background() to compare. - {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING", context.TODO()}, - {"TCP", "PGX_TEST_TCP_CONN_STRING", context.TODO()}, - {"Unix socket nil context", "PGX_TEST_UNIX_SOCKET_CONN_STRING", nil}, - {"TCP nil context", "PGX_TEST_TCP_CONN_STRING", nil}, + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, + {"TCP", "PGX_TEST_TCP_CONN_STRING"}, } for _, bm := range benchmarks { @@ -35,10 +28,10 @@ func BenchmarkConnect(b *testing.B) { } for i := 0; i < b.N; i++ { - conn, err := pgconn.Connect(bm.ctx, connString) + conn, err := pgconn.Connect(context.Background(), connString) require.Nil(b, err) - err = conn.Close(bm.ctx) + err = conn.Close(context.Background()) require.Nil(b, err) } }) @@ -51,9 +44,10 @@ func BenchmarkExec(b *testing.B) { name string ctx context.Context }{ - // Using and empty context other than context.Background() to compare. + // Using an empty context other than context.Background() to compare + // performance + {"background context", context.Background()}, {"empty context", context.TODO()}, - {"nil context", nil}, } for _, bm := range benchmarks { @@ -156,9 +150,10 @@ func BenchmarkExecPrepared(b *testing.B) { name string ctx context.Context }{ - // Using and empty context other than context.Background() to compare. + // Using an empty context other than context.Background() to compare + // performance + {"background context", context.Background()}, {"empty context", context.TODO()}, - {"nil context", nil}, } for _, bm := range benchmarks { diff --git a/doc.go b/doc.go index 25382c68..cde58cd8 100644 --- a/doc.go +++ b/doc.go @@ -22,7 +22,6 @@ Context Support All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the method immediately returns. In most circumstances, this will close the underlying connection. -A nil context can be passed for convenience. This has the same effect as passing context.Background(). The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the client to abort. diff --git a/helper_test.go b/helper_test.go index 1cb05fd2..1a3ca75e 100644 --- a/helper_test.go +++ b/helper_test.go @@ -29,25 +29,3 @@ func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) { assert.Equal(t, "2", string(result.Rows[1][0])) assert.Equal(t, "3", string(result.Rows[2][0])) } - -// Run subtest both with a context.Background() and nil context -func splitOnContext(t *testing.T, test func(t *testing.T, ctx context.Context)) { - t.Helper() - - cases := [...]struct { - name string - ctx context.Context - }{ - {"background context", context.Background()}, - {"nil context", nil}, - } - - for i := range cases { - c := cases[i] - t.Run(c.name, func(t *testing.T) { - t.Helper() - t.Parallel() - test(t, c.ctx) - }) - } -} diff --git a/pgconn.go b/pgconn.go index b8ea9df7..9763b319 100644 --- a/pgconn.go +++ b/pgconn.go @@ -116,10 +116,6 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err panic("config must be created by ParseConfig") } - if ctx == nil { - ctx = context.Background() - } - // Simplify usage by treating primary config and fallbacks the same. fallbackConfigs := []*FallbackConfig{ { @@ -366,9 +362,7 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { } defer pgConn.unlock() - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): return &contextAlreadyDoneError{err: ctx.Err()} @@ -400,9 +394,7 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa } defer pgConn.unlock() - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): return nil, &contextAlreadyDoneError{err: ctx.Err()} @@ -501,9 +493,7 @@ func (pgConn *PgConn) Close(ctx context.Context) error { defer pgConn.conn.Close() - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() } @@ -602,9 +592,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ } defer pgConn.unlock() - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): return nil, &contextAlreadyDoneError{err: ctx.Err()} @@ -693,19 +681,13 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. serverAddr := pgConn.conn.RemoteAddr() - _ctx := ctx - if _ctx == nil { - _ctx = context.Background() - } - cancelConn, err := pgConn.config.DialFunc(_ctx, serverAddr.Network(), serverAddr.String()) + cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) if err != nil { return err } defer cancelConn.Close() - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { contextWatcher := ctxwatch.NewContextWatcher( func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, func() { cancelConn.SetDeadline(time.Time{}) }, @@ -740,9 +722,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { } defer pgConn.unlock() - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): return ctx.Err() @@ -784,11 +764,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { ctx: ctx, } multiResult := &pgConn.multiResultReader - switch ctx { - case nil: - pgConn.multiResultReader.ctx = context.Background() - case context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): multiResult.closed = true @@ -882,9 +858,6 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by ctx: ctx, } result := &pgConn.resultReader - if ctx == nil { - pgConn.resultReader.ctx = context.Background() - } if err := pgConn.lock(); err != nil { result.concludeCommand(nil, err) @@ -899,9 +872,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by return result } - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()}) @@ -937,9 +908,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm return nil, err } - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): pgConn.unlock() @@ -1000,9 +969,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } defer pgConn.unlock() - switch ctx { - case nil, context.Background(): - default: + if ctx != context.Background() { select { case <-ctx.Done(): return nil, &contextAlreadyDoneError{err: ctx.Err()} @@ -1396,11 +1363,8 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR ctx: ctx, } multiResult := &pgConn.multiResultReader - switch ctx { - case nil: - pgConn.multiResultReader.ctx = context.Background() - case context.Background(): - default: + + if ctx != context.Background() { select { case <-ctx.Done(): multiResult.closed = true diff --git a/pgconn_test.go b/pgconn_test.go index 30d20229..6b57dd09 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -27,33 +27,31 @@ import ( ) func TestConnect(t *testing.T) { - splitOnContext(t, func(t *testing.T, ctx context.Context) { - tests := []struct { - name string - env string - }{ - {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, - {"TCP", "PGX_TEST_TCP_CONN_STRING"}, - {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, - {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, - {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"}, - } + tests := []struct { + name string + env string + }{ + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, + {"TCP", "PGX_TEST_TCP_CONN_STRING"}, + {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, + {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, + {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"}, + } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - connString := os.Getenv(tt.env) - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", tt.env) - } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + connString := os.Getenv(tt.env) + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", tt.env) + } - conn, err := pgconn.Connect(ctx, connString) - require.NoError(t, err) + conn, err := pgconn.Connect(context.Background(), connString) + require.NoError(t, err) - closeConn(t, conn) - }) - } - }) + closeConn(t, conn) + }) + } } // TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure @@ -61,21 +59,19 @@ func TestConnect(t *testing.T) { func TestConnectTLS(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") - } + connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") + } - conn, err := pgconn.Connect(ctx, connString) - require.NoError(t, err) + conn, err := pgconn.Connect(context.Background(), connString) + require.NoError(t, err) - if _, ok := conn.Conn().(*tls.Conn); !ok { - t.Error("not a TLS connection") - } + if _, ok := conn.Conn().(*tls.Conn); !ok { + t.Error("not a TLS connection") + } - closeConn(t, conn) - }) + closeConn(t, conn) } type pgmockWaitStep time.Duration @@ -142,259 +138,233 @@ func TestConnectWithContextThatTimesOut(t *testing.T) { func TestConnectInvalidUser(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") - } + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } - config, err := pgconn.ParseConfig(connString) - require.NoError(t, err) + config, err := pgconn.ParseConfig(connString) + require.NoError(t, err) - config.User = "pgxinvalidusertest" + config.User = "pgxinvalidusertest" - _, err = pgconn.ConnectConfig(ctx, config) - require.Error(t, err) - pgErr, ok := errors.Unwrap(err).(*pgconn.PgError) - if !ok { - t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err) - } - if pgErr.Code != "28000" && pgErr.Code != "28P01" { - t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) - } - }) + _, err = pgconn.ConnectConfig(context.Background(), config) + require.Error(t, err) + pgErr, ok := errors.Unwrap(err).(*pgconn.PgError) + if !ok { + t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err) + } + if pgErr.Code != "28000" && pgErr.Code != "28P01" { + t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) + } } func TestConnectWithConnectionRefused(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - // Presumably nothing is listening on 127.0.0.1:1 - conn, err := pgconn.Connect(ctx, "host=127.0.0.1 port=1") - if err == nil { - conn.Close(ctx) - t.Fatal("Expected error establishing connection to bad port") - } - }) + // Presumably nothing is listening on 127.0.0.1:1 + conn, err := pgconn.Connect(context.Background(), "host=127.0.0.1 port=1") + if err == nil { + conn.Close(context.Background()) + t.Fatal("Expected error establishing connection to bad port") + } } func TestConnectCustomDialer(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - dialed := false - config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { - dialed = true - return net.Dial(network, address) - } + dialed := false + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + dialed = true + return net.Dial(network, address) + } - conn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - require.True(t, dialed) - closeConn(t, conn) - }) + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + require.True(t, dialed) + closeConn(t, conn) } func TestConnectCustomLookup(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") - if connString == "" { - t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") - } + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } - config, err := pgconn.ParseConfig(connString) - require.NoError(t, err) + config, err := pgconn.ParseConfig(connString) + require.NoError(t, err) - looked := false - config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) { - looked = true - return net.LookupHost(host) - } + looked := false + config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) { + looked = true + return net.LookupHost(host) + } - conn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - require.True(t, looked) - closeConn(t, conn) - }) + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + require.True(t, looked) + closeConn(t, conn) } func TestConnectWithRuntimeParams(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - config.RuntimeParams = map[string]string{ - "application_name": "pgxtest", - "search_path": "myschema", - } + config.RuntimeParams = map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + } - conn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - defer closeConn(t, conn) + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, conn) - result := conn.ExecParams(ctx, "show application_name", nil, nil, nil, nil).Read() - require.Nil(t, result.Err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, "pgxtest", string(result.Rows[0][0])) + result := conn.ExecParams(context.Background(), "show application_name", nil, nil, nil, nil).Read() + require.Nil(t, result.Err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "pgxtest", string(result.Rows[0][0])) - result = conn.ExecParams(ctx, "show search_path", nil, nil, nil, nil).Read() - require.Nil(t, result.Err) - assert.Equal(t, 1, len(result.Rows)) - assert.Equal(t, "myschema", string(result.Rows[0][0])) - }) + result = conn.ExecParams(context.Background(), "show search_path", nil, nil, nil, nil).Read() + require.Nil(t, result.Err) + assert.Equal(t, 1, len(result.Rows)) + assert.Equal(t, "myschema", string(result.Rows[0][0])) } func TestConnectWithFallback(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - // Prepend current primary config to fallbacks - config.Fallbacks = append([]*pgconn.FallbackConfig{ - &pgconn.FallbackConfig{ - Host: config.Host, - Port: config.Port, - TLSConfig: config.TLSConfig, - }, - }, config.Fallbacks...) + // Prepend current primary config to fallbacks + config.Fallbacks = append([]*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }, + }, config.Fallbacks...) - // Make primary config bad - config.Host = "localhost" - config.Port = 1 // presumably nothing listening here + // Make primary config bad + config.Host = "localhost" + config.Port = 1 // presumably nothing listening here - // Prepend bad first fallback - config.Fallbacks = append([]*pgconn.FallbackConfig{ - &pgconn.FallbackConfig{ - Host: "localhost", - Port: 1, - TLSConfig: config.TLSConfig, - }, - }, config.Fallbacks...) + // Prepend bad first fallback + config.Fallbacks = append([]*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "localhost", + Port: 1, + TLSConfig: config.TLSConfig, + }, + }, config.Fallbacks...) - conn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - closeConn(t, conn) - }) + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + closeConn(t, conn) } func TestConnectWithValidateConnect(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - dialCount := 0 - config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { - dialCount++ - return net.Dial(network, address) + dialCount := 0 + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + dialCount++ + return net.Dial(network, address) + } + + acceptConnCount := 0 + config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + acceptConnCount++ + if acceptConnCount < 2 { + return errors.New("reject first conn") } + return nil + } - acceptConnCount := 0 - config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error { - acceptConnCount++ - if acceptConnCount < 2 { - return errors.New("reject first conn") - } - return nil - } - - // Append current primary config to fallbacks - config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{ - Host: config.Host, - Port: config.Port, - TLSConfig: config.TLSConfig, - }) - - // Repeat fallbacks - config.Fallbacks = append(config.Fallbacks, config.Fallbacks...) - - conn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - closeConn(t, conn) - - assert.True(t, dialCount > 1) - assert.True(t, acceptConnCount > 1) + // Append current primary config to fallbacks + config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{ + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, }) + + // Repeat fallbacks + config.Fallbacks = append(config.Fallbacks, config.Fallbacks...) + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + closeConn(t, conn) + + assert.True(t, dialCount > 1) + assert.True(t, acceptConnCount > 1) } func TestConnectWithValidateConnectTargetSessionAttrsReadWrite(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite - config.RuntimeParams["default_transaction_read_only"] = "on" + config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite + config.RuntimeParams["default_transaction_read_only"] = "on" - conn, err := pgconn.ConnectConfig(ctx, config) - if !assert.NotNil(t, err) { - conn.Close(ctx) - } - }) + conn, err := pgconn.ConnectConfig(context.Background(), config) + if !assert.NotNil(t, err) { + conn.Close(context.Background()) + } } func TestConnectWithAfterConnect(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { - _, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll() - return err - } + config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { + _, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll() + return err + } - conn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) - results, err := conn.Exec(ctx, "show search_path;").ReadAll() - require.NoError(t, err) - defer closeConn(t, conn) + results, err := conn.Exec(context.Background(), "show search_path;").ReadAll() + require.NoError(t, err) + defer closeConn(t, conn) - assert.Equal(t, []byte("foobar"), results[0].Rows[0][0]) - }) + assert.Equal(t, []byte("foobar"), results[0].Rows[0][0]) } func TestConnectConfigRequiresConfigFromParseConfig(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config := &pgconn.Config{} + config := &pgconn.Config{} - require.PanicsWithValue( - t, - "config must be created by ParseConfig", - func() { pgconn.ConnectConfig(ctx, config) }, - ) - }) + require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgconn.ConnectConfig(context.Background(), config) }) } func TestConnPrepareSyntaxError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - psd, err := pgConn.Prepare(ctx, "ps1", "SYNTAX ERROR", nil) - require.Nil(t, psd) - require.NotNil(t, err) + psd, err := pgConn.Prepare(context.Background(), "ps1", "SYNTAX ERROR", nil) + require.Nil(t, psd) + require.NotNil(t, err) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnPrepareContextPrecanceled(t *testing.T) { @@ -418,126 +388,116 @@ func TestConnPrepareContextPrecanceled(t *testing.T) { func TestConnExec(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - results, err := pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() - assert.NoError(t, err) + results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + assert.NoError(t, err) - assert.Len(t, results, 1) - assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecEmpty(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - multiResult := pgConn.Exec(ctx, ";") + multiResult := pgConn.Exec(context.Background(), ";") - resultCount := 0 - for multiResult.NextResult() { - resultCount++ - multiResult.ResultReader().Close() - } - assert.Equal(t, 0, resultCount) - err = multiResult.Close() - assert.NoError(t, err) + resultCount := 0 + for multiResult.NextResult() { + resultCount++ + multiResult.ResultReader().Close() + } + assert.Equal(t, 0, resultCount) + err = multiResult.Close() + assert.NoError(t, err) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecMultipleQueries(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - results, err := pgConn.Exec(ctx, "select 'Hello, world'; select 1").ReadAll() - assert.NoError(t, err) + results, err := pgConn.Exec(context.Background(), "select 'Hello, world'; select 1").ReadAll() + assert.NoError(t, err) - assert.Len(t, results, 2) + assert.Len(t, results, 2) - assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) - assert.Nil(t, results[1].Err) - assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) - assert.Len(t, results[1].Rows, 1) - assert.Equal(t, "1", string(results[1].Rows[0][0])) + assert.Nil(t, results[1].Err) + assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) + assert.Len(t, results[1].Rows, 1) + assert.Equal(t, "1", string(results[1].Rows[0][0])) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecMultipleQueriesError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - results, err := pgConn.Exec(ctx, "select 1; select 1/0; select 1").ReadAll() - require.NotNil(t, err) - if pgErr, ok := err.(*pgconn.PgError); ok { - assert.Equal(t, "22012", pgErr.Code) - } else { - t.Errorf("unexpected error: %v", err) - } + results, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1").ReadAll() + require.NotNil(t, err) + if pgErr, ok := err.(*pgconn.PgError); ok { + assert.Equal(t, "22012", pgErr.Code) + } else { + t.Errorf("unexpected error: %v", err) + } - assert.Len(t, results, 1) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "1", string(results[0].Rows[0][0])) + assert.Len(t, results, 1) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "1", string(results[0].Rows[0][0])) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecDeferredError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - setupSQL := `create temporary table t ( - id text primary key, - n int not null, - unique (n) deferrable initially deferred - ); + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); - insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` - _, err = pgConn.Exec(ctx, setupSQL).ReadAll() - assert.NoError(t, err) + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) - _, err = pgConn.Exec(ctx, `update t set n=n+1 where id='b' returning *`).ReadAll() - require.NotNil(t, err) + _, err = pgConn.Exec(context.Background(), `update t set n=n+1 where id='b' returning *`).ReadAll() + require.NotNil(t, err) - var pgErr *pgconn.PgError - require.True(t, errors.As(err, &pgErr)) - require.Equal(t, "23505", pgErr.Code) + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecContextCanceled(t *testing.T) { @@ -578,103 +538,95 @@ func TestConnExecContextPrecanceled(t *testing.T) { func TestConnExecParams(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) - rowCount := 0 - for result.NextRow() { - rowCount += 1 - assert.Equal(t, "Hello, world", string(result.Values()[0])) - } - assert.Equal(t, 1, rowCount) - commandTag, err := result.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.NoError(t, err) + result := pgConn.ExecParams(context.Background(), "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.NoError(t, err) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecParamsDeferredError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - setupSQL := `create temporary table t ( - id text primary key, - n int not null, - unique (n) deferrable initially deferred - ); + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); - insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` - _, err = pgConn.Exec(ctx, setupSQL).ReadAll() - assert.NoError(t, err) + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) - result := pgConn.ExecParams(ctx, `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read() - require.NotNil(t, result.Err) - var pgErr *pgconn.PgError - require.True(t, errors.As(result.Err, &pgErr)) - require.Equal(t, "23505", pgErr.Code) + result := pgConn.ExecParams(context.Background(), `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read() + require.NotNil(t, result.Err) + var pgErr *pgconn.PgError + require.True(t, errors.As(result.Err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecParamsMaxNumberOfParams(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - result := pgConn.ExecParams(ctx, sql, args, nil, nil, nil).Read() - require.NoError(t, result.Err) - require.Len(t, result.Rows, paramCount) + result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecParamsTooManyParams(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 + 1 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + 1 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - result := pgConn.ExecParams(ctx, sql, args, nil, nil, nil).Read() - require.Error(t, result.Err) - require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) + result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecParamsCanceled(t *testing.T) { @@ -719,92 +671,86 @@ func TestConnExecParamsPrecanceled(t *testing.T) { func TestConnExecPrepared(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - psd, err := pgConn.Prepare(ctx, "ps1", "select $1::text", nil) - require.NoError(t, err) - require.NotNil(t, psd) - assert.Len(t, psd.ParamOIDs, 1) - assert.Len(t, psd.Fields, 1) + psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, 1) + assert.Len(t, psd.Fields, 1) - result := pgConn.ExecPrepared(ctx, "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) - rowCount := 0 - for result.NextRow() { - rowCount += 1 - assert.Equal(t, "Hello, world", string(result.Values()[0])) - } - assert.Equal(t, 1, rowCount) - commandTag, err := result.Close() - assert.Equal(t, "SELECT 1", string(commandTag)) - assert.NoError(t, err) + result := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.NoError(t, err) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecPreparedMaxNumberOfParams(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - psd, err := pgConn.Prepare(ctx, "ps1", sql, nil) - require.NoError(t, err) - require.NotNil(t, psd) - assert.Len(t, psd.ParamOIDs, paramCount) - assert.Len(t, psd.Fields, 1) + psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) - result := pgConn.ExecPrepared(ctx, "ps1", args, nil, nil).Read() - require.NoError(t, result.Err) - require.Len(t, result.Rows, paramCount) + result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() + require.NoError(t, result.Err) + require.Len(t, result.Rows, paramCount) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecPreparedTooManyParams(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - paramCount := math.MaxUint16 + 1 - params := make([]string, 0, paramCount) - args := make([][]byte, 0, paramCount) - for i := 0; i < paramCount; i++ { - params = append(params, fmt.Sprintf("($%d::text)", i+1)) - args = append(args, []byte(strconv.Itoa(i))) - } - sql := "values" + strings.Join(params, ", ") + paramCount := math.MaxUint16 + 1 + params := make([]string, 0, paramCount) + args := make([][]byte, 0, paramCount) + for i := 0; i < paramCount; i++ { + params = append(params, fmt.Sprintf("($%d::text)", i+1)) + args = append(args, []byte(strconv.Itoa(i))) + } + sql := "values" + strings.Join(params, ", ") - psd, err := pgConn.Prepare(ctx, "ps1", sql, nil) - require.NoError(t, err) - require.NotNil(t, psd) - assert.Len(t, psd.ParamOIDs, paramCount) - assert.Len(t, psd.Fields, 1) + psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) - result := pgConn.ExecPrepared(ctx, "ps1", args, nil, nil).Read() - require.Error(t, result.Err) - require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) + result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() + require.Error(t, result.Err) + require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecPreparedCanceled(t *testing.T) { @@ -854,67 +800,63 @@ func TestConnExecPreparedPrecanceled(t *testing.T) { func TestConnExecBatch(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Prepare(ctx, "ps1", "select $1::text", nil) - require.NoError(t, err) + _, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + require.NoError(t, err) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) - batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) - batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) - results, err := pgConn.ExecBatch(ctx, batch).ReadAll() - require.NoError(t, err) - require.Len(t, results, 3) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) + batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) + results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.NoError(t, err) + require.Len(t, results, 3) - require.Len(t, results[0].Rows, 1) - require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + require.Len(t, results[0].Rows, 1) + require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - require.Len(t, results[1].Rows, 1) - require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) + require.Len(t, results[1].Rows, 1) + require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) - require.Len(t, results[2].Rows, 1) - require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) - }) + require.Len(t, results[2].Rows, 1) + require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) } func TestConnExecBatchDeferredError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - setupSQL := `create temporary table t ( - id text primary key, - n int not null, - unique (n) deferrable initially deferred - ); + setupSQL := `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); - insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` - _, err = pgConn.Exec(ctx, setupSQL).ReadAll() - assert.NoError(t, err) + _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() + assert.NoError(t, err) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil) - _, err = pgConn.ExecBatch(ctx, batch).ReadAll() - require.NotNil(t, err) - var pgErr *pgconn.PgError - require.True(t, errors.As(err, &pgErr)) - require.Equal(t, "23505", pgErr.Code) + batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil) + _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.NotNil(t, err) + var pgErr *pgconn.PgError + require.True(t, errors.As(err, &pgErr)) + require.Equal(t, "23505", pgErr.Code) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnExecBatchPrecanceled(t *testing.T) { @@ -953,82 +895,76 @@ func TestConnExecBatchHuge(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - queryCount := 100000 - args := make([]string, queryCount) + queryCount := 100000 + args := make([]string, queryCount) - for i := range args { - args[i] = strconv.Itoa(i) - batch.ExecParams("select $1::text", [][]byte{[]byte(args[i])}, nil, nil, nil) - } + for i := range args { + args[i] = strconv.Itoa(i) + batch.ExecParams("select $1::text", [][]byte{[]byte(args[i])}, nil, nil, nil) + } - results, err := pgConn.ExecBatch(ctx, batch).ReadAll() - require.NoError(t, err) - require.Len(t, results, queryCount) + results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.NoError(t, err) + require.Len(t, results, queryCount) - for i := range args { - require.Len(t, results[i].Rows, 1) - require.Equal(t, args[i], string(results[i].Rows[0][0])) - assert.Equal(t, "SELECT 1", string(results[i].CommandTag)) - } - }) + for i := range args { + require.Len(t, results[i].Rows, 1) + require.Equal(t, args[i], string(results[i].Rows[0][0])) + assert.Equal(t, "SELECT 1", string(results[i].CommandTag)) + } } func TestConnExecBatchImplicitTransaction(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(ctx, "create temporary table t(id int)").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), "create temporary table t(id int)").ReadAll() + require.NoError(t, err) - batch := &pgconn.Batch{} + batch := &pgconn.Batch{} - batch.ExecParams("insert into t(id) values(1)", nil, nil, nil, nil) - batch.ExecParams("insert into t(id) values(2)", nil, nil, nil, nil) - batch.ExecParams("insert into t(id) values(3)", nil, nil, nil, nil) - batch.ExecParams("select 1/0", nil, nil, nil, nil) - _, err = pgConn.ExecBatch(ctx, batch).ReadAll() - require.Error(t, err) + batch.ExecParams("insert into t(id) values(1)", nil, nil, nil, nil) + batch.ExecParams("insert into t(id) values(2)", nil, nil, nil, nil) + batch.ExecParams("insert into t(id) values(3)", nil, nil, nil, nil) + batch.ExecParams("select 1/0", nil, nil, nil, nil) + _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() + require.Error(t, err) - result := pgConn.ExecParams(ctx, "select count(*) from t", nil, nil, nil, nil).Read() - require.Equal(t, "0", string(result.Rows[0][0])) - }) + result := pgConn.ExecParams(context.Background(), "select count(*) from t", nil, nil, nil, nil).Read() + require.Equal(t, "0", string(result.Rows[0][0])) } func TestConnLocking(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - mrr := pgConn.Exec(ctx, "select 'Hello, world'") - _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() - assert.Error(t, err) - assert.Equal(t, "conn busy", err.Error()) - assert.True(t, pgconn.SafeToRetry(err)) + mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") + _, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + assert.Error(t, err) + assert.Equal(t, "conn busy", err.Error()) + assert.True(t, pgconn.SafeToRetry(err)) - results, err := mrr.ReadAll() - assert.NoError(t, err) - assert.Len(t, results, 1) - assert.Nil(t, results[0].Err) - assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + results, err := mrr.ReadAll() + assert.NoError(t, err) + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestCommandTag(t *testing.T) { @@ -1057,97 +993,91 @@ func TestCommandTag(t *testing.T) { func TestConnOnNotice(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - var msg string - config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) { - msg = notice.Message - } + var msg string + config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) { + msg = notice.Message + } - pgConn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) - multiResult := pgConn.Exec(ctx, `do $$ - begin - raise notice 'hello, world'; - end$$;`) - err = multiResult.Close() - require.NoError(t, err) - assert.Equal(t, "hello, world", msg) + multiResult := pgConn.Exec(context.Background(), `do $$ +begin + raise notice 'hello, world'; +end$$;`) + err = multiResult.Close() + require.NoError(t, err) + assert.Equal(t, "hello, world", msg) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnOnNotification(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - var msg string - config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { - msg = n.Payload - } + var msg string + config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { + msg = n.Payload + } - pgConn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(ctx, "listen foo").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() + require.NoError(t, err) - notifier, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - defer closeConn(t, notifier) - _, err = notifier.Exec(ctx, "notify foo, 'bar'").ReadAll() - require.NoError(t, err) + notifier, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, notifier) + _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() + require.NoError(t, err) - _, err = pgConn.Exec(ctx, "select 1").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), "select 1").ReadAll() + require.NoError(t, err) - assert.Equal(t, "bar", msg) + assert.Equal(t, "bar", msg) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnWaitForNotification(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) - var msg string - config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { - msg = n.Payload - } + var msg string + config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { + msg = n.Payload + } - pgConn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(ctx, "listen foo").ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() + require.NoError(t, err) - notifier, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - defer closeConn(t, notifier) - _, err = notifier.Exec(ctx, "notify foo, 'bar'").ReadAll() - require.NoError(t, err) + notifier, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, notifier) + _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() + require.NoError(t, err) - err = pgConn.WaitForNotification(ctx) - require.NoError(t, err) + err = pgConn.WaitForNotification(context.Background()) + require.NoError(t, err) - assert.Equal(t, "bar", msg) + assert.Equal(t, "bar", msg) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnWaitForNotificationPrecanceled(t *testing.T) { @@ -1189,100 +1119,94 @@ func TestConnWaitForNotificationTimeout(t *testing.T) { func TestConnCopyToSmall(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(ctx, `create temporary table foo( - a int2, - b int4, - c int8, - d varchar, - e text, - f date, - g json - )`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g json + )`).ReadAll() + require.NoError(t, err) - _, err = pgConn.Exec(ctx, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll() + require.NoError(t, err) - _, err = pgConn.Exec(ctx, `insert into foo values (null, null, null, null, null, null, null)`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), `insert into foo values (null, null, null, null, null, null, null)`).ReadAll() + require.NoError(t, err) - inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" + - "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n") + inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" + + "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n") - outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) + outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) - res, err := pgConn.CopyTo(ctx, outputWriter, "copy foo to stdout") - require.NoError(t, err) + res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") + require.NoError(t, err) - assert.Equal(t, int64(2), res.RowsAffected()) - assert.Equal(t, inputBytes, outputWriter.Bytes()) + assert.Equal(t, int64(2), res.RowsAffected()) + assert.Equal(t, inputBytes, outputWriter.Bytes()) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnCopyToLarge(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int2, + b int4, + c int8, + d varchar, + e text, + f date, + g json, + h bytea + )`).ReadAll() + require.NoError(t, err) + + inputBytes := make([]byte, 0) + + for i := 0; i < 1000; i++ { + _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll() require.NoError(t, err) - defer closeConn(t, pgConn) + inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...) + } - _, err = pgConn.Exec(ctx, `create temporary table foo( - a int2, - b int4, - c int8, - d varchar, - e text, - f date, - g json, - h bytea - )`).ReadAll() - require.NoError(t, err) + outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) - inputBytes := make([]byte, 0) + res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") + require.NoError(t, err) - for i := 0; i < 1000; i++ { - _, err = pgConn.Exec(ctx, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll() - require.NoError(t, err) - inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...) - } + assert.Equal(t, int64(1000), res.RowsAffected()) + assert.Equal(t, inputBytes, outputWriter.Bytes()) - outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) - - res, err := pgConn.CopyTo(ctx, outputWriter, "copy foo to stdout") - require.NoError(t, err) - - assert.Equal(t, int64(1000), res.RowsAffected()) - assert.Equal(t, inputBytes, outputWriter.Bytes()) - - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnCopyToQueryError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - outputWriter := bytes.NewBuffer(make([]byte, 0)) + outputWriter := bytes.NewBuffer(make([]byte, 0)) - res, err := pgConn.CopyTo(ctx, outputWriter, "cropy foo to stdout") - require.Error(t, err) - assert.IsType(t, &pgconn.PgError{}, err) - assert.Equal(t, int64(0), res.RowsAffected()) + res, err := pgConn.CopyTo(context.Background(), outputWriter, "cropy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnCopyToCanceled(t *testing.T) { @@ -1326,39 +1250,37 @@ func TestConnCopyToPrecanceled(t *testing.T) { func TestConnCopyFrom(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + srcBuf := &bytes.Buffer{} + + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) require.NoError(t, err) - defer closeConn(t, pgConn) + } - _, err = pgConn.Exec(ctx, `create temporary table foo( - a int4, - b varchar - )`).ReadAll() - require.NoError(t, err) + ct, err := pgConn.CopyFrom(context.Background(), srcBuf, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) - srcBuf := &bytes.Buffer{} + result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) - inputRows := [][][]byte{} - for i := 0; i < 1000; i++ { - a := strconv.Itoa(i) - b := "foo " + a + " bar" - inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) - _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) - require.NoError(t, err) - } + assert.Equal(t, inputRows, result.Rows) - ct, err := pgConn.CopyFrom(ctx, srcBuf, "COPY foo FROM STDIN WITH (FORMAT csv)") - require.NoError(t, err) - assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) - - result := pgConn.ExecParams(ctx, "select * from foo", nil, nil, nil, nil).Read() - require.NoError(t, result.Err) - - assert.Equal(t, inputRows, result.Rows) - - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnCopyFromCanceled(t *testing.T) { @@ -1436,163 +1358,153 @@ func TestConnCopyFromPrecanceled(t *testing.T) { func TestConnCopyFromGzipReader(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + f, err := ioutil.TempFile("", "*") + require.NoError(t, err) + + gw := gzip.NewWriter(f) + + inputRows := [][][]byte{} + for i := 0; i < 1000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) + _, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) require.NoError(t, err) - defer closeConn(t, pgConn) + } - _, err = pgConn.Exec(ctx, `create temporary table foo( - a int4, - b varchar - )`).ReadAll() - require.NoError(t, err) + err = gw.Close() + require.NoError(t, err) - f, err := ioutil.TempFile("", "*") - require.NoError(t, err) + _, err = f.Seek(0, 0) + require.NoError(t, err) - gw := gzip.NewWriter(f) + gr, err := gzip.NewReader(f) + require.NoError(t, err) - inputRows := [][][]byte{} - for i := 0; i < 1000; i++ { - a := strconv.Itoa(i) - b := "foo " + a + " bar" - inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) - _, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) - require.NoError(t, err) - } + ct, err := pgConn.CopyFrom(context.Background(), gr, "COPY foo FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) + assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) - err = gw.Close() - require.NoError(t, err) + err = gr.Close() + require.NoError(t, err) - _, err = f.Seek(0, 0) - require.NoError(t, err) + err = f.Close() + require.NoError(t, err) - gr, err := gzip.NewReader(f) - require.NoError(t, err) + err = os.Remove(f.Name()) + require.NoError(t, err) - ct, err := pgConn.CopyFrom(ctx, gr, "COPY foo FROM STDIN WITH (FORMAT csv)") - require.NoError(t, err) - assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) + result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) - err = gr.Close() - require.NoError(t, err) + assert.Equal(t, inputRows, result.Rows) - err = f.Close() - require.NoError(t, err) - - err = os.Remove(f.Name()) - require.NoError(t, err) - - result := pgConn.ExecParams(ctx, "select * from foo", nil, nil, nil, nil).Read() - require.NoError(t, result.Err) - - assert.Equal(t, inputRows, result.Rows) - - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnCopyFromQuerySyntaxError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - _, err = pgConn.Exec(ctx, `create temporary table foo( - a int4, - b varchar - )`).ReadAll() - require.NoError(t, err) + _, err = pgConn.Exec(context.Background(), `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) - srcBuf := &bytes.Buffer{} + srcBuf := &bytes.Buffer{} - res, err := pgConn.CopyFrom(ctx, srcBuf, "cropy foo to stdout") - require.Error(t, err) - assert.IsType(t, &pgconn.PgError{}, err) - assert.Equal(t, int64(0), res.RowsAffected()) + res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnCopyFromQueryNoTableError(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - srcBuf := &bytes.Buffer{} + srcBuf := &bytes.Buffer{} - res, err := pgConn.CopyFrom(ctx, srcBuf, "copy foo to stdout") - require.Error(t, err) - assert.IsType(t, &pgconn.PgError{}, err) - assert.Equal(t, int64(0), res.RowsAffected()) + res, err := pgConn.CopyFrom(context.Background(), srcBuf, "copy foo to stdout") + require.Error(t, err) + assert.IsType(t, &pgconn.PgError{}, err) + assert.Equal(t, int64(0), res.RowsAffected()) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnEscapeString(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - tests := []struct { - in string - out string - }{ - {in: "", out: ""}, - {in: "42", out: "42"}, - {in: "'", out: "''"}, - {in: "hi'there", out: "hi''there"}, - {in: "'hi there'", out: "''hi there''"}, + tests := []struct { + in string + out string + }{ + {in: "", out: ""}, + {in: "42", out: "42"}, + {in: "'", out: "''"}, + {in: "hi'there", out: "hi''there"}, + {in: "'hi there'", out: "''hi there''"}, + } + + for i, tt := range tests { + value, err := pgConn.EscapeString(tt.in) + if assert.NoErrorf(t, err, "%d.", i) { + assert.Equalf(t, tt.out, value, "%d.", i) } + } - for i, tt := range tests { - value, err := pgConn.EscapeString(tt.in) - if assert.NoErrorf(t, err, "%d.", i) { - assert.Equalf(t, tt.out, value, "%d.", i) - } - } - - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnCancelRequest(t *testing.T) { t.Parallel() - splitOnContext(t, func(t *testing.T, ctx context.Context) { - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, pgConn) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) - multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(2)") + multiResult := pgConn.Exec(context.Background(), "select 'Hello, world', pg_sleep(2)") - // This test flickers without the Sleep. It appears that since Exec only sends the query and returns without awaiting a - // response that the CancelRequest can race it and be received before the query is running and cancellable. So wait a - // few milliseconds. - time.Sleep(50 * time.Millisecond) + // This test flickers without the Sleep. It appears that since Exec only sends the query and returns without awaiting a + // response that the CancelRequest can race it and be received before the query is running and cancellable. So wait a + // few milliseconds. + time.Sleep(50 * time.Millisecond) - err = pgConn.CancelRequest(ctx) - require.NoError(t, err) + err = pgConn.CancelRequest(context.Background()) + require.NoError(t, err) - for multiResult.NextResult() { - } - err = multiResult.Close() + for multiResult.NextResult() { + } + err = multiResult.Close() - require.IsType(t, &pgconn.PgError{}, err) - require.Equal(t, "57014", err.(*pgconn.PgError).Code) + require.IsType(t, &pgconn.PgError{}, err) + require.Equal(t, "57014", err.(*pgconn.PgError).Code) - ensureConnValid(t, pgConn) - }) + ensureConnValid(t, pgConn) } func TestConnSendBytesAndReceiveMessage(t *testing.T) { @@ -1635,13 +1547,13 @@ func TestConnSendBytesAndReceiveMessage(t *testing.T) { } func Example() { - pgConn, err := pgconn.Connect(nil, os.Getenv("PGX_TEST_CONN_STRING")) + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) if err != nil { log.Fatalln(err) } - defer pgConn.Close(nil) + defer pgConn.Close(context.Background()) - result := pgConn.ExecParams(nil, "select generate_series(1,3)", nil, nil, nil, nil).Read() + result := pgConn.ExecParams(context.Background(), "select generate_series(1,3)", nil, nil, nil, nil).Read() if result.Err != nil { log.Fatalln(result.Err) } From b6669ae6dda06f53fe221f80507123d967f7f099 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Jan 2020 18:23:41 -0600 Subject: [PATCH 155/290] Add PgError.SQLState method fixes #15 --- errors.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/errors.go b/errors.go index a088dcdd..7a21af98 100644 --- a/errors.go +++ b/errors.go @@ -55,6 +55,11 @@ func (pe *PgError) Error() string { return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" } +// SQLState returns the SQLState of the error. +func (pe *PgError) SQLState() string { + return pe.Code +} + type connectError struct { config *Config msg string From fd2093cef8e97839e11bb13bb4a2c1b805ae62f5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Jan 2020 18:42:31 -0600 Subject: [PATCH 156/290] Add statement type convenience methods to CommandTag and optimize Added convenient way to check whether a statement was a select, insert, update, or delete. These methods do not allocate. RowsAffected now does not allocate even when a large number of rows are affected. It also is multiple times faster, though the absolute change is inconsequential. --- benchmark_test.go | 68 +++++++++++++++++++++++++++++++++++++++++++++++ pgconn.go | 64 +++++++++++++++++++++++++++++++++++++++++--- pgconn_test.go | 25 ++++++++++++----- 3 files changed, 146 insertions(+), 11 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index 3295a90f..ced785b6 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "os" + "strings" "testing" "github.com/jackc/pgconn" @@ -252,3 +253,70 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { // conn.ChanToSetDeadline().Ignore() // } // } + +func BenchmarkCommandTagRowsAffected(b *testing.B) { + benchmarks := []struct { + commandTag string + rowsAffected int64 + }{ + {"UPDATE 1", 1}, + {"UPDATE 123456789", 123456789}, + {"INSERT 0 1", 1}, + {"INSERT 0 123456789", 123456789}, + } + + for _, bm := range benchmarks { + ct := pgconn.CommandTag(bm.commandTag) + b.Run(bm.commandTag, func(b *testing.B) { + var n int64 + for i := 0; i < b.N; i++ { + n = ct.RowsAffected() + } + if n != bm.rowsAffected { + b.Errorf("expected %d got %d", bm.rowsAffected, n) + } + }) + } +} + +func BenchmarkCommandTagTypeFromString(b *testing.B) { + ct := pgconn.CommandTag("UPDATE 1") + + var update bool + for i := 0; i < b.N; i++ { + update = strings.HasPrefix(ct.String(), "UPDATE") + } + if !update { + b.Error("expected update") + } +} + +func BenchmarkCommandTagInsert(b *testing.B) { + benchmarks := []struct { + commandTag string + is bool + }{ + {"INSERT 1", true}, + {"INSERT 1234567890", true}, + {"UPDATE 1", false}, + {"UPDATE 1234567890", false}, + {"DELETE 1", false}, + {"DELETE 1234567890", false}, + {"SELECT 1", false}, + {"SELECT 1234567890", false}, + {"UNKNOWN 1234567890", false}, + } + + for _, bm := range benchmarks { + ct := pgconn.CommandTag(bm.commandTag) + b.Run(bm.commandTag, func(b *testing.B) { + var is bool + for i := 0; i < b.N; i++ { + is = ct.Insert() + } + if is != bm.is { + b.Errorf("expected %v got %v", bm.is, is) + } + }) + } +} diff --git a/pgconn.go b/pgconn.go index c46dc6a6..dce4bfb5 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1,7 +1,6 @@ package pgconn import ( - "bytes" "context" "crypto/md5" "crypto/tls" @@ -10,7 +9,6 @@ import ( "io" "math" "net" - "strconv" "strings" "sync" "time" @@ -579,11 +577,25 @@ type CommandTag []byte // RowsAffected returns the number of rows affected. If the CommandTag was not // for a row affecting command (e.g. "CREATE TABLE") then it returns 0. func (ct CommandTag) RowsAffected() int64 { - idx := bytes.LastIndexByte([]byte(ct), ' ') + // Find last non-digit + idx := -1 + for i := len(ct) - 1; i >= 0; i-- { + if ct[i] >= '0' && ct[i] <= '9' { + idx = i + } else { + break + } + } + if idx == -1 { return 0 } - n, _ := strconv.ParseInt(string([]byte(ct)[idx+1:]), 10, 64) + + var n int64 + for _, b := range ct[idx:] { + n = n*10 + int64(b-'0') + } + return n } @@ -591,6 +603,50 @@ func (ct CommandTag) String() string { return string(ct) } +// Insert is true if the command tag starts with "INSERT". +func (ct CommandTag) Insert() bool { + return len(ct) >= 6 && + ct[0] == 'I' && + ct[1] == 'N' && + ct[2] == 'S' && + ct[3] == 'E' && + ct[4] == 'R' && + ct[5] == 'T' +} + +// Update is true if the command tag starts with "UPDATE". +func (ct CommandTag) Update() bool { + return len(ct) >= 6 && + ct[0] == 'U' && + ct[1] == 'P' && + ct[2] == 'D' && + ct[3] == 'A' && + ct[4] == 'T' && + ct[5] == 'E' +} + +// Delete is true if the command tag starts with "DELETE". +func (ct CommandTag) Delete() bool { + return len(ct) >= 6 && + ct[0] == 'D' && + ct[1] == 'E' && + ct[2] == 'L' && + ct[3] == 'E' && + ct[4] == 'T' && + ct[5] == 'E' +} + +// Select is true if the command tag starts with "SELECT". +func (ct CommandTag) Select() bool { + return len(ct) >= 6 && + ct[0] == 'S' && + ct[1] == 'E' && + ct[2] == 'L' && + ct[3] == 'E' && + ct[4] == 'C' && + ct[5] == 'T' +} + type StatementDescription struct { Name string SQL string diff --git a/pgconn_test.go b/pgconn_test.go index 7ae6fdc5..2c303d81 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -973,20 +973,31 @@ func TestCommandTag(t *testing.T) { var tests = []struct { commandTag pgconn.CommandTag rowsAffected int64 + isInsert bool + isUpdate bool + isDelete bool + isSelect bool }{ - {commandTag: pgconn.CommandTag("INSERT 0 5"), rowsAffected: 5}, - {commandTag: pgconn.CommandTag("UPDATE 0"), rowsAffected: 0}, - {commandTag: pgconn.CommandTag("UPDATE 1"), rowsAffected: 1}, - {commandTag: pgconn.CommandTag("DELETE 0"), rowsAffected: 0}, - {commandTag: pgconn.CommandTag("DELETE 1"), rowsAffected: 1}, + {commandTag: pgconn.CommandTag("INSERT 0 5"), rowsAffected: 5, isInsert: true}, + {commandTag: pgconn.CommandTag("UPDATE 0"), rowsAffected: 0, isUpdate: true}, + {commandTag: pgconn.CommandTag("UPDATE 1"), rowsAffected: 1, isUpdate: true}, + {commandTag: pgconn.CommandTag("DELETE 0"), rowsAffected: 0, isDelete: true}, + {commandTag: pgconn.CommandTag("DELETE 1"), rowsAffected: 1, isDelete: true}, + {commandTag: pgconn.CommandTag("DELETE 1234567890"), rowsAffected: 1234567890, isDelete: true}, + {commandTag: pgconn.CommandTag("SELECT 1"), rowsAffected: 1, isSelect: true}, + {commandTag: pgconn.CommandTag("SELECT 99999999999"), rowsAffected: 99999999999, isSelect: true}, {commandTag: pgconn.CommandTag("CREATE TABLE"), rowsAffected: 0}, {commandTag: pgconn.CommandTag("ALTER TABLE"), rowsAffected: 0}, {commandTag: pgconn.CommandTag("DROP TABLE"), rowsAffected: 0}, } for i, tt := range tests { - actual := tt.commandTag.RowsAffected() - assert.Equalf(t, tt.rowsAffected, actual, "%d. %v", i, tt.commandTag) + ct := tt.commandTag + assert.Equalf(t, tt.rowsAffected, ct.RowsAffected(), "%d. %v", i, tt.commandTag) + assert.Equalf(t, tt.isInsert, ct.Insert(), "%d. %v", i, tt.commandTag) + assert.Equalf(t, tt.isUpdate, ct.Update(), "%d. %v", i, tt.commandTag) + assert.Equalf(t, tt.isDelete, ct.Delete(), "%d. %v", i, tt.commandTag) + assert.Equalf(t, tt.isSelect, ct.Select(), "%d. %v", i, tt.commandTag) } } From a48e9bf63c413ae0498d16e339f94c2884fa988e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Jan 2020 19:07:39 -0600 Subject: [PATCH 157/290] Update changelog --- CHANGELOG.md | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 92497f47..1debb10b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,25 @@ +# 1.2.0 (January 11, 2020) + +## Features + +* Add Insert(), Update(), Delete(), and Select() statement type query methods to CommandTag. +* Add PgError.SQLState method. This could be used for compatibility with other drivers and databases. + +## Performance + +* Improve performance when context.Background() is used. (bakape) +* CommandTag.RowsAffected is faster and does not allocate. + +## Fixes + +* Try to cancel any in-progress query when a conn is closed by ctx cancel. +* Handle NoticeResponse during CopyFrom. +* Ignore errors sending Terminate message while closing connection. This mimics the behavior of libpq PGfinish. + # 1.1.0 (October 12, 2019) -* Add PgConn.IsBusy() method +* Add PgConn.IsBusy() method. # 1.0.1 (September 19, 2019) -* Fix statement cache not properly cleaning discarded statements +* Fix statement cache not properly cleaning discarded statements. From 0df97353b8acf7c2751f7812f27d99a6974e596c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 12 Jan 2020 16:27:46 -0600 Subject: [PATCH 158/290] Fix racy usage of pgConn.contextWatcher in ayncClose --- pgconn.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pgconn.go b/pgconn.go index dce4bfb5..3a6c598a 100644 --- a/pgconn.go +++ b/pgconn.go @@ -518,13 +518,14 @@ func (pgConn *PgConn) ayncClose() { go func() { defer pgConn.conn.Close() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) + deadline := time.Now().Add(time.Second * 15) + + ctx, cancel := context.WithDeadline(context.Background(), deadline) defer cancel() pgConn.CancelRequest(ctx) - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() + pgConn.conn.SetDeadline(deadline) pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) pgConn.conn.Read(make([]byte, 1)) From 2582879459d09494565cda6b2fe91d7660623122 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 12 Jan 2020 16:28:56 -0600 Subject: [PATCH 159/290] Fix typo - rename ayncClose to asyncClose --- pgconn.go | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/pgconn.go b/pgconn.go index 3a6c598a..89d7bb45 100644 --- a/pgconn.go +++ b/pgconn.go @@ -372,7 +372,7 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { n, err := pgConn.conn.Write(buf) if err != nil { - pgConn.ayncClose() + pgConn.asyncClose() return &writeError{err: err, safeToRetry: n == 0} } @@ -431,7 +431,7 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { if err != nil { // Close on anything other than timeout error - everything else is fatal if err, ok := err.(net.Error); !(ok && err.Timeout()) { - pgConn.ayncClose() + pgConn.asyncClose() } return nil, err @@ -444,7 +444,7 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { pgConn.parameterStatuses[msg.Name] = msg.Value case *pgproto3.ErrorResponse: if msg.Severity == "FATAL" { - pgConn.ayncClose() + pgConn.asyncClose() return nil, ErrorResponseToPgError(msg) } case *pgproto3.NoticeResponse: @@ -507,9 +507,9 @@ func (pgConn *PgConn) Close(ctx context.Context) error { return pgConn.conn.Close() } -// ayncClose marks the connection as closed and asynchronously sends a cancel query message and closes the underlying +// asyncClose marks the connection as closed and asynchronously sends a cancel query message and closes the underlying // connection. -func (pgConn *PgConn) ayncClose() { +func (pgConn *PgConn) asyncClose() { if pgConn.status == connStatusClosed { return } @@ -680,7 +680,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ n, err := pgConn.conn.Write(buf) if err != nil { - pgConn.ayncClose() + pgConn.asyncClose() return nil, &pgconnError{msg: "write failed", err: err, safeToRetry: n == 0} } @@ -692,7 +692,7 @@ readloop: for { msg, err := pgConn.receiveMessage() if err != nil { - pgConn.ayncClose() + pgConn.asyncClose() return nil, err } @@ -852,7 +852,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { n, err := pgConn.conn.Write(buf) if err != nil { - pgConn.ayncClose() + pgConn.asyncClose() pgConn.contextWatcher.Unwatch() multiResult.closed = true multiResult.err = &writeError{err: err, safeToRetry: n == 0} @@ -965,7 +965,7 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { n, err := pgConn.conn.Write(buf) if err != nil { - pgConn.ayncClose() + pgConn.asyncClose() result.concludeCommand(nil, &writeError{err: err, safeToRetry: n == 0}) pgConn.contextWatcher.Unwatch() result.closed = true @@ -996,7 +996,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm n, err := pgConn.conn.Write(buf) if err != nil { - pgConn.ayncClose() + pgConn.asyncClose() pgConn.unlock() return nil, &writeError{err: err, safeToRetry: n == 0} } @@ -1007,7 +1007,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm for { msg, err := pgConn.receiveMessage() if err != nil { - pgConn.ayncClose() + pgConn.asyncClose() return nil, err } @@ -1016,7 +1016,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm case *pgproto3.CopyData: _, err := w.Write(msg.Data) if err != nil { - pgConn.ayncClose() + pgConn.asyncClose() return nil, err } case *pgproto3.ReadyForQuery: @@ -1056,7 +1056,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co n, err := pgConn.conn.Write(buf) if err != nil { - pgConn.ayncClose() + pgConn.asyncClose() return nil, &writeError{err: err, safeToRetry: n == 0} } @@ -1067,7 +1067,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co for pendingCopyInResponse { msg, err := pgConn.receiveMessage() if err != nil { - pgConn.ayncClose() + pgConn.asyncClose() return nil, err } @@ -1096,7 +1096,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, err = pgConn.conn.Write(buf) if err != nil { - pgConn.ayncClose() + pgConn.asyncClose() return nil, err } } @@ -1105,7 +1105,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co case <-signalMessageChan: msg, err := pgConn.receiveMessage() if err != nil { - pgConn.ayncClose() + pgConn.asyncClose() return nil, err } @@ -1129,7 +1129,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } _, err = pgConn.conn.Write(buf) if err != nil { - pgConn.ayncClose() + pgConn.asyncClose() return nil, err } @@ -1137,7 +1137,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co for { msg, err := pgConn.receiveMessage() if err != nil { - pgConn.ayncClose() + pgConn.asyncClose() return nil, err } @@ -1182,7 +1182,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) mrr.pgConn.contextWatcher.Unwatch() mrr.err = err mrr.closed = true - mrr.pgConn.ayncClose() + mrr.pgConn.asyncClose() return nil, mrr.err } @@ -1371,7 +1371,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error rr.pgConn.contextWatcher.Unwatch() rr.closed = true if rr.multiResultReader == nil { - rr.pgConn.ayncClose() + rr.pgConn.asyncClose() } return nil, rr.err From e7dd01e064b5caf31bd290db23fadd13e60f8cd8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 13 Jan 2020 08:48:32 -0600 Subject: [PATCH 160/290] Update changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1debb10b..c79d4f0b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.2.1 (January 13, 2020) + +* Fix data race in context cancellation introduced in v1.2.0. + # 1.2.0 (January 11, 2020) ## Features From 8be01d690fed6a2bd6d1cad7819c4fe00cb3611e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 17 Jan 2020 17:38:07 -0600 Subject: [PATCH 161/290] Make Host comment more precise --- config.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.go b/config.go index 628deed8..9876ac94 100644 --- a/config.go +++ b/config.go @@ -29,7 +29,7 @@ type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error // Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig and // then it can be modified. A manually initialized Config will cause ConnectConfig to panic. type Config struct { - Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) + Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp) Port uint16 Database string User string From 59525245114b2a264f25fbeeddda947a64e2c61e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 17 Jan 2020 17:38:44 -0600 Subject: [PATCH 162/290] Add Hijack and Construct fixes #9 --- pgconn.go | 69 +++++++++++++++++++++++++++++++++++++++++++++++++- pgconn_test.go | 26 +++++++++++++++++++ 2 files changed, 94 insertions(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index 89d7bb45..44a08cc8 100644 --- a/pgconn.go +++ b/pgconn.go @@ -27,6 +27,8 @@ const ( connStatusBusy ) +const wbufLen = 1024 + // Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from // LISTEN/NOTIFY notification. type Notice PgError @@ -192,7 +194,7 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { pgConn := new(PgConn) pgConn.config = config - pgConn.wbuf = make([]byte, 0, 1024) + pgConn.wbuf = make([]byte, 0, wbufLen) var err error network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) @@ -1481,3 +1483,68 @@ func (pgConn *PgConn) EscapeString(s string) (string, error) { return strings.Replace(s, "'", "''", -1), nil } + +// HijackedConn is the result of hijacking a connection. +// +// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning +// compatibility. +type HijackedConn struct { + Conn net.Conn // the underlying TCP or unix domain socket connection + PID uint32 // backend pid + SecretKey uint32 // key to use to send a cancel query message to the server + ParameterStatuses map[string]string // parameters that have been reported by the server + TxStatus byte + Frontend Frontend + Config *Config +} + +// Hijack extracts the internal connection data. pgConn must be in an idle state. pgConn is unusable after hijacking. +// Hijacking is typically only useful when using pgconn to establish a connection, but taking complete control of the +// raw connection after that (e.g. a load balancer or proxy). +// +// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning +// compatibility. +func (pgConn *PgConn) Hijack() (*HijackedConn, error) { + if err := pgConn.lock(); err != nil { + return nil, err + } + pgConn.status = connStatusClosed + + return &HijackedConn{ + Conn: pgConn.conn, + PID: pgConn.pid, + SecretKey: pgConn.secretKey, + ParameterStatuses: pgConn.parameterStatuses, + TxStatus: pgConn.txStatus, + Frontend: pgConn.frontend, + Config: pgConn.config, + }, nil +} + +// Construct created a PgConn from an already established connection to a PostgreSQL server. This is the inverse of +// PgConn.Hijack. The connection must be in an idle state. +// +// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning +// compatibility. +func Construct(hc *HijackedConn) (*PgConn, error) { + pgConn := &PgConn{ + conn: hc.Conn, + pid: hc.PID, + secretKey: hc.SecretKey, + parameterStatuses: hc.ParameterStatuses, + txStatus: hc.TxStatus, + frontend: hc.Frontend, + config: hc.Config, + + status: connStatusIdle, + + wbuf: make([]byte, 0, wbufLen), + } + + pgConn.contextWatcher = ctxwatch.NewContextWatcher( + func() { pgConn.conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, + func() { pgConn.conn.SetDeadline(time.Time{}) }, + ) + + return pgConn, nil +} diff --git a/pgconn_test.go b/pgconn_test.go index 2c303d81..34982bb7 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1600,6 +1600,32 @@ func TestConnSendBytesAndReceiveMessage(t *testing.T) { ensureConnValid(t, pgConn) } +func TestHijackAndConstruct(t *testing.T) { + t.Parallel() + + origConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + hc, err := origConn.Hijack() + require.NoError(t, err) + + newConn, err := pgconn.Construct(hc) + require.NoError(t, err) + + defer closeConn(t, newConn) + + results, err := newConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + assert.NoError(t, err) + + assert.Len(t, results, 1) + assert.Nil(t, results[0].Err) + assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) + + ensureConnValid(t, newConn) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) if err != nil { From a4375eb53f25d9dc139319d01d1921b2927179f9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 17 Jan 2020 17:42:20 -0600 Subject: [PATCH 163/290] Add test that Hijack'ed conn is no longer usable. --- pgconn_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pgconn_test.go b/pgconn_test.go index 34982bb7..c37a2fb2 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1609,6 +1609,9 @@ func TestHijackAndConstruct(t *testing.T) { hc, err := origConn.Hijack() require.NoError(t, err) + _, err = origConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() + require.Error(t, err) + newConn, err := pgconn.Construct(hc) require.NoError(t, err) From f909a64ff567aec10157dc6a7efb9e5c9365aac6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 23 Jan 2020 20:55:52 -0600 Subject: [PATCH 164/290] Update pgproto3 to v2.0.1 --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 4a188cce..59e7e98e 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.0 + github.com/jackc/pgproto3/v2 v2.0.1 github.com/stretchr/testify v1.4.0 golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 golang.org/x/text v0.3.2 diff --git a/go.sum b/go.sum index 51c55d12..0c7fc9f1 100644 --- a/go.sum +++ b/go.sum @@ -28,6 +28,8 @@ github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 h1:f2HwOeI github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.0 h1:FApgMJ/GtaXfI0s8Lvd0kaLaRwMOhs4VH92pwkwQQvU= github.com/jackc/pgproto3/v2 v2.0.0/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.1 h1:Rdjp4NFjwHnEslx2b66FfCI2S0LhO4itac3hXz6WX9M= +github.com/jackc/pgproto3/v2 v2.0.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= From 6124b07bb1380523c4d6e01db8b55546e6c61136 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 23 Jan 2020 20:57:13 -0600 Subject: [PATCH 165/290] Update changelog --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c79d4f0b..26e9c8c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +# 1.3.0 (January 23, 2020) + +* Add Hijack and Construct. +* Update pgproto3 to v2.0.1. + # 1.2.1 (January 13, 2020) * Fix data race in context cancellation introduced in v1.2.0. From 139342081ef84e9ca6933f2faa19e20059ad61a3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 25 Jan 2020 20:32:42 -0600 Subject: [PATCH 166/290] Fix CopyFrom deadlock when multiple NoticeResponse received during copy fixes #21 --- pgconn.go | 51 ++++++++++++++++++++++++++++++++++---------------- pgconn_test.go | 40 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 16 deletions(-) diff --git a/pgconn.go b/pgconn.go index 44a08cc8..271e6628 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1084,26 +1084,44 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } // Send copy data - buf = make([]byte, 0, 65536) - buf = append(buf, 'd') - sp := len(buf) - var readErr error + abortCopyChan := make(chan struct{}) + copyErrChan := make(chan error) signalMessageChan := pgConn.signalMessage() - for readErr == nil && pgErr == nil { - var n int - n, readErr = r.Read(buf[5:cap(buf)]) - if n > 0 { - buf = buf[0 : n+5] - pgio.SetInt32(buf[sp:], int32(n+4)) - _, err = pgConn.conn.Write(buf) - if err != nil { - pgConn.asyncClose() - return nil, err + go func() { + buf := make([]byte, 0, 65536) + buf = append(buf, 'd') + sp := len(buf) + + for { + n, readErr := r.Read(buf[5:cap(buf)]) + if n > 0 { + buf = buf[0 : n+5] + pgio.SetInt32(buf[sp:], int32(n+4)) + + _, writeErr := pgConn.conn.Write(buf) + if writeErr != nil { + copyErrChan <- writeErr + return + } + } + if readErr != nil { + copyErrChan <- readErr + return + } + + select { + case <-abortCopyChan: + return + default: } } + }() + var copyErr error + for copyErr == nil && pgErr == nil { select { + case copyErr = <-copyErrChan: case <-signalMessageChan: msg, err := pgConn.receiveMessage() if err != nil { @@ -1120,13 +1138,14 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co default: } } + close(abortCopyChan) buf = buf[:0] - if readErr == io.EOF || pgErr != nil { + if copyErr == io.EOF || pgErr != nil { copyDone := &pgproto3.CopyDone{} buf = copyDone.Encode(buf) } else { - copyFail := &pgproto3.CopyFail{Message: readErr.Error()} + copyFail := &pgproto3.CopyFail{Message: copyErr.Error()} buf = copyFail.Encode(buf) } _, err = pgConn.conn.Write(buf) diff --git a/pgconn_test.go b/pgconn_test.go index c37a2fb2..19ad3a0a 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1463,6 +1463,46 @@ func TestConnCopyFromQueryNoTableError(t *testing.T) { ensureConnValid(t, pgConn) } +// https://github.com/jackc/pgconn/issues/21 +func TestConnCopyFromNoticeResponseReceivedMidStream(t *testing.T) { + t.Parallel() + + ctx := context.Background() + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(ctx, `create temporary table sentences( + t text, + ts tsvector + )`).ReadAll() + require.NoError(t, err) + + _, err = pgConn.Exec(ctx, `create function pg_temp.sentences_trigger() returns trigger as $$ + begin + new.ts := to_tsvector(new.t); + return new; + end + $$ language plpgsql;`).ReadAll() + require.NoError(t, err) + + _, err = pgConn.Exec(ctx, `create trigger sentences_update before insert on sentences for each row execute procedure pg_temp.sentences_trigger();`).ReadAll() + require.NoError(t, err) + + longString := make([]byte, 10001) + for i := range longString { + longString[i] = 'x' + } + + buf := &bytes.Buffer{} + for i := 0; i < 1000; i++ { + buf.Write([]byte(fmt.Sprintf("%s\n", string(longString)))) + } + + _, err = pgConn.CopyFrom(ctx, buf, "COPY sentences(t) FROM STDIN WITH (FORMAT csv)") + require.NoError(t, err) +} + func TestConnEscapeString(t *testing.T) { t.Parallel() From 67f2418279fabea76c16c3b613b9893a3b86e7d8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 25 Jan 2020 20:39:18 -0600 Subject: [PATCH 167/290] Make copyErrChan buffered so goroutine can always terminate It is possible the goroutine that is reading from copyErrChan will not read in case of error. --- pgconn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index 271e6628..e34b4cfe 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1085,7 +1085,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co // Send copy data abortCopyChan := make(chan struct{}) - copyErrChan := make(chan error) + copyErrChan := make(chan error, 1) signalMessageChan := pgConn.signalMessage() go func() { From c9abb86f21f0b89b909e9d112829e21daf3c06d8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 25 Jan 2020 20:40:21 -0600 Subject: [PATCH 168/290] Ensure write failure in CopyFrom closes connection --- pgconn.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pgconn.go b/pgconn.go index e34b4cfe..f56575ca 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1101,6 +1101,9 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co _, writeErr := pgConn.conn.Write(buf) if writeErr != nil { + // Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine. + pgConn.conn.Close() + copyErrChan <- writeErr return } From 406afa0eb7f8a23c96e0c6ec7bb56cbce3fc1ca4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 5 Feb 2020 11:06:09 -0600 Subject: [PATCH 169/290] Release v1.3.1 --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 26e9c8c7..5a9ca414 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.3.1 (February 5, 2020) + +* Fix CopyFrom deadlock when multiple NoticeResponse received during copy + # 1.3.0 (January 23, 2020) * Add Hijack and Construct. From 06c4e181b1abf6d6d531b3da38b40f8a1932d21b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 5 Feb 2020 11:49:40 -0600 Subject: [PATCH 170/290] go mod tidy --- go.sum | 2 -- 1 file changed, 2 deletions(-) diff --git a/go.sum b/go.sum index 0c7fc9f1..c23d4412 100644 --- a/go.sum +++ b/go.sum @@ -26,8 +26,6 @@ github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 h1:f2HwOeI1NIJyNFVVeh1gUISyt57iw/fmI/IXJfH3ATE= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= -github.com/jackc/pgproto3/v2 v2.0.0 h1:FApgMJ/GtaXfI0s8Lvd0kaLaRwMOhs4VH92pwkwQQvU= -github.com/jackc/pgproto3/v2 v2.0.0/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.1 h1:Rdjp4NFjwHnEslx2b66FfCI2S0LhO4itac3hXz6WX9M= github.com/jackc/pgproto3/v2 v2.0.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= From ac364e7a4366fc363b67cbfc06edf41594d9d8cc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 7 Feb 2020 15:40:50 -0600 Subject: [PATCH 171/290] Use writeError for Write error --- pgconn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index f56575ca..751d8fc0 100644 --- a/pgconn.go +++ b/pgconn.go @@ -683,7 +683,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ n, err := pgConn.conn.Write(buf) if err != nil { pgConn.asyncClose() - return nil, &pgconnError{msg: "write failed", err: err, safeToRetry: n == 0} + return nil, &writeError{err: err, safeToRetry: n == 0} } psd := &StatementDescription{Name: name, SQL: sql} From 6db848c6fca46bd3c67b1a66b5f764fbb16807ba Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 14 Feb 2020 17:56:59 -0600 Subject: [PATCH 172/290] Update chunkreader to v2.0.1 --- CHANGELOG.md | 4 ++++ go.mod | 2 +- go.sum | 2 ++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a9ca414..eb099dc2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.3.2 (February 14, 2020) + +* Update chunkreader to v2.0.1 for optimized default buffer size. + # 1.3.1 (February 5, 2020) * Fix CopyFrom deadlock when multiple NoticeResponse received during copy diff --git a/go.mod b/go.mod index 59e7e98e..37590559 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/jackc/pgconn go 1.12 require ( - github.com/jackc/chunkreader/v2 v2.0.0 + github.com/jackc/chunkreader/v2 v2.0.1 github.com/jackc/pgio v1.0.0 github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 github.com/jackc/pgpassfile v1.0.0 diff --git a/go.sum b/go.sum index c23d4412..28f094e7 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZb github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= +github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= From 911e727d78134c87d39b064aaee2bfda30f7afde Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 7 Mar 2020 10:55:28 -0600 Subject: [PATCH 173/290] ExecParams and ExecPrepared handle empty query An empty query does not return CommandComplete. Instead it returns EmptyQueryResponse. --- pgconn.go | 2 ++ pgconn_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/pgconn.go b/pgconn.go index 751d8fc0..6155281d 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1406,6 +1406,8 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error rr.fieldDescriptions = msg.Fields case *pgproto3.CommandComplete: rr.concludeCommand(CommandTag(msg.CommandTag), nil) + case *pgproto3.EmptyQueryResponse: + rr.concludeCommand(nil, nil) case *pgproto3.ErrorResponse: rr.concludeCommand(nil, ErrorResponseToPgError(msg)) } diff --git a/pgconn_test.go b/pgconn_test.go index 19ad3a0a..17b40343 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -668,6 +668,24 @@ func TestConnExecParamsPrecanceled(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnExecParamsEmptySQL(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + result := pgConn.ExecParams(ctx, "", nil, nil, nil, nil).Read() + assert.Nil(t, result.CommandTag) + assert.Len(t, result.Rows, 0) + assert.NoError(t, result.Err) + + ensureConnValid(t, pgConn) +} + func TestConnExecPrepared(t *testing.T) { t.Parallel() @@ -797,6 +815,27 @@ func TestConnExecPreparedPrecanceled(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnExecPreparedEmptySQL(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Prepare(ctx, "ps1", "", nil) + require.NoError(t, err) + + result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() + assert.Nil(t, result.CommandTag) + assert.Len(t, result.Rows, 0) + assert.NoError(t, result.Err) + + ensureConnValid(t, pgConn) +} + func TestConnExecBatch(t *testing.T) { t.Parallel() From cfbd2519e3a9dd64906a0888c38ee05d78e19889 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 7 Mar 2020 13:17:39 -0600 Subject: [PATCH 174/290] Add PGSERVICE and PGSERVICEFILE support --- config.go | 93 +++++++++++++++++++++++++++++++++++++++++------- config_test.go | 95 ++++++++++++++++++++++++++++++++++++++++++++++++++ go.mod | 3 +- go.sum | 4 +++ 4 files changed, 182 insertions(+), 13 deletions(-) diff --git a/config.go b/config.go index 9876ac94..19521a8f 100644 --- a/config.go +++ b/config.go @@ -20,6 +20,7 @@ import ( "github.com/jackc/chunkreader/v2" "github.com/jackc/pgpassfile" "github.com/jackc/pgproto3/v2" + "github.com/jackc/pgservicefile" errors "golang.org/x/xerrors" ) @@ -108,6 +109,8 @@ func NetworkAddress(host string, port uint16) (network, address string) { // PGUSER // PGPASSWORD // PGPASSFILE +// PGSERVICE +// PGSERVICEFILE // PGSSLMODE // PGSSLCERT // PGSSLKEY @@ -145,25 +148,40 @@ func NetworkAddress(host string, port uint16) (network, address string) { // // min_read_buffer_size // The minimum size of the internal read buffer. Default 8192. +// servicefile +// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a +// part of the connection string. func ParseConfig(connString string) (*Config, error) { - settings := defaultSettings() - addEnvSettings(settings) + defaultSettings := defaultSettings() + envSettings := parseEnvSettings() + connStringSettings := make(map[string]string) if connString != "" { + var err error // connString may be a database URL or a DSN if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { - err := addURLSettings(settings, connString) + connStringSettings, err = parseURLSettings(connString) if err != nil { return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err} } } else { - err := addDSNSettings(settings, connString) + connStringSettings, err = parseDSNSettings(connString) if err != nil { return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err} } } } + settings := mergeSettings(defaultSettings, envSettings, connStringSettings) + if service, present := settings["service"]; present { + serviceSettings, err := parseServiceSettings(settings["servicefile"], service) + if err != nil { + return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err} + } + + settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings) + } + minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32) if err != nil { return nil, &parseConfigError{connString: connString, msg: "cannot parse min_read_buffer_size", err: err} @@ -205,6 +223,8 @@ func ParseConfig(connString string) (*Config, error) { "sslrootcert": struct{}{}, "target_session_attrs": struct{}{}, "min_read_buffer_size": struct{}{}, + "service": struct{}{}, + "servicefile": struct{}{}, } for k, v := range settings { @@ -293,6 +313,7 @@ func defaultSettings() map[string]string { if err == nil { settings["user"] = user.Username settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass") + settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf") } settings["target_session_attrs"] = "any" @@ -321,7 +342,21 @@ func defaultHost() string { return "localhost" } -func addEnvSettings(settings map[string]string) { +func mergeSettings(settingSets ...map[string]string) map[string]string { + settings := make(map[string]string) + + for _, s2 := range settingSets { + for k, v := range s2 { + settings[k] = v + } + } + + return settings +} + +func parseEnvSettings() map[string]string { + settings := make(map[string]string) + nameMap := map[string]string{ "PGHOST": "host", "PGPORT": "port", @@ -336,6 +371,8 @@ func addEnvSettings(settings map[string]string) { "PGSSLCERT": "sslcert", "PGSSLROOTCERT": "sslrootcert", "PGTARGETSESSIONATTRS": "target_session_attrs", + "PGSERVICE": "service", + "PGSERVICEFILE": "servicefile", } for envname, realname := range nameMap { @@ -344,12 +381,16 @@ func addEnvSettings(settings map[string]string) { settings[realname] = value } } + + return settings } -func addURLSettings(settings map[string]string, connString string) error { +func parseURLSettings(connString string) (map[string]string, error) { + settings := make(map[string]string) + url, err := url.Parse(connString) if err != nil { - return err + return nil, err } if url.User != nil { @@ -387,12 +428,14 @@ func addURLSettings(settings map[string]string, connString string) error { settings[k] = v[0] } - return nil + return settings, nil } var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1} -func addDSNSettings(settings map[string]string, s string) error { +func parseDSNSettings(s string) (map[string]string, error) { + settings := make(map[string]string) + nameMap := map[string]string{ "dbname": "database", } @@ -401,7 +444,7 @@ func addDSNSettings(settings map[string]string, s string) error { var key, val string eqIdx := strings.IndexRune(s, '=') if eqIdx < 0 { - return errors.New("invalid dsn") + return nil, errors.New("invalid dsn") } key = strings.Trim(s[:eqIdx], " \t\n\r\v\f") @@ -434,7 +477,7 @@ func addDSNSettings(settings map[string]string, s string) error { } } if end == len(s) { - return errors.New("unterminated quoted string in connection info string") + return nil, errors.New("unterminated quoted string in connection info string") } val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) if end == len(s) { @@ -451,7 +494,33 @@ func addDSNSettings(settings map[string]string, s string) error { settings[key] = val } - return nil + return settings, nil +} + +func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) { + servicefile, err := pgservicefile.ReadServicefile(servicefilePath) + if err != nil { + fmt.Errorf("failed to read service file: %v", servicefile) + } + + service, err := servicefile.GetService(serviceName) + if err != nil { + fmt.Errorf("unable to find service: %v", servicefile) + } + + nameMap := map[string]string{ + "dbname": "database", + } + + settings := make(map[string]string, len(service.Settings)) + for k, v := range service.Settings { + if k2, present := nameMap[k]; present { + k = k2 + } + settings[k] = v + } + + return settings, nil } type pgTLSArgs struct { diff --git a/config_test.go b/config_test.go index 9eb5df2f..0819740f 100644 --- a/config_test.go +++ b/config_test.go @@ -648,6 +648,101 @@ func TestParseConfigReadsPgPassfile(t *testing.T) { assertConfigsEqual(t, expected, actual, "passfile") } +func TestParseConfigReadsPgServiceFile(t *testing.T) { + t.Parallel() + + tf, err := ioutil.TempFile("", "") + require.NoError(t, err) + + defer tf.Close() + defer os.Remove(tf.Name()) + + _, err = tf.Write([]byte(` +[abc] +host=abc.example.com +port=9999 +dbname=abcdb +user=abcuser + +[def] +host = def.example.com +dbname = defdb +user = defuser +application_name = spaced string +`)) + require.NoError(t, err) + + tests := []struct { + name string + connString string + config *pgconn.Config + }{ + { + name: "abc", + connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tf.Name(), "abc"), + config: &pgconn.Config{ + Host: "abc.example.com", + Database: "abcdb", + User: "abcuser", + Port: 9999, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "abc.example.com", + Port: 9999, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "def", + connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tf.Name(), "def"), + config: &pgconn.Config{ + Host: "def.example.com", + Port: 5432, + Database: "defdb", + User: "defuser", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + RuntimeParams: map[string]string{"application_name": "spaced string"}, + Fallbacks: []*pgconn.FallbackConfig{ + &pgconn.FallbackConfig{ + Host: "def.example.com", + Port: 5432, + TLSConfig: nil, + }, + }, + }, + }, + { + name: "conn string has precedence", + connString: fmt.Sprintf("postgres://other.example.com:7777/?servicefile=%s&service=%s&sslmode=disable", tf.Name(), "abc"), + config: &pgconn.Config{ + Host: "other.example.com", + Database: "abcdb", + User: "abcuser", + Port: 7777, + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + } + + for i, tt := range tests { + config, err := pgconn.ParseConfig(tt.connString) + if !assert.NoErrorf(t, err, "Test %d (%s)", i, tt.name) { + continue + } + + assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name)) + } +} + func TestParseConfigExtractsMinReadBufferSize(t *testing.T) { t.Parallel() diff --git a/go.mod b/go.mod index 37590559..b306e1e4 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,8 @@ require ( github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgproto3/v2 v2.0.1 - github.com/stretchr/testify v1.4.0 + github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 + github.com/stretchr/testify v1.5.1 golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 golang.org/x/text v0.3.2 golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 diff --git a/go.sum b/go.sum index 28f094e7..13f276b2 100644 --- a/go.sum +++ b/go.sum @@ -30,6 +30,8 @@ github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 h1:f2HwOeI github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.1 h1:Rdjp4NFjwHnEslx2b66FfCI2S0LhO4itac3hXz6WX9M= github.com/jackc/pgproto3/v2 v2.0.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 h1:Q3tB+ExeflWUW7AFcAhXqk40s9mnNYLk1nOkKNZ5GnU= +github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= @@ -69,6 +71,8 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= From ccf634cf2e2816d97bdc40644bf47b8dd3e5cd97 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 7 Mar 2020 13:21:51 -0600 Subject: [PATCH 175/290] Release 1.4.0 --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index eb099dc2..e5b11b7c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +# 1.4.0 (March 7, 2020) + +* Fix ExecParams and ExecPrepared handling of empty query. +* Support reading config from PostgreSQL service files. + # 1.3.2 (February 14, 2020) * Update chunkreader to v2.0.1 for optimized default buffer size. From 4ed48d05d2378ae2828b6f767544aad61aa99a9c Mon Sep 17 00:00:00 2001 From: Greg Curtis Date: Tue, 17 Mar 2020 23:30:56 -0700 Subject: [PATCH 176/290] Implement "verify-ca" SSL mode ParseConfig currently treats the libpq "verify-ca" SSL mode as "verify-full". This is okay from a security standpoint because "verify-full" performs certificate verification and hostname verification, whereas "verify-ca" only performs certificate verification. The downside to this approach is that checking the hostname is unnecessary when the server's certificate has been signed by a private CA. It can also cause the SSL handshake to fail when connecting to an instance by IP. For example, a Google Cloud SQL instance typically doesn't have a hostname and uses its own private CA to sign its server and client certs. This change uses the tls.Config.VerifyPeerCertificate function to perform certificate verification without checking the hostname when the "verify-ca" SSL mode is set. This brings pgconn's behavior closer to that of libpq. See https://github.com/golang/go/issues/21971#issuecomment-332693931 and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate for more details on how this is implemented. --- config.go | 41 +++++++++++++++++++++++++++++++++++------ config_test.go | 4 +++- 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/config.go b/config.go index 19521a8f..70e6073a 100644 --- a/config.go +++ b/config.go @@ -132,11 +132,6 @@ func NetworkAddress(host string, port uint16) (network, address string) { // See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of // security each sslmode provides. // -// "verify-ca" mode currently is treated as "verify-full". e.g. It has stronger -// security guarantees than it would with libpq. Do not rely on this behavior as it -// may be possible to match libpq in the future. If you need full security use -// "verify-full". -// // Other known differences with libpq: // // If a host name resolves into multiple addresses, libpq will try all addresses. pgconn will only try the first. @@ -554,7 +549,41 @@ func configTLS(settings map[string]string) ([]*tls.Config, error) { tlsConfig.InsecureSkipVerify = true case "require": tlsConfig.InsecureSkipVerify = sslrootcert == "" - case "verify-ca", "verify-full": + case "verify-ca": + // Don't perform the default certificate verification because it + // will verify the hostname. Instead, verify the server's + // certificate chain ourselves in VerifyPeerCertificate and + // ignore the server name. This emulates libpq's verify-ca + // behavior. + // + // See https://github.com/golang/go/issues/21971#issuecomment-332693931 + // and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate + // for more info. + tlsConfig.InsecureSkipVerify = true + tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error { + certs := make([]*x509.Certificate, len(certificates)) + for i, asn1Data := range certificates { + cert, err := x509.ParseCertificate(asn1Data) + if err != nil { + return errors.New("failed to parse certificate from server: " + err.Error()) + } + certs[i] = cert + } + + // Leave DNSName empty to skip hostname verification. + opts := x509.VerifyOptions{ + Roots: tlsConfig.RootCAs, + Intermediates: x509.NewCertPool(), + } + // Skip the first cert because it's the leaf. All others + // are intermediates. + for _, cert := range certs[1:] { + opts.Intermediates.AddCert(cert) + } + _, err := certs[0].Verify(opts) + return err + } + case "verify-full": tlsConfig.ServerName = host default: return nil, errors.New("sslmode is invalid") diff --git a/config_test.go b/config_test.go index 0819740f..b6068cc8 100644 --- a/config_test.go +++ b/config_test.go @@ -132,7 +132,9 @@ func TestParseConfig(t *testing.T) { Host: "localhost", Port: 5432, Database: "mydb", - TLSConfig: &tls.Config{ServerName: "localhost"}, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, RuntimeParams: map[string]string{}, }, }, From 11d9f4e54fb9a1534259b4b9375bcaa392f30425 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 30 Mar 2020 11:09:29 -0500 Subject: [PATCH 177/290] Update golang.org/x/crypto for security fix --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index b306e1e4..4dc095ca 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/jackc/pgproto3/v2 v2.0.1 github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 github.com/stretchr/testify v1.5.1 - golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 + golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59 golang.org/x/text v0.3.2 golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 ) diff --git a/go.sum b/go.sum index 13f276b2..23fb8b32 100644 --- a/go.sum +++ b/go.sum @@ -83,6 +83,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 h1:7KByu05hhLed2MO29w7p1XfZvZ13m8mub3shuVftRs0= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59 h1:3zb4D3T4G8jdExgVU/95+vQXfpEPiMdCaZgmGVxjNHM= +golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= From e4f3224f4c6d615b7199c9a606c4e3385efd1f21 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 30 Mar 2020 11:15:08 -0500 Subject: [PATCH 178/290] Update changelog for v1.5.0 --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e5b11b7c..c4c3b2d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +# 1.5.0 (March 30, 2020) + +* Update golang.org/x/crypto for security fix +* Implement "verify-ca" SSL mode (Greg Curtis) + # 1.4.0 (March 7, 2020) * Fix ExecParams and ExecPrepared handling of empty query. From 84aee0ab4443115da0c34114c300a50a410e5402 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Petr=20Jedin=C3=BD?= Date: Wed, 8 Apr 2020 00:08:53 +0200 Subject: [PATCH 179/290] Fix behavior of sslmode=require with sslrootcert present According to PostgreSQL documentation the behavior should be the same as that of verify-ca sslmode https://www.postgresql.org/docs/12/libpq-ssl.html --- config.go | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/config.go b/config.go index 70e6073a..06184b02 100644 --- a/config.go +++ b/config.go @@ -548,7 +548,17 @@ func configTLS(settings map[string]string) ([]*tls.Config, error) { case "allow", "prefer": tlsConfig.InsecureSkipVerify = true case "require": - tlsConfig.InsecureSkipVerify = sslrootcert == "" + // According to PostgreSQL documentation, if a root CA file exists, + // the behavior of sslmode=require should be the same as that of verify-ca + // + // See https://www.postgresql.org/docs/12/libpq-ssl.html + if sslrootcert != "" { + goto nextCase + } + tlsConfig.InsecureSkipVerify = true + break + nextCase: + fallthrough case "verify-ca": // Don't perform the default certificate verification because it // will verify the hostname. Instead, verify the server's From 5d2be99c254e76f7dfb8b481db1791dd613b5d4c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 7 Apr 2020 19:38:21 -0500 Subject: [PATCH 180/290] Fix panic when closing conn during cancellable query fixes #29 --- internal/ctxwatch/context_watcher_test.go | 11 +++++++++++ pgconn.go | 7 +++++++ pgconn_test.go | 13 +++++++++++++ 3 files changed, 31 insertions(+) diff --git a/internal/ctxwatch/context_watcher_test.go b/internal/ctxwatch/context_watcher_test.go index 0b491bf8..6348b729 100644 --- a/internal/ctxwatch/context_watcher_test.go +++ b/internal/ctxwatch/context_watcher_test.go @@ -59,6 +59,17 @@ func TestContextWatcherMultipleWatchPanics(t *testing.T) { require.Panics(t, func() { cw.Watch(ctx2) }, "Expected panic when Watch called multiple times") } +func TestContextWatcherUnwatchIsAlwaysSafe(t *testing.T) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + cw.Unwatch() // unwatch when not / never watching + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cw.Watch(ctx) + cw.Unwatch() + cw.Unwatch() // double unwatch +} + func TestContextWatcherStress(t *testing.T) { var cancelFuncCalls int64 var cleanupFuncCalls int64 diff --git a/pgconn.go b/pgconn.go index 6155281d..d5a424ac 100644 --- a/pgconn.go +++ b/pgconn.go @@ -494,6 +494,13 @@ func (pgConn *PgConn) Close(ctx context.Context) error { defer pgConn.conn.Close() if ctx != context.Background() { + // Close may be called while a cancellable query is in progress. This will most often be triggered by panic when + // a defer closes the connection (possibly indirectly via a transaction or a connection pool). Unwatch to end any + // previous watch. It is safe to Unwatch regardless of whether a watch is already is progress. + // + // See https://github.com/jackc/pgconn/issues/29 + pgConn.contextWatcher.Unwatch() + pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() } diff --git a/pgconn_test.go b/pgconn_test.go index 17b40343..e29a36b2 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1708,6 +1708,19 @@ func TestHijackAndConstruct(t *testing.T) { ensureConnValid(t, newConn) } +func TestConnCloseWhileCancellableQueryInProgress(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + ctx, _ := context.WithCancel(context.Background()) + pgConn.Exec(ctx, "select n from generate_series(1,10) n") + + closeCtx, _ := context.WithCancel(context.Background()) + pgConn.Close(closeCtx) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) if err != nil { From 8f3f335b0f54dcf33019f333c9bcf7d78e2fb0ba Mon Sep 17 00:00:00 2001 From: Tobias Salzmann <796084+Eun@users.noreply.github.com> Date: Thu, 30 Apr 2020 11:22:43 +0200 Subject: [PATCH 181/290] concludeCommand should not throw away fieldDescriptions --- pgconn.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pgconn.go b/pgconn.go index d5a424ac..e518744a 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1412,7 +1412,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error case *pgproto3.RowDescription: rr.fieldDescriptions = msg.Fields case *pgproto3.CommandComplete: - rr.concludeCommand(CommandTag(msg.CommandTag), nil) + rr.concludeCommand(CommandTa/g(msg.CommandTag), nil) case *pgproto3.EmptyQueryResponse: rr.concludeCommand(nil, nil) case *pgproto3.ErrorResponse: @@ -1429,7 +1429,6 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { rr.commandTag = commandTag rr.err = err - rr.fieldDescriptions = nil rr.rowValues = nil rr.commandConcluded = true } From 8d9293e1e7bebc0adf7bbca40fdf5579bfa8b5e9 Mon Sep 17 00:00:00 2001 From: Tobias Salzmann <796084+Eun@users.noreply.github.com> Date: Thu, 30 Apr 2020 11:27:01 +0200 Subject: [PATCH 182/290] Update pgconn.go --- pgconn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index e518744a..4ff3c706 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1412,7 +1412,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error case *pgproto3.RowDescription: rr.fieldDescriptions = msg.Fields case *pgproto3.CommandComplete: - rr.concludeCommand(CommandTa/g(msg.CommandTag), nil) + rr.concludeCommand(CommandTag(msg.CommandTag), nil) case *pgproto3.EmptyQueryResponse: rr.concludeCommand(nil, nil) case *pgproto3.ErrorResponse: From 391e1ef2ced76042fae145dc82285d98dd85d2c1 Mon Sep 17 00:00:00 2001 From: georgysavva Date: Sat, 2 May 2020 16:35:22 +0300 Subject: [PATCH 183/290] Parse connect timeout setting into Config. Restrict context timeout via Config.ConnectTimeout on .Connect() call. --- config.go | 41 ++++++++++-------- config_test.go | 57 +++++++++++++------------ pgconn.go | 5 +++ pgconn_test.go | 113 ++++++++++++++++++++++++++++++------------------- 4 files changed, 129 insertions(+), 87 deletions(-) diff --git a/config.go b/config.go index 06184b02..4f23f7c2 100644 --- a/config.go +++ b/config.go @@ -30,16 +30,17 @@ type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error // Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig and // then it can be modified. A manually initialized Config will cause ConnectConfig to panic. type Config struct { - Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp) - Port uint16 - Database string - User string - Password string - TLSConfig *tls.Config // nil disables TLS - DialFunc DialFunc // e.g. net.Dialer.DialContext - LookupFunc LookupFunc // e.g. net.Resolver.LookupHost - BuildFrontend BuildFrontendFunc - RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) + Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp) + Port uint16 + Database string + User string + Password string + TLSConfig *tls.Config // nil disables TLS + ConnectTimeout time.Duration + DialFunc DialFunc // e.g. net.Dialer.DialContext + LookupFunc LookupFunc // e.g. net.Resolver.LookupHost + BuildFrontend BuildFrontendFunc + RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) Fallbacks []*FallbackConfig @@ -191,12 +192,13 @@ func ParseConfig(connString string) (*Config, error) { BuildFrontend: makeDefaultBuildFrontendFunc(int(minReadBufferSize)), } - if connectTimeout, present := settings["connect_timeout"]; present { - dialFunc, err := makeConnectTimeoutDialFunc(connectTimeout) + if connectTimeoutSetting, present := settings["connect_timeout"]; present { + connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting) if err != nil { return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err} } - config.DialFunc = dialFunc + config.ConnectTimeout = connectTimeout + config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout) } else { defaultDialer := makeDefaultDialer() config.DialFunc = defaultDialer.DialContext @@ -672,18 +674,21 @@ func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc { } } -func makeConnectTimeoutDialFunc(s string) (DialFunc, error) { +func parseConnectTimeoutSetting(s string) (time.Duration, error) { timeout, err := strconv.ParseInt(s, 10, 64) if err != nil { - return nil, err + return 0, err } if timeout < 0 { - return nil, errors.New("negative timeout") + return 0, errors.New("negative timeout") } + return time.Duration(timeout) * time.Second, nil +} +func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc { d := makeDefaultDialer() - d.Timeout = time.Duration(timeout) * time.Second - return d.DialContext, nil + d.Timeout = timeout + return d.DialContext } // ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible diff --git a/config_test.go b/config_test.go index b6068cc8..35f6899e 100644 --- a/config_test.go +++ b/config_test.go @@ -7,6 +7,7 @@ import ( "os" "os/user" "testing" + "time" "github.com/jackc/pgconn" "github.com/stretchr/testify/assert" @@ -127,11 +128,11 @@ func TestParseConfig(t *testing.T) { name: "sslmode verify-ca", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-ca", config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", TLSConfig: &tls.Config{ InsecureSkipVerify: true, }, @@ -153,14 +154,15 @@ func TestParseConfig(t *testing.T) { }, { name: "database url everything", - connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema", + connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema&connect_timeout=5", config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: nil, + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + ConnectTimeout: 5 * time.Second, RuntimeParams: map[string]string{ "application_name": "pgxtest", "search_path": "myschema", @@ -230,14 +232,15 @@ func TestParseConfig(t *testing.T) { }, { name: "DSN everything", - connString: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable application_name=pgxtest search_path=myschema", + connString: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable application_name=pgxtest search_path=myschema connect_timeout=5", config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", - TLSConfig: nil, + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + ConnectTimeout: 5 * time.Second, RuntimeParams: map[string]string{ "application_name": "pgxtest", "search_path": "myschema", @@ -501,6 +504,7 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName) assert.Equalf(t, expected.User, actual.User, "%s - User", testName) assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName) + assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName) assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) // Can't test function equality, so just test that they are set or not. @@ -590,13 +594,14 @@ func TestParseConfigEnvLibpq(t *testing.T) { "PGAPPNAME": "pgxtest", }, config: &pgconn.Config{ - Host: "123.123.123.123", - Port: 7777, - Database: "foo", - User: "bar", - Password: "baz", - TLSConfig: nil, - RuntimeParams: map[string]string{"application_name": "pgxtest"}, + Host: "123.123.123.123", + Port: 7777, + Database: "foo", + User: "bar", + Password: "baz", + ConnectTimeout: 10 * time.Second, + TLSConfig: nil, + RuntimeParams: map[string]string{"application_name": "pgxtest"}, }, }, } diff --git a/pgconn.go b/pgconn.go index d5a424ac..932984c8 100644 --- a/pgconn.go +++ b/pgconn.go @@ -116,6 +116,11 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err panic("config must be created by ParseConfig") } + if config.ConnectTimeout != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, config.ConnectTimeout) + defer cancel() + } // Simplify usage by treating primary config and fallbacks the same. fallbackConfigs := []*FallbackConfig{ { diff --git a/pgconn_test.go b/pgconn_test.go index e29a36b2..2f7974ea 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -6,6 +6,7 @@ import ( "context" "crypto/tls" "fmt" + "github.com/jackc/pgmock" "io" "io/ioutil" "log" @@ -18,7 +19,6 @@ import ( "time" "github.com/jackc/pgconn" - "github.com/jackc/pgmock" "github.com/jackc/pgproto3/v2" errors "golang.org/x/xerrors" @@ -81,58 +81,85 @@ func (s pgmockWaitStep) Step(*pgproto3.Backend) error { return nil } -func TestConnectWithContextThatTimesOut(t *testing.T) { +func TestConnectTimeout(t *testing.T) { t.Parallel() - - script := &pgmock.Script{ - Steps: []pgmock.Step{ - pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}), - pgmock.SendMessage(&pgproto3.AuthenticationOk{}), - pgmockWaitStep(time.Millisecond * 500), - pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}), - pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + tests := []struct { + name string + connect func(connStr string) error + }{ + { + name: "via context that times out", + connect: func(connStr string) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50) + defer cancel() + _, err := pgconn.Connect(ctx, connStr) + return err + }, + }, + { + name: "via config ConnectTimeout", + connect: func(connStr string) error { + conf, err := pgconn.ParseConfig(connStr) + require.NoError(t, err) + conf.ConnectTimeout = time.Microsecond * 50 + _, err = pgconn.ConnectConfig(context.Background(), conf) + return err + }, }, } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + script := &pgmock.Script{ + Steps: []pgmock.Step{ + pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}), + pgmock.SendMessage(&pgproto3.AuthenticationOk{}), + pgmockWaitStep(time.Millisecond * 500), + pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}), + pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + }, + } - ln, err := net.Listen("tcp", "127.0.0.1:") - require.NoError(t, err) - defer ln.Close() + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() - serverErrChan := make(chan error, 1) - go func() { - defer close(serverErrChan) + serverErrChan := make(chan error, 1) + go func() { + defer close(serverErrChan) - conn, err := ln.Accept() - if err != nil { - serverErrChan <- err - return - } - defer conn.Close() + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() - err = conn.SetDeadline(time.Now().Add(time.Millisecond * 450)) - if err != nil { - serverErrChan <- err - return - } + err = conn.SetDeadline(time.Now().Add(time.Millisecond * 450)) + if err != nil { + serverErrChan <- err + return + } - err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)) - if err != nil { - serverErrChan <- err - return - } - }() + err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)) + if err != nil { + serverErrChan <- err + return + } + }() - parts := strings.Split(ln.Addr().String(), ":") - host := parts[0] - port := parts[1] - connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) - tooLate := time.Now().Add(time.Millisecond * 500) + parts := strings.Split(ln.Addr().String(), ":") + host := parts[0] + port := parts[1] + connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) + tooLate := time.Now().Add(time.Millisecond * 500) - ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50) - defer cancel() - _, err = pgconn.Connect(ctx, connStr) - require.True(t, pgconn.Timeout(err), err) - require.True(t, time.Now().Before(tooLate)) + err = tt.connect(connStr) + require.True(t, pgconn.Timeout(err), err) + require.True(t, time.Now().Before(tooLate)) + }) + } } func TestConnectInvalidUser(t *testing.T) { From 2d5a17beab6e8f40c60b56efe0a92e9528f2a424 Mon Sep 17 00:00:00 2001 From: georgysavva Date: Sat, 2 May 2020 16:39:51 +0300 Subject: [PATCH 184/290] Add comment. --- pgconn.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pgconn.go b/pgconn.go index 932984c8..69f42621 100644 --- a/pgconn.go +++ b/pgconn.go @@ -116,6 +116,7 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err panic("config must be created by ParseConfig") } + // ConnectTimeout restricts the whole connection process. if config.ConnectTimeout != 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, config.ConnectTimeout) From 01a7510ae90d37ffbee1438612f340e1f988bb17 Mon Sep 17 00:00:00 2001 From: georgysavva Date: Sat, 2 May 2020 16:43:02 +0300 Subject: [PATCH 185/290] Reformat imports --- pgconn_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pgconn_test.go b/pgconn_test.go index 2f7974ea..9a75dede 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -6,7 +6,6 @@ import ( "context" "crypto/tls" "fmt" - "github.com/jackc/pgmock" "io" "io/ioutil" "log" @@ -18,6 +17,8 @@ import ( "testing" "time" + "github.com/jackc/pgmock" + "github.com/jackc/pgconn" "github.com/jackc/pgproto3/v2" errors "golang.org/x/xerrors" From c4e6445cc73142e773864439198a9ccf72767cb1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 9 May 2020 10:19:39 -0500 Subject: [PATCH 186/290] Explicitly test supported Go and PostgreSQL versions --- .travis.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 87a0c058..0371101f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,8 @@ language: go go: - - 1.x + - 1.14.x + - 1.13.x - tip git: @@ -29,7 +30,6 @@ env: - PGVERSION=10 - PGVERSION=9.6 - PGVERSION=9.5 - - PGVERSION=9.4 cache: directories: From 08d071c0944e1d60af437ce4e11da55d7150385e Mon Sep 17 00:00:00 2001 From: Lukas Vogel Date: Fri, 8 May 2020 13:38:34 +0200 Subject: [PATCH 187/290] Handle IPv6 in connection URLs Previously IPv6 addresses were wrongly split and lead to a parse error. This commit fixes the behavior. --- config.go | 20 +++++++++++++++----- config_test.go | 46 +++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 56 insertions(+), 10 deletions(-) diff --git a/config.go b/config.go index 06184b02..a9b19d67 100644 --- a/config.go +++ b/config.go @@ -399,13 +399,19 @@ func parseURLSettings(connString string) (map[string]string, error) { var hosts []string var ports []string for _, host := range strings.Split(url.Host, ",") { - parts := strings.SplitN(host, ":", 2) - if parts[0] != "" { - hosts = append(hosts, parts[0]) + if host == "" { + continue } - if len(parts) == 2 { - ports = append(ports, parts[1]) + if isIPOnly(host) { + hosts = append(hosts, strings.Trim(host, "[]")) + continue } + h, p, err := net.SplitHostPort(host) + if err != nil { + return nil, errors.Errorf("failed to split host:port in '%s', err: %w", host, err) + } + hosts = append(hosts, h) + ports = append(ports, p) } if len(hosts) > 0 { settings["host"] = strings.Join(hosts, ",") @@ -426,6 +432,10 @@ func parseURLSettings(connString string) (map[string]string, error) { return settings, nil } +func isIPOnly(host string) bool { + return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":") +} + var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1} func parseDSNSettings(s string) (map[string]string, error) { diff --git a/config_test.go b/config_test.go index b6068cc8..d932a605 100644 --- a/config_test.go +++ b/config_test.go @@ -127,11 +127,11 @@ func TestParseConfig(t *testing.T) { name: "sslmode verify-ca", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-ca", config: &pgconn.Config{ - User: "jack", - Password: "secret", - Host: "localhost", - Port: 5432, - Database: "mydb", + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", TLSConfig: &tls.Config{ InsecureSkipVerify: true, }, @@ -228,6 +228,42 @@ func TestParseConfig(t *testing.T) { RuntimeParams: map[string]string{}, }, }, + { + name: "database url IPv4 with port", + connString: "postgresql://jack@127.0.0.1:5433/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Host: "127.0.0.1", + Port: 5433, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url IPv6 with port", + connString: "postgresql://jack@[2001:db8::1]:5433/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Host: "2001:db8::1", + Port: 5433, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + name: "database url IPv6 no port", + connString: "postgresql://jack@[2001:db8::1]/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Host: "2001:db8::1", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, { name: "DSN everything", connString: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable application_name=pgxtest search_path=myschema", From 2ccb66fe2159792f5b28d01e65e2461795a6f854 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 16 May 2020 18:48:05 -0500 Subject: [PATCH 188/290] Doc fix --- pgconn.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pgconn.go b/pgconn.go index 4a6ef430..541f280e 100644 --- a/pgconn.go +++ b/pgconn.go @@ -890,11 +890,11 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { // ExecParams will panic if len(paramOIDs) is not 0, 1, or len(paramValues). // // paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or -// binary format. If paramFormats is nil all results will be in text protocol. ExecParams will panic if +// binary format. If paramFormats is nil all params are text format. ExecParams will panic if // len(paramFormats) is not 0, 1, or len(paramValues). // // resultFormats is a slice of format codes determining for each result column whether it is encoded in text or -// binary format. If resultFormats is nil all results will be in text protocol. +// binary format. If resultFormats is nil all results will be in text format. // // ResultReader must be closed before PgConn can be used again. func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *ResultReader { @@ -917,11 +917,11 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] // paramValues are the parameter values. It must be encoded in the format given by paramFormats. // // paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or -// binary format. If paramFormats is nil all results will be in text protocol. ExecPrepared will panic if +// binary format. If paramFormats is nil all params are text format. ExecPrepared will panic if // len(paramFormats) is not 0, 1, or len(paramValues). // // resultFormats is a slice of format codes determining for each result column whether it is encoded in text or -// binary format. If resultFormats is nil all results will be in text protocol. +// binary format. If resultFormats is nil all results will be in text format. // // ResultReader must be closed before PgConn can be used again. func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *ResultReader { From 8c33aa24430a9bbd9d34af6d8c211ab632f22e17 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 May 2020 11:47:42 -0500 Subject: [PATCH 189/290] Remove CPU wasting empty default statement fixes #39 --- pgconn.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index 541f280e..43edbb6b 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1151,7 +1151,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co default: signalMessageChan = pgConn.signalMessage() } - default: } } close(abortCopyChan) From 2647eff5675f7a45d02b82b633580357b11e05ad Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 25 May 2020 11:49:37 -0500 Subject: [PATCH 190/290] Fix ValidateConnect with cancelable context fixes #40 --- pgconn.go | 7 +++++++ pgconn_test.go | 7 +++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/pgconn.go b/pgconn.go index 43edbb6b..5644904a 100644 --- a/pgconn.go +++ b/pgconn.go @@ -288,6 +288,13 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig case *pgproto3.ReadyForQuery: pgConn.status = connStatusIdle if config.ValidateConnect != nil { + // ValidateConnect may execute commands that cause the context to be watched again. Unwatch first to avoid + // the watch already in progress panic. This is that last thing done by this method so there is no need to + // restart the watch after ValidateConnect returns. + // + // See https://github.com/jackc/pgconn/issues/40. + pgConn.contextWatcher.Unwatch() + err := config.ValidateConnect(ctx, pgConn) if err != nil { pgConn.conn.Close() diff --git a/pgconn_test.go b/pgconn_test.go index 9a75dede..6362c51b 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -346,9 +346,12 @@ func TestConnectWithValidateConnectTargetSessionAttrsReadWrite(t *testing.T) { config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite config.RuntimeParams["default_transaction_read_only"] = "on" - conn, err := pgconn.ConnectConfig(context.Background(), config) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + conn, err := pgconn.ConnectConfig(ctx, config) if !assert.NotNil(t, err) { - conn.Close(context.Background()) + conn.Close(ctx) } } From 8d541d00043bb65c60ee837927dc06c163729954 Mon Sep 17 00:00:00 2001 From: georgysavva Date: Mon, 1 Jun 2020 19:20:17 +0300 Subject: [PATCH 191/290] Add Config.Copy() method that return a smart copy of the config. --- config.go | 9 +++++++++ config_test.go | 33 +++++++++++++++++++++++++++++++++ go.mod | 1 + go.sum | 2 ++ 4 files changed, 45 insertions(+) diff --git a/config.go b/config.go index 299d4784..6640038b 100644 --- a/config.go +++ b/config.go @@ -17,6 +17,8 @@ import ( "strings" "time" + "github.com/mohae/deepcopy" + "github.com/jackc/chunkreader/v2" "github.com/jackc/pgpassfile" "github.com/jackc/pgproto3/v2" @@ -62,6 +64,13 @@ type Config struct { createdByParseConfig bool // Used to enforce created by ParseConfig rule. } +func (c *Config) Copy() *Config { + newConf := deepcopy.Copy(c).(*Config) + // We need to set this field manually because it's unexported and deep copy won't touch it. + newConf.createdByParseConfig = c.createdByParseConfig + return newConf +} + // FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a // network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections. type FallbackConfig struct { diff --git a/config_test.go b/config_test.go index 515ea6d3..72b775d4 100644 --- a/config_test.go +++ b/config_test.go @@ -1,6 +1,7 @@ package pgconn_test import ( + "context" "crypto/tls" "fmt" "io/ioutil" @@ -527,6 +528,38 @@ func TestParseConfig(t *testing.T) { } } +func TestConfigCopyReturnsEqualConfig(t *testing.T) { + connString := "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema&connect_timeout=5" + original, err := pgconn.ParseConfig(connString) + require.NoError(t, err) + + copied := original.Copy() + assertConfigsEqual(t, original, copied, "Test Config.Copy() returns equal config") +} + +func TestConfigCopyOriginalConfigDidNotChange(t *testing.T) { + connString := "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema&connect_timeout=5" + original, err := pgconn.ParseConfig(connString) + require.NoError(t, err) + + copied := original.Copy() + copied.Port = uint16(5433) + + assert.Equal(t, uint16(5432), original.Port) +} + +func TestConfigCopyCanBeUsedToConnect(t *testing.T) { + connString := os.Getenv("PGX_TEST_CONN_STRING") + original, err := pgconn.ParseConfig(connString) + require.NoError(t, err) + + copied := original.Copy() + assert.NotPanics(t, func() { + _, err = pgconn.ConnectConfig(context.Background(), copied) + }) + assert.NoError(t, err) +} + func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) { if !assert.NotNil(t, expected) { return diff --git a/go.mod b/go.mod index 4dc095ca..841eccc7 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgproto3/v2 v2.0.1 github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 + github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 github.com/stretchr/testify v1.5.1 golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59 golang.org/x/text v0.3.2 diff --git a/go.sum b/go.sum index 23fb8b32..1514a339 100644 --- a/go.sum +++ b/go.sum @@ -54,6 +54,8 @@ github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw= +github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= From a6d9265506df51336914769fd786ef8712997c31 Mon Sep 17 00:00:00 2001 From: georgysavva Date: Mon, 1 Jun 2020 20:52:08 +0300 Subject: [PATCH 192/290] Implement deep copy manually, stop using an external deep copy library. Add comment to the Config.Copy() method. --- config.go | 30 +++++++++++++++++++++++++----- config_test.go | 10 ++++++++-- go.mod | 1 - go.sum | 2 -- 4 files changed, 33 insertions(+), 10 deletions(-) diff --git a/config.go b/config.go index 6640038b..7ed99096 100644 --- a/config.go +++ b/config.go @@ -17,8 +17,6 @@ import ( "strings" "time" - "github.com/mohae/deepcopy" - "github.com/jackc/chunkreader/v2" "github.com/jackc/pgpassfile" "github.com/jackc/pgproto3/v2" @@ -64,10 +62,32 @@ type Config struct { createdByParseConfig bool // Used to enforce created by ParseConfig rule. } +// Copy returns a deep copy of the config that is safe to use and modify. +// The only exception is the TLSConfig field: +// according to the tls.Config docs it must not be modified after creation. func (c *Config) Copy() *Config { - newConf := deepcopy.Copy(c).(*Config) - // We need to set this field manually because it's unexported and deep copy won't touch it. - newConf.createdByParseConfig = c.createdByParseConfig + newConf := new(Config) + *newConf = *c + if newConf.TLSConfig != nil { + newConf.TLSConfig = c.TLSConfig.Clone() + } + if newConf.RuntimeParams != nil { + newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams)) + for k, v := range c.RuntimeParams { + newConf.RuntimeParams[k] = v + } + } + if newConf.Fallbacks != nil { + newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks)) + for i, fallback := range c.Fallbacks { + newFallback := new(FallbackConfig) + *newFallback = *fallback + if newFallback.TLSConfig != nil { + newFallback.TLSConfig = fallback.TLSConfig.Clone() + } + newConf.Fallbacks[i] = newFallback + } + } return newConf } diff --git a/config_test.go b/config_test.go index 72b775d4..ebe627b1 100644 --- a/config_test.go +++ b/config_test.go @@ -529,7 +529,7 @@ func TestParseConfig(t *testing.T) { } func TestConfigCopyReturnsEqualConfig(t *testing.T) { - connString := "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema&connect_timeout=5" + connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5" original, err := pgconn.ParseConfig(connString) require.NoError(t, err) @@ -538,14 +538,20 @@ func TestConfigCopyReturnsEqualConfig(t *testing.T) { } func TestConfigCopyOriginalConfigDidNotChange(t *testing.T) { - connString := "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema&connect_timeout=5" + connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5" original, err := pgconn.ParseConfig(connString) require.NoError(t, err) copied := original.Copy() + assertConfigsEqual(t, original, copied, "Test Config.Copy() returns equal config") + copied.Port = uint16(5433) + copied.RuntimeParams["foo"] = "bar" + copied.Fallbacks[0].Port = uint16(5433) assert.Equal(t, uint16(5432), original.Port) + assert.Equal(t, "", original.RuntimeParams["foo"]) + assert.Equal(t, uint16(5432), original.Fallbacks[0].Port) } func TestConfigCopyCanBeUsedToConnect(t *testing.T) { diff --git a/go.mod b/go.mod index 841eccc7..4dc095ca 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,6 @@ require ( github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgproto3/v2 v2.0.1 github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 - github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 github.com/stretchr/testify v1.5.1 golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59 golang.org/x/text v0.3.2 diff --git a/go.sum b/go.sum index 1514a339..23fb8b32 100644 --- a/go.sum +++ b/go.sum @@ -54,8 +54,6 @@ github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw= -github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= From 6cd2127b96fdbc7cdddcec0f8cdfbe6a322cbf24 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Jun 2020 10:44:22 -0500 Subject: [PATCH 193/290] Update pgproto3 dependency --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 4dc095ca..9b6baf5b 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.1 + github.com/jackc/pgproto3/v2 v2.0.2 github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 github.com/stretchr/testify v1.5.1 golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59 diff --git a/go.sum b/go.sum index 23fb8b32..2063a801 100644 --- a/go.sum +++ b/go.sum @@ -30,6 +30,8 @@ github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 h1:f2HwOeI github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.1 h1:Rdjp4NFjwHnEslx2b66FfCI2S0LhO4itac3hXz6WX9M= github.com/jackc/pgproto3/v2 v2.0.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.0.2 h1:q1Hsy66zh4vuNsajBUF2PNqfAMMfxU5mk594lPE9vjY= +github.com/jackc/pgproto3/v2 v2.0.2/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 h1:Q3tB+ExeflWUW7AFcAhXqk40s9mnNYLk1nOkKNZ5GnU= github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= From 59a0074b0a32d05ee32ccc39c6c9ca013013a69d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Jun 2020 10:51:44 -0500 Subject: [PATCH 194/290] Release v1.6.0 --- CHANGELOG.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c4c3b2d2..68b151d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,14 @@ +# 1.6.0 (June 6, 2020) + +* Fix panic when closing conn during cancellable query +* Fix behavior of sslmode=require with sslrootcert present (Petr Jediný) +* Fix field descriptions available after command concluded (Tobias Salzmann) +* Support connect_timeout (georgysavva) +* Handle IPv6 in connection URLs (Lukas Vogel) +* Fix ValidateConnect with cancelable context +* Improve CopyFrom performance +* Add Config.Copy (georgysavva) + # 1.5.0 (March 30, 2020) * Update golang.org/x/crypto for security fix From 6b254a445e49cc9f23f25a1e2eca7cd98fd65850 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 11 Jun 2020 20:51:40 -0500 Subject: [PATCH 195/290] Fix doc for ParseConfig --- config.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config.go b/config.go index 7ed99096..279ae400 100644 --- a/config.go +++ b/config.go @@ -112,7 +112,7 @@ func NetworkAddress(host string, port uint16) (network, address string) { return network, address } -// ParseConfig builds a []*Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same +// ParseConfig builds a *Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same // defaults as libpq (e.g. port=5432) and understands most PG* environment variables. connString may be a URL or a DSN. // It also may be empty to only read from the environment. If a password is not supplied it will attempt to read the // .pgpass file. From a1b9eb4d4e06feaa3587b1633165b7a52c80b4e7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 11 Jun 2020 20:55:41 -0500 Subject: [PATCH 196/290] Fix parseServiceSettings not returning error --- config.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config.go b/config.go index 279ae400..f15b33b4 100644 --- a/config.go +++ b/config.go @@ -536,12 +536,12 @@ func parseDSNSettings(s string) (map[string]string, error) { func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) { servicefile, err := pgservicefile.ReadServicefile(servicefilePath) if err != nil { - fmt.Errorf("failed to read service file: %v", servicefile) + return nil, fmt.Errorf("failed to read service file: %v", servicefile) } service, err := servicefile.GetService(serviceName) if err != nil { - fmt.Errorf("unable to find service: %v", servicefile) + return nil, fmt.Errorf("unable to find service: %v", servicefile) } nameMap := map[string]string{ From f27e874d554167e7bf0bead3d5b5bf8923abba05 Mon Sep 17 00:00:00 2001 From: Lukas Vogel Date: Fri, 12 Jun 2020 13:01:57 +0200 Subject: [PATCH 197/290] redact passwords in parse config errors Redact passwords when printing the parseConfigError in a best effort manner. This prevents people from leaking the password into logs, if they just print the error in logs. --- errors.go | 30 ++++++++++++++++++++++++++++-- errors_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ export_test.go | 11 +++++++++++ 3 files changed, 83 insertions(+), 2 deletions(-) create mode 100644 errors_test.go create mode 100644 export_test.go diff --git a/errors.go b/errors.go index 7a21af98..b746c825 100644 --- a/errors.go +++ b/errors.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "net" + "net/url" + "regexp" "strings" errors "golang.org/x/xerrors" @@ -98,10 +100,11 @@ type parseConfigError struct { } func (e *parseConfigError) Error() string { + connString := redactPW(e.connString) if e.err == nil { - return fmt.Sprintf("cannot parse `%s`: %s", e.connString, e.msg) + return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg) } - return fmt.Sprintf("cannot parse `%s`: %s (%s)", e.connString, e.msg, e.err.Error()) + return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error()) } func (e *parseConfigError) Unwrap() error { @@ -164,3 +167,26 @@ func (e *writeError) SafeToRetry() bool { func (e *writeError) Unwrap() error { return e.err } + +func redactPW(connString string) string { + if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { + if u, err := url.Parse(connString); err == nil { + return redactURL(u) + } + } + quotedDSN := regexp.MustCompile(`password='[^']*'`) + connString = quotedDSN.ReplaceAllLiteralString(connString, "password=xxxxx") + plainDSN := regexp.MustCompile(`password=[^ ]*`) + connString = plainDSN.ReplaceAllLiteralString(connString, "password=xxxxx") + return connString +} + +func redactURL(u *url.URL) string { + if u == nil { + return "" + } + if _, pwSet := u.User.Password(); pwSet { + u.User = url.UserPassword(u.User.Username(), "xxxxx") + } + return u.String() +} diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 00000000..bef835f8 --- /dev/null +++ b/errors_test.go @@ -0,0 +1,44 @@ +package pgconn_test + +import ( + "testing" + + "github.com/jackc/pgconn" + "github.com/stretchr/testify/assert" +) + +func TestConfigError(t *testing.T) { + tests := []struct { + name string + err error + expectedMsg string + }{ + { + name: "url with password", + err: pgconn.NewParseConfigError("postgresql://foo:password@host", "msg", nil), + expectedMsg: "cannot parse `postgresql://foo:xxxxx@host`: msg", + }, + { + name: "dsn with password unquoted", + err: pgconn.NewParseConfigError("host=host password=password user=user", "msg", nil), + expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg", + }, + { + name: "dsn with password quoted", + err: pgconn.NewParseConfigError("host=host password='pass word' user=user", "msg", nil), + expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg", + }, + { + name: "weird url", + err: pgconn.NewParseConfigError("postgresql://foo::pasword@host:1:", "msg", nil), + expectedMsg: "cannot parse `postgresql://foo:xxxxx@host:1:`: msg", + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.EqualError(t, tt.err, tt.expectedMsg) + }) + } +} diff --git a/export_test.go b/export_test.go new file mode 100644 index 00000000..2a0bad8b --- /dev/null +++ b/export_test.go @@ -0,0 +1,11 @@ +// File export_test exports some methods for better testing. + +package pgconn + +func NewParseConfigError(conn, msg string, err error) error { + return &parseConfigError{ + connString: conn, + msg: msg, + err: err, + } +} From 7cf5101bb27a95b5c5af77632b7dc0ddcef20690 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Jun 2020 12:58:15 -0500 Subject: [PATCH 198/290] Add NewConfig() refs #42 --- config.go | 21 ++++++++++++++++++++- pgconn_test.go | 18 ++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/config.go b/config.go index f15b33b4..b6c27ce5 100644 --- a/config.go +++ b/config.go @@ -112,6 +112,18 @@ func NetworkAddress(host string, port uint16) (network, address string) { return network, address } +// NewConfig returns an *Config without parsing a connection string or reading the standard PG* environment variables. +// Host, Port, Database, User, and Password must be set before the config can be used to establish a connection. +func NewConfig() *Config { + return &Config{ + DialFunc: makeDefaultDialer().DialContext, + LookupFunc: makeDefaultResolver().LookupHost, + BuildFrontend: makeDefaultBuildFrontendFunc(8192), + RuntimeParams: map[string]string{}, + createdByParseConfig: true, + } +} + // ParseConfig builds a *Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same // defaults as libpq (e.g. port=5432) and understands most PG* environment variables. connString may be a URL or a DSN. // It also may be empty to only read from the environment. If a password is not supplied it will attempt to read the @@ -154,7 +166,7 @@ func NetworkAddress(host string, port uint16) (network, address string) { // See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are // usually but not always the environment variable name downcased and without the "PG" prefix. // -// Important TLS Security Notes: +// Important Security Notes: // // ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if // not set. @@ -162,6 +174,13 @@ func NetworkAddress(host string, port uint16) (network, address string) { // See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of // security each sslmode provides. // +// The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of +// the Config struct. If the main TLS config is manually changed it will not affect the fallbacks. For example, in the +// case of sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try +// the fallback which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config +// is manually changed later but the unencrypted fallback is present. Remove or update all fallbacks or use NewConfig +// to build the config manually. +// // Other known differences with libpq: // // If a host name resolves into multiple addresses, libpq will try all addresses. pgconn will only try the first. diff --git a/pgconn_test.go b/pgconn_test.go index 6362c51b..2d3e482b 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -198,6 +198,24 @@ func TestConnectWithConnectionRefused(t *testing.T) { } } +func TestConnectConfigFromNewConfig(t *testing.T) { + t.Parallel() + + baseConfig, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + config := pgconn.NewConfig() + config.Host = baseConfig.Host + config.Port = baseConfig.Port + config.Database = baseConfig.Database + config.User = baseConfig.User + config.Password = baseConfig.Password + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + closeConn(t, conn) +} + func TestConnectCustomDialer(t *testing.T) { t.Parallel() From 473062b114e54e039d7af4a951e877b430ea0c67 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 27 Jun 2020 11:29:21 -0500 Subject: [PATCH 199/290] Remove NewConfig and add more docs for ParseConfig refs #42 --- config.go | 31 ++++++++++++------------------- pgconn_test.go | 18 ------------------ 2 files changed, 12 insertions(+), 37 deletions(-) diff --git a/config.go b/config.go index b6c27ce5..75292125 100644 --- a/config.go +++ b/config.go @@ -27,8 +27,8 @@ import ( type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error -// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig and -// then it can be modified. A manually initialized Config will cause ConnectConfig to panic. +// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig. A +// manually initialized Config will cause ConnectConfig to panic. type Config struct { Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp) Port uint16 @@ -112,18 +112,6 @@ func NetworkAddress(host string, port uint16) (network, address string) { return network, address } -// NewConfig returns an *Config without parsing a connection string or reading the standard PG* environment variables. -// Host, Port, Database, User, and Password must be set before the config can be used to establish a connection. -func NewConfig() *Config { - return &Config{ - DialFunc: makeDefaultDialer().DialContext, - LookupFunc: makeDefaultResolver().LookupHost, - BuildFrontend: makeDefaultBuildFrontendFunc(8192), - RuntimeParams: map[string]string{}, - createdByParseConfig: true, - } -} - // ParseConfig builds a *Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same // defaults as libpq (e.g. port=5432) and understands most PG* environment variables. connString may be a URL or a DSN. // It also may be empty to only read from the environment. If a password is not supplied it will attempt to read the @@ -135,6 +123,11 @@ func NewConfig() *Config { // # Example URL // postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca // +// The returned *Config may be modified. However, it is strongly recommended that any configuration that can be done +// through the connection string be done there. In particular the fields Host, Port, TLSConfig, and Fallbacks can be +// interdependent (e.g. TLSConfig needs knowledge of the host to validate the server certificate). These fields should +// not be modified individually. They should all be modified or all left unchanged. +// // ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated // values that will be tried in order. This can be used as part of a high availability system. See // https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. @@ -175,11 +168,11 @@ func NewConfig() *Config { // security each sslmode provides. // // The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of -// the Config struct. If the main TLS config is manually changed it will not affect the fallbacks. For example, in the -// case of sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try -// the fallback which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config -// is manually changed later but the unencrypted fallback is present. Remove or update all fallbacks or use NewConfig -// to build the config manually. +// the Config struct. If TLSConfig is manually changed it will not affect the fallbacks. For example, in the case of +// sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try the fallback +// which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually +// changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting +// TLCConfig. // // Other known differences with libpq: // diff --git a/pgconn_test.go b/pgconn_test.go index 2d3e482b..6362c51b 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -198,24 +198,6 @@ func TestConnectWithConnectionRefused(t *testing.T) { } } -func TestConnectConfigFromNewConfig(t *testing.T) { - t.Parallel() - - baseConfig, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - - config := pgconn.NewConfig() - config.Host = baseConfig.Host - config.Port = baseConfig.Port - config.Database = baseConfig.Database - config.User = baseConfig.User - config.Password = baseConfig.Password - - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - closeConn(t, conn) -} - func TestConnectCustomDialer(t *testing.T) { t.Parallel() From 82c2752e7151902340c65d2b61e483424b308c96 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 27 Jun 2020 11:35:23 -0500 Subject: [PATCH 200/290] Update golang.org/x/text to 0.3.3 golang.org/x/text had a vulnerability: https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14040 pgconn does not appear to use the affected code path, but it is still worth updating away from the vulnerable version. fixes #44 --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 9b6baf5b..45aa8a46 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,6 @@ require ( github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 github.com/stretchr/testify v1.5.1 golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59 - golang.org/x/text v0.3.2 + golang.org/x/text v0.3.3 golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 ) diff --git a/go.sum b/go.sum index 2063a801..29b3ebd8 100644 --- a/go.sum +++ b/go.sum @@ -102,6 +102,8 @@ golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= From 65717779e443e346ee1d7183f1ba1e2fb3947e7b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 27 Jun 2020 11:46:16 -0500 Subject: [PATCH 201/290] Fix crash when PGSERVICE not found --- config.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config.go b/config.go index 75292125..44953e0f 100644 --- a/config.go +++ b/config.go @@ -548,12 +548,12 @@ func parseDSNSettings(s string) (map[string]string, error) { func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) { servicefile, err := pgservicefile.ReadServicefile(servicefilePath) if err != nil { - return nil, fmt.Errorf("failed to read service file: %v", servicefile) + return nil, fmt.Errorf("failed to read service file: %v", servicefilePath) } service, err := servicefile.GetService(serviceName) if err != nil { - return nil, fmt.Errorf("unable to find service: %v", servicefile) + return nil, fmt.Errorf("unable to find service: %v", serviceName) } nameMap := map[string]string{ From bd7ffdb480379b6d0e73a0bf7fdf9d7050f9fa54 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 27 Jun 2020 11:48:20 -0500 Subject: [PATCH 202/290] Update golang.org/x/crypto dependency --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 45aa8a46..2487c271 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/jackc/pgproto3/v2 v2.0.2 github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 github.com/stretchr/testify v1.5.1 - golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59 + golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 golang.org/x/text v0.3.3 golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 ) diff --git a/go.sum b/go.sum index 29b3ebd8..2440dd48 100644 --- a/go.sum +++ b/go.sum @@ -87,6 +87,8 @@ golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 h1:7KByu05hhLed2MO29w7p1X golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59 h1:3zb4D3T4G8jdExgVU/95+vQXfpEPiMdCaZgmGVxjNHM= golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= From 503c2b445f76da704197860e7158bf75ce2a9ef0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 27 Jun 2020 11:51:30 -0500 Subject: [PATCH 203/290] Release v1.6.1 --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 68b151d8..25376301 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,10 @@ +# 1.6.1 (June 27, 2020) + +* Update golang.org/x/crypto to latest +* Update golang.org/x/text to 0.3.3 +* Fix error handling for bad PGSERVICE definition +* Redact passwords in ParseConfig errors (Lukas Vogel) + # 1.6.0 (June 6, 2020) * Fix panic when closing conn during cancellable query From 12752ce5d63917f9fa710ba0b117aa1b550b43ba Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 13 Jul 2020 19:34:45 -0500 Subject: [PATCH 204/290] Update pgservicefile --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 2487c271..d3550ca8 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgproto3/v2 v2.0.2 - github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 + github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/stretchr/testify v1.5.1 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 golang.org/x/text v0.3.3 diff --git a/go.sum b/go.sum index 2440dd48..0b144d0f 100644 --- a/go.sum +++ b/go.sum @@ -34,6 +34,8 @@ github.com/jackc/pgproto3/v2 v2.0.2 h1:q1Hsy66zh4vuNsajBUF2PNqfAMMfxU5mk594lPE9v github.com/jackc/pgproto3/v2 v2.0.2/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 h1:Q3tB+ExeflWUW7AFcAhXqk40s9mnNYLk1nOkKNZ5GnU= github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= From 9295bf7483021745c921e818151ef3b735090b4f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 14 Jul 2020 12:07:27 -0500 Subject: [PATCH 205/290] Update changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 25376301..c3088dd0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.6.2 (July 14, 2020) + +* Update pgservicefile library + # 1.6.1 (June 27, 2020) * Update golang.org/x/crypto to latest From 271b0ac95ee4426f3495a2577b624296c5372a70 Mon Sep 17 00:00:00 2001 From: vahid-sohrabloo Date: Fri, 17 Jul 2020 20:31:10 +0430 Subject: [PATCH 206/290] AppendCertsFromPEM doesn't have error and removes pgTLSArgs AppendCertsFromPEM doesn't have error and removes pgTLSArgs because not used --- config.go | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/config.go b/config.go index 44953e0f..906ed7f4 100644 --- a/config.go +++ b/config.go @@ -571,13 +571,6 @@ func parseServiceSettings(servicefilePath, serviceName string) (map[string]strin return settings, nil } -type pgTLSArgs struct { - sslMode string - sslRootCert string - sslCert string - sslKey string -} - // configTLS uses libpq's TLS parameters to construct []*tls.Config. It is // necessary to allow returning multiple TLS configs as sslmode "allow" and // "prefer" allow fallback. @@ -662,7 +655,7 @@ func configTLS(settings map[string]string) ([]*tls.Config, error) { } if !caCertPool.AppendCertsFromPEM(caCert) { - return nil, errors.Errorf("unable to add CA to cert pool: %w", err) + return nil, errors.New("unable to add CA to cert pool") } tlsConfig.RootCAs = caCertPool From 37c9edc242e83750fcfbef327001fd65603d63d0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 22 Jul 2020 06:43:39 -0500 Subject: [PATCH 207/290] Release v1.6.3 --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c3088dd0..58481415 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.6.3 (July 22, 2020) + +* Fix error message after AppendCertsFromPEM failure (vahid-sohrabloo) + # 1.6.2 (July 14, 2020) * Update pgservicefile library From 4e4c4ea5410aba437bc6d6e2c5a93c4acf6cce73 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 29 Jul 2020 21:47:23 -0500 Subject: [PATCH 208/290] Fix deadlock on error after CommandComplete but before ReadyForQuery See: https://github.com/jackc/pgx/issues/800 --- pgconn.go | 7 +++++- pgconn_test.go | 65 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index 5644904a..50607095 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1435,12 +1435,17 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error } func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { + // Keep the first error that is recorded. Store the error before checking if the command is already concluded to + // allow for receiving an error after CommandComplete but before ReadyForQuery. + if err != nil && rr.err == nil { + rr.err = err + } + if rr.commandConcluded { return } rr.commandTag = commandTag - rr.err = err rr.rowValues = nil rr.commandConcluded = true } diff --git a/pgconn_test.go b/pgconn_test.go index 6362c51b..379aa266 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1752,6 +1752,71 @@ func TestConnCloseWhileCancellableQueryInProgress(t *testing.T) { pgConn.Close(closeCtx) } +// https://github.com/jackc/pgx/issues/800 +func TestFatalErrorReceivedAfterCommandComplete(t *testing.T) { + t.Parallel() + + steps := pgmock.AcceptUnauthenticatedConnRequestSteps() + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Bind{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Execute{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Sync{})) + steps = append(steps, pgmock.SendMessage(&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{ + {Name: []byte("mock")}, + }})) + steps = append(steps, pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: []byte("SELECT 0")})) + steps = append(steps, pgmock.SendMessage(&pgproto3.ErrorResponse{Severity: "FATAL", Code: "57P01"})) + + script := &pgmock.Script{Steps: steps} + + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverErrChan := make(chan error, 1) + go func() { + defer close(serverErrChan) + + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + err = conn.SetDeadline(time.Now().Add(5 * time.Second)) + if err != nil { + serverErrChan <- err + return + } + + err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)) + if err != nil { + serverErrChan <- err + return + } + }() + + parts := strings.Split(ln.Addr().String(), ":") + host := parts[0] + port := parts[1] + connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + conn, err := pgconn.Connect(ctx, connStr) + require.NoError(t, err) + + rr := conn.ExecParams(ctx, "mocked...", nil, nil, nil, nil) + + for rr.NextRow() { + } + + _, err = rr.Close() + require.Error(t, err) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) if err != nil { From 44079b0d2c9ac3629a8ea9cafe4d75568b376f9e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 29 Jul 2020 22:11:15 -0500 Subject: [PATCH 209/290] Fix panic on parsing DSN with trailing '=' Also correctly return error with leading '='. fixes #47 --- config.go | 7 ++++++- config_test.go | 11 +++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/config.go b/config.go index 906ed7f4..b2583546 100644 --- a/config.go +++ b/config.go @@ -497,7 +497,8 @@ func parseDSNSettings(s string) (map[string]string, error) { key = strings.Trim(s[:eqIdx], " \t\n\r\v\f") s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f") - if s[0] != '\'' { + if len(s) == 0 { + } else if s[0] != '\'' { end := 0 for ; end < len(s); end++ { if asciiSpace[s[end]] == 1 { @@ -539,6 +540,10 @@ func parseDSNSettings(s string) (map[string]string, error) { key = k } + if key == "" { + return nil, errors.New("invalid dsn") + } + settings[key] = val } diff --git a/config_test.go b/config_test.go index ebe627b1..264eb299 100644 --- a/config_test.go +++ b/config_test.go @@ -528,6 +528,17 @@ func TestParseConfig(t *testing.T) { } } +// https://github.com/jackc/pgconn/issues/47 +func TestParseConfigDSNWithTrailingEmptyEqualDoesNotPanic(t *testing.T) { + _, err := pgconn.ParseConfig("host= user= password= port= database=") + require.NoError(t, err) +} + +func TestParseConfigDSNLeadingEqual(t *testing.T) { + _, err := pgconn.ParseConfig("= user=jack") + require.Error(t, err) +} + func TestConfigCopyReturnsEqualConfig(t *testing.T) { connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5" original, err := pgconn.ParseConfig(connString) From f45b4d6b76091608f30b1f8ff5de046a32080d3d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 29 Jul 2020 22:17:02 -0500 Subject: [PATCH 210/290] Release v1.6.4 --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 58481415..a6668fb0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +# 1.6.4 (July 29, 2020) + +* Fix deadlock on error after CommandComplete but before ReadyForQuery +* Fix panic on parsing DSN with trailing '=' + # 1.6.3 (July 22, 2020) * Fix error message after AppendCertsFromPEM failure (vahid-sohrabloo) From b6e34b44e5c0657be2eb7c36f5b12cc5c88dfe1f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 31 Jul 2020 17:04:18 -0500 Subject: [PATCH 211/290] Update pgproto3 --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index d3550ca8..a20501c5 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.2 + github.com/jackc/pgproto3/v2 v2.0.3 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/stretchr/testify v1.5.1 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 diff --git a/go.sum b/go.sum index 0b144d0f..d3226ebc 100644 --- a/go.sum +++ b/go.sum @@ -32,6 +32,8 @@ github.com/jackc/pgproto3/v2 v2.0.1 h1:Rdjp4NFjwHnEslx2b66FfCI2S0LhO4itac3hXz6WX github.com/jackc/pgproto3/v2 v2.0.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.0.2 h1:q1Hsy66zh4vuNsajBUF2PNqfAMMfxU5mk594lPE9vjY= github.com/jackc/pgproto3/v2 v2.0.2/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.0.3 h1:2S4PhE00mvdvaSiCYR1ZCmR1NAxeYfTSsqqSKxE1vzo= +github.com/jackc/pgproto3/v2 v2.0.3/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 h1:Q3tB+ExeflWUW7AFcAhXqk40s9mnNYLk1nOkKNZ5GnU= github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= From c894ca8b7d2a9e3dcf03f8cc319461a73f6a7fc6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 1 Aug 2020 05:49:56 -0500 Subject: [PATCH 212/290] Update pgproto3 to v2.0.4 --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index a20501c5..a74028c8 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.3 + github.com/jackc/pgproto3/v2 v2.0.4 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/stretchr/testify v1.5.1 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 diff --git a/go.sum b/go.sum index d3226ebc..61a2896b 100644 --- a/go.sum +++ b/go.sum @@ -34,6 +34,8 @@ github.com/jackc/pgproto3/v2 v2.0.2 h1:q1Hsy66zh4vuNsajBUF2PNqfAMMfxU5mk594lPE9v github.com/jackc/pgproto3/v2 v2.0.2/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.0.3 h1:2S4PhE00mvdvaSiCYR1ZCmR1NAxeYfTSsqqSKxE1vzo= github.com/jackc/pgproto3/v2 v2.0.3/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.0.4 h1:RHkX5ZUD9bl/kn0f9dYUWs1N7Nwvo1wwUYvKiR26Zco= +github.com/jackc/pgproto3/v2 v2.0.4/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 h1:Q3tB+ExeflWUW7AFcAhXqk40s9mnNYLk1nOkKNZ5GnU= github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= From 3eb5432c4738bc58b1e52a91c58decf07324130f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 20 Aug 2020 22:00:21 -0500 Subject: [PATCH 213/290] Add PgConn.CleanupChan --- CHANGELOG.md | 4 ++++ helper_test.go | 5 +++++ pgconn.go | 22 +++++++++++++++++++++- pgconn_test.go | 35 +++++++++++++++++++++++++++++++++++ 4 files changed, 65 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a6668fb0..8b988590 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# Unreleased + +* Add PgConn.CleanupChan so connection pools can determine when async close is complete + # 1.6.4 (July 29, 2020) * Fix deadlock on error after CommandComplete but before ReadyForQuery diff --git a/helper_test.go b/helper_test.go index 1a3ca75e..abb04905 100644 --- a/helper_test.go +++ b/helper_test.go @@ -15,6 +15,11 @@ func closeConn(t testing.TB, conn *pgconn.PgConn) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() require.NoError(t, conn.Close(ctx)) + select { + case <-conn.CleanupChan(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } } // Do a simple query to ensure the connection is still usable diff --git a/pgconn.go b/pgconn.go index 50607095..c132b26b 100644 --- a/pgconn.go +++ b/pgconn.go @@ -89,6 +89,8 @@ type PgConn struct { resultReader ResultReader multiResultReader MultiResultReader contextWatcher *ctxwatch.ContextWatcher + + cleanupChan chan struct{} } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) @@ -201,6 +203,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn := new(PgConn) pgConn.config = config pgConn.wbuf = make([]byte, 0, wbufLen) + pgConn.cleanupChan = make(chan struct{}) var err error network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) @@ -504,6 +507,7 @@ func (pgConn *PgConn) Close(ctx context.Context) error { } pgConn.status = connStatusClosed + defer close(pgConn.cleanupChan) defer pgConn.conn.Close() if ctx != context.Background() { @@ -538,6 +542,7 @@ func (pgConn *PgConn) asyncClose() { pgConn.status = connStatusClosed go func() { + defer close(pgConn.cleanupChan) defer pgConn.conn.Close() deadline := time.Now().Add(time.Second * 15) @@ -554,7 +559,21 @@ func (pgConn *PgConn) asyncClose() { }() } +// CleanupChan returns a channel that will be closed after all underlying resources have been cleaned up. A closed +// connection is no longer usable, but underlying resources, in particular the net.Conn, may not have finished closing +// yet. This is because certain errors such as a context cancellation require that the interrupted function call return +// immediately, but the error may also cause the connection to be closed. In these cases the underlying resources are +// closed asynchronously. +// +// This is only likely to be useful to connection pools. It gives them a way avoid establishing a new connection while +// an old connection is still being cleaned up and thereby exceeding the maximum pool size. +func (pgConn *PgConn) CleanupChan() chan (struct{}) { + return pgConn.cleanupChan +} + // IsClosed reports if the connection has been closed. +// +// CleanupChan() can be used to determine if all cleanup has been completed. func (pgConn *PgConn) IsClosed() bool { return pgConn.status < connStatusIdle } @@ -1585,7 +1604,8 @@ func Construct(hc *HijackedConn) (*PgConn, error) { status: connStatusIdle, - wbuf: make([]byte, 0, wbufLen), + wbuf: make([]byte, 0, wbufLen), + cleanupChan: make(chan struct{}), } pgConn.contextWatcher = ctxwatch.NewContextWatcher( diff --git a/pgconn_test.go b/pgconn_test.go index 379aa266..56afc1c2 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -547,6 +547,11 @@ func TestConnExecContextCanceled(t *testing.T) { err = multiResult.Close() assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupChan(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } } func TestConnExecContextPrecanceled(t *testing.T) { @@ -680,6 +685,11 @@ func TestConnExecParamsCanceled(t *testing.T) { assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupChan(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } } func TestConnExecParamsPrecanceled(t *testing.T) { @@ -824,6 +834,11 @@ func TestConnExecPreparedCanceled(t *testing.T) { assert.Equal(t, pgconn.CommandTag(nil), commandTag) assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupChan(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } } func TestConnExecPreparedPrecanceled(t *testing.T) { @@ -1306,6 +1321,11 @@ func TestConnCopyToCanceled(t *testing.T) { assert.Equal(t, pgconn.CommandTag(nil), res) assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupChan(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } } func TestConnCopyToPrecanceled(t *testing.T) { @@ -1397,6 +1417,11 @@ func TestConnCopyFromCanceled(t *testing.T) { assert.Error(t, err) assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupChan(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } } func TestConnCopyFromPrecanceled(t *testing.T) { @@ -1647,6 +1672,11 @@ func TestConnContextCanceledCancelsRunningQueryOnServer(t *testing.T) { err = multiResult.Close() assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupChan(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } otherConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) @@ -1750,6 +1780,11 @@ func TestConnCloseWhileCancellableQueryInProgress(t *testing.T) { closeCtx, _ := context.WithCancel(context.Background()) pgConn.Close(closeCtx) + select { + case <-pgConn.CleanupChan(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } } // https://github.com/jackc/pgx/issues/800 From fdfc783345f6b5df05b2039666c59ebd29a7e683 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 20 Aug 2020 22:08:40 -0500 Subject: [PATCH 214/290] Rename CleanupChan to CleanupDone --- CHANGELOG.md | 2 +- helper_test.go | 2 +- pgconn.go | 18 +++++++++--------- pgconn_test.go | 14 +++++++------- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b988590..497e00a1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Unreleased -* Add PgConn.CleanupChan so connection pools can determine when async close is complete +* Add PgConn.CleanupDone so connection pools can determine when async close is complete # 1.6.4 (July 29, 2020) diff --git a/helper_test.go b/helper_test.go index abb04905..87613dc9 100644 --- a/helper_test.go +++ b/helper_test.go @@ -16,7 +16,7 @@ func closeConn(t testing.TB, conn *pgconn.PgConn) { defer cancel() require.NoError(t, conn.Close(ctx)) select { - case <-conn.CleanupChan(): + case <-conn.CleanupDone(): case <-time.After(5 * time.Second): t.Fatal("Connection cleanup exceeded maximum time") } diff --git a/pgconn.go b/pgconn.go index c132b26b..d031b7a1 100644 --- a/pgconn.go +++ b/pgconn.go @@ -90,7 +90,7 @@ type PgConn struct { multiResultReader MultiResultReader contextWatcher *ctxwatch.ContextWatcher - cleanupChan chan struct{} + cleanupDone chan struct{} } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) @@ -203,7 +203,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn := new(PgConn) pgConn.config = config pgConn.wbuf = make([]byte, 0, wbufLen) - pgConn.cleanupChan = make(chan struct{}) + pgConn.cleanupDone = make(chan struct{}) var err error network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) @@ -507,7 +507,7 @@ func (pgConn *PgConn) Close(ctx context.Context) error { } pgConn.status = connStatusClosed - defer close(pgConn.cleanupChan) + defer close(pgConn.cleanupDone) defer pgConn.conn.Close() if ctx != context.Background() { @@ -542,7 +542,7 @@ func (pgConn *PgConn) asyncClose() { pgConn.status = connStatusClosed go func() { - defer close(pgConn.cleanupChan) + defer close(pgConn.cleanupDone) defer pgConn.conn.Close() deadline := time.Now().Add(time.Second * 15) @@ -559,7 +559,7 @@ func (pgConn *PgConn) asyncClose() { }() } -// CleanupChan returns a channel that will be closed after all underlying resources have been cleaned up. A closed +// CleanupDone returns a channel that will be closed after all underlying resources have been cleaned up. A closed // connection is no longer usable, but underlying resources, in particular the net.Conn, may not have finished closing // yet. This is because certain errors such as a context cancellation require that the interrupted function call return // immediately, but the error may also cause the connection to be closed. In these cases the underlying resources are @@ -567,13 +567,13 @@ func (pgConn *PgConn) asyncClose() { // // This is only likely to be useful to connection pools. It gives them a way avoid establishing a new connection while // an old connection is still being cleaned up and thereby exceeding the maximum pool size. -func (pgConn *PgConn) CleanupChan() chan (struct{}) { - return pgConn.cleanupChan +func (pgConn *PgConn) CleanupDone() chan (struct{}) { + return pgConn.cleanupDone } // IsClosed reports if the connection has been closed. // -// CleanupChan() can be used to determine if all cleanup has been completed. +// CleanupDone() can be used to determine if all cleanup has been completed. func (pgConn *PgConn) IsClosed() bool { return pgConn.status < connStatusIdle } @@ -1605,7 +1605,7 @@ func Construct(hc *HijackedConn) (*PgConn, error) { status: connStatusIdle, wbuf: make([]byte, 0, wbufLen), - cleanupChan: make(chan struct{}), + cleanupDone: make(chan struct{}), } pgConn.contextWatcher = ctxwatch.NewContextWatcher( diff --git a/pgconn_test.go b/pgconn_test.go index 56afc1c2..f6750a60 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -548,7 +548,7 @@ func TestConnExecContextCanceled(t *testing.T) { assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) select { - case <-pgConn.CleanupChan(): + case <-pgConn.CleanupDone(): case <-time.After(5 * time.Second): t.Fatal("Connection cleanup exceeded maximum time") } @@ -686,7 +686,7 @@ func TestConnExecParamsCanceled(t *testing.T) { assert.True(t, pgConn.IsClosed()) select { - case <-pgConn.CleanupChan(): + case <-pgConn.CleanupDone(): case <-time.After(5 * time.Second): t.Fatal("Connection cleanup exceeded maximum time") } @@ -835,7 +835,7 @@ func TestConnExecPreparedCanceled(t *testing.T) { assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) select { - case <-pgConn.CleanupChan(): + case <-pgConn.CleanupDone(): case <-time.After(5 * time.Second): t.Fatal("Connection cleanup exceeded maximum time") } @@ -1322,7 +1322,7 @@ func TestConnCopyToCanceled(t *testing.T) { assert.True(t, pgConn.IsClosed()) select { - case <-pgConn.CleanupChan(): + case <-pgConn.CleanupDone(): case <-time.After(5 * time.Second): t.Fatal("Connection cleanup exceeded maximum time") } @@ -1418,7 +1418,7 @@ func TestConnCopyFromCanceled(t *testing.T) { assert.True(t, pgConn.IsClosed()) select { - case <-pgConn.CleanupChan(): + case <-pgConn.CleanupDone(): case <-time.After(5 * time.Second): t.Fatal("Connection cleanup exceeded maximum time") } @@ -1673,7 +1673,7 @@ func TestConnContextCanceledCancelsRunningQueryOnServer(t *testing.T) { assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) select { - case <-pgConn.CleanupChan(): + case <-pgConn.CleanupDone(): case <-time.After(5 * time.Second): t.Fatal("Connection cleanup exceeded maximum time") } @@ -1781,7 +1781,7 @@ func TestConnCloseWhileCancellableQueryInProgress(t *testing.T) { closeCtx, _ := context.WithCancel(context.Background()) pgConn.Close(closeCtx) select { - case <-pgConn.CleanupChan(): + case <-pgConn.CleanupDone(): case <-time.After(5 * time.Second): t.Fatal("Connection cleanup exceeded maximum time") } From 1debbfeec4c2b878d81b3491499f0bc6f5c5a40d Mon Sep 17 00:00:00 2001 From: Sebastiaan Mannem Date: Sun, 2 Aug 2020 16:43:47 +0200 Subject: [PATCH 215/290] Adding SendBytesWithResults option to receive data after sending a message (used by copy-both) --- pgconn.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/pgconn.go b/pgconn.go index d031b7a1..decdc03d 100644 --- a/pgconn.go +++ b/pgconn.go @@ -904,6 +904,51 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { return multiResult } +// SendBytesWithResults sends buf to the PostgreSQL server. It must only be used when the connection is not busy. e.g. It is as +// error to call SendBytes while reading the result of a query. +// +// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. +// See https://www.postgresql.org/docs/current/protocol.html. +// +// So far this only seems required with CopyDone handling. +func (pgConn *PgConn) SendBytesWithResults(ctx context.Context, buf []byte) *MultiResultReader { + if err := pgConn.lock(); err != nil { + return &MultiResultReader{ + closed: true, + err: err, + } + } + + pgConn.multiResultReader = MultiResultReader{ + pgConn: pgConn, + ctx: ctx, + } + multiResult := &pgConn.multiResultReader + if ctx != context.Background() { + select { + case <-ctx.Done(): + multiResult.closed = true + multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} + pgConn.unlock() + return multiResult + default: + } + pgConn.contextWatcher.Watch(ctx) + } + + n, err := pgConn.conn.Write(buf) + if err != nil { + pgConn.asyncClose() + pgConn.contextWatcher.Unwatch() + multiResult.closed = true + multiResult.err = &writeError{err: err, safeToRetry: n == 0} + pgConn.unlock() + return multiResult + } + + return multiResult +} + // ExecParams executes a command via the PostgreSQL extended query protocol. // // sql is a SQL command string. It may only contain one query. Parameter substitution is positional using $1, $2, $3, From 5db484908cf74895bb9e03414d1ba022a24e11bd Mon Sep 17 00:00:00 2001 From: Sebastiaan Mannem Date: Sun, 23 Aug 2020 00:21:46 +0200 Subject: [PATCH 216/290] Changing SendBytesWithResults to ReceiveResults (that only does the reading). --- pgconn.go | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/pgconn.go b/pgconn.go index decdc03d..e2ab5c13 100644 --- a/pgconn.go +++ b/pgconn.go @@ -904,14 +904,12 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { return multiResult } -// SendBytesWithResults sends buf to the PostgreSQL server. It must only be used when the connection is not busy. e.g. It is as -// error to call SendBytes while reading the result of a query. +// ReceiveResults reads the result that might be returned by Postgres after a SendBytes +// (e.a. after sending a CopyDone in a copy-both situation). // // This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. // See https://www.postgresql.org/docs/current/protocol.html. -// -// So far this only seems required with CopyDone handling. -func (pgConn *PgConn) SendBytesWithResults(ctx context.Context, buf []byte) *MultiResultReader { +func (pgConn *PgConn) ReceiveResults(ctx context.Context) *MultiResultReader { if err := pgConn.lock(); err != nil { return &MultiResultReader{ closed: true, @@ -936,16 +934,6 @@ func (pgConn *PgConn) SendBytesWithResults(ctx context.Context, buf []byte) *Mul pgConn.contextWatcher.Watch(ctx) } - n, err := pgConn.conn.Write(buf) - if err != nil { - pgConn.asyncClose() - pgConn.contextWatcher.Unwatch() - multiResult.closed = true - multiResult.err = &writeError{err: err, safeToRetry: n == 0} - pgConn.unlock() - return multiResult - } - return multiResult } From 0d4f029683fc678cb3084b4dd714e1fde88856e3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Sep 2020 13:14:11 -0500 Subject: [PATCH 217/290] Exec(Params|Prepared) return ResultReader with FieldDescriptions loaded Previously, it wasn't loaded until NextRow was called the first time. --- pgconn.go | 50 ++++++++++++++++++++++++++++++++++++++++++++++++-- pgconn_test.go | 38 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 84 insertions(+), 4 deletions(-) diff --git a/pgconn.go b/pgconn.go index e2ab5c13..ff812069 100644 --- a/pgconn.go +++ b/pgconn.go @@ -84,6 +84,8 @@ type PgConn struct { bufferingReceiveMsg pgproto3.BackendMessage bufferingReceiveErr error + peekedMsg pgproto3.BackendMessage + // Reusable / preallocated resources wbuf []byte // write buffer resultReader ResultReader @@ -427,8 +429,12 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa return msg, err } -// receiveMessage receives a message without setting up context cancellation -func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { +// peekMessage peeks at the next message without setting up context cancellation. +func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { + if pgConn.peekedMsg != nil { + return pgConn.peekedMsg, nil + } + var msg pgproto3.BackendMessage var err error if pgConn.bufferingReceive { @@ -455,6 +461,23 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { return nil, err } + pgConn.peekedMsg = msg + return msg, nil +} + +// receiveMessage receives a message without setting up context cancellation +func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { + msg, err := pgConn.peekMessage() + if err != nil { + // Close on anything other than timeout error - everything else is fatal + if err, ok := err.(net.Error); !(ok && err.Timeout()) { + pgConn.asyncClose() + } + + return nil, err + } + pgConn.peekedMsg = nil + switch msg := msg.(type) { case *pgproto3.ReadyForQuery: pgConn.txStatus = msg.TxStatus @@ -1044,7 +1067,10 @@ func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { pgConn.contextWatcher.Unwatch() result.closed = true pgConn.unlock() + return } + + result.readUntilRowDescription() } // CopyTo executes the copy command sql and copies the results to w. @@ -1454,6 +1480,26 @@ func (rr *ResultReader) Close() (CommandTag, error) { return rr.commandTag, rr.err } +// readUntilRowDescription ensures the ResultReader's fieldDescriptions are loaded. It does not return an error as any +// error will be stored in the ResultReader. +func (rr *ResultReader) readUntilRowDescription() { + for !rr.commandConcluded { + // Peek before receive to avoid consuming a DataRow if the result set does not include a RowDescription method. + // This should never happen under normal pgconn usage, but it is possible if SendBytes and ReceiveResults are + // manually used to construct a query that does not issue a describe statement. + msg, _ := rr.pgConn.peekMessage() + if _, ok := msg.(*pgproto3.DataRow); ok { + return + } + + // Consume the message + msg, _ = rr.receiveMessage() + if _, ok := msg.(*pgproto3.RowDescription); ok { + return + } + } +} + func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) { if rr.multiResultReader == nil { msg, err = rr.pgConn.receiveMessage() diff --git a/pgconn_test.go b/pgconn_test.go index f6750a60..24200e73 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -481,6 +481,34 @@ func TestConnExecMultipleQueries(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnExecMultipleQueriesEagerFieldDescriptions(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + mrr := pgConn.Exec(context.Background(), "select 'Hello, world' as msg; select 1 as num") + + require.True(t, mrr.NextResult()) + require.Len(t, mrr.ResultReader().FieldDescriptions(), 1) + assert.Equal(t, []byte("msg"), mrr.ResultReader().FieldDescriptions()[0].Name) + _, err = mrr.ResultReader().Close() + require.NoError(t, err) + + require.True(t, mrr.NextResult()) + require.Len(t, mrr.ResultReader().FieldDescriptions(), 1) + assert.Equal(t, []byte("num"), mrr.ResultReader().FieldDescriptions()[0].Name) + _, err = mrr.ResultReader().Close() + require.NoError(t, err) + + require.False(t, mrr.NextResult()) + + require.NoError(t, mrr.Close()) + + ensureConnValid(t, pgConn) +} + func TestConnExecMultipleQueriesError(t *testing.T) { t.Parallel() @@ -578,7 +606,10 @@ func TestConnExecParams(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) - result := pgConn.ExecParams(context.Background(), "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil) + result := pgConn.ExecParams(context.Background(), "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil) + require.Len(t, result.FieldDescriptions(), 1) + assert.Equal(t, []byte("msg"), result.FieldDescriptions()[0].Name) + rowCount := 0 for result.NextRow() { rowCount += 1 @@ -734,13 +765,16 @@ func TestConnExecPrepared(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) - psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) + psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text as msg", nil) require.NoError(t, err) require.NotNil(t, psd) assert.Len(t, psd.ParamOIDs, 1) assert.Len(t, psd.Fields, 1) result := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) + require.Len(t, result.FieldDescriptions(), 1) + assert.Equal(t, []byte("msg"), result.FieldDescriptions()[0].Name) + rowCount := 0 for result.NextRow() { rowCount += 1 From b6b3a8631050ce8a4398a68c4c165269c1a14450 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 5 Sep 2020 13:26:56 -0500 Subject: [PATCH 218/290] Update CI Go versions --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 0371101f..95dce226 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,8 @@ language: go go: + - 1.15.x - 1.14.x - - 1.13.x - tip git: From be69c1c10b10bcaeb5cb7d1e7b72022060c4222d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 10 Sep 2020 19:40:46 -0500 Subject: [PATCH 219/290] Fix parseDSNSettings with bad backslash fixes #49 --- config.go | 3 +++ config_test.go | 7 +++++++ 2 files changed, 10 insertions(+) diff --git a/config.go b/config.go index b2583546..b05727ca 100644 --- a/config.go +++ b/config.go @@ -506,6 +506,9 @@ func parseDSNSettings(s string) (map[string]string, error) { } if s[end] == '\\' { end++ + if end == len(s) { + return nil, errors.New("invalid backslash") + } } } val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) diff --git a/config_test.go b/config_test.go index 264eb299..d322f65a 100644 --- a/config_test.go +++ b/config_test.go @@ -539,6 +539,13 @@ func TestParseConfigDSNLeadingEqual(t *testing.T) { require.Error(t, err) } +// https://github.com/jackc/pgconn/issues/49 +func TestParseConfigDSNTrailingBackslash(t *testing.T) { + _, err := pgconn.ParseConfig(`x=x\`) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid backslash") +} + func TestConfigCopyReturnsEqualConfig(t *testing.T) { connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5" original, err := pgconn.ParseConfig(connString) From 28d24269e93ebc5aacc9271320226e6faae0c4dc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 26 Sep 2020 11:35:23 -0500 Subject: [PATCH 220/290] Upgrade pgproto3 to v2.0.5 --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index a74028c8..f2c10401 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.4 + github.com/jackc/pgproto3/v2 v2.0.5 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/stretchr/testify v1.5.1 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 diff --git a/go.sum b/go.sum index 61a2896b..08a11e19 100644 --- a/go.sum +++ b/go.sum @@ -36,6 +36,8 @@ github.com/jackc/pgproto3/v2 v2.0.3 h1:2S4PhE00mvdvaSiCYR1ZCmR1NAxeYfTSsqqSKxE1v github.com/jackc/pgproto3/v2 v2.0.3/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.0.4 h1:RHkX5ZUD9bl/kn0f9dYUWs1N7Nwvo1wwUYvKiR26Zco= github.com/jackc/pgproto3/v2 v2.0.4/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.0.5 h1:NUbEWPmCQZbMmYlTjVoNPhc0CfnYyz2bfUAh6A5ZVJM= +github.com/jackc/pgproto3/v2 v2.0.5/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 h1:Q3tB+ExeflWUW7AFcAhXqk40s9mnNYLk1nOkKNZ5GnU= github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= From 035868ca0c24b120f199e4bef6ac29a333e76baa Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 26 Sep 2020 11:39:23 -0500 Subject: [PATCH 221/290] Release v1.7.0 --- CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 497e00a1..e7444fcd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ -# Unreleased +# 1.7.0 (September 26, 2020) +* Exec(Params|Prepared) return ResultReader with FieldDescriptions loaded +* Add ReceiveResults (Sebastiaan Mannem) +* Fix parsing DSN connection with bad backslash * Add PgConn.CleanupDone so connection pools can determine when async close is complete # 1.6.4 (July 29, 2020) From 416f037e777022678e4138e780381b8f5b58364f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 5 Oct 2020 19:39:05 -0500 Subject: [PATCH 222/290] Fix docs for Timeout --- errors.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/errors.go b/errors.go index b746c825..164b0848 100644 --- a/errors.go +++ b/errors.go @@ -20,7 +20,7 @@ func SafeToRetry(err error) bool { } // Timeout checks if err was was caused by a timeout. To be specific, it is true if err is or was caused by a -// context.Canceled, context.Canceled or an implementer of net.Error where Timeout() is true. +// context.Canceled, context.DeadlineExceeded or an implementer of net.Error where Timeout() is true. func Timeout(err error) bool { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return true From f3f5b70a872eb9875c7bc0cbc6f7b3876c08d92b Mon Sep 17 00:00:00 2001 From: Feike Steenbergen Date: Thu, 29 Oct 2020 18:59:15 +0100 Subject: [PATCH 223/290] Ensure the example code snippet compiles again There were 2 errors when using the example code: - not enough arguments in call to pgConn.Close - no new variables on left side of := With these changes, the example works again. --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 5d14e914..6a68e230 100644 --- a/README.md +++ b/README.md @@ -15,16 +15,16 @@ pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL")) if err != nil { log.Fatalln("pgconn failed to connect:", err) } -defer pgConn.Close() +defer pgConn.Close(context.Background()) result := pgConn.ExecParams(context.Background(), "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil) for result.NextRow() { fmt.Println("User 123 has email:", string(result.Values()[0])) } -_, err := result.Close() +_, err = result.Close() if err != nil { log.Fatalln("failed reading result:", err) -}) +} ``` ## Testing From 340bfece2c33b6375414a694688d05b56f6c31af Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 29 Oct 2020 21:20:28 -0500 Subject: [PATCH 224/290] Do not asyncClose in response to a FATAL PG error This will reduce spurious server log messages on authentication failures. See https://github.com/jackc/pgconn/pull/53. --- pgconn.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index ff812069..3652cedb 100644 --- a/pgconn.go +++ b/pgconn.go @@ -485,7 +485,8 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { pgConn.parameterStatuses[msg.Name] = msg.Value case *pgproto3.ErrorResponse: if msg.Severity == "FATAL" { - pgConn.asyncClose() + pgConn.status = connStatusClosed + pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return. return nil, ErrorResponseToPgError(msg) } case *pgproto3.NoticeResponse: From 9c2888b49ee8af394820dd9dd5c66ec81cea7685 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 31 Oct 2020 16:25:01 -0500 Subject: [PATCH 225/290] Release v1.7.1 --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e7444fcd..e9753526 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.7.1 (October 31, 2020) + +* Do not asyncClose after receiving FATAL error from PostgreSQL server + # 1.7.0 (September 26, 2020) * Exec(Params|Prepared) return ResultReader with FieldDescriptions loaded From 0f17ba2cf3b307aeddfa5cd6ada0d1fe7ad3e46c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 3 Nov 2020 19:17:52 -0600 Subject: [PATCH 226/290] Fix unconstrained data value slices See https://github.com/jackc/pgx/issues/859 --- go.mod | 2 +- go.sum | 2 ++ pgconn_test.go | 26 ++++++++++++++++++++++++++ 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index f2c10401..7e578765 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.5 + github.com/jackc/pgproto3/v2 v2.0.6 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/stretchr/testify v1.5.1 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 diff --git a/go.sum b/go.sum index 08a11e19..f3eb0e08 100644 --- a/go.sum +++ b/go.sum @@ -38,6 +38,8 @@ github.com/jackc/pgproto3/v2 v2.0.4 h1:RHkX5ZUD9bl/kn0f9dYUWs1N7Nwvo1wwUYvKiR26Z github.com/jackc/pgproto3/v2 v2.0.4/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.0.5 h1:NUbEWPmCQZbMmYlTjVoNPhc0CfnYyz2bfUAh6A5ZVJM= github.com/jackc/pgproto3/v2 v2.0.5/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.0.6 h1:b1105ZGEMFe7aCvrT1Cca3VoVb4ZFMaFJLJcg/3zD+8= +github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 h1:Q3tB+ExeflWUW7AFcAhXqk40s9mnNYLk1nOkKNZ5GnU= github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= diff --git a/pgconn_test.go b/pgconn_test.go index 24200e73..b71e7d3f 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -758,6 +758,32 @@ func TestConnExecParamsEmptySQL(t *testing.T) { ensureConnValid(t, pgConn) } +// https://github.com/jackc/pgx/issues/859 +func TestResultReaderValuesHaveSameCapacityAsLength(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + result := pgConn.ExecParams(context.Background(), "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil) + require.Len(t, result.FieldDescriptions(), 1) + assert.Equal(t, []byte("msg"), result.FieldDescriptions()[0].Name) + + rowCount := 0 + for result.NextRow() { + rowCount += 1 + assert.Equal(t, "Hello, world", string(result.Values()[0])) + assert.Equal(t, len(result.Values()[0]), cap(result.Values()[0])) + } + assert.Equal(t, 1, rowCount) + commandTag, err := result.Close() + assert.Equal(t, "SELECT 1", string(commandTag)) + assert.NoError(t, err) + + ensureConnValid(t, pgConn) +} + func TestConnExecPrepared(t *testing.T) { t.Parallel() From b82b993fa8aa3fd6d8aac15689301db049d5504f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 3 Nov 2020 19:20:03 -0600 Subject: [PATCH 227/290] Release v1.7.2 --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e9753526..92b1de06 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.7.2 (November 3, 2020) + +* Fix data value slices into work buffer with capacities larger than length. + # 1.7.1 (October 31, 2020) * Do not asyncClose after receiving FATAL error from PostgreSQL server From a885de9c949c36c1359edc6de00cff0bc4b16bb1 Mon Sep 17 00:00:00 2001 From: Ethan Pailes Date: Mon, 9 Nov 2020 08:20:34 -0500 Subject: [PATCH 228/290] stmtcache: add new StatementErrored method This patch adds a new StatementErrored method to the stmtcache. This routine MUST be called by users of the cache whenever the execution of a statement results in an error. This will allow the cache to make an intelligent decision about whether or not the statement needs to be purged from the cache. --- stmtcache/lru.go | 50 ++++++++++++++++++++++++++++++ stmtcache/lru_test.go | 69 ++++++++++++++++++++++++++++++++++++++++++ stmtcache/stmtcache.go | 8 +++++ 3 files changed, 127 insertions(+) diff --git a/stmtcache/lru.go b/stmtcache/lru.go index d82ced19..2f183f90 100644 --- a/stmtcache/lru.go +++ b/stmtcache/lru.go @@ -20,6 +20,7 @@ type LRU struct { m map[string]*list.Element l *list.List psNamePrefix string + stmtsToClear []string } // NewLRU creates a new LRU. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache. @@ -41,6 +42,17 @@ func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU { // Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) { + // flush an outstanding bad statements + txStatus := c.conn.TxStatus() + if (txStatus == 'I' || txStatus == 'T') && len(c.stmtsToClear) > 0 { + for _, stmt := range c.stmtsToClear { + err := c.clearStmt(ctx, stmt) + if err != nil { + return nil, err + } + } + } + if el, ok := c.m[sql]; ok { c.l.MoveToFront(el) return el.Value.(*pgconn.StatementDescription), nil @@ -76,6 +88,44 @@ func (c *LRU) Clear(ctx context.Context) error { return nil } +func (c *LRU) StatementErrored(ctx context.Context, sql string, err error) error { + pgErr, ok := err.(*pgconn.PgError) + if !ok { + // we don't know how to handle this error + return nil + } + + isInvalidCachedPlanError := pgErr.Severity == "ERROR" && + pgErr.Code == "0A000" && + pgErr.Message == "cached plan must not change result type" + if !isInvalidCachedPlanError { + // only flush if a plan has been changed out from under us + return nil + } + + c.stmtsToClear = append(c.stmtsToClear, sql) + + return nil +} + +func (c *LRU) clearStmt(ctx context.Context, sql string) error { + elem, inMap := c.m[sql] + if !inMap { + // The statement probably fell off the back of the list. In that case, we've + // ensured that it isn't in the cache, so we can declare victory. + return nil + } + + c.l.Remove(elem) + + psd := elem.Value.(*pgconn.StatementDescription) + delete(c.m, psd.SQL) + if c.mode == ModePrepare { + return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close() + } + return nil +} + // Len returns the number of cached prepared statement descriptions. func (c *LRU) Len() int { return c.l.Len() diff --git a/stmtcache/lru_test.go b/stmtcache/lru_test.go index d2902dbb..75925509 100644 --- a/stmtcache/lru_test.go +++ b/stmtcache/lru_test.go @@ -59,6 +59,75 @@ func TestLRUModePrepare(t *testing.T) { require.Empty(t, fetchServerStatements(t, ctx, conn)) } +func TestLRUStmtInvalidation(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer conn.Close(ctx) + + // we construct a fake error because its not super straightforward to actually call + // a prepared statement from the LRU cache without the helper routines which live + // in pgx proper. + fakeInvalidCachePlanError := &pgconn.PgError{ + Severity: "ERROR", + Code: "0A000", + Message: "cached plan must not change result type", + } + + cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2) + + // + // outside of a transaction, we eagerly flush the statement + // + + _, err = cache.Get(ctx, "select 1") + require.NoError(t, err) + require.EqualValues(t, 1, cache.Len()) + require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) + + err = cache.StatementErrored(ctx, "select 1", fakeInvalidCachePlanError) + require.NoError(t, err) + _, err = cache.Get(ctx, "select 2") + require.NoError(t, err) + require.EqualValues(t, 1, cache.Len()) + require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn)) + + err = cache.Clear(ctx) + require.NoError(t, err) + + // + // within an errored transaction, we defer the flush to after the first get + // that happens after the transaction is rolled back + // + + _, err = cache.Get(ctx, "select 1") + require.NoError(t, err) + require.EqualValues(t, 1, cache.Len()) + require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) + + res := conn.Exec(ctx, "begin") + require.NoError(t, res.Close()) + require.Equal(t, byte('T'), conn.TxStatus()) + + res = conn.Exec(ctx, "selec") + require.Error(t, res.Close()) + require.Equal(t, byte('E'), conn.TxStatus()) + + err = cache.StatementErrored(ctx, "select 1", fakeInvalidCachePlanError) + require.EqualValues(t, 1, cache.Len()) + + res = conn.Exec(ctx, "rollback") + require.NoError(t, res.Close()) + + _, err = cache.Get(ctx, "select 2") + require.EqualValues(t, 1, cache.Len()) + require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn)) +} + func TestLRUModePrepareStress(t *testing.T) { t.Parallel() diff --git a/stmtcache/stmtcache.go b/stmtcache/stmtcache.go index 96215799..6e88ba54 100644 --- a/stmtcache/stmtcache.go +++ b/stmtcache/stmtcache.go @@ -20,6 +20,14 @@ type Cache interface { // Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session. Clear(ctx context.Context) error + // StatementErrored informs the cache that the given statement resulted in an error when it + // was last used against the database. In some cases, this will cause the cache to flush + // the statement from the cache. It will only do so when the underlying `*pgconn.PgConn` + // is not currently in a transaction. If the connection is in the middle of a transaction, + // the bad statement will instead be flushed during the next call to Get that occurrs outside + // of a transaction. + StatementErrored(ctx context.Context, sql string, err error) error + // Len returns the number of cached prepared statement descriptions. Len() int From 426124b32fb35daaee23175487b5a4117e38244e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 11 Nov 2020 15:48:49 -0600 Subject: [PATCH 229/290] Add stmtcache.LRU test thjat integrates over the database --- stmtcache/lru_test.go | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/stmtcache/lru_test.go b/stmtcache/lru_test.go index 75925509..58a0c378 100644 --- a/stmtcache/lru_test.go +++ b/stmtcache/lru_test.go @@ -128,6 +128,44 @@ func TestLRUStmtInvalidation(t *testing.T) { require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn)) } +func TestLRUStmtInvalidationIntegration(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer conn.Close(ctx) + + cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2) + + result := conn.ExecParams(ctx, "create temporary table stmtcache_table (a text)", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + sql := "select * from stmtcache_table" + sd1, err := cache.Get(ctx, sql) + require.NoError(t, err) + + result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read() + require.NoError(t, result.Err) + + result = conn.ExecParams(ctx, "alter table stmtcache_table add column b text", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read() + require.EqualError(t, result.Err, "ERROR: cached plan must not change result type (SQLSTATE 0A000)") + + cache.StatementErrored(ctx, sql, result.Err) + + sd2, err := cache.Get(ctx, sql) + require.NoError(t, err) + require.NotEqual(t, sd1.Name, sd2.Name) + + result = conn.ExecPrepared(ctx, sd2.Name, nil, nil, nil).Read() + require.NoError(t, result.Err) +} + func TestLRUModePrepareStress(t *testing.T) { t.Parallel() From cba610c245265ff50ea3c56a9961da218ed7d730 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 11 Nov 2020 15:52:59 -0600 Subject: [PATCH 230/290] StatementErrored does not need context nor return an error --- stmtcache/lru.go | 14 ++++---------- stmtcache/lru_test.go | 7 +++---- stmtcache/stmtcache.go | 10 ++++------ 3 files changed, 11 insertions(+), 20 deletions(-) diff --git a/stmtcache/lru.go b/stmtcache/lru.go index 2f183f90..f58f2ac3 100644 --- a/stmtcache/lru.go +++ b/stmtcache/lru.go @@ -88,24 +88,18 @@ func (c *LRU) Clear(ctx context.Context) error { return nil } -func (c *LRU) StatementErrored(ctx context.Context, sql string, err error) error { +func (c *LRU) StatementErrored(sql string, err error) { pgErr, ok := err.(*pgconn.PgError) if !ok { - // we don't know how to handle this error - return nil + return } isInvalidCachedPlanError := pgErr.Severity == "ERROR" && pgErr.Code == "0A000" && pgErr.Message == "cached plan must not change result type" - if !isInvalidCachedPlanError { - // only flush if a plan has been changed out from under us - return nil + if isInvalidCachedPlanError { + c.stmtsToClear = append(c.stmtsToClear, sql) } - - c.stmtsToClear = append(c.stmtsToClear, sql) - - return nil } func (c *LRU) clearStmt(ctx context.Context, sql string) error { diff --git a/stmtcache/lru_test.go b/stmtcache/lru_test.go index 58a0c378..2d620905 100644 --- a/stmtcache/lru_test.go +++ b/stmtcache/lru_test.go @@ -89,8 +89,7 @@ func TestLRUStmtInvalidation(t *testing.T) { require.EqualValues(t, 1, cache.Len()) require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) - err = cache.StatementErrored(ctx, "select 1", fakeInvalidCachePlanError) - require.NoError(t, err) + cache.StatementErrored("select 1", fakeInvalidCachePlanError) _, err = cache.Get(ctx, "select 2") require.NoError(t, err) require.EqualValues(t, 1, cache.Len()) @@ -117,7 +116,7 @@ func TestLRUStmtInvalidation(t *testing.T) { require.Error(t, res.Close()) require.Equal(t, byte('E'), conn.TxStatus()) - err = cache.StatementErrored(ctx, "select 1", fakeInvalidCachePlanError) + cache.StatementErrored("select 1", fakeInvalidCachePlanError) require.EqualValues(t, 1, cache.Len()) res = conn.Exec(ctx, "rollback") @@ -156,7 +155,7 @@ func TestLRUStmtInvalidationIntegration(t *testing.T) { result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read() require.EqualError(t, result.Err, "ERROR: cached plan must not change result type (SQLSTATE 0A000)") - cache.StatementErrored(ctx, sql, result.Err) + cache.StatementErrored(sql, result.Err) sd2, err := cache.Get(ctx, sql) require.NoError(t, err) diff --git a/stmtcache/stmtcache.go b/stmtcache/stmtcache.go index 6e88ba54..d083e1b4 100644 --- a/stmtcache/stmtcache.go +++ b/stmtcache/stmtcache.go @@ -21,12 +21,10 @@ type Cache interface { Clear(ctx context.Context) error // StatementErrored informs the cache that the given statement resulted in an error when it - // was last used against the database. In some cases, this will cause the cache to flush - // the statement from the cache. It will only do so when the underlying `*pgconn.PgConn` - // is not currently in a transaction. If the connection is in the middle of a transaction, - // the bad statement will instead be flushed during the next call to Get that occurrs outside - // of a transaction. - StatementErrored(ctx context.Context, sql string, err error) error + // was last used against the database. In some cases, this will cause the cache to maer that + // statement as bad. The bad statement will instead be flushed during the next call to Get + // that occurs outside of a failed transaction. + StatementErrored(sql string, err error) // Len returns the number of cached prepared statement descriptions. Len() int From 3742d6209e5f0a4b70b173477c6c40a0aaf21ce9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 3 Dec 2020 19:12:18 -0600 Subject: [PATCH 231/290] Release v1.8.0 --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 92b1de06..787853b2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.8.0 (December 3, 2020) + +* Add StatementErrored method to stmtcache.Cache. This allows the cache to purge invalidated prepared statements. (Ethan Pailes) + # 1.7.2 (November 3, 2020) * Fix data value slices into work buffer with capacities larger than length. From a581247a126989e47bee6507f0b0f3c0ac9b4167 Mon Sep 17 00:00:00 2001 From: "ip.novikov" Date: Sat, 5 Dec 2020 15:28:01 +0300 Subject: [PATCH 232/290] Add check for url with broken password replace broken password in parseConfigError message --- errors.go | 2 ++ errors_test.go | 10 ++++++++++ 2 files changed, 12 insertions(+) diff --git a/errors.go b/errors.go index 164b0848..369c8ca3 100644 --- a/errors.go +++ b/errors.go @@ -178,6 +178,8 @@ func redactPW(connString string) string { connString = quotedDSN.ReplaceAllLiteralString(connString, "password=xxxxx") plainDSN := regexp.MustCompile(`password=[^ ]*`) connString = plainDSN.ReplaceAllLiteralString(connString, "password=xxxxx") + brokenURL := regexp.MustCompile(`:\w.*@`) + connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@") return connString } diff --git a/errors_test.go b/errors_test.go index bef835f8..1bff3656 100644 --- a/errors_test.go +++ b/errors_test.go @@ -33,6 +33,16 @@ func TestConfigError(t *testing.T) { err: pgconn.NewParseConfigError("postgresql://foo::pasword@host:1:", "msg", nil), expectedMsg: "cannot parse `postgresql://foo:xxxxx@host:1:`: msg", }, + { + name: "weird url with slash in password", + err: pgconn.NewParseConfigError("postgres://user:pass/word@host:5432/db_name", "msg", nil), + expectedMsg: "cannot parse `postgres://user:xxxxxx@host:5432/db_name`: msg", + }, + { + name: "url without password", + err: pgconn.NewParseConfigError("postgresql://other@host/db", "msg", nil), + expectedMsg: "cannot parse `postgresql://other@host/db`: msg", + }, } for _, tt := range tests { tt := tt From e0d22c1100233860131b45abe453a1c196391f98 Mon Sep 17 00:00:00 2001 From: "ip.novikov" Date: Sat, 5 Dec 2020 22:11:52 +0300 Subject: [PATCH 233/290] improve regexp get shortest sequence between : and @ --- errors.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/errors.go b/errors.go index 369c8ca3..b37b1d97 100644 --- a/errors.go +++ b/errors.go @@ -178,7 +178,7 @@ func redactPW(connString string) string { connString = quotedDSN.ReplaceAllLiteralString(connString, "password=xxxxx") plainDSN := regexp.MustCompile(`password=[^ ]*`) connString = plainDSN.ReplaceAllLiteralString(connString, "password=xxxxx") - brokenURL := regexp.MustCompile(`:\w.*@`) + brokenURL := regexp.MustCompile(`:[^:@]+?@`) connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@") return connString } From e276d9b832bfd155cb35b9080b21827bc5d0f996 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Wed, 23 Dec 2020 12:21:34 -0600 Subject: [PATCH 234/290] Add more documentation to TxStatus --- pgconn.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index 3652cedb..53e32252 100644 --- a/pgconn.go +++ b/pgconn.go @@ -512,7 +512,14 @@ func (pgConn *PgConn) PID() uint32 { return pgConn.pid } -// TxStatus returns the current TxStatus as reported by the server. +// TxStatus returns the current TxStatus as reported by the server in the ReadyForQuery message. +// +// Possible return values: +// 'I' - idle / not in transaction +// 'T' - in a transaction +// 'E' - in a failed transaction +// +// See https://www.postgresql.org/docs/current/protocol-message-formats.html. func (pgConn *PgConn) TxStatus() byte { return pgConn.txStatus } From 724bf94515c9c2f671860d27d278cb8593f6055b Mon Sep 17 00:00:00 2001 From: Moshe Katz Date: Sat, 2 Jan 2021 22:51:02 -0500 Subject: [PATCH 235/290] use proper pgpass location on Windows --- config.go | 43 -------------------------------------- config_test.go | 18 ++++++++++++++-- defaults.go | 51 +++++++++++++++++++++++++++++++++++++++++++++ defaults_windows.go | 46 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 113 insertions(+), 45 deletions(-) create mode 100644 defaults.go create mode 100644 defaults_windows.go diff --git a/config.go b/config.go index b05727ca..e4ee244a 100644 --- a/config.go +++ b/config.go @@ -11,7 +11,6 @@ import ( "net" "net/url" "os" - "os/user" "path/filepath" "strconv" "strings" @@ -338,48 +337,6 @@ func ParseConfig(connString string) (*Config, error) { return config, nil } -func defaultSettings() map[string]string { - settings := make(map[string]string) - - settings["host"] = defaultHost() - settings["port"] = "5432" - - // Default to the OS user name. Purposely ignoring err getting user name from - // OS. The client application will simply have to specify the user in that - // case (which they typically will be doing anyway). - user, err := user.Current() - if err == nil { - settings["user"] = user.Username - settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass") - settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf") - } - - settings["target_session_attrs"] = "any" - - settings["min_read_buffer_size"] = "8192" - - return settings -} - -// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost -// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it -// checks the existence of common locations. -func defaultHost() string { - candidatePaths := []string{ - "/var/run/postgresql", // Debian - "/private/tmp", // OSX - homebrew - "/tmp", // standard PostgreSQL - } - - for _, path := range candidatePaths { - if _, err := os.Stat(path); err == nil { - return path - } - } - - return "localhost" -} - func mergeSettings(settingSets ...map[string]string) map[string]string { settings := make(map[string]string) diff --git a/config_test.go b/config_test.go index d322f65a..f6391672 100644 --- a/config_test.go +++ b/config_test.go @@ -7,6 +7,8 @@ import ( "io/ioutil" "os" "os/user" + "runtime" + "strings" "testing" "time" @@ -21,7 +23,13 @@ func TestParseConfig(t *testing.T) { var osUserName string osUser, err := user.Current() if err == nil { - osUserName = osUser.Username + // Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`, + // but the libpq default is just the `user` portion, so we strip off the first part. + if runtime.GOOS == "windows" && strings.Contains(osUser.Username, "\\") { + osUserName = osUser.Username[strings.LastIndex(osUser.Username, "\\")+1:] + } else { + osUserName = osUser.Username + } } tests := []struct { @@ -630,7 +638,13 @@ func TestParseConfigEnvLibpq(t *testing.T) { var osUserName string osUser, err := user.Current() if err == nil { - osUserName = osUser.Username + // Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`, + // but the libpq default is just the `user` portion, so we strip off the first part. + if runtime.GOOS == "windows" && strings.Contains(osUser.Username, "\\") { + osUserName = osUser.Username[strings.LastIndex(osUser.Username, "\\")+1:] + } else { + osUserName = osUser.Username + } } pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT"} diff --git a/defaults.go b/defaults.go new file mode 100644 index 00000000..d3313481 --- /dev/null +++ b/defaults.go @@ -0,0 +1,51 @@ +// +build !windows + +package pgconn + +import ( + "os" + "os/user" + "path/filepath" +) + +func defaultSettings() map[string]string { + settings := make(map[string]string) + + settings["host"] = defaultHost() + settings["port"] = "5432" + + // Default to the OS user name. Purposely ignoring err getting user name from + // OS. The client application will simply have to specify the user in that + // case (which they typically will be doing anyway). + user, err := user.Current() + if err == nil { + settings["user"] = user.Username + settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass") + settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf") + } + + settings["target_session_attrs"] = "any" + + settings["min_read_buffer_size"] = "8192" + + return settings +} + +// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost +// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it +// checks the existence of common locations. +func defaultHost() string { + candidatePaths := []string{ + "/var/run/postgresql", // Debian + "/private/tmp", // OSX - homebrew + "/tmp", // standard PostgreSQL + } + + for _, path := range candidatePaths { + if _, err := os.Stat(path); err == nil { + return path + } + } + + return "localhost" +} diff --git a/defaults_windows.go b/defaults_windows.go new file mode 100644 index 00000000..55243700 --- /dev/null +++ b/defaults_windows.go @@ -0,0 +1,46 @@ +package pgconn + +import ( + "os" + "os/user" + "path/filepath" + "strings" +) + +func defaultSettings() map[string]string { + settings := make(map[string]string) + + settings["host"] = defaultHost() + settings["port"] = "5432" + + // Default to the OS user name. Purposely ignoring err getting user name from + // OS. The client application will simply have to specify the user in that + // case (which they typically will be doing anyway). + user, err := user.Current() + appData := os.Getenv("APPDATA") + if err == nil { + // Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`, + // but the libpq default is just the `user` portion, so we strip off the first part. + username := user.Username + if strings.Contains(username, "\\") { + username = username[strings.LastIndex(username, "\\")+1:] + } + + settings["user"] = username + settings["passfile"] = filepath.Join(appData, "postgresql", "pgpass.conf") + settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf") + } + + settings["target_session_attrs"] = "any" + + settings["min_read_buffer_size"] = "8192" + + return settings +} + +// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost +// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it +// checks the existence of common locations. +func defaultHost() string { + return "localhost" +} From 120139a206078c030cdab77ee1d05984bb503fe5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 14 Jan 2021 18:22:18 -0600 Subject: [PATCH 236/290] Add link to PG docs for connString format fixes #62 --- config.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/config.go b/config.go index e4ee244a..38e94f26 100644 --- a/config.go +++ b/config.go @@ -112,9 +112,10 @@ func NetworkAddress(host string, port uint16) (network, address string) { } // ParseConfig builds a *Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same -// defaults as libpq (e.g. port=5432) and understands most PG* environment variables. connString may be a URL or a DSN. -// It also may be empty to only read from the environment. If a password is not supplied it will attempt to read the -// .pgpass file. +// defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely matches +// the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style). See +// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be +// empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file. // // # Example DSN // user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca From 7d8845a9d8f32c059555e20783828da2534e52f8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Jan 2021 12:47:34 -0600 Subject: [PATCH 237/290] Initial import from pgtype --- .github/workflows/ci.yml | 52 ++++++++++++++++++++++++++++++++++++++++ README.md | 1 + 2 files changed, 53 insertions(+) create mode 100644 .github/workflows/ci.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..27ea2d4d --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,52 @@ +name: CI + +on: + push: + branches: [ github-ci-wip ] + pull_request: + branches: [ github-ci-wip ] + +jobs: + + test: + name: Test + runs-on: ubuntu-latest + + services: + postgres: + image: postgres + env: + POSTGRES_PASSWORD: secret + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 + + steps: + + - name: Set up Go 1.x + uses: actions/setup-go@v2 + with: + go-version: ^1.13 + + - name: Check out code into the Go module directory + uses: actions/checkout@v2 + + - name: Create hstore extension + run: psql -c 'create extension hstore' + env: + PGHOST: localhost + PGUSER: postgres + PGPASSWORD: secret + PGSSLMODE: disable + + - name: Test + run: go test -v ./... + env: + PGHOST: localhost + PGUSER: postgres + PGPASSWORD: secret + PGSSLMODE: disable diff --git a/README.md b/README.md index 6a68e230..d7238c39 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ [![](https://godoc.org/github.com/jackc/pgconn?status.svg)](https://godoc.org/github.com/jackc/pgconn) [![Build Status](https://travis-ci.org/jackc/pgconn.svg)](https://travis-ci.org/jackc/pgconn) +![CI](https://github.com/jackc/pgtype/workflows/CI/badge.svg) # pgconn From 63bcdfde61d2395e45710c959bed950feeaa5bde Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Jan 2021 12:48:58 -0600 Subject: [PATCH 238/290] Fix CI link --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d7238c39..feead016 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ [![](https://godoc.org/github.com/jackc/pgconn?status.svg)](https://godoc.org/github.com/jackc/pgconn) [![Build Status](https://travis-ci.org/jackc/pgconn.svg)](https://travis-ci.org/jackc/pgconn) -![CI](https://github.com/jackc/pgtype/workflows/CI/badge.svg) +![CI](https://github.com/jackc/pgconn/workflows/CI/badge.svg) # pgconn From 6c2a423dbc25d634270b04ecaac7a1d644037945 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Jan 2021 12:58:25 -0600 Subject: [PATCH 239/290] Try to debug failing CI test --- config_test.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/config_test.go b/config_test.go index f6391672..7b8b8937 100644 --- a/config_test.go +++ b/config_test.go @@ -568,7 +568,16 @@ func TestConfigCopyOriginalConfigDidNotChange(t *testing.T) { original, err := pgconn.ParseConfig(connString) require.NoError(t, err) + fmt.Printf("original: %#v\n", original) + for i, f := range original.Fallbacks { + fmt.Printf("original fallback %d: %#v\n", i, f) + } + copied := original.Copy() + fmt.Printf("copied: %#v\n", copied) + for i, f := range copied.Fallbacks { + fmt.Printf("copied fallback %d: %#v\n", i, f) + } assertConfigsEqual(t, original, copied, "Test Config.Copy() returns equal config") copied.Port = uint16(5433) From a9c2b5c3cbb210352546ca3763dd259d7b752771 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Jan 2021 13:01:27 -0600 Subject: [PATCH 240/290] Revert "Try to debug failing CI test" This reverts commit 6c2a423dbc25d634270b04ecaac7a1d644037945. --- config_test.go | 9 --------- 1 file changed, 9 deletions(-) diff --git a/config_test.go b/config_test.go index 7b8b8937..f6391672 100644 --- a/config_test.go +++ b/config_test.go @@ -568,16 +568,7 @@ func TestConfigCopyOriginalConfigDidNotChange(t *testing.T) { original, err := pgconn.ParseConfig(connString) require.NoError(t, err) - fmt.Printf("original: %#v\n", original) - for i, f := range original.Fallbacks { - fmt.Printf("original fallback %d: %#v\n", i, f) - } - copied := original.Copy() - fmt.Printf("copied: %#v\n", copied) - for i, f := range copied.Fallbacks { - fmt.Printf("copied fallback %d: %#v\n", i, f) - } assertConfigsEqual(t, original, copied, "Test Config.Copy() returns equal config") copied.Port = uint16(5433) From 74517d73154ecdf045aad3fedcf47d66499b5548 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Jan 2021 13:03:56 -0600 Subject: [PATCH 241/290] Fix test when PGSSLMODE=disable When PGSSLMODE=disable no fallback config was created which would cause the check that fallbacks are deep copied to crash on: copied.Fallbacks[0].Port = uint16(5433) --- config_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config_test.go b/config_test.go index f6391672..e869422d 100644 --- a/config_test.go +++ b/config_test.go @@ -564,7 +564,7 @@ func TestConfigCopyReturnsEqualConfig(t *testing.T) { } func TestConfigCopyOriginalConfigDidNotChange(t *testing.T) { - connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5" + connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5&sslmode=prefer" original, err := pgconn.ParseConfig(connString) require.NoError(t, err) From eb322859067bf699fbfe8e8a5a8c6c89a1f5ff7e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Jan 2021 16:22:38 -0600 Subject: [PATCH 242/290] Use native PostgreSQL package Also remove travis integration. --- .github/workflows/ci.yml | 32 ++++-------- .travis.yml | 49 ------------------- {travis => ci}/script.bash | 0 .../before_install.bash => ci/setup_test.bash | 12 +++++ travis/before_script.bash | 17 ------- 5 files changed, 21 insertions(+), 89 deletions(-) delete mode 100644 .travis.yml rename {travis => ci}/script.bash (100%) rename travis/before_install.bash => ci/setup_test.bash (73%) delete mode 100755 travis/before_script.bash diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 27ea2d4d..3e3c1ed8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,19 +12,6 @@ jobs: name: Test runs-on: ubuntu-latest - services: - postgres: - image: postgres - env: - POSTGRES_PASSWORD: secret - options: >- - --health-cmd pg_isready - --health-interval 10s - --health-timeout 5s - --health-retries 5 - ports: - - 5432:5432 - steps: - name: Set up Go 1.x @@ -35,18 +22,17 @@ jobs: - name: Check out code into the Go module directory uses: actions/checkout@v2 - - name: Create hstore extension - run: psql -c 'create extension hstore' + - name: Setup database server for testing + run: ci/setup_test.bash env: - PGHOST: localhost - PGUSER: postgres - PGPASSWORD: secret - PGSSLMODE: disable + PGVERSION: 12 - name: Test run: go test -v ./... env: - PGHOST: localhost - PGUSER: postgres - PGPASSWORD: secret - PGSSLMODE: disable + PGX_TEST_CONN_STRING: postgres://pgx_md5:secret@127.0.0.1/pgx_test + PGX_TEST_UNIX_SOCKET_CONN_STRING: "host=/var/run/postgresql dbname=pgx_test" + PGX_TEST_TCP_CONN_STRING: postgres://pgx_md5:secret@127.0.0.1/pgx_test + PGX_TEST_TLS_CONN_STRING: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + PGX_TEST_MD5_PASSWORD_CONN_STRING: postgres://pgx_md5:secret@127.0.0.1/pgx_test + PGX_TEST_PLAIN_PASSWORD_CONN_STRING: postgres://pgx_pw:secret@127.0.0.1/pgx_test diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 95dce226..00000000 --- a/.travis.yml +++ /dev/null @@ -1,49 +0,0 @@ -language: go - -go: - - 1.15.x - - 1.14.x - - tip - -git: - depth: 1 - -# Derived from https://github.com/lib/pq/blob/master/.travis.yml -before_install: - - ./travis/before_install.bash - -env: - global: - - GO111MODULE=on - - GOPROXY=https://proxy.golang.org - - GOFLAGS=-mod=readonly - - PGX_TEST_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test - - PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql dbname=pgx_test" - - PGX_TEST_TCP_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test - - PGX_TEST_TLS_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require - - PGX_TEST_MD5_PASSWORD_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test - - PGX_TEST_PLAIN_PASSWORD_CONN_STRING=postgres://pgx_pw:secret@127.0.0.1/pgx_test - matrix: - - CRATEVERSION=2.1 PGX_TEST_CRATEDB_CONN_STRING="host=127.0.0.1 port=6543 user=pgx dbname=pgx_test" - - PGVERSION=12 - - PGVERSION=11 - - PGVERSION=10 - - PGVERSION=9.6 - - PGVERSION=9.5 - -cache: - directories: - - $HOME/.cache/go-build - - $HOME/gopath/pkg/mod - -before_script: - - ./travis/before_script.bash - -install: go mod download - -script: - - ./travis/script.bash - -matrix: - allow_failures: - - go: tip diff --git a/travis/script.bash b/ci/script.bash similarity index 100% rename from travis/script.bash rename to ci/script.bash diff --git a/travis/before_install.bash b/ci/setup_test.bash similarity index 73% rename from travis/before_install.bash rename to ci/setup_test.bash index 23c7d9cf..78e30383 100755 --- a/travis/before_install.bash +++ b/ci/setup_test.bash @@ -24,6 +24,18 @@ then echo "max_replication_slots=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf fi sudo /etc/init.d/postgresql restart + + # The tricky test user, below, has to actually exist so that it can be used in a test + # of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles. + psql -U postgres -c 'create database pgx_test' + psql -U postgres pgx_test -c 'create extension hstore' + psql -U postgres pgx_test -c 'create domain uint64 as numeric(20,0)' + psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'" + psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'" + psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'" + psql -U postgres -c "create user travis" + psql -U postgres -c "create user pgx_replication with replication password 'secret'" + psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'" fi if [ "${CRATEVERSION-}" != "" ] diff --git a/travis/before_script.bash b/travis/before_script.bash deleted file mode 100755 index 923b7d06..00000000 --- a/travis/before_script.bash +++ /dev/null @@ -1,17 +0,0 @@ -#!/usr/bin/env bash -set -eux - -if [ "${PGVERSION-}" != "" ] -then - # The tricky test user, below, has to actually exist so that it can be used in a test - # of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles. - psql -U postgres -c 'create database pgx_test' - psql -U postgres pgx_test -c 'create extension hstore' - psql -U postgres pgx_test -c 'create domain uint64 as numeric(20,0)' - psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'" - psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'" - psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'" - psql -U postgres -c "create user travis" - psql -U postgres -c "create user pgx_replication with replication password 'secret'" - psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'" -fi From c107f909a2aba0b35ed9817cafc6acf872861a89 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Jan 2021 16:28:27 -0600 Subject: [PATCH 243/290] Create user for Unix domain socket --- ci/setup_test.bash | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/setup_test.bash b/ci/setup_test.bash index 78e30383..144e93fd 100755 --- a/ci/setup_test.bash +++ b/ci/setup_test.bash @@ -33,7 +33,7 @@ then psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'" psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'" psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'" - psql -U postgres -c "create user travis" + psql -U postgres -c "create user `whoami`" psql -U postgres -c "create user pgx_replication with replication password 'secret'" psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'" fi From c10c60cad5d4a336c46cfb324c7879185266f34b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Jan 2021 16:38:58 -0600 Subject: [PATCH 244/290] Add build matrix for Go and PG --- .github/workflows/ci.yml | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3e3c1ed8..b37ca273 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,12 +12,17 @@ jobs: name: Test runs-on: ubuntu-latest + strategy: + matrix: + go_version: [1.14, 1.15] + pg_version: [9.6, 10, 11, 12, 13] + steps: - name: Set up Go 1.x uses: actions/setup-go@v2 with: - go-version: ^1.13 + go-version: ${{ matrix.go_version }} - name: Check out code into the Go module directory uses: actions/checkout@v2 @@ -25,7 +30,7 @@ jobs: - name: Setup database server for testing run: ci/setup_test.bash env: - PGVERSION: 12 + PGVERSION: ${{ matrix.pg_version }} - name: Test run: go test -v ./... From ed0090f61043e2bce64be49da76fe6b7e4a1fbca Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Jan 2021 16:44:17 -0600 Subject: [PATCH 245/290] Use race detector on Github CI --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b37ca273..5acb0eea 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -33,7 +33,7 @@ jobs: PGVERSION: ${{ matrix.pg_version }} - name: Test - run: go test -v ./... + run: go test -v -race ./... env: PGX_TEST_CONN_STRING: postgres://pgx_md5:secret@127.0.0.1/pgx_test PGX_TEST_UNIX_SOCKET_CONN_STRING: "host=/var/run/postgresql dbname=pgx_test" From 609cd81d64b4689ca9126322fd54f5eecaaf909f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Jan 2021 16:47:51 -0600 Subject: [PATCH 246/290] Remove obsolete Travis badge --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index feead016..c651f483 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,4 @@ [![](https://godoc.org/github.com/jackc/pgconn?status.svg)](https://godoc.org/github.com/jackc/pgconn) -[![Build Status](https://travis-ci.org/jackc/pgconn.svg)](https://travis-ci.org/jackc/pgconn) ![CI](https://github.com/jackc/pgconn/workflows/CI/badge.svg) # pgconn From 9cf57526250f6cd3e6cbf4fd7269c882e66898ce Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 30 Jan 2021 16:48:51 -0600 Subject: [PATCH 247/290] Change Github CI to run on master --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5acb0eea..862235ae 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,9 +2,9 @@ name: CI on: push: - branches: [ github-ci-wip ] + branches: [ master ] pull_request: - branches: [ github-ci-wip ] + branches: [ master ] jobs: From a78ab5bdcda1e98bb43673be2ffda39435b91fda Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Feb 2021 09:39:42 -0600 Subject: [PATCH 248/290] Test should abort if cannot setup database --- pgconn_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgconn_test.go b/pgconn_test.go index b71e7d3f..76156420 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -990,7 +990,7 @@ func TestConnExecBatchDeferredError(t *testing.T) { insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() - assert.NoError(t, err) + require.NoError(t, err) batch := &pgconn.Batch{} From d05c52217a6e39cdc3ad75808786189aace7b71b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Feb 2021 10:47:22 -0600 Subject: [PATCH 249/290] Initial CockroachDB testing --- pgconn_test.go | 90 ++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 80 insertions(+), 10 deletions(-) diff --git a/pgconn_test.go b/pgconn_test.go index 76156420..564a0c51 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -524,9 +524,18 @@ func TestConnExecMultipleQueriesError(t *testing.T) { t.Errorf("unexpected error: %v", err) } - assert.Len(t, results, 1) - assert.Len(t, results[0].Rows, 1) - assert.Equal(t, "1", string(results[0].Rows[0][0])) + if pgConn.ParameterStatus("crdb_version") != "" { + // CockroachDB starts the second query result set and then sends the divide by zero error. + require.Len(t, results, 2) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "1", string(results[0].Rows[0][0])) + assert.Len(t, results[1].Rows, 0) + } else { + // PostgreSQL sends the divide by zero and never sends the second query result set. + require.Len(t, results, 1) + assert.Len(t, results[0].Rows, 1) + assert.Equal(t, "1", string(results[0].Rows[0][0])) + } ensureConnValid(t, pgConn) } @@ -538,6 +547,10 @@ func TestConnExecDeferredError(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") + } + setupSQL := `create temporary table t ( id text primary key, n int not null, @@ -630,6 +643,10 @@ func TestConnExecParamsDeferredError(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") + } + setupSQL := `create temporary table t ( id text primary key, n int not null, @@ -860,14 +877,19 @@ func TestConnExecPreparedTooManyParams(t *testing.T) { sql := "values" + strings.Join(params, ", ") psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) - require.NoError(t, err) - require.NotNil(t, psd) - assert.Len(t, psd.ParamOIDs, paramCount) - assert.Len(t, psd.Fields, 1) + if pgConn.ParameterStatus("crdb_version") != "" { + // CockroachDB rejects preparing a statement with more than 65535 parameters. + require.EqualError(t, err, "ERROR: more than 65535 arguments to prepared statement: 65536 (SQLSTATE 08P01)") + } else { + // PostgreSQL accepts preparing a statement with more than 65535 parameters and only fails when executing it through the extended protocol. + require.NoError(t, err) + require.NotNil(t, psd) + assert.Len(t, psd.ParamOIDs, paramCount) + assert.Len(t, psd.Fields, 1) - result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() - require.Error(t, result.Err) - require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) + result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() + require.EqualError(t, result.Err, "extended protocol limited to 65535 parameters") + } ensureConnValid(t, pgConn) } @@ -981,6 +1003,10 @@ func TestConnExecBatchDeferredError(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") + } + setupSQL := `create temporary table t ( id text primary key, n int not null, @@ -1161,6 +1187,10 @@ func TestConnOnNotice(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support PL/PGSQL (https://github.com/cockroachdb/cockroach/issues/17511)") + } + multiResult := pgConn.Exec(context.Background(), `do $$ begin raise notice 'hello, world'; @@ -1187,6 +1217,10 @@ func TestConnOnNotification(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)") + } + _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() require.NoError(t, err) @@ -1219,6 +1253,10 @@ func TestConnWaitForNotification(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)") + } + _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() require.NoError(t, err) @@ -1279,6 +1317,10 @@ func TestConnCopyToSmall(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does support COPY TO") + } + _, err = pgConn.Exec(context.Background(), `create temporary table foo( a int2, b int4, @@ -1317,6 +1359,10 @@ func TestConnCopyToLarge(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does support COPY TO") + } + _, err = pgConn.Exec(context.Background(), `create temporary table foo( a int2, b int4, @@ -1372,6 +1418,10 @@ func TestConnCopyToCanceled(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)") + } + outputWriter := &bytes.Buffer{} ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) @@ -1415,6 +1465,10 @@ func TestConnCopyFrom(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not fully support COPY FROM (https://www.cockroachlabs.com/docs/v20.2/copy-from.html)") + } + _, err = pgConn.Exec(context.Background(), `create temporary table foo( a int4, b varchar @@ -1451,6 +1505,10 @@ func TestConnCopyFromCanceled(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)") + } + _, err = pgConn.Exec(context.Background(), `create temporary table foo( a int4, b varchar @@ -1528,6 +1586,10 @@ func TestConnCopyFromGzipReader(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not fully support COPY FROM (https://www.cockroachlabs.com/docs/v20.2/copy-from.html)") + } + _, err = pgConn.Exec(context.Background(), `create temporary table foo( a int4, b varchar @@ -1627,6 +1689,10 @@ func TestConnCopyFromNoticeResponseReceivedMidStream(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support triggers (https://github.com/cockroachdb/cockroach/issues/28296)") + } + _, err = pgConn.Exec(ctx, `create temporary table sentences( t text, ts tsvector @@ -1693,6 +1759,10 @@ func TestConnCancelRequest(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)") + } + multiResult := pgConn.Exec(context.Background(), "select 'Hello, world', pg_sleep(2)") // This test flickers without the Sleep. It appears that since Exec only sends the query and returns without awaiting a From 4bde08d1a63976925a721ab2f4e000ad594fb34f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Feb 2021 11:19:09 -0600 Subject: [PATCH 250/290] LRU statement cache tests handle CockroackDB --- stmtcache/lru_test.go | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/stmtcache/lru_test.go b/stmtcache/lru_test.go index 2d620905..a4108155 100644 --- a/stmtcache/lru_test.go +++ b/stmtcache/lru_test.go @@ -5,6 +5,7 @@ import ( "fmt" "math/rand" "os" + "regexp" "testing" "time" @@ -239,7 +240,19 @@ func fetchServerStatements(t testing.TB, ctx context.Context, conn *pgconn.PgCon require.NoError(t, result.Err) var statements []string for _, r := range result.Rows { - statements = append(statements, string(r[0])) + statement := string(r[0]) + if conn.ParameterStatus("crdb_version") != "" { + if statement == "PREPARE AS select statement from pg_prepared_statements" { + // CockroachDB includes the currently running unnamed prepared statement while PostgreSQL does not. Ignore it. + continue + } + + // CockroachDB includes the "PREPARE ... AS" text in the statement even if it was prepared through the extended + // protocol will PostgreSQL does not. Normalize the statement. + re := regexp.MustCompile(`^PREPARE lrupsc[0-9_]+ AS `) + statement = re.ReplaceAllString(statement, "") + } + statements = append(statements, statement) } return statements } From fb88a34cb4995248d154b5eaadde52136de25547 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Feb 2021 16:40:16 -0600 Subject: [PATCH 251/290] Skip test with known issue on CockroachDB --- pgconn_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pgconn_test.go b/pgconn_test.go index 564a0c51..87edefc2 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -1098,6 +1098,10 @@ func TestConnExecBatchImplicitTransaction(t *testing.T) { require.NoError(t, err) defer closeConn(t, pgConn) + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/44803)") + } + _, err = pgConn.Exec(context.Background(), "create temporary table t(id int)").ReadAll() require.NoError(t, err) From b9a1aad8d94163ffdf29aaedb48b78d3e2329ee3 Mon Sep 17 00:00:00 2001 From: Georges Varouchas Date: Thu, 4 Mar 2021 17:58:49 +0100 Subject: [PATCH 252/290] add failing test to highlight issue #65 if frontend returns a message with "Severity: FATAL", even after calling "conn.Close()", the 'CleanupDone()' channel is still blocking --- frontend_test.go | 70 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 frontend_test.go diff --git a/frontend_test.go b/frontend_test.go new file mode 100644 index 00000000..b82552bf --- /dev/null +++ b/frontend_test.go @@ -0,0 +1,70 @@ +package pgconn_test + +import ( + "context" + "io" + "os" + "testing" + + "github.com/jackc/pgconn" + "github.com/jackc/pgproto3/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// frontendWrapper allows to hijack a regular frontend, and inject a specific response +type frontendWrapper struct { + front pgconn.Frontend + + msg pgproto3.BackendMessage +} + +// frontendWrapper implements the pgconn.Frontend interface +var _ pgconn.Frontend = (*frontendWrapper)(nil) + +func (f *frontendWrapper) Receive() (pgproto3.BackendMessage, error) { + if f.msg != nil { + return f.msg, nil + } + + return f.front.Receive() +} + +func TestFrontendFatalErrExec(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + buildFrontend := config.BuildFrontend + var front *frontendWrapper + + config.BuildFrontend = func(r io.Reader, w io.Writer) pgconn.Frontend { + wrapped := buildFrontend(r, w) + front = &frontendWrapper{wrapped, nil} + + return front + } + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + require.NotNil(t, conn) + require.NotNil(t, front) + + // set frontend to return a "FATAL" message on next call + front.msg = &pgproto3.ErrorResponse{Severity: "FATAL", Message: "unit testing fatal error"} + + _, err = conn.Exec(context.Background(), "SELECT 1").ReadAll() + assert.Error(t, err) + + err = conn.Close(context.Background()) + assert.NoError(t, err) + + select { + case <-conn.CleanupDone(): + t.Log("ok, CleanupDone() is not blocking") + + default: + assert.Fail(t, "connection closed but CleanupDone() still blocking") + } +} From 36c8fb8257391de896e4c934ace6e82ea5631f3a Mon Sep 17 00:00:00 2001 From: Georges Varouchas Date: Thu, 4 Mar 2021 18:07:41 +0100 Subject: [PATCH 253/290] fix #65 : close cleanupDone channel on "FATAL" messages --- pgconn.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pgconn.go b/pgconn.go index 53e32252..0c1717ff 100644 --- a/pgconn.go +++ b/pgconn.go @@ -487,6 +487,7 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { if msg.Severity == "FATAL" { pgConn.status = connStatusClosed pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return. + close(pgConn.cleanupDone) return nil, ErrorResponseToPgError(msg) } case *pgproto3.NoticeResponse: From 3b0400a0d401491f45add1f347ed0383ca6a76a1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Mar 2021 14:42:22 -0600 Subject: [PATCH 254/290] Test Go 1.15 and 1.16 in CI --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 862235ae..fa5c9e8f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,7 +14,7 @@ jobs: strategy: matrix: - go_version: [1.14, 1.15] + go_version: [1.15, 1.16] pg_version: [9.6, 10, 11, 12, 13] steps: From cf5894e0927e66175468e7622712d2a4c6df0964 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Mar 2021 14:45:33 -0600 Subject: [PATCH 255/290] Use std errors instead of golang.org/x/xerrors New error functionality was introduced in Go 1.13. pgconn only officially supports 1.15+. Transitional xerrors package can now be removed. --- auth_scram.go | 6 +++--- config.go | 8 ++++---- errors.go | 3 +-- go.mod | 1 - pgconn.go | 5 +++-- pgconn_test.go | 6 ++---- 6 files changed, 13 insertions(+), 16 deletions(-) diff --git a/auth_scram.go b/auth_scram.go index 665fc2c2..6a143fcd 100644 --- a/auth_scram.go +++ b/auth_scram.go @@ -18,13 +18,13 @@ import ( "crypto/rand" "crypto/sha256" "encoding/base64" + "errors" "fmt" "strconv" "github.com/jackc/pgproto3/v2" "golang.org/x/crypto/pbkdf2" "golang.org/x/text/secure/precis" - errors "golang.org/x/xerrors" ) const clientNonceLen = 18 @@ -192,12 +192,12 @@ func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error { var err error sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr)) if err != nil { - return errors.Errorf("invalid SCRAM salt received from server: %w", err) + return fmt.Errorf("invalid SCRAM salt received from server: %w", err) } sc.iterations, err = strconv.Atoi(string(iterationsStr)) if err != nil || sc.iterations <= 0 { - return errors.Errorf("invalid SCRAM iteration count received from server: %w", err) + return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err) } if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) { diff --git a/config.go b/config.go index 38e94f26..c162d3c3 100644 --- a/config.go +++ b/config.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "errors" "fmt" "io" "io/ioutil" @@ -20,7 +21,6 @@ import ( "github.com/jackc/pgpassfile" "github.com/jackc/pgproto3/v2" "github.com/jackc/pgservicefile" - errors "golang.org/x/xerrors" ) type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error @@ -409,7 +409,7 @@ func parseURLSettings(connString string) (map[string]string, error) { } h, p, err := net.SplitHostPort(host) if err != nil { - return nil, errors.Errorf("failed to split host:port in '%s', err: %w", host, err) + return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err) } hosts = append(hosts, h) ports = append(ports, p) @@ -617,7 +617,7 @@ func configTLS(settings map[string]string) ([]*tls.Config, error) { caPath := sslrootcert caCert, err := ioutil.ReadFile(caPath) if err != nil { - return nil, errors.Errorf("unable to read CA file: %w", err) + return nil, fmt.Errorf("unable to read CA file: %w", err) } if !caCertPool.AppendCertsFromPEM(caCert) { @@ -635,7 +635,7 @@ func configTLS(settings map[string]string) ([]*tls.Config, error) { if sslcert != "" && sslkey != "" { cert, err := tls.LoadX509KeyPair(sslcert, sslkey) if err != nil { - return nil, errors.Errorf("unable to read cert: %w", err) + return nil, fmt.Errorf("unable to read cert: %w", err) } tlsConfig.Certificates = []tls.Certificate{cert} diff --git a/errors.go b/errors.go index b37b1d97..77adfcf0 100644 --- a/errors.go +++ b/errors.go @@ -2,13 +2,12 @@ package pgconn import ( "context" + "errors" "fmt" "net" "net/url" "regexp" "strings" - - errors "golang.org/x/xerrors" ) // SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server. diff --git a/go.mod b/go.mod index 7e578765..2dc0cd4d 100644 --- a/go.mod +++ b/go.mod @@ -12,5 +12,4 @@ require ( github.com/stretchr/testify v1.5.1 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 golang.org/x/text v0.3.3 - golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 ) diff --git a/pgconn.go b/pgconn.go index 0c1717ff..20233e57 100644 --- a/pgconn.go +++ b/pgconn.go @@ -6,6 +6,8 @@ import ( "crypto/tls" "encoding/binary" "encoding/hex" + "errors" + "fmt" "io" "math" "net" @@ -16,7 +18,6 @@ import ( "github.com/jackc/pgconn/internal/ctxwatch" "github.com/jackc/pgio" "github.com/jackc/pgproto3/v2" - errors "golang.org/x/xerrors" ) const ( @@ -1043,7 +1044,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by } if len(paramValues) > math.MaxUint16 { - result.concludeCommand(nil, errors.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) + result.concludeCommand(nil, fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) result.closed = true pgConn.unlock() return result diff --git a/pgconn_test.go b/pgconn_test.go index 87edefc2..7ceda791 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -5,6 +5,7 @@ import ( "compress/gzip" "context" "crypto/tls" + "errors" "fmt" "io" "io/ioutil" @@ -17,12 +18,9 @@ import ( "testing" "time" - "github.com/jackc/pgmock" - "github.com/jackc/pgconn" + "github.com/jackc/pgmock" "github.com/jackc/pgproto3/v2" - errors "golang.org/x/xerrors" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) From a0350a932a7e4313c547e36e6e2e8b7ccd8ce3d1 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Mar 2021 15:01:44 -0600 Subject: [PATCH 256/290] ci.yml consistently uses kebab case --- .github/workflows/ci.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fa5c9e8f..77d32cb7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,15 +14,15 @@ jobs: strategy: matrix: - go_version: [1.15, 1.16] - pg_version: [9.6, 10, 11, 12, 13] + go-version: [1.15, 1.16] + pg-version: [9.6, 10, 11, 12, 13] steps: - name: Set up Go 1.x uses: actions/setup-go@v2 with: - go-version: ${{ matrix.go_version }} + go-version: ${{ matrix.go-version }} - name: Check out code into the Go module directory uses: actions/checkout@v2 @@ -30,7 +30,7 @@ jobs: - name: Setup database server for testing run: ci/setup_test.bash env: - PGVERSION: ${{ matrix.pg_version }} + PGVERSION: ${{ matrix.pg-version }} - name: Test run: go test -v -race ./... From 7de3392269f1eb7d43900b8406392ea767fae479 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Mar 2021 15:15:03 -0600 Subject: [PATCH 257/290] Manually specify all build matrix options - Saves some CI time by only testing older version of Go once - Specify connection --- .github/workflows/ci.yml | 38 ++++++++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 77d32cb7..67ffeaab 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,8 +14,30 @@ jobs: strategy: matrix: - go-version: [1.15, 1.16] - pg-version: [9.6, 10, 11, 12, 13] + include: + - go-version: 1.15 + pg-version: 13 + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + - go-version: 1.16 + pg-version: 9.6 + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + - go-version: 1.16 + pg-version: 10 + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + - go-version: 1.16 + pg-version: 11 + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + - go-version: 1.16 + pg-version: 12 + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + - go-version: 1.16 + pg-version: 13 + pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test steps: @@ -35,9 +57,9 @@ jobs: - name: Test run: go test -v -race ./... env: - PGX_TEST_CONN_STRING: postgres://pgx_md5:secret@127.0.0.1/pgx_test - PGX_TEST_UNIX_SOCKET_CONN_STRING: "host=/var/run/postgresql dbname=pgx_test" - PGX_TEST_TCP_CONN_STRING: postgres://pgx_md5:secret@127.0.0.1/pgx_test - PGX_TEST_TLS_CONN_STRING: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require - PGX_TEST_MD5_PASSWORD_CONN_STRING: postgres://pgx_md5:secret@127.0.0.1/pgx_test - PGX_TEST_PLAIN_PASSWORD_CONN_STRING: postgres://pgx_pw:secret@127.0.0.1/pgx_test + PGX_TEST_CONN_STRING: ${{ matrix.pgx-test-conn-string }} + PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }} + PGX_TEST_TCP_CONN_STRING: ${{ matrix.pgx-test-tcp-conn-string }} + PGX_TEST_TLS_CONN_STRING: ${{ matrix.pgx-test-tls-conn-string }} + PGX_TEST_MD5_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-md5-password-conn-string }} + PGX_TEST_PLAIN_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-plain-password-conn-string }} From 1e905d8e38f6c9344707931ccd2afa03a2f34273 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Mar 2021 15:20:03 -0600 Subject: [PATCH 258/290] Refactor connection strings into build matrix This is in preparation for adding CockroachDB to the build matrix. --- .github/workflows/ci.yml | 40 +++++++++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 67ffeaab..6880ae90 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,24 +14,38 @@ jobs: strategy: matrix: + go-version: [1.15, 1.16] + pg-version: [9.6, 10, 11, 12, 13] include: - - go-version: 1.15 - pg-version: 13 + - pg-version: 9.6 pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - - go-version: 1.16 - pg-version: 9.6 + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test + - pg-version: 10 pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - - go-version: 1.16 - pg-version: 10 + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test + - pg-version: 11 pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - - go-version: 1.16 - pg-version: 11 + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test + - pg-version: 12 pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - - go-version: 1.16 - pg-version: 12 - pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test - - go-version: 1.16 - pg-version: 13 + pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" + pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test + pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test + - pg-version: 13 pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test From 0d307bcc5e8ce129be1875bce1595a397aa46140 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Mar 2021 15:49:50 -0600 Subject: [PATCH 259/290] Add CockroachDB to CI --- .github/workflows/ci.yml | 6 ++++-- ci/setup_test.bash | 10 +++++++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6880ae90..d84462da 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,12 +10,12 @@ jobs: test: name: Test - runs-on: ubuntu-latest + runs-on: ubuntu-18.04 strategy: matrix: go-version: [1.15, 1.16] - pg-version: [9.6, 10, 11, 12, 13] + pg-version: [9.6, 10, 11, 12, 13, cockroachdb] include: - pg-version: 9.6 pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test @@ -52,6 +52,8 @@ jobs: pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test + - pg-version: cockroachdb + pgx-test-conn-string: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on" steps: diff --git a/ci/setup_test.bash b/ci/setup_test.bash index 144e93fd..f71bd98c 100755 --- a/ci/setup_test.bash +++ b/ci/setup_test.bash @@ -1,7 +1,7 @@ #!/usr/bin/env bash set -eux -if [ "${PGVERSION-}" != "" ] +if [[ "${PGVERSION-}" =~ ^[0-9.]+$ ]] then sudo apt-get remove -y --purge postgresql libpq-dev libpq5 postgresql-client-common postgresql-common sudo rm -rf /var/lib/postgresql @@ -38,6 +38,14 @@ then psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'" fi +if [[ "${PGVERSION-}" =~ ^cockroach ]] +then + wget -qO- https://binaries.cockroachdb.com/cockroach-v20.2.5.linux-amd64.tgz | tar xvz + sudo mv cockroach-v20.2.5.linux-amd64/cockroach /usr/local/bin/ + cockroach start-single-node --insecure --background --listen-addr=localhost + cockroach sql --insecure -e 'create database pgx_test' +fi + if [ "${CRATEVERSION-}" != "" ] then docker run \ From 5daa019e4eb52df3409ebf17c83116b7c0e827e5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Mar 2021 16:08:38 -0600 Subject: [PATCH 260/290] Update README.md to authentication test setup --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c651f483..1c698a11 100644 --- a/README.md +++ b/README.md @@ -52,5 +52,5 @@ PGX_TEST_CONN_STRING="host=/var/run/postgresql dbname=pgx_test" go test ./... Pgconn supports multiple connection types and means of authentication. These tests are optional. They will only run if the appropriate environment variable is set. Run `go test -v | grep SKIP` to see if any tests are being -skipped. Most developers will not need to enable these tests. See `travis.yml` for an example set up if you need change +skipped. Most developers will not need to enable these tests. See `ci/setup_test.bash` for an example set up if you need change authentication code. From 26ccb4ee08e9895ad83905cbfbd7dc782261f8c3 Mon Sep 17 00:00:00 2001 From: Andrey Borodin Date: Wed, 10 Mar 2021 22:19:41 +0500 Subject: [PATCH 261/290] Resume fallback on server error When server responds with "TLS required" or too "many connections for role" fallbacks are not traversed any further. This could be OK, but fallbacks without TLS are added autoatically so that if we have multiple hosts requiring TLS we never traverse beyond first one. --- pgconn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgconn.go b/pgconn.go index 20233e57..a245159d 100644 --- a/pgconn.go +++ b/pgconn.go @@ -151,7 +151,7 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err if err == nil { break } else if err, ok := err.(*PgError); ok { - return nil, &connectError{config: config, msg: "server error", err: err} + err = &connectError{config: config, msg: "server error", err: err} } } From 70be4b4a02e4c00a3cf4199749f60a0544e12d9b Mon Sep 17 00:00:00 2001 From: Andrey Borodin Date: Wed, 10 Mar 2021 22:29:01 +0500 Subject: [PATCH 262/290] Fix incoherent type assignment --- pgconn.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pgconn.go b/pgconn.go index a245159d..826d70e9 100644 --- a/pgconn.go +++ b/pgconn.go @@ -150,8 +150,8 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err pgConn, err = connect(ctx, config, fc) if err == nil { break - } else if err, ok := err.(*PgError); ok { - err = &connectError{config: config, msg: "server error", err: err} + } else if pgerr, ok := err.(*PgError); ok { + err = &connectError{config: config, msg: "server error", err: pgerr} } } From b6027e37f43987793a1e39b97b99598777218547 Mon Sep 17 00:00:00 2001 From: Andrey Borodin Date: Fri, 12 Mar 2021 11:48:43 +0500 Subject: [PATCH 263/290] Stop fallback in case of invalid password --- pgconn.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pgconn.go b/pgconn.go index 826d70e9..668808aa 100644 --- a/pgconn.go +++ b/pgconn.go @@ -152,6 +152,10 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err break } else if pgerr, ok := err.(*PgError); ok { err = &connectError{config: config, msg: "server error", err: pgerr} + ERRCODE_INVALID_PASSWORD := "28P01" + if pgerr.Code == ERRCODE_INVALID_PASSWORD { + break; + } } } From 8990c125cf4a71bcf938328b43d52a289053725e Mon Sep 17 00:00:00 2001 From: Andrey Borodin Date: Fri, 12 Mar 2021 11:55:01 +0500 Subject: [PATCH 264/290] Stop fallback on ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION --- pgconn.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pgconn.go b/pgconn.go index 668808aa..197aad4a 100644 --- a/pgconn.go +++ b/pgconn.go @@ -152,9 +152,10 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err break } else if pgerr, ok := err.(*PgError); ok { err = &connectError{config: config, msg: "server error", err: pgerr} - ERRCODE_INVALID_PASSWORD := "28P01" - if pgerr.Code == ERRCODE_INVALID_PASSWORD { - break; + ERRCODE_INVALID_PASSWORD := "28P01" // worng password + ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION := "28000" // db does not exist + if pgerr.Code == ERRCODE_INVALID_PASSWORD || pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION { + break } } } From e8f75629d095956d7bff81362bdcd17e37d02464 Mon Sep 17 00:00:00 2001 From: Ethan Pailes Date: Mon, 22 Mar 2021 13:51:08 -0400 Subject: [PATCH 265/290] upgrade x/crypto to avoid CVE-2020-9283 I found this when scanning for security issues in some dependencies. I doubt that this CVE will impact pgconn since I don't think it uses the ssh cropto module, but I think it is worth being fairly agressive about upgrading security sensative libraries and this doesn't seem to be a breaking change. --- go.mod | 2 +- go.sum | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 2dc0cd4d..e9003cb7 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,6 @@ require ( github.com/jackc/pgproto3/v2 v2.0.6 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/stretchr/testify v1.5.1 - golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 + golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 golang.org/x/text v0.3.3 ) diff --git a/go.sum b/go.sum index f3eb0e08..58bb1286 100644 --- a/go.sum +++ b/go.sum @@ -99,10 +99,13 @@ golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59 h1:3zb4D3T4G8jdExgVU/95+v golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 h1:It14KIkyBFYkHkwZ7k45minvA9aorojkyjGk9KJ5B/w= +golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -111,6 +114,8 @@ golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= From cdb667b5b002eb70aaac3666814309b07539895d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 25 Mar 2021 09:09:55 -0400 Subject: [PATCH 266/290] Update copyright date --- LICENSE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LICENSE b/LICENSE index c1c4f50f..aebadd6c 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2019 Jack Christensen +Copyright (c) 2019-2021 Jack Christensen MIT License From 464a7d88d9ccf1ca9f76a84984d95a5657ac3faa Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 25 Mar 2021 09:15:34 -0400 Subject: [PATCH 267/290] Release v1.8.1 --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 787853b2..c377b3ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,10 @@ +# 1.8.1 (March 25, 2021) + +* Better connection string sanitization (ip.novikov) +* Use proper pgpass location on Windows (Moshe Katz) +* Use errors instead of golang.org/x/xerrors +* Resume fallback on server error in Connect (Andrey Borodin) + # 1.8.0 (December 3, 2020) * Add StatementErrored method to stmtcache.Cache. This allows the cache to purge invalidated prepared statements. (Ethan Pailes) From 3f76b98073687a376f84a10c0972c3dd0c5de55c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 9 Apr 2021 18:20:06 -0500 Subject: [PATCH 268/290] Allow dbname query parameter in URL conn string fixes #69 --- config.go | 8 ++++++++ config_test.go | 12 ++++++++++++ 2 files changed, 20 insertions(+) diff --git a/config.go b/config.go index c162d3c3..6991e1de 100644 --- a/config.go +++ b/config.go @@ -426,7 +426,15 @@ func parseURLSettings(connString string) (map[string]string, error) { settings["database"] = database } + nameMap := map[string]string{ + "dbname": "database", + } + for k, v := range url.Query() { + if k2, present := nameMap[k]; present { + k = k2 + } + settings[k] = v[0] } diff --git a/config_test.go b/config_test.go index e869422d..11dd23dc 100644 --- a/config_test.go +++ b/config_test.go @@ -227,6 +227,18 @@ func TestParseConfig(t *testing.T) { RuntimeParams: map[string]string{}, }, }, + { + name: "database url dbname", + connString: "postgres://localhost/?dbname=foo&sslmode=disable", + config: &pgconn.Config{ + User: osUserName, + Host: "localhost", + Port: 5432, + Database: "foo", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, { name: "database url postgresql protocol", connString: "postgresql://jack@localhost:5432/mydb?sslmode=disable", From fb42201c18fcd016c235d4b613f76b2fc1599588 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 14 May 2021 18:39:31 -0500 Subject: [PATCH 269/290] Fix default host when parsing URL without host but with port fixes https://github.com/jackc/pgconn/issues/72 --- config.go | 8 ++++++-- config_test.go | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/config.go b/config.go index 6991e1de..16480589 100644 --- a/config.go +++ b/config.go @@ -411,8 +411,12 @@ func parseURLSettings(connString string) (map[string]string, error) { if err != nil { return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err) } - hosts = append(hosts, h) - ports = append(ports, p) + if h != "" { + hosts = append(hosts, h) + } + if p != "" { + ports = append(ports, p) + } } if len(hosts) > 0 { settings["host"] = strings.Join(hosts, ",") diff --git a/config_test.go b/config_test.go index 11dd23dc..d29173d1 100644 --- a/config_test.go +++ b/config_test.go @@ -32,6 +32,10 @@ func TestParseConfig(t *testing.T) { } } + config, err := pgconn.ParseConfig("") + require.NoError(t, err) + defaultHost := config.Host + tests := []struct { name string connString string @@ -428,6 +432,20 @@ func TestParseConfig(t *testing.T) { }, }, }, + // https://github.com/jackc/pgconn/issues/72 + { + name: "URL without host but with port still uses default host", + connString: "postgres://jack:secret@:1/mydb?sslmode=disable", + config: &pgconn.Config{ + User: "jack", + Password: "secret", + Host: defaultHost, + Port: 1, + Database: "mydb", + TLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, { name: "DSN multiple hosts one port", connString: "user=jack password=secret host=foo,bar,baz port=5432 dbname=mydb sslmode=disable", From cfcd61d0cbf58dfa254d094864bcfc22c7a3e104 Mon Sep 17 00:00:00 2001 From: Sivabalan Thirunavukkarasu Date: Thu, 17 Jun 2021 20:17:10 +0800 Subject: [PATCH 270/290] Updating dependency versions --- go.mod | 4 ++-- go.sum | 33 ++++++++++----------------------- 2 files changed, 12 insertions(+), 25 deletions(-) diff --git a/go.mod b/go.mod index e9003cb7..233fa205 100644 --- a/go.mod +++ b/go.mod @@ -5,11 +5,11 @@ go 1.12 require ( github.com/jackc/chunkreader/v2 v2.0.1 github.com/jackc/pgio v1.0.0 - github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 + github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgproto3/v2 v2.0.6 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/stretchr/testify v1.5.1 golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 - golang.org/x/text v0.3.3 + golang.org/x/text v0.3.6 ) diff --git a/go.sum b/go.sum index 58bb1286..14121a04 100644 --- a/go.sum +++ b/go.sum @@ -8,17 +8,18 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= -github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= +github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= -github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye4717ITLaNwV9mWbJx0dLCpcRzdA= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= +github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd h1:eDErF6V/JPJON/B7s68BxwHgfmyOntHJQ8IOaz0x4R8= +github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= @@ -26,22 +27,9 @@ github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= -github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29 h1:f2HwOeI1NIJyNFVVeh1gUISyt57iw/fmI/IXJfH3ATE= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= -github.com/jackc/pgproto3/v2 v2.0.1 h1:Rdjp4NFjwHnEslx2b66FfCI2S0LhO4itac3hXz6WX9M= -github.com/jackc/pgproto3/v2 v2.0.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.0.2 h1:q1Hsy66zh4vuNsajBUF2PNqfAMMfxU5mk594lPE9vjY= -github.com/jackc/pgproto3/v2 v2.0.2/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.0.3 h1:2S4PhE00mvdvaSiCYR1ZCmR1NAxeYfTSsqqSKxE1vzo= -github.com/jackc/pgproto3/v2 v2.0.3/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.0.4 h1:RHkX5ZUD9bl/kn0f9dYUWs1N7Nwvo1wwUYvKiR26Zco= -github.com/jackc/pgproto3/v2 v2.0.4/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.0.5 h1:NUbEWPmCQZbMmYlTjVoNPhc0CfnYyz2bfUAh6A5ZVJM= -github.com/jackc/pgproto3/v2 v2.0.5/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.0.6 h1:b1105ZGEMFe7aCvrT1Cca3VoVb4ZFMaFJLJcg/3zD+8= github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8 h1:Q3tB+ExeflWUW7AFcAhXqk40s9mnNYLk1nOkKNZ5GnU= -github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= @@ -81,7 +69,6 @@ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= @@ -93,12 +80,9 @@ go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= -golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 h1:7KByu05hhLed2MO29w7p1XfZvZ13m8mub3shuVftRs0= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59 h1:3zb4D3T4G8jdExgVU/95+vQXfpEPiMdCaZgmGVxjNHM= -golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 h1:It14KIkyBFYkHkwZ7k45minvA9aorojkyjGk9KJ5B/w= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= @@ -114,20 +98,23 @@ golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= -golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From bacf81fb4eada115882c2f4f29d0d42047902be5 Mon Sep 17 00:00:00 2001 From: Sivabalan Thirunavukkarasu Date: Thu, 17 Jun 2021 20:43:54 +0800 Subject: [PATCH 271/290] Bumping versions for other dependencies --- go.mod | 6 +++--- go.sum | 14 +++++++++----- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/go.mod b/go.mod index 233fa205..57f773b1 100644 --- a/go.mod +++ b/go.mod @@ -7,9 +7,9 @@ require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.0.6 + github.com/jackc/pgproto3/v2 v2.1.0 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b - github.com/stretchr/testify v1.5.1 - golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 + github.com/stretchr/testify v1.7.0 + golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e golang.org/x/text v0.3.6 ) diff --git a/go.sum b/go.sum index 14121a04..eedcac1b 100644 --- a/go.sum +++ b/go.sum @@ -28,8 +28,9 @@ github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= -github.com/jackc/pgproto3/v2 v2.0.6 h1:b1105ZGEMFe7aCvrT1Cca3VoVb4ZFMaFJLJcg/3zD+8= github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.1.0 h1:h2yg3kjIyAGSZKDijYn1/gXHlYLCwl9ZjEh2PU0yVxE= +github.com/jackc/pgproto3/v2 v2.1.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= @@ -70,8 +71,9 @@ github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoH github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= @@ -83,8 +85,8 @@ golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaE golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= -golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 h1:It14KIkyBFYkHkwZ7k45minvA9aorojkyjGk9KJ5B/w= -golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= +golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e h1:gsTQYXdTw2Gq7RBsWvlQ91b+aEQ6bXFUngBGuR8sPpI= +golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -100,6 +102,7 @@ golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -119,5 +122,6 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= -gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= From bf76d1ed51099a78209a2dc109d826cab20d286e Mon Sep 17 00:00:00 2001 From: mgoddard Date: Sat, 19 Jun 2021 07:16:00 -0400 Subject: [PATCH 272/290] Solve issue with 'sslmode=verify-full' when there are multiple hosts --- config.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/config.go b/config.go index 16480589..172e7478 100644 --- a/config.go +++ b/config.go @@ -297,7 +297,7 @@ func ParseConfig(connString string) (*Config, error) { tlsConfigs = append(tlsConfigs, nil) } else { var err error - tlsConfigs, err = configTLS(settings) + tlsConfigs, err = configTLS(settings, host) if err != nil { return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err} } @@ -552,8 +552,8 @@ func parseServiceSettings(servicefilePath, serviceName string) (map[string]strin // configTLS uses libpq's TLS parameters to construct []*tls.Config. It is // necessary to allow returning multiple TLS configs as sslmode "allow" and // "prefer" allow fallback. -func configTLS(settings map[string]string) ([]*tls.Config, error) { - host := settings["host"] +func configTLS(settings map[string]string, thisHost string) ([]*tls.Config, error) { + host := thisHost sslmode := settings["sslmode"] sslrootcert := settings["sslrootcert"] sslcert := settings["sslcert"] From a123e5b4e575b5eb3c68ae4ab87c508d341242df Mon Sep 17 00:00:00 2001 From: Joshua Brindle Date: Mon, 21 Jun 2021 15:25:10 -0400 Subject: [PATCH 273/290] Add defaults for sslcert, sslkey, and sslrootcert per https://www.postgresql.org/docs/current/libpq-ssl.html psql will use client certs located in ~/.postgresql on posix systems or %APPDATA%\postgresql on Windows systems. --- defaults.go | 13 +++++++++++++ defaults_windows.go | 13 +++++++++++++ 2 files changed, 26 insertions(+) diff --git a/defaults.go b/defaults.go index d3313481..f69cad31 100644 --- a/defaults.go +++ b/defaults.go @@ -22,6 +22,19 @@ func defaultSettings() map[string]string { settings["user"] = user.Username settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass") settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf") + sslcert := filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt") + sslkey := filepath.Join(user.HomeDir, ".postgresql", "postgresql.key") + if _, err := os.Stat(sslcert); err == nil { + if _, err := os.Stat(sslkey); err == nil { + // Both the cert and key must be present to use them, or do not use either + settings["sslcert"] = sslcert + settings["sslkey"] = sslkey + } + } + sslrootcert := filepath.Join(user.HomeDir, ".postgresql", "root.crt") + if _, err := os.Stat(sslrootcert); err == nil { + settings["sslrootcert"] = sslrootcert + } } settings["target_session_attrs"] = "any" diff --git a/defaults_windows.go b/defaults_windows.go index 55243700..71eb77db 100644 --- a/defaults_windows.go +++ b/defaults_windows.go @@ -29,6 +29,19 @@ func defaultSettings() map[string]string { settings["user"] = username settings["passfile"] = filepath.Join(appData, "postgresql", "pgpass.conf") settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf") + sslcert := filepath.Join(appData, "postgresql", "postgresql.crt") + sslkey := filepath.Join(appData, "postgresql", "postgresql.key") + if _, err := os.Stat(sslcert); err == nil { + if _, err := os.Stat(sslkey); err == nil { + // Both the cert and key must be present to use them, or do not use either + settings["sslcert"] = sslcert + settings["sslkey"] = sslkey + } + } + sslrootcert := filepath.Join(appData, "postgresql", "root.crt") + if _, err := os.Stat(sslrootcert); err == nil { + settings["sslrootcert"] = sslrootcert + } } settings["target_session_attrs"] = "any" From c0b4d3bc05e51a6df4c011de058f0e4daf7e154f Mon Sep 17 00:00:00 2001 From: Michael Darr Date: Tue, 29 Jun 2021 14:24:09 -0400 Subject: [PATCH 274/290] Implement timeout error Signed-off-by: Michael Darr --- errors.go | 47 ++++++++++++++++++++++++++++++++++++++++------- pgconn.go | 37 +++++++++++++++++++++++++------------ 2 files changed, 65 insertions(+), 19 deletions(-) diff --git a/errors.go b/errors.go index 77adfcf0..5df851d5 100644 --- a/errors.go +++ b/errors.go @@ -18,15 +18,11 @@ func SafeToRetry(err error) bool { return false } -// Timeout checks if err was was caused by a timeout. To be specific, it is true if err is or was caused by a +// Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a // context.Canceled, context.DeadlineExceeded or an implementer of net.Error where Timeout() is true. func Timeout(err error) bool { - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return true - } - - var netErr net.Error - return errors.As(err, &netErr) && netErr.Timeout() + var timeoutErr *ErrTimeout + return errors.As(err, &timeoutErr) } // PgError represents an error reported by the PostgreSQL server. See @@ -134,6 +130,32 @@ func (e *pgconnError) Unwrap() error { return e.err } +// ErrTimeout occurs when an error was caused by a timeout. Specifically, it wraps an error which is +// context.Canceled, context.DeadlineExceeded, or an implementer of net.Error where Timeout() is true. +type ErrTimeout struct { + err error +} + +func (e *ErrTimeout) Error() string { + return fmt.Sprintf("timeout: %s", e.err.Error()) +} + +func (e *ErrTimeout) SafeToRetry() bool { + var ctxErr *contextAlreadyDoneError + if errors.As(e, &ctxErr) { + return ctxErr.SafeToRetry() + } + var netErr net.Error + if errors.As(e, &netErr) { + return netErr.Temporary() + } + return false +} + +func (e *ErrTimeout) Unwrap() error { + return e.err +} + type contextAlreadyDoneError struct { err error } @@ -150,6 +172,17 @@ func (e *contextAlreadyDoneError) Unwrap() error { return e.err } +// newContextAlreadyDoneError wraps a context error in `contextAlreadyDoneError`. If the context was cancelled or its +// deadline passed, the returned error is also wrapped by `ErrTimeout`. +func newContextAlreadyDoneError(ctx context.Context) (err error) { + ctxErr := ctx.Err() + err = &contextAlreadyDoneError{err: ctxErr} + if ctxErr != nil { + err = &ErrTimeout{err: err} + } + return err +} + type writeError struct { err error safeToRetry bool diff --git a/pgconn.go b/pgconn.go index 197aad4a..74e24257 100644 --- a/pgconn.go +++ b/pgconn.go @@ -217,6 +217,10 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) pgConn.conn, err = config.DialFunc(ctx, network, address) if err != nil { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + err = &ErrTimeout{err: err} + } return nil, &connectError{config: config, msg: "dial error", err: err} } @@ -389,7 +393,7 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { if ctx != context.Background() { select { case <-ctx.Done(): - return &contextAlreadyDoneError{err: ctx.Err()} + return newContextAlreadyDoneError(ctx) default: } pgConn.contextWatcher.Watch(ctx) @@ -421,7 +425,7 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa if ctx != context.Background() { select { case <-ctx.Done(): - return nil, &contextAlreadyDoneError{err: ctx.Err()} + return nil, newContextAlreadyDoneError(ctx) default: } pgConn.contextWatcher.Watch(ctx) @@ -451,7 +455,8 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { pgConn.bufferingReceive = false // If a timeout error happened in the background try the read again. - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { msg, err = pgConn.frontend.Receive() } } else { @@ -460,8 +465,12 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { if err != nil { // Close on anything other than timeout error - everything else is fatal - if err, ok := err.(net.Error); !(ok && err.Timeout()) { + var netErr net.Error + isNetErr := errors.As(err, &netErr) + if !(isNetErr && netErr.Timeout()) { pgConn.asyncClose() + } else if isNetErr && netErr.Timeout() { + err = &ErrTimeout{err: err} } return nil, err @@ -476,8 +485,12 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { msg, err := pgConn.peekMessage() if err != nil { // Close on anything other than timeout error - everything else is fatal - if err, ok := err.(net.Error); !(ok && err.Timeout()) { + var netErr net.Error + isNetErr := errors.As(err, &netErr) + if !(isNetErr && netErr.Timeout()) { pgConn.asyncClose() + } else if isNetErr && netErr.Timeout() { + err = &ErrTimeout{err: err} } return nil, err @@ -745,7 +758,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ if ctx != context.Background() { select { case <-ctx.Done(): - return nil, &contextAlreadyDoneError{err: ctx.Err()} + return nil, newContextAlreadyDoneError(ctx) default: } pgConn.contextWatcher.Watch(ctx) @@ -918,7 +931,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { select { case <-ctx.Done(): multiResult.closed = true - multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} + multiResult.err = newContextAlreadyDoneError(ctx) pgConn.unlock() return multiResult default: @@ -964,7 +977,7 @@ func (pgConn *PgConn) ReceiveResults(ctx context.Context) *MultiResultReader { select { case <-ctx.Done(): multiResult.closed = true - multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} + multiResult.err = newContextAlreadyDoneError(ctx) pgConn.unlock() return multiResult default: @@ -1058,7 +1071,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by if ctx != context.Background() { select { case <-ctx.Done(): - result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()}) + result.concludeCommand(nil, newContextAlreadyDoneError(ctx)) result.closed = true pgConn.unlock() return result @@ -1098,7 +1111,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm select { case <-ctx.Done(): pgConn.unlock() - return nil, &contextAlreadyDoneError{err: ctx.Err()} + return nil, newContextAlreadyDoneError(ctx) default: } pgConn.contextWatcher.Watch(ctx) @@ -1158,7 +1171,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co if ctx != context.Background() { select { case <-ctx.Done(): - return nil, &contextAlreadyDoneError{err: ctx.Err()} + return nil, newContextAlreadyDoneError(ctx) default: } pgConn.contextWatcher.Watch(ctx) @@ -1601,7 +1614,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR select { case <-ctx.Done(): multiResult.closed = true - multiResult.err = &contextAlreadyDoneError{err: ctx.Err()} + multiResult.err = newContextAlreadyDoneError(ctx) pgConn.unlock() return multiResult default: From b3e64d3cdb6e805e32adce9c4a148c2ebf6e9cee Mon Sep 17 00:00:00 2001 From: Michael Darr Date: Tue, 6 Jul 2021 15:36:46 -0400 Subject: [PATCH 275/290] Simplify SafeToRetry for ErrTimeout Signed-off-by: Michael Darr --- errors.go | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/errors.go b/errors.go index 5df851d5..0bb322cd 100644 --- a/errors.go +++ b/errors.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "net" "net/url" "regexp" "strings" @@ -141,15 +140,7 @@ func (e *ErrTimeout) Error() string { } func (e *ErrTimeout) SafeToRetry() bool { - var ctxErr *contextAlreadyDoneError - if errors.As(e, &ctxErr) { - return ctxErr.SafeToRetry() - } - var netErr net.Error - if errors.As(e, &netErr) { - return netErr.Temporary() - } - return false + return SafeToRetry(e.err) } func (e *ErrTimeout) Unwrap() error { From 9a9830c00d579aaa709b095acd2ab96162e3a564 Mon Sep 17 00:00:00 2001 From: Michael Darr Date: Tue, 6 Jul 2021 15:43:26 -0400 Subject: [PATCH 276/290] Always double-wrap contextAlreadyDoneError Signed-off-by: Michael Darr --- errors.go | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/errors.go b/errors.go index 0bb322cd..ab83b3a5 100644 --- a/errors.go +++ b/errors.go @@ -163,15 +163,9 @@ func (e *contextAlreadyDoneError) Unwrap() error { return e.err } -// newContextAlreadyDoneError wraps a context error in `contextAlreadyDoneError`. If the context was cancelled or its -// deadline passed, the returned error is also wrapped by `ErrTimeout`. +// newContextAlreadyDoneError double-wraps a context error in `contextAlreadyDoneError` and `ErrTimeout`. func newContextAlreadyDoneError(ctx context.Context) (err error) { - ctxErr := ctx.Err() - err = &contextAlreadyDoneError{err: ctxErr} - if ctxErr != nil { - err = &ErrTimeout{err: err} - } - return err + return &ErrTimeout{&contextAlreadyDoneError{err: ctx.Err()}} } type writeError struct { From a50d96d4915cae7d1a28601ce9e7a57b0ea5ae41 Mon Sep 17 00:00:00 2001 From: Michael Darr Date: Tue, 6 Jul 2021 21:44:44 -0400 Subject: [PATCH 277/290] Make timeout error private Signed-off-by: Michael Darr --- errors.go | 16 ++++++++-------- pgconn.go | 6 +++--- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/errors.go b/errors.go index ab83b3a5..64401d65 100644 --- a/errors.go +++ b/errors.go @@ -20,7 +20,7 @@ func SafeToRetry(err error) bool { // Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a // context.Canceled, context.DeadlineExceeded or an implementer of net.Error where Timeout() is true. func Timeout(err error) bool { - var timeoutErr *ErrTimeout + var timeoutErr *errTimeout return errors.As(err, &timeoutErr) } @@ -129,21 +129,21 @@ func (e *pgconnError) Unwrap() error { return e.err } -// ErrTimeout occurs when an error was caused by a timeout. Specifically, it wraps an error which is +// errTimeout occurs when an error was caused by a timeout. Specifically, it wraps an error which is // context.Canceled, context.DeadlineExceeded, or an implementer of net.Error where Timeout() is true. -type ErrTimeout struct { +type errTimeout struct { err error } -func (e *ErrTimeout) Error() string { +func (e *errTimeout) Error() string { return fmt.Sprintf("timeout: %s", e.err.Error()) } -func (e *ErrTimeout) SafeToRetry() bool { +func (e *errTimeout) SafeToRetry() bool { return SafeToRetry(e.err) } -func (e *ErrTimeout) Unwrap() error { +func (e *errTimeout) Unwrap() error { return e.err } @@ -163,9 +163,9 @@ func (e *contextAlreadyDoneError) Unwrap() error { return e.err } -// newContextAlreadyDoneError double-wraps a context error in `contextAlreadyDoneError` and `ErrTimeout`. +// newContextAlreadyDoneError double-wraps a context error in `contextAlreadyDoneError` and `errTimeout`. func newContextAlreadyDoneError(ctx context.Context) (err error) { - return &ErrTimeout{&contextAlreadyDoneError{err: ctx.Err()}} + return &errTimeout{&contextAlreadyDoneError{err: ctx.Err()}} } type writeError struct { diff --git a/pgconn.go b/pgconn.go index 74e24257..a17a108d 100644 --- a/pgconn.go +++ b/pgconn.go @@ -219,7 +219,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if err != nil { var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { - err = &ErrTimeout{err: err} + err = &errTimeout{err: err} } return nil, &connectError{config: config, msg: "dial error", err: err} } @@ -470,7 +470,7 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { if !(isNetErr && netErr.Timeout()) { pgConn.asyncClose() } else if isNetErr && netErr.Timeout() { - err = &ErrTimeout{err: err} + err = &errTimeout{err: err} } return nil, err @@ -490,7 +490,7 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { if !(isNetErr && netErr.Timeout()) { pgConn.asyncClose() } else if isNetErr && netErr.Timeout() { - err = &ErrTimeout{err: err} + err = &errTimeout{err: err} } return nil, err From 5b7c6a3c8e9f0191a7383abb3440f3656a1efdc4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 10 Jul 2021 09:54:24 -0500 Subject: [PATCH 278/290] Upgrade to pgproto3 v2.1.1 --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 57f773b1..dad81ebe 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/jackc/pgio v1.0.0 github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgproto3/v2 v2.1.0 + github.com/jackc/pgproto3/v2 v2.1.1 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/stretchr/testify v1.7.0 golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e diff --git a/go.sum b/go.sum index eedcac1b..54405c28 100644 --- a/go.sum +++ b/go.sum @@ -31,6 +31,8 @@ github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1: github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.1.0 h1:h2yg3kjIyAGSZKDijYn1/gXHlYLCwl9ZjEh2PU0yVxE= github.com/jackc/pgproto3/v2 v2.1.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.1.1 h1:7PQ/4gLoqnl87ZxL7xjO0DR5gYuviDCZxQJsUlFW1eI= +github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= From 13d454882b790b8a8fa00e049e9dc2c0e84318fc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 10 Jul 2021 09:54:39 -0500 Subject: [PATCH 279/290] Release v1.9.0 --- CHANGELOG.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c377b3ed..c496ea30 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,12 @@ +# 1.9.0 (July 10, 2021) + +* pgconn.Timeout only is true for errors originating in pgconn (Michael Darr) +* Add defaults for sslcert, sslkey, and sslrootcert (Joshua Brindle) +* Solve issue with 'sslmode=verify-full' when there are multiple hosts (mgoddard) +* Fix default host when parsing URL without host but with port +* Allow dbname query parameter in URL conn string +* Update underlying dependencies + # 1.8.1 (March 25, 2021) * Better connection string sanitization (ip.novikov) From 6996e8d6c546d45bab6f1e8b24c010f40f095e6e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Jul 2021 09:09:22 -0500 Subject: [PATCH 280/290] Context errors returned instead of net.Error The net.Error caused by using SetDeadline to implement context cancellation shouldn't leak. fixes #80 --- errors.go | 10 ++++++++++ pgconn.go | 28 ++++++++++++++-------------- pgconn_test.go | 5 ++++- 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/errors.go b/errors.go index 64401d65..a32b29c9 100644 --- a/errors.go +++ b/errors.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "net" "net/url" "regexp" "strings" @@ -105,6 +106,15 @@ func (e *parseConfigError) Unwrap() error { return e.err } +// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == +// true. Otherwise returns err. +func preferContextOverNetTimeoutError(ctx context.Context, err error) error { + if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil { + return &errTimeout{err: ctx.Err()} + } + return err +} + type pgconnError struct { msg string err error diff --git a/pgconn.go b/pgconn.go index a17a108d..43b13e43 100644 --- a/pgconn.go +++ b/pgconn.go @@ -271,7 +271,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if err, ok := err.(*PgError); ok { return nil, err } - return nil, &connectError{config: config, msg: "failed to receive message", err: err} + return nil, &connectError{config: config, msg: "failed to receive message", err: preferContextOverNetTimeoutError(ctx, err)} } switch msg := msg.(type) { @@ -434,7 +434,10 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa msg, err := pgConn.receiveMessage() if err != nil { - err = &pgconnError{msg: "receive message failed", err: err, safeToRetry: true} + err = &pgconnError{ + msg: "receive message failed", + err: preferContextOverNetTimeoutError(ctx, err), + safeToRetry: true} } return msg, err } @@ -469,8 +472,6 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { isNetErr := errors.As(err, &netErr) if !(isNetErr && netErr.Timeout()) { pgConn.asyncClose() - } else if isNetErr && netErr.Timeout() { - err = &errTimeout{err: err} } return nil, err @@ -489,8 +490,6 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { isNetErr := errors.As(err, &netErr) if !(isNetErr && netErr.Timeout()) { pgConn.asyncClose() - } else if isNetErr && netErr.Timeout() { - err = &errTimeout{err: err} } return nil, err @@ -785,7 +784,7 @@ readloop: msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() - return nil, err + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -888,7 +887,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { if ctx != context.Background() { select { case <-ctx.Done(): - return ctx.Err() + return newContextAlreadyDoneError(ctx) default: } @@ -899,7 +898,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { for { msg, err := pgConn.receiveMessage() if err != nil { - return err + return preferContextOverNetTimeoutError(ctx, err) } switch msg.(type) { @@ -1136,7 +1135,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() - return nil, err + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -1196,7 +1195,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() - return nil, err + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -1255,7 +1254,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() - return nil, err + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -1287,7 +1286,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() - return nil, err + return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { @@ -1329,7 +1328,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) if err != nil { mrr.pgConn.contextWatcher.Unwatch() - mrr.err = err + mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) mrr.closed = true mrr.pgConn.asyncClose() return nil, mrr.err @@ -1536,6 +1535,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error } if err != nil { + err = preferContextOverNetTimeoutError(rr.ctx, err) rr.concludeCommand(nil, err) rr.pgConn.contextWatcher.Unwatch() rr.closed = true diff --git a/pgconn_test.go b/pgconn_test.go index 7ceda791..c20b7425 100644 --- a/pgconn_test.go +++ b/pgconn_test.go @@ -585,6 +585,7 @@ func TestConnExecContextCanceled(t *testing.T) { } err = multiResult.Close() assert.True(t, pgconn.Timeout(err)) + assert.ErrorIs(t, err, context.DeadlineExceeded) assert.True(t, pgConn.IsClosed()) select { case <-pgConn.CleanupDone(): @@ -729,6 +730,7 @@ func TestConnExecParamsCanceled(t *testing.T) { commandTag, err := result.Close() assert.Equal(t, pgconn.CommandTag(nil), commandTag) assert.True(t, pgconn.Timeout(err)) + assert.ErrorIs(t, err, context.DeadlineExceeded) assert.True(t, pgConn.IsClosed()) select { @@ -1289,7 +1291,7 @@ func TestConnWaitForNotificationPrecanceled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() err = pgConn.WaitForNotification(ctx) - require.Equal(t, context.Canceled, err) + require.ErrorIs(t, err, context.Canceled) ensureConnValid(t, pgConn) } @@ -1308,6 +1310,7 @@ func TestConnWaitForNotificationTimeout(t *testing.T) { err = pgConn.WaitForNotification(ctx) cancel() assert.True(t, pgconn.Timeout(err)) + assert.ErrorIs(t, err, context.DeadlineExceeded) ensureConnValid(t, pgConn) } From d89c8390a530599c1ba1b6f68bbb0de092cbd6cb Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Jul 2021 10:25:08 -0500 Subject: [PATCH 281/290] Update dependencies and go mod tidy --- go.mod | 4 ++-- go.sum | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index dad81ebe..6fdd0e97 100644 --- a/go.mod +++ b/go.mod @@ -5,11 +5,11 @@ go 1.12 require ( github.com/jackc/chunkreader/v2 v2.0.1 github.com/jackc/pgio v1.0.0 - github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd + github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgproto3/v2 v2.1.1 github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/stretchr/testify v1.7.0 - golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e + golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 golang.org/x/text v0.3.6 ) diff --git a/go.sum b/go.sum index 54405c28..3c77ee21 100644 --- a/go.sum +++ b/go.sum @@ -15,11 +15,13 @@ github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9 github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= +github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= -github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd h1:eDErF6V/JPJON/B7s68BxwHgfmyOntHJQ8IOaz0x4R8= github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= @@ -29,8 +31,6 @@ github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgproto3/v2 v2.1.0 h1:h2yg3kjIyAGSZKDijYn1/gXHlYLCwl9ZjEh2PU0yVxE= -github.com/jackc/pgproto3/v2 v2.1.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.1.1 h1:7PQ/4gLoqnl87ZxL7xjO0DR5gYuviDCZxQJsUlFW1eI= github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= @@ -87,8 +87,9 @@ golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaE golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= -golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e h1:gsTQYXdTw2Gq7RBsWvlQ91b+aEQ6bXFUngBGuR8sPpI= golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= From 53f5fed36c570f0b5c98d6ec2415658c7b9bd11c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Jul 2021 10:52:26 -0500 Subject: [PATCH 282/290] Release v1.10.0 --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c496ea30..45c02f1e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 1.10.0 (July 24, 2021) + +* net.Timeout errors are no longer returned when a query is canceled via context. A wrapped context error is returned. + # 1.9.0 (July 10, 2021) * pgconn.Timeout only is true for errors originating in pgconn (Michael Darr) From 3bee0c6398156fb4c1c302a0ce7b0b5bd6108ce9 Mon Sep 17 00:00:00 2001 From: Kei Kamikawa Date: Fri, 13 Aug 2021 12:53:24 +0900 Subject: [PATCH 283/290] removed lines to read conn --- pgconn.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/pgconn.go b/pgconn.go index 43b13e43..a1d22394 100644 --- a/pgconn.go +++ b/pgconn.go @@ -578,7 +578,6 @@ func (pgConn *PgConn) Close(ctx context.Context) error { // // See https://github.com/jackc/pgx/issues/637 pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) - pgConn.conn.Read(make([]byte, 1)) return pgConn.conn.Close() } @@ -605,7 +604,6 @@ func (pgConn *PgConn) asyncClose() { pgConn.conn.SetDeadline(deadline) pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) - pgConn.conn.Read(make([]byte, 1)) }() } From 290ee79d1e8d48c3ff1c1381e01ba76d6b71985a Mon Sep 17 00:00:00 2001 From: Rueian Date: Mon, 27 Sep 2021 14:29:53 +0800 Subject: [PATCH 284/290] feat: remove unnecessary pending for CopyInResponse --- pgconn.go | 23 ++--------------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/pgconn.go b/pgconn.go index a1d22394..382ad33c 100644 --- a/pgconn.go +++ b/pgconn.go @@ -1185,27 +1185,6 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co return nil, &writeError{err: err, safeToRetry: n == 0} } - // Read until copy in response or error. - var commandTag CommandTag - var pgErr error - pendingCopyInResponse := true - for pendingCopyInResponse { - msg, err := pgConn.receiveMessage() - if err != nil { - pgConn.asyncClose() - return nil, preferContextOverNetTimeoutError(ctx, err) - } - - switch msg := msg.(type) { - case *pgproto3.CopyInResponse: - pendingCopyInResponse = false - case *pgproto3.ErrorResponse: - pgErr = ErrorResponseToPgError(msg) - case *pgproto3.ReadyForQuery: - return commandTag, pgErr - } - } - // Send copy data abortCopyChan := make(chan struct{}) copyErrChan := make(chan error, 1) @@ -1244,6 +1223,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } }() + var pgErr error var copyErr error for copyErr == nil && pgErr == nil { select { @@ -1280,6 +1260,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } // Read results + var commandTag CommandTag for { msg, err := pgConn.receiveMessage() if err != nil { From 162dc65eff6f037c98baa36f1f4c75658408d65a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 6 Nov 2021 08:57:49 -0500 Subject: [PATCH 285/290] Make ContextWatcher concurrency safe fixes #94 --- internal/ctxwatch/context_watcher.go | 13 +++++++++++-- internal/ctxwatch/context_watcher_test.go | 15 ++++++++++++++- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/internal/ctxwatch/context_watcher.go b/internal/ctxwatch/context_watcher.go index 391f0b79..b39cb3ee 100644 --- a/internal/ctxwatch/context_watcher.go +++ b/internal/ctxwatch/context_watcher.go @@ -2,6 +2,7 @@ package ctxwatch import ( "context" + "sync" ) // ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a @@ -10,8 +11,10 @@ type ContextWatcher struct { onCancel func() onUnwatchAfterCancel func() unwatchChan chan struct{} - watchInProgress bool - onCancelWasCalled bool + + lock sync.Mutex + watchInProgress bool + onCancelWasCalled bool } // NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled. @@ -29,6 +32,9 @@ func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWat // Watch starts watching ctx. If ctx is canceled then the onCancel function passed to NewContextWatcher will be called. func (cw *ContextWatcher) Watch(ctx context.Context) { + cw.lock.Lock() + defer cw.lock.Unlock() + if cw.watchInProgress { panic("Watch already in progress") } @@ -54,6 +60,9 @@ func (cw *ContextWatcher) Watch(ctx context.Context) { // Unwatch stops watching the previously watched context. If the onCancel function passed to NewContextWatcher was // called then onUnwatchAfterCancel will also be called. func (cw *ContextWatcher) Unwatch() { + cw.lock.Lock() + defer cw.lock.Unlock() + if cw.watchInProgress { cw.unwatchChan <- struct{}{} if cw.onCancelWasCalled { diff --git a/internal/ctxwatch/context_watcher_test.go b/internal/ctxwatch/context_watcher_test.go index 6348b729..289606c3 100644 --- a/internal/ctxwatch/context_watcher_test.go +++ b/internal/ctxwatch/context_watcher_test.go @@ -59,7 +59,7 @@ func TestContextWatcherMultipleWatchPanics(t *testing.T) { require.Panics(t, func() { cw.Watch(ctx2) }, "Expected panic when Watch called multiple times") } -func TestContextWatcherUnwatchIsAlwaysSafe(t *testing.T) { +func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) { cw := ctxwatch.NewContextWatcher(func() {}, func() {}) cw.Unwatch() // unwatch when not / never watching @@ -70,6 +70,19 @@ func TestContextWatcherUnwatchIsAlwaysSafe(t *testing.T) { cw.Unwatch() // double unwatch } +func TestContextWatcherUnwatchIsConcurrencySafe(t *testing.T) { + cw := ctxwatch.NewContextWatcher(func() {}, func() {}) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + cw.Watch(ctx) + + go cw.Unwatch() + go cw.Unwatch() + + <-ctx.Done() +} + func TestContextWatcherStress(t *testing.T) { var cancelFuncCalls int64 var cleanupFuncCalls int64 From 141f132ae7e1428ba7bcf519ff618c39a6d07fea Mon Sep 17 00:00:00 2001 From: Georges Varouchas Date: Mon, 8 Nov 2021 21:00:05 +0100 Subject: [PATCH 286/290] add a unit test on LRU context check TestLRUContext highlights the lack of context check when querying for a cached value --- stmtcache/lru_test.go | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/stmtcache/lru_test.go b/stmtcache/lru_test.go index a4108155..f594ceac 100644 --- a/stmtcache/lru_test.go +++ b/stmtcache/lru_test.go @@ -235,6 +235,40 @@ func TestLRUModeDescribe(t *testing.T) { require.Empty(t, fetchServerStatements(t, ctx, conn)) } +func TestLRUContext(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer conn.Close(ctx) + + cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2) + + // test 1 : getting a value for the first time with a cancelled context returns an error + ctx1, cancel1 := context.WithCancel(ctx) + cancel1() + + desc, err := cache.Get(ctx1, "SELECT 1") + require.Error(t, err) + require.Nil(t, desc) + + // test 2 : when querying for the 2nd time a cached value, if the context is canceled return an error + ctx2, cancel2 := context.WithCancel(ctx) + + desc, err = cache.Get(ctx2, "SELECT 2") + require.NoError(t, err) + require.NotNil(t, desc) + + cancel2() + + desc, err = cache.Get(ctx2, "SELECT 2") + require.Error(t, err) + require.Nil(t, desc) +} + func fetchServerStatements(t testing.TB, ctx context.Context, conn *pgconn.PgConn) []string { result := conn.ExecParams(ctx, `select statement from pg_prepared_statements`, nil, nil, nil, nil).Read() require.NoError(t, result.Err) From cd7dcd58025f5936f76170cf8d9d2fa467b3c189 Mon Sep 17 00:00:00 2001 From: Georges Varouchas Date: Mon, 8 Nov 2021 21:00:24 +0100 Subject: [PATCH 287/290] have lru.Get() always check if context is already expired --- stmtcache/lru.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/stmtcache/lru.go b/stmtcache/lru.go index f58f2ac3..a4106457 100644 --- a/stmtcache/lru.go +++ b/stmtcache/lru.go @@ -53,6 +53,14 @@ func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription } } + if ctx != context.Background() { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + } + if el, ok := c.m[sql]; ok { c.l.MoveToFront(el) return el.Value.(*pgconn.StatementDescription), nil From 146268e829bdea59e5381b3199e4e3b5f5388b0b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Nov 2021 04:12:35 -0600 Subject: [PATCH 288/290] Move context test above bad statement cleanup --- stmtcache/lru.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/stmtcache/lru.go b/stmtcache/lru.go index a4106457..90fb76c2 100644 --- a/stmtcache/lru.go +++ b/stmtcache/lru.go @@ -42,6 +42,14 @@ func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU { // Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) { + if ctx != context.Background() { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + } + // flush an outstanding bad statements txStatus := c.conn.TxStatus() if (txStatus == 'I' || txStatus == 'T') && len(c.stmtsToClear) > 0 { @@ -53,14 +61,6 @@ func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription } } - if ctx != context.Background() { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - } - if el, ok := c.m[sql]; ok { c.l.MoveToFront(el) return el.Value.(*pgconn.StatementDescription), nil From 662ecb496ffc8c64f7bfa156694e0fe525a97685 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 20 Nov 2021 09:56:46 -0600 Subject: [PATCH 289/290] Release v1.10.1 --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 45c02f1e..63933a3a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,10 @@ +# 1.10.1 (November 20, 2021) + +* Close without waiting for response (Kei Kamikawa) +* Save waiting for network round-trip in CopyFrom (Rueian) +* Fix concurrency issue with ContextWatcher +* LRU.Get always checks context for cancellation / expiration (Georges Varouchas) + # 1.10.0 (July 24, 2021) * net.Timeout errors are no longer returned when a query is canceled via context. A wrapped context error is returned. From 19ec4d505ffaf0d1fecefb733c722c319e5df081 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Dec 2021 13:51:24 -0600 Subject: [PATCH 290/290] Import to pgx main repo in pgconn subdir --- {.github => pgconn/.github}/workflows/ci.yml | 0 .gitignore => pgconn/.gitignore | 0 CHANGELOG.md => pgconn/CHANGELOG.md | 0 LICENSE => pgconn/LICENSE | 0 README.md => pgconn/README.md | 0 auth_scram.go => pgconn/auth_scram.go | 0 benchmark_test.go => pgconn/benchmark_test.go | 0 {ci => pgconn/ci}/script.bash | 0 {ci => pgconn/ci}/setup_test.bash | 0 config.go => pgconn/config.go | 0 config_test.go => pgconn/config_test.go | 0 defaults.go => pgconn/defaults.go | 0 defaults_windows.go => pgconn/defaults_windows.go | 0 doc.go => pgconn/doc.go | 0 errors.go => pgconn/errors.go | 0 errors_test.go => pgconn/errors_test.go | 0 export_test.go => pgconn/export_test.go | 0 frontend_test.go => pgconn/frontend_test.go | 0 go.mod => pgconn/go.mod | 0 go.sum => pgconn/go.sum | 0 helper_test.go => pgconn/helper_test.go | 0 {internal => pgconn/internal}/ctxwatch/context_watcher.go | 0 {internal => pgconn/internal}/ctxwatch/context_watcher_test.go | 0 pgconn.go => pgconn/pgconn.go | 0 pgconn_stress_test.go => pgconn/pgconn_stress_test.go | 0 pgconn_test.go => pgconn/pgconn_test.go | 0 {stmtcache => pgconn/stmtcache}/lru.go | 0 {stmtcache => pgconn/stmtcache}/lru_test.go | 0 {stmtcache => pgconn/stmtcache}/stmtcache.go | 0 29 files changed, 0 insertions(+), 0 deletions(-) rename {.github => pgconn/.github}/workflows/ci.yml (100%) rename .gitignore => pgconn/.gitignore (100%) rename CHANGELOG.md => pgconn/CHANGELOG.md (100%) rename LICENSE => pgconn/LICENSE (100%) rename README.md => pgconn/README.md (100%) rename auth_scram.go => pgconn/auth_scram.go (100%) rename benchmark_test.go => pgconn/benchmark_test.go (100%) rename {ci => pgconn/ci}/script.bash (100%) rename {ci => pgconn/ci}/setup_test.bash (100%) rename config.go => pgconn/config.go (100%) rename config_test.go => pgconn/config_test.go (100%) rename defaults.go => pgconn/defaults.go (100%) rename defaults_windows.go => pgconn/defaults_windows.go (100%) rename doc.go => pgconn/doc.go (100%) rename errors.go => pgconn/errors.go (100%) rename errors_test.go => pgconn/errors_test.go (100%) rename export_test.go => pgconn/export_test.go (100%) rename frontend_test.go => pgconn/frontend_test.go (100%) rename go.mod => pgconn/go.mod (100%) rename go.sum => pgconn/go.sum (100%) rename helper_test.go => pgconn/helper_test.go (100%) rename {internal => pgconn/internal}/ctxwatch/context_watcher.go (100%) rename {internal => pgconn/internal}/ctxwatch/context_watcher_test.go (100%) rename pgconn.go => pgconn/pgconn.go (100%) rename pgconn_stress_test.go => pgconn/pgconn_stress_test.go (100%) rename pgconn_test.go => pgconn/pgconn_test.go (100%) rename {stmtcache => pgconn/stmtcache}/lru.go (100%) rename {stmtcache => pgconn/stmtcache}/lru_test.go (100%) rename {stmtcache => pgconn/stmtcache}/stmtcache.go (100%) diff --git a/.github/workflows/ci.yml b/pgconn/.github/workflows/ci.yml similarity index 100% rename from .github/workflows/ci.yml rename to pgconn/.github/workflows/ci.yml diff --git a/.gitignore b/pgconn/.gitignore similarity index 100% rename from .gitignore rename to pgconn/.gitignore diff --git a/CHANGELOG.md b/pgconn/CHANGELOG.md similarity index 100% rename from CHANGELOG.md rename to pgconn/CHANGELOG.md diff --git a/LICENSE b/pgconn/LICENSE similarity index 100% rename from LICENSE rename to pgconn/LICENSE diff --git a/README.md b/pgconn/README.md similarity index 100% rename from README.md rename to pgconn/README.md diff --git a/auth_scram.go b/pgconn/auth_scram.go similarity index 100% rename from auth_scram.go rename to pgconn/auth_scram.go diff --git a/benchmark_test.go b/pgconn/benchmark_test.go similarity index 100% rename from benchmark_test.go rename to pgconn/benchmark_test.go diff --git a/ci/script.bash b/pgconn/ci/script.bash similarity index 100% rename from ci/script.bash rename to pgconn/ci/script.bash diff --git a/ci/setup_test.bash b/pgconn/ci/setup_test.bash similarity index 100% rename from ci/setup_test.bash rename to pgconn/ci/setup_test.bash diff --git a/config.go b/pgconn/config.go similarity index 100% rename from config.go rename to pgconn/config.go diff --git a/config_test.go b/pgconn/config_test.go similarity index 100% rename from config_test.go rename to pgconn/config_test.go diff --git a/defaults.go b/pgconn/defaults.go similarity index 100% rename from defaults.go rename to pgconn/defaults.go diff --git a/defaults_windows.go b/pgconn/defaults_windows.go similarity index 100% rename from defaults_windows.go rename to pgconn/defaults_windows.go diff --git a/doc.go b/pgconn/doc.go similarity index 100% rename from doc.go rename to pgconn/doc.go diff --git a/errors.go b/pgconn/errors.go similarity index 100% rename from errors.go rename to pgconn/errors.go diff --git a/errors_test.go b/pgconn/errors_test.go similarity index 100% rename from errors_test.go rename to pgconn/errors_test.go diff --git a/export_test.go b/pgconn/export_test.go similarity index 100% rename from export_test.go rename to pgconn/export_test.go diff --git a/frontend_test.go b/pgconn/frontend_test.go similarity index 100% rename from frontend_test.go rename to pgconn/frontend_test.go diff --git a/go.mod b/pgconn/go.mod similarity index 100% rename from go.mod rename to pgconn/go.mod diff --git a/go.sum b/pgconn/go.sum similarity index 100% rename from go.sum rename to pgconn/go.sum diff --git a/helper_test.go b/pgconn/helper_test.go similarity index 100% rename from helper_test.go rename to pgconn/helper_test.go diff --git a/internal/ctxwatch/context_watcher.go b/pgconn/internal/ctxwatch/context_watcher.go similarity index 100% rename from internal/ctxwatch/context_watcher.go rename to pgconn/internal/ctxwatch/context_watcher.go diff --git a/internal/ctxwatch/context_watcher_test.go b/pgconn/internal/ctxwatch/context_watcher_test.go similarity index 100% rename from internal/ctxwatch/context_watcher_test.go rename to pgconn/internal/ctxwatch/context_watcher_test.go diff --git a/pgconn.go b/pgconn/pgconn.go similarity index 100% rename from pgconn.go rename to pgconn/pgconn.go diff --git a/pgconn_stress_test.go b/pgconn/pgconn_stress_test.go similarity index 100% rename from pgconn_stress_test.go rename to pgconn/pgconn_stress_test.go diff --git a/pgconn_test.go b/pgconn/pgconn_test.go similarity index 100% rename from pgconn_test.go rename to pgconn/pgconn_test.go diff --git a/stmtcache/lru.go b/pgconn/stmtcache/lru.go similarity index 100% rename from stmtcache/lru.go rename to pgconn/stmtcache/lru.go diff --git a/stmtcache/lru_test.go b/pgconn/stmtcache/lru_test.go similarity index 100% rename from stmtcache/lru_test.go rename to pgconn/stmtcache/lru_test.go diff --git a/stmtcache/stmtcache.go b/pgconn/stmtcache/stmtcache.go similarity index 100% rename from stmtcache/stmtcache.go rename to pgconn/stmtcache/stmtcache.go