1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import executorch.exir.tests.models as models 10 11import torch 12from executorch import exir 13from executorch.exir.backend.backend_api import to_backend 14from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo 15from executorch.exir.backend.test.qnn_backend_demo import QnnBackend 16from executorch.exir.delegate import executorch_call_delegate 17from hypothesis import given, settings, strategies as st 18 19 20class TestBackendDebugHandle(unittest.TestCase): 21 def test_add_mul_partitioner(self): 22 class Model(torch.nn.Module): 23 def __init__(self): 24 super().__init__() 25 26 def forward(self, a, x, b): 27 y = torch.mm(a, x) 28 z = y + b 29 a = z - a 30 y = torch.mm(a, x) 31 z = y + b 32 return z 33 34 m = Model() 35 inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2)) 36 37 ep = exir.capture(m, inputs, exir.CaptureConfig()).to_edge() 38 executorch_prog = ep 39 executorch_prog.exported_program = to_backend( 40 ep.exported_program, AddMulPartitionerDemo() 41 ) 42 lowered_nodes = [ 43 getattr(executorch_prog.exported_program.graph_module, node.target) 44 for node in executorch_prog.exported_program.graph.nodes 45 if node.op == "get_attr" 46 ] 47 for lowered_node in lowered_nodes: 48 self.assertEqual(len(lowered_node.meta["debug_handle_map"]), 2) 49 50 call_delegate_nodes = [ 51 node 52 for node in executorch_prog.exported_program.graph.nodes 53 if node.target == executorch_call_delegate 54 ] 55 56 for call_delegate_node in call_delegate_nodes: 57 self.assertIsNotNone(call_delegate_node.meta["debug_handle"]) 58 59 @given( 60 unlift=st.booleans(), # verify both lifted and unlifted graph 61 ) 62 @settings(deadline=500000) 63 def test_lowered_the_whole_model(self, unlift): 64 module_list = [ 65 models.Emformer(), 66 models.Repeat(), 67 models.ElementwiseAdd(), 68 models.MLP(), 69 models.ModelWithUnusedArg(), 70 ] 71 72 capture_config = ( 73 exir.CaptureConfig(enable_aot=True) if unlift else exir.CaptureConfig() 74 ) 75 76 edge_compile_config = exir.EdgeCompileConfig( 77 _check_ir_validity=False, _use_edge_ops=True 78 ) 79 80 for model in module_list: 81 model_inputs = model.get_random_inputs() 82 83 edgeir_m = exir.capture(model, model_inputs, capture_config).to_edge( 84 edge_compile_config 85 ) 86 lowered_model = to_backend( 87 QnnBackend.__name__, edgeir_m.exported_program, [] 88 ) 89 90 # QnnBackend compile all nodes as one node. The debug_handle_map will be like (1: (debug handle from all nodes)) 91 # Ensure there is only one debug identifier 92 self.assertEqual( 93 len(lowered_model.meta["debug_handle_map"].keys()), 94 1, 95 ) 96 97 all_debug_handles = list(lowered_model.meta["debug_handle_map"].values())[0] 98 self.assertEqual( 99 len(all_debug_handles), 100 len(lowered_model.original_module.graph.nodes), 101 ) 102 103 class ComposedModel(torch.nn.Module): 104 def __init__(self, lowered_model): 105 super().__init__() 106 self.back_bone = lowered_model 107 108 def forward(self, *args): 109 return self.back_bone(*args) 110 111 edge = exir.capture( 112 ComposedModel(lowered_model), model_inputs, capture_config 113 ).to_edge(edge_compile_config) 114 lowered_nodes = [ 115 getattr(edge.exported_program.graph_module, node.target) 116 for node in edge.exported_program.graph.nodes 117 if node.op == "get_attr" 118 ] 119 for lowered_node in lowered_nodes: 120 self.assertEqual( 121 len(lowered_node.meta["debug_handle_map"].keys()), 122 1, 123 ) 124 125 all_debug_handles = list( 126 lowered_node.meta["debug_handle_map"].values() 127 )[0] 128 self.assertEqual( 129 len(all_debug_handles), 130 len(lowered_node.original_module.graph.nodes), 131 ) 132