diff --git a/conn.go b/conn.go index 2af381b7..7d5e7f66 100644 --- a/conn.go +++ b/conn.go @@ -296,6 +296,11 @@ func ParseURI(uri string) (ConnConfig, error) { } cp.Database = strings.TrimLeft(url.Path, "/") + err = configSSL(url.Query().Get("sslmode"), &cp) + if err != nil { + return cp, err + } + return cp, nil } @@ -303,12 +308,14 @@ var dsn_regexp = regexp.MustCompile(`([a-z]+)=((?:"[^"]+")|(?:[^ ]+))`) // ParseDSN parses a database DSN (data source name) into a ConnConfig // -// e.g. ParseDSN("user=username password=password host=1.2.3.4 port=5432 dbname=mydb") +// e.g. ParseDSN("user=username password=password host=1.2.3.4 port=5432 dbname=mydb sslmode=disable") func ParseDSN(s string) (ConnConfig, error) { var cp ConnConfig m := dsn_regexp.FindAllStringSubmatch(s, -1) + var sslmode string + for _, b := range m { switch b[1] { case "user": @@ -325,9 +332,16 @@ func ParseDSN(s string) (ConnConfig, error) { } case "dbname": cp.Database = b[2] + case "sslmode": + sslmode = b[2] } } + err := configSSL(sslmode, &cp) + if err != nil { + return cp, err + } + return cp, nil } @@ -380,6 +394,15 @@ func ParseEnvLibpq() (ConnConfig, error) { sslmode := os.Getenv("PGSSLMODE") + err := configSSL(sslmode, &cc) + if err != nil { + return cc, err + } + + return cc, nil +} + +func configSSL(sslmode string, cc *ConnConfig) error { // Match libpq default behavior if sslmode == "" { sslmode = "prefer" @@ -399,10 +422,10 @@ func ParseEnvLibpq() (ConnConfig, error) { ServerName: cc.Host, } default: - return cc, errors.New("sslmode is invalid") + return errors.New("sslmode is invalid") } - return cc, nil + return nil } // Prepare creates a prepared statement with name and sql. sql can contain placeholders diff --git a/conn_test.go b/conn_test.go index 32f58e03..3cc9aa5b 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3,7 +3,6 @@ package pgx_test import ( "crypto/tls" "fmt" - "github.com/jackc/pgx" "net" "os" "reflect" @@ -12,6 +11,8 @@ import ( "sync" "testing" "time" + + "github.com/jackc/pgx" ) func TestConnect(t *testing.T) { @@ -264,6 +265,34 @@ func TestParseURI(t *testing.T) { url string connParams pgx.ConnConfig }{ + { + url: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + }, + }, + { + url: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + UseFallbackTLS: false, + FallbackTLSConfig: nil, + }, + }, { url: "postgres://jack:secret@localhost:5432/mydb", connParams: pgx.ConnConfig{ @@ -272,6 +301,11 @@ func TestParseURI(t *testing.T) { Host: "localhost", Port: 5432, Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, }, }, { @@ -282,6 +316,11 @@ func TestParseURI(t *testing.T) { Host: "localhost", Port: 5432, Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, }, }, { @@ -291,6 +330,11 @@ func TestParseURI(t *testing.T) { Host: "localhost", Port: 5432, Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, }, }, { @@ -299,6 +343,11 @@ func TestParseURI(t *testing.T) { User: "jack", Host: "localhost", Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, }, }, } @@ -324,7 +373,7 @@ func TestParseDSN(t *testing.T) { connParams pgx.ConnConfig }{ { - url: "user=jack password=secret host=localhost port=5432 dbname=mydb", + url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable", connParams: pgx.ConnConfig{ User: "jack", Password: "secret", @@ -333,6 +382,36 @@ func TestParseDSN(t *testing.T) { Database: "mydb", }, }, + { + url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=prefer", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + }, + }, + { + url: "user=jack password=secret host=localhost port=5432 dbname=mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + }, + }, { url: "user=jack host=localhost port=5432 dbname=mydb", connParams: pgx.ConnConfig{ @@ -340,6 +419,11 @@ func TestParseDSN(t *testing.T) { Host: "localhost", Port: 5432, Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, }, }, { @@ -348,6 +432,11 @@ func TestParseDSN(t *testing.T) { User: "jack", Host: "localhost", Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, }, }, } diff --git a/stdlib/sql.go b/stdlib/sql.go index 2292f1c3..048e6d04 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -2,7 +2,7 @@ // // A database/sql connection can be established through sql.Open. // -// db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test") +// db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test?sslmode=disable") // if err != nil { // return err // } @@ -48,8 +48,9 @@ import ( "database/sql/driver" "errors" "fmt" - "github.com/jackc/pgx" "io" + + "github.com/jackc/pgx" ) var openFromConnPoolCount int