• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1.. _elastic_train_script:
2
3Train script
4-------------
5
6If your train script works with ``torch.distributed.launch`` it will continue
7working with ``torchrun`` with these differences:
8
91. No need to manually pass ``RANK``, ``WORLD_SIZE``,
10   ``MASTER_ADDR``, and ``MASTER_PORT``.
11
122. ``rdzv_backend`` and ``rdzv_endpoint`` can be provided. For most users
13   this will be set to ``c10d`` (see `rendezvous <rendezvous.html>`_). The default
14   ``rdzv_backend`` creates a non-elastic rendezvous where ``rdzv_endpoint`` holds
15   the master address.
16
173. Make sure you have a ``load_checkpoint(path)`` and
18   ``save_checkpoint(path)`` logic in your script. When any number of
19   workers fail we restart all the workers with the same program
20   arguments so you will lose progress up to the most recent checkpoint
21   (see `elastic launch <run.html>`_).
22
234. ``use_env`` flag has been removed. If you were parsing local rank by parsing
24   the ``--local-rank`` option, you need to get the local rank from the
25   environment variable ``LOCAL_RANK`` (e.g. ``int(os.environ["LOCAL_RANK"])``).
26
27Below is an expository example of a training script that checkpoints on each
28epoch, hence the worst-case progress lost on failure is one full epoch worth
29of training.
30
31.. code-block:: python
32
33  def main():
34       args = parse_args(sys.argv[1:])
35       state = load_checkpoint(args.checkpoint_path)
36       initialize(state)
37
38       # torch.distributed.run ensures that this will work
39       # by exporting all the env vars needed to initialize the process group
40       torch.distributed.init_process_group(backend=args.backend)
41
42       for i in range(state.epoch, state.total_num_epochs)
43            for batch in iter(state.dataset)
44                train(batch, state.model)
45
46            state.epoch += 1
47            save_checkpoint(state)
48
49For concrete examples of torchelastic-compliant train scripts, visit
50our `examples <examples.html>`_ page.
51