From 9538d15c29005e5044da6ba3f4c8ff06daec1278 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Mon, 3 Jun 2019 23:51:48 +0300 Subject: [PATCH 01/29] Draft of connection writable checking Signed-off-by: Artemiy Ryabinkov --- conn.go | 74 +++++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 64 insertions(+), 10 deletions(-) diff --git a/conn.go b/conn.go index cb24748c..e0169d6d 100644 --- a/conn.go +++ b/conn.go @@ -89,6 +89,13 @@ type ConnConfig struct { // used by default. The same functionality can be controlled on a per query // basis by setting QueryExOptions.SimpleProtocol. PreferSimpleProtocol bool + + // TargetSessionAttr allows to specify which servers are accepted for this connection. + // "any", meaning that any kind of servers can be accepted. This is as well the default value. + // "read-write", to disallow connections to read-only servers, hot standbys for example. + // @see https://www.postgresql.org/message-id/CAD__OuhqPRGpcsfwPHz_PDqAGkoqS1UvnUnOnAB-LBWBW=wu4A@mail.gmail.com + // @see https://paquier.xyz/postgresql-2/postgres-10-libpq-read-write/ + TargetSessionAttrs string } func (cc *ConnConfig) networkAddress() (network, address string) { @@ -262,8 +269,15 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) } } + if c.config.TargetSessionAttrs != "" && + c.config.TargetSessionAttrs != "any" && + c.config.TargetSessionAttrs != "read-write" { + return nil, errors.New("invalid value for target_session_attrs, expected \"any\" or \"read-write\"") + } + c.onNotice = config.OnNotice + // TODO: Parse multi-hosts network, address := c.config.networkAddress() if c.config.Dial == nil { d := defaultDialer() @@ -273,22 +287,58 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) if c.shouldLog(LogLevelInfo) { c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"network": network, "address": address}) } - err = c.connect(config, network, address, config.TLSConfig) - if err != nil && config.UseFallbackTLS { - if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err}) + + // TODO: Start loop for all hosts [host0 .. hostN] + for { + err = c.connect(config, network, address, config.TLSConfig) + if err != nil && config.UseFallbackTLS { + if c.shouldLog(LogLevelInfo) { + c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err}) + } + err = c.connect(config, network, address, config.FallbackTLSConfig) } - err = c.connect(config, network, address, config.FallbackTLSConfig) + + if err != nil { + if c.shouldLog(LogLevelError) { + c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err}) + } + // TODO: Collect error + continue + } + + err = c.writeable() + if err != nil { + // TODO: Log info about not writable host + // TODO: Collect error + continue + } + + return c, nil } + + // TODO: Return collected errors + return nil, nil +} + +func (c *Conn) writeable() error { + if c.config.TargetSessionAttrs == "" || c.config.TargetSessionAttrs == "any" { + return nil + } + + var st string + err := c.QueryRowEx(context.Background(), "SHOW transaction_read_only", &QueryExOptions{SimpleProtocol: true}). + Scan(st) + if err != nil { - if c.shouldLog(LogLevelError) { - c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err}) - } - return nil, err + return errors.Wrap(err, "failed to fetch transaction_read_only state") } - return c, nil + if st == "on" { + return errors.New("writable connection disabled by server") + } + + return nil } func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) { @@ -709,6 +759,10 @@ func (old ConnConfig) Merge(other ConnConfig) ConnConfig { cc.PreferSimpleProtocol = old.PreferSimpleProtocol || other.PreferSimpleProtocol + if other.TargetSessionAttrs != "" { + cc.TargetSessionAttrs = other.TargetSessionAttrs + } + cc.RuntimeParams = make(map[string]string) for k, v := range old.RuntimeParams { cc.RuntimeParams[k] = v From 9f031bb8f9bea60bd51ebc1cbaaa8e5db779b191 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Sun, 16 Jun 2019 14:03:43 +0300 Subject: [PATCH 02/29] Return net.Addr from networkAddress Signed-off-by: Artemiy Ryabinkov --- conn.go | 82 +++++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 59 insertions(+), 23 deletions(-) diff --git a/conn.go b/conn.go index e0169d6d..323875a5 100644 --- a/conn.go +++ b/conn.go @@ -98,20 +98,24 @@ type ConnConfig struct { TargetSessionAttrs string } -func (cc *ConnConfig) networkAddress() (network, address string) { - network = "tcp" - address = fmt.Sprintf("%s:%d", cc.Host, cc.Port) - // See if host is a valid path, if yes connect with a socket +func (cc *ConnConfig) networkAddress() net.Addr { + // See if host is a valid path, if yes connect with a unix socket if _, err := os.Stat(cc.Host); err == nil { // For backward compatibility accept socket file paths -- but directories are now preferred - network = "unix" - address = cc.Host + network := "unix" + address := cc.Host + if !strings.Contains(address, "/.s.PGSQL.") { - address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10) + address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatUint(uint64(cc.Port), 10) } + + return &net.UnixAddr{Name: address, Net: network} } - return network, address + return &net.TCPAddr{ + Port: int(cc.Port), + IP: net.ParseIP(cc.Host), + } } // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. @@ -277,48 +281,80 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) c.onNotice = config.OnNotice - // TODO: Parse multi-hosts - network, address := c.config.networkAddress() if c.config.Dial == nil { d := defaultDialer() c.config.Dial = d.Dial } - if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"network": network, "address": address}) - } + // TODO: Parse multi-hosts + hostAddr := c.config.networkAddress() + + addrs := []net.Addr{hostAddr} + + var errs []error + for _, addr := range addrs { + network, address := addr.Network(), addr.String() + + if c.shouldLog(LogLevelInfo) { + c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{ + "network": network, + "address": address, + }) + } - // TODO: Start loop for all hosts [host0 .. hostN] - for { err = c.connect(config, network, address, config.TLSConfig) if err != nil && config.UseFallbackTLS { if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err}) + c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{ + "err": err, + "network": network, + "address": address, + }) } err = c.connect(config, network, address, config.FallbackTLSConfig) } if err != nil { if c.shouldLog(LogLevelError) { - c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err}) + c.log(LogLevelError, "connect failed", map[string]interface{}{ + "err": err, + "network": network, + "address": address, + }) } - // TODO: Collect error + + errs = append(errs, err) continue } err = c.writeable() if err != nil { - // TODO: Log info about not writable host - // TODO: Collect error + if c.shouldLog(LogLevelInfo) { + c.log(LogLevelInfo, "host is not writable", map[string]interface{}{ + "err": err, + "network": network, + "address": address, + }) + } + + errs = append(errs, err) continue } return c, nil } + // To keep backwards, if specific error type expected. + if len(errs) == 1 { + return nil, errs[0] + } - // TODO: Return collected errors - return nil, nil + var errmsg string + for _, err := range errs { + errmsg += ";" + err.Error() + } + + return nil, errors.New(errmsg) } func (c *Conn) writeable() error { @@ -331,7 +367,7 @@ func (c *Conn) writeable() error { Scan(st) if err != nil { - return errors.Wrap(err, "failed to fetch transaction_read_only state") + return errors.Wrap(err, "failed to fetch \"transaction_read_only\" state") } if st == "on" { From 25e1f674a2a02cacc2db24924545539366fac825 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Sun, 16 Jun 2019 14:36:54 +0300 Subject: [PATCH 03/29] Fix doCancel with addr from networkAddress Signed-off-by: Artemiy Ryabinkov --- conn.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/conn.go b/conn.go index 323875a5..287ce7db 100644 --- a/conn.go +++ b/conn.go @@ -1741,8 +1741,8 @@ func quoteIdentifier(s string) string { } func doCancel(c *Conn) error { - network, address := c.config.networkAddress() - cancelConn, err := c.config.Dial(network, address) + addr := c.config.networkAddress() + cancelConn, err := c.config.Dial(addr.Network(), addr.String()) if err != nil { return err } From 6ec815a7489b64856307091567f59635b9a65bbe Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Tue, 18 Jun 2019 16:02:09 +0300 Subject: [PATCH 04/29] Support Multiple Hosts in ConnConfig Signed-off-by: Artemiy Ryabinkov --- conn.go | 146 ++++++++++++++++++++++++++++++------ conn_config_test.go.example | 3 + conn_config_test.go.travis | 2 + conn_test.go | 101 +++++++++++++++++++++++++ pgpass.go | 1 + 5 files changed, 232 insertions(+), 21 deletions(-) diff --git a/conn.go b/conn.go index 287ce7db..a6ea78f1 100644 --- a/conn.go +++ b/conn.go @@ -63,10 +63,29 @@ type DialFunc func(network, addr string) (net.Conn, error) // ConnConfig contains all the options used to establish a connection. type ConnConfig struct { - Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) - Port uint16 // default: 5432 - Database string - User string // default: OS user name + // Name of host to connect to. (e.g. localhost) + // If a host name begins with a slash, it specifies Unix-domain communication + // rather than TCP/IP communication; the value is the name of the directory + // in which the socket file is stored. (e.g. /private/tmp) + // The default behavior when host is not specified, or is empty, is to connect to localhost. + // + // A comma-separated list of host names is also accepted, + // in which case each host name in the list is tried in order; + // an empty item in the list selects the default behavior as explained above. + // @see https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS + Host string + + // Port number to connect to at the server host, + // or socket file name extension for Unix-domain connections. + // An empty or zero value, specifies the default port number — 5432. + // + // If multiple hosts were given in the Host parameter, then + // this parameter may specify a single port number to be used for all hosts, + // or for those that haven't port explicitly defined. + Port uint16 + Database string + User string // default: OS user name + // TODO: Allow password to be different for each host/port pair if a password file is used Password string TLSConfig *tls.Config // config for TLS connection -- nil disables TLS UseFallbackTLS bool // Try FallbackTLSConfig if connecting with TLSConfig fails. Used for preferring TLS, but allowing unencrypted, or vice-versa @@ -95,10 +114,34 @@ type ConnConfig struct { // "read-write", to disallow connections to read-only servers, hot standbys for example. // @see https://www.postgresql.org/message-id/CAD__OuhqPRGpcsfwPHz_PDqAGkoqS1UvnUnOnAB-LBWBW=wu4A@mail.gmail.com // @see https://paquier.xyz/postgresql-2/postgres-10-libpq-read-write/ + // + // The query SHOW transaction_read_only will be sent upon any successful connection; + // if it returns on, the connection will be closed. + // If multiple hosts were specified in the connection string, + // any remaining servers will be tried just as if the connection attempt had failed. + // The default value of this parameter, any, regards all connections as acceptable. TargetSessionAttrs string } -func (cc *ConnConfig) networkAddress() net.Addr { +// hostAddr represents network end point defined as hostname or IP + port. +type hostAddr struct { + Host string + Port uint16 +} + +// Network returns the address's network name, "tcp". +func (a *hostAddr) Network() string { return "tcp" } + +// String implements net.Addr String method. +func (a *hostAddr) String() string { + if a == nil { + return "" + } + + return net.JoinHostPort(a.Host, strconv.Itoa(int(a.Port))) +} + +func (cc *ConnConfig) networkAddresses() ([]net.Addr, error) { // See if host is a valid path, if yes connect with a unix socket if _, err := os.Stat(cc.Host); err == nil { // For backward compatibility accept socket file paths -- but directories are now preferred @@ -109,13 +152,50 @@ func (cc *ConnConfig) networkAddress() net.Addr { address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatUint(uint64(cc.Port), 10) } - return &net.UnixAddr{Name: address, Net: network} + addrs := []net.Addr{ + &net.UnixAddr{Name: address, Net: network}, + } + + return addrs, nil } - return &net.TCPAddr{ - Port: int(cc.Port), - IP: net.ParseIP(cc.Host), + if cc.Host == "" { + addrs := []net.Addr{ + &net.TCPAddr{Port: int(cc.Port)}, + } + + return addrs, nil } + + var addrs []net.Addr + + hostports := strings.Split(cc.Host, ",") + for i, hostport := range hostports { + if hostport == "" { + return nil, fmt.Errorf("multi-host part %d is empty, at least host or port must be defined", i) + } + + // It's not possible to use net.TCPAddr here, cuz host may be hostname. + addr := hostAddr{ + Host: hostport, + Port: cc.Port, + } + + pos := strings.IndexByte(hostport, ':') + if pos != -1 { + p, err := strconv.ParseUint(hostport[pos+1:], 10, 16) + if err != nil { + return nil, fmt.Errorf("multi-host part %d (%s) has invalid port format", i, hostport) + } + + addr.Host = hostport[:pos] + addr.Port = uint16(p) + } + + addrs = append(addrs, &addr) + } + + return addrs, nil } // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. @@ -156,6 +236,10 @@ type Conn struct { ConnInfo *pgtype.ConnInfo frontend *pgproto3.Frontend + + // In case of Multiple Hosts we need to know what addr was used to connect. + // This address will be used to send a cancellation request. + addr net.Addr } // PreparedStatement is a description of a prepared statement @@ -286,10 +370,10 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) c.config.Dial = d.Dial } - // TODO: Parse multi-hosts - hostAddr := c.config.networkAddress() - - addrs := []net.Addr{hostAddr} + addrs, err := c.config.networkAddresses() + if err != nil { + return nil, err + } var errs []error for _, addr := range addrs { @@ -323,12 +407,23 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) }) } + // On any auth errors return immediately + if pgErr, ok := err.(PgError); ok { + switch pgErr.Code { + // @see: https://www.postgresql.org/docs/current/errcodes-appendix.html + case "28000", "28P01": // Invalid Authorization Specification + return nil, pgErr + } + } + errs = append(errs, err) continue } - err = c.writeable() + err = c.writable() if err != nil { + c.die(err) + if c.shouldLog(LogLevelInfo) { c.log(LogLevelInfo, "host is not writable", map[string]interface{}{ "err": err, @@ -341,6 +436,8 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) continue } + c.addr = addr + return c, nil } @@ -351,30 +448,34 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) var errmsg string for _, err := range errs { - errmsg += ";" + err.Error() + errmsg += "; " + err.Error() } return nil, errors.New(errmsg) } -func (c *Conn) writeable() error { +func (c *Conn) writable() error { if c.config.TargetSessionAttrs == "" || c.config.TargetSessionAttrs == "any" { return nil } var st string err := c.QueryRowEx(context.Background(), "SHOW transaction_read_only", &QueryExOptions{SimpleProtocol: true}). - Scan(st) + Scan(&st) if err != nil { return errors.Wrap(err, "failed to fetch \"transaction_read_only\" state") } - if st == "on" { + switch st { + case "on": return errors.New("writable connection disabled by server") + case "off": + // If transaction_read_only = off, then connection is writable. + return nil } - return nil + return errors.New("unexpected \"transaction_read_only\" status") } func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) { @@ -958,6 +1059,8 @@ func ParseDSN(s string) (ConnConfig, error) { // ParseConnectionString parses either a URI or a DSN connection string. // see ParseURI and ParseDSN for details. func ParseConnectionString(s string) (ConnConfig, error) { + // TODO: Multiple Hosts support + // @see: https://www.postgresql.org/docs/10/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS if u, err := url.Parse(s); err == nil && u.Scheme != "" { return ParseURI(s) } @@ -981,6 +1084,8 @@ func ParseConnectionString(s string) (ConnConfig, error) { // PGSSLROOTCERT // PGAPPNAME // PGCONNECT_TIMEOUT +// TODO: PGTARGETSESSIONATTRS support +// @see: https://www.postgresql.org/docs/10/libpq-envars.html // // Important TLS Security Notes: // ParseEnvLibpq tries to match libpq behavior with regard to PGSSLMODE. This @@ -1741,8 +1846,7 @@ func quoteIdentifier(s string) string { } func doCancel(c *Conn) error { - addr := c.config.networkAddress() - cancelConn, err := c.config.Dial(addr.Network(), addr.String()) + cancelConn, err := c.config.Dial(c.addr.Network(), c.addr.String()) if err != nil { return err } diff --git a/conn_config_test.go.example b/conn_config_test.go.example index 096e1354..620b0ea1 100644 --- a/conn_config_test.go.example +++ b/conn_config_test.go.example @@ -7,6 +7,7 @@ import ( // "go/build" // "io/ioutil" // "path" + // "net" "github.com/jackc/pgx" ) @@ -14,6 +15,7 @@ import ( var defaultConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} // To skip tests for specific connection / authentication types set that connection param to nil +var multihostConnConfig *pgx.ConnConfig = nil var tcpConnConfig *pgx.ConnConfig = nil var unixSocketConnConfig *pgx.ConnConfig = nil var md5ConnConfig *pgx.ConnConfig = nil @@ -24,6 +26,7 @@ var customDialerConnConfig *pgx.ConnConfig = nil var replicationConnConfig *pgx.ConnConfig = nil var cratedbConnConfig *pgx.ConnConfig = nil +// var multihostConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "2.2.2.2:1,127.0.0.1,4.2.4.2", User: "pgx_md5", Password: "secret", Database: "pgx_test", Dial: (&net.Dialer{KeepAlive: 5 * time.Minute, Timeout: 100 * time.Millisecond}).Dial} // var tcpConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} // var unixSocketConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "/private/tmp", User: "pgx_none", Database: "pgx_test"} // var md5ConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} diff --git a/conn_config_test.go.travis b/conn_config_test.go.travis index cf29a743..738f1112 100644 --- a/conn_config_test.go.travis +++ b/conn_config_test.go.travis @@ -5,9 +5,11 @@ import ( "github.com/jackc/pgx" "os" "strconv" + "net" ) var defaultConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} +var multihostConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "2.2.2.2:1,127.0.0.1,4.2.4.2", User: "pgx_md5", Password: "secret", Database: "pgx_test", Dial: (&net.Dialer{KeepAlive: 5 * time.Minute, Timeout: 100 * time.Millisecond}).Dial} var tcpConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} var unixSocketConnConfig = &pgx.ConnConfig{Host: "/var/run/postgresql", User: "postgres", Database: "pgx_test"} var md5ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} diff --git a/conn_test.go b/conn_test.go index 6ca00c6d..14efbeca 100644 --- a/conn_test.go +++ b/conn_test.go @@ -84,6 +84,107 @@ func TestConnect(t *testing.T) { } } + +func TestConnectWithMultiHost(t *testing.T) { + t.Parallel() + + if multihostConnConfig == nil { + t.Skip("Skipping due to undefined multihostConnConfig") + } + + conn, err := pgx.Connect(*multihostConnConfig) + if err != nil { + t.Fatalf("Unable to establish connection: %v", err) + } + + if _, present := conn.RuntimeParams["server_version"]; !present { + t.Error("Runtime parameters not stored") + } + + if conn.PID() == 0 { + t.Error("Backend PID not stored") + } + + var currentDB string + err = conn.QueryRow("select current_database()").Scan(¤tDB) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if currentDB != defaultConnConfig.Database { + t.Errorf("Did not connect to specified database (%v)", defaultConnConfig.Database) + } + + var user string + err = conn.QueryRow("select current_user").Scan(&user) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if user != defaultConnConfig.User { + t.Errorf("Did not connect as specified user (%v)", defaultConnConfig.User) + } + + err = conn.Close() + if err != nil { + t.Fatal("Unable to close connection") + } +} + + +func TestConnectWithMultiHostWritable(t *testing.T) { + t.Parallel() + + if multihostConnConfig == nil { + t.Skip("Skipping due to undefined multihostConnConfig") + } + + connConfig := *multihostConnConfig + connConfig.TargetSessionAttrs = "read-write" + + conn := mustConnect(t, connConfig) + defer closeConn(t, conn) + + if _, present := conn.RuntimeParams["server_version"]; !present { + t.Error("Runtime parameters not stored") + } + + if conn.PID() == 0 { + t.Error("Backend PID not stored") + } + + var currentDB string + err := conn.QueryRow("select current_database()").Scan(¤tDB) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if currentDB != defaultConnConfig.Database { + t.Errorf("Did not connect to specified database (%v)", defaultConnConfig.Database) + } + + var user string + err = conn.QueryRow("select current_user").Scan(&user) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if user != defaultConnConfig.User { + t.Errorf("Did not connect as specified user (%v)", defaultConnConfig.User) + } + + var st string + err = conn.QueryRow("SHOW transaction_read_only").Scan(&st) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + + if st == "on" { + t.Error("Connection is not writable") + } + + err = conn.Close() + if err != nil { + t.Fatal("Unable to close connection") + } +} + func TestConnectWithUnixSocketDirectory(t *testing.T) { t.Parallel() diff --git a/pgpass.go b/pgpass.go index 34b9bdf5..ff97e5f0 100644 --- a/pgpass.go +++ b/pgpass.go @@ -57,6 +57,7 @@ func parsepgpass(line, cfgHost, cfgPort, cfgDatabase, cfgUsername string) *strin return &parts[4] } +// TODO: Multi-host support func pgpass(cfg *ConnConfig) (found bool) { passfile := os.Getenv("PGPASSFILE") if passfile == "" { From 2837818b67f39d5dfb49a6b37f7d8a23ef263896 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Tue, 18 Jun 2019 17:09:38 +0300 Subject: [PATCH 05/29] fix typo Signed-off-by: Artemiy Ryabinkov --- conn.go | 2 +- conn_config_test.go.example | 1 + conn_config_test.go.travis | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/conn.go b/conn.go index a6ea78f1..153c7a3d 100644 --- a/conn.go +++ b/conn.go @@ -441,7 +441,7 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) return c, nil } - // To keep backwards, if specific error type expected. + // To keep backwards compatibility, if specific error type expected. if len(errs) == 1 { return nil, errs[0] } diff --git a/conn_config_test.go.example b/conn_config_test.go.example index 620b0ea1..2ca84ac3 100644 --- a/conn_config_test.go.example +++ b/conn_config_test.go.example @@ -8,6 +8,7 @@ import ( // "io/ioutil" // "path" // "net" + // "time" "github.com/jackc/pgx" ) diff --git a/conn_config_test.go.travis b/conn_config_test.go.travis index 738f1112..fbfb5252 100644 --- a/conn_config_test.go.travis +++ b/conn_config_test.go.travis @@ -6,6 +6,7 @@ import ( "os" "strconv" "net" + "time" ) var defaultConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} From bcb2afe2be3d755f0ca53f3df0b262f3ca64096f Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Wed, 10 Jul 2019 22:59:17 +0300 Subject: [PATCH 06/29] TargetSessionAttrs as custom type Signed-off-by: Artemiy Ryabinkov --- conn.go | 28 ++++++++++++++++++++++------ conn_test.go | 2 +- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/conn.go b/conn.go index 153c7a3d..f0cf20be 100644 --- a/conn.go +++ b/conn.go @@ -61,6 +61,24 @@ type NoticeHandler func(*Conn, *Notice) // DialFunc is a function that can be used to connect to a PostgreSQL server type DialFunc func(network, addr string) (net.Conn, error) +// TargetSessionType represents target session attrs configuration parameter. +type TargetSessionType string + +// Block enumerates available values for TargetSessionType. +const ( + AnyTargetSession = "any" + ReadWriteTargetSession = "read-write" +) + +func (t TargetSessionType) isValid() error { + switch t { + case "", AnyTargetSession, ReadWriteTargetSession: + return nil + } + + return errors.New("invalid value for target_session_attrs, expected \"any\" or \"read-write\"") +} + // ConnConfig contains all the options used to establish a connection. type ConnConfig struct { // Name of host to connect to. (e.g. localhost) @@ -120,7 +138,7 @@ type ConnConfig struct { // If multiple hosts were specified in the connection string, // any remaining servers will be tried just as if the connection attempt had failed. // The default value of this parameter, any, regards all connections as acceptable. - TargetSessionAttrs string + TargetSessionAttrs TargetSessionType } // hostAddr represents network end point defined as hostname or IP + port. @@ -357,10 +375,8 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) } } - if c.config.TargetSessionAttrs != "" && - c.config.TargetSessionAttrs != "any" && - c.config.TargetSessionAttrs != "read-write" { - return nil, errors.New("invalid value for target_session_attrs, expected \"any\" or \"read-write\"") + if err := c.config.TargetSessionAttrs.isValid(); err != nil { + return nil, err } c.onNotice = config.OnNotice @@ -455,7 +471,7 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) } func (c *Conn) writable() error { - if c.config.TargetSessionAttrs == "" || c.config.TargetSessionAttrs == "any" { + if c.config.TargetSessionAttrs == "" || c.config.TargetSessionAttrs == AnyTargetSession { return nil } diff --git a/conn_test.go b/conn_test.go index 14efbeca..28bfe48b 100644 --- a/conn_test.go +++ b/conn_test.go @@ -138,7 +138,7 @@ func TestConnectWithMultiHostWritable(t *testing.T) { } connConfig := *multihostConnConfig - connConfig.TargetSessionAttrs = "read-write" + connConfig.TargetSessionAttrs = pgx.ReadWriteTargetSession conn := mustConnect(t, connConfig) defer closeConn(t, conn) From 7d4215cb88d63e43baa8b8735bdafbfbd673c8bd Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Wed, 10 Jul 2019 23:16:46 +0300 Subject: [PATCH 07/29] fix error message building from errors array on connection establishing Signed-off-by: Artemiy Ryabinkov --- conn.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/conn.go b/conn.go index f0cf20be..e570249e 100644 --- a/conn.go +++ b/conn.go @@ -462,12 +462,12 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) return nil, errs[0] } - var errmsg string + errmsgs := make([]string, len(errs)) for _, err := range errs { - errmsg += "; " + err.Error() + errmsgs = append(errmsgs, err.Error()) } - return nil, errors.New(errmsg) + return nil, errors.New(strings.Join(errmsgs, ";")) } func (c *Conn) writable() error { From 75b4ba635c0224e046dc3c3d5e8d5d30c5b65d61 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Thu, 11 Jul 2019 00:16:58 +0300 Subject: [PATCH 08/29] try to improve readability of writable checking Signed-off-by: Artemiy Ryabinkov --- conn.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/conn.go b/conn.go index e570249e..975cf337 100644 --- a/conn.go +++ b/conn.go @@ -79,6 +79,10 @@ func (t TargetSessionType) isValid() error { return errors.New("invalid value for target_session_attrs, expected \"any\" or \"read-write\"") } +func (t TargetSessionType) writableRequired() bool { + return t == ReadWriteTargetSession +} + // ConnConfig contains all the options used to establish a connection. type ConnConfig struct { // Name of host to connect to. (e.g. localhost) @@ -436,7 +440,7 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) continue } - err = c.writable() + err = c.checkWritable() if err != nil { c.die(err) @@ -470,8 +474,8 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) return nil, errors.New(strings.Join(errmsgs, ";")) } -func (c *Conn) writable() error { - if c.config.TargetSessionAttrs == "" || c.config.TargetSessionAttrs == AnyTargetSession { +func (c *Conn) checkWritable() error { + if !c.config.TargetSessionAttrs.writableRequired() { return nil } @@ -485,7 +489,7 @@ func (c *Conn) writable() error { switch st { case "on": - return errors.New("writable connection disabled by server") + return errors.New("writable transactions disabled by server") case "off": // If transaction_read_only = off, then connection is writable. return nil From 18189fafd54ca2b678681a0550d632d1a5434f2c Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Thu, 11 Jul 2019 20:28:04 +0300 Subject: [PATCH 09/29] ParseConnectionString supports Multi-Hosts Signed-off-by: Artemiy Ryabinkov --- conn.go | 123 +++++++++++++++++++++++++++++++++++++++++---------- conn_test.go | 76 +++++++++++++++++++++++++++++++ 2 files changed, 176 insertions(+), 23 deletions(-) diff --git a/conn.go b/conn.go index 975cf337..fd134461 100644 --- a/conn.go +++ b/conn.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "io/ioutil" + "math" "net" "net/url" "os" @@ -947,16 +948,26 @@ func ParseURI(uri string) (ConnConfig, error) { cp.Password, _ = url.User.Password() } - parts := strings.SplitN(url.Host, ":", 2) - cp.Host = parts[0] - if len(parts) == 2 { - p, err := strconv.ParseUint(parts[1], 10, 16) - if err != nil { - return cp, err + hasMuliHosts := strings.IndexByte(url.Host, ',') != -1 + if !hasMuliHosts { + parts := strings.SplitN(url.Host, ":", 2) + cp.Host = parts[0] + if len(parts) == 2 { + p, err := strconv.ParseUint(parts[1], 10, 16) + if err != nil { + return cp, err + } + cp.Port = uint16(p) } - cp.Port = uint16(p) + } else { + cp.Host = url.Host } + cp.Database = strings.TrimLeft(url.Path, "/") + cp.TargetSessionAttrs = TargetSessionType(url.Query().Get("target_session_attrs")) + if err := cp.TargetSessionAttrs.isValid(); err != nil { + return cp, err + } if pgtimeout := url.Query().Get("connect_timeout"); pgtimeout != "" { timeout, err := strconv.ParseInt(pgtimeout, 10, 64) @@ -980,11 +991,12 @@ func ParseURI(uri string) (ConnConfig, error) { } ignoreKeys := map[string]struct{}{ - "connect_timeout": {}, - "sslcert": {}, - "sslkey": {}, - "sslmode": {}, - "sslrootcert": {}, + "connect_timeout": {}, + "sslcert": {}, + "sslkey": {}, + "sslmode": {}, + "sslrootcert": {}, + "target_session_attrs": {}, } cp.RuntimeParams = make(map[string]string) @@ -1029,6 +1041,7 @@ func ParseDSN(s string) (ConnConfig, error) { cp.RuntimeParams = make(map[string]string) + var hostval, portval string for _, b := range m { switch b[1] { case "user": @@ -1036,13 +1049,9 @@ func ParseDSN(s string) (ConnConfig, error) { case "password": cp.Password = b[2] case "host": - cp.Host = b[2] + hostval = b[2] case "port": - p, err := strconv.ParseUint(b[2], 10, 16) - if err != nil { - return cp, err - } - cp.Port = uint16(p) + portval = b[2] case "dbname": cp.Database = b[2] case "sslmode": @@ -1061,26 +1070,94 @@ func ParseDSN(s string) (ConnConfig, error) { d := defaultDialer() d.Timeout = time.Duration(timeout) * time.Second cp.Dial = d.Dial + case "target_session_attrs": + cp.TargetSessionAttrs = TargetSessionType(b[2]) + if err := cp.TargetSessionAttrs.isValid(); err != nil { + return cp, err + } default: cp.RuntimeParams[b[1]] = b[2] } } - err := configTLS(tlsArgs, &cp) + host, port, err := parseHostPortDSN(hostval, portval) if err != nil { return cp, err } + + cp.Host, cp.Port = host, port + + err = configTLS(tlsArgs, &cp) + if err != nil { + return cp, err + } + if cp.Password == "" { pgpass(&cp) } + return cp, nil } -// ParseConnectionString parses either a URI or a DSN connection string. -// see ParseURI and ParseDSN for details. +func parseHostPortDSN(hostval, portval string) (host string, port uint16, err error) { + if portval == "" { + return hostval, 0, nil + } + + hosts := strings.Split(hostval, ",") + ports := strings.Split(portval, ",") + + if len(ports) == 1 { + port, err := parsePort(portval) + if err != nil { + return "", 0, errors.Errorf("invalid port: %v", err) + } + + return hostval, port, nil + } + + if len(hosts) != len(ports) { + return "", 0, errors.New("the number of hosts and ports must be the same") + } + + hostports := make([]string, len(hosts)) + for i, host := range hosts { + hostports[i] = host + ":" + ports[i] + } + + return strings.Join(hostports, ","), 0, nil +} + +func parsePort(s string) (uint16, error) { + port, err := strconv.ParseUint(s, 10, 16) + if err != nil { + return 0, err + } + if port < 1 || port > math.MaxUint16 { + return 0, errors.New("outside range") + } + return uint16(port), nil +} + +// ParseConnectionString parses either a URI or a DSN connection string and builds ConnConfig. +// +// # Example DSN +// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca +// +// # Example URL +// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca +// +// ParseConnectionString supports specifying multiple hosts in similar manner to libpq. +// Host and port may include comma separated values that will be tried in order. +// This can be used as part of a high availability system. +// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. +// +// # Example URL +// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb +// +// # Example DSN +// user=jack password=secret host=host1,host2,host3 port=5432,5433,5434 dbname=mydb sslmode=verify-ca func ParseConnectionString(s string) (ConnConfig, error) { - // TODO: Multiple Hosts support - // @see: https://www.postgresql.org/docs/10/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS if u, err := url.Parse(s); err == nil && u.Scheme != "" { return ParseURI(s) } diff --git a/conn_test.go b/conn_test.go index 28bfe48b..7719bec7 100644 --- a/conn_test.go +++ b/conn_test.go @@ -622,6 +622,38 @@ func TestParseURI(t *testing.T) { RuntimeParams: map[string]string{}, }, }, + { + url: "postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "foo.example.com:5432,bar.example.com:5432", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "postgres://jack@localhost,10.10.20.30/mydb?application_name=pgxtest&target_session_attrs=read-write", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost,10.10.20.30", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{ + "application_name": "pgxtest", + }, + TargetSessionAttrs: pgx.ReadWriteTargetSession, + }, + }, } for i, tt := range tests { @@ -748,6 +780,50 @@ func TestParseDSN(t *testing.T) { RuntimeParams: map[string]string{}, }, }, + { + url: "user=jack host=localhost1,localhost2 dbname=mydb connect_timeout=10", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost1,localhost2", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user=jack host=100.200.220.50,localhost43 port=5432,5433 dbname=mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "100.200.220.50:5432,localhost43:5433", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user=jack host=localhost dbname=mydb target_session_attrs=read-write", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + TargetSessionAttrs: pgx.ReadWriteTargetSession, + }, + }, } for i, tt := range tests { From 39cbdf789d3448c56ba394557e8100a143694c56 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Thu, 11 Jul 2019 20:56:44 +0300 Subject: [PATCH 10/29] Support of PGTARGETSESSIONATTRS ENV variable Signed-off-by: Artemiy Ryabinkov --- conn.go | 8 +++++++- pgpass.go | 1 - 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/conn.go b/conn.go index fd134461..76e09e73 100644 --- a/conn.go +++ b/conn.go @@ -1016,6 +1016,7 @@ func ParseURI(uri string) (ConnConfig, error) { if cp.Password == "" { pgpass(&cp) } + return cp, nil } @@ -1181,7 +1182,7 @@ func ParseConnectionString(s string) (ConnConfig, error) { // PGSSLROOTCERT // PGAPPNAME // PGCONNECT_TIMEOUT -// TODO: PGTARGETSESSIONATTRS support +// PGTARGETSESSIONATTRS // @see: https://www.postgresql.org/docs/10/libpq-envars.html // // Important TLS Security Notes: @@ -1228,6 +1229,11 @@ func ParseEnvLibpq() (ConnConfig, error) { } } + cc.TargetSessionAttrs = TargetSessionType(os.Getenv("PGTARGETSESSIONATTRS")) + if err := cc.TargetSessionAttrs.isValid(); err != nil { + return cc, err + } + tlsArgs := configTLSArgs{ sslMode: os.Getenv("PGSSLMODE"), sslKey: os.Getenv("PGSSLKEY"), diff --git a/pgpass.go b/pgpass.go index ff97e5f0..34b9bdf5 100644 --- a/pgpass.go +++ b/pgpass.go @@ -57,7 +57,6 @@ func parsepgpass(line, cfgHost, cfgPort, cfgDatabase, cfgUsername string) *strin return &parts[4] } -// TODO: Multi-host support func pgpass(cfg *ConnConfig) (found bool) { passfile := os.Getenv("PGPASSFILE") if passfile == "" { From f87825cac7b1ae3311a31a2093bcb00065667ba6 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Thu, 11 Jul 2019 21:38:29 +0300 Subject: [PATCH 11/29] remove TODO that PR will not cover Signed-off-by: Artemiy Ryabinkov --- conn.go | 1 - 1 file changed, 1 deletion(-) diff --git a/conn.go b/conn.go index 76e09e73..e195bb2f 100644 --- a/conn.go +++ b/conn.go @@ -108,7 +108,6 @@ type ConnConfig struct { Port uint16 Database string User string // default: OS user name - // TODO: Allow password to be different for each host/port pair if a password file is used Password string TLSConfig *tls.Config // config for TLS connection -- nil disables TLS UseFallbackTLS bool // Try FallbackTLSConfig if connecting with TLSConfig fails. Used for preferring TLS, but allowing unencrypted, or vice-versa From 98acf573cce94af544d41b3e2bbfc9d86b7494cf Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Sat, 13 Jul 2019 21:21:23 +0300 Subject: [PATCH 12/29] fix errors collecting on multi-host Signed-off-by: Artemiy Ryabinkov --- conn.go | 4 +- examples/multihosts/README.md | 25 ++++++++++++ examples/multihosts/main.go | 74 +++++++++++++++++++++++++++++++++++ 3 files changed, 101 insertions(+), 2 deletions(-) create mode 100644 examples/multihosts/README.md create mode 100644 examples/multihosts/main.go diff --git a/conn.go b/conn.go index e195bb2f..7f2ae6a8 100644 --- a/conn.go +++ b/conn.go @@ -467,8 +467,8 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) } errmsgs := make([]string, len(errs)) - for _, err := range errs { - errmsgs = append(errmsgs, err.Error()) + for i, err := range errs { + errmsgs[i] = err.Error() } return nil, errors.New(strings.Join(errmsgs, ";")) diff --git a/examples/multihosts/README.md b/examples/multihosts/README.md new file mode 100644 index 00000000..4b73eb51 --- /dev/null +++ b/examples/multihosts/README.md @@ -0,0 +1,25 @@ +# Description + +This is a sample chat program implemented using PostgreSQL's listen/notify +functionality with pgx. + +Start multiple instances of this program connected to the same database to chat +between them. + +## Connection configuration + +The database connection is configured via the standard PostgreSQL environment variables. + +* PGHOST - defaults to localhost +* PGUSER - defaults to current OS user +* PGPASSWORD - defaults to empty string +* PGDATABASE - defaults to user name + +You can either export them then run chat: + + export PGHOST=/private/tmp + ./chat + +Or you can prefix the chat execution with the environment variables: + + PGHOST=/private/tmp ./chat diff --git a/examples/multihosts/main.go b/examples/multihosts/main.go new file mode 100644 index 00000000..83b16c02 --- /dev/null +++ b/examples/multihosts/main.go @@ -0,0 +1,74 @@ +package main + +import ( + "bufio" + "context" + "fmt" + "os" + + "github.com/jackc/pgx" +) + +var pool *pgx.ConnPool + +func main() { + config, err := pgx.ParseEnvLibpq() + if err != nil { + fmt.Fprintln(os.Stderr, "Unable to parse environment:", err) + os.Exit(1) + } + + pool, err = pgx.NewConnPool(pgx.ConnPoolConfig{ConnConfig: config}) + if err != nil { + fmt.Fprintln(os.Stderr, "Unable to connect to database:", err) + os.Exit(1) + } + + go listen() + + fmt.Println(`Type a message and press enter. + +This message should appear in any other chat instances connected to the same +database. + +Type "exit" to quit.`) + + scanner := bufio.NewScanner(os.Stdin) + for scanner.Scan() { + msg := scanner.Text() + if msg == "exit" { + os.Exit(0) + } + + _, err = pool.Exec("select pg_notify('chat', $1)", msg) + if err != nil { + fmt.Fprintln(os.Stderr, "Error sending notification:", err) + os.Exit(1) + } + } + if err := scanner.Err(); err != nil { + fmt.Fprintln(os.Stderr, "Error scanning from stdin:", err) + os.Exit(1) + } +} + +func listen() { + conn, err := pool.Acquire() + if err != nil { + fmt.Fprintln(os.Stderr, "Error acquiring connection:", err) + os.Exit(1) + } + defer pool.Release(conn) + + conn.Listen("chat") + + for { + notification, err := conn.WaitForNotification(context.Background()) + if err != nil { + fmt.Fprintln(os.Stderr, "Error waiting for notification:", err) + os.Exit(1) + } + + fmt.Println("PID:", notification.PID, "Channel:", notification.Channel, "Payload:", notification.Payload) + } +} From a2b647c393b3349c9ddf568d279e7fcd71520f88 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Sat, 13 Jul 2019 22:17:03 +0300 Subject: [PATCH 13/29] drop extra example Signed-off-by: Artemiy Ryabinkov --- conn.go | 2 +- examples/multihosts/README.md | 25 ------------ examples/multihosts/main.go | 74 ----------------------------------- 3 files changed, 1 insertion(+), 100 deletions(-) delete mode 100644 examples/multihosts/README.md delete mode 100644 examples/multihosts/main.go diff --git a/conn.go b/conn.go index 7f2ae6a8..01a572d7 100644 --- a/conn.go +++ b/conn.go @@ -471,7 +471,7 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) errmsgs[i] = err.Error() } - return nil, errors.New(strings.Join(errmsgs, ";")) + return nil, errors.New(strings.Join(errmsgs, "; ")) } func (c *Conn) checkWritable() error { diff --git a/examples/multihosts/README.md b/examples/multihosts/README.md deleted file mode 100644 index 4b73eb51..00000000 --- a/examples/multihosts/README.md +++ /dev/null @@ -1,25 +0,0 @@ -# Description - -This is a sample chat program implemented using PostgreSQL's listen/notify -functionality with pgx. - -Start multiple instances of this program connected to the same database to chat -between them. - -## Connection configuration - -The database connection is configured via the standard PostgreSQL environment variables. - -* PGHOST - defaults to localhost -* PGUSER - defaults to current OS user -* PGPASSWORD - defaults to empty string -* PGDATABASE - defaults to user name - -You can either export them then run chat: - - export PGHOST=/private/tmp - ./chat - -Or you can prefix the chat execution with the environment variables: - - PGHOST=/private/tmp ./chat diff --git a/examples/multihosts/main.go b/examples/multihosts/main.go deleted file mode 100644 index 83b16c02..00000000 --- a/examples/multihosts/main.go +++ /dev/null @@ -1,74 +0,0 @@ -package main - -import ( - "bufio" - "context" - "fmt" - "os" - - "github.com/jackc/pgx" -) - -var pool *pgx.ConnPool - -func main() { - config, err := pgx.ParseEnvLibpq() - if err != nil { - fmt.Fprintln(os.Stderr, "Unable to parse environment:", err) - os.Exit(1) - } - - pool, err = pgx.NewConnPool(pgx.ConnPoolConfig{ConnConfig: config}) - if err != nil { - fmt.Fprintln(os.Stderr, "Unable to connect to database:", err) - os.Exit(1) - } - - go listen() - - fmt.Println(`Type a message and press enter. - -This message should appear in any other chat instances connected to the same -database. - -Type "exit" to quit.`) - - scanner := bufio.NewScanner(os.Stdin) - for scanner.Scan() { - msg := scanner.Text() - if msg == "exit" { - os.Exit(0) - } - - _, err = pool.Exec("select pg_notify('chat', $1)", msg) - if err != nil { - fmt.Fprintln(os.Stderr, "Error sending notification:", err) - os.Exit(1) - } - } - if err := scanner.Err(); err != nil { - fmt.Fprintln(os.Stderr, "Error scanning from stdin:", err) - os.Exit(1) - } -} - -func listen() { - conn, err := pool.Acquire() - if err != nil { - fmt.Fprintln(os.Stderr, "Error acquiring connection:", err) - os.Exit(1) - } - defer pool.Release(conn) - - conn.Listen("chat") - - for { - notification, err := conn.WaitForNotification(context.Background()) - if err != nil { - fmt.Fprintln(os.Stderr, "Error waiting for notification:", err) - os.Exit(1) - } - - fmt.Println("PID:", notification.PID, "Channel:", notification.Channel, "Payload:", notification.Payload) - } -} From 1ecc111e17995b5aba2e0b7b1fd57c616f9172a7 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Sun, 14 Jul 2019 18:29:08 +0300 Subject: [PATCH 14/29] Reuse pool.connInfo for createConnectionUnlocked method Signed-off-by: Artemiy Ryabinkov --- conn_pool.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conn_pool.go b/conn_pool.go index 47a0b391..d43b6337 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -319,7 +319,7 @@ func (p *ConnPool) createConnection() (*Conn, error) { func (p *ConnPool) createConnectionUnlocked() (*Conn, error) { p.inProgressConnects++ p.cond.L.Unlock() - c, err := Connect(p.config) + c, err := connect(p.config, p.connInfo) p.cond.L.Lock() p.inProgressConnects-- From 8e0e1123dfa4f7280ad56da42fa211bb91ea39f4 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Sun, 14 Jul 2019 20:04:55 +0300 Subject: [PATCH 15/29] use deepCopy of connInfo in createConnectionUnlocked method Signed-off-by: Artemiy Ryabinkov --- conn_pool.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conn_pool.go b/conn_pool.go index d43b6337..e8972a0b 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -319,7 +319,7 @@ func (p *ConnPool) createConnection() (*Conn, error) { func (p *ConnPool) createConnectionUnlocked() (*Conn, error) { p.inProgressConnects++ p.cond.L.Unlock() - c, err := connect(p.config, p.connInfo) + c, err := connect(p.config, p.connInfo.DeepCopy()) p.cond.L.Lock() p.inProgressConnects-- From fc020c24ac9590f6547f8ad1d291fc75b4873a84 Mon Sep 17 00:00:00 2001 From: Nicholas Wilson Date: Wed, 24 Jul 2019 12:32:18 +0100 Subject: [PATCH 16/29] Add support for pgtype.UUID to write into any [16]byte type --- pgtype/convert.go | 29 +++++++++++++++++++++++++++++ pgtype/uuid.go | 2 +- pgtype/uuid_test.go | 21 +++++++++++++++++++++ 3 files changed, 51 insertions(+), 1 deletion(-) diff --git a/pgtype/convert.go b/pgtype/convert.go index 5dfb738e..ee6907c4 100644 --- a/pgtype/convert.go +++ b/pgtype/convert.go @@ -163,6 +163,27 @@ func underlyingTimeType(val interface{}) (interface{}, bool) { return time.Time{}, false } +// underlyingUUIDType gets the underlying type that can be converted to [16]byte +func underlyingUUIDType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return time.Time{}, false + } + convVal := refVal.Elem().Interface() + return convVal, true + } + + uuidType := reflect.TypeOf([16]byte{}) + if refVal.Type().ConvertibleTo(uuidType) { + return refVal.Convert(uuidType).Interface(), true + } + + return nil, false +} + // underlyingSliceType gets the underlying slice type func underlyingSliceType(val interface{}) (interface{}, bool) { refVal := reflect.ValueOf(val) @@ -401,6 +422,14 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) { } } + if dstVal.Kind() == reflect.Array { + if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok { + baseArrayType := reflect.PtrTo(reflect.ArrayOf(dstVal.Len(), baseElemType)) + nextDst := dstPtr.Convert(baseArrayType) + return nextDst.Interface(), dstPtr.Type() != nextDst.Type() + } + } + return nil, false } diff --git a/pgtype/uuid.go b/pgtype/uuid.go index 5e1eead5..8d33d8f8 100644 --- a/pgtype/uuid.go +++ b/pgtype/uuid.go @@ -39,7 +39,7 @@ func (dst *UUID) Set(src interface{}) error { } *dst = UUID{Bytes: uuid, Status: Present} default: - if originalSrc, ok := underlyingPtrType(src); ok { + if originalSrc, ok := underlyingUUIDType(src); ok { return dst.Set(originalSrc) } return errors.Errorf("cannot convert %v to UUID", value) diff --git a/pgtype/uuid_test.go b/pgtype/uuid_test.go index 162d999f..1eddeda1 100644 --- a/pgtype/uuid_test.go +++ b/pgtype/uuid_test.go @@ -15,6 +15,8 @@ func TestUUIDTranscode(t *testing.T) { }) } +type SomeUUIDType [16]byte + func TestUUIDSet(t *testing.T) { successfulTests := []struct { source interface{} @@ -32,6 +34,10 @@ func TestUUIDSet(t *testing.T) { source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, }, + { + source: SomeUUIDType{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, { source: ([]byte)(nil), result: pgtype.UUID{Status: pgtype.Null}, @@ -86,6 +92,21 @@ func TestUUIDAssignTo(t *testing.T) { } } + { + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst SomeUUIDType + expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + { src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} var dst string From 251e6b7730c7b31b600e6fe06162e541f3032604 Mon Sep 17 00:00:00 2001 From: Nicholas Wilson Date: Wed, 24 Jul 2019 12:32:43 +0100 Subject: [PATCH 17/29] Tidying: make underlyingTimeType consistent with other underlyingFooType The first return value is ignored when returning false - so there's no point returning an empty time.Time when it can be nil. --- pgtype/convert.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pgtype/convert.go b/pgtype/convert.go index ee6907c4..029e3d48 100644 --- a/pgtype/convert.go +++ b/pgtype/convert.go @@ -149,7 +149,7 @@ func underlyingTimeType(val interface{}) (interface{}, bool) { switch refVal.Kind() { case reflect.Ptr: if refVal.IsNil() { - return time.Time{}, false + return nil, false } convVal := refVal.Elem().Interface() return convVal, true @@ -160,7 +160,7 @@ func underlyingTimeType(val interface{}) (interface{}, bool) { return refVal.Convert(timeType).Interface(), true } - return time.Time{}, false + return nil, false } // underlyingUUIDType gets the underlying type that can be converted to [16]byte From 92cd1ad639bf07d9395db46faecbbe73ac7d59ef Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Mon, 29 Jul 2019 21:19:36 +0300 Subject: [PATCH 18/29] Set 8KB as default size of ChunkReader buffer Signed-off-by: Artemiy Ryabinkov --- chunkreader/chunkreader.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/chunkreader/chunkreader.go b/chunkreader/chunkreader.go index f8d437b2..5c36292d 100644 --- a/chunkreader/chunkreader.go +++ b/chunkreader/chunkreader.go @@ -28,7 +28,11 @@ func NewChunkReader(r io.Reader) *ChunkReader { func NewChunkReaderEx(r io.Reader, options Options) (*ChunkReader, error) { if options.MinBufLen == 0 { - options.MinBufLen = 4096 + // By historical reasons Postgres currently has 8KB send buffer inside, + // so here we want to have at least the same size buffer. + // @see https://github.com/postgres/postgres/blob/249d64999615802752940e017ee5166e726bc7cd/src/backend/libpq/pqcomm.c#L134 + // @see https://www.postgresql.org/message-id/0cdc5485-cb3c-5e16-4a46-e3b2f7a41322%40ya.ru + options.MinBufLen = 8192 } return &ChunkReader{ From 95ea78048a9569250c078d1965a235a214239960 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 3 Aug 2019 09:45:04 -0500 Subject: [PATCH 19/29] Remove 0 bytes when sanitizing identifiers fixes #562 --- conn.go | 9 +++++---- conn_test.go | 30 +++++++++++++++++++++++++----- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/conn.go b/conn.go index 03f9d190..121297b8 100644 --- a/conn.go +++ b/conn.go @@ -105,9 +105,9 @@ type ConnConfig struct { // If multiple hosts were given in the Host parameter, then // this parameter may specify a single port number to be used for all hosts, // or for those that haven't port explicitly defined. - Port uint16 - Database string - User string // default: OS user name + Port uint16 + Database string + User string // default: OS user name Password string TLSConfig *tls.Config // config for TLS connection -- nil disables TLS UseFallbackTLS bool // Try FallbackTLSConfig if connecting with TLSConfig fails. Used for preferring TLS, but allowing unencrypted, or vice-versa @@ -307,7 +307,8 @@ type Identifier []string func (ident Identifier) Sanitize() string { parts := make([]string, len(ident)) for i := range ident { - parts[i] = `"` + strings.Replace(ident[i], `"`, `""`, -1) + `"` + s := strings.Replace(ident[i], string([]byte{0}), "", -1) + parts[i] = `"` + strings.Replace(s, `"`, `""`, -1) + `"` } return strings.Join(parts, ".") } diff --git a/conn_test.go b/conn_test.go index 7719bec7..fea3b659 100644 --- a/conn_test.go +++ b/conn_test.go @@ -84,7 +84,6 @@ func TestConnect(t *testing.T) { } } - func TestConnectWithMultiHost(t *testing.T) { t.Parallel() @@ -129,7 +128,6 @@ func TestConnectWithMultiHost(t *testing.T) { } } - func TestConnectWithMultiHostWritable(t *testing.T) { t.Parallel() @@ -818,9 +816,9 @@ func TestParseDSN(t *testing.T) { TLSConfig: &tls.Config{ InsecureSkipVerify: true, }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, TargetSessionAttrs: pgx.ReadWriteTargetSession, }, }, @@ -2319,6 +2317,24 @@ func TestSetLogLevel(t *testing.T) { } } +func TestIdentifierSanitizeNullSentToServer(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ident := pgx.Identifier{"foo" + string([]byte{0}) + "bar"} + + var n int64 + err := conn.QueryRow(`select 1 as ` + ident.Sanitize()).Scan(&n) + if err != nil { + t.Fatal(err) + } + if n != 1 { + t.Fatal("unexpected n") + } +} + func TestIdentifierSanitize(t *testing.T) { t.Parallel() @@ -2346,6 +2362,10 @@ func TestIdentifierSanitize(t *testing.T) { ident: pgx.Identifier{`you should " not do this`, `please don't`}, expected: `"you should "" not do this"."please don't"`, }, + { + ident: pgx.Identifier{`you should ` + string([]byte{0}) + `not do this`}, + expected: `"you should not do this"`, + }, } for i, tt := range tests { From 7fe7f33557938739e5342d82d0720523c344eb71 Mon Sep 17 00:00:00 2001 From: "Andrew S. Brown" Date: Sun, 4 Aug 2019 15:31:32 -0700 Subject: [PATCH 20/29] Terminate context prior to releasing when killing batch connection --- batch.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/batch.go b/batch.go index 4b624387..8c924e8d 100644 --- a/batch.go +++ b/batch.go @@ -135,7 +135,7 @@ func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error { _, err = b.conn.conn.Write(buf) if err != nil { - b.conn.die(err) + b.die(err) return err } @@ -281,10 +281,13 @@ func (b *Batch) die(err error) { } b.err = err - b.conn.die(err) + if b.conn != nil { + err = b.conn.termContext(err) + b.conn.die(err) - if b.conn != nil && b.connPool != nil { - b.connPool.Release(b.conn) + if b.connPool != nil { + b.connPool.Release(b.conn) + } } } From ca9de512569587bfd5f26ffd2a5e266a7bbfbef5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 6 Aug 2019 16:42:20 -0500 Subject: [PATCH 21/29] Return deferred errors Deferred errors are sent after the CommandComplete message. They could be silently dropped depending on the context in which it occurred. fixes #570 --- batch.go | 17 +++++++++++++++++ batch_test.go | 52 +++++++++++++++++++++++++++++++++++++++++++++++++++ conn_test.go | 26 ++++++++++++++++++++++++++ query.go | 19 +++++++++++++++++++ query_test.go | 43 +++++++++++++++++++++++++++++++++++++++++- 5 files changed, 156 insertions(+), 1 deletion(-) diff --git a/batch.go b/batch.go index 8c924e8d..7f5422dc 100644 --- a/batch.go +++ b/batch.go @@ -268,6 +268,23 @@ func (b *Batch) Close() (err error) { } } + for b.conn.pendingReadyForQueryCount > 0 { + msg, err := b.conn.rxMsg() + if err != nil { + return err + } + + switch msg := msg.(type) { + case *pgproto3.ErrorResponse: + return b.conn.rxErrorResponse(msg) + default: + err = b.conn.processContextFreeMsg(msg) + if err != nil { + return err + } + } + } + if err = b.conn.ensureConnectionReadyForQuery(); err != nil { return err } diff --git a/batch_test.go b/batch_test.go index 61bbe357..d0e26875 100644 --- a/batch_test.go +++ b/batch_test.go @@ -701,3 +701,55 @@ func TestTxBeginBatchRollback(t *testing.T) { ensureConnValid(t, conn) } + +func TestConnBeginBatchDeferredError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); + + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`) + + batch := conn.BeginBatch() + batch.Queue(`update t set n=n+1 where id='b' returning *`, + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + + err := batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + rows, err := batch.QueryResults() + if err != nil { + t.Error(err) + } + + for rows.Next() { + var id string + var n int32 + err = rows.Scan(&id, &n) + if err != nil { + t.Fatal(err) + } + } + + err = batch.Close() + if err == nil { + t.Fatal("expected error 23505 but got none") + } + + if err, ok := err.(pgx.PgError); !ok || err.Code != "23505" { + t.Fatalf("expected error 23505, got %v", err) + } + + ensureConnValid(t, conn) +} diff --git a/conn_test.go b/conn_test.go index fea3b659..c6ce50cc 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1370,6 +1370,32 @@ func TestExecFailure(t *testing.T) { } } +func TestExecDeferredError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred +); + +insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`) + + _, err := conn.Exec(`update t set n=n+1 where id='b'`) + if err == nil { + t.Fatal("expected error 23505 but got none") + } + + if err, ok := err.(pgx.PgError); !ok || err.Code != "23505" { + t.Fatalf("expected error 23505, got %v", err) + } + + ensureConnValid(t, conn) +} + func TestExecFailureWithArguments(t *testing.T) { t.Parallel() diff --git a/query.go b/query.go index 5c6cbf7f..bf4ec561 100644 --- a/query.go +++ b/query.go @@ -69,6 +69,25 @@ func (rows *Rows) Close() { return } + // If there is no error and a batch operation is not in progress read until we get the ReadyForQuery message or the + // ErrorResponse. This is necessary to detect a deferred constraint violation where the ErrorResponse is sent after + // CommandComplete. + if rows.err == nil && rows.batch == nil && rows.conn.pendingReadyForQueryCount == 1 { + for rows.conn.pendingReadyForQueryCount > 0 { + msg, err := rows.conn.rxMsg() + if err != nil { + rows.err = err + break + } + + err = rows.conn.processContextFreeMsg(msg) + if err != nil { + rows.err = err + break + } + } + } + if rows.unlockConn { rows.conn.unlock() rows.unlockConn = false diff --git a/query_test.go b/query_test.go index 06b7b8b7..ea1fd66e 100644 --- a/query_test.go +++ b/query_test.go @@ -14,7 +14,7 @@ import ( "github.com/jackc/pgx" "github.com/jackc/pgx/pgtype" satori "github.com/jackc/pgx/pgtype/ext/satori-uuid" - "github.com/satori/go.uuid" + uuid "github.com/satori/go.uuid" "github.com/shopspring/decimal" ) @@ -424,6 +424,47 @@ func TestConnQueryErrorWhileReturningRows(t *testing.T) { } +// https://github.com/jackc/pgx/issues/570 +func TestConnQueryDeferredError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred +); + +insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`) + + rows, err := conn.Query(`update t set n=n+1 where id='b' returning *`) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + for rows.Next() { + var id string + var n int32 + err = rows.Scan(&id, &n) + if err != nil { + t.Fatal(err) + } + } + + if rows.Err() == nil { + t.Fatal("expected error 23505 but got none") + } + + if err, ok := rows.Err().(pgx.PgError); !ok || err.Code != "23505" { + t.Fatalf("expected error 23505, got %v", err) + } + + ensureConnValid(t, conn) +} + func TestQueryEncodeError(t *testing.T) { t.Parallel() From 9e3f51e5c6759ee9d6eadfa33240b9503e39b096 Mon Sep 17 00:00:00 2001 From: Nathaniel Caza Date: Wed, 7 Aug 2019 13:49:34 -0500 Subject: [PATCH 22/29] Allow specifying LevelRepeatableRead --- stdlib/sql.go | 2 +- stdlib/sql_test.go | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/stdlib/sql.go b/stdlib/sql.go index ec5933f3..e564152f 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -277,7 +277,7 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e pgxOpts.IsoLevel = pgx.ReadUncommitted case sql.LevelReadCommitted: pgxOpts.IsoLevel = pgx.ReadCommitted - case sql.LevelSnapshot: + case sql.LevelRepeatableRead, sql.LevelSnapshot: pgxOpts.IsoLevel = pgx.RepeatableRead case sql.LevelSerializable: pgxOpts.IsoLevel = pgx.Serializable diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index cf2b91b1..895ee583 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -629,6 +629,7 @@ func TestConnBeginTxIsolation(t *testing.T) { {sqlIso: sql.LevelDefault, pgIso: defaultIsoLevel}, {sqlIso: sql.LevelReadUncommitted, pgIso: "read uncommitted"}, {sqlIso: sql.LevelReadCommitted, pgIso: "read committed"}, + {sqlIso: sql.LevelRepeatableRead, pgIso: "repeatable read"}, {sqlIso: sql.LevelSnapshot, pgIso: "repeatable read"}, {sqlIso: sql.LevelSerializable, pgIso: "serializable"}, } From 50b92ce0f591145c4b3da1c1e8fb9d98db845ebb Mon Sep 17 00:00:00 2001 From: Ian Stapleton Cordasco Date: Sun, 11 Aug 2019 08:16:48 -0500 Subject: [PATCH 23/29] Correct WaitForNotification example While working on a project that was using this, I tried using the example code but instead found that WaitForNotification expects a Context (which makes sense). This corrects the docs for folks using that as a jumping off point. --- doc.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc.go b/doc.go index 5808c09d..0c2b35d3 100644 --- a/doc.go +++ b/doc.go @@ -225,7 +225,7 @@ notification. return nil } - if notification, err := conn.WaitForNotification(time.Second); err != nil { + if notification, err := conn.WaitForNotification(context.TODO()); err != nil { // do something with notification } From 809600d6671eeea159f1560abff7af084d71f1a0 Mon Sep 17 00:00:00 2001 From: Jonathan Yoder Date: Thu, 15 Aug 2019 09:31:38 -0400 Subject: [PATCH 24/29] Clarify stdlib.AcquireConn Comment --- stdlib/sql.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stdlib/sql.go b/stdlib/sql.go index e564152f..3cd2d941 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -43,8 +43,8 @@ // // AcquireConn and ReleaseConn acquire and release a *pgx.Conn from the standard // database/sql.DB connection pool. This allows operations that must be -// performed on a single connection, but should not be run in a transaction or -// to use pgx specific functionality. +// performed on a single connection without running in a transaction, and it +// supports operations that use pgx specific functionality. // // conn, err := stdlib.AcquireConn(db) // if err != nil { From 7829081b8c1eebc860dab63378b60eb47456bea2 Mon Sep 17 00:00:00 2001 From: Dmitriy Garanzha Date: Fri, 16 Aug 2019 13:22:16 +0300 Subject: [PATCH 25/29] Load user-defined array type oids. --- conn.go | 2 +- pgmock/pgmock.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/conn.go b/conn.go index 121297b8..9dc4cdbf 100644 --- a/conn.go +++ b/conn.go @@ -615,7 +615,7 @@ left join pg_namespace nsp on t.typnamespace=nsp.oid left join pg_class cls on t.typrelid=cls.oid where ( t.typtype in('b', 'p', 'r', 'e', 'c') - and (base_type.oid is null or base_type.typtype in('b', 'p', 'r')) + and (base_type.oid is null or base_type.typtype in('b', 'p', 'r', 'c')) and (cls.oid is null or cls.relkind='c') )` ) diff --git a/pgmock/pgmock.go b/pgmock/pgmock.go index d4ab0d13..5c3fdc27 100644 --- a/pgmock/pgmock.go +++ b/pgmock/pgmock.go @@ -214,7 +214,7 @@ left join pg_namespace nsp on t.typnamespace=nsp.oid left join pg_class cls on t.typrelid=cls.oid where ( t.typtype in('b', 'p', 'r', 'e', 'c') - and (base_type.oid is null or base_type.typtype in('b', 'p', 'r')) + and (base_type.oid is null or base_type.typtype in('b', 'p', 'r', 'c')) and (cls.oid is null or cls.relkind='c') )`, }), From 12c6319244e4836c5bb6bbff2f786bf73487c574 Mon Sep 17 00:00:00 2001 From: Kale Blankenship Date: Wed, 28 Aug 2019 12:50:51 -0700 Subject: [PATCH 26/29] Include ParameterOIDs when preparing statements on new pool connections ParameterOIDs passed to ConnPool.PrepareEx are used to prepare the statement on existing connections in the pool. If additional connections are later created ParameterOIDs are omitted, potentially causing query failures. --- conn_pool.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/conn_pool.go b/conn_pool.go index e8972a0b..344f00d7 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -341,7 +341,8 @@ func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) { } for _, ps := range p.preparedStatements { - if _, err := c.Prepare(ps.Name, ps.SQL); err != nil { + opts := &PrepareExOptions{ParameterOIDs: ps.ParameterOIDs} + if _, err := c.PrepareEx(context.Background(), ps.Name, ps.SQL, opts); err != nil { c.die(err) return nil, err } From 78f498fc43f957b2eccdac1d002798ee3c277a5c Mon Sep 17 00:00:00 2001 From: Kale Blankenship Date: Sat, 31 Aug 2019 10:27:19 -0700 Subject: [PATCH 27/29] Add ConnPool.AcquireEx --- conn_pool.go | 19 ++++++ conn_pool_test.go | 144 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 163 insertions(+) diff --git a/conn_pool.go b/conn_pool.go index 344f00d7..95e1b015 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -110,6 +110,25 @@ func (p *ConnPool) Acquire() (*Conn, error) { return c, err } +func (p *ConnPool) AcquireEx(ctx context.Context) (*Conn, error) { + var deadline *time.Time + + if p.acquireTimeout > 0 { + tmp := time.Now().Add(p.acquireTimeout) + deadline = &tmp + } + + ctxDeadline, ok := ctx.Deadline() + if ok && (deadline == nil || ctxDeadline.Before(*deadline)) { + deadline = &ctxDeadline + } + + p.cond.L.Lock() + c, err := p.acquire(deadline) + p.cond.L.Unlock() + return c, err +} + // deadlinePassed returns true if the given deadline has passed. func (p *ConnPool) deadlinePassed(deadline *time.Time) bool { return deadline != nil && time.Now().After(*deadline) diff --git a/conn_pool_test.go b/conn_pool_test.go index 84a74aed..83bdf1fd 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -45,6 +45,12 @@ func acquireWithTimeTaken(pool *pgx.ConnPool) (*pgx.Conn, time.Duration, error) return c, time.Since(startTime), err } +func acquireExWithTimeTaken(pool *pgx.ConnPool, ctx context.Context) (*pgx.Conn, time.Duration, error) { + startTime := time.Now() + c, err := pool.AcquireEx(ctx) + return c, time.Since(startTime), err +} + func TestNewConnPool(t *testing.T) { t.Parallel() @@ -315,6 +321,144 @@ func TestPoolWithoutAcquireTimeoutSet(t *testing.T) { } } +func TestPoolWithAcquireExContextTimeoutSet(t *testing.T) { + t.Parallel() + + config := pgx.ConnPoolConfig{ + ConnConfig: *defaultConnConfig, + MaxConnections: 1, + } + + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + // Consume all connections ... + allConnections := acquireAllConnections(t, pool, config.MaxConnections) + defer releaseAllConnections(pool, allConnections) + + ctxTimeout := 2 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout) + defer cancel() + + // ... then try to consume 1 more. It should fail after a short timeout. + _, timeTaken, err := acquireExWithTimeTaken(pool, ctx) + + if err == nil || err != pgx.ErrAcquireTimeout { + t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err) + } + if timeTaken < ctxTimeout { + t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", ctxTimeout, timeTaken) + } +} + +func TestPoolWithAcquireExPoolTimeoutLower(t *testing.T) { + t.Parallel() + + connAllocTimeout := 2 * time.Second + config := pgx.ConnPoolConfig{ + ConnConfig: *defaultConnConfig, + MaxConnections: 1, + AcquireTimeout: connAllocTimeout, + } + + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + // Consume all connections ... + allConnections := acquireAllConnections(t, pool, config.MaxConnections) + defer releaseAllConnections(pool, allConnections) + + ctxTimeout := 5 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout) + defer cancel() + + // ... then try to consume 1 more. It should fail after a short timeout. + _, timeTaken, err := acquireExWithTimeTaken(pool, ctx) + + if err == nil || err != pgx.ErrAcquireTimeout { + t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err) + } + if timeTaken < connAllocTimeout { + t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", connAllocTimeout, timeTaken) + } + if timeTaken > ctxTimeout { + t.Fatalf("Expected connection allocation time to be less than %v, instead it was '%v'", ctxTimeout, timeTaken) + } +} + +func TestPoolWithAcquireExPoolTimeoutHigher(t *testing.T) { + t.Parallel() + + connAllocTimeout := 5 * time.Second + config := pgx.ConnPoolConfig{ + ConnConfig: *defaultConnConfig, + MaxConnections: 1, + AcquireTimeout: connAllocTimeout, + } + + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + // Consume all connections ... + allConnections := acquireAllConnections(t, pool, config.MaxConnections) + defer releaseAllConnections(pool, allConnections) + + ctxTimeout := 2 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout) + defer cancel() + + // ... then try to consume 1 more. It should fail after a short timeout. + _, timeTaken, err := acquireExWithTimeTaken(pool, ctx) + + if err == nil || err != pgx.ErrAcquireTimeout { + t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err) + } + if timeTaken < ctxTimeout { + t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", ctxTimeout, timeTaken) + } + if timeTaken > connAllocTimeout { + t.Fatalf("Expected connection allocation time to be less than %v, instead it was '%v'", connAllocTimeout, timeTaken) + } +} + +func TestPoolWithoutAcquireExTimeoutSet(t *testing.T) { + t.Parallel() + + maxConnections := 1 + pool := createConnPool(t, maxConnections) + defer pool.Close() + + // Consume all connections ... + allConnections := acquireAllConnections(t, pool, maxConnections) + + // ... then try to consume 1 more. It should hang forever. + // To unblock it we release the previously taken connection in a goroutine. + stopDeadWaitTimeout := 5 * time.Second + timer := time.AfterFunc(stopDeadWaitTimeout+100*time.Millisecond, func() { + releaseAllConnections(pool, allConnections) + }) + defer timer.Stop() + + conn, timeTaken, err := acquireExWithTimeTaken(pool, context.Background()) + if err == nil { + pool.Release(conn) + } else { + t.Fatalf("Expected error to be nil, instead it was '%v'", err) + } + if timeTaken < stopDeadWaitTimeout { + t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", stopDeadWaitTimeout, timeTaken) + } +} + func TestPoolErrClosedPool(t *testing.T) { t.Parallel() From 35908df25f9079a73270eefb1cf4c3df635ee876 Mon Sep 17 00:00:00 2001 From: Dmitriy Garanzha Date: Mon, 2 Sep 2019 16:57:21 +0300 Subject: [PATCH 28/29] Filter automatically created table array types. --- conn.go | 2 ++ pgmock/pgmock.go | 2 ++ 2 files changed, 4 insertions(+) diff --git a/conn.go b/conn.go index 9dc4cdbf..b98434f7 100644 --- a/conn.go +++ b/conn.go @@ -611,12 +611,14 @@ func initPostgresql(c *Conn) (*pgtype.ConnInfo, error) { end from pg_type t left join pg_type base_type on t.typelem=base_type.oid +left join pg_class base_cls ON base_type.typrelid = base_cls.oid left join pg_namespace nsp on t.typnamespace=nsp.oid left join pg_class cls on t.typrelid=cls.oid where ( t.typtype in('b', 'p', 'r', 'e', 'c') and (base_type.oid is null or base_type.typtype in('b', 'p', 'r', 'c')) and (cls.oid is null or cls.relkind='c') + and (base_cls.oid is null or base_cls.relkind = 'c') )` ) diff --git a/pgmock/pgmock.go b/pgmock/pgmock.go index 5c3fdc27..7b9e7991 100644 --- a/pgmock/pgmock.go +++ b/pgmock/pgmock.go @@ -210,12 +210,14 @@ func PgxInitSteps() []Step { end from pg_type t left join pg_type base_type on t.typelem=base_type.oid +left join pg_class base_cls ON base_type.typrelid = base_cls.oid left join pg_namespace nsp on t.typnamespace=nsp.oid left join pg_class cls on t.typrelid=cls.oid where ( t.typtype in('b', 'p', 'r', 'e', 'c') and (base_type.oid is null or base_type.typtype in('b', 'p', 'r', 'c')) and (cls.oid is null or cls.relkind='c') + and (base_cls.oid is null or base_cls.relkind = 'c') )`, }), ExpectMessage(&pgproto3.Describe{ From f26e4c0e6921395ee2556c61c0152b031254ff6c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 2 Sep 2019 12:19:55 -0500 Subject: [PATCH 29/29] Update status of v4 --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index b7051f65..0a4cacc3 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,8 @@ if err != nil { ## v4 Coming Soon -This is the current stable v3 version. v4 is currently is in prelease status. Consider using [v4](https://github.com/jackc/pgx/tree/v4) for new development or test upgrading existing applications. +This is the current stable v3 version. v4 is currently is in release candidate status. Consider using +[v4](https://github.com/jackc/pgx/tree/v4) for new development or test upgrading existing applications. ## Features