2
0

Support using a custom dialer

For example I may want to use a dialer which retries transient network
errors (e.g. DNS issues).

Signed-off-by: Lewis Marshall <lewis@lmars.net>
This commit is contained in:
Lewis Marshall
2015-04-18 22:38:15 +01:00
parent d46a762159
commit 784d12cbbc
4 changed files with 54 additions and 22 deletions
+19 -20
View File
@@ -20,6 +20,8 @@ import (
"time"
)
type DialFunc func(network, addr string) (net.Conn, error)
// ConnConfig contains all the options used to establish a connection.
type ConnConfig struct {
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
@@ -29,6 +31,7 @@ type ConnConfig struct {
Password string
TLSConfig *tls.Config // config for TLS connection -- nil disables TLS
Logger Logger
Dial DialFunc
}
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
@@ -122,30 +125,26 @@ func Connect(config ConnConfig) (c *Conn, err error) {
c.logger.Debug("Using default connection config", "Port", c.config.Port)
}
network := "tcp"
address := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port)
// See if host is a valid path, if yes connect with a socket
_, err = os.Stat(c.config.Host)
if err == nil {
if _, err := os.Stat(c.config.Host); err == nil {
// For backward compatibility accept socket file paths -- but directories are now preferred
socket := c.config.Host
if !strings.Contains(socket, "/.s.PGSQL.") {
socket = filepath.Join(socket, ".s.PGSQL.") + strconv.FormatInt(int64(c.config.Port), 10)
}
c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at socket: %s", socket))
c.conn, err = net.Dial("unix", socket)
if err != nil {
c.logger.Error(fmt.Sprintf("Connection failed: %v", err))
return nil, err
}
} else {
c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at host: %s:%d", c.config.Host, c.config.Port))
d := net.Dialer{KeepAlive: 5 * time.Minute}
c.conn, err = d.Dial("tcp", fmt.Sprintf("%s:%d", c.config.Host, c.config.Port))
if err != nil {
c.logger.Error(fmt.Sprintf("Connection failed: %v", err))
return nil, err
network = "unix"
address = c.config.Host
if !strings.Contains(address, "/.s.PGSQL.") {
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(c.config.Port), 10)
}
}
if c.config.Dial == nil {
c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial
}
c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at %s address: %s", network, address))
c.conn, err = c.config.Dial(network, address)
if err != nil {
c.logger.Error(fmt.Sprintf("Connection failed: %v", err))
return nil, err
}
defer func() {
if c != nil && err != nil {
c.conn.Close()