• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8import unittest
9
10from typing import Tuple
11
12import pytest
13
14import torch
15from executorch.backends.arm.test import common
16from executorch.backends.arm.test.tester.arm_tester import ArmTester
17from executorch.exir.backend.backend_details import CompileSpec
18from parameterized import parameterized
19
20logger = logging.getLogger(__name__)
21logger.setLevel(logging.INFO)
22
23"""
24This file contain unit tests where conv are combined with other ops.
25"""
26
27
28class ComboBlockBottleneckResidual(torch.nn.Module):
29    # This is the essence of MobileNetV2. Ref: https://arxiv.org/abs/1801.04381
30    edge_op_list = [
31        "executorch_exir_dialects_edge__ops_aten_convolution_default",
32        "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default",
33        "executorch_exir_dialects_edge__ops_aten_hardtanh_default",
34        "executorch_exir_dialects_edge__ops_aten_add_Tensor",
35    ]
36
37    def __init__(self):
38        super().__init__()
39        # (t, c, n, s) = (6, 96, 1, 1)
40        # 1. 1x1 CONV2d + ReLU6 (Pointwise)
41        self.pointwise_conv2d = torch.nn.Conv2d(
42            in_channels=64, out_channels=384, kernel_size=1, stride=1, groups=1
43        )  ## (1, 384, 81, 81)
44        self.batch_norm2d_16 = torch.nn.BatchNorm2d(384, affine=False)
45        self.relu6 = torch.nn.ReLU6()
46
47        # 2. 3x3 DepthwiseConv2d + ReLu6
48        self.depthwise_conv2d = torch.nn.Conv2d(
49            in_channels=384,
50            out_channels=384,
51            kernel_size=3,
52            padding=1,
53            stride=1,
54            groups=384,
55        )  ## (1, 384, H, W)
56
57        # 3. Linear 1x1 Conv2d
58        self.pointwise_conv2d_linear = torch.nn.Conv2d(
59            in_channels=384, out_channels=64, kernel_size=1, stride=1, groups=1
60        )  ## (1, 64, 81, 81)
61
62    def get_inputs(self) -> Tuple[torch.Tensor]:
63        return (torch.randn(1, 64, 81, 81),)
64
65    def forward(self, x):
66        input = x
67        # 1x1 CONV2d + ReLU6 (Pointwise)
68        x = self.pointwise_conv2d(x)
69        x = self.batch_norm2d_16(x)
70        x = self.relu6(x)
71
72        # 3x3 DepthwiseConv2d + ReLu6
73        x = self.depthwise_conv2d(x)
74        x = self.batch_norm2d_16(x)
75        x = self.relu6(x)
76
77        # Linear 1x1 Conv2d
78        x = self.pointwise_conv2d_linear(x)
79
80        # Final Residual Connection
81        x = x + input
82
83        return x
84
85
86class ComboConv2dMeandim(torch.nn.Module):
87    edge_op_list = [
88        "executorch_exir_dialects_edge__ops_aten_convolution_default",
89        "executorch_exir_dialects_edge__ops_aten_mean_dim",
90    ]
91
92    def __init__(self):
93        super().__init__()
94        self.conv2d = torch.nn.Conv2d(
95            in_channels=3, out_channels=10, kernel_size=5, stride=1, bias=False
96        )
97        # will be specialized to aten.mean.dim
98        self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
99
100    def get_inputs(self) -> Tuple[torch.Tensor]:
101        return (torch.randn(1, 3, 128, 128),)
102
103    def forward(self, x):
104        x = self.conv2d(x)
105        return self.adaptive_avg_pool2d(x)
106
107
108class ComboConvBatchnormRelu6(torch.nn.Module):
109    edge_op_list = [
110        "executorch_exir_dialects_edge__ops_aten_convolution_default",
111        "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default",
112        "executorch_exir_dialects_edge__ops_aten_hardtanh_default",
113    ]
114
115    def __init__(self):
116        super().__init__()
117        self.conv2d = torch.nn.Conv2d(
118            in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1
119        )
120        self.batch_norm2d = torch.nn.BatchNorm2d(3, affine=False)
121        self.relu6 = torch.nn.ReLU6()
122
123    def get_inputs(self) -> Tuple[torch.Tensor]:
124        return (torch.randn(1, 3, 256, 256),)
125
126    def forward(self, x):
127        x = self.conv2d(x)
128        x = self.batch_norm2d(x)
129        x = self.relu6(x)
130        return x
131
132
133class ComboConvRelu6(torch.nn.Module):
134    edge_op_list = [
135        "executorch_exir_dialects_edge__ops_aten_convolution_default",
136        "executorch_exir_dialects_edge__ops_aten_hardtanh_default",
137    ]
138
139    test_data = [
140        (20 * torch.randn(1, 3, 256, 256),),
141        (5 * torch.randn(1, 3, 256, 256),),
142        (torch.randn(1, 3, 256, 256),),
143        (-5 * torch.randn(1, 3, 256, 256),),
144    ]
145
146    def __init__(self):
147        super().__init__()
148        self.conv2d = torch.nn.Conv2d(
149            in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1
150        )
151        self.relu6 = torch.nn.ReLU6()
152
153    def forward(self, x):
154        x = self.conv2d(x)
155        x = self.relu6(x)
156        return x
157
158
159class ComboConvAvgPool2d(torch.nn.Module):
160    edge_op_list = [
161        "executorch_exir_dialects_edge__ops_aten_convolution_default",
162        "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default",
163    ]
164
165    test_data = [
166        (20 * torch.randn(1, 3, 64, 32),),
167        (torch.randn(1, 3, 100, 200),),
168        (5 * torch.randn(1, 3, 256, 256),),
169        (torch.rand(1, 3, 512, 128),),
170    ]
171
172    def __init__(self):
173        super().__init__()
174        self.conv2d = torch.nn.Conv2d(
175            in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1
176        )
177        self.avg_pool2d = torch.nn.AvgPool2d(kernel_size=(2, 2))
178
179    def forward(self, x):
180        x = self.conv2d(x)
181        x = self.avg_pool2d(x)
182        return x
183
184
185class TestConvCombos(unittest.TestCase):
186    """Tests conv combined with other ops."""
187
188    def _test_conv_combo_tosa_MI_pipeline(
189        self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
190    ):
191        (
192            ArmTester(
193                module,
194                example_inputs=test_data,
195                compile_spec=common.get_tosa_compile_spec(
196                    "TOSA-0.80.0+MI", permute_memory_to_nhwc=True
197                ),
198            )
199            .export()
200            .to_edge()
201            .partition()
202            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
203            .check_not(list(module.edge_op_list))
204            .to_executorch()
205            .run_method_and_compare_outputs(inputs=test_data)
206        )
207
208    def _test_conv_combo_tosa_BI_pipeline(
209        self,
210        module: torch.nn.Module,
211        test_data: Tuple[torch.Tensor],
212        atol: float = 1e-3,
213        rtol: float = 1e-3,
214    ):
215        (
216            ArmTester(
217                module,
218                example_inputs=test_data,
219                compile_spec=common.get_tosa_compile_spec(
220                    "TOSA-0.80.0+BI", permute_memory_to_nhwc=True
221                ),
222            )
223            .quantize()
224            .export()
225            .to_edge()
226            .partition()
227            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
228            .check_not(list(module.edge_op_list))
229            .to_executorch()
230            .run_method_and_compare_outputs(
231                inputs=test_data, atol=atol, rtol=rtol, qtol=1
232            )
233        )
234
235    def _test_conv_combo_ethos_BI_pipeline(
236        self,
237        module: torch.nn.Module,
238        compile_spec: CompileSpec,
239        test_data: Tuple[torch.Tensor],
240    ):
241        (
242            ArmTester(
243                module,
244                example_inputs=test_data,
245                compile_spec=compile_spec,
246            )
247            .quantize()
248            .export()
249            .to_edge()
250            .partition()
251            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
252            .check_not(list(module.edge_op_list))
253            .to_executorch()
254        )
255
256    ####################
257    ## Conv + meandim ##
258    ####################
259    def test_conv_meandim_tosa_MI(self):
260        model = ComboConv2dMeandim()
261        self._test_conv_combo_tosa_MI_pipeline(model, model.get_inputs())
262
263    def test_conv_meandim_tosa_BI(self):
264        model = ComboConv2dMeandim()
265        self._test_conv_combo_tosa_BI_pipeline(model, model.get_inputs())
266
267    def test_conv_meandim_u55_BI(self):
268        model = ComboConv2dMeandim()
269        self._test_conv_combo_ethos_BI_pipeline(
270            model,
271            common.get_u55_compile_spec(permute_memory_to_nhwc=True),
272            model.get_inputs(),
273        )
274
275    def test_conv_meandim_u85_BI(self):
276        model = ComboConv2dMeandim()
277        self._test_conv_combo_ethos_BI_pipeline(
278            model,
279            common.get_u85_compile_spec(permute_memory_to_nhwc=True),
280            model.get_inputs(),
281        )
282
283    ##############################
284    ## Conv + batch norm + relu ##
285    ##############################
286    def test_conv_batchnorm_relu6_tosa_MI(self):
287        model = ComboConvBatchnormRelu6()
288        self._test_conv_combo_tosa_MI_pipeline(model, model.get_inputs())
289
290    def test_conv_batchnorm_relu6_tosa_BI(self):
291        model = ComboConvBatchnormRelu6()
292        self._test_conv_combo_tosa_BI_pipeline(model, model.get_inputs())
293
294    def test_conv_batchnorm_relu6_u55_BI(self):
295        model = ComboConvBatchnormRelu6()
296        self._test_conv_combo_ethos_BI_pipeline(
297            model, common.get_u55_compile_spec(), model.get_inputs()
298        )
299
300    def test_conv_batchnorm_relu_u85_BI(self):
301        model = ComboConvBatchnormRelu6()
302        self._test_conv_combo_ethos_BI_pipeline(
303            model,
304            common.get_u85_compile_spec(),
305            model.get_inputs(),
306        )
307
308    ##################
309    ## Conv + ReLU6 ##
310    ##################
311    @parameterized.expand(ComboConvRelu6.test_data)
312    def test_conv_relu6_tosa_MI(self, test_data: torch.Tensor):
313        model = ComboConvRelu6()
314        test_data = (test_data,)
315        self._test_conv_combo_tosa_MI_pipeline(model, test_data)
316
317    @parameterized.expand(ComboConvRelu6.test_data)
318    def test_conv_relu6_tosa_BI(self, test_data: torch.Tensor):
319        model = ComboConvRelu6()
320        test_data = (test_data,)
321        self._test_conv_combo_tosa_BI_pipeline(model, test_data)
322
323    @parameterized.expand(ComboConvRelu6.test_data)
324    def test_conv_relu6_u55_BI(self, test_data: torch.Tensor):
325        model = ComboConvRelu6()
326        test_data = (test_data,)
327        self._test_conv_combo_ethos_BI_pipeline(
328            model, common.get_u55_compile_spec(permute_memory_to_nhwc=True), test_data
329        )
330
331    @parameterized.expand(ComboConvRelu6.test_data)
332    def test_conv_relu6_u85_BI(self, test_data: torch.Tensor):
333        model = ComboConvRelu6()
334        test_data = (test_data,)
335        self._test_conv_combo_ethos_BI_pipeline(
336            model, common.get_u85_compile_spec(permute_memory_to_nhwc=True), test_data
337        )
338
339    ###############################
340    ## Block bottleneck residual ##
341    ###############################
342    def test_block_bottleneck_residual_tosa_MI(self):
343        model = ComboBlockBottleneckResidual()
344        self._test_conv_combo_tosa_MI_pipeline(model, model.get_inputs())
345
346    # TODO: Investigate flakyness (MLTORCH-307)
347    @pytest.mark.flaky(reruns=3)
348    def test_block_bottleneck_residual_tosa_BI(self):
349        model = ComboBlockBottleneckResidual()
350        self._test_conv_combo_tosa_BI_pipeline(model, model.get_inputs())
351
352    def test_block_bottleneck_residual_u55_BI(self):
353        model = ComboBlockBottleneckResidual()
354        self._test_conv_combo_ethos_BI_pipeline(
355            model,
356            common.get_u55_compile_spec(permute_memory_to_nhwc=True),
357            model.get_inputs(),
358        )
359
360    def test_block_bottleneck_residual_u85_BI(self):
361        model = ComboBlockBottleneckResidual()
362        self._test_conv_combo_ethos_BI_pipeline(
363            model,
364            common.get_u85_compile_spec(permute_memory_to_nhwc=True),
365            model.get_inputs(),
366        )
367
368    ######################
369    ## Conv + AvgPool2d ##
370    ######################
371    @parameterized.expand(ComboConvAvgPool2d.test_data)
372    def test_conv_avgpool2d_tosa_MI(self, test_data: torch.Tensor):
373        model = ComboConvAvgPool2d()
374        test_data = (test_data,)
375        self._test_conv_combo_tosa_MI_pipeline(model, test_data)
376
377    @parameterized.expand(ComboConvAvgPool2d.test_data)
378    def test_conv_avgpool2d_tosa_BI(self, test_data: torch.Tensor):
379        model = ComboConvAvgPool2d()
380        test_data = (test_data,)
381        self._test_conv_combo_tosa_BI_pipeline(model, test_data)
382
383    @parameterized.expand(ComboConvAvgPool2d.test_data)
384    def test_conv_avgpool2d_u55_BI(self, test_data: torch.Tensor):
385        model = ComboConvAvgPool2d()
386        test_data = (test_data,)
387        self._test_conv_combo_ethos_BI_pipeline(
388            model,
389            common.get_u55_compile_spec(),
390            test_data,
391        )
392
393    @parameterized.expand(ComboConvAvgPool2d.test_data)
394    def test_conv_avgpool2d_u85_BI(self, test_data: torch.Tensor):
395        model = ComboConvAvgPool2d()
396        test_data = (test_data,)
397        self._test_conv_combo_ethos_BI_pipeline(
398            model,
399            common.get_u85_compile_spec(),
400            test_data,
401        )
402