1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.backends.arm.test import common 11from executorch.backends.arm.test.tester.arm_tester import ArmTester 12from parameterized import parameterized 13 14test_data_t = tuple[str, torch.Tensor] 15test_data_suite: list[test_data_t] = [ 16 ( 17 "op_reciprocal_rank1_ones", 18 torch.ones(5), 19 ), 20 ( 21 "op_reciprocal_rank1_rand", 22 torch.rand(5) * 5, 23 ), 24 ("op_reciprocal_rank1_negative_ones", torch.ones(5) * (-1)), 25 ("op_reciprocal_rank4_ones", torch.ones(5, 10, 25, 20)), 26 ("op_reciprocal_rank4_negative_ones", (-1) * torch.ones(5, 10, 25, 20)), 27 ("op_reciprocal_rank4_ones_reciprocal_negative", torch.ones(5, 10, 25, 20)), 28 ("op_reciprocal_rank4_large_rand", 200 * torch.rand(5, 10, 25, 20)), 29 ("op_reciprocal_rank4_negative_large_rand", (-200) * torch.rand(5, 10, 25, 20)), 30 ("op_reciprocal_rank4_large_randn", 200 * torch.randn(5, 10, 25, 20) + 1), 31] 32 33 34class TestReciprocal(unittest.TestCase): 35 """Tests reciprocal""" 36 37 class Reciprocal(torch.nn.Module): 38 39 def forward(self, input_: torch.Tensor): 40 return input_.reciprocal() 41 42 def _test_reciprocal_tosa_MI_pipeline( 43 self, module: torch.nn.Module, test_data: tuple[torch.Tensor] 44 ): 45 ( 46 ArmTester( 47 module, 48 example_inputs=test_data, 49 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), 50 ) 51 .export() 52 .check_count({"torch.ops.aten.reciprocal.default": 1}) 53 .check_not(["torch.ops.quantized_decomposed"]) 54 .to_edge() 55 .partition() 56 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 57 .to_executorch() 58 .run_method_and_compare_outputs(inputs=test_data) 59 ) 60 61 def _test_reciprocal_tosa_BI_pipeline( 62 self, module: torch.nn.Module, test_data: tuple[torch.Tensor] 63 ): 64 ( 65 ArmTester( 66 module, 67 example_inputs=test_data, 68 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 69 ) 70 .quantize() 71 .export() 72 .check_count({"torch.ops.aten.reciprocal.default": 1}) 73 .check(["torch.ops.quantized_decomposed"]) 74 .to_edge() 75 .partition() 76 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 77 .to_executorch() 78 .run_method_and_compare_outputs(inputs=test_data) 79 ) 80 81 def _test_reciprocal_u55_BI_pipeline( 82 self, module: torch.nn.Module, test_data: tuple[torch.Tensor] 83 ): 84 ( 85 ArmTester( 86 module, 87 example_inputs=test_data, 88 compile_spec=common.get_u55_compile_spec(), 89 ) 90 .quantize() 91 .export() 92 .check_count({"torch.ops.aten.reciprocal.default": 1}) 93 .check(["torch.ops.quantized_decomposed"]) 94 .to_edge() 95 .partition() 96 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 97 .to_executorch() 98 ) 99 100 @parameterized.expand(test_data_suite) 101 def test_reciprocal_tosa_MI(self, test_name: str, input_: torch.Tensor): 102 test_data = (input_,) 103 self._test_reciprocal_tosa_MI_pipeline(self.Reciprocal(), test_data) 104 105 # Expected to fail since ArmQuantizer cannot quantize a Reciprocal layer 106 # TODO(MLETORCH-129) 107 @parameterized.expand(test_data_suite) 108 def test_reciprocal_tosa_BI(self, test_name: str, input_: torch.Tensor): 109 110 test_data = (input_,) 111 self._test_reciprocal_tosa_BI_pipeline(self.Reciprocal(), test_data) 112 113 @parameterized.expand(test_data_suite) 114 def test_reciprocal_u55_BI(self, test_name: str, input_: torch.Tensor): 115 test_data = (input_,) 116 self._test_reciprocal_u55_BI_pipeline(self.Reciprocal(), test_data) 117