• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import itertools
3import logging
4import operator
5from collections import defaultdict
6from dataclasses import dataclass
7from typing import Any, Callable, Dict, List, Tuple
8
9import torch
10from torch._higher_order_ops.triton_kernel_wrap import (
11    kernel_side_table,
12    triton_kernel_wrapper_functional,
13)
14from torch._inductor import config, inductor_prims
15from torch._inductor.fx_utils import get_node_storage, is_node_realized
16from torch._inductor.lowering import (
17    inplaceable_foreach_ops as inplaceable_foreach_ops_lowerings,
18)
19from torch._inductor.virtualized import V
20from torch.fx.immutable_collections import immutable_dict
21from torch.fx.passes.reinplace import _is_view_op
22from torch.utils import _pytree as pytree
23
24
25log = logging.getLogger(__name__)
26aten = torch.ops.aten
27
28
29@dataclass(frozen=True)
30class InplaceableOp:
31    inplace_op: Callable[..., Any]
32    mutated_arg: int
33    extra_check: Callable[[torch.fx.Node], bool] = lambda node: True
34
35
36_SCATTER_OP_TO_VIEW = {
37    torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default,
38    torch.ops.aten.select_scatter.default: torch.ops.aten.select.int,
39    torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor,
40    torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default,
41}
42_VIEW_OP_TO_SCATTER = {v: k for k, v in _SCATTER_OP_TO_VIEW.items()}
43
44
45def graph_call_function(graph: torch.fx.Graph, fn, *args, **kwargs):
46    fake_args, fake_kwargs = pytree.tree_map(
47        lambda node: node.meta["val"] if isinstance(node, torch.fx.Node) else node,
48        (args, kwargs),
49    )
50    with V.fake_mode:
51        fake_result = fn(*fake_args, **fake_kwargs)
52
53    node = graph.call_function(fn, args, kwargs)
54    node.meta["val"] = fake_result
55    return node
56
57
58@dataclass
59class ViewOp:
60    target: torch._ops.OpOverload
61    args: Tuple[Any, ...]
62    kwargs: Dict[str, Any]
63
64
65def _inplace_generalized_scatter(
66    inp: torch.Tensor, src: torch.Tensor, view_ops: List[ViewOp]
67) -> torch.Tensor:
68    tmp = inp
69    for view in view_ops:
70        fake_args, fake_kwargs = pytree.tree_map(
71            lambda node: node.meta["val"] if isinstance(node, torch.fx.Node) else node,
72            (view.args, view.kwargs),
73        )
74        tmp = view.target(tmp, *fake_args, **fake_kwargs)
75    try:
76        tmp.copy_(src)
77    except RuntimeError as e:
78        raise RuntimeError(
79            f"shape error in scatter op, can not broadcast {src.shape} to {tmp.shape}"
80        ) from e
81    return inp
82
83
84def _generalized_scatter(
85    inp: torch.Tensor, src: torch.Tensor, view_ops: List[ViewOp]
86) -> torch.Tensor:
87    out = inp.clone()
88    return _inplace_generalized_scatter(out, src, view_ops)
89
90
91def _decompose_scatter_functional_helper(
92    graph: torch.fx.Graph,
93    inp: torch.Tensor,
94    src: torch.Tensor,
95    view_ops: List[ViewOp],
96) -> torch.fx.Node:
97    view_op, view_ops_tail = view_ops[0], view_ops[1:]
98
99    if view_ops_tail:
100        view = graph_call_function(
101            graph, view_op.target, inp, *view_op.args, **view_op.kwargs
102        )
103        src = _decompose_scatter_functional_helper(graph, view, src, view_ops[1:])  # type: ignore[assignment]
104
105    return graph_call_function(
106        graph,
107        _VIEW_OP_TO_SCATTER[view_op.target],
108        inp,
109        src,
110        *view_op.args,
111        **view_op.kwargs,
112    )
113
114
115def _decompose_scatter_functional(
116    graph: torch.fx.Graph, node: torch.fx.Node
117) -> torch.fx.Node:
118    """Decompose _generalized_scatter to a sequence of view_scatter operations
119
120    e.g. _generalized_scatter(inp, src, [(aten.slice, 0, 0, 10), (aten.slice, 1, 10, -10)])
121
122    will become
123
124    view = aten.slice(inp, 0, 0, 10)
125    view_updated = aten.slice_scatter(view, src, 1, 10, -10)
126    inp_updated = aten.slice_scatter(inp, view_updated, 0, 0, 10)
127    """
128    assert node.target is _generalized_scatter
129    inp, src, view_ops = node.args
130    return _decompose_scatter_functional_helper(graph, *node.args)  # type: ignore[arg-type]
131
132
133def _decompose_scatter_mutating(
134    graph: torch.fx.Graph, node: torch.fx.Node
135) -> torch.fx.Node:
136    """Decompose _generalized_scatter using mutations
137
138    e.g. _generalized_scatter(inp, src, [(aten.slice, 0, 0, 10), (aten.slice, 1, 10, -10)])
139
140    will become
141
142    inp_updated = aten.clone(inp)
143    slice1 = aten.slice(inp_updated, 0, 0, 10)
144    slice2 = aten.slice(slice1, 1, 10, -10)
145    slice2.copy_(src)
146
147    """
148    assert node.target in (_generalized_scatter, _inplace_generalized_scatter)
149    inp, src, view_ops = node.args
150    assert not node.kwargs
151
152    if node.target is _generalized_scatter:
153        inp = graph_call_function(graph, aten.clone, inp)
154
155    tmp = inp
156    for view in view_ops:  # type: ignore[union-attr]
157        tmp = graph_call_function(graph, view.target, tmp, *view.args, **view.kwargs)  # type: ignore[union-attr]
158
159    graph_call_function(graph, aten.copy_.default, tmp, src)
160    return inp  # type: ignore[return-value]
161
162
163# View ops whose view_scatter op is lowered into mutations anyway,
164# so is never a pessimisation to decompose.
165_ALWAYS_MUTATING_SCATTER_OPS = {
166    aten.as_strided.default,
167    aten.diagonal.default,
168}
169
170
171def scatter_always_uses_mutation(node: torch.fx.Node) -> bool:
172    _, _, view_ops = node.args
173    return any(view.target in _ALWAYS_MUTATING_SCATTER_OPS for view in view_ops)  # type: ignore[union-attr]
174
175
176def should_reinplace_scatter(node: torch.fx.Node) -> bool:
177    """Choose between mutating and functional scatter decompositions
178
179    Reinplacing view scatter ops can be pessimising as it blocks fusion with the
180    input or output tensor computations. However, it is still profitable if the
181    input and output would have been realized anyway.
182
183    """
184    inp, src, view_ops = node.args
185
186    # Mutating scatter ops unconditionally realize input and output
187    if scatter_always_uses_mutation(node):
188        return True
189
190    if is_node_realized(inp) and is_node_realized(node):  # type: ignore[arg-type]
191        return True
192
193    # If the output is copied back into the input, this forces both to be
194    # realized as the output is a user of the input
195    if inp.op in ("placeholder", "get_attr") and any(  # type: ignore[union-attr]
196        user.target is aten.copy_.default and user.args[0] is inp for user in node.users
197    ):
198        return True
199
200    # Otherwise, assume fusions will make functional variants profitable
201    return False
202
203
204def decompose_generalized_scatter(graph: torch.fx.Graph) -> None:
205    """Replace _generalized_scatter with normal aten ops"""
206    for node in itertools.chain(
207        graph.find_nodes(op="call_function", target=_generalized_scatter),
208        graph.find_nodes(op="call_function", target=_inplace_generalized_scatter),
209    ):
210        use_mutation = (
211            node.target is _inplace_generalized_scatter
212            or scatter_always_uses_mutation(node)
213        )
214
215        with graph.inserting_before(node):
216            if use_mutation:
217                new_node = _decompose_scatter_mutating(graph, node)
218            else:
219                new_node = _decompose_scatter_functional(graph, node)
220
221        node.replace_all_uses_with(new_node)
222        graph.erase_node(node)
223
224
225def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None:
226    """
227    This canonicalizes view scatter ops into a generalized form, defined as:
228      def scatter(inp, src, views):
229        tmp = inp.clone()
230        for view in views:
231          tmp = view(tmp)
232        tmp.copy_(src)
233
234    We also fuse consecutive view scatter ops of the form
235        a = scatter(view2(self), src, [view1])
236        b = scatter(self, a, [view2])
237    which can be rewritten as
238        b = scatter(self, src, [view2, view1])
239        a = view2(b)
240
241    This is both more efficient as we only do a single scatter, and also
242    easier to reinplace since there is only one use of `self`
243    """
244
245    node_to_view_base: Dict[torch.fx.Node, torch.fx.Node] = {}
246    node_to_view_op: Dict[torch.fx.Node, List[ViewOp]] = defaultdict(list)
247
248    def handle_views(node: torch.fx.Node):
249        inp = node.args[0]
250        node_to_view_base[node] = node_to_view_base.get(inp, inp)  # type: ignore[arg-type]
251        node_to_view_op[node] = [
252            *node_to_view_op[inp],  # type: ignore[index]
253            ViewOp(
254                node.target,  # type: ignore[arg-type]
255                args=node.args[1:],
256                kwargs=node.kwargs,
257            ),
258        ]
259
260    def handle_view_scatter(node: torch.fx.Node):
261        assert len(node.args) >= 2
262        inp, src = node.args[:2]
263
264        scatter_view_op = ViewOp(
265            _SCATTER_OP_TO_VIEW[node.target],
266            args=node.args[2:],
267            kwargs=node.kwargs,
268        )
269
270        def can_fuse():
271            if src.target is not _generalized_scatter:  # type: ignore[union-attr]
272                return False
273            src_inp, src_src, src_scatter_view_op = src.args  # type: ignore[union-attr]
274
275            inp_base = node_to_view_base.get(inp, inp)  # type: ignore[arg-type]
276            src_base = node_to_view_base.get(src_inp, src_inp)  # type: ignore[arg-type]
277            return inp_base is src_base and node_to_view_op[src_inp] == [  # type: ignore[index]
278                *node_to_view_op[inp],  # type: ignore[index]
279                scatter_view_op,
280            ]
281
282        if not can_fuse():
283            with graph.inserting_before(node):
284                new_node = graph_call_function(
285                    graph,
286                    _generalized_scatter,
287                    inp,
288                    src,
289                    [scatter_view_op],
290                )
291            node.replace_all_uses_with(new_node)
292            graph.erase_node(node)
293            return
294
295        src_inp, src_src, src_scatter_view_op = src.args  # type: ignore[union-attr]
296        with graph.inserting_before(src):  # type: ignore[arg-type]
297            new_node = graph_call_function(
298                graph,
299                _generalized_scatter,
300                inp,
301                src_src,
302                [scatter_view_op, *src_scatter_view_op],  # type: ignore[misc]
303            )
304            node.replace_all_uses_with(new_node)
305            graph.erase_node(node)
306
307            if src.users:  # type: ignore[union-attr]
308                new_src = graph_call_function(
309                    graph,
310                    _SCATTER_OP_TO_VIEW[node.target],
311                    new_node,
312                    *node.args[2:],
313                    **node.kwargs,
314                )
315
316                handle_views(new_src)
317                src.replace_all_uses_with(new_src)  # type: ignore[union-attr]
318
319            graph.erase_node(src)  # type: ignore[arg-type]
320
321    for node in graph.nodes:
322        if _is_view_op(node.target):
323            handle_views(node)
324        elif node.target in _SCATTER_OP_TO_VIEW:
325            handle_view_scatter(node)
326
327
328inplaceable_ops = {
329    aten.index_put.default: InplaceableOp(aten.index_put_.default, 0),
330    aten._unsafe_index_put.default: InplaceableOp(inductor_prims._unsafe_index_put_, 0),
331    _generalized_scatter: InplaceableOp(
332        _inplace_generalized_scatter,
333        0,
334        extra_check=should_reinplace_scatter,
335    ),
336}
337
338try:
339    c10d_functional = torch.ops._c10d_functional
340    inplaceable_collective_ops = {
341        c10d_functional.all_reduce.default: InplaceableOp(
342            c10d_functional.all_reduce_.default, 0
343        ),
344        c10d_functional.all_reduce_coalesced.default: InplaceableOp(
345            c10d_functional.all_reduce_coalesced_.default, 0
346        ),
347    }
348    inplaceable_ops.update(inplaceable_collective_ops)
349except AttributeError:
350    # _c10d_functional ops are only available when torch
351    # is built with USE_DISTRIBUTED=1.
352    pass
353
354inplaceable_foreach_ops: Dict[torch._ops.OpOverload, InplaceableOp] = {}
355for outplace_op, inplace_op in inplaceable_foreach_ops_lowerings.items():
356    inplaceable_foreach_ops[outplace_op] = InplaceableOp(inplace_op, 0)
357
358
359inplaceable_triton_ops = {triton_kernel_wrapper_functional}
360
361
362# Operators that don't depend on the tensor data
363META_ONLY_OPS = {
364    aten.sym_size.int,
365    aten.sym_stride.int,
366    aten.sym_numel.default,
367    aten.sym_storage_offset.default,
368}
369
370
371def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
372    """
373    Reinplaces in-placeable operations.
374    If there are no uses of a view of the mutated arg after the current node,
375    it is possible to inplace the op.
376    This above algorithm could be justified by observing side effects. While
377    we traverse the graph in forwards direction, only latter nodes could view
378    side effects of the current node. If the current node is not used later as
379    well as no view of this node is used later in the graph, then it is safe to
380    inplace as there would be no way to observe the side effects.
381    This condition is slightly different for graph inputs where they can only
382    be inplaced if the above condition is true and there's a copy_ in the
383    epilogue that signals that the caller wants to observe the mutation.
384
385    Unlike JIT Inductor, AOTInductor currently unlifts weights and buffers from
386    input args, so instead of checking mutation on placeholder, AOTInductor
387    checks mutation on get_attr. This is subject to change in future.
388    """
389
390    copy_args_to_copy_nodes = {}
391    # maps argument to the first copy_ node that mutates it.
392    copy_nodes = {}
393    mutated_inputs = set()
394    storage_to_nodes = defaultdict(list)
395    node_order: Dict[Any, int] = {}
396    for i, node in enumerate(reversed(graph.nodes)):
397        node_order[node] = len(graph.nodes) - i - 1
398        storage_to_nodes[get_node_storage(node)].append(node)
399        if node.target == aten.copy_.default and node.args[0].op in (
400            "placeholder",
401            "get_attr",
402        ):
403            dst = node.args[0]
404            src = node.args[1]
405            # If the target is a getitem and it indexes a possible clone,
406            # then skip over it
407            if src.target == operator.getitem and (
408                (
409                    src.args[0].target == triton_kernel_wrapper_functional
410                    and src.args[0].kwargs["kwargs"][src.args[1]] == node.args[0]
411                )
412                or (src.args[0].target in inplaceable_foreach_ops)
413                or (src.args[0].target == torch.ops.higher_order.auto_functionalized)
414            ):
415                src = src.args[0]
416
417            copy_args_to_copy_nodes[(dst, src)] = node
418            copy_nodes[dst] = node
419
420            mutated_inputs.add(node.args[0])
421
422    def any_use_of_views_after_node(node, shared_view_nodes, *, copy_node, mutated_arg):
423        node_loc = node_order[node]
424        copy_node_loc = node_order[copy_node] if copy_node is not None else None
425
426        def is_meta_only_user(node):
427            if _is_view_op(node.target):
428                return all(is_meta_only_user(u) for u in node.users)
429            return node.target in META_ONLY_OPS
430
431        for view in shared_view_nodes:
432            for user in view.users:
433                user_loc = node_order[user]
434                # Skip all users before node
435                if user_loc <= node_loc:
436                    continue
437                # Ignore uses after the copy_ epilogue node, where the input
438                # has already been mutated anyway
439                if copy_node_loc is not None and copy_node_loc <= user_loc:
440                    continue
441                # Reinplacing does not change shape metadata
442                if is_meta_only_user(user):
443                    continue
444                # If our graph looks like:
445                # foo(mutated_arg)
446                # mutated_arg.copy_(other)
447                # then it's safe for us to reinplace foo because mutated_arg
448                # will get overwritten anyways.
449                if (
450                    user.target is torch.ops.aten.copy_.default
451                    and mutated_arg is user.args[0]
452                ):
453                    continue
454                return True
455        return False
456
457    def can_inplace(node, mutated_arg):
458        if isinstance(mutated_arg, (list, tuple)):
459            unique_storages = {get_node_storage(arg) for arg in mutated_arg}
460            if len(unique_storages) != len(mutated_arg):
461                # at least two Tensors in mutated_arg alias each other, so we can't reinplace it.
462                # We can probably do better (that is, reinplace one of them and clone the other)
463                # but that requires more work and mutable List[Tensor] are not that common.
464                return False
465            return all(can_inplace(node, arg) for arg in mutated_arg)
466
467        if get_node_storage(mutated_arg) is None:
468            return False
469        shared_view_nodes = storage_to_nodes[get_node_storage(mutated_arg)]
470
471        if mutated_arg.op in ("placeholder", "get_attr"):
472            # Get the first copy_ node that mutates the mutated_arg.
473            copy_node = copy_nodes.get(mutated_arg, None)
474            if copy_node is None:
475                # There is no copy_ back to the candidate mutated_arg (which is a graph input).
476                # Therefore the semantics of the program are that it does not mutate
477                # mutated_arg, so we cannot re-inplace it.
478                return False
479            if any_use_of_views_after_node(
480                node, shared_view_nodes, copy_node=copy_node, mutated_arg=mutated_arg
481            ):
482                return False
483
484            return True
485        elif any(view.op in ("placeholder", "get_attr") for view in shared_view_nodes):
486            # This should never happen in auto_functionalize_v2 non-inference mode,
487            # since all mutated_arg are bases.
488
489            # If mutated arg is view of any of the inputs of the graph,
490            # do not allow for inplacing.
491            # This would require more sophisticated algorithm to handle
492            return False
493        else:
494            return not any_use_of_views_after_node(
495                node, shared_view_nodes, copy_node=None, mutated_arg=mutated_arg
496            )
497
498    def log_inplace_results(
499        node_name,
500        old_tensors_to_clone,
501        tensors_to_clone,
502        possibly_missed_reinplacing_opportunities,
503    ):
504        log.info(
505            "For node %s, attempted to reinplace %s. We were unable to reinplace %s; "
506            "%s (if non-empty) are possible missed reinplacing opportunities that may be bad for "
507            "memory usage and performance.",
508            node_name,
509            old_tensors_to_clone,
510            tensors_to_clone,
511            possibly_missed_reinplacing_opportunities,
512        )
513        torch._dynamo.utils.counters["inductor"][
514            "possibly_missed_reinplacing_opportunities"
515        ] += len(possibly_missed_reinplacing_opportunities)
516
517    replace_dict: Dict[torch.fx.Node, torch.fx.Node] = {}
518
519    def reinplace_and_refine_tensors_to_clone(
520        old_tensors_to_clone, kwargs, node_name, auto_functionalize_v2=False
521    ):
522        tensors_to_clone: List[str] = []
523        storage_of_reinplaced_args = set()
524        possibly_missed_reinplacing_opportunities = []
525
526        def tensor_with_same_storage_already_reinplaced(arg):
527            if isinstance(arg, (list, tuple)):
528                return any(
529                    get_node_storage(a) in storage_of_reinplaced_args for a in arg
530                )
531            return get_node_storage(mutated_arg) in storage_of_reinplaced_args
532
533        for arg in old_tensors_to_clone:
534            assert arg in kwargs
535
536            mutated_arg = kwargs[arg]
537
538            # Let's say we have:
539            # - op(x, y) that mutates both x and y
540            # - new_x, new_y = functional_op(x, y) is the functional variant
541            # If we are presented with functional_op(x, x), we must not reinplace
542            # this into op(x, x), because then it would be writing to the same Tensor.
543            # Instead, it's OK to reinplace one of them and to clone the other:
544            # >>> y = x.clone()
545            # >>> op(x, y)
546            # This also applies if we have views: functional_op(x, x[0])
547            # should not reinplace into op(x, x[0]).
548            should_attempt_reinplace = not tensor_with_same_storage_already_reinplaced(
549                mutated_arg
550            )
551            if should_attempt_reinplace and can_inplace(node, mutated_arg):
552                # In general, we probably do not need those optimizations.
553                copy_node = copy_args_to_copy_nodes.get((mutated_arg, node))
554                if copy_node is not None:
555                    replace_dict[copy_node] = copy_node.args[0]
556                if not auto_functionalize_v2:
557                    for user in node.users:
558                        # For auto_functionalize_v2, arg is the index of the base, where base at index i corresponds to
559                        # output atindex size(out)+i.
560                        # This used to compare string with integers before for auto_functionalize_v2. Not sure
561                        # if it was needed for inplaceable_triton_ops?
562                        if user.target == operator.getitem and user.args[1] == arg:
563                            replace_dict[user] = mutated_arg
564
565                if isinstance(mutated_arg, (list, tuple)):
566                    for a in mutated_arg:
567                        storage_of_reinplaced_args.add(get_node_storage(a))
568                else:
569                    storage_of_reinplaced_args.add(get_node_storage(mutated_arg))
570            else:
571                if should_attempt_reinplace:
572                    possibly_missed_reinplacing_opportunities.append(arg)
573                tensors_to_clone.append(arg)
574
575        log_inplace_results(
576            node_name,
577            old_tensors_to_clone,
578            tensors_to_clone,
579            possibly_missed_reinplacing_opportunities,
580        )
581        return tensors_to_clone
582
583    for node in graph.nodes:
584        if (inplaceable_op := inplaceable_ops.get(node.target, None)) is not None:
585            mutated_arg = node.args[inplaceable_op.mutated_arg]
586            if can_inplace(node, mutated_arg) and inplaceable_op.extra_check(node):
587                # TODO(yifu): this doesn't properly remove copy epilogues for
588                # ops that mutate multiple inputs. Need to revise the copy
589                # node tracking logic to support the case.
590                copy_node = copy_args_to_copy_nodes.get((mutated_arg, node))
591                if copy_node is not None:
592                    replace_dict[copy_node] = copy_node.args[0]
593                node.target = inplaceable_op.inplace_op
594        elif node.target == torch.ops.higher_order.auto_functionalized_v2:
595            _mutable_op = node.args[0]
596            kwargs = node.kwargs
597
598            all_bases = kwargs["_all_bases"]
599            bases_to_clone = range(len(all_bases))
600            base_tensors_dct = dict(enumerate(all_bases))
601            new_bases_to_clone: List[int] = reinplace_and_refine_tensors_to_clone(
602                bases_to_clone,
603                base_tensors_dct,
604                node.target,
605                auto_functionalize_v2=True,
606            )
607            # Stash the metadata. There is a pass later on where we decompose
608            # auto_functionalized into clones + a mutable op; this metadata
609            # tells the decomp to only clone the following inputs
610            node.meta["only_clone_these_tensors"] = new_bases_to_clone
611        elif node.target == torch.ops.higher_order.auto_functionalized:
612            _mutable_op = node.args[0]
613            from torch._higher_order_ops.auto_functionalize import get_mutable_args
614
615            tensors_to_clone, _ = get_mutable_args(_mutable_op)
616            # Don't try to reinplace Optional[Tensor] args that are None.
617            tensors_to_clone = [
618                t for t in tensors_to_clone if node.kwargs[t] is not None
619            ]
620            tensors_to_clone = reinplace_and_refine_tensors_to_clone(
621                tensors_to_clone,
622                node.kwargs,
623                _mutable_op._name,
624                auto_functionalize_v2=False,
625            )
626
627            # Stash the metadata. There is a pass later on where we decompose
628            # auto_functionalized into clones + a mutable op; this metadata
629            # tells the decomp to only clone the following inputs
630            node.meta["only_clone_these_tensors"] = tensors_to_clone
631        elif node.target in inplaceable_triton_ops:
632            kernel_idx = node.kwargs["kernel_idx"]
633            kernel = kernel_side_table.get_kernel(kernel_idx)
634            from triton.runtime.autotuner import Autotuner
635            from triton.runtime.jit import JITFunction
636
637            if isinstance(kernel, JITFunction):
638                kernel_name = kernel.fn.__name__
639            elif isinstance(kernel, Autotuner):
640                if config.is_fbcode():
641                    # Autotuner has different implementations for AMD and NV
642                    if torch.version.hip is None:
643                        kernel_name = kernel.base_fn.__name__
644                    else:
645                        kernel_name = kernel.fn.__name__
646                else:
647                    kernel_name = kernel.base_fn.__name__
648            else:
649                raise AssertionError("Unknown triton kernel type")
650
651            # inplaceable_triton_ops take an additional argument called
652            # tensors_to_clone which contain a list of tensors to clone
653            # This pass iterates over them and sees which ones are safe
654            # to eliminate (i.e. no longer need the clones)
655            tensors_to_clone = reinplace_and_refine_tensors_to_clone(
656                node.kwargs["tensors_to_clone"], node.kwargs["kwargs"], kernel_name
657            )
658
659            kwargs = dict(node.kwargs)
660            kwargs["tensors_to_clone"] = tensors_to_clone
661            node.kwargs = immutable_dict(kwargs)
662        elif (
663            inplaceable_op := inplaceable_foreach_ops.get(node.target, None)
664        ) is not None:
665            mutated_args = node.args[inplaceable_op.mutated_arg]
666
667            if not all((arg, node) in copy_args_to_copy_nodes for arg in mutated_args):
668                continue
669
670            if can_inplace(node, mutated_args):
671                for arg in mutated_args:
672                    copy_node = copy_args_to_copy_nodes[(arg, node)]
673                    replace_dict[copy_node] = copy_node.args[0]
674
675                node.target = inplaceable_op.inplace_op
676    for node, replacement in replace_dict.items():
677        while replacement in replace_dict:
678            replacement = replace_dict[replacement]
679        replace_dict[node] = replacement
680
681        node.replace_all_uses_with(replacement)
682        graph.erase_node(node)
683
684
685def reinplace_inplaceable_ops(graph: torch.fx.Graph) -> None:
686    canonicalize_view_scatter_ops(graph)
687    reinplace_inplaceable_ops_core(graph)
688    decompose_generalized_scatter(graph)
689