1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# Copyright 2024 Arm Limited and/or its affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8import unittest 9from typing import Tuple 10 11import torch 12 13from executorch.backends.arm.quantizer.arm_quantizer import ( 14 ArmQuantizer, 15 get_symmetric_quantization_config, 16) 17 18from executorch.backends.arm.test import common 19from executorch.backends.arm.test.tester.arm_tester import ArmTester 20from executorch.backends.xnnpack.test.tester.tester import Quantize 21from parameterized import parameterized 22 23 24test_data_suite = [ 25 # (test_name, test_data) 26 ("zeros", torch.zeros(1, 10, 10, 10)), 27 ("ones", torch.ones(10, 10, 10)), 28 ("rand", torch.rand(10, 10) - 0.5), 29 ("randn_pos", torch.randn(10) + 10), 30 ("randn_neg", torch.randn(10) - 10), 31 ("ramp", torch.arange(-16, 16, 0.2)), 32] 33 34 35class TestHardTanh(unittest.TestCase): 36 """Tests HardTanh Operator.""" 37 38 class HardTanh(torch.nn.Module): 39 40 def __init__(self): 41 super().__init__() 42 43 self.hardTanh = torch.nn.Hardtanh() 44 45 def forward(self, x): 46 return self.hardTanh(x) 47 48 def _test_hardtanh_tosa_MI_pipeline( 49 self, module: torch.nn.Module, test_data: Tuple[torch.tensor] 50 ): 51 ( 52 ArmTester( 53 module, 54 example_inputs=test_data, 55 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), 56 ) 57 .export() 58 .check(["torch.ops.aten.hardtanh.default"]) 59 .check_not(["torch.ops.quantized_decomposed"]) 60 .to_edge() 61 .partition() 62 .check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"]) 63 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 64 .to_executorch() 65 .run_method_and_compare_outputs(inputs=test_data) 66 ) 67 68 def _test_hardtanh_tosa_BI_pipeline( 69 self, module: torch.nn.Module, test_data: Tuple[torch.tensor] 70 ): 71 quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) 72 ( 73 ArmTester( 74 module, 75 example_inputs=test_data, 76 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 77 ) 78 .quantize(Quantize(quantizer, get_symmetric_quantization_config())) 79 .export() 80 .check_count({"torch.ops.aten.hardtanh.default": 1}) 81 .check(["torch.ops.quantized_decomposed"]) 82 .to_edge() 83 .partition() 84 .check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"]) 85 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 86 .to_executorch() 87 .run_method_and_compare_outputs(inputs=test_data) 88 ) 89 90 def _test_hardtanh_tosa_u55_BI_pipeline( 91 self, module: torch.nn.Module, test_data: Tuple[torch.tensor] 92 ): 93 quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) 94 ( 95 ArmTester( 96 module, 97 example_inputs=test_data, 98 compile_spec=common.get_u55_compile_spec(), 99 ) 100 .quantize(Quantize(quantizer, get_symmetric_quantization_config())) 101 .export() 102 .check_count({"torch.ops.aten.hardtanh.default": 1}) 103 .check(["torch.ops.quantized_decomposed"]) 104 .to_edge() 105 .partition() 106 .check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"]) 107 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 108 .to_executorch() 109 ) 110 111 @parameterized.expand(test_data_suite) 112 def test_hardtanh_tosa_MI( 113 self, 114 test_name: str, 115 test_data: torch.Tensor, 116 ): 117 self._test_hardtanh_tosa_MI_pipeline(self.HardTanh(), (test_data,)) 118 119 @parameterized.expand(test_data_suite) 120 def test_hardtanh_tosa_BI(self, test_name: str, test_data: torch.Tensor): 121 self._test_hardtanh_tosa_BI_pipeline(self.HardTanh(), (test_data,)) 122 123 @parameterized.expand(test_data_suite) 124 def test_hardtanh_tosa_u55_BI(self, test_name: str, test_data: torch.Tensor): 125 self._test_hardtanh_tosa_u55_BI_pipeline(self.HardTanh(), (test_data,)) 126