Merge branch 'master' into backend-unexpected-eof
This commit is contained in:
@@ -35,6 +35,11 @@ type Backend struct {
|
|||||||
authType uint32
|
authType uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code.
|
||||||
|
maxStartupPacketLen = 10000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source.
|
||||||
|
)
|
||||||
|
|
||||||
// NewBackend creates a new Backend.
|
// NewBackend creates a new Backend.
|
||||||
func NewBackend(cr ChunkReader, w io.Writer) *Backend {
|
func NewBackend(cr ChunkReader, w io.Writer) *Backend {
|
||||||
return &Backend{cr: cr, w: w}
|
return &Backend{cr: cr, w: w}
|
||||||
@@ -56,6 +61,10 @@ func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) {
|
|||||||
}
|
}
|
||||||
msgSize := int(binary.BigEndian.Uint32(buf) - 4)
|
msgSize := int(binary.BigEndian.Uint32(buf) - 4)
|
||||||
|
|
||||||
|
if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen {
|
||||||
|
return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize)
|
||||||
|
}
|
||||||
|
|
||||||
buf, err = b.cr.Next(msgSize)
|
buf, err = b.cr.Next(msgSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, translateEOFtoErrUnexpectedEOF(err)
|
return nil, translateEOFtoErrUnexpectedEOF(err)
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
"github.com/jackc/pgproto3/v2"
|
"github.com/jackc/pgproto3/v2"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBackendReceiveInterrupted(t *testing.T) {
|
func TestBackendReceiveInterrupted(t *testing.T) {
|
||||||
@@ -54,3 +56,63 @@ func TestBackendReceiveUnexpectedEOF(t *testing.T) {
|
|||||||
assert.Nil(t, msg)
|
assert.Nil(t, msg)
|
||||||
assert.Equal(t, io.ErrUnexpectedEOF, err)
|
assert.Equal(t, io.ErrUnexpectedEOF, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStartupMessage(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("valid StartupMessage", func(t *testing.T) {
|
||||||
|
want := &pgproto3.StartupMessage{
|
||||||
|
ProtocolVersion: pgproto3.ProtocolVersionNumber,
|
||||||
|
Parameters: map[string]string{
|
||||||
|
"username": "tester",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dst := []byte{}
|
||||||
|
dst = want.Encode(dst)
|
||||||
|
|
||||||
|
server := &interruptReader{}
|
||||||
|
server.push(dst)
|
||||||
|
|
||||||
|
backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil)
|
||||||
|
|
||||||
|
msg, err := backend.ReceiveStartupMessage()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, want, msg)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid packet length", func(t *testing.T) {
|
||||||
|
wantErr := "invalid length of startup packet"
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
packetLen uint32
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "large packet length",
|
||||||
|
// Since the StartupMessage contains the "Length of message contents
|
||||||
|
// in bytes, including self", the max startup packet length is actually
|
||||||
|
// 10000+4. Therefore, let's go past the limit with 10005
|
||||||
|
packetLen: 10005,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "short packet length",
|
||||||
|
packetLen: 3,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
server := &interruptReader{}
|
||||||
|
dst := []byte{}
|
||||||
|
dst = pgio.AppendUint32(dst, tt.packetLen)
|
||||||
|
dst = pgio.AppendUint32(dst, pgproto3.ProtocolVersionNumber)
|
||||||
|
server.push(dst)
|
||||||
|
|
||||||
|
backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil)
|
||||||
|
|
||||||
|
msg, err := backend.ReceiveStartupMessage()
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, msg)
|
||||||
|
require.Contains(t, err.Error(), wantErr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user