1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# Copyright 2024 Arm Limited and/or its affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8import logging 9import unittest 10 11from typing import Tuple 12 13import torch 14from executorch.backends.arm.test import common 15from executorch.backends.arm.test.tester.arm_tester import ArmTester 16from parameterized import parameterized 17 18logger = logging.getLogger(__name__) 19logger.setLevel(logging.INFO) 20 21test_data_suite = [ 22 # (test_name, test_data, [num_features, affine, track_running_stats, weight, bias, running_mean, running_var,] ) 23 ( 24 "zeros_affineT_runStatsT_default_weight_bias_mean_var", 25 torch.zeros(1, 32, 112, 112), 26 [ 27 32, 28 True, 29 True, 30 ], 31 ), 32 ( 33 "zeros_affineF_runStatsT_default_weight_bias_mean_var", 34 torch.zeros(1, 32, 112, 112), 35 [ 36 32, 37 False, 38 True, 39 ], 40 ), 41 ( 42 "zeros_affineT_runStatsT_rand_weight_bias_mean_var", 43 torch.zeros(1, 32, 112, 112), 44 [ 45 32, 46 True, 47 True, 48 torch.rand(32), 49 torch.rand(32), 50 torch.rand(32), 51 torch.rand(32), 52 ], 53 ), 54 ( 55 "zeros_affineF_runStatsT_rand_weight_bias_mean_var", 56 torch.zeros(1, 32, 112, 112), 57 [ 58 32, 59 False, 60 True, 61 torch.rand(32), 62 torch.rand(32), 63 torch.rand(32), 64 torch.rand(32), 65 ], 66 ), 67 ( 68 "ones_affineT_runStatsT_default_weight_bias_mean_var", 69 torch.ones(1, 32, 112, 112), 70 [ 71 32, 72 True, 73 True, 74 ], 75 ), 76 ( 77 "ones_affineF_runStatsT_default_weight_bias_mean_var", 78 torch.ones(1, 32, 112, 112), 79 [ 80 32, 81 False, 82 True, 83 ], 84 ), 85 ( 86 "ones_affineT_runStatsT_rand_weight_bias_mean_var", 87 torch.ones(1, 32, 112, 112), 88 [ 89 32, 90 True, 91 True, 92 torch.rand(32), 93 torch.rand(32), 94 torch.rand(32), 95 torch.rand(32), 96 ], 97 ), 98 ( 99 "ones_affineF_runStatsT_rand_weight_bias_mean_var", 100 torch.ones(1, 32, 112, 112), 101 [ 102 32, 103 False, 104 True, 105 torch.rand(32), 106 torch.rand(32), 107 torch.rand(32), 108 torch.rand(32), 109 ], 110 ), 111 ( 112 "rand_affineT_runStatsT_default_weight_bias_mean_var", 113 torch.rand(1, 32, 112, 112), 114 [ 115 32, 116 True, 117 True, 118 ], 119 ), 120 ( 121 "rand_affineF_runStatsT_default_weight_bias_mean_var", 122 torch.rand(1, 32, 112, 112), 123 [ 124 32, 125 False, 126 True, 127 ], 128 ), 129 ( 130 "rand_affineT_runStatsT_rand_weight_bias_mean_var", 131 torch.rand(1, 32, 112, 112), 132 [ 133 32, 134 True, 135 True, 136 torch.rand(32), 137 torch.rand(32), 138 torch.rand(32), 139 torch.rand(32), 140 ], 141 ), 142 ( 143 "rand_affineF_runStatsT_rand_weight_bias_mean_var", 144 torch.rand(1, 32, 112, 112), 145 [ 146 32, 147 False, 148 True, 149 torch.rand(32), 150 torch.rand(32), 151 torch.rand(32), 152 torch.rand(32), 153 ], 154 ), 155 ( 156 "randn_affineT_runStatsT_default_weight_bias_mean_var", 157 torch.randn(1, 32, 112, 112), 158 [ 159 32, 160 True, 161 True, 162 ], 163 ), 164 ( 165 "randn_affineF_runStatsT_default_weight_bias_mean_var", 166 torch.randn(1, 32, 112, 112), 167 [ 168 32, 169 False, 170 True, 171 ], 172 ), 173 ( 174 "randn_affineT_runStatsT_rand_weight_bias_mean_var", 175 torch.randn(1, 32, 112, 112), 176 [ 177 32, 178 True, 179 True, 180 torch.rand(32), 181 torch.rand(32), 182 torch.rand(32), 183 torch.rand(32), 184 ], 185 ), 186 ( 187 "randn_affineF_runStatsT_rand_weight_bias_mean_var", 188 torch.randn(1, 32, 112, 112), 189 [ 190 32, 191 False, 192 True, 193 torch.rand(32), 194 torch.rand(32), 195 torch.rand(32), 196 torch.rand(32), 197 ], 198 ), 199 # Test some different sizes 200 ( 201 "size_3_4_5_6_affineT_runStatsT_rand_weight_bias_mean_var", 202 torch.rand(3, 4, 5, 6), 203 [4, True, True, torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4)], 204 ), 205 ( 206 "size_3_4_5_6_affineF_runStatsT_rand_weight_bias_mean_var", 207 torch.rand(3, 4, 5, 6), 208 [4, True, True, torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4)], 209 ), 210 ( 211 "size_1_3_254_254_affineT_runStatsT_rand_weight_bias_mean_var", 212 torch.rand(1, 3, 254, 254), 213 [3, True, True, torch.rand(3), torch.rand(3), torch.rand(3), torch.rand(3)], 214 ), 215 ( 216 "size_1_3_254_254_affineF_runStatsT_rand_weight_bias_mean_var", 217 torch.rand(1, 3, 254, 254), 218 [3, True, True, torch.rand(3), torch.rand(3), torch.rand(3), torch.rand(3)], 219 ), 220 # Test combination of weight and bias 221 ( 222 "check_weight_bias_affineT_runStatsT_none_none", 223 torch.rand(1, 32, 112, 112), 224 [32, True, True, None, None], 225 ), 226 ( 227 "check_weight_bias_affineF_runStatsT_none_none", 228 torch.rand(1, 32, 112, 112), 229 [32, False, True, None, None], 230 ), 231 ( 232 "check_weight_bias_affineT_runStatsT_weight_none", 233 torch.rand(1, 32, 112, 112), 234 [32, True, True, torch.rand(32)], 235 ), 236 ( 237 "check_weight_bias_affineF_runStatsT_weight_none", 238 torch.rand(1, 32, 112, 112), 239 [32, False, True, torch.rand(32)], 240 ), 241 ( 242 "check_weight_bias_affineT_runStatsT_none_bias", 243 torch.rand(1, 32, 112, 112), 244 [32, True, True, None, torch.rand(32)], 245 ), 246 ( 247 "check_weight_bias_affineF_runStatsT_none_bias", 248 torch.rand(1, 32, 112, 112), 249 [32, False, True, None, torch.rand(32)], 250 ), 251 ( 252 "check_weight_bias_affineT_runStatsT_weight_bias", 253 torch.rand(1, 32, 112, 112), 254 [32, True, True, torch.rand(32), torch.rand(32)], 255 ), 256 ( 257 "check_weight_bias_affineF_runStatsT_weight_bias", 258 torch.rand(1, 32, 112, 112), 259 [32, False, True, torch.rand(32), torch.rand(32)], 260 ), 261 # Test combination of running_mean and running_var 262 ( 263 "check_mean_var_affineT_runStatsT_none_none", 264 torch.randn(1, 32, 112, 112), 265 [32, True, True, torch.rand(32), torch.rand(32), None, None], 266 ), 267 ( 268 "check_mean_var_affineF_runStatsT_none_none", 269 torch.randn(1, 32, 112, 112), 270 [32, False, True, torch.rand(32), torch.rand(32), None, None], 271 ), 272 ( 273 "check_mean_var_affineT_runStatsT_mean_none", 274 torch.randn(1, 32, 112, 112), 275 [32, True, True, torch.rand(32), torch.rand(32), torch.rand(32), None], 276 ), 277 ( 278 "check_mean_var_affineF_runStatsT_mean_none", 279 torch.randn(1, 32, 112, 112), 280 [32, False, True, torch.rand(32), torch.rand(32), torch.rand(32), None], 281 ), 282 ( 283 "check_mean_var_affineT_runStatsT_none_var", 284 torch.randn(1, 32, 112, 112), 285 [32, True, True, torch.rand(32), torch.rand(32), None, torch.rand(32)], 286 ), 287 ( 288 "check_mean_var_affineF_runStatsT_none_var", 289 torch.randn(1, 32, 112, 112), 290 [32, False, True, torch.rand(32), torch.rand(32), None, torch.rand(32)], 291 ), 292 ( 293 "check_mean_var_affineT_runStatsT_mean_var", 294 torch.randn(1, 32, 112, 112), 295 [ 296 32, 297 True, 298 True, 299 torch.rand(32), 300 torch.rand(32), 301 torch.rand(32), 302 torch.rand(32), 303 ], 304 ), 305 ( 306 "check_mean_var_affineF_runStatsT_mean_var", 307 torch.randn(1, 32, 112, 112), 308 [ 309 32, 310 False, 311 True, 312 torch.rand(32), 313 torch.rand(32), 314 torch.rand(32), 315 torch.rand(32), 316 ], 317 ), 318] 319 320test_no_stats_data_suite = [ 321 # (test_name, test_data, [num_features, affine, track_running_stats, weight, bias, running_mean, running_var, ] ) 322 ( 323 "zeros_affineT_runStatsF_default_weight_bias", 324 torch.zeros(1, 32, 112, 112), 325 [ 326 32, 327 True, 328 False, 329 ], 330 ), 331 ( 332 "zeros_affineF_runStatsF_default_weight_bias", 333 torch.zeros(1, 32, 112, 112), 334 [ 335 32, 336 False, 337 False, 338 ], 339 ), 340 ( 341 "zeros_affineT_runStatsF_rand_weight_bias", 342 torch.zeros(1, 32, 112, 112), 343 [32, True, False, torch.rand(32), torch.rand(32)], 344 ), 345 ( 346 "zeros_affineF_runStatsF_rand_weight_bias", 347 torch.zeros(1, 32, 112, 112), 348 [32, False, False, torch.rand(32), torch.rand(32)], 349 ), 350 ( 351 "ones_affineT_runStatsF_default_weight_bias", 352 torch.ones(1, 32, 112, 112), 353 [ 354 32, 355 True, 356 False, 357 ], 358 ), 359 ( 360 "ones_affineF_runStatsF_default_weight_bias", 361 torch.ones(1, 32, 112, 112), 362 [ 363 32, 364 False, 365 False, 366 ], 367 ), 368 ( 369 "ones_affineT_runStatsF_rand_weight_bias", 370 torch.ones(1, 32, 112, 112), 371 [32, True, False, torch.rand(32), torch.rand(32)], 372 ), 373 ( 374 "ones_affineF_runStatsF", 375 torch.ones(1, 32, 112, 112), 376 [32, False, False, torch.rand(32), torch.rand(32)], 377 ), 378 ( 379 "rand_affineT_runStatsF_default_weight_bias", 380 torch.rand(1, 32, 112, 112), 381 [ 382 32, 383 True, 384 False, 385 ], 386 ), 387 ( 388 "rand_affineF_runStatsF_default_weight_bias", 389 torch.rand(1, 32, 112, 112), 390 [ 391 32, 392 False, 393 False, 394 ], 395 ), 396 ( 397 "rand_affineT_runStatsF_rand_weight_bias", 398 torch.rand(1, 32, 112, 112), 399 [32, True, False, torch.rand(32), torch.rand(32)], 400 ), 401 ( 402 "rand_affineF_runStatsF_rand_weight_bias", 403 torch.rand(1, 32, 112, 112), 404 [32, False, False, torch.rand(32), torch.rand(32)], 405 ), 406 ( 407 "randn_affineT_runStatsF_default_weight_bias", 408 torch.randn(1, 32, 112, 112), 409 [ 410 32, 411 True, 412 False, 413 ], 414 ), 415 ( 416 "randn_affineF_runStatsF_default_weight_bias", 417 torch.randn(1, 32, 112, 112), 418 [ 419 32, 420 False, 421 False, 422 ], 423 ), 424 ( 425 "randn_affineT_runStatsF_rand_weight_bias", 426 torch.randn(1, 32, 112, 112), 427 [32, True, False, torch.rand(32), torch.rand(32)], 428 ), 429 ( 430 "randn_affineF_runStatsF_rand_weight_bias", 431 torch.randn(1, 32, 112, 112), 432 [32, False, False, torch.rand(32), torch.rand(32)], 433 ), 434 # Test some different sizes 435 ( 436 "size_3_4_5_6_affineT_runStatsF_rand_weight_bias_mean_var", 437 torch.rand(3, 4, 5, 6), 438 [4, True, False, torch.rand(4), torch.rand(4)], 439 ), 440 ( 441 "size_3_4_5_6_affineF_runStatsF_rand_weight_bias_mean_var", 442 torch.rand(3, 4, 5, 6), 443 [4, True, False, torch.rand(4), torch.rand(4)], 444 ), 445 ( 446 "size_1_3_254_254_affineT_runStatsF_rand_weight_bias_mean_var", 447 torch.rand(1, 3, 254, 254), 448 [3, True, False, torch.rand(3), torch.rand(3)], 449 ), 450 ( 451 "size_1_3_254_254_affineF_runStatsF_rand_weight_bias_mean_var", 452 torch.rand(1, 3, 254, 254), 453 [3, True, False, torch.rand(3), torch.rand(3)], 454 ), 455 # Test combination of weight and bias 456 ( 457 "check_weight_bias_affineT_runStatsF_none_none", 458 torch.rand(1, 32, 112, 112), 459 [32, True, False, None, None], 460 ), 461 ( 462 "check_weight_bias_affineF_runStatsF_none_none", 463 torch.rand(1, 32, 112, 112), 464 [32, False, False, None, None], 465 ), 466 ( 467 "check_weight_bias_affineT_runStatsF_weight_none", 468 torch.rand(1, 32, 112, 112), 469 [32, True, False, torch.rand(32)], 470 ), 471 ( 472 "check_weight_bias_affineF_runStatsF_weight_none", 473 torch.rand(1, 32, 112, 112), 474 [32, False, False, torch.rand(32)], 475 ), 476 ( 477 "check_weight_bias_affineT_runStatsF_none_bias", 478 torch.rand(1, 32, 112, 112), 479 [32, True, False, None, torch.rand(32)], 480 ), 481 ( 482 "check_weight_bias_affineF_runStatsF_none_bias", 483 torch.rand(1, 32, 112, 112), 484 [32, False, False, None, torch.rand(32)], 485 ), 486 ( 487 "check_weight_bias_affineT_runStatsF_weight_bias", 488 torch.rand(1, 32, 112, 112), 489 [32, True, False, torch.rand(32), torch.rand(32)], 490 ), 491 ( 492 "check_weight_bias_affineF_runStatsF_weight_bias", 493 torch.rand(1, 32, 112, 112), 494 [32, False, False, torch.rand(32), torch.rand(32)], 495 ), 496] 497 498 499class TestBatchNorm2d(unittest.TestCase): 500 """Tests BatchNorm2d.""" 501 502 class BatchNorm2d(torch.nn.Module): 503 def __init__( 504 self, 505 num_features: int = 32, 506 affine: bool = False, 507 track_running_stats: bool = True, 508 weights: torch.tensor = None, 509 bias: torch.tensor = None, 510 running_mean: torch.tensor = None, 511 running_var: torch.tensor = None, 512 ): 513 super().__init__() 514 self.batch_norm_2d = torch.nn.BatchNorm2d( 515 num_features, affine=affine, track_running_stats=track_running_stats 516 ) 517 if weights is not None: 518 self.batch_norm_2d.weight = torch.nn.Parameter(weights) 519 if bias is not None: 520 self.batch_norm_2d.bias = torch.nn.Parameter(bias) 521 if running_mean is not None: 522 self.batch_norm_2d.running_mean = running_mean 523 if running_var is not None: 524 self.batch_norm_2d.running_var = running_var 525 526 def forward(self, x): 527 return self.batch_norm_2d(x) 528 529 def _test_batchnorm2d_tosa_MI_pipeline( 530 self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] 531 ): 532 ( 533 ArmTester( 534 module, 535 example_inputs=test_data, 536 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), 537 ) 538 .export() 539 .check_not(["torch.ops.quantized_decomposed"]) 540 .to_edge() 541 .check_count( 542 { 543 "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 1 544 } 545 ) 546 .partition() 547 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 548 .check_not( 549 [ 550 "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default" 551 ] 552 ) 553 .to_executorch() 554 .run_method_and_compare_outputs(inputs=test_data) 555 ) 556 557 def _test_batchnorm2d_no_stats_tosa_MI_pipeline( 558 self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] 559 ): 560 ( 561 ArmTester( 562 module, 563 example_example_inputs=test_data, 564 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), 565 ) 566 .export() 567 .check_count({"torch.ops.aten._native_batch_norm_legit.no_stats": 1}) 568 .check_not(["torch.ops.quantized_decomposed"]) 569 .to_edge() 570 .check_count( 571 { 572 "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_stats": 1 573 } 574 ) 575 .partition() 576 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 577 .check_not( 578 [ 579 "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_stats" 580 ] 581 ) 582 .to_executorch() 583 .run_method_and_compare_outputs(inputs=test_data) 584 ) 585 586 def _test_batchnorm2d_tosa_BI_pipeline( 587 self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] 588 ): 589 ( 590 ArmTester( 591 module, 592 example_inputs=test_data, 593 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 594 ) 595 .quantize() 596 .export() 597 .check_count( 598 {"torch.ops.aten._native_batch_norm_legit_no_training.default": 1} 599 ) 600 .check(["torch.ops.quantized_decomposed"]) 601 .to_edge() 602 .check_count( 603 { 604 "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 1 605 } 606 ) 607 .partition() 608 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 609 .check_not( 610 [ 611 "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default" 612 ] 613 ) 614 .to_executorch() 615 .run_method_and_compare_outputs(inputs=test_data) 616 ) 617 618 def _test_batchnorm2d_u55_BI_pipeline( 619 self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] 620 ): 621 ( 622 ArmTester( 623 module, 624 example_inputs=test_data, 625 compile_spec=common.get_u55_compile_spec(), 626 ) 627 .quantize() 628 .export() 629 .check_count( 630 {"torch.ops.aten._native_batch_norm_legit_no_training.default": 1} 631 ) 632 .check(["torch.ops.quantized_decomposed"]) 633 .to_edge() 634 .check_count( 635 { 636 "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 1 637 } 638 ) 639 .partition() 640 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 641 .check_not( 642 [ 643 "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default" 644 ] 645 ) 646 .to_executorch() 647 ) 648 649 @parameterized.expand(test_data_suite) 650 def test_batchnorm2d_tosa_MI( 651 self, 652 test_name: str, 653 test_data: torch.Tensor, 654 model_params: ( 655 int 656 | Tuple[ 657 int, bool, bool, torch.tensor, torch.tensor, torch.tensor, torch.tensor 658 ] 659 ), 660 ): 661 self._test_batchnorm2d_tosa_MI_pipeline( 662 self.BatchNorm2d(*model_params), (test_data,) 663 ) 664 665 # Expected to fail since not inplemented 666 @parameterized.expand(test_no_stats_data_suite) 667 @unittest.expectedFailure 668 def test_batchnorm2d_no_stats_tosa_MI( 669 self, 670 test_name: str, 671 test_data: torch.Tensor, 672 model_params: ( 673 int 674 | Tuple[ 675 int, bool, bool, torch.tensor, torch.tensor, torch.tensor, torch.tensor 676 ] 677 ), 678 ): 679 self._test_batchnorm2d_no_stats_tosa_MI_pipeline( 680 self.BatchNorm2d(*model_params), (test_data,) 681 ) 682 683 # Expected to fail since ArmQuantizer cannot quantize a BatchNorm layer 684 # TODO(MLETORCH-100) 685 @parameterized.expand(test_data_suite) 686 @unittest.skip( 687 reason="Expected to fail since ArmQuantizer cannot quantize a BatchNorm layer" 688 ) 689 def test_batchnorm2d_tosa_BI( 690 self, 691 test_name: str, 692 test_data: torch.Tensor, 693 model_params: ( 694 int 695 | Tuple[ 696 int, bool, bool, torch.tensor, torch.tensor, torch.tensor, torch.tensor 697 ] 698 ), 699 ): 700 self._test_batchnorm2d_tosa_BI_pipeline( 701 self.BatchNorm2d(*model_params), (test_data,) 702 ) 703 704 # Expected to fail since ArmQuantizer cannot quantize a BatchNorm layer 705 # TODO(MLETORCH-100) 706 @parameterized.expand(test_data_suite) 707 @unittest.skip( 708 reason="Expected to fail since ArmQuantizer cannot quantize a BatchNorm layer" 709 ) 710 @unittest.expectedFailure 711 def test_batchnorm2d_u55_BI( 712 self, 713 test_name: str, 714 test_data: torch.Tensor, 715 model_params: ( 716 int 717 | Tuple[ 718 int, bool, bool, torch.tensor, torch.tensor, torch.tensor, torch.tensor 719 ] 720 ), 721 ): 722 self._test_batchnorm2d_u55_BI_pipeline( 723 self.BatchNorm2d(*model_params), (test_data,) 724 ) 725