# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """End-to-end profiler tests. This must be built and run with `buck2 -c executorch.prof_enabled=true`. """ import unittest import torch from executorch.exir import to_edge from executorch.extension.pybindings.portable_lib import ( _create_profile_block, _dump_profile_results, _load_for_executorch_from_buffer, _reset_profile_results, ) from executorch.extension.pytree import tree_flatten from executorch.profiler.fb.parse_profiler_results import profile_table from executorch.profiler.parse_profiler_results import ( deserialize_profile_results, profile_aggregate_framework_tax, profile_framework_tax_table, ) from torch.export import export class Module(torch.nn.Module): def __init__(self): super().__init__() self.register_buffer("a", 3 * torch.ones(2, 2, dtype=torch.float)) self.register_buffer("b", 2 * torch.ones(2, 2, dtype=torch.float)) def forward(self, x): a = torch.mul(self.a, x) b = torch.add(a, self.b) return b class TestCustomOps(unittest.TestCase): @classmethod def setUpClass(cls) -> None: model = Module() inputs = (torch.ones(2, 2, dtype=torch.float),) # The serialized program file. This must live longer than cls.module, # because the C++ pybindings will have a pointer to it. But none of the # tests should need to touch it. cls.__buffer: bytes = to_edge(export(model, inputs)).to_executorch().buffer cls.module = _load_for_executorch_from_buffer(cls.__buffer) # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. cls.inputs_flattened, _ = tree_flatten(inputs) cls.module.run_method("forward", tuple(cls.inputs_flattened)) prof_dump = _dump_profile_results() assert ( len(prof_dump) > 0 ), "prof_dump is empty; may need to build with `-c executorch.prof_enabled=true`" cls.prof_results, cls.mem_results = deserialize_profile_results(prof_dump) cls.expect_ops = ["native_call_add.out", "native_call_mul.out"] def test_profiler_new_block(self) -> None: block_names = ["block_1", "block_2"] _reset_profile_results() _create_profile_block(block_names[0]) self.module.run_method("forward", tuple(self.inputs_flattened)) _create_profile_block(block_names[1]) self.module.run_method("forward", tuple(self.inputs_flattened)) prof_dump = _dump_profile_results() self.assertGreater( len(prof_dump), 0, "prof_dump is empty; may need to build with `-c executorch.prof_enabled=true`", ) prof_results, mem_results = deserialize_profile_results(prof_dump) for i, (block_name_, _) in enumerate(prof_results.items()): self.assertTrue(block_names[i] == block_name_) self.assertEqual(len(prof_results), 2) def test_profiler_expected_ops(self) -> None: found_count = 0 for block_name, prof_data_list in self.prof_results.items(): for prof_event in prof_data_list: if prof_event.name in self.expect_ops: found_count += 1 self.assertTrue(block_name == "default") self.assertEqual(found_count, len(self.expect_ops)) def test_profile_framework_tax(self) -> None: prof_agg_data = profile_aggregate_framework_tax(self.prof_results) for name, framework_tax in prof_agg_data.items(): self.assertTrue(len(framework_tax.exec_time) == 1) self.assertTrue(len(framework_tax.kernel_and_delegate_time) == 1) self.assertTrue(len(framework_tax.framework_tax) == 1) self.assertTrue(float(framework_tax.framework_tax[0]) < 100) self.assertTrue(name == "default") def test_gen_profile_table(self) -> None: prof_table = profile_table(self.prof_results) found_count = 0 for table in prof_table: for entry in table: for op in self.expect_ops: found_count += 1 if op in entry.get_string() else 0 self.assertEqual(found_count, len(self.expect_ops)) def test_gen_profile_framework_tax_table(self) -> None: prof_agg_data = profile_aggregate_framework_tax(self.prof_results) prof_framework_tax_table = profile_framework_tax_table(prof_agg_data) expected_entries = [ "Model execution time", "Time spent in kernels", "Framework tax", ] found_count = 0 for table in prof_framework_tax_table: for entry in table: for expected_entry in expected_entries: found_count += 1 if expected_entry in entry.get_string() else 0 self.assertEqual(found_count, len(expected_entries)) def main() -> None: unittest.main() if __name__ == "__main__": main() # pragma: no cover