diff --git a/backend.go b/backend.go index c9fa87ff..e9ba38fc 100644 --- a/backend.go +++ b/backend.go @@ -35,6 +35,11 @@ type Backend struct { authType uint32 } +const ( + minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code. + maxStartupPacketLen = 10000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source. +) + // NewBackend creates a new Backend. func NewBackend(cr ChunkReader, w io.Writer) *Backend { return &Backend{cr: cr, w: w} @@ -56,6 +61,10 @@ func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) { } msgSize := int(binary.BigEndian.Uint32(buf) - 4) + if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen { + return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize) + } + buf, err = b.cr.Next(msgSize) if err != nil { return nil, translateEOFtoErrUnexpectedEOF(err) diff --git a/backend_test.go b/backend_test.go index 19970c34..5e9a2ac5 100644 --- a/backend_test.go +++ b/backend_test.go @@ -4,8 +4,10 @@ import ( "io" "testing" + "github.com/jackc/pgio" "github.com/jackc/pgproto3/v2" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestBackendReceiveInterrupted(t *testing.T) { @@ -54,3 +56,63 @@ func TestBackendReceiveUnexpectedEOF(t *testing.T) { assert.Nil(t, msg) assert.Equal(t, io.ErrUnexpectedEOF, err) } + +func TestStartupMessage(t *testing.T) { + t.Parallel() + + t.Run("valid StartupMessage", func(t *testing.T) { + want := &pgproto3.StartupMessage{ + ProtocolVersion: pgproto3.ProtocolVersionNumber, + Parameters: map[string]string{ + "username": "tester", + }, + } + dst := []byte{} + dst = want.Encode(dst) + + server := &interruptReader{} + server.push(dst) + + backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil) + + msg, err := backend.ReceiveStartupMessage() + require.NoError(t, err) + require.Equal(t, want, msg) + }) + + t.Run("invalid packet length", func(t *testing.T) { + wantErr := "invalid length of startup packet" + tests := []struct { + name string + packetLen uint32 + }{ + { + name: "large packet length", + // Since the StartupMessage contains the "Length of message contents + // in bytes, including self", the max startup packet length is actually + // 10000+4. Therefore, let's go past the limit with 10005 + packetLen: 10005, + }, + { + name: "short packet length", + packetLen: 3, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := &interruptReader{} + dst := []byte{} + dst = pgio.AppendUint32(dst, tt.packetLen) + dst = pgio.AppendUint32(dst, pgproto3.ProtocolVersionNumber) + server.push(dst) + + backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil) + + msg, err := backend.ReceiveStartupMessage() + require.Error(t, err) + require.Nil(t, msg) + require.Contains(t, err.Error(), wantErr) + }) + } + }) +}