• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import executorch.exir.tests.models as models
10
11import torch
12from executorch import exir
13from executorch.exir.backend.backend_api import to_backend
14from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
15from executorch.exir.backend.test.qnn_backend_demo import QnnBackend
16from executorch.exir.delegate import executorch_call_delegate
17from hypothesis import given, settings, strategies as st
18
19
20class TestBackendDebugHandle(unittest.TestCase):
21    def test_add_mul_partitioner(self):
22        class Model(torch.nn.Module):
23            def __init__(self):
24                super().__init__()
25
26            def forward(self, a, x, b):
27                y = torch.mm(a, x)
28                z = y + b
29                a = z - a
30                y = torch.mm(a, x)
31                z = y + b
32                return z
33
34        m = Model()
35        inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
36
37        ep = exir.capture(m, inputs, exir.CaptureConfig()).to_edge()
38        executorch_prog = ep
39        executorch_prog.exported_program = to_backend(
40            ep.exported_program, AddMulPartitionerDemo()
41        )
42        lowered_nodes = [
43            getattr(executorch_prog.exported_program.graph_module, node.target)
44            for node in executorch_prog.exported_program.graph.nodes
45            if node.op == "get_attr"
46        ]
47        for lowered_node in lowered_nodes:
48            self.assertEqual(len(lowered_node.meta["debug_handle_map"]), 2)
49
50        call_delegate_nodes = [
51            node
52            for node in executorch_prog.exported_program.graph.nodes
53            if node.target == executorch_call_delegate
54        ]
55
56        for call_delegate_node in call_delegate_nodes:
57            self.assertIsNotNone(call_delegate_node.meta["debug_handle"])
58
59    @given(
60        unlift=st.booleans(),  # verify both lifted and unlifted graph
61    )
62    @settings(deadline=500000)
63    def test_lowered_the_whole_model(self, unlift):
64        module_list = [
65            models.Emformer(),
66            models.Repeat(),
67            models.ElementwiseAdd(),
68            models.MLP(),
69            models.ModelWithUnusedArg(),
70        ]
71
72        capture_config = (
73            exir.CaptureConfig(enable_aot=True) if unlift else exir.CaptureConfig()
74        )
75
76        edge_compile_config = exir.EdgeCompileConfig(
77            _check_ir_validity=False, _use_edge_ops=True
78        )
79
80        for model in module_list:
81            model_inputs = model.get_random_inputs()
82
83            edgeir_m = exir.capture(model, model_inputs, capture_config).to_edge(
84                edge_compile_config
85            )
86            lowered_model = to_backend(
87                QnnBackend.__name__, edgeir_m.exported_program, []
88            )
89
90            # QnnBackend compile all nodes as one node. The debug_handle_map will be like (1: (debug handle from all nodes))
91            # Ensure there is only one debug identifier
92            self.assertEqual(
93                len(lowered_model.meta["debug_handle_map"].keys()),
94                1,
95            )
96
97            all_debug_handles = list(lowered_model.meta["debug_handle_map"].values())[0]
98            self.assertEqual(
99                len(all_debug_handles),
100                len(lowered_model.original_module.graph.nodes),
101            )
102
103            class ComposedModel(torch.nn.Module):
104                def __init__(self, lowered_model):
105                    super().__init__()
106                    self.back_bone = lowered_model
107
108                def forward(self, *args):
109                    return self.back_bone(*args)
110
111            edge = exir.capture(
112                ComposedModel(lowered_model), model_inputs, capture_config
113            ).to_edge(edge_compile_config)
114            lowered_nodes = [
115                getattr(edge.exported_program.graph_module, node.target)
116                for node in edge.exported_program.graph.nodes
117                if node.op == "get_attr"
118            ]
119            for lowered_node in lowered_nodes:
120                self.assertEqual(
121                    len(lowered_node.meta["debug_handle_map"].keys()),
122                    1,
123                )
124
125                all_debug_handles = list(
126                    lowered_node.meta["debug_handle_map"].values()
127                )[0]
128                self.assertEqual(
129                    len(all_debug_handles),
130                    len(lowered_node.original_module.graph.nodes),
131                )
132