1# Owner(s): ["module: inductor"] 2import contextlib 3import functools 4import importlib 5import itertools 6import os 7import sys 8import unittest 9import weakref 10 11import torch 12from torch import nn 13from torch._dynamo.utils import counters 14from torch._inductor import config 15from torch._inductor.test_case import TestCase as InductorTestCase 16from torch._inductor.utils import override_lowering, run_and_get_code 17from torch.testing import FileCheck 18from torch.testing._internal.common_cuda import SM80OrLater 19from torch.testing._internal.common_utils import skipIfRocm 20 21 22# Make the helper files in test/ importable 23pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 24sys.path.append(pytorch_test_dir) 25 26from inductor.test_torchinductor import check_model, check_model_cuda, copy_tests 27from torch.testing._internal.common_utils import TEST_WITH_ASAN, TEST_WITH_ROCM 28 29 30importlib.import_module("functorch") 31importlib.import_module("filelock") 32 33from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA 34 35 36aten = torch.ops.aten 37prims = torch.ops.prims 38requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") 39 40 41class TestCase(InductorTestCase): 42 @classmethod 43 def setUpClass(cls): 44 super().setUpClass() 45 cls._stack = contextlib.ExitStack() 46 cls._stack.enter_context( 47 config.patch( 48 { 49 "debug": True, 50 "cpp.min_chunk_size": 1, 51 "triton.autotune_pointwise": False, # too slow 52 "implicit_fallbacks": False, 53 "freezing": True, 54 "freezing_discard_parameters": True, 55 } 56 ) 57 ) 58 59 @classmethod 60 def tearDownClass(cls): 61 cls._stack.close() 62 super().tearDownClass() 63 64 def setUp(self): 65 torch._dynamo.reset() 66 super().setUp() 67 68 def tearDown(self): 69 super().tearDown() 70 torch._dynamo.reset() 71 72 73class ConvBN(torch.nn.Module): 74 def __init__(self, in_channels, out_channels, bias=False, **kwargs): 75 super().__init__() 76 self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs) 77 self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001, dtype=torch.float) 78 79 def forward(self, x): 80 return self.bn(self.conv(x)) 81 82 83class ConvBNHardswish(torch.nn.Module): 84 def __init__(self, in_channels, out_channels, bias=False, **kwargs): 85 super().__init__() 86 self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs) 87 self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001, dtype=torch.float) 88 self.hardswish = nn.Hardswish(inplace=True) 89 90 def forward(self, x): 91 return self.hardswish(self.bn(self.conv(x))) 92 93 94class ConvFunctionalBN(torch.nn.Module): 95 def __init__( 96 self, 97 in_channels, 98 out_channels, 99 bias=False, 100 kernel_size=3, 101 stride=2, 102 running_mean=None, 103 running_var=None, 104 weight=None, 105 bn_bias=None, 106 ): 107 super().__init__() 108 self.conv = torch.nn.Conv2d( 109 in_channels, out_channels, bias=bias, kernel_size=kernel_size, stride=stride 110 ) 111 self.running_mean = running_mean 112 self.running_var = running_var 113 self.weight = weight 114 self.bias = bn_bias 115 116 def forward(self, x): 117 return torch.nn.functional.batch_norm( 118 self.conv(x), 119 self.running_mean, 120 self.running_var, 121 self.weight, 122 self.bias, 123 False, 124 0.1, 125 1e-5, 126 ) 127 128 129class ConvMultiBN(torch.nn.Module): 130 def __init__(self, in_channels, out_channels, bias=False, **kwargs): 131 super().__init__() 132 self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs) 133 self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001, dtype=torch.float) 134 self.bn2 = torch.nn.BatchNorm2d(out_channels, eps=0.1, dtype=torch.float) 135 136 def forward(self, x): 137 tmp = self.bn(self.conv(x)) 138 tmp2 = self.bn2(self.conv(x)) 139 return tmp + tmp2 140 141 142class ConvMultiFunctionalBN(torch.nn.Module): 143 def __init__( 144 self, 145 in_channels, 146 out_channels, 147 bias=False, 148 kernel_size=3, 149 stride=2, 150 running_mean=None, 151 running_var=None, 152 weight=None, 153 bn_bias=None, 154 running_mean2=None, 155 ): 156 super().__init__() 157 self.conv = torch.nn.Conv2d( 158 in_channels, out_channels, bias=bias, kernel_size=kernel_size, stride=stride 159 ) 160 self.running_mean = running_mean 161 self.running_var = running_var 162 self.weight = weight 163 self.bias = bn_bias 164 self.running_mean2 = running_mean2 165 166 def forward(self, x): 167 tmp = torch.nn.functional.batch_norm( 168 self.conv(x), 169 self.running_mean, 170 self.running_var, 171 self.weight, 172 self.bias, 173 False, 174 0.1, 175 1e-5, 176 ) 177 tmp2 = torch.nn.functional.batch_norm( 178 self.conv(x), 179 self.running_mean2, 180 self.running_var, 181 self.weight, 182 self.bias, 183 False, 184 0.1, 185 1e-5, 186 ) 187 return tmp + tmp2 188 189 190class OptimizeForInferenceTemplate(TestCase): 191 def test_mutation(self): 192 class Mod(torch.nn.Module): 193 def __init__(self) -> None: 194 super().__init__() 195 self.mutated_param = torch.nn.Parameter(torch.zeros([10, 10])) 196 197 def forward(self): 198 self.mutated_param.add_(10) 199 return self.mutated_param 200 201 with torch.no_grad(): 202 mod = Mod().to(self.device) 203 out_eager = mod() 204 out_eager2 = mod() 205 206 mod = Mod().to(self.device) 207 208 @torch.compile 209 def foo(mod): 210 return mod() 211 212 out_comp = foo(mod) 213 out_comp2 = foo(mod) 214 215 self.assertEqual(out_eager, out_comp) 216 self.assertEqual(out_eager2, out_comp2) 217 218 def test_aliased_param_return(self): 219 class Mod(torch.nn.Module): 220 def __init__(self) -> None: 221 super().__init__() 222 self.aliased_param = torch.nn.Parameter(torch.zeros([10, 10])) 223 224 def forward(self): 225 return self.aliased_param[1:], self.aliased_param 226 227 mod = Mod().to(self.device).eval() 228 229 @torch.compile() 230 def foo(mod): 231 return mod() 232 233 with torch.no_grad(): 234 mod_eager = mod() 235 self.assertEqual(foo(mod), mod_eager) 236 237 def test_autocast(self): 238 if self.device == "cpu": 239 raise unittest.SkipTest("MLKDNN Bug") 240 241 mod = torch.nn.Linear(10, 10).to(self.device).eval() 242 inp = torch.rand([10, 10]).to(self.device).to(torch.half) 243 244 @torch.compile() 245 def foo(mod, inp): 246 return mod(inp) 247 248 with torch.no_grad(): 249 with self.autocast(): 250 out_eager = mod(inp) 251 out_compiled, code = run_and_get_code(foo, mod, inp) 252 253 FileCheck().check_not("@triton.jit").run(code[0]) 254 self.assertEqual(out_eager, out_compiled) 255 256 def test_mm_concat(self): 257 # CPU path will replace mm with mkl._linear, 258 # skip this case for now. 259 if self.device == "cpu": 260 raise unittest.SkipTest("NYI CPU") 261 262 class MM(torch.nn.Module): 263 def __init__(self) -> None: 264 super().__init__() 265 266 self.t1 = torch.nn.Parameter(torch.rand(10, 10)) 267 self.t2 = torch.nn.Parameter(torch.rand(10, 10)) 268 self.t3 = torch.nn.Parameter(torch.rand(10, 10)) 269 270 def forward(self, x): 271 return x @ self.t1, x @ self.t2, x @ self.t3 272 273 class MM2(torch.nn.Module): 274 def __init__(self) -> None: 275 super().__init__() 276 277 self.t1 = torch.nn.Parameter(torch.rand(10, 10)) 278 self.t2 = torch.nn.Parameter(torch.rand(10, 10)) 279 280 def forward(self, x): 281 return x @ self.t1, x @ self.t2 282 283 class AddMM(MM): 284 def __init__(self) -> None: 285 super().__init__() 286 287 self.b1 = torch.nn.Parameter(torch.rand([10])) 288 self.b2 = torch.nn.Parameter(torch.rand([10])) 289 self.b3 = torch.nn.Parameter(torch.rand([10])) 290 291 def forward(self, x): 292 return [ 293 aten.addmm(b, x, p) 294 for b, p in [ 295 (self.b1, self.t1), 296 (self.b2, self.t2), 297 (self.b3, self.t3), 298 ] 299 ] 300 301 for mod_fn in [ 302 lambda: MM().to(self.device), 303 lambda: MM2().to(self.device), 304 lambda: AddMM().to(self.device), 305 ]: 306 mod = mod_fn() 307 inp = torch.rand([10, 10]).to(self.device) 308 309 @torch.compile() 310 def foo(mod, inp): 311 return mod(inp) 312 313 kernel_invoke = "kernel_cpp_0" if self.device == "cpu" else "triton.jit" 314 315 with torch.no_grad(): 316 out_eager = mod(inp) 317 out, code = run_and_get_code(foo, mod, inp) 318 FileCheck().check_not(kernel_invoke).check_count( 319 "mm(", count=1, exactly=True 320 ).run(code[0]) 321 self.assertEqual(out_eager, out) 322 323 mod2 = mod_fn() 324 mod2.t1 = torch.nn.Parameter(torch.rand([10, 15], device=self.device)) 325 mod2.t2 = torch.nn.Parameter(torch.rand([10, 20], device=self.device)) 326 327 if hasattr(mod2, "b1"): 328 mod2.b1 = torch.nn.Parameter(torch.rand([15], device=self.device)) 329 mod2.b2 = torch.nn.Parameter(torch.rand([20], device=self.device)) 330 331 # not fused 332 count = 3 if hasattr(mod2, "t3") else 2 333 334 with torch.no_grad(): 335 out_eager = mod2(inp) 336 out, code = run_and_get_code(foo, mod2, inp) 337 FileCheck().check_not(kernel_invoke).check_count( 338 "mm(", count=count, exactly=True 339 ).run(code[0]) 340 self.assertEqual(out_eager, out) 341 342 # With inlining of inbuilt nn modules, Dynamo traces the innards of inbuilt 343 # module and does not modify the eager module. 344 @torch._dynamo.config.patch(inline_inbuilt_nn_modules=False) 345 def test_error_on_eager(self): 346 mod = ConvBN(3, 32, kernel_size=3, stride=2).eval().to(self.device) 347 348 x = torch.rand(3, 3, 32, 32).to(self.device) 349 350 @torch.compile() 351 def foo(mod, x): 352 return mod(x) 353 354 with torch.no_grad(): 355 foo(mod, x) 356 357 with self.assertRaisesRegex( 358 RuntimeError, "Trying to run Pytorch Eager Module after Dynamo Freezing" 359 ): 360 mod(x) 361 362 def test_rng_op(self): 363 @torch.compile() 364 def foo(): 365 return torch.rand([4, 4], device=self.device) + 1 366 367 with torch.no_grad(): 368 o1 = foo() 369 o2 = foo() 370 self.assertNotEqual(o1, o2) 371 372 def test_symint_not_folded(self): 373 def fn(a): 374 return a.cos(), torch.zeros(a.shape[0], a.shape[1]) 375 376 fn_opt = torch._dynamo.optimize("inductor", dynamic=True)(fn) 377 inp = torch.randn(2, 4, 6).to(self.device) 378 torch._dynamo.mark_dynamic(inp, 0) 379 torch._dynamo.mark_dynamic(inp, 1) 380 381 with torch.no_grad(): 382 self.assertEqual(fn(inp), fn_opt(inp)) 383 inp2 = torch.randn(3, 5, 6).to(self.device) 384 torch._dynamo.mark_dynamic(inp2, 0) 385 torch._dynamo.mark_dynamic(inp2, 1) 386 self.assertEqual(fn(inp2), fn_opt(inp2)) 387 388 @requires_cuda 389 def test_conv_multiple_uses(self): 390 from torch import nn 391 392 class ToyModel(nn.Module): 393 def __init__(self, *args, **kwargs) -> None: 394 super().__init__(*args, **kwargs) 395 self.conv1 = nn.Conv2d(1, 1, 1) 396 self.bn1 = nn.BatchNorm2d(1) 397 self.bn1.weight.data.normal_() 398 399 def forward(self, x, y): 400 return self.conv1(x) + self.bn1(self.conv1(y)) 401 402 model = ToyModel() 403 model.eval().cuda() 404 405 a = torch.rand(64, 1, 32, 32).cuda() 406 b = torch.rand(64, 1, 32, 32).cuda() 407 408 output = model(a, b) 409 410 with torch.no_grad(): 411 output2 = torch.compile(model)(a, b) 412 413 self.assertEqual(output, output2) 414 415 def test_unfolded_bn(self): 416 x = torch.rand([3, 32, 15, 15]).to(self.device) 417 418 mod = torch.nn.BatchNorm2d(32, eps=0.001).eval().to(self.device) 419 420 @torch.compile() 421 def foo(mod, x): 422 return mod(x) + 10 423 424 out_compiled_no_inference = foo(mod, x) 425 426 # would error if not decomposed 427 with torch.no_grad(): 428 out_compiled = foo(mod, x) 429 430 self.assertEqual(out_compiled_no_inference, out_compiled) 431 432 @torch._inductor.config.patch(layout_optimization=False) 433 def test_folded_conv_bn(self): 434 for use_bias, dtype in itertools.product( 435 [True, False], [torch.float16, torch.bfloat16, torch.float32] 436 ): 437 if self.device == "cpu" and dtype == torch.float16: 438 continue 439 440 if self.device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: 441 continue 442 443 mod = ( 444 ConvBN(3, 32, bias=use_bias, kernel_size=3, stride=2) 445 .eval() 446 .to(self.device) 447 .to(dtype) 448 ) 449 450 x = torch.rand(3, 3, 32, 32).to(self.device).to(dtype) 451 452 torch._dynamo.reset() 453 counters.clear() 454 455 @torch.compile() 456 def foo(mod, x): 457 return mod(x) 458 459 # TODO - bias is separate kernel right now, we should only unfuse it 460 # from conv if it can be fused 461 462 with torch.no_grad(): 463 out_eager = mod(x) 464 out_optimized_for_infernece, code = run_and_get_code(foo, mod, x) 465 466 # we unfuse the conv bias, but it should only have one constant in the kernel 467 if self.device == "cuda": 468 FileCheck().check_not(".run(").check("conv").check(".run(").check_same( 469 "frozen_param" 470 ).check_not("frozen_param").check_next("return").run(code[0]) 471 472 self.assertEqual( 473 out_optimized_for_infernece, out_eager, atol=1e-2, rtol=1e-2 474 ) 475 self.assertEqual(counters["inductor"]["binary_folding"], 4) 476 477 @torch._inductor.config.patch(layout_optimization=False) 478 def test_folded_conv_bn_hardswish(self): 479 for use_bias, dtype in itertools.product( 480 [True, False], [torch.float16, torch.bfloat16, torch.float32] 481 ): 482 if self.device == "cpu" and dtype == torch.float16: 483 continue 484 485 if self.device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: 486 continue 487 488 mod = ( 489 ConvBNHardswish(3, 32, bias=use_bias, kernel_size=3, stride=2) 490 .eval() 491 .to(self.device) 492 .to(dtype) 493 ) 494 495 x = torch.rand(3, 3, 32, 32).to(self.device).to(dtype) 496 497 torch._dynamo.reset() 498 counters.clear() 499 500 @torch.compile() 501 def foo(mod, x): 502 return mod(x) 503 504 # TODO - bias is separate kernel right now, we should only unfuse it 505 # from conv if it can be fused 506 507 with torch.no_grad(): 508 out_eager = mod(x) 509 out_optimized_for_infernece, code = run_and_get_code(foo, mod, x) 510 511 # we unfuse the conv bias, but it should only have one constant in the kernel 512 if self.device == "cuda": 513 FileCheck().check_not(".run(").check("conv").check(".run(").check_same( 514 "frozen_param" 515 ).check_not("frozen_param").check_next("return").run(code[0]) 516 517 self.assertEqual( 518 out_optimized_for_infernece, out_eager, atol=1e-2, rtol=1e-2 519 ) 520 self.assertEqual(counters["inductor"]["binary_folding"], 4) 521 522 @torch._inductor.config.patch(layout_optimization=False) 523 def test_folded_conv_bn_with_module_sharing(self): 524 mod = ( 525 ConvBN(32, 32, bias=True, kernel_size=3, stride=2) 526 .to(self.device) 527 .to(torch.float32) 528 ) 529 530 # Update the default parameters of BN module 531 for _ in range(10): 532 mod(torch.rand(3, 32, 32, 32).to(self.device).to(torch.float32)) 533 534 mod.eval() 535 x = torch.rand(3, 32, 32, 32).to(self.device).to(torch.float32) 536 537 def foo(mod, x): 538 mod(x) 539 return mod(x) 540 541 with torch.no_grad(): 542 out_eager = foo(mod, x) 543 out_optimized_for_infernece, _ = run_and_get_code( 544 torch.compile(foo), mod, x 545 ) 546 547 self.assertEqual(out_optimized_for_infernece, out_eager, atol=1e-2, rtol=1e-2) 548 549 @torch._inductor.config.patch(layout_optimization=False) 550 def test_folded_conv_functional_bn_with_module_sharing(self): 551 x = torch.rand(3, 32, 32, 32).to(self.device).to(torch.float32) 552 running_mean = torch.mean(x, dim=(0, 2, 3)).to(self.device) 553 running_var = torch.var(x, dim=(0, 2, 3)).to(self.device) 554 555 mod = ( 556 ConvFunctionalBN( 557 32, 558 32, 559 bias=True, 560 kernel_size=3, 561 stride=2, 562 running_mean=running_mean, 563 running_var=running_var, 564 weight=torch.ones(32).to(self.device), 565 bn_bias=torch.zeros(32).to(self.device), 566 ) 567 .eval() 568 .to(self.device) 569 .to(torch.float32) 570 ) 571 572 def foo(mod, x): 573 mod(x) 574 return mod(x) 575 576 with torch.no_grad(): 577 out_eager = foo(mod, x) 578 out_optimized_for_infernece, _ = run_and_get_code( 579 torch.compile(foo), mod, x 580 ) 581 582 self.assertEqual(out_optimized_for_infernece, out_eager, atol=1e-2, rtol=1e-2) 583 584 @torch._inductor.config.patch(layout_optimization=False) 585 def test_conv_bn_with_multi_bn_share_conv(self): 586 mod = ( 587 ConvMultiBN(32, 32, bias=True, kernel_size=3, stride=2) 588 .to(self.device) 589 .to(torch.float32) 590 ) 591 592 # Update the default parameters of BN module 593 for _ in range(10): 594 mod(torch.rand(3, 32, 32, 32).to(self.device).to(torch.float32)) 595 596 mod.eval() 597 x = torch.rand(3, 32, 32, 32).to(self.device).to(torch.float32) 598 599 def foo(mod, x): 600 return mod(x) 601 602 with torch.no_grad(): 603 out_eager = foo(mod, x) 604 out_optimized_for_infernece, _ = run_and_get_code( 605 torch.compile(foo), mod, x 606 ) 607 608 self.assertEqual(out_optimized_for_infernece, out_eager, atol=1e-2, rtol=1e-2) 609 610 @torch._inductor.config.patch(layout_optimization=False) 611 def test_conv_functional_bn_with_multi_bn_share_conv(self): 612 x = torch.rand(3, 32, 32, 32).to(self.device).to(torch.float32) 613 running_mean = torch.mean(x, dim=(0, 2, 3)).to(self.device) 614 running_var = torch.var(x, dim=(0, 2, 3)).to(self.device) 615 running_mean2 = torch.mean(x, dim=(0, 2, 3)).to(self.device) 616 617 mod = ( 618 ConvMultiFunctionalBN( 619 32, 620 32, 621 bias=True, 622 kernel_size=3, 623 stride=2, 624 running_mean=running_mean, 625 running_var=running_var, 626 weight=torch.ones(32).to(self.device), 627 bn_bias=torch.zeros(32).to(self.device), 628 running_mean2=running_mean2, 629 ) 630 .eval() 631 .to(self.device) 632 .to(torch.float32) 633 ) 634 635 def foo(mod, x): 636 return mod(x) 637 638 with torch.no_grad(): 639 out_eager = foo(mod, x) 640 out_optimized_for_infernece, _ = run_and_get_code( 641 torch.compile(foo), mod, x 642 ) 643 self.assertEqual(out_optimized_for_infernece, out_eager, atol=1e-2, rtol=1e-2) 644 645 @torch._inductor.config.patch(layout_optimization=False) 646 def test_dont_change_dtype_folding(self): 647 dtype = torch.float16 if self.device == "cuda" else torch.bfloat16 648 649 mod = ( 650 torch.nn.Conv2d(3, 32, bias=None, kernel_size=3, stride=2) 651 .eval() 652 .to(self.device) 653 .to(dtype) 654 ) 655 x = torch.rand(3, 3, 32, 32).to(self.device).to(dtype) 656 657 def foo(mod, x): 658 return mod(x) * torch.full([1], 2.0, device=self.device) 659 660 foo_c = torch.compile(foo) 661 662 with torch.no_grad(): 663 out_eager = foo(mod, x) 664 out_compiled = foo_c(mod, x) 665 self.assertEqual(out_eager, out_compiled) 666 667 def test_param_deallocated(self): 668 # TODO: cpu path keeps an extra copy of graph around somewhere, 669 # memory not as important for cpu 670 if self.device == "cpu": 671 raise unittest.SkipTest("NYI CPU") 672 673 class Mod(torch.nn.Module): 674 def __init__(self) -> None: 675 super().__init__() 676 self.param = torch.nn.Parameter(torch.zeros([10, 10])) 677 678 def forward(self, x): 679 return (self.param + 10) + x 680 681 mod = Mod().eval().to(self.device) 682 inp = torch.rand([10], device=self.device) 683 684 with torch.no_grad(): 685 eager = mod(inp) 686 687 weight_ref = weakref.ref(mod.param) 688 689 @torch.compile() 690 def foo(mod, inp): 691 return mod(inp) 692 693 with torch.no_grad(): 694 compiled = foo(mod, inp) 695 696 self.assertEqual(eager, compiled) 697 self.assertTrue(weight_ref() is None) 698 699 @skipIfRocm 700 def test_conv_with_as_strided(self): 701 class Model(nn.Module): 702 def __init__(self, groups): 703 super().__init__() 704 self.kv = torch.nn.Conv2d( 705 256, 706 384, 707 kernel_size=(1, 1), 708 stride=(1, 1), 709 bias=False, 710 groups=groups, 711 ) 712 713 def forward(self, x): 714 convolution = self.kv(x) 715 constant_pad_nd = torch.ops.aten.constant_pad_nd.default( 716 convolution, [2, 2, 2, 2], 0.0 717 ) 718 # as_strided inputs are depend on input's size and stide. 719 as_strided = torch.ops.aten.as_strided.default( 720 constant_pad_nd, [8, 384, 2, 20, 12], [153600, 400, 160, 1, 20] 721 ) 722 as_strided_1 = torch.ops.aten.as_strided.default( 723 as_strided, [8, 384, 2, 2, 12, 12], [153600, 400, 160, 8, 20, 1] 724 ) 725 clone = torch.ops.aten.clone.default( 726 as_strided_1, memory_format=torch.contiguous_format 727 ) 728 return clone 729 730 @torch.compile() 731 def foo(mod, inp): 732 return mod(inp) 733 734 with torch.no_grad(): 735 x = torch.randn(8, 256, 16, 16).to(self.device) 736 for groups in [1, 2]: 737 mod = Model(groups).to(self.device).eval() 738 mod_eager = mod(x) 739 self.assertEqual(foo(mod, x), mod_eager) 740 741 def test_cpp_wrapper(self): 742 mod = ConvBN(3, 32, kernel_size=3, stride=2).eval().to(self.device) 743 744 x = torch.rand(3, 3, 32, 32).to(self.device) 745 746 @torch.compile(options={"cpp_wrapper": True}) 747 def foo(mod, x): 748 return mod(x) 749 750 out_eager = mod(x) 751 752 with torch.no_grad(): 753 self.assertEqual(foo(mod, x), out_eager) 754 self.assertEqual(foo(mod, x), out_eager) 755 756 def test_conv_layout_convert_with_view(self): 757 class Model(torch.nn.Module): 758 def __init__(self) -> None: 759 super().__init__() 760 self.conv = nn.Conv2d( 761 3, 128, kernel_size=3, padding=1, stride=1, bias=False 762 ) 763 self.bn = nn.BatchNorm2d(3) 764 765 def forward(self, x): 766 x = self.bn(x) 767 x = self.conv(x) 768 return torch.flatten(x, 1) 769 770 mod = Model().to(self.device).eval() 771 772 @torch.compile() 773 def foo(mod, inp): 774 return mod(inp) 775 776 with torch.no_grad(): 777 x = torch.rand(2, 3, 5, 5).to(self.device) 778 mod_eager = mod(x) 779 self.assertEqual(foo(mod, x), mod_eager) 780 781 @skipIfRocm 782 def test_conv_weight_layout_convert(self): 783 class Model(torch.nn.Module): 784 def __init__(self) -> None: 785 super().__init__() 786 self.conv = nn.Conv2d( 787 3, 128, kernel_size=3, padding=1, stride=1, bias=False 788 ) 789 790 def forward(self, x): 791 return self.conv(x) 792 793 @staticmethod 794 def get_example_inputs(): 795 return (torch.rand(2, 3, 5, 5).to(self.device),) 796 797 from torch._inductor.compile_fx import compile_fx, compile_fx_inner 798 799 nconv = 0 800 801 def my_inner_compile(gm, example_inputs, *args, **kwargs): 802 out = compile_fx_inner(gm, example_inputs, *args, **kwargs) 803 804 nonlocal nconv 805 convs = [n for n in gm.graph.nodes if n.target == aten.convolution.default] 806 nconv += len(convs) 807 for conv in convs: 808 weight_node = conv.args[1] 809 weight_const_tensor = getattr(gm, weight_node.target) 810 self.assertTrue( 811 weight_const_tensor.is_contiguous(memory_format=torch.channels_last) 812 ) 813 self.assertTrue( 814 weight_node.meta["val"].is_contiguous( 815 memory_format=torch.channels_last 816 ) 817 ) 818 819 return out 820 821 mod = torch.compile( 822 Model().eval().to(self.device), 823 backend=functools.partial(compile_fx, inner_compile=my_inner_compile), 824 ) 825 inp = mod.get_example_inputs() 826 with torch.no_grad(): 827 mod(*inp) 828 829 # Only check the assertion for CUDA. 830 # For CPU, we may get torch.ops.mkldnn._convolution_pointwise.default 831 # in the joint graph rather than torch.ops.aten.convolution.default. 832 # Currently we only handle aten.convolution.default in layout 833 # optimization. That's why the count may be 0 here for CPU. 834 if self.device == "cuda": 835 self.assertTrue(nconv == 1) 836 837 def test_unequal_bias_horizontal_addmm_fusion(self): 838 device = self.device 839 840 class Model(torch.nn.Module): 841 def __init__(self) -> None: 842 super().__init__() 843 self.w1 = torch.tensor( 844 [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], device=device 845 ) 846 self.b1 = torch.zeros(3, device=device) 847 self.w2 = torch.tensor( 848 [[0.0, 0.0, 1.0], [0.0, 0.0, 1.0], [0.0, 0.0, 1.0]], device=device 849 ) 850 self.b2 = torch.tensor([[-1.0, -1.0, -1.0]], device=device) 851 self.w3 = torch.tensor( 852 [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], device=device 853 ) 854 self.b3 = torch.tensor([1.0, 2.0, 3.0], device=device) 855 856 def forward(self, x): 857 out1 = torch.nn.functional.linear(x, self.w1, self.b1) 858 out2 = torch.nn.functional.linear(x, self.w2, self.b2) 859 out3 = torch.nn.functional.linear(x, self.w3, self.b3) 860 return (out1, out2, out3) 861 862 func = Model().to(device).eval() 863 x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], device=device) 864 865 with torch.no_grad(): 866 out_eager = func(x.clone()) 867 868 func1 = torch.compile(func) 869 out_compiled = func1(x.clone()) 870 self.assertEqual(out_eager, out_compiled) 871 872 @skipIfRocm 873 def test_redundant_clone_for_layout_convert(self): 874 class Model(torch.nn.Module): 875 def __init__(self) -> None: 876 super().__init__() 877 self.conv = nn.Conv2d( 878 3, 128, kernel_size=3, padding=1, stride=1, bias=False 879 ) 880 881 def forward(self, x): 882 y = x + 1 883 return self.conv(x), y 884 885 @staticmethod 886 def get_example_inputs(): 887 return (torch.rand(2, 3, 5, 5).to(self.device),) 888 889 mod = Model().eval().to(self.device) 890 inp = mod.get_example_inputs() 891 with torch.no_grad(): 892 expected_outputs = mod(*inp) 893 894 num_same_stride = 0 895 num_diff_stride = 0 896 897 def debug_inductor_force_stride_order(orig_fn, input_tensor, stride): 898 nonlocal num_same_stride, num_diff_stride 899 input_tensor.realize() 900 if tuple(input_tensor.get_stride()) == tuple(stride): 901 num_same_stride += 1 902 else: 903 num_diff_stride += 1 904 return orig_fn(input_tensor, stride) 905 906 with override_lowering( 907 prims.inductor_force_stride_order.default, debug_inductor_force_stride_order 908 ): 909 opt_mod = torch.compile(mod) 910 with torch.no_grad(): 911 actual_outputs = opt_mod(*inp) 912 913 self.assertEqual(len(actual_outputs), len(expected_outputs)) 914 self.assertEqual(2, len(actual_outputs)) 915 for i, actual, expected in zip( 916 itertools.count(), actual_outputs, expected_outputs 917 ): 918 self.assertTrue( 919 torch.allclose(expected, actual, atol=1e-4, rtol=1e-4), 920 f"{i}th output: expected {expected}, actual {actual}", 921 ) 922 923 if self.device == "cpu": 924 # CPU use different convolution implementation, skip the checks below 925 return 926 927 self.assertTrue( 928 actual_outputs[0].is_contiguous(memory_format=torch.contiguous_format) 929 ) 930 self.assertTrue( 931 actual_outputs[1].is_contiguous(memory_format=torch.contiguous_format) 932 ) 933 934 # we don't change the stride of y returned by forward. So there will 935 # be no extra copy 936 self.assertTrue(num_same_stride == 1, f"num_same_stride is {num_same_stride}") 937 # we changed the stride of self.conv(x) returned by forward. So there 938 # may be an extra copy 939 self.assertTrue(num_diff_stride == 1, f"num_diff_stride is {num_diff_stride}") 940 941 942if TEST_WITH_ROCM: 943 torch._inductor.config.force_layout_optimization = 1 944 os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1" 945 946if HAS_CPU and not torch.backends.mps.is_available(): 947 948 class FreezingCpuTests(TestCase): 949 common = check_model 950 device = "cpu" 951 autocast = torch.cpu.amp.autocast 952 953 copy_tests(OptimizeForInferenceTemplate, FreezingCpuTests, "cpu") 954 955if HAS_CUDA and not TEST_WITH_ASAN: 956 957 class FreezingCudaTests(TestCase): 958 common = check_model_cuda 959 device = "cuda" 960 autocast = torch.cuda.amp.autocast 961 962 copy_tests(OptimizeForInferenceTemplate, FreezingCudaTests, "cuda") 963 964 965del OptimizeForInferenceTemplate 966 967 968if __name__ == "__main__": 969 from torch._inductor.test_case import run_tests 970 971 if HAS_CPU or HAS_CUDA: 972 run_tests(needs="filelock") 973