1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import logging 8import os 9import shutil 10import tempfile 11import unittest 12 13import torch 14from executorch.backends.arm.test import common 15 16from executorch.backends.arm.test.tester.arm_tester import ArmTester 17 18logger = logging.getLogger(__name__) 19logger.setLevel(logging.INFO) 20 21 22class Linear(torch.nn.Module): 23 def __init__( 24 self, 25 in_features: int, 26 out_features: int = 3, 27 bias: bool = True, 28 ): 29 super().__init__() 30 self.inputs = (torch.randn(5, 10, 25, in_features),) 31 self.fc = torch.nn.Linear( 32 in_features=in_features, 33 out_features=out_features, 34 bias=bias, 35 ) 36 37 def get_inputs(self): 38 return self.inputs 39 40 def forward(self, x): 41 return self.fc(x) 42 43 44class TestDumpPartitionedArtifact(unittest.TestCase): 45 """Tests dumping the partition artifact in ArmTester. Both to file and to stdout.""" 46 47 def _tosa_MI_pipeline(self, module: torch.nn.Module, dump_file=None): 48 ( 49 ArmTester( 50 module, 51 example_inputs=module.get_inputs(), 52 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), 53 ) 54 .export() 55 .to_edge() 56 .partition() 57 .dump_artifact(dump_file) 58 .dump_artifact() 59 ) 60 61 def _tosa_BI_pipeline(self, module: torch.nn.Module, dump_file=None): 62 ( 63 ArmTester( 64 module, 65 example_inputs=module.get_inputs(), 66 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 67 ) 68 .quantize() 69 .export() 70 .to_edge_transform_and_lower() 71 .dump_artifact(dump_file) 72 .dump_artifact() 73 ) 74 75 def _is_tosa_marker_in_file(self, tmp_file): 76 for line in open(tmp_file).readlines(): 77 if "'name': 'main'" in line: 78 return True 79 return False 80 81 def test_MI_artifact(self): 82 model = Linear(20, 30) 83 tmp_file = os.path.join(tempfile.mkdtemp(), "tosa_dump_MI.txt") 84 self._tosa_MI_pipeline(model, dump_file=tmp_file) 85 assert os.path.exists(tmp_file), f"File {tmp_file} was not created" 86 if self._is_tosa_marker_in_file(tmp_file): 87 return # Implicit pass test 88 self.fail("File does not contain TOSA dump!") 89 90 def test_BI_artifact(self): 91 model = Linear(20, 30) 92 tmp_file = os.path.join(tempfile.mkdtemp(), "tosa_dump_BI.txt") 93 self._tosa_BI_pipeline(model, dump_file=tmp_file) 94 assert os.path.exists(tmp_file), f"File {tmp_file} was not created" 95 if self._is_tosa_marker_in_file(tmp_file): 96 return # Implicit pass test 97 self.fail("File does not contain TOSA dump!") 98 99 100class TestNumericalDiffPrints(unittest.TestCase): 101 """Tests trigging the exception printout from the ArmTester's run and compare function.""" 102 103 def test_numerical_diff_prints(self): 104 model = Linear(20, 30) 105 tester = ( 106 ArmTester( 107 model, 108 example_inputs=model.get_inputs(), 109 compile_spec=common.get_tosa_compile_spec( 110 "TOSA-0.80.0+MI", permute_memory_to_nhwc=True 111 ), 112 ) 113 .export() 114 .to_edge_transform_and_lower() 115 .to_executorch() 116 ) 117 # We expect an assertion error here. Any other issues will cause the 118 # test to fail. Likewise the test will fail if the assertion error is 119 # not present. 120 try: 121 # Tolerate 0 difference => we want to trigger a numerical diff 122 tester.run_method_and_compare_outputs(atol=0, rtol=0, qtol=0) 123 except AssertionError: 124 pass # Implicit pass test 125 else: 126 self.fail() 127 128 129def test_dump_ops_and_dtypes(): 130 model = Linear(20, 30) 131 ( 132 ArmTester( 133 model, 134 example_inputs=model.get_inputs(), 135 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 136 ) 137 .quantize() 138 .dump_dtype_distribution() 139 .dump_operator_distribution() 140 .export() 141 .dump_dtype_distribution() 142 .dump_operator_distribution() 143 .to_edge_transform_and_lower() 144 .dump_dtype_distribution() 145 .dump_operator_distribution() 146 ) 147 # Just test that there are no execptions. 148 149 150def test_dump_ops_and_dtypes_parseable(): 151 model = Linear(20, 30) 152 ( 153 ArmTester( 154 model, 155 example_inputs=model.get_inputs(), 156 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 157 ) 158 .quantize() 159 .dump_dtype_distribution(print_table=False) 160 .dump_operator_distribution(print_table=False) 161 .export() 162 .dump_dtype_distribution(print_table=False) 163 .dump_operator_distribution(print_table=False) 164 .to_edge_transform_and_lower() 165 .dump_dtype_distribution(print_table=False) 166 .dump_operator_distribution(print_table=False) 167 ) 168 # Just test that there are no execptions. 169 170 171class TestCollateTosaTests(unittest.TestCase): 172 """Tests the collation of TOSA tests through setting the environment variable TOSA_TESTCASE_BASE_PATH.""" 173 174 def test_collate_tosa_BI_tests(self): 175 # Set the environment variable to trigger the collation of TOSA tests 176 os.environ["TOSA_TESTCASES_BASE_PATH"] = "test_collate_tosa_tests" 177 # Clear out the directory 178 179 model = Linear(20, 30) 180 ( 181 ArmTester( 182 model, 183 example_inputs=model.get_inputs(), 184 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 185 ) 186 .quantize() 187 .export() 188 .to_edge_transform_and_lower() 189 .to_executorch() 190 ) 191 # test that the output directory is created and contains the expected files 192 assert os.path.exists( 193 "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests" 194 ) 195 assert os.path.exists( 196 "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/output_tag5.tosa" 197 ) 198 assert os.path.exists( 199 "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/desc_tag5.json" 200 ) 201 202 os.environ.pop("TOSA_TESTCASES_BASE_PATH") 203 shutil.rmtree("test_collate_tosa_tests", ignore_errors=True) 204 205 206def test_dump_tosa_ops(caplog): 207 caplog.set_level(logging.INFO) 208 model = Linear(20, 30) 209 ( 210 ArmTester( 211 model, 212 example_inputs=model.get_inputs(), 213 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 214 ) 215 .quantize() 216 .export() 217 .to_edge_transform_and_lower() 218 .dump_operator_distribution() 219 ) 220 assert "TOSA operators:" in caplog.text 221 222 223def test_fail_dump_tosa_ops(caplog): 224 caplog.set_level(logging.INFO) 225 226 class Add(torch.nn.Module): 227 def forward(self, x): 228 return x + x 229 230 model = Add() 231 compile_spec = common.get_u55_compile_spec() 232 ( 233 ArmTester(model, example_inputs=(torch.ones(5),), compile_spec=compile_spec) 234 .quantize() 235 .export() 236 .to_edge_transform_and_lower() 237 .dump_operator_distribution() 238 ) 239 assert "Can not get operator distribution for Vela command stream." in caplog.text 240