Refactor authentication message handling
This commit is contained in:
@@ -1,93 +0,0 @@
|
|||||||
package pgproto3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/jackc/pgio"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Authentication message type constants.
|
|
||||||
const (
|
|
||||||
AuthTypeOk = 0
|
|
||||||
AuthTypeCleartextPassword = 3
|
|
||||||
AuthTypeMD5Password = 5
|
|
||||||
AuthTypeSASL = 10
|
|
||||||
AuthTypeSASLContinue = 11
|
|
||||||
AuthTypeSASLFinal = 12
|
|
||||||
)
|
|
||||||
|
|
||||||
// Authentication is a message sent from the backend during the authentication process.
|
|
||||||
//
|
|
||||||
// There are multiple authentication messages that each begin with 'R'. This structure represents all such
|
|
||||||
// authentication messages.
|
|
||||||
type Authentication struct {
|
|
||||||
Type uint32
|
|
||||||
|
|
||||||
// MD5Password fields
|
|
||||||
Salt [4]byte
|
|
||||||
|
|
||||||
// SASL fields
|
|
||||||
SASLAuthMechanisms []string
|
|
||||||
|
|
||||||
// SASLContinue and SASLFinal data
|
|
||||||
SASLData []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
|
||||||
func (*Authentication) Backend() {}
|
|
||||||
|
|
||||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
|
||||||
// type identifier and 4 byte message length.
|
|
||||||
func (dst *Authentication) Decode(src []byte) error {
|
|
||||||
*dst = Authentication{Type: binary.BigEndian.Uint32(src[:4])}
|
|
||||||
|
|
||||||
switch dst.Type {
|
|
||||||
case AuthTypeOk:
|
|
||||||
case AuthTypeCleartextPassword:
|
|
||||||
case AuthTypeMD5Password:
|
|
||||||
copy(dst.Salt[:], src[4:8])
|
|
||||||
case AuthTypeSASL:
|
|
||||||
authMechanisms := src[4:]
|
|
||||||
for len(authMechanisms) > 1 {
|
|
||||||
idx := bytes.IndexByte(authMechanisms, 0)
|
|
||||||
if idx > 0 {
|
|
||||||
dst.SASLAuthMechanisms = append(dst.SASLAuthMechanisms, string(authMechanisms[:idx]))
|
|
||||||
authMechanisms = authMechanisms[idx+1:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case AuthTypeSASLContinue, AuthTypeSASLFinal:
|
|
||||||
dst.SASLData = src[4:]
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unknown authentication type: %d", dst.Type)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
|
||||||
func (src *Authentication) Encode(dst []byte) []byte {
|
|
||||||
dst = append(dst, 'R')
|
|
||||||
sp := len(dst)
|
|
||||||
dst = pgio.AppendInt32(dst, -1)
|
|
||||||
dst = pgio.AppendUint32(dst, src.Type)
|
|
||||||
|
|
||||||
switch src.Type {
|
|
||||||
case AuthTypeMD5Password:
|
|
||||||
dst = append(dst, src.Salt[:]...)
|
|
||||||
case AuthTypeSASL:
|
|
||||||
for _, s := range src.SASLAuthMechanisms {
|
|
||||||
dst = append(dst, []byte(s)...)
|
|
||||||
dst = append(dst, 0)
|
|
||||||
}
|
|
||||||
dst = append(dst, 0)
|
|
||||||
case AuthTypeSASLContinue:
|
|
||||||
dst = pgio.AppendInt32(dst, int32(len(src.SASLData)))
|
|
||||||
dst = append(dst, src.SASLData...)
|
|
||||||
}
|
|
||||||
|
|
||||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthenticationCleartextPassword is a message sent from the backend indicating that a clear-text password is required.
|
||||||
|
type AuthenticationCleartextPassword struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*AuthenticationCleartextPassword) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *AuthenticationCleartextPassword) Decode(src []byte) error {
|
||||||
|
if len(src) != 4 {
|
||||||
|
return errors.New("bad authentication message size")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeCleartextPassword {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
dst = pgio.AppendInt32(dst, 8)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthenticationMD5Password is a message sent from the backend indicating that an MD5 hashed password is required.
|
||||||
|
type AuthenticationMD5Password struct {
|
||||||
|
Salt [4]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*AuthenticationMD5Password) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *AuthenticationMD5Password) Decode(src []byte) error {
|
||||||
|
if len(src) != 8 {
|
||||||
|
return errors.New("bad authentication message size")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeMD5Password {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(dst.Salt[:], src[4:8])
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *AuthenticationMD5Password) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
dst = pgio.AppendInt32(dst, 12)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeOk)
|
||||||
|
dst = append(dst, src.Salt[:]...)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthenticationOk is a message sent from the backend indicating that authentication was successful.
|
||||||
|
type AuthenticationOk struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*AuthenticationOk) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *AuthenticationOk) Decode(src []byte) error {
|
||||||
|
if len(src) != 4 {
|
||||||
|
return errors.New("bad authentication message size")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeOk {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *AuthenticationOk) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
dst = pgio.AppendInt32(dst, 8)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeOk)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
@@ -0,0 +1,60 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthenticationSASL is a message sent from the backend indicating that SASL authentication is required.
|
||||||
|
type AuthenticationSASL struct {
|
||||||
|
AuthMechanisms []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*AuthenticationSASL) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *AuthenticationSASL) Decode(src []byte) error {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return errors.New("authentication message too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeSASL {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
|
||||||
|
authMechanisms := src[4:]
|
||||||
|
for len(authMechanisms) > 1 {
|
||||||
|
idx := bytes.IndexByte(authMechanisms, 0)
|
||||||
|
if idx > 0 {
|
||||||
|
dst.AuthMechanisms = append(dst.AuthMechanisms, string(authMechanisms[:idx]))
|
||||||
|
authMechanisms = authMechanisms[idx+1:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *AuthenticationSASL) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeSASL)
|
||||||
|
|
||||||
|
for _, s := range src.AuthMechanisms {
|
||||||
|
dst = append(dst, []byte(s)...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
}
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthenticationSASLContinue is a message sent from the backend containing a SASL challenge.
|
||||||
|
type AuthenticationSASLContinue struct {
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*AuthenticationSASLContinue) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *AuthenticationSASLContinue) Decode(src []byte) error {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return errors.New("authentication message too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeSASLContinue {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Data = src[4:]
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeSASLContinue)
|
||||||
|
|
||||||
|
dst = pgio.AppendInt32(dst, int32(len(src.Data)))
|
||||||
|
dst = append(dst, src.Data...)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthenticationSASLFinal is a message sent from the backend indicating a SASL authentication has completed.
|
||||||
|
type AuthenticationSASLFinal struct {
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*AuthenticationSASLFinal) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *AuthenticationSASLFinal) Decode(src []byte) error {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return errors.New("authentication message too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeSASLFinal {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Data = src[4:]
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeSASLFinal)
|
||||||
|
|
||||||
|
dst = pgio.AppendInt32(dst, int32(len(src.Data)))
|
||||||
|
dst = append(dst, src.Data...)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
+129
-85
@@ -2,6 +2,7 @@ package pgproto3
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
)
|
)
|
||||||
@@ -12,29 +13,34 @@ type Frontend struct {
|
|||||||
w io.Writer
|
w io.Writer
|
||||||
|
|
||||||
// Backend message flyweights
|
// Backend message flyweights
|
||||||
authentication Authentication
|
authenticationOk AuthenticationOk
|
||||||
backendKeyData BackendKeyData
|
authenticationCleartextPassword AuthenticationCleartextPassword
|
||||||
bindComplete BindComplete
|
authenticationMD5Password AuthenticationMD5Password
|
||||||
closeComplete CloseComplete
|
authenticationSASL AuthenticationSASL
|
||||||
commandComplete CommandComplete
|
authenticationSASLContinue AuthenticationSASLContinue
|
||||||
copyBothResponse CopyBothResponse
|
authenticationSASLFinal AuthenticationSASLFinal
|
||||||
copyData CopyData
|
backendKeyData BackendKeyData
|
||||||
copyInResponse CopyInResponse
|
bindComplete BindComplete
|
||||||
copyOutResponse CopyOutResponse
|
closeComplete CloseComplete
|
||||||
copyDone CopyDone
|
commandComplete CommandComplete
|
||||||
dataRow DataRow
|
copyBothResponse CopyBothResponse
|
||||||
emptyQueryResponse EmptyQueryResponse
|
copyData CopyData
|
||||||
errorResponse ErrorResponse
|
copyInResponse CopyInResponse
|
||||||
functionCallResponse FunctionCallResponse
|
copyOutResponse CopyOutResponse
|
||||||
noData NoData
|
copyDone CopyDone
|
||||||
noticeResponse NoticeResponse
|
dataRow DataRow
|
||||||
notificationResponse NotificationResponse
|
emptyQueryResponse EmptyQueryResponse
|
||||||
parameterDescription ParameterDescription
|
errorResponse ErrorResponse
|
||||||
parameterStatus ParameterStatus
|
functionCallResponse FunctionCallResponse
|
||||||
parseComplete ParseComplete
|
noData NoData
|
||||||
readyForQuery ReadyForQuery
|
noticeResponse NoticeResponse
|
||||||
rowDescription RowDescription
|
notificationResponse NotificationResponse
|
||||||
portalSuspended PortalSuspended
|
parameterDescription ParameterDescription
|
||||||
|
parameterStatus ParameterStatus
|
||||||
|
parseComplete ParseComplete
|
||||||
|
readyForQuery ReadyForQuery
|
||||||
|
rowDescription RowDescription
|
||||||
|
portalSuspended PortalSuspended
|
||||||
|
|
||||||
bodyLen int
|
bodyLen int
|
||||||
msgType byte
|
msgType byte
|
||||||
@@ -47,83 +53,121 @@ func NewFrontend(cr ChunkReader, w io.Writer) *Frontend {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Send sends a message to the backend.
|
// Send sends a message to the backend.
|
||||||
func (b *Frontend) Send(msg FrontendMessage) error {
|
func (f *Frontend) Send(msg FrontendMessage) error {
|
||||||
_, err := b.w.Write(msg.Encode(nil))
|
_, err := f.w.Write(msg.Encode(nil))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Receive receives a message from the backend.
|
// Receive receives a message from the backend.
|
||||||
func (b *Frontend) Receive() (BackendMessage, error) {
|
func (f *Frontend) Receive() (BackendMessage, error) {
|
||||||
if !b.partialMsg {
|
if !f.partialMsg {
|
||||||
header, err := b.cr.Next(5)
|
header, err := f.cr.Next(5)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
b.msgType = header[0]
|
f.msgType = header[0]
|
||||||
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
|
f.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
|
||||||
b.partialMsg = true
|
f.partialMsg = true
|
||||||
}
|
}
|
||||||
|
|
||||||
var msg BackendMessage
|
msgBody, err := f.cr.Next(f.bodyLen)
|
||||||
switch b.msgType {
|
|
||||||
case '1':
|
|
||||||
msg = &b.parseComplete
|
|
||||||
case '2':
|
|
||||||
msg = &b.bindComplete
|
|
||||||
case '3':
|
|
||||||
msg = &b.closeComplete
|
|
||||||
case 'A':
|
|
||||||
msg = &b.notificationResponse
|
|
||||||
case 'c':
|
|
||||||
msg = &b.copyDone
|
|
||||||
case 'C':
|
|
||||||
msg = &b.commandComplete
|
|
||||||
case 'd':
|
|
||||||
msg = &b.copyData
|
|
||||||
case 'D':
|
|
||||||
msg = &b.dataRow
|
|
||||||
case 'E':
|
|
||||||
msg = &b.errorResponse
|
|
||||||
case 'G':
|
|
||||||
msg = &b.copyInResponse
|
|
||||||
case 'H':
|
|
||||||
msg = &b.copyOutResponse
|
|
||||||
case 'I':
|
|
||||||
msg = &b.emptyQueryResponse
|
|
||||||
case 'K':
|
|
||||||
msg = &b.backendKeyData
|
|
||||||
case 'n':
|
|
||||||
msg = &b.noData
|
|
||||||
case 'N':
|
|
||||||
msg = &b.noticeResponse
|
|
||||||
case 'R':
|
|
||||||
msg = &b.authentication
|
|
||||||
case 's':
|
|
||||||
msg = &b.portalSuspended
|
|
||||||
case 'S':
|
|
||||||
msg = &b.parameterStatus
|
|
||||||
case 't':
|
|
||||||
msg = &b.parameterDescription
|
|
||||||
case 'T':
|
|
||||||
msg = &b.rowDescription
|
|
||||||
case 'V':
|
|
||||||
msg = &b.functionCallResponse
|
|
||||||
case 'W':
|
|
||||||
msg = &b.copyBothResponse
|
|
||||||
case 'Z':
|
|
||||||
msg = &b.readyForQuery
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unknown message type: %c", b.msgType)
|
|
||||||
}
|
|
||||||
|
|
||||||
msgBody, err := b.cr.Next(b.bodyLen)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
b.partialMsg = false
|
f.partialMsg = false
|
||||||
|
|
||||||
|
var msg BackendMessage
|
||||||
|
switch f.msgType {
|
||||||
|
case '1':
|
||||||
|
msg = &f.parseComplete
|
||||||
|
case '2':
|
||||||
|
msg = &f.bindComplete
|
||||||
|
case '3':
|
||||||
|
msg = &f.closeComplete
|
||||||
|
case 'A':
|
||||||
|
msg = &f.notificationResponse
|
||||||
|
case 'c':
|
||||||
|
msg = &f.copyDone
|
||||||
|
case 'C':
|
||||||
|
msg = &f.commandComplete
|
||||||
|
case 'd':
|
||||||
|
msg = &f.copyData
|
||||||
|
case 'D':
|
||||||
|
msg = &f.dataRow
|
||||||
|
case 'E':
|
||||||
|
msg = &f.errorResponse
|
||||||
|
case 'G':
|
||||||
|
msg = &f.copyInResponse
|
||||||
|
case 'H':
|
||||||
|
msg = &f.copyOutResponse
|
||||||
|
case 'I':
|
||||||
|
msg = &f.emptyQueryResponse
|
||||||
|
case 'K':
|
||||||
|
msg = &f.backendKeyData
|
||||||
|
case 'n':
|
||||||
|
msg = &f.noData
|
||||||
|
case 'N':
|
||||||
|
msg = &f.noticeResponse
|
||||||
|
case 'R':
|
||||||
|
var err error
|
||||||
|
msg, err = f.findAuthenticationMessageType(msgBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
case 's':
|
||||||
|
msg = &f.portalSuspended
|
||||||
|
case 'S':
|
||||||
|
msg = &f.parameterStatus
|
||||||
|
case 't':
|
||||||
|
msg = &f.parameterDescription
|
||||||
|
case 'T':
|
||||||
|
msg = &f.rowDescription
|
||||||
|
case 'V':
|
||||||
|
msg = &f.functionCallResponse
|
||||||
|
case 'W':
|
||||||
|
msg = &f.copyBothResponse
|
||||||
|
case 'Z':
|
||||||
|
msg = &f.readyForQuery
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown message type: %c", f.msgType)
|
||||||
|
}
|
||||||
|
|
||||||
err = msg.Decode(msgBody)
|
err = msg.Decode(msgBody)
|
||||||
return msg, err
|
return msg, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Authentication message type constants.
|
||||||
|
const (
|
||||||
|
AuthTypeOk = 0
|
||||||
|
AuthTypeCleartextPassword = 3
|
||||||
|
AuthTypeMD5Password = 5
|
||||||
|
AuthTypeSASL = 10
|
||||||
|
AuthTypeSASLContinue = 11
|
||||||
|
AuthTypeSASLFinal = 12
|
||||||
|
)
|
||||||
|
|
||||||
|
func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return nil, errors.New("authentication message too short")
|
||||||
|
}
|
||||||
|
authType := binary.BigEndian.Uint32(src[:4])
|
||||||
|
|
||||||
|
switch authType {
|
||||||
|
case AuthTypeOk:
|
||||||
|
return &f.authenticationOk, nil
|
||||||
|
case AuthTypeCleartextPassword:
|
||||||
|
return &f.authenticationCleartextPassword, nil
|
||||||
|
case AuthTypeMD5Password:
|
||||||
|
return &f.authenticationMD5Password, nil
|
||||||
|
case AuthTypeSASL:
|
||||||
|
return &f.authenticationSASL, nil
|
||||||
|
case AuthTypeSASLContinue:
|
||||||
|
return &f.authenticationSASLContinue, nil
|
||||||
|
case AuthTypeSASLFinal:
|
||||||
|
return &f.authenticationSASLFinal, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown authentication type: %d", authType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user