1# mypy: allow-untyped-defs 2import itertools 3import logging 4import operator 5from collections import defaultdict 6from dataclasses import dataclass 7from typing import Any, Callable, Dict, List, Tuple 8 9import torch 10from torch._higher_order_ops.triton_kernel_wrap import ( 11 kernel_side_table, 12 triton_kernel_wrapper_functional, 13) 14from torch._inductor import config, inductor_prims 15from torch._inductor.fx_utils import get_node_storage, is_node_realized 16from torch._inductor.lowering import ( 17 inplaceable_foreach_ops as inplaceable_foreach_ops_lowerings, 18) 19from torch._inductor.virtualized import V 20from torch.fx.immutable_collections import immutable_dict 21from torch.fx.passes.reinplace import _is_view_op 22from torch.utils import _pytree as pytree 23 24 25log = logging.getLogger(__name__) 26aten = torch.ops.aten 27 28 29@dataclass(frozen=True) 30class InplaceableOp: 31 inplace_op: Callable[..., Any] 32 mutated_arg: int 33 extra_check: Callable[[torch.fx.Node], bool] = lambda node: True 34 35 36_SCATTER_OP_TO_VIEW = { 37 torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default, 38 torch.ops.aten.select_scatter.default: torch.ops.aten.select.int, 39 torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor, 40 torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default, 41} 42_VIEW_OP_TO_SCATTER = {v: k for k, v in _SCATTER_OP_TO_VIEW.items()} 43 44 45def graph_call_function(graph: torch.fx.Graph, fn, *args, **kwargs): 46 fake_args, fake_kwargs = pytree.tree_map( 47 lambda node: node.meta["val"] if isinstance(node, torch.fx.Node) else node, 48 (args, kwargs), 49 ) 50 with V.fake_mode: 51 fake_result = fn(*fake_args, **fake_kwargs) 52 53 node = graph.call_function(fn, args, kwargs) 54 node.meta["val"] = fake_result 55 return node 56 57 58@dataclass 59class ViewOp: 60 target: torch._ops.OpOverload 61 args: Tuple[Any, ...] 62 kwargs: Dict[str, Any] 63 64 65def _inplace_generalized_scatter( 66 inp: torch.Tensor, src: torch.Tensor, view_ops: List[ViewOp] 67) -> torch.Tensor: 68 tmp = inp 69 for view in view_ops: 70 fake_args, fake_kwargs = pytree.tree_map( 71 lambda node: node.meta["val"] if isinstance(node, torch.fx.Node) else node, 72 (view.args, view.kwargs), 73 ) 74 tmp = view.target(tmp, *fake_args, **fake_kwargs) 75 try: 76 tmp.copy_(src) 77 except RuntimeError as e: 78 raise RuntimeError( 79 f"shape error in scatter op, can not broadcast {src.shape} to {tmp.shape}" 80 ) from e 81 return inp 82 83 84def _generalized_scatter( 85 inp: torch.Tensor, src: torch.Tensor, view_ops: List[ViewOp] 86) -> torch.Tensor: 87 out = inp.clone() 88 return _inplace_generalized_scatter(out, src, view_ops) 89 90 91def _decompose_scatter_functional_helper( 92 graph: torch.fx.Graph, 93 inp: torch.Tensor, 94 src: torch.Tensor, 95 view_ops: List[ViewOp], 96) -> torch.fx.Node: 97 view_op, view_ops_tail = view_ops[0], view_ops[1:] 98 99 if view_ops_tail: 100 view = graph_call_function( 101 graph, view_op.target, inp, *view_op.args, **view_op.kwargs 102 ) 103 src = _decompose_scatter_functional_helper(graph, view, src, view_ops[1:]) # type: ignore[assignment] 104 105 return graph_call_function( 106 graph, 107 _VIEW_OP_TO_SCATTER[view_op.target], 108 inp, 109 src, 110 *view_op.args, 111 **view_op.kwargs, 112 ) 113 114 115def _decompose_scatter_functional( 116 graph: torch.fx.Graph, node: torch.fx.Node 117) -> torch.fx.Node: 118 """Decompose _generalized_scatter to a sequence of view_scatter operations 119 120 e.g. _generalized_scatter(inp, src, [(aten.slice, 0, 0, 10), (aten.slice, 1, 10, -10)]) 121 122 will become 123 124 view = aten.slice(inp, 0, 0, 10) 125 view_updated = aten.slice_scatter(view, src, 1, 10, -10) 126 inp_updated = aten.slice_scatter(inp, view_updated, 0, 0, 10) 127 """ 128 assert node.target is _generalized_scatter 129 inp, src, view_ops = node.args 130 return _decompose_scatter_functional_helper(graph, *node.args) # type: ignore[arg-type] 131 132 133def _decompose_scatter_mutating( 134 graph: torch.fx.Graph, node: torch.fx.Node 135) -> torch.fx.Node: 136 """Decompose _generalized_scatter using mutations 137 138 e.g. _generalized_scatter(inp, src, [(aten.slice, 0, 0, 10), (aten.slice, 1, 10, -10)]) 139 140 will become 141 142 inp_updated = aten.clone(inp) 143 slice1 = aten.slice(inp_updated, 0, 0, 10) 144 slice2 = aten.slice(slice1, 1, 10, -10) 145 slice2.copy_(src) 146 147 """ 148 assert node.target in (_generalized_scatter, _inplace_generalized_scatter) 149 inp, src, view_ops = node.args 150 assert not node.kwargs 151 152 if node.target is _generalized_scatter: 153 inp = graph_call_function(graph, aten.clone, inp) 154 155 tmp = inp 156 for view in view_ops: # type: ignore[union-attr] 157 tmp = graph_call_function(graph, view.target, tmp, *view.args, **view.kwargs) # type: ignore[union-attr] 158 159 graph_call_function(graph, aten.copy_.default, tmp, src) 160 return inp # type: ignore[return-value] 161 162 163# View ops whose view_scatter op is lowered into mutations anyway, 164# so is never a pessimisation to decompose. 165_ALWAYS_MUTATING_SCATTER_OPS = { 166 aten.as_strided.default, 167 aten.diagonal.default, 168} 169 170 171def scatter_always_uses_mutation(node: torch.fx.Node) -> bool: 172 _, _, view_ops = node.args 173 return any(view.target in _ALWAYS_MUTATING_SCATTER_OPS for view in view_ops) # type: ignore[union-attr] 174 175 176def should_reinplace_scatter(node: torch.fx.Node) -> bool: 177 """Choose between mutating and functional scatter decompositions 178 179 Reinplacing view scatter ops can be pessimising as it blocks fusion with the 180 input or output tensor computations. However, it is still profitable if the 181 input and output would have been realized anyway. 182 183 """ 184 inp, src, view_ops = node.args 185 186 # Mutating scatter ops unconditionally realize input and output 187 if scatter_always_uses_mutation(node): 188 return True 189 190 if is_node_realized(inp) and is_node_realized(node): # type: ignore[arg-type] 191 return True 192 193 # If the output is copied back into the input, this forces both to be 194 # realized as the output is a user of the input 195 if inp.op in ("placeholder", "get_attr") and any( # type: ignore[union-attr] 196 user.target is aten.copy_.default and user.args[0] is inp for user in node.users 197 ): 198 return True 199 200 # Otherwise, assume fusions will make functional variants profitable 201 return False 202 203 204def decompose_generalized_scatter(graph: torch.fx.Graph) -> None: 205 """Replace _generalized_scatter with normal aten ops""" 206 for node in itertools.chain( 207 graph.find_nodes(op="call_function", target=_generalized_scatter), 208 graph.find_nodes(op="call_function", target=_inplace_generalized_scatter), 209 ): 210 use_mutation = ( 211 node.target is _inplace_generalized_scatter 212 or scatter_always_uses_mutation(node) 213 ) 214 215 with graph.inserting_before(node): 216 if use_mutation: 217 new_node = _decompose_scatter_mutating(graph, node) 218 else: 219 new_node = _decompose_scatter_functional(graph, node) 220 221 node.replace_all_uses_with(new_node) 222 graph.erase_node(node) 223 224 225def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None: 226 """ 227 This canonicalizes view scatter ops into a generalized form, defined as: 228 def scatter(inp, src, views): 229 tmp = inp.clone() 230 for view in views: 231 tmp = view(tmp) 232 tmp.copy_(src) 233 234 We also fuse consecutive view scatter ops of the form 235 a = scatter(view2(self), src, [view1]) 236 b = scatter(self, a, [view2]) 237 which can be rewritten as 238 b = scatter(self, src, [view2, view1]) 239 a = view2(b) 240 241 This is both more efficient as we only do a single scatter, and also 242 easier to reinplace since there is only one use of `self` 243 """ 244 245 node_to_view_base: Dict[torch.fx.Node, torch.fx.Node] = {} 246 node_to_view_op: Dict[torch.fx.Node, List[ViewOp]] = defaultdict(list) 247 248 def handle_views(node: torch.fx.Node): 249 inp = node.args[0] 250 node_to_view_base[node] = node_to_view_base.get(inp, inp) # type: ignore[arg-type] 251 node_to_view_op[node] = [ 252 *node_to_view_op[inp], # type: ignore[index] 253 ViewOp( 254 node.target, # type: ignore[arg-type] 255 args=node.args[1:], 256 kwargs=node.kwargs, 257 ), 258 ] 259 260 def handle_view_scatter(node: torch.fx.Node): 261 assert len(node.args) >= 2 262 inp, src = node.args[:2] 263 264 scatter_view_op = ViewOp( 265 _SCATTER_OP_TO_VIEW[node.target], 266 args=node.args[2:], 267 kwargs=node.kwargs, 268 ) 269 270 def can_fuse(): 271 if src.target is not _generalized_scatter: # type: ignore[union-attr] 272 return False 273 src_inp, src_src, src_scatter_view_op = src.args # type: ignore[union-attr] 274 275 inp_base = node_to_view_base.get(inp, inp) # type: ignore[arg-type] 276 src_base = node_to_view_base.get(src_inp, src_inp) # type: ignore[arg-type] 277 return inp_base is src_base and node_to_view_op[src_inp] == [ # type: ignore[index] 278 *node_to_view_op[inp], # type: ignore[index] 279 scatter_view_op, 280 ] 281 282 if not can_fuse(): 283 with graph.inserting_before(node): 284 new_node = graph_call_function( 285 graph, 286 _generalized_scatter, 287 inp, 288 src, 289 [scatter_view_op], 290 ) 291 node.replace_all_uses_with(new_node) 292 graph.erase_node(node) 293 return 294 295 src_inp, src_src, src_scatter_view_op = src.args # type: ignore[union-attr] 296 with graph.inserting_before(src): # type: ignore[arg-type] 297 new_node = graph_call_function( 298 graph, 299 _generalized_scatter, 300 inp, 301 src_src, 302 [scatter_view_op, *src_scatter_view_op], # type: ignore[misc] 303 ) 304 node.replace_all_uses_with(new_node) 305 graph.erase_node(node) 306 307 if src.users: # type: ignore[union-attr] 308 new_src = graph_call_function( 309 graph, 310 _SCATTER_OP_TO_VIEW[node.target], 311 new_node, 312 *node.args[2:], 313 **node.kwargs, 314 ) 315 316 handle_views(new_src) 317 src.replace_all_uses_with(new_src) # type: ignore[union-attr] 318 319 graph.erase_node(src) # type: ignore[arg-type] 320 321 for node in graph.nodes: 322 if _is_view_op(node.target): 323 handle_views(node) 324 elif node.target in _SCATTER_OP_TO_VIEW: 325 handle_view_scatter(node) 326 327 328inplaceable_ops = { 329 aten.index_put.default: InplaceableOp(aten.index_put_.default, 0), 330 aten._unsafe_index_put.default: InplaceableOp(inductor_prims._unsafe_index_put_, 0), 331 _generalized_scatter: InplaceableOp( 332 _inplace_generalized_scatter, 333 0, 334 extra_check=should_reinplace_scatter, 335 ), 336} 337 338try: 339 c10d_functional = torch.ops._c10d_functional 340 inplaceable_collective_ops = { 341 c10d_functional.all_reduce.default: InplaceableOp( 342 c10d_functional.all_reduce_.default, 0 343 ), 344 c10d_functional.all_reduce_coalesced.default: InplaceableOp( 345 c10d_functional.all_reduce_coalesced_.default, 0 346 ), 347 } 348 inplaceable_ops.update(inplaceable_collective_ops) 349except AttributeError: 350 # _c10d_functional ops are only available when torch 351 # is built with USE_DISTRIBUTED=1. 352 pass 353 354inplaceable_foreach_ops: Dict[torch._ops.OpOverload, InplaceableOp] = {} 355for outplace_op, inplace_op in inplaceable_foreach_ops_lowerings.items(): 356 inplaceable_foreach_ops[outplace_op] = InplaceableOp(inplace_op, 0) 357 358 359inplaceable_triton_ops = {triton_kernel_wrapper_functional} 360 361 362# Operators that don't depend on the tensor data 363META_ONLY_OPS = { 364 aten.sym_size.int, 365 aten.sym_stride.int, 366 aten.sym_numel.default, 367 aten.sym_storage_offset.default, 368} 369 370 371def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None: 372 """ 373 Reinplaces in-placeable operations. 374 If there are no uses of a view of the mutated arg after the current node, 375 it is possible to inplace the op. 376 This above algorithm could be justified by observing side effects. While 377 we traverse the graph in forwards direction, only latter nodes could view 378 side effects of the current node. If the current node is not used later as 379 well as no view of this node is used later in the graph, then it is safe to 380 inplace as there would be no way to observe the side effects. 381 This condition is slightly different for graph inputs where they can only 382 be inplaced if the above condition is true and there's a copy_ in the 383 epilogue that signals that the caller wants to observe the mutation. 384 385 Unlike JIT Inductor, AOTInductor currently unlifts weights and buffers from 386 input args, so instead of checking mutation on placeholder, AOTInductor 387 checks mutation on get_attr. This is subject to change in future. 388 """ 389 390 copy_args_to_copy_nodes = {} 391 # maps argument to the first copy_ node that mutates it. 392 copy_nodes = {} 393 mutated_inputs = set() 394 storage_to_nodes = defaultdict(list) 395 node_order: Dict[Any, int] = {} 396 for i, node in enumerate(reversed(graph.nodes)): 397 node_order[node] = len(graph.nodes) - i - 1 398 storage_to_nodes[get_node_storage(node)].append(node) 399 if node.target == aten.copy_.default and node.args[0].op in ( 400 "placeholder", 401 "get_attr", 402 ): 403 dst = node.args[0] 404 src = node.args[1] 405 # If the target is a getitem and it indexes a possible clone, 406 # then skip over it 407 if src.target == operator.getitem and ( 408 ( 409 src.args[0].target == triton_kernel_wrapper_functional 410 and src.args[0].kwargs["kwargs"][src.args[1]] == node.args[0] 411 ) 412 or (src.args[0].target in inplaceable_foreach_ops) 413 or (src.args[0].target == torch.ops.higher_order.auto_functionalized) 414 ): 415 src = src.args[0] 416 417 copy_args_to_copy_nodes[(dst, src)] = node 418 copy_nodes[dst] = node 419 420 mutated_inputs.add(node.args[0]) 421 422 def any_use_of_views_after_node(node, shared_view_nodes, *, copy_node, mutated_arg): 423 node_loc = node_order[node] 424 copy_node_loc = node_order[copy_node] if copy_node is not None else None 425 426 def is_meta_only_user(node): 427 if _is_view_op(node.target): 428 return all(is_meta_only_user(u) for u in node.users) 429 return node.target in META_ONLY_OPS 430 431 for view in shared_view_nodes: 432 for user in view.users: 433 user_loc = node_order[user] 434 # Skip all users before node 435 if user_loc <= node_loc: 436 continue 437 # Ignore uses after the copy_ epilogue node, where the input 438 # has already been mutated anyway 439 if copy_node_loc is not None and copy_node_loc <= user_loc: 440 continue 441 # Reinplacing does not change shape metadata 442 if is_meta_only_user(user): 443 continue 444 # If our graph looks like: 445 # foo(mutated_arg) 446 # mutated_arg.copy_(other) 447 # then it's safe for us to reinplace foo because mutated_arg 448 # will get overwritten anyways. 449 if ( 450 user.target is torch.ops.aten.copy_.default 451 and mutated_arg is user.args[0] 452 ): 453 continue 454 return True 455 return False 456 457 def can_inplace(node, mutated_arg): 458 if isinstance(mutated_arg, (list, tuple)): 459 unique_storages = {get_node_storage(arg) for arg in mutated_arg} 460 if len(unique_storages) != len(mutated_arg): 461 # at least two Tensors in mutated_arg alias each other, so we can't reinplace it. 462 # We can probably do better (that is, reinplace one of them and clone the other) 463 # but that requires more work and mutable List[Tensor] are not that common. 464 return False 465 return all(can_inplace(node, arg) for arg in mutated_arg) 466 467 if get_node_storage(mutated_arg) is None: 468 return False 469 shared_view_nodes = storage_to_nodes[get_node_storage(mutated_arg)] 470 471 if mutated_arg.op in ("placeholder", "get_attr"): 472 # Get the first copy_ node that mutates the mutated_arg. 473 copy_node = copy_nodes.get(mutated_arg, None) 474 if copy_node is None: 475 # There is no copy_ back to the candidate mutated_arg (which is a graph input). 476 # Therefore the semantics of the program are that it does not mutate 477 # mutated_arg, so we cannot re-inplace it. 478 return False 479 if any_use_of_views_after_node( 480 node, shared_view_nodes, copy_node=copy_node, mutated_arg=mutated_arg 481 ): 482 return False 483 484 return True 485 elif any(view.op in ("placeholder", "get_attr") for view in shared_view_nodes): 486 # This should never happen in auto_functionalize_v2 non-inference mode, 487 # since all mutated_arg are bases. 488 489 # If mutated arg is view of any of the inputs of the graph, 490 # do not allow for inplacing. 491 # This would require more sophisticated algorithm to handle 492 return False 493 else: 494 return not any_use_of_views_after_node( 495 node, shared_view_nodes, copy_node=None, mutated_arg=mutated_arg 496 ) 497 498 def log_inplace_results( 499 node_name, 500 old_tensors_to_clone, 501 tensors_to_clone, 502 possibly_missed_reinplacing_opportunities, 503 ): 504 log.info( 505 "For node %s, attempted to reinplace %s. We were unable to reinplace %s; " 506 "%s (if non-empty) are possible missed reinplacing opportunities that may be bad for " 507 "memory usage and performance.", 508 node_name, 509 old_tensors_to_clone, 510 tensors_to_clone, 511 possibly_missed_reinplacing_opportunities, 512 ) 513 torch._dynamo.utils.counters["inductor"][ 514 "possibly_missed_reinplacing_opportunities" 515 ] += len(possibly_missed_reinplacing_opportunities) 516 517 replace_dict: Dict[torch.fx.Node, torch.fx.Node] = {} 518 519 def reinplace_and_refine_tensors_to_clone( 520 old_tensors_to_clone, kwargs, node_name, auto_functionalize_v2=False 521 ): 522 tensors_to_clone: List[str] = [] 523 storage_of_reinplaced_args = set() 524 possibly_missed_reinplacing_opportunities = [] 525 526 def tensor_with_same_storage_already_reinplaced(arg): 527 if isinstance(arg, (list, tuple)): 528 return any( 529 get_node_storage(a) in storage_of_reinplaced_args for a in arg 530 ) 531 return get_node_storage(mutated_arg) in storage_of_reinplaced_args 532 533 for arg in old_tensors_to_clone: 534 assert arg in kwargs 535 536 mutated_arg = kwargs[arg] 537 538 # Let's say we have: 539 # - op(x, y) that mutates both x and y 540 # - new_x, new_y = functional_op(x, y) is the functional variant 541 # If we are presented with functional_op(x, x), we must not reinplace 542 # this into op(x, x), because then it would be writing to the same Tensor. 543 # Instead, it's OK to reinplace one of them and to clone the other: 544 # >>> y = x.clone() 545 # >>> op(x, y) 546 # This also applies if we have views: functional_op(x, x[0]) 547 # should not reinplace into op(x, x[0]). 548 should_attempt_reinplace = not tensor_with_same_storage_already_reinplaced( 549 mutated_arg 550 ) 551 if should_attempt_reinplace and can_inplace(node, mutated_arg): 552 # In general, we probably do not need those optimizations. 553 copy_node = copy_args_to_copy_nodes.get((mutated_arg, node)) 554 if copy_node is not None: 555 replace_dict[copy_node] = copy_node.args[0] 556 if not auto_functionalize_v2: 557 for user in node.users: 558 # For auto_functionalize_v2, arg is the index of the base, where base at index i corresponds to 559 # output atindex size(out)+i. 560 # This used to compare string with integers before for auto_functionalize_v2. Not sure 561 # if it was needed for inplaceable_triton_ops? 562 if user.target == operator.getitem and user.args[1] == arg: 563 replace_dict[user] = mutated_arg 564 565 if isinstance(mutated_arg, (list, tuple)): 566 for a in mutated_arg: 567 storage_of_reinplaced_args.add(get_node_storage(a)) 568 else: 569 storage_of_reinplaced_args.add(get_node_storage(mutated_arg)) 570 else: 571 if should_attempt_reinplace: 572 possibly_missed_reinplacing_opportunities.append(arg) 573 tensors_to_clone.append(arg) 574 575 log_inplace_results( 576 node_name, 577 old_tensors_to_clone, 578 tensors_to_clone, 579 possibly_missed_reinplacing_opportunities, 580 ) 581 return tensors_to_clone 582 583 for node in graph.nodes: 584 if (inplaceable_op := inplaceable_ops.get(node.target, None)) is not None: 585 mutated_arg = node.args[inplaceable_op.mutated_arg] 586 if can_inplace(node, mutated_arg) and inplaceable_op.extra_check(node): 587 # TODO(yifu): this doesn't properly remove copy epilogues for 588 # ops that mutate multiple inputs. Need to revise the copy 589 # node tracking logic to support the case. 590 copy_node = copy_args_to_copy_nodes.get((mutated_arg, node)) 591 if copy_node is not None: 592 replace_dict[copy_node] = copy_node.args[0] 593 node.target = inplaceable_op.inplace_op 594 elif node.target == torch.ops.higher_order.auto_functionalized_v2: 595 _mutable_op = node.args[0] 596 kwargs = node.kwargs 597 598 all_bases = kwargs["_all_bases"] 599 bases_to_clone = range(len(all_bases)) 600 base_tensors_dct = dict(enumerate(all_bases)) 601 new_bases_to_clone: List[int] = reinplace_and_refine_tensors_to_clone( 602 bases_to_clone, 603 base_tensors_dct, 604 node.target, 605 auto_functionalize_v2=True, 606 ) 607 # Stash the metadata. There is a pass later on where we decompose 608 # auto_functionalized into clones + a mutable op; this metadata 609 # tells the decomp to only clone the following inputs 610 node.meta["only_clone_these_tensors"] = new_bases_to_clone 611 elif node.target == torch.ops.higher_order.auto_functionalized: 612 _mutable_op = node.args[0] 613 from torch._higher_order_ops.auto_functionalize import get_mutable_args 614 615 tensors_to_clone, _ = get_mutable_args(_mutable_op) 616 # Don't try to reinplace Optional[Tensor] args that are None. 617 tensors_to_clone = [ 618 t for t in tensors_to_clone if node.kwargs[t] is not None 619 ] 620 tensors_to_clone = reinplace_and_refine_tensors_to_clone( 621 tensors_to_clone, 622 node.kwargs, 623 _mutable_op._name, 624 auto_functionalize_v2=False, 625 ) 626 627 # Stash the metadata. There is a pass later on where we decompose 628 # auto_functionalized into clones + a mutable op; this metadata 629 # tells the decomp to only clone the following inputs 630 node.meta["only_clone_these_tensors"] = tensors_to_clone 631 elif node.target in inplaceable_triton_ops: 632 kernel_idx = node.kwargs["kernel_idx"] 633 kernel = kernel_side_table.get_kernel(kernel_idx) 634 from triton.runtime.autotuner import Autotuner 635 from triton.runtime.jit import JITFunction 636 637 if isinstance(kernel, JITFunction): 638 kernel_name = kernel.fn.__name__ 639 elif isinstance(kernel, Autotuner): 640 if config.is_fbcode(): 641 # Autotuner has different implementations for AMD and NV 642 if torch.version.hip is None: 643 kernel_name = kernel.base_fn.__name__ 644 else: 645 kernel_name = kernel.fn.__name__ 646 else: 647 kernel_name = kernel.base_fn.__name__ 648 else: 649 raise AssertionError("Unknown triton kernel type") 650 651 # inplaceable_triton_ops take an additional argument called 652 # tensors_to_clone which contain a list of tensors to clone 653 # This pass iterates over them and sees which ones are safe 654 # to eliminate (i.e. no longer need the clones) 655 tensors_to_clone = reinplace_and_refine_tensors_to_clone( 656 node.kwargs["tensors_to_clone"], node.kwargs["kwargs"], kernel_name 657 ) 658 659 kwargs = dict(node.kwargs) 660 kwargs["tensors_to_clone"] = tensors_to_clone 661 node.kwargs = immutable_dict(kwargs) 662 elif ( 663 inplaceable_op := inplaceable_foreach_ops.get(node.target, None) 664 ) is not None: 665 mutated_args = node.args[inplaceable_op.mutated_arg] 666 667 if not all((arg, node) in copy_args_to_copy_nodes for arg in mutated_args): 668 continue 669 670 if can_inplace(node, mutated_args): 671 for arg in mutated_args: 672 copy_node = copy_args_to_copy_nodes[(arg, node)] 673 replace_dict[copy_node] = copy_node.args[0] 674 675 node.target = inplaceable_op.inplace_op 676 for node, replacement in replace_dict.items(): 677 while replacement in replace_dict: 678 replacement = replace_dict[replacement] 679 replace_dict[node] = replacement 680 681 node.replace_all_uses_with(replacement) 682 graph.erase_node(node) 683 684 685def reinplace_inplaceable_ops(graph: torch.fx.Graph) -> None: 686 canonicalize_view_scatter_ops(graph) 687 reinplace_inplaceable_ops_core(graph) 688 decompose_generalized_scatter(graph) 689