• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["oncall: distributed"]
2import collections
3import inspect
4import logging
5import math
6import operator
7from dataclasses import dataclass
8from functools import partial
9from typing import (
10    Any,
11    Callable,
12    cast,
13    Dict,
14    Generator,
15    List,
16    Optional,
17    Set,
18    Tuple,
19    Union,
20)
21
22import torch
23import torch.fx as fx
24from torch._dynamo.utils import counters
25from torch.fx.passes.graph_transform_observer import GraphTransformObserver
26from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
27from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
28
29from .. import config
30from ..fx_utils import get_fake_args_kwargs
31from ..virtualized import V
32
33
34aten = torch.ops.aten
35logger: logging.Logger = logging.getLogger("comm_fusion")
36
37
38def move_block_after(block: List[fx.Node], target_node: fx.Node) -> None:
39    for node in block:
40        target_node.append(node)
41        target_node = node
42
43
44def move_block_before(block: List[fx.Node], target_node: fx.Node) -> None:
45    for node in block:
46        target_node.prepend(node)
47        target_node = node
48
49
50def call_function(
51    graph: fx.Graph,
52    target: Union[str, Callable[..., Any]],
53    args: Optional[Tuple[fx.node.Argument, ...]] = None,
54    kwargs: Optional[Dict[str, fx.node.Argument]] = None,
55) -> fx.Node:
56    # We accept target as a str to avoid typing error as the type of
57    # a node.target is Union[str, Callable[..., Any]].
58    # This also allows us to avoid writing check for every call.
59    if isinstance(target, str):
60        raise RuntimeError(f"Call function should not get a str target {target=}")
61    node = graph.call_function(target, args, kwargs)
62    _, args, kwargs = get_fake_args_kwargs(node)
63    with V.fake_mode:
64        node.meta["val"] = target(*args, **kwargs)
65        # node.meta["val"] may be a container. So we use tree_map here
66        # to recursively extract the tensor metadata.
67        node.meta["tensor_meta"] = tree_map(
68            _extract_tensor_metadata, (node.meta["val"],)
69        )[0]
70    return node
71
72
73@dataclass(unsafe_hash=True)
74class CommBlock:
75    shape: Union[torch.Size, List[torch.Size]]
76    node_list: List[fx.Node]
77    inputs: List[fx.Node]
78    wait_nodes: List[fx.Node]
79    comm_node: fx.Node
80    outputs: Set[fx.Node]
81
82
83def get_comm_block(comm_node: fx.Node) -> Optional[CommBlock]:
84    """
85    Given a collective node (e.g., allreduce), find out all the nodes belong to
86    this communcation.
87
88    Args:
89        comm_node(fx.Node): The target communication/collective node.
90    Returns:
91        The CommBlock that encapsulates the related nodes (e.g., wait_node) of
92        the given comm_node.
93    """
94    node_list = []
95    wait_nodes = []
96    inputs, _ = tree_flatten((comm_node.args, comm_node.kwargs))
97    input_nodes = [inp for inp in inputs if isinstance(inp, fx.Node)]
98    wait_prefixes = "wait_tensor"
99    # If the users of the wait node are following items, we consinder them
100    # to be a part of the output.
101    intermediate_outputs = ("split", "reshape", "getitem", "detach", "alias")
102
103    first_user = next(iter(comm_node.users))
104    if (
105        len(comm_node.users) == 1
106        and first_user.target == torch.ops._c10d_functional.wait_tensor.default
107    ):
108        # Collective with only one output
109        node_list = [comm_node, first_user]
110        wait_nodes.append(first_user)
111    elif len(comm_node.users) > 1 and first_user.target == operator.getitem:
112        # Collective with only more than one output
113        node_list.append(comm_node)
114        for user in comm_node.users:
115            if user.target != operator.getitem:
116                return None
117            if len(user.users) != 1:
118                return None
119            wait_node = next(iter(user.users))
120            if wait_node.target != torch.ops._c10d_functional.wait_tensor.default:
121                return None
122            wait_nodes.append(wait_node)
123            node_list.append(user)
124        node_list.extend(wait_nodes)
125    else:
126        return None
127
128    # Identify all the outputs of this collective block.
129    outputs: Set[fx.Node] = set()
130    nodes = collections.deque(wait_nodes)
131    while nodes:
132        node = nodes.popleft()
133        for user in node.users:
134            if isinstance(user, fx.Node) and user.name.startswith(intermediate_outputs):
135                nodes.append(user)
136                node_list.append(user)
137            else:
138                outputs.add(node)
139                break
140
141    tensor_meta = input_nodes[0].meta["tensor_meta"]
142    shape: Union[torch.Size, List[torch.Size]]
143    if isinstance(tensor_meta, TensorMetadata):
144        shape = tensor_meta.shape
145    elif isinstance(tensor_meta, (list, tuple)):
146        shape = [tm.shape for tm in tensor_meta]
147    else:
148        logger.warning("Unexpected type of tensor_meta %s", type(tensor_meta))
149        return None
150
151    return CommBlock(
152        shape=shape,
153        node_list=node_list,
154        wait_nodes=wait_nodes,
155        comm_node=comm_node,
156        inputs=input_nodes,
157        outputs=outputs,
158    )
159
160
161def get_all_comm_blocks(
162    graph: fx.Graph,
163    comm_ops: Tuple[torch._ops.OpOverload, ...],
164    comm_filter: Optional[Callable[..., bool]] = None,
165) -> List[CommBlock]:
166    if comm_filter is None:
167
168        def always_true(comm_block: CommBlock) -> bool:
169            return True
170
171        comm_filter = always_true
172
173    blocks = []
174    for node in graph.nodes:
175        if node.target not in comm_ops:
176            continue
177        comm_block = get_comm_block(node)
178        if comm_block is not None and comm_filter(comm_block):
179            blocks.append(comm_block)
180    return blocks
181
182
183def _fuse_allreduce_by_concat(
184    graph: fx.Graph,
185    last_input_node: fx.Node,
186    all_input_nodes: List[fx.Node],
187    last_comm_block: CommBlock,
188) -> CommBlock:
189    """Given a list of inputs in order, create a fused allreduce using concat."""
190    # Flatten all the inputs to the all_reduce nodes.
191    with graph.inserting_after(last_input_node):
192        cat_inputs = []
193        for input_node in all_input_nodes:
194            assert isinstance(input_node.args[0], fx.Node)
195            input_node = input_node.args[0]
196            cat_inputs.append(
197                call_function(graph, aten.flatten.using_ints, (input_node,))
198            )
199
200    # Concat all the flattened nodes.
201    with graph.inserting_after(cat_inputs[0]):
202        cat_node = call_function(graph, aten.cat, (cat_inputs,))
203
204    # Insert the fused div node and remove the input div nodes.
205    # This is an optimization and is not mandatory for fusion.
206    divisors = [div.args[1] for div in all_input_nodes]
207    assert all(divisor == divisors[0] for divisor in divisors)
208    with graph.inserting_after(cat_node):
209        div_node = call_function(graph, last_input_node.target, (cat_node, divisors[0]))
210
211    # Create a new Comm/all_reduce node.
212    last_comm_node = last_comm_block.comm_node
213    last_wait_node = last_comm_block.wait_nodes[0]
214    with graph.inserting_after(div_node):
215        flatten_args, spec = tree_flatten((last_comm_node.args, last_comm_node.kwargs))
216        flatten_args[0] = div_node
217        args, kwargs = tree_unflatten(flatten_args, spec)
218        fused_comm_node = call_function(graph, last_comm_node.target, args, kwargs)
219
220    # Create a new Wait node.
221    with graph.inserting_after(fused_comm_node):
222        flatten_args, spec = tree_flatten((last_wait_node.args, last_wait_node.kwargs))
223        flatten_args[0] = fused_comm_node
224        args, kwargs = tree_unflatten(flatten_args, spec)
225        fused_wait_node = call_function(graph, last_wait_node.target, args, kwargs)
226
227    # Move the fused all_reduce and its args to right after the input node
228    nodes_to_move = cat_inputs + [cat_node, div_node, fused_comm_node, fused_wait_node]
229    move_block_after(nodes_to_move, last_input_node)
230
231    return CommBlock(
232        shape=cast(TensorMetadata, cat_node.meta.get("tensor_meta")).shape,
233        node_list=[fused_comm_node, fused_wait_node],
234        wait_nodes=[fused_wait_node],
235        comm_node=fused_comm_node,
236        inputs=[div_node],
237        outputs={fused_wait_node},
238    )
239
240
241def _fuse_with_coalesced_op(
242    graph: fx.Graph,
243    last_input_node: fx.Node,
244    all_input_nodes: List[fx.Node],
245    last_comm_block: CommBlock,
246) -> CommBlock:
247    """Given a list of inputs in order, create a fused allreduce by coalesced."""
248    last_comm_node = last_comm_block.comm_node
249    last_wait_node = last_comm_block.wait_nodes[0]
250
251    # Insert the fused div node and remove the input div nodes.
252    # This is an optimization and is not mandatory for fusion.
253    dividends = [div.args[0] for div in all_input_nodes]
254    divisors = [div.args[1] for div in all_input_nodes]
255    assert all(divisor == divisors[0] for divisor in divisors)
256    with graph.inserting_before(last_input_node):
257        last_input_node = call_function(
258            graph, aten._foreach_div.Scalar, (dividends, divisors[0])
259        )
260    input_node = last_input_node
261
262    # Create a new Comm/all_reduce_coalesced node.
263    with graph.inserting_after(last_comm_node):
264        flatten_args, spec = tree_flatten((last_comm_node.args, last_comm_node.kwargs))
265        flatten_args[0] = input_node
266        args, kwargs = tree_unflatten(flatten_args, spec)
267        fused_comm_node = call_function(
268            graph, torch.ops._c10d_functional.all_reduce_coalesced.default, args, kwargs
269        )
270
271    # Create a new wait node.
272    getitem_nodes = []
273    wait_nodes = []
274    flatten_args, spec = tree_flatten((last_wait_node.args, last_wait_node.kwargs))
275    for idx in range(len(all_input_nodes)):
276        with graph.inserting_after(fused_comm_node):
277            gi_node = call_function(graph, operator.getitem, (fused_comm_node, idx))
278        getitem_nodes.append(gi_node)
279        flatten_args[0] = gi_node
280        args, kwargs = tree_unflatten(flatten_args, spec)
281        with graph.inserting_after(gi_node):
282            wait_nodes.append(call_function(graph, last_wait_node.target, args, kwargs))
283
284    # Move the new all_reduce_coalesced and its args to right after the input node
285    nodes_to_move = [fused_comm_node] + getitem_nodes + wait_nodes
286    move_block_after(nodes_to_move, last_input_node)
287
288    return CommBlock(
289        shape=[
290            tm.shape
291            for tm in cast(
292                List[TensorMetadata], fused_comm_node.meta.get("tensor_meta")
293            )
294        ],
295        node_list=[fused_comm_node] + getitem_nodes + wait_nodes,
296        wait_nodes=wait_nodes,
297        comm_node=fused_comm_node,
298        inputs=[input_node],
299        outputs=set(wait_nodes),
300    )
301
302
303def _scatter_fused_allreduce_waits(
304    graph: fx.Graph,
305    fused_comm_block: CommBlock,
306    orig_comm_blocks: List[CommBlock],
307    node_indices: Dict[fx.Node, int],
308    split_and_reshape: bool = True,
309) -> None:
310    """
311    Scatters the result of the fused communication node to the original users.
312    If the fused method is concat splitting the output and reshape will be inserted,
313    before inserting getitem. Otherwise getitem will be used as the users of the
314    wait node.
315    """
316
317    # Before we mass up the order, we need to get the index of the last wait node
318    # in orig_comm_blocks. This index will be later used to determinee what users
319    # nodes need to be move to maintain a correct topological sort order.
320    last_wait_node_idx = 0
321    for node in graph.nodes:
322        last_wait_node_idx = max(
323            node_indices.get(node, last_wait_node_idx), last_wait_node_idx
324        )
325        if node == orig_comm_blocks[-1].wait_nodes[0]:
326            break
327
328    if split_and_reshape:
329        fused_wait_node = fused_comm_block.wait_nodes[0]
330        with graph.inserting_after(fused_wait_node):
331            split_node = call_function(
332                graph,
333                aten.split,
334                (
335                    fused_wait_node,
336                    [math.prod(cast(List[int], cb.shape)) for cb in orig_comm_blocks],
337                ),
338            )
339        with graph.inserting_after(split_node):
340            fused_outputs = []
341            for idx, comm_block in enumerate(orig_comm_blocks):
342                split_idx_node = call_function(
343                    graph, operator.getitem, (split_node, idx)
344                )
345                with graph.inserting_after(split_idx_node):
346                    fused_outputs.append(
347                        call_function(
348                            graph, aten.reshape, (split_idx_node, comm_block.shape)
349                        )
350                    )
351    else:
352        fused_outputs = fused_comm_block.wait_nodes
353
354    # Scatter the fused outputs.
355    incorrect_order_nodes = []
356    for comm_block, fused_output in zip(orig_comm_blocks, fused_outputs):
357        # Some descendant users of the orig_comm_blocks may be scheduled before
358        # the fused all_reduce. For example, the user nodes of the very first
359        # all_reduce may be scheduled before the second all_reduce. Since the
360        # fused all_reduce is inserted right after the last all_reudce, the
361        # order can be wrong.
362        # `incorrect_order_nodes` records these nodes.
363
364        orig_wait = comm_block.wait_nodes[0]
365        nodes = collections.deque(list(orig_wait.users))
366        while nodes:
367            user_node = nodes.popleft()
368            if not isinstance(user_node, fx.Node):
369                continue
370            if node_indices[user_node] < last_wait_node_idx:
371                incorrect_order_nodes.append(user_node)
372                nodes.extend(list(user_node.users))
373
374        orig_wait.replace_all_uses_with(fused_output)
375
376    last_fused_result = fused_outputs[0]
377    fused_outputs_set = set(fused_outputs)
378    for node in graph.nodes:
379        if node in fused_outputs_set:
380            last_fused_result = node
381
382    # Move the incorrect_order_nodes to right after the last fused_result.
383    incorrect_order_nodes = sorted(
384        incorrect_order_nodes, key=lambda node: node_indices[node]
385    )
386    move_block_after(incorrect_order_nodes, last_fused_result)
387
388
389def _fuse_allreduce(
390    graph: fx.Graph,
391    comm_blocks: List[CommBlock],
392    node_indices: Dict[fx.Node, int],
393    use_concat: bool,
394) -> CommBlock:
395    """Given a list of allreduce CommBlock, fuse the CommBlocks into one CommBlock."""
396
397    if len(comm_blocks) == 1:
398        return comm_blocks[0]
399
400    # Find the last input node of all the CommBlocks. This node will be served
401    # as the inserting point of the new collective op.
402    last_input_node = comm_blocks[0].inputs[0]
403    last_input_index = -1
404    all_input_nodes = []
405    for comm_block in comm_blocks:
406        input_node = comm_block.inputs[0]
407        all_input_nodes.append(input_node)
408        index = node_indices[input_node]
409        if index >= last_input_index:
410            assert index != last_input_index
411            last_input_node = input_node
412            last_input_index = index
413
414    if use_concat:
415        fused_comm_block = _fuse_allreduce_by_concat(
416            graph, last_input_node, all_input_nodes, comm_blocks[-1]
417        )
418    else:
419        fused_comm_block = _fuse_with_coalesced_op(
420            graph, last_input_node, all_input_nodes, comm_blocks[-1]
421        )
422
423    _scatter_fused_allreduce_waits(
424        graph, fused_comm_block, comm_blocks, node_indices, split_and_reshape=use_concat
425    )
426
427    for comm_block in comm_blocks:
428        for wait in comm_block.wait_nodes:
429            graph.erase_node(wait)
430        graph.erase_node(comm_block.comm_node)
431    graph.eliminate_dead_code()
432
433    return fused_comm_block
434
435
436def _bucket_size_fusion(
437    graph: fx.Graph, comm_blocks: List[CommBlock], bucket_size_mb: int
438) -> Generator[List[CommBlock], None, None]:
439    MB = 1024**2
440    bucket_size = 1 * MB
441    bucket_cap_size = bucket_size_mb * MB
442    curr_size = 0
443    curr_blocks = []
444
445    count = 0
446    fuse_count = 0
447    for i, block in enumerate(comm_blocks):
448        curr_blocks.append(block)
449        itemsize = block.comm_node.meta["tensor_meta"].dtype.itemsize
450        curr_size += cast(torch.Size, block.shape).numel() * itemsize
451        count += 1
452        if curr_size < bucket_size and i != len(comm_blocks) - 1:
453            continue
454
455        fuse_count += 1
456        if torch.distributed.get_rank() == 0:
457            logger.info(
458                "DDP bucketing: block%d, count=%d, curr_size=%d, bucket_size=%d",
459                fuse_count,
460                count,
461                curr_size,
462                bucket_size,
463            )
464
465        # Set the debug counters
466        counters["inductor"]["ddp_buckets"] = fuse_count
467        yield curr_blocks
468
469        bucket_size = bucket_cap_size
470        curr_blocks = []
471        curr_size = 0
472        count = 0
473
474
475def _fuse_ddp_communication(
476    graph: fx.Graph, algorithm_fn: Callable[..., Any], fusion_fn: Callable[..., Any]
477) -> None:
478    for output in reversed(graph.nodes):
479        if output.op == "output":
480            break
481
482    def ddp_reducer_filter(block: CommBlock) -> bool:
483        if (
484            not isinstance(block.comm_node.args[0], fx.Node)
485            or block.comm_node.args[0].target != aten.div.Tensor
486        ):
487            return False
488
489        if len(block.wait_nodes[0].users) != 1:
490            # gradient/wait node should only be used by one user
491            return False
492
493        # Two cases:
494        # 1. gradient/wait node should be directly used by the output
495        # if gradient is None before bwd.
496        # 2. gradient/wait node should be directly used by copy_.
497        if (
498            output not in block.wait_nodes[0].users
499            and next(iter(block.wait_nodes[0].users)).target != aten.copy_.default
500        ):
501            return False
502
503        return True
504
505    ops = (
506        torch.ops._c10d_functional.all_reduce_.default,
507        torch.ops._c10d_functional.all_reduce.default,
508    )
509    comm_blocks = get_all_comm_blocks(graph, ops, comm_filter=ddp_reducer_filter)
510    node_indices = {node: i for i, node in enumerate(graph.nodes)}
511
512    for block in algorithm_fn(graph, comm_blocks):
513        fusion_fn(graph, block, node_indices)
514
515
516def fuse_ddp_with_coalesced_op(graph: fx.Graph, bucket_size_mb: int) -> None:
517    _fuse_ddp_communication(
518        graph,
519        partial(_bucket_size_fusion, bucket_size_mb=bucket_size_mb),
520        partial(_fuse_allreduce, use_concat=False),
521    )
522
523
524def fuse_ddp_with_concat_op(graph: fx.Graph, bucket_size_mb: int) -> None:
525    _fuse_ddp_communication(
526        graph,
527        partial(_bucket_size_fusion, bucket_size_mb=bucket_size_mb),
528        partial(_fuse_allreduce, use_concat=True),
529    )
530
531
532def schedule_comm_wait(graph: fx.Graph) -> None:
533    """
534    Delay the execution of wait tensors of allreduce until its first user.
535
536    This algorithm considers the intermediate users, like split, getitem,
537    of the wait node and schedule those intermediate users as well.
538    This will result in a better overlapping result.
539    """
540    ops = (
541        torch.ops._c10d_functional.all_reduce_.default,
542        torch.ops._c10d_functional.all_reduce.default,
543        torch.ops._c10d_functional.all_reduce_coalesced.default,
544        torch.ops._c10d_functional.all_reduce_coalesced_.default,
545    )
546    comm_blocks = get_all_comm_blocks(graph, ops)
547    if not comm_blocks:
548        return
549
550    # Find all the end users.
551    allreduce_users: Set[fx.Node] = set()
552    for allreduce in comm_blocks:
553        for output in allreduce.outputs:
554            allreduce_users.update(output.users)
555
556    node_indices = {node: i for i, node in enumerate(graph.nodes)}
557    for allreduce in comm_blocks:
558        # Find the earliest/first user -- target_node.
559        assert (
560            len(allreduce.outputs) >= 1
561        ), f"Found a allreduce that has zero outputs/users -- {allreduce}."
562        # Initialize the target node to avoid typing issues.
563        target_node = next(iter(next(iter(allreduce.outputs)).users))
564        target_node_index = 2**31
565        for user in (user for output in allreduce.outputs for user in output.users):
566            index = node_indices[user]
567            if index < target_node_index:
568                target_node = user
569                target_node_index = index
570
571        # Move wait nodes and all the subsequent nodes in the comm_block to
572        # before the first user -- target_node.
573        wait_idx = -1
574        for wait_idx, node in enumerate(allreduce.node_list):
575            if node == allreduce.wait_nodes[0]:
576                break
577        assert wait_idx >= 0
578        move_block_before(allreduce.node_list[wait_idx:], target_node)
579
580
581def fuse_ddp_communication(
582    graph: fx.Graph, passes: List[Union[Callable[..., None], str]], bucket_size_mb: int
583) -> None:
584    for i, pa in enumerate(passes):
585        with GraphTransformObserver(
586            graph.owning_module,
587            f"fuse_ddp_communication_pass_{i}",
588            config.trace.log_url_for_graph_xform,
589        ):
590            if isinstance(pa, str):
591                func = globals()[pa]
592            else:
593                func = pa
594            if "bucket_size_mb" in {
595                v.name for v in inspect.signature(func).parameters.values()
596            }:
597                func(graph, bucket_size_mb=bucket_size_mb)
598            else:
599                func(graph)
600