• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# mypy: allow-untyped-defs
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9
10from contextlib import contextmanager
11from datetime import timedelta
12from typing import Callable, Iterable, List, Optional
13
14import torch
15
16
17DistStoreError = torch._C._DistStoreError
18
19_NUM_MEMBERS = "/num_members"
20_LAST_MEMBER_CHECKIN = "/last_member"
21_TRACE = "/TRACE"
22_TRACING_GATE = "/TRACING_GATE"
23_MAX_TRACE_MISSING_RANKS = 16
24
25
26__all__ = ["store_timeout", "get_all", "synchronize", "barrier"]
27
28
29@contextmanager
30def store_timeout(store, timeout: float):
31    """
32    This sets the timeout and then restores the old timeout when the context
33    manager exits.
34
35    Args:
36        store: the store to set the timeout on
37        timeout: the timeout to set
38    """
39
40    old_timeout = store.timeout
41    store.set_timeout(timedelta(seconds=timeout))
42    yield
43    store.set_timeout(old_timeout)
44
45
46def get_all(store, rank: int, prefix: str, world_size: int):
47    r"""
48    Given a store and a prefix, the method goes through the array of keys
49    of the following format: ``{prefix}{idx}``, where idx is in a range
50    from 0 to size, and tries to retrieve the data.
51
52    The Rank0 process waits at the end to make sure all other processes
53    finished the procedure before exiting.
54
55    Usage
56
57    ::
58
59     values = get_all(store, 'torchelastic/data', 3)
60     value1 = values[0] # retrieves the data for key torchelastic/data0
61     value2 = values[1] # retrieves the data for key torchelastic/data1
62     value3 = values[2] # retrieves the data for key torchelastic/data2
63
64    """
65    data_arr = store.multi_get([f"{prefix}{idx}" for idx in range(world_size)])
66
67    barrier_key = _barrier_nonblocking(
68        store=store,
69        world_size=world_size,
70        key_prefix=f"{prefix}/finished",
71    )
72    if rank == 0:
73        # Rank0 runs the TCPStore daemon, as a result it needs to exit last.
74        # Otherwise, the barrier may timeout if rank0 process finished the work
75        # before other processes finished `get_all` method
76        store.wait([barrier_key])
77
78    return data_arr
79
80
81def synchronize(
82    store,
83    data: bytes,
84    rank: int,
85    world_size: int,
86    key_prefix: str,
87    timeout: float = 300,
88) -> List[bytes]:
89    """
90    Synchronizes ``world_size`` agents between each other using the underlying c10d store.
91    The ``data`` will be available on each of the agents.
92
93    Note: The data on the path is not deleted, as a result there can be stale data if
94        you use the same key_prefix twice.
95
96    Time complexity: O(N) per worker, O(N^2) globally.
97    """
98    with store_timeout(store, timeout):
99        store.set(f"{key_prefix}{rank}", data)
100        agent_data = get_all(store, rank, key_prefix, world_size)
101        return agent_data
102
103
104def _try_detecting_missing_ranks(
105    store,
106    world_size: int,
107    key_prefix: str,
108    rank: int,
109    rank_decoder: Callable[[int], str],
110    trace_timeout: float,
111) -> Optional[Iterable[str]]:
112    store.set(f"{key_prefix}{rank}{_TRACE}", "<val_ignored>")
113
114    def _find_missing_ranks():
115        missing_rank_info = set()
116        ranks_missing = 0
117        for i in range(1, world_size):
118            # reduce noise, assuming in general 8 ranks per node
119            # It is valuable to know that 1 or >1 nodes have timed-out.
120            if ranks_missing >= _MAX_TRACE_MISSING_RANKS:
121                break
122            try:
123                if ranks_missing == 0:
124                    store.wait(
125                        [f"{key_prefix}{i}{_TRACE}"], timedelta(seconds=trace_timeout)
126                    )
127                else:
128                    # use a shortest timeout, some ranks have failed to check-in
129                    store.wait([f"{key_prefix}{i}{_TRACE}"], timedelta(milliseconds=1))
130            except DistStoreError:
131                ranks_missing += 1
132                missing_rank_info.add(rank_decoder(i))
133        return missing_rank_info
134
135    def _checkin():
136        try:
137            store.wait([f"{key_prefix}{_TRACING_GATE}"])
138            return [f"[<check rank 0 ({rank_decoder(0)}) for missing rank info>]"]
139        except DistStoreError:
140            # in case rank0 is the source of the timeout, original exception will be raised
141            return None
142
143    if rank == 0:
144        missing_rank_info = _find_missing_ranks()
145        store.set(f"{key_prefix}{_TRACING_GATE}", "<val_ignored>")
146        return missing_rank_info
147    else:
148        return _checkin()
149
150
151def _barrier_nonblocking(store, world_size: int, key_prefix: str) -> str:
152    """
153    Does all the non-blocking operations for a barrier and returns the final key
154    that can be waited on.
155    """
156    num_members_key = key_prefix + _NUM_MEMBERS
157    last_member_key = key_prefix + _LAST_MEMBER_CHECKIN
158
159    idx = store.add(num_members_key, 1)
160    if idx == world_size:
161        store.set(last_member_key, "<val_ignored>")
162
163    return last_member_key
164
165
166def barrier(
167    store,
168    world_size: int,
169    key_prefix: str,
170    barrier_timeout: float = 300,
171    rank: Optional[int] = None,
172    rank_tracing_decoder: Optional[Callable[[int], str]] = None,
173    trace_timeout: float = 10,
174) -> None:
175    """
176    A global lock between agents. This will pause all workers until at least
177    ``world_size`` workers respond.
178
179    This uses a fast incrementing index to assign waiting ranks and a success
180    flag set by the last worker.
181
182    Time complexity: O(1) per worker, O(N) globally.
183
184    Optionally, passing rank will enable tracing of missing ranks on timeouts.
185    `rank_tracing_decoder` lambda arg can be used to convert rank data
186    into a more meaninful information at an app level (e.g. hostname).
187
188    Note: Since the data is not removed from the store, the barrier can be used
189        once per unique ``key_prefix``.
190    """
191
192    if rank is None:
193        assert rank_tracing_decoder is None, "Tracing requires rank information"
194
195    with store_timeout(store, barrier_timeout):
196        last_member_key = _barrier_nonblocking(
197            store=store, world_size=world_size, key_prefix=key_prefix
198        )
199        try:
200            store.wait([last_member_key])
201        except DistStoreError as e:
202            if rank is None:
203                raise e
204            else:
205                missing_ranks = _try_detecting_missing_ranks(
206                    store,
207                    world_size,
208                    key_prefix,
209                    rank,
210                    rank_tracing_decoder or (lambda x: str(x)),
211                    trace_timeout,
212                )
213                if missing_ranks is not None:
214                    raise DistStoreError(
215                        "Timed out waiting on barrier on "
216                        "rank {}, for key prefix: {} (world_size={}, missing_ranks={}, timeout={})".format(
217                            rank,
218                            key_prefix,
219                            world_size,
220                            f"[{', '.join(missing_ranks)}]",
221                            barrier_timeout,
222                        )
223                    ) from None
224                else:
225                    raise e
226