• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.arm.test import common
11from executorch.backends.arm.test.tester.arm_tester import ArmTester
12from parameterized import parameterized
13
14test_data_t = tuple[str, torch.Tensor]
15test_data_suite: list[test_data_t] = [
16    (
17        "op_reciprocal_rank1_ones",
18        torch.ones(5),
19    ),
20    (
21        "op_reciprocal_rank1_rand",
22        torch.rand(5) * 5,
23    ),
24    ("op_reciprocal_rank1_negative_ones", torch.ones(5) * (-1)),
25    ("op_reciprocal_rank4_ones", torch.ones(5, 10, 25, 20)),
26    ("op_reciprocal_rank4_negative_ones", (-1) * torch.ones(5, 10, 25, 20)),
27    ("op_reciprocal_rank4_ones_reciprocal_negative", torch.ones(5, 10, 25, 20)),
28    ("op_reciprocal_rank4_large_rand", 200 * torch.rand(5, 10, 25, 20)),
29    ("op_reciprocal_rank4_negative_large_rand", (-200) * torch.rand(5, 10, 25, 20)),
30    ("op_reciprocal_rank4_large_randn", 200 * torch.randn(5, 10, 25, 20) + 1),
31]
32
33
34class TestReciprocal(unittest.TestCase):
35    """Tests reciprocal"""
36
37    class Reciprocal(torch.nn.Module):
38
39        def forward(self, input_: torch.Tensor):
40            return input_.reciprocal()
41
42    def _test_reciprocal_tosa_MI_pipeline(
43        self, module: torch.nn.Module, test_data: tuple[torch.Tensor]
44    ):
45        (
46            ArmTester(
47                module,
48                example_inputs=test_data,
49                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
50            )
51            .export()
52            .check_count({"torch.ops.aten.reciprocal.default": 1})
53            .check_not(["torch.ops.quantized_decomposed"])
54            .to_edge()
55            .partition()
56            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
57            .to_executorch()
58            .run_method_and_compare_outputs(inputs=test_data)
59        )
60
61    def _test_reciprocal_tosa_BI_pipeline(
62        self, module: torch.nn.Module, test_data: tuple[torch.Tensor]
63    ):
64        (
65            ArmTester(
66                module,
67                example_inputs=test_data,
68                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
69            )
70            .quantize()
71            .export()
72            .check_count({"torch.ops.aten.reciprocal.default": 1})
73            .check(["torch.ops.quantized_decomposed"])
74            .to_edge()
75            .partition()
76            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
77            .to_executorch()
78            .run_method_and_compare_outputs(inputs=test_data)
79        )
80
81    def _test_reciprocal_u55_BI_pipeline(
82        self, module: torch.nn.Module, test_data: tuple[torch.Tensor]
83    ):
84        (
85            ArmTester(
86                module,
87                example_inputs=test_data,
88                compile_spec=common.get_u55_compile_spec(),
89            )
90            .quantize()
91            .export()
92            .check_count({"torch.ops.aten.reciprocal.default": 1})
93            .check(["torch.ops.quantized_decomposed"])
94            .to_edge()
95            .partition()
96            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
97            .to_executorch()
98        )
99
100    @parameterized.expand(test_data_suite)
101    def test_reciprocal_tosa_MI(self, test_name: str, input_: torch.Tensor):
102        test_data = (input_,)
103        self._test_reciprocal_tosa_MI_pipeline(self.Reciprocal(), test_data)
104
105    # Expected to fail since ArmQuantizer cannot quantize a Reciprocal layer
106    # TODO(MLETORCH-129)
107    @parameterized.expand(test_data_suite)
108    def test_reciprocal_tosa_BI(self, test_name: str, input_: torch.Tensor):
109
110        test_data = (input_,)
111        self._test_reciprocal_tosa_BI_pipeline(self.Reciprocal(), test_data)
112
113    @parameterized.expand(test_data_suite)
114    def test_reciprocal_u55_BI(self, test_name: str, input_: torch.Tensor):
115        test_data = (input_,)
116        self._test_reciprocal_u55_BI_pipeline(self.Reciprocal(), test_data)
117