diff --git a/connection.go b/connection.go index e2a3ff79..0c44c0dc 100644 --- a/connection.go +++ b/connection.go @@ -7,6 +7,7 @@ import ( "bufio" "bytes" "crypto/md5" + "crypto/tls" "encoding/binary" "encoding/hex" "errors" @@ -24,7 +25,8 @@ type ConnectionParameters struct { Database string User string Password string - MsgBufSize int // Size of work buffer used for transcoding messages. For optimal performance, it should be large enough to store a single row from any result set. Default: 1024 + MsgBufSize int // Size of work buffer used for transcoding messages. For optimal performance, it should be large enough to store a single row from any result set. Default: 1024 + SSL bool // Require SSL connection } // Connection is a PostgreSQL connection handle. It is not safe for concurrent usage. @@ -117,13 +119,20 @@ func Connect(parameters ConnectionParameters) (c *Connection, err error) { } }() - c.writer = bufio.NewWriter(c.conn) c.bufSize = c.parameters.MsgBufSize c.buf = bytes.NewBuffer(make([]byte, 0, c.bufSize)) c.RuntimeParams = make(map[string]string) c.preparedStatements = make(map[string]*preparedStatement) c.alive = true + if parameters.SSL { + if err = c.startSSL(); err != nil { + return + } + } + + c.writer = bufio.NewWriter(c.conn) + msg := newStartupMessage() msg.options["user"] = c.parameters.User if c.parameters.Database != "" { @@ -885,6 +894,28 @@ func (c *Connection) rxNotificationResponse(r *MessageReader) (err error) { return } +func (c *Connection) startSSL() (err error) { + err = binary.Write(c.conn, binary.BigEndian, []int32{8, 80877103}) + if err != nil { + return + } + + response := make([]byte, 1) + if _, err = io.ReadFull(c.conn, response); err != nil { + return + } + + if response[0] != 'S' { + err = errors.New("Could not use SSL") + return + } + + config := &tls.Config{InsecureSkipVerify: true} + c.conn = tls.Client(c.conn, config) + + return nil +} + func (c *Connection) txStartupMessage(msg *startupMessage) (err error) { _, err = c.writer.Write(msg.Bytes()) if err != nil { diff --git a/connection_test.go b/connection_test.go index 38d55285..9d32f561 100644 --- a/connection_test.go +++ b/connection_test.go @@ -77,6 +77,22 @@ func TestConnectWithTcp(t *testing.T) { } } +func TestConnectWithSSL(t *testing.T) { + if sslConnectionParameters == nil { + return + } + + conn, err := pgx.Connect(*sslConnectionParameters) + if err != nil { + t.Fatal("Unable to establish connection: " + err.Error()) + } + + err = conn.Close() + if err != nil { + t.Fatal("Unable to close connection") + } +} + func TestConnectWithInvalidUser(t *testing.T) { if invalidUserConnectionParameters == nil { return