From 7c9e8407262f7bfb750aef36c3e49bbff6596d35 Mon Sep 17 00:00:00 2001 From: Yuli Khodorkovskiy Date: Thu, 27 May 2021 14:48:11 -0400 Subject: [PATCH] Add support for identifying authentication messages The pgprotocol overloads 'p' messages with PasswordMessage, SASLInitialResponse, SASLResponse, and GSSResponse. This patch allows contextual identification of the message by setting the authType in the frontend and then setting this value in the backend when a AuthenticationResponseMessage is received. --- authentication_cleartext_password.go | 3 ++ authentication_md5_password.go | 3 ++ authentication_ok.go | 3 ++ authentication_sasl.go | 3 ++ authentication_sasl_continue.go | 3 ++ authentication_sasl_final.go | 3 ++ backend.go | 81 +++++++++++++++++++++------- frontend.go | 27 ++++++++-- password_message.go | 3 ++ pgproto3.go | 5 ++ 10 files changed, 113 insertions(+), 21 deletions(-) diff --git a/authentication_cleartext_password.go b/authentication_cleartext_password.go index 1b87a718..241fa600 100644 --- a/authentication_cleartext_password.go +++ b/authentication_cleartext_password.go @@ -15,6 +15,9 @@ type AuthenticationCleartextPassword struct { // Backend identifies this message as sendable by the PostgreSQL backend. func (*AuthenticationCleartextPassword) Backend() {} +// Backend identifies this message as an authentication response. +func (*AuthenticationCleartextPassword) AuthenticationResponse() {} + // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *AuthenticationCleartextPassword) Decode(src []byte) error { diff --git a/authentication_md5_password.go b/authentication_md5_password.go index 95795b31..32ec0390 100644 --- a/authentication_md5_password.go +++ b/authentication_md5_password.go @@ -16,6 +16,9 @@ type AuthenticationMD5Password struct { // Backend identifies this message as sendable by the PostgreSQL backend. func (*AuthenticationMD5Password) Backend() {} +// Backend identifies this message as an authentication response. +func (*AuthenticationMD5Password) AuthenticationResponse() {} + // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *AuthenticationMD5Password) Decode(src []byte) error { diff --git a/authentication_ok.go b/authentication_ok.go index ad69b907..2b476fe5 100644 --- a/authentication_ok.go +++ b/authentication_ok.go @@ -15,6 +15,9 @@ type AuthenticationOk struct { // Backend identifies this message as sendable by the PostgreSQL backend. func (*AuthenticationOk) Backend() {} +// Backend identifies this message as an authentication response. +func (*AuthenticationOk) AuthenticationResponse() {} + // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *AuthenticationOk) Decode(src []byte) error { diff --git a/authentication_sasl.go b/authentication_sasl.go index d2b09750..bdcb2c36 100644 --- a/authentication_sasl.go +++ b/authentication_sasl.go @@ -17,6 +17,9 @@ type AuthenticationSASL struct { // Backend identifies this message as sendable by the PostgreSQL backend. func (*AuthenticationSASL) Backend() {} +// Backend identifies this message as an authentication response. +func (*AuthenticationSASL) AuthenticationResponse() {} + // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *AuthenticationSASL) Decode(src []byte) error { diff --git a/authentication_sasl_continue.go b/authentication_sasl_continue.go index d258065f..7f4a9c23 100644 --- a/authentication_sasl_continue.go +++ b/authentication_sasl_continue.go @@ -16,6 +16,9 @@ type AuthenticationSASLContinue struct { // Backend identifies this message as sendable by the PostgreSQL backend. func (*AuthenticationSASLContinue) Backend() {} +// Backend identifies this message as an authentication response. +func (*AuthenticationSASLContinue) AuthenticationResponse() {} + // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *AuthenticationSASLContinue) Decode(src []byte) error { diff --git a/authentication_sasl_final.go b/authentication_sasl_final.go index 6a681d73..d82b9ee4 100644 --- a/authentication_sasl_final.go +++ b/authentication_sasl_final.go @@ -16,6 +16,9 @@ type AuthenticationSASLFinal struct { // Backend identifies this message as sendable by the PostgreSQL backend. func (*AuthenticationSASLFinal) Backend() {} +// Backend identifies this message as an authentication response. +func (*AuthenticationSASLFinal) AuthenticationResponse() {} + // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *AuthenticationSASLFinal) Decode(src []byte) error { diff --git a/backend.go b/backend.go index cc6f1f03..232aa11d 100644 --- a/backend.go +++ b/backend.go @@ -12,27 +12,27 @@ type Backend struct { w io.Writer // Frontend message flyweights - bind Bind - cancelRequest CancelRequest - _close Close - copyFail CopyFail - copyData CopyData - copyDone CopyDone - describe Describe - execute Execute - flush Flush - gssEncRequest GSSEncRequest - parse Parse - passwordMessage PasswordMessage - query Query - sslRequest SSLRequest - startupMessage StartupMessage - sync Sync - terminate Terminate + bind Bind + cancelRequest CancelRequest + _close Close + copyFail CopyFail + copyData CopyData + copyDone CopyDone + describe Describe + execute Execute + flush Flush + gssEncRequest GSSEncRequest + parse Parse + query Query + sslRequest SSLRequest + startupMessage StartupMessage + sync Sync + terminate Terminate bodyLen int msgType byte partialMsg bool + authType uint32 } // NewBackend creates a new Backend. @@ -127,7 +127,19 @@ func (b *Backend) Receive() (FrontendMessage, error) { case 'P': msg = &b.parse case 'p': - msg = &b.passwordMessage + switch b.authType { + case AuthTypeSASL: + msg = &SASLInitialResponse{} + case AuthTypeSASLContinue: + msg = &SASLResponse{} + case AuthTypeSASLFinal: + msg = &SASLResponse{} + case AuthTypeCleartextPassword, AuthTypeMD5Password: + fallthrough + default: + // to maintain backwards compatability + msg = &PasswordMessage{} + } case 'Q': msg = &b.query case 'S': @@ -148,3 +160,36 @@ func (b *Backend) Receive() (FrontendMessage, error) { err = msg.Decode(msgBody) return msg, err } + +// SetAuthType sets the authentication type in the backend. +// Since multiple message types can start with 'p', SetAuthType allows +// contextual identification of FrontendMessages. For example, in the +// PG message flow documentation for PasswordMessage: +// +// Byte1('p') +// +// Identifies the message as a password response. Note that this is also used for +// GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from +// the context. +// +// Since the Frontend does not know about the state of a backend, it is important +// to call SetAuthType() after an authentication request is received by the Frontend. +func (b *Backend) SetAuthType(authType uint32) error { + switch authType { + case AuthTypeOk, + AuthTypeCleartextPassword, + AuthTypeMD5Password, + AuthTypeSCMCreds, + AuthTypeGSS, + AuthTypeGSSCont, + AuthTypeSSPI, + AuthTypeSASL, + AuthTypeSASLContinue, + AuthTypeSASLFinal: + b.authType = authType + default: + return fmt.Errorf("authType not recognized: %d", authType) + } + + return nil +} diff --git a/frontend.go b/frontend.go index b8f545ca..c33dfb08 100644 --- a/frontend.go +++ b/frontend.go @@ -45,6 +45,7 @@ type Frontend struct { bodyLen int msgType byte partialMsg bool + authType uint32 } // NewFrontend creates a new Frontend. @@ -146,10 +147,16 @@ func (f *Frontend) Receive() (BackendMessage, error) { } // Authentication message type constants. +// See src/include/libpq/pqcomm.h for all +// constants. const ( AuthTypeOk = 0 AuthTypeCleartextPassword = 3 AuthTypeMD5Password = 5 + AuthTypeSCMCreds = 6 + AuthTypeGSS = 7 + AuthTypeGSSCont = 8 + AuthTypeSSPI = 9 AuthTypeSASL = 10 AuthTypeSASLContinue = 11 AuthTypeSASLFinal = 12 @@ -159,15 +166,23 @@ func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, er if len(src) < 4 { return nil, errors.New("authentication message too short") } - authType := binary.BigEndian.Uint32(src[:4]) + f.authType = binary.BigEndian.Uint32(src[:4]) - switch authType { + switch f.authType { case AuthTypeOk: return &f.authenticationOk, nil case AuthTypeCleartextPassword: return &f.authenticationCleartextPassword, nil case AuthTypeMD5Password: return &f.authenticationMD5Password, nil + case AuthTypeSCMCreds: + return nil, errors.New("AuthTypeSCMCreds is unimplemented") + case AuthTypeGSS: + return nil, errors.New("AuthTypeGSS is unimplemented") + case AuthTypeGSSCont: + return nil, errors.New("AuthTypeGSSCont is unimplemented") + case AuthTypeSSPI: + return nil, errors.New("AuthTypeSSPI is unimplemented") case AuthTypeSASL: return &f.authenticationSASL, nil case AuthTypeSASLContinue: @@ -175,6 +190,12 @@ func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, er case AuthTypeSASLFinal: return &f.authenticationSASLFinal, nil default: - return nil, fmt.Errorf("unknown authentication type: %d", authType) + return nil, fmt.Errorf("unknown authentication type: %d", f.authType) } } + +// GetAuthType returns the authType used in the current state of the frontend. +// See SetAuthType for more information. +func (f *Frontend) GetAuthType() uint32 { + return f.authType +} diff --git a/password_message.go b/password_message.go index 4b68b31a..cae76c50 100644 --- a/password_message.go +++ b/password_message.go @@ -14,6 +14,9 @@ type PasswordMessage struct { // Frontend identifies this message as sendable by a PostgreSQL frontend. func (*PasswordMessage) Frontend() {} +// Frontend identifies this message as an authentication response. +func (*PasswordMessage) InitialResponse() {} + // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *PasswordMessage) Decode(src []byte) error { diff --git a/pgproto3.go b/pgproto3.go index fb0782cf..70c825e3 100644 --- a/pgproto3.go +++ b/pgproto3.go @@ -27,6 +27,11 @@ type BackendMessage interface { Backend() // no-op method to distinguish frontend from backend methods } +type AuthenticationResponseMessage interface { + BackendMessage + AuthenticationResponse() // no-op method to distinguish authentication responses +} + type invalidMessageLenErr struct { messageType string expectedLen int