1.. _ddp: 2 3Distributed Data Parallel 4========================= 5 6.. warning:: 7 The implementation of :class:`torch.nn.parallel.DistributedDataParallel` 8 evolves over time. This design note is written based on the state as of v1.4. 9 10 11:class:`torch.nn.parallel.DistributedDataParallel` (DDP) transparently performs 12distributed data parallel training. This page describes how it works and reveals 13implementation details. 14 15Example 16^^^^^^^ 17 18Let us start with a simple :class:`torch.nn.parallel.DistributedDataParallel` 19example. This example uses a :class:`torch.nn.Linear` as the local model, wraps 20it with DDP, and then runs one forward pass, one backward pass, and an optimizer 21step on the DDP model. After that, parameters on the local model will be 22updated, and all models on different processes should be exactly the same. 23 24.. code:: 25 26 import torch 27 import torch.distributed as dist 28 import torch.multiprocessing as mp 29 import torch.nn as nn 30 import torch.optim as optim 31 import os 32 from torch.nn.parallel import DistributedDataParallel as DDP 33 34 35 def example(rank, world_size): 36 # create default process group 37 dist.init_process_group("gloo", rank=rank, world_size=world_size) 38 # create local model 39 model = nn.Linear(10, 10).to(rank) 40 # construct DDP model 41 ddp_model = DDP(model, device_ids=[rank]) 42 # define loss function and optimizer 43 loss_fn = nn.MSELoss() 44 optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) 45 46 # forward pass 47 outputs = ddp_model(torch.randn(20, 10).to(rank)) 48 labels = torch.randn(20, 10).to(rank) 49 # backward pass 50 loss_fn(outputs, labels).backward() 51 # update parameters 52 optimizer.step() 53 54 def main(): 55 world_size = 2 56 mp.spawn(example, 57 args=(world_size,), 58 nprocs=world_size, 59 join=True) 60 61 if __name__=="__main__": 62 # Environment variables which need to be 63 # set when using c10d's default "env" 64 # initialization mode. 65 os.environ["MASTER_ADDR"] = "localhost" 66 os.environ["MASTER_PORT"] = "29500" 67 main() 68 69DDP works with TorchDynamo. When used with TorchDynamo, apply the DDP model wrapper 70before compiling the model, such that torchdynamo can apply ``DDPOptimizer`` 71(graph-break optimizations) based on DDP bucket sizes. (See `TorchDynamo DDPOptimizer <./ddp.html#torchdynamo-ddpoptimizer>`_ for more information.) 72 73 74.. code:: 75 76 ddp_model = DDP(model, device_ids=[rank]) 77 ddp_model = torch.compile(ddp_model) 78 79Internal Design 80^^^^^^^^^^^^^^^ 81 82This section reveals how it works under the hood of 83:class:`torch.nn.parallel.DistributedDataParallel` by diving into details of 84every step in one iteration. 85 86- **Prerequisite**: DDP relies on c10d ``ProcessGroup`` for communications. 87 Hence, applications must create ``ProcessGroup`` instances before constructing 88 DDP. 89- **Construction**: The DDP constructor takes a reference to the local module, 90 and broadcasts ``state_dict()`` from the process with rank 0 to all other 91 processes in the group to make sure that all model replicas start from the 92 exact same state. Then, each DDP process creates a local ``Reducer``, which 93 later will take care of the gradients synchronization during the backward 94 pass. To improve communication efficiency, the ``Reducer`` organizes parameter 95 gradients into buckets, and reduces one bucket at a time. Bucket size can be 96 configured by setting the `bucket_cap_mb` argument in DDP constructor. The 97 mapping from parameter gradients to buckets is determined at the construction 98 time, based on the bucket size limit and parameter sizes. Model parameters are 99 allocated into buckets in (roughly) the reverse order of 100 ``Model.parameters()`` from the given model. The reason for using the reverse 101 order is because DDP expects gradients to become ready during the backward 102 pass in approximately that order. The figure below shows an example. Note 103 that, the ``grad0`` and ``grad1`` are in ``bucket1``, and the other two 104 gradients are in ``bucket0``. Of course, this assumption might not always 105 be true, and when that happens it could hurt DDP backward speed as the 106 ``Reducer`` cannot kick off the communication at the earliest possible time. 107 Besides bucketing, the ``Reducer`` also registers autograd hooks during 108 construction, one hook per parameter. These hooks will be triggered during 109 the backward pass when the gradient becomes ready. 110- **Forward Pass**: The DDP takes the input and passes it to the local model, 111 and then analyzes the output from the local model if 112 ``find_unused_parameters`` is set to ``True``. This mode allows running 113 backward on a subgraph of the model, and DDP finds out which parameters are 114 involved in the backward pass by traversing the autograd graph from the model 115 output and marking all unused parameters as ready for reduction. During the 116 backward pass, the ``Reducer`` would only wait for unready parameters, but it 117 would still reduce all buckets. Marking a parameter gradient as ready does not 118 help DDP skip buckets as for now, but it will prevent DDP from waiting for 119 absent gradients forever during the backward pass. Note that traversing the 120 autograd graph introduces extra overheads, so applications should only set 121 ``find_unused_parameters`` to ``True`` when necessary. 122- **Backward Pass**: The ``backward()`` function is directly invoked on the loss 123 ``Tensor``, which is out of DDP's control, and DDP uses autograd hooks 124 registered at construction time to trigger gradients synchronizations. When 125 one gradient becomes ready, its corresponding DDP hook on that grad 126 accumulator will fire, and DDP will then mark that parameter gradient as 127 ready for reduction. When gradients in one bucket are all ready, the 128 ``Reducer`` kicks off an asynchronous ``allreduce`` on that bucket to 129 calculate mean of gradients across all processes. When all buckets are ready, 130 the ``Reducer`` will block waiting for all ``allreduce`` operations to finish. 131 When this is done, averaged gradients are written to the ``param.grad`` field 132 of all parameters. So after the backward pass, the `grad` field on the same 133 corresponding parameter across different DDP processes should be the same. 134- **Optimizer Step**: From the optimizer's perspective, it is optimizing a local 135 model. Model replicas on all DDP processes can keep in sync because they all 136 start from the same state and they have the same averaged gradients in 137 every iteration. 138 139 140.. image:: https://user-images.githubusercontent.com/16999635/72401724-d296d880-371a-11ea-90ab-737f86543df9.png 141 :alt: ddp_grad_sync.png 142 :width: 700 px 143 144.. note:: 145 DDP requires ``Reducer`` instances on all processes to invoke ``allreduce`` 146 in exactly the same order, which is done by always running ``allreduce`` 147 in the bucket index order instead of actual bucket ready order. Mismatched 148 ``allreduce`` order across processes can lead to wrong results or DDP backward 149 hang. 150 151Implementation 152^^^^^^^^^^^^^^ 153 154Below are pointers to the DDP implementation components. The stacked graph shows 155the structure of the code. 156 157ProcessGroup 158------------ 159 160- `ProcessGroup.hpp <https://github.com/pytorch/pytorch/blob/v1.7.0/torch/lib/c10d/ProcessGroup.hpp>`__: 161 contains the abstract API of all process group implementations. The ``c10d`` 162 library provides 3 implementations out of the box, namely, 163 `ProcessGroupGloo`, `ProcessGroupNCCL`, and `ProcessGroupMPI`. 164 ``DistributedDataParallel`` uses ``ProcessGroup::broadcast()`` to send 165 model states from the process with rank 0 to others during initialization 166 and ``ProcessGroup::allreduce()`` to sum gradients. 167 168 169- `Store.hpp <https://github.com/pytorch/pytorch/blob/v1.7.0/torch/lib/c10d/Store.hpp>`__: 170 assists the rendezvous service for process group instances to find each other. 171 172DistributedDataParallel 173----------------------- 174 175- `distributed.py <https://github.com/pytorch/pytorch/blob/v1.7.0/torch/nn/parallel/distributed.py>`__: 176 is the Python entry point for DDP. It implements the initialization steps and 177 the ``forward`` function for the ``nn.parallel.DistributedDataParallel`` 178 module which call into C++ libraries. Its ``_sync_param`` function performs 179 intra-process parameter synchronization when one DDP process works on multiple 180 devices, and it also broadcasts model buffers from the process with rank 0 to 181 all other processes. The inter-process parameter synchronization happens in 182 ``Reducer.cpp``. 183 184- `comm.h <https://github.com/pytorch/pytorch/blob/v1.7.0/torch/csrc/distributed/c10d/comm.h>`__: 185 implements the coalesced broadcast helper function which is invoked to 186 broadcast model states during initialization and synchronize model buffers 187 before the forward pass. 188 189- `reducer.h <https://github.com/pytorch/pytorch/blob/v1.7.0/torch/csrc/distributed/c10d/reducer.h>`__: 190 provides the core implementation for gradient synchronization in the backward 191 pass. It has three entry point functions: 192 193 * ``Reducer``: The constructor is called in ``distributed.py`` which registers 194 ``Reducer::autograd_hook()`` to gradient accumulators. 195 * ``autograd_hook()`` function will be invoked by the autograd engine when 196 a gradient becomes ready. 197 * ``prepare_for_backward()`` is called at the end of DDP forward pass in 198 ``distributed.py``. It traverses the autograd graph to find unused 199 parameters when ``find_unused_parameters`` is set to ``True`` in DDP 200 constructor. 201 202.. image:: https://user-images.githubusercontent.com/16999635/72313120-4e7c1c80-3658-11ea-9c6d-44336b2daeac.png 203 :alt: ddp_code.png 204 :width: 400 px 205 206 207TorchDynamo DDPOptimizer 208------------------------ 209 210DDP's performance advantage comes from overlapping allreduce collectives with computations during backwards. 211AotAutograd prevents this overlap when used with TorchDynamo for compiling a whole forward and whole backward graph, 212because allreduce ops are launched by autograd hooks _after_ the whole optimized backwards computation finishes. 213 214TorchDynamo's DDPOptimizer helps by breaking the forward graph at the logical boundaries of DDP's allreduce buckets 215during backwards. Note: the goal is to break the graph during backwards, and the simplest implementation is to 216break the forward graphs and then call AotAutograd and compilation on each section. This allows DDP's allreduce hooks 217to fire in-between sections of backwards, and schedule communications to overlap with compute. 218 219See `this blog post <https://dev-discuss.pytorch.org/t/torchdynamo-update-9-making-ddp-work-with-torchdynamo/860/1>`_ for 220a more in-depth explanation and experimental results, or read the docs and code at 221`torch/_dynamo/optimizations/distributed.py <https://github.com/pytorch/pytorch/blob/bbc39b7bb48d28d67e3253a89cc82df3687ddd1b/torch/_dynamo/backends/distributed.py#L124>`_ 222 223To Debug DDPOptimizer, set `TORCH_LOGS='ddp_graphs'` for full graph dumps. For logs without graphs, add any of 'dynamo', 'distributed', or 'dist_ddp' to `TORCH_LOGS` 224(for basic info about bucket boundaries). To disable DDPOptimizer, set `torch._dynamo.config.optimize_ddp=False`. 225DDP and TorchDynamo should still work correctly without DDPOptimizer, but with performance degradation. 226