1# Owner(s): ["oncall: distributed"] 2import collections 3import inspect 4import logging 5import math 6import operator 7from dataclasses import dataclass 8from functools import partial 9from typing import ( 10 Any, 11 Callable, 12 cast, 13 Dict, 14 Generator, 15 List, 16 Optional, 17 Set, 18 Tuple, 19 Union, 20) 21 22import torch 23import torch.fx as fx 24from torch._dynamo.utils import counters 25from torch.fx.passes.graph_transform_observer import GraphTransformObserver 26from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata 27from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten 28 29from .. import config 30from ..fx_utils import get_fake_args_kwargs 31from ..virtualized import V 32 33 34aten = torch.ops.aten 35logger: logging.Logger = logging.getLogger("comm_fusion") 36 37 38def move_block_after(block: List[fx.Node], target_node: fx.Node) -> None: 39 for node in block: 40 target_node.append(node) 41 target_node = node 42 43 44def move_block_before(block: List[fx.Node], target_node: fx.Node) -> None: 45 for node in block: 46 target_node.prepend(node) 47 target_node = node 48 49 50def call_function( 51 graph: fx.Graph, 52 target: Union[str, Callable[..., Any]], 53 args: Optional[Tuple[fx.node.Argument, ...]] = None, 54 kwargs: Optional[Dict[str, fx.node.Argument]] = None, 55) -> fx.Node: 56 # We accept target as a str to avoid typing error as the type of 57 # a node.target is Union[str, Callable[..., Any]]. 58 # This also allows us to avoid writing check for every call. 59 if isinstance(target, str): 60 raise RuntimeError(f"Call function should not get a str target {target=}") 61 node = graph.call_function(target, args, kwargs) 62 _, args, kwargs = get_fake_args_kwargs(node) 63 with V.fake_mode: 64 node.meta["val"] = target(*args, **kwargs) 65 # node.meta["val"] may be a container. So we use tree_map here 66 # to recursively extract the tensor metadata. 67 node.meta["tensor_meta"] = tree_map( 68 _extract_tensor_metadata, (node.meta["val"],) 69 )[0] 70 return node 71 72 73@dataclass(unsafe_hash=True) 74class CommBlock: 75 shape: Union[torch.Size, List[torch.Size]] 76 node_list: List[fx.Node] 77 inputs: List[fx.Node] 78 wait_nodes: List[fx.Node] 79 comm_node: fx.Node 80 outputs: Set[fx.Node] 81 82 83def get_comm_block(comm_node: fx.Node) -> Optional[CommBlock]: 84 """ 85 Given a collective node (e.g., allreduce), find out all the nodes belong to 86 this communcation. 87 88 Args: 89 comm_node(fx.Node): The target communication/collective node. 90 Returns: 91 The CommBlock that encapsulates the related nodes (e.g., wait_node) of 92 the given comm_node. 93 """ 94 node_list = [] 95 wait_nodes = [] 96 inputs, _ = tree_flatten((comm_node.args, comm_node.kwargs)) 97 input_nodes = [inp for inp in inputs if isinstance(inp, fx.Node)] 98 wait_prefixes = "wait_tensor" 99 # If the users of the wait node are following items, we consinder them 100 # to be a part of the output. 101 intermediate_outputs = ("split", "reshape", "getitem", "detach", "alias") 102 103 first_user = next(iter(comm_node.users)) 104 if ( 105 len(comm_node.users) == 1 106 and first_user.target == torch.ops._c10d_functional.wait_tensor.default 107 ): 108 # Collective with only one output 109 node_list = [comm_node, first_user] 110 wait_nodes.append(first_user) 111 elif len(comm_node.users) > 1 and first_user.target == operator.getitem: 112 # Collective with only more than one output 113 node_list.append(comm_node) 114 for user in comm_node.users: 115 if user.target != operator.getitem: 116 return None 117 if len(user.users) != 1: 118 return None 119 wait_node = next(iter(user.users)) 120 if wait_node.target != torch.ops._c10d_functional.wait_tensor.default: 121 return None 122 wait_nodes.append(wait_node) 123 node_list.append(user) 124 node_list.extend(wait_nodes) 125 else: 126 return None 127 128 # Identify all the outputs of this collective block. 129 outputs: Set[fx.Node] = set() 130 nodes = collections.deque(wait_nodes) 131 while nodes: 132 node = nodes.popleft() 133 for user in node.users: 134 if isinstance(user, fx.Node) and user.name.startswith(intermediate_outputs): 135 nodes.append(user) 136 node_list.append(user) 137 else: 138 outputs.add(node) 139 break 140 141 tensor_meta = input_nodes[0].meta["tensor_meta"] 142 shape: Union[torch.Size, List[torch.Size]] 143 if isinstance(tensor_meta, TensorMetadata): 144 shape = tensor_meta.shape 145 elif isinstance(tensor_meta, (list, tuple)): 146 shape = [tm.shape for tm in tensor_meta] 147 else: 148 logger.warning("Unexpected type of tensor_meta %s", type(tensor_meta)) 149 return None 150 151 return CommBlock( 152 shape=shape, 153 node_list=node_list, 154 wait_nodes=wait_nodes, 155 comm_node=comm_node, 156 inputs=input_nodes, 157 outputs=outputs, 158 ) 159 160 161def get_all_comm_blocks( 162 graph: fx.Graph, 163 comm_ops: Tuple[torch._ops.OpOverload, ...], 164 comm_filter: Optional[Callable[..., bool]] = None, 165) -> List[CommBlock]: 166 if comm_filter is None: 167 168 def always_true(comm_block: CommBlock) -> bool: 169 return True 170 171 comm_filter = always_true 172 173 blocks = [] 174 for node in graph.nodes: 175 if node.target not in comm_ops: 176 continue 177 comm_block = get_comm_block(node) 178 if comm_block is not None and comm_filter(comm_block): 179 blocks.append(comm_block) 180 return blocks 181 182 183def _fuse_allreduce_by_concat( 184 graph: fx.Graph, 185 last_input_node: fx.Node, 186 all_input_nodes: List[fx.Node], 187 last_comm_block: CommBlock, 188) -> CommBlock: 189 """Given a list of inputs in order, create a fused allreduce using concat.""" 190 # Flatten all the inputs to the all_reduce nodes. 191 with graph.inserting_after(last_input_node): 192 cat_inputs = [] 193 for input_node in all_input_nodes: 194 assert isinstance(input_node.args[0], fx.Node) 195 input_node = input_node.args[0] 196 cat_inputs.append( 197 call_function(graph, aten.flatten.using_ints, (input_node,)) 198 ) 199 200 # Concat all the flattened nodes. 201 with graph.inserting_after(cat_inputs[0]): 202 cat_node = call_function(graph, aten.cat, (cat_inputs,)) 203 204 # Insert the fused div node and remove the input div nodes. 205 # This is an optimization and is not mandatory for fusion. 206 divisors = [div.args[1] for div in all_input_nodes] 207 assert all(divisor == divisors[0] for divisor in divisors) 208 with graph.inserting_after(cat_node): 209 div_node = call_function(graph, last_input_node.target, (cat_node, divisors[0])) 210 211 # Create a new Comm/all_reduce node. 212 last_comm_node = last_comm_block.comm_node 213 last_wait_node = last_comm_block.wait_nodes[0] 214 with graph.inserting_after(div_node): 215 flatten_args, spec = tree_flatten((last_comm_node.args, last_comm_node.kwargs)) 216 flatten_args[0] = div_node 217 args, kwargs = tree_unflatten(flatten_args, spec) 218 fused_comm_node = call_function(graph, last_comm_node.target, args, kwargs) 219 220 # Create a new Wait node. 221 with graph.inserting_after(fused_comm_node): 222 flatten_args, spec = tree_flatten((last_wait_node.args, last_wait_node.kwargs)) 223 flatten_args[0] = fused_comm_node 224 args, kwargs = tree_unflatten(flatten_args, spec) 225 fused_wait_node = call_function(graph, last_wait_node.target, args, kwargs) 226 227 # Move the fused all_reduce and its args to right after the input node 228 nodes_to_move = cat_inputs + [cat_node, div_node, fused_comm_node, fused_wait_node] 229 move_block_after(nodes_to_move, last_input_node) 230 231 return CommBlock( 232 shape=cast(TensorMetadata, cat_node.meta.get("tensor_meta")).shape, 233 node_list=[fused_comm_node, fused_wait_node], 234 wait_nodes=[fused_wait_node], 235 comm_node=fused_comm_node, 236 inputs=[div_node], 237 outputs={fused_wait_node}, 238 ) 239 240 241def _fuse_with_coalesced_op( 242 graph: fx.Graph, 243 last_input_node: fx.Node, 244 all_input_nodes: List[fx.Node], 245 last_comm_block: CommBlock, 246) -> CommBlock: 247 """Given a list of inputs in order, create a fused allreduce by coalesced.""" 248 last_comm_node = last_comm_block.comm_node 249 last_wait_node = last_comm_block.wait_nodes[0] 250 251 # Insert the fused div node and remove the input div nodes. 252 # This is an optimization and is not mandatory for fusion. 253 dividends = [div.args[0] for div in all_input_nodes] 254 divisors = [div.args[1] for div in all_input_nodes] 255 assert all(divisor == divisors[0] for divisor in divisors) 256 with graph.inserting_before(last_input_node): 257 last_input_node = call_function( 258 graph, aten._foreach_div.Scalar, (dividends, divisors[0]) 259 ) 260 input_node = last_input_node 261 262 # Create a new Comm/all_reduce_coalesced node. 263 with graph.inserting_after(last_comm_node): 264 flatten_args, spec = tree_flatten((last_comm_node.args, last_comm_node.kwargs)) 265 flatten_args[0] = input_node 266 args, kwargs = tree_unflatten(flatten_args, spec) 267 fused_comm_node = call_function( 268 graph, torch.ops._c10d_functional.all_reduce_coalesced.default, args, kwargs 269 ) 270 271 # Create a new wait node. 272 getitem_nodes = [] 273 wait_nodes = [] 274 flatten_args, spec = tree_flatten((last_wait_node.args, last_wait_node.kwargs)) 275 for idx in range(len(all_input_nodes)): 276 with graph.inserting_after(fused_comm_node): 277 gi_node = call_function(graph, operator.getitem, (fused_comm_node, idx)) 278 getitem_nodes.append(gi_node) 279 flatten_args[0] = gi_node 280 args, kwargs = tree_unflatten(flatten_args, spec) 281 with graph.inserting_after(gi_node): 282 wait_nodes.append(call_function(graph, last_wait_node.target, args, kwargs)) 283 284 # Move the new all_reduce_coalesced and its args to right after the input node 285 nodes_to_move = [fused_comm_node] + getitem_nodes + wait_nodes 286 move_block_after(nodes_to_move, last_input_node) 287 288 return CommBlock( 289 shape=[ 290 tm.shape 291 for tm in cast( 292 List[TensorMetadata], fused_comm_node.meta.get("tensor_meta") 293 ) 294 ], 295 node_list=[fused_comm_node] + getitem_nodes + wait_nodes, 296 wait_nodes=wait_nodes, 297 comm_node=fused_comm_node, 298 inputs=[input_node], 299 outputs=set(wait_nodes), 300 ) 301 302 303def _scatter_fused_allreduce_waits( 304 graph: fx.Graph, 305 fused_comm_block: CommBlock, 306 orig_comm_blocks: List[CommBlock], 307 node_indices: Dict[fx.Node, int], 308 split_and_reshape: bool = True, 309) -> None: 310 """ 311 Scatters the result of the fused communication node to the original users. 312 If the fused method is concat splitting the output and reshape will be inserted, 313 before inserting getitem. Otherwise getitem will be used as the users of the 314 wait node. 315 """ 316 317 # Before we mass up the order, we need to get the index of the last wait node 318 # in orig_comm_blocks. This index will be later used to determinee what users 319 # nodes need to be move to maintain a correct topological sort order. 320 last_wait_node_idx = 0 321 for node in graph.nodes: 322 last_wait_node_idx = max( 323 node_indices.get(node, last_wait_node_idx), last_wait_node_idx 324 ) 325 if node == orig_comm_blocks[-1].wait_nodes[0]: 326 break 327 328 if split_and_reshape: 329 fused_wait_node = fused_comm_block.wait_nodes[0] 330 with graph.inserting_after(fused_wait_node): 331 split_node = call_function( 332 graph, 333 aten.split, 334 ( 335 fused_wait_node, 336 [math.prod(cast(List[int], cb.shape)) for cb in orig_comm_blocks], 337 ), 338 ) 339 with graph.inserting_after(split_node): 340 fused_outputs = [] 341 for idx, comm_block in enumerate(orig_comm_blocks): 342 split_idx_node = call_function( 343 graph, operator.getitem, (split_node, idx) 344 ) 345 with graph.inserting_after(split_idx_node): 346 fused_outputs.append( 347 call_function( 348 graph, aten.reshape, (split_idx_node, comm_block.shape) 349 ) 350 ) 351 else: 352 fused_outputs = fused_comm_block.wait_nodes 353 354 # Scatter the fused outputs. 355 incorrect_order_nodes = [] 356 for comm_block, fused_output in zip(orig_comm_blocks, fused_outputs): 357 # Some descendant users of the orig_comm_blocks may be scheduled before 358 # the fused all_reduce. For example, the user nodes of the very first 359 # all_reduce may be scheduled before the second all_reduce. Since the 360 # fused all_reduce is inserted right after the last all_reudce, the 361 # order can be wrong. 362 # `incorrect_order_nodes` records these nodes. 363 364 orig_wait = comm_block.wait_nodes[0] 365 nodes = collections.deque(list(orig_wait.users)) 366 while nodes: 367 user_node = nodes.popleft() 368 if not isinstance(user_node, fx.Node): 369 continue 370 if node_indices[user_node] < last_wait_node_idx: 371 incorrect_order_nodes.append(user_node) 372 nodes.extend(list(user_node.users)) 373 374 orig_wait.replace_all_uses_with(fused_output) 375 376 last_fused_result = fused_outputs[0] 377 fused_outputs_set = set(fused_outputs) 378 for node in graph.nodes: 379 if node in fused_outputs_set: 380 last_fused_result = node 381 382 # Move the incorrect_order_nodes to right after the last fused_result. 383 incorrect_order_nodes = sorted( 384 incorrect_order_nodes, key=lambda node: node_indices[node] 385 ) 386 move_block_after(incorrect_order_nodes, last_fused_result) 387 388 389def _fuse_allreduce( 390 graph: fx.Graph, 391 comm_blocks: List[CommBlock], 392 node_indices: Dict[fx.Node, int], 393 use_concat: bool, 394) -> CommBlock: 395 """Given a list of allreduce CommBlock, fuse the CommBlocks into one CommBlock.""" 396 397 if len(comm_blocks) == 1: 398 return comm_blocks[0] 399 400 # Find the last input node of all the CommBlocks. This node will be served 401 # as the inserting point of the new collective op. 402 last_input_node = comm_blocks[0].inputs[0] 403 last_input_index = -1 404 all_input_nodes = [] 405 for comm_block in comm_blocks: 406 input_node = comm_block.inputs[0] 407 all_input_nodes.append(input_node) 408 index = node_indices[input_node] 409 if index >= last_input_index: 410 assert index != last_input_index 411 last_input_node = input_node 412 last_input_index = index 413 414 if use_concat: 415 fused_comm_block = _fuse_allreduce_by_concat( 416 graph, last_input_node, all_input_nodes, comm_blocks[-1] 417 ) 418 else: 419 fused_comm_block = _fuse_with_coalesced_op( 420 graph, last_input_node, all_input_nodes, comm_blocks[-1] 421 ) 422 423 _scatter_fused_allreduce_waits( 424 graph, fused_comm_block, comm_blocks, node_indices, split_and_reshape=use_concat 425 ) 426 427 for comm_block in comm_blocks: 428 for wait in comm_block.wait_nodes: 429 graph.erase_node(wait) 430 graph.erase_node(comm_block.comm_node) 431 graph.eliminate_dead_code() 432 433 return fused_comm_block 434 435 436def _bucket_size_fusion( 437 graph: fx.Graph, comm_blocks: List[CommBlock], bucket_size_mb: int 438) -> Generator[List[CommBlock], None, None]: 439 MB = 1024**2 440 bucket_size = 1 * MB 441 bucket_cap_size = bucket_size_mb * MB 442 curr_size = 0 443 curr_blocks = [] 444 445 count = 0 446 fuse_count = 0 447 for i, block in enumerate(comm_blocks): 448 curr_blocks.append(block) 449 itemsize = block.comm_node.meta["tensor_meta"].dtype.itemsize 450 curr_size += cast(torch.Size, block.shape).numel() * itemsize 451 count += 1 452 if curr_size < bucket_size and i != len(comm_blocks) - 1: 453 continue 454 455 fuse_count += 1 456 if torch.distributed.get_rank() == 0: 457 logger.info( 458 "DDP bucketing: block%d, count=%d, curr_size=%d, bucket_size=%d", 459 fuse_count, 460 count, 461 curr_size, 462 bucket_size, 463 ) 464 465 # Set the debug counters 466 counters["inductor"]["ddp_buckets"] = fuse_count 467 yield curr_blocks 468 469 bucket_size = bucket_cap_size 470 curr_blocks = [] 471 curr_size = 0 472 count = 0 473 474 475def _fuse_ddp_communication( 476 graph: fx.Graph, algorithm_fn: Callable[..., Any], fusion_fn: Callable[..., Any] 477) -> None: 478 for output in reversed(graph.nodes): 479 if output.op == "output": 480 break 481 482 def ddp_reducer_filter(block: CommBlock) -> bool: 483 if ( 484 not isinstance(block.comm_node.args[0], fx.Node) 485 or block.comm_node.args[0].target != aten.div.Tensor 486 ): 487 return False 488 489 if len(block.wait_nodes[0].users) != 1: 490 # gradient/wait node should only be used by one user 491 return False 492 493 # Two cases: 494 # 1. gradient/wait node should be directly used by the output 495 # if gradient is None before bwd. 496 # 2. gradient/wait node should be directly used by copy_. 497 if ( 498 output not in block.wait_nodes[0].users 499 and next(iter(block.wait_nodes[0].users)).target != aten.copy_.default 500 ): 501 return False 502 503 return True 504 505 ops = ( 506 torch.ops._c10d_functional.all_reduce_.default, 507 torch.ops._c10d_functional.all_reduce.default, 508 ) 509 comm_blocks = get_all_comm_blocks(graph, ops, comm_filter=ddp_reducer_filter) 510 node_indices = {node: i for i, node in enumerate(graph.nodes)} 511 512 for block in algorithm_fn(graph, comm_blocks): 513 fusion_fn(graph, block, node_indices) 514 515 516def fuse_ddp_with_coalesced_op(graph: fx.Graph, bucket_size_mb: int) -> None: 517 _fuse_ddp_communication( 518 graph, 519 partial(_bucket_size_fusion, bucket_size_mb=bucket_size_mb), 520 partial(_fuse_allreduce, use_concat=False), 521 ) 522 523 524def fuse_ddp_with_concat_op(graph: fx.Graph, bucket_size_mb: int) -> None: 525 _fuse_ddp_communication( 526 graph, 527 partial(_bucket_size_fusion, bucket_size_mb=bucket_size_mb), 528 partial(_fuse_allreduce, use_concat=True), 529 ) 530 531 532def schedule_comm_wait(graph: fx.Graph) -> None: 533 """ 534 Delay the execution of wait tensors of allreduce until its first user. 535 536 This algorithm considers the intermediate users, like split, getitem, 537 of the wait node and schedule those intermediate users as well. 538 This will result in a better overlapping result. 539 """ 540 ops = ( 541 torch.ops._c10d_functional.all_reduce_.default, 542 torch.ops._c10d_functional.all_reduce.default, 543 torch.ops._c10d_functional.all_reduce_coalesced.default, 544 torch.ops._c10d_functional.all_reduce_coalesced_.default, 545 ) 546 comm_blocks = get_all_comm_blocks(graph, ops) 547 if not comm_blocks: 548 return 549 550 # Find all the end users. 551 allreduce_users: Set[fx.Node] = set() 552 for allreduce in comm_blocks: 553 for output in allreduce.outputs: 554 allreduce_users.update(output.users) 555 556 node_indices = {node: i for i, node in enumerate(graph.nodes)} 557 for allreduce in comm_blocks: 558 # Find the earliest/first user -- target_node. 559 assert ( 560 len(allreduce.outputs) >= 1 561 ), f"Found a allreduce that has zero outputs/users -- {allreduce}." 562 # Initialize the target node to avoid typing issues. 563 target_node = next(iter(next(iter(allreduce.outputs)).users)) 564 target_node_index = 2**31 565 for user in (user for output in allreduce.outputs for user in output.users): 566 index = node_indices[user] 567 if index < target_node_index: 568 target_node = user 569 target_node_index = index 570 571 # Move wait nodes and all the subsequent nodes in the comm_block to 572 # before the first user -- target_node. 573 wait_idx = -1 574 for wait_idx, node in enumerate(allreduce.node_list): 575 if node == allreduce.wait_nodes[0]: 576 break 577 assert wait_idx >= 0 578 move_block_before(allreduce.node_list[wait_idx:], target_node) 579 580 581def fuse_ddp_communication( 582 graph: fx.Graph, passes: List[Union[Callable[..., None], str]], bucket_size_mb: int 583) -> None: 584 for i, pa in enumerate(passes): 585 with GraphTransformObserver( 586 graph.owning_module, 587 f"fuse_ddp_communication_pass_{i}", 588 config.trace.log_url_for_graph_xform, 589 ): 590 if isinstance(pa, str): 591 func = globals()[pa] 592 else: 593 func = pa 594 if "bucket_size_mb" in { 595 v.name for v in inspect.signature(func).parameters.values() 596 }: 597 func(graph, bucket_size_mb=bucket_size_mb) 598 else: 599 func(graph) 600