diff --git a/backend.go b/backend.go index 232aa11d..6944f80d 100644 --- a/backend.go +++ b/backend.go @@ -35,6 +35,11 @@ type Backend struct { authType uint32 } +const ( + minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code. + maxStartupPacketLen = 10000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source. +) + // NewBackend creates a new Backend. func NewBackend(cr ChunkReader, w io.Writer) *Backend { return &Backend{cr: cr, w: w} @@ -56,6 +61,10 @@ func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) { } msgSize := int(binary.BigEndian.Uint32(buf) - 4) + if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen { + return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize) + } + buf, err = b.cr.Next(msgSize) if err != nil { return nil, err diff --git a/backend_test.go b/backend_test.go index 43a3f76c..3cfde003 100644 --- a/backend_test.go +++ b/backend_test.go @@ -3,7 +3,9 @@ package pgproto3_test import ( "testing" + "github.com/jackc/pgio" "github.com/jackc/pgproto3/v2" + "github.com/stretchr/testify/require" ) func TestBackendReceiveInterrupted(t *testing.T) { @@ -32,3 +34,64 @@ func TestBackendReceiveInterrupted(t *testing.T) { t.Fatalf("unexpected msg: %v", msg) } } + +func TestStartupMessage(t *testing.T) { + t.Parallel() + + t.Run("valid StartupMessage", func(t *testing.T) { + want := &pgproto3.StartupMessage{ + ProtocolVersion: pgproto3.ProtocolVersionNumber, + Parameters: map[string]string{ + "username": "tester", + }, + } + dst := []byte{} + dst = want.Encode(dst) + + server := &interruptReader{} + server.push(dst) + + backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil) + + msg, err := backend.ReceiveStartupMessage() + require.NoError(t, err) + require.Equal(t, want, msg) + }) + + t.Run("invalid packet length", func(t *testing.T) { + wantErr := "invalid length of startup packet" + tests := []struct { + name string + packetLen uint32 + }{ + { + name: "large packet length", + // Since the StartupMessage contains the "Length of message contents + // in bytes, including self", the max startup packet length is actually + // 10000+4. Therefore, let's go past the limit with 10005 + packetLen: 10005, + }, + { + name: "short packet length", + packetLen: 3, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := &interruptReader{} + dst := []byte{} + dst = pgio.AppendUint32(dst, tt.packetLen) + dst = pgio.AppendUint32(dst, pgproto3.ProtocolVersionNumber) + server.push(dst) + + backend := pgproto3.NewBackend(pgproto3.NewChunkReader(server), nil) + + msg, err := backend.ReceiveStartupMessage() + require.Error(t, err) + require.Nil(t, msg) + require.Contains(t, err.Error(), wantErr) + }) + } + }) + +}