diff --git a/conn.go b/conn.go index 287ce7db..a6ea78f1 100644 --- a/conn.go +++ b/conn.go @@ -63,10 +63,29 @@ type DialFunc func(network, addr string) (net.Conn, error) // ConnConfig contains all the options used to establish a connection. type ConnConfig struct { - Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) - Port uint16 // default: 5432 - Database string - User string // default: OS user name + // Name of host to connect to. (e.g. localhost) + // If a host name begins with a slash, it specifies Unix-domain communication + // rather than TCP/IP communication; the value is the name of the directory + // in which the socket file is stored. (e.g. /private/tmp) + // The default behavior when host is not specified, or is empty, is to connect to localhost. + // + // A comma-separated list of host names is also accepted, + // in which case each host name in the list is tried in order; + // an empty item in the list selects the default behavior as explained above. + // @see https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS + Host string + + // Port number to connect to at the server host, + // or socket file name extension for Unix-domain connections. + // An empty or zero value, specifies the default port number — 5432. + // + // If multiple hosts were given in the Host parameter, then + // this parameter may specify a single port number to be used for all hosts, + // or for those that haven't port explicitly defined. + Port uint16 + Database string + User string // default: OS user name + // TODO: Allow password to be different for each host/port pair if a password file is used Password string TLSConfig *tls.Config // config for TLS connection -- nil disables TLS UseFallbackTLS bool // Try FallbackTLSConfig if connecting with TLSConfig fails. Used for preferring TLS, but allowing unencrypted, or vice-versa @@ -95,10 +114,34 @@ type ConnConfig struct { // "read-write", to disallow connections to read-only servers, hot standbys for example. // @see https://www.postgresql.org/message-id/CAD__OuhqPRGpcsfwPHz_PDqAGkoqS1UvnUnOnAB-LBWBW=wu4A@mail.gmail.com // @see https://paquier.xyz/postgresql-2/postgres-10-libpq-read-write/ + // + // The query SHOW transaction_read_only will be sent upon any successful connection; + // if it returns on, the connection will be closed. + // If multiple hosts were specified in the connection string, + // any remaining servers will be tried just as if the connection attempt had failed. + // The default value of this parameter, any, regards all connections as acceptable. TargetSessionAttrs string } -func (cc *ConnConfig) networkAddress() net.Addr { +// hostAddr represents network end point defined as hostname or IP + port. +type hostAddr struct { + Host string + Port uint16 +} + +// Network returns the address's network name, "tcp". +func (a *hostAddr) Network() string { return "tcp" } + +// String implements net.Addr String method. +func (a *hostAddr) String() string { + if a == nil { + return "" + } + + return net.JoinHostPort(a.Host, strconv.Itoa(int(a.Port))) +} + +func (cc *ConnConfig) networkAddresses() ([]net.Addr, error) { // See if host is a valid path, if yes connect with a unix socket if _, err := os.Stat(cc.Host); err == nil { // For backward compatibility accept socket file paths -- but directories are now preferred @@ -109,13 +152,50 @@ func (cc *ConnConfig) networkAddress() net.Addr { address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatUint(uint64(cc.Port), 10) } - return &net.UnixAddr{Name: address, Net: network} + addrs := []net.Addr{ + &net.UnixAddr{Name: address, Net: network}, + } + + return addrs, nil } - return &net.TCPAddr{ - Port: int(cc.Port), - IP: net.ParseIP(cc.Host), + if cc.Host == "" { + addrs := []net.Addr{ + &net.TCPAddr{Port: int(cc.Port)}, + } + + return addrs, nil } + + var addrs []net.Addr + + hostports := strings.Split(cc.Host, ",") + for i, hostport := range hostports { + if hostport == "" { + return nil, fmt.Errorf("multi-host part %d is empty, at least host or port must be defined", i) + } + + // It's not possible to use net.TCPAddr here, cuz host may be hostname. + addr := hostAddr{ + Host: hostport, + Port: cc.Port, + } + + pos := strings.IndexByte(hostport, ':') + if pos != -1 { + p, err := strconv.ParseUint(hostport[pos+1:], 10, 16) + if err != nil { + return nil, fmt.Errorf("multi-host part %d (%s) has invalid port format", i, hostport) + } + + addr.Host = hostport[:pos] + addr.Port = uint16(p) + } + + addrs = append(addrs, &addr) + } + + return addrs, nil } // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. @@ -156,6 +236,10 @@ type Conn struct { ConnInfo *pgtype.ConnInfo frontend *pgproto3.Frontend + + // In case of Multiple Hosts we need to know what addr was used to connect. + // This address will be used to send a cancellation request. + addr net.Addr } // PreparedStatement is a description of a prepared statement @@ -286,10 +370,10 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) c.config.Dial = d.Dial } - // TODO: Parse multi-hosts - hostAddr := c.config.networkAddress() - - addrs := []net.Addr{hostAddr} + addrs, err := c.config.networkAddresses() + if err != nil { + return nil, err + } var errs []error for _, addr := range addrs { @@ -323,12 +407,23 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) }) } + // On any auth errors return immediately + if pgErr, ok := err.(PgError); ok { + switch pgErr.Code { + // @see: https://www.postgresql.org/docs/current/errcodes-appendix.html + case "28000", "28P01": // Invalid Authorization Specification + return nil, pgErr + } + } + errs = append(errs, err) continue } - err = c.writeable() + err = c.writable() if err != nil { + c.die(err) + if c.shouldLog(LogLevelInfo) { c.log(LogLevelInfo, "host is not writable", map[string]interface{}{ "err": err, @@ -341,6 +436,8 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) continue } + c.addr = addr + return c, nil } @@ -351,30 +448,34 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) var errmsg string for _, err := range errs { - errmsg += ";" + err.Error() + errmsg += "; " + err.Error() } return nil, errors.New(errmsg) } -func (c *Conn) writeable() error { +func (c *Conn) writable() error { if c.config.TargetSessionAttrs == "" || c.config.TargetSessionAttrs == "any" { return nil } var st string err := c.QueryRowEx(context.Background(), "SHOW transaction_read_only", &QueryExOptions{SimpleProtocol: true}). - Scan(st) + Scan(&st) if err != nil { return errors.Wrap(err, "failed to fetch \"transaction_read_only\" state") } - if st == "on" { + switch st { + case "on": return errors.New("writable connection disabled by server") + case "off": + // If transaction_read_only = off, then connection is writable. + return nil } - return nil + return errors.New("unexpected \"transaction_read_only\" status") } func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) { @@ -958,6 +1059,8 @@ func ParseDSN(s string) (ConnConfig, error) { // ParseConnectionString parses either a URI or a DSN connection string. // see ParseURI and ParseDSN for details. func ParseConnectionString(s string) (ConnConfig, error) { + // TODO: Multiple Hosts support + // @see: https://www.postgresql.org/docs/10/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS if u, err := url.Parse(s); err == nil && u.Scheme != "" { return ParseURI(s) } @@ -981,6 +1084,8 @@ func ParseConnectionString(s string) (ConnConfig, error) { // PGSSLROOTCERT // PGAPPNAME // PGCONNECT_TIMEOUT +// TODO: PGTARGETSESSIONATTRS support +// @see: https://www.postgresql.org/docs/10/libpq-envars.html // // Important TLS Security Notes: // ParseEnvLibpq tries to match libpq behavior with regard to PGSSLMODE. This @@ -1741,8 +1846,7 @@ func quoteIdentifier(s string) string { } func doCancel(c *Conn) error { - addr := c.config.networkAddress() - cancelConn, err := c.config.Dial(addr.Network(), addr.String()) + cancelConn, err := c.config.Dial(c.addr.Network(), c.addr.String()) if err != nil { return err } diff --git a/conn_config_test.go.example b/conn_config_test.go.example index 096e1354..620b0ea1 100644 --- a/conn_config_test.go.example +++ b/conn_config_test.go.example @@ -7,6 +7,7 @@ import ( // "go/build" // "io/ioutil" // "path" + // "net" "github.com/jackc/pgx" ) @@ -14,6 +15,7 @@ import ( var defaultConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} // To skip tests for specific connection / authentication types set that connection param to nil +var multihostConnConfig *pgx.ConnConfig = nil var tcpConnConfig *pgx.ConnConfig = nil var unixSocketConnConfig *pgx.ConnConfig = nil var md5ConnConfig *pgx.ConnConfig = nil @@ -24,6 +26,7 @@ var customDialerConnConfig *pgx.ConnConfig = nil var replicationConnConfig *pgx.ConnConfig = nil var cratedbConnConfig *pgx.ConnConfig = nil +// var multihostConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "2.2.2.2:1,127.0.0.1,4.2.4.2", User: "pgx_md5", Password: "secret", Database: "pgx_test", Dial: (&net.Dialer{KeepAlive: 5 * time.Minute, Timeout: 100 * time.Millisecond}).Dial} // var tcpConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} // var unixSocketConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "/private/tmp", User: "pgx_none", Database: "pgx_test"} // var md5ConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} diff --git a/conn_config_test.go.travis b/conn_config_test.go.travis index cf29a743..738f1112 100644 --- a/conn_config_test.go.travis +++ b/conn_config_test.go.travis @@ -5,9 +5,11 @@ import ( "github.com/jackc/pgx" "os" "strconv" + "net" ) var defaultConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} +var multihostConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "2.2.2.2:1,127.0.0.1,4.2.4.2", User: "pgx_md5", Password: "secret", Database: "pgx_test", Dial: (&net.Dialer{KeepAlive: 5 * time.Minute, Timeout: 100 * time.Millisecond}).Dial} var tcpConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} var unixSocketConnConfig = &pgx.ConnConfig{Host: "/var/run/postgresql", User: "postgres", Database: "pgx_test"} var md5ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} diff --git a/conn_test.go b/conn_test.go index 6ca00c6d..14efbeca 100644 --- a/conn_test.go +++ b/conn_test.go @@ -84,6 +84,107 @@ func TestConnect(t *testing.T) { } } + +func TestConnectWithMultiHost(t *testing.T) { + t.Parallel() + + if multihostConnConfig == nil { + t.Skip("Skipping due to undefined multihostConnConfig") + } + + conn, err := pgx.Connect(*multihostConnConfig) + if err != nil { + t.Fatalf("Unable to establish connection: %v", err) + } + + if _, present := conn.RuntimeParams["server_version"]; !present { + t.Error("Runtime parameters not stored") + } + + if conn.PID() == 0 { + t.Error("Backend PID not stored") + } + + var currentDB string + err = conn.QueryRow("select current_database()").Scan(¤tDB) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if currentDB != defaultConnConfig.Database { + t.Errorf("Did not connect to specified database (%v)", defaultConnConfig.Database) + } + + var user string + err = conn.QueryRow("select current_user").Scan(&user) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if user != defaultConnConfig.User { + t.Errorf("Did not connect as specified user (%v)", defaultConnConfig.User) + } + + err = conn.Close() + if err != nil { + t.Fatal("Unable to close connection") + } +} + + +func TestConnectWithMultiHostWritable(t *testing.T) { + t.Parallel() + + if multihostConnConfig == nil { + t.Skip("Skipping due to undefined multihostConnConfig") + } + + connConfig := *multihostConnConfig + connConfig.TargetSessionAttrs = "read-write" + + conn := mustConnect(t, connConfig) + defer closeConn(t, conn) + + if _, present := conn.RuntimeParams["server_version"]; !present { + t.Error("Runtime parameters not stored") + } + + if conn.PID() == 0 { + t.Error("Backend PID not stored") + } + + var currentDB string + err := conn.QueryRow("select current_database()").Scan(¤tDB) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if currentDB != defaultConnConfig.Database { + t.Errorf("Did not connect to specified database (%v)", defaultConnConfig.Database) + } + + var user string + err = conn.QueryRow("select current_user").Scan(&user) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if user != defaultConnConfig.User { + t.Errorf("Did not connect as specified user (%v)", defaultConnConfig.User) + } + + var st string + err = conn.QueryRow("SHOW transaction_read_only").Scan(&st) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + + if st == "on" { + t.Error("Connection is not writable") + } + + err = conn.Close() + if err != nil { + t.Fatal("Unable to close connection") + } +} + func TestConnectWithUnixSocketDirectory(t *testing.T) { t.Parallel() diff --git a/pgpass.go b/pgpass.go index 34b9bdf5..ff97e5f0 100644 --- a/pgpass.go +++ b/pgpass.go @@ -57,6 +57,7 @@ func parsepgpass(line, cfgHost, cfgPort, cfgDatabase, cfgUsername string) *strin return &parts[4] } +// TODO: Multi-host support func pgpass(cfg *ConnConfig) (found bool) { passfile := os.Getenv("PGPASSFILE") if passfile == "" {