• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2# Copyright (c) Facebook, Inc. and its affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8import binascii
9import logging
10import os
11import tempfile
12from base64 import b64decode, b64encode
13from datetime import timedelta
14from typing import Any, cast, Optional, Tuple
15
16from torch.distributed import FileStore, Store, TCPStore
17from torch.distributed.elastic.events import construct_and_record_rdzv_event, NodeState
18
19from .api import (
20    RendezvousConnectionError,
21    RendezvousError,
22    RendezvousParameters,
23    RendezvousStateError,
24)
25from .dynamic_rendezvous import RendezvousBackend, Token
26from .utils import _matches_machine_hostname, parse_rendezvous_endpoint
27
28
29logger = logging.getLogger(__name__)
30
31# default port for the TCP store
32DEFAULT_PORT = 29400
33
34
35class C10dRendezvousBackend(RendezvousBackend):
36    """Represents a C10d-backed rendezvous backend.
37
38    Args:
39        store:
40            The :py:class:`torch.distributed.Store` instance to use to
41            communicate with the C10d store.
42        run_id:
43            The run id of the rendezvous.
44    """
45
46    # See the explanation in the __init__ method.
47    _NULL_SENTINEL = "Y2FuaW1hZGFt"
48
49    _store: Store
50    _key: str
51
52    def __init__(self, store: Store, run_id: str) -> None:
53        if not run_id:
54            raise ValueError("The run id must be a non-empty string.")
55
56        self._store = store
57
58        self._key = "torch.rendezvous." + run_id
59
60        # The read operation of a store blocks the caller until the specified
61        # key becomes available. This behavior makes it tricky to use a store
62        # as a regular key-value dictionary.
63        #
64        # As a workaround we initially set a sentinel value as the rendezvous
65        # state. Whenever this value gets returned we treat it as a None.
66        self._call_store("compare_set", self._key, "", self._NULL_SENTINEL)
67
68    @property
69    def name(self) -> str:
70        """See base class."""
71        return "c10d"
72
73    def get_state(self) -> Optional[Tuple[bytes, Token]]:
74        """See base class."""
75        base64_state: bytes = self._call_store("get", self._key)
76
77        return self._decode_state(base64_state)
78
79    def set_state(
80        self, state: bytes, token: Optional[Token] = None
81    ) -> Optional[Tuple[bytes, Token, bool]]:
82        """See base class."""
83        base64_state_str: str = b64encode(state).decode()
84
85        if token:
86            # Shortcut if we know for sure that the token is not valid.
87            if not isinstance(token, bytes):
88                result = self.get_state()
89                if result is not None:
90                    tmp = *result, False
91                    # Python 3.6 does not support tuple unpacking in return
92                    # statements.
93                    return tmp
94                return None
95
96            token = token.decode()
97        else:
98            token = self._NULL_SENTINEL
99
100        base64_state: bytes = self._call_store(
101            "compare_set", self._key, token, base64_state_str
102        )
103
104        state_token_pair = self._decode_state(base64_state)
105        if state_token_pair is None:
106            return None
107
108        new_state, new_token = state_token_pair
109
110        # C10d Store's compare_set method does not offer an easy way to find out
111        # whether our write attempt was successful. As a brute-force solution we
112        # perform a bitwise comparison of our local state and the remote state.
113        return new_state, new_token, new_state == state
114
115    def _call_store(self, store_op: str, *args, **kwargs) -> Any:
116        try:
117            return getattr(self._store, store_op)(*args, **kwargs)
118        except (ValueError, RuntimeError, TimeoutError) as exc:
119            raise RendezvousConnectionError(
120                "The connection to the C10d store has failed. See inner exception for details."
121            ) from exc
122
123    def _decode_state(self, base64_state: bytes) -> Optional[Tuple[bytes, Token]]:
124        if base64_state == self._NULL_SENTINEL.encode():
125            return None
126
127        try:
128            state = b64decode(base64_state)
129        except binascii.Error as exc:
130            raise RendezvousStateError(
131                "The state object is corrupt. See inner exception for details."
132            ) from exc
133
134        return state, base64_state
135
136
137def _create_tcp_store(params: RendezvousParameters) -> TCPStore:
138    host, port = parse_rendezvous_endpoint(params.endpoint, default_port=DEFAULT_PORT)
139
140    cfg_is_host = params.get_as_bool("is_host")
141    # If the user has explicitly specified whether our process should host the
142    # the store, respect it.
143    if cfg_is_host is not None:
144        is_host = cfg_is_host
145    # Otherwise try to determine whether we are the host based on our hostname
146    # and IP address.
147    else:
148        is_host = _matches_machine_hostname(host)
149
150    # The timeout
151    read_timeout = cast(int, params.get_as_int("read_timeout", 60))
152    if read_timeout <= 0:
153        raise ValueError("The read timeout must be a positive integer.")
154
155    # In specific cases we attempt to instantiate the store twice. For details
156    # see the explanation in the except clause below.
157    for is_server in [is_host, False]:
158        try:
159            store = TCPStore(
160                host,
161                port,
162                is_master=is_server,
163                multi_tenant=True,
164                timeout=timedelta(seconds=read_timeout),
165            )
166
167            if is_server:
168                msg = f"Process {os.getpid()} hosts the TCP store for the C10d rendezvous backend."
169                construct_and_record_rdzv_event(
170                    run_id=params.run_id, message=msg, node_state=NodeState.INIT
171                )
172                logger.info(msg)
173
174            break
175        except (ValueError, RuntimeError, TimeoutError) as exc:
176            # If we heuristically inferred the value of is_host as True and our
177            # first attempt to instantiate the TCP store has failed, try it one
178            # more time with is_host set to False. As an edge case there can be
179            # more than one process that is part of the same rendezvous on this
180            # machine and only one of them will eventually host the store.
181
182            if not is_server or cfg_is_host is not None:
183                raise RendezvousConnectionError(
184                    "The connection to the C10d store has failed. See inner exception for details."
185                ) from exc
186
187    return store  # type: ignore[possibly-undefined]
188
189
190def _create_file_store(params: RendezvousParameters) -> FileStore:
191    # If a user specifies an endpoint, we treat it as a path to a file.
192    if params.endpoint:
193        path = params.endpoint
194    else:
195        try:
196            # The temporary file is readable and writable only by the user of
197            # this process.
198            _, path = tempfile.mkstemp()
199        except OSError as exc:
200            raise RendezvousError(
201                "The file creation for C10d store has failed. See inner exception for details."
202            ) from exc
203
204    try:
205        store = FileStore(path)
206    except (ValueError, RuntimeError) as exc:
207        raise RendezvousConnectionError(
208            "The connection to the C10d store has failed. See inner exception for details."
209        ) from exc
210
211    return store
212
213
214def create_backend(params: RendezvousParameters) -> Tuple[C10dRendezvousBackend, Store]:
215    """Create a new :py:class:`C10dRendezvousBackend` from the specified parameters.
216
217    +--------------+-----------------------------------------------------------+
218    | Parameter    | Description                                               |
219    +==============+===========================================================+
220    | store_type   | The type of the C10d store. The currently supported types |
221    |              | are "tcp" and "file" which correspond to                  |
222    |              | :py:class:`torch.distributed.TCPStore` and                |
223    |              | :py:class:`torch.distributed.FileStore`, respectively.    |
224    |              | Defaults to "tcp".                                        |
225    +--------------+-----------------------------------------------------------+
226    | read_timeout | The read timeout, in seconds, for store operations.       |
227    |              | Defaults to 60 seconds.                                   |
228    |              |                                                           |
229    |              | Note this only applies to                                 |
230    |              | :py:class:`torch.distributed.TCPStore`. It is not relevant|
231    |              | to :py:class:`torch.distributed.FileStore` which does not |
232    |              | take in timeout as a parameter.                           |
233    +--------------+-----------------------------------------------------------+
234    | is_host      | A boolean value indicating whether this backend instance  |
235    |              | will host the C10d store. If not specified it will be     |
236    |              | inferred heuristically by matching the hostname or the IP |
237    |              | address of this machine against the specified rendezvous  |
238    |              | endpoint. Defaults to ``None``.                           |
239    |              |                                                           |
240    |              | Note that this configuration option only applies to       |
241    |              | :py:class:`torch.distributed.TCPStore`. In normal         |
242    |              | circumstances you can safely skip it; the only time when  |
243    |              | it is needed is if its value cannot be correctly          |
244    |              | determined (e.g. the rendezvous endpoint has a CNAME as   |
245    |              | the hostname or does not match the FQDN of the machine).  |
246    +--------------+-----------------------------------------------------------+
247    """
248    # As of today we only support TCPStore and FileStore. Other store types do
249    # not have the required functionality (e.g. compare_set) yet.
250    store_type = params.get("store_type", "tcp").strip().lower()
251    store: Store
252
253    try:
254        if store_type == "file":
255            store = _create_file_store(params)
256        elif store_type == "tcp":
257            store = _create_tcp_store(params)
258        else:
259            raise ValueError(
260                "Invalid store type given. Currently only supports file and tcp."
261            )
262
263        backend = C10dRendezvousBackend(store, params.run_id)
264
265    except Exception as e:
266        construct_and_record_rdzv_event(
267            message=f"{type(e).__name__}: {str(e)}",
268            run_id=params.run_id,
269            node_state=NodeState.FAILED,
270        )
271        raise
272
273    return backend, store
274