• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# Copyright 2024 Arm Limited and/or its affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8import logging
9import unittest
10
11from typing import Tuple
12
13import torch
14from executorch.backends.arm.test import common
15from executorch.backends.arm.test.tester.arm_tester import ArmTester
16from parameterized import parameterized
17
18logger = logging.getLogger(__name__)
19logger.setLevel(logging.INFO)
20
21test_data_suite = [
22    # (test_name, test_data, [num_features, affine, track_running_stats, weight, bias, running_mean, running_var,] )
23    (
24        "zeros_affineT_runStatsT_default_weight_bias_mean_var",
25        torch.zeros(1, 32, 112, 112),
26        [
27            32,
28            True,
29            True,
30        ],
31    ),
32    (
33        "zeros_affineF_runStatsT_default_weight_bias_mean_var",
34        torch.zeros(1, 32, 112, 112),
35        [
36            32,
37            False,
38            True,
39        ],
40    ),
41    (
42        "zeros_affineT_runStatsT_rand_weight_bias_mean_var",
43        torch.zeros(1, 32, 112, 112),
44        [
45            32,
46            True,
47            True,
48            torch.rand(32),
49            torch.rand(32),
50            torch.rand(32),
51            torch.rand(32),
52        ],
53    ),
54    (
55        "zeros_affineF_runStatsT_rand_weight_bias_mean_var",
56        torch.zeros(1, 32, 112, 112),
57        [
58            32,
59            False,
60            True,
61            torch.rand(32),
62            torch.rand(32),
63            torch.rand(32),
64            torch.rand(32),
65        ],
66    ),
67    (
68        "ones_affineT_runStatsT_default_weight_bias_mean_var",
69        torch.ones(1, 32, 112, 112),
70        [
71            32,
72            True,
73            True,
74        ],
75    ),
76    (
77        "ones_affineF_runStatsT_default_weight_bias_mean_var",
78        torch.ones(1, 32, 112, 112),
79        [
80            32,
81            False,
82            True,
83        ],
84    ),
85    (
86        "ones_affineT_runStatsT_rand_weight_bias_mean_var",
87        torch.ones(1, 32, 112, 112),
88        [
89            32,
90            True,
91            True,
92            torch.rand(32),
93            torch.rand(32),
94            torch.rand(32),
95            torch.rand(32),
96        ],
97    ),
98    (
99        "ones_affineF_runStatsT_rand_weight_bias_mean_var",
100        torch.ones(1, 32, 112, 112),
101        [
102            32,
103            False,
104            True,
105            torch.rand(32),
106            torch.rand(32),
107            torch.rand(32),
108            torch.rand(32),
109        ],
110    ),
111    (
112        "rand_affineT_runStatsT_default_weight_bias_mean_var",
113        torch.rand(1, 32, 112, 112),
114        [
115            32,
116            True,
117            True,
118        ],
119    ),
120    (
121        "rand_affineF_runStatsT_default_weight_bias_mean_var",
122        torch.rand(1, 32, 112, 112),
123        [
124            32,
125            False,
126            True,
127        ],
128    ),
129    (
130        "rand_affineT_runStatsT_rand_weight_bias_mean_var",
131        torch.rand(1, 32, 112, 112),
132        [
133            32,
134            True,
135            True,
136            torch.rand(32),
137            torch.rand(32),
138            torch.rand(32),
139            torch.rand(32),
140        ],
141    ),
142    (
143        "rand_affineF_runStatsT_rand_weight_bias_mean_var",
144        torch.rand(1, 32, 112, 112),
145        [
146            32,
147            False,
148            True,
149            torch.rand(32),
150            torch.rand(32),
151            torch.rand(32),
152            torch.rand(32),
153        ],
154    ),
155    (
156        "randn_affineT_runStatsT_default_weight_bias_mean_var",
157        torch.randn(1, 32, 112, 112),
158        [
159            32,
160            True,
161            True,
162        ],
163    ),
164    (
165        "randn_affineF_runStatsT_default_weight_bias_mean_var",
166        torch.randn(1, 32, 112, 112),
167        [
168            32,
169            False,
170            True,
171        ],
172    ),
173    (
174        "randn_affineT_runStatsT_rand_weight_bias_mean_var",
175        torch.randn(1, 32, 112, 112),
176        [
177            32,
178            True,
179            True,
180            torch.rand(32),
181            torch.rand(32),
182            torch.rand(32),
183            torch.rand(32),
184        ],
185    ),
186    (
187        "randn_affineF_runStatsT_rand_weight_bias_mean_var",
188        torch.randn(1, 32, 112, 112),
189        [
190            32,
191            False,
192            True,
193            torch.rand(32),
194            torch.rand(32),
195            torch.rand(32),
196            torch.rand(32),
197        ],
198    ),
199    # Test some different sizes
200    (
201        "size_3_4_5_6_affineT_runStatsT_rand_weight_bias_mean_var",
202        torch.rand(3, 4, 5, 6),
203        [4, True, True, torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4)],
204    ),
205    (
206        "size_3_4_5_6_affineF_runStatsT_rand_weight_bias_mean_var",
207        torch.rand(3, 4, 5, 6),
208        [4, True, True, torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4)],
209    ),
210    (
211        "size_1_3_254_254_affineT_runStatsT_rand_weight_bias_mean_var",
212        torch.rand(1, 3, 254, 254),
213        [3, True, True, torch.rand(3), torch.rand(3), torch.rand(3), torch.rand(3)],
214    ),
215    (
216        "size_1_3_254_254_affineF_runStatsT_rand_weight_bias_mean_var",
217        torch.rand(1, 3, 254, 254),
218        [3, True, True, torch.rand(3), torch.rand(3), torch.rand(3), torch.rand(3)],
219    ),
220    # Test combination of weight and bias
221    (
222        "check_weight_bias_affineT_runStatsT_none_none",
223        torch.rand(1, 32, 112, 112),
224        [32, True, True, None, None],
225    ),
226    (
227        "check_weight_bias_affineF_runStatsT_none_none",
228        torch.rand(1, 32, 112, 112),
229        [32, False, True, None, None],
230    ),
231    (
232        "check_weight_bias_affineT_runStatsT_weight_none",
233        torch.rand(1, 32, 112, 112),
234        [32, True, True, torch.rand(32)],
235    ),
236    (
237        "check_weight_bias_affineF_runStatsT_weight_none",
238        torch.rand(1, 32, 112, 112),
239        [32, False, True, torch.rand(32)],
240    ),
241    (
242        "check_weight_bias_affineT_runStatsT_none_bias",
243        torch.rand(1, 32, 112, 112),
244        [32, True, True, None, torch.rand(32)],
245    ),
246    (
247        "check_weight_bias_affineF_runStatsT_none_bias",
248        torch.rand(1, 32, 112, 112),
249        [32, False, True, None, torch.rand(32)],
250    ),
251    (
252        "check_weight_bias_affineT_runStatsT_weight_bias",
253        torch.rand(1, 32, 112, 112),
254        [32, True, True, torch.rand(32), torch.rand(32)],
255    ),
256    (
257        "check_weight_bias_affineF_runStatsT_weight_bias",
258        torch.rand(1, 32, 112, 112),
259        [32, False, True, torch.rand(32), torch.rand(32)],
260    ),
261    # Test combination of running_mean and running_var
262    (
263        "check_mean_var_affineT_runStatsT_none_none",
264        torch.randn(1, 32, 112, 112),
265        [32, True, True, torch.rand(32), torch.rand(32), None, None],
266    ),
267    (
268        "check_mean_var_affineF_runStatsT_none_none",
269        torch.randn(1, 32, 112, 112),
270        [32, False, True, torch.rand(32), torch.rand(32), None, None],
271    ),
272    (
273        "check_mean_var_affineT_runStatsT_mean_none",
274        torch.randn(1, 32, 112, 112),
275        [32, True, True, torch.rand(32), torch.rand(32), torch.rand(32), None],
276    ),
277    (
278        "check_mean_var_affineF_runStatsT_mean_none",
279        torch.randn(1, 32, 112, 112),
280        [32, False, True, torch.rand(32), torch.rand(32), torch.rand(32), None],
281    ),
282    (
283        "check_mean_var_affineT_runStatsT_none_var",
284        torch.randn(1, 32, 112, 112),
285        [32, True, True, torch.rand(32), torch.rand(32), None, torch.rand(32)],
286    ),
287    (
288        "check_mean_var_affineF_runStatsT_none_var",
289        torch.randn(1, 32, 112, 112),
290        [32, False, True, torch.rand(32), torch.rand(32), None, torch.rand(32)],
291    ),
292    (
293        "check_mean_var_affineT_runStatsT_mean_var",
294        torch.randn(1, 32, 112, 112),
295        [
296            32,
297            True,
298            True,
299            torch.rand(32),
300            torch.rand(32),
301            torch.rand(32),
302            torch.rand(32),
303        ],
304    ),
305    (
306        "check_mean_var_affineF_runStatsT_mean_var",
307        torch.randn(1, 32, 112, 112),
308        [
309            32,
310            False,
311            True,
312            torch.rand(32),
313            torch.rand(32),
314            torch.rand(32),
315            torch.rand(32),
316        ],
317    ),
318]
319
320test_no_stats_data_suite = [
321    # (test_name, test_data, [num_features, affine, track_running_stats, weight, bias, running_mean, running_var, ] )
322    (
323        "zeros_affineT_runStatsF_default_weight_bias",
324        torch.zeros(1, 32, 112, 112),
325        [
326            32,
327            True,
328            False,
329        ],
330    ),
331    (
332        "zeros_affineF_runStatsF_default_weight_bias",
333        torch.zeros(1, 32, 112, 112),
334        [
335            32,
336            False,
337            False,
338        ],
339    ),
340    (
341        "zeros_affineT_runStatsF_rand_weight_bias",
342        torch.zeros(1, 32, 112, 112),
343        [32, True, False, torch.rand(32), torch.rand(32)],
344    ),
345    (
346        "zeros_affineF_runStatsF_rand_weight_bias",
347        torch.zeros(1, 32, 112, 112),
348        [32, False, False, torch.rand(32), torch.rand(32)],
349    ),
350    (
351        "ones_affineT_runStatsF_default_weight_bias",
352        torch.ones(1, 32, 112, 112),
353        [
354            32,
355            True,
356            False,
357        ],
358    ),
359    (
360        "ones_affineF_runStatsF_default_weight_bias",
361        torch.ones(1, 32, 112, 112),
362        [
363            32,
364            False,
365            False,
366        ],
367    ),
368    (
369        "ones_affineT_runStatsF_rand_weight_bias",
370        torch.ones(1, 32, 112, 112),
371        [32, True, False, torch.rand(32), torch.rand(32)],
372    ),
373    (
374        "ones_affineF_runStatsF",
375        torch.ones(1, 32, 112, 112),
376        [32, False, False, torch.rand(32), torch.rand(32)],
377    ),
378    (
379        "rand_affineT_runStatsF_default_weight_bias",
380        torch.rand(1, 32, 112, 112),
381        [
382            32,
383            True,
384            False,
385        ],
386    ),
387    (
388        "rand_affineF_runStatsF_default_weight_bias",
389        torch.rand(1, 32, 112, 112),
390        [
391            32,
392            False,
393            False,
394        ],
395    ),
396    (
397        "rand_affineT_runStatsF_rand_weight_bias",
398        torch.rand(1, 32, 112, 112),
399        [32, True, False, torch.rand(32), torch.rand(32)],
400    ),
401    (
402        "rand_affineF_runStatsF_rand_weight_bias",
403        torch.rand(1, 32, 112, 112),
404        [32, False, False, torch.rand(32), torch.rand(32)],
405    ),
406    (
407        "randn_affineT_runStatsF_default_weight_bias",
408        torch.randn(1, 32, 112, 112),
409        [
410            32,
411            True,
412            False,
413        ],
414    ),
415    (
416        "randn_affineF_runStatsF_default_weight_bias",
417        torch.randn(1, 32, 112, 112),
418        [
419            32,
420            False,
421            False,
422        ],
423    ),
424    (
425        "randn_affineT_runStatsF_rand_weight_bias",
426        torch.randn(1, 32, 112, 112),
427        [32, True, False, torch.rand(32), torch.rand(32)],
428    ),
429    (
430        "randn_affineF_runStatsF_rand_weight_bias",
431        torch.randn(1, 32, 112, 112),
432        [32, False, False, torch.rand(32), torch.rand(32)],
433    ),
434    # Test some different sizes
435    (
436        "size_3_4_5_6_affineT_runStatsF_rand_weight_bias_mean_var",
437        torch.rand(3, 4, 5, 6),
438        [4, True, False, torch.rand(4), torch.rand(4)],
439    ),
440    (
441        "size_3_4_5_6_affineF_runStatsF_rand_weight_bias_mean_var",
442        torch.rand(3, 4, 5, 6),
443        [4, True, False, torch.rand(4), torch.rand(4)],
444    ),
445    (
446        "size_1_3_254_254_affineT_runStatsF_rand_weight_bias_mean_var",
447        torch.rand(1, 3, 254, 254),
448        [3, True, False, torch.rand(3), torch.rand(3)],
449    ),
450    (
451        "size_1_3_254_254_affineF_runStatsF_rand_weight_bias_mean_var",
452        torch.rand(1, 3, 254, 254),
453        [3, True, False, torch.rand(3), torch.rand(3)],
454    ),
455    # Test combination of weight and bias
456    (
457        "check_weight_bias_affineT_runStatsF_none_none",
458        torch.rand(1, 32, 112, 112),
459        [32, True, False, None, None],
460    ),
461    (
462        "check_weight_bias_affineF_runStatsF_none_none",
463        torch.rand(1, 32, 112, 112),
464        [32, False, False, None, None],
465    ),
466    (
467        "check_weight_bias_affineT_runStatsF_weight_none",
468        torch.rand(1, 32, 112, 112),
469        [32, True, False, torch.rand(32)],
470    ),
471    (
472        "check_weight_bias_affineF_runStatsF_weight_none",
473        torch.rand(1, 32, 112, 112),
474        [32, False, False, torch.rand(32)],
475    ),
476    (
477        "check_weight_bias_affineT_runStatsF_none_bias",
478        torch.rand(1, 32, 112, 112),
479        [32, True, False, None, torch.rand(32)],
480    ),
481    (
482        "check_weight_bias_affineF_runStatsF_none_bias",
483        torch.rand(1, 32, 112, 112),
484        [32, False, False, None, torch.rand(32)],
485    ),
486    (
487        "check_weight_bias_affineT_runStatsF_weight_bias",
488        torch.rand(1, 32, 112, 112),
489        [32, True, False, torch.rand(32), torch.rand(32)],
490    ),
491    (
492        "check_weight_bias_affineF_runStatsF_weight_bias",
493        torch.rand(1, 32, 112, 112),
494        [32, False, False, torch.rand(32), torch.rand(32)],
495    ),
496]
497
498
499class TestBatchNorm2d(unittest.TestCase):
500    """Tests BatchNorm2d."""
501
502    class BatchNorm2d(torch.nn.Module):
503        def __init__(
504            self,
505            num_features: int = 32,
506            affine: bool = False,
507            track_running_stats: bool = True,
508            weights: torch.tensor = None,
509            bias: torch.tensor = None,
510            running_mean: torch.tensor = None,
511            running_var: torch.tensor = None,
512        ):
513            super().__init__()
514            self.batch_norm_2d = torch.nn.BatchNorm2d(
515                num_features, affine=affine, track_running_stats=track_running_stats
516            )
517            if weights is not None:
518                self.batch_norm_2d.weight = torch.nn.Parameter(weights)
519            if bias is not None:
520                self.batch_norm_2d.bias = torch.nn.Parameter(bias)
521            if running_mean is not None:
522                self.batch_norm_2d.running_mean = running_mean
523            if running_var is not None:
524                self.batch_norm_2d.running_var = running_var
525
526        def forward(self, x):
527            return self.batch_norm_2d(x)
528
529    def _test_batchnorm2d_tosa_MI_pipeline(
530        self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
531    ):
532        (
533            ArmTester(
534                module,
535                example_inputs=test_data,
536                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
537            )
538            .export()
539            .check_not(["torch.ops.quantized_decomposed"])
540            .to_edge()
541            .check_count(
542                {
543                    "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 1
544                }
545            )
546            .partition()
547            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
548            .check_not(
549                [
550                    "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default"
551                ]
552            )
553            .to_executorch()
554            .run_method_and_compare_outputs(inputs=test_data)
555        )
556
557    def _test_batchnorm2d_no_stats_tosa_MI_pipeline(
558        self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
559    ):
560        (
561            ArmTester(
562                module,
563                example_example_inputs=test_data,
564                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
565            )
566            .export()
567            .check_count({"torch.ops.aten._native_batch_norm_legit.no_stats": 1})
568            .check_not(["torch.ops.quantized_decomposed"])
569            .to_edge()
570            .check_count(
571                {
572                    "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_stats": 1
573                }
574            )
575            .partition()
576            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
577            .check_not(
578                [
579                    "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_stats"
580                ]
581            )
582            .to_executorch()
583            .run_method_and_compare_outputs(inputs=test_data)
584        )
585
586    def _test_batchnorm2d_tosa_BI_pipeline(
587        self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
588    ):
589        (
590            ArmTester(
591                module,
592                example_inputs=test_data,
593                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
594            )
595            .quantize()
596            .export()
597            .check_count(
598                {"torch.ops.aten._native_batch_norm_legit_no_training.default": 1}
599            )
600            .check(["torch.ops.quantized_decomposed"])
601            .to_edge()
602            .check_count(
603                {
604                    "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 1
605                }
606            )
607            .partition()
608            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
609            .check_not(
610                [
611                    "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default"
612                ]
613            )
614            .to_executorch()
615            .run_method_and_compare_outputs(inputs=test_data)
616        )
617
618    def _test_batchnorm2d_u55_BI_pipeline(
619        self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
620    ):
621        (
622            ArmTester(
623                module,
624                example_inputs=test_data,
625                compile_spec=common.get_u55_compile_spec(),
626            )
627            .quantize()
628            .export()
629            .check_count(
630                {"torch.ops.aten._native_batch_norm_legit_no_training.default": 1}
631            )
632            .check(["torch.ops.quantized_decomposed"])
633            .to_edge()
634            .check_count(
635                {
636                    "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 1
637                }
638            )
639            .partition()
640            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
641            .check_not(
642                [
643                    "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default"
644                ]
645            )
646            .to_executorch()
647        )
648
649    @parameterized.expand(test_data_suite)
650    def test_batchnorm2d_tosa_MI(
651        self,
652        test_name: str,
653        test_data: torch.Tensor,
654        model_params: (
655            int
656            | Tuple[
657                int, bool, bool, torch.tensor, torch.tensor, torch.tensor, torch.tensor
658            ]
659        ),
660    ):
661        self._test_batchnorm2d_tosa_MI_pipeline(
662            self.BatchNorm2d(*model_params), (test_data,)
663        )
664
665    # Expected to fail since not inplemented
666    @parameterized.expand(test_no_stats_data_suite)
667    @unittest.expectedFailure
668    def test_batchnorm2d_no_stats_tosa_MI(
669        self,
670        test_name: str,
671        test_data: torch.Tensor,
672        model_params: (
673            int
674            | Tuple[
675                int, bool, bool, torch.tensor, torch.tensor, torch.tensor, torch.tensor
676            ]
677        ),
678    ):
679        self._test_batchnorm2d_no_stats_tosa_MI_pipeline(
680            self.BatchNorm2d(*model_params), (test_data,)
681        )
682
683    # Expected to fail since ArmQuantizer cannot quantize a BatchNorm layer
684    # TODO(MLETORCH-100)
685    @parameterized.expand(test_data_suite)
686    @unittest.skip(
687        reason="Expected to fail since ArmQuantizer cannot quantize a BatchNorm layer"
688    )
689    def test_batchnorm2d_tosa_BI(
690        self,
691        test_name: str,
692        test_data: torch.Tensor,
693        model_params: (
694            int
695            | Tuple[
696                int, bool, bool, torch.tensor, torch.tensor, torch.tensor, torch.tensor
697            ]
698        ),
699    ):
700        self._test_batchnorm2d_tosa_BI_pipeline(
701            self.BatchNorm2d(*model_params), (test_data,)
702        )
703
704    # Expected to fail since ArmQuantizer cannot quantize a BatchNorm layer
705    # TODO(MLETORCH-100)
706    @parameterized.expand(test_data_suite)
707    @unittest.skip(
708        reason="Expected to fail since ArmQuantizer cannot quantize a BatchNorm layer"
709    )
710    @unittest.expectedFailure
711    def test_batchnorm2d_u55_BI(
712        self,
713        test_name: str,
714        test_data: torch.Tensor,
715        model_params: (
716            int
717            | Tuple[
718                int, bool, bool, torch.tensor, torch.tensor, torch.tensor, torch.tensor
719            ]
720        ),
721    ):
722        self._test_batchnorm2d_u55_BI_pipeline(
723            self.BatchNorm2d(*model_params), (test_data,)
724        )
725