2
0

Return net.Addr from networkAddress

Signed-off-by: Artemiy Ryabinkov <getlag@ya.ru>
This commit is contained in:
Artemiy Ryabinkov
2019-06-16 14:03:43 +03:00
parent 9538d15c29
commit 9f031bb8f9
+59 -23
View File
@@ -98,20 +98,24 @@ type ConnConfig struct {
TargetSessionAttrs string
}
func (cc *ConnConfig) networkAddress() (network, address string) {
network = "tcp"
address = fmt.Sprintf("%s:%d", cc.Host, cc.Port)
// See if host is a valid path, if yes connect with a socket
func (cc *ConnConfig) networkAddress() net.Addr {
// See if host is a valid path, if yes connect with a unix socket
if _, err := os.Stat(cc.Host); err == nil {
// For backward compatibility accept socket file paths -- but directories are now preferred
network = "unix"
address = cc.Host
network := "unix"
address := cc.Host
if !strings.Contains(address, "/.s.PGSQL.") {
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10)
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatUint(uint64(cc.Port), 10)
}
return &net.UnixAddr{Name: address, Net: network}
}
return network, address
return &net.TCPAddr{
Port: int(cc.Port),
IP: net.ParseIP(cc.Host),
}
}
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
@@ -277,48 +281,80 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error)
c.onNotice = config.OnNotice
// TODO: Parse multi-hosts
network, address := c.config.networkAddress()
if c.config.Dial == nil {
d := defaultDialer()
c.config.Dial = d.Dial
}
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"network": network, "address": address})
}
// TODO: Parse multi-hosts
hostAddr := c.config.networkAddress()
addrs := []net.Addr{hostAddr}
var errs []error
for _, addr := range addrs {
network, address := addr.Network(), addr.String()
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{
"network": network,
"address": address,
})
}
// TODO: Start loop for all hosts [host0 .. hostN]
for {
err = c.connect(config, network, address, config.TLSConfig)
if err != nil && config.UseFallbackTLS {
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err})
c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{
"err": err,
"network": network,
"address": address,
})
}
err = c.connect(config, network, address, config.FallbackTLSConfig)
}
if err != nil {
if c.shouldLog(LogLevelError) {
c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err})
c.log(LogLevelError, "connect failed", map[string]interface{}{
"err": err,
"network": network,
"address": address,
})
}
// TODO: Collect error
errs = append(errs, err)
continue
}
err = c.writeable()
if err != nil {
// TODO: Log info about not writable host
// TODO: Collect error
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "host is not writable", map[string]interface{}{
"err": err,
"network": network,
"address": address,
})
}
errs = append(errs, err)
continue
}
return c, nil
}
// To keep backwards, if specific error type expected.
if len(errs) == 1 {
return nil, errs[0]
}
// TODO: Return collected errors
return nil, nil
var errmsg string
for _, err := range errs {
errmsg += ";" + err.Error()
}
return nil, errors.New(errmsg)
}
func (c *Conn) writeable() error {
@@ -331,7 +367,7 @@ func (c *Conn) writeable() error {
Scan(st)
if err != nil {
return errors.Wrap(err, "failed to fetch transaction_read_only state")
return errors.Wrap(err, "failed to fetch \"transaction_read_only\" state")
}
if st == "on" {