• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2
3# Copyright (c) Facebook, Inc. and its affiliates.
4# All rights reserved.
5#
6# This source code is licensed under the BSD-style license found in the
7# LICENSE file in the root directory of this source tree.
8
9"""
10Library that launches and manages ``n`` copies of worker subprocesses either specified by a function or a binary.
11
12For functions, it uses ``torch.multiprocessing`` (and therefore python
13``multiprocessing``) to spawn/fork worker processes. For binaries it uses python
14``subprocessing.Popen`` to create worker processes.
15
16
17Usage 1: Launching two trainers as a function
18
19::
20
21 from torch.distributed.elastic.multiprocessing import Std, start_processes
22
23 def trainer(a, b, c):
24     pass # train
25
26
27 # runs two trainers
28 # LOCAL_RANK=0 trainer(1,2,3)
29 # LOCAL_RANK=1 trainer(4,5,6)
30 ctx = start_processes(
31         name="trainer",
32         entrypoint=trainer,
33         args={0: (1,2,3), 1: (4,5,6)},
34         envs={0: {"LOCAL_RANK": 0}, 1: {"LOCAL_RANK": 1}},
35         log_dir="/tmp/foobar",
36         redirects=Std.ALL, # write all worker stdout/stderr to a log file
37         tee={0: Std.ERR}, # tee only local rank 0's stderr to console
38       )
39
40 # waits for all copies of trainer to finish
41 ctx.wait()
42
43Usage 2: Launching 2 echo workers as a binary
44
45::
46
47 # same as invoking
48 # echo hello
49 # echo world > stdout.log
50 ctx = start_processes(
51         name="echo"
52         entrypoint="echo",
53         log_dir="/tmp/foobar",
54         args={0: "hello", 1: "world"},
55         redirects={1: Std.OUT},
56        )
57
58Just like ``torch.multiprocessing``, the return value of the function
59:func:`start_processes` is a process context (:class:`api.PContext`). If a function
60was launched, a :class:`api.MultiprocessContext` is returned and if a binary
61was launched a :class:`api.SubprocessContext` is returned. Both are specific
62implementations of the parent :class:`api.PContext` class.
63"""
64
65from typing import Callable, Dict, Optional, Tuple, Union
66
67from torch.distributed.elastic.multiprocessing.api import (  # noqa: F401
68    _validate_full_rank,
69    DefaultLogsSpecs,
70    LogsDest,
71    LogsSpecs,
72    MultiprocessContext,
73    PContext,
74    ProcessFailure,
75    RunProcsResult,
76    SignalException,
77    Std,
78    SubprocessContext,
79    to_map,
80)
81from torch.distributed.elastic.utils.logging import get_logger
82
83
84__all__ = [
85    "start_processes",
86    "MultiprocessContext",
87    "PContext",
88    "ProcessFailure",
89    "RunProcsResult",
90    "SignalException",
91    "Std",
92    "LogsDest",
93    "LogsSpecs",
94    "DefaultLogsSpecs",
95    "SubprocessContext",
96    "to_map",
97]
98
99
100def start_processes(
101    name: str,
102    entrypoint: Union[Callable, str],
103    args: Dict[int, Tuple],
104    envs: Dict[int, Dict[str, str]],
105    logs_specs: LogsSpecs,
106    log_line_prefixes: Optional[Dict[int, str]] = None,
107    start_method: str = "spawn",
108) -> PContext:
109    """
110    Start ``n`` copies of ``entrypoint`` processes with the provided options.
111
112    ``entrypoint`` is either a ``Callable`` (function) or a ``str`` (binary).
113    The number of copies is determined by the number of entries for ``args`` and
114    ``envs`` arguments, which need to have the same key set.
115
116    ``args`` and ``env`` parameters are the arguments and environment variables
117    to pass down to the entrypoint mapped by the replica index (local rank).
118    All local ranks must be accounted for.
119    That is, the keyset should be ``{0,1,...,(nprocs-1)}``.
120
121    .. note:: When the ``entrypoint`` is a binary (``str``), ``args`` can only be strings.
122              If any other type is given, then it is casted to a string representation
123              (e.g. ``str(arg1)``). Furthermore, a binary failure will only write
124              an ``error.json`` error file if the main function is annotated with
125              ``torch.distributed.elastic.multiprocessing.errors.record``. For function launches,
126              this is done by default and there is no need to manually annotate
127              with the ``@record`` annotation.
128
129    ``redirects`` and ``tee`` are bitmasks specifying which std stream(s) to redirect
130    to a log file in the ``log_dir``. Valid mask values are defined in ``Std``.
131    To redirect/tee only certain local ranks, pass ``redirects`` as a map with the key as
132    the local rank to specify the redirect behavior for.
133    Any missing local ranks will default to ``Std.NONE``.
134
135    ``tee`` acts like the unix "tee" command in that it redirects + prints to console.
136    To avoid worker stdout/stderr from printing to console, use the ``redirects`` parameter.
137
138    For each process, the ``log_dir`` will contain:
139
140    #. ``{local_rank}/error.json``: if the process failed, a file with the error info
141    #. ``{local_rank}/stdout.json``: if ``redirect & STDOUT == STDOUT``
142    #. ``{local_rank}/stderr.json``: if ``redirect & STDERR == STDERR``
143
144    .. note:: It is expected that the ``log_dir`` exists, is empty, and is a directory.
145
146    Example:
147    ::
148
149     log_dir = "/tmp/test"
150
151     # ok; two copies of foo: foo("bar0"), foo("bar1")
152     start_processes(
153        name="trainer",
154        entrypoint=foo,
155        args:{0:("bar0",), 1:("bar1",),
156        envs:{0:{}, 1:{}},
157        log_dir=log_dir
158     )
159
160     # invalid; envs missing for local rank 1
161     start_processes(
162        name="trainer",
163        entrypoint=foo,
164        args:{0:("bar0",), 1:("bar1",),
165        envs:{0:{}},
166        log_dir=log_dir
167     )
168
169     # ok; two copies of /usr/bin/touch: touch file1, touch file2
170     start_processes(
171        name="trainer",
172        entrypoint="/usr/bin/touch",
173        args:{0:("file1",), 1:("file2",),
174        envs:{0:{}, 1:{}},
175        log_dir=log_dir
176      )
177
178     # caution; arguments casted to string, runs:
179     # echo "1" "2" "3" and echo "[1, 2, 3]"
180     start_processes(
181        name="trainer",
182        entrypoint="/usr/bin/echo",
183        args:{0:(1,2,3), 1:([1,2,3],),
184        envs:{0:{}, 1:{}},
185        log_dir=log_dir
186      )
187
188    Args:
189        name: a human readable short name that describes what the processes are
190              (used as header when tee'ing stdout/stderr outputs)
191        entrypoint: either a ``Callable`` (function) or ``cmd`` (binary)
192        args: arguments to each replica
193        envs: env vars to each replica
194        log_dir: directory used to write log files
195        start_method: multiprocessing start method (spawn, fork, forkserver)
196                      ignored for binaries
197        redirects: which std streams to redirect to a log file
198        tee: which std streams to redirect + print to console
199        local_ranks_filter: which ranks' logs to print to console
200
201    """
202
203    nprocs = len(args)
204    _validate_full_rank(args, nprocs, "args")
205    _validate_full_rank(envs, nprocs, "envs")
206
207    context: PContext
208    if isinstance(entrypoint, str):
209        context = SubprocessContext(
210            name=name,
211            entrypoint=entrypoint,
212            args=args,
213            envs=envs,
214            logs_specs=logs_specs,
215            log_line_prefixes=log_line_prefixes,
216        )
217    else:
218        context = MultiprocessContext(
219            name=name,
220            entrypoint=entrypoint,
221            args=args,
222            envs=envs,
223            log_line_prefixes=log_line_prefixes,
224            start_method=start_method,
225            logs_specs=logs_specs,
226        )
227
228    try:
229        context.start()
230        return context
231    except Exception:
232        context.close()
233        raise
234