From 6413491657556ae5457ce9b1557440ecb816bb95 Mon Sep 17 00:00:00 2001 From: Rick Snyder Date: Tue, 18 Aug 2015 16:00:37 -0400 Subject: [PATCH] Add support for specifying sslmode in connection strings Add tests for sslmode parameter when calling ParseURI. Fix existing tests to work since default sslmode is 'prefer' Make sure we default to prefer if sslmode is not provided in ParseDSN Fix existing tests for ParseDSN to expect TLS configuration for prefer since prefer is the default sslmode; also, add tests for ParseDSN when specifying sslmode parameter on connection string --- conn.go | 29 ++++++++++++++-- conn_test.go | 93 +++++++++++++++++++++++++++++++++++++++++++++++++-- stdlib/sql.go | 5 +-- 3 files changed, 120 insertions(+), 7 deletions(-) diff --git a/conn.go b/conn.go index 2af381b7..7d5e7f66 100644 --- a/conn.go +++ b/conn.go @@ -296,6 +296,11 @@ func ParseURI(uri string) (ConnConfig, error) { } cp.Database = strings.TrimLeft(url.Path, "/") + err = configSSL(url.Query().Get("sslmode"), &cp) + if err != nil { + return cp, err + } + return cp, nil } @@ -303,12 +308,14 @@ var dsn_regexp = regexp.MustCompile(`([a-z]+)=((?:"[^"]+")|(?:[^ ]+))`) // ParseDSN parses a database DSN (data source name) into a ConnConfig // -// e.g. ParseDSN("user=username password=password host=1.2.3.4 port=5432 dbname=mydb") +// e.g. ParseDSN("user=username password=password host=1.2.3.4 port=5432 dbname=mydb sslmode=disable") func ParseDSN(s string) (ConnConfig, error) { var cp ConnConfig m := dsn_regexp.FindAllStringSubmatch(s, -1) + var sslmode string + for _, b := range m { switch b[1] { case "user": @@ -325,9 +332,16 @@ func ParseDSN(s string) (ConnConfig, error) { } case "dbname": cp.Database = b[2] + case "sslmode": + sslmode = b[2] } } + err := configSSL(sslmode, &cp) + if err != nil { + return cp, err + } + return cp, nil } @@ -380,6 +394,15 @@ func ParseEnvLibpq() (ConnConfig, error) { sslmode := os.Getenv("PGSSLMODE") + err := configSSL(sslmode, &cc) + if err != nil { + return cc, err + } + + return cc, nil +} + +func configSSL(sslmode string, cc *ConnConfig) error { // Match libpq default behavior if sslmode == "" { sslmode = "prefer" @@ -399,10 +422,10 @@ func ParseEnvLibpq() (ConnConfig, error) { ServerName: cc.Host, } default: - return cc, errors.New("sslmode is invalid") + return errors.New("sslmode is invalid") } - return cc, nil + return nil } // Prepare creates a prepared statement with name and sql. sql can contain placeholders diff --git a/conn_test.go b/conn_test.go index 32f58e03..3cc9aa5b 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3,7 +3,6 @@ package pgx_test import ( "crypto/tls" "fmt" - "github.com/jackc/pgx" "net" "os" "reflect" @@ -12,6 +11,8 @@ import ( "sync" "testing" "time" + + "github.com/jackc/pgx" ) func TestConnect(t *testing.T) { @@ -264,6 +265,34 @@ func TestParseURI(t *testing.T) { url string connParams pgx.ConnConfig }{ + { + url: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + }, + }, + { + url: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: nil, + UseFallbackTLS: false, + FallbackTLSConfig: nil, + }, + }, { url: "postgres://jack:secret@localhost:5432/mydb", connParams: pgx.ConnConfig{ @@ -272,6 +301,11 @@ func TestParseURI(t *testing.T) { Host: "localhost", Port: 5432, Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, }, }, { @@ -282,6 +316,11 @@ func TestParseURI(t *testing.T) { Host: "localhost", Port: 5432, Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, }, }, { @@ -291,6 +330,11 @@ func TestParseURI(t *testing.T) { Host: "localhost", Port: 5432, Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, }, }, { @@ -299,6 +343,11 @@ func TestParseURI(t *testing.T) { User: "jack", Host: "localhost", Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, }, }, } @@ -324,7 +373,7 @@ func TestParseDSN(t *testing.T) { connParams pgx.ConnConfig }{ { - url: "user=jack password=secret host=localhost port=5432 dbname=mydb", + url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable", connParams: pgx.ConnConfig{ User: "jack", Password: "secret", @@ -333,6 +382,36 @@ func TestParseDSN(t *testing.T) { Database: "mydb", }, }, + { + url: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=prefer", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + }, + }, + { + url: "user=jack password=secret host=localhost port=5432 dbname=mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + }, + }, { url: "user=jack host=localhost port=5432 dbname=mydb", connParams: pgx.ConnConfig{ @@ -340,6 +419,11 @@ func TestParseDSN(t *testing.T) { Host: "localhost", Port: 5432, Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, }, }, { @@ -348,6 +432,11 @@ func TestParseDSN(t *testing.T) { User: "jack", Host: "localhost", Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, }, }, } diff --git a/stdlib/sql.go b/stdlib/sql.go index 2292f1c3..048e6d04 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -2,7 +2,7 @@ // // A database/sql connection can be established through sql.Open. // -// db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test") +// db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test?sslmode=disable") // if err != nil { // return err // } @@ -48,8 +48,9 @@ import ( "database/sql/driver" "errors" "fmt" - "github.com/jackc/pgx" "io" + + "github.com/jackc/pgx" ) var openFromConnPoolCount int