Refactor authentication message handling
This commit is contained in:
+129
-85
@@ -2,6 +2,7 @@ package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
@@ -12,29 +13,34 @@ type Frontend struct {
|
||||
w io.Writer
|
||||
|
||||
// Backend message flyweights
|
||||
authentication Authentication
|
||||
backendKeyData BackendKeyData
|
||||
bindComplete BindComplete
|
||||
closeComplete CloseComplete
|
||||
commandComplete CommandComplete
|
||||
copyBothResponse CopyBothResponse
|
||||
copyData CopyData
|
||||
copyInResponse CopyInResponse
|
||||
copyOutResponse CopyOutResponse
|
||||
copyDone CopyDone
|
||||
dataRow DataRow
|
||||
emptyQueryResponse EmptyQueryResponse
|
||||
errorResponse ErrorResponse
|
||||
functionCallResponse FunctionCallResponse
|
||||
noData NoData
|
||||
noticeResponse NoticeResponse
|
||||
notificationResponse NotificationResponse
|
||||
parameterDescription ParameterDescription
|
||||
parameterStatus ParameterStatus
|
||||
parseComplete ParseComplete
|
||||
readyForQuery ReadyForQuery
|
||||
rowDescription RowDescription
|
||||
portalSuspended PortalSuspended
|
||||
authenticationOk AuthenticationOk
|
||||
authenticationCleartextPassword AuthenticationCleartextPassword
|
||||
authenticationMD5Password AuthenticationMD5Password
|
||||
authenticationSASL AuthenticationSASL
|
||||
authenticationSASLContinue AuthenticationSASLContinue
|
||||
authenticationSASLFinal AuthenticationSASLFinal
|
||||
backendKeyData BackendKeyData
|
||||
bindComplete BindComplete
|
||||
closeComplete CloseComplete
|
||||
commandComplete CommandComplete
|
||||
copyBothResponse CopyBothResponse
|
||||
copyData CopyData
|
||||
copyInResponse CopyInResponse
|
||||
copyOutResponse CopyOutResponse
|
||||
copyDone CopyDone
|
||||
dataRow DataRow
|
||||
emptyQueryResponse EmptyQueryResponse
|
||||
errorResponse ErrorResponse
|
||||
functionCallResponse FunctionCallResponse
|
||||
noData NoData
|
||||
noticeResponse NoticeResponse
|
||||
notificationResponse NotificationResponse
|
||||
parameterDescription ParameterDescription
|
||||
parameterStatus ParameterStatus
|
||||
parseComplete ParseComplete
|
||||
readyForQuery ReadyForQuery
|
||||
rowDescription RowDescription
|
||||
portalSuspended PortalSuspended
|
||||
|
||||
bodyLen int
|
||||
msgType byte
|
||||
@@ -47,83 +53,121 @@ func NewFrontend(cr ChunkReader, w io.Writer) *Frontend {
|
||||
}
|
||||
|
||||
// Send sends a message to the backend.
|
||||
func (b *Frontend) Send(msg FrontendMessage) error {
|
||||
_, err := b.w.Write(msg.Encode(nil))
|
||||
func (f *Frontend) Send(msg FrontendMessage) error {
|
||||
_, err := f.w.Write(msg.Encode(nil))
|
||||
return err
|
||||
}
|
||||
|
||||
// Receive receives a message from the backend.
|
||||
func (b *Frontend) Receive() (BackendMessage, error) {
|
||||
if !b.partialMsg {
|
||||
header, err := b.cr.Next(5)
|
||||
func (f *Frontend) Receive() (BackendMessage, error) {
|
||||
if !f.partialMsg {
|
||||
header, err := f.cr.Next(5)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b.msgType = header[0]
|
||||
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
|
||||
b.partialMsg = true
|
||||
f.msgType = header[0]
|
||||
f.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
|
||||
f.partialMsg = true
|
||||
}
|
||||
|
||||
var msg BackendMessage
|
||||
switch b.msgType {
|
||||
case '1':
|
||||
msg = &b.parseComplete
|
||||
case '2':
|
||||
msg = &b.bindComplete
|
||||
case '3':
|
||||
msg = &b.closeComplete
|
||||
case 'A':
|
||||
msg = &b.notificationResponse
|
||||
case 'c':
|
||||
msg = &b.copyDone
|
||||
case 'C':
|
||||
msg = &b.commandComplete
|
||||
case 'd':
|
||||
msg = &b.copyData
|
||||
case 'D':
|
||||
msg = &b.dataRow
|
||||
case 'E':
|
||||
msg = &b.errorResponse
|
||||
case 'G':
|
||||
msg = &b.copyInResponse
|
||||
case 'H':
|
||||
msg = &b.copyOutResponse
|
||||
case 'I':
|
||||
msg = &b.emptyQueryResponse
|
||||
case 'K':
|
||||
msg = &b.backendKeyData
|
||||
case 'n':
|
||||
msg = &b.noData
|
||||
case 'N':
|
||||
msg = &b.noticeResponse
|
||||
case 'R':
|
||||
msg = &b.authentication
|
||||
case 's':
|
||||
msg = &b.portalSuspended
|
||||
case 'S':
|
||||
msg = &b.parameterStatus
|
||||
case 't':
|
||||
msg = &b.parameterDescription
|
||||
case 'T':
|
||||
msg = &b.rowDescription
|
||||
case 'V':
|
||||
msg = &b.functionCallResponse
|
||||
case 'W':
|
||||
msg = &b.copyBothResponse
|
||||
case 'Z':
|
||||
msg = &b.readyForQuery
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown message type: %c", b.msgType)
|
||||
}
|
||||
|
||||
msgBody, err := b.cr.Next(b.bodyLen)
|
||||
msgBody, err := f.cr.Next(f.bodyLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b.partialMsg = false
|
||||
f.partialMsg = false
|
||||
|
||||
var msg BackendMessage
|
||||
switch f.msgType {
|
||||
case '1':
|
||||
msg = &f.parseComplete
|
||||
case '2':
|
||||
msg = &f.bindComplete
|
||||
case '3':
|
||||
msg = &f.closeComplete
|
||||
case 'A':
|
||||
msg = &f.notificationResponse
|
||||
case 'c':
|
||||
msg = &f.copyDone
|
||||
case 'C':
|
||||
msg = &f.commandComplete
|
||||
case 'd':
|
||||
msg = &f.copyData
|
||||
case 'D':
|
||||
msg = &f.dataRow
|
||||
case 'E':
|
||||
msg = &f.errorResponse
|
||||
case 'G':
|
||||
msg = &f.copyInResponse
|
||||
case 'H':
|
||||
msg = &f.copyOutResponse
|
||||
case 'I':
|
||||
msg = &f.emptyQueryResponse
|
||||
case 'K':
|
||||
msg = &f.backendKeyData
|
||||
case 'n':
|
||||
msg = &f.noData
|
||||
case 'N':
|
||||
msg = &f.noticeResponse
|
||||
case 'R':
|
||||
var err error
|
||||
msg, err = f.findAuthenticationMessageType(msgBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case 's':
|
||||
msg = &f.portalSuspended
|
||||
case 'S':
|
||||
msg = &f.parameterStatus
|
||||
case 't':
|
||||
msg = &f.parameterDescription
|
||||
case 'T':
|
||||
msg = &f.rowDescription
|
||||
case 'V':
|
||||
msg = &f.functionCallResponse
|
||||
case 'W':
|
||||
msg = &f.copyBothResponse
|
||||
case 'Z':
|
||||
msg = &f.readyForQuery
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown message type: %c", f.msgType)
|
||||
}
|
||||
|
||||
err = msg.Decode(msgBody)
|
||||
return msg, err
|
||||
}
|
||||
|
||||
// Authentication message type constants.
|
||||
const (
|
||||
AuthTypeOk = 0
|
||||
AuthTypeCleartextPassword = 3
|
||||
AuthTypeMD5Password = 5
|
||||
AuthTypeSASL = 10
|
||||
AuthTypeSASLContinue = 11
|
||||
AuthTypeSASLFinal = 12
|
||||
)
|
||||
|
||||
func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) {
|
||||
if len(src) < 4 {
|
||||
return nil, errors.New("authentication message too short")
|
||||
}
|
||||
authType := binary.BigEndian.Uint32(src[:4])
|
||||
|
||||
switch authType {
|
||||
case AuthTypeOk:
|
||||
return &f.authenticationOk, nil
|
||||
case AuthTypeCleartextPassword:
|
||||
return &f.authenticationCleartextPassword, nil
|
||||
case AuthTypeMD5Password:
|
||||
return &f.authenticationMD5Password, nil
|
||||
case AuthTypeSASL:
|
||||
return &f.authenticationSASL, nil
|
||||
case AuthTypeSASLContinue:
|
||||
return &f.authenticationSASLContinue, nil
|
||||
case AuthTypeSASLFinal:
|
||||
return &f.authenticationSASLFinal, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown authentication type: %d", authType)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user