From 9538d15c29005e5044da6ba3f4c8ff06daec1278 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Mon, 3 Jun 2019 23:51:48 +0300 Subject: [PATCH 01/15] Draft of connection writable checking Signed-off-by: Artemiy Ryabinkov --- conn.go | 74 +++++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 64 insertions(+), 10 deletions(-) diff --git a/conn.go b/conn.go index cb24748c..e0169d6d 100644 --- a/conn.go +++ b/conn.go @@ -89,6 +89,13 @@ type ConnConfig struct { // used by default. The same functionality can be controlled on a per query // basis by setting QueryExOptions.SimpleProtocol. PreferSimpleProtocol bool + + // TargetSessionAttr allows to specify which servers are accepted for this connection. + // "any", meaning that any kind of servers can be accepted. This is as well the default value. + // "read-write", to disallow connections to read-only servers, hot standbys for example. + // @see https://www.postgresql.org/message-id/CAD__OuhqPRGpcsfwPHz_PDqAGkoqS1UvnUnOnAB-LBWBW=wu4A@mail.gmail.com + // @see https://paquier.xyz/postgresql-2/postgres-10-libpq-read-write/ + TargetSessionAttrs string } func (cc *ConnConfig) networkAddress() (network, address string) { @@ -262,8 +269,15 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) } } + if c.config.TargetSessionAttrs != "" && + c.config.TargetSessionAttrs != "any" && + c.config.TargetSessionAttrs != "read-write" { + return nil, errors.New("invalid value for target_session_attrs, expected \"any\" or \"read-write\"") + } + c.onNotice = config.OnNotice + // TODO: Parse multi-hosts network, address := c.config.networkAddress() if c.config.Dial == nil { d := defaultDialer() @@ -273,22 +287,58 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) if c.shouldLog(LogLevelInfo) { c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"network": network, "address": address}) } - err = c.connect(config, network, address, config.TLSConfig) - if err != nil && config.UseFallbackTLS { - if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err}) + + // TODO: Start loop for all hosts [host0 .. hostN] + for { + err = c.connect(config, network, address, config.TLSConfig) + if err != nil && config.UseFallbackTLS { + if c.shouldLog(LogLevelInfo) { + c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err}) + } + err = c.connect(config, network, address, config.FallbackTLSConfig) } - err = c.connect(config, network, address, config.FallbackTLSConfig) + + if err != nil { + if c.shouldLog(LogLevelError) { + c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err}) + } + // TODO: Collect error + continue + } + + err = c.writeable() + if err != nil { + // TODO: Log info about not writable host + // TODO: Collect error + continue + } + + return c, nil } + + // TODO: Return collected errors + return nil, nil +} + +func (c *Conn) writeable() error { + if c.config.TargetSessionAttrs == "" || c.config.TargetSessionAttrs == "any" { + return nil + } + + var st string + err := c.QueryRowEx(context.Background(), "SHOW transaction_read_only", &QueryExOptions{SimpleProtocol: true}). + Scan(st) + if err != nil { - if c.shouldLog(LogLevelError) { - c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err}) - } - return nil, err + return errors.Wrap(err, "failed to fetch transaction_read_only state") } - return c, nil + if st == "on" { + return errors.New("writable connection disabled by server") + } + + return nil } func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) { @@ -709,6 +759,10 @@ func (old ConnConfig) Merge(other ConnConfig) ConnConfig { cc.PreferSimpleProtocol = old.PreferSimpleProtocol || other.PreferSimpleProtocol + if other.TargetSessionAttrs != "" { + cc.TargetSessionAttrs = other.TargetSessionAttrs + } + cc.RuntimeParams = make(map[string]string) for k, v := range old.RuntimeParams { cc.RuntimeParams[k] = v From 9f031bb8f9bea60bd51ebc1cbaaa8e5db779b191 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Sun, 16 Jun 2019 14:03:43 +0300 Subject: [PATCH 02/15] Return net.Addr from networkAddress Signed-off-by: Artemiy Ryabinkov --- conn.go | 82 +++++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 59 insertions(+), 23 deletions(-) diff --git a/conn.go b/conn.go index e0169d6d..323875a5 100644 --- a/conn.go +++ b/conn.go @@ -98,20 +98,24 @@ type ConnConfig struct { TargetSessionAttrs string } -func (cc *ConnConfig) networkAddress() (network, address string) { - network = "tcp" - address = fmt.Sprintf("%s:%d", cc.Host, cc.Port) - // See if host is a valid path, if yes connect with a socket +func (cc *ConnConfig) networkAddress() net.Addr { + // See if host is a valid path, if yes connect with a unix socket if _, err := os.Stat(cc.Host); err == nil { // For backward compatibility accept socket file paths -- but directories are now preferred - network = "unix" - address = cc.Host + network := "unix" + address := cc.Host + if !strings.Contains(address, "/.s.PGSQL.") { - address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10) + address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatUint(uint64(cc.Port), 10) } + + return &net.UnixAddr{Name: address, Net: network} } - return network, address + return &net.TCPAddr{ + Port: int(cc.Port), + IP: net.ParseIP(cc.Host), + } } // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. @@ -277,48 +281,80 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) c.onNotice = config.OnNotice - // TODO: Parse multi-hosts - network, address := c.config.networkAddress() if c.config.Dial == nil { d := defaultDialer() c.config.Dial = d.Dial } - if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"network": network, "address": address}) - } + // TODO: Parse multi-hosts + hostAddr := c.config.networkAddress() + + addrs := []net.Addr{hostAddr} + + var errs []error + for _, addr := range addrs { + network, address := addr.Network(), addr.String() + + if c.shouldLog(LogLevelInfo) { + c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{ + "network": network, + "address": address, + }) + } - // TODO: Start loop for all hosts [host0 .. hostN] - for { err = c.connect(config, network, address, config.TLSConfig) if err != nil && config.UseFallbackTLS { if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err}) + c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{ + "err": err, + "network": network, + "address": address, + }) } err = c.connect(config, network, address, config.FallbackTLSConfig) } if err != nil { if c.shouldLog(LogLevelError) { - c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err}) + c.log(LogLevelError, "connect failed", map[string]interface{}{ + "err": err, + "network": network, + "address": address, + }) } - // TODO: Collect error + + errs = append(errs, err) continue } err = c.writeable() if err != nil { - // TODO: Log info about not writable host - // TODO: Collect error + if c.shouldLog(LogLevelInfo) { + c.log(LogLevelInfo, "host is not writable", map[string]interface{}{ + "err": err, + "network": network, + "address": address, + }) + } + + errs = append(errs, err) continue } return c, nil } + // To keep backwards, if specific error type expected. + if len(errs) == 1 { + return nil, errs[0] + } - // TODO: Return collected errors - return nil, nil + var errmsg string + for _, err := range errs { + errmsg += ";" + err.Error() + } + + return nil, errors.New(errmsg) } func (c *Conn) writeable() error { @@ -331,7 +367,7 @@ func (c *Conn) writeable() error { Scan(st) if err != nil { - return errors.Wrap(err, "failed to fetch transaction_read_only state") + return errors.Wrap(err, "failed to fetch \"transaction_read_only\" state") } if st == "on" { From 25e1f674a2a02cacc2db24924545539366fac825 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Sun, 16 Jun 2019 14:36:54 +0300 Subject: [PATCH 03/15] Fix doCancel with addr from networkAddress Signed-off-by: Artemiy Ryabinkov --- conn.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/conn.go b/conn.go index 323875a5..287ce7db 100644 --- a/conn.go +++ b/conn.go @@ -1741,8 +1741,8 @@ func quoteIdentifier(s string) string { } func doCancel(c *Conn) error { - network, address := c.config.networkAddress() - cancelConn, err := c.config.Dial(network, address) + addr := c.config.networkAddress() + cancelConn, err := c.config.Dial(addr.Network(), addr.String()) if err != nil { return err } From 6ec815a7489b64856307091567f59635b9a65bbe Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Tue, 18 Jun 2019 16:02:09 +0300 Subject: [PATCH 04/15] Support Multiple Hosts in ConnConfig Signed-off-by: Artemiy Ryabinkov --- conn.go | 146 ++++++++++++++++++++++++++++++------ conn_config_test.go.example | 3 + conn_config_test.go.travis | 2 + conn_test.go | 101 +++++++++++++++++++++++++ pgpass.go | 1 + 5 files changed, 232 insertions(+), 21 deletions(-) diff --git a/conn.go b/conn.go index 287ce7db..a6ea78f1 100644 --- a/conn.go +++ b/conn.go @@ -63,10 +63,29 @@ type DialFunc func(network, addr string) (net.Conn, error) // ConnConfig contains all the options used to establish a connection. type ConnConfig struct { - Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) - Port uint16 // default: 5432 - Database string - User string // default: OS user name + // Name of host to connect to. (e.g. localhost) + // If a host name begins with a slash, it specifies Unix-domain communication + // rather than TCP/IP communication; the value is the name of the directory + // in which the socket file is stored. (e.g. /private/tmp) + // The default behavior when host is not specified, or is empty, is to connect to localhost. + // + // A comma-separated list of host names is also accepted, + // in which case each host name in the list is tried in order; + // an empty item in the list selects the default behavior as explained above. + // @see https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS + Host string + + // Port number to connect to at the server host, + // or socket file name extension for Unix-domain connections. + // An empty or zero value, specifies the default port number — 5432. + // + // If multiple hosts were given in the Host parameter, then + // this parameter may specify a single port number to be used for all hosts, + // or for those that haven't port explicitly defined. + Port uint16 + Database string + User string // default: OS user name + // TODO: Allow password to be different for each host/port pair if a password file is used Password string TLSConfig *tls.Config // config for TLS connection -- nil disables TLS UseFallbackTLS bool // Try FallbackTLSConfig if connecting with TLSConfig fails. Used for preferring TLS, but allowing unencrypted, or vice-versa @@ -95,10 +114,34 @@ type ConnConfig struct { // "read-write", to disallow connections to read-only servers, hot standbys for example. // @see https://www.postgresql.org/message-id/CAD__OuhqPRGpcsfwPHz_PDqAGkoqS1UvnUnOnAB-LBWBW=wu4A@mail.gmail.com // @see https://paquier.xyz/postgresql-2/postgres-10-libpq-read-write/ + // + // The query SHOW transaction_read_only will be sent upon any successful connection; + // if it returns on, the connection will be closed. + // If multiple hosts were specified in the connection string, + // any remaining servers will be tried just as if the connection attempt had failed. + // The default value of this parameter, any, regards all connections as acceptable. TargetSessionAttrs string } -func (cc *ConnConfig) networkAddress() net.Addr { +// hostAddr represents network end point defined as hostname or IP + port. +type hostAddr struct { + Host string + Port uint16 +} + +// Network returns the address's network name, "tcp". +func (a *hostAddr) Network() string { return "tcp" } + +// String implements net.Addr String method. +func (a *hostAddr) String() string { + if a == nil { + return "" + } + + return net.JoinHostPort(a.Host, strconv.Itoa(int(a.Port))) +} + +func (cc *ConnConfig) networkAddresses() ([]net.Addr, error) { // See if host is a valid path, if yes connect with a unix socket if _, err := os.Stat(cc.Host); err == nil { // For backward compatibility accept socket file paths -- but directories are now preferred @@ -109,13 +152,50 @@ func (cc *ConnConfig) networkAddress() net.Addr { address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatUint(uint64(cc.Port), 10) } - return &net.UnixAddr{Name: address, Net: network} + addrs := []net.Addr{ + &net.UnixAddr{Name: address, Net: network}, + } + + return addrs, nil } - return &net.TCPAddr{ - Port: int(cc.Port), - IP: net.ParseIP(cc.Host), + if cc.Host == "" { + addrs := []net.Addr{ + &net.TCPAddr{Port: int(cc.Port)}, + } + + return addrs, nil } + + var addrs []net.Addr + + hostports := strings.Split(cc.Host, ",") + for i, hostport := range hostports { + if hostport == "" { + return nil, fmt.Errorf("multi-host part %d is empty, at least host or port must be defined", i) + } + + // It's not possible to use net.TCPAddr here, cuz host may be hostname. + addr := hostAddr{ + Host: hostport, + Port: cc.Port, + } + + pos := strings.IndexByte(hostport, ':') + if pos != -1 { + p, err := strconv.ParseUint(hostport[pos+1:], 10, 16) + if err != nil { + return nil, fmt.Errorf("multi-host part %d (%s) has invalid port format", i, hostport) + } + + addr.Host = hostport[:pos] + addr.Port = uint16(p) + } + + addrs = append(addrs, &addr) + } + + return addrs, nil } // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. @@ -156,6 +236,10 @@ type Conn struct { ConnInfo *pgtype.ConnInfo frontend *pgproto3.Frontend + + // In case of Multiple Hosts we need to know what addr was used to connect. + // This address will be used to send a cancellation request. + addr net.Addr } // PreparedStatement is a description of a prepared statement @@ -286,10 +370,10 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) c.config.Dial = d.Dial } - // TODO: Parse multi-hosts - hostAddr := c.config.networkAddress() - - addrs := []net.Addr{hostAddr} + addrs, err := c.config.networkAddresses() + if err != nil { + return nil, err + } var errs []error for _, addr := range addrs { @@ -323,12 +407,23 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) }) } + // On any auth errors return immediately + if pgErr, ok := err.(PgError); ok { + switch pgErr.Code { + // @see: https://www.postgresql.org/docs/current/errcodes-appendix.html + case "28000", "28P01": // Invalid Authorization Specification + return nil, pgErr + } + } + errs = append(errs, err) continue } - err = c.writeable() + err = c.writable() if err != nil { + c.die(err) + if c.shouldLog(LogLevelInfo) { c.log(LogLevelInfo, "host is not writable", map[string]interface{}{ "err": err, @@ -341,6 +436,8 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) continue } + c.addr = addr + return c, nil } @@ -351,30 +448,34 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) var errmsg string for _, err := range errs { - errmsg += ";" + err.Error() + errmsg += "; " + err.Error() } return nil, errors.New(errmsg) } -func (c *Conn) writeable() error { +func (c *Conn) writable() error { if c.config.TargetSessionAttrs == "" || c.config.TargetSessionAttrs == "any" { return nil } var st string err := c.QueryRowEx(context.Background(), "SHOW transaction_read_only", &QueryExOptions{SimpleProtocol: true}). - Scan(st) + Scan(&st) if err != nil { return errors.Wrap(err, "failed to fetch \"transaction_read_only\" state") } - if st == "on" { + switch st { + case "on": return errors.New("writable connection disabled by server") + case "off": + // If transaction_read_only = off, then connection is writable. + return nil } - return nil + return errors.New("unexpected \"transaction_read_only\" status") } func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) { @@ -958,6 +1059,8 @@ func ParseDSN(s string) (ConnConfig, error) { // ParseConnectionString parses either a URI or a DSN connection string. // see ParseURI and ParseDSN for details. func ParseConnectionString(s string) (ConnConfig, error) { + // TODO: Multiple Hosts support + // @see: https://www.postgresql.org/docs/10/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS if u, err := url.Parse(s); err == nil && u.Scheme != "" { return ParseURI(s) } @@ -981,6 +1084,8 @@ func ParseConnectionString(s string) (ConnConfig, error) { // PGSSLROOTCERT // PGAPPNAME // PGCONNECT_TIMEOUT +// TODO: PGTARGETSESSIONATTRS support +// @see: https://www.postgresql.org/docs/10/libpq-envars.html // // Important TLS Security Notes: // ParseEnvLibpq tries to match libpq behavior with regard to PGSSLMODE. This @@ -1741,8 +1846,7 @@ func quoteIdentifier(s string) string { } func doCancel(c *Conn) error { - addr := c.config.networkAddress() - cancelConn, err := c.config.Dial(addr.Network(), addr.String()) + cancelConn, err := c.config.Dial(c.addr.Network(), c.addr.String()) if err != nil { return err } diff --git a/conn_config_test.go.example b/conn_config_test.go.example index 096e1354..620b0ea1 100644 --- a/conn_config_test.go.example +++ b/conn_config_test.go.example @@ -7,6 +7,7 @@ import ( // "go/build" // "io/ioutil" // "path" + // "net" "github.com/jackc/pgx" ) @@ -14,6 +15,7 @@ import ( var defaultConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} // To skip tests for specific connection / authentication types set that connection param to nil +var multihostConnConfig *pgx.ConnConfig = nil var tcpConnConfig *pgx.ConnConfig = nil var unixSocketConnConfig *pgx.ConnConfig = nil var md5ConnConfig *pgx.ConnConfig = nil @@ -24,6 +26,7 @@ var customDialerConnConfig *pgx.ConnConfig = nil var replicationConnConfig *pgx.ConnConfig = nil var cratedbConnConfig *pgx.ConnConfig = nil +// var multihostConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "2.2.2.2:1,127.0.0.1,4.2.4.2", User: "pgx_md5", Password: "secret", Database: "pgx_test", Dial: (&net.Dialer{KeepAlive: 5 * time.Minute, Timeout: 100 * time.Millisecond}).Dial} // var tcpConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} // var unixSocketConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "/private/tmp", User: "pgx_none", Database: "pgx_test"} // var md5ConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} diff --git a/conn_config_test.go.travis b/conn_config_test.go.travis index cf29a743..738f1112 100644 --- a/conn_config_test.go.travis +++ b/conn_config_test.go.travis @@ -5,9 +5,11 @@ import ( "github.com/jackc/pgx" "os" "strconv" + "net" ) var defaultConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} +var multihostConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "2.2.2.2:1,127.0.0.1,4.2.4.2", User: "pgx_md5", Password: "secret", Database: "pgx_test", Dial: (&net.Dialer{KeepAlive: 5 * time.Minute, Timeout: 100 * time.Millisecond}).Dial} var tcpConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} var unixSocketConnConfig = &pgx.ConnConfig{Host: "/var/run/postgresql", User: "postgres", Database: "pgx_test"} var md5ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} diff --git a/conn_test.go b/conn_test.go index 6ca00c6d..14efbeca 100644 --- a/conn_test.go +++ b/conn_test.go @@ -84,6 +84,107 @@ func TestConnect(t *testing.T) { } } + +func TestConnectWithMultiHost(t *testing.T) { + t.Parallel() + + if multihostConnConfig == nil { + t.Skip("Skipping due to undefined multihostConnConfig") + } + + conn, err := pgx.Connect(*multihostConnConfig) + if err != nil { + t.Fatalf("Unable to establish connection: %v", err) + } + + if _, present := conn.RuntimeParams["server_version"]; !present { + t.Error("Runtime parameters not stored") + } + + if conn.PID() == 0 { + t.Error("Backend PID not stored") + } + + var currentDB string + err = conn.QueryRow("select current_database()").Scan(¤tDB) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if currentDB != defaultConnConfig.Database { + t.Errorf("Did not connect to specified database (%v)", defaultConnConfig.Database) + } + + var user string + err = conn.QueryRow("select current_user").Scan(&user) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if user != defaultConnConfig.User { + t.Errorf("Did not connect as specified user (%v)", defaultConnConfig.User) + } + + err = conn.Close() + if err != nil { + t.Fatal("Unable to close connection") + } +} + + +func TestConnectWithMultiHostWritable(t *testing.T) { + t.Parallel() + + if multihostConnConfig == nil { + t.Skip("Skipping due to undefined multihostConnConfig") + } + + connConfig := *multihostConnConfig + connConfig.TargetSessionAttrs = "read-write" + + conn := mustConnect(t, connConfig) + defer closeConn(t, conn) + + if _, present := conn.RuntimeParams["server_version"]; !present { + t.Error("Runtime parameters not stored") + } + + if conn.PID() == 0 { + t.Error("Backend PID not stored") + } + + var currentDB string + err := conn.QueryRow("select current_database()").Scan(¤tDB) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if currentDB != defaultConnConfig.Database { + t.Errorf("Did not connect to specified database (%v)", defaultConnConfig.Database) + } + + var user string + err = conn.QueryRow("select current_user").Scan(&user) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if user != defaultConnConfig.User { + t.Errorf("Did not connect as specified user (%v)", defaultConnConfig.User) + } + + var st string + err = conn.QueryRow("SHOW transaction_read_only").Scan(&st) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + + if st == "on" { + t.Error("Connection is not writable") + } + + err = conn.Close() + if err != nil { + t.Fatal("Unable to close connection") + } +} + func TestConnectWithUnixSocketDirectory(t *testing.T) { t.Parallel() diff --git a/pgpass.go b/pgpass.go index 34b9bdf5..ff97e5f0 100644 --- a/pgpass.go +++ b/pgpass.go @@ -57,6 +57,7 @@ func parsepgpass(line, cfgHost, cfgPort, cfgDatabase, cfgUsername string) *strin return &parts[4] } +// TODO: Multi-host support func pgpass(cfg *ConnConfig) (found bool) { passfile := os.Getenv("PGPASSFILE") if passfile == "" { From 2837818b67f39d5dfb49a6b37f7d8a23ef263896 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Tue, 18 Jun 2019 17:09:38 +0300 Subject: [PATCH 05/15] fix typo Signed-off-by: Artemiy Ryabinkov --- conn.go | 2 +- conn_config_test.go.example | 1 + conn_config_test.go.travis | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/conn.go b/conn.go index a6ea78f1..153c7a3d 100644 --- a/conn.go +++ b/conn.go @@ -441,7 +441,7 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) return c, nil } - // To keep backwards, if specific error type expected. + // To keep backwards compatibility, if specific error type expected. if len(errs) == 1 { return nil, errs[0] } diff --git a/conn_config_test.go.example b/conn_config_test.go.example index 620b0ea1..2ca84ac3 100644 --- a/conn_config_test.go.example +++ b/conn_config_test.go.example @@ -8,6 +8,7 @@ import ( // "io/ioutil" // "path" // "net" + // "time" "github.com/jackc/pgx" ) diff --git a/conn_config_test.go.travis b/conn_config_test.go.travis index 738f1112..fbfb5252 100644 --- a/conn_config_test.go.travis +++ b/conn_config_test.go.travis @@ -6,6 +6,7 @@ import ( "os" "strconv" "net" + "time" ) var defaultConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} From bcb2afe2be3d755f0ca53f3df0b262f3ca64096f Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Wed, 10 Jul 2019 22:59:17 +0300 Subject: [PATCH 06/15] TargetSessionAttrs as custom type Signed-off-by: Artemiy Ryabinkov --- conn.go | 28 ++++++++++++++++++++++------ conn_test.go | 2 +- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/conn.go b/conn.go index 153c7a3d..f0cf20be 100644 --- a/conn.go +++ b/conn.go @@ -61,6 +61,24 @@ type NoticeHandler func(*Conn, *Notice) // DialFunc is a function that can be used to connect to a PostgreSQL server type DialFunc func(network, addr string) (net.Conn, error) +// TargetSessionType represents target session attrs configuration parameter. +type TargetSessionType string + +// Block enumerates available values for TargetSessionType. +const ( + AnyTargetSession = "any" + ReadWriteTargetSession = "read-write" +) + +func (t TargetSessionType) isValid() error { + switch t { + case "", AnyTargetSession, ReadWriteTargetSession: + return nil + } + + return errors.New("invalid value for target_session_attrs, expected \"any\" or \"read-write\"") +} + // ConnConfig contains all the options used to establish a connection. type ConnConfig struct { // Name of host to connect to. (e.g. localhost) @@ -120,7 +138,7 @@ type ConnConfig struct { // If multiple hosts were specified in the connection string, // any remaining servers will be tried just as if the connection attempt had failed. // The default value of this parameter, any, regards all connections as acceptable. - TargetSessionAttrs string + TargetSessionAttrs TargetSessionType } // hostAddr represents network end point defined as hostname or IP + port. @@ -357,10 +375,8 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) } } - if c.config.TargetSessionAttrs != "" && - c.config.TargetSessionAttrs != "any" && - c.config.TargetSessionAttrs != "read-write" { - return nil, errors.New("invalid value for target_session_attrs, expected \"any\" or \"read-write\"") + if err := c.config.TargetSessionAttrs.isValid(); err != nil { + return nil, err } c.onNotice = config.OnNotice @@ -455,7 +471,7 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) } func (c *Conn) writable() error { - if c.config.TargetSessionAttrs == "" || c.config.TargetSessionAttrs == "any" { + if c.config.TargetSessionAttrs == "" || c.config.TargetSessionAttrs == AnyTargetSession { return nil } diff --git a/conn_test.go b/conn_test.go index 14efbeca..28bfe48b 100644 --- a/conn_test.go +++ b/conn_test.go @@ -138,7 +138,7 @@ func TestConnectWithMultiHostWritable(t *testing.T) { } connConfig := *multihostConnConfig - connConfig.TargetSessionAttrs = "read-write" + connConfig.TargetSessionAttrs = pgx.ReadWriteTargetSession conn := mustConnect(t, connConfig) defer closeConn(t, conn) From 7d4215cb88d63e43baa8b8735bdafbfbd673c8bd Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Wed, 10 Jul 2019 23:16:46 +0300 Subject: [PATCH 07/15] fix error message building from errors array on connection establishing Signed-off-by: Artemiy Ryabinkov --- conn.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/conn.go b/conn.go index f0cf20be..e570249e 100644 --- a/conn.go +++ b/conn.go @@ -462,12 +462,12 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) return nil, errs[0] } - var errmsg string + errmsgs := make([]string, len(errs)) for _, err := range errs { - errmsg += "; " + err.Error() + errmsgs = append(errmsgs, err.Error()) } - return nil, errors.New(errmsg) + return nil, errors.New(strings.Join(errmsgs, ";")) } func (c *Conn) writable() error { From 75b4ba635c0224e046dc3c3d5e8d5d30c5b65d61 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Thu, 11 Jul 2019 00:16:58 +0300 Subject: [PATCH 08/15] try to improve readability of writable checking Signed-off-by: Artemiy Ryabinkov --- conn.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/conn.go b/conn.go index e570249e..975cf337 100644 --- a/conn.go +++ b/conn.go @@ -79,6 +79,10 @@ func (t TargetSessionType) isValid() error { return errors.New("invalid value for target_session_attrs, expected \"any\" or \"read-write\"") } +func (t TargetSessionType) writableRequired() bool { + return t == ReadWriteTargetSession +} + // ConnConfig contains all the options used to establish a connection. type ConnConfig struct { // Name of host to connect to. (e.g. localhost) @@ -436,7 +440,7 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) continue } - err = c.writable() + err = c.checkWritable() if err != nil { c.die(err) @@ -470,8 +474,8 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) return nil, errors.New(strings.Join(errmsgs, ";")) } -func (c *Conn) writable() error { - if c.config.TargetSessionAttrs == "" || c.config.TargetSessionAttrs == AnyTargetSession { +func (c *Conn) checkWritable() error { + if !c.config.TargetSessionAttrs.writableRequired() { return nil } @@ -485,7 +489,7 @@ func (c *Conn) writable() error { switch st { case "on": - return errors.New("writable connection disabled by server") + return errors.New("writable transactions disabled by server") case "off": // If transaction_read_only = off, then connection is writable. return nil From 18189fafd54ca2b678681a0550d632d1a5434f2c Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Thu, 11 Jul 2019 20:28:04 +0300 Subject: [PATCH 09/15] ParseConnectionString supports Multi-Hosts Signed-off-by: Artemiy Ryabinkov --- conn.go | 123 +++++++++++++++++++++++++++++++++++++++++---------- conn_test.go | 76 +++++++++++++++++++++++++++++++ 2 files changed, 176 insertions(+), 23 deletions(-) diff --git a/conn.go b/conn.go index 975cf337..fd134461 100644 --- a/conn.go +++ b/conn.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "io/ioutil" + "math" "net" "net/url" "os" @@ -947,16 +948,26 @@ func ParseURI(uri string) (ConnConfig, error) { cp.Password, _ = url.User.Password() } - parts := strings.SplitN(url.Host, ":", 2) - cp.Host = parts[0] - if len(parts) == 2 { - p, err := strconv.ParseUint(parts[1], 10, 16) - if err != nil { - return cp, err + hasMuliHosts := strings.IndexByte(url.Host, ',') != -1 + if !hasMuliHosts { + parts := strings.SplitN(url.Host, ":", 2) + cp.Host = parts[0] + if len(parts) == 2 { + p, err := strconv.ParseUint(parts[1], 10, 16) + if err != nil { + return cp, err + } + cp.Port = uint16(p) } - cp.Port = uint16(p) + } else { + cp.Host = url.Host } + cp.Database = strings.TrimLeft(url.Path, "/") + cp.TargetSessionAttrs = TargetSessionType(url.Query().Get("target_session_attrs")) + if err := cp.TargetSessionAttrs.isValid(); err != nil { + return cp, err + } if pgtimeout := url.Query().Get("connect_timeout"); pgtimeout != "" { timeout, err := strconv.ParseInt(pgtimeout, 10, 64) @@ -980,11 +991,12 @@ func ParseURI(uri string) (ConnConfig, error) { } ignoreKeys := map[string]struct{}{ - "connect_timeout": {}, - "sslcert": {}, - "sslkey": {}, - "sslmode": {}, - "sslrootcert": {}, + "connect_timeout": {}, + "sslcert": {}, + "sslkey": {}, + "sslmode": {}, + "sslrootcert": {}, + "target_session_attrs": {}, } cp.RuntimeParams = make(map[string]string) @@ -1029,6 +1041,7 @@ func ParseDSN(s string) (ConnConfig, error) { cp.RuntimeParams = make(map[string]string) + var hostval, portval string for _, b := range m { switch b[1] { case "user": @@ -1036,13 +1049,9 @@ func ParseDSN(s string) (ConnConfig, error) { case "password": cp.Password = b[2] case "host": - cp.Host = b[2] + hostval = b[2] case "port": - p, err := strconv.ParseUint(b[2], 10, 16) - if err != nil { - return cp, err - } - cp.Port = uint16(p) + portval = b[2] case "dbname": cp.Database = b[2] case "sslmode": @@ -1061,26 +1070,94 @@ func ParseDSN(s string) (ConnConfig, error) { d := defaultDialer() d.Timeout = time.Duration(timeout) * time.Second cp.Dial = d.Dial + case "target_session_attrs": + cp.TargetSessionAttrs = TargetSessionType(b[2]) + if err := cp.TargetSessionAttrs.isValid(); err != nil { + return cp, err + } default: cp.RuntimeParams[b[1]] = b[2] } } - err := configTLS(tlsArgs, &cp) + host, port, err := parseHostPortDSN(hostval, portval) if err != nil { return cp, err } + + cp.Host, cp.Port = host, port + + err = configTLS(tlsArgs, &cp) + if err != nil { + return cp, err + } + if cp.Password == "" { pgpass(&cp) } + return cp, nil } -// ParseConnectionString parses either a URI or a DSN connection string. -// see ParseURI and ParseDSN for details. +func parseHostPortDSN(hostval, portval string) (host string, port uint16, err error) { + if portval == "" { + return hostval, 0, nil + } + + hosts := strings.Split(hostval, ",") + ports := strings.Split(portval, ",") + + if len(ports) == 1 { + port, err := parsePort(portval) + if err != nil { + return "", 0, errors.Errorf("invalid port: %v", err) + } + + return hostval, port, nil + } + + if len(hosts) != len(ports) { + return "", 0, errors.New("the number of hosts and ports must be the same") + } + + hostports := make([]string, len(hosts)) + for i, host := range hosts { + hostports[i] = host + ":" + ports[i] + } + + return strings.Join(hostports, ","), 0, nil +} + +func parsePort(s string) (uint16, error) { + port, err := strconv.ParseUint(s, 10, 16) + if err != nil { + return 0, err + } + if port < 1 || port > math.MaxUint16 { + return 0, errors.New("outside range") + } + return uint16(port), nil +} + +// ParseConnectionString parses either a URI or a DSN connection string and builds ConnConfig. +// +// # Example DSN +// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca +// +// # Example URL +// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca +// +// ParseConnectionString supports specifying multiple hosts in similar manner to libpq. +// Host and port may include comma separated values that will be tried in order. +// This can be used as part of a high availability system. +// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. +// +// # Example URL +// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb +// +// # Example DSN +// user=jack password=secret host=host1,host2,host3 port=5432,5433,5434 dbname=mydb sslmode=verify-ca func ParseConnectionString(s string) (ConnConfig, error) { - // TODO: Multiple Hosts support - // @see: https://www.postgresql.org/docs/10/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS if u, err := url.Parse(s); err == nil && u.Scheme != "" { return ParseURI(s) } diff --git a/conn_test.go b/conn_test.go index 28bfe48b..7719bec7 100644 --- a/conn_test.go +++ b/conn_test.go @@ -622,6 +622,38 @@ func TestParseURI(t *testing.T) { RuntimeParams: map[string]string{}, }, }, + { + url: "postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "foo.example.com:5432,bar.example.com:5432", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "postgres://jack@localhost,10.10.20.30/mydb?application_name=pgxtest&target_session_attrs=read-write", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost,10.10.20.30", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{ + "application_name": "pgxtest", + }, + TargetSessionAttrs: pgx.ReadWriteTargetSession, + }, + }, } for i, tt := range tests { @@ -748,6 +780,50 @@ func TestParseDSN(t *testing.T) { RuntimeParams: map[string]string{}, }, }, + { + url: "user=jack host=localhost1,localhost2 dbname=mydb connect_timeout=10", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost1,localhost2", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user=jack host=100.200.220.50,localhost43 port=5432,5433 dbname=mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "100.200.220.50:5432,localhost43:5433", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user=jack host=localhost dbname=mydb target_session_attrs=read-write", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + TargetSessionAttrs: pgx.ReadWriteTargetSession, + }, + }, } for i, tt := range tests { From 39cbdf789d3448c56ba394557e8100a143694c56 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Thu, 11 Jul 2019 20:56:44 +0300 Subject: [PATCH 10/15] Support of PGTARGETSESSIONATTRS ENV variable Signed-off-by: Artemiy Ryabinkov --- conn.go | 8 +++++++- pgpass.go | 1 - 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/conn.go b/conn.go index fd134461..76e09e73 100644 --- a/conn.go +++ b/conn.go @@ -1016,6 +1016,7 @@ func ParseURI(uri string) (ConnConfig, error) { if cp.Password == "" { pgpass(&cp) } + return cp, nil } @@ -1181,7 +1182,7 @@ func ParseConnectionString(s string) (ConnConfig, error) { // PGSSLROOTCERT // PGAPPNAME // PGCONNECT_TIMEOUT -// TODO: PGTARGETSESSIONATTRS support +// PGTARGETSESSIONATTRS // @see: https://www.postgresql.org/docs/10/libpq-envars.html // // Important TLS Security Notes: @@ -1228,6 +1229,11 @@ func ParseEnvLibpq() (ConnConfig, error) { } } + cc.TargetSessionAttrs = TargetSessionType(os.Getenv("PGTARGETSESSIONATTRS")) + if err := cc.TargetSessionAttrs.isValid(); err != nil { + return cc, err + } + tlsArgs := configTLSArgs{ sslMode: os.Getenv("PGSSLMODE"), sslKey: os.Getenv("PGSSLKEY"), diff --git a/pgpass.go b/pgpass.go index ff97e5f0..34b9bdf5 100644 --- a/pgpass.go +++ b/pgpass.go @@ -57,7 +57,6 @@ func parsepgpass(line, cfgHost, cfgPort, cfgDatabase, cfgUsername string) *strin return &parts[4] } -// TODO: Multi-host support func pgpass(cfg *ConnConfig) (found bool) { passfile := os.Getenv("PGPASSFILE") if passfile == "" { From f87825cac7b1ae3311a31a2093bcb00065667ba6 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Thu, 11 Jul 2019 21:38:29 +0300 Subject: [PATCH 11/15] remove TODO that PR will not cover Signed-off-by: Artemiy Ryabinkov --- conn.go | 1 - 1 file changed, 1 deletion(-) diff --git a/conn.go b/conn.go index 76e09e73..e195bb2f 100644 --- a/conn.go +++ b/conn.go @@ -108,7 +108,6 @@ type ConnConfig struct { Port uint16 Database string User string // default: OS user name - // TODO: Allow password to be different for each host/port pair if a password file is used Password string TLSConfig *tls.Config // config for TLS connection -- nil disables TLS UseFallbackTLS bool // Try FallbackTLSConfig if connecting with TLSConfig fails. Used for preferring TLS, but allowing unencrypted, or vice-versa From 98acf573cce94af544d41b3e2bbfc9d86b7494cf Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Sat, 13 Jul 2019 21:21:23 +0300 Subject: [PATCH 12/15] fix errors collecting on multi-host Signed-off-by: Artemiy Ryabinkov --- conn.go | 4 +- examples/multihosts/README.md | 25 ++++++++++++ examples/multihosts/main.go | 74 +++++++++++++++++++++++++++++++++++ 3 files changed, 101 insertions(+), 2 deletions(-) create mode 100644 examples/multihosts/README.md create mode 100644 examples/multihosts/main.go diff --git a/conn.go b/conn.go index e195bb2f..7f2ae6a8 100644 --- a/conn.go +++ b/conn.go @@ -467,8 +467,8 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) } errmsgs := make([]string, len(errs)) - for _, err := range errs { - errmsgs = append(errmsgs, err.Error()) + for i, err := range errs { + errmsgs[i] = err.Error() } return nil, errors.New(strings.Join(errmsgs, ";")) diff --git a/examples/multihosts/README.md b/examples/multihosts/README.md new file mode 100644 index 00000000..4b73eb51 --- /dev/null +++ b/examples/multihosts/README.md @@ -0,0 +1,25 @@ +# Description + +This is a sample chat program implemented using PostgreSQL's listen/notify +functionality with pgx. + +Start multiple instances of this program connected to the same database to chat +between them. + +## Connection configuration + +The database connection is configured via the standard PostgreSQL environment variables. + +* PGHOST - defaults to localhost +* PGUSER - defaults to current OS user +* PGPASSWORD - defaults to empty string +* PGDATABASE - defaults to user name + +You can either export them then run chat: + + export PGHOST=/private/tmp + ./chat + +Or you can prefix the chat execution with the environment variables: + + PGHOST=/private/tmp ./chat diff --git a/examples/multihosts/main.go b/examples/multihosts/main.go new file mode 100644 index 00000000..83b16c02 --- /dev/null +++ b/examples/multihosts/main.go @@ -0,0 +1,74 @@ +package main + +import ( + "bufio" + "context" + "fmt" + "os" + + "github.com/jackc/pgx" +) + +var pool *pgx.ConnPool + +func main() { + config, err := pgx.ParseEnvLibpq() + if err != nil { + fmt.Fprintln(os.Stderr, "Unable to parse environment:", err) + os.Exit(1) + } + + pool, err = pgx.NewConnPool(pgx.ConnPoolConfig{ConnConfig: config}) + if err != nil { + fmt.Fprintln(os.Stderr, "Unable to connect to database:", err) + os.Exit(1) + } + + go listen() + + fmt.Println(`Type a message and press enter. + +This message should appear in any other chat instances connected to the same +database. + +Type "exit" to quit.`) + + scanner := bufio.NewScanner(os.Stdin) + for scanner.Scan() { + msg := scanner.Text() + if msg == "exit" { + os.Exit(0) + } + + _, err = pool.Exec("select pg_notify('chat', $1)", msg) + if err != nil { + fmt.Fprintln(os.Stderr, "Error sending notification:", err) + os.Exit(1) + } + } + if err := scanner.Err(); err != nil { + fmt.Fprintln(os.Stderr, "Error scanning from stdin:", err) + os.Exit(1) + } +} + +func listen() { + conn, err := pool.Acquire() + if err != nil { + fmt.Fprintln(os.Stderr, "Error acquiring connection:", err) + os.Exit(1) + } + defer pool.Release(conn) + + conn.Listen("chat") + + for { + notification, err := conn.WaitForNotification(context.Background()) + if err != nil { + fmt.Fprintln(os.Stderr, "Error waiting for notification:", err) + os.Exit(1) + } + + fmt.Println("PID:", notification.PID, "Channel:", notification.Channel, "Payload:", notification.Payload) + } +} From a2b647c393b3349c9ddf568d279e7fcd71520f88 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Sat, 13 Jul 2019 22:17:03 +0300 Subject: [PATCH 13/15] drop extra example Signed-off-by: Artemiy Ryabinkov --- conn.go | 2 +- examples/multihosts/README.md | 25 ------------ examples/multihosts/main.go | 74 ----------------------------------- 3 files changed, 1 insertion(+), 100 deletions(-) delete mode 100644 examples/multihosts/README.md delete mode 100644 examples/multihosts/main.go diff --git a/conn.go b/conn.go index 7f2ae6a8..01a572d7 100644 --- a/conn.go +++ b/conn.go @@ -471,7 +471,7 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) errmsgs[i] = err.Error() } - return nil, errors.New(strings.Join(errmsgs, ";")) + return nil, errors.New(strings.Join(errmsgs, "; ")) } func (c *Conn) checkWritable() error { diff --git a/examples/multihosts/README.md b/examples/multihosts/README.md deleted file mode 100644 index 4b73eb51..00000000 --- a/examples/multihosts/README.md +++ /dev/null @@ -1,25 +0,0 @@ -# Description - -This is a sample chat program implemented using PostgreSQL's listen/notify -functionality with pgx. - -Start multiple instances of this program connected to the same database to chat -between them. - -## Connection configuration - -The database connection is configured via the standard PostgreSQL environment variables. - -* PGHOST - defaults to localhost -* PGUSER - defaults to current OS user -* PGPASSWORD - defaults to empty string -* PGDATABASE - defaults to user name - -You can either export them then run chat: - - export PGHOST=/private/tmp - ./chat - -Or you can prefix the chat execution with the environment variables: - - PGHOST=/private/tmp ./chat diff --git a/examples/multihosts/main.go b/examples/multihosts/main.go deleted file mode 100644 index 83b16c02..00000000 --- a/examples/multihosts/main.go +++ /dev/null @@ -1,74 +0,0 @@ -package main - -import ( - "bufio" - "context" - "fmt" - "os" - - "github.com/jackc/pgx" -) - -var pool *pgx.ConnPool - -func main() { - config, err := pgx.ParseEnvLibpq() - if err != nil { - fmt.Fprintln(os.Stderr, "Unable to parse environment:", err) - os.Exit(1) - } - - pool, err = pgx.NewConnPool(pgx.ConnPoolConfig{ConnConfig: config}) - if err != nil { - fmt.Fprintln(os.Stderr, "Unable to connect to database:", err) - os.Exit(1) - } - - go listen() - - fmt.Println(`Type a message and press enter. - -This message should appear in any other chat instances connected to the same -database. - -Type "exit" to quit.`) - - scanner := bufio.NewScanner(os.Stdin) - for scanner.Scan() { - msg := scanner.Text() - if msg == "exit" { - os.Exit(0) - } - - _, err = pool.Exec("select pg_notify('chat', $1)", msg) - if err != nil { - fmt.Fprintln(os.Stderr, "Error sending notification:", err) - os.Exit(1) - } - } - if err := scanner.Err(); err != nil { - fmt.Fprintln(os.Stderr, "Error scanning from stdin:", err) - os.Exit(1) - } -} - -func listen() { - conn, err := pool.Acquire() - if err != nil { - fmt.Fprintln(os.Stderr, "Error acquiring connection:", err) - os.Exit(1) - } - defer pool.Release(conn) - - conn.Listen("chat") - - for { - notification, err := conn.WaitForNotification(context.Background()) - if err != nil { - fmt.Fprintln(os.Stderr, "Error waiting for notification:", err) - os.Exit(1) - } - - fmt.Println("PID:", notification.PID, "Channel:", notification.Channel, "Payload:", notification.Payload) - } -} From 1ecc111e17995b5aba2e0b7b1fd57c616f9172a7 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Sun, 14 Jul 2019 18:29:08 +0300 Subject: [PATCH 14/15] Reuse pool.connInfo for createConnectionUnlocked method Signed-off-by: Artemiy Ryabinkov --- conn_pool.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conn_pool.go b/conn_pool.go index 47a0b391..d43b6337 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -319,7 +319,7 @@ func (p *ConnPool) createConnection() (*Conn, error) { func (p *ConnPool) createConnectionUnlocked() (*Conn, error) { p.inProgressConnects++ p.cond.L.Unlock() - c, err := Connect(p.config) + c, err := connect(p.config, p.connInfo) p.cond.L.Lock() p.inProgressConnects-- From 8e0e1123dfa4f7280ad56da42fa211bb91ea39f4 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Sun, 14 Jul 2019 20:04:55 +0300 Subject: [PATCH 15/15] use deepCopy of connInfo in createConnectionUnlocked method Signed-off-by: Artemiy Ryabinkov --- conn_pool.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conn_pool.go b/conn_pool.go index d43b6337..e8972a0b 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -319,7 +319,7 @@ func (p *ConnPool) createConnection() (*Conn, error) { func (p *ConnPool) createConnectionUnlocked() (*Conn, error) { p.inProgressConnects++ p.cond.L.Unlock() - c, err := connect(p.config, p.connInfo) + c, err := connect(p.config, p.connInfo.DeepCopy()) p.cond.L.Lock() p.inProgressConnects--