• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import tempfile
10import unittest
11from typing import Dict, Tuple
12
13import torch
14
15from executorch.devtools import generate_etrecord, parse_etrecord
16
17from executorch.devtools.debug_format.base_schema import (
18    OperatorGraph,
19    OperatorNode,
20    ValueNode,
21)
22
23from executorch.devtools.debug_format.et_schema import FXOperatorGraph
24from executorch.devtools.etdump import schema_flatcc as flatcc
25
26from executorch.devtools.etrecord.tests.etrecord_test import TestETRecord
27from executorch.devtools.inspector._inspector_utils import (
28    calculate_time_scale_factor,
29    create_debug_handle_to_op_node_mapping,
30    EDGE_DIALECT_GRAPH_KEY,
31    find_populated_event,
32    gen_graphs_from_etrecord,
33    is_inference_output_equal,
34    TimeScale,
35)
36
37
38class TestInspectorUtils(unittest.TestCase):
39    def test_gen_graphs_from_etrecord(self):
40        captured_output, edge_output, et_output = TestETRecord().get_test_model()
41        with tempfile.TemporaryDirectory() as tmpdirname:
42            generate_etrecord(
43                tmpdirname + "/etrecord.bin",
44                edge_output,
45                et_output,
46                {
47                    "aten_dialect_output": captured_output,
48                },
49            )
50
51            etrecord = parse_etrecord(tmpdirname + "/etrecord.bin")
52
53            graphs = gen_graphs_from_etrecord(etrecord)
54
55            self.assertTrue("aten_dialect_output/forward" in graphs)
56            self.assertTrue(EDGE_DIALECT_GRAPH_KEY in graphs)
57
58            self.assertTrue(
59                isinstance(graphs["aten_dialect_output/forward"], FXOperatorGraph)
60            )
61            self.assertTrue(isinstance(graphs[EDGE_DIALECT_GRAPH_KEY], FXOperatorGraph))
62
63    def test_create_debug_handle_to_op_node_mapping(self):
64        graph, expected_mapping = gen_mock_operator_graph_with_expected_map()
65        debug_handle_to_op_node_map = create_debug_handle_to_op_node_mapping(graph)
66
67        self.assertEqual(debug_handle_to_op_node_map, expected_mapping)
68
69    def test_find_populated_event(self):
70        profile_event = flatcc.ProfileEvent(
71            name="test_profile_event",
72            chain_index=1,
73            instruction_id=1,
74            delegate_debug_id_str="",
75            delegate_debug_id_int=-1,
76            delegate_debug_metadata="",
77            start_time=1001,
78            end_time=2002,
79        )
80        debug_event = flatcc.DebugEvent(
81            name="test_debug_event",
82            chain_index=1,
83            instruction_id=0,
84            delegate_debug_id_str="56",
85            delegate_debug_id_int=-1,
86            debug_entry=flatcc.Value(
87                val=flatcc.ValueType.TENSOR.value,
88                tensor=flatcc.Tensor(
89                    scalar_type=flatcc.ScalarType.INT,
90                    sizes=[1],
91                    strides=[1],
92                    offset=12345,
93                ),
94                tensor_list=[
95                    flatcc.TensorList(
96                        tensors=[
97                            flatcc.Tensor(
98                                scalar_type=flatcc.ScalarType.INT,
99                                sizes=[1],
100                                strides=[1],
101                                offset=12345,
102                            )
103                        ]
104                    )
105                ],
106                int_value=flatcc.Int(1),
107                float_value=flatcc.Float(1.0),
108                double_value=flatcc.Double(1.0),
109                bool_value=flatcc.Bool(False),
110                output=flatcc.Bool(True),
111            ),
112        )
113
114        # Profile Populated
115        event = flatcc.Event(
116            profile_event=profile_event, debug_event=None, allocation_event=None
117        )
118        self.assertEqual(find_populated_event(event), profile_event)
119
120        # Debug Populated
121        event = flatcc.Event(
122            profile_event=None, debug_event=debug_event, allocation_event=None
123        )
124        self.assertEqual(find_populated_event(event), debug_event)
125
126        # Neither Populated
127        event = flatcc.Event(
128            profile_event=None, debug_event=None, allocation_event=None
129        )
130        with self.assertRaises(ValueError):
131            self.assertEqual(find_populated_event(event), profile_event)
132
133        # Both Populated (Returns Profile Event)
134        event = flatcc.Event(
135            profile_event=profile_event, debug_event=debug_event, allocation_event=None
136        )
137        self.assertEqual(find_populated_event(event), profile_event)
138
139    def test_is_inference_output_equal_returns_false_for_different_tensor_values(self):
140        self.assertFalse(
141            is_inference_output_equal(
142                torch.tensor([[2, 1], [4, 3]]),
143                torch.tensor([[5, 6], [7, 8]]),
144            )
145        )
146
147    def test_is_inference_output_equal_returns_false_for_different_tensor_lists(self):
148        tensor_list_1 = (
149            [
150                torch.tensor([[1, 2], [3, 4]]),
151                torch.tensor([[1, 2], [3, 4]]),
152                torch.tensor([[1, 2], [3, 4]]),
153            ],
154        )
155        tensor_list_2 = [
156            torch.tensor([[1, 2], [3, 4]]),
157            torch.tensor([[1, 2], [3, 4]]),
158        ]
159        # Not equal because of different number of tensors
160        self.assertFalse(is_inference_output_equal(tensor_list_1, tensor_list_2))
161
162    def test_is_inference_output_equal_returns_true_for_same_tensor_values(self):
163        self.assertTrue(
164            is_inference_output_equal(
165                torch.tensor([[2, 1], [4, 3]]),
166                torch.tensor([[2, 1], [4, 3]]),
167            )
168        )
169
170    def test_is_inference_output_equal_returns_true_for_same_strs(self):
171        self.assertTrue(
172            is_inference_output_equal(
173                "value_string",
174                "value_string",
175            )
176        )
177
178    def test_calculate_time_scale_factor_second_based(self):
179        self.assertEqual(
180            calculate_time_scale_factor(TimeScale.NS, TimeScale.MS), 1000000
181        )
182        self.assertEqual(
183            calculate_time_scale_factor(TimeScale.MS, TimeScale.NS), 1 / 1000000
184        )
185
186    def test_calculate_time_scale_factor_cycles(self):
187        self.assertEqual(
188            calculate_time_scale_factor(TimeScale.CYCLES, TimeScale.CYCLES), 1
189        )
190
191
192def gen_mock_operator_graph_with_expected_map() -> (
193    Tuple[OperatorGraph, Dict[int, OperatorNode]]
194):
195    # Make a mock OperatorGraph instance for testing
196    node_input = ValueNode("input")
197    mapping = {}
198    node_fused_conv_relu = OperatorNode(
199        "fused_conv_relu",
200        [node_input],
201        None,
202        metadata={
203            "debug_handle": 111,
204            "stack_trace": "stack_trace_relu",
205            "nn_module_stack": "module_hierarchy_relu",
206        },
207    )
208    mapping[111] = node_fused_conv_relu
209    node_sin = OperatorNode(
210        "sin",
211        [node_fused_conv_relu],
212        None,
213        metadata={
214            "debug_handle": 222,
215            "stack_trace": "stack_trace_sin",
216            "nn_module_stack": "module_hierarchy_sin",
217        },
218    )
219    mapping[222] = node_sin
220    node_cos = OperatorNode(
221        "cos",
222        [node_sin],
223        None,
224        metadata={
225            "debug_handle": 333,
226            "stack_trace": "stack_trace_cos",
227            "nn_module_stack": "module_hierarchy_cos",
228        },
229    )
230    mapping[333] = node_cos
231    node_div = OperatorNode(
232        "div",
233        [node_cos],
234        None,
235        metadata={
236            "debug_handle": 444,
237            "stack_trace": "stack_trace_div",
238            "nn_module_stack": "module_hierarchy_div",
239        },
240    )
241    mapping[444] = node_div
242    node_output = ValueNode("output", [node_div])
243    return (
244        OperatorGraph(
245            "mock_et_model",
246            [
247                node_input,
248                node_fused_conv_relu,
249                node_sin,
250                node_cos,
251                node_div,
252                node_output,
253            ],
254        ),
255        mapping,
256    )
257