Merge branch 'master' into backend-unexpected-eof
This commit is contained in:
@@ -35,6 +35,11 @@ type Backend struct {
|
||||
authType uint32
|
||||
}
|
||||
|
||||
const (
|
||||
minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code.
|
||||
maxStartupPacketLen = 10000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source.
|
||||
)
|
||||
|
||||
// NewBackend creates a new Backend.
|
||||
func NewBackend(cr ChunkReader, w io.Writer) *Backend {
|
||||
return &Backend{cr: cr, w: w}
|
||||
@@ -56,6 +61,10 @@ func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) {
|
||||
}
|
||||
msgSize := int(binary.BigEndian.Uint32(buf) - 4)
|
||||
|
||||
if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen {
|
||||
return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize)
|
||||
}
|
||||
|
||||
buf, err = b.cr.Next(msgSize)
|
||||
if err != nil {
|
||||
return nil, translateEOFtoErrUnexpectedEOF(err)
|
||||
|
||||
@@ -4,8 +4,10 @@ import (
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/jackc/pgproto3/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBackendReceiveInterrupted(t *testing.T) {
|
||||
@@ -54,3 +56,63 @@ func TestBackendReceiveUnexpectedEOF(t *testing.T) {
|
||||
assert.Nil(t, msg)
|
||||
assert.Equal(t, io.ErrUnexpectedEOF, err)
|
||||
}
|
||||
|
||||
func TestStartupMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("valid StartupMessage", func(t *testing.T) {
|
||||
want := &pgproto3.StartupMessage{
|
||||
ProtocolVersion: pgproto3.ProtocolVersionNumber,
|
||||
Parameters: map[string]string{
|
||||
"username": "tester",
|
||||
},
|
||||
}
|
||||
dst := []byte{}
|
||||
dst = want.Encode(dst)
|
||||
|
||||
server := &interruptReader{}
|
||||
server.push(dst)
|
||||
|
||||
backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil)
|
||||
|
||||
msg, err := backend.ReceiveStartupMessage()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, want, msg)
|
||||
})
|
||||
|
||||
t.Run("invalid packet length", func(t *testing.T) {
|
||||
wantErr := "invalid length of startup packet"
|
||||
tests := []struct {
|
||||
name string
|
||||
packetLen uint32
|
||||
}{
|
||||
{
|
||||
name: "large packet length",
|
||||
// Since the StartupMessage contains the "Length of message contents
|
||||
// in bytes, including self", the max startup packet length is actually
|
||||
// 10000+4. Therefore, let's go past the limit with 10005
|
||||
packetLen: 10005,
|
||||
},
|
||||
{
|
||||
name: "short packet length",
|
||||
packetLen: 3,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := &interruptReader{}
|
||||
dst := []byte{}
|
||||
dst = pgio.AppendUint32(dst, tt.packetLen)
|
||||
dst = pgio.AppendUint32(dst, pgproto3.ProtocolVersionNumber)
|
||||
server.push(dst)
|
||||
|
||||
backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil)
|
||||
|
||||
msg, err := backend.ReceiveStartupMessage()
|
||||
require.Error(t, err)
|
||||
require.Nil(t, msg)
|
||||
require.Contains(t, err.Error(), wantErr)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user