• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8from typing import Tuple
9
10import torch
11from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import (
12    FuseBatchNormWithConvPass,
13)
14from executorch.backends.xnnpack.test.tester import RunPasses, Tester
15
16
17class TestBatchNormFusion(unittest.TestCase):
18    PassStage = RunPasses([FuseBatchNormWithConvPass])
19    bn_name = "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default"
20
21    class ModelConvBN(torch.nn.Module):
22        def __init__(
23            self, in_features: int, out_features: int, kernel_size: Tuple[int, int]
24        ):
25            super().__init__()
26            self.conv2d = torch.nn.Conv2d(in_features, out_features, kernel_size)
27            self.bn = torch.nn.BatchNorm2d(out_features)
28
29        def forward(self, x):
30            y = self.conv2d(x)
31            y = self.bn(y)
32            y = self.conv2d(y)
33            y = y + y
34            return self.bn(y)
35
36    def test_fp32_batch_norm_fusion(self):
37        (
38            Tester(self.ModelConvBN(2, 2, (2, 2)).eval(), (torch.randn(2, 2, 4, 4),))
39            .export()
40            .to_edge()
41            .run_passes(self.PassStage)
42            .check_count({self.bn_name: 1})
43            .run_method_and_compare_outputs()
44        )
45
46    def test_q8_batch_norm_fusion(self):
47        (
48            Tester(self.ModelConvBN(2, 2, (2, 2)).eval(), (torch.randn(2, 2, 4, 4),))
49            .quantize()
50            .export()
51            .to_edge()
52            .run_passes(self.PassStage)
53            .check_count({self.bn_name: 1})
54            .run_method_and_compare_outputs()
55        )
56
57    def test_fp32_batch_norm_no_fusion_doesnt_partition(self):
58        """
59        We do not currently support standalone batch norms (i.e. batch norms that are
60        not fused with a conv). This is planned, but until implemented, this test ensures
61        that we do not partition the standalone batch norm and then fail to lower.
62        """
63
64        class BN(torch.nn.Module):
65            def __init__(self):
66                super().__init__()
67                self.bn = torch.nn.BatchNorm2d(2)
68
69            def forward(self, x):
70                return self.bn(x)
71
72        (
73            Tester(BN(), (torch.randn(2, 2, 4, 4),))
74            .export()
75            .to_edge()
76            .check_count({self.bn_name: 1})
77            .partition()
78            .check_count({self.bn_name: 1})
79        )
80