• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["module: inductor"]
2import copy
3import itertools
4import os
5import unittest
6
7import torch
8import torch._dynamo.config as dynamo_config
9import torch._inductor.config as inductor_config
10import torch._inductor.fx_passes.post_grad
11import torch.nn.functional as F
12from torch._dynamo.utils import count_calls, counters
13from torch._higher_order_ops.out_dtype import out_dtype
14from torch._inductor.fx_passes import joint_graph
15from torch._inductor.pattern_matcher import (
16    Arg,
17    CallFunction,
18    gen_pattern,
19    is_mutation_op,
20    KeywordArg,
21    Match,
22    PatternMatcherPass,
23    PatternPrettyPrinter,
24    register_graph_pattern,
25    stable_topological_sort,
26)
27from torch._inductor.test_case import run_tests, TestCase
28from torch._inductor.utils import run_and_get_code
29from torch._inductor.virtualized import V
30from torch.testing import FileCheck
31from torch.testing._internal.common_cuda import SM80OrLater
32from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm
33from torch.testing._internal.inductor_utils import HAS_CUDA, IS_A100, IS_BIG_GPU
34from torch.utils import _pytree as pytree
35
36
37class TestPatternMatcher(TestCase):
38    def common(
39        self,
40        fn,
41        args,
42        expected_matches,
43        expected_nodes,
44        additional_check=lambda code: None,
45        reference_in_float=False,
46    ):
47        counters.clear()
48        torch.manual_seed(42)
49        if reference_in_float:
50            ref_inputs = pytree.tree_map_only(
51                torch.Tensor, lambda x: x.to(torch.float32), args
52            )
53        else:
54            ref_inputs = args
55        expected = fn(*ref_inputs)
56        torch.manual_seed(42)
57        actual, codes = run_and_get_code(torch.compile(fn), *args)
58        if len(codes) == 1:
59            codes = codes[0]
60        torch.testing.assert_close(actual, expected, check_dtype=not reference_in_float)
61
62        self.assertEqual(
63            counters["inductor"]["pattern_matcher_count"], expected_matches
64        )
65        self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], expected_nodes)
66        additional_check(codes)
67        counters.clear()
68
69    def test_mm_plus_mm(self):
70        def fn(a, b, c, d):
71            return torch.add(torch.mm(a, b), torch.mm(c, d))
72
73        # when m1 == n1 and m2 == n2, mm_plus_mm can be matched to fused op
74        fusible_args_list = [
75            (
76                torch.randn(16, 16, device="cuda"),
77                torch.randn(16, 16, device="cuda"),
78                torch.randn(16, 16, device="cuda"),
79                torch.randn(16, 16, device="cuda"),
80            ),
81            (
82                torch.randn(1, 4, device="cuda"),
83                torch.randn(4, 2, device="cuda"),
84                torch.randn(1, 5, device="cuda"),
85                torch.randn(5, 2, device="cuda"),
86            ),
87        ]
88        for args in fusible_args_list:
89            self.common(fn, args, 1, 3)
90
91        # if not fusible, it can only match add(mm())
92        unfusible_args_list = [
93            # https://github.com/pytorch/pytorch/issues/100670.
94            (
95                torch.randn(1, 4, device="cuda"),
96                torch.randn(4, 2, device="cuda"),
97                torch.randn(1, 2, device="cuda"),
98                torch.randn(2, 1, device="cuda"),
99            ),
100            (
101                torch.randn(1, 2, device="cuda"),
102                torch.randn(2, 1, device="cuda"),
103                torch.randn(1, 4, device="cuda"),
104                torch.randn(4, 2, device="cuda"),
105            ),
106        ]
107        for args in unfusible_args_list:
108            self.common(fn, args, 1, 2)
109
110    def _test_fused_int_mm_mul_impl(self, fn, args, fused_int_mm_mul_expected=True):
111        torch._dynamo.reset()
112        counters.clear()
113        ref = fn(*args)
114        test, (code,) = run_and_get_code(torch.compile(fn, mode="max-autotune"), *args)
115        self.assertEqual("fused_int_mm_mul" in code, fused_int_mm_mul_expected)
116        if fused_int_mm_mul_expected:
117            indices = ~ref.isinf()
118            torch.testing.assert_close(
119                ref[indices], test[indices]
120            )  # also checks that dtype is correct
121
122    @skipIfRocm
123    @unittest.skipIf(not SM80OrLater, "need sm_80")
124    @inductor_config.patch(force_fuse_int_mm_with_mul=True)
125    def test_fused_int_mm_mul(self):
126        def fn1(a, b, c):
127            return out_dtype(torch.ops.aten.mm.default, torch.int32, a, b) * c
128
129        def fn2(a, b, c):
130            return (out_dtype(torch.ops.aten.mm.default, torch.int32, a, b) * c).to(
131                torch.bfloat16
132            )
133
134        args_list = [
135            (
136                torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda"),
137                torch.randint(-128, 127, (32, 8), dtype=torch.int8, device="cuda"),
138                torch.randn((32, 1), dtype=torch.float16, device="cuda") * 0 + 0.5,
139            ),
140            (
141                torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda"),
142                torch.randint(-128, 127, (32, 8), dtype=torch.int8, device="cuda"),
143                torch.randn((1, 8), dtype=torch.bfloat16, device="cuda"),
144            ),
145            (
146                torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda"),
147                torch.randint(-128, 127, (32, 8), dtype=torch.int8, device="cuda"),
148                torch.randn((1, 8), dtype=torch.float32, device="cuda"),
149            ),
150        ]
151
152        for args in args_list:
153            self._test_fused_int_mm_mul_impl(fn1, args, True)
154            self._test_fused_int_mm_mul_impl(fn2, args, True)
155
156    @skipIfRocm
157    @unittest.skipIf(not SM80OrLater, "need sm_80")
158    @inductor_config.patch(force_fuse_int_mm_with_mul=True)
159    def test_fused_int_mm_mul_gating(self):
160        def fn1(a, b, c):
161            return out_dtype(torch.ops.aten.mm.default, torch.int32, a, b) * c
162
163        args1 = (
164            torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda"),
165            torch.randint(-128, 127, (32, 8), dtype=torch.int8, device="cuda"),
166            torch.randn((8), dtype=torch.float32, device="cuda"),
167        )
168
169        args2 = (
170            torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda"),
171            torch.randint(-128, 127, (32, 8), dtype=torch.int8, device="cuda"),
172            torch.randn((32, 1), dtype=torch.float16, device="cuda"),
173        )
174        self._test_fused_int_mm_mul_impl(fn1, args1, False)
175        self._test_fused_int_mm_mul_impl(fn1, [arg.cpu() for arg in args2], False)
176        inductor_config.force_fuse_int_mm_with_mul = False
177        self._test_fused_int_mm_mul_impl(fn1, args2, False)
178
179    def _test_mixed_impl(
180        self,
181        fn,
182        args,
183        mixed_mm_expected,
184        fallback_mixed_mm_expected,
185        rtol=None,
186        atol=None,
187    ):
188        torch._dynamo.reset()
189        counters.clear()
190        ref = fn(*args)
191        test, (code,) = run_and_get_code(torch.compile(fn), *args)
192        torch.testing.assert_close(ref, test, rtol=rtol, atol=atol)
193        self.assertEqual("mixed_mm" in code, mixed_mm_expected)
194        self.assertEqual("fallback_mixed_mm" in code, fallback_mixed_mm_expected)
195
196    @unittest.skipIf(not SM80OrLater, "need sm_80")
197    @inductor_config.patch(mixed_mm_choice="triton")
198    def test_mixed_mm(self):
199        def fn(a, b):
200            return torch.mm(a, b.to(a.dtype))
201
202        args_list = [
203            (
204                torch.randn(8, 8, device="cuda"),
205                torch.randint(-128, 127, (8, 8), dtype=torch.int8, device="cuda"),
206            ),
207            (
208                torch.randn(8, 2, device="cuda", dtype=torch.bfloat16),
209                torch.randint(-128, 127, (2, 8), dtype=torch.int8, device="cuda"),
210            ),
211            (
212                torch.randn(8, 5, device="cuda", dtype=torch.float16),
213                torch.randint(0, 255, (5, 2), dtype=torch.uint8, device="cuda"),
214            ),
215            (
216                torch.randn(8, 8, device="cuda", dtype=torch.float32),
217                torch.randn(8, 8, device="cuda", dtype=torch.bfloat16),
218            ),
219        ]
220
221        for args in args_list:
222            self._test_mixed_impl(fn, args, True, False)
223
224    @unittest.skipIf(not SM80OrLater, "need sm_80")
225    @inductor_config.patch(mixed_mm_choice="triton")
226    def test_mixed_mm_exhaustive_dtypes(self):
227        def fn(a, b):
228            return torch.mm(a, b.to(a.dtype))
229
230        dtypes_left = [torch.float16, torch.float32, torch.bfloat16]
231        dtypes_right = [torch.int8, torch.uint8]
232        dtype_ranges = {torch.uint8: (0, 255), torch.int8: (-128, 127)}
233        for dtype_left, dtype_right in itertools.product(dtypes_left, dtypes_right):
234            low, high = dtype_ranges[dtype_right]
235            args = (
236                torch.randn(256, 256, dtype=dtype_left, device="cuda"),
237                torch.randint(low, high, (256, 256), dtype=dtype_right, device="cuda"),
238            )
239            fallback_mixed_mm_expected = (
240                dtype_left == torch.bfloat16 and dtype_right == torch.uint8
241            )
242            self._test_mixed_impl(
243                fn, args, True, fallback_mixed_mm_expected, rtol=0.16, atol=1e-4
244            )
245
246    @unittest.skipIf(not SM80OrLater, "need sm_80")
247    @inductor_config.patch(mixed_mm_choice="triton")
248    def test_mixed_mm_bad_cases(self):
249        def fn(a, b):
250            return torch.mm(a, b.to(a.dtype))
251
252        # when b is transposed and not contiguous, we skip triton and use fallback
253        args_list = [
254            (
255                torch.randn(8, 8, device="cuda", dtype=torch.float16),
256                torch.randint(-128, 127, (4, 8), dtype=torch.int8, device="cuda").t()[
257                    :, ::2
258                ],
259            ),
260            (
261                torch.randn(8, 8, device="cuda", dtype=torch.bfloat16),
262                torch.randint(0, 255, (4, 8), dtype=torch.uint8, device="cuda").t()[
263                    :, ::2
264                ],
265            ),
266        ]
267
268        for args in args_list:
269            self._test_mixed_impl(fn, args, True, True)
270
271    @unittest.skipIf(not SM80OrLater, "need sm_80")
272    @inductor_config.patch(mixed_mm_choice="triton", max_autotune_gemm=True)
273    def test_mixed_mm_epi_works(self):
274        def fn(a, b, c, d):
275            return torch.mm(a, b.to(a.dtype)) * c + d
276
277        args_list = [
278            (
279                torch.randn(8, 8, device="cuda"),
280                torch.randint(-128, 127, (8, 8), dtype=torch.int8, device="cuda"),
281                torch.randn(8, device="cuda"),
282                torch.randn(8, device="cuda"),
283            ),
284            (
285                torch.randn(8, 2, device="cuda", dtype=torch.bfloat16),
286                torch.randint(-128, 127, (2, 8), dtype=torch.int8, device="cuda"),
287                torch.randn(8, device="cuda", dtype=torch.bfloat16),
288                torch.randn(8, device="cuda", dtype=torch.bfloat16),
289            ),
290            (
291                torch.randn(8, 5, device="cuda", dtype=torch.float16),
292                torch.randint(0, 255, (5, 2), dtype=torch.uint8, device="cuda"),
293                torch.randn(2, device="cuda", dtype=torch.float16),
294                torch.randn(2, device="cuda", dtype=torch.float16),
295            ),
296        ]
297
298        for args in args_list:
299            self._test_mixed_impl(fn, args, True, False)
300
301    @unittest.skipIf(not SM80OrLater, "need sm_80")
302    @unittest.skipIf(not IS_A100, "heuristic only run on Linux A100")
303    @unittest.skipIf(not IS_BIG_GPU, "tests fail on small GPU")
304    @inductor_config.patch(
305        mixed_mm_choice="heuristic",
306        autoheuristic_use="",
307        fx_graph_cache=False,
308        fx_graph_remote_cache=False,
309        shape_padding=False,
310    )
311    def test_mixed_mm_heuristic_no(self):
312        def fn(a, b):
313            return torch.mm(a, b.to(a.dtype))
314
315        # examples that should not be selected by handwritten heuristic
316        mat1_dtype = torch.float16
317        dyn_tensor = torch.randn(4, 4096, dtype=mat1_dtype, device="cuda")
318        torch._dynamo.mark_dynamic(dyn_tensor, 0)
319        args_list = [
320            (
321                torch.randn(1, 4097, dtype=mat1_dtype, device="cuda"),
322                torch.randint(-128, 127, (4097, 4096), dtype=torch.int8, device="cuda"),
323            ),
324            (
325                torch.randn(1, 4096, dtype=mat1_dtype, device="cuda"),
326                torch.randint(-128, 127, (4096, 4097), dtype=torch.int8, device="cuda"),
327            ),
328            (
329                torch.randn(8, 8, dtype=mat1_dtype, device="cuda"),
330                torch.randint(-128, 127, (8, 8), dtype=torch.int8, device="cuda"),
331            ),
332            (
333                torch.randn(8, 2048, dtype=mat1_dtype, device="cuda"),
334                torch.randint(-128, 127, (2048, 2048), dtype=torch.int8, device="cuda"),
335            ),
336            (
337                torch.randn(8, 2048, dtype=mat1_dtype, device="cuda"),
338                torch.randint(
339                    -128, 127, (2048, 2048), dtype=torch.int8, device="cuda"
340                ).t(),
341            ),
342            (
343                torch.randn(8, 4096, dtype=mat1_dtype, device="cuda"),
344                torch.randint(-128, 127, (4096, 4096), dtype=torch.int8, device="cuda")[
345                    :, ::2
346                ],
347            ),
348            (
349                torch.randn(1, 4096, dtype=torch.float32, device="cuda"),
350                torch.randint(-128, 127, (4096, 4096), dtype=torch.int8, device="cuda"),
351            ),
352            (
353                dyn_tensor,
354                torch.randint(-128, 127, (4096, 4096), dtype=torch.int8, device="cuda"),
355            ),
356        ]
357
358        for args in args_list:
359            self._test_mixed_impl(fn, args, True, True)
360
361    @unittest.skipIf(not SM80OrLater, "need sm_80")
362    @unittest.skipIf(not IS_A100, "heuristic only run on Linux A100")
363    @unittest.skipIf(not IS_BIG_GPU, "tests fail on small GPU")
364    @inductor_config.patch(
365        mixed_mm_choice="heuristic",
366        autoheuristic_use="",
367        fx_graph_cache=False,
368        fx_graph_remote_cache=False,
369        shape_padding=False,
370    )
371    def test_mixed_mm_heuristic_yes(self):
372        def fn(a, b):
373            return torch.mm(a, b.to(a.dtype))
374
375        mat1_dtype = torch.float16
376        # examples that should be selected by handwritten heuristic
377        args_list = [
378            (
379                torch.randn(1, 4096, dtype=mat1_dtype, device="cuda"),
380                torch.randint(-128, 127, (4096, 4096), dtype=torch.int8, device="cuda"),
381            ),
382            (
383                torch.randn(4, 4096, dtype=mat1_dtype, device="cuda"),
384                torch.randint(-128, 127, (4096, 4096), dtype=torch.int8, device="cuda"),
385            ),
386            (
387                torch.randn(8, 4096, dtype=mat1_dtype, device="cuda"),
388                torch.randint(-128, 127, (4096, 4096), dtype=torch.int8, device="cuda"),
389            ),
390            (
391                torch.randn(8, 4096, dtype=mat1_dtype, device="cuda"),
392                torch.randint(
393                    -128, 127, (4096, 4096), dtype=torch.int8, device="cuda"
394                ).t(),
395            ),
396            (
397                torch.randn(16, 4096, dtype=mat1_dtype, device="cuda"),
398                torch.randint(
399                    -128, 127, (8192, 4096), dtype=torch.int8, device="cuda"
400                ).t(),
401            ),
402            (
403                torch.randn(32, 4096, dtype=mat1_dtype, device="cuda"),
404                torch.randint(-128, 127, (4096, 8192), dtype=torch.int8, device="cuda"),
405            ),
406            (
407                torch.randn(64, 4096, dtype=mat1_dtype, device="cuda"),
408                torch.randint(-128, 127, (4096, 4096), dtype=torch.int8, device="cuda"),
409            ),
410        ]
411
412        for args in args_list:
413            self._test_mixed_impl(fn, args, True, False, rtol=0.01, atol=0.04)
414
415    @unittest.skipIf(not SM80OrLater, "need sm_80")
416    def test_mixed_mm_gating(self):
417        def fn(a, b):
418            return torch.mm(a, b.to(a.dtype))
419
420        args = (
421            torch.randn(8, 8, device="cuda"),
422            torch.randint(-128, 127, (8, 8), dtype=torch.int8, device="cuda"),
423        )
424        # will ignore the mixed_mm code (including fallback)
425        with inductor_config.patch(
426            {"mixed_mm_choice": "default", "use_mixed_mm": False}
427        ):
428            self._test_mixed_impl(fn, args, False, False)
429
430        # will use fallback_mixed_mm kernel due to no gemm_autotune
431        with inductor_config.patch(
432            {"mixed_mm_choice": "default", "use_mixed_mm": True}
433        ):
434            self._test_mixed_impl(fn, args, True, True)
435
436        # will use mixed_mm kernel
437        with inductor_config.patch(
438            {"mixed_mm_choice": "triton", "use_mixed_mm": False}
439        ):
440            self._test_mixed_impl(fn, args, True, False)
441
442        # shows that use_mixed_mm doesn't do anything if foce_mixed_mm is set
443        with inductor_config.patch({"mixed_mm_choice": "triton", "use_mixed_mm": True}):
444            self._test_mixed_impl(fn, args, True, False)
445
446        # will use fallback_mixed_mm kernel
447        with inductor_config.patch({"mixed_mm_choice": "aten", "use_mixed_mm": False}):
448            self._test_mixed_impl(fn, args, True, True)
449
450        # will use fallback_mixed_mm kernel
451        with inductor_config.patch({"mixed_mm_choice": "aten", "use_mixed_mm": True}):
452            self._test_mixed_impl(fn, args, True, True)
453
454        # will use fallback_mixed_mm kernel because fallback is the only choice
455        with inductor_config.patch(
456            {"mixed_mm_choice": "aten", "use_mixed_mm": True, "max_autotune_gemm": True}
457        ):
458            self._test_mixed_impl(fn, args, True, True)
459
460    @inductor_config.patch(use_mixed_mm=True)
461    def test_mixed_mm_cpu(self):
462        def fn(a, b):
463            return torch.mm(a, b.to(a.dtype))
464
465        args = (
466            torch.randn(8, 8),
467            torch.randint(-128, 127, (8, 8), dtype=torch.int8),
468        )
469        self._test_mixed_impl(fn, args, False, False)
470
471    @unittest.skipIf(not SM80OrLater, "need sm_80")
472    @inductor_config.patch(use_mixed_mm=True)
473    def test_uint4x2_mixed_mm(self):
474        def fn(a, b):
475            return torch.mm(
476                a,
477                torch.cat((b & 0xF, b >> 4), 1)
478                .reshape(-1, b.shape[1])
479                .to(a.dtype)
480                .sub(8),
481            )
482
483        def check_uint4x2_mixed_mm(args, expect_mixed_mm):
484            torch._dynamo.reset()
485            counters.clear()
486            ref = fn(*args)
487            test, (code,) = run_and_get_code(torch.compile(fn), *args)
488            torch.testing.assert_close(ref, test)
489            self.assertEqual("uint4x2_mixed_mm" in code, expect_mixed_mm)
490
491        args_expect_mixed_mm = [
492            (
493                torch.randn(8, 8, device="cuda"),
494                torch.randint(0, 255, (4, 8), dtype=torch.uint8, device="cuda"),
495            ),
496            (
497                torch.randn(8, 8, device="cuda", dtype=torch.float16),
498                torch.randint(0, 255, (4, 8), dtype=torch.uint8, device="cuda")
499                .t()
500                .contiguous()
501                .t(),
502            ),
503        ]
504
505        for args in args_expect_mixed_mm:
506            check_uint4x2_mixed_mm(args, True)
507
508        # mixed mm is only enabled when casting from a lower-bitwidth dtype to a higher one
509        args_expect_no_mixed_mm = [
510            (
511                torch.randn(8, 8, device="cuda"),
512                torch.randint(0, 255, (4, 8), dtype=torch.int32, device="cuda"),
513            ),
514            (
515                torch.randn(8, 8, device="cuda"),
516                torch.randint(0, 255, (4, 8), dtype=torch.int64, device="cuda"),
517            ),
518        ]
519
520        for args in args_expect_no_mixed_mm:
521            check_uint4x2_mixed_mm(args, False)
522
523    @unittest.skipIf(not SM80OrLater, "need sm_80")
524    @inductor_config.patch(use_mixed_mm=True)
525    def test_uint4x2_mixed_mm_epi(self):
526        def fn(a, b, c, d):
527            return (
528                torch.mm(
529                    a,
530                    torch.cat((b & 0xF, b >> 4), 1)
531                    .reshape(-1, b.shape[1])
532                    .to(a.dtype)
533                    .sub(8),
534                )
535                * c
536                + d
537            )
538
539        args_list = [
540            (
541                torch.randn(8, 8, device="cuda"),
542                torch.randint(0, 255, (4, 8), dtype=torch.uint8, device="cuda"),
543                torch.randn(8, device="cuda"),
544                torch.randn(8, device="cuda"),
545            ),
546        ]
547
548        for args in args_list:
549            torch._dynamo.reset()
550            counters.clear()
551            ref = fn(*args)
552            test, (code,) = run_and_get_code(torch.compile(fn), *args)
553            torch.testing.assert_close(ref, test)
554            self.assertTrue("uint4x2_mixed_mm" in code)
555            self.assertTrue("fused_add_mm_mul" in code)
556
557    @inductor_config.patch(use_mixed_mm=True)
558    def test_uint4x2_mixed_mm_fail_to_match(self):
559        def fn(a, b):
560            return torch.mm(
561                a,
562                torch.cat((b & 0xF, b >> 4), 1)
563                .reshape(-1, b.shape[1])
564                .to(a.dtype)
565                .sub(8),
566            )
567
568        args_list = [
569            (  # cpu
570                torch.randn(8, 8),
571                torch.randint(0, 255, (4, 8), dtype=torch.uint8),
572            ),
573            (  # int8
574                torch.randn(8, 8, device="cuda"),
575                torch.randint(-128, 127, (4, 8), dtype=torch.int8, device="cuda"),
576            ),  # we don't match for int8 since numerics
577        ]  # for int8 bitshifts don't match between triton and pytorch
578
579        for args in args_list:
580            torch._dynamo.reset()
581            counters.clear()
582            ref = fn(*args)
583            test, (code,) = run_and_get_code(torch.compile(fn), *args)
584            torch.testing.assert_close(ref, test)
585            self.assertFalse("uint4x2_mixed_mm" in code)
586
587    @inductor_config.patch(mixed_mm_choice="default")
588    @inductor_config.patch(use_mixed_mm=False)
589    def test_uint4x2_mixed_mm_gating_works(self):
590        def fn(a, b):
591            return torch.mm(
592                a,
593                torch.cat((b & 0xF, b >> 4), 1)
594                .reshape(-1, b.shape[1])
595                .to(a.dtype)
596                .sub(8),
597            )
598
599        args_list = [
600            (
601                torch.randn(8, 8, device="cuda"),
602                torch.randint(0, 255, (4, 8), dtype=torch.uint8, device="cuda"),
603            ),
604        ]
605
606        for args in args_list:
607            torch._dynamo.reset()
608            counters.clear()
609            ref = fn(*args)
610            test, (code,) = run_and_get_code(torch.compile(fn), *args)
611            torch.testing.assert_close(ref, test)
612            self.assertFalse("uint4x2_mixed_mm" in code)
613
614    def test_addmm(self):
615        def fn(a, b, c):
616            return torch.add(a, torch.mm(b, c)), torch.mm(b, c) + a
617
618        args_list = [
619            (
620                torch.randn(16, 16, device="cuda"),
621                torch.randn(16, 16, device="cuda"),
622                torch.randn(16, 16, device="cuda"),
623                True,
624            ),
625            (
626                torch.randn(8, device="cuda"),
627                torch.randn(16, 16, device="cuda"),
628                torch.randn(16, 8, device="cuda"),
629                True,
630            ),
631            (
632                torch.randn(16, 16, device="cuda"),
633                torch.randn(1, 16, device="cuda"),
634                torch.randn(16, 16, device="cuda"),
635                False,
636            ),
637            (
638                torch.randn(1, 16, 16, device="cuda"),
639                torch.randn(16, 16, device="cuda"),
640                torch.randn(16, 16, device="cuda"),
641                False,
642            ),
643            (
644                4,
645                torch.randn(16, 16, device="cuda"),
646                torch.randn(16, 16, device="cuda"),
647                False,
648            ),
649        ]
650        for a, b, c, should_fuse in args_list:
651            torch._dynamo.reset()
652            counters.clear()
653            args = (a, b, c)
654            e1, e2 = fn(*args)
655            a1, a2 = torch.compile(fn)(*args)
656            torch.testing.assert_close(a1, e1)
657            torch.testing.assert_close(a2, e2)
658            count, nodes = (2, 4) if should_fuse else (0, 0)
659            self.assertEqual(counters["inductor"]["pattern_matcher_count"], count)
660            self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], nodes)
661
662    def test_addmm_symbolic_scalar(self):
663        def fn(m1, m2):
664            bias = m1.size(0)
665            return torch.add(bias, torch.mm(m1, m2)), torch.mm(m1, m2) + bias
666
667        m1 = torch.randn(16, 16, device="cuda")
668        m2 = torch.randn(16, 16, device="cuda")
669
670        counters.clear()
671        expect = fn(m1, m2)
672        actual = torch.compile(fn, dynamic=True)(m1, m2)
673        self.assertEqual(expect, actual)
674        self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
675
676    def test_addmm_broadcasting_bias(self):
677        class Model(torch.nn.Module):
678            def __init__(self) -> None:
679                super().__init__()
680                self.linear = torch.nn.functional.linear
681                self.linear_weight = torch.randn(4, 4).cuda()
682                self.bias = torch.randn(1, 4).cuda()
683
684            def forward(self, x):
685                x = self.linear(x, self.linear_weight, self.bias)
686                return x
687
688        input_tensor = torch.randn(1, 3, 4).cuda()
689
690        func = Model().cuda()
691
692        res1 = func(input_tensor)
693        jit_func = torch.compile(func)
694        res2 = jit_func(input_tensor)
695
696        self.assertEqual(res1, res2)
697
698    def test_cat_mm(self):
699        def fn(a, b, c):
700            return torch.cat(
701                [
702                    torch.mm(a, b),
703                    torch.mm(b, c),
704                    torch.mm(a, c),
705                ],
706                1,
707            )
708
709        args = [
710            torch.randn(16, 16, device="cuda"),
711            torch.randn(16, 16, device="cuda"),
712            torch.randn(16, 16, device="cuda"),
713        ]
714        self.common(fn, args, 1, 4)
715
716    def test_cat_addmm(self):
717        def fn(a, b, c):
718            return torch.cat(
719                [
720                    torch.addmm(a, b, c),
721                    torch.addmm(b, c, a),
722                    torch.addmm(c, a, b),
723                ],
724                1,
725            )
726
727        args = [
728            torch.randn(16, 16, device="cuda"),
729            torch.randn(16, 16, device="cuda"),
730            torch.randn(16, 16, device="cuda"),
731        ]
732        self.common(fn, args, 1, 4)
733
734    def test_cat_slice_cat_cuda(self):
735        def fn(a, b):
736            cat_1 = torch.ops.aten.cat.default([a, b], 1)
737            slice_1 = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807)
738            slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 19)
739            return torch.ops.aten.cat.default([cat_1, slice_2], 1)
740
741        args = [
742            torch.randn(2, 32, device="cuda"),
743            torch.randn(2, 16, device="cuda"),
744        ]
745        self.common(fn, args, 1, 3)
746
747        args = [
748            torch.randn(2, 8, device="cuda"),
749            torch.randn(2, 16, device="cuda"),
750        ]
751        counters.clear()
752        expected = fn(*args)
753        actual = torch.compile(fn)(*args)
754        torch.testing.assert_close(actual, expected)
755        # We don't recompile for dynamic-shape cases.
756        if dynamo_config.assume_static_by_default:
757            self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
758            self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 3)
759
760        # Verify we fallback to non-optimal path for negative `end`.
761        def fn(a, b):
762            cat_1 = torch.ops.aten.cat.default([a, b], 1)
763            slice_1 = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807)
764            slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 0, -1)
765            return torch.ops.aten.cat.default([cat_1, slice_2], 1)
766
767        args = [
768            torch.randn(2, 8, device="cuda"),
769            torch.randn(2, 16, device="cuda"),
770        ]
771        self.common(fn, args, 1, 3)
772
773    def test_pointless_convert(self):
774        def fn1(x):
775            x = torch.ops.prims.convert_element_type.default(x, torch.float16)
776            x = torch.ops.prims.convert_element_type.default(x, torch.float32)
777            return x
778
779        gm = torch.fx.symbolic_trace(fn1)
780        self.assertEqual(count_calls(gm.graph), 2)
781        joint_graph.joint_graph_passes(gm)
782        self.assertEqual(count_calls(gm.graph), 1)
783
784        def fn2(x):
785            x = torch.ops.prims.convert_element_type.default(x, torch.int32)
786            x = torch.ops.prims.convert_element_type.default(x, torch.float32)
787            return x
788
789        gm = torch.fx.symbolic_trace(fn2)
790        self.assertEqual(count_calls(gm.graph), 2)
791        joint_graph.joint_graph_passes(gm)
792        self.assertEqual(count_calls(gm.graph), 2)
793
794    # Constant folding was explicitly turned off due to issue #108388
795    # Turn it back on for test
796    @inductor_config.patch(joint_graph_constant_folding=True)
797    def test_pointless_cumsum(self):
798        def fn1():
799            ones = torch.full(
800                [1, 128], 1, layout=torch.strided, dtype=torch.float32
801            ).to(torch.int64)
802            return torch.cumsum(ones, 1) * ones
803
804        def fn2():
805            ones = torch.full(
806                [55, 10], 1, layout=torch.strided, dtype=torch.float32
807            ).to(torch.int64)
808            return torch.cumsum(ones, 1)
809
810        def fn3():
811            twos = torch.full([5, 4, 3], 2, dtype=torch.int64)
812            return torch.cumsum(twos, 0)
813
814        def fn4():
815            x = torch.full([100], 0.1, dtype=torch.float32)
816            return torch.cumsum(x, 0)
817
818        def fn5():
819            t1 = torch.full([2, 4], 1)
820            t2 = t1.to(dtype=torch.bool)
821            return torch.cumsum(t2, 1)
822
823        def fn6():
824            x = torch.full([10, 10], True, dtype=torch.int32)
825            return torch.cumsum(x, 1)
826
827        for fn in (fn1, fn2, fn3, fn4, fn5, fn6):
828            result, (code,) = run_and_get_code(torch.compile(fn, fullgraph=True))
829            self.assertNotIn("aten.cumsum", code)
830            self.assertEqual(result, fn())
831            self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
832            counters.clear()
833
834    def test_splitwithsizes_cat(self):
835        # Good case
836        def fn(a):
837            split_with_sizes = torch.ops.aten.split_with_sizes.default(a, [8, 24], 1)
838            getitem = split_with_sizes[0]
839            getitem_1 = split_with_sizes[1]
840            cat = torch.ops.aten.cat.default([getitem, getitem_1], 1)
841            return cat**2
842
843        args = [
844            torch.randn(2, 32, device="cuda"),
845        ]
846        self.common(fn, args, 1, 4)
847
848        # Not all getitems are passed to cat
849        def fn(a):
850            split_with_sizes = torch.ops.aten.split_with_sizes.default(a, [8, 8, 16], 1)
851            getitem = split_with_sizes[0]
852            getitem_1 = split_with_sizes[1]
853            getitem_2 = split_with_sizes[2]
854            cat = torch.ops.aten.cat.default([getitem, getitem_1], 1)
855            return cat**2 + getitem_2
856
857        args = [
858            torch.randn(2, 32, device="cuda"),
859        ]
860        self.common(fn, args, 0, 0)
861
862        # Different dimensions  (TODO this case should be handled by replacing with a reshape)
863        def fn(a):
864            split_with_sizes = torch.ops.aten.split_with_sizes.default(
865                a, [8, 8, 8, 8], 1
866            )
867            cat = torch.ops.aten.cat.default(split_with_sizes, 0)
868            return cat**2
869
870        args = [
871            torch.randn(2, 32, device="cuda"),
872        ]
873        self.common(fn, args, 0, 0)
874
875        # https://github.com/pytorch/pytorch/issues/99686.
876        def fn(a):
877            x = torch.ops.aten.split_with_sizes.default(a, [3, 2, 3], dim=1)
878            cat = torch.ops.aten.cat.default([x[1], x[0], x[2]], dim=1)
879            return cat
880
881        args = [
882            torch.randn(1, 8, device="cuda"),
883        ]
884        self.common(fn, args, 0, 0)
885
886    def test_cat_splitwithsizes(self):
887        # good case
888        def fn(a, b, c):
889            cat = torch.ops.aten.cat.default([a, b, c], 1)
890            split_with_sizes = torch.ops.aten.split_with_sizes.default(
891                cat, [2, 3, 5], 1
892            )
893            return [s**2 for s in split_with_sizes]
894
895        args = [
896            torch.randn(2, 2, device="cuda"),
897            torch.randn(2, 3, device="cuda"),
898            torch.randn(2, 5, device="cuda"),
899        ]
900        self.common(fn, args, 1, 2)
901
902        # cat node has other users
903        def fn(a, b, c):
904            cat = torch.ops.aten.cat.default([a, b, c], 1)
905            split_with_sizes = torch.ops.aten.split_with_sizes.default(
906                cat, [2, 3, 5], 1
907            )
908            return [s**2 for s in split_with_sizes] + [cat**3]
909
910        args = [
911            torch.randn(2, 2, device="cuda"),
912            torch.randn(2, 3, device="cuda"),
913            torch.randn(2, 5, device="cuda"),
914        ]
915        self.common(fn, args, 0, 0)
916
917        # cat and split dims are different
918        def fn(a, b, c):
919            cat = torch.ops.aten.cat.default([a, b, c], 1)
920            split_with_sizes = torch.ops.aten.split_with_sizes.default(
921                cat, [2, 3, 5], 0
922            )
923            return [s**2 for s in split_with_sizes]
924
925        args = [
926            torch.randn(10, 2, device="cuda"),
927            torch.randn(10, 3, device="cuda"),
928            torch.randn(10, 5, device="cuda"),
929        ]
930        self.common(fn, args, 0, 0)
931
932        # cat and split lenghts are different
933        def fn(a, b, c):
934            cat = torch.ops.aten.cat.default([a, b, c], 1)
935            split_with_sizes = torch.ops.aten.split_with_sizes.default(cat, [5, 5], 1)
936            return [s**2 for s in split_with_sizes]
937
938        args = [
939            torch.randn(2, 2, device="cuda"),
940            torch.randn(2, 3, device="cuda"),
941            torch.randn(2, 5, device="cuda"),
942        ]
943        self.common(fn, args, 0, 0)
944
945        # cat input sizes and split sizes are different
946        def fn(a, b, c):
947            cat = torch.ops.aten.cat.default([a, b, c], 1)
948            split_with_sizes = torch.ops.aten.split_with_sizes.default(
949                cat, [2, 5, 3], 1
950            )
951            return [s**2 for s in split_with_sizes]
952
953        args = [
954            torch.randn(2, 2, device="cuda"),
955            torch.randn(2, 3, device="cuda"),
956            torch.randn(2, 5, device="cuda"),
957        ]
958        self.common(fn, args, 0, 0)
959
960    def test_symint_pattern_matching(self):
961        import torch._inductor.config as config
962        from torch._inductor.pattern_matcher import (
963            fwd_only,
964            PatternMatcherPass,
965            register_replacement,
966        )
967
968        saved_graph = None
969
970        class _CustomPass(PatternMatcherPass):
971            def __init__(self) -> None:
972                super().__init__()
973
974            def __call__(self, g: torch.fx.graph.Graph):
975                self.apply(g)
976                nonlocal saved_graph
977                saved_graph = g
978
979        with config.patch(
980            # leave custom pass only in post_grad_passes()
981            pattern_matcher=False,
982            # define pattern match as custom post grad opt pass
983            post_grad_custom_pre_pass=None,
984            post_grad_custom_post_pass=_CustomPass(),
985        ):
986
987            def add(x, y):
988                return x + y
989
990            # testing that
991            def sym_minus(x, y):
992                return (x - (-y.size(0))) - (y * -1) - y.size(0)
993
994            device = "cpu"
995            my_args = [
996                torch.empty([8, 1], device=device),
997                torch.empty([10], device=device),
998            ]
999
1000            invoked = False
1001
1002            def extra_check(match):
1003                nonlocal invoked
1004                invoked = True
1005                return True
1006
1007            register_replacement(
1008                add,
1009                sym_minus,
1010                my_args,
1011                fwd_only,
1012                [config.post_grad_custom_post_pass],
1013                extra_check=extra_check,
1014            )
1015
1016            @torch.compile(dynamic=True)
1017            def foo(x, y):
1018                return x + y
1019
1020            x = torch.rand([8, 1])
1021            y = torch.rand([10])
1022
1023            self.assertEqual(foo(x, y), x + y)
1024
1025            self.assertTrue(invoked)
1026            # we trace out the y.sym_size in replacement
1027            FileCheck().check("sym_size_int").check_same("num_users=2").check_same(
1028                "target=torch.ops.aten.sym_size"
1029            ).run(str(saved_graph))
1030
1031    @inductor_config.patch(fx_graph_remote_cache=False)
1032    def test_match_with_mutation(self):
1033        counter = 0
1034        test_pass = PatternMatcherPass(pass_name="test")
1035
1036        @register_graph_pattern(
1037            CallFunction(
1038                torch.add, KeywordArg("x"), CallFunction(torch.sin, KeywordArg("x"))
1039            ),
1040            pass_dict=test_pass,
1041        )
1042        def _test(match, x):
1043            nonlocal counter
1044            counter += 1
1045
1046        def fn0(x, y):
1047            a = torch.sin(x)
1048            b = torch.add(x, a)
1049            return b
1050
1051        def fn1(x, y):
1052            a = torch.sin(x)
1053            x.copy_(y)
1054            b = torch.add(x, a)
1055            return b
1056
1057        def fn2(x, y):
1058            a = torch.sin(x)
1059            with torch.no_grad():
1060                b = torch.add(x, a)
1061            return b
1062
1063        def fn3(x, y):
1064            a = torch.sin(x)
1065            with torch.autocast("cuda"):
1066                b = torch.add(x, a)
1067            return b
1068
1069        def fn4(x, y):
1070            a = torch.sin(x)
1071            torch.manual_seed(1234)
1072            b = torch.add(x, a)
1073            return b
1074
1075        def fn5(x, y):
1076            a = torch.sin(x)
1077            torch.add(y, 1, out=x)
1078            b = torch.add(x, a)
1079            return b
1080
1081        args = [
1082            torch.randn(5, 5, device="cuda"),
1083            torch.randn(5, 5, device="cuda"),
1084        ]
1085
1086        with unittest.mock.patch(
1087            "torch._inductor.fx_passes.pre_grad.config.pre_grad_fusion_options",
1088            {"test": {}},
1089        ), unittest.mock.patch(
1090            "torch._inductor.fx_passes.pre_grad.PRE_GRAD_FUSIONS",
1091            [],
1092        ), unittest.mock.patch(
1093            "torch._inductor.fx_passes.pre_grad.PRE_GRAD_PATTERNS",
1094            {"test": test_pass},
1095        ):
1096            for fn in (fn0, fn1, fn2, fn3, fn4, fn5):
1097                counter = 0
1098                expected = fn(*copy.deepcopy(args))
1099                actual = torch.compile(fn)(*copy.deepcopy(args))
1100                # should not match
1101                self.assertEqual(counter, int(fn is fn0))
1102                torch.testing.assert_close(actual, expected)
1103
1104    def test_remove_pointless_clones(self):
1105        @torch.compile(fullgraph=True)
1106        def fn(a, b):
1107            return torch.mm(a, b).clone()
1108
1109        result, (code) = run_and_get_code(fn, torch.randn(8, 8), torch.randn(8, 8))
1110        # clone would create a buf1
1111        self.assertIn("return (buf0, )", code[0])
1112        self.assertNotIn("async_compile.cpp", code[0])
1113
1114    def test_unfuse_bias_addmm(self):
1115        args = [
1116            torch.randn(20, device="cuda"),
1117            torch.randn(10, 15, device="cuda"),
1118            torch.randn(15, 20, device="cuda"),
1119        ]
1120
1121        @torch.compile()
1122        def fn(inp, a, b):
1123            return torch.ops.aten.addmm(inp, a, b)
1124
1125        _, (code) = run_and_get_code(fn, args[0], args[1], args[2])
1126        FileCheck().check("extern_kernels.addmm(").run(code[0])
1127
1128        @torch.compile()
1129        def fn2(inp, a, b):
1130            return torch.nn.functional.gelu(torch.ops.aten.addmm(inp, a, b))
1131
1132        _, (code) = run_and_get_code(fn2, args[0], args[1], args[2])
1133        FileCheck().check_not("extern_kernels.addmm(").run(code[0])
1134
1135        @torch.compile()
1136        def fn2(inp, a, b):
1137            return torch.nn.functional.gelu(
1138                torch.ops.aten.addmm(inp, a, b).unsqueeze(0)
1139            )
1140
1141        # hit the view path
1142        _, (code) = run_and_get_code(fn2, args[0], args[1], args[2])
1143        FileCheck().check_not("extern_kernels.addmm(").run(code[0])
1144
1145    def test_serialized_patterns_up_to_date(self):
1146        import torch.utils._pytree as pytree
1147        from torch._inductor.fx_passes import joint_graph
1148        from torch._inductor.pattern_matcher import _known_precompiled_patterns
1149
1150        # Ensure the patterns are loaded
1151        os.environ.pop("PYTORCH_GEN_PATTERNS", None)
1152        joint_graph.lazy_init()
1153
1154        with torch._subclasses.FakeTensorMode() as mode:
1155            for (
1156                search_fn,
1157                example_inputs,
1158                trace_fn,
1159                scalar_workaround,
1160                search_fn_pattern,
1161            ) in _known_precompiled_patterns:
1162                # Because the example_inputs were saved as fake tensors in a
1163                # different FakeTensorMode we need to update them to our
1164                # FakeTensorMode().
1165                def remap_fake_tensor(x):
1166                    if isinstance(x, torch.Tensor):
1167                        return torch._subclasses.FakeTensor.from_tensor(x, mode)
1168                    return x
1169
1170                example_inputs = pytree.tree_map(remap_fake_tensor, example_inputs)
1171
1172                pattern = gen_pattern(
1173                    search_fn, example_inputs, trace_fn, scalar_workaround
1174                )
1175                pattern_pp = PatternPrettyPrinter.run(pattern)
1176
1177                self.assertEqual(
1178                    pattern_pp,
1179                    PatternPrettyPrinter.run(search_fn_pattern),
1180                    msg=f"Found mismatched pattern {search_fn.__name__}. Run torchgen/fuse/gen_patterns.py",
1181                )
1182
1183                # Since we've already checked that the serialized patterns match
1184                # lets verify the serializer by ensuring the generated patterns
1185                # also match (since search_fn_pattern is the serialized version
1186                # of search_fn).
1187                self.assertTrue(pattern.pattern_eq(search_fn_pattern))
1188
1189    @inductor_config.patch(fx_graph_remote_cache=False)
1190    def test_match_equivalent_function_invocations1(self):
1191        counter = 0
1192        test_pass = PatternMatcherPass()
1193
1194        args = [
1195            torch.randn(20, device="cuda"),
1196            torch.randn(10, 15, device="cuda"),
1197            torch.randn(15, 20, device="cuda"),
1198        ]
1199
1200        def f0(inp, a, b):
1201            return torch.ops.aten.addmm(inp, a, b)
1202
1203        def f1(inp, a, b):
1204            return torch.ops.aten.addmm(inp, a, b, beta=1.0)
1205
1206        def f2(inp, a, b):
1207            return torch.ops.aten.addmm(inp, a, b, beta=1.0, alpha=1.0)
1208
1209        # This graph pattern should successfully match all of the above functions
1210        @register_graph_pattern(
1211            CallFunction(
1212                torch.ops.aten.addmm,
1213                Arg(),
1214                Arg(),
1215                Arg(),
1216                beta=KeywordArg("beta"),
1217                alpha=KeywordArg("alpha"),
1218            ),
1219            pass_dict=test_pass,
1220        )
1221        def addmm_replacement(match: Match, inp, mat1, mat2, beta, alpha):
1222            nonlocal counter
1223            counter += 1
1224
1225            def repl(inp, x1, x2):
1226                return (x1 @ x2) * alpha + inp * beta
1227
1228            with V.fake_mode:
1229                match.replace_by_example(repl, [inp, mat1, mat2])
1230
1231        with unittest.mock.patch(
1232            "torch._inductor.fx_passes.post_grad.pass_patterns",
1233            torch._inductor.fx_passes.post_grad.pass_patterns + [test_pass],
1234        ):
1235            for fn in (f0, f1, f2):
1236                counter = 0
1237                expected = fn(*copy.deepcopy(args))
1238                opt_fn = torch.compile(fn)
1239                actual, (code) = run_and_get_code(opt_fn, args[0], args[1], args[2])
1240                # pattern should match
1241                self.assertEqual(counter, 1)
1242                torch.testing.assert_close(actual, expected)
1243                # addmm should be replaced
1244                FileCheck().check_not("extern_kernels.addmm(").run(code[0])
1245
1246    @inductor_config.patch(fx_graph_remote_cache=False)
1247    def test_match_equivalent_function_invocations2(self):
1248        counter = 0
1249        test_pass = PatternMatcherPass()
1250
1251        args = [
1252            torch.randn(20, device="cuda"),
1253            torch.randn(10, 15, device="cuda"),
1254            torch.randn(15, 20, device="cuda"),
1255        ]
1256
1257        def f0(inp, a, b):
1258            return torch.ops.aten.addmm(inp, a, b)
1259
1260        def f1(inp, a, b):
1261            return torch.ops.aten.addmm(inp, a, b, beta=1.0)
1262
1263        def f2(inp, a, b):
1264            return torch.ops.aten.addmm(inp, a, b, beta=1.0, alpha=1.0)
1265
1266        # This graph pattern should only match f0
1267        @register_graph_pattern(
1268            CallFunction(torch.ops.aten.addmm, Arg(), Arg(), Arg()),
1269            pass_dict=test_pass,
1270        )
1271        def addmm_replacement(match: Match, inp, mat1, mat2):
1272            nonlocal counter
1273            counter += 1
1274
1275            def repl(inp, x1, x2):
1276                return x1 @ x2 + inp
1277
1278            with V.fake_mode:
1279                match.replace_by_example(repl, [inp, mat1, mat2])
1280
1281        with unittest.mock.patch(
1282            "torch._inductor.fx_passes.post_grad.pass_patterns",
1283            torch._inductor.fx_passes.post_grad.pass_patterns + [test_pass],
1284        ):
1285            for fn in (f0, f1, f2):
1286                counter = 0
1287                expected = fn(*copy.deepcopy(args))
1288                actual = torch.compile(fn)(*copy.deepcopy(args))
1289                self.assertEqual(counter, 1)
1290                torch.testing.assert_close(actual, expected)
1291
1292    @inductor_config.patch(fx_graph_remote_cache=False)
1293    def test_match_equivalent_function_invocations3(self):
1294        counter = 0
1295        test_pass = PatternMatcherPass()
1296
1297        args = [
1298            torch.randn(20, device="cuda"),
1299            torch.randn(10, 15, device="cuda"),
1300            torch.randn(15, 20, device="cuda"),
1301        ]
1302
1303        def f0(inp, a, b):
1304            return torch.ops.aten.addmm(inp, a, b)
1305
1306        def f1(inp, a, b):
1307            return torch.ops.aten.addmm(inp, a, b, beta=1.0)
1308
1309        def f2(inp, a, b):
1310            return torch.ops.aten.addmm(inp, a, b, beta=1.0, alpha=1.0)
1311
1312        # This graph pattern should only match f1
1313        @register_graph_pattern(
1314            CallFunction(
1315                torch.ops.aten.addmm, Arg(), Arg(), Arg(), beta=KeywordArg("beta")
1316            ),
1317            pass_dict=test_pass,
1318        )
1319        def addmm_replacement(match: Match, inp, mat1, mat2, beta):
1320            nonlocal counter
1321            counter += 1
1322
1323            def repl(inp, x1, x2):
1324                return x1 @ x2 + inp
1325
1326            with V.fake_mode:
1327                match.replace_by_example(repl, [inp, mat1, mat2])
1328
1329        with unittest.mock.patch(
1330            "torch._inductor.fx_passes.post_grad.pass_patterns",
1331            torch._inductor.fx_passes.post_grad.pass_patterns + [test_pass],
1332        ):
1333            for fn in (f0, f1, f2):
1334                counter = 0
1335                expected = fn(*copy.deepcopy(args))
1336                actual = torch.compile(fn)(*copy.deepcopy(args))
1337                self.assertEqual(counter, 1)
1338                torch.testing.assert_close(actual, expected)
1339
1340    def test_stable_topological_sort(self):
1341        def fn1(a, b):
1342            return a + b
1343
1344        graph = torch.fx.Graph()
1345        a = graph.placeholder("x")
1346        b = graph.placeholder("y")
1347        c = graph.call_function(fn1, (a, b))
1348        stable_topological_sort(graph)
1349        self.assertEqual(list(graph.nodes), [a, b, c])
1350
1351        graph = torch.fx.Graph()
1352        b = graph.placeholder("y")
1353        a = graph.placeholder("x")
1354        c = graph.call_function(fn1, (a, b))
1355        stable_topological_sort(graph)
1356        self.assertEqual(list(graph.nodes), [b, a, c])
1357
1358        graph = torch.fx.Graph()
1359        a = graph.placeholder("x")
1360        b = graph.placeholder("y")
1361        c = graph.call_function(fn1, (b, a))
1362        c.append(a)
1363        stable_topological_sort(graph)
1364        self.assertEqual(list(graph.nodes), [b, a, c])
1365
1366    def test_scaled_softmax(self):
1367        def mul_softmax(a, b):
1368            return F.softmax(a * b, dim=0)
1369
1370        def div_softmax(x, inv_scale):
1371            return F.softmax(x / inv_scale, dim=0)
1372
1373        x = torch.randn(10, 10)
1374        scale = 1e6
1375        inv_scale = 1 / scale
1376        self.common(mul_softmax, (x, scale), 1, 3)
1377        self.common(mul_softmax, (scale, x), 1, 3)
1378        self.common(div_softmax, (x, inv_scale), 1, 3)
1379
1380        scale = torch.randn(10) * 1e6
1381        inv_scale = 1 / scale
1382        self.common(mul_softmax, (x, scale), 1, 3)
1383        self.common(mul_softmax, (scale, x), 1, 3)
1384        self.common(div_softmax, (x, inv_scale), 1, 3)
1385
1386        scale = torch.randn(1, 10) * 1e6
1387        inv_scale = 1 / scale
1388        self.common(mul_softmax, (x, scale), 1, 3)
1389        self.common(mul_softmax, (scale, x), 1, 3)
1390        self.common(div_softmax, (x, inv_scale), 1, 3)
1391
1392        # Test matching with type promotion
1393        x = torch.randn(10, 10, dtype=torch.bfloat16)
1394        scale = torch.randn(10, dtype=torch.bfloat16) * 1e6
1395        inv_scale = 1 / scale
1396        self.common(mul_softmax, (x, scale), 1, 4, reference_in_float=True)
1397        self.common(mul_softmax, (scale, x), 1, 4, reference_in_float=True)
1398        self.common(div_softmax, (x, inv_scale), 1, 4, reference_in_float=True)
1399
1400        # No match if scale changes in softmax dim
1401        scale = torch.randn(10, 10)
1402        self.common(mul_softmax, (x, scale), 0, 0)
1403        self.common(mul_softmax, (scale, x), 0, 0)
1404        self.common(div_softmax, (x, scale), 0, 0)
1405
1406    def test_mutation_op_matching(self):
1407        def check(type, func_name, args, kwargs, expect=True):
1408            assert type in ["call_function", "call_method"]
1409            graph = torch.fx.Graph()
1410            getattr(graph, type)(func_name, args, kwargs)
1411            res = is_mutation_op(next(iter(graph.nodes)))
1412            if expect:
1413                self.assertTrue(res)
1414            else:
1415                self.assertFalse(res)
1416
1417        t = torch.randn(1)
1418        check("call_function", torch._C._set_grad_enabled, (False,), {})
1419        check("call_method", "copy_", (t, t), {})
1420        check("call_method", "relu_", (t,), {})
1421        check("call_function", torch.manual_seed, (0,), {})
1422        check("call_function", torch.ops.aten.set_.source_Tensor, (t, t), {})
1423        check(
1424            "call_function",
1425            torch.amp.autocast_mode._enter_autocast,
1426            ("cuda", None, True, None),
1427            {},
1428        )
1429        check("call_function", torch.amp.autocast_mode._exit_autocast, (None,), {})
1430        check(
1431            "call_function",
1432            torch.ops._c10d_functional.all_gather_into_tensor_out,
1433            (t, 2, "0"),
1434            {"out": t},
1435        )
1436        check("call_function", torch.ops.inductor.resize_storage_bytes_, (t, 0), {})
1437        check(
1438            "call_function",
1439            torch.ops.inductor.resize_storage_bytes_.default,
1440            (t, 0),
1441            {},
1442        )
1443        check(
1444            "call_function",
1445            torch.ops.fsdp.split_with_sizes_copy,
1446            (t, [64, 128, 8, 8]),
1447            {"dim": 1, "out": [t, t, t, t]},
1448        )
1449        check("call_function", torch.ops.fsdp.set_, (t, t), {})
1450        check(
1451            "call_function", torch.ops.aten.__rshift__.Scalar, (t, 2), {}, expect=False
1452        )
1453        check(
1454            "call_function",
1455            torch.ops._c10d_functional.all_gather_into_tensor,
1456            (t, 2, "0"),
1457            {},
1458            expect=False,
1459        )
1460
1461
1462if __name__ == "__main__":
1463    if IS_LINUX and HAS_CUDA:
1464        run_tests()
1465