# Owner(s): ["module: inductor"] import copy import importlib import itertools import os import sys import torch from torch import nn # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) from torch._dynamo.utils import counters from torch._inductor import config as inductor_config from torch._inductor.test_case import TestCase from torch.testing._internal.common_utils import TEST_WITH_ASAN from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA importlib.import_module("functorch") importlib.import_module("filelock") from inductor.test_torchinductor import copy_tests class ConvOp(nn.Module): expected_optimization_count = 1 def __init__( self, conv_class, bn_class, use_bias, in_channels, out_channels, device, **kwargs, ): super().__init__() self.conv = conv_class(in_channels, out_channels, bias=use_bias, **kwargs).to( device ) self.bn = bn_class(out_channels).to(device) def forward(self, x): x = self.conv(x) return self.bn(x) class MultiUserConvOp(nn.Module): expected_optimization_count = 3 def __init__( self, conv_class, bn_class, use_bias, in_channels, out_channels, device, **kwargs, ): super().__init__() self.conv1 = conv_class(in_channels, out_channels, bias=use_bias, **kwargs).to( device ) self.bn1 = bn_class(out_channels).to(device) self.conv2 = conv_class(out_channels, out_channels, bias=use_bias, **kwargs).to( device ) self.bn2 = bn_class(out_channels).to(device) self.conv3 = conv_class(out_channels, out_channels, bias=use_bias, **kwargs).to( device ) self.bn3 = bn_class(out_channels).to(device) def forward(self, x): # this conv-bn pair can use efficient_conv_bn_eval x = self.bn1(self.conv1(input=x)) # this conv-bn pair cannot use efficient_conv_bn_eval feature # just for the second forward of the `self.conv2` x = self.bn2(input=self.conv2(self.conv2(x))) # this conv-bn pair can use efficient_conv_bn_eval feature # just for the first forward of the `self.bn3` # test for multiple users of one computation node x = self.bn3(input=self.conv3(input=x)) x = self.bn3(x) + x return x class EfficientConvBNEvalTemplate(TestCase): @inductor_config.patch({"efficient_conv_bn_eval_fx_passes": True}) def test_basic(self): def test_conv_bn_eval( test_class, use_bias, module, sync_bn, decompose_nn_module ): from functorch import make_fx from torch._dispatch.python import enable_python_dispatcher kwargs = {"kernel_size": 3, "stride": 2} if module[0] != nn.Linear else {} mod_eager = test_class( module[0], module[1], use_bias, 3, 32, self.device, **kwargs, ).eval() # Copy module to test backward mod_optimized = copy.deepcopy(mod_eager) if sync_bn: mod_eager = nn.SyncBatchNorm.convert_sync_batchnorm(mod_eager).eval() mod_optimized = nn.SyncBatchNorm.convert_sync_batchnorm( mod_optimized ).eval() torch._dynamo.reset() inps = [4, 3] # Conv shape goes from big to small, and ConvTranspose shape goes from small to big spatial_d = ( 4 if issubclass(module[0], nn.modules.conv._ConvTransposeNd) else 96 ) if module[0] == nn.Conv1d or module[0] == nn.ConvTranspose1d: inps += [spatial_d] * 1 if module[0] == nn.Conv2d or module[0] == nn.ConvTranspose2d: inps += [spatial_d] * 2 if module[0] == nn.Conv3d or module[0] == nn.ConvTranspose3d: inps += [spatial_d] * 3 inp = torch.rand(inps).to(self.device) if decompose_nn_module: with enable_python_dispatcher(): mod_optimized = make_fx(mod_optimized, pre_dispatch=True)(inp) mod_optimized = torch.compile(mod_optimized) original_value = counters["inductor"]["efficient_conv_bn_eval"] optim_eager = torch.optim.SGD(mod_eager.parameters(), lr=1e-3) optim_optimized = torch.optim.SGD(mod_optimized.parameters(), lr=1e-3) optim_eager.zero_grad() optim_optimized.zero_grad() # test forward out_eager = mod_eager(inp) out_optimized = mod_optimized(inp) self.assertEqual(out_optimized, out_eager, atol=3e-04, rtol=1e-5) out_eager.mean().backward() out_optimized.mean().backward() optim_eager.step() optim_optimized.step() # test forward (by testing forward again after one training iteration) inp_bw = torch.rand_like(inp) out_eager_bw = mod_eager(inp_bw) out_optimized_bw = mod_optimized(inp_bw) self.assertEqual(out_eager_bw, out_optimized_bw, atol=3e-04, rtol=1e-5) current_value = counters["inductor"]["efficient_conv_bn_eval"] self.assertEqual( current_value - original_value, test_class.expected_optimization_count ) conv_bias = [True, False] modules = [ (nn.Linear, nn.BatchNorm1d), (nn.Conv1d, nn.BatchNorm1d), (nn.Conv2d, nn.BatchNorm2d), (nn.Conv3d, nn.BatchNorm3d), (nn.ConvTranspose1d, nn.BatchNorm1d), (nn.ConvTranspose2d, nn.BatchNorm2d), (nn.ConvTranspose3d, nn.BatchNorm3d), ] test_classes = [ConvOp, MultiUserConvOp] sync_bns = [False, True] decompose_nn_modules = [False, True] for ( test_class, use_bias, module, sync_bn, decompose_nn_module, ) in itertools.product( test_classes, conv_bias, modules, sync_bns, decompose_nn_modules, ): test_conv_bn_eval( test_class, use_bias, module, sync_bn, decompose_nn_module ) if HAS_CPU and not torch.backends.mps.is_available(): class EfficientConvBNEvalCpuTests(TestCase): device = "cpu" copy_tests(EfficientConvBNEvalTemplate, EfficientConvBNEvalCpuTests, "cpu") if HAS_CUDA and not TEST_WITH_ASAN: class EfficientConvBNEvalCudaTests(TestCase): device = "cuda" copy_tests(EfficientConvBNEvalTemplate, EfficientConvBNEvalCudaTests, "cuda") del EfficientConvBNEvalTemplate if __name__ == "__main__": from torch._inductor.test_case import run_tests if HAS_CPU or HAS_CUDA: run_tests(needs="filelock")