# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright 2024 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import unittest from typing import Tuple import torch from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.exir.backend.backend_details import CompileSpec from parameterized import parameterized class TestMeanDim(unittest.TestCase): """Tests MeanDim, called AdaptiveAvgPool2d in Pytorch.""" class AdaptiveAveragePool2d(torch.nn.Module): test_data_suite = [ # (test_name, test_data) ( "zeros", torch.zeros(1, 1280, 7, 7), ), ( "ones", torch.ones(1, 1280, 7, 7), ), ( "rand", torch.rand(1, 1280, 7, 7), ), ( "randn", torch.randn(1, 1280, 7, 7), ), ] def __init__(self): super().__init__() self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d(output_size=(1, 1)) def forward(self, x): return self.adaptive_avg_pool2d(x) class MeanDim(torch.nn.Module): test_data_suite = [ # (test_name, test_data) ("zeros", torch.zeros(1, 1280, 7, 7), -1, True), ("ones", torch.ones(1, 1280, 7, 7), (-1, 2), True), ( "rand", torch.rand(1, 1280, 7, 7), (-1), True, ), ( "randn", torch.randn(1, 1280, 7, 7), (-1, -2, -3), True, ), ] def __init__(self, dim: int | list[int] = -1, keepdim: bool = True): super().__init__() self.dim = dim self.keepdim = keepdim def forward(self, x: torch.Tensor): return x.mean(dim=self.dim, keepdim=self.keepdim) def _test_adaptive_avg_pool2d_tosa_MI_pipeline( self, module: torch.nn.Module, test_data: Tuple[torch.tensor] ): ( ArmTester( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check(["torch.ops.aten.adaptive_avg_pool2d.default"]) .check_not(["torch.ops.quantized_decomposed"]) .to_edge() .partition() .check_not(["executorch_exir_dialects_edge__ops_aten_mean_dim"]) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .run_method_and_compare_outputs(inputs=test_data) ) def _test_adaptive_avg_pool2d_tosa_BI_pipeline( self, module: torch.nn.Module, test_data: Tuple[torch.tensor] ): ( ArmTester( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() .check_count({"torch.ops.aten.adaptive_avg_pool2d.default": 1}) .check(["torch.ops.quantized_decomposed"]) .to_edge() .partition() .check_not(["executorch_exir_dialects_edge__ops_aten_mean_dim"]) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .run_method_and_compare_outputs(inputs=test_data) ) def _test_adaptive_avg_pool2d_tosa_ethosu_BI_pipeline( self, module: torch.nn.Module, compile_spec: CompileSpec, test_data: Tuple[torch.tensor], ): ( ArmTester( module, example_inputs=test_data, compile_spec=compile_spec, ) .quantize() .export() .check(["torch.ops.aten.adaptive_avg_pool2d.default"]) .check(["torch.ops.quantized_decomposed"]) .to_edge() .partition() .check_not( [ "executorch_exir_dialects_edge__ops_aten_mean_dim", "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default", ] ) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() ) def _test_meandim_tosa_MI_pipeline( self, module: torch.nn.Module, test_data: Tuple[torch.tensor] ): ( ArmTester( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), ) .export() .check_not(["torch.ops.quantized_decomposed"]) .to_edge() .partition() .check_not(["executorch_exir_dialects_edge__ops_aten_mean_dim"]) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .run_method_and_compare_outputs(inputs=test_data) ) def _test_meandim_tosa_BI_pipeline( self, module: torch.nn.Module, test_data: Tuple[torch.tensor] ): ( ArmTester( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), ) .quantize() .export() .check(["torch.ops.quantized_decomposed"]) .to_edge() .partition() .check_not(["executorch_exir_dialects_edge__ops_aten_mean_dim"]) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .run_method_and_compare_outputs(inputs=test_data, qtol=1.0) ) def _test_meandim_tosa_ethosu_BI_pipeline( self, module: torch.nn.Module, compile_spec: CompileSpec, test_data: Tuple[torch.tensor], ): ( ArmTester( module, example_inputs=test_data, compile_spec=compile_spec, ) .quantize() .export() .check(["torch.ops.quantized_decomposed"]) .to_edge() .partition() .check_not( [ "executorch_exir_dialects_edge__ops_aten_mean_dim", "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default", ] ) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() ) @parameterized.expand(AdaptiveAveragePool2d.test_data_suite) def test_adaptive_avg_pool2d_tosa_MI( self, test_name: str, test_data: torch.Tensor, ): self._test_adaptive_avg_pool2d_tosa_MI_pipeline( self.AdaptiveAveragePool2d(), (test_data,) ) @parameterized.expand(AdaptiveAveragePool2d.test_data_suite) def test_adaptive_avg_pool2d_tosa_BI( self, test_name: str, test_data: torch.Tensor, ): self._test_adaptive_avg_pool2d_tosa_BI_pipeline( self.AdaptiveAveragePool2d(), (test_data,) ) @parameterized.expand(AdaptiveAveragePool2d.test_data_suite) def test_adaptive_avg_pool2d_tosa_u55_BI( self, test_name: str, test_data: torch.Tensor, ): self._test_adaptive_avg_pool2d_tosa_ethosu_BI_pipeline( self.AdaptiveAveragePool2d(), common.get_u55_compile_spec(), (test_data,) ) @parameterized.expand(AdaptiveAveragePool2d.test_data_suite) def test_adaptive_avg_pool2d_tosa_u85_BI( self, test_name: str, test_data: torch.Tensor, ): self._test_adaptive_avg_pool2d_tosa_ethosu_BI_pipeline( self.AdaptiveAveragePool2d(), common.get_u85_compile_spec(), (test_data,) ) @parameterized.expand(MeanDim.test_data_suite) def test_meandim_tosa_MI( self, test_name: str, test_data: torch.Tensor, dim: int | list[int] = -1, keepdim: bool = True, ): self._test_meandim_tosa_MI_pipeline(self.MeanDim(dim, keepdim), (test_data,)) @parameterized.expand(MeanDim.test_data_suite) def test_meandim_tosa_BI( self, test_name: str, test_data: torch.Tensor, dim: int | list[int] = -1, keepdim: bool = True, ): self._test_meandim_tosa_BI_pipeline(self.MeanDim(dim, keepdim), (test_data,)) @parameterized.expand(MeanDim.test_data_suite) def test_meandim_tosa_u55_BI( self, test_name: str, test_data: torch.Tensor, dim: int | list[int] = -1, keepdim: bool = True, ): self._test_meandim_tosa_ethosu_BI_pipeline( self.MeanDim(dim, keepdim), common.get_u55_compile_spec(), (test_data,), ) @parameterized.expand(MeanDim.test_data_suite) def test_meandim_tosa_u85_BI( self, test_name: str, test_data: torch.Tensor, dim: int | list[int] = -1, keepdim: bool = True, ): self._test_meandim_tosa_ethosu_BI_pipeline( self.MeanDim(dim, keepdim), common.get_u85_compile_spec(), (test_data,), )