• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
11    ConvertMeanDimToAveragePool,
12)
13
14from executorch.backends.arm.test import common
15from executorch.backends.arm.test.tester.arm_tester import ArmTester
16
17from executorch.backends.xnnpack.test.tester.tester import RunPasses
18
19
20class MeanDim(torch.nn.Module):
21    def forward(self, x):
22        return torch.mean(x, dim=[-1, -2], keepdim=True)
23
24    def get_inputs(self):
25        return (torch.rand(1, 1280, 7, 7),)
26
27
28class MeanDim2(torch.nn.Module):
29    def forward(self, x):
30        return torch.mean(x, dim=1)
31
32    def get_inputs(self):
33        return (torch.rand(1, 1280, 7, 7),)
34
35
36class TestMeandimToAveragePool2dPass(unittest.TestCase):
37    """
38    Tests the MeanDimToAveragePool2dPass which converts mean.dim to average_pool2d
39    for the special case where dim is [-1, -2] and keepdim is True.
40    """
41
42    def test_tosa_BI_meandim_to_averagepool(self):
43        module = MeanDim()
44        test_pass_stage = RunPasses([ConvertMeanDimToAveragePool])
45        (
46            ArmTester(
47                module,
48                example_inputs=module.get_inputs(),
49                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
50            )
51            .quantize()
52            .export()
53            .to_edge()
54            .check(["executorch_exir_dialects_edge__ops_aten_mean_dim"])
55            .run_passes(test_pass_stage)
56            .check(["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"])
57        )
58
59    def test_tosa_BI_meandim_no_modification(self):
60        module = MeanDim2()
61        test_pass_stage = RunPasses([ConvertMeanDimToAveragePool])
62        (
63            ArmTester(
64                module,
65                example_inputs=module.get_inputs(),
66                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
67            )
68            .quantize()
69            .export()
70            .to_edge()
71            .check(["executorch_exir_dialects_edge__ops_aten_mean_dim"])
72            .run_passes(test_pass_stage)
73            .check(["executorch_exir_dialects_edge__ops_aten_mean_dim"])
74            .check_not(["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"])
75        )
76