From 0d1ceed7a6902fbf9ae6c0c43f00174400547d5d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 31 Aug 2019 15:43:07 -0500 Subject: [PATCH] Refactor authentication message handling --- authentication.go | 93 ------------ authentication_cleartext_password.go | 39 +++++ authentication_md5_password.go | 43 ++++++ authentication_ok.go | 39 +++++ authentication_sasl.go | 60 ++++++++ authentication_sasl_continue.go | 49 ++++++ authentication_sasl_final.go | 49 ++++++ frontend.go | 214 ++++++++++++++++----------- 8 files changed, 408 insertions(+), 178 deletions(-) delete mode 100644 authentication.go create mode 100644 authentication_cleartext_password.go create mode 100644 authentication_md5_password.go create mode 100644 authentication_ok.go create mode 100644 authentication_sasl.go create mode 100644 authentication_sasl_continue.go create mode 100644 authentication_sasl_final.go diff --git a/authentication.go b/authentication.go deleted file mode 100644 index 5ff05d96..00000000 --- a/authentication.go +++ /dev/null @@ -1,93 +0,0 @@ -package pgproto3 - -import ( - "bytes" - "encoding/binary" - "fmt" - - "github.com/jackc/pgio" -) - -// Authentication message type constants. -const ( - AuthTypeOk = 0 - AuthTypeCleartextPassword = 3 - AuthTypeMD5Password = 5 - AuthTypeSASL = 10 - AuthTypeSASLContinue = 11 - AuthTypeSASLFinal = 12 -) - -// Authentication is a message sent from the backend during the authentication process. -// -// There are multiple authentication messages that each begin with 'R'. This structure represents all such -// authentication messages. -type Authentication struct { - Type uint32 - - // MD5Password fields - Salt [4]byte - - // SASL fields - SASLAuthMechanisms []string - - // SASLContinue and SASLFinal data - SASLData []byte -} - -// Backend identifies this message as sendable by the PostgreSQL backend. -func (*Authentication) Backend() {} - -// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message -// type identifier and 4 byte message length. -func (dst *Authentication) Decode(src []byte) error { - *dst = Authentication{Type: binary.BigEndian.Uint32(src[:4])} - - switch dst.Type { - case AuthTypeOk: - case AuthTypeCleartextPassword: - case AuthTypeMD5Password: - copy(dst.Salt[:], src[4:8]) - case AuthTypeSASL: - authMechanisms := src[4:] - for len(authMechanisms) > 1 { - idx := bytes.IndexByte(authMechanisms, 0) - if idx > 0 { - dst.SASLAuthMechanisms = append(dst.SASLAuthMechanisms, string(authMechanisms[:idx])) - authMechanisms = authMechanisms[idx+1:] - } - } - case AuthTypeSASLContinue, AuthTypeSASLFinal: - dst.SASLData = src[4:] - default: - return fmt.Errorf("unknown authentication type: %d", dst.Type) - } - - return nil -} - -// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Authentication) Encode(dst []byte) []byte { - dst = append(dst, 'R') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - dst = pgio.AppendUint32(dst, src.Type) - - switch src.Type { - case AuthTypeMD5Password: - dst = append(dst, src.Salt[:]...) - case AuthTypeSASL: - for _, s := range src.SASLAuthMechanisms { - dst = append(dst, []byte(s)...) - dst = append(dst, 0) - } - dst = append(dst, 0) - case AuthTypeSASLContinue: - dst = pgio.AppendInt32(dst, int32(len(src.SASLData))) - dst = append(dst, src.SASLData...) - } - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst -} diff --git a/authentication_cleartext_password.go b/authentication_cleartext_password.go new file mode 100644 index 00000000..dd82c7a7 --- /dev/null +++ b/authentication_cleartext_password.go @@ -0,0 +1,39 @@ +package pgproto3 + +import ( + "encoding/binary" + "errors" + + "github.com/jackc/pgio" +) + +// AuthenticationCleartextPassword is a message sent from the backend indicating that a clear-text password is required. +type AuthenticationCleartextPassword struct { +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationCleartextPassword) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationCleartextPassword) Decode(src []byte) error { + if len(src) != 4 { + return errors.New("bad authentication message size") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeCleartextPassword { + return errors.New("bad auth type") + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte { + dst = append(dst, 'R') + dst = pgio.AppendInt32(dst, 8) + dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword) + return dst +} diff --git a/authentication_md5_password.go b/authentication_md5_password.go new file mode 100644 index 00000000..4680db5a --- /dev/null +++ b/authentication_md5_password.go @@ -0,0 +1,43 @@ +package pgproto3 + +import ( + "encoding/binary" + "errors" + + "github.com/jackc/pgio" +) + +// AuthenticationMD5Password is a message sent from the backend indicating that an MD5 hashed password is required. +type AuthenticationMD5Password struct { + Salt [4]byte +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationMD5Password) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationMD5Password) Decode(src []byte) error { + if len(src) != 8 { + return errors.New("bad authentication message size") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeMD5Password { + return errors.New("bad auth type") + } + + copy(dst.Salt[:], src[4:8]) + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationMD5Password) Encode(dst []byte) []byte { + dst = append(dst, 'R') + dst = pgio.AppendInt32(dst, 12) + dst = pgio.AppendUint32(dst, AuthTypeOk) + dst = append(dst, src.Salt[:]...) + return dst +} diff --git a/authentication_ok.go b/authentication_ok.go new file mode 100644 index 00000000..7b13c6e0 --- /dev/null +++ b/authentication_ok.go @@ -0,0 +1,39 @@ +package pgproto3 + +import ( + "encoding/binary" + "errors" + + "github.com/jackc/pgio" +) + +// AuthenticationOk is a message sent from the backend indicating that authentication was successful. +type AuthenticationOk struct { +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationOk) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationOk) Decode(src []byte) error { + if len(src) != 4 { + return errors.New("bad authentication message size") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeOk { + return errors.New("bad auth type") + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationOk) Encode(dst []byte) []byte { + dst = append(dst, 'R') + dst = pgio.AppendInt32(dst, 8) + dst = pgio.AppendUint32(dst, AuthTypeOk) + return dst +} diff --git a/authentication_sasl.go b/authentication_sasl.go new file mode 100644 index 00000000..c57ae32d --- /dev/null +++ b/authentication_sasl.go @@ -0,0 +1,60 @@ +package pgproto3 + +import ( + "bytes" + "encoding/binary" + "errors" + + "github.com/jackc/pgio" +) + +// AuthenticationSASL is a message sent from the backend indicating that SASL authentication is required. +type AuthenticationSASL struct { + AuthMechanisms []string +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationSASL) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationSASL) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("authentication message too short") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeSASL { + return errors.New("bad auth type") + } + + authMechanisms := src[4:] + for len(authMechanisms) > 1 { + idx := bytes.IndexByte(authMechanisms, 0) + if idx > 0 { + dst.AuthMechanisms = append(dst.AuthMechanisms, string(authMechanisms[:idx])) + authMechanisms = authMechanisms[idx+1:] + } + } + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationSASL) Encode(dst []byte) []byte { + dst = append(dst, 'R') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + dst = pgio.AppendUint32(dst, AuthTypeSASL) + + for _, s := range src.AuthMechanisms { + dst = append(dst, []byte(s)...) + dst = append(dst, 0) + } + dst = append(dst, 0) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} diff --git a/authentication_sasl_continue.go b/authentication_sasl_continue.go new file mode 100644 index 00000000..a393ae10 --- /dev/null +++ b/authentication_sasl_continue.go @@ -0,0 +1,49 @@ +package pgproto3 + +import ( + "encoding/binary" + "errors" + + "github.com/jackc/pgio" +) + +// AuthenticationSASLContinue is a message sent from the backend containing a SASL challenge. +type AuthenticationSASLContinue struct { + Data []byte +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationSASLContinue) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationSASLContinue) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("authentication message too short") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeSASLContinue { + return errors.New("bad auth type") + } + + dst.Data = src[4:] + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte { + dst = append(dst, 'R') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + dst = pgio.AppendUint32(dst, AuthTypeSASLContinue) + + dst = pgio.AppendInt32(dst, int32(len(src.Data))) + dst = append(dst, src.Data...) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} diff --git a/authentication_sasl_final.go b/authentication_sasl_final.go new file mode 100644 index 00000000..b8f89d59 --- /dev/null +++ b/authentication_sasl_final.go @@ -0,0 +1,49 @@ +package pgproto3 + +import ( + "encoding/binary" + "errors" + + "github.com/jackc/pgio" +) + +// AuthenticationSASLFinal is a message sent from the backend indicating a SASL authentication has completed. +type AuthenticationSASLFinal struct { + Data []byte +} + +// Backend identifies this message as sendable by the PostgreSQL backend. +func (*AuthenticationSASLFinal) Backend() {} + +// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message +// type identifier and 4 byte message length. +func (dst *AuthenticationSASLFinal) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("authentication message too short") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeSASLFinal { + return errors.New("bad auth type") + } + + dst.Data = src[4:] + + return nil +} + +// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. +func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte { + dst = append(dst, 'R') + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + dst = pgio.AppendUint32(dst, AuthTypeSASLFinal) + + dst = pgio.AppendInt32(dst, int32(len(src.Data))) + dst = append(dst, src.Data...) + + pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + + return dst +} diff --git a/frontend.go b/frontend.go index a67b6670..0826685b 100644 --- a/frontend.go +++ b/frontend.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/binary" + "errors" "fmt" "io" ) @@ -12,29 +13,34 @@ type Frontend struct { w io.Writer // Backend message flyweights - authentication Authentication - backendKeyData BackendKeyData - bindComplete BindComplete - closeComplete CloseComplete - commandComplete CommandComplete - copyBothResponse CopyBothResponse - copyData CopyData - copyInResponse CopyInResponse - copyOutResponse CopyOutResponse - copyDone CopyDone - dataRow DataRow - emptyQueryResponse EmptyQueryResponse - errorResponse ErrorResponse - functionCallResponse FunctionCallResponse - noData NoData - noticeResponse NoticeResponse - notificationResponse NotificationResponse - parameterDescription ParameterDescription - parameterStatus ParameterStatus - parseComplete ParseComplete - readyForQuery ReadyForQuery - rowDescription RowDescription - portalSuspended PortalSuspended + authenticationOk AuthenticationOk + authenticationCleartextPassword AuthenticationCleartextPassword + authenticationMD5Password AuthenticationMD5Password + authenticationSASL AuthenticationSASL + authenticationSASLContinue AuthenticationSASLContinue + authenticationSASLFinal AuthenticationSASLFinal + backendKeyData BackendKeyData + bindComplete BindComplete + closeComplete CloseComplete + commandComplete CommandComplete + copyBothResponse CopyBothResponse + copyData CopyData + copyInResponse CopyInResponse + copyOutResponse CopyOutResponse + copyDone CopyDone + dataRow DataRow + emptyQueryResponse EmptyQueryResponse + errorResponse ErrorResponse + functionCallResponse FunctionCallResponse + noData NoData + noticeResponse NoticeResponse + notificationResponse NotificationResponse + parameterDescription ParameterDescription + parameterStatus ParameterStatus + parseComplete ParseComplete + readyForQuery ReadyForQuery + rowDescription RowDescription + portalSuspended PortalSuspended bodyLen int msgType byte @@ -47,83 +53,121 @@ func NewFrontend(cr ChunkReader, w io.Writer) *Frontend { } // Send sends a message to the backend. -func (b *Frontend) Send(msg FrontendMessage) error { - _, err := b.w.Write(msg.Encode(nil)) +func (f *Frontend) Send(msg FrontendMessage) error { + _, err := f.w.Write(msg.Encode(nil)) return err } // Receive receives a message from the backend. -func (b *Frontend) Receive() (BackendMessage, error) { - if !b.partialMsg { - header, err := b.cr.Next(5) +func (f *Frontend) Receive() (BackendMessage, error) { + if !f.partialMsg { + header, err := f.cr.Next(5) if err != nil { return nil, err } - b.msgType = header[0] - b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 - b.partialMsg = true + f.msgType = header[0] + f.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 + f.partialMsg = true } - var msg BackendMessage - switch b.msgType { - case '1': - msg = &b.parseComplete - case '2': - msg = &b.bindComplete - case '3': - msg = &b.closeComplete - case 'A': - msg = &b.notificationResponse - case 'c': - msg = &b.copyDone - case 'C': - msg = &b.commandComplete - case 'd': - msg = &b.copyData - case 'D': - msg = &b.dataRow - case 'E': - msg = &b.errorResponse - case 'G': - msg = &b.copyInResponse - case 'H': - msg = &b.copyOutResponse - case 'I': - msg = &b.emptyQueryResponse - case 'K': - msg = &b.backendKeyData - case 'n': - msg = &b.noData - case 'N': - msg = &b.noticeResponse - case 'R': - msg = &b.authentication - case 's': - msg = &b.portalSuspended - case 'S': - msg = &b.parameterStatus - case 't': - msg = &b.parameterDescription - case 'T': - msg = &b.rowDescription - case 'V': - msg = &b.functionCallResponse - case 'W': - msg = &b.copyBothResponse - case 'Z': - msg = &b.readyForQuery - default: - return nil, fmt.Errorf("unknown message type: %c", b.msgType) - } - - msgBody, err := b.cr.Next(b.bodyLen) + msgBody, err := f.cr.Next(f.bodyLen) if err != nil { return nil, err } - b.partialMsg = false + f.partialMsg = false + + var msg BackendMessage + switch f.msgType { + case '1': + msg = &f.parseComplete + case '2': + msg = &f.bindComplete + case '3': + msg = &f.closeComplete + case 'A': + msg = &f.notificationResponse + case 'c': + msg = &f.copyDone + case 'C': + msg = &f.commandComplete + case 'd': + msg = &f.copyData + case 'D': + msg = &f.dataRow + case 'E': + msg = &f.errorResponse + case 'G': + msg = &f.copyInResponse + case 'H': + msg = &f.copyOutResponse + case 'I': + msg = &f.emptyQueryResponse + case 'K': + msg = &f.backendKeyData + case 'n': + msg = &f.noData + case 'N': + msg = &f.noticeResponse + case 'R': + var err error + msg, err = f.findAuthenticationMessageType(msgBody) + if err != nil { + return nil, err + } + case 's': + msg = &f.portalSuspended + case 'S': + msg = &f.parameterStatus + case 't': + msg = &f.parameterDescription + case 'T': + msg = &f.rowDescription + case 'V': + msg = &f.functionCallResponse + case 'W': + msg = &f.copyBothResponse + case 'Z': + msg = &f.readyForQuery + default: + return nil, fmt.Errorf("unknown message type: %c", f.msgType) + } err = msg.Decode(msgBody) return msg, err } + +// Authentication message type constants. +const ( + AuthTypeOk = 0 + AuthTypeCleartextPassword = 3 + AuthTypeMD5Password = 5 + AuthTypeSASL = 10 + AuthTypeSASLContinue = 11 + AuthTypeSASLFinal = 12 +) + +func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) { + if len(src) < 4 { + return nil, errors.New("authentication message too short") + } + authType := binary.BigEndian.Uint32(src[:4]) + + switch authType { + case AuthTypeOk: + return &f.authenticationOk, nil + case AuthTypeCleartextPassword: + return &f.authenticationCleartextPassword, nil + case AuthTypeMD5Password: + return &f.authenticationMD5Password, nil + case AuthTypeSASL: + return &f.authenticationSASL, nil + case AuthTypeSASLContinue: + return &f.authenticationSASLContinue, nil + case AuthTypeSASLFinal: + return &f.authenticationSASLFinal, nil + default: + return nil, fmt.Errorf("unknown authentication type: %d", authType) + } +}