• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2
3
4"""
5A set of primitive functions for performing collective ops.
6
7Each should also handle single rank scenario.
8"""
9
10from __future__ import annotations
11
12from dataclasses import dataclass
13from typing import Any, Callable, cast, Generic, List, Optional, Tuple, TypeVar, Union
14
15import torch.distributed as dist
16
17
18T = TypeVar("T")
19
20
21@dataclass
22class SyncPayload(Generic[T]):
23    stage_name: Optional[str]
24    success: bool
25    payload: T
26    exception: Optional[Exception] = None
27
28
29def broadcast(
30    data_or_fn: Union[T, Callable[[], T]],
31    *,
32    success: bool = True,
33    stage_name: Optional[str] = None,
34    rank: int = 0,
35    pg: Optional[dist.ProcessGroup] = None,
36) -> T:
37    """
38    Broadcasts the data payload from rank 0 to all other ranks.
39    Or if a function is passed, execute it in rank 0 and broadcast result to all other ranks.
40
41    Can be used to broadcast a failure signal to stop all ranks.
42
43    If the function raises an exception, all ranks will raise.
44
45    Args:
46        data_or_fn: the data to broadcast or function to execute and broadcast result.
47        success: False to stop all ranks.
48        stage_name: the name of the logical stage for synchronization and debugging
49        rank: rank to broadcast data or execute function and broadcast resutls.
50        pg: the process group for sync
51    Throws:
52        RuntimeError from original exception trace
53    Returns:
54        the value after synchronization
55
56    Example usage:
57    >> id = broadcast(data_or_fn=allocate_id, rank=0, pg=ext_pg.my_pg)
58    """
59
60    if not success and data_or_fn is not None:
61        raise AssertionError(
62            "Data or Function is expected to be None if not successful"
63        )
64
65    payload: Optional[T] = None
66    exception: Optional[Exception] = None
67    # if no pg is passed then execute if rank is 0
68    if (pg is None and rank == 0) or (pg is not None and pg.rank() == rank):
69        # determine if it is an executable function or data payload only
70        if callable(data_or_fn):
71            try:
72                payload = data_or_fn()
73            except Exception as e:
74                success = False
75                exception = e
76        else:
77            payload = data_or_fn
78
79    # broadcast the exception type if any to all ranks for failure categorization
80    sync_obj = SyncPayload(
81        stage_name=stage_name,
82        success=success,
83        payload=payload,
84        exception=exception,
85    )
86
87    if pg is not None:
88        broadcast_list = [sync_obj]
89        dist.broadcast_object_list(broadcast_list, src=rank, group=pg)
90        assert len(broadcast_list) == 1
91        sync_obj = broadcast_list[0]
92
93    # failure in any rank will trigger a throw in every rank.
94    if not sync_obj.success:
95        error_msg = f"Rank {rank} failed"
96        if stage_name is not None:
97            error_msg += f": stage {sync_obj.stage_name}"
98        if sync_obj.exception is not None:
99            error_msg += f": exception {sync_obj.exception}"
100        raise RuntimeError(error_msg) from sync_obj.exception
101
102    return cast(T, sync_obj.payload)
103
104
105def all_gather(
106    data_or_fn: Union[T, Callable[[], T]],
107    stage_name: Optional[str] = None,
108    pg: Optional[dist.ProcessGroup] = None,
109) -> List[T]:
110    """
111    A simple all_gather primitive with basic synchronization guard logic,
112    by checking payload from all ranks has the same stage name.
113
114    Args:
115        data_or_fn: the data to be all gathered across ranks or function to be executed
116        stage_name: the sync stage name for out-of-sync protection
117        pg: the process group for sync
118    Throws:
119        RuntimeError from original exception trace
120    Returns:
121        a list of synced data from all ranks
122
123    Example usage:
124    >> all_ids = all_gather(data_or_fn=allocate_id, pg=ext_pg.my_pg)
125    """
126    payload: Optional[T] = None
127    exception: Optional[Exception] = None
128    success = True
129    # determine if it is an executable function or data payload only
130    if callable(data_or_fn):
131        try:
132            payload = data_or_fn()
133        except Exception as e:
134            success = False
135            exception = e
136    else:
137        payload = data_or_fn
138
139    sync_obj = SyncPayload(
140        stage_name=stage_name,
141        success=success,
142        payload=payload,
143        exception=exception,
144    )
145
146    if pg is not None:
147        # List of success/failure across all ranks.
148        total_list = [None] * dist.get_world_size(pg)
149        all_gather_object_enforce_type(pg, total_list, sync_obj)
150        # Each rank will throw RuntimeError in case of failure on any rank.
151        stage_name = cast(SyncPayload[T], total_list[0]).stage_name
152        exception_list: List[Tuple[int, Exception]] = []
153        ret_list: List[T] = []
154        error_msg: str = ""
155
156        for i, sp in enumerate(cast(List[SyncPayload[T]], total_list)):
157            if sp.stage_name != stage_name:
158                error_msg += (
159                    f"Unexpected stage name received from rank {i}: {sp.stage_name} "
160                )
161                continue
162            if not sp.success and sp.exception is not None:
163                exception_list.append((i, sp.exception))
164                continue
165            ret_list.append(sp.payload)
166
167        if len(exception_list) > 0:
168            raise RuntimeError(  # type: ignore[misc]
169                error_msg, exception_list
170            ) from exception_list[0]
171        return ret_list
172    else:
173        if not sync_obj.success:
174            raise RuntimeError(
175                f"all_gather failed with exception {sync_obj.exception}",
176            ) from sync_obj.exception
177        return [sync_obj.payload]  # type: ignore[list-item]
178
179
180# Note: use Any for typing for now so users can pass in
181# either a list of None or target type placeholders
182# otherwise pyre would complain
183def all_gather_object_enforce_type(
184    pg: dist.ProcessGroup,
185    # pyre-fixme[2]: Parameter must have a type that does not contain `Any`
186    object_list: List[Any],
187    # pyre-fixme[2]: Parameter must have a type other than `Any`
188    obj: Any,
189    # pyre-fixme[2]: Parameter must have a type that does not contain `Any`
190    type_checker: Callable[[Any, Any], bool] = lambda x, y: type(x) == type(y),
191) -> None:
192    """
193    Similar to plain all_gather_object but with additional type checking
194    AFTER gather is done to ensure basic consistency.
195    If check does not pass, all ranks will fail with exception.
196
197    This is generally to prevent conditional logic leading to
198    unexpected messages being received. This is considered fatal code error,
199    but due to logic stacks this might happen implicitly in practice.
200
201    The default check does not check sub type (considered different)
202    or covariance (considered same) but users can pass in custom checker
203    if more complicated check is needed.
204    """
205    dist.all_gather_object(object_list, obj, group=pg)
206
207    # conservative check
208    list_len = len(object_list)
209    if list_len == 0:
210        return
211    first_obj = object_list[0]
212    for i in range(1, list_len):
213        if not type_checker(first_obj, object_list[i]):
214            raise TypeError(
215                f"Object type at index {i} is {type(object_list[i])}, "
216                f"while first object type is {type(first_obj)}"
217            )
218