1# Owner(s): ["module: inductor"] 2import copy 3import itertools 4import os 5import unittest 6 7import torch 8import torch._dynamo.config as dynamo_config 9import torch._inductor.config as inductor_config 10import torch._inductor.fx_passes.post_grad 11import torch.nn.functional as F 12from torch._dynamo.utils import count_calls, counters 13from torch._higher_order_ops.out_dtype import out_dtype 14from torch._inductor.fx_passes import joint_graph 15from torch._inductor.pattern_matcher import ( 16 Arg, 17 CallFunction, 18 gen_pattern, 19 is_mutation_op, 20 KeywordArg, 21 Match, 22 PatternMatcherPass, 23 PatternPrettyPrinter, 24 register_graph_pattern, 25 stable_topological_sort, 26) 27from torch._inductor.test_case import run_tests, TestCase 28from torch._inductor.utils import run_and_get_code 29from torch._inductor.virtualized import V 30from torch.testing import FileCheck 31from torch.testing._internal.common_cuda import SM80OrLater 32from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm 33from torch.testing._internal.inductor_utils import HAS_CUDA, IS_A100, IS_BIG_GPU 34from torch.utils import _pytree as pytree 35 36 37class TestPatternMatcher(TestCase): 38 def common( 39 self, 40 fn, 41 args, 42 expected_matches, 43 expected_nodes, 44 additional_check=lambda code: None, 45 reference_in_float=False, 46 ): 47 counters.clear() 48 torch.manual_seed(42) 49 if reference_in_float: 50 ref_inputs = pytree.tree_map_only( 51 torch.Tensor, lambda x: x.to(torch.float32), args 52 ) 53 else: 54 ref_inputs = args 55 expected = fn(*ref_inputs) 56 torch.manual_seed(42) 57 actual, codes = run_and_get_code(torch.compile(fn), *args) 58 if len(codes) == 1: 59 codes = codes[0] 60 torch.testing.assert_close(actual, expected, check_dtype=not reference_in_float) 61 62 self.assertEqual( 63 counters["inductor"]["pattern_matcher_count"], expected_matches 64 ) 65 self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], expected_nodes) 66 additional_check(codes) 67 counters.clear() 68 69 def test_mm_plus_mm(self): 70 def fn(a, b, c, d): 71 return torch.add(torch.mm(a, b), torch.mm(c, d)) 72 73 # when m1 == n1 and m2 == n2, mm_plus_mm can be matched to fused op 74 fusible_args_list = [ 75 ( 76 torch.randn(16, 16, device="cuda"), 77 torch.randn(16, 16, device="cuda"), 78 torch.randn(16, 16, device="cuda"), 79 torch.randn(16, 16, device="cuda"), 80 ), 81 ( 82 torch.randn(1, 4, device="cuda"), 83 torch.randn(4, 2, device="cuda"), 84 torch.randn(1, 5, device="cuda"), 85 torch.randn(5, 2, device="cuda"), 86 ), 87 ] 88 for args in fusible_args_list: 89 self.common(fn, args, 1, 3) 90 91 # if not fusible, it can only match add(mm()) 92 unfusible_args_list = [ 93 # https://github.com/pytorch/pytorch/issues/100670. 94 ( 95 torch.randn(1, 4, device="cuda"), 96 torch.randn(4, 2, device="cuda"), 97 torch.randn(1, 2, device="cuda"), 98 torch.randn(2, 1, device="cuda"), 99 ), 100 ( 101 torch.randn(1, 2, device="cuda"), 102 torch.randn(2, 1, device="cuda"), 103 torch.randn(1, 4, device="cuda"), 104 torch.randn(4, 2, device="cuda"), 105 ), 106 ] 107 for args in unfusible_args_list: 108 self.common(fn, args, 1, 2) 109 110 def _test_fused_int_mm_mul_impl(self, fn, args, fused_int_mm_mul_expected=True): 111 torch._dynamo.reset() 112 counters.clear() 113 ref = fn(*args) 114 test, (code,) = run_and_get_code(torch.compile(fn, mode="max-autotune"), *args) 115 self.assertEqual("fused_int_mm_mul" in code, fused_int_mm_mul_expected) 116 if fused_int_mm_mul_expected: 117 indices = ~ref.isinf() 118 torch.testing.assert_close( 119 ref[indices], test[indices] 120 ) # also checks that dtype is correct 121 122 @skipIfRocm 123 @unittest.skipIf(not SM80OrLater, "need sm_80") 124 @inductor_config.patch(force_fuse_int_mm_with_mul=True) 125 def test_fused_int_mm_mul(self): 126 def fn1(a, b, c): 127 return out_dtype(torch.ops.aten.mm.default, torch.int32, a, b) * c 128 129 def fn2(a, b, c): 130 return (out_dtype(torch.ops.aten.mm.default, torch.int32, a, b) * c).to( 131 torch.bfloat16 132 ) 133 134 args_list = [ 135 ( 136 torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda"), 137 torch.randint(-128, 127, (32, 8), dtype=torch.int8, device="cuda"), 138 torch.randn((32, 1), dtype=torch.float16, device="cuda") * 0 + 0.5, 139 ), 140 ( 141 torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda"), 142 torch.randint(-128, 127, (32, 8), dtype=torch.int8, device="cuda"), 143 torch.randn((1, 8), dtype=torch.bfloat16, device="cuda"), 144 ), 145 ( 146 torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda"), 147 torch.randint(-128, 127, (32, 8), dtype=torch.int8, device="cuda"), 148 torch.randn((1, 8), dtype=torch.float32, device="cuda"), 149 ), 150 ] 151 152 for args in args_list: 153 self._test_fused_int_mm_mul_impl(fn1, args, True) 154 self._test_fused_int_mm_mul_impl(fn2, args, True) 155 156 @skipIfRocm 157 @unittest.skipIf(not SM80OrLater, "need sm_80") 158 @inductor_config.patch(force_fuse_int_mm_with_mul=True) 159 def test_fused_int_mm_mul_gating(self): 160 def fn1(a, b, c): 161 return out_dtype(torch.ops.aten.mm.default, torch.int32, a, b) * c 162 163 args1 = ( 164 torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda"), 165 torch.randint(-128, 127, (32, 8), dtype=torch.int8, device="cuda"), 166 torch.randn((8), dtype=torch.float32, device="cuda"), 167 ) 168 169 args2 = ( 170 torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda"), 171 torch.randint(-128, 127, (32, 8), dtype=torch.int8, device="cuda"), 172 torch.randn((32, 1), dtype=torch.float16, device="cuda"), 173 ) 174 self._test_fused_int_mm_mul_impl(fn1, args1, False) 175 self._test_fused_int_mm_mul_impl(fn1, [arg.cpu() for arg in args2], False) 176 inductor_config.force_fuse_int_mm_with_mul = False 177 self._test_fused_int_mm_mul_impl(fn1, args2, False) 178 179 def _test_mixed_impl( 180 self, 181 fn, 182 args, 183 mixed_mm_expected, 184 fallback_mixed_mm_expected, 185 rtol=None, 186 atol=None, 187 ): 188 torch._dynamo.reset() 189 counters.clear() 190 ref = fn(*args) 191 test, (code,) = run_and_get_code(torch.compile(fn), *args) 192 torch.testing.assert_close(ref, test, rtol=rtol, atol=atol) 193 self.assertEqual("mixed_mm" in code, mixed_mm_expected) 194 self.assertEqual("fallback_mixed_mm" in code, fallback_mixed_mm_expected) 195 196 @unittest.skipIf(not SM80OrLater, "need sm_80") 197 @inductor_config.patch(mixed_mm_choice="triton") 198 def test_mixed_mm(self): 199 def fn(a, b): 200 return torch.mm(a, b.to(a.dtype)) 201 202 args_list = [ 203 ( 204 torch.randn(8, 8, device="cuda"), 205 torch.randint(-128, 127, (8, 8), dtype=torch.int8, device="cuda"), 206 ), 207 ( 208 torch.randn(8, 2, device="cuda", dtype=torch.bfloat16), 209 torch.randint(-128, 127, (2, 8), dtype=torch.int8, device="cuda"), 210 ), 211 ( 212 torch.randn(8, 5, device="cuda", dtype=torch.float16), 213 torch.randint(0, 255, (5, 2), dtype=torch.uint8, device="cuda"), 214 ), 215 ( 216 torch.randn(8, 8, device="cuda", dtype=torch.float32), 217 torch.randn(8, 8, device="cuda", dtype=torch.bfloat16), 218 ), 219 ] 220 221 for args in args_list: 222 self._test_mixed_impl(fn, args, True, False) 223 224 @unittest.skipIf(not SM80OrLater, "need sm_80") 225 @inductor_config.patch(mixed_mm_choice="triton") 226 def test_mixed_mm_exhaustive_dtypes(self): 227 def fn(a, b): 228 return torch.mm(a, b.to(a.dtype)) 229 230 dtypes_left = [torch.float16, torch.float32, torch.bfloat16] 231 dtypes_right = [torch.int8, torch.uint8] 232 dtype_ranges = {torch.uint8: (0, 255), torch.int8: (-128, 127)} 233 for dtype_left, dtype_right in itertools.product(dtypes_left, dtypes_right): 234 low, high = dtype_ranges[dtype_right] 235 args = ( 236 torch.randn(256, 256, dtype=dtype_left, device="cuda"), 237 torch.randint(low, high, (256, 256), dtype=dtype_right, device="cuda"), 238 ) 239 fallback_mixed_mm_expected = ( 240 dtype_left == torch.bfloat16 and dtype_right == torch.uint8 241 ) 242 self._test_mixed_impl( 243 fn, args, True, fallback_mixed_mm_expected, rtol=0.16, atol=1e-4 244 ) 245 246 @unittest.skipIf(not SM80OrLater, "need sm_80") 247 @inductor_config.patch(mixed_mm_choice="triton") 248 def test_mixed_mm_bad_cases(self): 249 def fn(a, b): 250 return torch.mm(a, b.to(a.dtype)) 251 252 # when b is transposed and not contiguous, we skip triton and use fallback 253 args_list = [ 254 ( 255 torch.randn(8, 8, device="cuda", dtype=torch.float16), 256 torch.randint(-128, 127, (4, 8), dtype=torch.int8, device="cuda").t()[ 257 :, ::2 258 ], 259 ), 260 ( 261 torch.randn(8, 8, device="cuda", dtype=torch.bfloat16), 262 torch.randint(0, 255, (4, 8), dtype=torch.uint8, device="cuda").t()[ 263 :, ::2 264 ], 265 ), 266 ] 267 268 for args in args_list: 269 self._test_mixed_impl(fn, args, True, True) 270 271 @unittest.skipIf(not SM80OrLater, "need sm_80") 272 @inductor_config.patch(mixed_mm_choice="triton", max_autotune_gemm=True) 273 def test_mixed_mm_epi_works(self): 274 def fn(a, b, c, d): 275 return torch.mm(a, b.to(a.dtype)) * c + d 276 277 args_list = [ 278 ( 279 torch.randn(8, 8, device="cuda"), 280 torch.randint(-128, 127, (8, 8), dtype=torch.int8, device="cuda"), 281 torch.randn(8, device="cuda"), 282 torch.randn(8, device="cuda"), 283 ), 284 ( 285 torch.randn(8, 2, device="cuda", dtype=torch.bfloat16), 286 torch.randint(-128, 127, (2, 8), dtype=torch.int8, device="cuda"), 287 torch.randn(8, device="cuda", dtype=torch.bfloat16), 288 torch.randn(8, device="cuda", dtype=torch.bfloat16), 289 ), 290 ( 291 torch.randn(8, 5, device="cuda", dtype=torch.float16), 292 torch.randint(0, 255, (5, 2), dtype=torch.uint8, device="cuda"), 293 torch.randn(2, device="cuda", dtype=torch.float16), 294 torch.randn(2, device="cuda", dtype=torch.float16), 295 ), 296 ] 297 298 for args in args_list: 299 self._test_mixed_impl(fn, args, True, False) 300 301 @unittest.skipIf(not SM80OrLater, "need sm_80") 302 @unittest.skipIf(not IS_A100, "heuristic only run on Linux A100") 303 @unittest.skipIf(not IS_BIG_GPU, "tests fail on small GPU") 304 @inductor_config.patch( 305 mixed_mm_choice="heuristic", 306 autoheuristic_use="", 307 fx_graph_cache=False, 308 fx_graph_remote_cache=False, 309 shape_padding=False, 310 ) 311 def test_mixed_mm_heuristic_no(self): 312 def fn(a, b): 313 return torch.mm(a, b.to(a.dtype)) 314 315 # examples that should not be selected by handwritten heuristic 316 mat1_dtype = torch.float16 317 dyn_tensor = torch.randn(4, 4096, dtype=mat1_dtype, device="cuda") 318 torch._dynamo.mark_dynamic(dyn_tensor, 0) 319 args_list = [ 320 ( 321 torch.randn(1, 4097, dtype=mat1_dtype, device="cuda"), 322 torch.randint(-128, 127, (4097, 4096), dtype=torch.int8, device="cuda"), 323 ), 324 ( 325 torch.randn(1, 4096, dtype=mat1_dtype, device="cuda"), 326 torch.randint(-128, 127, (4096, 4097), dtype=torch.int8, device="cuda"), 327 ), 328 ( 329 torch.randn(8, 8, dtype=mat1_dtype, device="cuda"), 330 torch.randint(-128, 127, (8, 8), dtype=torch.int8, device="cuda"), 331 ), 332 ( 333 torch.randn(8, 2048, dtype=mat1_dtype, device="cuda"), 334 torch.randint(-128, 127, (2048, 2048), dtype=torch.int8, device="cuda"), 335 ), 336 ( 337 torch.randn(8, 2048, dtype=mat1_dtype, device="cuda"), 338 torch.randint( 339 -128, 127, (2048, 2048), dtype=torch.int8, device="cuda" 340 ).t(), 341 ), 342 ( 343 torch.randn(8, 4096, dtype=mat1_dtype, device="cuda"), 344 torch.randint(-128, 127, (4096, 4096), dtype=torch.int8, device="cuda")[ 345 :, ::2 346 ], 347 ), 348 ( 349 torch.randn(1, 4096, dtype=torch.float32, device="cuda"), 350 torch.randint(-128, 127, (4096, 4096), dtype=torch.int8, device="cuda"), 351 ), 352 ( 353 dyn_tensor, 354 torch.randint(-128, 127, (4096, 4096), dtype=torch.int8, device="cuda"), 355 ), 356 ] 357 358 for args in args_list: 359 self._test_mixed_impl(fn, args, True, True) 360 361 @unittest.skipIf(not SM80OrLater, "need sm_80") 362 @unittest.skipIf(not IS_A100, "heuristic only run on Linux A100") 363 @unittest.skipIf(not IS_BIG_GPU, "tests fail on small GPU") 364 @inductor_config.patch( 365 mixed_mm_choice="heuristic", 366 autoheuristic_use="", 367 fx_graph_cache=False, 368 fx_graph_remote_cache=False, 369 shape_padding=False, 370 ) 371 def test_mixed_mm_heuristic_yes(self): 372 def fn(a, b): 373 return torch.mm(a, b.to(a.dtype)) 374 375 mat1_dtype = torch.float16 376 # examples that should be selected by handwritten heuristic 377 args_list = [ 378 ( 379 torch.randn(1, 4096, dtype=mat1_dtype, device="cuda"), 380 torch.randint(-128, 127, (4096, 4096), dtype=torch.int8, device="cuda"), 381 ), 382 ( 383 torch.randn(4, 4096, dtype=mat1_dtype, device="cuda"), 384 torch.randint(-128, 127, (4096, 4096), dtype=torch.int8, device="cuda"), 385 ), 386 ( 387 torch.randn(8, 4096, dtype=mat1_dtype, device="cuda"), 388 torch.randint(-128, 127, (4096, 4096), dtype=torch.int8, device="cuda"), 389 ), 390 ( 391 torch.randn(8, 4096, dtype=mat1_dtype, device="cuda"), 392 torch.randint( 393 -128, 127, (4096, 4096), dtype=torch.int8, device="cuda" 394 ).t(), 395 ), 396 ( 397 torch.randn(16, 4096, dtype=mat1_dtype, device="cuda"), 398 torch.randint( 399 -128, 127, (8192, 4096), dtype=torch.int8, device="cuda" 400 ).t(), 401 ), 402 ( 403 torch.randn(32, 4096, dtype=mat1_dtype, device="cuda"), 404 torch.randint(-128, 127, (4096, 8192), dtype=torch.int8, device="cuda"), 405 ), 406 ( 407 torch.randn(64, 4096, dtype=mat1_dtype, device="cuda"), 408 torch.randint(-128, 127, (4096, 4096), dtype=torch.int8, device="cuda"), 409 ), 410 ] 411 412 for args in args_list: 413 self._test_mixed_impl(fn, args, True, False, rtol=0.01, atol=0.04) 414 415 @unittest.skipIf(not SM80OrLater, "need sm_80") 416 def test_mixed_mm_gating(self): 417 def fn(a, b): 418 return torch.mm(a, b.to(a.dtype)) 419 420 args = ( 421 torch.randn(8, 8, device="cuda"), 422 torch.randint(-128, 127, (8, 8), dtype=torch.int8, device="cuda"), 423 ) 424 # will ignore the mixed_mm code (including fallback) 425 with inductor_config.patch( 426 {"mixed_mm_choice": "default", "use_mixed_mm": False} 427 ): 428 self._test_mixed_impl(fn, args, False, False) 429 430 # will use fallback_mixed_mm kernel due to no gemm_autotune 431 with inductor_config.patch( 432 {"mixed_mm_choice": "default", "use_mixed_mm": True} 433 ): 434 self._test_mixed_impl(fn, args, True, True) 435 436 # will use mixed_mm kernel 437 with inductor_config.patch( 438 {"mixed_mm_choice": "triton", "use_mixed_mm": False} 439 ): 440 self._test_mixed_impl(fn, args, True, False) 441 442 # shows that use_mixed_mm doesn't do anything if foce_mixed_mm is set 443 with inductor_config.patch({"mixed_mm_choice": "triton", "use_mixed_mm": True}): 444 self._test_mixed_impl(fn, args, True, False) 445 446 # will use fallback_mixed_mm kernel 447 with inductor_config.patch({"mixed_mm_choice": "aten", "use_mixed_mm": False}): 448 self._test_mixed_impl(fn, args, True, True) 449 450 # will use fallback_mixed_mm kernel 451 with inductor_config.patch({"mixed_mm_choice": "aten", "use_mixed_mm": True}): 452 self._test_mixed_impl(fn, args, True, True) 453 454 # will use fallback_mixed_mm kernel because fallback is the only choice 455 with inductor_config.patch( 456 {"mixed_mm_choice": "aten", "use_mixed_mm": True, "max_autotune_gemm": True} 457 ): 458 self._test_mixed_impl(fn, args, True, True) 459 460 @inductor_config.patch(use_mixed_mm=True) 461 def test_mixed_mm_cpu(self): 462 def fn(a, b): 463 return torch.mm(a, b.to(a.dtype)) 464 465 args = ( 466 torch.randn(8, 8), 467 torch.randint(-128, 127, (8, 8), dtype=torch.int8), 468 ) 469 self._test_mixed_impl(fn, args, False, False) 470 471 @unittest.skipIf(not SM80OrLater, "need sm_80") 472 @inductor_config.patch(use_mixed_mm=True) 473 def test_uint4x2_mixed_mm(self): 474 def fn(a, b): 475 return torch.mm( 476 a, 477 torch.cat((b & 0xF, b >> 4), 1) 478 .reshape(-1, b.shape[1]) 479 .to(a.dtype) 480 .sub(8), 481 ) 482 483 def check_uint4x2_mixed_mm(args, expect_mixed_mm): 484 torch._dynamo.reset() 485 counters.clear() 486 ref = fn(*args) 487 test, (code,) = run_and_get_code(torch.compile(fn), *args) 488 torch.testing.assert_close(ref, test) 489 self.assertEqual("uint4x2_mixed_mm" in code, expect_mixed_mm) 490 491 args_expect_mixed_mm = [ 492 ( 493 torch.randn(8, 8, device="cuda"), 494 torch.randint(0, 255, (4, 8), dtype=torch.uint8, device="cuda"), 495 ), 496 ( 497 torch.randn(8, 8, device="cuda", dtype=torch.float16), 498 torch.randint(0, 255, (4, 8), dtype=torch.uint8, device="cuda") 499 .t() 500 .contiguous() 501 .t(), 502 ), 503 ] 504 505 for args in args_expect_mixed_mm: 506 check_uint4x2_mixed_mm(args, True) 507 508 # mixed mm is only enabled when casting from a lower-bitwidth dtype to a higher one 509 args_expect_no_mixed_mm = [ 510 ( 511 torch.randn(8, 8, device="cuda"), 512 torch.randint(0, 255, (4, 8), dtype=torch.int32, device="cuda"), 513 ), 514 ( 515 torch.randn(8, 8, device="cuda"), 516 torch.randint(0, 255, (4, 8), dtype=torch.int64, device="cuda"), 517 ), 518 ] 519 520 for args in args_expect_no_mixed_mm: 521 check_uint4x2_mixed_mm(args, False) 522 523 @unittest.skipIf(not SM80OrLater, "need sm_80") 524 @inductor_config.patch(use_mixed_mm=True) 525 def test_uint4x2_mixed_mm_epi(self): 526 def fn(a, b, c, d): 527 return ( 528 torch.mm( 529 a, 530 torch.cat((b & 0xF, b >> 4), 1) 531 .reshape(-1, b.shape[1]) 532 .to(a.dtype) 533 .sub(8), 534 ) 535 * c 536 + d 537 ) 538 539 args_list = [ 540 ( 541 torch.randn(8, 8, device="cuda"), 542 torch.randint(0, 255, (4, 8), dtype=torch.uint8, device="cuda"), 543 torch.randn(8, device="cuda"), 544 torch.randn(8, device="cuda"), 545 ), 546 ] 547 548 for args in args_list: 549 torch._dynamo.reset() 550 counters.clear() 551 ref = fn(*args) 552 test, (code,) = run_and_get_code(torch.compile(fn), *args) 553 torch.testing.assert_close(ref, test) 554 self.assertTrue("uint4x2_mixed_mm" in code) 555 self.assertTrue("fused_add_mm_mul" in code) 556 557 @inductor_config.patch(use_mixed_mm=True) 558 def test_uint4x2_mixed_mm_fail_to_match(self): 559 def fn(a, b): 560 return torch.mm( 561 a, 562 torch.cat((b & 0xF, b >> 4), 1) 563 .reshape(-1, b.shape[1]) 564 .to(a.dtype) 565 .sub(8), 566 ) 567 568 args_list = [ 569 ( # cpu 570 torch.randn(8, 8), 571 torch.randint(0, 255, (4, 8), dtype=torch.uint8), 572 ), 573 ( # int8 574 torch.randn(8, 8, device="cuda"), 575 torch.randint(-128, 127, (4, 8), dtype=torch.int8, device="cuda"), 576 ), # we don't match for int8 since numerics 577 ] # for int8 bitshifts don't match between triton and pytorch 578 579 for args in args_list: 580 torch._dynamo.reset() 581 counters.clear() 582 ref = fn(*args) 583 test, (code,) = run_and_get_code(torch.compile(fn), *args) 584 torch.testing.assert_close(ref, test) 585 self.assertFalse("uint4x2_mixed_mm" in code) 586 587 @inductor_config.patch(mixed_mm_choice="default") 588 @inductor_config.patch(use_mixed_mm=False) 589 def test_uint4x2_mixed_mm_gating_works(self): 590 def fn(a, b): 591 return torch.mm( 592 a, 593 torch.cat((b & 0xF, b >> 4), 1) 594 .reshape(-1, b.shape[1]) 595 .to(a.dtype) 596 .sub(8), 597 ) 598 599 args_list = [ 600 ( 601 torch.randn(8, 8, device="cuda"), 602 torch.randint(0, 255, (4, 8), dtype=torch.uint8, device="cuda"), 603 ), 604 ] 605 606 for args in args_list: 607 torch._dynamo.reset() 608 counters.clear() 609 ref = fn(*args) 610 test, (code,) = run_and_get_code(torch.compile(fn), *args) 611 torch.testing.assert_close(ref, test) 612 self.assertFalse("uint4x2_mixed_mm" in code) 613 614 def test_addmm(self): 615 def fn(a, b, c): 616 return torch.add(a, torch.mm(b, c)), torch.mm(b, c) + a 617 618 args_list = [ 619 ( 620 torch.randn(16, 16, device="cuda"), 621 torch.randn(16, 16, device="cuda"), 622 torch.randn(16, 16, device="cuda"), 623 True, 624 ), 625 ( 626 torch.randn(8, device="cuda"), 627 torch.randn(16, 16, device="cuda"), 628 torch.randn(16, 8, device="cuda"), 629 True, 630 ), 631 ( 632 torch.randn(16, 16, device="cuda"), 633 torch.randn(1, 16, device="cuda"), 634 torch.randn(16, 16, device="cuda"), 635 False, 636 ), 637 ( 638 torch.randn(1, 16, 16, device="cuda"), 639 torch.randn(16, 16, device="cuda"), 640 torch.randn(16, 16, device="cuda"), 641 False, 642 ), 643 ( 644 4, 645 torch.randn(16, 16, device="cuda"), 646 torch.randn(16, 16, device="cuda"), 647 False, 648 ), 649 ] 650 for a, b, c, should_fuse in args_list: 651 torch._dynamo.reset() 652 counters.clear() 653 args = (a, b, c) 654 e1, e2 = fn(*args) 655 a1, a2 = torch.compile(fn)(*args) 656 torch.testing.assert_close(a1, e1) 657 torch.testing.assert_close(a2, e2) 658 count, nodes = (2, 4) if should_fuse else (0, 0) 659 self.assertEqual(counters["inductor"]["pattern_matcher_count"], count) 660 self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], nodes) 661 662 def test_addmm_symbolic_scalar(self): 663 def fn(m1, m2): 664 bias = m1.size(0) 665 return torch.add(bias, torch.mm(m1, m2)), torch.mm(m1, m2) + bias 666 667 m1 = torch.randn(16, 16, device="cuda") 668 m2 = torch.randn(16, 16, device="cuda") 669 670 counters.clear() 671 expect = fn(m1, m2) 672 actual = torch.compile(fn, dynamic=True)(m1, m2) 673 self.assertEqual(expect, actual) 674 self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0) 675 676 def test_addmm_broadcasting_bias(self): 677 class Model(torch.nn.Module): 678 def __init__(self) -> None: 679 super().__init__() 680 self.linear = torch.nn.functional.linear 681 self.linear_weight = torch.randn(4, 4).cuda() 682 self.bias = torch.randn(1, 4).cuda() 683 684 def forward(self, x): 685 x = self.linear(x, self.linear_weight, self.bias) 686 return x 687 688 input_tensor = torch.randn(1, 3, 4).cuda() 689 690 func = Model().cuda() 691 692 res1 = func(input_tensor) 693 jit_func = torch.compile(func) 694 res2 = jit_func(input_tensor) 695 696 self.assertEqual(res1, res2) 697 698 def test_cat_mm(self): 699 def fn(a, b, c): 700 return torch.cat( 701 [ 702 torch.mm(a, b), 703 torch.mm(b, c), 704 torch.mm(a, c), 705 ], 706 1, 707 ) 708 709 args = [ 710 torch.randn(16, 16, device="cuda"), 711 torch.randn(16, 16, device="cuda"), 712 torch.randn(16, 16, device="cuda"), 713 ] 714 self.common(fn, args, 1, 4) 715 716 def test_cat_addmm(self): 717 def fn(a, b, c): 718 return torch.cat( 719 [ 720 torch.addmm(a, b, c), 721 torch.addmm(b, c, a), 722 torch.addmm(c, a, b), 723 ], 724 1, 725 ) 726 727 args = [ 728 torch.randn(16, 16, device="cuda"), 729 torch.randn(16, 16, device="cuda"), 730 torch.randn(16, 16, device="cuda"), 731 ] 732 self.common(fn, args, 1, 4) 733 734 def test_cat_slice_cat_cuda(self): 735 def fn(a, b): 736 cat_1 = torch.ops.aten.cat.default([a, b], 1) 737 slice_1 = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807) 738 slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 19) 739 return torch.ops.aten.cat.default([cat_1, slice_2], 1) 740 741 args = [ 742 torch.randn(2, 32, device="cuda"), 743 torch.randn(2, 16, device="cuda"), 744 ] 745 self.common(fn, args, 1, 3) 746 747 args = [ 748 torch.randn(2, 8, device="cuda"), 749 torch.randn(2, 16, device="cuda"), 750 ] 751 counters.clear() 752 expected = fn(*args) 753 actual = torch.compile(fn)(*args) 754 torch.testing.assert_close(actual, expected) 755 # We don't recompile for dynamic-shape cases. 756 if dynamo_config.assume_static_by_default: 757 self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1) 758 self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 3) 759 760 # Verify we fallback to non-optimal path for negative `end`. 761 def fn(a, b): 762 cat_1 = torch.ops.aten.cat.default([a, b], 1) 763 slice_1 = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807) 764 slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 0, -1) 765 return torch.ops.aten.cat.default([cat_1, slice_2], 1) 766 767 args = [ 768 torch.randn(2, 8, device="cuda"), 769 torch.randn(2, 16, device="cuda"), 770 ] 771 self.common(fn, args, 1, 3) 772 773 def test_pointless_convert(self): 774 def fn1(x): 775 x = torch.ops.prims.convert_element_type.default(x, torch.float16) 776 x = torch.ops.prims.convert_element_type.default(x, torch.float32) 777 return x 778 779 gm = torch.fx.symbolic_trace(fn1) 780 self.assertEqual(count_calls(gm.graph), 2) 781 joint_graph.joint_graph_passes(gm) 782 self.assertEqual(count_calls(gm.graph), 1) 783 784 def fn2(x): 785 x = torch.ops.prims.convert_element_type.default(x, torch.int32) 786 x = torch.ops.prims.convert_element_type.default(x, torch.float32) 787 return x 788 789 gm = torch.fx.symbolic_trace(fn2) 790 self.assertEqual(count_calls(gm.graph), 2) 791 joint_graph.joint_graph_passes(gm) 792 self.assertEqual(count_calls(gm.graph), 2) 793 794 # Constant folding was explicitly turned off due to issue #108388 795 # Turn it back on for test 796 @inductor_config.patch(joint_graph_constant_folding=True) 797 def test_pointless_cumsum(self): 798 def fn1(): 799 ones = torch.full( 800 [1, 128], 1, layout=torch.strided, dtype=torch.float32 801 ).to(torch.int64) 802 return torch.cumsum(ones, 1) * ones 803 804 def fn2(): 805 ones = torch.full( 806 [55, 10], 1, layout=torch.strided, dtype=torch.float32 807 ).to(torch.int64) 808 return torch.cumsum(ones, 1) 809 810 def fn3(): 811 twos = torch.full([5, 4, 3], 2, dtype=torch.int64) 812 return torch.cumsum(twos, 0) 813 814 def fn4(): 815 x = torch.full([100], 0.1, dtype=torch.float32) 816 return torch.cumsum(x, 0) 817 818 def fn5(): 819 t1 = torch.full([2, 4], 1) 820 t2 = t1.to(dtype=torch.bool) 821 return torch.cumsum(t2, 1) 822 823 def fn6(): 824 x = torch.full([10, 10], True, dtype=torch.int32) 825 return torch.cumsum(x, 1) 826 827 for fn in (fn1, fn2, fn3, fn4, fn5, fn6): 828 result, (code,) = run_and_get_code(torch.compile(fn, fullgraph=True)) 829 self.assertNotIn("aten.cumsum", code) 830 self.assertEqual(result, fn()) 831 self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1) 832 counters.clear() 833 834 def test_splitwithsizes_cat(self): 835 # Good case 836 def fn(a): 837 split_with_sizes = torch.ops.aten.split_with_sizes.default(a, [8, 24], 1) 838 getitem = split_with_sizes[0] 839 getitem_1 = split_with_sizes[1] 840 cat = torch.ops.aten.cat.default([getitem, getitem_1], 1) 841 return cat**2 842 843 args = [ 844 torch.randn(2, 32, device="cuda"), 845 ] 846 self.common(fn, args, 1, 4) 847 848 # Not all getitems are passed to cat 849 def fn(a): 850 split_with_sizes = torch.ops.aten.split_with_sizes.default(a, [8, 8, 16], 1) 851 getitem = split_with_sizes[0] 852 getitem_1 = split_with_sizes[1] 853 getitem_2 = split_with_sizes[2] 854 cat = torch.ops.aten.cat.default([getitem, getitem_1], 1) 855 return cat**2 + getitem_2 856 857 args = [ 858 torch.randn(2, 32, device="cuda"), 859 ] 860 self.common(fn, args, 0, 0) 861 862 # Different dimensions (TODO this case should be handled by replacing with a reshape) 863 def fn(a): 864 split_with_sizes = torch.ops.aten.split_with_sizes.default( 865 a, [8, 8, 8, 8], 1 866 ) 867 cat = torch.ops.aten.cat.default(split_with_sizes, 0) 868 return cat**2 869 870 args = [ 871 torch.randn(2, 32, device="cuda"), 872 ] 873 self.common(fn, args, 0, 0) 874 875 # https://github.com/pytorch/pytorch/issues/99686. 876 def fn(a): 877 x = torch.ops.aten.split_with_sizes.default(a, [3, 2, 3], dim=1) 878 cat = torch.ops.aten.cat.default([x[1], x[0], x[2]], dim=1) 879 return cat 880 881 args = [ 882 torch.randn(1, 8, device="cuda"), 883 ] 884 self.common(fn, args, 0, 0) 885 886 def test_cat_splitwithsizes(self): 887 # good case 888 def fn(a, b, c): 889 cat = torch.ops.aten.cat.default([a, b, c], 1) 890 split_with_sizes = torch.ops.aten.split_with_sizes.default( 891 cat, [2, 3, 5], 1 892 ) 893 return [s**2 for s in split_with_sizes] 894 895 args = [ 896 torch.randn(2, 2, device="cuda"), 897 torch.randn(2, 3, device="cuda"), 898 torch.randn(2, 5, device="cuda"), 899 ] 900 self.common(fn, args, 1, 2) 901 902 # cat node has other users 903 def fn(a, b, c): 904 cat = torch.ops.aten.cat.default([a, b, c], 1) 905 split_with_sizes = torch.ops.aten.split_with_sizes.default( 906 cat, [2, 3, 5], 1 907 ) 908 return [s**2 for s in split_with_sizes] + [cat**3] 909 910 args = [ 911 torch.randn(2, 2, device="cuda"), 912 torch.randn(2, 3, device="cuda"), 913 torch.randn(2, 5, device="cuda"), 914 ] 915 self.common(fn, args, 0, 0) 916 917 # cat and split dims are different 918 def fn(a, b, c): 919 cat = torch.ops.aten.cat.default([a, b, c], 1) 920 split_with_sizes = torch.ops.aten.split_with_sizes.default( 921 cat, [2, 3, 5], 0 922 ) 923 return [s**2 for s in split_with_sizes] 924 925 args = [ 926 torch.randn(10, 2, device="cuda"), 927 torch.randn(10, 3, device="cuda"), 928 torch.randn(10, 5, device="cuda"), 929 ] 930 self.common(fn, args, 0, 0) 931 932 # cat and split lenghts are different 933 def fn(a, b, c): 934 cat = torch.ops.aten.cat.default([a, b, c], 1) 935 split_with_sizes = torch.ops.aten.split_with_sizes.default(cat, [5, 5], 1) 936 return [s**2 for s in split_with_sizes] 937 938 args = [ 939 torch.randn(2, 2, device="cuda"), 940 torch.randn(2, 3, device="cuda"), 941 torch.randn(2, 5, device="cuda"), 942 ] 943 self.common(fn, args, 0, 0) 944 945 # cat input sizes and split sizes are different 946 def fn(a, b, c): 947 cat = torch.ops.aten.cat.default([a, b, c], 1) 948 split_with_sizes = torch.ops.aten.split_with_sizes.default( 949 cat, [2, 5, 3], 1 950 ) 951 return [s**2 for s in split_with_sizes] 952 953 args = [ 954 torch.randn(2, 2, device="cuda"), 955 torch.randn(2, 3, device="cuda"), 956 torch.randn(2, 5, device="cuda"), 957 ] 958 self.common(fn, args, 0, 0) 959 960 def test_symint_pattern_matching(self): 961 import torch._inductor.config as config 962 from torch._inductor.pattern_matcher import ( 963 fwd_only, 964 PatternMatcherPass, 965 register_replacement, 966 ) 967 968 saved_graph = None 969 970 class _CustomPass(PatternMatcherPass): 971 def __init__(self) -> None: 972 super().__init__() 973 974 def __call__(self, g: torch.fx.graph.Graph): 975 self.apply(g) 976 nonlocal saved_graph 977 saved_graph = g 978 979 with config.patch( 980 # leave custom pass only in post_grad_passes() 981 pattern_matcher=False, 982 # define pattern match as custom post grad opt pass 983 post_grad_custom_pre_pass=None, 984 post_grad_custom_post_pass=_CustomPass(), 985 ): 986 987 def add(x, y): 988 return x + y 989 990 # testing that 991 def sym_minus(x, y): 992 return (x - (-y.size(0))) - (y * -1) - y.size(0) 993 994 device = "cpu" 995 my_args = [ 996 torch.empty([8, 1], device=device), 997 torch.empty([10], device=device), 998 ] 999 1000 invoked = False 1001 1002 def extra_check(match): 1003 nonlocal invoked 1004 invoked = True 1005 return True 1006 1007 register_replacement( 1008 add, 1009 sym_minus, 1010 my_args, 1011 fwd_only, 1012 [config.post_grad_custom_post_pass], 1013 extra_check=extra_check, 1014 ) 1015 1016 @torch.compile(dynamic=True) 1017 def foo(x, y): 1018 return x + y 1019 1020 x = torch.rand([8, 1]) 1021 y = torch.rand([10]) 1022 1023 self.assertEqual(foo(x, y), x + y) 1024 1025 self.assertTrue(invoked) 1026 # we trace out the y.sym_size in replacement 1027 FileCheck().check("sym_size_int").check_same("num_users=2").check_same( 1028 "target=torch.ops.aten.sym_size" 1029 ).run(str(saved_graph)) 1030 1031 @inductor_config.patch(fx_graph_remote_cache=False) 1032 def test_match_with_mutation(self): 1033 counter = 0 1034 test_pass = PatternMatcherPass(pass_name="test") 1035 1036 @register_graph_pattern( 1037 CallFunction( 1038 torch.add, KeywordArg("x"), CallFunction(torch.sin, KeywordArg("x")) 1039 ), 1040 pass_dict=test_pass, 1041 ) 1042 def _test(match, x): 1043 nonlocal counter 1044 counter += 1 1045 1046 def fn0(x, y): 1047 a = torch.sin(x) 1048 b = torch.add(x, a) 1049 return b 1050 1051 def fn1(x, y): 1052 a = torch.sin(x) 1053 x.copy_(y) 1054 b = torch.add(x, a) 1055 return b 1056 1057 def fn2(x, y): 1058 a = torch.sin(x) 1059 with torch.no_grad(): 1060 b = torch.add(x, a) 1061 return b 1062 1063 def fn3(x, y): 1064 a = torch.sin(x) 1065 with torch.autocast("cuda"): 1066 b = torch.add(x, a) 1067 return b 1068 1069 def fn4(x, y): 1070 a = torch.sin(x) 1071 torch.manual_seed(1234) 1072 b = torch.add(x, a) 1073 return b 1074 1075 def fn5(x, y): 1076 a = torch.sin(x) 1077 torch.add(y, 1, out=x) 1078 b = torch.add(x, a) 1079 return b 1080 1081 args = [ 1082 torch.randn(5, 5, device="cuda"), 1083 torch.randn(5, 5, device="cuda"), 1084 ] 1085 1086 with unittest.mock.patch( 1087 "torch._inductor.fx_passes.pre_grad.config.pre_grad_fusion_options", 1088 {"test": {}}, 1089 ), unittest.mock.patch( 1090 "torch._inductor.fx_passes.pre_grad.PRE_GRAD_FUSIONS", 1091 [], 1092 ), unittest.mock.patch( 1093 "torch._inductor.fx_passes.pre_grad.PRE_GRAD_PATTERNS", 1094 {"test": test_pass}, 1095 ): 1096 for fn in (fn0, fn1, fn2, fn3, fn4, fn5): 1097 counter = 0 1098 expected = fn(*copy.deepcopy(args)) 1099 actual = torch.compile(fn)(*copy.deepcopy(args)) 1100 # should not match 1101 self.assertEqual(counter, int(fn is fn0)) 1102 torch.testing.assert_close(actual, expected) 1103 1104 def test_remove_pointless_clones(self): 1105 @torch.compile(fullgraph=True) 1106 def fn(a, b): 1107 return torch.mm(a, b).clone() 1108 1109 result, (code) = run_and_get_code(fn, torch.randn(8, 8), torch.randn(8, 8)) 1110 # clone would create a buf1 1111 self.assertIn("return (buf0, )", code[0]) 1112 self.assertNotIn("async_compile.cpp", code[0]) 1113 1114 def test_unfuse_bias_addmm(self): 1115 args = [ 1116 torch.randn(20, device="cuda"), 1117 torch.randn(10, 15, device="cuda"), 1118 torch.randn(15, 20, device="cuda"), 1119 ] 1120 1121 @torch.compile() 1122 def fn(inp, a, b): 1123 return torch.ops.aten.addmm(inp, a, b) 1124 1125 _, (code) = run_and_get_code(fn, args[0], args[1], args[2]) 1126 FileCheck().check("extern_kernels.addmm(").run(code[0]) 1127 1128 @torch.compile() 1129 def fn2(inp, a, b): 1130 return torch.nn.functional.gelu(torch.ops.aten.addmm(inp, a, b)) 1131 1132 _, (code) = run_and_get_code(fn2, args[0], args[1], args[2]) 1133 FileCheck().check_not("extern_kernels.addmm(").run(code[0]) 1134 1135 @torch.compile() 1136 def fn2(inp, a, b): 1137 return torch.nn.functional.gelu( 1138 torch.ops.aten.addmm(inp, a, b).unsqueeze(0) 1139 ) 1140 1141 # hit the view path 1142 _, (code) = run_and_get_code(fn2, args[0], args[1], args[2]) 1143 FileCheck().check_not("extern_kernels.addmm(").run(code[0]) 1144 1145 def test_serialized_patterns_up_to_date(self): 1146 import torch.utils._pytree as pytree 1147 from torch._inductor.fx_passes import joint_graph 1148 from torch._inductor.pattern_matcher import _known_precompiled_patterns 1149 1150 # Ensure the patterns are loaded 1151 os.environ.pop("PYTORCH_GEN_PATTERNS", None) 1152 joint_graph.lazy_init() 1153 1154 with torch._subclasses.FakeTensorMode() as mode: 1155 for ( 1156 search_fn, 1157 example_inputs, 1158 trace_fn, 1159 scalar_workaround, 1160 search_fn_pattern, 1161 ) in _known_precompiled_patterns: 1162 # Because the example_inputs were saved as fake tensors in a 1163 # different FakeTensorMode we need to update them to our 1164 # FakeTensorMode(). 1165 def remap_fake_tensor(x): 1166 if isinstance(x, torch.Tensor): 1167 return torch._subclasses.FakeTensor.from_tensor(x, mode) 1168 return x 1169 1170 example_inputs = pytree.tree_map(remap_fake_tensor, example_inputs) 1171 1172 pattern = gen_pattern( 1173 search_fn, example_inputs, trace_fn, scalar_workaround 1174 ) 1175 pattern_pp = PatternPrettyPrinter.run(pattern) 1176 1177 self.assertEqual( 1178 pattern_pp, 1179 PatternPrettyPrinter.run(search_fn_pattern), 1180 msg=f"Found mismatched pattern {search_fn.__name__}. Run torchgen/fuse/gen_patterns.py", 1181 ) 1182 1183 # Since we've already checked that the serialized patterns match 1184 # lets verify the serializer by ensuring the generated patterns 1185 # also match (since search_fn_pattern is the serialized version 1186 # of search_fn). 1187 self.assertTrue(pattern.pattern_eq(search_fn_pattern)) 1188 1189 @inductor_config.patch(fx_graph_remote_cache=False) 1190 def test_match_equivalent_function_invocations1(self): 1191 counter = 0 1192 test_pass = PatternMatcherPass() 1193 1194 args = [ 1195 torch.randn(20, device="cuda"), 1196 torch.randn(10, 15, device="cuda"), 1197 torch.randn(15, 20, device="cuda"), 1198 ] 1199 1200 def f0(inp, a, b): 1201 return torch.ops.aten.addmm(inp, a, b) 1202 1203 def f1(inp, a, b): 1204 return torch.ops.aten.addmm(inp, a, b, beta=1.0) 1205 1206 def f2(inp, a, b): 1207 return torch.ops.aten.addmm(inp, a, b, beta=1.0, alpha=1.0) 1208 1209 # This graph pattern should successfully match all of the above functions 1210 @register_graph_pattern( 1211 CallFunction( 1212 torch.ops.aten.addmm, 1213 Arg(), 1214 Arg(), 1215 Arg(), 1216 beta=KeywordArg("beta"), 1217 alpha=KeywordArg("alpha"), 1218 ), 1219 pass_dict=test_pass, 1220 ) 1221 def addmm_replacement(match: Match, inp, mat1, mat2, beta, alpha): 1222 nonlocal counter 1223 counter += 1 1224 1225 def repl(inp, x1, x2): 1226 return (x1 @ x2) * alpha + inp * beta 1227 1228 with V.fake_mode: 1229 match.replace_by_example(repl, [inp, mat1, mat2]) 1230 1231 with unittest.mock.patch( 1232 "torch._inductor.fx_passes.post_grad.pass_patterns", 1233 torch._inductor.fx_passes.post_grad.pass_patterns + [test_pass], 1234 ): 1235 for fn in (f0, f1, f2): 1236 counter = 0 1237 expected = fn(*copy.deepcopy(args)) 1238 opt_fn = torch.compile(fn) 1239 actual, (code) = run_and_get_code(opt_fn, args[0], args[1], args[2]) 1240 # pattern should match 1241 self.assertEqual(counter, 1) 1242 torch.testing.assert_close(actual, expected) 1243 # addmm should be replaced 1244 FileCheck().check_not("extern_kernels.addmm(").run(code[0]) 1245 1246 @inductor_config.patch(fx_graph_remote_cache=False) 1247 def test_match_equivalent_function_invocations2(self): 1248 counter = 0 1249 test_pass = PatternMatcherPass() 1250 1251 args = [ 1252 torch.randn(20, device="cuda"), 1253 torch.randn(10, 15, device="cuda"), 1254 torch.randn(15, 20, device="cuda"), 1255 ] 1256 1257 def f0(inp, a, b): 1258 return torch.ops.aten.addmm(inp, a, b) 1259 1260 def f1(inp, a, b): 1261 return torch.ops.aten.addmm(inp, a, b, beta=1.0) 1262 1263 def f2(inp, a, b): 1264 return torch.ops.aten.addmm(inp, a, b, beta=1.0, alpha=1.0) 1265 1266 # This graph pattern should only match f0 1267 @register_graph_pattern( 1268 CallFunction(torch.ops.aten.addmm, Arg(), Arg(), Arg()), 1269 pass_dict=test_pass, 1270 ) 1271 def addmm_replacement(match: Match, inp, mat1, mat2): 1272 nonlocal counter 1273 counter += 1 1274 1275 def repl(inp, x1, x2): 1276 return x1 @ x2 + inp 1277 1278 with V.fake_mode: 1279 match.replace_by_example(repl, [inp, mat1, mat2]) 1280 1281 with unittest.mock.patch( 1282 "torch._inductor.fx_passes.post_grad.pass_patterns", 1283 torch._inductor.fx_passes.post_grad.pass_patterns + [test_pass], 1284 ): 1285 for fn in (f0, f1, f2): 1286 counter = 0 1287 expected = fn(*copy.deepcopy(args)) 1288 actual = torch.compile(fn)(*copy.deepcopy(args)) 1289 self.assertEqual(counter, 1) 1290 torch.testing.assert_close(actual, expected) 1291 1292 @inductor_config.patch(fx_graph_remote_cache=False) 1293 def test_match_equivalent_function_invocations3(self): 1294 counter = 0 1295 test_pass = PatternMatcherPass() 1296 1297 args = [ 1298 torch.randn(20, device="cuda"), 1299 torch.randn(10, 15, device="cuda"), 1300 torch.randn(15, 20, device="cuda"), 1301 ] 1302 1303 def f0(inp, a, b): 1304 return torch.ops.aten.addmm(inp, a, b) 1305 1306 def f1(inp, a, b): 1307 return torch.ops.aten.addmm(inp, a, b, beta=1.0) 1308 1309 def f2(inp, a, b): 1310 return torch.ops.aten.addmm(inp, a, b, beta=1.0, alpha=1.0) 1311 1312 # This graph pattern should only match f1 1313 @register_graph_pattern( 1314 CallFunction( 1315 torch.ops.aten.addmm, Arg(), Arg(), Arg(), beta=KeywordArg("beta") 1316 ), 1317 pass_dict=test_pass, 1318 ) 1319 def addmm_replacement(match: Match, inp, mat1, mat2, beta): 1320 nonlocal counter 1321 counter += 1 1322 1323 def repl(inp, x1, x2): 1324 return x1 @ x2 + inp 1325 1326 with V.fake_mode: 1327 match.replace_by_example(repl, [inp, mat1, mat2]) 1328 1329 with unittest.mock.patch( 1330 "torch._inductor.fx_passes.post_grad.pass_patterns", 1331 torch._inductor.fx_passes.post_grad.pass_patterns + [test_pass], 1332 ): 1333 for fn in (f0, f1, f2): 1334 counter = 0 1335 expected = fn(*copy.deepcopy(args)) 1336 actual = torch.compile(fn)(*copy.deepcopy(args)) 1337 self.assertEqual(counter, 1) 1338 torch.testing.assert_close(actual, expected) 1339 1340 def test_stable_topological_sort(self): 1341 def fn1(a, b): 1342 return a + b 1343 1344 graph = torch.fx.Graph() 1345 a = graph.placeholder("x") 1346 b = graph.placeholder("y") 1347 c = graph.call_function(fn1, (a, b)) 1348 stable_topological_sort(graph) 1349 self.assertEqual(list(graph.nodes), [a, b, c]) 1350 1351 graph = torch.fx.Graph() 1352 b = graph.placeholder("y") 1353 a = graph.placeholder("x") 1354 c = graph.call_function(fn1, (a, b)) 1355 stable_topological_sort(graph) 1356 self.assertEqual(list(graph.nodes), [b, a, c]) 1357 1358 graph = torch.fx.Graph() 1359 a = graph.placeholder("x") 1360 b = graph.placeholder("y") 1361 c = graph.call_function(fn1, (b, a)) 1362 c.append(a) 1363 stable_topological_sort(graph) 1364 self.assertEqual(list(graph.nodes), [b, a, c]) 1365 1366 def test_scaled_softmax(self): 1367 def mul_softmax(a, b): 1368 return F.softmax(a * b, dim=0) 1369 1370 def div_softmax(x, inv_scale): 1371 return F.softmax(x / inv_scale, dim=0) 1372 1373 x = torch.randn(10, 10) 1374 scale = 1e6 1375 inv_scale = 1 / scale 1376 self.common(mul_softmax, (x, scale), 1, 3) 1377 self.common(mul_softmax, (scale, x), 1, 3) 1378 self.common(div_softmax, (x, inv_scale), 1, 3) 1379 1380 scale = torch.randn(10) * 1e6 1381 inv_scale = 1 / scale 1382 self.common(mul_softmax, (x, scale), 1, 3) 1383 self.common(mul_softmax, (scale, x), 1, 3) 1384 self.common(div_softmax, (x, inv_scale), 1, 3) 1385 1386 scale = torch.randn(1, 10) * 1e6 1387 inv_scale = 1 / scale 1388 self.common(mul_softmax, (x, scale), 1, 3) 1389 self.common(mul_softmax, (scale, x), 1, 3) 1390 self.common(div_softmax, (x, inv_scale), 1, 3) 1391 1392 # Test matching with type promotion 1393 x = torch.randn(10, 10, dtype=torch.bfloat16) 1394 scale = torch.randn(10, dtype=torch.bfloat16) * 1e6 1395 inv_scale = 1 / scale 1396 self.common(mul_softmax, (x, scale), 1, 4, reference_in_float=True) 1397 self.common(mul_softmax, (scale, x), 1, 4, reference_in_float=True) 1398 self.common(div_softmax, (x, inv_scale), 1, 4, reference_in_float=True) 1399 1400 # No match if scale changes in softmax dim 1401 scale = torch.randn(10, 10) 1402 self.common(mul_softmax, (x, scale), 0, 0) 1403 self.common(mul_softmax, (scale, x), 0, 0) 1404 self.common(div_softmax, (x, scale), 0, 0) 1405 1406 def test_mutation_op_matching(self): 1407 def check(type, func_name, args, kwargs, expect=True): 1408 assert type in ["call_function", "call_method"] 1409 graph = torch.fx.Graph() 1410 getattr(graph, type)(func_name, args, kwargs) 1411 res = is_mutation_op(next(iter(graph.nodes))) 1412 if expect: 1413 self.assertTrue(res) 1414 else: 1415 self.assertFalse(res) 1416 1417 t = torch.randn(1) 1418 check("call_function", torch._C._set_grad_enabled, (False,), {}) 1419 check("call_method", "copy_", (t, t), {}) 1420 check("call_method", "relu_", (t,), {}) 1421 check("call_function", torch.manual_seed, (0,), {}) 1422 check("call_function", torch.ops.aten.set_.source_Tensor, (t, t), {}) 1423 check( 1424 "call_function", 1425 torch.amp.autocast_mode._enter_autocast, 1426 ("cuda", None, True, None), 1427 {}, 1428 ) 1429 check("call_function", torch.amp.autocast_mode._exit_autocast, (None,), {}) 1430 check( 1431 "call_function", 1432 torch.ops._c10d_functional.all_gather_into_tensor_out, 1433 (t, 2, "0"), 1434 {"out": t}, 1435 ) 1436 check("call_function", torch.ops.inductor.resize_storage_bytes_, (t, 0), {}) 1437 check( 1438 "call_function", 1439 torch.ops.inductor.resize_storage_bytes_.default, 1440 (t, 0), 1441 {}, 1442 ) 1443 check( 1444 "call_function", 1445 torch.ops.fsdp.split_with_sizes_copy, 1446 (t, [64, 128, 8, 8]), 1447 {"dim": 1, "out": [t, t, t, t]}, 1448 ) 1449 check("call_function", torch.ops.fsdp.set_, (t, t), {}) 1450 check( 1451 "call_function", torch.ops.aten.__rshift__.Scalar, (t, 2), {}, expect=False 1452 ) 1453 check( 1454 "call_function", 1455 torch.ops._c10d_functional.all_gather_into_tensor, 1456 (t, 2, "0"), 1457 {}, 1458 expect=False, 1459 ) 1460 1461 1462if __name__ == "__main__": 1463 if IS_LINUX and HAS_CUDA: 1464 run_tests() 1465