1#!/usr/bin/env python3 2# mypy: allow-untyped-defs 3 4# Copyright (c) Facebook, Inc. and its affiliates. 5# All rights reserved. 6# 7# This source code is licensed under the BSD-style license found in the 8# LICENSE file in the root directory of this source tree. 9 10from contextlib import contextmanager 11from datetime import timedelta 12from typing import Callable, Iterable, List, Optional 13 14import torch 15 16 17DistStoreError = torch._C._DistStoreError 18 19_NUM_MEMBERS = "/num_members" 20_LAST_MEMBER_CHECKIN = "/last_member" 21_TRACE = "/TRACE" 22_TRACING_GATE = "/TRACING_GATE" 23_MAX_TRACE_MISSING_RANKS = 16 24 25 26__all__ = ["store_timeout", "get_all", "synchronize", "barrier"] 27 28 29@contextmanager 30def store_timeout(store, timeout: float): 31 """ 32 This sets the timeout and then restores the old timeout when the context 33 manager exits. 34 35 Args: 36 store: the store to set the timeout on 37 timeout: the timeout to set 38 """ 39 40 old_timeout = store.timeout 41 store.set_timeout(timedelta(seconds=timeout)) 42 yield 43 store.set_timeout(old_timeout) 44 45 46def get_all(store, rank: int, prefix: str, world_size: int): 47 r""" 48 Given a store and a prefix, the method goes through the array of keys 49 of the following format: ``{prefix}{idx}``, where idx is in a range 50 from 0 to size, and tries to retrieve the data. 51 52 The Rank0 process waits at the end to make sure all other processes 53 finished the procedure before exiting. 54 55 Usage 56 57 :: 58 59 values = get_all(store, 'torchelastic/data', 3) 60 value1 = values[0] # retrieves the data for key torchelastic/data0 61 value2 = values[1] # retrieves the data for key torchelastic/data1 62 value3 = values[2] # retrieves the data for key torchelastic/data2 63 64 """ 65 data_arr = store.multi_get([f"{prefix}{idx}" for idx in range(world_size)]) 66 67 barrier_key = _barrier_nonblocking( 68 store=store, 69 world_size=world_size, 70 key_prefix=f"{prefix}/finished", 71 ) 72 if rank == 0: 73 # Rank0 runs the TCPStore daemon, as a result it needs to exit last. 74 # Otherwise, the barrier may timeout if rank0 process finished the work 75 # before other processes finished `get_all` method 76 store.wait([barrier_key]) 77 78 return data_arr 79 80 81def synchronize( 82 store, 83 data: bytes, 84 rank: int, 85 world_size: int, 86 key_prefix: str, 87 timeout: float = 300, 88) -> List[bytes]: 89 """ 90 Synchronizes ``world_size`` agents between each other using the underlying c10d store. 91 The ``data`` will be available on each of the agents. 92 93 Note: The data on the path is not deleted, as a result there can be stale data if 94 you use the same key_prefix twice. 95 96 Time complexity: O(N) per worker, O(N^2) globally. 97 """ 98 with store_timeout(store, timeout): 99 store.set(f"{key_prefix}{rank}", data) 100 agent_data = get_all(store, rank, key_prefix, world_size) 101 return agent_data 102 103 104def _try_detecting_missing_ranks( 105 store, 106 world_size: int, 107 key_prefix: str, 108 rank: int, 109 rank_decoder: Callable[[int], str], 110 trace_timeout: float, 111) -> Optional[Iterable[str]]: 112 store.set(f"{key_prefix}{rank}{_TRACE}", "<val_ignored>") 113 114 def _find_missing_ranks(): 115 missing_rank_info = set() 116 ranks_missing = 0 117 for i in range(1, world_size): 118 # reduce noise, assuming in general 8 ranks per node 119 # It is valuable to know that 1 or >1 nodes have timed-out. 120 if ranks_missing >= _MAX_TRACE_MISSING_RANKS: 121 break 122 try: 123 if ranks_missing == 0: 124 store.wait( 125 [f"{key_prefix}{i}{_TRACE}"], timedelta(seconds=trace_timeout) 126 ) 127 else: 128 # use a shortest timeout, some ranks have failed to check-in 129 store.wait([f"{key_prefix}{i}{_TRACE}"], timedelta(milliseconds=1)) 130 except DistStoreError: 131 ranks_missing += 1 132 missing_rank_info.add(rank_decoder(i)) 133 return missing_rank_info 134 135 def _checkin(): 136 try: 137 store.wait([f"{key_prefix}{_TRACING_GATE}"]) 138 return [f"[<check rank 0 ({rank_decoder(0)}) for missing rank info>]"] 139 except DistStoreError: 140 # in case rank0 is the source of the timeout, original exception will be raised 141 return None 142 143 if rank == 0: 144 missing_rank_info = _find_missing_ranks() 145 store.set(f"{key_prefix}{_TRACING_GATE}", "<val_ignored>") 146 return missing_rank_info 147 else: 148 return _checkin() 149 150 151def _barrier_nonblocking(store, world_size: int, key_prefix: str) -> str: 152 """ 153 Does all the non-blocking operations for a barrier and returns the final key 154 that can be waited on. 155 """ 156 num_members_key = key_prefix + _NUM_MEMBERS 157 last_member_key = key_prefix + _LAST_MEMBER_CHECKIN 158 159 idx = store.add(num_members_key, 1) 160 if idx == world_size: 161 store.set(last_member_key, "<val_ignored>") 162 163 return last_member_key 164 165 166def barrier( 167 store, 168 world_size: int, 169 key_prefix: str, 170 barrier_timeout: float = 300, 171 rank: Optional[int] = None, 172 rank_tracing_decoder: Optional[Callable[[int], str]] = None, 173 trace_timeout: float = 10, 174) -> None: 175 """ 176 A global lock between agents. This will pause all workers until at least 177 ``world_size`` workers respond. 178 179 This uses a fast incrementing index to assign waiting ranks and a success 180 flag set by the last worker. 181 182 Time complexity: O(1) per worker, O(N) globally. 183 184 Optionally, passing rank will enable tracing of missing ranks on timeouts. 185 `rank_tracing_decoder` lambda arg can be used to convert rank data 186 into a more meaninful information at an app level (e.g. hostname). 187 188 Note: Since the data is not removed from the store, the barrier can be used 189 once per unique ``key_prefix``. 190 """ 191 192 if rank is None: 193 assert rank_tracing_decoder is None, "Tracing requires rank information" 194 195 with store_timeout(store, barrier_timeout): 196 last_member_key = _barrier_nonblocking( 197 store=store, world_size=world_size, key_prefix=key_prefix 198 ) 199 try: 200 store.wait([last_member_key]) 201 except DistStoreError as e: 202 if rank is None: 203 raise e 204 else: 205 missing_ranks = _try_detecting_missing_ranks( 206 store, 207 world_size, 208 key_prefix, 209 rank, 210 rank_tracing_decoder or (lambda x: str(x)), 211 trace_timeout, 212 ) 213 if missing_ranks is not None: 214 raise DistStoreError( 215 "Timed out waiting on barrier on " 216 "rank {}, for key prefix: {} (world_size={}, missing_ranks={}, timeout={})".format( 217 rank, 218 key_prefix, 219 world_size, 220 f"[{', '.join(missing_ranks)}]", 221 barrier_timeout, 222 ) 223 ) from None 224 else: 225 raise e 226