2
0

Refactor authentication message handling

This commit is contained in:
Jack Christensen
2019-08-31 15:43:07 -05:00
parent 439ea11d47
commit 0d1ceed7a6
8 changed files with 408 additions and 178 deletions
-93
View File
@@ -1,93 +0,0 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"fmt"
"github.com/jackc/pgio"
)
// Authentication message type constants.
const (
AuthTypeOk = 0
AuthTypeCleartextPassword = 3
AuthTypeMD5Password = 5
AuthTypeSASL = 10
AuthTypeSASLContinue = 11
AuthTypeSASLFinal = 12
)
// Authentication is a message sent from the backend during the authentication process.
//
// There are multiple authentication messages that each begin with 'R'. This structure represents all such
// authentication messages.
type Authentication struct {
Type uint32
// MD5Password fields
Salt [4]byte
// SASL fields
SASLAuthMechanisms []string
// SASLContinue and SASLFinal data
SASLData []byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*Authentication) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *Authentication) Decode(src []byte) error {
*dst = Authentication{Type: binary.BigEndian.Uint32(src[:4])}
switch dst.Type {
case AuthTypeOk:
case AuthTypeCleartextPassword:
case AuthTypeMD5Password:
copy(dst.Salt[:], src[4:8])
case AuthTypeSASL:
authMechanisms := src[4:]
for len(authMechanisms) > 1 {
idx := bytes.IndexByte(authMechanisms, 0)
if idx > 0 {
dst.SASLAuthMechanisms = append(dst.SASLAuthMechanisms, string(authMechanisms[:idx]))
authMechanisms = authMechanisms[idx+1:]
}
}
case AuthTypeSASLContinue, AuthTypeSASLFinal:
dst.SASLData = src[4:]
default:
return fmt.Errorf("unknown authentication type: %d", dst.Type)
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Authentication) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, src.Type)
switch src.Type {
case AuthTypeMD5Password:
dst = append(dst, src.Salt[:]...)
case AuthTypeSASL:
for _, s := range src.SASLAuthMechanisms {
dst = append(dst, []byte(s)...)
dst = append(dst, 0)
}
dst = append(dst, 0)
case AuthTypeSASLContinue:
dst = pgio.AppendInt32(dst, int32(len(src.SASLData)))
dst = append(dst, src.SASLData...)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
+39
View File
@@ -0,0 +1,39 @@
package pgproto3
import (
"encoding/binary"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationCleartextPassword is a message sent from the backend indicating that a clear-text password is required.
type AuthenticationCleartextPassword struct {
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationCleartextPassword) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationCleartextPassword) Decode(src []byte) error {
if len(src) != 4 {
return errors.New("bad authentication message size")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeCleartextPassword {
return errors.New("bad auth type")
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 8)
dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword)
return dst
}
+43
View File
@@ -0,0 +1,43 @@
package pgproto3
import (
"encoding/binary"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationMD5Password is a message sent from the backend indicating that an MD5 hashed password is required.
type AuthenticationMD5Password struct {
Salt [4]byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationMD5Password) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationMD5Password) Decode(src []byte) error {
if len(src) != 8 {
return errors.New("bad authentication message size")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeMD5Password {
return errors.New("bad auth type")
}
copy(dst.Salt[:], src[4:8])
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationMD5Password) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 12)
dst = pgio.AppendUint32(dst, AuthTypeOk)
dst = append(dst, src.Salt[:]...)
return dst
}
+39
View File
@@ -0,0 +1,39 @@
package pgproto3
import (
"encoding/binary"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationOk is a message sent from the backend indicating that authentication was successful.
type AuthenticationOk struct {
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationOk) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationOk) Decode(src []byte) error {
if len(src) != 4 {
return errors.New("bad authentication message size")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeOk {
return errors.New("bad auth type")
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationOk) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 8)
dst = pgio.AppendUint32(dst, AuthTypeOk)
return dst
}
+60
View File
@@ -0,0 +1,60 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationSASL is a message sent from the backend indicating that SASL authentication is required.
type AuthenticationSASL struct {
AuthMechanisms []string
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationSASL) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationSASL) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeSASL {
return errors.New("bad auth type")
}
authMechanisms := src[4:]
for len(authMechanisms) > 1 {
idx := bytes.IndexByte(authMechanisms, 0)
if idx > 0 {
dst.AuthMechanisms = append(dst.AuthMechanisms, string(authMechanisms[:idx]))
authMechanisms = authMechanisms[idx+1:]
}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASL) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, AuthTypeSASL)
for _, s := range src.AuthMechanisms {
dst = append(dst, []byte(s)...)
dst = append(dst, 0)
}
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
+49
View File
@@ -0,0 +1,49 @@
package pgproto3
import (
"encoding/binary"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationSASLContinue is a message sent from the backend containing a SASL challenge.
type AuthenticationSASLContinue struct {
Data []byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationSASLContinue) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationSASLContinue) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeSASLContinue {
return errors.New("bad auth type")
}
dst.Data = src[4:]
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, AuthTypeSASLContinue)
dst = pgio.AppendInt32(dst, int32(len(src.Data)))
dst = append(dst, src.Data...)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
+49
View File
@@ -0,0 +1,49 @@
package pgproto3
import (
"encoding/binary"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationSASLFinal is a message sent from the backend indicating a SASL authentication has completed.
type AuthenticationSASLFinal struct {
Data []byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationSASLFinal) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationSASLFinal) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeSASLFinal {
return errors.New("bad auth type")
}
dst.Data = src[4:]
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, AuthTypeSASLFinal)
dst = pgio.AppendInt32(dst, int32(len(src.Data)))
dst = append(dst, src.Data...)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
+91 -47
View File
@@ -2,6 +2,7 @@ package pgproto3
import ( import (
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"io" "io"
) )
@@ -12,7 +13,12 @@ type Frontend struct {
w io.Writer w io.Writer
// Backend message flyweights // Backend message flyweights
authentication Authentication authenticationOk AuthenticationOk
authenticationCleartextPassword AuthenticationCleartextPassword
authenticationMD5Password AuthenticationMD5Password
authenticationSASL AuthenticationSASL
authenticationSASLContinue AuthenticationSASLContinue
authenticationSASLFinal AuthenticationSASLFinal
backendKeyData BackendKeyData backendKeyData BackendKeyData
bindComplete BindComplete bindComplete BindComplete
closeComplete CloseComplete closeComplete CloseComplete
@@ -47,83 +53,121 @@ func NewFrontend(cr ChunkReader, w io.Writer) *Frontend {
} }
// Send sends a message to the backend. // Send sends a message to the backend.
func (b *Frontend) Send(msg FrontendMessage) error { func (f *Frontend) Send(msg FrontendMessage) error {
_, err := b.w.Write(msg.Encode(nil)) _, err := f.w.Write(msg.Encode(nil))
return err return err
} }
// Receive receives a message from the backend. // Receive receives a message from the backend.
func (b *Frontend) Receive() (BackendMessage, error) { func (f *Frontend) Receive() (BackendMessage, error) {
if !b.partialMsg { if !f.partialMsg {
header, err := b.cr.Next(5) header, err := f.cr.Next(5)
if err != nil { if err != nil {
return nil, err return nil, err
} }
b.msgType = header[0] f.msgType = header[0]
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 f.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
b.partialMsg = true f.partialMsg = true
} }
msgBody, err := f.cr.Next(f.bodyLen)
if err != nil {
return nil, err
}
f.partialMsg = false
var msg BackendMessage var msg BackendMessage
switch b.msgType { switch f.msgType {
case '1': case '1':
msg = &b.parseComplete msg = &f.parseComplete
case '2': case '2':
msg = &b.bindComplete msg = &f.bindComplete
case '3': case '3':
msg = &b.closeComplete msg = &f.closeComplete
case 'A': case 'A':
msg = &b.notificationResponse msg = &f.notificationResponse
case 'c': case 'c':
msg = &b.copyDone msg = &f.copyDone
case 'C': case 'C':
msg = &b.commandComplete msg = &f.commandComplete
case 'd': case 'd':
msg = &b.copyData msg = &f.copyData
case 'D': case 'D':
msg = &b.dataRow msg = &f.dataRow
case 'E': case 'E':
msg = &b.errorResponse msg = &f.errorResponse
case 'G': case 'G':
msg = &b.copyInResponse msg = &f.copyInResponse
case 'H': case 'H':
msg = &b.copyOutResponse msg = &f.copyOutResponse
case 'I': case 'I':
msg = &b.emptyQueryResponse msg = &f.emptyQueryResponse
case 'K': case 'K':
msg = &b.backendKeyData msg = &f.backendKeyData
case 'n': case 'n':
msg = &b.noData msg = &f.noData
case 'N': case 'N':
msg = &b.noticeResponse msg = &f.noticeResponse
case 'R': case 'R':
msg = &b.authentication var err error
case 's': msg, err = f.findAuthenticationMessageType(msgBody)
msg = &b.portalSuspended
case 'S':
msg = &b.parameterStatus
case 't':
msg = &b.parameterDescription
case 'T':
msg = &b.rowDescription
case 'V':
msg = &b.functionCallResponse
case 'W':
msg = &b.copyBothResponse
case 'Z':
msg = &b.readyForQuery
default:
return nil, fmt.Errorf("unknown message type: %c", b.msgType)
}
msgBody, err := b.cr.Next(b.bodyLen)
if err != nil { if err != nil {
return nil, err return nil, err
} }
case 's':
b.partialMsg = false msg = &f.portalSuspended
case 'S':
msg = &f.parameterStatus
case 't':
msg = &f.parameterDescription
case 'T':
msg = &f.rowDescription
case 'V':
msg = &f.functionCallResponse
case 'W':
msg = &f.copyBothResponse
case 'Z':
msg = &f.readyForQuery
default:
return nil, fmt.Errorf("unknown message type: %c", f.msgType)
}
err = msg.Decode(msgBody) err = msg.Decode(msgBody)
return msg, err return msg, err
} }
// Authentication message type constants.
const (
AuthTypeOk = 0
AuthTypeCleartextPassword = 3
AuthTypeMD5Password = 5
AuthTypeSASL = 10
AuthTypeSASLContinue = 11
AuthTypeSASLFinal = 12
)
func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) {
if len(src) < 4 {
return nil, errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src[:4])
switch authType {
case AuthTypeOk:
return &f.authenticationOk, nil
case AuthTypeCleartextPassword:
return &f.authenticationCleartextPassword, nil
case AuthTypeMD5Password:
return &f.authenticationMD5Password, nil
case AuthTypeSASL:
return &f.authenticationSASL, nil
case AuthTypeSASLContinue:
return &f.authenticationSASLContinue, nil
case AuthTypeSASLFinal:
return &f.authenticationSASLFinal, nil
default:
return nil, fmt.Errorf("unknown authentication type: %d", authType)
}
}