2
0

Return net.Addr from networkAddress

Signed-off-by: Artemiy Ryabinkov <getlag@ya.ru>
This commit is contained in:
Artemiy Ryabinkov
2019-06-16 14:03:43 +03:00
parent 9538d15c29
commit 9f031bb8f9
+59 -23
View File
@@ -98,20 +98,24 @@ type ConnConfig struct {
TargetSessionAttrs string TargetSessionAttrs string
} }
func (cc *ConnConfig) networkAddress() (network, address string) { func (cc *ConnConfig) networkAddress() net.Addr {
network = "tcp" // See if host is a valid path, if yes connect with a unix socket
address = fmt.Sprintf("%s:%d", cc.Host, cc.Port)
// See if host is a valid path, if yes connect with a socket
if _, err := os.Stat(cc.Host); err == nil { if _, err := os.Stat(cc.Host); err == nil {
// For backward compatibility accept socket file paths -- but directories are now preferred // For backward compatibility accept socket file paths -- but directories are now preferred
network = "unix" network := "unix"
address = cc.Host address := cc.Host
if !strings.Contains(address, "/.s.PGSQL.") { if !strings.Contains(address, "/.s.PGSQL.") {
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10) address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatUint(uint64(cc.Port), 10)
} }
return &net.UnixAddr{Name: address, Net: network}
} }
return network, address return &net.TCPAddr{
Port: int(cc.Port),
IP: net.ParseIP(cc.Host),
}
} }
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
@@ -277,48 +281,80 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error)
c.onNotice = config.OnNotice c.onNotice = config.OnNotice
// TODO: Parse multi-hosts
network, address := c.config.networkAddress()
if c.config.Dial == nil { if c.config.Dial == nil {
d := defaultDialer() d := defaultDialer()
c.config.Dial = d.Dial c.config.Dial = d.Dial
} }
if c.shouldLog(LogLevelInfo) { // TODO: Parse multi-hosts
c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"network": network, "address": address}) hostAddr := c.config.networkAddress()
}
addrs := []net.Addr{hostAddr}
var errs []error
for _, addr := range addrs {
network, address := addr.Network(), addr.String()
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{
"network": network,
"address": address,
})
}
// TODO: Start loop for all hosts [host0 .. hostN]
for {
err = c.connect(config, network, address, config.TLSConfig) err = c.connect(config, network, address, config.TLSConfig)
if err != nil && config.UseFallbackTLS { if err != nil && config.UseFallbackTLS {
if c.shouldLog(LogLevelInfo) { if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err}) c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{
"err": err,
"network": network,
"address": address,
})
} }
err = c.connect(config, network, address, config.FallbackTLSConfig) err = c.connect(config, network, address, config.FallbackTLSConfig)
} }
if err != nil { if err != nil {
if c.shouldLog(LogLevelError) { if c.shouldLog(LogLevelError) {
c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err}) c.log(LogLevelError, "connect failed", map[string]interface{}{
"err": err,
"network": network,
"address": address,
})
} }
// TODO: Collect error
errs = append(errs, err)
continue continue
} }
err = c.writeable() err = c.writeable()
if err != nil { if err != nil {
// TODO: Log info about not writable host if c.shouldLog(LogLevelInfo) {
// TODO: Collect error c.log(LogLevelInfo, "host is not writable", map[string]interface{}{
"err": err,
"network": network,
"address": address,
})
}
errs = append(errs, err)
continue continue
} }
return c, nil return c, nil
} }
// To keep backwards, if specific error type expected.
if len(errs) == 1 {
return nil, errs[0]
}
// TODO: Return collected errors var errmsg string
return nil, nil for _, err := range errs {
errmsg += ";" + err.Error()
}
return nil, errors.New(errmsg)
} }
func (c *Conn) writeable() error { func (c *Conn) writeable() error {
@@ -331,7 +367,7 @@ func (c *Conn) writeable() error {
Scan(st) Scan(st)
if err != nil { if err != nil {
return errors.Wrap(err, "failed to fetch transaction_read_only state") return errors.Wrap(err, "failed to fetch \"transaction_read_only\" state")
} }
if st == "on" { if st == "on" {