2
0

Merge branch 'master' into backend-unexpected-eof

This commit is contained in:
Jack Christensen
2021-07-06 20:07:55 -05:00
committed by GitHub
2 changed files with 71 additions and 0 deletions
+9
View File
@@ -35,6 +35,11 @@ type Backend struct {
authType uint32
}
const (
minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code.
maxStartupPacketLen = 10000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source.
)
// NewBackend creates a new Backend.
func NewBackend(cr ChunkReader, w io.Writer) *Backend {
return &Backend{cr: cr, w: w}
@@ -56,6 +61,10 @@ func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) {
}
msgSize := int(binary.BigEndian.Uint32(buf) - 4)
if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen {
return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize)
}
buf, err = b.cr.Next(msgSize)
if err != nil {
return nil, translateEOFtoErrUnexpectedEOF(err)
+62
View File
@@ -4,8 +4,10 @@ import (
"io"
"testing"
"github.com/jackc/pgio"
"github.com/jackc/pgproto3/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestBackendReceiveInterrupted(t *testing.T) {
@@ -54,3 +56,63 @@ func TestBackendReceiveUnexpectedEOF(t *testing.T) {
assert.Nil(t, msg)
assert.Equal(t, io.ErrUnexpectedEOF, err)
}
func TestStartupMessage(t *testing.T) {
t.Parallel()
t.Run("valid StartupMessage", func(t *testing.T) {
want := &pgproto3.StartupMessage{
ProtocolVersion: pgproto3.ProtocolVersionNumber,
Parameters: map[string]string{
"username": "tester",
},
}
dst := []byte{}
dst = want.Encode(dst)
server := &interruptReader{}
server.push(dst)
backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil)
msg, err := backend.ReceiveStartupMessage()
require.NoError(t, err)
require.Equal(t, want, msg)
})
t.Run("invalid packet length", func(t *testing.T) {
wantErr := "invalid length of startup packet"
tests := []struct {
name string
packetLen uint32
}{
{
name: "large packet length",
// Since the StartupMessage contains the "Length of message contents
// in bytes, including self", the max startup packet length is actually
// 10000+4. Therefore, let's go past the limit with 10005
packetLen: 10005,
},
{
name: "short packet length",
packetLen: 3,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := &interruptReader{}
dst := []byte{}
dst = pgio.AppendUint32(dst, tt.packetLen)
dst = pgio.AppendUint32(dst, pgproto3.ProtocolVersionNumber)
server.push(dst)
backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil)
msg, err := backend.ReceiveStartupMessage()
require.Error(t, err)
require.Nil(t, msg)
require.Contains(t, err.Error(), wantErr)
})
}
})
}