diff --git a/authentication_cleartext_password.go b/authentication_cleartext_password.go index 1b87a718..241fa600 100644 --- a/authentication_cleartext_password.go +++ b/authentication_cleartext_password.go @@ -15,6 +15,9 @@ type AuthenticationCleartextPassword struct { // Backend identifies this message as sendable by the PostgreSQL backend. func (*AuthenticationCleartextPassword) Backend() {} +// Backend identifies this message as an authentication response. +func (*AuthenticationCleartextPassword) AuthenticationResponse() {} + // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *AuthenticationCleartextPassword) Decode(src []byte) error { diff --git a/authentication_md5_password.go b/authentication_md5_password.go index 95795b31..32ec0390 100644 --- a/authentication_md5_password.go +++ b/authentication_md5_password.go @@ -16,6 +16,9 @@ type AuthenticationMD5Password struct { // Backend identifies this message as sendable by the PostgreSQL backend. func (*AuthenticationMD5Password) Backend() {} +// Backend identifies this message as an authentication response. +func (*AuthenticationMD5Password) AuthenticationResponse() {} + // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *AuthenticationMD5Password) Decode(src []byte) error { diff --git a/authentication_ok.go b/authentication_ok.go index ad69b907..2b476fe5 100644 --- a/authentication_ok.go +++ b/authentication_ok.go @@ -15,6 +15,9 @@ type AuthenticationOk struct { // Backend identifies this message as sendable by the PostgreSQL backend. func (*AuthenticationOk) Backend() {} +// Backend identifies this message as an authentication response. +func (*AuthenticationOk) AuthenticationResponse() {} + // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *AuthenticationOk) Decode(src []byte) error { diff --git a/authentication_sasl.go b/authentication_sasl.go index d2b09750..bdcb2c36 100644 --- a/authentication_sasl.go +++ b/authentication_sasl.go @@ -17,6 +17,9 @@ type AuthenticationSASL struct { // Backend identifies this message as sendable by the PostgreSQL backend. func (*AuthenticationSASL) Backend() {} +// Backend identifies this message as an authentication response. +func (*AuthenticationSASL) AuthenticationResponse() {} + // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *AuthenticationSASL) Decode(src []byte) error { diff --git a/authentication_sasl_continue.go b/authentication_sasl_continue.go index d258065f..7f4a9c23 100644 --- a/authentication_sasl_continue.go +++ b/authentication_sasl_continue.go @@ -16,6 +16,9 @@ type AuthenticationSASLContinue struct { // Backend identifies this message as sendable by the PostgreSQL backend. func (*AuthenticationSASLContinue) Backend() {} +// Backend identifies this message as an authentication response. +func (*AuthenticationSASLContinue) AuthenticationResponse() {} + // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *AuthenticationSASLContinue) Decode(src []byte) error { diff --git a/authentication_sasl_final.go b/authentication_sasl_final.go index 6a681d73..d82b9ee4 100644 --- a/authentication_sasl_final.go +++ b/authentication_sasl_final.go @@ -16,6 +16,9 @@ type AuthenticationSASLFinal struct { // Backend identifies this message as sendable by the PostgreSQL backend. func (*AuthenticationSASLFinal) Backend() {} +// Backend identifies this message as an authentication response. +func (*AuthenticationSASLFinal) AuthenticationResponse() {} + // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *AuthenticationSASLFinal) Decode(src []byte) error { diff --git a/backend.go b/backend.go index cc6f1f03..232aa11d 100644 --- a/backend.go +++ b/backend.go @@ -12,27 +12,27 @@ type Backend struct { w io.Writer // Frontend message flyweights - bind Bind - cancelRequest CancelRequest - _close Close - copyFail CopyFail - copyData CopyData - copyDone CopyDone - describe Describe - execute Execute - flush Flush - gssEncRequest GSSEncRequest - parse Parse - passwordMessage PasswordMessage - query Query - sslRequest SSLRequest - startupMessage StartupMessage - sync Sync - terminate Terminate + bind Bind + cancelRequest CancelRequest + _close Close + copyFail CopyFail + copyData CopyData + copyDone CopyDone + describe Describe + execute Execute + flush Flush + gssEncRequest GSSEncRequest + parse Parse + query Query + sslRequest SSLRequest + startupMessage StartupMessage + sync Sync + terminate Terminate bodyLen int msgType byte partialMsg bool + authType uint32 } // NewBackend creates a new Backend. @@ -127,7 +127,19 @@ func (b *Backend) Receive() (FrontendMessage, error) { case 'P': msg = &b.parse case 'p': - msg = &b.passwordMessage + switch b.authType { + case AuthTypeSASL: + msg = &SASLInitialResponse{} + case AuthTypeSASLContinue: + msg = &SASLResponse{} + case AuthTypeSASLFinal: + msg = &SASLResponse{} + case AuthTypeCleartextPassword, AuthTypeMD5Password: + fallthrough + default: + // to maintain backwards compatability + msg = &PasswordMessage{} + } case 'Q': msg = &b.query case 'S': @@ -148,3 +160,36 @@ func (b *Backend) Receive() (FrontendMessage, error) { err = msg.Decode(msgBody) return msg, err } + +// SetAuthType sets the authentication type in the backend. +// Since multiple message types can start with 'p', SetAuthType allows +// contextual identification of FrontendMessages. For example, in the +// PG message flow documentation for PasswordMessage: +// +// Byte1('p') +// +// Identifies the message as a password response. Note that this is also used for +// GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from +// the context. +// +// Since the Frontend does not know about the state of a backend, it is important +// to call SetAuthType() after an authentication request is received by the Frontend. +func (b *Backend) SetAuthType(authType uint32) error { + switch authType { + case AuthTypeOk, + AuthTypeCleartextPassword, + AuthTypeMD5Password, + AuthTypeSCMCreds, + AuthTypeGSS, + AuthTypeGSSCont, + AuthTypeSSPI, + AuthTypeSASL, + AuthTypeSASLContinue, + AuthTypeSASLFinal: + b.authType = authType + default: + return fmt.Errorf("authType not recognized: %d", authType) + } + + return nil +} diff --git a/frontend.go b/frontend.go index b8f545ca..c33dfb08 100644 --- a/frontend.go +++ b/frontend.go @@ -45,6 +45,7 @@ type Frontend struct { bodyLen int msgType byte partialMsg bool + authType uint32 } // NewFrontend creates a new Frontend. @@ -146,10 +147,16 @@ func (f *Frontend) Receive() (BackendMessage, error) { } // Authentication message type constants. +// See src/include/libpq/pqcomm.h for all +// constants. const ( AuthTypeOk = 0 AuthTypeCleartextPassword = 3 AuthTypeMD5Password = 5 + AuthTypeSCMCreds = 6 + AuthTypeGSS = 7 + AuthTypeGSSCont = 8 + AuthTypeSSPI = 9 AuthTypeSASL = 10 AuthTypeSASLContinue = 11 AuthTypeSASLFinal = 12 @@ -159,15 +166,23 @@ func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, er if len(src) < 4 { return nil, errors.New("authentication message too short") } - authType := binary.BigEndian.Uint32(src[:4]) + f.authType = binary.BigEndian.Uint32(src[:4]) - switch authType { + switch f.authType { case AuthTypeOk: return &f.authenticationOk, nil case AuthTypeCleartextPassword: return &f.authenticationCleartextPassword, nil case AuthTypeMD5Password: return &f.authenticationMD5Password, nil + case AuthTypeSCMCreds: + return nil, errors.New("AuthTypeSCMCreds is unimplemented") + case AuthTypeGSS: + return nil, errors.New("AuthTypeGSS is unimplemented") + case AuthTypeGSSCont: + return nil, errors.New("AuthTypeGSSCont is unimplemented") + case AuthTypeSSPI: + return nil, errors.New("AuthTypeSSPI is unimplemented") case AuthTypeSASL: return &f.authenticationSASL, nil case AuthTypeSASLContinue: @@ -175,6 +190,12 @@ func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, er case AuthTypeSASLFinal: return &f.authenticationSASLFinal, nil default: - return nil, fmt.Errorf("unknown authentication type: %d", authType) + return nil, fmt.Errorf("unknown authentication type: %d", f.authType) } } + +// GetAuthType returns the authType used in the current state of the frontend. +// See SetAuthType for more information. +func (f *Frontend) GetAuthType() uint32 { + return f.authType +} diff --git a/password_message.go b/password_message.go index 4b68b31a..cae76c50 100644 --- a/password_message.go +++ b/password_message.go @@ -14,6 +14,9 @@ type PasswordMessage struct { // Frontend identifies this message as sendable by a PostgreSQL frontend. func (*PasswordMessage) Frontend() {} +// Frontend identifies this message as an authentication response. +func (*PasswordMessage) InitialResponse() {} + // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *PasswordMessage) Decode(src []byte) error { diff --git a/pgproto3.go b/pgproto3.go index fb0782cf..70c825e3 100644 --- a/pgproto3.go +++ b/pgproto3.go @@ -27,6 +27,11 @@ type BackendMessage interface { Backend() // no-op method to distinguish frontend from backend methods } +type AuthenticationResponseMessage interface { + BackendMessage + AuthenticationResponse() // no-op method to distinguish authentication responses +} + type invalidMessageLenErr struct { messageType string expectedLen int