• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2#
3# This source code is licensed under the BSD license found in the
4# LICENSE file in the root directory of this source tree.
5
6r"""Zero Redundancy Optimizer."""
7import collections
8import copy
9import enum
10import inspect
11import io
12import logging
13from itertools import chain
14from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
15
16import torch
17import torch.distributed as dist
18from torch.distributed.algorithms.join import Join, Joinable, JoinHook
19from torch.distributed.optim.utils import functional_optim_map
20from torch.optim import Optimizer
21
22
23__all__ = ["ZeroRedundancyOptimizer"]
24
25
26logger = logging.getLogger(__name__)
27
28
29# Credits:  classy_vision/generic/distributed_util.py
30def _recursive_copy_to_device(
31    value: Any,
32    non_blocking: bool,
33    device: torch.device,
34) -> Any:
35    r"""
36    Recursively searches lists, tuples, dicts and copies tensors to device if possible.
37
38    Non-tensor values are passed as-is in the result.
39
40    .. note:  These are all copies, so if there are two objects that reference
41    the same object, then after this call, there will be two different objects
42    referenced on the device.
43    """
44    if isinstance(value, torch.Tensor):
45        return value.to(device, non_blocking=non_blocking)
46
47    if isinstance(value, (list, tuple)):
48        values = [
49            _recursive_copy_to_device(val, non_blocking=non_blocking, device=device)
50            for val in value
51        ]
52        return values if isinstance(value, list) else tuple(values)
53
54    if isinstance(value, collections.abc.Mapping):
55        return {
56            key: _recursive_copy_to_device(
57                val, non_blocking=non_blocking, device=device
58            )
59            for key, val in value.items()
60        }
61
62    return value
63
64
65def _is_trainable(param: torch.Tensor) -> bool:
66    r"""Return if a parameter is trainable, where trainability is equivalent to requiring a gradient."""
67    return param.requires_grad
68
69
70def _broadcast_object(
71    obj: Any,
72    src_rank: int,
73    group: object = dist.group.WORLD,
74    device: torch.device = torch.device("cpu"),
75) -> Any:
76    r"""
77    Broadcasts an object to the given group.
78
79    It will be sending the object if called from the source rank and receiving
80    the object otherwise.
81
82    Arguments:
83        obj: object to broadcast; only used if called on the source rank.
84        src_rank (int): source rank.
85        group (``ProcessGroup``, optional): group used for the broadcast
86            (default: ``dist.group.WORLD``).
87        device (``torch.device``, optional): device to send from or receive
88            to (default: ``torch.device("cpu")``).
89
90    Returns:
91        The broadcasted object.
92    """
93    if dist.get_rank() == src_rank:
94        # Send the object
95        buffer = io.BytesIO()
96        torch.save(obj, buffer)
97        data = bytearray(buffer.getbuffer())
98        length_tensor = torch.LongTensor([len(data)]).to(device)
99        data_send_tensor = torch.ByteTensor(data).to(device)
100        dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
101        dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False)
102    else:
103        # Receive the object
104        length_tensor = torch.LongTensor([0]).to(device)
105        dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
106        data_recv_tensor = torch.empty(
107            [int(length_tensor.item())], dtype=torch.uint8, device=device
108        )
109        dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False)
110        buffer = io.BytesIO(data_recv_tensor.cpu().numpy())
111        obj = torch.load(buffer, map_location=device, weights_only=False)
112    return obj
113
114
115class _ZeROJoinHook(JoinHook):
116    def __init__(self, zero):
117        assert isinstance(zero, ZeroRedundancyOptimizer), (
118            "ZeRO join hook requires passing in a ZeroRedundancyOptimizer "
119            "instance as the state"
120        )
121        self.zero = zero
122        super().__init__()
123
124    def main_hook(self):
125        """
126        Perform an optimizer step.
127
128        This step updates the joined process's shard of
129        the parameters and broadcasts those parameters.
130        """
131        self.zero.step()
132
133
134class _DDPBucketAssignment:
135    r"""
136    Represent a :class:`DistributedDataParallel` bucket assignment.
137
138    This means that a (possibly non-strict) subset of the parameters corresponding to
139    a DDP bucket assigned to a rank to update.
140
141    Attributes:
142        bucket_index (int): index of the bucket determined by the DDP gradient
143            bucket all-reduce order.
144        parameters (List[torch.Tensor]): model parameters in the bucket
145            assigned to this rank.
146        offset (int): offset into the :class:`GradBucket` 's :meth:`parameters`
147            giving the index of the first element in the passed-in
148            ``parameters``; this equivalently indexes into the
149            :class:`GradBucket` 's :meth:`gradients`.
150        device (torch.device): device on which the parameters are stored.
151        tensor (torch.Tensor): flattened tensor giving the data of the
152            parameter subset assigned to the rank.
153    """
154
155    def __init__(
156        self,
157        bucket_index: int,
158        parameters: List[torch.Tensor],
159        offset: int,
160    ):
161        self.bucket_index = bucket_index
162        self.parameters = parameters
163        self.offset = offset
164        if len(self.parameters) == 0:
165            raise ValueError("Empty bucket assignment")
166        # DDP guarantees all parameters in the bucket have the same device
167        self.device: torch.device = self.parameters[0].device
168        self.tensor: Optional[torch.Tensor] = None
169
170
171class _OverlapStatus(enum.IntEnum):
172    r"""
173    Define possible statuses that :class:`ZeroRedundancyOptimizer` can be in when overlapping with :class:`DistributedDataParallel`.
174
175    Attributes:
176        ``UNINITIALIZED``: The ZeRO instance is effectively uninitialized and
177            is waiting for DDP to finalize its bucketing.
178        ``DDP_HAS_REBUILT_BUCKETS``: DDP has rebuilt its buckets, meaning that
179            its bucketing is finalized. The ZeRO instance can now collect the
180            necessary information about the DDP bucketing.
181        ``INITIALIZED``: The ZeRO instance is fully initialized and can now
182            optimize parameters.
183    """
184
185    UNINITIALIZED = 0
186    DDP_HAS_REBUILT_BUCKETS = 1
187    INITIALIZED = 2
188
189
190class _OverlapInfo:
191    r"""
192    Information needed by :class:`ZeroRedundancyOptimizer` to overlap with :class:`DistributedDataParallel`.
193
194    Arguments:
195        world_size (int): world size of the process group being used.
196
197    Attributes:
198        shard_buckets (bool): if ``True``, then the assignment of each
199            :class:`DistributedDataParallel` bucket is partitioned across
200            possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e.
201            across possibly multiple ranks) to approximate uniformity following
202            a threshold given by the total parameter size divided by the world
203            size; if ``False``, then each bucket is wholly assigned to a single
204            :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank);
205            this should be set to the value passed into the hook constructor.
206        status (_OverlapStatus): current status; see :class:`_OverlapStatus`
207            for more information.
208        params_per_bucket (List[List[torch.Tensor]]): ``params_per_bucket[i]``
209            gives the model parameters in the ``i``th bucket.
210        params_per_rank (List[List[torch.Tensor]]): ``params_per_rank[i]``
211            gives the model parameters assigned to the ``i``th rank, where the
212            parameters are grouped by increasing bucket indices.
213        offsets (Dict[int, int]): maps from bucket index to the offset in
214            ``self.params_per_rank[rank]`` giving the index of the first
215            parameter in that bucket, where ``rank`` is this process's own
216            rank; the keys of this :class:`dict` are the bucket indices
217            assigned to this rank.
218        num_bucket_assignments (int): total number of bucket assignments across
219            all ranks; this is equal to the number of
220            :class:`DistributedDataParallel` gradient buckets if
221            ``shard_buckets=False`` and possibly greater otherwise.
222        total_size (int, optional): total size of all buckets (i.e. sum of
223            ``param.numel()`` for all ``param`` across all buckets) if
224            ``shard_buckets=True``; otherwise, ``None``.
225        broadcast_handles (List[Work]): :class:`list` of async work handles for
226            the parameter broadcasts.
227        bucket_index_to_future (Dict[int, torch.futures.Future]):
228            :class:`dict` mapping bucket index to the corresponding all-reduce
229            future.
230        bucket_index_to_bucket (Dict[int, dist.GradBucket]): :class:`dict`
231            mapping bucket index to the corresponding bucket.
232        bucket_indices_seen (List[int]): :class:`list` of the bucket indices
233            seen on this iteration.
234    """
235
236    def __init__(self, world_size) -> None:
237        self.status: _OverlapStatus = _OverlapStatus.UNINITIALIZED
238        self.shard_buckets: bool = False
239
240        # Modified per bucket reconstruction
241        self.params_per_bucket: List[List[torch.Tensor]] = []
242        self.params_per_rank: List[List[torch.Tensor]] = [[] for _ in range(world_size)]
243        self.offsets: Dict[int, int] = {}
244        # Group Ranks
245        self.assigned_ranks_per_bucket: List[Set[int]] = []
246        self.num_bucket_assignments: int = 0
247        self.total_size: Optional[int] = None
248
249        # Modified per iteration
250        self.broadcast_handles: List[Any] = []
251        self.bucket_indices_seen: List[int] = []
252        # Used by `hook_with_zero_step()`
253        self.bucket_index_to_future: Dict[int, torch.futures.Future] = {}
254        self.bucket_index_to_bucket: Dict[int, dist.GradBucket] = {}
255
256    def wait_for_broadcasts(self) -> None:
257        r"""
258        Wait for all parameter broadcasts.
259
260        This function should be called once all broadcasts have been scheduled,
261        meaning ``self.broadcast_handles`` is filled. This clears ``self.broadcast_handles``
262        in preparation for the next iteration.
263        """
264        assert (
265            len(self.broadcast_handles) == self.num_bucket_assignments
266        ), f"Missing at least one broadcast handle on rank {dist.get_rank()}"
267        _ = [x.wait() for x in self.broadcast_handles]
268        self.broadcast_handles.clear()
269
270    def clear_per_iter_info(self) -> None:
271        r"""
272        Clear the data structures that are modified per-iteration.
273
274        This function should be called at the end of an iteration.
275        """
276        self.bucket_indices_seen.clear()
277        self.bucket_index_to_future.clear()
278        self.bucket_index_to_bucket.clear()
279
280
281class ZeroRedundancyOptimizer(Optimizer, Joinable):
282    r"""
283    Wrap an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>` and shards its states across ranks in the group.
284
285    The sharing is done as described by ZeRO_.
286
287    The local optimizer instance in each rank is only
288    responsible for updating approximately ``1 / world_size`` parameters and
289    hence only needs to keep ``1 / world_size`` optimizer states. After
290    parameters are updated locally, each rank will broadcast its parameters to
291    all other peers to keep all model replicas in the same state.
292    ``ZeroRedundancyOptimizer`` can be used in conjunction with
293    :class:`torch.nn.parallel.DistributedDataParallel` to reduce per-rank peak
294    memory consumption.
295
296    ``ZeroRedundancyOptimizer`` uses a sorted-greedy algorithm to pack a number
297    of parameters at each rank. Each parameter belongs to a single rank and is
298    not divided among ranks. The partition is arbitrary and might not match the
299    the parameter registration or usage order.
300
301    Arguments:
302        params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s
303            or :class:`dict` s giving all parameters, which will be sharded
304            across ranks.
305
306    Keyword Args:
307        optimizer_class (:class:`torch.nn.Optimizer`): the class of the local
308            optimizer.
309        process_group (``ProcessGroup``, optional): ``torch.distributed``
310            ``ProcessGroup`` (default: ``dist.group.WORLD`` initialized by
311            :meth:`torch.distributed.init_process_group`).
312        parameters_as_bucket_view (bool, optional): if ``True``, parameters are
313            packed into buckets to speed up communication, and ``param.data``
314            fields point to bucket views at different offsets; if ``False``,
315            each individual parameter is communicated separately, and each
316            ``params.data`` stays intact (default: ``False``).
317        overlap_with_ddp (bool, optional): if ``True``, :meth:`step` is
318            overlapped with :class:`DistributedDataParallel` 's gradient
319            synchronization; this requires (1) either a functional optimizer
320            for the ``optimizer_class`` argument or one with a functional
321            equivalent and (2) registering a DDP communication hook
322            constructed from one of the functions in ``ddp_zero_hook.py``;
323            parameters are packed into buckets matching those in
324            :class:`DistributedDataParallel`, meaning that the
325            ``parameters_as_bucket_view`` argument is ignored.
326            If ``False``, :meth:`step` runs disjointly after the backward pass
327            (per normal).
328            (default: ``False``)
329        **defaults: any trailing arguments, which are forwarded to the local
330            optimizer.
331
332    Example::
333
334        >>> # xdoctest: +SKIP
335        >>> import torch.nn as nn
336        >>> from torch.distributed.optim import ZeroRedundancyOptimizer
337        >>> from torch.nn.parallel import DistributedDataParallel as DDP
338        >>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
339        >>> ddp = DDP(model, device_ids=[rank])
340        >>> opt = ZeroRedundancyOptimizer(
341        >>>     ddp.parameters(),
342        >>>     optimizer_class=torch.optim.Adam,
343        >>>     lr=0.01
344        >>> )
345        >>> ddp(inputs).sum().backward()
346        >>> opt.step()
347
348    .. warning::
349        Currently, ``ZeroRedundancyOptimizer`` requires that all of the
350        passed-in parameters are the same dense type.
351
352    .. warning::
353        If you pass ``overlap_with_ddp=True``, be wary of the following: Given
354        the way that overlapping :class:`DistributedDataParallel` with
355        :class:`ZeroRedundancyOptimizer` is currently implemented, the first
356        two or three training iterations do not perform parameter updates in
357        the optimizer step, depending on if ``static_graph=False`` or
358        ``static_graph=True``, respectively. This is because it needs
359        information about the gradient bucketing strategy used by
360        :class:`DistributedDataParallel`, which is not finalized until the
361        second forward pass if ``static_graph=False`` or until the third
362        forward pass if ``static_graph=True``. To adjust for this, one option
363        is to prepend dummy inputs.
364
365    .. warning:: ZeroRedundancyOptimizer is experimental and subject to change.
366
367    .. _ZeRO: https://arxiv.org/abs/1910.02054
368
369    """
370
371    def __init__(
372        self,
373        params,
374        optimizer_class: Type[Optimizer],
375        process_group: Optional[Any] = None,
376        parameters_as_bucket_view: bool = False,
377        overlap_with_ddp: bool = False,
378        **defaults: Any,
379    ):
380        r"""Init."""
381        # Perform type and assumption checks on the input parameters
382        params = self._verify_and_init_params(params)
383        self._verify_same_dense_param_type()
384
385        # NOTE: The parent constructor uses `add_param_group()` which is
386        # partially overloaded in ZeroRedundancyOptimizer, so we use the
387        # `initialized` flag to dissociate the behaviour of `add_param_group()`
388        # between the parent and child.
389        self.initialized = False
390
391        Optimizer.__init__(self, params, defaults)
392        Joinable.__init__(self)
393        # Now, all parameters are held in both `self._all_params` and
394        # `self.param_groups`
395
396        # Internal data structures (`_cache` indicates lazily evaluated)
397        self._param_to_rank_cache: Dict[torch.Tensor, int] = {}
398        self._param_to_index_cache: Dict[torch.Tensor, int] = {}
399        self._partition_parameters_cache: List[List[Dict]] = []
400        self._index_to_param_cache: List[torch.Tensor] = []
401        self._device_to_params_per_rank_cache: Dict[
402            torch.device, List[List[torch.Tensor]]
403        ] = {}
404        self._bucket_assignments_per_rank_cache: List[
405            Dict[int, _DDPBucketAssignment]
406        ] = []
407        self._is_trainable_mask = self._get_is_trainable_mask()
408
409        # Default device for collective communication and buckets
410        self._default_device = self._all_params[0].device
411
412        self.process_group = (
413            process_group if process_group is not None else dist.group.WORLD
414        )
415        self.world_size: int = dist.get_world_size(self.process_group)
416        self.rank: int = dist.get_rank(self.process_group)
417        self.global_rank: int = dist.distributed_c10d.get_global_rank(
418            self.process_group, self.rank
419        )
420
421        self._overlap_with_ddp: bool = overlap_with_ddp
422        self._optim_defaults = defaults
423        self._optim_constructor = self._get_optimizer_constructor(optimizer_class)
424
425        # If `overlap_with_ddp=True`, local optimizer initialization is delayed
426        # to run time after the necessary information has been collected
427        if not overlap_with_ddp:
428            self._init_local_optimizer()
429        else:
430            self._overlap_info: _OverlapInfo = _OverlapInfo(self.world_size)
431            if parameters_as_bucket_view:
432                logger.warning(
433                    "`parameters_as_bucket_view=True` will be ignored since "
434                    "`overlap_with_ddp=True`; instead, a different bucketing "
435                    "strategy will be used"
436                )
437
438        # `self._buckets` is used if `parameters_as_bucket_view=True`, in
439        # which case parameter data is flattened into contiguous bucket tensors
440        self.parameters_as_bucket_view = parameters_as_bucket_view
441        self._buckets: List[List[torch.Tensor]] = []
442        self._build_param_buckets()
443
444        # Optional consolidated optimizer state, only populated if this rank
445        # is the target in `consolidate_state_dict()`
446        self._all_state_dicts: List[Dict[str, Any]] = []
447
448        self.initialized = True
449
450    def _clear_cache(self) -> None:
451        r"""Clear the cached data structures giving partition information."""
452        self._partition_parameters_cache.clear()
453        self._param_to_rank_cache.clear()
454        self._index_to_param_cache.clear()
455        self._param_to_index_cache.clear()
456        self._device_to_params_per_rank_cache.clear()
457        self._bucket_assignments_per_rank_cache.clear()
458
459    def add_param_group(self, param_group: Dict[str, Any]) -> None:
460        r"""
461        Add a parameter group to the :class:`Optimizer` 's ``param_groups``.
462
463        This can be useful when fine tuning a pre-trained network, as frozen
464        layers can be made trainable and added to the :class:`Optimizer` as
465        training progresses.
466
467        Arguments:
468            param_group (dict): specifies the parameters to be optimized and
469                group-specific optimization options.
470
471        .. warning:: This method handles updating the shards on all partitions
472            but needs to be called on all ranks. Calling this on a subset of
473            the ranks will cause the training to hang because communication
474            primitives are called depending on the managed parameters and
475            expect all the ranks to participate on the same set of parameters.
476        """
477        if self.initialized and self._overlap_with_ddp:
478            raise RuntimeError(
479                "ZeroRedundancyOptimizer with `overlap_with_ddp=True` only "
480                "supports a single parameter group"
481            )
482
483        super().add_param_group(param_group)
484        # NOTE: The rest of the method assumes that the call to the parent's
485        # `add_param_group()` appends the new parameter group and preserves
486        # the previous parameter-group ordering
487
488        if self.initialized:
489            # Force a re-partitioning of the parameters
490            self._clear_cache()
491            param_groups = self._partition_parameters()[self.rank]
492            # NOTE: All parameters in the old parameter groups should be
493            # assigned to the same ranks so that the local optimizers do not
494            # need to be reinitialized
495
496            # Add the parameters assigned to this rank from the new parameter
497            # group to the local optimizer, if any
498            if len(param_groups) == len(self.optim.param_groups) + 1:
499                self.optim.add_param_group(param_groups[-1])
500
501            # Update the bucketing strategy accordingly
502            if self.parameters_as_bucket_view:
503                self._build_param_buckets()
504
505    def consolidate_state_dict(self, to: int = 0) -> None:
506        r"""
507        Consolidate a list of ``state_dict`` s (one per rank) on the target rank.
508
509        Arguments:
510            to (int): the rank that receives the optimizer states (default: 0).
511
512        Raises:
513            RuntimeError: if ``overlap_with_ddp=True`` and this method is
514                called before this :class:`ZeroRedundancyOptimizer` instance
515                has been fully initialized, which happens once
516                :class:`DistributedDataParallel` gradient buckets have been
517                rebuilt.
518
519        .. warning:: This needs to be called on all ranks.
520        """
521        self._check_overlap_initialized()
522
523        # Sync the exposed `param_groups` attributes to the local optimizer in
524        # case they have been updated
525        self._sync_param_groups(self.param_groups, self.optim.param_groups)
526
527        # Pull the sharded state from all ranks and store them in rank order
528        empty_messenger = torch.tensor(
529            [0], dtype=torch.uint8, device=self._default_device
530        )
531
532        # NOTE: We wastefully use `broadcast()` (e.g. instead of `gather()`)
533        # due to compatibility issues with NCCL backend; a possible follow-up
534        # is to move all sharded state management to RPC RRef
535        self._all_state_dicts = []
536        for rank in range(self.world_size):
537            global_rank = dist.distributed_c10d.get_global_rank(
538                self.process_group, rank
539            )
540            if self.rank == to:
541                # Consolidate all local `state_dict`s on this rank, storing on
542                # CPU to save GPU memory
543                if rank == self.rank:
544                    # Directly append own optimizer state
545                    self._all_state_dicts.append(
546                        _recursive_copy_to_device(
547                            self.optim.state_dict(),
548                            non_blocking=True,
549                            device=torch.device("cpu"),
550                        )
551                    )
552                else:
553                    # Receive the optimizer state from the source rank
554                    local_state_dict = _broadcast_object(
555                        empty_messenger,
556                        src_rank=global_rank,
557                        group=self.process_group,
558                        device=self._default_device,
559                    )
560                    self._all_state_dicts.append(
561                        _recursive_copy_to_device(
562                            local_state_dict,
563                            non_blocking=True,
564                            device=torch.device("cpu"),
565                        )
566                    )
567            else:
568                if rank == self.rank:
569                    # Send the optimizer state to the target rank
570                    _ = _broadcast_object(
571                        self.optim.state_dict(),
572                        src_rank=self.global_rank,
573                        group=self.process_group,
574                        device=self._default_device,
575                    )
576                elif rank != to:
577                    # Discard the received object; `broadcast()` is used for
578                    # compatibility reasons
579                    _ = _broadcast_object(
580                        empty_messenger,
581                        src_rank=global_rank,
582                        group=self.process_group,
583                        device=self._default_device,
584                    )
585
586    def _verify_params_per_rank(
587        self,
588        params_per_rank: List[List[torch.Tensor]],
589    ) -> None:
590        r"""
591        Verify ``params_per_rank`` for :meth:`_partition_parameters`.
592
593        The verification is done by checking that ``params_per_rank`` has length equal
594        to the world size and that it does not contain any parameters not passed into the
595        :class:`ZeroRedundancyOptimizer` constructor.
596
597        The parameters in ``params_per_rank`` being a strict subset of those
598        passed into the constructor is valid since some parameters may be
599        frozen.
600
601        Raises:
602            ValueError: if ``params_per_rank`` does not have length equal to
603                the world size or if it contains a parameter that was not
604                passed into the :class:`ZeroRedundancyOptimizer` constructor.
605        """
606        if len(params_per_rank) != self.world_size:
607            raise ValueError(
608                "`params_per_rank` must have length equal to the world size"
609            )
610        all_params_set = set(self._all_params)
611        for params in params_per_rank:
612            for param in params:
613                if param not in all_params_set:
614                    raise ValueError(
615                        "Passing a new parameter in `params_per_rank` that "
616                        "was not passed into the ZeroRedundancyOptimizer "
617                        "constructor"
618                    )
619
620    def _partition_param_group(
621        self, param_group: Dict[str, Any], params_per_rank: List[List[torch.Tensor]]
622    ) -> None:
623        r"""
624        Partition the parameter group ``param_group`` according to ``params_per_rank``.
625
626        The partition will modify the ``self._partition_parameters_cache``. This method should
627        only be used as a subroutine for :meth:`_partition_parameters`.
628
629        Arguments:
630            param_group (dict[str, Any]): a parameter group as normally defined
631                in an optimizer state.
632            params_per_rank (list[list[torch.Tensor]]): a :class:`list` of
633                length world size containing :class:`list` s of parameters to
634                assign to each rank.
635        """
636        for rank, params in enumerate(params_per_rank):
637            rank_param_group = copy.copy(param_group)
638            rank_param_group["params"] = params
639            self._partition_parameters_cache[rank].append(rank_param_group)
640
641    def _partition_parameters(
642        self,
643        params_per_rank: Optional[List[List[torch.Tensor]]] = None,
644    ) -> List[List[Dict]]:
645        r"""
646        Partitions parameters across distributed data parallel ranks.
647
648        Arguments:
649            params_per_rank (list[list[torch.Tensor]], optional): a
650                :class:`list` of length world size containing :class:`list` s
651                of parameters to assign to each rank; this provides a way to
652                specify a partition manually.
653                If ``None``, the parameters are partitioned according to an
654                internal algorithm.
655                (default: ``None``)
656
657        Returns:
658            A :class:`list` where each element of the list contains the
659            ``param_groups`` for a rank (which itself is a :class:`list` of
660            :class:`dict`); element 0 corresponds to rank 0, etc.; each rank
661            stores the ``param_groups`` for all ranks for the collective
662            communication in :meth:`step`.
663
664        Raises:
665            ValueError: see :meth:`_validate_params_per_rank`.
666            RuntimeError: if ``params_per_rank`` is not ``None`` and this
667                :class:`ZeroRedundancyOptimizer` instance is using more than
668                one parameter group.
669        """
670        if params_per_rank is None:
671            # Partition the parameters optimizing for uniformity
672            if len(self._partition_parameters_cache) == 0:
673                self._partition_parameters_cache = [[] for _ in range(self.world_size)]
674                sizes = [0] * self.world_size
675                for param_group in self.param_groups:
676                    param_group_params_per_rank: List[List] = [
677                        [] for _ in range(self.world_size)
678                    ]
679                    # Sort the parameters by size (largest first)
680                    params_sorted = sorted(
681                        param_group["params"], key=lambda t: t.numel(), reverse=True
682                    )
683                    for param in params_sorted:
684                        # Greedily add the parameter to rank with smallest size so far
685                        rank = self._get_min_index(sizes)
686                        param_group_params_per_rank[rank].append(param)
687                        sizes[rank] += param.numel()
688                    # Apply the constructed partition of the parameter group
689                    self._partition_param_group(
690                        param_group, param_group_params_per_rank
691                    )
692
693            return self._partition_parameters_cache
694
695        # Partition the parameters according to `params_per_rank`
696        assert len(self._partition_parameters_cache) == 0, (
697            "Specifying `params_per_rank` should only be done when the "
698            "parameters have not been partitioned yet"
699        )
700        if len(self.param_groups) != 1:
701            raise RuntimeError(
702                "Specifying `params_per_rank` only supports a single parameter group"
703            )
704        self._verify_params_per_rank(params_per_rank)
705        self._partition_parameters_cache = [[] for _ in range(self.world_size)]
706
707        # Apply the passed-in partition of the parameter group
708        param_group = self.param_groups[0]
709        self._partition_param_group(param_group, params_per_rank)
710
711        return self._partition_parameters_cache
712
713    @property
714    def _param_to_rank(self) -> Dict[torch.Tensor, int]:
715        r""":class:`dict` mapping parameters to their assigned data parallel rank in the partition."""
716        if len(self._param_to_rank_cache) == 0:
717            for rank, param_groups in enumerate(self._partition_parameters()):
718                for param_group in param_groups:
719                    for param in param_group["params"]:
720                        self._param_to_rank_cache[param] = rank
721        return self._param_to_rank_cache
722
723    @property
724    def _param_to_index(self) -> Dict[torch.Tensor, int]:
725        r"""
726        :class:`dict` mapping parameters to their indices in the global optimizer state.
727
728        NOTE: This assumes that the global optimizer state's indexing (in
729        ``state_dict``) follows a linear ordering over the parameter groups.
730        """
731        if len(self._param_to_index_cache) == 0:
732            self._param_to_index_cache = {
733                p: i
734                for i, p in enumerate(chain(*(g["params"] for g in self.param_groups)))
735            }
736        return self._param_to_index_cache
737
738    @property
739    def _index_to_param(self) -> List[torch.Tensor]:
740        r"""List mapping parameter indices in the global optimizer scheme to the actual params."""
741        if len(self._index_to_param_cache) == 0:
742            self._index_to_param_cache = list(
743                chain(*(g["params"] for g in self.param_groups))
744            )
745        return self._index_to_param_cache
746
747    def _broadcast_params_from_rank(self, rank: int):
748        r"""
749        Broadcast the shard of parameters from a given rank to all other ranks asynchronously.
750
751        Arguments:
752            rank (int): the source rank.
753
754        Returns:
755            A :class:`list` of async work handles for the ``broadcast()`` s
756            performed to synchronize the parameters.
757        """
758        assert not self._overlap_with_ddp, (
759            "`_broadcast_params_from_rank()` should not be used if "
760            "`overlap_with_ddp=True`; instead, the broadcasting should "
761            "happen in the DDP communication hook"
762        )
763        handles = []
764        if self.parameters_as_bucket_view:
765            for dev_i_buckets in self._buckets:
766                bucket = dev_i_buckets[rank]
767                global_rank = dist.distributed_c10d.get_global_rank(
768                    self.process_group, rank
769                )
770                handles.append(
771                    dist.broadcast(
772                        tensor=bucket,
773                        src=global_rank,
774                        group=self.process_group,
775                        async_op=True,
776                    )
777                )
778        else:
779            param_groups = self._partition_parameters()[rank]
780            global_rank = dist.distributed_c10d.get_global_rank(
781                self.process_group, rank
782            )
783            for param_group in param_groups:
784                for param in param_group["params"]:
785                    handles.append(
786                        dist.broadcast(
787                            tensor=param.data,
788                            src=global_rank,
789                            group=self.process_group,
790                            async_op=True,
791                        )
792                    )
793        return handles
794
795    def _sync_params(self):
796        r"""
797        Sync all parameter shards across the ranks.
798
799        This rank sends its shard of the parameters to all other ranks and
800        receives a shard from each other rank. This is done using
801        ``broadcast()``. Parameters are sent bucket-by-bucket if
802        ``parameters_as_bucket_view=True``and sent parameter-by-parameter
803        otherwise.
804        """
805        handles = []
806        for rank in range(self.world_size):
807            handles.extend(self._broadcast_params_from_rank(rank))
808        _ = [x.wait() for x in handles]
809
810    @property
811    def _device_to_params_per_rank(
812        self,
813    ) -> Dict[torch.device, List[List[torch.Tensor]]]:
814        r"""
815        Return device parameters assigned per rank.
816
817        :class:`dict` mapping each device to a :class:`list` of the per-rank parameter
818        lists filtered to only include the parameters stored on that device.
819        Each per-rank parameter list gives the parameters assigned to that rank
820        to update.
821
822        This is used for constructing the parameter buckets if
823        ``parameters_as_bucket_view=True``.
824
825        Let ``dev_i`` denote the ``i``th device for this rank. Then:
826        ``dev_0`` maps to a list containing:
827            rank 0's assigned parameters stored on ``dev_0``,
828            rank 1's assigned parameters stored on ``dev_0``,
829            ...
830        ``dev_1`` maps to a list containing:
831            rank 0's assigned parameters stored on ``dev_1``,
832            rank 1's assigned parameters stored on ``dev_1``,
833            ...
834        ...
835        """
836        assert self.parameters_as_bucket_view, (
837            "`_device_to_params_per_rank` should only be used if "
838            "`parameters_as_bucket_view=True`"
839        )
840        if len(self._device_to_params_per_rank_cache) == 0:
841            for rank, param_groups in enumerate(self._partition_parameters()):
842                for param_group in param_groups:
843                    for param in param_group["params"]:
844                        device = param.device
845                        if device not in self._device_to_params_per_rank_cache:
846                            self._device_to_params_per_rank_cache[device] = [
847                                [] for _ in range(self.world_size)
848                            ]
849                        self._device_to_params_per_rank_cache[device][rank].append(
850                            param
851                        )
852        return self._device_to_params_per_rank_cache
853
854    def _get_min_index(
855        self,
856        values: List[int],
857        disallowed_indices: Optional[Set[int]] = None,
858    ) -> int:
859        r"""
860        Return ``values.index(min(values))``, except only uses one pass.
861
862        It also excludes any indices in ``disallowed_indices`` if provided.
863
864        Arguments:
865            values: (List[int]): :class:`list` of values.
866            disallowed_indices (Optional[Set[int]]): indices that are
867                disallowed from being the returned min index.
868        """
869        min_index = -1
870        min_value = float("inf")
871        for i, value in enumerate(values):
872            if disallowed_indices and i in disallowed_indices:
873                continue
874            if value < min_value:
875                min_value = value
876                min_index = i
877        assert min_index >= 0, "All indices are disallowed"
878        return min_index
879
880    def _assign_bucket_subset_to_rank(
881        self,
882        bucket_index: int,
883        bucket_params: List[torch.Tensor],
884        bucket_offset: int,
885        assigned_rank: int,
886        assigned_ranks_per_bucket: List[Set[int]],
887    ) -> None:
888        r"""
889        Assign ``bucket_params`` to the rank with the least size assigned so far and collects relevant information.
890
891        The model parameters given by ``bucket_params`` represents a (possibly non-strict)
892        subset of the parameters corresponding to a :class:`DistributedDataParallel` bucket.
893
894        Arguments:
895            bucket_index (int): index of the :class:`DistributedDataParallel`
896                gradient bucket.
897            bucket_params (List[torch.Tensor]): subset of the parameters
898                corresponding to the bucket to assign.
899            bucket_offset (int): offset giving the index of the first element
900                in ``bucket_params`` in the bucket's full parameter list.
901            assigned_rank (int): group rank to assign to.
902            assigned_ranks_per_bucket (List[Set[int]]): :class:`set` of group ranks
903                assigned to each bucket.
904        """
905        overlap_info = self._overlap_info
906        if len(bucket_params) == 0:
907            raise ValueError("Empty bucket assignment")
908        params_per_rank = overlap_info.params_per_rank
909        offsets = overlap_info.offsets
910
911        self._bucket_assignments_per_rank_cache[assigned_rank][
912            bucket_index
913        ] = _DDPBucketAssignment(bucket_index, bucket_params, bucket_offset)
914        if self.global_rank == assigned_rank:
915            offsets[bucket_index] = len(params_per_rank[assigned_rank])
916        params_per_rank[assigned_rank].extend(bucket_params)
917        assigned_ranks_per_bucket[bucket_index].add(assigned_rank)
918        self._overlap_info.num_bucket_assignments += 1
919
920    @property
921    def _bucket_assignments_per_rank(self) -> List[Dict[int, _DDPBucketAssignment]]:
922        r"""
923        Return DDP bucket parameters assigned per rank.
924
925        :class:`list` of length world size consisting of :class:`dict` s
926        mapping bucket indices to :class:`_DDPBucketAssignment` s for each
927        rank.
928        """
929        assert (
930            self._overlap_with_ddp
931        ), "`_bucket_assignments_per_rank` only be used if `overlap_with_ddp=True`"
932        if len(self._bucket_assignments_per_rank_cache) > 0:
933            return self._bucket_assignments_per_rank_cache
934
935        overlap_info = self._overlap_info
936        assert overlap_info.status == _OverlapStatus.INITIALIZED
937
938        self._bucket_assignments_per_rank_cache = [{} for _ in range(self.world_size)]
939        params_per_bucket = overlap_info.params_per_bucket
940
941        if overlap_info.shard_buckets:
942            # Define the assignment threshold to approximate uniformity
943            assert overlap_info.total_size is not None, "`total_size` was not computed"
944            threshold = overlap_info.total_size / self.world_size  # type: ignore[operator]
945            size_per_rank = [0 for _ in range(self.world_size)]
946
947        num_buckets = len(params_per_bucket)
948        overlap_info.assigned_ranks_per_bucket = [set() for _ in range(num_buckets)]
949        assigned_ranks_per_bucket = overlap_info.assigned_ranks_per_bucket
950        if not overlap_info.shard_buckets:
951            # Assign each DDP bucket entirely to a single rank
952            for bucket_index, bucket_params in enumerate(params_per_bucket):
953                assert len(bucket_params) > 0, "Empty bucket"
954                assigned_rank = self._get_assigned_rank(bucket_index)
955                self._assign_bucket_subset_to_rank(
956                    bucket_index,
957                    bucket_params,
958                    0,
959                    assigned_rank,
960                    assigned_ranks_per_bucket,
961                )
962        else:
963            # Assign each DDP bucket to possibly multiple ranks
964            # Specifically, sort the DDP buckets by increasing size, and for
965            # each bucket, iteratively assign the maximal unassigned subset
966            # with size less than `threshold` to the rank with the least total
967            # size so far -- each such assignment is represented by a
968            # `_DDPBucketAssignment` instance and only contains parameters from
969            # a single DDP bucket
970            params_per_bucket_enum = sorted(
971                enumerate(params_per_bucket), key=lambda x: sum(p.numel() for p in x[1])
972            )
973            for bucket_index, bucket_params in params_per_bucket_enum:
974                assert len(bucket_params) > 0, "Empty bucket"
975                bucket_offset = 0
976                assignment_size = 0
977                for param_index, param in enumerate(bucket_params):
978                    param_numel = param.numel()
979                    if (
980                        assignment_size + param_numel >= threshold
981                        and param_index > bucket_offset
982                    ):
983                        assigned_rank = self._get_min_index(
984                            size_per_rank, assigned_ranks_per_bucket[bucket_index]
985                        )
986                        # Include up to but not including the parameter that
987                        # exceeded the threshold
988                        self._assign_bucket_subset_to_rank(
989                            bucket_index,
990                            bucket_params[bucket_offset:param_index],
991                            bucket_offset,
992                            assigned_rank,
993                            assigned_ranks_per_bucket,
994                        )
995                        size_per_rank[assigned_rank] += assignment_size
996                        bucket_offset = param_index
997                        assignment_size = 0
998                    assignment_size += param_numel
999                # Assign the remainder of the bucket so that no assignment
1000                # spans across two buckets
1001                assigned_rank = self._get_min_index(
1002                    size_per_rank, assigned_ranks_per_bucket[bucket_index]
1003                )
1004                self._assign_bucket_subset_to_rank(
1005                    bucket_index,
1006                    bucket_params[bucket_offset:],
1007                    bucket_offset,
1008                    assigned_rank,
1009                    assigned_ranks_per_bucket,
1010                )
1011                size_per_rank[assigned_rank] += assignment_size
1012
1013        return self._bucket_assignments_per_rank_cache
1014
1015    def _local_step(
1016        self,
1017        gradients: Optional[List[Optional[torch.Tensor]]] = None,
1018        closure: Optional[Callable[[], float]] = None,
1019        **kwargs: Any,
1020    ) -> Optional[float]:
1021        r"""
1022        Perform a single optimizer step without syncing parameters across ranks.
1023
1024        Arguments:
1025            gradients (list[Optional[torch.Tensor]], optional): a :class:`list`
1026                of length equal to the number of parameters assigned to this
1027                rank containing gradient tensors or ``None`` as its elements;
1028                a ``None`` in the :class:`list` indicates that the
1029                corresponding parameter should not be updated.
1030                If the argument itself is ``None``, then all parameters are
1031                updated, and the gradients are assumed to be already populated.
1032                (default: ``None``)
1033            closure (Callable): a closure that re-evaluates the model and
1034                returns the loss; optional for most optimizers and should be
1035                ``None`` if ``gradients`` is not ``None``; (default: ``None``)
1036        Returns:
1037            Optional loss depending on the underlying local optimizer.
1038
1039        .. warning::
1040            The argument ``gradients`` should only be specified (i.e. not
1041            ``None``) if ``overlap_with_ddp=True``, in which case
1042            :class:`ZeroRedundancyOptimizer` wraps a functional optimizer.
1043        """
1044        Join.notify_join_context(self)
1045        # Check if the model trainability has changed
1046        is_trainable_mask = self._get_is_trainable_mask()
1047        if is_trainable_mask != self._is_trainable_mask:
1048            if self._overlap_with_ddp:
1049                raise RuntimeError(
1050                    "ZeroRedundancyOptimizer with `overlap_with_ddp=True` "
1051                    "does not support changing parameter trainability at run "
1052                    "time"
1053                )
1054            logger.warning(
1055                "ZeroRedundancyOptimizer detected that the trainable "
1056                "parameters changed; rebuilding the parameter buckets if "
1057                "enabled"
1058            )
1059            self._build_param_buckets()
1060            self._is_trainable_mask = is_trainable_mask
1061
1062        # Sync the exposed `param_groups` attributes to the local optimizer in
1063        # case they have been updated
1064        self._sync_param_groups(self.param_groups, self.optim.param_groups)
1065
1066        # Run the optimizer step on this shard only
1067        if gradients is None:
1068            loss = (
1069                self.optim.step(**kwargs)
1070                if closure is None
1071                else self.optim.step(closure=closure, **kwargs)
1072            )
1073        else:
1074            assert self._overlap_with_ddp, (
1075                "Specifying `gradients` should not "
1076                "be used when `overlap_with_ddp=False`"
1077            )
1078            assert (
1079                closure is None
1080            ), "`closure` is not supported when using a local functional optimizer"
1081            loss = self.optim.step(gradients=gradients)
1082
1083        # Sync any updated attributes in the local optimizer to the exposed
1084        # `param_groups`
1085        self._sync_param_groups(self.optim.param_groups, self.param_groups)
1086
1087        return loss
1088
1089    def step(
1090        self,
1091        closure: Optional[Callable[[], float]] = None,
1092        **kwargs: Any,
1093    ) -> Optional[float]:
1094        r"""
1095        Perform a single optimizer step and syncs parameters across all ranks.
1096
1097        Arguments:
1098            closure (Callable): a closure that re-evaluates the model and
1099                returns the loss; optional for most optimizers.
1100        Returns:
1101            Optional loss depending on the underlying local optimizer.
1102
1103        .. note: Any extra parameters are passed to the base optimizer as-is.
1104        """
1105        if self._overlap_with_ddp:
1106            logger.warning(
1107                "`step()` should not be included in the training loop when "
1108                "`overlap_with_ddp=True`"
1109            )
1110            return None
1111
1112        # Perform the local optimizer step
1113        loss = self._local_step(closure=closure, **kwargs)
1114
1115        # Sync all of the updated parameter shards across the ranks
1116        self._sync_params()
1117
1118        return loss
1119
1120    def join_hook(self, **kwargs):
1121        r"""
1122        Return the ZeRO join hook.
1123
1124        It enables training on uneven inputs by
1125        shadowing the collective communications in the optimizer step.
1126
1127        Gradients must be properly set before this hook is called.
1128
1129        Arguments:
1130            kwargs (dict): a :class:`dict` containing any keyword arguments
1131                to modify the behavior of the join hook at run time; all
1132                :class:`Joinable` instances sharing the same join context
1133                manager are forwarded the same value for ``kwargs``.
1134
1135        This hook does not support any keyword arguments; i.e. ``kwargs`` is
1136        unused.
1137        """
1138        return _ZeROJoinHook(self)
1139
1140    @property
1141    def join_device(self) -> torch.device:
1142        r"""Return default device."""
1143        return self._default_device
1144
1145    @property
1146    def join_process_group(self) -> Any:
1147        r"""Return process group."""
1148        return self.process_group
1149
1150    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
1151        r"""
1152        Load the state pertaining to the given rank from the input ``state_dict``, updating the local optimizer as needed.
1153
1154        Arguments:
1155            state_dict (dict): optimizer state; should be an object returned
1156                from a call to :meth:`state_dict`.
1157
1158        Raises:
1159            RuntimeError: if ``overlap_with_ddp=True`` and this method is
1160                called before this :class:`ZeroRedundancyOptimizer` instance
1161                has been fully initialized, which happens once
1162                :class:`DistributedDataParallel` gradient buckets have been
1163                rebuilt.
1164        """
1165        self._check_overlap_initialized()
1166
1167        for index, value in state_dict["state"].items():
1168            param = self._index_to_param[index]
1169            if self._param_to_rank[param] != self.rank:
1170                # Clear any state irrelevant to this rank
1171                state_dict["state"][index] = None
1172            else:
1173                # Load the parameter state to the local optimizer
1174                self.optim.state[param] = _recursive_copy_to_device(
1175                    value, non_blocking=True, device=param.device
1176                )
1177                # Force zero-dimensional tensors (like Adam "step") on CPU
1178                for state_name, state_value in self.optim.state[param].items():
1179                    if torch.is_tensor(state_value) and state_value.dim() == 0:
1180                        self.optim.state[param][state_name] = state_value.cpu()
1181
1182        super().load_state_dict(state_dict)
1183
1184        # Sync the input state with the exposed and local optimizer states
1185        self._sync_param_groups(state_dict["param_groups"], self.param_groups)
1186        self._sync_param_groups(self.param_groups, self.optim.param_groups)
1187
1188    def state_dict(self) -> Dict[str, Any]:
1189        r"""
1190        Return the last global optimizer state known to this rank.
1191
1192        .. warning:
1193            If the state has not been consolidated to this rank, this raises a
1194            runtime error, and even if it has, the state may not be up-to-date,
1195            depending on when :meth:`consolidate_state_dict` was last called.
1196
1197        Raises:
1198            RuntimeError: if ``overlap_with_ddp=True`` and this method is
1199                called before this :class:`ZeroRedundancyOptimizer` instance
1200                has been fully initialized, which happens once
1201                :class:`DistributedDataParallel` gradient buckets have been
1202                rebuilt; or if this method is called without a preceding call
1203                to :meth:`consolidate_state_dict`.
1204        """
1205        self._check_overlap_initialized()
1206
1207        if len(self._all_state_dicts) == 0:
1208            raise RuntimeError(
1209                "Optimizer state has not been consolidated on this rank. "
1210                f"Please call `consolidate_state_dict(to={self.rank})` on "
1211                "all ranks beforehand if you meant to save the global state."
1212            )
1213
1214        # Get the possibly-stale global optimizer state that uses global
1215        # parameter indexing
1216        state_dict = super().state_dict()
1217
1218        # Update the global optimizer state with local state information,
1219        # factoring in the translation from local to global indexing
1220        for rank, local_state_dict in enumerate(self._all_state_dicts):
1221            local_param_groups = local_state_dict["param_groups"]
1222            global_param_groups = self._partition_parameters()[rank]
1223            assert len(local_param_groups) == len(
1224                global_param_groups
1225            ), "Mismatch between number of local and global parameter groups"
1226
1227            for local_param_group, global_param_group in zip(
1228                local_param_groups, global_param_groups
1229            ):
1230                # `local_param_group` stores local indices, while
1231                # `global_param_group` stores the tensors directly
1232                local_param_indices = local_param_group["params"]
1233                global_params = global_param_group["params"]
1234
1235                assert len(local_param_indices) == len(
1236                    global_params
1237                ), "Mismatch between number of local and global parameters in parameter group"
1238                for local_param_index, global_param in zip(
1239                    local_param_indices, global_params
1240                ):
1241                    # Update the global parameter state, if any
1242                    if local_param_index in local_state_dict["state"]:
1243                        global_param_index = self._param_to_index[global_param]
1244                        state_dict["state"][global_param_index] = local_state_dict[
1245                            "state"
1246                        ][local_param_index]
1247
1248        # Sort the parameters in the state
1249        state_dict["state"] = dict(sorted(state_dict["state"].items()))
1250        return state_dict
1251
1252    @staticmethod
1253    def _sync_param_groups(
1254        src_param_groups: List[Dict[Any, Any]],
1255        dst_param_groups: List[Dict[Any, Any]],
1256    ) -> None:
1257        r"""
1258        Sync the attributes from the source parameter groups to the destination parameter groups.
1259
1260        Example attributes include learning rate or scheduler attributes. The
1261        two parameter groups should have the same length (i.e. same number of
1262        parameter groups).
1263
1264        Arguments:
1265            src_param_groups (list[dict]): parameter groups giving the
1266                attribute settings to copy.
1267            dst_param_groups (list[dict]): parameter groups giving the
1268                attribute settings to set.
1269        """
1270        assert len(src_param_groups) == len(
1271            dst_param_groups
1272        ), "Mismatch between number of source and destination parameter groups"
1273        for src_param_group, dst_param_group in zip(src_param_groups, dst_param_groups):
1274            # Sync all attributes except the parameters
1275            for attr in filter(lambda x: x != "params", src_param_group.keys()):
1276                dst_param_group[attr] = src_param_group[attr]
1277
1278    def _build_param_buckets(self) -> None:
1279        r"""
1280        Build parameter buckets if ``parameters_as_bucket_view=True``.
1281
1282        For each device that stores this rank's parameters, there is a
1283        bucket (represented as a tensor) containing all of the parameters on
1284        that device that are assigned to a given rank in the parameter update
1285        partition.
1286
1287        This method is called in the constructor and any time parameter
1288        trainability is changed.
1289
1290        .. warning::
1291            The current implementation assumes that all of the parameters in a
1292            bucket are of the same dense type when allocating the bucket's
1293            tensor.
1294
1295        .. warning::
1296            If the model parameters are stored across more than one device,
1297            then the storage partitioning must be the same across all
1298            processes in order for parameter synchronization to work.
1299        """
1300        if not self.parameters_as_bucket_view or self._overlap_with_ddp:
1301            return
1302
1303        # `self._buckets[i][j]` are the parameters stored on device i and
1304        # assigned to rank j
1305        num_devices = len(self._device_to_params_per_rank)
1306        self._buckets = [[] for _ in range(num_devices)]  # type: ignore[assignment]
1307
1308        for dev_i, (device, params_per_rank) in enumerate(
1309            self._device_to_params_per_rank.items()
1310        ):
1311            for params in params_per_rank:
1312                bucket_size = 0
1313                dtype = None
1314                trainable_params = []
1315                for param in params:
1316                    if not _is_trainable(param):
1317                        # Clone in case the parameter was previously part of
1318                        # a bucket to avoid the data from being destroyed
1319                        param.data = param.data.detach().clone()
1320                    else:
1321                        bucket_size += param.numel()
1322                        trainable_params.append(param)
1323                    dtype = param.dtype  # assumes all same dtype
1324
1325                if bucket_size == 0:
1326                    # Create a dummy bucket if there are no parameters
1327                    bucket = torch.zeros(1, device=device)
1328                else:
1329                    # Construct the bucket (assuming all dense and same dtype)
1330                    bucket = torch.empty(bucket_size, dtype=dtype, device=device)
1331                    offset = 0
1332                    for param in trainable_params:
1333                        offset_next = offset + param.numel()
1334                        bucket[offset:offset_next].copy_(param.data.flatten())
1335                        param.data = bucket[offset:offset_next].view_as(param.data)
1336                        offset = offset_next
1337                self._buckets[dev_i].append(bucket)  # type: ignore[arg-type]
1338
1339    def _build_ddp_param_buckets(self) -> None:
1340        r"""
1341        Build the DDP bucket with parameters assigned to this rank.
1342
1343        For each DDP bucket with parameters assigned to this rank, flattens the
1344        data of those parameters into a single tensor and saves the tensor to
1345        the ``tensor`` attribute in the corresponding
1346        :class:`_DDPBucketAssignment` instance stored in
1347        ``self._bucket_assignments_per_rank``.
1348
1349        :class:`DistributedDataParallel` guarantees that the parameters
1350        corresponding to a gradient bucket have the same device and the same
1351        dtype.
1352        """
1353        for bucket_assignments in self._bucket_assignments_per_rank:
1354            for bucket_assignment in bucket_assignments.values():
1355                params = bucket_assignment.parameters
1356                bucket_size = 0
1357                dtype = None
1358                for param in params:
1359                    assert _is_trainable(param), (
1360                        "Model parameter "
1361                        "corresponding to a gradient in a DDP bucket should "
1362                        "require a gradient"
1363                    )
1364                    bucket_size += param.numel()
1365                    dtype = param.dtype  # assumes all same dtype
1366                assert bucket_size > 0, "Empty bucket"
1367
1368                # Construct the bucket tensor (assuming all dense and same dtype)
1369                tensor = torch.empty(
1370                    bucket_size, dtype=dtype, device=bucket_assignment.device
1371                )
1372                offset = 0
1373                for param in params:
1374                    offset_next = offset + param.numel()
1375                    tensor[offset:offset_next].copy_(param.data.flatten())
1376                    param.data = tensor[offset:offset_next].view_as(param.data)
1377                    offset = offset_next
1378                bucket_assignment.tensor = tensor
1379
1380    def _verify_and_init_params(
1381        self,
1382        params: Any,
1383    ) -> Union[List[torch.Tensor], List[dict]]:
1384        r"""
1385        Verify the type of ``params`` and initializes ``self._all_params`` as a :class:`list` of all parameters.
1386
1387        The initializagtion will first make sure that provided ``params`` is valid.
1388
1389        Arguments:
1390            params (Any): Candidate parameter list or parameter groups to verify.
1391
1392        Raises:
1393            TypeError: ``params`` has an invalid type.
1394            ValueError: ``params`` is empty.
1395
1396        Returns:
1397            The persistent form of ``params`` to be passed into the parent
1398            :class:`Optimizer` constructor -- i.e. returns ``params`` as a
1399            :class:`list` to ensure that it can be iterated over again.
1400        """
1401        if isinstance(params, torch.Tensor):
1402            raise TypeError(
1403                "`params` argument should be an iterable of "
1404                f"Tensors, but got {torch.typename(params)}"
1405            )
1406        try:
1407            all_params = list(params)
1408        except TypeError as e:
1409            raise TypeError(
1410                "`params` argument should be an iterable of Tensors"
1411                f" or dicts, but got {torch.typename(params)}"
1412            ) from e
1413        if len(all_params) == 0:
1414            raise ValueError("ZeroRedundancyOptimizer got an empty parameter list")
1415        all_tensors = True
1416        all_dicts = True
1417        for param in all_params:
1418            all_tensors &= isinstance(param, torch.Tensor)
1419            all_dicts &= isinstance(param, dict)
1420        if not all_tensors and not all_dicts:
1421            raise TypeError(
1422                "`params` argument should be an iterable of Tensors or dicts"
1423            )
1424        # Ensure that `self._all_params` contains a list of all parameters
1425        if all_tensors:
1426            self._all_params = all_params
1427        elif all_dicts:
1428            self._all_params = []
1429            # `all_params` contains parameter groups (not parameters)
1430            for param_group in all_params:
1431                if "params" not in param_group:
1432                    raise ValueError(
1433                        "Each parameter group passed-in via `params` must "
1434                        "have a 'params' key mapping to the parameters in "
1435                        "the group"
1436                    )
1437                self._all_params.extend(param_group["params"])
1438        return all_params
1439
1440    def _verify_same_dense_param_type(self) -> None:
1441        r"""
1442        Verify that all parameters are of the same dense type.
1443
1444        The method assumes that ``self._all_params`` has been initialized
1445        and is non-empty.
1446
1447        Raises:
1448            ValueError: ``params`` contains sparse parameters or parameters
1449            of varying dense types.
1450
1451        NOTE: This method can be removed once support for sparse parameters
1452        and varying parameter types is added.
1453        """
1454        typename = torch.typename(self._all_params[0])
1455        if self._all_params[0].is_sparse:
1456            raise ValueError(
1457                "ZeroRedundancyOptimizer only supports using "
1458                "the same dense type for all parameters but got "
1459                f"{typename}"
1460            )
1461        for param in self._all_params[1:]:
1462            other_typename = torch.typename(param)
1463            if other_typename != typename:
1464                raise ValueError(
1465                    "ZeroRedundancyOptimizer only supports "
1466                    "using the same dense type for all "
1467                    f"parameters but got both {typename} and "
1468                    f"{other_typename}"
1469                )
1470
1471    def _get_is_trainable_mask(self) -> List[bool]:
1472        r"""Return a boolean mask indicating if each parameter is trainable (``requires_grad``) or not."""
1473        return list(map(_is_trainable, self._all_params))
1474
1475    def _init_local_optimizer(self) -> None:
1476        r"""
1477        Initialize this rank's local optimizer, responsible for its subset of the parameters.
1478
1479        The local optimizer is saved in ``self.optim``.
1480        """
1481        assert (
1482            self._optim_constructor is not None
1483        ), "The local optimizer class has not been set"
1484
1485        param_groups = self._partition_parameters()[self.rank]
1486        # `overlap_with_ddp=True` requires a local functional optimizer
1487        if self._overlap_with_ddp:
1488            # Functional optimizers only support a single parameter group and
1489            # require passing in the parameters as a list
1490            assert len(param_groups) == 1, (
1491                "Initializing the local "
1492                "functional optimizer with more than one parameter group"
1493            )
1494            params = param_groups[0]["params"]
1495            # Try to pass `_allow_empty_param_list=True` to avoid erroring
1496            if (
1497                "_allow_empty_param_list"
1498                in inspect.signature(self._optim_constructor).parameters
1499            ):
1500                self.optim: Any = self._optim_constructor(
1501                    params, **self._optim_defaults, _allow_empty_param_list=True
1502                )
1503            else:
1504                logger.warning(
1505                    "%s does not support the argument "
1506                    "`_allow_empty_param_list`; ZeroRedundancyOptimizer may "
1507                    "error due to an empty parameter list",
1508                    self._optim_constructor,
1509                )
1510                self.optim: Any = self._optim_constructor(params, **self._optim_defaults)  # type: ignore[no-redef]
1511
1512            # Log information about the DDP and ZeRO bucketing
1513            if dist.get_debug_level() != dist.DebugLevel.OFF:
1514                local_numel = sum(p.numel() for p in params)
1515                num_assigned_buckets = len(
1516                    self._bucket_assignments_per_rank[self.global_rank]
1517                )
1518                logger.info(
1519                    "rank %s with %s parameters " "across %s buckets",
1520                    self.global_rank,
1521                    local_numel,
1522                    num_assigned_buckets,
1523                )
1524                if self.global_rank == 0:
1525                    logger.info(
1526                        "%s DDP " "buckets and " "%s bucket " "assignments",
1527                        len(self._overlap_info.params_per_bucket),
1528                        self._overlap_info.num_bucket_assignments,
1529                    )
1530        else:
1531            # NOTE: Passing `param_groups` into the local optimizer constructor
1532            # bypasses the empty parameter list check
1533            self.optim: Optimizer = self._optim_constructor(param_groups, **self._optim_defaults)  # type: ignore[no-redef]
1534
1535        # TODO: Manually add `self.param_groups` if using a functional
1536        # optimizer; remove this if/when the functional optimizers support
1537        # multiple parameter groups
1538        if self._overlap_with_ddp and not hasattr(self.optim, "param_groups"):
1539            assert hasattr(self.optim, "param_group"), (
1540                "The functional optimizer should set at least one of the "
1541                "attributes `param_group` or `param_groups`"
1542            )
1543            self.optim.param_groups = [self.optim.param_group]  # type: ignore[attr-defined]
1544
1545        self._sync_param_groups(self.optim.param_groups, self.param_groups)
1546
1547    def _init_zero_for_overlap(self) -> None:
1548        r"""Perform a delayed initialization of the local optimizer and the supporting data structures."""
1549        assert self._overlap_with_ddp, (
1550            "`_init_zero_for_overlap()` should only be called when "
1551            "`overlap_with_ddp=True`"
1552        )
1553        self._overlap_info.status = _OverlapStatus.INITIALIZED
1554        self._clear_cache()
1555        self._partition_parameters(self._overlap_info.params_per_rank)
1556        self._build_ddp_param_buckets()
1557        self._init_local_optimizer()
1558
1559    def _get_assigned_rank(self, bucket_index: int) -> int:
1560        r"""
1561        Return the single rank assigned to a :class:`DistributedDataParallel` gradient bucket.
1562
1563        Arguments:
1564            bucket_index (int): index of the :class:`DistributedDataParallel`
1565                bucket for which to get the assigned rank.
1566        """
1567        assert not self._overlap_info.shard_buckets, (
1568            "The bucket assignment requires global bucket information and "
1569            "will be computed later; there should be no need to use this "
1570            "method"
1571        )
1572        return bucket_index % self.world_size
1573
1574    def _check_overlap_initialized(self):
1575        r"""
1576        Check the delayed initialization depending on the value of ``overlap_with_ddp``.
1577
1578        The delayed initialization has occurred (see
1579        :meth:`_init_zero_for_overlap`) if ``overlap_with_ddp=True``, and
1580        raises a ``RuntimeError`` if not. This should preface methods that
1581        should not be run before that delayed initialization.
1582
1583        Raises:
1584            RuntimeError: if ``overlap_with_ddp=True`` and
1585                :meth:`_init_zero_for_overlap` has not been called.
1586        """
1587        if (
1588            self._overlap_with_ddp
1589            and self._overlap_info.status != _OverlapStatus.INITIALIZED
1590        ):
1591            raise RuntimeError(
1592                "This method should not be called until this "
1593                "ZeroRedundancyOptimizer instance has been fully "
1594                "initialized"
1595            )
1596
1597    def _get_optimizer_constructor(self, optimizer_class: Any) -> Any:
1598        r"""
1599        Return the optimizer constructor using validation and transformation depending on ``overlap_with_ddp``.
1600
1601        Returns:
1602            - ``optimizer_class`` if ``overlap_with_ddp=False`` and
1603                ``optimizer_class`` is not a functional optimizer.
1604            - ``optimizer_class`` if ``overlap_with_ddp=True`` and
1605                ``optimizer_class`` is already a functional optimizer.
1606            - The functional equivalent of ``optimizer_class`` if
1607                ``overlap_with_ddp=True`` and ``optimizer_class`` is not
1608                already a functional optimizer (assuming the equivalent
1609                exists).
1610
1611        Raises:
1612            ValueError:
1613
1614                - if ``overlap_with_ddp=True`` but ``optimizer_class`` is
1615                    neither a functional optimizer nor translatable to a
1616                    functional optimizer.
1617                - if ``overlap_with_ddp=False`` and ``optimizer_class`` is a
1618                    functional optimizer.
1619        """
1620        functional_optims = functional_optim_map.values()
1621        if not self._overlap_with_ddp:
1622            if optimizer_class in functional_optims:
1623                # Using a functional optimizer is only supported when
1624                # `overlap_with_ddp=True`
1625                raise ValueError(
1626                    f"Passing in a functional optimizer {optimizer_class} "
1627                    "when `overlap_with_ddp=False`"
1628                )
1629            else:
1630                return optimizer_class
1631        else:
1632            if optimizer_class in functional_optims:
1633                # Already a functional optimizer
1634                return optimizer_class
1635            elif optimizer_class in functional_optim_map:
1636                # Translate the passed-in optimizer class to its functional
1637                # equivalent if `overlap_with_ddp=True`
1638                optim_constructor = functional_optim_map[optimizer_class]
1639                logger.info(
1640                    "Using the functional optimizer %s "
1641                    "instead of %s since "
1642                    "`overlap_with_ddp=True`",
1643                    optim_constructor,
1644                    optimizer_class,
1645                )
1646                return optim_constructor
1647            else:
1648                raise ValueError(
1649                    "Using `ddp_with_overlap=True` requires using a "
1650                    "functional optimizer, but there is no supported functional "
1651                    f"optimizer equivalent for {optimizer_class}"
1652                )
1653