1# Owner(s): ["module: functionalization"] 2 3import numpy as np 4 5import torch 6import torch._dynamo.testing 7import torch._inductor.config as inductor_config 8import torch._inductor.test_case 9import torch.onnx.operators 10import torch.utils._pytree as pytree 11import torch.utils.cpp_extension 12from torch import Tensor 13from torch.testing._internal.logging_utils import logs_to_string 14 15 16class AutoFunctionalizeTests(torch._inductor.test_case.TestCase): 17 def test_auto_functionalize_can_with_default(self): 18 with torch.library._scoped_library("mylib", "FRAGMENT") as lib: 19 torch.library.define( 20 "mylib::foo", 21 "(Tensor a, int b, Tensor(d!)? c=None, Tensor? d=None, int e=-1) -> ()", 22 tags=torch.Tag.pt2_compliant_tag, 23 lib=lib, 24 ) 25 26 @torch.library.impl("mylib::foo", "cpu", lib=lib) 27 def foo_impl(a, b, c=None, d=None, e=-1): 28 a + b 29 return 30 31 def f(a, mode): 32 return torch.ops.mylib.foo( 33 a, 34 0, 35 ) 36 37 a = torch.tensor([10, 10, 10], dtype=torch.int64) 38 39 torch.compile(f)(a, 0) 40 41 def test_auto_functionalize_can_with_none_return(self): 42 with torch.library._scoped_library("mylib", "FRAGMENT") as lib: 43 lib.define("foo(Tensor x, Tensor(a!) out) -> None") 44 45 def foo_impl(x, out): 46 out.copy_(x) 47 48 lib.impl("foo", foo_impl, "CompositeExplicitAutograd") 49 x = torch.randn(3) 50 out = torch.zeros(3) 51 52 @torch.compile 53 def f(x, out): 54 torch.ops.mylib.foo(x, out) 55 56 f(x, out) 57 58 def test_auto_functionalize_self_as_mutate_arg(self): 59 with torch.library._scoped_library("mylib", "FRAGMENT") as lib: 60 lib.define("foo(Tensor(a!) self) -> None") 61 62 def foo_impl(self: torch.Tensor) -> None: 63 self.sin_() 64 65 x = torch.randn(3) 66 lib.impl("foo", foo_impl, "CompositeExplicitAutograd") 67 68 @torch.compile(backend="inductor", fullgraph=True) 69 def f(x): 70 torch.ops.mylib.foo(x) 71 72 f(x) 73 74 def test_auto_functionalize_tensorlist(self): 75 with torch.library._scoped_library("mylib", "FRAGMENT") as lib: 76 torch.library.define( 77 "mylib::foo", 78 "(Tensor all_gather_output, SymInt[] all_gather_input_split_sizes, int dim, Tensor(a!)[] out) -> ()", 79 tags=torch.Tag.pt2_compliant_tag, 80 lib=lib, 81 ) 82 83 @torch.library.impl("mylib::foo", "cpu", lib=lib) 84 @torch._dynamo.disable 85 def foo_impl(all_gather_output, all_gather_input_split_sizes, dim, out): 86 for o in out: 87 o.copy_(all_gather_output) 88 89 def f(all_gather_output, all_gather_input_split_sizes, dim, out): 90 torch.ops.mylib.foo( 91 all_gather_output, all_gather_input_split_sizes, dim, out 92 ) 93 94 a = torch.ones(4) 95 b = [2, 3] 96 c = 0 97 d = [torch.empty(4) for _ in range(2)] 98 orig_args = (a, b, c, d) 99 100 compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 101 torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args) 102 103 eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 104 f(*eager_args) 105 self.assertEqual(compiled_args, eager_args) 106 107 def test_can_auto_functionalize(self): 108 from torch._higher_order_ops.auto_functionalize import can_auto_functionalize 109 110 expected_true = [ 111 "(Tensor(a!) x) -> ()", 112 "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()", 113 "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()", 114 "(Tensor(a!) x, Tensor y, Tensor(b!)[] z, SymInt w) -> ()", 115 "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor", 116 "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor)", 117 ] 118 expected_false = [ 119 "(Tensor x) -> ()", 120 "(Tensor(a) x) -> Tensor(a)", 121 "(Tensor(a!) x) -> Tensor(a!)", 122 "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor(a)", 123 "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))", 124 "(Tensor(a) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))", 125 "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor[])", 126 ] 127 for schema in expected_true: 128 with torch.library._scoped_library("mylib", "FRAGMENT") as lib: 129 torch.library.define("mylib::a", schema, lib=lib) 130 131 self.assertTrue( 132 can_auto_functionalize(torch.ops.mylib.a.default), msg=schema 133 ) 134 self.assertFalse(can_auto_functionalize(torch.ops.mylib.a)) 135 136 for schema in expected_false: 137 with torch.library._scoped_library("mylib", "FRAGMENT") as lib: 138 torch.library.define("mylib::a", schema, lib=lib) 139 self.assertFalse( 140 can_auto_functionalize(torch.ops.mylib.a.default), msg=schema 141 ) 142 self.assertFalse(can_auto_functionalize(torch.ops.mylib.a)) 143 144 @torch._inductor.config.patch(enable_auto_functionalized_v2=False) 145 def test_auto_functionalize_old(self): 146 with torch.library._scoped_library("mylib", "FRAGMENT") as lib: 147 torch.library.define( 148 "mylib::foo", 149 "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> ()", 150 tags=torch.Tag.pt2_compliant_tag, 151 lib=lib, 152 ) 153 154 @torch.library.impl("mylib::foo", "cpu", lib=lib) 155 @torch._dynamo.disable 156 def foo_impl(x, y, z, w, n): 157 x.add_(y[0] + w) 158 z.add_(y[1] + n) 159 160 def f(x, y, z, n): 161 torch.ops.mylib.foo(x, y, z, 2, n) 162 163 x = torch.randn(3) 164 y = (torch.randn(3), torch.randn(3)) 165 z = torch.randn(3) 166 n = torch.randn(3) 167 orig_args = (x, y, z, n) 168 compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 169 log_stream, ctx = logs_to_string( 170 "torch._inductor.compile_fx", "post_grad_graphs" 171 ) 172 with ctx(): 173 torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args) 174 175 post_grad_graphs = "\n".join( 176 log_stream.getvalue().strip().split("\n")[3:] 177 ).strip() 178 179 # Check the graph under static shapes 180 if torch._dynamo.config.assume_static_by_default: 181 self.assertExpectedInline( 182 post_grad_graphs, 183 """\ 184def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: \ 185"f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): 186 # No stacktrace found for following nodes 187 foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = \ 188arg3_1 = arg1_1 = arg0_1 = foo_default = None 189 return ()""", 190 ) 191 192 eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 193 f(*eager_args) 194 self.assertEqual(compiled_args, eager_args) 195 196 @torch._inductor.config.patch(enable_auto_functionalized_v2=False) 197 def test_auto_functionalize_with_returns_old(self): 198 with torch.library._scoped_library("mylib", "FRAGMENT") as lib: 199 torch.library.define( 200 "mylib::foo", 201 "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> (Tensor, Tensor)", 202 tags=torch.Tag.pt2_compliant_tag, 203 lib=lib, 204 ) 205 206 @torch.library.impl("mylib::foo", "cpu", lib=lib) 207 @torch._dynamo.disable 208 def foo_impl(x, y, z, w, n): 209 x.add_(y[0] + w) 210 z.add_(y[1] + n) 211 return y[0] + w, y[1] + n 212 213 @torch.library.impl_abstract("mylib::foo", lib=lib) 214 def foo_abstract(x, y, z, w, n): 215 return y[0] + w, y[1] + n 216 217 def f(x, y, z, n): 218 return torch.ops.mylib.foo(x, y, z, 2, n) 219 220 x = torch.randn(3) 221 y = (torch.randn(3), torch.randn(3)) 222 z = torch.randn(3) 223 n = torch.randn(3) 224 orig_args = (x, y, z, n) 225 226 compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 227 log_stream, ctx = logs_to_string( 228 "torch._inductor.compile_fx", "post_grad_graphs" 229 ) 230 with ctx(): 231 compiled_out = torch.compile(f, backend="inductor", fullgraph=True)( 232 *compiled_args 233 ) 234 235 if torch._dynamo.config.assume_static_by_default: 236 post_grad_graphs = "\n".join( 237 log_stream.getvalue().strip().split("\n")[3:] 238 ).strip() 239 self.assertExpectedInline( 240 post_grad_graphs, 241 """\ 242def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): 243 foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None 244 getitem_4: "f32[3][1]cpu" = foo_default[0] 245 getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None 246 return (getitem_4, getitem_5)""", # noqa: B950 247 ignore_comments=True, 248 ) 249 250 eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 251 eager_out = f(*eager_args) 252 self.assertEqual(compiled_args, eager_args) 253 self.assertEqual(compiled_out, eager_out) 254 255 def test_auto_functionalize_on_view(self): 256 for value in [True, False]: 257 with torch.library._scoped_library( 258 "mylib", "FRAGMENT" 259 ) as lib, inductor_config.patch({"enable_auto_functionalized_v2": value}): 260 torch.library.define( 261 "mylib::foo", 262 "(Tensor(a!) x) -> ()", 263 tags=torch.Tag.pt2_compliant_tag, 264 lib=lib, 265 ) 266 267 @torch.library.impl("mylib::foo", "cpu", lib=lib) 268 @torch._dynamo.disable 269 def foo_impl(x): 270 x_np = x.detach().numpy() # view 271 np.sin(x_np, out=x_np) 272 return 273 274 x = torch.randn(3) 275 expected = x.sin() 276 torch.ops.mylib.foo(x) 277 assert torch.allclose(x, expected) 278 279 @torch.compile(backend="aot_eager_decomp_partition", fullgraph=True) 280 def f(x): 281 x = x.clone() 282 y = x[:] 283 torch.ops.mylib.foo(y) 284 return x 285 286 y = f(x) 287 self.assertEqual(y, x.sin()) 288 289 @torch._inductor.config.patch(enable_auto_functionalized_v2=False) 290 def test_auto_functionalize_optional_old(self): 291 with torch.library._scoped_library("mylib", "FRAGMENT") as lib: 292 torch.library.define( 293 "mylib::foo", 294 "(Tensor(a!)? x, Tensor[] y, Tensor(b!)? z, SymInt w, Tensor n) -> ()", 295 tags=torch.Tag.pt2_compliant_tag, 296 lib=lib, 297 ) 298 299 @torch.library.impl("mylib::foo", "cpu", lib=lib) 300 @torch._dynamo.disable 301 def foo_impl(x, y, z, w, n): 302 if x is not None: 303 x.add_(y[0] + w) 304 if z is not None: 305 z.add_(y[1] + n) 306 307 def f(x, y, z, n): 308 torch.ops.mylib.foo(x, y, z, 2, n) 309 310 x = None 311 y = (torch.randn(3), torch.randn(3)) 312 z = torch.randn(3) 313 n = torch.randn(3) 314 orig_args = (x, y, z, n) 315 compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 316 log_stream, ctx = logs_to_string( 317 "torch._inductor.compile_fx", "post_grad_graphs" 318 ) 319 with ctx(): 320 torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args) 321 if torch._dynamo.config.assume_static_by_default: 322 post_grad_graphs = "\n".join( 323 log_stream.getvalue().strip().split("\n")[3:] 324 ).strip() 325 self.assertExpectedInline( 326 post_grad_graphs, 327 """\ 328def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"): 329 # No stacktrace found for following nodes 330 foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); \ 331arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None 332 return ()""", 333 ) 334 335 eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 336 f(*eager_args) 337 self.assertEqual(compiled_args, eager_args) 338 339 @torch._dynamo.config.patch( 340 capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True 341 ) 342 def test_unbacked_auto_functionalize_op(self): 343 @torch.library.custom_op( 344 "mylib::mk_image", mutates_args=("decoder",), device_types=["cpu"] 345 ) 346 def mk_image(decoder: Tensor) -> Tensor: 347 return torch.randn(2, 3, 4, 5) 348 349 @torch.library.register_fake("mylib::mk_image") 350 def _(decoder: Tensor) -> Tensor: 351 image_size = [torch.library.get_ctx().new_dynamic_size() for _ in range(4)] 352 return torch.empty(image_size) 353 354 @torch.compile(fullgraph=True) 355 def f(x): 356 return torch.ops.mylib.mk_image.default(x) 357 358 x = torch.zeros(100, dtype=torch.int64) 359 f(x) 360 361 @torch._inductor.config.patch(enable_auto_functionalized_v2=True) 362 def test_auto_functionalize_v2(self, _dynamic=False): 363 with torch.library._scoped_library("mylib", "FRAGMENT") as lib: 364 torch.library.define( 365 "mylib::foo", 366 "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> ()", 367 tags=torch.Tag.pt2_compliant_tag, 368 lib=lib, 369 ) 370 371 @torch.library.impl("mylib::foo", "cpu", lib=lib) 372 @torch._dynamo.disable 373 def foo_impl(x, y, z, w, n): 374 x.add_(y[0] + w) 375 z.add_(y[1] + n) 376 377 def f(x, y, z, n): 378 torch.ops.mylib.foo(x, y, z, 2, n) 379 380 x = torch.randn(3) 381 y = (torch.randn(3), torch.randn(3)) 382 z = torch.randn(3) 383 n = torch.randn(3) 384 orig_args = (x, y, z, n) 385 386 compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 387 388 log_stream, ctx = logs_to_string( 389 "torch._inductor.compile_fx", "post_grad_graphs" 390 ) 391 with ctx(): 392 torch.compile(f, backend="inductor", dynamic=_dynamic, fullgraph=True)( 393 *compiled_args 394 ) 395 396 post_grad_graphs = "\n".join( 397 log_stream.getvalue().strip().split("\n")[3:] 398 ).strip() 399 400 if torch._dynamo.config.assume_static_by_default: 401 if _dynamic: 402 self.assertExpectedInline( 403 post_grad_graphs, 404 """\ 405def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu", arg3_1: "f32[s0][1]cpu", arg4_1: "f32[s0][1]cpu", arg5_1: "f32[s0][1]cpu"): 406 foo_default = torch.ops.mylib.foo.default(arg5_1, [arg3_1, arg4_1], arg2_1, 2, arg1_1); arg3_1 = arg4_1 = arg1_1 = foo_default = None 407 copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy_ = None 408 copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg5_1, arg5_1); arg5_1 = copy__1 = None 409 return ()""", # noqa: B950 410 ignore_comments=True, 411 ignore_empty_lines=True, 412 ) 413 else: 414 self.assertExpectedInline( 415 post_grad_graphs, 416 """\ 417def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): 418 foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = foo_default = None 419 copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None 420 copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg4_1, arg4_1); arg4_1 = copy__1 = None 421 return ()""", # noqa: B950 422 ignore_comments=True, 423 ignore_empty_lines=True, 424 ) 425 426 eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 427 f(*eager_args) 428 self.assertEqual(compiled_args, eager_args) 429 430 def run_aot_eager(self, f, orig_args, _dynamic=False): 431 aot_eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 432 433 log_stream, ctx = logs_to_string( 434 "torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs" 435 ) 436 437 result = None 438 with ctx(): 439 result = torch.compile( 440 f, backend="aot_eager", fullgraph=True, dynamic=_dynamic 441 )(*aot_eager_args) 442 443 graph = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() 444 return [aot_eager_args, result, graph] 445 446 def run_inductor(self, f, orig_args, _dynamic=False): 447 compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 448 449 log_stream, ctx = logs_to_string( 450 "torch._inductor.compile_fx", "post_grad_graphs" 451 ) 452 result = None 453 with ctx(): 454 result = torch.compile( 455 f, backend="inductor", fullgraph=True, dynamic=_dynamic 456 )(*compiled_args) 457 458 graph = "\n".join(log_stream.getvalue().strip().split("\n")[3:]).strip() 459 460 return [compiled_args, result, graph] 461 462 @torch._inductor.config.patch(enable_auto_functionalized_v2=True) 463 def test_auto_functionalize_with_returns_v2(self): 464 with torch.library._scoped_library("mylib", "FRAGMENT") as lib: 465 torch.library.define( 466 "mylib::foo", 467 "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> (Tensor, Tensor)", 468 tags=torch.Tag.pt2_compliant_tag, 469 lib=lib, 470 ) 471 472 @torch.library.impl("mylib::foo", "cpu", lib=lib) 473 @torch._dynamo.disable 474 def foo_impl(x, y, z, w, n): 475 x.add_(y[0] + w) 476 z.add_(y[1] + n) 477 return y[0] + w, y[1] + n 478 479 @torch.library.impl_abstract("mylib::foo", lib=lib) 480 def foo_abstract(x, y, z, w, n): 481 return y[0] + w, y[1] + n 482 483 def f(x, y, z, n): 484 return torch.ops.mylib.foo(x, y, z, 2, n) 485 486 x = torch.randn(3) 487 y = (torch.randn(3), torch.randn(3)) 488 z = torch.randn(3) 489 n = torch.randn(3) 490 orig_args = (x, y, z, n) 491 compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 492 log_stream, ctx = logs_to_string( 493 "torch._inductor.compile_fx", "post_grad_graphs" 494 ) 495 with ctx(): 496 compiled_out = torch.compile(f, backend="inductor", fullgraph=True)( 497 *compiled_args 498 ) 499 if torch._dynamo.config.assume_static_by_default: 500 post_grad_graphs = "\n".join( 501 log_stream.getvalue().strip().split("\n")[3:] 502 ).strip() 503 self.assertExpectedInline( 504 post_grad_graphs, 505 """\ 506def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): 507 foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = None 508 getitem_4: "f32[3][1]cpu" = foo_default[0] 509 getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None 510 511 copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None 512 copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg4_1, arg4_1); arg4_1 = copy__1 = None 513 return (getitem_4, getitem_5)""", # noqa: B950 514 ignore_comments=True, 515 ignore_empty_lines=True, 516 ) 517 518 eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 519 eager_out = f(*eager_args) 520 self.assertEqual(compiled_args, eager_args) 521 self.assertEqual(compiled_out, eager_out) 522 523 # foo takes two inputs that are not views. 524 @torch._inductor.config.patch(enable_auto_functionalized_v2=True) 525 def test_auto_functionalize_extra1(self, _dynamic=False): 526 with torch.library._scoped_library("mylib", "FRAGMENT") as lib: 527 torch.library.define( 528 "mylib::foo", 529 "(Tensor(a!) x, Tensor(b!) y) -> ()", 530 tags=torch.Tag.pt2_compliant_tag, 531 lib=lib, 532 ) 533 534 @torch.library.impl("mylib::foo", "cpu", lib=lib) 535 @torch._dynamo.disable 536 def foo_impl(x, y): 537 x.sin_() 538 y.sin_() 539 540 def f(x, y): 541 torch.ops.mylib.foo(x, y) 542 return x + y 543 544 orig_args = (torch.randn(2), torch.randn(2)) 545 546 [aot_eager_args, result1, graph_aot] = self.run_aot_eager( 547 f, orig_args, _dynamic 548 ) 549 [inductor_args, result2, graph_inductor] = self.run_inductor( 550 f, orig_args, _dynamic 551 ) 552 eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 553 result3 = f(*eager_args) 554 555 self.assertEqual(inductor_args, eager_args) 556 self.assertEqual(inductor_args, aot_eager_args) 557 558 self.assertEqual(result3, result1) 559 self.assertEqual(result3, result2) 560 561 if torch._dynamo.config.assume_static_by_default: 562 if _dynamic: 563 self.assertExpectedInline( 564 graph_aot, 565 """\ 566def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu"): 567 auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg2_1, arg1_1]) 568 getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1] 569 getitem_2: "f32[s0][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None 570 add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2) 571 copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy_ = None 572 copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_1); arg2_1 = getitem_1 = copy__1 = None 573 return (add,)""", # noqa: B950 574 ignore_comments=True, 575 ignore_empty_lines=True, 576 ) 577 else: 578 self.assertExpectedInline( 579 graph_aot, 580 """\ 581def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"): 582 auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg1_1, arg0_1]) 583 getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1] 584 getitem_2: "f32[2][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None 585 add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2) 586 copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_2); arg0_1 = getitem_2 = copy_ = None 587 copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy__1 = None 588 return (add,)""", # noqa: B950 589 ignore_comments=True, 590 ignore_empty_lines=True, 591 ) 592 593 if torch._dynamo.config.assume_static_by_default: 594 if _dynamic: 595 self.assertExpectedInline( 596 graph_inductor, 597 """\ 598def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu"): 599 foo_default = torch.ops.mylib.foo.default(arg2_1, arg1_1); foo_default = None 600 add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(arg2_1, arg1_1) 601 copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None 602 copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None 603 return (add,)""", 604 ignore_comments=True, 605 ignore_empty_lines=True, 606 ) 607 else: 608 self.assertExpectedInline( 609 graph_inductor, 610 """\ 611def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"): 612 foo_default = torch.ops.mylib.foo.default(arg1_1, arg0_1); foo_default = None 613 add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(arg1_1, arg0_1) 614 copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None 615 copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None 616 return (add,)""", 617 ignore_comments=True, 618 ignore_empty_lines=True, 619 ) 620 621 # foo takes two views on the same input, function does not have return. 622 @torch._inductor.config.patch(enable_auto_functionalized_v2=True) 623 def test_auto_functionalize_extra2(self, _dynamic=False): 624 with torch.library._scoped_library("mylib", "FRAGMENT") as lib: 625 torch.library.define( 626 "mylib::foo", 627 "(Tensor(a!) x, Tensor(b!) y) -> ()", 628 tags=torch.Tag.pt2_compliant_tag, 629 lib=lib, 630 ) 631 632 @torch.library.impl("mylib::foo", "cpu", lib=lib) 633 @torch._dynamo.disable 634 def foo_impl(x, y): 635 x.sin_() 636 y.sin_() 637 638 def f(x): 639 a = x[0] 640 b = x[1] 641 torch.ops.mylib.foo(a, b) 642 return 643 644 orig_args = [torch.randn(2)] 645 646 [aot_eager_args, result1, graph_aot] = self.run_aot_eager( 647 f, orig_args, _dynamic 648 ) 649 [inductor_args, result2, graph_inductor] = self.run_inductor( 650 f, orig_args, _dynamic 651 ) 652 eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 653 result3 = f(*eager_args) 654 655 self.assertEqual(inductor_args, eager_args) 656 self.assertEqual(inductor_args, aot_eager_args) 657 658 self.assertEqual(result3, result1) 659 self.assertEqual(result3, result2) 660 661 if torch._dynamo.config.assume_static_by_default: 662 if _dynamic: 663 self.assertExpectedInline( 664 graph_aot, 665 """\ 666def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"): 667 auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg1_1]) 668 getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None 669 copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy_ = None 670 return ()""", # noqa: B950 671 ignore_comments=True, 672 ignore_empty_lines=True, 673 ) 674 else: 675 self.assertExpectedInline( 676 graph_aot, 677 """\ 678def forward(self, arg0_1: "f32[2][1]cpu"): 679 auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg0_1]) 680 getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None 681 copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = getitem_1 = copy_ = None 682 return ()""", # noqa: B950 683 ignore_comments=True, 684 ignore_empty_lines=True, 685 ) 686 687 # 2. Run with inductor backend 688 689 if torch._dynamo.config.assume_static_by_default: 690 if _dynamic: 691 self.assertExpectedInline( 692 graph_inductor, 693 """\ 694def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"): 695 as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 0) 696 as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 1) 697 foo_default = torch.ops.mylib.foo.default(as_strided_default, as_strided_default_1); as_strided_default = as_strided_default_1 = foo_default = None 698 copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None 699 return ()""", # noqa: B950 700 ignore_comments=True, 701 ignore_empty_lines=True, 702 ) 703 else: 704 self.assertExpectedInline( 705 graph_inductor, 706 """\ 707def forward(self, arg0_1: "f32[2][1]cpu"): 708 as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 0) 709 as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 1) 710 foo_default = torch.ops.mylib.foo.default(as_strided_default, as_strided_default_1); as_strided_default = as_strided_default_1 = foo_default = None 711 copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None 712 return ()""", # noqa: B950 713 ignore_comments=True, 714 ignore_empty_lines=True, 715 ) 716 717 # foo takes two views on the same input, function returns both views and the input 718 @torch._inductor.config.patch(enable_auto_functionalized_v2=True) 719 def test_auto_functionalize_extra3(self): 720 with torch.library._scoped_library("mylib", "FRAGMENT") as lib: 721 torch.library.define( 722 "mylib::foo", 723 "(Tensor(a!) x, Tensor(b!) y) -> ()", 724 tags=torch.Tag.pt2_compliant_tag, 725 lib=lib, 726 ) 727 728 @torch.library.impl("mylib::foo", "cpu", lib=lib) 729 @torch._dynamo.disable 730 def foo_impl(x, y): 731 x.sin_() 732 y.sin_() 733 734 def f(x): 735 a = x[0] 736 b = x[1] 737 torch.ops.mylib.foo(a, b) 738 return (a, b, x) 739 740 orig_args = [torch.randn(2)] 741 742 [aot_eager_args, result1, graph_aot] = self.run_aot_eager(f, orig_args) 743 [inductor_args, result2, graph_inductor] = self.run_inductor(f, orig_args) 744 eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 745 result3 = f(*eager_args) 746 747 self.assertEqual(inductor_args, eager_args) 748 self.assertEqual(inductor_args, aot_eager_args) 749 750 self.assertEqual(result3, result1) 751 self.assertEqual(result3, result2) 752 753 if torch._dynamo.config.assume_static_by_default: 754 self.assertExpectedInline( 755 graph_aot, 756 """\ 757def forward(self, arg0_1: "f32[2][1]cpu"): 758 auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg0_1]) 759 getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None 760 copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = copy_ = None 761 select_2: "f32[][]cpu" = torch.ops.aten.select.int(getitem_1, 0, 0) 762 select_3: "f32[][]cpu" = torch.ops.aten.select.int(getitem_1, 0, 1); getitem_1 = None 763 return (select_2, select_3)""", # noqa: B950 764 ignore_comments=True, 765 ignore_empty_lines=True, 766 ) 767 768 # 2. Run with inductor backend 769 770 if torch._dynamo.config.assume_static_by_default: 771 self.assertExpectedInline( 772 graph_inductor, 773 """\ 774def forward(self, arg0_1: "f32[2][1]cpu"): 775 as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 0) 776 as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 1) 777 foo_default = torch.ops.mylib.foo.default(as_strided_default, as_strided_default_1); as_strided_default = as_strided_default_1 = foo_default = None 778 copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); copy_ = None 779 select_2: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0) 780 select_3: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1); arg0_1 = None 781 return (select_2, select_3)""", # noqa: B950 782 ignore_comments=True, 783 ignore_empty_lines=True, 784 ) 785 786 # foo takes a mutable list with views in addition to other args. 787 @torch._inductor.config.patch(enable_auto_functionalized_v2=True) 788 def test_auto_functionalize_extra4(self): 789 with torch.library._scoped_library("mylib", "FRAGMENT") as lib: 790 torch.library.define( 791 "mylib::foo", 792 "(Tensor(a!) x, Tensor(b!)[] y) -> ()", 793 tags=torch.Tag.pt2_compliant_tag, 794 lib=lib, 795 ) 796 797 @torch.library.impl("mylib::foo", "cpu", lib=lib) 798 @torch._dynamo.disable 799 def foo_impl(x, y): 800 x.sin_() 801 y[0].sin_() 802 803 def f(x, y, z): 804 a = x[0] 805 b = z[0] 806 torch.ops.mylib.foo(a, [b, y]) 807 808 orig_args = [torch.randn(2), torch.randn(2), torch.randn(2)] 809 810 [aot_eager_args, result1, graph_aot] = self.run_aot_eager(f, orig_args) 811 [inductor_args, result2, graph_inductor] = self.run_inductor(f, orig_args) 812 eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 813 result3 = f(*eager_args) 814 815 self.assertEqual(inductor_args[2], eager_args[2]) 816 self.assertEqual(inductor_args, aot_eager_args) 817 818 self.assertEqual(result3, result1) 819 self.assertEqual(result3, result2) 820 821 if torch._dynamo.config.assume_static_by_default: 822 self.assertExpectedInline( 823 graph_aot, 824 """\ 825def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu", arg2_1: "f32[2][1]cpu"): 826 auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_length = 2, _y_0_base_index = 1, _y_0_size = (), _y_0_stride = (), _y_0_storage_offset = 0, _y_1_base_index = 2, _all_bases = [arg0_1, arg1_1, arg2_1]) 827 getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1] 828 getitem_2: "f32[2][1]cpu" = auto_functionalized_v2[2] 829 getitem_3: "f32[2][1]cpu" = auto_functionalized_v2[3]; auto_functionalized_v2 = None 830 copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = getitem_1 = copy_ = None 831 copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy__1 = None 832 copy__2: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_3); arg2_1 = getitem_3 = copy__2 = None 833 return ()""", # noqa: B950 834 ignore_comments=True, 835 ignore_empty_lines=True, 836 ) 837 838 # 2. Run with inductor backend 839 840 if torch._dynamo.config.assume_static_by_default: 841 self.assertExpectedInline( 842 graph_inductor, 843 """\ 844def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu", arg2_1: "f32[2][1]cpu"): 845 as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 0) 846 as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 0) 847 foo_default = torch.ops.mylib.foo.default(as_strided_default, [as_strided_default_1, arg2_1]); as_strided_default = as_strided_default_1 = foo_default = None 848 copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None 849 copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None 850 copy__2: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__2 = None 851 return ()""", # noqa: B950 852 ignore_comments=True, 853 ignore_empty_lines=True, 854 ) 855 856 @torch._inductor.config.patch(enable_auto_functionalized_v2=True) 857 def test_auto_functionalize_optional_v2(self): 858 with torch.library._scoped_library("mylib", "FRAGMENT") as lib: 859 torch.library.define( 860 "mylib::foo", 861 "(Tensor(a!)? x, Tensor[] y, Tensor(b!)? z, SymInt w, Tensor n) -> ()", 862 tags=torch.Tag.pt2_compliant_tag, 863 lib=lib, 864 ) 865 866 @torch.library.impl("mylib::foo", "cpu", lib=lib) 867 @torch._dynamo.disable 868 def foo_impl(x, y, z, w, n): 869 if x is not None: 870 x.add_(y[0] + w) 871 if z is not None: 872 z.add_(y[1] + n) 873 874 def f(x, y, z, n): 875 torch.ops.mylib.foo(x, y, z, 2, n) 876 877 x = None 878 y = (torch.randn(3), torch.randn(3)) 879 z = torch.randn(3) 880 n = torch.randn(3) 881 orig_args = (x, y, z, n) 882 883 compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 884 log_stream, ctx = logs_to_string( 885 "torch._inductor.compile_fx", "post_grad_graphs" 886 ) 887 with ctx(): 888 torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args) 889 890 if torch._dynamo.config.assume_static_by_default: 891 post_grad_graphs = "\n".join( 892 log_stream.getvalue().strip().split("\n")[3:] 893 ).strip() 894 self.assertExpectedInline( 895 post_grad_graphs, 896 """\ 897def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"): 898 foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = foo_default = None 899 copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None 900 return ()""", # noqa: B950 901 ignore_comments=True, 902 ignore_empty_lines=True, 903 ) 904 905 eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) 906 f(*eager_args) 907 self.assertEqual(compiled_args, eager_args) 908 909 @torch._inductor.config.patch(enable_auto_functionalized_v2=False) 910 def test_inference_mode1_v2(self): 911 with torch.inference_mode(): 912 self.test_auto_functionalize_extra1() 913 914 @torch._inductor.config.patch(enable_auto_functionalized_v2=True) 915 def test_inference_mode2_v2(self): 916 with torch.inference_mode(): 917 self.test_auto_functionalize_extra2() 918 919 @torch._inductor.config.patch(enable_auto_functionalized_v2=True) 920 def test_inference_mode3_v2(self): 921 with torch.inference_mode(): 922 self.test_auto_functionalize_extra3() 923 924 @torch._inductor.config.patch(enable_auto_functionalized_v2=True) 925 def test_inference_mode4_v2(self): 926 with torch.inference_mode(): 927 self.test_auto_functionalize_extra4() 928 929 @torch._inductor.config.patch(enable_auto_functionalized_v2=True) 930 def test_dynamic_v2(self): 931 self.test_auto_functionalize_v2(_dynamic=True) 932 933 @torch._inductor.config.patch(enable_auto_functionalized_v2=True) 934 def test_dynamic2_v2(self): 935 self.test_auto_functionalize_extra1(_dynamic=True) 936 937 @torch._inductor.config.patch(enable_auto_functionalized_v2=True) 938 def test_dynamic3_v2(self): 939 self.test_auto_functionalize_extra2(_dynamic=True) 940 941 # foo takes two views on the same input, function does not have return. 942 @torch._inductor.config.patch(enable_auto_functionalized_v2=True) 943 def test_graph_input_is_view(self): 944 with torch.library._scoped_library("mylib", "FRAGMENT") as lib: 945 torch.library.define( 946 "mylib::foo", 947 "(Tensor(a!) x) -> ()", 948 tags=torch.Tag.pt2_compliant_tag, 949 lib=lib, 950 ) 951 952 @torch.library.impl("mylib::foo", "cpu", lib=lib) 953 @torch._dynamo.disable 954 def foo_impl(x): 955 pass 956 957 @torch.compile(fullgraph=True, dynamic=False, backend="aot_eager") 958 def f(x): 959 a = x[0] 960 torch.ops.mylib.foo(a) 961 return 962 963 x = torch.tensor([[1, 2], [3, 4]]) 964 # This would fail if auto_functionalized_v2 uses clone and not clone_preserve_strides 965 # to clone not-inplaced args. 966 f(x[1]) 967 968 969if __name__ == "__main__": 970 from torch._inductor.test_case import run_tests 971 972 run_tests() 973