• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# Copyright 2024 Arm Limited and/or its affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8import unittest
9from typing import Tuple
10
11import torch
12
13from executorch.backends.arm.quantizer.arm_quantizer import (
14    ArmQuantizer,
15    get_symmetric_quantization_config,
16)
17
18from executorch.backends.arm.test import common
19from executorch.backends.arm.test.tester.arm_tester import ArmTester
20from executorch.backends.xnnpack.test.tester.tester import Quantize
21from parameterized import parameterized
22
23
24test_data_suite = [
25    # (test_name, test_data)
26    ("zeros", torch.zeros(1, 10, 10, 10)),
27    ("ones", torch.ones(10, 10, 10)),
28    ("rand", torch.rand(10, 10) - 0.5),
29    ("randn_pos", torch.randn(10) + 10),
30    ("randn_neg", torch.randn(10) - 10),
31    ("ramp", torch.arange(-16, 16, 0.2)),
32]
33
34
35class TestHardTanh(unittest.TestCase):
36    """Tests HardTanh Operator."""
37
38    class HardTanh(torch.nn.Module):
39
40        def __init__(self):
41            super().__init__()
42
43            self.hardTanh = torch.nn.Hardtanh()
44
45        def forward(self, x):
46            return self.hardTanh(x)
47
48    def _test_hardtanh_tosa_MI_pipeline(
49        self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
50    ):
51        (
52            ArmTester(
53                module,
54                example_inputs=test_data,
55                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
56            )
57            .export()
58            .check(["torch.ops.aten.hardtanh.default"])
59            .check_not(["torch.ops.quantized_decomposed"])
60            .to_edge()
61            .partition()
62            .check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"])
63            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
64            .to_executorch()
65            .run_method_and_compare_outputs(inputs=test_data)
66        )
67
68    def _test_hardtanh_tosa_BI_pipeline(
69        self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
70    ):
71        quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
72        (
73            ArmTester(
74                module,
75                example_inputs=test_data,
76                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
77            )
78            .quantize(Quantize(quantizer, get_symmetric_quantization_config()))
79            .export()
80            .check_count({"torch.ops.aten.hardtanh.default": 1})
81            .check(["torch.ops.quantized_decomposed"])
82            .to_edge()
83            .partition()
84            .check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"])
85            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
86            .to_executorch()
87            .run_method_and_compare_outputs(inputs=test_data)
88        )
89
90    def _test_hardtanh_tosa_u55_BI_pipeline(
91        self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
92    ):
93        quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
94        (
95            ArmTester(
96                module,
97                example_inputs=test_data,
98                compile_spec=common.get_u55_compile_spec(),
99            )
100            .quantize(Quantize(quantizer, get_symmetric_quantization_config()))
101            .export()
102            .check_count({"torch.ops.aten.hardtanh.default": 1})
103            .check(["torch.ops.quantized_decomposed"])
104            .to_edge()
105            .partition()
106            .check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"])
107            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
108            .to_executorch()
109        )
110
111    @parameterized.expand(test_data_suite)
112    def test_hardtanh_tosa_MI(
113        self,
114        test_name: str,
115        test_data: torch.Tensor,
116    ):
117        self._test_hardtanh_tosa_MI_pipeline(self.HardTanh(), (test_data,))
118
119    @parameterized.expand(test_data_suite)
120    def test_hardtanh_tosa_BI(self, test_name: str, test_data: torch.Tensor):
121        self._test_hardtanh_tosa_BI_pipeline(self.HardTanh(), (test_data,))
122
123    @parameterized.expand(test_data_suite)
124    def test_hardtanh_tosa_u55_BI(self, test_name: str, test_data: torch.Tensor):
125        self._test_hardtanh_tosa_u55_BI_pipeline(self.HardTanh(), (test_data,))
126