diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index 09f04141..83dea963 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -223,7 +223,13 @@ func (f *Frontend) Receive() (BackendMessage, error) { } f.msgType = header[0] - f.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 + + msgLength := int(binary.BigEndian.Uint32(header[1:])) + if msgLength < 4 { + return nil, fmt.Errorf("invalid message length: %d", msgLength) + } + + f.bodyLen = msgLength - 4 f.partialMsg = true } diff --git a/pgproto3/fuzz_test.go b/pgproto3/fuzz_test.go new file mode 100644 index 00000000..84ea8430 --- /dev/null +++ b/pgproto3/fuzz_test.go @@ -0,0 +1,29 @@ +package pgproto3_test + +import ( + "bytes" + "testing" + + "github.com/jackc/pgx/v5/pgproto3" + "github.com/stretchr/testify/require" +) + +func FuzzFrontend(f *testing.F) { + testcases := [][]byte{ + {'Z', 0, 0, 0, 5}, + } + for _, tc := range testcases { + f.Add(tc) + } + f.Fuzz(func(t *testing.T, encodedMsg []byte) { + r := &bytes.Buffer{} + w := &bytes.Buffer{} + fe := pgproto3.NewFrontend(r, w) + + _, err := r.Write(encodedMsg) + require.NoError(t, err) + + // Not checking anything other than no panic. + fe.Receive() + }) +} diff --git a/pgproto3/testdata/fuzz/FuzzFrontend/65d91093341a68b16f04605e392b0501847a9b35d3857e67872046dbdc04913e b/pgproto3/testdata/fuzz/FuzzFrontend/65d91093341a68b16f04605e392b0501847a9b35d3857e67872046dbdc04913e new file mode 100644 index 00000000..4db40929 --- /dev/null +++ b/pgproto3/testdata/fuzz/FuzzFrontend/65d91093341a68b16f04605e392b0501847a9b35d3857e67872046dbdc04913e @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("0\x00\x00\x00\x02")