2
0

Refactor authentication message handling

This commit is contained in:
Jack Christensen
2019-08-31 15:43:07 -05:00
parent 439ea11d47
commit 0d1ceed7a6
8 changed files with 408 additions and 178 deletions
-93
View File
@@ -1,93 +0,0 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"fmt"
"github.com/jackc/pgio"
)
// Authentication message type constants.
const (
AuthTypeOk = 0
AuthTypeCleartextPassword = 3
AuthTypeMD5Password = 5
AuthTypeSASL = 10
AuthTypeSASLContinue = 11
AuthTypeSASLFinal = 12
)
// Authentication is a message sent from the backend during the authentication process.
//
// There are multiple authentication messages that each begin with 'R'. This structure represents all such
// authentication messages.
type Authentication struct {
Type uint32
// MD5Password fields
Salt [4]byte
// SASL fields
SASLAuthMechanisms []string
// SASLContinue and SASLFinal data
SASLData []byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*Authentication) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *Authentication) Decode(src []byte) error {
*dst = Authentication{Type: binary.BigEndian.Uint32(src[:4])}
switch dst.Type {
case AuthTypeOk:
case AuthTypeCleartextPassword:
case AuthTypeMD5Password:
copy(dst.Salt[:], src[4:8])
case AuthTypeSASL:
authMechanisms := src[4:]
for len(authMechanisms) > 1 {
idx := bytes.IndexByte(authMechanisms, 0)
if idx > 0 {
dst.SASLAuthMechanisms = append(dst.SASLAuthMechanisms, string(authMechanisms[:idx]))
authMechanisms = authMechanisms[idx+1:]
}
}
case AuthTypeSASLContinue, AuthTypeSASLFinal:
dst.SASLData = src[4:]
default:
return fmt.Errorf("unknown authentication type: %d", dst.Type)
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Authentication) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, src.Type)
switch src.Type {
case AuthTypeMD5Password:
dst = append(dst, src.Salt[:]...)
case AuthTypeSASL:
for _, s := range src.SASLAuthMechanisms {
dst = append(dst, []byte(s)...)
dst = append(dst, 0)
}
dst = append(dst, 0)
case AuthTypeSASLContinue:
dst = pgio.AppendInt32(dst, int32(len(src.SASLData)))
dst = append(dst, src.SASLData...)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
+39
View File
@@ -0,0 +1,39 @@
package pgproto3
import (
"encoding/binary"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationCleartextPassword is a message sent from the backend indicating that a clear-text password is required.
type AuthenticationCleartextPassword struct {
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationCleartextPassword) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationCleartextPassword) Decode(src []byte) error {
if len(src) != 4 {
return errors.New("bad authentication message size")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeCleartextPassword {
return errors.New("bad auth type")
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 8)
dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword)
return dst
}
+43
View File
@@ -0,0 +1,43 @@
package pgproto3
import (
"encoding/binary"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationMD5Password is a message sent from the backend indicating that an MD5 hashed password is required.
type AuthenticationMD5Password struct {
Salt [4]byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationMD5Password) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationMD5Password) Decode(src []byte) error {
if len(src) != 8 {
return errors.New("bad authentication message size")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeMD5Password {
return errors.New("bad auth type")
}
copy(dst.Salt[:], src[4:8])
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationMD5Password) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 12)
dst = pgio.AppendUint32(dst, AuthTypeOk)
dst = append(dst, src.Salt[:]...)
return dst
}
+39
View File
@@ -0,0 +1,39 @@
package pgproto3
import (
"encoding/binary"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationOk is a message sent from the backend indicating that authentication was successful.
type AuthenticationOk struct {
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationOk) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationOk) Decode(src []byte) error {
if len(src) != 4 {
return errors.New("bad authentication message size")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeOk {
return errors.New("bad auth type")
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationOk) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 8)
dst = pgio.AppendUint32(dst, AuthTypeOk)
return dst
}
+60
View File
@@ -0,0 +1,60 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationSASL is a message sent from the backend indicating that SASL authentication is required.
type AuthenticationSASL struct {
AuthMechanisms []string
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationSASL) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationSASL) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeSASL {
return errors.New("bad auth type")
}
authMechanisms := src[4:]
for len(authMechanisms) > 1 {
idx := bytes.IndexByte(authMechanisms, 0)
if idx > 0 {
dst.AuthMechanisms = append(dst.AuthMechanisms, string(authMechanisms[:idx]))
authMechanisms = authMechanisms[idx+1:]
}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASL) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, AuthTypeSASL)
for _, s := range src.AuthMechanisms {
dst = append(dst, []byte(s)...)
dst = append(dst, 0)
}
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
+49
View File
@@ -0,0 +1,49 @@
package pgproto3
import (
"encoding/binary"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationSASLContinue is a message sent from the backend containing a SASL challenge.
type AuthenticationSASLContinue struct {
Data []byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationSASLContinue) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationSASLContinue) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeSASLContinue {
return errors.New("bad auth type")
}
dst.Data = src[4:]
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, AuthTypeSASLContinue)
dst = pgio.AppendInt32(dst, int32(len(src.Data)))
dst = append(dst, src.Data...)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
+49
View File
@@ -0,0 +1,49 @@
package pgproto3
import (
"encoding/binary"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationSASLFinal is a message sent from the backend indicating a SASL authentication has completed.
type AuthenticationSASLFinal struct {
Data []byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationSASLFinal) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationSASLFinal) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeSASLFinal {
return errors.New("bad auth type")
}
dst.Data = src[4:]
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, AuthTypeSASLFinal)
dst = pgio.AppendInt32(dst, int32(len(src.Data)))
dst = append(dst, src.Data...)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
+129 -85
View File
@@ -2,6 +2,7 @@ package pgproto3
import ( import (
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"io" "io"
) )
@@ -12,29 +13,34 @@ type Frontend struct {
w io.Writer w io.Writer
// Backend message flyweights // Backend message flyweights
authentication Authentication authenticationOk AuthenticationOk
backendKeyData BackendKeyData authenticationCleartextPassword AuthenticationCleartextPassword
bindComplete BindComplete authenticationMD5Password AuthenticationMD5Password
closeComplete CloseComplete authenticationSASL AuthenticationSASL
commandComplete CommandComplete authenticationSASLContinue AuthenticationSASLContinue
copyBothResponse CopyBothResponse authenticationSASLFinal AuthenticationSASLFinal
copyData CopyData backendKeyData BackendKeyData
copyInResponse CopyInResponse bindComplete BindComplete
copyOutResponse CopyOutResponse closeComplete CloseComplete
copyDone CopyDone commandComplete CommandComplete
dataRow DataRow copyBothResponse CopyBothResponse
emptyQueryResponse EmptyQueryResponse copyData CopyData
errorResponse ErrorResponse copyInResponse CopyInResponse
functionCallResponse FunctionCallResponse copyOutResponse CopyOutResponse
noData NoData copyDone CopyDone
noticeResponse NoticeResponse dataRow DataRow
notificationResponse NotificationResponse emptyQueryResponse EmptyQueryResponse
parameterDescription ParameterDescription errorResponse ErrorResponse
parameterStatus ParameterStatus functionCallResponse FunctionCallResponse
parseComplete ParseComplete noData NoData
readyForQuery ReadyForQuery noticeResponse NoticeResponse
rowDescription RowDescription notificationResponse NotificationResponse
portalSuspended PortalSuspended parameterDescription ParameterDescription
parameterStatus ParameterStatus
parseComplete ParseComplete
readyForQuery ReadyForQuery
rowDescription RowDescription
portalSuspended PortalSuspended
bodyLen int bodyLen int
msgType byte msgType byte
@@ -47,83 +53,121 @@ func NewFrontend(cr ChunkReader, w io.Writer) *Frontend {
} }
// Send sends a message to the backend. // Send sends a message to the backend.
func (b *Frontend) Send(msg FrontendMessage) error { func (f *Frontend) Send(msg FrontendMessage) error {
_, err := b.w.Write(msg.Encode(nil)) _, err := f.w.Write(msg.Encode(nil))
return err return err
} }
// Receive receives a message from the backend. // Receive receives a message from the backend.
func (b *Frontend) Receive() (BackendMessage, error) { func (f *Frontend) Receive() (BackendMessage, error) {
if !b.partialMsg { if !f.partialMsg {
header, err := b.cr.Next(5) header, err := f.cr.Next(5)
if err != nil { if err != nil {
return nil, err return nil, err
} }
b.msgType = header[0] f.msgType = header[0]
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 f.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
b.partialMsg = true f.partialMsg = true
} }
var msg BackendMessage msgBody, err := f.cr.Next(f.bodyLen)
switch b.msgType {
case '1':
msg = &b.parseComplete
case '2':
msg = &b.bindComplete
case '3':
msg = &b.closeComplete
case 'A':
msg = &b.notificationResponse
case 'c':
msg = &b.copyDone
case 'C':
msg = &b.commandComplete
case 'd':
msg = &b.copyData
case 'D':
msg = &b.dataRow
case 'E':
msg = &b.errorResponse
case 'G':
msg = &b.copyInResponse
case 'H':
msg = &b.copyOutResponse
case 'I':
msg = &b.emptyQueryResponse
case 'K':
msg = &b.backendKeyData
case 'n':
msg = &b.noData
case 'N':
msg = &b.noticeResponse
case 'R':
msg = &b.authentication
case 's':
msg = &b.portalSuspended
case 'S':
msg = &b.parameterStatus
case 't':
msg = &b.parameterDescription
case 'T':
msg = &b.rowDescription
case 'V':
msg = &b.functionCallResponse
case 'W':
msg = &b.copyBothResponse
case 'Z':
msg = &b.readyForQuery
default:
return nil, fmt.Errorf("unknown message type: %c", b.msgType)
}
msgBody, err := b.cr.Next(b.bodyLen)
if err != nil { if err != nil {
return nil, err return nil, err
} }
b.partialMsg = false f.partialMsg = false
var msg BackendMessage
switch f.msgType {
case '1':
msg = &f.parseComplete
case '2':
msg = &f.bindComplete
case '3':
msg = &f.closeComplete
case 'A':
msg = &f.notificationResponse
case 'c':
msg = &f.copyDone
case 'C':
msg = &f.commandComplete
case 'd':
msg = &f.copyData
case 'D':
msg = &f.dataRow
case 'E':
msg = &f.errorResponse
case 'G':
msg = &f.copyInResponse
case 'H':
msg = &f.copyOutResponse
case 'I':
msg = &f.emptyQueryResponse
case 'K':
msg = &f.backendKeyData
case 'n':
msg = &f.noData
case 'N':
msg = &f.noticeResponse
case 'R':
var err error
msg, err = f.findAuthenticationMessageType(msgBody)
if err != nil {
return nil, err
}
case 's':
msg = &f.portalSuspended
case 'S':
msg = &f.parameterStatus
case 't':
msg = &f.parameterDescription
case 'T':
msg = &f.rowDescription
case 'V':
msg = &f.functionCallResponse
case 'W':
msg = &f.copyBothResponse
case 'Z':
msg = &f.readyForQuery
default:
return nil, fmt.Errorf("unknown message type: %c", f.msgType)
}
err = msg.Decode(msgBody) err = msg.Decode(msgBody)
return msg, err return msg, err
} }
// Authentication message type constants.
const (
AuthTypeOk = 0
AuthTypeCleartextPassword = 3
AuthTypeMD5Password = 5
AuthTypeSASL = 10
AuthTypeSASLContinue = 11
AuthTypeSASLFinal = 12
)
func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) {
if len(src) < 4 {
return nil, errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src[:4])
switch authType {
case AuthTypeOk:
return &f.authenticationOk, nil
case AuthTypeCleartextPassword:
return &f.authenticationCleartextPassword, nil
case AuthTypeMD5Password:
return &f.authenticationMD5Password, nil
case AuthTypeSASL:
return &f.authenticationSASL, nil
case AuthTypeSASLContinue:
return &f.authenticationSASLContinue, nil
case AuthTypeSASLFinal:
return &f.authenticationSASLFinal, nil
default:
return nil, fmt.Errorf("unknown authentication type: %d", authType)
}
}