• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack.test.tester import Tester
11
12
13class TestHardswish(unittest.TestCase):
14    class Hardswish(torch.nn.Module):
15        def __init__(self):
16            super().__init__()
17            self.hardswish = torch.nn.Hardswish()
18
19        def forward(self, x):
20            return self.hardswish(x)
21
22    class HardswishFunctional(torch.nn.Module):
23        def forward(self, x):
24            return torch.nn.functional.hardswish(x)
25
26    def _test_hardswish(self, inputs):
27        (
28            Tester(self.Hardswish(), inputs)
29            .export()
30            .check_count({"torch.ops.aten.hardswish.default": 1})
31            .to_edge_transform_and_lower()
32            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
33            .check_not(
34                [
35                    "executorch_exir_dialects_edge__ops_aten_hardswish_default",
36                ]
37            )
38            .to_executorch()
39            .serialize()
40            .run_method_and_compare_outputs()
41        )
42
43    def test_fp16_hardswish(self):
44        inputs = (torch.randn(1, 3, 3).to(torch.float16),)
45        self._test_hardswish(inputs)
46
47    def test_fp32_hardswish(self):
48        inputs = (torch.randn(1, 3, 3),)
49        self._test_hardswish(inputs)
50
51    def test_fp32_hardswish_functional(self):
52        inputs = (torch.randn(1, 3, 3),)
53        (
54            Tester(self.HardswishFunctional(), inputs)
55            .export()
56            .check_count({"torch.ops.aten.hardswish.default": 1})
57            .to_edge_transform_and_lower()
58            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
59            .check_not(
60                [
61                    "executorch_exir_dialects_edge__ops_aten_hardswish_default",
62                ]
63            )
64            .to_executorch()
65            .serialize()
66            .run_method_and_compare_outputs()
67        )
68