1#!/usr/bin/env python3 2 3 4""" 5A set of primitive functions for performing collective ops. 6 7Each should also handle single rank scenario. 8""" 9 10from __future__ import annotations 11 12from dataclasses import dataclass 13from typing import Any, Callable, cast, Generic, List, Optional, Tuple, TypeVar, Union 14 15import torch.distributed as dist 16 17 18T = TypeVar("T") 19 20 21@dataclass 22class SyncPayload(Generic[T]): 23 stage_name: Optional[str] 24 success: bool 25 payload: T 26 exception: Optional[Exception] = None 27 28 29def broadcast( 30 data_or_fn: Union[T, Callable[[], T]], 31 *, 32 success: bool = True, 33 stage_name: Optional[str] = None, 34 rank: int = 0, 35 pg: Optional[dist.ProcessGroup] = None, 36) -> T: 37 """ 38 Broadcasts the data payload from rank 0 to all other ranks. 39 Or if a function is passed, execute it in rank 0 and broadcast result to all other ranks. 40 41 Can be used to broadcast a failure signal to stop all ranks. 42 43 If the function raises an exception, all ranks will raise. 44 45 Args: 46 data_or_fn: the data to broadcast or function to execute and broadcast result. 47 success: False to stop all ranks. 48 stage_name: the name of the logical stage for synchronization and debugging 49 rank: rank to broadcast data or execute function and broadcast resutls. 50 pg: the process group for sync 51 Throws: 52 RuntimeError from original exception trace 53 Returns: 54 the value after synchronization 55 56 Example usage: 57 >> id = broadcast(data_or_fn=allocate_id, rank=0, pg=ext_pg.my_pg) 58 """ 59 60 if not success and data_or_fn is not None: 61 raise AssertionError( 62 "Data or Function is expected to be None if not successful" 63 ) 64 65 payload: Optional[T] = None 66 exception: Optional[Exception] = None 67 # if no pg is passed then execute if rank is 0 68 if (pg is None and rank == 0) or (pg is not None and pg.rank() == rank): 69 # determine if it is an executable function or data payload only 70 if callable(data_or_fn): 71 try: 72 payload = data_or_fn() 73 except Exception as e: 74 success = False 75 exception = e 76 else: 77 payload = data_or_fn 78 79 # broadcast the exception type if any to all ranks for failure categorization 80 sync_obj = SyncPayload( 81 stage_name=stage_name, 82 success=success, 83 payload=payload, 84 exception=exception, 85 ) 86 87 if pg is not None: 88 broadcast_list = [sync_obj] 89 dist.broadcast_object_list(broadcast_list, src=rank, group=pg) 90 assert len(broadcast_list) == 1 91 sync_obj = broadcast_list[0] 92 93 # failure in any rank will trigger a throw in every rank. 94 if not sync_obj.success: 95 error_msg = f"Rank {rank} failed" 96 if stage_name is not None: 97 error_msg += f": stage {sync_obj.stage_name}" 98 if sync_obj.exception is not None: 99 error_msg += f": exception {sync_obj.exception}" 100 raise RuntimeError(error_msg) from sync_obj.exception 101 102 return cast(T, sync_obj.payload) 103 104 105def all_gather( 106 data_or_fn: Union[T, Callable[[], T]], 107 stage_name: Optional[str] = None, 108 pg: Optional[dist.ProcessGroup] = None, 109) -> List[T]: 110 """ 111 A simple all_gather primitive with basic synchronization guard logic, 112 by checking payload from all ranks has the same stage name. 113 114 Args: 115 data_or_fn: the data to be all gathered across ranks or function to be executed 116 stage_name: the sync stage name for out-of-sync protection 117 pg: the process group for sync 118 Throws: 119 RuntimeError from original exception trace 120 Returns: 121 a list of synced data from all ranks 122 123 Example usage: 124 >> all_ids = all_gather(data_or_fn=allocate_id, pg=ext_pg.my_pg) 125 """ 126 payload: Optional[T] = None 127 exception: Optional[Exception] = None 128 success = True 129 # determine if it is an executable function or data payload only 130 if callable(data_or_fn): 131 try: 132 payload = data_or_fn() 133 except Exception as e: 134 success = False 135 exception = e 136 else: 137 payload = data_or_fn 138 139 sync_obj = SyncPayload( 140 stage_name=stage_name, 141 success=success, 142 payload=payload, 143 exception=exception, 144 ) 145 146 if pg is not None: 147 # List of success/failure across all ranks. 148 total_list = [None] * dist.get_world_size(pg) 149 all_gather_object_enforce_type(pg, total_list, sync_obj) 150 # Each rank will throw RuntimeError in case of failure on any rank. 151 stage_name = cast(SyncPayload[T], total_list[0]).stage_name 152 exception_list: List[Tuple[int, Exception]] = [] 153 ret_list: List[T] = [] 154 error_msg: str = "" 155 156 for i, sp in enumerate(cast(List[SyncPayload[T]], total_list)): 157 if sp.stage_name != stage_name: 158 error_msg += ( 159 f"Unexpected stage name received from rank {i}: {sp.stage_name} " 160 ) 161 continue 162 if not sp.success and sp.exception is not None: 163 exception_list.append((i, sp.exception)) 164 continue 165 ret_list.append(sp.payload) 166 167 if len(exception_list) > 0: 168 raise RuntimeError( # type: ignore[misc] 169 error_msg, exception_list 170 ) from exception_list[0] 171 return ret_list 172 else: 173 if not sync_obj.success: 174 raise RuntimeError( 175 f"all_gather failed with exception {sync_obj.exception}", 176 ) from sync_obj.exception 177 return [sync_obj.payload] # type: ignore[list-item] 178 179 180# Note: use Any for typing for now so users can pass in 181# either a list of None or target type placeholders 182# otherwise pyre would complain 183def all_gather_object_enforce_type( 184 pg: dist.ProcessGroup, 185 # pyre-fixme[2]: Parameter must have a type that does not contain `Any` 186 object_list: List[Any], 187 # pyre-fixme[2]: Parameter must have a type other than `Any` 188 obj: Any, 189 # pyre-fixme[2]: Parameter must have a type that does not contain `Any` 190 type_checker: Callable[[Any, Any], bool] = lambda x, y: type(x) == type(y), 191) -> None: 192 """ 193 Similar to plain all_gather_object but with additional type checking 194 AFTER gather is done to ensure basic consistency. 195 If check does not pass, all ranks will fail with exception. 196 197 This is generally to prevent conditional logic leading to 198 unexpected messages being received. This is considered fatal code error, 199 but due to logic stacks this might happen implicitly in practice. 200 201 The default check does not check sub type (considered different) 202 or covariance (considered same) but users can pass in custom checker 203 if more complicated check is needed. 204 """ 205 dist.all_gather_object(object_list, obj, group=pg) 206 207 # conservative check 208 list_len = len(object_list) 209 if list_len == 0: 210 return 211 first_obj = object_list[0] 212 for i in range(1, list_len): 213 if not type_checker(first_obj, object_list[i]): 214 raise TypeError( 215 f"Object type at index {i} is {type(object_list[i])}, " 216 f"while first object type is {type(first_obj)}" 217 ) 218