From 9538d15c29005e5044da6ba3f4c8ff06daec1278 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Mon, 3 Jun 2019 23:51:48 +0300 Subject: [PATCH] Draft of connection writable checking Signed-off-by: Artemiy Ryabinkov --- conn.go | 74 +++++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 64 insertions(+), 10 deletions(-) diff --git a/conn.go b/conn.go index cb24748c..e0169d6d 100644 --- a/conn.go +++ b/conn.go @@ -89,6 +89,13 @@ type ConnConfig struct { // used by default. The same functionality can be controlled on a per query // basis by setting QueryExOptions.SimpleProtocol. PreferSimpleProtocol bool + + // TargetSessionAttr allows to specify which servers are accepted for this connection. + // "any", meaning that any kind of servers can be accepted. This is as well the default value. + // "read-write", to disallow connections to read-only servers, hot standbys for example. + // @see https://www.postgresql.org/message-id/CAD__OuhqPRGpcsfwPHz_PDqAGkoqS1UvnUnOnAB-LBWBW=wu4A@mail.gmail.com + // @see https://paquier.xyz/postgresql-2/postgres-10-libpq-read-write/ + TargetSessionAttrs string } func (cc *ConnConfig) networkAddress() (network, address string) { @@ -262,8 +269,15 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) } } + if c.config.TargetSessionAttrs != "" && + c.config.TargetSessionAttrs != "any" && + c.config.TargetSessionAttrs != "read-write" { + return nil, errors.New("invalid value for target_session_attrs, expected \"any\" or \"read-write\"") + } + c.onNotice = config.OnNotice + // TODO: Parse multi-hosts network, address := c.config.networkAddress() if c.config.Dial == nil { d := defaultDialer() @@ -273,22 +287,58 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) if c.shouldLog(LogLevelInfo) { c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"network": network, "address": address}) } - err = c.connect(config, network, address, config.TLSConfig) - if err != nil && config.UseFallbackTLS { - if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err}) + + // TODO: Start loop for all hosts [host0 .. hostN] + for { + err = c.connect(config, network, address, config.TLSConfig) + if err != nil && config.UseFallbackTLS { + if c.shouldLog(LogLevelInfo) { + c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err}) + } + err = c.connect(config, network, address, config.FallbackTLSConfig) } - err = c.connect(config, network, address, config.FallbackTLSConfig) + + if err != nil { + if c.shouldLog(LogLevelError) { + c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err}) + } + // TODO: Collect error + continue + } + + err = c.writeable() + if err != nil { + // TODO: Log info about not writable host + // TODO: Collect error + continue + } + + return c, nil } + + // TODO: Return collected errors + return nil, nil +} + +func (c *Conn) writeable() error { + if c.config.TargetSessionAttrs == "" || c.config.TargetSessionAttrs == "any" { + return nil + } + + var st string + err := c.QueryRowEx(context.Background(), "SHOW transaction_read_only", &QueryExOptions{SimpleProtocol: true}). + Scan(st) + if err != nil { - if c.shouldLog(LogLevelError) { - c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err}) - } - return nil, err + return errors.Wrap(err, "failed to fetch transaction_read_only state") } - return c, nil + if st == "on" { + return errors.New("writable connection disabled by server") + } + + return nil } func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) { @@ -709,6 +759,10 @@ func (old ConnConfig) Merge(other ConnConfig) ConnConfig { cc.PreferSimpleProtocol = old.PreferSimpleProtocol || other.PreferSimpleProtocol + if other.TargetSessionAttrs != "" { + cc.TargetSessionAttrs = other.TargetSessionAttrs + } + cc.RuntimeParams = make(map[string]string) for k, v := range old.RuntimeParams { cc.RuntimeParams[k] = v