From b2a540ca814e2103bdb45188625a035e2ef3f6b6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 16 Apr 2019 20:30:55 -0500 Subject: [PATCH] Add sufficient support for SCRAM --- authentication.go | 30 +++++++++++++++++++ sasl_initial_response.go | 64 ++++++++++++++++++++++++++++++++++++++++ sasl_response.go | 38 ++++++++++++++++++++++++ 3 files changed, 132 insertions(+) create mode 100644 sasl_initial_response.go create mode 100644 sasl_response.go diff --git a/authentication.go b/authentication.go index 14275a86..2078c87c 100644 --- a/authentication.go +++ b/authentication.go @@ -1,6 +1,7 @@ package pgproto3 import ( + "bytes" "encoding/binary" "github.com/jackc/pgio" @@ -11,6 +12,9 @@ const ( AuthTypeOk = 0 AuthTypeCleartextPassword = 3 AuthTypeMD5Password = 5 + AuthTypeSASL = 10 + AuthTypeSASLContinue = 11 + AuthTypeSASLFinal = 12 ) type Authentication struct { @@ -18,6 +22,12 @@ type Authentication struct { // MD5Password fields Salt [4]byte + + // SASL fields + SASLAuthMechanisms []string + + // SASLContinue and SASLFinal data + SASLData []byte } func (*Authentication) Backend() {} @@ -30,6 +40,17 @@ func (dst *Authentication) Decode(src []byte) error { case AuthTypeCleartextPassword: case AuthTypeMD5Password: copy(dst.Salt[:], src[4:8]) + case AuthTypeSASL: + authMechanisms := src[4:] + for len(authMechanisms) > 1 { + idx := bytes.IndexByte(authMechanisms, 0) + if idx > 0 { + dst.SASLAuthMechanisms = append(dst.SASLAuthMechanisms, string(authMechanisms[:idx])) + authMechanisms = authMechanisms[idx+1:] + } + } + case AuthTypeSASLContinue, AuthTypeSASLFinal: + dst.SASLData = src[4:] default: return errors.Errorf("unknown authentication type: %d", dst.Type) } @@ -46,6 +67,15 @@ func (src *Authentication) Encode(dst []byte) []byte { switch src.Type { case AuthTypeMD5Password: dst = append(dst, src.Salt[:]...) + case AuthTypeSASL: + for _, s := range src.SASLAuthMechanisms { + dst = append(dst, []byte(s)...) + dst = append(dst, 0) + } + dst = append(dst, 0) + case AuthTypeSASLContinue: + dst = pgio.AppendInt32(dst, int32(len(src.SASLData))) + dst = append(dst, src.SASLData...) } pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) diff --git a/sasl_initial_response.go b/sasl_initial_response.go new file mode 100644 index 00000000..63766131 --- /dev/null +++ b/sasl_initial_response.go @@ -0,0 +1,64 @@ +package pgproto3 + +import ( + "bytes" + "encoding/hex" + "encoding/json" + "errors" + + "github.com/jackc/pgio" +) + +type SASLInitialResponse struct { + AuthMechanism string + Data []byte +} + +func (*SASLInitialResponse) Frontend() {} + +func (dst *SASLInitialResponse) Decode(src []byte) error { + *dst = SASLInitialResponse{} + + rp := 0 + + idx := bytes.IndexByte(src, 0) + if idx < 0 { + return errors.New("invalid SASLInitialResponse") + } + + dst.AuthMechanism = string(src[rp:idx]) + rp = idx + 1 + + rp += 4 // The rest of the message is data so we can just skip the size + dst.Data = src[rp:] + + return nil +} + +func (src *SASLInitialResponse) Encode(dst []byte) []byte { + dst = append(dst, 'p') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + + dst = append(dst, []byte(src.AuthMechanism)...) + dst = append(dst, 0) + + dst = pgio.AppendInt32(dst, int32(len(src.Data))) + dst = append(dst, src.Data...) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} + +func (src *SASLInitialResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + AuthMechanism string + Data string + }{ + Type: "SASLInitialResponse", + AuthMechanism: src.AuthMechanism, + Data: hex.EncodeToString(src.Data), + }) +} diff --git a/sasl_response.go b/sasl_response.go new file mode 100644 index 00000000..1e8d3bd3 --- /dev/null +++ b/sasl_response.go @@ -0,0 +1,38 @@ +package pgproto3 + +import ( + "encoding/hex" + "encoding/json" + + "github.com/jackc/pgio" +) + +type SASLResponse struct { + Data []byte +} + +func (*SASLResponse) Frontend() {} + +func (dst *SASLResponse) Decode(src []byte) error { + *dst = SASLResponse{Data: src} + return nil +} + +func (src *SASLResponse) Encode(dst []byte) []byte { + dst = append(dst, 'p') + dst = pgio.AppendInt32(dst, int32(4+len(src.Data))) + + dst = append(dst, src.Data...) + + return dst +} + +func (src *SASLResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data string + }{ + Type: "SASLResponse", + Data: hex.EncodeToString(src.Data), + }) +}