• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2# Copyright (c) Facebook, Inc. and its affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8import binascii
9from base64 import b64decode, b64encode
10from typing import cast, Optional, Tuple
11
12import urllib3.exceptions  # type: ignore[import]
13from etcd import (  # type: ignore[import]
14    Client as EtcdClient,
15    EtcdAlreadyExist,
16    EtcdCompareFailed,
17    EtcdException,
18    EtcdKeyNotFound,
19    EtcdResult,
20)
21
22from torch.distributed import Store
23
24from .api import RendezvousConnectionError, RendezvousParameters, RendezvousStateError
25from .dynamic_rendezvous import RendezvousBackend, Token
26from .etcd_store import EtcdStore
27from .utils import parse_rendezvous_endpoint
28
29
30class EtcdRendezvousBackend(RendezvousBackend):
31    """Represents an etcd-based rendezvous backend.
32
33    Args:
34        client:
35            The ``etcd.Client`` instance to use to communicate with etcd.
36        run_id:
37            The run id of the rendezvous.
38        key_prefix:
39            The path under which to store the rendezvous state in etcd.
40        ttl:
41            The TTL of the rendezvous state. If not specified, defaults to two hours.
42    """
43
44    _DEFAULT_TTL = 7200  # 2 hours
45
46    _client: EtcdClient
47    _key: str
48    _ttl: int
49
50    def __init__(
51        self,
52        client: EtcdClient,
53        run_id: str,
54        key_prefix: Optional[str] = None,
55        ttl: Optional[int] = None,
56    ) -> None:
57        if not run_id:
58            raise ValueError("The run id must be a non-empty string.")
59
60        self._client = client
61
62        if key_prefix:
63            self._key = key_prefix + "/" + run_id
64        else:
65            self._key = run_id
66
67        if ttl and ttl > 0:
68            self._ttl = ttl
69        else:
70            self._ttl = self._DEFAULT_TTL
71
72    @property
73    def name(self) -> str:
74        """See base class."""
75        return "etcd-v2"
76
77    def get_state(self) -> Optional[Tuple[bytes, Token]]:
78        """See base class."""
79        try:
80            result = self._client.read(self._key)
81        except EtcdKeyNotFound:
82            return None
83        except (EtcdException, urllib3.exceptions.TimeoutError) as exc:
84            raise RendezvousConnectionError(
85                "The connection to etcd has failed. See inner exception for details."
86            ) from exc
87
88        return self._decode_state(result)
89
90    def set_state(
91        self, state: bytes, token: Optional[Token] = None
92    ) -> Optional[Tuple[bytes, Token, bool]]:
93        """See base class."""
94        base64_state = b64encode(state).decode()
95
96        kwargs = {}
97
98        def get_state():
99            result = self.get_state()
100            if result is not None:
101                tmp = *result, False
102                # Python 3.6 does not support tuple unpacking in return
103                # statements.
104                return tmp
105            return None
106
107        if token:
108            try:
109                token = int(token)
110            except ValueError:
111                return get_state()
112
113        if token:
114            kwargs["prevIndex"] = token
115        else:
116            kwargs["prevExist"] = False
117
118        try:
119            result = self._client.write(self._key, base64_state, self._ttl, **kwargs)
120        except (EtcdAlreadyExist, EtcdCompareFailed):
121            result = None
122        except (EtcdException, urllib3.exceptions.TimeoutError) as exc:
123            raise RendezvousConnectionError(
124                "The connection to etcd has failed. See inner exception for details."
125            ) from exc
126
127        if result is None:
128            return get_state()
129
130        tmp = *self._decode_state(result), True
131        return tmp
132
133    def _decode_state(self, result: EtcdResult) -> Tuple[bytes, Token]:
134        base64_state = result.value.encode()
135
136        try:
137            state = b64decode(base64_state)
138        except binascii.Error as exc:
139            raise RendezvousStateError(
140                "The state object is corrupt. See inner exception for details."
141            ) from exc
142
143        return state, result.modifiedIndex
144
145
146def _create_etcd_client(params: RendezvousParameters) -> EtcdClient:
147    host, port = parse_rendezvous_endpoint(params.endpoint, default_port=2379)
148
149    # The timeout
150    read_timeout = cast(int, params.get_as_int("read_timeout", 60))
151    if read_timeout <= 0:
152        raise ValueError("The read timeout must be a positive integer.")
153
154    # The communication protocol
155    protocol = params.get("protocol", "http").strip().lower()
156    if protocol != "http" and protocol != "https":
157        raise ValueError("The protocol must be HTTP or HTTPS.")
158
159    # The SSL client certificate
160    ssl_cert = params.get("ssl_cert")
161    if ssl_cert:
162        ssl_cert_key = params.get("ssl_cert_key")
163        if ssl_cert_key:
164            # The etcd client expects the certificate key as the second element
165            # of the `cert` tuple.
166            ssl_cert = (ssl_cert, ssl_cert_key)
167
168    # The root certificate
169    ca_cert = params.get("ca_cert")
170
171    try:
172        return EtcdClient(
173            host,
174            port,
175            read_timeout=read_timeout,
176            protocol=protocol,
177            cert=ssl_cert,
178            ca_cert=ca_cert,
179            allow_reconnect=True,
180        )
181    except (EtcdException, urllib3.exceptions.TimeoutError) as exc:
182        raise RendezvousConnectionError(
183            "The connection to etcd has failed. See inner exception for details."
184        ) from exc
185
186
187def create_backend(params: RendezvousParameters) -> Tuple[EtcdRendezvousBackend, Store]:
188    """Create a new :py:class:`EtcdRendezvousBackend` from the specified parameters.
189
190    +--------------+-----------------------------------------------------------+
191    | Parameter    | Description                                               |
192    +==============+===========================================================+
193    | read_timeout | The read timeout, in seconds, for etcd operations.        |
194    |              | Defaults to 60 seconds.                                   |
195    +--------------+-----------------------------------------------------------+
196    | protocol     | The protocol to use to communicate with etcd. Valid       |
197    |              | values are "http" and "https". Defaults to "http".        |
198    +--------------+-----------------------------------------------------------+
199    | ssl_cert     | The path to the SSL client certificate to use along with  |
200    |              | HTTPS. Defaults to ``None``.                              |
201    +--------------+-----------------------------------------------------------+
202    | ssl_cert_key | The path to the private key of the SSL client certificate |
203    |              | to use along with HTTPS. Defaults to ``None``.            |
204    +--------------+-----------------------------------------------------------+
205    | ca_cert      | The path to the rool SSL authority certificate. Defaults  |
206    |              | to ``None``.                                              |
207    +--------------+-----------------------------------------------------------+
208    """
209    client = _create_etcd_client(params)
210
211    backend = EtcdRendezvousBackend(
212        client, params.run_id, key_prefix="/torch/elastic/rendezvous"
213    )
214
215    store = EtcdStore(client, "/torch/elastic/store")
216
217    return backend, store
218