• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6import itertools
7import unittest
8
9import torch
10from executorch.backends.arm.quantizer.arm_quantizer_utils import is_annotated
11from executorch.backends.arm.test import common
12from executorch.backends.arm.test.tester.arm_tester import ArmTester
13from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
14
15
16class SingleOpModel(torch.nn.Module):
17    def __init__(self, op, example_input, **op_kwargs) -> None:
18        super().__init__()
19        self.op = op
20        self._example_input = example_input
21        self.op_kwargs = op_kwargs
22
23    def forward(self, x):
24        return self.op(x, **self.op_kwargs)
25
26    def example_inputs(self):
27        return self._example_input
28
29
30class TestGenericAnnotator(unittest.TestCase):
31    def check_annotation(self, model):
32        tester = ArmTester(
33            model,
34            model.example_inputs(),
35            common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
36        )
37        quant_model = tester.quantize().get_artifact()
38        partitions = get_source_partitions(quant_model.graph, [model.op])
39        partitions = list(itertools.chain.from_iterable(partitions.values()))
40
41        assert len(partitions) == 1
42        partition = partitions[0]
43        assert all(is_annotated(node) for node in partition.nodes)
44
45    def test_squeeze(self):
46        self.check_annotation(SingleOpModel(torch.squeeze, (torch.rand(8, 8, 1),)))
47        self.check_annotation(SingleOpModel(torch.squeeze_copy, (torch.rand(8, 8, 1),)))
48
49    def test_unsqueeze(self):
50        self.check_annotation(
51            SingleOpModel(torch.unsqueeze, (torch.rand(8, 8),), dim=0)
52        )
53        self.check_annotation(
54            SingleOpModel(torch.unsqueeze_copy, (torch.rand(8, 8),), dim=0)
55        )
56
57    def test_reshape(self):
58        self.check_annotation(
59            SingleOpModel(torch.reshape, (torch.randn(8, 8),), shape=(64,)),
60        )
61
62    def test_view(self):
63        self.check_annotation(
64            SingleOpModel(torch.view_copy, (torch.randn(4, 4),), size=(2, 8)),
65        )
66
67    def test_slice(self):
68        self.check_annotation(
69            SingleOpModel(torch.slice_copy, (torch.randn(3, 4),)),
70        )
71
72    def test_transpose(self):
73        self.check_annotation(
74            SingleOpModel(torch.transpose, (torch.randn(2, 3),), dim0=0, dim1=1),
75        )
76        self.check_annotation(
77            SingleOpModel(torch.transpose_copy, (torch.randn(2, 3),), dim0=0, dim1=1),
78        )
79
80    def test_tile(self):
81        self.check_annotation(
82            SingleOpModel(torch.tile, (torch.randn(4, 4),), dims=(2,)),
83        )
84
85    def test_flip(self):
86        self.check_annotation(
87            SingleOpModel(torch.flip, (torch.randn(2, 4),), dims=(0, 1)),
88        )
89