# mypy: allow-untyped-defs # pyre-strict from __future__ import annotations import heapq import operator import sys from collections import defaultdict from typing import Dict, List, Set, TYPE_CHECKING import torch from . import config, ir from .dependencies import WeakDep from .utils import ( contains_collective, contains_wait, find_recursive_deps_of_node, find_recursive_users_of_node, is_collective, is_fallback_op, is_wait, ) overlap_log = torch._logging.getArtifactLogger(__name__, "overlap") if TYPE_CHECKING: from .scheduler import BaseSchedulerNode def sink_waits(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]: """ Greedily schedules waits as late as possible. """ return _schedule_for_comm( snodes, raise_comms=False, sink_waits=True, reorder_for_overlap=False ) def raise_comms(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]: """ Greedily schedules comms as early as possible. """ return _schedule_for_comm( snodes, raise_comms=True, sink_waits=False, reorder_for_overlap=False ) def reorder_compute_for_overlap( snodes: List[BaseSchedulerNode], ) -> List[BaseSchedulerNode]: """ This achieves the following overall scheduling procedure: Step 1: Given that we've currently scheduled comm N, we now schedule all compute nodes that are required for comm N + 1 but do not depend on comm N, to run at the same time with comm N. Step 2: If all those compute nodes are sufficient to overlap comm N, we're done. Otherwise, we now need to look elsewhere to find compute that overlaps with comm N. We prioritize compute nodes that are needed sooner. Step 3: We schedule the compute nodes dependent on comm N and required for comm N + 1. Step 4: We schedule comm N + 1. Repeat this for subsequent comm nodes. """ return _schedule_for_comm( snodes, raise_comms=True, sink_waits=True, reorder_for_overlap=True ) def _schedule_for_comm( snodes: List[BaseSchedulerNode], raise_comms: bool, sink_waits: bool, reorder_for_overlap: bool, ) -> List[BaseSchedulerNode]: """ Schedule `snodes` for various comm optimization objectives. Args: snodes: the nodes to be scheduled. raise_comms: whether to greedily schedule collectives as early as possible sink_wait: whether to greedily schedule waits as late as possible reorder_compute_for_overlap: whether to reorder compute nodes to optimize for compute/communication overlapping. Returns: The new schedule order. Some notes on the synergy between different options: - `raise_comms` provides more overlapping oppurtunies for `reorder_compute_for_overlap`. - When both `raise_comms` and `sink_waits` is `True`, `raise_comms` is prioritized. """ # We assign each node a tuple of scores (score_0, score_1, score_2), # decreasing in importance, with a lower value indicating a higher ranking: # # - score_0: the lowest comm_idx among the comm nodes that the node blocks. # If a node doesn't block any comm nodes, its score_0 is set to # sys.maxsize. This score ensures that comm nodes get scheduled as early as # possible. # - score_1: 1 if the node is a wait node, 0 otherwise. This score ensures # that wait nodes are deferred as late as possible. # - score_2: the index of the node in the original topological order. This # score provides stability in case of ties. # # When only raise_comms is True, only score_0 and score_2 are considered. # When only sink_waits is True, only score_1 and score_2 are considered. # When neither is True, the original order is yielded. buf_name_to_snode = {} name_to_fused_node = {} scores_0, scores_1, scores_2 = {}, {}, {} for idx, snode in enumerate(snodes): for buf_name in snode.get_buffer_names(): buf_name_to_snode[buf_name] = snode for op_name in snode.get_operation_names(): name_to_fused_node[op_name] = snode name_to_fused_node[snode.get_name()] = snode node_name = snode.get_name() scores_0[node_name] = sys.maxsize scores_1[node_name] = 0 scores_2[node_name] = idx comm_idx = 0 for snode in snodes: if raise_comms and contains_collective(snode): scores_0[snode.get_name()] = comm_idx for anc in snode.ancestors: anc_fused_name = name_to_fused_node[anc].get_name() scores_0[anc_fused_name] = min(scores_0[anc_fused_name], comm_idx) comm_idx += 1 elif sink_waits and contains_wait(snode): scores_1[snode.get_name()] = 1 class Runnable: def __init__(self, snode) -> None: self.snode = snode name = next(iter(snode.get_operation_names())) fused_name = name_to_fused_node[name].get_name() self.score = ( scores_0[fused_name], scores_1[fused_name], scores_2[fused_name], ) def __lt__(self, other): return self.score < other.score unmet_deps: Dict[BaseSchedulerNode, Set[str]] = { snode: {dep.name for dep in snode.unmet_dependencies} for snode in snodes } ready: List[Runnable] = [] buffer_users: Dict[str, Set[BaseSchedulerNode]] = defaultdict(set) snode_to_cost = {snode: estimate_op_runtime(snode) for snode in snodes} for snode, deps in unmet_deps.items(): if len(deps) == 0: heapq.heappush(ready, Runnable(snode)) for dep in deps: buffer_users[dep].add(snode) scheduled = [] def schedule(snode): """ Schedules `snode` and put all unblocked nodes onto the ready queue. """ scheduled.append(snode) for buf_name in snode.get_buffer_names(): for snode in buffer_users[buf_name]: unmet_deps[snode].remove(buf_name) if len(unmet_deps[snode]) == 0: heapq.heappush(ready, Runnable(snode)) def get_overlapping_candidate(): """ Return the next node in the ready queue that's neither a collective or a wait. """ candidates = [ x for x in ready if not contains_collective(x.snode) and not contains_wait(x.snode) ] if len(candidates) == 0: return None return min(candidates, key=lambda x: x.score) def schedule_collective_for_overlap(snode): """ Schedules collective node `snode`, along with one or more compute nodes to overlap with it. The strategy is described in the comment of `reorder_compute_for_overlap`. """ assert contains_collective(snode) schedule(snode) collective_cost = snode_to_cost[snode] while ( collective_cost > 0 and (candidate := get_overlapping_candidate()) is not None ): ready.remove(candidate) schedule(candidate.snode) collective_cost -= snode_to_cost[candidate.snode] heapq.heapify(ready) while len(ready): snode = heapq.heappop(ready).snode if reorder_for_overlap and contains_collective(snode): schedule_collective_for_overlap(snode) else: schedule(snode) for snode, deps in unmet_deps.items(): assert len(deps) == 0, ( "Detected unscheduled nodes. " f"Nodes with unmet dependencies: {unmet_deps}" ) return scheduled def decide_global_ordering_of_comms( nodes: List[BaseSchedulerNode], name_to_buf, name_to_fused_node ) -> List[BaseSchedulerNode]: """ Decide global ordering of comms, by just enforcing the ordering that's in the input graph (might not be the same ordering as the eager mode program). TODO: Come up with a better approach """ # If FSDP2 is used, we apply FSDP-specific passes. if any( is_fallback_op( x.node, { torch.ops.fsdp.all_gather_copy_in.default, torch.ops.fsdp.chunk_cat.default, }, ) for x in nodes ): nodes = enforce_comm_ordering_for_fsdp(nodes, name_to_buf, name_to_fused_node) comm_nodes = [n for n in nodes if contains_collective(n)] for i in range(1, len(comm_nodes)): # Enforce ordering by making previous comm a `WeakDep` dependency of the next comm mutating_buf = next(iter(comm_nodes[i].get_buffer_names())) for buf in comm_nodes[i - 1].get_buffer_names(): comm_nodes[i].add_fake_dep(WeakDep(buf, mutating_buf=mutating_buf)) return nodes def estimate_op_runtime(snode: BaseSchedulerNode) -> float: """ Returns estimated op runtime in nanoseconds (ns) """ if config.estimate_op_runtime == "default": runtime = snode.get_estimated_runtime() else: assert callable(config.estimate_op_runtime) runtime = config.estimate_op_runtime(snode) return runtime def node_summary(snode): detail = "" if isinstance(snode.node, ir.ExternKernelOut): detail = f" ({snode.node.python_kernel_name})" out_tensor_info = "" if ( hasattr(snode.node, "layout") and hasattr(snode.node.layout, "size") and hasattr(snode.node.layout, "stride") ): out_tensor_info = ( f" (size={snode.node.layout.size}, stride={snode.node.layout.stride})" ) node_name = "" if hasattr(snode.node, "name"): node_name = snode.node.name return f"{snode.node.__class__.__name__}{detail}{out_tensor_info} ({node_name})" def visualize_overlap(order): total_est_runtime: float = 0.0 cur_comm_node = None for snode in order: if cur_comm_node is None: if contains_collective(snode): total_est_runtime += estimate_op_runtime(snode) cur_comm_node = snode.node elif is_wait(snode.node): raise AssertionError( "Wait is not expected when there is no collective running" ) else: # exposed compute op total_est_runtime += estimate_op_runtime(snode) overlap_log.debug(f"{node_summary(snode)}") # noqa: G004 else: # cur_comm_node is not None if contains_collective(snode): raise AssertionError( "Found two collectives running at the same time. " "`visualize_overlap` needs to be updated to handle this case" ) elif is_wait(snode.node): # end of this comm op overlap_log.debug(f"{node_summary(snode)}") # noqa: G004 cur_comm_node = None else: # overlapped compute op overlap_log.debug(f"| {node_summary(snode)}") # noqa: G004 overlap_log.debug( f"Est. runtime (ms): {total_est_runtime / 1000 / 1000}" # noqa: G004 ) def reorder_compute_and_comm_for_overlap( snodes: List[BaseSchedulerNode], ) -> List[BaseSchedulerNode]: order = snodes for p in config.reorder_for_compute_comm_overlap_passes: if isinstance(p, str) and p in globals(): p = globals()[p] # it is a builtin pass if torch.distributed.get_rank() == 0: overlap_log.debug( f"==== Visualize overlap before reordering pass {p} ====" # noqa: G004 ) try: visualize_overlap(order) except Exception as e: overlap_log.debug(str(e)) order = p(order) # type: ignore[operator] if torch.distributed.get_rank() == 0: overlap_log.debug( f"==== Visualize overlap after reordering pass {p} ====" # noqa: G004 ) try: visualize_overlap(order) except Exception as e: overlap_log.debug(str(e)) return order def reinplace_fsdp_all_gather(graph: torch.fx.Graph) -> None: try: import torch.distributed._composable.fsdp._fsdp_collectives assert torch.distributed.is_available() # Assert existence of these ops assert ( torch.ops._c10d_functional.all_gather_into_tensor and torch.ops._c10d_functional.all_gather_into_tensor_out ) except (ImportError, AttributeError, AssertionError): return from .pattern_matcher import ( CallFunction, KeywordArg, Match, PatternMatcherPass, register_graph_pattern, ) """ all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(...); getitem = all_gather_copy_in[0]; (getitem_1 = all_gather_copy_in[1];) # optional all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem, ...); -> all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(...); getitem = all_gather_copy_in[0]; getitem_1 = all_gather_copy_in[1]; all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor_out.default(getitem, ..., out=getitem_1); """ def remove_unused_getitem(g): # Remove `getitem_X = all_gather_copy_in[1]` which is never used. node_list = list(g.nodes) for n in node_list: if ( n.target == operator.getitem and n.args[0].target is torch.ops.fsdp.all_gather_copy_in.default and n.args[1] == 1 ): g.erase_node(n) graph_pass = PatternMatcherPass() @register_graph_pattern( CallFunction( torch.ops._c10d_functional.all_gather_into_tensor.default, CallFunction( operator.getitem, CallFunction( torch.ops.fsdp.all_gather_copy_in.default, KeywordArg("all_gather_inputs"), KeywordArg("inp_split_sizes"), KeywordArg("all_gather_input_numel"), KeywordArg("world_size"), KeywordArg("rank"), KeywordArg("dtype"), KeywordArg("device"), ), KeywordArg("item_idx"), ), KeywordArg("group_size"), KeywordArg("group_name"), ), pass_dict=graph_pass, extra_check=lambda match: match.kwargs["item_idx"] == 0, ) def reinplace_all_gather(match: Match, *args, **kwargs): def repl( *args, ): copy_in_args = args[:-2] group_size = args[-2] group_name = args[-1] all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default( *copy_in_args ) getitem = all_gather_copy_in[0] getitem_1 = all_gather_copy_in[1] all_gather_into_tensor = ( torch.ops._c10d_functional.all_gather_into_tensor_out.default( getitem, group_size, group_name, out=getitem_1 ) ) return all_gather_into_tensor match.replace_by_example( repl, [ kwargs["all_gather_inputs"], kwargs["inp_split_sizes"], kwargs["all_gather_input_numel"], kwargs["world_size"], kwargs["rank"], kwargs["dtype"], kwargs["device"], kwargs["group_size"], kwargs["group_name"], ], ) remove_unused_getitem(graph) graph_pass.apply(graph) # type: ignore[arg-type] def get_op_idx(snode): assert not isinstance( snode, ( torch._inductor.scheduler.FusedSchedulerNode, torch._inductor.scheduler.GroupedSchedulerNode, ), ) return int(snode.get_name()[2:]) def enforce_comm_ordering_for_fsdp( snodes: List[torch._inductor.scheduler.BaseSchedulerNode], name_to_buf: Dict[str, torch._inductor.scheduler.SchedulerBuffer], name_to_fused_node: Dict[str, BaseSchedulerNode], ) -> List[torch._inductor.scheduler.BaseSchedulerNode]: from . import scheduler new_order: list[BaseSchedulerNode] = [] scheduled = set() ag_exists = False rs_exists = False ag_grouped_node_to_wait_grouped_node = {} rs_grouped_node_to_wait_grouped_node = {} snode_name_to_final_snode = {} def _create_group_node(snodes_to_group): group_node = scheduler.GroupedSchedulerNode.create(snodes_to_group) for snode in snodes_to_group: snode_name_to_final_snode[snode.get_name()] = group_node snode_name_to_final_snode[group_node.get_name()] = group_node return group_node # Create grouped nodes for specific sets of ops for snode in snodes: # Case 1: Handle AllGather if is_collective( snode.node, op=torch.ops._c10d_functional.all_gather_into_tensor_out.default ) and any( is_fallback_op( name_to_fused_node[x].node, torch.ops.fsdp.all_gather_copy_in.default ) for x in snode.ancestors ): ag_exists = True ag_snode = snode ag_related_snode_set: set[scheduler.BaseSchedulerNode] = set() # Find the "cast + copy_in + getitem + all_gather" code block find_recursive_deps_of_node( ag_snode, ag_related_snode_set, name_to_buf, name_to_fused_node, ) # Find the "all_gather + all_gather_wait_tensor + copy_out + set_" code block allowed_ops = { torch.ops._c10d_functional.all_gather_into_tensor_out.default, torch.ops._c10d_functional.wait_tensor.default, torch.ops.fsdp.split_with_sizes_copy.default, torch.ops.aten.set_.source_Tensor, } find_recursive_users_of_node( ag_snode, ag_related_snode_set, name_to_buf, name_to_fused_node, criteria_cb=lambda x: not ( isinstance(x, scheduler.NopKernelSchedulerNode) or ( isinstance(x, scheduler.ExternKernelSchedulerNode) and x.node.op_overload in allowed_ops # type: ignore[union-attr] ) ), ) # sort nodes by original operation order ag_related_snodes = sorted( ag_related_snode_set, key=lambda x: get_op_idx(x) ) # In the "reuse layer" case, some ops in the 2nd all-gather code block could also # depend on ops in the 1st all-gather code block, and we don't want to group them together. end_idx_of_current_ag_block = len(ag_related_snodes) copy_out_count = 0 for i in range(len(ag_related_snodes)): cur_snode = ag_related_snodes[i] if is_fallback_op( cur_snode.node, torch.ops.fsdp.split_with_sizes_copy.default ): copy_out_count += 1 if copy_out_count > 1: end_idx_of_current_ag_block = i break ag_related_snodes = ag_related_snodes[:end_idx_of_current_ag_block] # Group "cast + copy_in + getitem + all_gather" into one GroupedSchedulerNode wait_node_idx = None for i in range(len(ag_related_snodes) - 1): if isinstance(ag_related_snodes[i + 1].node, ir._WaitKernel): wait_node_idx = i + 1 break assert wait_node_idx is not None ag_group_node = _create_group_node(ag_related_snodes[:wait_node_idx]) # Group "all_gather_wait_tensor + copy_out + set_" into one GroupedSchedulerNode ag_wait_group_node = _create_group_node(ag_related_snodes[wait_node_idx:]) ag_grouped_node_to_wait_grouped_node[ag_group_node] = ag_wait_group_node # Case 2: Handle ReduceScatter elif is_fallback_op(snode.node, torch.ops.fsdp.chunk_cat.default): rs_exists = True rs_snode = snode # Find the "reduce_scatter copy-in + reduce_scatter comm + reduce_scatter wait" code block rs_related_snode_set: set[scheduler.BaseSchedulerNode] = set() find_recursive_users_of_node( rs_snode, rs_related_snode_set, name_to_buf, name_to_fused_node, ) # sort nodes by original operation order rs_related_snodes = sorted( rs_related_snode_set, key=lambda x: get_op_idx(x) ) # Group "reduce_scatter copy-in + reduce_scatter comm" into one GroupedSchedulerNode wait_node_idx = None for i in range(len(rs_related_snodes) - 1): if isinstance(rs_related_snodes[i + 1].node, ir._WaitKernel): wait_node_idx = i + 1 break assert wait_node_idx is not None rs_group_node = _create_group_node(rs_related_snodes[:wait_node_idx]) # Group "reduce_scatter wait + related output nodes" into one GroupedSchedulerNode rs_wait_group_node = _create_group_node(rs_related_snodes[wait_node_idx:]) rs_grouped_node_to_wait_grouped_node[rs_group_node] = rs_wait_group_node assert len(snode_name_to_final_snode) > 0 if ag_exists: assert len(ag_grouped_node_to_wait_grouped_node) > 0 if rs_exists: assert len(rs_grouped_node_to_wait_grouped_node) > 0 # Build the new node schedule, taking GroupedSchedulerNode into account for snode in snodes: if snode.get_name() in snode_name_to_final_snode: snode = snode_name_to_final_snode[snode.get_name()] if snode in scheduled: continue new_order.append(snode) scheduled.add(snode) # Enforce AllGather ordering: previous AllGather's "wait then copy_out" group node must run # before next AllGather's "copy_in then AG" group node prev_ag_wait = None for ag_group_node, wait_group_node in ag_grouped_node_to_wait_grouped_node.items(): if prev_ag_wait is not None: mutating_buf = next(iter(ag_group_node.get_buffer_names())) for o in prev_ag_wait.get_outputs(): ag_group_node.add_fake_dep( WeakDep(o.get_name(), mutating_buf=mutating_buf) ) prev_ag_wait = wait_group_node # Enforce ReduceScatter ordering: previous ReduceScatter's "wait" group node must run # before next ReduceScatter's "copy_in then RS" group node prev_rs_wait = None for rs_group_node, wait_group_node in rs_grouped_node_to_wait_grouped_node.items(): if prev_rs_wait is not None: mutating_buf = next(iter(rs_group_node.get_buffer_names())) for o in prev_rs_wait.get_outputs(): rs_group_node.add_fake_dep( WeakDep(o.get_name(), mutating_buf=mutating_buf) ) prev_rs_wait = wait_group_node return new_order # type: ignore[return-value]