1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8from typing import Tuple 9 10import torch 11from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import ( 12 FuseBatchNormWithConvPass, 13) 14from executorch.backends.xnnpack.test.tester import RunPasses, Tester 15 16 17class TestBatchNormFusion(unittest.TestCase): 18 PassStage = RunPasses([FuseBatchNormWithConvPass]) 19 bn_name = "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default" 20 21 class ModelConvBN(torch.nn.Module): 22 def __init__( 23 self, in_features: int, out_features: int, kernel_size: Tuple[int, int] 24 ): 25 super().__init__() 26 self.conv2d = torch.nn.Conv2d(in_features, out_features, kernel_size) 27 self.bn = torch.nn.BatchNorm2d(out_features) 28 29 def forward(self, x): 30 y = self.conv2d(x) 31 y = self.bn(y) 32 y = self.conv2d(y) 33 y = y + y 34 return self.bn(y) 35 36 def test_fp32_batch_norm_fusion(self): 37 ( 38 Tester(self.ModelConvBN(2, 2, (2, 2)).eval(), (torch.randn(2, 2, 4, 4),)) 39 .export() 40 .to_edge() 41 .run_passes(self.PassStage) 42 .check_count({self.bn_name: 1}) 43 .run_method_and_compare_outputs() 44 ) 45 46 def test_q8_batch_norm_fusion(self): 47 ( 48 Tester(self.ModelConvBN(2, 2, (2, 2)).eval(), (torch.randn(2, 2, 4, 4),)) 49 .quantize() 50 .export() 51 .to_edge() 52 .run_passes(self.PassStage) 53 .check_count({self.bn_name: 1}) 54 .run_method_and_compare_outputs() 55 ) 56 57 def test_fp32_batch_norm_no_fusion_doesnt_partition(self): 58 """ 59 We do not currently support standalone batch norms (i.e. batch norms that are 60 not fused with a conv). This is planned, but until implemented, this test ensures 61 that we do not partition the standalone batch norm and then fail to lower. 62 """ 63 64 class BN(torch.nn.Module): 65 def __init__(self): 66 super().__init__() 67 self.bn = torch.nn.BatchNorm2d(2) 68 69 def forward(self, x): 70 return self.bn(x) 71 72 ( 73 Tester(BN(), (torch.randn(2, 2, 4, 4),)) 74 .export() 75 .to_edge() 76 .check_count({self.bn_name: 1}) 77 .partition() 78 .check_count({self.bn_name: 1}) 79 ) 80