Return net.Addr from networkAddress
Signed-off-by: Artemiy Ryabinkov <getlag@ya.ru>
This commit is contained in:
@@ -98,20 +98,24 @@ type ConnConfig struct {
|
|||||||
TargetSessionAttrs string
|
TargetSessionAttrs string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cc *ConnConfig) networkAddress() (network, address string) {
|
func (cc *ConnConfig) networkAddress() net.Addr {
|
||||||
network = "tcp"
|
// See if host is a valid path, if yes connect with a unix socket
|
||||||
address = fmt.Sprintf("%s:%d", cc.Host, cc.Port)
|
|
||||||
// See if host is a valid path, if yes connect with a socket
|
|
||||||
if _, err := os.Stat(cc.Host); err == nil {
|
if _, err := os.Stat(cc.Host); err == nil {
|
||||||
// For backward compatibility accept socket file paths -- but directories are now preferred
|
// For backward compatibility accept socket file paths -- but directories are now preferred
|
||||||
network = "unix"
|
network := "unix"
|
||||||
address = cc.Host
|
address := cc.Host
|
||||||
|
|
||||||
if !strings.Contains(address, "/.s.PGSQL.") {
|
if !strings.Contains(address, "/.s.PGSQL.") {
|
||||||
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10)
|
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatUint(uint64(cc.Port), 10)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return &net.UnixAddr{Name: address, Net: network}
|
||||||
}
|
}
|
||||||
|
|
||||||
return network, address
|
return &net.TCPAddr{
|
||||||
|
Port: int(cc.Port),
|
||||||
|
IP: net.ParseIP(cc.Host),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
|
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
|
||||||
@@ -277,48 +281,80 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error)
|
|||||||
|
|
||||||
c.onNotice = config.OnNotice
|
c.onNotice = config.OnNotice
|
||||||
|
|
||||||
// TODO: Parse multi-hosts
|
|
||||||
network, address := c.config.networkAddress()
|
|
||||||
if c.config.Dial == nil {
|
if c.config.Dial == nil {
|
||||||
d := defaultDialer()
|
d := defaultDialer()
|
||||||
c.config.Dial = d.Dial
|
c.config.Dial = d.Dial
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.shouldLog(LogLevelInfo) {
|
// TODO: Parse multi-hosts
|
||||||
c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"network": network, "address": address})
|
hostAddr := c.config.networkAddress()
|
||||||
}
|
|
||||||
|
addrs := []net.Addr{hostAddr}
|
||||||
|
|
||||||
|
var errs []error
|
||||||
|
for _, addr := range addrs {
|
||||||
|
network, address := addr.Network(), addr.String()
|
||||||
|
|
||||||
|
if c.shouldLog(LogLevelInfo) {
|
||||||
|
c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{
|
||||||
|
"network": network,
|
||||||
|
"address": address,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Start loop for all hosts [host0 .. hostN]
|
|
||||||
for {
|
|
||||||
err = c.connect(config, network, address, config.TLSConfig)
|
err = c.connect(config, network, address, config.TLSConfig)
|
||||||
if err != nil && config.UseFallbackTLS {
|
if err != nil && config.UseFallbackTLS {
|
||||||
if c.shouldLog(LogLevelInfo) {
|
if c.shouldLog(LogLevelInfo) {
|
||||||
c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err})
|
c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{
|
||||||
|
"err": err,
|
||||||
|
"network": network,
|
||||||
|
"address": address,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
err = c.connect(config, network, address, config.FallbackTLSConfig)
|
err = c.connect(config, network, address, config.FallbackTLSConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if c.shouldLog(LogLevelError) {
|
if c.shouldLog(LogLevelError) {
|
||||||
c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err})
|
c.log(LogLevelError, "connect failed", map[string]interface{}{
|
||||||
|
"err": err,
|
||||||
|
"network": network,
|
||||||
|
"address": address,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
// TODO: Collect error
|
|
||||||
|
errs = append(errs, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err = c.writeable()
|
err = c.writeable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO: Log info about not writable host
|
if c.shouldLog(LogLevelInfo) {
|
||||||
// TODO: Collect error
|
c.log(LogLevelInfo, "host is not writable", map[string]interface{}{
|
||||||
|
"err": err,
|
||||||
|
"network": network,
|
||||||
|
"address": address,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
errs = append(errs, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// To keep backwards, if specific error type expected.
|
||||||
|
if len(errs) == 1 {
|
||||||
|
return nil, errs[0]
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Return collected errors
|
var errmsg string
|
||||||
return nil, nil
|
for _, err := range errs {
|
||||||
|
errmsg += ";" + err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errors.New(errmsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) writeable() error {
|
func (c *Conn) writeable() error {
|
||||||
@@ -331,7 +367,7 @@ func (c *Conn) writeable() error {
|
|||||||
Scan(st)
|
Scan(st)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "failed to fetch transaction_read_only state")
|
return errors.Wrap(err, "failed to fetch \"transaction_read_only\" state")
|
||||||
}
|
}
|
||||||
|
|
||||||
if st == "on" {
|
if st == "on" {
|
||||||
|
|||||||
Reference in New Issue
Block a user