From 9f031bb8f9bea60bd51ebc1cbaaa8e5db779b191 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Sun, 16 Jun 2019 14:03:43 +0300 Subject: [PATCH] Return net.Addr from networkAddress Signed-off-by: Artemiy Ryabinkov --- conn.go | 82 +++++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 59 insertions(+), 23 deletions(-) diff --git a/conn.go b/conn.go index e0169d6d..323875a5 100644 --- a/conn.go +++ b/conn.go @@ -98,20 +98,24 @@ type ConnConfig struct { TargetSessionAttrs string } -func (cc *ConnConfig) networkAddress() (network, address string) { - network = "tcp" - address = fmt.Sprintf("%s:%d", cc.Host, cc.Port) - // See if host is a valid path, if yes connect with a socket +func (cc *ConnConfig) networkAddress() net.Addr { + // See if host is a valid path, if yes connect with a unix socket if _, err := os.Stat(cc.Host); err == nil { // For backward compatibility accept socket file paths -- but directories are now preferred - network = "unix" - address = cc.Host + network := "unix" + address := cc.Host + if !strings.Contains(address, "/.s.PGSQL.") { - address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10) + address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatUint(uint64(cc.Port), 10) } + + return &net.UnixAddr{Name: address, Net: network} } - return network, address + return &net.TCPAddr{ + Port: int(cc.Port), + IP: net.ParseIP(cc.Host), + } } // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. @@ -277,48 +281,80 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) c.onNotice = config.OnNotice - // TODO: Parse multi-hosts - network, address := c.config.networkAddress() if c.config.Dial == nil { d := defaultDialer() c.config.Dial = d.Dial } - if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"network": network, "address": address}) - } + // TODO: Parse multi-hosts + hostAddr := c.config.networkAddress() + + addrs := []net.Addr{hostAddr} + + var errs []error + for _, addr := range addrs { + network, address := addr.Network(), addr.String() + + if c.shouldLog(LogLevelInfo) { + c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{ + "network": network, + "address": address, + }) + } - // TODO: Start loop for all hosts [host0 .. hostN] - for { err = c.connect(config, network, address, config.TLSConfig) if err != nil && config.UseFallbackTLS { if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err}) + c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{ + "err": err, + "network": network, + "address": address, + }) } err = c.connect(config, network, address, config.FallbackTLSConfig) } if err != nil { if c.shouldLog(LogLevelError) { - c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err}) + c.log(LogLevelError, "connect failed", map[string]interface{}{ + "err": err, + "network": network, + "address": address, + }) } - // TODO: Collect error + + errs = append(errs, err) continue } err = c.writeable() if err != nil { - // TODO: Log info about not writable host - // TODO: Collect error + if c.shouldLog(LogLevelInfo) { + c.log(LogLevelInfo, "host is not writable", map[string]interface{}{ + "err": err, + "network": network, + "address": address, + }) + } + + errs = append(errs, err) continue } return c, nil } + // To keep backwards, if specific error type expected. + if len(errs) == 1 { + return nil, errs[0] + } - // TODO: Return collected errors - return nil, nil + var errmsg string + for _, err := range errs { + errmsg += ";" + err.Error() + } + + return nil, errors.New(errmsg) } func (c *Conn) writeable() error { @@ -331,7 +367,7 @@ func (c *Conn) writeable() error { Scan(st) if err != nil { - return errors.Wrap(err, "failed to fetch transaction_read_only state") + return errors.Wrap(err, "failed to fetch \"transaction_read_only\" state") } if st == "on" {