1.. _elastic_train_script: 2 3Train script 4------------- 5 6If your train script works with ``torch.distributed.launch`` it will continue 7working with ``torchrun`` with these differences: 8 91. No need to manually pass ``RANK``, ``WORLD_SIZE``, 10 ``MASTER_ADDR``, and ``MASTER_PORT``. 11 122. ``rdzv_backend`` and ``rdzv_endpoint`` can be provided. For most users 13 this will be set to ``c10d`` (see `rendezvous <rendezvous.html>`_). The default 14 ``rdzv_backend`` creates a non-elastic rendezvous where ``rdzv_endpoint`` holds 15 the master address. 16 173. Make sure you have a ``load_checkpoint(path)`` and 18 ``save_checkpoint(path)`` logic in your script. When any number of 19 workers fail we restart all the workers with the same program 20 arguments so you will lose progress up to the most recent checkpoint 21 (see `elastic launch <run.html>`_). 22 234. ``use_env`` flag has been removed. If you were parsing local rank by parsing 24 the ``--local-rank`` option, you need to get the local rank from the 25 environment variable ``LOCAL_RANK`` (e.g. ``int(os.environ["LOCAL_RANK"])``). 26 27Below is an expository example of a training script that checkpoints on each 28epoch, hence the worst-case progress lost on failure is one full epoch worth 29of training. 30 31.. code-block:: python 32 33 def main(): 34 args = parse_args(sys.argv[1:]) 35 state = load_checkpoint(args.checkpoint_path) 36 initialize(state) 37 38 # torch.distributed.run ensures that this will work 39 # by exporting all the env vars needed to initialize the process group 40 torch.distributed.init_process_group(backend=args.backend) 41 42 for i in range(state.epoch, state.total_num_epochs) 43 for batch in iter(state.dataset) 44 train(batch, state.model) 45 46 state.epoch += 1 47 save_checkpoint(state) 48 49For concrete examples of torchelastic-compliant train scripts, visit 50our `examples <examples.html>`_ page. 51