Add SCRAM authentication
This commit is contained in:
+255
@@ -0,0 +1,255 @@
|
||||
// SCRAM-SHA-256 authentication
|
||||
//
|
||||
// Resources:
|
||||
// https://tools.ietf.org/html/rfc5802
|
||||
// https://tools.ietf.org/html/rfc8265
|
||||
// https://www.postgresql.org/docs/current/sasl-authentication.html
|
||||
//
|
||||
// Inspiration drawn from other implementations:
|
||||
// https://github.com/lib/pq/pull/608
|
||||
// https://github.com/lib/pq/pull/788
|
||||
// https://github.com/lib/pq/pull/833
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
"golang.org/x/text/secure/precis"
|
||||
)
|
||||
|
||||
const clientNonceLen = 18
|
||||
|
||||
// Perform SCRAM authentication.
|
||||
func (c *Conn) scramAuth(serverAuthMechanisms []string) error {
|
||||
sc, err := newScramClient(serverAuthMechanisms, c.config.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send client-first-message in a SASLInitialResponse
|
||||
saslInitialResponse := &pgproto3.SASLInitialResponse{
|
||||
AuthMechanism: "SCRAM-SHA-256",
|
||||
Data: sc.clientFirstMessage(),
|
||||
}
|
||||
_, err = c.conn.Write(saslInitialResponse.Encode(nil))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Receive server-first-message payload in a AuthenticationSASLContinue.
|
||||
authMsg, err := c.rxAuthMsg(pgproto3.AuthTypeSASLContinue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = sc.recvServerFirstMessage(authMsg.SASLData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send client-final-message in a SASLResponse
|
||||
saslResponse := &pgproto3.SASLResponse{
|
||||
Data: []byte(sc.clientFinalMessage()),
|
||||
}
|
||||
_, err = c.conn.Write(saslResponse.Encode(nil))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Receive server-final-message payload in a AuthenticationSASLFinal.
|
||||
authMsg, err = c.rxAuthMsg(pgproto3.AuthTypeSASLFinal)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sc.recvServerFinalMessage(authMsg.SASLData)
|
||||
}
|
||||
|
||||
func (c *Conn) rxAuthMsg(typ uint32) (*pgproto3.Authentication, error) {
|
||||
msg, err := c.rxMsg()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authMsg, ok := msg.(*pgproto3.Authentication)
|
||||
if !ok {
|
||||
return nil, errors.New("unexpected message type")
|
||||
}
|
||||
if authMsg.Type != typ {
|
||||
return nil, errors.New("unexpected auth type")
|
||||
}
|
||||
|
||||
return authMsg, nil
|
||||
}
|
||||
|
||||
type scramClient struct {
|
||||
serverAuthMechanisms []string
|
||||
password []byte
|
||||
clientNonce []byte
|
||||
|
||||
clientFirstMessageBare []byte
|
||||
|
||||
serverFirstMessage []byte
|
||||
clientAndServerNonce []byte
|
||||
salt []byte
|
||||
iterations int
|
||||
|
||||
saltedPassword []byte
|
||||
authMessage []byte
|
||||
}
|
||||
|
||||
func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) {
|
||||
sc := &scramClient{
|
||||
serverAuthMechanisms: serverAuthMechanisms,
|
||||
}
|
||||
|
||||
// Ensure server supports SCRAM-SHA-256
|
||||
hasScramSHA256 := false
|
||||
for _, mech := range sc.serverAuthMechanisms {
|
||||
if mech == "SCRAM-SHA-256" {
|
||||
hasScramSHA256 = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasScramSHA256 {
|
||||
return nil, errors.New("server does not support SCRAM-SHA-256")
|
||||
}
|
||||
|
||||
// precis.OpaqueString is equivalent to SASLprep for password.
|
||||
var err error
|
||||
sc.password, err = precis.OpaqueString.Bytes([]byte(password))
|
||||
if err != nil {
|
||||
// PostgreSQL allows passwords invalid according to SCRAM / SASLprep.
|
||||
sc.password = []byte(password)
|
||||
}
|
||||
|
||||
buf := make([]byte, clientNonceLen)
|
||||
_, err = rand.Read(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sc.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len(buf)))
|
||||
base64.RawStdEncoding.Encode(sc.clientNonce, buf)
|
||||
|
||||
return sc, nil
|
||||
}
|
||||
|
||||
func (sc *scramClient) clientFirstMessage() []byte {
|
||||
sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce))
|
||||
return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare))
|
||||
}
|
||||
|
||||
func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error {
|
||||
sc.serverFirstMessage = serverFirstMessage
|
||||
buf := serverFirstMessage
|
||||
if !bytes.HasPrefix(buf, []byte("r=")) {
|
||||
return errors.New("invalid SCRAM server-first-message received from server: did not include r=")
|
||||
}
|
||||
buf = buf[2:]
|
||||
|
||||
idx := bytes.IndexByte(buf, ',')
|
||||
if idx == -1 {
|
||||
return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
|
||||
}
|
||||
sc.clientAndServerNonce = buf[:idx]
|
||||
buf = buf[idx+1:]
|
||||
|
||||
if !bytes.HasPrefix(buf, []byte("s=")) {
|
||||
return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
|
||||
}
|
||||
buf = buf[2:]
|
||||
|
||||
idx = bytes.IndexByte(buf, ',')
|
||||
if idx == -1 {
|
||||
return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
|
||||
}
|
||||
saltStr := buf[:idx]
|
||||
buf = buf[idx+1:]
|
||||
|
||||
if !bytes.HasPrefix(buf, []byte("i=")) {
|
||||
return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
|
||||
}
|
||||
buf = buf[2:]
|
||||
iterationsStr := buf
|
||||
|
||||
var err error
|
||||
sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr))
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid SCRAM salt received from server: %v", err)
|
||||
}
|
||||
|
||||
sc.iterations, err = strconv.Atoi(string(iterationsStr))
|
||||
if err != nil || sc.iterations <= 0 {
|
||||
return fmt.Errorf("invalid SCRAM iteration count received from server: %s", iterationsStr)
|
||||
}
|
||||
|
||||
if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) {
|
||||
return errors.New("invalid SCRAM nonce: did not start with client nonce")
|
||||
}
|
||||
|
||||
if len(sc.clientAndServerNonce) <= len(sc.clientNonce) {
|
||||
return errors.New("invalid SCRAM nonce: did not include server nonce")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sc *scramClient) clientFinalMessage() string {
|
||||
clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce))
|
||||
|
||||
sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New)
|
||||
sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(","))
|
||||
|
||||
clientProof := computeClientProof(sc.saltedPassword, sc.authMessage)
|
||||
|
||||
return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof)
|
||||
}
|
||||
|
||||
func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error {
|
||||
if !bytes.HasPrefix(serverFinalMessage, []byte("v=")) {
|
||||
return errors.New("invalid SCRAM server-final-message received from server")
|
||||
}
|
||||
|
||||
serverSignature := serverFinalMessage[2:]
|
||||
|
||||
if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) {
|
||||
return errors.New("invalid SCRAM ServerSignature received from server")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func computeHMAC(key, msg []byte) []byte {
|
||||
mac := hmac.New(sha256.New, key)
|
||||
mac.Write(msg)
|
||||
return mac.Sum(nil)
|
||||
}
|
||||
|
||||
func computeClientProof(saltedPassword, authMessage []byte) []byte {
|
||||
clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
|
||||
storedKey := sha256.Sum256(clientKey)
|
||||
clientSignature := computeHMAC(storedKey[:], authMessage)
|
||||
|
||||
clientProof := make([]byte, len(clientSignature))
|
||||
for i := 0; i < len(clientSignature); i++ {
|
||||
clientProof[i] = clientKey[i] ^ clientSignature[i]
|
||||
}
|
||||
|
||||
buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof)))
|
||||
base64.StdEncoding.Encode(buf, clientProof)
|
||||
return buf
|
||||
}
|
||||
|
||||
func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte {
|
||||
serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
|
||||
serverSignature := computeHMAC(serverKey[:], authMessage)
|
||||
buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature)))
|
||||
base64.StdEncoding.Encode(buf, serverSignature)
|
||||
return buf
|
||||
}
|
||||
@@ -1425,6 +1425,8 @@ func (c *Conn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) {
|
||||
case pgproto3.AuthTypeMD5Password:
|
||||
digestedPassword := "md5" + hexMD5(hexMD5(c.config.Password+c.config.User)+string(msg.Salt[:]))
|
||||
err = c.txPasswordMessage(digestedPassword)
|
||||
case pgproto3.AuthTypeSASL:
|
||||
err = c.scramAuth(msg.SASLAuthMechanisms)
|
||||
default:
|
||||
err = errors.New("Received unknown authentication message")
|
||||
}
|
||||
|
||||
@@ -212,6 +212,56 @@ func TestConnectWithMD5Password(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectWithSCRAMPassword(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
connString := os.Getenv("PGX_TEST_SCRAM_PASSWORD_CONN_STRING")
|
||||
if connString == "" {
|
||||
t.Skip("Skipping due to missing PGX_TEST_SCRAM_PASSWORD_CONN_STRING env var")
|
||||
}
|
||||
|
||||
connConfig, err := pgx.ParseConnectionString(connString)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to parse config: %v", err)
|
||||
}
|
||||
|
||||
conn, err := pgx.Connect(connConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to establish connection: %v", err)
|
||||
}
|
||||
|
||||
if _, present := conn.RuntimeParams["server_version"]; !present {
|
||||
t.Error("Runtime parameters not stored")
|
||||
}
|
||||
|
||||
if conn.PID() == 0 {
|
||||
t.Error("Backend PID not stored")
|
||||
}
|
||||
|
||||
var currentDB string
|
||||
err = conn.QueryRow("select current_database()").Scan(¤tDB)
|
||||
if err != nil {
|
||||
t.Fatalf("QueryRow Scan unexpectedly failed: %v", err)
|
||||
}
|
||||
if currentDB != defaultConnConfig.Database {
|
||||
t.Errorf("Did not connect to specified database (%v)", defaultConnConfig.Database)
|
||||
}
|
||||
|
||||
var user string
|
||||
err = conn.QueryRow("select current_user").Scan(&user)
|
||||
if err != nil {
|
||||
t.Fatalf("QueryRow Scan unexpectedly failed: %v", err)
|
||||
}
|
||||
if user != defaultConnConfig.User {
|
||||
t.Errorf("Did not connect as specified user (%v)", defaultConnConfig.User)
|
||||
}
|
||||
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
t.Fatal("Unable to close connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectWithTLSFallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
@@ -11,6 +12,9 @@ const (
|
||||
AuthTypeOk = 0
|
||||
AuthTypeCleartextPassword = 3
|
||||
AuthTypeMD5Password = 5
|
||||
AuthTypeSASL = 10
|
||||
AuthTypeSASLContinue = 11
|
||||
AuthTypeSASLFinal = 12
|
||||
)
|
||||
|
||||
type Authentication struct {
|
||||
@@ -18,6 +22,12 @@ type Authentication struct {
|
||||
|
||||
// MD5Password fields
|
||||
Salt [4]byte
|
||||
|
||||
// SASL fields
|
||||
SASLAuthMechanisms []string
|
||||
|
||||
// SASLContinue and SASLFinal data
|
||||
SASLData []byte
|
||||
}
|
||||
|
||||
func (*Authentication) Backend() {}
|
||||
@@ -30,6 +40,17 @@ func (dst *Authentication) Decode(src []byte) error {
|
||||
case AuthTypeCleartextPassword:
|
||||
case AuthTypeMD5Password:
|
||||
copy(dst.Salt[:], src[4:8])
|
||||
case AuthTypeSASL:
|
||||
authMechanisms := src[4:]
|
||||
for len(authMechanisms) > 1 {
|
||||
idx := bytes.IndexByte(authMechanisms, 0)
|
||||
if idx > 0 {
|
||||
dst.SASLAuthMechanisms = append(dst.SASLAuthMechanisms, string(authMechanisms[:idx]))
|
||||
authMechanisms = authMechanisms[idx+1:]
|
||||
}
|
||||
}
|
||||
case AuthTypeSASLContinue, AuthTypeSASLFinal:
|
||||
dst.SASLData = src[4:]
|
||||
default:
|
||||
return errors.Errorf("unknown authentication type: %d", dst.Type)
|
||||
}
|
||||
@@ -46,6 +67,15 @@ func (src *Authentication) Encode(dst []byte) []byte {
|
||||
switch src.Type {
|
||||
case AuthTypeMD5Password:
|
||||
dst = append(dst, src.Salt[:]...)
|
||||
case AuthTypeSASL:
|
||||
for _, s := range src.SASLAuthMechanisms {
|
||||
dst = append(dst, []byte(s)...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
dst = append(dst, 0)
|
||||
case AuthTypeSASLContinue:
|
||||
dst = pgio.AppendInt32(dst, int32(len(src.SASLData)))
|
||||
dst = append(dst, src.SASLData...)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type SASLInitialResponse struct {
|
||||
AuthMechanism string
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func (*SASLInitialResponse) Frontend() {}
|
||||
|
||||
func (dst *SASLInitialResponse) Decode(src []byte) error {
|
||||
*dst = SASLInitialResponse{}
|
||||
|
||||
rp := 0
|
||||
|
||||
idx := bytes.IndexByte(src, 0)
|
||||
if idx < 0 {
|
||||
return errors.New("invalid SASLInitialResponse")
|
||||
}
|
||||
|
||||
dst.AuthMechanism = string(src[rp:idx])
|
||||
rp = idx + 1
|
||||
|
||||
rp += 4 // The rest of the message is data so we can just skip the size
|
||||
dst.Data = src[rp:]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *SASLInitialResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'p')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, []byte(src.AuthMechanism)...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
dst = pgio.AppendInt32(dst, int32(len(src.Data)))
|
||||
dst = append(dst, src.Data...)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *SASLInitialResponse) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
AuthMechanism string
|
||||
Data string
|
||||
}{
|
||||
Type: "SASLInitialResponse",
|
||||
AuthMechanism: src.AuthMechanism,
|
||||
Data: hex.EncodeToString(src.Data),
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgx/pgio"
|
||||
)
|
||||
|
||||
type SASLResponse struct {
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func (*SASLResponse) Frontend() {}
|
||||
|
||||
func (dst *SASLResponse) Decode(src []byte) error {
|
||||
*dst = SASLResponse{Data: src}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (src *SASLResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'p')
|
||||
dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
|
||||
|
||||
dst = append(dst, src.Data...)
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (src *SASLResponse) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Data string
|
||||
}{
|
||||
Type: "SASLResponse",
|
||||
Data: hex.EncodeToString(src.Data),
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user