From 1ba5dcbe01a089d1f2bd3822e7bc7a0c913c44cc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 31 Aug 2019 11:48:01 -0500 Subject: [PATCH] Support SSLRequest and CancelRequest --- backend.go | 38 +++++++++++++++++++++++------- cancel_request.go | 58 ++++++++++++++++++++++++++++++++++++++++++++++ ssl_request.go | 49 +++++++++++++++++++++++++++++++++++++++ startup_message.go | 9 +------ 4 files changed, 138 insertions(+), 16 deletions(-) create mode 100644 cancel_request.go create mode 100644 ssl_request.go diff --git a/backend.go b/backend.go index 2e2f5eea..be8f3bdb 100644 --- a/backend.go +++ b/backend.go @@ -2,6 +2,7 @@ package pgproto3 import ( "encoding/binary" + "fmt" "io" "github.com/pkg/errors" @@ -14,6 +15,7 @@ type Backend struct { // Frontend message flyweights bind Bind + cancelRequest CancelRequest _close Close copyFail CopyFail describe Describe @@ -22,6 +24,7 @@ type Backend struct { parse Parse passwordMessage PasswordMessage query Query + sslRequest SSLRequest startupMessage StartupMessage sync Sync terminate Terminate @@ -42,9 +45,10 @@ func (b *Backend) Send(msg BackendMessage) error { return err } -// ReceiveStartupMessage receives the initial startup message. This method is used of the normal Receive method -// because StartupMessage and SSLRequest are "special" and do not include the message type as the first byte. -func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) { +// ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method +// because the initial connection message is "special" and does not include the message type as the first byte. This +// will return either a StartupMessage, SSLRequest, or CancelRequest. +func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) { buf, err := b.cr.Next(4) if err != nil { return nil, err @@ -56,12 +60,30 @@ func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) { return nil, err } - err = b.startupMessage.Decode(buf) - if err != nil { - return nil, err - } + code := binary.BigEndian.Uint32(buf) - return &b.startupMessage, nil + switch code { + case ProtocolVersionNumber: + err = b.startupMessage.Decode(buf) + if err != nil { + return nil, err + } + return &b.startupMessage, nil + case sslRequestNumber: + err = b.sslRequest.Decode(buf) + if err != nil { + return nil, err + } + return &b.sslRequest, nil + case cancelRequestCode: + err = b.cancelRequest.Decode(buf) + if err != nil { + return nil, err + } + return &b.cancelRequest, nil + default: + return nil, fmt.Errorf("unknown startup message code: %d", code) + } } // Receive receives a message from the frontend. diff --git a/cancel_request.go b/cancel_request.go new file mode 100644 index 00000000..ec1d8606 --- /dev/null +++ b/cancel_request.go @@ -0,0 +1,58 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgio" + "github.com/pkg/errors" +) + +const cancelRequestCode = 80877102 + +type CancelRequest struct { + ProcessID uint32 + SecretKey uint32 +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*CancelRequest) Frontend() {} + +func (dst *CancelRequest) Decode(src []byte) error { + if len(src) != 12 { + return errors.Errorf("bad cancel request size") + } + + requestCode := binary.BigEndian.Uint32(src) + + if requestCode != cancelRequestCode { + return errors.Errorf("bad cancel request code") + } + + dst.ProcessID = binary.BigEndian.Uint32(src[4:]) + dst.SecretKey = binary.BigEndian.Uint32(src[8:]) + + return nil +} + +// Encode encodes src into dst. dst will include the 4 byte message length. +func (src *CancelRequest) Encode(dst []byte) []byte { + dst = pgio.AppendInt32(dst, 16) + dst = pgio.AppendInt32(dst, cancelRequestCode) + dst = pgio.AppendUint32(dst, src.ProcessID) + dst = pgio.AppendUint32(dst, src.SecretKey) + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src CancelRequest) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ProcessID uint32 + SecretKey uint32 + }{ + Type: "CancelRequest", + ProcessID: src.ProcessID, + SecretKey: src.SecretKey, + }) +} diff --git a/ssl_request.go b/ssl_request.go new file mode 100644 index 00000000..2f4b378a --- /dev/null +++ b/ssl_request.go @@ -0,0 +1,49 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + + "github.com/jackc/pgio" + "github.com/pkg/errors" +) + +const sslRequestNumber = 80877103 + +type SSLRequest struct { +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*SSLRequest) Frontend() {} + +func (dst *SSLRequest) Decode(src []byte) error { + if len(src) < 4 { + return errors.Errorf("ssl request too short") + } + + requestCode := binary.BigEndian.Uint32(src) + + if requestCode != sslRequestNumber { + return errors.Errorf("bad ssl request code") + } + + return nil +} + +// Encode encodes src into dst. dst will include the 4 byte message length. +func (src *SSLRequest) Encode(dst []byte) []byte { + dst = pgio.AppendInt32(dst, 8) + dst = pgio.AppendInt32(dst, sslRequestNumber) + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src SSLRequest) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ProtocolVersion uint32 + Parameters map[string]string + }{ + Type: "SSLRequest", + }) +} diff --git a/startup_message.go b/startup_message.go index 0c5c961d..5be42500 100644 --- a/startup_message.go +++ b/startup_message.go @@ -9,10 +9,7 @@ import ( "github.com/pkg/errors" ) -const ( - ProtocolVersionNumber = 196608 // 3.0 - sslRequestNumber = 80877103 -) +const ProtocolVersionNumber = 196608 // 3.0 type StartupMessage struct { ProtocolVersion uint32 @@ -32,10 +29,6 @@ func (dst *StartupMessage) Decode(src []byte) error { dst.ProtocolVersion = binary.BigEndian.Uint32(src) rp := 4 - if dst.ProtocolVersion == sslRequestNumber { - return errors.Errorf("can't handle ssl connection request") - } - if dst.ProtocolVersion != ProtocolVersionNumber { return errors.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion) }