• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# mypy: allow-untyped-defs
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9import atexit
10import logging
11import os
12import shlex
13import shutil
14import socket
15import subprocess
16import tempfile
17import time
18from typing import Optional, TextIO, Union
19
20
21try:
22    import etcd  # type: ignore[import]
23except ModuleNotFoundError:
24    pass
25
26
27logger = logging.getLogger(__name__)
28
29
30def find_free_port():
31    """
32    Find a free port and binds a temporary socket to it so that the port can be "reserved" until used.
33
34    .. note:: the returned socket must be closed before using the port,
35              otherwise a ``address already in use`` error will happen.
36              The socket should be held and closed as close to the
37              consumer of the port as possible since otherwise, there
38              is a greater chance of race-condition where a different
39              process may see the port as being free and take it.
40
41    Returns: a socket binded to the reserved free port
42
43    Usage::
44
45    sock = find_free_port()
46    port = sock.getsockname()[1]
47    sock.close()
48    use_port(port)
49    """
50    addrs = socket.getaddrinfo(
51        host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
52    )
53
54    for addr in addrs:
55        family, type, proto, _, _ = addr
56        try:
57            s = socket.socket(family, type, proto)
58            s.bind(("localhost", 0))
59            s.listen(0)
60            return s
61        except OSError as e:
62            s.close()  # type: ignore[possibly-undefined]
63            print(f"Socket creation attempt failed: {e}")
64    raise RuntimeError("Failed to create a socket")
65
66
67def stop_etcd(subprocess, data_dir: Optional[str] = None):
68    if subprocess and subprocess.poll() is None:
69        logger.info("stopping etcd server")
70        subprocess.terminate()
71        subprocess.wait()
72
73    if data_dir:
74        logger.info("deleting etcd data dir: %s", data_dir)
75        shutil.rmtree(data_dir, ignore_errors=True)
76
77
78class EtcdServer:
79    """
80    .. note:: tested on etcd server v3.4.3.
81
82    Starts and stops a local standalone etcd server on a random free
83    port. Useful for single node, multi-worker launches or testing,
84    where a sidecar etcd server is more convenient than having to
85    separately setup an etcd server.
86
87    This class registers a termination handler to shutdown the etcd
88    subprocess on exit. This termination handler is NOT a substitute for
89    calling the ``stop()`` method.
90
91    The following fallback mechanism is used to find the etcd binary:
92
93    1. Uses env var TORCHELASTIC_ETCD_BINARY_PATH
94    2. Uses ``<this file root>/bin/etcd`` if one exists
95    3. Uses ``etcd`` from ``PATH``
96
97    Usage
98    ::
99
100     server = EtcdServer("/usr/bin/etcd", 2379, "/tmp/default.etcd")
101     server.start()
102     client = server.get_client()
103     # use client
104     server.stop()
105
106    Args:
107        etcd_binary_path: path of etcd server binary (see above for fallback path)
108    """
109
110    def __init__(self, data_dir: Optional[str] = None):
111        self._port = -1
112        self._host = "localhost"
113
114        root = os.path.dirname(__file__)
115        default_etcd_bin = os.path.join(root, "bin/etcd")
116        self._etcd_binary_path = os.environ.get(
117            "TORCHELASTIC_ETCD_BINARY_PATH", default_etcd_bin
118        )
119        if not os.path.isfile(self._etcd_binary_path):
120            self._etcd_binary_path = "etcd"
121
122        self._base_data_dir = (
123            data_dir if data_dir else tempfile.mkdtemp(prefix="torchelastic_etcd_data")
124        )
125        self._etcd_cmd = None
126        self._etcd_proc: Optional[subprocess.Popen] = None
127
128    def _get_etcd_server_process(self) -> subprocess.Popen:
129        if not self._etcd_proc:
130            raise RuntimeError(
131                "No etcd server process started. Call etcd_server.start() first"
132            )
133        else:
134            return self._etcd_proc
135
136    def get_port(self) -> int:
137        """Return the port the server is running on."""
138        return self._port
139
140    def get_host(self) -> str:
141        """Return the host the server is running on."""
142        return self._host
143
144    def get_endpoint(self) -> str:
145        """Return the etcd server endpoint (host:port)."""
146        return f"{self._host}:{self._port}"
147
148    def start(
149        self,
150        timeout: int = 60,
151        num_retries: int = 3,
152        stderr: Union[int, TextIO, None] = None,
153    ) -> None:
154        """
155        Start the server, and waits for it to be ready. When this function returns the sever is ready to take requests.
156
157        Args:
158            timeout: time (in seconds) to wait for the server to be ready
159                before giving up.
160            num_retries: number of retries to start the server. Each retry
161                will wait for max ``timeout`` before considering it as failed.
162            stderr: the standard error file handle. Valid values are
163                `subprocess.PIPE`, `subprocess.DEVNULL`, an existing file
164                descriptor (a positive integer), an existing file object, and
165                `None`.
166
167        Raises:
168            TimeoutError: if the server is not ready within the specified timeout
169        """
170        curr_retries = 0
171        while True:
172            try:
173                data_dir = os.path.join(self._base_data_dir, str(curr_retries))
174                os.makedirs(data_dir, exist_ok=True)
175                return self._start(data_dir, timeout, stderr)
176            except Exception as e:
177                curr_retries += 1
178                stop_etcd(self._etcd_proc)
179                logger.warning(
180                    "Failed to start etcd server, got error: %s, retrying", str(e)
181                )
182                if curr_retries >= num_retries:
183                    shutil.rmtree(self._base_data_dir, ignore_errors=True)
184                    raise
185        atexit.register(stop_etcd, self._etcd_proc, self._base_data_dir)
186
187    def _start(
188        self, data_dir: str, timeout: int = 60, stderr: Union[int, TextIO, None] = None
189    ) -> None:
190        sock = find_free_port()
191        sock_peer = find_free_port()
192        self._port = sock.getsockname()[1]
193        peer_port = sock_peer.getsockname()[1]
194
195        etcd_cmd = shlex.split(
196            " ".join(
197                [
198                    self._etcd_binary_path,
199                    "--enable-v2",
200                    "--data-dir",
201                    data_dir,
202                    "--listen-client-urls",
203                    f"http://{self._host}:{self._port}",
204                    "--advertise-client-urls",
205                    f"http://{self._host}:{self._port}",
206                    "--listen-peer-urls",
207                    f"http://{self._host}:{peer_port}",
208                ]
209            )
210        )
211
212        logger.info("Starting etcd server: [%s]", etcd_cmd)
213
214        sock.close()
215        sock_peer.close()
216        self._etcd_proc = subprocess.Popen(etcd_cmd, close_fds=True, stderr=stderr)
217        self._wait_for_ready(timeout)
218
219    def get_client(self):
220        """Return an etcd client object that can be used to make requests to this server."""
221        return etcd.Client(
222            host=self._host, port=self._port, version_prefix="/v2", read_timeout=10
223        )
224
225    def _wait_for_ready(self, timeout: int = 60) -> None:
226        client = etcd.Client(
227            host=f"{self._host}", port=self._port, version_prefix="/v2", read_timeout=5
228        )
229        max_time = time.time() + timeout
230
231        while time.time() < max_time:
232            if self._get_etcd_server_process().poll() is not None:
233                # etcd server process finished
234                exitcode = self._get_etcd_server_process().returncode
235                raise RuntimeError(
236                    f"Etcd server process exited with the code: {exitcode}"
237                )
238            try:
239                logger.info("etcd server ready. version: %s", client.version)
240                return
241            except Exception:
242                time.sleep(1)
243        raise TimeoutError("Timed out waiting for etcd server to be ready!")
244
245    def stop(self) -> None:
246        """Stop the server and cleans up auto generated resources (e.g. data dir)."""
247        logger.info("EtcdServer stop method called")
248        stop_etcd(self._etcd_proc, self._base_data_dir)
249