Return net.Addr from networkAddress
Signed-off-by: Artemiy Ryabinkov <getlag@ya.ru>
This commit is contained in:
@@ -98,20 +98,24 @@ type ConnConfig struct {
|
||||
TargetSessionAttrs string
|
||||
}
|
||||
|
||||
func (cc *ConnConfig) networkAddress() (network, address string) {
|
||||
network = "tcp"
|
||||
address = fmt.Sprintf("%s:%d", cc.Host, cc.Port)
|
||||
// See if host is a valid path, if yes connect with a socket
|
||||
func (cc *ConnConfig) networkAddress() net.Addr {
|
||||
// See if host is a valid path, if yes connect with a unix socket
|
||||
if _, err := os.Stat(cc.Host); err == nil {
|
||||
// For backward compatibility accept socket file paths -- but directories are now preferred
|
||||
network = "unix"
|
||||
address = cc.Host
|
||||
network := "unix"
|
||||
address := cc.Host
|
||||
|
||||
if !strings.Contains(address, "/.s.PGSQL.") {
|
||||
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10)
|
||||
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatUint(uint64(cc.Port), 10)
|
||||
}
|
||||
|
||||
return &net.UnixAddr{Name: address, Net: network}
|
||||
}
|
||||
|
||||
return network, address
|
||||
return &net.TCPAddr{
|
||||
Port: int(cc.Port),
|
||||
IP: net.ParseIP(cc.Host),
|
||||
}
|
||||
}
|
||||
|
||||
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
|
||||
@@ -277,48 +281,80 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error)
|
||||
|
||||
c.onNotice = config.OnNotice
|
||||
|
||||
// TODO: Parse multi-hosts
|
||||
network, address := c.config.networkAddress()
|
||||
if c.config.Dial == nil {
|
||||
d := defaultDialer()
|
||||
c.config.Dial = d.Dial
|
||||
}
|
||||
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"network": network, "address": address})
|
||||
}
|
||||
// TODO: Parse multi-hosts
|
||||
hostAddr := c.config.networkAddress()
|
||||
|
||||
addrs := []net.Addr{hostAddr}
|
||||
|
||||
var errs []error
|
||||
for _, addr := range addrs {
|
||||
network, address := addr.Network(), addr.String()
|
||||
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{
|
||||
"network": network,
|
||||
"address": address,
|
||||
})
|
||||
}
|
||||
|
||||
// TODO: Start loop for all hosts [host0 .. hostN]
|
||||
for {
|
||||
err = c.connect(config, network, address, config.TLSConfig)
|
||||
if err != nil && config.UseFallbackTLS {
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err})
|
||||
c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{
|
||||
"err": err,
|
||||
"network": network,
|
||||
"address": address,
|
||||
})
|
||||
}
|
||||
err = c.connect(config, network, address, config.FallbackTLSConfig)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if c.shouldLog(LogLevelError) {
|
||||
c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err})
|
||||
c.log(LogLevelError, "connect failed", map[string]interface{}{
|
||||
"err": err,
|
||||
"network": network,
|
||||
"address": address,
|
||||
})
|
||||
}
|
||||
// TODO: Collect error
|
||||
|
||||
errs = append(errs, err)
|
||||
continue
|
||||
}
|
||||
|
||||
err = c.writeable()
|
||||
if err != nil {
|
||||
// TODO: Log info about not writable host
|
||||
// TODO: Collect error
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
c.log(LogLevelInfo, "host is not writable", map[string]interface{}{
|
||||
"err": err,
|
||||
"network": network,
|
||||
"address": address,
|
||||
})
|
||||
}
|
||||
|
||||
errs = append(errs, err)
|
||||
continue
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// To keep backwards, if specific error type expected.
|
||||
if len(errs) == 1 {
|
||||
return nil, errs[0]
|
||||
}
|
||||
|
||||
// TODO: Return collected errors
|
||||
return nil, nil
|
||||
var errmsg string
|
||||
for _, err := range errs {
|
||||
errmsg += ";" + err.Error()
|
||||
}
|
||||
|
||||
return nil, errors.New(errmsg)
|
||||
}
|
||||
|
||||
func (c *Conn) writeable() error {
|
||||
@@ -331,7 +367,7 @@ func (c *Conn) writeable() error {
|
||||
Scan(st)
|
||||
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to fetch transaction_read_only state")
|
||||
return errors.Wrap(err, "failed to fetch \"transaction_read_only\" state")
|
||||
}
|
||||
|
||||
if st == "on" {
|
||||
|
||||
Reference in New Issue
Block a user