• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["module: inductor"]
2import contextlib
3import functools
4import importlib
5import itertools
6import os
7import sys
8import unittest
9import weakref
10
11import torch
12from torch import nn
13from torch._dynamo.utils import counters
14from torch._inductor import config
15from torch._inductor.test_case import TestCase as InductorTestCase
16from torch._inductor.utils import override_lowering, run_and_get_code
17from torch.testing import FileCheck
18from torch.testing._internal.common_cuda import SM80OrLater
19from torch.testing._internal.common_utils import skipIfRocm
20
21
22# Make the helper files in test/ importable
23pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
24sys.path.append(pytorch_test_dir)
25
26from inductor.test_torchinductor import check_model, check_model_cuda, copy_tests
27from torch.testing._internal.common_utils import TEST_WITH_ASAN, TEST_WITH_ROCM
28
29
30importlib.import_module("functorch")
31importlib.import_module("filelock")
32
33from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
34
35
36aten = torch.ops.aten
37prims = torch.ops.prims
38requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
39
40
41class TestCase(InductorTestCase):
42    @classmethod
43    def setUpClass(cls):
44        super().setUpClass()
45        cls._stack = contextlib.ExitStack()
46        cls._stack.enter_context(
47            config.patch(
48                {
49                    "debug": True,
50                    "cpp.min_chunk_size": 1,
51                    "triton.autotune_pointwise": False,  # too slow
52                    "implicit_fallbacks": False,
53                    "freezing": True,
54                    "freezing_discard_parameters": True,
55                }
56            )
57        )
58
59    @classmethod
60    def tearDownClass(cls):
61        cls._stack.close()
62        super().tearDownClass()
63
64    def setUp(self):
65        torch._dynamo.reset()
66        super().setUp()
67
68    def tearDown(self):
69        super().tearDown()
70        torch._dynamo.reset()
71
72
73class ConvBN(torch.nn.Module):
74    def __init__(self, in_channels, out_channels, bias=False, **kwargs):
75        super().__init__()
76        self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs)
77        self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001, dtype=torch.float)
78
79    def forward(self, x):
80        return self.bn(self.conv(x))
81
82
83class ConvBNHardswish(torch.nn.Module):
84    def __init__(self, in_channels, out_channels, bias=False, **kwargs):
85        super().__init__()
86        self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs)
87        self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001, dtype=torch.float)
88        self.hardswish = nn.Hardswish(inplace=True)
89
90    def forward(self, x):
91        return self.hardswish(self.bn(self.conv(x)))
92
93
94class ConvFunctionalBN(torch.nn.Module):
95    def __init__(
96        self,
97        in_channels,
98        out_channels,
99        bias=False,
100        kernel_size=3,
101        stride=2,
102        running_mean=None,
103        running_var=None,
104        weight=None,
105        bn_bias=None,
106    ):
107        super().__init__()
108        self.conv = torch.nn.Conv2d(
109            in_channels, out_channels, bias=bias, kernel_size=kernel_size, stride=stride
110        )
111        self.running_mean = running_mean
112        self.running_var = running_var
113        self.weight = weight
114        self.bias = bn_bias
115
116    def forward(self, x):
117        return torch.nn.functional.batch_norm(
118            self.conv(x),
119            self.running_mean,
120            self.running_var,
121            self.weight,
122            self.bias,
123            False,
124            0.1,
125            1e-5,
126        )
127
128
129class ConvMultiBN(torch.nn.Module):
130    def __init__(self, in_channels, out_channels, bias=False, **kwargs):
131        super().__init__()
132        self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs)
133        self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001, dtype=torch.float)
134        self.bn2 = torch.nn.BatchNorm2d(out_channels, eps=0.1, dtype=torch.float)
135
136    def forward(self, x):
137        tmp = self.bn(self.conv(x))
138        tmp2 = self.bn2(self.conv(x))
139        return tmp + tmp2
140
141
142class ConvMultiFunctionalBN(torch.nn.Module):
143    def __init__(
144        self,
145        in_channels,
146        out_channels,
147        bias=False,
148        kernel_size=3,
149        stride=2,
150        running_mean=None,
151        running_var=None,
152        weight=None,
153        bn_bias=None,
154        running_mean2=None,
155    ):
156        super().__init__()
157        self.conv = torch.nn.Conv2d(
158            in_channels, out_channels, bias=bias, kernel_size=kernel_size, stride=stride
159        )
160        self.running_mean = running_mean
161        self.running_var = running_var
162        self.weight = weight
163        self.bias = bn_bias
164        self.running_mean2 = running_mean2
165
166    def forward(self, x):
167        tmp = torch.nn.functional.batch_norm(
168            self.conv(x),
169            self.running_mean,
170            self.running_var,
171            self.weight,
172            self.bias,
173            False,
174            0.1,
175            1e-5,
176        )
177        tmp2 = torch.nn.functional.batch_norm(
178            self.conv(x),
179            self.running_mean2,
180            self.running_var,
181            self.weight,
182            self.bias,
183            False,
184            0.1,
185            1e-5,
186        )
187        return tmp + tmp2
188
189
190class OptimizeForInferenceTemplate(TestCase):
191    def test_mutation(self):
192        class Mod(torch.nn.Module):
193            def __init__(self) -> None:
194                super().__init__()
195                self.mutated_param = torch.nn.Parameter(torch.zeros([10, 10]))
196
197            def forward(self):
198                self.mutated_param.add_(10)
199                return self.mutated_param
200
201        with torch.no_grad():
202            mod = Mod().to(self.device)
203            out_eager = mod()
204            out_eager2 = mod()
205
206            mod = Mod().to(self.device)
207
208            @torch.compile
209            def foo(mod):
210                return mod()
211
212            out_comp = foo(mod)
213            out_comp2 = foo(mod)
214
215            self.assertEqual(out_eager, out_comp)
216            self.assertEqual(out_eager2, out_comp2)
217
218    def test_aliased_param_return(self):
219        class Mod(torch.nn.Module):
220            def __init__(self) -> None:
221                super().__init__()
222                self.aliased_param = torch.nn.Parameter(torch.zeros([10, 10]))
223
224            def forward(self):
225                return self.aliased_param[1:], self.aliased_param
226
227        mod = Mod().to(self.device).eval()
228
229        @torch.compile()
230        def foo(mod):
231            return mod()
232
233        with torch.no_grad():
234            mod_eager = mod()
235            self.assertEqual(foo(mod), mod_eager)
236
237    def test_autocast(self):
238        if self.device == "cpu":
239            raise unittest.SkipTest("MLKDNN Bug")
240
241        mod = torch.nn.Linear(10, 10).to(self.device).eval()
242        inp = torch.rand([10, 10]).to(self.device).to(torch.half)
243
244        @torch.compile()
245        def foo(mod, inp):
246            return mod(inp)
247
248        with torch.no_grad():
249            with self.autocast():
250                out_eager = mod(inp)
251                out_compiled, code = run_and_get_code(foo, mod, inp)
252
253                FileCheck().check_not("@triton.jit").run(code[0])
254                self.assertEqual(out_eager, out_compiled)
255
256    def test_mm_concat(self):
257        # CPU path will replace mm with mkl._linear,
258        # skip this case for now.
259        if self.device == "cpu":
260            raise unittest.SkipTest("NYI CPU")
261
262        class MM(torch.nn.Module):
263            def __init__(self) -> None:
264                super().__init__()
265
266                self.t1 = torch.nn.Parameter(torch.rand(10, 10))
267                self.t2 = torch.nn.Parameter(torch.rand(10, 10))
268                self.t3 = torch.nn.Parameter(torch.rand(10, 10))
269
270            def forward(self, x):
271                return x @ self.t1, x @ self.t2, x @ self.t3
272
273        class MM2(torch.nn.Module):
274            def __init__(self) -> None:
275                super().__init__()
276
277                self.t1 = torch.nn.Parameter(torch.rand(10, 10))
278                self.t2 = torch.nn.Parameter(torch.rand(10, 10))
279
280            def forward(self, x):
281                return x @ self.t1, x @ self.t2
282
283        class AddMM(MM):
284            def __init__(self) -> None:
285                super().__init__()
286
287                self.b1 = torch.nn.Parameter(torch.rand([10]))
288                self.b2 = torch.nn.Parameter(torch.rand([10]))
289                self.b3 = torch.nn.Parameter(torch.rand([10]))
290
291            def forward(self, x):
292                return [
293                    aten.addmm(b, x, p)
294                    for b, p in [
295                        (self.b1, self.t1),
296                        (self.b2, self.t2),
297                        (self.b3, self.t3),
298                    ]
299                ]
300
301        for mod_fn in [
302            lambda: MM().to(self.device),
303            lambda: MM2().to(self.device),
304            lambda: AddMM().to(self.device),
305        ]:
306            mod = mod_fn()
307            inp = torch.rand([10, 10]).to(self.device)
308
309            @torch.compile()
310            def foo(mod, inp):
311                return mod(inp)
312
313            kernel_invoke = "kernel_cpp_0" if self.device == "cpu" else "triton.jit"
314
315            with torch.no_grad():
316                out_eager = mod(inp)
317                out, code = run_and_get_code(foo, mod, inp)
318                FileCheck().check_not(kernel_invoke).check_count(
319                    "mm(", count=1, exactly=True
320                ).run(code[0])
321                self.assertEqual(out_eager, out)
322
323            mod2 = mod_fn()
324            mod2.t1 = torch.nn.Parameter(torch.rand([10, 15], device=self.device))
325            mod2.t2 = torch.nn.Parameter(torch.rand([10, 20], device=self.device))
326
327            if hasattr(mod2, "b1"):
328                mod2.b1 = torch.nn.Parameter(torch.rand([15], device=self.device))
329                mod2.b2 = torch.nn.Parameter(torch.rand([20], device=self.device))
330
331            # not fused
332            count = 3 if hasattr(mod2, "t3") else 2
333
334            with torch.no_grad():
335                out_eager = mod2(inp)
336                out, code = run_and_get_code(foo, mod2, inp)
337                FileCheck().check_not(kernel_invoke).check_count(
338                    "mm(", count=count, exactly=True
339                ).run(code[0])
340                self.assertEqual(out_eager, out)
341
342    # With inlining of inbuilt nn modules, Dynamo traces the innards of inbuilt
343    # module and does not modify the eager module.
344    @torch._dynamo.config.patch(inline_inbuilt_nn_modules=False)
345    def test_error_on_eager(self):
346        mod = ConvBN(3, 32, kernel_size=3, stride=2).eval().to(self.device)
347
348        x = torch.rand(3, 3, 32, 32).to(self.device)
349
350        @torch.compile()
351        def foo(mod, x):
352            return mod(x)
353
354        with torch.no_grad():
355            foo(mod, x)
356
357        with self.assertRaisesRegex(
358            RuntimeError, "Trying to run Pytorch Eager Module after Dynamo Freezing"
359        ):
360            mod(x)
361
362    def test_rng_op(self):
363        @torch.compile()
364        def foo():
365            return torch.rand([4, 4], device=self.device) + 1
366
367        with torch.no_grad():
368            o1 = foo()
369            o2 = foo()
370            self.assertNotEqual(o1, o2)
371
372    def test_symint_not_folded(self):
373        def fn(a):
374            return a.cos(), torch.zeros(a.shape[0], a.shape[1])
375
376        fn_opt = torch._dynamo.optimize("inductor", dynamic=True)(fn)
377        inp = torch.randn(2, 4, 6).to(self.device)
378        torch._dynamo.mark_dynamic(inp, 0)
379        torch._dynamo.mark_dynamic(inp, 1)
380
381        with torch.no_grad():
382            self.assertEqual(fn(inp), fn_opt(inp))
383            inp2 = torch.randn(3, 5, 6).to(self.device)
384            torch._dynamo.mark_dynamic(inp2, 0)
385            torch._dynamo.mark_dynamic(inp2, 1)
386            self.assertEqual(fn(inp2), fn_opt(inp2))
387
388    @requires_cuda
389    def test_conv_multiple_uses(self):
390        from torch import nn
391
392        class ToyModel(nn.Module):
393            def __init__(self, *args, **kwargs) -> None:
394                super().__init__(*args, **kwargs)
395                self.conv1 = nn.Conv2d(1, 1, 1)
396                self.bn1 = nn.BatchNorm2d(1)
397                self.bn1.weight.data.normal_()
398
399            def forward(self, x, y):
400                return self.conv1(x) + self.bn1(self.conv1(y))
401
402        model = ToyModel()
403        model.eval().cuda()
404
405        a = torch.rand(64, 1, 32, 32).cuda()
406        b = torch.rand(64, 1, 32, 32).cuda()
407
408        output = model(a, b)
409
410        with torch.no_grad():
411            output2 = torch.compile(model)(a, b)
412
413        self.assertEqual(output, output2)
414
415    def test_unfolded_bn(self):
416        x = torch.rand([3, 32, 15, 15]).to(self.device)
417
418        mod = torch.nn.BatchNorm2d(32, eps=0.001).eval().to(self.device)
419
420        @torch.compile()
421        def foo(mod, x):
422            return mod(x) + 10
423
424        out_compiled_no_inference = foo(mod, x)
425
426        # would error if not decomposed
427        with torch.no_grad():
428            out_compiled = foo(mod, x)
429
430            self.assertEqual(out_compiled_no_inference, out_compiled)
431
432    @torch._inductor.config.patch(layout_optimization=False)
433    def test_folded_conv_bn(self):
434        for use_bias, dtype in itertools.product(
435            [True, False], [torch.float16, torch.bfloat16, torch.float32]
436        ):
437            if self.device == "cpu" and dtype == torch.float16:
438                continue
439
440            if self.device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
441                continue
442
443            mod = (
444                ConvBN(3, 32, bias=use_bias, kernel_size=3, stride=2)
445                .eval()
446                .to(self.device)
447                .to(dtype)
448            )
449
450            x = torch.rand(3, 3, 32, 32).to(self.device).to(dtype)
451
452            torch._dynamo.reset()
453            counters.clear()
454
455            @torch.compile()
456            def foo(mod, x):
457                return mod(x)
458
459            # TODO - bias is separate kernel right now, we should only unfuse it
460            # from conv if it can be fused
461
462            with torch.no_grad():
463                out_eager = mod(x)
464                out_optimized_for_infernece, code = run_and_get_code(foo, mod, x)
465
466            # we unfuse the conv bias, but it should only have one constant in the kernel
467            if self.device == "cuda":
468                FileCheck().check_not(".run(").check("conv").check(".run(").check_same(
469                    "frozen_param"
470                ).check_not("frozen_param").check_next("return").run(code[0])
471
472            self.assertEqual(
473                out_optimized_for_infernece, out_eager, atol=1e-2, rtol=1e-2
474            )
475            self.assertEqual(counters["inductor"]["binary_folding"], 4)
476
477    @torch._inductor.config.patch(layout_optimization=False)
478    def test_folded_conv_bn_hardswish(self):
479        for use_bias, dtype in itertools.product(
480            [True, False], [torch.float16, torch.bfloat16, torch.float32]
481        ):
482            if self.device == "cpu" and dtype == torch.float16:
483                continue
484
485            if self.device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
486                continue
487
488            mod = (
489                ConvBNHardswish(3, 32, bias=use_bias, kernel_size=3, stride=2)
490                .eval()
491                .to(self.device)
492                .to(dtype)
493            )
494
495            x = torch.rand(3, 3, 32, 32).to(self.device).to(dtype)
496
497            torch._dynamo.reset()
498            counters.clear()
499
500            @torch.compile()
501            def foo(mod, x):
502                return mod(x)
503
504            # TODO - bias is separate kernel right now, we should only unfuse it
505            # from conv if it can be fused
506
507            with torch.no_grad():
508                out_eager = mod(x)
509                out_optimized_for_infernece, code = run_and_get_code(foo, mod, x)
510
511            # we unfuse the conv bias, but it should only have one constant in the kernel
512            if self.device == "cuda":
513                FileCheck().check_not(".run(").check("conv").check(".run(").check_same(
514                    "frozen_param"
515                ).check_not("frozen_param").check_next("return").run(code[0])
516
517            self.assertEqual(
518                out_optimized_for_infernece, out_eager, atol=1e-2, rtol=1e-2
519            )
520            self.assertEqual(counters["inductor"]["binary_folding"], 4)
521
522    @torch._inductor.config.patch(layout_optimization=False)
523    def test_folded_conv_bn_with_module_sharing(self):
524        mod = (
525            ConvBN(32, 32, bias=True, kernel_size=3, stride=2)
526            .to(self.device)
527            .to(torch.float32)
528        )
529
530        # Update the default parameters of BN module
531        for _ in range(10):
532            mod(torch.rand(3, 32, 32, 32).to(self.device).to(torch.float32))
533
534        mod.eval()
535        x = torch.rand(3, 32, 32, 32).to(self.device).to(torch.float32)
536
537        def foo(mod, x):
538            mod(x)
539            return mod(x)
540
541        with torch.no_grad():
542            out_eager = foo(mod, x)
543            out_optimized_for_infernece, _ = run_and_get_code(
544                torch.compile(foo), mod, x
545            )
546
547        self.assertEqual(out_optimized_for_infernece, out_eager, atol=1e-2, rtol=1e-2)
548
549    @torch._inductor.config.patch(layout_optimization=False)
550    def test_folded_conv_functional_bn_with_module_sharing(self):
551        x = torch.rand(3, 32, 32, 32).to(self.device).to(torch.float32)
552        running_mean = torch.mean(x, dim=(0, 2, 3)).to(self.device)
553        running_var = torch.var(x, dim=(0, 2, 3)).to(self.device)
554
555        mod = (
556            ConvFunctionalBN(
557                32,
558                32,
559                bias=True,
560                kernel_size=3,
561                stride=2,
562                running_mean=running_mean,
563                running_var=running_var,
564                weight=torch.ones(32).to(self.device),
565                bn_bias=torch.zeros(32).to(self.device),
566            )
567            .eval()
568            .to(self.device)
569            .to(torch.float32)
570        )
571
572        def foo(mod, x):
573            mod(x)
574            return mod(x)
575
576        with torch.no_grad():
577            out_eager = foo(mod, x)
578            out_optimized_for_infernece, _ = run_and_get_code(
579                torch.compile(foo), mod, x
580            )
581
582        self.assertEqual(out_optimized_for_infernece, out_eager, atol=1e-2, rtol=1e-2)
583
584    @torch._inductor.config.patch(layout_optimization=False)
585    def test_conv_bn_with_multi_bn_share_conv(self):
586        mod = (
587            ConvMultiBN(32, 32, bias=True, kernel_size=3, stride=2)
588            .to(self.device)
589            .to(torch.float32)
590        )
591
592        # Update the default parameters of BN module
593        for _ in range(10):
594            mod(torch.rand(3, 32, 32, 32).to(self.device).to(torch.float32))
595
596        mod.eval()
597        x = torch.rand(3, 32, 32, 32).to(self.device).to(torch.float32)
598
599        def foo(mod, x):
600            return mod(x)
601
602        with torch.no_grad():
603            out_eager = foo(mod, x)
604            out_optimized_for_infernece, _ = run_and_get_code(
605                torch.compile(foo), mod, x
606            )
607
608        self.assertEqual(out_optimized_for_infernece, out_eager, atol=1e-2, rtol=1e-2)
609
610    @torch._inductor.config.patch(layout_optimization=False)
611    def test_conv_functional_bn_with_multi_bn_share_conv(self):
612        x = torch.rand(3, 32, 32, 32).to(self.device).to(torch.float32)
613        running_mean = torch.mean(x, dim=(0, 2, 3)).to(self.device)
614        running_var = torch.var(x, dim=(0, 2, 3)).to(self.device)
615        running_mean2 = torch.mean(x, dim=(0, 2, 3)).to(self.device)
616
617        mod = (
618            ConvMultiFunctionalBN(
619                32,
620                32,
621                bias=True,
622                kernel_size=3,
623                stride=2,
624                running_mean=running_mean,
625                running_var=running_var,
626                weight=torch.ones(32).to(self.device),
627                bn_bias=torch.zeros(32).to(self.device),
628                running_mean2=running_mean2,
629            )
630            .eval()
631            .to(self.device)
632            .to(torch.float32)
633        )
634
635        def foo(mod, x):
636            return mod(x)
637
638        with torch.no_grad():
639            out_eager = foo(mod, x)
640            out_optimized_for_infernece, _ = run_and_get_code(
641                torch.compile(foo), mod, x
642            )
643        self.assertEqual(out_optimized_for_infernece, out_eager, atol=1e-2, rtol=1e-2)
644
645    @torch._inductor.config.patch(layout_optimization=False)
646    def test_dont_change_dtype_folding(self):
647        dtype = torch.float16 if self.device == "cuda" else torch.bfloat16
648
649        mod = (
650            torch.nn.Conv2d(3, 32, bias=None, kernel_size=3, stride=2)
651            .eval()
652            .to(self.device)
653            .to(dtype)
654        )
655        x = torch.rand(3, 3, 32, 32).to(self.device).to(dtype)
656
657        def foo(mod, x):
658            return mod(x) * torch.full([1], 2.0, device=self.device)
659
660        foo_c = torch.compile(foo)
661
662        with torch.no_grad():
663            out_eager = foo(mod, x)
664            out_compiled = foo_c(mod, x)
665            self.assertEqual(out_eager, out_compiled)
666
667    def test_param_deallocated(self):
668        # TODO: cpu path keeps an extra copy of graph around somewhere,
669        # memory not as important for cpu
670        if self.device == "cpu":
671            raise unittest.SkipTest("NYI CPU")
672
673        class Mod(torch.nn.Module):
674            def __init__(self) -> None:
675                super().__init__()
676                self.param = torch.nn.Parameter(torch.zeros([10, 10]))
677
678            def forward(self, x):
679                return (self.param + 10) + x
680
681        mod = Mod().eval().to(self.device)
682        inp = torch.rand([10], device=self.device)
683
684        with torch.no_grad():
685            eager = mod(inp)
686
687        weight_ref = weakref.ref(mod.param)
688
689        @torch.compile()
690        def foo(mod, inp):
691            return mod(inp)
692
693        with torch.no_grad():
694            compiled = foo(mod, inp)
695
696        self.assertEqual(eager, compiled)
697        self.assertTrue(weight_ref() is None)
698
699    @skipIfRocm
700    def test_conv_with_as_strided(self):
701        class Model(nn.Module):
702            def __init__(self, groups):
703                super().__init__()
704                self.kv = torch.nn.Conv2d(
705                    256,
706                    384,
707                    kernel_size=(1, 1),
708                    stride=(1, 1),
709                    bias=False,
710                    groups=groups,
711                )
712
713            def forward(self, x):
714                convolution = self.kv(x)
715                constant_pad_nd = torch.ops.aten.constant_pad_nd.default(
716                    convolution, [2, 2, 2, 2], 0.0
717                )
718                # as_strided inputs are depend on input's size and stide.
719                as_strided = torch.ops.aten.as_strided.default(
720                    constant_pad_nd, [8, 384, 2, 20, 12], [153600, 400, 160, 1, 20]
721                )
722                as_strided_1 = torch.ops.aten.as_strided.default(
723                    as_strided, [8, 384, 2, 2, 12, 12], [153600, 400, 160, 8, 20, 1]
724                )
725                clone = torch.ops.aten.clone.default(
726                    as_strided_1, memory_format=torch.contiguous_format
727                )
728                return clone
729
730        @torch.compile()
731        def foo(mod, inp):
732            return mod(inp)
733
734        with torch.no_grad():
735            x = torch.randn(8, 256, 16, 16).to(self.device)
736            for groups in [1, 2]:
737                mod = Model(groups).to(self.device).eval()
738                mod_eager = mod(x)
739                self.assertEqual(foo(mod, x), mod_eager)
740
741    def test_cpp_wrapper(self):
742        mod = ConvBN(3, 32, kernel_size=3, stride=2).eval().to(self.device)
743
744        x = torch.rand(3, 3, 32, 32).to(self.device)
745
746        @torch.compile(options={"cpp_wrapper": True})
747        def foo(mod, x):
748            return mod(x)
749
750        out_eager = mod(x)
751
752        with torch.no_grad():
753            self.assertEqual(foo(mod, x), out_eager)
754            self.assertEqual(foo(mod, x), out_eager)
755
756    def test_conv_layout_convert_with_view(self):
757        class Model(torch.nn.Module):
758            def __init__(self) -> None:
759                super().__init__()
760                self.conv = nn.Conv2d(
761                    3, 128, kernel_size=3, padding=1, stride=1, bias=False
762                )
763                self.bn = nn.BatchNorm2d(3)
764
765            def forward(self, x):
766                x = self.bn(x)
767                x = self.conv(x)
768                return torch.flatten(x, 1)
769
770        mod = Model().to(self.device).eval()
771
772        @torch.compile()
773        def foo(mod, inp):
774            return mod(inp)
775
776        with torch.no_grad():
777            x = torch.rand(2, 3, 5, 5).to(self.device)
778            mod_eager = mod(x)
779            self.assertEqual(foo(mod, x), mod_eager)
780
781    @skipIfRocm
782    def test_conv_weight_layout_convert(self):
783        class Model(torch.nn.Module):
784            def __init__(self) -> None:
785                super().__init__()
786                self.conv = nn.Conv2d(
787                    3, 128, kernel_size=3, padding=1, stride=1, bias=False
788                )
789
790            def forward(self, x):
791                return self.conv(x)
792
793            @staticmethod
794            def get_example_inputs():
795                return (torch.rand(2, 3, 5, 5).to(self.device),)
796
797        from torch._inductor.compile_fx import compile_fx, compile_fx_inner
798
799        nconv = 0
800
801        def my_inner_compile(gm, example_inputs, *args, **kwargs):
802            out = compile_fx_inner(gm, example_inputs, *args, **kwargs)
803
804            nonlocal nconv
805            convs = [n for n in gm.graph.nodes if n.target == aten.convolution.default]
806            nconv += len(convs)
807            for conv in convs:
808                weight_node = conv.args[1]
809                weight_const_tensor = getattr(gm, weight_node.target)
810                self.assertTrue(
811                    weight_const_tensor.is_contiguous(memory_format=torch.channels_last)
812                )
813                self.assertTrue(
814                    weight_node.meta["val"].is_contiguous(
815                        memory_format=torch.channels_last
816                    )
817                )
818
819            return out
820
821        mod = torch.compile(
822            Model().eval().to(self.device),
823            backend=functools.partial(compile_fx, inner_compile=my_inner_compile),
824        )
825        inp = mod.get_example_inputs()
826        with torch.no_grad():
827            mod(*inp)
828
829        # Only check the assertion for CUDA.
830        # For CPU, we may get torch.ops.mkldnn._convolution_pointwise.default
831        # in the joint graph rather than torch.ops.aten.convolution.default.
832        # Currently we only handle aten.convolution.default in layout
833        # optimization. That's why the count may be 0 here for CPU.
834        if self.device == "cuda":
835            self.assertTrue(nconv == 1)
836
837    def test_unequal_bias_horizontal_addmm_fusion(self):
838        device = self.device
839
840        class Model(torch.nn.Module):
841            def __init__(self) -> None:
842                super().__init__()
843                self.w1 = torch.tensor(
844                    [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], device=device
845                )
846                self.b1 = torch.zeros(3, device=device)
847                self.w2 = torch.tensor(
848                    [[0.0, 0.0, 1.0], [0.0, 0.0, 1.0], [0.0, 0.0, 1.0]], device=device
849                )
850                self.b2 = torch.tensor([[-1.0, -1.0, -1.0]], device=device)
851                self.w3 = torch.tensor(
852                    [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], device=device
853                )
854                self.b3 = torch.tensor([1.0, 2.0, 3.0], device=device)
855
856            def forward(self, x):
857                out1 = torch.nn.functional.linear(x, self.w1, self.b1)
858                out2 = torch.nn.functional.linear(x, self.w2, self.b2)
859                out3 = torch.nn.functional.linear(x, self.w3, self.b3)
860                return (out1, out2, out3)
861
862        func = Model().to(device).eval()
863        x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], device=device)
864
865        with torch.no_grad():
866            out_eager = func(x.clone())
867
868            func1 = torch.compile(func)
869            out_compiled = func1(x.clone())
870            self.assertEqual(out_eager, out_compiled)
871
872    @skipIfRocm
873    def test_redundant_clone_for_layout_convert(self):
874        class Model(torch.nn.Module):
875            def __init__(self) -> None:
876                super().__init__()
877                self.conv = nn.Conv2d(
878                    3, 128, kernel_size=3, padding=1, stride=1, bias=False
879                )
880
881            def forward(self, x):
882                y = x + 1
883                return self.conv(x), y
884
885            @staticmethod
886            def get_example_inputs():
887                return (torch.rand(2, 3, 5, 5).to(self.device),)
888
889        mod = Model().eval().to(self.device)
890        inp = mod.get_example_inputs()
891        with torch.no_grad():
892            expected_outputs = mod(*inp)
893
894        num_same_stride = 0
895        num_diff_stride = 0
896
897        def debug_inductor_force_stride_order(orig_fn, input_tensor, stride):
898            nonlocal num_same_stride, num_diff_stride
899            input_tensor.realize()
900            if tuple(input_tensor.get_stride()) == tuple(stride):
901                num_same_stride += 1
902            else:
903                num_diff_stride += 1
904            return orig_fn(input_tensor, stride)
905
906        with override_lowering(
907            prims.inductor_force_stride_order.default, debug_inductor_force_stride_order
908        ):
909            opt_mod = torch.compile(mod)
910            with torch.no_grad():
911                actual_outputs = opt_mod(*inp)
912
913        self.assertEqual(len(actual_outputs), len(expected_outputs))
914        self.assertEqual(2, len(actual_outputs))
915        for i, actual, expected in zip(
916            itertools.count(), actual_outputs, expected_outputs
917        ):
918            self.assertTrue(
919                torch.allclose(expected, actual, atol=1e-4, rtol=1e-4),
920                f"{i}th output: expected {expected}, actual {actual}",
921            )
922
923        if self.device == "cpu":
924            # CPU use different convolution implementation, skip the checks below
925            return
926
927        self.assertTrue(
928            actual_outputs[0].is_contiguous(memory_format=torch.contiguous_format)
929        )
930        self.assertTrue(
931            actual_outputs[1].is_contiguous(memory_format=torch.contiguous_format)
932        )
933
934        # we don't change the stride of y returned by forward. So there will
935        # be no extra copy
936        self.assertTrue(num_same_stride == 1, f"num_same_stride is {num_same_stride}")
937        # we changed the stride of self.conv(x) returned by forward. So there
938        # may be an extra copy
939        self.assertTrue(num_diff_stride == 1, f"num_diff_stride is {num_diff_stride}")
940
941
942if TEST_WITH_ROCM:
943    torch._inductor.config.force_layout_optimization = 1
944    os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1"
945
946if HAS_CPU and not torch.backends.mps.is_available():
947
948    class FreezingCpuTests(TestCase):
949        common = check_model
950        device = "cpu"
951        autocast = torch.cpu.amp.autocast
952
953    copy_tests(OptimizeForInferenceTemplate, FreezingCpuTests, "cpu")
954
955if HAS_CUDA and not TEST_WITH_ASAN:
956
957    class FreezingCudaTests(TestCase):
958        common = check_model_cuda
959        device = "cuda"
960        autocast = torch.cuda.amp.autocast
961
962    copy_tests(OptimizeForInferenceTemplate, FreezingCudaTests, "cuda")
963
964
965del OptimizeForInferenceTemplate
966
967
968if __name__ == "__main__":
969    from torch._inductor.test_case import run_tests
970
971    if HAS_CPU or HAS_CUDA:
972        run_tests(needs="filelock")
973