2
0

Draft of connection writable checking

Signed-off-by: Artemiy Ryabinkov <getlag@ya.ru>
This commit is contained in:
Artemiy Ryabinkov
2019-06-03 23:51:48 +03:00
parent 9be6a06c27
commit 9538d15c29
+64 -10
View File
@@ -89,6 +89,13 @@ type ConnConfig struct {
// used by default. The same functionality can be controlled on a per query
// basis by setting QueryExOptions.SimpleProtocol.
PreferSimpleProtocol bool
// TargetSessionAttr allows to specify which servers are accepted for this connection.
// "any", meaning that any kind of servers can be accepted. This is as well the default value.
// "read-write", to disallow connections to read-only servers, hot standbys for example.
// @see https://www.postgresql.org/message-id/CAD__OuhqPRGpcsfwPHz_PDqAGkoqS1UvnUnOnAB-LBWBW=wu4A@mail.gmail.com
// @see https://paquier.xyz/postgresql-2/postgres-10-libpq-read-write/
TargetSessionAttrs string
}
func (cc *ConnConfig) networkAddress() (network, address string) {
@@ -262,8 +269,15 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error)
}
}
if c.config.TargetSessionAttrs != "" &&
c.config.TargetSessionAttrs != "any" &&
c.config.TargetSessionAttrs != "read-write" {
return nil, errors.New("invalid value for target_session_attrs, expected \"any\" or \"read-write\"")
}
c.onNotice = config.OnNotice
// TODO: Parse multi-hosts
network, address := c.config.networkAddress()
if c.config.Dial == nil {
d := defaultDialer()
@@ -273,22 +287,58 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error)
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"network": network, "address": address})
}
err = c.connect(config, network, address, config.TLSConfig)
if err != nil && config.UseFallbackTLS {
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err})
// TODO: Start loop for all hosts [host0 .. hostN]
for {
err = c.connect(config, network, address, config.TLSConfig)
if err != nil && config.UseFallbackTLS {
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err})
}
err = c.connect(config, network, address, config.FallbackTLSConfig)
}
err = c.connect(config, network, address, config.FallbackTLSConfig)
if err != nil {
if c.shouldLog(LogLevelError) {
c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err})
}
// TODO: Collect error
continue
}
err = c.writeable()
if err != nil {
// TODO: Log info about not writable host
// TODO: Collect error
continue
}
return c, nil
}
// TODO: Return collected errors
return nil, nil
}
func (c *Conn) writeable() error {
if c.config.TargetSessionAttrs == "" || c.config.TargetSessionAttrs == "any" {
return nil
}
var st string
err := c.QueryRowEx(context.Background(), "SHOW transaction_read_only", &QueryExOptions{SimpleProtocol: true}).
Scan(st)
if err != nil {
if c.shouldLog(LogLevelError) {
c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err})
}
return nil, err
return errors.Wrap(err, "failed to fetch transaction_read_only state")
}
return c, nil
if st == "on" {
return errors.New("writable connection disabled by server")
}
return nil
}
func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) {
@@ -709,6 +759,10 @@ func (old ConnConfig) Merge(other ConnConfig) ConnConfig {
cc.PreferSimpleProtocol = old.PreferSimpleProtocol || other.PreferSimpleProtocol
if other.TargetSessionAttrs != "" {
cc.TargetSessionAttrs = other.TargetSessionAttrs
}
cc.RuntimeParams = make(map[string]string)
for k, v := range old.RuntimeParams {
cc.RuntimeParams[k] = v