• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["module: functionalization"]
2
3import numpy as np
4
5import torch
6import torch._dynamo.testing
7import torch._inductor.config as inductor_config
8import torch._inductor.test_case
9import torch.onnx.operators
10import torch.utils._pytree as pytree
11import torch.utils.cpp_extension
12from torch import Tensor
13from torch.testing._internal.logging_utils import logs_to_string
14
15
16class AutoFunctionalizeTests(torch._inductor.test_case.TestCase):
17    def test_auto_functionalize_can_with_default(self):
18        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
19            torch.library.define(
20                "mylib::foo",
21                "(Tensor a, int b, Tensor(d!)? c=None, Tensor? d=None, int e=-1) -> ()",
22                tags=torch.Tag.pt2_compliant_tag,
23                lib=lib,
24            )
25
26            @torch.library.impl("mylib::foo", "cpu", lib=lib)
27            def foo_impl(a, b, c=None, d=None, e=-1):
28                a + b
29                return
30
31            def f(a, mode):
32                return torch.ops.mylib.foo(
33                    a,
34                    0,
35                )
36
37            a = torch.tensor([10, 10, 10], dtype=torch.int64)
38
39            torch.compile(f)(a, 0)
40
41    def test_auto_functionalize_can_with_none_return(self):
42        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
43            lib.define("foo(Tensor x, Tensor(a!) out) -> None")
44
45            def foo_impl(x, out):
46                out.copy_(x)
47
48            lib.impl("foo", foo_impl, "CompositeExplicitAutograd")
49            x = torch.randn(3)
50            out = torch.zeros(3)
51
52            @torch.compile
53            def f(x, out):
54                torch.ops.mylib.foo(x, out)
55
56            f(x, out)
57
58    def test_auto_functionalize_self_as_mutate_arg(self):
59        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
60            lib.define("foo(Tensor(a!) self) -> None")
61
62            def foo_impl(self: torch.Tensor) -> None:
63                self.sin_()
64
65            x = torch.randn(3)
66            lib.impl("foo", foo_impl, "CompositeExplicitAutograd")
67
68            @torch.compile(backend="inductor", fullgraph=True)
69            def f(x):
70                torch.ops.mylib.foo(x)
71
72            f(x)
73
74    def test_auto_functionalize_tensorlist(self):
75        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
76            torch.library.define(
77                "mylib::foo",
78                "(Tensor all_gather_output, SymInt[] all_gather_input_split_sizes, int dim, Tensor(a!)[] out) -> ()",
79                tags=torch.Tag.pt2_compliant_tag,
80                lib=lib,
81            )
82
83            @torch.library.impl("mylib::foo", "cpu", lib=lib)
84            @torch._dynamo.disable
85            def foo_impl(all_gather_output, all_gather_input_split_sizes, dim, out):
86                for o in out:
87                    o.copy_(all_gather_output)
88
89            def f(all_gather_output, all_gather_input_split_sizes, dim, out):
90                torch.ops.mylib.foo(
91                    all_gather_output, all_gather_input_split_sizes, dim, out
92                )
93
94            a = torch.ones(4)
95            b = [2, 3]
96            c = 0
97            d = [torch.empty(4) for _ in range(2)]
98            orig_args = (a, b, c, d)
99
100            compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
101            torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args)
102
103            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
104            f(*eager_args)
105            self.assertEqual(compiled_args, eager_args)
106
107    def test_can_auto_functionalize(self):
108        from torch._higher_order_ops.auto_functionalize import can_auto_functionalize
109
110        expected_true = [
111            "(Tensor(a!) x) -> ()",
112            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()",
113            "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()",
114            "(Tensor(a!) x, Tensor y, Tensor(b!)[] z, SymInt w) -> ()",
115            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor",
116            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor)",
117        ]
118        expected_false = [
119            "(Tensor x) -> ()",
120            "(Tensor(a) x) -> Tensor(a)",
121            "(Tensor(a!) x) -> Tensor(a!)",
122            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor(a)",
123            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))",
124            "(Tensor(a) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))",
125            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor[])",
126        ]
127        for schema in expected_true:
128            with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
129                torch.library.define("mylib::a", schema, lib=lib)
130
131                self.assertTrue(
132                    can_auto_functionalize(torch.ops.mylib.a.default), msg=schema
133                )
134                self.assertFalse(can_auto_functionalize(torch.ops.mylib.a))
135
136        for schema in expected_false:
137            with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
138                torch.library.define("mylib::a", schema, lib=lib)
139                self.assertFalse(
140                    can_auto_functionalize(torch.ops.mylib.a.default), msg=schema
141                )
142                self.assertFalse(can_auto_functionalize(torch.ops.mylib.a))
143
144    @torch._inductor.config.patch(enable_auto_functionalized_v2=False)
145    def test_auto_functionalize_old(self):
146        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
147            torch.library.define(
148                "mylib::foo",
149                "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> ()",
150                tags=torch.Tag.pt2_compliant_tag,
151                lib=lib,
152            )
153
154            @torch.library.impl("mylib::foo", "cpu", lib=lib)
155            @torch._dynamo.disable
156            def foo_impl(x, y, z, w, n):
157                x.add_(y[0] + w)
158                z.add_(y[1] + n)
159
160            def f(x, y, z, n):
161                torch.ops.mylib.foo(x, y, z, 2, n)
162
163            x = torch.randn(3)
164            y = (torch.randn(3), torch.randn(3))
165            z = torch.randn(3)
166            n = torch.randn(3)
167            orig_args = (x, y, z, n)
168            compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
169            log_stream, ctx = logs_to_string(
170                "torch._inductor.compile_fx", "post_grad_graphs"
171            )
172            with ctx():
173                torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args)
174
175            post_grad_graphs = "\n".join(
176                log_stream.getvalue().strip().split("\n")[3:]
177            ).strip()
178
179            # Check the graph under static shapes
180            if torch._dynamo.config.assume_static_by_default:
181                self.assertExpectedInline(
182                    post_grad_graphs,
183                    """\
184def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: \
185"f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
186        # No stacktrace found for following nodes
187        foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1);  arg4_1 = arg2_1 = \
188arg3_1 = arg1_1 = arg0_1 = foo_default = None
189        return ()""",
190                )
191
192            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
193            f(*eager_args)
194            self.assertEqual(compiled_args, eager_args)
195
196    @torch._inductor.config.patch(enable_auto_functionalized_v2=False)
197    def test_auto_functionalize_with_returns_old(self):
198        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
199            torch.library.define(
200                "mylib::foo",
201                "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> (Tensor, Tensor)",
202                tags=torch.Tag.pt2_compliant_tag,
203                lib=lib,
204            )
205
206            @torch.library.impl("mylib::foo", "cpu", lib=lib)
207            @torch._dynamo.disable
208            def foo_impl(x, y, z, w, n):
209                x.add_(y[0] + w)
210                z.add_(y[1] + n)
211                return y[0] + w, y[1] + n
212
213            @torch.library.impl_abstract("mylib::foo", lib=lib)
214            def foo_abstract(x, y, z, w, n):
215                return y[0] + w, y[1] + n
216
217            def f(x, y, z, n):
218                return torch.ops.mylib.foo(x, y, z, 2, n)
219
220            x = torch.randn(3)
221            y = (torch.randn(3), torch.randn(3))
222            z = torch.randn(3)
223            n = torch.randn(3)
224            orig_args = (x, y, z, n)
225
226            compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
227            log_stream, ctx = logs_to_string(
228                "torch._inductor.compile_fx", "post_grad_graphs"
229            )
230            with ctx():
231                compiled_out = torch.compile(f, backend="inductor", fullgraph=True)(
232                    *compiled_args
233                )
234
235            if torch._dynamo.config.assume_static_by_default:
236                post_grad_graphs = "\n".join(
237                    log_stream.getvalue().strip().split("\n")[3:]
238                ).strip()
239                self.assertExpectedInline(
240                    post_grad_graphs,
241                    """\
242def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
243        foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1);  arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None
244        getitem_4: "f32[3][1]cpu" = foo_default[0]
245        getitem_5: "f32[3][1]cpu" = foo_default[1];  foo_default = None
246        return (getitem_4, getitem_5)""",  # noqa: B950
247                    ignore_comments=True,
248                )
249
250            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
251            eager_out = f(*eager_args)
252            self.assertEqual(compiled_args, eager_args)
253            self.assertEqual(compiled_out, eager_out)
254
255    def test_auto_functionalize_on_view(self):
256        for value in [True, False]:
257            with torch.library._scoped_library(
258                "mylib", "FRAGMENT"
259            ) as lib, inductor_config.patch({"enable_auto_functionalized_v2": value}):
260                torch.library.define(
261                    "mylib::foo",
262                    "(Tensor(a!) x) -> ()",
263                    tags=torch.Tag.pt2_compliant_tag,
264                    lib=lib,
265                )
266
267                @torch.library.impl("mylib::foo", "cpu", lib=lib)
268                @torch._dynamo.disable
269                def foo_impl(x):
270                    x_np = x.detach().numpy()  # view
271                    np.sin(x_np, out=x_np)
272                    return
273
274                x = torch.randn(3)
275                expected = x.sin()
276                torch.ops.mylib.foo(x)
277                assert torch.allclose(x, expected)
278
279                @torch.compile(backend="aot_eager_decomp_partition", fullgraph=True)
280                def f(x):
281                    x = x.clone()
282                    y = x[:]
283                    torch.ops.mylib.foo(y)
284                    return x
285
286                y = f(x)
287                self.assertEqual(y, x.sin())
288
289    @torch._inductor.config.patch(enable_auto_functionalized_v2=False)
290    def test_auto_functionalize_optional_old(self):
291        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
292            torch.library.define(
293                "mylib::foo",
294                "(Tensor(a!)? x, Tensor[] y, Tensor(b!)? z, SymInt w, Tensor n) -> ()",
295                tags=torch.Tag.pt2_compliant_tag,
296                lib=lib,
297            )
298
299            @torch.library.impl("mylib::foo", "cpu", lib=lib)
300            @torch._dynamo.disable
301            def foo_impl(x, y, z, w, n):
302                if x is not None:
303                    x.add_(y[0] + w)
304                if z is not None:
305                    z.add_(y[1] + n)
306
307            def f(x, y, z, n):
308                torch.ops.mylib.foo(x, y, z, 2, n)
309
310            x = None
311            y = (torch.randn(3), torch.randn(3))
312            z = torch.randn(3)
313            n = torch.randn(3)
314            orig_args = (x, y, z, n)
315            compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
316            log_stream, ctx = logs_to_string(
317                "torch._inductor.compile_fx", "post_grad_graphs"
318            )
319            with ctx():
320                torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args)
321            if torch._dynamo.config.assume_static_by_default:
322                post_grad_graphs = "\n".join(
323                    log_stream.getvalue().strip().split("\n")[3:]
324                ).strip()
325                self.assertExpectedInline(
326                    post_grad_graphs,
327                    """\
328def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"):
329        # No stacktrace found for following nodes
330        foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1);  \
331arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None
332        return ()""",
333                )
334
335            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
336            f(*eager_args)
337            self.assertEqual(compiled_args, eager_args)
338
339    @torch._dynamo.config.patch(
340        capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
341    )
342    def test_unbacked_auto_functionalize_op(self):
343        @torch.library.custom_op(
344            "mylib::mk_image", mutates_args=("decoder",), device_types=["cpu"]
345        )
346        def mk_image(decoder: Tensor) -> Tensor:
347            return torch.randn(2, 3, 4, 5)
348
349        @torch.library.register_fake("mylib::mk_image")
350        def _(decoder: Tensor) -> Tensor:
351            image_size = [torch.library.get_ctx().new_dynamic_size() for _ in range(4)]
352            return torch.empty(image_size)
353
354        @torch.compile(fullgraph=True)
355        def f(x):
356            return torch.ops.mylib.mk_image.default(x)
357
358        x = torch.zeros(100, dtype=torch.int64)
359        f(x)
360
361    @torch._inductor.config.patch(enable_auto_functionalized_v2=True)
362    def test_auto_functionalize_v2(self, _dynamic=False):
363        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
364            torch.library.define(
365                "mylib::foo",
366                "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> ()",
367                tags=torch.Tag.pt2_compliant_tag,
368                lib=lib,
369            )
370
371            @torch.library.impl("mylib::foo", "cpu", lib=lib)
372            @torch._dynamo.disable
373            def foo_impl(x, y, z, w, n):
374                x.add_(y[0] + w)
375                z.add_(y[1] + n)
376
377            def f(x, y, z, n):
378                torch.ops.mylib.foo(x, y, z, 2, n)
379
380            x = torch.randn(3)
381            y = (torch.randn(3), torch.randn(3))
382            z = torch.randn(3)
383            n = torch.randn(3)
384            orig_args = (x, y, z, n)
385
386            compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
387
388            log_stream, ctx = logs_to_string(
389                "torch._inductor.compile_fx", "post_grad_graphs"
390            )
391            with ctx():
392                torch.compile(f, backend="inductor", dynamic=_dynamic, fullgraph=True)(
393                    *compiled_args
394                )
395
396            post_grad_graphs = "\n".join(
397                log_stream.getvalue().strip().split("\n")[3:]
398            ).strip()
399
400            if torch._dynamo.config.assume_static_by_default:
401                if _dynamic:
402                    self.assertExpectedInline(
403                        post_grad_graphs,
404                        """\
405def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu", arg3_1: "f32[s0][1]cpu", arg4_1: "f32[s0][1]cpu", arg5_1: "f32[s0][1]cpu"):
406        foo_default = torch.ops.mylib.foo.default(arg5_1, [arg3_1, arg4_1], arg2_1, 2, arg1_1);  arg3_1 = arg4_1 = arg1_1 = foo_default = None
407        copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1);  arg2_1 = copy_ = None
408        copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg5_1, arg5_1);  arg5_1 = copy__1 = None
409        return ()""",  # noqa: B950
410                        ignore_comments=True,
411                        ignore_empty_lines=True,
412                    )
413                else:
414                    self.assertExpectedInline(
415                        post_grad_graphs,
416                        """\
417def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
418        foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1);  arg2_1 = arg3_1 = arg0_1 = foo_default = None
419        copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1);  arg1_1 = copy_ = None
420        copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg4_1, arg4_1);  arg4_1 = copy__1 = None
421        return ()""",  # noqa: B950
422                        ignore_comments=True,
423                        ignore_empty_lines=True,
424                    )
425
426            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
427            f(*eager_args)
428            self.assertEqual(compiled_args, eager_args)
429
430    def run_aot_eager(self, f, orig_args, _dynamic=False):
431        aot_eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
432
433        log_stream, ctx = logs_to_string(
434            "torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
435        )
436
437        result = None
438        with ctx():
439            result = torch.compile(
440                f, backend="aot_eager", fullgraph=True, dynamic=_dynamic
441            )(*aot_eager_args)
442
443            graph = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
444        return [aot_eager_args, result, graph]
445
446    def run_inductor(self, f, orig_args, _dynamic=False):
447        compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
448
449        log_stream, ctx = logs_to_string(
450            "torch._inductor.compile_fx", "post_grad_graphs"
451        )
452        result = None
453        with ctx():
454            result = torch.compile(
455                f, backend="inductor", fullgraph=True, dynamic=_dynamic
456            )(*compiled_args)
457
458            graph = "\n".join(log_stream.getvalue().strip().split("\n")[3:]).strip()
459
460        return [compiled_args, result, graph]
461
462    @torch._inductor.config.patch(enable_auto_functionalized_v2=True)
463    def test_auto_functionalize_with_returns_v2(self):
464        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
465            torch.library.define(
466                "mylib::foo",
467                "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> (Tensor, Tensor)",
468                tags=torch.Tag.pt2_compliant_tag,
469                lib=lib,
470            )
471
472            @torch.library.impl("mylib::foo", "cpu", lib=lib)
473            @torch._dynamo.disable
474            def foo_impl(x, y, z, w, n):
475                x.add_(y[0] + w)
476                z.add_(y[1] + n)
477                return y[0] + w, y[1] + n
478
479            @torch.library.impl_abstract("mylib::foo", lib=lib)
480            def foo_abstract(x, y, z, w, n):
481                return y[0] + w, y[1] + n
482
483            def f(x, y, z, n):
484                return torch.ops.mylib.foo(x, y, z, 2, n)
485
486            x = torch.randn(3)
487            y = (torch.randn(3), torch.randn(3))
488            z = torch.randn(3)
489            n = torch.randn(3)
490            orig_args = (x, y, z, n)
491            compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
492            log_stream, ctx = logs_to_string(
493                "torch._inductor.compile_fx", "post_grad_graphs"
494            )
495            with ctx():
496                compiled_out = torch.compile(f, backend="inductor", fullgraph=True)(
497                    *compiled_args
498                )
499            if torch._dynamo.config.assume_static_by_default:
500                post_grad_graphs = "\n".join(
501                    log_stream.getvalue().strip().split("\n")[3:]
502                ).strip()
503                self.assertExpectedInline(
504                    post_grad_graphs,
505                    """\
506def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
507        foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1);  arg2_1 = arg3_1 = arg0_1 = None
508        getitem_4: "f32[3][1]cpu" = foo_default[0]
509        getitem_5: "f32[3][1]cpu" = foo_default[1];  foo_default = None
510
511        copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1);  arg1_1 = copy_ = None
512        copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg4_1, arg4_1);  arg4_1 = copy__1 = None
513        return (getitem_4, getitem_5)""",  # noqa: B950
514                    ignore_comments=True,
515                    ignore_empty_lines=True,
516                )
517
518            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
519            eager_out = f(*eager_args)
520            self.assertEqual(compiled_args, eager_args)
521            self.assertEqual(compiled_out, eager_out)
522
523    # foo takes two inputs that are not views.
524    @torch._inductor.config.patch(enable_auto_functionalized_v2=True)
525    def test_auto_functionalize_extra1(self, _dynamic=False):
526        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
527            torch.library.define(
528                "mylib::foo",
529                "(Tensor(a!) x, Tensor(b!) y) -> ()",
530                tags=torch.Tag.pt2_compliant_tag,
531                lib=lib,
532            )
533
534            @torch.library.impl("mylib::foo", "cpu", lib=lib)
535            @torch._dynamo.disable
536            def foo_impl(x, y):
537                x.sin_()
538                y.sin_()
539
540            def f(x, y):
541                torch.ops.mylib.foo(x, y)
542                return x + y
543
544            orig_args = (torch.randn(2), torch.randn(2))
545
546            [aot_eager_args, result1, graph_aot] = self.run_aot_eager(
547                f, orig_args, _dynamic
548            )
549            [inductor_args, result2, graph_inductor] = self.run_inductor(
550                f, orig_args, _dynamic
551            )
552            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
553            result3 = f(*eager_args)
554
555            self.assertEqual(inductor_args, eager_args)
556            self.assertEqual(inductor_args, aot_eager_args)
557
558            self.assertEqual(result3, result1)
559            self.assertEqual(result3, result2)
560
561            if torch._dynamo.config.assume_static_by_default:
562                if _dynamic:
563                    self.assertExpectedInline(
564                        graph_aot,
565                        """\
566def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu"):
567        auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg2_1, arg1_1])
568        getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1]
569        getitem_2: "f32[s0][1]cpu" = auto_functionalized_v2[2];  auto_functionalized_v2 = None
570        add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2)
571        copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2);  arg1_1 = getitem_2 = copy_ = None
572        copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_1);  arg2_1 = getitem_1 = copy__1 = None
573        return (add,)""",  # noqa: B950
574                        ignore_comments=True,
575                        ignore_empty_lines=True,
576                    )
577                else:
578                    self.assertExpectedInline(
579                        graph_aot,
580                        """\
581def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"):
582        auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg1_1, arg0_1])
583        getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]
584        getitem_2: "f32[2][1]cpu" = auto_functionalized_v2[2];  auto_functionalized_v2 = None
585        add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2)
586        copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_2);  arg0_1 = getitem_2 = copy_ = None
587        copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1);  arg1_1 = getitem_1 = copy__1 = None
588        return (add,)""",  # noqa: B950
589                        ignore_comments=True,
590                        ignore_empty_lines=True,
591                    )
592
593            if torch._dynamo.config.assume_static_by_default:
594                if _dynamic:
595                    self.assertExpectedInline(
596                        graph_inductor,
597                        """\
598def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu"):
599        foo_default = torch.ops.mylib.foo.default(arg2_1, arg1_1);  foo_default = None
600        add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(arg2_1, arg1_1)
601        copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1);  arg1_1 = copy_ = None
602        copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1);  arg2_1 = copy__1 = None
603        return (add,)""",
604                        ignore_comments=True,
605                        ignore_empty_lines=True,
606                    )
607                else:
608                    self.assertExpectedInline(
609                        graph_inductor,
610                        """\
611def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"):
612        foo_default = torch.ops.mylib.foo.default(arg1_1, arg0_1);  foo_default = None
613        add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(arg1_1, arg0_1)
614        copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1);  arg0_1 = copy_ = None
615        copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1);  arg1_1 = copy__1 = None
616        return (add,)""",
617                        ignore_comments=True,
618                        ignore_empty_lines=True,
619                    )
620
621    # foo takes two views on the same input, function does not have return.
622    @torch._inductor.config.patch(enable_auto_functionalized_v2=True)
623    def test_auto_functionalize_extra2(self, _dynamic=False):
624        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
625            torch.library.define(
626                "mylib::foo",
627                "(Tensor(a!) x, Tensor(b!) y) -> ()",
628                tags=torch.Tag.pt2_compliant_tag,
629                lib=lib,
630            )
631
632            @torch.library.impl("mylib::foo", "cpu", lib=lib)
633            @torch._dynamo.disable
634            def foo_impl(x, y):
635                x.sin_()
636                y.sin_()
637
638            def f(x):
639                a = x[0]
640                b = x[1]
641                torch.ops.mylib.foo(a, b)
642                return
643
644            orig_args = [torch.randn(2)]
645
646            [aot_eager_args, result1, graph_aot] = self.run_aot_eager(
647                f, orig_args, _dynamic
648            )
649            [inductor_args, result2, graph_inductor] = self.run_inductor(
650                f, orig_args, _dynamic
651            )
652            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
653            result3 = f(*eager_args)
654
655            self.assertEqual(inductor_args, eager_args)
656            self.assertEqual(inductor_args, aot_eager_args)
657
658            self.assertEqual(result3, result1)
659            self.assertEqual(result3, result2)
660
661            if torch._dynamo.config.assume_static_by_default:
662                if _dynamic:
663                    self.assertExpectedInline(
664                        graph_aot,
665                        """\
666def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
667        auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg1_1])
668        getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1];  auto_functionalized_v2 = None
669        copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1);  arg1_1 = getitem_1 = copy_ = None
670        return ()""",  # noqa: B950
671                        ignore_comments=True,
672                        ignore_empty_lines=True,
673                    )
674                else:
675                    self.assertExpectedInline(
676                        graph_aot,
677                        """\
678def forward(self, arg0_1: "f32[2][1]cpu"):
679        auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg0_1])
680        getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1];  auto_functionalized_v2 = None
681        copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1);  arg0_1 = getitem_1 = copy_ = None
682        return ()""",  # noqa: B950
683                        ignore_comments=True,
684                        ignore_empty_lines=True,
685                    )
686
687            # 2. Run with inductor backend
688
689            if torch._dynamo.config.assume_static_by_default:
690                if _dynamic:
691                    self.assertExpectedInline(
692                        graph_inductor,
693                        """\
694def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"):
695        as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 0)
696        as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 1)
697        foo_default = torch.ops.mylib.foo.default(as_strided_default, as_strided_default_1);  as_strided_default = as_strided_default_1 = foo_default = None
698        copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1);  arg1_1 = copy_ = None
699        return ()""",  # noqa: B950
700                        ignore_comments=True,
701                        ignore_empty_lines=True,
702                    )
703                else:
704                    self.assertExpectedInline(
705                        graph_inductor,
706                        """\
707def forward(self, arg0_1: "f32[2][1]cpu"):
708        as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 0)
709        as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 1)
710        foo_default = torch.ops.mylib.foo.default(as_strided_default, as_strided_default_1);  as_strided_default = as_strided_default_1 = foo_default = None
711        copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1);  arg0_1 = copy_ = None
712        return ()""",  # noqa: B950
713                        ignore_comments=True,
714                        ignore_empty_lines=True,
715                    )
716
717    # foo takes two views on the same input, function returns both views and the input
718    @torch._inductor.config.patch(enable_auto_functionalized_v2=True)
719    def test_auto_functionalize_extra3(self):
720        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
721            torch.library.define(
722                "mylib::foo",
723                "(Tensor(a!) x, Tensor(b!) y) -> ()",
724                tags=torch.Tag.pt2_compliant_tag,
725                lib=lib,
726            )
727
728            @torch.library.impl("mylib::foo", "cpu", lib=lib)
729            @torch._dynamo.disable
730            def foo_impl(x, y):
731                x.sin_()
732                y.sin_()
733
734            def f(x):
735                a = x[0]
736                b = x[1]
737                torch.ops.mylib.foo(a, b)
738                return (a, b, x)
739
740            orig_args = [torch.randn(2)]
741
742            [aot_eager_args, result1, graph_aot] = self.run_aot_eager(f, orig_args)
743            [inductor_args, result2, graph_inductor] = self.run_inductor(f, orig_args)
744            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
745            result3 = f(*eager_args)
746
747            self.assertEqual(inductor_args, eager_args)
748            self.assertEqual(inductor_args, aot_eager_args)
749
750            self.assertEqual(result3, result1)
751            self.assertEqual(result3, result2)
752
753            if torch._dynamo.config.assume_static_by_default:
754                self.assertExpectedInline(
755                    graph_aot,
756                    """\
757def forward(self, arg0_1: "f32[2][1]cpu"):
758        auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg0_1])
759        getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1];  auto_functionalized_v2 = None
760        copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1);  arg0_1 = copy_ = None
761        select_2: "f32[][]cpu" = torch.ops.aten.select.int(getitem_1, 0, 0)
762        select_3: "f32[][]cpu" = torch.ops.aten.select.int(getitem_1, 0, 1);  getitem_1 = None
763        return (select_2, select_3)""",  # noqa: B950
764                    ignore_comments=True,
765                    ignore_empty_lines=True,
766                )
767
768            # 2. Run with inductor backend
769
770            if torch._dynamo.config.assume_static_by_default:
771                self.assertExpectedInline(
772                    graph_inductor,
773                    """\
774def forward(self, arg0_1: "f32[2][1]cpu"):
775        as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 0)
776        as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 1)
777        foo_default = torch.ops.mylib.foo.default(as_strided_default, as_strided_default_1);  as_strided_default = as_strided_default_1 = foo_default = None
778        copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1);  copy_ = None
779        select_2: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0)
780        select_3: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1);  arg0_1 = None
781        return (select_2, select_3)""",  # noqa: B950
782                    ignore_comments=True,
783                    ignore_empty_lines=True,
784                )
785
786    # foo takes a mutable list with views in addition to other args.
787    @torch._inductor.config.patch(enable_auto_functionalized_v2=True)
788    def test_auto_functionalize_extra4(self):
789        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
790            torch.library.define(
791                "mylib::foo",
792                "(Tensor(a!) x, Tensor(b!)[] y) -> ()",
793                tags=torch.Tag.pt2_compliant_tag,
794                lib=lib,
795            )
796
797            @torch.library.impl("mylib::foo", "cpu", lib=lib)
798            @torch._dynamo.disable
799            def foo_impl(x, y):
800                x.sin_()
801                y[0].sin_()
802
803            def f(x, y, z):
804                a = x[0]
805                b = z[0]
806                torch.ops.mylib.foo(a, [b, y])
807
808            orig_args = [torch.randn(2), torch.randn(2), torch.randn(2)]
809
810            [aot_eager_args, result1, graph_aot] = self.run_aot_eager(f, orig_args)
811            [inductor_args, result2, graph_inductor] = self.run_inductor(f, orig_args)
812            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
813            result3 = f(*eager_args)
814
815            self.assertEqual(inductor_args[2], eager_args[2])
816            self.assertEqual(inductor_args, aot_eager_args)
817
818            self.assertEqual(result3, result1)
819            self.assertEqual(result3, result2)
820
821            if torch._dynamo.config.assume_static_by_default:
822                self.assertExpectedInline(
823                    graph_aot,
824                    """\
825def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu", arg2_1: "f32[2][1]cpu"):
826        auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_length = 2, _y_0_base_index = 1, _y_0_size = (), _y_0_stride = (), _y_0_storage_offset = 0, _y_1_base_index = 2, _all_bases = [arg0_1, arg1_1, arg2_1])
827        getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]
828        getitem_2: "f32[2][1]cpu" = auto_functionalized_v2[2]
829        getitem_3: "f32[2][1]cpu" = auto_functionalized_v2[3];  auto_functionalized_v2 = None
830        copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1);  arg0_1 = getitem_1 = copy_ = None
831        copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2);  arg1_1 = getitem_2 = copy__1 = None
832        copy__2: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_3);  arg2_1 = getitem_3 = copy__2 = None
833        return ()""",  # noqa: B950
834                    ignore_comments=True,
835                    ignore_empty_lines=True,
836                )
837
838            # 2. Run with inductor backend
839
840            if torch._dynamo.config.assume_static_by_default:
841                self.assertExpectedInline(
842                    graph_inductor,
843                    """\
844def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu", arg2_1: "f32[2][1]cpu"):
845        as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 0)
846        as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 0)
847        foo_default = torch.ops.mylib.foo.default(as_strided_default, [as_strided_default_1, arg2_1]);  as_strided_default = as_strided_default_1 = foo_default = None
848        copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1);  arg0_1 = copy_ = None
849        copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1);  arg1_1 = copy__1 = None
850        copy__2: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1);  arg2_1 = copy__2 = None
851        return ()""",  # noqa: B950
852                    ignore_comments=True,
853                    ignore_empty_lines=True,
854                )
855
856    @torch._inductor.config.patch(enable_auto_functionalized_v2=True)
857    def test_auto_functionalize_optional_v2(self):
858        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
859            torch.library.define(
860                "mylib::foo",
861                "(Tensor(a!)? x, Tensor[] y, Tensor(b!)? z, SymInt w, Tensor n) -> ()",
862                tags=torch.Tag.pt2_compliant_tag,
863                lib=lib,
864            )
865
866            @torch.library.impl("mylib::foo", "cpu", lib=lib)
867            @torch._dynamo.disable
868            def foo_impl(x, y, z, w, n):
869                if x is not None:
870                    x.add_(y[0] + w)
871                if z is not None:
872                    z.add_(y[1] + n)
873
874            def f(x, y, z, n):
875                torch.ops.mylib.foo(x, y, z, 2, n)
876
877            x = None
878            y = (torch.randn(3), torch.randn(3))
879            z = torch.randn(3)
880            n = torch.randn(3)
881            orig_args = (x, y, z, n)
882
883            compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
884            log_stream, ctx = logs_to_string(
885                "torch._inductor.compile_fx", "post_grad_graphs"
886            )
887            with ctx():
888                torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args)
889
890            if torch._dynamo.config.assume_static_by_default:
891                post_grad_graphs = "\n".join(
892                    log_stream.getvalue().strip().split("\n")[3:]
893                ).strip()
894                self.assertExpectedInline(
895                    post_grad_graphs,
896                    """\
897def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"):
898        foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1);  arg2_1 = arg3_1 = arg0_1 = foo_default = None
899        copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1);  arg1_1 = copy_ = None
900        return ()""",  # noqa: B950
901                    ignore_comments=True,
902                    ignore_empty_lines=True,
903                )
904
905            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
906            f(*eager_args)
907            self.assertEqual(compiled_args, eager_args)
908
909    @torch._inductor.config.patch(enable_auto_functionalized_v2=False)
910    def test_inference_mode1_v2(self):
911        with torch.inference_mode():
912            self.test_auto_functionalize_extra1()
913
914    @torch._inductor.config.patch(enable_auto_functionalized_v2=True)
915    def test_inference_mode2_v2(self):
916        with torch.inference_mode():
917            self.test_auto_functionalize_extra2()
918
919    @torch._inductor.config.patch(enable_auto_functionalized_v2=True)
920    def test_inference_mode3_v2(self):
921        with torch.inference_mode():
922            self.test_auto_functionalize_extra3()
923
924    @torch._inductor.config.patch(enable_auto_functionalized_v2=True)
925    def test_inference_mode4_v2(self):
926        with torch.inference_mode():
927            self.test_auto_functionalize_extra4()
928
929    @torch._inductor.config.patch(enable_auto_functionalized_v2=True)
930    def test_dynamic_v2(self):
931        self.test_auto_functionalize_v2(_dynamic=True)
932
933    @torch._inductor.config.patch(enable_auto_functionalized_v2=True)
934    def test_dynamic2_v2(self):
935        self.test_auto_functionalize_extra1(_dynamic=True)
936
937    @torch._inductor.config.patch(enable_auto_functionalized_v2=True)
938    def test_dynamic3_v2(self):
939        self.test_auto_functionalize_extra2(_dynamic=True)
940
941    # foo takes two views on the same input, function does not have return.
942    @torch._inductor.config.patch(enable_auto_functionalized_v2=True)
943    def test_graph_input_is_view(self):
944        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
945            torch.library.define(
946                "mylib::foo",
947                "(Tensor(a!) x) -> ()",
948                tags=torch.Tag.pt2_compliant_tag,
949                lib=lib,
950            )
951
952            @torch.library.impl("mylib::foo", "cpu", lib=lib)
953            @torch._dynamo.disable
954            def foo_impl(x):
955                pass
956
957            @torch.compile(fullgraph=True, dynamic=False, backend="aot_eager")
958            def f(x):
959                a = x[0]
960                torch.ops.mylib.foo(a)
961                return
962
963            x = torch.tensor([[1, 2], [3, 4]])
964            # This would fail if auto_functionalized_v2 uses clone and not clone_preserve_strides
965            # to clone not-inplaced args.
966            f(x[1])
967
968
969if __name__ == "__main__":
970    from torch._inductor.test_case import run_tests
971
972    run_tests()
973