1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.backends.xnnpack.test.tester import Tester 11 12 13class TestHardswish(unittest.TestCase): 14 class Hardswish(torch.nn.Module): 15 def __init__(self): 16 super().__init__() 17 self.hardswish = torch.nn.Hardswish() 18 19 def forward(self, x): 20 return self.hardswish(x) 21 22 class HardswishFunctional(torch.nn.Module): 23 def forward(self, x): 24 return torch.nn.functional.hardswish(x) 25 26 def _test_hardswish(self, inputs): 27 ( 28 Tester(self.Hardswish(), inputs) 29 .export() 30 .check_count({"torch.ops.aten.hardswish.default": 1}) 31 .to_edge_transform_and_lower() 32 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 33 .check_not( 34 [ 35 "executorch_exir_dialects_edge__ops_aten_hardswish_default", 36 ] 37 ) 38 .to_executorch() 39 .serialize() 40 .run_method_and_compare_outputs() 41 ) 42 43 def test_fp16_hardswish(self): 44 inputs = (torch.randn(1, 3, 3).to(torch.float16),) 45 self._test_hardswish(inputs) 46 47 def test_fp32_hardswish(self): 48 inputs = (torch.randn(1, 3, 3),) 49 self._test_hardswish(inputs) 50 51 def test_fp32_hardswish_functional(self): 52 inputs = (torch.randn(1, 3, 3),) 53 ( 54 Tester(self.HardswishFunctional(), inputs) 55 .export() 56 .check_count({"torch.ops.aten.hardswish.default": 1}) 57 .to_edge_transform_and_lower() 58 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 59 .check_not( 60 [ 61 "executorch_exir_dialects_edge__ops_aten_hardswish_default", 62 ] 63 ) 64 .to_executorch() 65 .serialize() 66 .run_method_and_compare_outputs() 67 ) 68