2
0

Refactor authentication message handling

This commit is contained in:
Jack Christensen
2019-08-31 15:43:07 -05:00
parent 439ea11d47
commit 0d1ceed7a6
8 changed files with 408 additions and 178 deletions
-93
View File
@@ -1,93 +0,0 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"fmt"
"github.com/jackc/pgio"
)
// Authentication message type constants.
const (
AuthTypeOk = 0
AuthTypeCleartextPassword = 3
AuthTypeMD5Password = 5
AuthTypeSASL = 10
AuthTypeSASLContinue = 11
AuthTypeSASLFinal = 12
)
// Authentication is a message sent from the backend during the authentication process.
//
// There are multiple authentication messages that each begin with 'R'. This structure represents all such
// authentication messages.
type Authentication struct {
Type uint32
// MD5Password fields
Salt [4]byte
// SASL fields
SASLAuthMechanisms []string
// SASLContinue and SASLFinal data
SASLData []byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*Authentication) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *Authentication) Decode(src []byte) error {
*dst = Authentication{Type: binary.BigEndian.Uint32(src[:4])}
switch dst.Type {
case AuthTypeOk:
case AuthTypeCleartextPassword:
case AuthTypeMD5Password:
copy(dst.Salt[:], src[4:8])
case AuthTypeSASL:
authMechanisms := src[4:]
for len(authMechanisms) > 1 {
idx := bytes.IndexByte(authMechanisms, 0)
if idx > 0 {
dst.SASLAuthMechanisms = append(dst.SASLAuthMechanisms, string(authMechanisms[:idx]))
authMechanisms = authMechanisms[idx+1:]
}
}
case AuthTypeSASLContinue, AuthTypeSASLFinal:
dst.SASLData = src[4:]
default:
return fmt.Errorf("unknown authentication type: %d", dst.Type)
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Authentication) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, src.Type)
switch src.Type {
case AuthTypeMD5Password:
dst = append(dst, src.Salt[:]...)
case AuthTypeSASL:
for _, s := range src.SASLAuthMechanisms {
dst = append(dst, []byte(s)...)
dst = append(dst, 0)
}
dst = append(dst, 0)
case AuthTypeSASLContinue:
dst = pgio.AppendInt32(dst, int32(len(src.SASLData)))
dst = append(dst, src.SASLData...)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
+39
View File
@@ -0,0 +1,39 @@
package pgproto3
import (
"encoding/binary"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationCleartextPassword is a message sent from the backend indicating that a clear-text password is required.
type AuthenticationCleartextPassword struct {
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationCleartextPassword) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationCleartextPassword) Decode(src []byte) error {
if len(src) != 4 {
return errors.New("bad authentication message size")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeCleartextPassword {
return errors.New("bad auth type")
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 8)
dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword)
return dst
}
+43
View File
@@ -0,0 +1,43 @@
package pgproto3
import (
"encoding/binary"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationMD5Password is a message sent from the backend indicating that an MD5 hashed password is required.
type AuthenticationMD5Password struct {
Salt [4]byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationMD5Password) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationMD5Password) Decode(src []byte) error {
if len(src) != 8 {
return errors.New("bad authentication message size")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeMD5Password {
return errors.New("bad auth type")
}
copy(dst.Salt[:], src[4:8])
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationMD5Password) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 12)
dst = pgio.AppendUint32(dst, AuthTypeOk)
dst = append(dst, src.Salt[:]...)
return dst
}
+39
View File
@@ -0,0 +1,39 @@
package pgproto3
import (
"encoding/binary"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationOk is a message sent from the backend indicating that authentication was successful.
type AuthenticationOk struct {
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationOk) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationOk) Decode(src []byte) error {
if len(src) != 4 {
return errors.New("bad authentication message size")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeOk {
return errors.New("bad auth type")
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationOk) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 8)
dst = pgio.AppendUint32(dst, AuthTypeOk)
return dst
}
+60
View File
@@ -0,0 +1,60 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationSASL is a message sent from the backend indicating that SASL authentication is required.
type AuthenticationSASL struct {
AuthMechanisms []string
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationSASL) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationSASL) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeSASL {
return errors.New("bad auth type")
}
authMechanisms := src[4:]
for len(authMechanisms) > 1 {
idx := bytes.IndexByte(authMechanisms, 0)
if idx > 0 {
dst.AuthMechanisms = append(dst.AuthMechanisms, string(authMechanisms[:idx]))
authMechanisms = authMechanisms[idx+1:]
}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASL) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, AuthTypeSASL)
for _, s := range src.AuthMechanisms {
dst = append(dst, []byte(s)...)
dst = append(dst, 0)
}
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
+49
View File
@@ -0,0 +1,49 @@
package pgproto3
import (
"encoding/binary"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationSASLContinue is a message sent from the backend containing a SASL challenge.
type AuthenticationSASLContinue struct {
Data []byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationSASLContinue) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationSASLContinue) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeSASLContinue {
return errors.New("bad auth type")
}
dst.Data = src[4:]
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, AuthTypeSASLContinue)
dst = pgio.AppendInt32(dst, int32(len(src.Data)))
dst = append(dst, src.Data...)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
+49
View File
@@ -0,0 +1,49 @@
package pgproto3
import (
"encoding/binary"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationSASLFinal is a message sent from the backend indicating a SASL authentication has completed.
type AuthenticationSASLFinal struct {
Data []byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationSASLFinal) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationSASLFinal) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeSASLFinal {
return errors.New("bad auth type")
}
dst.Data = src[4:]
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, AuthTypeSASLFinal)
dst = pgio.AppendInt32(dst, int32(len(src.Data)))
dst = append(dst, src.Data...)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
+129 -85
View File
@@ -2,6 +2,7 @@ package pgproto3
import (
"encoding/binary"
"errors"
"fmt"
"io"
)
@@ -12,29 +13,34 @@ type Frontend struct {
w io.Writer
// Backend message flyweights
authentication Authentication
backendKeyData BackendKeyData
bindComplete BindComplete
closeComplete CloseComplete
commandComplete CommandComplete
copyBothResponse CopyBothResponse
copyData CopyData
copyInResponse CopyInResponse
copyOutResponse CopyOutResponse
copyDone CopyDone
dataRow DataRow
emptyQueryResponse EmptyQueryResponse
errorResponse ErrorResponse
functionCallResponse FunctionCallResponse
noData NoData
noticeResponse NoticeResponse
notificationResponse NotificationResponse
parameterDescription ParameterDescription
parameterStatus ParameterStatus
parseComplete ParseComplete
readyForQuery ReadyForQuery
rowDescription RowDescription
portalSuspended PortalSuspended
authenticationOk AuthenticationOk
authenticationCleartextPassword AuthenticationCleartextPassword
authenticationMD5Password AuthenticationMD5Password
authenticationSASL AuthenticationSASL
authenticationSASLContinue AuthenticationSASLContinue
authenticationSASLFinal AuthenticationSASLFinal
backendKeyData BackendKeyData
bindComplete BindComplete
closeComplete CloseComplete
commandComplete CommandComplete
copyBothResponse CopyBothResponse
copyData CopyData
copyInResponse CopyInResponse
copyOutResponse CopyOutResponse
copyDone CopyDone
dataRow DataRow
emptyQueryResponse EmptyQueryResponse
errorResponse ErrorResponse
functionCallResponse FunctionCallResponse
noData NoData
noticeResponse NoticeResponse
notificationResponse NotificationResponse
parameterDescription ParameterDescription
parameterStatus ParameterStatus
parseComplete ParseComplete
readyForQuery ReadyForQuery
rowDescription RowDescription
portalSuspended PortalSuspended
bodyLen int
msgType byte
@@ -47,83 +53,121 @@ func NewFrontend(cr ChunkReader, w io.Writer) *Frontend {
}
// Send sends a message to the backend.
func (b *Frontend) Send(msg FrontendMessage) error {
_, err := b.w.Write(msg.Encode(nil))
func (f *Frontend) Send(msg FrontendMessage) error {
_, err := f.w.Write(msg.Encode(nil))
return err
}
// Receive receives a message from the backend.
func (b *Frontend) Receive() (BackendMessage, error) {
if !b.partialMsg {
header, err := b.cr.Next(5)
func (f *Frontend) Receive() (BackendMessage, error) {
if !f.partialMsg {
header, err := f.cr.Next(5)
if err != nil {
return nil, err
}
b.msgType = header[0]
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
b.partialMsg = true
f.msgType = header[0]
f.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
f.partialMsg = true
}
var msg BackendMessage
switch b.msgType {
case '1':
msg = &b.parseComplete
case '2':
msg = &b.bindComplete
case '3':
msg = &b.closeComplete
case 'A':
msg = &b.notificationResponse
case 'c':
msg = &b.copyDone
case 'C':
msg = &b.commandComplete
case 'd':
msg = &b.copyData
case 'D':
msg = &b.dataRow
case 'E':
msg = &b.errorResponse
case 'G':
msg = &b.copyInResponse
case 'H':
msg = &b.copyOutResponse
case 'I':
msg = &b.emptyQueryResponse
case 'K':
msg = &b.backendKeyData
case 'n':
msg = &b.noData
case 'N':
msg = &b.noticeResponse
case 'R':
msg = &b.authentication
case 's':
msg = &b.portalSuspended
case 'S':
msg = &b.parameterStatus
case 't':
msg = &b.parameterDescription
case 'T':
msg = &b.rowDescription
case 'V':
msg = &b.functionCallResponse
case 'W':
msg = &b.copyBothResponse
case 'Z':
msg = &b.readyForQuery
default:
return nil, fmt.Errorf("unknown message type: %c", b.msgType)
}
msgBody, err := b.cr.Next(b.bodyLen)
msgBody, err := f.cr.Next(f.bodyLen)
if err != nil {
return nil, err
}
b.partialMsg = false
f.partialMsg = false
var msg BackendMessage
switch f.msgType {
case '1':
msg = &f.parseComplete
case '2':
msg = &f.bindComplete
case '3':
msg = &f.closeComplete
case 'A':
msg = &f.notificationResponse
case 'c':
msg = &f.copyDone
case 'C':
msg = &f.commandComplete
case 'd':
msg = &f.copyData
case 'D':
msg = &f.dataRow
case 'E':
msg = &f.errorResponse
case 'G':
msg = &f.copyInResponse
case 'H':
msg = &f.copyOutResponse
case 'I':
msg = &f.emptyQueryResponse
case 'K':
msg = &f.backendKeyData
case 'n':
msg = &f.noData
case 'N':
msg = &f.noticeResponse
case 'R':
var err error
msg, err = f.findAuthenticationMessageType(msgBody)
if err != nil {
return nil, err
}
case 's':
msg = &f.portalSuspended
case 'S':
msg = &f.parameterStatus
case 't':
msg = &f.parameterDescription
case 'T':
msg = &f.rowDescription
case 'V':
msg = &f.functionCallResponse
case 'W':
msg = &f.copyBothResponse
case 'Z':
msg = &f.readyForQuery
default:
return nil, fmt.Errorf("unknown message type: %c", f.msgType)
}
err = msg.Decode(msgBody)
return msg, err
}
// Authentication message type constants.
const (
AuthTypeOk = 0
AuthTypeCleartextPassword = 3
AuthTypeMD5Password = 5
AuthTypeSASL = 10
AuthTypeSASLContinue = 11
AuthTypeSASLFinal = 12
)
func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) {
if len(src) < 4 {
return nil, errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src[:4])
switch authType {
case AuthTypeOk:
return &f.authenticationOk, nil
case AuthTypeCleartextPassword:
return &f.authenticationCleartextPassword, nil
case AuthTypeMD5Password:
return &f.authenticationMD5Password, nil
case AuthTypeSASL:
return &f.authenticationSASL, nil
case AuthTypeSASLContinue:
return &f.authenticationSASLContinue, nil
case AuthTypeSASLFinal:
return &f.authenticationSASLFinal, nil
default:
return nil, fmt.Errorf("unknown authentication type: %d", authType)
}
}