Add support for SslPassword
This commit is contained in:
committed by
Jack Christensen
parent
a18df2374a
commit
32ec44f726
Generated
+9
@@ -0,0 +1,9 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="WEB_MODULE" version="4">
|
||||||
|
<component name="Go" enabled="true" />
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$" />
|
||||||
|
<orderEntry type="inheritedJdk" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
</module>
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -60,6 +61,9 @@ type Config struct {
|
|||||||
// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
|
// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
|
||||||
OnNotification NotificationHandler
|
OnNotification NotificationHandler
|
||||||
|
|
||||||
|
// SslPasswordCallback is a callback function to handle Auth callback for SSL Password
|
||||||
|
SslPasswordCallback SslPasswordCallbackHandler
|
||||||
|
|
||||||
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
|
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -132,6 +136,11 @@ func NetworkAddress(host string, port uint16) (network, address string) {
|
|||||||
return network, address
|
return network, address
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ParseConfig builds a *Config when sslpasswordcallback function is not provided
|
||||||
|
func ParseConfig(connString string) (*Config, error) {
|
||||||
|
return ParseConfigWithSslPasswordCallback(connString, nil)
|
||||||
|
}
|
||||||
|
|
||||||
// ParseConfig builds a *Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same
|
// ParseConfig builds a *Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same
|
||||||
// defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely matches
|
// defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely matches
|
||||||
// the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style). See
|
// the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style). See
|
||||||
@@ -171,6 +180,7 @@ func NetworkAddress(host string, port uint16) (network, address string) {
|
|||||||
// PGSSLCERT
|
// PGSSLCERT
|
||||||
// PGSSLKEY
|
// PGSSLKEY
|
||||||
// PGSSLROOTCERT
|
// PGSSLROOTCERT
|
||||||
|
// PGSSLPASSWORD
|
||||||
// PGAPPNAME
|
// PGAPPNAME
|
||||||
// PGCONNECT_TIMEOUT
|
// PGCONNECT_TIMEOUT
|
||||||
// PGTARGETSESSIONATTRS
|
// PGTARGETSESSIONATTRS
|
||||||
@@ -194,6 +204,7 @@ func NetworkAddress(host string, port uint16) (network, address string) {
|
|||||||
// which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually
|
// which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually
|
||||||
// changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting
|
// changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting
|
||||||
// TLCConfig.
|
// TLCConfig.
|
||||||
|
// sslPasswordCallback function provide a callback function for sslpassword
|
||||||
//
|
//
|
||||||
// Other known differences with libpq:
|
// Other known differences with libpq:
|
||||||
//
|
//
|
||||||
@@ -207,7 +218,7 @@ func NetworkAddress(host string, port uint16) (network, address string) {
|
|||||||
// servicefile
|
// servicefile
|
||||||
// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a
|
// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a
|
||||||
// part of the connection string.
|
// part of the connection string.
|
||||||
func ParseConfig(connString string) (*Config, error) {
|
func ParseConfigWithSslPasswordCallback(connString string, sslPasswordCallback SslPasswordCallbackHandler) (*Config, error) {
|
||||||
defaultSettings := defaultSettings()
|
defaultSettings := defaultSettings()
|
||||||
envSettings := parseEnvSettings()
|
envSettings := parseEnvSettings()
|
||||||
|
|
||||||
@@ -278,6 +289,7 @@ func ParseConfig(connString string) (*Config, error) {
|
|||||||
"sslkey": {},
|
"sslkey": {},
|
||||||
"sslcert": {},
|
"sslcert": {},
|
||||||
"sslrootcert": {},
|
"sslrootcert": {},
|
||||||
|
"sslpassword": {},
|
||||||
"krbspn": {},
|
"krbspn": {},
|
||||||
"krbsrvname": {},
|
"krbsrvname": {},
|
||||||
"target_session_attrs": {},
|
"target_session_attrs": {},
|
||||||
@@ -326,7 +338,7 @@ func ParseConfig(connString string) (*Config, error) {
|
|||||||
tlsConfigs = append(tlsConfigs, nil)
|
tlsConfigs = append(tlsConfigs, nil)
|
||||||
} else {
|
} else {
|
||||||
var err error
|
var err error
|
||||||
tlsConfigs, err = configTLS(settings, host)
|
tlsConfigs, err = configTLS(settings, host, sslPasswordCallback)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err}
|
return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err}
|
||||||
}
|
}
|
||||||
@@ -406,6 +418,7 @@ func parseEnvSettings() map[string]string {
|
|||||||
"PGSSLKEY": "sslkey",
|
"PGSSLKEY": "sslkey",
|
||||||
"PGSSLCERT": "sslcert",
|
"PGSSLCERT": "sslcert",
|
||||||
"PGSSLROOTCERT": "sslrootcert",
|
"PGSSLROOTCERT": "sslrootcert",
|
||||||
|
"PGSSLPASSWORD": "sslpassword",
|
||||||
"PGTARGETSESSIONATTRS": "target_session_attrs",
|
"PGTARGETSESSIONATTRS": "target_session_attrs",
|
||||||
"PGSERVICE": "service",
|
"PGSERVICE": "service",
|
||||||
"PGSERVICEFILE": "servicefile",
|
"PGSERVICEFILE": "servicefile",
|
||||||
@@ -592,12 +605,13 @@ func parseServiceSettings(servicefilePath, serviceName string) (map[string]strin
|
|||||||
// configTLS uses libpq's TLS parameters to construct []*tls.Config. It is
|
// configTLS uses libpq's TLS parameters to construct []*tls.Config. It is
|
||||||
// necessary to allow returning multiple TLS configs as sslmode "allow" and
|
// necessary to allow returning multiple TLS configs as sslmode "allow" and
|
||||||
// "prefer" allow fallback.
|
// "prefer" allow fallback.
|
||||||
func configTLS(settings map[string]string, thisHost string) ([]*tls.Config, error) {
|
func configTLS(settings map[string]string, thisHost string, sslPasswordCallback SslPasswordCallbackHandler) ([]*tls.Config, error) {
|
||||||
host := thisHost
|
host := thisHost
|
||||||
sslmode := settings["sslmode"]
|
sslmode := settings["sslmode"]
|
||||||
sslrootcert := settings["sslrootcert"]
|
sslrootcert := settings["sslrootcert"]
|
||||||
sslcert := settings["sslcert"]
|
sslcert := settings["sslcert"]
|
||||||
sslkey := settings["sslkey"]
|
sslkey := settings["sslkey"]
|
||||||
|
sslpassword := settings["sslpassword"]
|
||||||
|
|
||||||
// Match libpq default behavior
|
// Match libpq default behavior
|
||||||
if sslmode == "" {
|
if sslmode == "" {
|
||||||
@@ -685,11 +699,43 @@ func configTLS(settings map[string]string, thisHost string) ([]*tls.Config, erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
if sslcert != "" && sslkey != "" {
|
if sslcert != "" && sslkey != "" {
|
||||||
cert, err := tls.LoadX509KeyPair(sslcert, sslkey)
|
buf, err := ioutil.ReadFile(sslkey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to read sslkey: %w", err)
|
||||||
|
}
|
||||||
|
block, _ := pem.Decode(buf)
|
||||||
|
var pemKey []byte
|
||||||
|
// If PEM is encrypted, attempt to decrypt using pass phrase
|
||||||
|
if x509.IsEncryptedPEMBlock(block) {
|
||||||
|
if sslpassword == "" {
|
||||||
|
if sslPasswordCallback == nil {
|
||||||
|
return nil, fmt.Errorf("unable to find sslpassword: %w", err)
|
||||||
|
}
|
||||||
|
sslpassword = sslPasswordCallback()
|
||||||
|
}
|
||||||
|
// Attempt decryption with pass phrase
|
||||||
|
// NOTE: only supports RSA (PKCS#1)
|
||||||
|
decryptedKey, err := x509.DecryptPEMBlock(block, []byte(sslpassword))
|
||||||
|
// Should we also provide warning for PKCS#1 needed?
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to decrypt key: %w", err)
|
||||||
|
}
|
||||||
|
pemBytes := pem.Block{
|
||||||
|
Type: "RSA PRIVATE KEY",
|
||||||
|
Bytes: decryptedKey,
|
||||||
|
}
|
||||||
|
pemKey = pem.EncodeToMemory(&pemBytes)
|
||||||
|
} else {
|
||||||
|
pemKey = pem.EncodeToMemory(block)
|
||||||
|
}
|
||||||
|
certfile, err := ioutil.ReadFile(sslcert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to read cert: %w", err)
|
return nil, fmt.Errorf("unable to read cert: %w", err)
|
||||||
}
|
}
|
||||||
|
cert, err := tls.X509KeyPair(certfile, pemKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to load cert: %w", err)
|
||||||
|
}
|
||||||
tlsConfig.Certificates = []tls.Certificate{cert}
|
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -64,6 +64,8 @@ type NoticeHandler func(*PgConn, *Notice)
|
|||||||
// notice event.
|
// notice event.
|
||||||
type NotificationHandler func(*PgConn, *Notification)
|
type NotificationHandler func(*PgConn, *Notification)
|
||||||
|
|
||||||
|
type SslPasswordCallbackHandler func() (string)
|
||||||
|
|
||||||
// Frontend used to receive messages from backend.
|
// Frontend used to receive messages from backend.
|
||||||
type Frontend interface {
|
type Frontend interface {
|
||||||
Receive() (pgproto3.BackendMessage, error)
|
Receive() (pgproto3.BackendMessage, error)
|
||||||
|
|||||||
+54
-1
@@ -1,6 +1,7 @@
|
|||||||
package pgconn_test
|
package pgconn_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
"context"
|
"context"
|
||||||
@@ -63,7 +64,59 @@ func TestConnectTLS(t *testing.T) {
|
|||||||
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING")
|
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING")
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := pgconn.Connect(context.Background(), connString)
|
var conn *pgconn.PgConn
|
||||||
|
var err error
|
||||||
|
|
||||||
|
isSslPasswrodEmpty := strings.HasSuffix(connString, "sslpassword=")
|
||||||
|
|
||||||
|
if isSslPasswrodEmpty {
|
||||||
|
config, err := pgconn.ParseConfigWithSslPasswordCallback(connString, GetSslPassword)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
conn, err = pgconn.ConnectConfig(context.Background(), config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
conn, err = pgconn.Connect(context.Background(), connString)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := conn.Conn().(*tls.Conn); !ok {
|
||||||
|
t.Error("not a TLS connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
closeConn(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetSslPassword() string {
|
||||||
|
readFile, err := os.Open("data.txt")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
}
|
||||||
|
fileScanner := bufio.NewScanner(readFile)
|
||||||
|
fileScanner.Split(bufio.ScanLines)
|
||||||
|
for fileScanner.Scan() {
|
||||||
|
line := fileScanner.Text()
|
||||||
|
if strings.HasPrefix(line, "sslpassword=") {
|
||||||
|
index := len("sslpassword=")
|
||||||
|
line := line[index:]
|
||||||
|
return line
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectTLSCallback(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
connString := os.Getenv("PGX_TEST_TLS_CONN_STRING")
|
||||||
|
if connString == "" {
|
||||||
|
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING")
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := pgconn.ParseConfigWithSslPasswordCallback(connString, GetSslPassword)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
conn, err := pgconn.ConnectConfig(context.Background(), config)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
if _, ok := conn.Conn().(*tls.Conn); !ok {
|
if _, ok := conn.Conn().(*tls.Conn); !ok {
|
||||||
|
|||||||
Reference in New Issue
Block a user