1#!/usr/bin/env python3 2 3# Copyright (c) Facebook, Inc. and its affiliates. 4# All rights reserved. 5# 6# This source code is licensed under the BSD-style license found in the 7# LICENSE file in the root directory of this source tree. 8 9""" 10Library that launches and manages ``n`` copies of worker subprocesses either specified by a function or a binary. 11 12For functions, it uses ``torch.multiprocessing`` (and therefore python 13``multiprocessing``) to spawn/fork worker processes. For binaries it uses python 14``subprocessing.Popen`` to create worker processes. 15 16 17Usage 1: Launching two trainers as a function 18 19:: 20 21 from torch.distributed.elastic.multiprocessing import Std, start_processes 22 23 def trainer(a, b, c): 24 pass # train 25 26 27 # runs two trainers 28 # LOCAL_RANK=0 trainer(1,2,3) 29 # LOCAL_RANK=1 trainer(4,5,6) 30 ctx = start_processes( 31 name="trainer", 32 entrypoint=trainer, 33 args={0: (1,2,3), 1: (4,5,6)}, 34 envs={0: {"LOCAL_RANK": 0}, 1: {"LOCAL_RANK": 1}}, 35 log_dir="/tmp/foobar", 36 redirects=Std.ALL, # write all worker stdout/stderr to a log file 37 tee={0: Std.ERR}, # tee only local rank 0's stderr to console 38 ) 39 40 # waits for all copies of trainer to finish 41 ctx.wait() 42 43Usage 2: Launching 2 echo workers as a binary 44 45:: 46 47 # same as invoking 48 # echo hello 49 # echo world > stdout.log 50 ctx = start_processes( 51 name="echo" 52 entrypoint="echo", 53 log_dir="/tmp/foobar", 54 args={0: "hello", 1: "world"}, 55 redirects={1: Std.OUT}, 56 ) 57 58Just like ``torch.multiprocessing``, the return value of the function 59:func:`start_processes` is a process context (:class:`api.PContext`). If a function 60was launched, a :class:`api.MultiprocessContext` is returned and if a binary 61was launched a :class:`api.SubprocessContext` is returned. Both are specific 62implementations of the parent :class:`api.PContext` class. 63""" 64 65from typing import Callable, Dict, Optional, Tuple, Union 66 67from torch.distributed.elastic.multiprocessing.api import ( # noqa: F401 68 _validate_full_rank, 69 DefaultLogsSpecs, 70 LogsDest, 71 LogsSpecs, 72 MultiprocessContext, 73 PContext, 74 ProcessFailure, 75 RunProcsResult, 76 SignalException, 77 Std, 78 SubprocessContext, 79 to_map, 80) 81from torch.distributed.elastic.utils.logging import get_logger 82 83 84__all__ = [ 85 "start_processes", 86 "MultiprocessContext", 87 "PContext", 88 "ProcessFailure", 89 "RunProcsResult", 90 "SignalException", 91 "Std", 92 "LogsDest", 93 "LogsSpecs", 94 "DefaultLogsSpecs", 95 "SubprocessContext", 96 "to_map", 97] 98 99 100def start_processes( 101 name: str, 102 entrypoint: Union[Callable, str], 103 args: Dict[int, Tuple], 104 envs: Dict[int, Dict[str, str]], 105 logs_specs: LogsSpecs, 106 log_line_prefixes: Optional[Dict[int, str]] = None, 107 start_method: str = "spawn", 108) -> PContext: 109 """ 110 Start ``n`` copies of ``entrypoint`` processes with the provided options. 111 112 ``entrypoint`` is either a ``Callable`` (function) or a ``str`` (binary). 113 The number of copies is determined by the number of entries for ``args`` and 114 ``envs`` arguments, which need to have the same key set. 115 116 ``args`` and ``env`` parameters are the arguments and environment variables 117 to pass down to the entrypoint mapped by the replica index (local rank). 118 All local ranks must be accounted for. 119 That is, the keyset should be ``{0,1,...,(nprocs-1)}``. 120 121 .. note:: When the ``entrypoint`` is a binary (``str``), ``args`` can only be strings. 122 If any other type is given, then it is casted to a string representation 123 (e.g. ``str(arg1)``). Furthermore, a binary failure will only write 124 an ``error.json`` error file if the main function is annotated with 125 ``torch.distributed.elastic.multiprocessing.errors.record``. For function launches, 126 this is done by default and there is no need to manually annotate 127 with the ``@record`` annotation. 128 129 ``redirects`` and ``tee`` are bitmasks specifying which std stream(s) to redirect 130 to a log file in the ``log_dir``. Valid mask values are defined in ``Std``. 131 To redirect/tee only certain local ranks, pass ``redirects`` as a map with the key as 132 the local rank to specify the redirect behavior for. 133 Any missing local ranks will default to ``Std.NONE``. 134 135 ``tee`` acts like the unix "tee" command in that it redirects + prints to console. 136 To avoid worker stdout/stderr from printing to console, use the ``redirects`` parameter. 137 138 For each process, the ``log_dir`` will contain: 139 140 #. ``{local_rank}/error.json``: if the process failed, a file with the error info 141 #. ``{local_rank}/stdout.json``: if ``redirect & STDOUT == STDOUT`` 142 #. ``{local_rank}/stderr.json``: if ``redirect & STDERR == STDERR`` 143 144 .. note:: It is expected that the ``log_dir`` exists, is empty, and is a directory. 145 146 Example: 147 :: 148 149 log_dir = "/tmp/test" 150 151 # ok; two copies of foo: foo("bar0"), foo("bar1") 152 start_processes( 153 name="trainer", 154 entrypoint=foo, 155 args:{0:("bar0",), 1:("bar1",), 156 envs:{0:{}, 1:{}}, 157 log_dir=log_dir 158 ) 159 160 # invalid; envs missing for local rank 1 161 start_processes( 162 name="trainer", 163 entrypoint=foo, 164 args:{0:("bar0",), 1:("bar1",), 165 envs:{0:{}}, 166 log_dir=log_dir 167 ) 168 169 # ok; two copies of /usr/bin/touch: touch file1, touch file2 170 start_processes( 171 name="trainer", 172 entrypoint="/usr/bin/touch", 173 args:{0:("file1",), 1:("file2",), 174 envs:{0:{}, 1:{}}, 175 log_dir=log_dir 176 ) 177 178 # caution; arguments casted to string, runs: 179 # echo "1" "2" "3" and echo "[1, 2, 3]" 180 start_processes( 181 name="trainer", 182 entrypoint="/usr/bin/echo", 183 args:{0:(1,2,3), 1:([1,2,3],), 184 envs:{0:{}, 1:{}}, 185 log_dir=log_dir 186 ) 187 188 Args: 189 name: a human readable short name that describes what the processes are 190 (used as header when tee'ing stdout/stderr outputs) 191 entrypoint: either a ``Callable`` (function) or ``cmd`` (binary) 192 args: arguments to each replica 193 envs: env vars to each replica 194 log_dir: directory used to write log files 195 start_method: multiprocessing start method (spawn, fork, forkserver) 196 ignored for binaries 197 redirects: which std streams to redirect to a log file 198 tee: which std streams to redirect + print to console 199 local_ranks_filter: which ranks' logs to print to console 200 201 """ 202 203 nprocs = len(args) 204 _validate_full_rank(args, nprocs, "args") 205 _validate_full_rank(envs, nprocs, "envs") 206 207 context: PContext 208 if isinstance(entrypoint, str): 209 context = SubprocessContext( 210 name=name, 211 entrypoint=entrypoint, 212 args=args, 213 envs=envs, 214 logs_specs=logs_specs, 215 log_line_prefixes=log_line_prefixes, 216 ) 217 else: 218 context = MultiprocessContext( 219 name=name, 220 entrypoint=entrypoint, 221 args=args, 222 envs=envs, 223 log_line_prefixes=log_line_prefixes, 224 start_method=start_method, 225 logs_specs=logs_specs, 226 ) 227 228 try: 229 context.start() 230 return context 231 except Exception: 232 context.close() 233 raise 234