From e6b823d64953284cfbc874de358700902129cdae Mon Sep 17 00:00:00 2001 From: Yuli Khodorkovskiy Date: Tue, 17 Dec 2019 20:03:55 -0500 Subject: [PATCH] Add missing GSSEncRequest --- backend.go | 9 ++++++++- gss_enc_request.go | 49 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 1 deletion(-) create mode 100644 gss_enc_request.go diff --git a/backend.go b/backend.go index 5741647f..cd7e8ce2 100644 --- a/backend.go +++ b/backend.go @@ -19,6 +19,7 @@ type Backend struct { describe Describe execute Execute flush Flush + gssEncRequest GSSEncRequest parse Parse passwordMessage PasswordMessage query Query @@ -45,7 +46,7 @@ func (b *Backend) Send(msg BackendMessage) error { // ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method // because the initial connection message is "special" and does not include the message type as the first byte. This -// will return either a StartupMessage, SSLRequest, or CancelRequest. +// will return either a StartupMessage, SSLRequest, GSSEncRequest, or CancelRequest. func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) { buf, err := b.cr.Next(4) if err != nil { @@ -79,6 +80,12 @@ func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) { return nil, err } return &b.cancelRequest, nil + case gssEncReqNumber: + err = b.gssEncRequest.Decode(buf) + if err != nil { + return nil, err + } + return &b.gssEncRequest, nil default: return nil, fmt.Errorf("unknown startup message code: %d", code) } diff --git a/gss_enc_request.go b/gss_enc_request.go new file mode 100644 index 00000000..cf405a3e --- /dev/null +++ b/gss_enc_request.go @@ -0,0 +1,49 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + + "github.com/jackc/pgio" +) + +const gssEncReqNumber = 80877104 + +type GSSEncRequest struct { +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (*GSSEncRequest) Frontend() {} + +func (dst *GSSEncRequest) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("gss encoding request too short") + } + + requestCode := binary.BigEndian.Uint32(src) + + if requestCode != gssEncReqNumber { + return errors.New("bad gss encoding request code") + } + + return nil +} + +// Encode encodes src into dst. dst will include the 4 byte message length. +func (src *GSSEncRequest) Encode(dst []byte) []byte { + dst = pgio.AppendInt32(dst, 8) + dst = pgio.AppendInt32(dst, gssEncReqNumber) + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (src GSSEncRequest) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + ProtocolVersion uint32 + Parameters map[string]string + }{ + Type: "GSSEncRequest", + }) +}