1#!/usr/bin/env python3 2# mypy: allow-untyped-defs 3 4# Copyright (c) Facebook, Inc. and its affiliates. 5# All rights reserved. 6# 7# This source code is licensed under the BSD-style license found in the 8# LICENSE file in the root directory of this source tree. 9import atexit 10import logging 11import os 12import shlex 13import shutil 14import socket 15import subprocess 16import tempfile 17import time 18from typing import Optional, TextIO, Union 19 20 21try: 22 import etcd # type: ignore[import] 23except ModuleNotFoundError: 24 pass 25 26 27logger = logging.getLogger(__name__) 28 29 30def find_free_port(): 31 """ 32 Find a free port and binds a temporary socket to it so that the port can be "reserved" until used. 33 34 .. note:: the returned socket must be closed before using the port, 35 otherwise a ``address already in use`` error will happen. 36 The socket should be held and closed as close to the 37 consumer of the port as possible since otherwise, there 38 is a greater chance of race-condition where a different 39 process may see the port as being free and take it. 40 41 Returns: a socket binded to the reserved free port 42 43 Usage:: 44 45 sock = find_free_port() 46 port = sock.getsockname()[1] 47 sock.close() 48 use_port(port) 49 """ 50 addrs = socket.getaddrinfo( 51 host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM 52 ) 53 54 for addr in addrs: 55 family, type, proto, _, _ = addr 56 try: 57 s = socket.socket(family, type, proto) 58 s.bind(("localhost", 0)) 59 s.listen(0) 60 return s 61 except OSError as e: 62 s.close() # type: ignore[possibly-undefined] 63 print(f"Socket creation attempt failed: {e}") 64 raise RuntimeError("Failed to create a socket") 65 66 67def stop_etcd(subprocess, data_dir: Optional[str] = None): 68 if subprocess and subprocess.poll() is None: 69 logger.info("stopping etcd server") 70 subprocess.terminate() 71 subprocess.wait() 72 73 if data_dir: 74 logger.info("deleting etcd data dir: %s", data_dir) 75 shutil.rmtree(data_dir, ignore_errors=True) 76 77 78class EtcdServer: 79 """ 80 .. note:: tested on etcd server v3.4.3. 81 82 Starts and stops a local standalone etcd server on a random free 83 port. Useful for single node, multi-worker launches or testing, 84 where a sidecar etcd server is more convenient than having to 85 separately setup an etcd server. 86 87 This class registers a termination handler to shutdown the etcd 88 subprocess on exit. This termination handler is NOT a substitute for 89 calling the ``stop()`` method. 90 91 The following fallback mechanism is used to find the etcd binary: 92 93 1. Uses env var TORCHELASTIC_ETCD_BINARY_PATH 94 2. Uses ``<this file root>/bin/etcd`` if one exists 95 3. Uses ``etcd`` from ``PATH`` 96 97 Usage 98 :: 99 100 server = EtcdServer("/usr/bin/etcd", 2379, "/tmp/default.etcd") 101 server.start() 102 client = server.get_client() 103 # use client 104 server.stop() 105 106 Args: 107 etcd_binary_path: path of etcd server binary (see above for fallback path) 108 """ 109 110 def __init__(self, data_dir: Optional[str] = None): 111 self._port = -1 112 self._host = "localhost" 113 114 root = os.path.dirname(__file__) 115 default_etcd_bin = os.path.join(root, "bin/etcd") 116 self._etcd_binary_path = os.environ.get( 117 "TORCHELASTIC_ETCD_BINARY_PATH", default_etcd_bin 118 ) 119 if not os.path.isfile(self._etcd_binary_path): 120 self._etcd_binary_path = "etcd" 121 122 self._base_data_dir = ( 123 data_dir if data_dir else tempfile.mkdtemp(prefix="torchelastic_etcd_data") 124 ) 125 self._etcd_cmd = None 126 self._etcd_proc: Optional[subprocess.Popen] = None 127 128 def _get_etcd_server_process(self) -> subprocess.Popen: 129 if not self._etcd_proc: 130 raise RuntimeError( 131 "No etcd server process started. Call etcd_server.start() first" 132 ) 133 else: 134 return self._etcd_proc 135 136 def get_port(self) -> int: 137 """Return the port the server is running on.""" 138 return self._port 139 140 def get_host(self) -> str: 141 """Return the host the server is running on.""" 142 return self._host 143 144 def get_endpoint(self) -> str: 145 """Return the etcd server endpoint (host:port).""" 146 return f"{self._host}:{self._port}" 147 148 def start( 149 self, 150 timeout: int = 60, 151 num_retries: int = 3, 152 stderr: Union[int, TextIO, None] = None, 153 ) -> None: 154 """ 155 Start the server, and waits for it to be ready. When this function returns the sever is ready to take requests. 156 157 Args: 158 timeout: time (in seconds) to wait for the server to be ready 159 before giving up. 160 num_retries: number of retries to start the server. Each retry 161 will wait for max ``timeout`` before considering it as failed. 162 stderr: the standard error file handle. Valid values are 163 `subprocess.PIPE`, `subprocess.DEVNULL`, an existing file 164 descriptor (a positive integer), an existing file object, and 165 `None`. 166 167 Raises: 168 TimeoutError: if the server is not ready within the specified timeout 169 """ 170 curr_retries = 0 171 while True: 172 try: 173 data_dir = os.path.join(self._base_data_dir, str(curr_retries)) 174 os.makedirs(data_dir, exist_ok=True) 175 return self._start(data_dir, timeout, stderr) 176 except Exception as e: 177 curr_retries += 1 178 stop_etcd(self._etcd_proc) 179 logger.warning( 180 "Failed to start etcd server, got error: %s, retrying", str(e) 181 ) 182 if curr_retries >= num_retries: 183 shutil.rmtree(self._base_data_dir, ignore_errors=True) 184 raise 185 atexit.register(stop_etcd, self._etcd_proc, self._base_data_dir) 186 187 def _start( 188 self, data_dir: str, timeout: int = 60, stderr: Union[int, TextIO, None] = None 189 ) -> None: 190 sock = find_free_port() 191 sock_peer = find_free_port() 192 self._port = sock.getsockname()[1] 193 peer_port = sock_peer.getsockname()[1] 194 195 etcd_cmd = shlex.split( 196 " ".join( 197 [ 198 self._etcd_binary_path, 199 "--enable-v2", 200 "--data-dir", 201 data_dir, 202 "--listen-client-urls", 203 f"http://{self._host}:{self._port}", 204 "--advertise-client-urls", 205 f"http://{self._host}:{self._port}", 206 "--listen-peer-urls", 207 f"http://{self._host}:{peer_port}", 208 ] 209 ) 210 ) 211 212 logger.info("Starting etcd server: [%s]", etcd_cmd) 213 214 sock.close() 215 sock_peer.close() 216 self._etcd_proc = subprocess.Popen(etcd_cmd, close_fds=True, stderr=stderr) 217 self._wait_for_ready(timeout) 218 219 def get_client(self): 220 """Return an etcd client object that can be used to make requests to this server.""" 221 return etcd.Client( 222 host=self._host, port=self._port, version_prefix="/v2", read_timeout=10 223 ) 224 225 def _wait_for_ready(self, timeout: int = 60) -> None: 226 client = etcd.Client( 227 host=f"{self._host}", port=self._port, version_prefix="/v2", read_timeout=5 228 ) 229 max_time = time.time() + timeout 230 231 while time.time() < max_time: 232 if self._get_etcd_server_process().poll() is not None: 233 # etcd server process finished 234 exitcode = self._get_etcd_server_process().returncode 235 raise RuntimeError( 236 f"Etcd server process exited with the code: {exitcode}" 237 ) 238 try: 239 logger.info("etcd server ready. version: %s", client.version) 240 return 241 except Exception: 242 time.sleep(1) 243 raise TimeoutError("Timed out waiting for etcd server to be ready!") 244 245 def stop(self) -> None: 246 """Stop the server and cleans up auto generated resources (e.g. data dir).""" 247 logger.info("EtcdServer stop method called") 248 stop_etcd(self._etcd_proc, self._base_data_dir) 249