Update pgproto3 to enable pgmock
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package pgio
|
package pgio
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -38,3 +39,13 @@ func NextInt64(buf []byte) ([]byte, int64) {
|
|||||||
buf, n := NextUint64(buf)
|
buf, n := NextUint64(buf)
|
||||||
return buf, int64(n)
|
return buf, int64(n)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NextCString(buf []byte) ([]byte, string, bool) {
|
||||||
|
idx := bytes.IndexByte(buf, 0)
|
||||||
|
if idx < 0 {
|
||||||
|
return buf, "", false
|
||||||
|
}
|
||||||
|
cstring := string(buf[:idx])
|
||||||
|
buf = buf[:idx+1]
|
||||||
|
return buf, cstring, true
|
||||||
|
}
|
||||||
|
|||||||
+28
-2
@@ -2,7 +2,6 @@ package pgproto3
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
@@ -20,6 +19,7 @@ type Backend struct {
|
|||||||
parse Parse
|
parse Parse
|
||||||
passwordMessage PasswordMessage
|
passwordMessage PasswordMessage
|
||||||
query Query
|
query Query
|
||||||
|
startupMessage StartupMessage
|
||||||
sync Sync
|
sync Sync
|
||||||
terminate Terminate
|
terminate Terminate
|
||||||
}
|
}
|
||||||
@@ -30,7 +30,33 @@ func NewBackend(r io.Reader, w io.Writer) (*Backend, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *Backend) Send(msg BackendMessage) error {
|
func (b *Backend) Send(msg BackendMessage) error {
|
||||||
return errors.New("not implemented")
|
buf, err := msg.MarshalBinary()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = b.w.Write(buf)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Backend) ReceiveStartupMessage() (*StartupMessage, error) {
|
||||||
|
buf, err := b.cr.Next(4)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
msgSize := int(binary.BigEndian.Uint32(buf) - 4)
|
||||||
|
|
||||||
|
buf, err = b.cr.Next(msgSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = b.startupMessage.Decode(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &b.startupMessage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Backend) Receive() (FrontendMessage, error) {
|
func (b *Backend) Receive() (FrontendMessage, error) {
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package pgproto3
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
@@ -43,7 +42,13 @@ func NewFrontend(r io.Reader, w io.Writer) (*Frontend, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *Frontend) Send(msg FrontendMessage) error {
|
func (b *Frontend) Send(msg FrontendMessage) error {
|
||||||
return errors.New("not implemented")
|
buf, err := msg.MarshalBinary()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = b.w.Write(buf)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Frontend) Receive() (BackendMessage, error) {
|
func (b *Frontend) Receive() (BackendMessage, error) {
|
||||||
|
|||||||
@@ -0,0 +1,95 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
protocolVersionNumber = 196608 // 3.0
|
||||||
|
sslRequestNumber = 80877103
|
||||||
|
)
|
||||||
|
|
||||||
|
type StartupMessage struct {
|
||||||
|
ProtocolVersion uint32
|
||||||
|
Parameters map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*StartupMessage) Frontend() {}
|
||||||
|
|
||||||
|
func (dst *StartupMessage) Decode(src []byte) error {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return fmt.Errorf("startup message too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.ProtocolVersion = binary.BigEndian.Uint32(src)
|
||||||
|
rp := 4
|
||||||
|
|
||||||
|
if dst.ProtocolVersion == sslRequestNumber {
|
||||||
|
return fmt.Errorf("can't handle ssl connection request")
|
||||||
|
}
|
||||||
|
|
||||||
|
if dst.ProtocolVersion != protocolVersionNumber {
|
||||||
|
return fmt.Errorf("Bad startup message version number. Expected %d, got %d", protocolVersionNumber, dst.ProtocolVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Parameters = make(map[string]string)
|
||||||
|
for {
|
||||||
|
idx := bytes.IndexByte(src[rp:], 0)
|
||||||
|
if idx < 0 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "StartupMesage"}
|
||||||
|
}
|
||||||
|
key := string(src[rp : rp+idx])
|
||||||
|
rp += idx + 1
|
||||||
|
|
||||||
|
idx = bytes.IndexByte(src[rp:], 0)
|
||||||
|
if idx < 0 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "StartupMesage"}
|
||||||
|
}
|
||||||
|
value := string(src[rp : rp+idx])
|
||||||
|
rp += idx + 1
|
||||||
|
|
||||||
|
dst.Parameters[key] = value
|
||||||
|
|
||||||
|
if len(src[rp:]) == 1 {
|
||||||
|
if src[rp] != 0 {
|
||||||
|
return fmt.Errorf("Bad startup message last byte. Expected 0, got %d", src[rp])
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src *StartupMessage) MarshalBinary() ([]byte, error) {
|
||||||
|
var bigEndian BigEndianBuf
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
buf.Write(bigEndian.Uint32(0))
|
||||||
|
buf.Write(bigEndian.Uint32(src.ProtocolVersion))
|
||||||
|
for k, v := range src.Parameters {
|
||||||
|
buf.WriteString(k)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
buf.WriteString(v)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
buf.WriteByte(0)
|
||||||
|
|
||||||
|
binary.BigEndian.PutUint32(buf.Bytes()[0:4], uint32(buf.Len()))
|
||||||
|
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src *StartupMessage) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
ProtocolVersion uint32
|
||||||
|
Parameters map[string]string
|
||||||
|
}{
|
||||||
|
Type: "StartupMessage",
|
||||||
|
ProtocolVersion: src.ProtocolVersion,
|
||||||
|
Parameters: src.Parameters,
|
||||||
|
})
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user