1# mypy: allow-untyped-defs 2# Copyright (c) Facebook, Inc. and its affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8import binascii 9from base64 import b64decode, b64encode 10from typing import cast, Optional, Tuple 11 12import urllib3.exceptions # type: ignore[import] 13from etcd import ( # type: ignore[import] 14 Client as EtcdClient, 15 EtcdAlreadyExist, 16 EtcdCompareFailed, 17 EtcdException, 18 EtcdKeyNotFound, 19 EtcdResult, 20) 21 22from torch.distributed import Store 23 24from .api import RendezvousConnectionError, RendezvousParameters, RendezvousStateError 25from .dynamic_rendezvous import RendezvousBackend, Token 26from .etcd_store import EtcdStore 27from .utils import parse_rendezvous_endpoint 28 29 30class EtcdRendezvousBackend(RendezvousBackend): 31 """Represents an etcd-based rendezvous backend. 32 33 Args: 34 client: 35 The ``etcd.Client`` instance to use to communicate with etcd. 36 run_id: 37 The run id of the rendezvous. 38 key_prefix: 39 The path under which to store the rendezvous state in etcd. 40 ttl: 41 The TTL of the rendezvous state. If not specified, defaults to two hours. 42 """ 43 44 _DEFAULT_TTL = 7200 # 2 hours 45 46 _client: EtcdClient 47 _key: str 48 _ttl: int 49 50 def __init__( 51 self, 52 client: EtcdClient, 53 run_id: str, 54 key_prefix: Optional[str] = None, 55 ttl: Optional[int] = None, 56 ) -> None: 57 if not run_id: 58 raise ValueError("The run id must be a non-empty string.") 59 60 self._client = client 61 62 if key_prefix: 63 self._key = key_prefix + "/" + run_id 64 else: 65 self._key = run_id 66 67 if ttl and ttl > 0: 68 self._ttl = ttl 69 else: 70 self._ttl = self._DEFAULT_TTL 71 72 @property 73 def name(self) -> str: 74 """See base class.""" 75 return "etcd-v2" 76 77 def get_state(self) -> Optional[Tuple[bytes, Token]]: 78 """See base class.""" 79 try: 80 result = self._client.read(self._key) 81 except EtcdKeyNotFound: 82 return None 83 except (EtcdException, urllib3.exceptions.TimeoutError) as exc: 84 raise RendezvousConnectionError( 85 "The connection to etcd has failed. See inner exception for details." 86 ) from exc 87 88 return self._decode_state(result) 89 90 def set_state( 91 self, state: bytes, token: Optional[Token] = None 92 ) -> Optional[Tuple[bytes, Token, bool]]: 93 """See base class.""" 94 base64_state = b64encode(state).decode() 95 96 kwargs = {} 97 98 def get_state(): 99 result = self.get_state() 100 if result is not None: 101 tmp = *result, False 102 # Python 3.6 does not support tuple unpacking in return 103 # statements. 104 return tmp 105 return None 106 107 if token: 108 try: 109 token = int(token) 110 except ValueError: 111 return get_state() 112 113 if token: 114 kwargs["prevIndex"] = token 115 else: 116 kwargs["prevExist"] = False 117 118 try: 119 result = self._client.write(self._key, base64_state, self._ttl, **kwargs) 120 except (EtcdAlreadyExist, EtcdCompareFailed): 121 result = None 122 except (EtcdException, urllib3.exceptions.TimeoutError) as exc: 123 raise RendezvousConnectionError( 124 "The connection to etcd has failed. See inner exception for details." 125 ) from exc 126 127 if result is None: 128 return get_state() 129 130 tmp = *self._decode_state(result), True 131 return tmp 132 133 def _decode_state(self, result: EtcdResult) -> Tuple[bytes, Token]: 134 base64_state = result.value.encode() 135 136 try: 137 state = b64decode(base64_state) 138 except binascii.Error as exc: 139 raise RendezvousStateError( 140 "The state object is corrupt. See inner exception for details." 141 ) from exc 142 143 return state, result.modifiedIndex 144 145 146def _create_etcd_client(params: RendezvousParameters) -> EtcdClient: 147 host, port = parse_rendezvous_endpoint(params.endpoint, default_port=2379) 148 149 # The timeout 150 read_timeout = cast(int, params.get_as_int("read_timeout", 60)) 151 if read_timeout <= 0: 152 raise ValueError("The read timeout must be a positive integer.") 153 154 # The communication protocol 155 protocol = params.get("protocol", "http").strip().lower() 156 if protocol != "http" and protocol != "https": 157 raise ValueError("The protocol must be HTTP or HTTPS.") 158 159 # The SSL client certificate 160 ssl_cert = params.get("ssl_cert") 161 if ssl_cert: 162 ssl_cert_key = params.get("ssl_cert_key") 163 if ssl_cert_key: 164 # The etcd client expects the certificate key as the second element 165 # of the `cert` tuple. 166 ssl_cert = (ssl_cert, ssl_cert_key) 167 168 # The root certificate 169 ca_cert = params.get("ca_cert") 170 171 try: 172 return EtcdClient( 173 host, 174 port, 175 read_timeout=read_timeout, 176 protocol=protocol, 177 cert=ssl_cert, 178 ca_cert=ca_cert, 179 allow_reconnect=True, 180 ) 181 except (EtcdException, urllib3.exceptions.TimeoutError) as exc: 182 raise RendezvousConnectionError( 183 "The connection to etcd has failed. See inner exception for details." 184 ) from exc 185 186 187def create_backend(params: RendezvousParameters) -> Tuple[EtcdRendezvousBackend, Store]: 188 """Create a new :py:class:`EtcdRendezvousBackend` from the specified parameters. 189 190 +--------------+-----------------------------------------------------------+ 191 | Parameter | Description | 192 +==============+===========================================================+ 193 | read_timeout | The read timeout, in seconds, for etcd operations. | 194 | | Defaults to 60 seconds. | 195 +--------------+-----------------------------------------------------------+ 196 | protocol | The protocol to use to communicate with etcd. Valid | 197 | | values are "http" and "https". Defaults to "http". | 198 +--------------+-----------------------------------------------------------+ 199 | ssl_cert | The path to the SSL client certificate to use along with | 200 | | HTTPS. Defaults to ``None``. | 201 +--------------+-----------------------------------------------------------+ 202 | ssl_cert_key | The path to the private key of the SSL client certificate | 203 | | to use along with HTTPS. Defaults to ``None``. | 204 +--------------+-----------------------------------------------------------+ 205 | ca_cert | The path to the rool SSL authority certificate. Defaults | 206 | | to ``None``. | 207 +--------------+-----------------------------------------------------------+ 208 """ 209 client = _create_etcd_client(params) 210 211 backend = EtcdRendezvousBackend( 212 client, params.run_id, key_prefix="/torch/elastic/rendezvous" 213 ) 214 215 store = EtcdStore(client, "/torch/elastic/store") 216 217 return backend, store 218