1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import tempfile 10import unittest 11from typing import Dict, Tuple 12 13import torch 14 15from executorch.devtools import generate_etrecord, parse_etrecord 16 17from executorch.devtools.debug_format.base_schema import ( 18 OperatorGraph, 19 OperatorNode, 20 ValueNode, 21) 22 23from executorch.devtools.debug_format.et_schema import FXOperatorGraph 24from executorch.devtools.etdump import schema_flatcc as flatcc 25 26from executorch.devtools.etrecord.tests.etrecord_test import TestETRecord 27from executorch.devtools.inspector._inspector_utils import ( 28 calculate_time_scale_factor, 29 create_debug_handle_to_op_node_mapping, 30 EDGE_DIALECT_GRAPH_KEY, 31 find_populated_event, 32 gen_graphs_from_etrecord, 33 is_inference_output_equal, 34 TimeScale, 35) 36 37 38class TestInspectorUtils(unittest.TestCase): 39 def test_gen_graphs_from_etrecord(self): 40 captured_output, edge_output, et_output = TestETRecord().get_test_model() 41 with tempfile.TemporaryDirectory() as tmpdirname: 42 generate_etrecord( 43 tmpdirname + "/etrecord.bin", 44 edge_output, 45 et_output, 46 { 47 "aten_dialect_output": captured_output, 48 }, 49 ) 50 51 etrecord = parse_etrecord(tmpdirname + "/etrecord.bin") 52 53 graphs = gen_graphs_from_etrecord(etrecord) 54 55 self.assertTrue("aten_dialect_output/forward" in graphs) 56 self.assertTrue(EDGE_DIALECT_GRAPH_KEY in graphs) 57 58 self.assertTrue( 59 isinstance(graphs["aten_dialect_output/forward"], FXOperatorGraph) 60 ) 61 self.assertTrue(isinstance(graphs[EDGE_DIALECT_GRAPH_KEY], FXOperatorGraph)) 62 63 def test_create_debug_handle_to_op_node_mapping(self): 64 graph, expected_mapping = gen_mock_operator_graph_with_expected_map() 65 debug_handle_to_op_node_map = create_debug_handle_to_op_node_mapping(graph) 66 67 self.assertEqual(debug_handle_to_op_node_map, expected_mapping) 68 69 def test_find_populated_event(self): 70 profile_event = flatcc.ProfileEvent( 71 name="test_profile_event", 72 chain_index=1, 73 instruction_id=1, 74 delegate_debug_id_str="", 75 delegate_debug_id_int=-1, 76 delegate_debug_metadata="", 77 start_time=1001, 78 end_time=2002, 79 ) 80 debug_event = flatcc.DebugEvent( 81 name="test_debug_event", 82 chain_index=1, 83 instruction_id=0, 84 delegate_debug_id_str="56", 85 delegate_debug_id_int=-1, 86 debug_entry=flatcc.Value( 87 val=flatcc.ValueType.TENSOR.value, 88 tensor=flatcc.Tensor( 89 scalar_type=flatcc.ScalarType.INT, 90 sizes=[1], 91 strides=[1], 92 offset=12345, 93 ), 94 tensor_list=[ 95 flatcc.TensorList( 96 tensors=[ 97 flatcc.Tensor( 98 scalar_type=flatcc.ScalarType.INT, 99 sizes=[1], 100 strides=[1], 101 offset=12345, 102 ) 103 ] 104 ) 105 ], 106 int_value=flatcc.Int(1), 107 float_value=flatcc.Float(1.0), 108 double_value=flatcc.Double(1.0), 109 bool_value=flatcc.Bool(False), 110 output=flatcc.Bool(True), 111 ), 112 ) 113 114 # Profile Populated 115 event = flatcc.Event( 116 profile_event=profile_event, debug_event=None, allocation_event=None 117 ) 118 self.assertEqual(find_populated_event(event), profile_event) 119 120 # Debug Populated 121 event = flatcc.Event( 122 profile_event=None, debug_event=debug_event, allocation_event=None 123 ) 124 self.assertEqual(find_populated_event(event), debug_event) 125 126 # Neither Populated 127 event = flatcc.Event( 128 profile_event=None, debug_event=None, allocation_event=None 129 ) 130 with self.assertRaises(ValueError): 131 self.assertEqual(find_populated_event(event), profile_event) 132 133 # Both Populated (Returns Profile Event) 134 event = flatcc.Event( 135 profile_event=profile_event, debug_event=debug_event, allocation_event=None 136 ) 137 self.assertEqual(find_populated_event(event), profile_event) 138 139 def test_is_inference_output_equal_returns_false_for_different_tensor_values(self): 140 self.assertFalse( 141 is_inference_output_equal( 142 torch.tensor([[2, 1], [4, 3]]), 143 torch.tensor([[5, 6], [7, 8]]), 144 ) 145 ) 146 147 def test_is_inference_output_equal_returns_false_for_different_tensor_lists(self): 148 tensor_list_1 = ( 149 [ 150 torch.tensor([[1, 2], [3, 4]]), 151 torch.tensor([[1, 2], [3, 4]]), 152 torch.tensor([[1, 2], [3, 4]]), 153 ], 154 ) 155 tensor_list_2 = [ 156 torch.tensor([[1, 2], [3, 4]]), 157 torch.tensor([[1, 2], [3, 4]]), 158 ] 159 # Not equal because of different number of tensors 160 self.assertFalse(is_inference_output_equal(tensor_list_1, tensor_list_2)) 161 162 def test_is_inference_output_equal_returns_true_for_same_tensor_values(self): 163 self.assertTrue( 164 is_inference_output_equal( 165 torch.tensor([[2, 1], [4, 3]]), 166 torch.tensor([[2, 1], [4, 3]]), 167 ) 168 ) 169 170 def test_is_inference_output_equal_returns_true_for_same_strs(self): 171 self.assertTrue( 172 is_inference_output_equal( 173 "value_string", 174 "value_string", 175 ) 176 ) 177 178 def test_calculate_time_scale_factor_second_based(self): 179 self.assertEqual( 180 calculate_time_scale_factor(TimeScale.NS, TimeScale.MS), 1000000 181 ) 182 self.assertEqual( 183 calculate_time_scale_factor(TimeScale.MS, TimeScale.NS), 1 / 1000000 184 ) 185 186 def test_calculate_time_scale_factor_cycles(self): 187 self.assertEqual( 188 calculate_time_scale_factor(TimeScale.CYCLES, TimeScale.CYCLES), 1 189 ) 190 191 192def gen_mock_operator_graph_with_expected_map() -> ( 193 Tuple[OperatorGraph, Dict[int, OperatorNode]] 194): 195 # Make a mock OperatorGraph instance for testing 196 node_input = ValueNode("input") 197 mapping = {} 198 node_fused_conv_relu = OperatorNode( 199 "fused_conv_relu", 200 [node_input], 201 None, 202 metadata={ 203 "debug_handle": 111, 204 "stack_trace": "stack_trace_relu", 205 "nn_module_stack": "module_hierarchy_relu", 206 }, 207 ) 208 mapping[111] = node_fused_conv_relu 209 node_sin = OperatorNode( 210 "sin", 211 [node_fused_conv_relu], 212 None, 213 metadata={ 214 "debug_handle": 222, 215 "stack_trace": "stack_trace_sin", 216 "nn_module_stack": "module_hierarchy_sin", 217 }, 218 ) 219 mapping[222] = node_sin 220 node_cos = OperatorNode( 221 "cos", 222 [node_sin], 223 None, 224 metadata={ 225 "debug_handle": 333, 226 "stack_trace": "stack_trace_cos", 227 "nn_module_stack": "module_hierarchy_cos", 228 }, 229 ) 230 mapping[333] = node_cos 231 node_div = OperatorNode( 232 "div", 233 [node_cos], 234 None, 235 metadata={ 236 "debug_handle": 444, 237 "stack_trace": "stack_trace_div", 238 "nn_module_stack": "module_hierarchy_div", 239 }, 240 ) 241 mapping[444] = node_div 242 node_output = ValueNode("output", [node_div]) 243 return ( 244 OperatorGraph( 245 "mock_et_model", 246 [ 247 node_input, 248 node_fused_conv_relu, 249 node_sin, 250 node_cos, 251 node_div, 252 node_output, 253 ], 254 ), 255 mapping, 256 ) 257