# Owner(s): ["module: inductor"] import contextlib import functools import importlib import itertools import os import sys import unittest import weakref import torch from torch import nn from torch._dynamo.utils import counters from torch._inductor import config from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.utils import override_lowering, run_and_get_code from torch.testing import FileCheck from torch.testing._internal.common_cuda import SM80OrLater from torch.testing._internal.common_utils import skipIfRocm # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) from inductor.test_torchinductor import check_model, check_model_cuda, copy_tests from torch.testing._internal.common_utils import TEST_WITH_ASAN, TEST_WITH_ROCM importlib.import_module("functorch") importlib.import_module("filelock") from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA aten = torch.ops.aten prims = torch.ops.prims requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") class TestCase(InductorTestCase): @classmethod def setUpClass(cls): super().setUpClass() cls._stack = contextlib.ExitStack() cls._stack.enter_context( config.patch( { "debug": True, "cpp.min_chunk_size": 1, "triton.autotune_pointwise": False, # too slow "implicit_fallbacks": False, "freezing": True, "freezing_discard_parameters": True, } ) ) @classmethod def tearDownClass(cls): cls._stack.close() super().tearDownClass() def setUp(self): torch._dynamo.reset() super().setUp() def tearDown(self): super().tearDown() torch._dynamo.reset() class ConvBN(torch.nn.Module): def __init__(self, in_channels, out_channels, bias=False, **kwargs): super().__init__() self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs) self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001, dtype=torch.float) def forward(self, x): return self.bn(self.conv(x)) class ConvBNHardswish(torch.nn.Module): def __init__(self, in_channels, out_channels, bias=False, **kwargs): super().__init__() self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs) self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001, dtype=torch.float) self.hardswish = nn.Hardswish(inplace=True) def forward(self, x): return self.hardswish(self.bn(self.conv(x))) class ConvFunctionalBN(torch.nn.Module): def __init__( self, in_channels, out_channels, bias=False, kernel_size=3, stride=2, running_mean=None, running_var=None, weight=None, bn_bias=None, ): super().__init__() self.conv = torch.nn.Conv2d( in_channels, out_channels, bias=bias, kernel_size=kernel_size, stride=stride ) self.running_mean = running_mean self.running_var = running_var self.weight = weight self.bias = bn_bias def forward(self, x): return torch.nn.functional.batch_norm( self.conv(x), self.running_mean, self.running_var, self.weight, self.bias, False, 0.1, 1e-5, ) class ConvMultiBN(torch.nn.Module): def __init__(self, in_channels, out_channels, bias=False, **kwargs): super().__init__() self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs) self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001, dtype=torch.float) self.bn2 = torch.nn.BatchNorm2d(out_channels, eps=0.1, dtype=torch.float) def forward(self, x): tmp = self.bn(self.conv(x)) tmp2 = self.bn2(self.conv(x)) return tmp + tmp2 class ConvMultiFunctionalBN(torch.nn.Module): def __init__( self, in_channels, out_channels, bias=False, kernel_size=3, stride=2, running_mean=None, running_var=None, weight=None, bn_bias=None, running_mean2=None, ): super().__init__() self.conv = torch.nn.Conv2d( in_channels, out_channels, bias=bias, kernel_size=kernel_size, stride=stride ) self.running_mean = running_mean self.running_var = running_var self.weight = weight self.bias = bn_bias self.running_mean2 = running_mean2 def forward(self, x): tmp = torch.nn.functional.batch_norm( self.conv(x), self.running_mean, self.running_var, self.weight, self.bias, False, 0.1, 1e-5, ) tmp2 = torch.nn.functional.batch_norm( self.conv(x), self.running_mean2, self.running_var, self.weight, self.bias, False, 0.1, 1e-5, ) return tmp + tmp2 class OptimizeForInferenceTemplate(TestCase): def test_mutation(self): class Mod(torch.nn.Module): def __init__(self) -> None: super().__init__() self.mutated_param = torch.nn.Parameter(torch.zeros([10, 10])) def forward(self): self.mutated_param.add_(10) return self.mutated_param with torch.no_grad(): mod = Mod().to(self.device) out_eager = mod() out_eager2 = mod() mod = Mod().to(self.device) @torch.compile def foo(mod): return mod() out_comp = foo(mod) out_comp2 = foo(mod) self.assertEqual(out_eager, out_comp) self.assertEqual(out_eager2, out_comp2) def test_aliased_param_return(self): class Mod(torch.nn.Module): def __init__(self) -> None: super().__init__() self.aliased_param = torch.nn.Parameter(torch.zeros([10, 10])) def forward(self): return self.aliased_param[1:], self.aliased_param mod = Mod().to(self.device).eval() @torch.compile() def foo(mod): return mod() with torch.no_grad(): mod_eager = mod() self.assertEqual(foo(mod), mod_eager) def test_autocast(self): if self.device == "cpu": raise unittest.SkipTest("MLKDNN Bug") mod = torch.nn.Linear(10, 10).to(self.device).eval() inp = torch.rand([10, 10]).to(self.device).to(torch.half) @torch.compile() def foo(mod, inp): return mod(inp) with torch.no_grad(): with self.autocast(): out_eager = mod(inp) out_compiled, code = run_and_get_code(foo, mod, inp) FileCheck().check_not("@triton.jit").run(code[0]) self.assertEqual(out_eager, out_compiled) def test_mm_concat(self): # CPU path will replace mm with mkl._linear, # skip this case for now. if self.device == "cpu": raise unittest.SkipTest("NYI CPU") class MM(torch.nn.Module): def __init__(self) -> None: super().__init__() self.t1 = torch.nn.Parameter(torch.rand(10, 10)) self.t2 = torch.nn.Parameter(torch.rand(10, 10)) self.t3 = torch.nn.Parameter(torch.rand(10, 10)) def forward(self, x): return x @ self.t1, x @ self.t2, x @ self.t3 class MM2(torch.nn.Module): def __init__(self) -> None: super().__init__() self.t1 = torch.nn.Parameter(torch.rand(10, 10)) self.t2 = torch.nn.Parameter(torch.rand(10, 10)) def forward(self, x): return x @ self.t1, x @ self.t2 class AddMM(MM): def __init__(self) -> None: super().__init__() self.b1 = torch.nn.Parameter(torch.rand([10])) self.b2 = torch.nn.Parameter(torch.rand([10])) self.b3 = torch.nn.Parameter(torch.rand([10])) def forward(self, x): return [ aten.addmm(b, x, p) for b, p in [ (self.b1, self.t1), (self.b2, self.t2), (self.b3, self.t3), ] ] for mod_fn in [ lambda: MM().to(self.device), lambda: MM2().to(self.device), lambda: AddMM().to(self.device), ]: mod = mod_fn() inp = torch.rand([10, 10]).to(self.device) @torch.compile() def foo(mod, inp): return mod(inp) kernel_invoke = "kernel_cpp_0" if self.device == "cpu" else "triton.jit" with torch.no_grad(): out_eager = mod(inp) out, code = run_and_get_code(foo, mod, inp) FileCheck().check_not(kernel_invoke).check_count( "mm(", count=1, exactly=True ).run(code[0]) self.assertEqual(out_eager, out) mod2 = mod_fn() mod2.t1 = torch.nn.Parameter(torch.rand([10, 15], device=self.device)) mod2.t2 = torch.nn.Parameter(torch.rand([10, 20], device=self.device)) if hasattr(mod2, "b1"): mod2.b1 = torch.nn.Parameter(torch.rand([15], device=self.device)) mod2.b2 = torch.nn.Parameter(torch.rand([20], device=self.device)) # not fused count = 3 if hasattr(mod2, "t3") else 2 with torch.no_grad(): out_eager = mod2(inp) out, code = run_and_get_code(foo, mod2, inp) FileCheck().check_not(kernel_invoke).check_count( "mm(", count=count, exactly=True ).run(code[0]) self.assertEqual(out_eager, out) # With inlining of inbuilt nn modules, Dynamo traces the innards of inbuilt # module and does not modify the eager module. @torch._dynamo.config.patch(inline_inbuilt_nn_modules=False) def test_error_on_eager(self): mod = ConvBN(3, 32, kernel_size=3, stride=2).eval().to(self.device) x = torch.rand(3, 3, 32, 32).to(self.device) @torch.compile() def foo(mod, x): return mod(x) with torch.no_grad(): foo(mod, x) with self.assertRaisesRegex( RuntimeError, "Trying to run Pytorch Eager Module after Dynamo Freezing" ): mod(x) def test_rng_op(self): @torch.compile() def foo(): return torch.rand([4, 4], device=self.device) + 1 with torch.no_grad(): o1 = foo() o2 = foo() self.assertNotEqual(o1, o2) def test_symint_not_folded(self): def fn(a): return a.cos(), torch.zeros(a.shape[0], a.shape[1]) fn_opt = torch._dynamo.optimize("inductor", dynamic=True)(fn) inp = torch.randn(2, 4, 6).to(self.device) torch._dynamo.mark_dynamic(inp, 0) torch._dynamo.mark_dynamic(inp, 1) with torch.no_grad(): self.assertEqual(fn(inp), fn_opt(inp)) inp2 = torch.randn(3, 5, 6).to(self.device) torch._dynamo.mark_dynamic(inp2, 0) torch._dynamo.mark_dynamic(inp2, 1) self.assertEqual(fn(inp2), fn_opt(inp2)) @requires_cuda def test_conv_multiple_uses(self): from torch import nn class ToyModel(nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.conv1 = nn.Conv2d(1, 1, 1) self.bn1 = nn.BatchNorm2d(1) self.bn1.weight.data.normal_() def forward(self, x, y): return self.conv1(x) + self.bn1(self.conv1(y)) model = ToyModel() model.eval().cuda() a = torch.rand(64, 1, 32, 32).cuda() b = torch.rand(64, 1, 32, 32).cuda() output = model(a, b) with torch.no_grad(): output2 = torch.compile(model)(a, b) self.assertEqual(output, output2) def test_unfolded_bn(self): x = torch.rand([3, 32, 15, 15]).to(self.device) mod = torch.nn.BatchNorm2d(32, eps=0.001).eval().to(self.device) @torch.compile() def foo(mod, x): return mod(x) + 10 out_compiled_no_inference = foo(mod, x) # would error if not decomposed with torch.no_grad(): out_compiled = foo(mod, x) self.assertEqual(out_compiled_no_inference, out_compiled) @torch._inductor.config.patch(layout_optimization=False) def test_folded_conv_bn(self): for use_bias, dtype in itertools.product( [True, False], [torch.float16, torch.bfloat16, torch.float32] ): if self.device == "cpu" and dtype == torch.float16: continue if self.device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: continue mod = ( ConvBN(3, 32, bias=use_bias, kernel_size=3, stride=2) .eval() .to(self.device) .to(dtype) ) x = torch.rand(3, 3, 32, 32).to(self.device).to(dtype) torch._dynamo.reset() counters.clear() @torch.compile() def foo(mod, x): return mod(x) # TODO - bias is separate kernel right now, we should only unfuse it # from conv if it can be fused with torch.no_grad(): out_eager = mod(x) out_optimized_for_infernece, code = run_and_get_code(foo, mod, x) # we unfuse the conv bias, but it should only have one constant in the kernel if self.device == "cuda": FileCheck().check_not(".run(").check("conv").check(".run(").check_same( "frozen_param" ).check_not("frozen_param").check_next("return").run(code[0]) self.assertEqual( out_optimized_for_infernece, out_eager, atol=1e-2, rtol=1e-2 ) self.assertEqual(counters["inductor"]["binary_folding"], 4) @torch._inductor.config.patch(layout_optimization=False) def test_folded_conv_bn_hardswish(self): for use_bias, dtype in itertools.product( [True, False], [torch.float16, torch.bfloat16, torch.float32] ): if self.device == "cpu" and dtype == torch.float16: continue if self.device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: continue mod = ( ConvBNHardswish(3, 32, bias=use_bias, kernel_size=3, stride=2) .eval() .to(self.device) .to(dtype) ) x = torch.rand(3, 3, 32, 32).to(self.device).to(dtype) torch._dynamo.reset() counters.clear() @torch.compile() def foo(mod, x): return mod(x) # TODO - bias is separate kernel right now, we should only unfuse it # from conv if it can be fused with torch.no_grad(): out_eager = mod(x) out_optimized_for_infernece, code = run_and_get_code(foo, mod, x) # we unfuse the conv bias, but it should only have one constant in the kernel if self.device == "cuda": FileCheck().check_not(".run(").check("conv").check(".run(").check_same( "frozen_param" ).check_not("frozen_param").check_next("return").run(code[0]) self.assertEqual( out_optimized_for_infernece, out_eager, atol=1e-2, rtol=1e-2 ) self.assertEqual(counters["inductor"]["binary_folding"], 4) @torch._inductor.config.patch(layout_optimization=False) def test_folded_conv_bn_with_module_sharing(self): mod = ( ConvBN(32, 32, bias=True, kernel_size=3, stride=2) .to(self.device) .to(torch.float32) ) # Update the default parameters of BN module for _ in range(10): mod(torch.rand(3, 32, 32, 32).to(self.device).to(torch.float32)) mod.eval() x = torch.rand(3, 32, 32, 32).to(self.device).to(torch.float32) def foo(mod, x): mod(x) return mod(x) with torch.no_grad(): out_eager = foo(mod, x) out_optimized_for_infernece, _ = run_and_get_code( torch.compile(foo), mod, x ) self.assertEqual(out_optimized_for_infernece, out_eager, atol=1e-2, rtol=1e-2) @torch._inductor.config.patch(layout_optimization=False) def test_folded_conv_functional_bn_with_module_sharing(self): x = torch.rand(3, 32, 32, 32).to(self.device).to(torch.float32) running_mean = torch.mean(x, dim=(0, 2, 3)).to(self.device) running_var = torch.var(x, dim=(0, 2, 3)).to(self.device) mod = ( ConvFunctionalBN( 32, 32, bias=True, kernel_size=3, stride=2, running_mean=running_mean, running_var=running_var, weight=torch.ones(32).to(self.device), bn_bias=torch.zeros(32).to(self.device), ) .eval() .to(self.device) .to(torch.float32) ) def foo(mod, x): mod(x) return mod(x) with torch.no_grad(): out_eager = foo(mod, x) out_optimized_for_infernece, _ = run_and_get_code( torch.compile(foo), mod, x ) self.assertEqual(out_optimized_for_infernece, out_eager, atol=1e-2, rtol=1e-2) @torch._inductor.config.patch(layout_optimization=False) def test_conv_bn_with_multi_bn_share_conv(self): mod = ( ConvMultiBN(32, 32, bias=True, kernel_size=3, stride=2) .to(self.device) .to(torch.float32) ) # Update the default parameters of BN module for _ in range(10): mod(torch.rand(3, 32, 32, 32).to(self.device).to(torch.float32)) mod.eval() x = torch.rand(3, 32, 32, 32).to(self.device).to(torch.float32) def foo(mod, x): return mod(x) with torch.no_grad(): out_eager = foo(mod, x) out_optimized_for_infernece, _ = run_and_get_code( torch.compile(foo), mod, x ) self.assertEqual(out_optimized_for_infernece, out_eager, atol=1e-2, rtol=1e-2) @torch._inductor.config.patch(layout_optimization=False) def test_conv_functional_bn_with_multi_bn_share_conv(self): x = torch.rand(3, 32, 32, 32).to(self.device).to(torch.float32) running_mean = torch.mean(x, dim=(0, 2, 3)).to(self.device) running_var = torch.var(x, dim=(0, 2, 3)).to(self.device) running_mean2 = torch.mean(x, dim=(0, 2, 3)).to(self.device) mod = ( ConvMultiFunctionalBN( 32, 32, bias=True, kernel_size=3, stride=2, running_mean=running_mean, running_var=running_var, weight=torch.ones(32).to(self.device), bn_bias=torch.zeros(32).to(self.device), running_mean2=running_mean2, ) .eval() .to(self.device) .to(torch.float32) ) def foo(mod, x): return mod(x) with torch.no_grad(): out_eager = foo(mod, x) out_optimized_for_infernece, _ = run_and_get_code( torch.compile(foo), mod, x ) self.assertEqual(out_optimized_for_infernece, out_eager, atol=1e-2, rtol=1e-2) @torch._inductor.config.patch(layout_optimization=False) def test_dont_change_dtype_folding(self): dtype = torch.float16 if self.device == "cuda" else torch.bfloat16 mod = ( torch.nn.Conv2d(3, 32, bias=None, kernel_size=3, stride=2) .eval() .to(self.device) .to(dtype) ) x = torch.rand(3, 3, 32, 32).to(self.device).to(dtype) def foo(mod, x): return mod(x) * torch.full([1], 2.0, device=self.device) foo_c = torch.compile(foo) with torch.no_grad(): out_eager = foo(mod, x) out_compiled = foo_c(mod, x) self.assertEqual(out_eager, out_compiled) def test_param_deallocated(self): # TODO: cpu path keeps an extra copy of graph around somewhere, # memory not as important for cpu if self.device == "cpu": raise unittest.SkipTest("NYI CPU") class Mod(torch.nn.Module): def __init__(self) -> None: super().__init__() self.param = torch.nn.Parameter(torch.zeros([10, 10])) def forward(self, x): return (self.param + 10) + x mod = Mod().eval().to(self.device) inp = torch.rand([10], device=self.device) with torch.no_grad(): eager = mod(inp) weight_ref = weakref.ref(mod.param) @torch.compile() def foo(mod, inp): return mod(inp) with torch.no_grad(): compiled = foo(mod, inp) self.assertEqual(eager, compiled) self.assertTrue(weight_ref() is None) @skipIfRocm def test_conv_with_as_strided(self): class Model(nn.Module): def __init__(self, groups): super().__init__() self.kv = torch.nn.Conv2d( 256, 384, kernel_size=(1, 1), stride=(1, 1), bias=False, groups=groups, ) def forward(self, x): convolution = self.kv(x) constant_pad_nd = torch.ops.aten.constant_pad_nd.default( convolution, [2, 2, 2, 2], 0.0 ) # as_strided inputs are depend on input's size and stide. as_strided = torch.ops.aten.as_strided.default( constant_pad_nd, [8, 384, 2, 20, 12], [153600, 400, 160, 1, 20] ) as_strided_1 = torch.ops.aten.as_strided.default( as_strided, [8, 384, 2, 2, 12, 12], [153600, 400, 160, 8, 20, 1] ) clone = torch.ops.aten.clone.default( as_strided_1, memory_format=torch.contiguous_format ) return clone @torch.compile() def foo(mod, inp): return mod(inp) with torch.no_grad(): x = torch.randn(8, 256, 16, 16).to(self.device) for groups in [1, 2]: mod = Model(groups).to(self.device).eval() mod_eager = mod(x) self.assertEqual(foo(mod, x), mod_eager) def test_cpp_wrapper(self): mod = ConvBN(3, 32, kernel_size=3, stride=2).eval().to(self.device) x = torch.rand(3, 3, 32, 32).to(self.device) @torch.compile(options={"cpp_wrapper": True}) def foo(mod, x): return mod(x) out_eager = mod(x) with torch.no_grad(): self.assertEqual(foo(mod, x), out_eager) self.assertEqual(foo(mod, x), out_eager) def test_conv_layout_convert_with_view(self): class Model(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = nn.Conv2d( 3, 128, kernel_size=3, padding=1, stride=1, bias=False ) self.bn = nn.BatchNorm2d(3) def forward(self, x): x = self.bn(x) x = self.conv(x) return torch.flatten(x, 1) mod = Model().to(self.device).eval() @torch.compile() def foo(mod, inp): return mod(inp) with torch.no_grad(): x = torch.rand(2, 3, 5, 5).to(self.device) mod_eager = mod(x) self.assertEqual(foo(mod, x), mod_eager) @skipIfRocm def test_conv_weight_layout_convert(self): class Model(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = nn.Conv2d( 3, 128, kernel_size=3, padding=1, stride=1, bias=False ) def forward(self, x): return self.conv(x) @staticmethod def get_example_inputs(): return (torch.rand(2, 3, 5, 5).to(self.device),) from torch._inductor.compile_fx import compile_fx, compile_fx_inner nconv = 0 def my_inner_compile(gm, example_inputs, *args, **kwargs): out = compile_fx_inner(gm, example_inputs, *args, **kwargs) nonlocal nconv convs = [n for n in gm.graph.nodes if n.target == aten.convolution.default] nconv += len(convs) for conv in convs: weight_node = conv.args[1] weight_const_tensor = getattr(gm, weight_node.target) self.assertTrue( weight_const_tensor.is_contiguous(memory_format=torch.channels_last) ) self.assertTrue( weight_node.meta["val"].is_contiguous( memory_format=torch.channels_last ) ) return out mod = torch.compile( Model().eval().to(self.device), backend=functools.partial(compile_fx, inner_compile=my_inner_compile), ) inp = mod.get_example_inputs() with torch.no_grad(): mod(*inp) # Only check the assertion for CUDA. # For CPU, we may get torch.ops.mkldnn._convolution_pointwise.default # in the joint graph rather than torch.ops.aten.convolution.default. # Currently we only handle aten.convolution.default in layout # optimization. That's why the count may be 0 here for CPU. if self.device == "cuda": self.assertTrue(nconv == 1) def test_unequal_bias_horizontal_addmm_fusion(self): device = self.device class Model(torch.nn.Module): def __init__(self) -> None: super().__init__() self.w1 = torch.tensor( [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], device=device ) self.b1 = torch.zeros(3, device=device) self.w2 = torch.tensor( [[0.0, 0.0, 1.0], [0.0, 0.0, 1.0], [0.0, 0.0, 1.0]], device=device ) self.b2 = torch.tensor([[-1.0, -1.0, -1.0]], device=device) self.w3 = torch.tensor( [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], device=device ) self.b3 = torch.tensor([1.0, 2.0, 3.0], device=device) def forward(self, x): out1 = torch.nn.functional.linear(x, self.w1, self.b1) out2 = torch.nn.functional.linear(x, self.w2, self.b2) out3 = torch.nn.functional.linear(x, self.w3, self.b3) return (out1, out2, out3) func = Model().to(device).eval() x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], device=device) with torch.no_grad(): out_eager = func(x.clone()) func1 = torch.compile(func) out_compiled = func1(x.clone()) self.assertEqual(out_eager, out_compiled) @skipIfRocm def test_redundant_clone_for_layout_convert(self): class Model(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = nn.Conv2d( 3, 128, kernel_size=3, padding=1, stride=1, bias=False ) def forward(self, x): y = x + 1 return self.conv(x), y @staticmethod def get_example_inputs(): return (torch.rand(2, 3, 5, 5).to(self.device),) mod = Model().eval().to(self.device) inp = mod.get_example_inputs() with torch.no_grad(): expected_outputs = mod(*inp) num_same_stride = 0 num_diff_stride = 0 def debug_inductor_force_stride_order(orig_fn, input_tensor, stride): nonlocal num_same_stride, num_diff_stride input_tensor.realize() if tuple(input_tensor.get_stride()) == tuple(stride): num_same_stride += 1 else: num_diff_stride += 1 return orig_fn(input_tensor, stride) with override_lowering( prims.inductor_force_stride_order.default, debug_inductor_force_stride_order ): opt_mod = torch.compile(mod) with torch.no_grad(): actual_outputs = opt_mod(*inp) self.assertEqual(len(actual_outputs), len(expected_outputs)) self.assertEqual(2, len(actual_outputs)) for i, actual, expected in zip( itertools.count(), actual_outputs, expected_outputs ): self.assertTrue( torch.allclose(expected, actual, atol=1e-4, rtol=1e-4), f"{i}th output: expected {expected}, actual {actual}", ) if self.device == "cpu": # CPU use different convolution implementation, skip the checks below return self.assertTrue( actual_outputs[0].is_contiguous(memory_format=torch.contiguous_format) ) self.assertTrue( actual_outputs[1].is_contiguous(memory_format=torch.contiguous_format) ) # we don't change the stride of y returned by forward. So there will # be no extra copy self.assertTrue(num_same_stride == 1, f"num_same_stride is {num_same_stride}") # we changed the stride of self.conv(x) returned by forward. So there # may be an extra copy self.assertTrue(num_diff_stride == 1, f"num_diff_stride is {num_diff_stride}") if TEST_WITH_ROCM: torch._inductor.config.force_layout_optimization = 1 os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1" if HAS_CPU and not torch.backends.mps.is_available(): class FreezingCpuTests(TestCase): common = check_model device = "cpu" autocast = torch.cpu.amp.autocast copy_tests(OptimizeForInferenceTemplate, FreezingCpuTests, "cpu") if HAS_CUDA and not TEST_WITH_ASAN: class FreezingCudaTests(TestCase): common = check_model_cuda device = "cuda" autocast = torch.cuda.amp.autocast copy_tests(OptimizeForInferenceTemplate, FreezingCudaTests, "cuda") del OptimizeForInferenceTemplate if __name__ == "__main__": from torch._inductor.test_case import run_tests if HAS_CPU or HAS_CUDA: run_tests(needs="filelock")