• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import weakref
3from typing import Any, Callable, List, Optional
4
5import torch
6import torch.distributed as dist
7from torch.distributed.optim import ZeroRedundancyOptimizer
8from torch.distributed.optim.zero_redundancy_optimizer import _OverlapStatus
9from torch.nn.parallel.distributed import DistributedDataParallel
10
11
12__all__ = ["hook_with_zero_step", "hook_with_zero_step_interleaved"]
13
14# Functional optimizers require passing a list of gradients to their `step()`
15# method, and ZeRO requires a functional optimizer to overlap with DDP
16# Passing a `None` instead of an actual gradient indicates to the optimizer
17# to not update the corresponding parameter
18_NO_PARAM_UPDATE: None = None
19
20
21def _perform_local_step(
22    bucket: dist.GradBucket,
23    zero: ZeroRedundancyOptimizer,
24    rank: int,
25):
26    r"""
27    Perform a local optimizer step using the gradients provided by ``bucket``.
28
29    Arguments:
30        bucket (dist.GradBucket): the bucket providing the gradients.
31        zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer`
32            instance to perform the :meth:`_local_step`.
33        rank (int): the calling process's rank.
34
35    .. warning::
36        This function assumes that appropriate synchronization has taken place
37        so that the bucket's gradients can be used.
38    """
39    overlap_info = zero._overlap_info
40    bucket_index = bucket.index()
41    assert (
42        len(zero.optim.param_groups) == 1
43    ), "Overlapping DDP with ZeRO only supports a single parameter group"
44
45    # Construct the `gradients` input for the local optimizer step, which
46    # expects `None` in a list position to indicate that the corresponding
47    # parameter should not be updated
48    num_local_optim_params = len(zero.optim.param_groups[0]["params"])
49    gradients: List[Optional[torch.Tensor]] = [
50        _NO_PARAM_UPDATE for _ in range(num_local_optim_params)
51    ]
52    assert (
53        bucket_index in overlap_info.offsets
54    ), f"Bucket index {bucket_index} was not assigned to rank {rank}"
55    gradients_offset = overlap_info.offsets[bucket_index]
56    bucket_assignment = zero._bucket_assignments_per_rank[rank][bucket_index]
57    bucket_offset = bucket_assignment.offset
58    length = len(bucket_assignment.parameters)
59    bucket_gradients = bucket.gradients()[bucket_offset : bucket_offset + length]
60    for i, grad in enumerate(bucket_gradients):
61        gradients[gradients_offset + i] = grad
62
63    zero._local_step(gradients)
64
65
66def _broadcast_bucket(
67    bucket_index: int,
68    zero: ZeroRedundancyOptimizer,
69):
70    r"""
71    Broadcasts a bucket's parameters.
72
73    Arguments:
74        bucket_index (int): the index of the bucket corresponding to the
75            parameters to broadcast.
76        zero (ZeroRedundancyOptimizer): the calling process's
77            :class:`ZeroRedundancyOptimizer` instance.
78    """
79    overlap_info = zero._overlap_info
80    assert (
81        len(overlap_info.assigned_ranks_per_bucket) > bucket_index
82    ), "`assigned_ranks_per_bucket` is not fully constructed"
83    # Sort to ensure the same ordering across ranks
84    assigned_ranks = sorted(overlap_info.assigned_ranks_per_bucket[bucket_index])
85    assert len(assigned_ranks) > 0, (
86        f"Bucket {bucket_index} should be " "assigned to at least one rank"
87    )
88    for assigned_rank in assigned_ranks:
89        bucket_assignments = zero._bucket_assignments_per_rank[assigned_rank]
90        if bucket_index in bucket_assignments:
91            overlap_info.broadcast_handles.append(
92                dist.broadcast(
93                    bucket_assignments[bucket_index].tensor,
94                    src=dist.get_global_rank(zero.process_group, assigned_rank),
95                    group=zero.process_group,
96                    async_op=True,
97                )
98            )
99
100
101def _save_ddp_bucket_info(
102    bucket: dist.GradBucket,
103    zero: ZeroRedundancyOptimizer,
104):
105    r"""
106    Save :class:`DistributedDataParallel` gradient bucket information for :class:`ZeroRedundancyOptimizer` instance ``zero``.
107
108    In particular, this function is meant to be called upon seeing each
109    gradient bucket to use when overlapping, meaning it does not save or compute any global
110    information.
111
112    Arguments:
113        bucket (dist.GradBucket): the current gradient bucket.
114        zero (ZeroRedundancyOptimizer): the calling process's
115            :class:`ZeroRedundancyOptimizer` instance.
116    """
117    overlap_info = zero._overlap_info
118    bucket_params = bucket.parameters()
119    assert len(bucket_params) > 0, "Empty bucket"
120
121    # Save the parameters in the bucket
122    overlap_info.params_per_bucket.append(bucket_params)
123    if overlap_info.shard_buckets:
124        # Additionally save the bucket size for the assignment heuristic to use
125        bucket_size = 0
126        for param in bucket_params:
127            bucket_size += param.numel()
128        assert overlap_info.total_size is not None
129        overlap_info.total_size += bucket_size
130
131
132def _hook_with_zero_step_setup(
133    ddp_ref: weakref.ReferenceType,
134    zero: ZeroRedundancyOptimizer,
135    bucket: dist.GradBucket,
136):
137    r"""
138    Encapsulate the setup logic for :func:`hook_with_zero_step` and :func:`hook_with_zero_step_interleaved`.
139
140    This means the logic to run in the
141    hook before the backward pass and optimizer step can actually be
142    overlapped. This is factored out since it is common to both
143    :func:`hook_with_zero_step` and :func:`hook_with_zero_step_interleaved`.
144
145    Arguments:
146        ddp_ref (weakref.ReferenceType): weak reference to the process's
147            :class:`DistributedDataParallel` instance.
148        zero (ZeroRedundancyOptimizer): the calling process's
149            :class:`ZeroRedundancyOptimizer` instance.
150        bucket (dist.GradBucket): the current gradient bucket.
151    """
152    # Proceed as normal until the DDP buckets have been rebuilt
153    if not ddp_ref()._has_rebuilt_buckets:  # type: ignore[union-attr]
154        assert zero._overlap_info.status == _OverlapStatus.UNINITIALIZED
155        return
156
157    bucket_index = bucket.index()
158    overlap_info = zero._overlap_info
159    if overlap_info.status == _OverlapStatus.UNINITIALIZED:
160        overlap_info.status = _OverlapStatus.DDP_HAS_REBUILT_BUCKETS
161
162    if overlap_info.status == _OverlapStatus.DDP_HAS_REBUILT_BUCKETS:
163        if bucket_index == 0 and len(overlap_info.params_per_bucket) > 0:
164            # This corresponds to the first bucket of the backward pass
165            # immediately after all information has been saved, so we
166            # can perform the delayed ZeRO initialization
167            zero._init_zero_for_overlap()
168        else:
169            # Once DDP buckets have been rebuilt but ZeRO has not been
170            # properly initialized yet, save the information needed
171            _save_ddp_bucket_info(bucket, zero)
172
173
174def hook_with_zero_step(
175    hook: Callable[[Any, dist.GradBucket], torch.futures.Future],
176    ddp: DistributedDataParallel,
177    zero: ZeroRedundancyOptimizer,
178    shard_buckets: bool = False,
179) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
180    r"""
181    Modify ``hook`` to overlap :class:`ZeroRedundancyOptimizer` optimizer step with :class:`DistributedDataParallel` backward pass.
182
183    This approach overlaps the optimizer computation and communication with the
184    backward communication. In particular, the backward computation proceeds
185    contiguously, and the optimizer computation follows, overlapping with
186    outstanding backward communication (i.e. all-reduces) and possibly other
187    optimizer communication (i.e. broadcasts).
188    The optimizer step computation begins after the last gradient bucket computation has finished.
189
190    This approach may be preferred over :meth:`hook_with_zero_step_interleaved`
191    if communication is relatively slow compared to computation.
192
193    Arguments:
194        hook (Callable[[Any, dist.GradBucket], torch.futures.Future]): the hook
195            to modify.
196        ddp (DistributedDataParallel): the :class:`DistributedDataParallel`
197            instance to use.
198        zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer`
199            instance to use.
200        shard_buckets (bool): if ``True``, then the assignment of each
201            :class:`DistributedDataParallel` bucket is partitioned across
202            possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e.
203            across possibly multiple ranks) to approximate uniformity; if
204            ``False``, then each bucket is wholly assigned to a single
205            :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank).
206
207    Returns:
208        The modified hook.
209
210    Raises:
211        ValueError: if ``zero`` was constructed with ``overlap_with_ddp=False``.
212        RuntimeError: if using any backend other than NCCL/HCCL since currently
213            Gloo may hang.
214
215    .. warning::
216        Given the way that overlapping :class:`DistributedDataParallel` with
217        :class:`ZeroRedundancyOptimizer` is currently implemented, the first
218        two or three training iterations do not perform parameter updates in
219        the optimizer step, depending on if ``static_graph=False`` or
220        ``static_graph=True``, respectively. This is because it needs
221        information about the gradient bucketing strategy used by
222        :class:`DistributedDataParallel`, which is not finalized until the
223        second forward pass if ``static_graph=False`` or until the third
224        forward pass if ``static_graph=True``.
225    """
226    if not zero._overlap_with_ddp:
227        raise ValueError(
228            "ZeroRedundancyOptimizer must be constructed with "
229            "`overlap_with_ddp=True` to use this hook properly"
230        )
231    ddp_ref = weakref.ref(ddp)
232
233    # NOTE: Gloo may hang with this overlapping approach, so we require
234    # NCCL/HCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300
235    pg = dist.get_backend(ddp_ref().process_group)  # type: ignore[union-attr]
236    if (pg != dist.Backend.NCCL) and (pg != "hccl"):
237        raise RuntimeError(
238            "Overlapping DDP with ZeRO using this approach currently requires "
239            "NCCL/HCCL backend to avoid hangs"
240        )
241
242    if shard_buckets:
243        zero._overlap_info.shard_buckets = True
244        zero._overlap_info.total_size = 0
245
246    def hook_with_zero_fn(
247        state: Any,
248        bucket: dist.GradBucket,
249    ) -> torch.futures.Future[torch.Tensor]:
250        r"""
251        Return :class:`Future` that runs the optimizer step if this corresponds to the last gradient bucket.
252
253        Perform equivalent of :class:`ZeroRedundancyOptimizer` :meth:`step` if ``bucket`` is last gradient bucket.
254        The function gives a gradient bucket tensor and
255        performs additional computation on the iteration that
256        the :class:`DistributedDataParallel` buckets are rebuilt to collect
257        information used to implement the modified hook.
258
259        Arguments:
260            state (Any): any state for the hook.
261            bucket (dist.GradBucket): the :class:`DistributedDataParallel`
262                gradient bucket.
263        """
264        fut = hook(state, bucket)
265        _hook_with_zero_step_setup(ddp_ref, zero, bucket)
266        if zero._overlap_info.status != _OverlapStatus.INITIALIZED:
267            return fut
268
269        overlap_info = zero._overlap_info
270        bucket_index = bucket.index()
271        rank = zero.global_rank
272
273        assert overlap_info.status == _OverlapStatus.INITIALIZED
274        assert (
275            len(overlap_info.assigned_ranks_per_bucket) > bucket_index
276        ), "`assigned_ranks_per_bucket` is not fully constructed"
277        assigned_to_bucket = (
278            rank in overlap_info.assigned_ranks_per_bucket[bucket_index]
279        )
280
281        # Save the bucket reference and all-reduce future for the final bucket
282        if assigned_to_bucket:
283            overlap_info.bucket_index_to_bucket[bucket_index] = bucket
284            overlap_info.bucket_index_to_future[bucket_index] = fut
285
286        # Check that buckets are indexed incrementally starting from 0 in the
287        # order of their autograd hooks firing
288        if len(overlap_info.bucket_indices_seen) > 0:
289            assert (
290                overlap_info.bucket_indices_seen[-1] == bucket_index - 1
291            ), "Bucket indices are not in incremental order"
292        else:
293            assert bucket_index == 0, "Bucket indices do not start from 0"
294        overlap_info.bucket_indices_seen.append(bucket_index)
295
296        # Directly return the future without any optimizer computation if this
297        # is not the last bucket
298        num_buckets = len(overlap_info.params_per_bucket)
299        is_last_bucket = bucket_index == num_buckets - 1
300        if not is_last_bucket:
301            return fut
302
303        # Perform partial optimizer step on all buckets after the final
304        # bucket has been computed
305        # NOTE: This should not be chained as a callback to the last bucket's
306        # all-reduce future since that would add synchronization that delays
307        # all optimizer computation to wait for that last all-reduce
308        for bucket_index in range(num_buckets):
309            assigned_ranks = overlap_info.assigned_ranks_per_bucket[bucket_index]
310            if rank in assigned_ranks:
311                # Wait on the bucket's all-reduce future to ensure correct
312                # gradients
313                assert bucket_index in overlap_info.bucket_index_to_future, (
314                    f"All-reduce future for bucket {bucket_index} not saved "
315                    f"on rank {rank}"
316                )
317                allreduce_future = overlap_info.bucket_index_to_future[bucket_index]
318                allreduce_future.wait()
319
320                # Perform the partial optimizer step
321                curr_bucket = overlap_info.bucket_index_to_bucket[bucket_index]
322                _perform_local_step(curr_bucket, zero, rank)
323
324            _broadcast_bucket(bucket_index, zero)
325
326        # Ensure that all parameter updates are finished before the
327        # next forward pass
328        overlap_info.wait_for_broadcasts()
329        overlap_info.clear_per_iter_info()
330
331        return fut
332
333    return hook_with_zero_fn
334
335
336def hook_with_zero_step_interleaved(
337    hook: Callable[[Any, dist.GradBucket], torch.futures.Future],
338    ddp: DistributedDataParallel,
339    zero: ZeroRedundancyOptimizer,
340    shard_buckets: bool = False,
341) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
342    r"""
343    Modify ``hook`` to overlap :class:`ZeroRedundancyOptimizer` optimizer step with :class:`DistributedDataParallel` backward pass
344
345    This approach overlaps the optimizer computation and communication with the
346    backward computation and communication. In particular, once a bucket's
347    gradients have been computed, the optimizer computation using those
348    gradients is launched (though the actual computation must wait for the
349    bucket's all-reduce to complete). This yields an interleaving of all-
350    reduces and broadcasts in the communication stream.
351
352    This approach may be preferred over :meth:`hook_with_zero_step` if
353    communication is relatively fast compared to computation.
354
355    Arguments:
356        hook (Any * dist.GradBucket -> torch.futures.Future): the hook to
357            modify.
358        ddp (DistributedDataParallel): the :class:`DistributedDataParallel`
359            instance to use.
360        zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer`
361            instance to use.
362        shard_buckets (bool): if ``True``, then the assignment of each
363            :class:`DistributedDataParallel` bucket is partitioned across
364            possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e.
365            across possibly multiple ranks) to approximate uniformity; if
366            ``False``, then each bucket is wholly assigned to a single
367            :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank).
368
369    Returns:
370        The modified hook.
371
372    Raises:
373        ValueError: if ``zero`` was constructed with ``overlap_with_ddp=False``.
374        RuntimeError: if using any backend other than NCCL since currently
375            Gloo may hang.
376
377    .. warning::
378        Given the way that overlapping :class:`DistributedDataParallel` with
379        :class:`ZeroRedundancyOptimizer` is currently implemented, the first
380        two or three training iterations do not perform parameter updates in
381        the optimizer step, depending on if ``static_graph=False`` or
382        ``static_graph=True``, respectively. This is because it needs
383        information about the gradient bucketing strategy used by
384        :class:`DistributedDataParallel`, which is not finalized until the
385        second forward pass if ``static_graph=False`` or until the third
386        forward pass if ``static_graph=True``.
387    """
388    if not zero._overlap_with_ddp:
389        raise ValueError(
390            "ZeroRedundancyOptimizer must be constructed with "
391            "`overlap_with_ddp=True` to use this hook properly"
392        )
393    ddp_ref = weakref.ref(ddp)
394
395    # NOTE: Gloo may hang with this overlapping approach, so we require
396    # NCCL/HCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300
397    pg = dist.get_backend(ddp_ref().process_group)  # type: ignore[union-attr]
398    if (pg != dist.Backend.NCCL) and (pg != "hccl"):
399        raise RuntimeError(
400            "Overlapping DDP with ZeRO using this approach currently requires "
401            "NCCL/HCCL backend to avoid hangs"
402        )
403
404    if shard_buckets:
405        zero._overlap_info.shard_buckets = True
406        zero._overlap_info.total_size = 0
407
408    def hook_with_zero_interleaved_fn(
409        state,
410        bucket: dist.GradBucket,
411    ) -> torch.futures.Future[torch.Tensor]:
412        r"""
413        Return :class:`Future` that gives gradient bucket tensor and performs partial :class:`ZeroRedundancyOptimizer` :meth:`step`.
414
415        This function uses the gradients in gradient in given bucket to perform a partial
416        :class:`ZeroRedundancyOptimizer` :meth:`step`
417
418        Arguments:
419            state: any state for the hook.
420            bucket (dist.GradBucket): the :class:`DistributedDataParallel`
421                gradient bucket.
422        """
423        fut = hook(state, bucket)
424        _hook_with_zero_step_setup(ddp_ref, zero, bucket)
425        if zero._overlap_info.status != _OverlapStatus.INITIALIZED:
426            return fut
427
428        def zero_step(fut: torch.futures.Future) -> torch.Tensor:
429            r"""
430            Perform partial :class:`ZeroRedundancyOptimizer` :meth:`step` using gradients in the :class:`DistributedDataParallel`.
431
432            Returns:
433                A :class:`torch.Tensor` representing the contents of the
434                gradient bucket.
435            """
436            overlap_info = zero._overlap_info
437            bucket_index = bucket.index()
438            rank = zero.global_rank
439
440            assigned_ranks = overlap_info.assigned_ranks_per_bucket[bucket_index]
441            overlap_info.bucket_indices_seen.append(bucket_index)
442            if rank in assigned_ranks:
443                _perform_local_step(bucket, zero, rank)
444
445            _broadcast_bucket(bucket_index, zero)
446
447            num_buckets = len(overlap_info.params_per_bucket)
448            if len(overlap_info.bucket_indices_seen) == num_buckets:
449                # Ensure that all parameter updates are finished before the
450                # next forward pass
451                overlap_info.wait_for_broadcasts()
452                overlap_info.clear_per_iter_info()
453
454            return bucket.buffer()
455
456        return fut.then(zero_step)
457
458    return hook_with_zero_interleaved_fn
459