• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8import os
9import shutil
10import tempfile
11import unittest
12
13import torch
14from executorch.backends.arm.test import common
15
16from executorch.backends.arm.test.tester.arm_tester import ArmTester
17
18logger = logging.getLogger(__name__)
19logger.setLevel(logging.INFO)
20
21
22class Linear(torch.nn.Module):
23    def __init__(
24        self,
25        in_features: int,
26        out_features: int = 3,
27        bias: bool = True,
28    ):
29        super().__init__()
30        self.inputs = (torch.randn(5, 10, 25, in_features),)
31        self.fc = torch.nn.Linear(
32            in_features=in_features,
33            out_features=out_features,
34            bias=bias,
35        )
36
37    def get_inputs(self):
38        return self.inputs
39
40    def forward(self, x):
41        return self.fc(x)
42
43
44class TestDumpPartitionedArtifact(unittest.TestCase):
45    """Tests dumping the partition artifact in ArmTester. Both to file and to stdout."""
46
47    def _tosa_MI_pipeline(self, module: torch.nn.Module, dump_file=None):
48        (
49            ArmTester(
50                module,
51                example_inputs=module.get_inputs(),
52                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
53            )
54            .export()
55            .to_edge()
56            .partition()
57            .dump_artifact(dump_file)
58            .dump_artifact()
59        )
60
61    def _tosa_BI_pipeline(self, module: torch.nn.Module, dump_file=None):
62        (
63            ArmTester(
64                module,
65                example_inputs=module.get_inputs(),
66                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
67            )
68            .quantize()
69            .export()
70            .to_edge_transform_and_lower()
71            .dump_artifact(dump_file)
72            .dump_artifact()
73        )
74
75    def _is_tosa_marker_in_file(self, tmp_file):
76        for line in open(tmp_file).readlines():
77            if "'name': 'main'" in line:
78                return True
79        return False
80
81    def test_MI_artifact(self):
82        model = Linear(20, 30)
83        tmp_file = os.path.join(tempfile.mkdtemp(), "tosa_dump_MI.txt")
84        self._tosa_MI_pipeline(model, dump_file=tmp_file)
85        assert os.path.exists(tmp_file), f"File {tmp_file} was not created"
86        if self._is_tosa_marker_in_file(tmp_file):
87            return  # Implicit pass test
88        self.fail("File does not contain TOSA dump!")
89
90    def test_BI_artifact(self):
91        model = Linear(20, 30)
92        tmp_file = os.path.join(tempfile.mkdtemp(), "tosa_dump_BI.txt")
93        self._tosa_BI_pipeline(model, dump_file=tmp_file)
94        assert os.path.exists(tmp_file), f"File {tmp_file} was not created"
95        if self._is_tosa_marker_in_file(tmp_file):
96            return  # Implicit pass test
97        self.fail("File does not contain TOSA dump!")
98
99
100class TestNumericalDiffPrints(unittest.TestCase):
101    """Tests trigging the exception printout from the ArmTester's run and compare function."""
102
103    def test_numerical_diff_prints(self):
104        model = Linear(20, 30)
105        tester = (
106            ArmTester(
107                model,
108                example_inputs=model.get_inputs(),
109                compile_spec=common.get_tosa_compile_spec(
110                    "TOSA-0.80.0+MI", permute_memory_to_nhwc=True
111                ),
112            )
113            .export()
114            .to_edge_transform_and_lower()
115            .to_executorch()
116        )
117        # We expect an assertion error here. Any other issues will cause the
118        # test to fail. Likewise the test will fail if the assertion error is
119        # not present.
120        try:
121            # Tolerate 0 difference => we want to trigger a numerical diff
122            tester.run_method_and_compare_outputs(atol=0, rtol=0, qtol=0)
123        except AssertionError:
124            pass  # Implicit pass test
125        else:
126            self.fail()
127
128
129def test_dump_ops_and_dtypes():
130    model = Linear(20, 30)
131    (
132        ArmTester(
133            model,
134            example_inputs=model.get_inputs(),
135            compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
136        )
137        .quantize()
138        .dump_dtype_distribution()
139        .dump_operator_distribution()
140        .export()
141        .dump_dtype_distribution()
142        .dump_operator_distribution()
143        .to_edge_transform_and_lower()
144        .dump_dtype_distribution()
145        .dump_operator_distribution()
146    )
147    # Just test that there are no execptions.
148
149
150def test_dump_ops_and_dtypes_parseable():
151    model = Linear(20, 30)
152    (
153        ArmTester(
154            model,
155            example_inputs=model.get_inputs(),
156            compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
157        )
158        .quantize()
159        .dump_dtype_distribution(print_table=False)
160        .dump_operator_distribution(print_table=False)
161        .export()
162        .dump_dtype_distribution(print_table=False)
163        .dump_operator_distribution(print_table=False)
164        .to_edge_transform_and_lower()
165        .dump_dtype_distribution(print_table=False)
166        .dump_operator_distribution(print_table=False)
167    )
168    # Just test that there are no execptions.
169
170
171class TestCollateTosaTests(unittest.TestCase):
172    """Tests the collation of TOSA tests through setting the environment variable TOSA_TESTCASE_BASE_PATH."""
173
174    def test_collate_tosa_BI_tests(self):
175        # Set the environment variable to trigger the collation of TOSA tests
176        os.environ["TOSA_TESTCASES_BASE_PATH"] = "test_collate_tosa_tests"
177        # Clear out the directory
178
179        model = Linear(20, 30)
180        (
181            ArmTester(
182                model,
183                example_inputs=model.get_inputs(),
184                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
185            )
186            .quantize()
187            .export()
188            .to_edge_transform_and_lower()
189            .to_executorch()
190        )
191        # test that the output directory is created and contains the expected files
192        assert os.path.exists(
193            "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests"
194        )
195        assert os.path.exists(
196            "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/output_tag5.tosa"
197        )
198        assert os.path.exists(
199            "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/desc_tag5.json"
200        )
201
202        os.environ.pop("TOSA_TESTCASES_BASE_PATH")
203        shutil.rmtree("test_collate_tosa_tests", ignore_errors=True)
204
205
206def test_dump_tosa_ops(caplog):
207    caplog.set_level(logging.INFO)
208    model = Linear(20, 30)
209    (
210        ArmTester(
211            model,
212            example_inputs=model.get_inputs(),
213            compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
214        )
215        .quantize()
216        .export()
217        .to_edge_transform_and_lower()
218        .dump_operator_distribution()
219    )
220    assert "TOSA operators:" in caplog.text
221
222
223def test_fail_dump_tosa_ops(caplog):
224    caplog.set_level(logging.INFO)
225
226    class Add(torch.nn.Module):
227        def forward(self, x):
228            return x + x
229
230    model = Add()
231    compile_spec = common.get_u55_compile_spec()
232    (
233        ArmTester(model, example_inputs=(torch.ones(5),), compile_spec=compile_spec)
234        .quantize()
235        .export()
236        .to_edge_transform_and_lower()
237        .dump_operator_distribution()
238    )
239    assert "Can not get operator distribution for Vela command stream." in caplog.text
240