2
0

Add support for identifying authentication messages

The pgprotocol overloads 'p' messages with PasswordMessage,
SASLInitialResponse, SASLResponse, and GSSResponse. This patch allows
contextual identification of the message by setting the authType in the
frontend and then setting this value in the backend when a
AuthenticationResponseMessage is received.
This commit is contained in:
Yuli Khodorkovskiy
2021-05-27 14:48:11 -04:00
committed by Jack Christensen
parent 28c20e93c0
commit 7c9e840726
10 changed files with 113 additions and 21 deletions
+3
View File
@@ -15,6 +15,9 @@ type AuthenticationCleartextPassword struct {
// Backend identifies this message as sendable by the PostgreSQL backend. // Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationCleartextPassword) Backend() {} func (*AuthenticationCleartextPassword) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationCleartextPassword) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length. // type identifier and 4 byte message length.
func (dst *AuthenticationCleartextPassword) Decode(src []byte) error { func (dst *AuthenticationCleartextPassword) Decode(src []byte) error {
+3
View File
@@ -16,6 +16,9 @@ type AuthenticationMD5Password struct {
// Backend identifies this message as sendable by the PostgreSQL backend. // Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationMD5Password) Backend() {} func (*AuthenticationMD5Password) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationMD5Password) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length. // type identifier and 4 byte message length.
func (dst *AuthenticationMD5Password) Decode(src []byte) error { func (dst *AuthenticationMD5Password) Decode(src []byte) error {
+3
View File
@@ -15,6 +15,9 @@ type AuthenticationOk struct {
// Backend identifies this message as sendable by the PostgreSQL backend. // Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationOk) Backend() {} func (*AuthenticationOk) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationOk) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length. // type identifier and 4 byte message length.
func (dst *AuthenticationOk) Decode(src []byte) error { func (dst *AuthenticationOk) Decode(src []byte) error {
+3
View File
@@ -17,6 +17,9 @@ type AuthenticationSASL struct {
// Backend identifies this message as sendable by the PostgreSQL backend. // Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationSASL) Backend() {} func (*AuthenticationSASL) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationSASL) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length. // type identifier and 4 byte message length.
func (dst *AuthenticationSASL) Decode(src []byte) error { func (dst *AuthenticationSASL) Decode(src []byte) error {
+3
View File
@@ -16,6 +16,9 @@ type AuthenticationSASLContinue struct {
// Backend identifies this message as sendable by the PostgreSQL backend. // Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationSASLContinue) Backend() {} func (*AuthenticationSASLContinue) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationSASLContinue) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length. // type identifier and 4 byte message length.
func (dst *AuthenticationSASLContinue) Decode(src []byte) error { func (dst *AuthenticationSASLContinue) Decode(src []byte) error {
+3
View File
@@ -16,6 +16,9 @@ type AuthenticationSASLFinal struct {
// Backend identifies this message as sendable by the PostgreSQL backend. // Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationSASLFinal) Backend() {} func (*AuthenticationSASLFinal) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationSASLFinal) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length. // type identifier and 4 byte message length.
func (dst *AuthenticationSASLFinal) Decode(src []byte) error { func (dst *AuthenticationSASLFinal) Decode(src []byte) error {
+63 -18
View File
@@ -12,27 +12,27 @@ type Backend struct {
w io.Writer w io.Writer
// Frontend message flyweights // Frontend message flyweights
bind Bind bind Bind
cancelRequest CancelRequest cancelRequest CancelRequest
_close Close _close Close
copyFail CopyFail copyFail CopyFail
copyData CopyData copyData CopyData
copyDone CopyDone copyDone CopyDone
describe Describe describe Describe
execute Execute execute Execute
flush Flush flush Flush
gssEncRequest GSSEncRequest gssEncRequest GSSEncRequest
parse Parse parse Parse
passwordMessage PasswordMessage query Query
query Query sslRequest SSLRequest
sslRequest SSLRequest startupMessage StartupMessage
startupMessage StartupMessage sync Sync
sync Sync terminate Terminate
terminate Terminate
bodyLen int bodyLen int
msgType byte msgType byte
partialMsg bool partialMsg bool
authType uint32
} }
// NewBackend creates a new Backend. // NewBackend creates a new Backend.
@@ -127,7 +127,19 @@ func (b *Backend) Receive() (FrontendMessage, error) {
case 'P': case 'P':
msg = &b.parse msg = &b.parse
case 'p': case 'p':
msg = &b.passwordMessage switch b.authType {
case AuthTypeSASL:
msg = &SASLInitialResponse{}
case AuthTypeSASLContinue:
msg = &SASLResponse{}
case AuthTypeSASLFinal:
msg = &SASLResponse{}
case AuthTypeCleartextPassword, AuthTypeMD5Password:
fallthrough
default:
// to maintain backwards compatability
msg = &PasswordMessage{}
}
case 'Q': case 'Q':
msg = &b.query msg = &b.query
case 'S': case 'S':
@@ -148,3 +160,36 @@ func (b *Backend) Receive() (FrontendMessage, error) {
err = msg.Decode(msgBody) err = msg.Decode(msgBody)
return msg, err return msg, err
} }
// SetAuthType sets the authentication type in the backend.
// Since multiple message types can start with 'p', SetAuthType allows
// contextual identification of FrontendMessages. For example, in the
// PG message flow documentation for PasswordMessage:
//
// Byte1('p')
//
// Identifies the message as a password response. Note that this is also used for
// GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from
// the context.
//
// Since the Frontend does not know about the state of a backend, it is important
// to call SetAuthType() after an authentication request is received by the Frontend.
func (b *Backend) SetAuthType(authType uint32) error {
switch authType {
case AuthTypeOk,
AuthTypeCleartextPassword,
AuthTypeMD5Password,
AuthTypeSCMCreds,
AuthTypeGSS,
AuthTypeGSSCont,
AuthTypeSSPI,
AuthTypeSASL,
AuthTypeSASLContinue,
AuthTypeSASLFinal:
b.authType = authType
default:
return fmt.Errorf("authType not recognized: %d", authType)
}
return nil
}
+24 -3
View File
@@ -45,6 +45,7 @@ type Frontend struct {
bodyLen int bodyLen int
msgType byte msgType byte
partialMsg bool partialMsg bool
authType uint32
} }
// NewFrontend creates a new Frontend. // NewFrontend creates a new Frontend.
@@ -146,10 +147,16 @@ func (f *Frontend) Receive() (BackendMessage, error) {
} }
// Authentication message type constants. // Authentication message type constants.
// See src/include/libpq/pqcomm.h for all
// constants.
const ( const (
AuthTypeOk = 0 AuthTypeOk = 0
AuthTypeCleartextPassword = 3 AuthTypeCleartextPassword = 3
AuthTypeMD5Password = 5 AuthTypeMD5Password = 5
AuthTypeSCMCreds = 6
AuthTypeGSS = 7
AuthTypeGSSCont = 8
AuthTypeSSPI = 9
AuthTypeSASL = 10 AuthTypeSASL = 10
AuthTypeSASLContinue = 11 AuthTypeSASLContinue = 11
AuthTypeSASLFinal = 12 AuthTypeSASLFinal = 12
@@ -159,15 +166,23 @@ func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, er
if len(src) < 4 { if len(src) < 4 {
return nil, errors.New("authentication message too short") return nil, errors.New("authentication message too short")
} }
authType := binary.BigEndian.Uint32(src[:4]) f.authType = binary.BigEndian.Uint32(src[:4])
switch authType { switch f.authType {
case AuthTypeOk: case AuthTypeOk:
return &f.authenticationOk, nil return &f.authenticationOk, nil
case AuthTypeCleartextPassword: case AuthTypeCleartextPassword:
return &f.authenticationCleartextPassword, nil return &f.authenticationCleartextPassword, nil
case AuthTypeMD5Password: case AuthTypeMD5Password:
return &f.authenticationMD5Password, nil return &f.authenticationMD5Password, nil
case AuthTypeSCMCreds:
return nil, errors.New("AuthTypeSCMCreds is unimplemented")
case AuthTypeGSS:
return nil, errors.New("AuthTypeGSS is unimplemented")
case AuthTypeGSSCont:
return nil, errors.New("AuthTypeGSSCont is unimplemented")
case AuthTypeSSPI:
return nil, errors.New("AuthTypeSSPI is unimplemented")
case AuthTypeSASL: case AuthTypeSASL:
return &f.authenticationSASL, nil return &f.authenticationSASL, nil
case AuthTypeSASLContinue: case AuthTypeSASLContinue:
@@ -175,6 +190,12 @@ func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, er
case AuthTypeSASLFinal: case AuthTypeSASLFinal:
return &f.authenticationSASLFinal, nil return &f.authenticationSASLFinal, nil
default: default:
return nil, fmt.Errorf("unknown authentication type: %d", authType) return nil, fmt.Errorf("unknown authentication type: %d", f.authType)
} }
} }
// GetAuthType returns the authType used in the current state of the frontend.
// See SetAuthType for more information.
func (f *Frontend) GetAuthType() uint32 {
return f.authType
}
+3
View File
@@ -14,6 +14,9 @@ type PasswordMessage struct {
// Frontend identifies this message as sendable by a PostgreSQL frontend. // Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*PasswordMessage) Frontend() {} func (*PasswordMessage) Frontend() {}
// Frontend identifies this message as an authentication response.
func (*PasswordMessage) InitialResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length. // type identifier and 4 byte message length.
func (dst *PasswordMessage) Decode(src []byte) error { func (dst *PasswordMessage) Decode(src []byte) error {
+5
View File
@@ -27,6 +27,11 @@ type BackendMessage interface {
Backend() // no-op method to distinguish frontend from backend methods Backend() // no-op method to distinguish frontend from backend methods
} }
type AuthenticationResponseMessage interface {
BackendMessage
AuthenticationResponse() // no-op method to distinguish authentication responses
}
type invalidMessageLenErr struct { type invalidMessageLenErr struct {
messageType string messageType string
expectedLen int expectedLen int