Support using a custom dialer
For example I may want to use a dialer which retries transient network errors (e.g. DNS issues). Signed-off-by: Lewis Marshall <lewis@lmars.net>
This commit is contained in:
@@ -20,6 +20,8 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type DialFunc func(network, addr string) (net.Conn, error)
|
||||
|
||||
// ConnConfig contains all the options used to establish a connection.
|
||||
type ConnConfig struct {
|
||||
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
|
||||
@@ -29,6 +31,7 @@ type ConnConfig struct {
|
||||
Password string
|
||||
TLSConfig *tls.Config // config for TLS connection -- nil disables TLS
|
||||
Logger Logger
|
||||
Dial DialFunc
|
||||
}
|
||||
|
||||
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
|
||||
@@ -122,30 +125,26 @@ func Connect(config ConnConfig) (c *Conn, err error) {
|
||||
c.logger.Debug("Using default connection config", "Port", c.config.Port)
|
||||
}
|
||||
|
||||
network := "tcp"
|
||||
address := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port)
|
||||
// See if host is a valid path, if yes connect with a socket
|
||||
_, err = os.Stat(c.config.Host)
|
||||
if err == nil {
|
||||
if _, err := os.Stat(c.config.Host); err == nil {
|
||||
// For backward compatibility accept socket file paths -- but directories are now preferred
|
||||
socket := c.config.Host
|
||||
if !strings.Contains(socket, "/.s.PGSQL.") {
|
||||
socket = filepath.Join(socket, ".s.PGSQL.") + strconv.FormatInt(int64(c.config.Port), 10)
|
||||
}
|
||||
|
||||
c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at socket: %s", socket))
|
||||
c.conn, err = net.Dial("unix", socket)
|
||||
if err != nil {
|
||||
c.logger.Error(fmt.Sprintf("Connection failed: %v", err))
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at host: %s:%d", c.config.Host, c.config.Port))
|
||||
d := net.Dialer{KeepAlive: 5 * time.Minute}
|
||||
c.conn, err = d.Dial("tcp", fmt.Sprintf("%s:%d", c.config.Host, c.config.Port))
|
||||
if err != nil {
|
||||
c.logger.Error(fmt.Sprintf("Connection failed: %v", err))
|
||||
return nil, err
|
||||
network = "unix"
|
||||
address = c.config.Host
|
||||
if !strings.Contains(address, "/.s.PGSQL.") {
|
||||
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(c.config.Port), 10)
|
||||
}
|
||||
}
|
||||
if c.config.Dial == nil {
|
||||
c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial
|
||||
}
|
||||
c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address))
|
||||
c.conn, err = c.config.Dial(network, address)
|
||||
if err != nil {
|
||||
c.logger.Error(fmt.Sprintf("Connection failed: %v", err))
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if c != nil && err != nil {
|
||||
c.conn.Close()
|
||||
|
||||
Reference in New Issue
Block a user