diff --git a/connection.go b/connection.go index 12965bce..6249080f 100644 --- a/connection.go +++ b/connection.go @@ -16,12 +16,15 @@ import ( "io/ioutil" "net" "os/user" + "path/filepath" + "strconv" + "strings" "time" ) // ConnectionParameters contains all the options used to establish a connection. type ConnectionParameters struct { - Socket string // path to unix domain socket (e.g. /private/tmp/.s.PGSQL.5432) + Socket string // path to unix domain socket directory (e.g. /private/tmp) Host string // url (e.g. localhost) Port uint16 // default: 5432 Database string @@ -126,8 +129,14 @@ func Connect(parameters ConnectionParameters) (c *Connection, err error) { } if c.parameters.Socket != "" { - c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at socket: %s", c.parameters.Socket)) - c.conn, err = net.Dial("unix", c.parameters.Socket) + // For backward compatibility accept socket file paths -- but directories are now preferred + socket := c.parameters.Socket + if !strings.Contains(socket, "/.s.PGSQL.") { + socket = filepath.Join(socket, ".s.PGSQL.") + strconv.FormatInt(int64(c.parameters.Port), 10) + } + + c.logger.Info(fmt.Sprintf("Dialing PostgreSQL server at socket: %s", socket)) + c.conn, err = net.Dial("unix", socket) if err != nil { c.logger.Error(fmt.Sprintf("Connection failed: %v", err)) return nil, err diff --git a/connection_settings_test.go.example b/connection_settings_test.go.example index 3ce4a19f..b54f0eae 100644 --- a/connection_settings_test.go.example +++ b/connection_settings_test.go.example @@ -15,7 +15,7 @@ var noPasswordConnectionParameters *pgx.ConnectionParameters = nil var invalidUserConnectionParameters *pgx.ConnectionParameters = nil // var tcpConnectionParameters *pgx.ConnectionParameters = &pgx.ConnectionParameters{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} -// var unixSocketConnectionParameters *pgx.ConnectionParameters = &pgx.ConnectionParameters{Socket: "/private/tmp/.s.PGSQL.5432", User: "pgx_none", Database: "pgx_test"} +// var unixSocketConnectionParameters *pgx.ConnectionParameters = &pgx.ConnectionParameters{Socket: "/private/tmp", User: "pgx_none", Database: "pgx_test"} // var md5ConnectionParameters *pgx.ConnectionParameters = &pgx.ConnectionParameters{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} // var plainPasswordConnectionParameters *pgx.ConnectionParameters = &pgx.ConnectionParameters{Host: "127.0.0.1", User: "pgx_pw", Password: "secret", Database: "pgx_test"} // var noPasswordConnectionParameters *pgx.ConnectionParameters = &pgx.ConnectionParameters{Host: "127.0.0.1", User: "pgx_none", Database: "pgx_test"} diff --git a/connection_test.go b/connection_test.go index c8cd22b8..2e424a97 100644 --- a/connection_test.go +++ b/connection_test.go @@ -44,7 +44,8 @@ func TestConnect(t *testing.T) { } } -func TestConnectWithUnixSocket(t *testing.T) { +func TestConnectWithUnixSocketDirectory(t *testing.T) { + // /.s.PGSQL.5432 if unixSocketConnectionParameters == nil { return } @@ -60,6 +61,24 @@ func TestConnectWithUnixSocket(t *testing.T) { } } +func TestConnectWithUnixSocketFile(t *testing.T) { + if unixSocketConnectionParameters == nil { + return + } + + connParams := *unixSocketConnectionParameters + connParams.Socket = connParams.Socket + "/.s.PGSQL.5432" + conn, err := pgx.Connect(connParams) + if err != nil { + t.Fatalf("Unable to establish connection: %v", err) + } + + err = conn.Close() + if err != nil { + t.Fatal("Unable to close connection") + } +} + func TestConnectWithTcp(t *testing.T) { if tcpConnectionParameters == nil { return