• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1.. _ddp:
2
3Distributed Data Parallel
4=========================
5
6.. warning::
7  The implementation of :class:`torch.nn.parallel.DistributedDataParallel`
8  evolves over time. This design note is written based on the state as of v1.4.
9
10
11:class:`torch.nn.parallel.DistributedDataParallel` (DDP) transparently performs
12distributed data parallel training. This page describes how it works and reveals
13implementation details.
14
15Example
16^^^^^^^
17
18Let us start with a simple :class:`torch.nn.parallel.DistributedDataParallel`
19example. This example uses a :class:`torch.nn.Linear` as the local model, wraps
20it with DDP, and then runs one forward pass, one backward pass, and an optimizer
21step on the DDP model. After that, parameters on the local model will be
22updated, and all models on different processes should be exactly the same.
23
24.. code::
25
26    import torch
27    import torch.distributed as dist
28    import torch.multiprocessing as mp
29    import torch.nn as nn
30    import torch.optim as optim
31    import os
32    from torch.nn.parallel import DistributedDataParallel as DDP
33
34
35    def example(rank, world_size):
36        # create default process group
37        dist.init_process_group("gloo", rank=rank, world_size=world_size)
38        # create local model
39        model = nn.Linear(10, 10).to(rank)
40        # construct DDP model
41        ddp_model = DDP(model, device_ids=[rank])
42        # define loss function and optimizer
43        loss_fn = nn.MSELoss()
44        optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
45
46        # forward pass
47        outputs = ddp_model(torch.randn(20, 10).to(rank))
48        labels = torch.randn(20, 10).to(rank)
49        # backward pass
50        loss_fn(outputs, labels).backward()
51        # update parameters
52        optimizer.step()
53
54    def main():
55        world_size = 2
56        mp.spawn(example,
57            args=(world_size,),
58            nprocs=world_size,
59            join=True)
60
61    if __name__=="__main__":
62        # Environment variables which need to be
63        # set when using c10d's default "env"
64        # initialization mode.
65        os.environ["MASTER_ADDR"] = "localhost"
66        os.environ["MASTER_PORT"] = "29500"
67        main()
68
69DDP works with TorchDynamo.  When used with TorchDynamo, apply the DDP model wrapper
70before compiling the model, such that torchdynamo can apply ``DDPOptimizer``
71(graph-break optimizations) based on DDP bucket sizes.  (See `TorchDynamo DDPOptimizer <./ddp.html#torchdynamo-ddpoptimizer>`_ for more information.)
72
73
74.. code::
75
76        ddp_model = DDP(model, device_ids=[rank])
77        ddp_model = torch.compile(ddp_model)
78
79Internal Design
80^^^^^^^^^^^^^^^
81
82This section reveals how it works under the hood of
83:class:`torch.nn.parallel.DistributedDataParallel` by diving into details of
84every step in one iteration.
85
86- **Prerequisite**: DDP relies on c10d ``ProcessGroup`` for communications.
87  Hence, applications must create ``ProcessGroup`` instances before constructing
88  DDP.
89- **Construction**: The DDP constructor takes a reference to the local module,
90  and broadcasts ``state_dict()`` from the process with rank 0 to all other
91  processes in the group to make sure that all model replicas start from the
92  exact same state. Then, each DDP process creates a local ``Reducer``, which
93  later will take care of the gradients synchronization during the backward
94  pass. To improve communication efficiency, the ``Reducer`` organizes parameter
95  gradients into buckets, and reduces one bucket at a time. Bucket size can be
96  configured by setting the `bucket_cap_mb` argument in DDP constructor. The
97  mapping from parameter gradients to buckets is determined at the construction
98  time, based on the bucket size limit and parameter sizes. Model parameters are
99  allocated into buckets in (roughly) the reverse order of
100  ``Model.parameters()`` from the given model. The reason for using the reverse
101  order is because DDP expects gradients to become ready during the backward
102  pass in approximately that order. The figure below shows an example. Note
103  that, the ``grad0`` and ``grad1`` are in ``bucket1``, and the other two
104  gradients are in ``bucket0``. Of course, this assumption might not always
105  be true, and when that happens it could hurt DDP backward speed as the
106  ``Reducer`` cannot kick off the communication at the earliest possible time.
107  Besides bucketing, the ``Reducer`` also registers autograd hooks during
108  construction, one hook per parameter. These hooks will be triggered during
109  the backward pass when the gradient becomes ready.
110- **Forward Pass**: The DDP takes the input and passes it to the local model,
111  and then analyzes the output from the local model if
112  ``find_unused_parameters`` is set to ``True``. This mode allows running
113  backward on a subgraph of the model, and DDP finds out which parameters are
114  involved in the backward pass by traversing the autograd graph from the model
115  output and marking all unused parameters as ready for reduction. During the
116  backward pass, the ``Reducer`` would only wait for unready parameters, but it
117  would still reduce all buckets. Marking a parameter gradient as ready does not
118  help DDP skip buckets as for now, but it will prevent DDP from waiting for
119  absent gradients forever during the backward pass. Note that traversing the
120  autograd graph introduces extra overheads, so applications should only set
121  ``find_unused_parameters`` to ``True`` when necessary.
122- **Backward Pass**: The ``backward()`` function is directly invoked on the loss
123  ``Tensor``, which is out of DDP's control, and DDP uses autograd hooks
124  registered at construction time to trigger gradients synchronizations. When
125  one gradient becomes ready, its corresponding DDP hook on that grad
126  accumulator will fire, and DDP will then mark that parameter gradient as
127  ready for reduction. When gradients in one bucket are all ready, the
128  ``Reducer`` kicks off an asynchronous ``allreduce`` on that bucket to
129  calculate mean of gradients across all processes. When all buckets are ready,
130  the ``Reducer`` will block waiting for all ``allreduce`` operations to finish.
131  When this is done, averaged gradients are written to the ``param.grad`` field
132  of all parameters. So after the backward pass, the `grad` field on the same
133  corresponding parameter across different DDP processes should be the same.
134- **Optimizer Step**: From the optimizer's perspective, it is optimizing a local
135  model. Model replicas on all DDP processes can keep in sync because they all
136  start from the same state and they have the same averaged gradients in
137  every iteration.
138
139
140.. image:: https://user-images.githubusercontent.com/16999635/72401724-d296d880-371a-11ea-90ab-737f86543df9.png
141    :alt: ddp_grad_sync.png
142    :width: 700 px
143
144.. note::
145  DDP requires ``Reducer`` instances on all processes to invoke ``allreduce``
146  in exactly the same order, which is done by always running ``allreduce``
147  in the bucket index order instead of actual bucket ready order. Mismatched
148  ``allreduce`` order across processes can lead to wrong results or DDP backward
149  hang.
150
151Implementation
152^^^^^^^^^^^^^^
153
154Below are pointers to the DDP implementation components. The stacked graph shows
155the structure of the code.
156
157ProcessGroup
158------------
159
160- `ProcessGroup.hpp <https://github.com/pytorch/pytorch/blob/v1.7.0/torch/lib/c10d/ProcessGroup.hpp>`__:
161  contains the abstract API of all process group implementations. The ``c10d``
162  library provides 3 implementations out of the box, namely,
163  `ProcessGroupGloo`, `ProcessGroupNCCL`, and `ProcessGroupMPI`.
164  ``DistributedDataParallel`` uses ``ProcessGroup::broadcast()`` to send
165  model states from the process with rank 0 to others during initialization
166  and ``ProcessGroup::allreduce()`` to sum gradients.
167
168
169- `Store.hpp <https://github.com/pytorch/pytorch/blob/v1.7.0/torch/lib/c10d/Store.hpp>`__:
170  assists the rendezvous service for process group instances to find each other.
171
172DistributedDataParallel
173-----------------------
174
175- `distributed.py <https://github.com/pytorch/pytorch/blob/v1.7.0/torch/nn/parallel/distributed.py>`__:
176  is the Python entry point for DDP. It implements the initialization steps and
177  the ``forward`` function for the ``nn.parallel.DistributedDataParallel``
178  module which call into C++ libraries. Its ``_sync_param`` function performs
179  intra-process parameter synchronization when one DDP process works on multiple
180  devices, and it also broadcasts model buffers from the process with rank 0 to
181  all other processes. The inter-process parameter synchronization happens in
182  ``Reducer.cpp``.
183
184- `comm.h <https://github.com/pytorch/pytorch/blob/v1.7.0/torch/csrc/distributed/c10d/comm.h>`__:
185  implements the coalesced broadcast helper function which is invoked to
186  broadcast model states during initialization and synchronize model buffers
187  before the forward pass.
188
189- `reducer.h <https://github.com/pytorch/pytorch/blob/v1.7.0/torch/csrc/distributed/c10d/reducer.h>`__:
190  provides the core implementation for gradient synchronization in the backward
191  pass. It has three entry point functions:
192
193  * ``Reducer``: The constructor is called in ``distributed.py`` which registers
194    ``Reducer::autograd_hook()`` to gradient accumulators.
195  * ``autograd_hook()`` function will be invoked by the autograd engine when
196    a gradient becomes ready.
197  * ``prepare_for_backward()`` is called at the end of DDP forward pass in
198    ``distributed.py``. It traverses the autograd graph to find unused
199    parameters when ``find_unused_parameters`` is set to ``True`` in DDP
200    constructor.
201
202.. image:: https://user-images.githubusercontent.com/16999635/72313120-4e7c1c80-3658-11ea-9c6d-44336b2daeac.png
203    :alt: ddp_code.png
204    :width: 400 px
205
206
207TorchDynamo DDPOptimizer
208------------------------
209
210DDP's performance advantage comes from overlapping allreduce collectives with computations during backwards.
211AotAutograd prevents this overlap when used with TorchDynamo for compiling a whole forward and whole backward graph,
212because allreduce ops are launched by autograd hooks _after_ the whole optimized backwards computation finishes.
213
214TorchDynamo's DDPOptimizer helps by breaking the forward graph at the logical boundaries of DDP's allreduce buckets
215during backwards.  Note: the goal is to break the graph during backwards, and the simplest implementation is to
216break the forward graphs and then call AotAutograd and compilation on each section.  This allows DDP's allreduce hooks
217to fire in-between sections of backwards, and schedule communications to overlap with compute.
218
219See `this blog post <https://dev-discuss.pytorch.org/t/torchdynamo-update-9-making-ddp-work-with-torchdynamo/860/1>`_ for
220a more in-depth explanation and experimental results, or read the docs and code at
221`torch/_dynamo/optimizations/distributed.py <https://github.com/pytorch/pytorch/blob/bbc39b7bb48d28d67e3253a89cc82df3687ddd1b/torch/_dynamo/backends/distributed.py#L124>`_
222
223To Debug DDPOptimizer, set `TORCH_LOGS='ddp_graphs'` for full graph dumps. For logs without graphs, add any of 'dynamo', 'distributed', or 'dist_ddp' to  `TORCH_LOGS`
224(for basic info about bucket boundaries).  To disable DDPOptimizer, set `torch._dynamo.config.optimize_ddp=False`.
225DDP and TorchDynamo should still work correctly without DDPOptimizer, but with performance degradation.
226