1# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2# 3# This source code is licensed under the BSD license found in the 4# LICENSE file in the root directory of this source tree. 5 6r"""Zero Redundancy Optimizer.""" 7import collections 8import copy 9import enum 10import inspect 11import io 12import logging 13from itertools import chain 14from typing import Any, Callable, Dict, List, Optional, Set, Type, Union 15 16import torch 17import torch.distributed as dist 18from torch.distributed.algorithms.join import Join, Joinable, JoinHook 19from torch.distributed.optim.utils import functional_optim_map 20from torch.optim import Optimizer 21 22 23__all__ = ["ZeroRedundancyOptimizer"] 24 25 26logger = logging.getLogger(__name__) 27 28 29# Credits: classy_vision/generic/distributed_util.py 30def _recursive_copy_to_device( 31 value: Any, 32 non_blocking: bool, 33 device: torch.device, 34) -> Any: 35 r""" 36 Recursively searches lists, tuples, dicts and copies tensors to device if possible. 37 38 Non-tensor values are passed as-is in the result. 39 40 .. note: These are all copies, so if there are two objects that reference 41 the same object, then after this call, there will be two different objects 42 referenced on the device. 43 """ 44 if isinstance(value, torch.Tensor): 45 return value.to(device, non_blocking=non_blocking) 46 47 if isinstance(value, (list, tuple)): 48 values = [ 49 _recursive_copy_to_device(val, non_blocking=non_blocking, device=device) 50 for val in value 51 ] 52 return values if isinstance(value, list) else tuple(values) 53 54 if isinstance(value, collections.abc.Mapping): 55 return { 56 key: _recursive_copy_to_device( 57 val, non_blocking=non_blocking, device=device 58 ) 59 for key, val in value.items() 60 } 61 62 return value 63 64 65def _is_trainable(param: torch.Tensor) -> bool: 66 r"""Return if a parameter is trainable, where trainability is equivalent to requiring a gradient.""" 67 return param.requires_grad 68 69 70def _broadcast_object( 71 obj: Any, 72 src_rank: int, 73 group: object = dist.group.WORLD, 74 device: torch.device = torch.device("cpu"), 75) -> Any: 76 r""" 77 Broadcasts an object to the given group. 78 79 It will be sending the object if called from the source rank and receiving 80 the object otherwise. 81 82 Arguments: 83 obj: object to broadcast; only used if called on the source rank. 84 src_rank (int): source rank. 85 group (``ProcessGroup``, optional): group used for the broadcast 86 (default: ``dist.group.WORLD``). 87 device (``torch.device``, optional): device to send from or receive 88 to (default: ``torch.device("cpu")``). 89 90 Returns: 91 The broadcasted object. 92 """ 93 if dist.get_rank() == src_rank: 94 # Send the object 95 buffer = io.BytesIO() 96 torch.save(obj, buffer) 97 data = bytearray(buffer.getbuffer()) 98 length_tensor = torch.LongTensor([len(data)]).to(device) 99 data_send_tensor = torch.ByteTensor(data).to(device) 100 dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) 101 dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False) 102 else: 103 # Receive the object 104 length_tensor = torch.LongTensor([0]).to(device) 105 dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) 106 data_recv_tensor = torch.empty( 107 [int(length_tensor.item())], dtype=torch.uint8, device=device 108 ) 109 dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False) 110 buffer = io.BytesIO(data_recv_tensor.cpu().numpy()) 111 obj = torch.load(buffer, map_location=device, weights_only=False) 112 return obj 113 114 115class _ZeROJoinHook(JoinHook): 116 def __init__(self, zero): 117 assert isinstance(zero, ZeroRedundancyOptimizer), ( 118 "ZeRO join hook requires passing in a ZeroRedundancyOptimizer " 119 "instance as the state" 120 ) 121 self.zero = zero 122 super().__init__() 123 124 def main_hook(self): 125 """ 126 Perform an optimizer step. 127 128 This step updates the joined process's shard of 129 the parameters and broadcasts those parameters. 130 """ 131 self.zero.step() 132 133 134class _DDPBucketAssignment: 135 r""" 136 Represent a :class:`DistributedDataParallel` bucket assignment. 137 138 This means that a (possibly non-strict) subset of the parameters corresponding to 139 a DDP bucket assigned to a rank to update. 140 141 Attributes: 142 bucket_index (int): index of the bucket determined by the DDP gradient 143 bucket all-reduce order. 144 parameters (List[torch.Tensor]): model parameters in the bucket 145 assigned to this rank. 146 offset (int): offset into the :class:`GradBucket` 's :meth:`parameters` 147 giving the index of the first element in the passed-in 148 ``parameters``; this equivalently indexes into the 149 :class:`GradBucket` 's :meth:`gradients`. 150 device (torch.device): device on which the parameters are stored. 151 tensor (torch.Tensor): flattened tensor giving the data of the 152 parameter subset assigned to the rank. 153 """ 154 155 def __init__( 156 self, 157 bucket_index: int, 158 parameters: List[torch.Tensor], 159 offset: int, 160 ): 161 self.bucket_index = bucket_index 162 self.parameters = parameters 163 self.offset = offset 164 if len(self.parameters) == 0: 165 raise ValueError("Empty bucket assignment") 166 # DDP guarantees all parameters in the bucket have the same device 167 self.device: torch.device = self.parameters[0].device 168 self.tensor: Optional[torch.Tensor] = None 169 170 171class _OverlapStatus(enum.IntEnum): 172 r""" 173 Define possible statuses that :class:`ZeroRedundancyOptimizer` can be in when overlapping with :class:`DistributedDataParallel`. 174 175 Attributes: 176 ``UNINITIALIZED``: The ZeRO instance is effectively uninitialized and 177 is waiting for DDP to finalize its bucketing. 178 ``DDP_HAS_REBUILT_BUCKETS``: DDP has rebuilt its buckets, meaning that 179 its bucketing is finalized. The ZeRO instance can now collect the 180 necessary information about the DDP bucketing. 181 ``INITIALIZED``: The ZeRO instance is fully initialized and can now 182 optimize parameters. 183 """ 184 185 UNINITIALIZED = 0 186 DDP_HAS_REBUILT_BUCKETS = 1 187 INITIALIZED = 2 188 189 190class _OverlapInfo: 191 r""" 192 Information needed by :class:`ZeroRedundancyOptimizer` to overlap with :class:`DistributedDataParallel`. 193 194 Arguments: 195 world_size (int): world size of the process group being used. 196 197 Attributes: 198 shard_buckets (bool): if ``True``, then the assignment of each 199 :class:`DistributedDataParallel` bucket is partitioned across 200 possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e. 201 across possibly multiple ranks) to approximate uniformity following 202 a threshold given by the total parameter size divided by the world 203 size; if ``False``, then each bucket is wholly assigned to a single 204 :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank); 205 this should be set to the value passed into the hook constructor. 206 status (_OverlapStatus): current status; see :class:`_OverlapStatus` 207 for more information. 208 params_per_bucket (List[List[torch.Tensor]]): ``params_per_bucket[i]`` 209 gives the model parameters in the ``i``th bucket. 210 params_per_rank (List[List[torch.Tensor]]): ``params_per_rank[i]`` 211 gives the model parameters assigned to the ``i``th rank, where the 212 parameters are grouped by increasing bucket indices. 213 offsets (Dict[int, int]): maps from bucket index to the offset in 214 ``self.params_per_rank[rank]`` giving the index of the first 215 parameter in that bucket, where ``rank`` is this process's own 216 rank; the keys of this :class:`dict` are the bucket indices 217 assigned to this rank. 218 num_bucket_assignments (int): total number of bucket assignments across 219 all ranks; this is equal to the number of 220 :class:`DistributedDataParallel` gradient buckets if 221 ``shard_buckets=False`` and possibly greater otherwise. 222 total_size (int, optional): total size of all buckets (i.e. sum of 223 ``param.numel()`` for all ``param`` across all buckets) if 224 ``shard_buckets=True``; otherwise, ``None``. 225 broadcast_handles (List[Work]): :class:`list` of async work handles for 226 the parameter broadcasts. 227 bucket_index_to_future (Dict[int, torch.futures.Future]): 228 :class:`dict` mapping bucket index to the corresponding all-reduce 229 future. 230 bucket_index_to_bucket (Dict[int, dist.GradBucket]): :class:`dict` 231 mapping bucket index to the corresponding bucket. 232 bucket_indices_seen (List[int]): :class:`list` of the bucket indices 233 seen on this iteration. 234 """ 235 236 def __init__(self, world_size) -> None: 237 self.status: _OverlapStatus = _OverlapStatus.UNINITIALIZED 238 self.shard_buckets: bool = False 239 240 # Modified per bucket reconstruction 241 self.params_per_bucket: List[List[torch.Tensor]] = [] 242 self.params_per_rank: List[List[torch.Tensor]] = [[] for _ in range(world_size)] 243 self.offsets: Dict[int, int] = {} 244 # Group Ranks 245 self.assigned_ranks_per_bucket: List[Set[int]] = [] 246 self.num_bucket_assignments: int = 0 247 self.total_size: Optional[int] = None 248 249 # Modified per iteration 250 self.broadcast_handles: List[Any] = [] 251 self.bucket_indices_seen: List[int] = [] 252 # Used by `hook_with_zero_step()` 253 self.bucket_index_to_future: Dict[int, torch.futures.Future] = {} 254 self.bucket_index_to_bucket: Dict[int, dist.GradBucket] = {} 255 256 def wait_for_broadcasts(self) -> None: 257 r""" 258 Wait for all parameter broadcasts. 259 260 This function should be called once all broadcasts have been scheduled, 261 meaning ``self.broadcast_handles`` is filled. This clears ``self.broadcast_handles`` 262 in preparation for the next iteration. 263 """ 264 assert ( 265 len(self.broadcast_handles) == self.num_bucket_assignments 266 ), f"Missing at least one broadcast handle on rank {dist.get_rank()}" 267 _ = [x.wait() for x in self.broadcast_handles] 268 self.broadcast_handles.clear() 269 270 def clear_per_iter_info(self) -> None: 271 r""" 272 Clear the data structures that are modified per-iteration. 273 274 This function should be called at the end of an iteration. 275 """ 276 self.bucket_indices_seen.clear() 277 self.bucket_index_to_future.clear() 278 self.bucket_index_to_bucket.clear() 279 280 281class ZeroRedundancyOptimizer(Optimizer, Joinable): 282 r""" 283 Wrap an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>` and shards its states across ranks in the group. 284 285 The sharing is done as described by ZeRO_. 286 287 The local optimizer instance in each rank is only 288 responsible for updating approximately ``1 / world_size`` parameters and 289 hence only needs to keep ``1 / world_size`` optimizer states. After 290 parameters are updated locally, each rank will broadcast its parameters to 291 all other peers to keep all model replicas in the same state. 292 ``ZeroRedundancyOptimizer`` can be used in conjunction with 293 :class:`torch.nn.parallel.DistributedDataParallel` to reduce per-rank peak 294 memory consumption. 295 296 ``ZeroRedundancyOptimizer`` uses a sorted-greedy algorithm to pack a number 297 of parameters at each rank. Each parameter belongs to a single rank and is 298 not divided among ranks. The partition is arbitrary and might not match the 299 the parameter registration or usage order. 300 301 Arguments: 302 params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s 303 or :class:`dict` s giving all parameters, which will be sharded 304 across ranks. 305 306 Keyword Args: 307 optimizer_class (:class:`torch.nn.Optimizer`): the class of the local 308 optimizer. 309 process_group (``ProcessGroup``, optional): ``torch.distributed`` 310 ``ProcessGroup`` (default: ``dist.group.WORLD`` initialized by 311 :meth:`torch.distributed.init_process_group`). 312 parameters_as_bucket_view (bool, optional): if ``True``, parameters are 313 packed into buckets to speed up communication, and ``param.data`` 314 fields point to bucket views at different offsets; if ``False``, 315 each individual parameter is communicated separately, and each 316 ``params.data`` stays intact (default: ``False``). 317 overlap_with_ddp (bool, optional): if ``True``, :meth:`step` is 318 overlapped with :class:`DistributedDataParallel` 's gradient 319 synchronization; this requires (1) either a functional optimizer 320 for the ``optimizer_class`` argument or one with a functional 321 equivalent and (2) registering a DDP communication hook 322 constructed from one of the functions in ``ddp_zero_hook.py``; 323 parameters are packed into buckets matching those in 324 :class:`DistributedDataParallel`, meaning that the 325 ``parameters_as_bucket_view`` argument is ignored. 326 If ``False``, :meth:`step` runs disjointly after the backward pass 327 (per normal). 328 (default: ``False``) 329 **defaults: any trailing arguments, which are forwarded to the local 330 optimizer. 331 332 Example:: 333 334 >>> # xdoctest: +SKIP 335 >>> import torch.nn as nn 336 >>> from torch.distributed.optim import ZeroRedundancyOptimizer 337 >>> from torch.nn.parallel import DistributedDataParallel as DDP 338 >>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)]) 339 >>> ddp = DDP(model, device_ids=[rank]) 340 >>> opt = ZeroRedundancyOptimizer( 341 >>> ddp.parameters(), 342 >>> optimizer_class=torch.optim.Adam, 343 >>> lr=0.01 344 >>> ) 345 >>> ddp(inputs).sum().backward() 346 >>> opt.step() 347 348 .. warning:: 349 Currently, ``ZeroRedundancyOptimizer`` requires that all of the 350 passed-in parameters are the same dense type. 351 352 .. warning:: 353 If you pass ``overlap_with_ddp=True``, be wary of the following: Given 354 the way that overlapping :class:`DistributedDataParallel` with 355 :class:`ZeroRedundancyOptimizer` is currently implemented, the first 356 two or three training iterations do not perform parameter updates in 357 the optimizer step, depending on if ``static_graph=False`` or 358 ``static_graph=True``, respectively. This is because it needs 359 information about the gradient bucketing strategy used by 360 :class:`DistributedDataParallel`, which is not finalized until the 361 second forward pass if ``static_graph=False`` or until the third 362 forward pass if ``static_graph=True``. To adjust for this, one option 363 is to prepend dummy inputs. 364 365 .. warning:: ZeroRedundancyOptimizer is experimental and subject to change. 366 367 .. _ZeRO: https://arxiv.org/abs/1910.02054 368 369 """ 370 371 def __init__( 372 self, 373 params, 374 optimizer_class: Type[Optimizer], 375 process_group: Optional[Any] = None, 376 parameters_as_bucket_view: bool = False, 377 overlap_with_ddp: bool = False, 378 **defaults: Any, 379 ): 380 r"""Init.""" 381 # Perform type and assumption checks on the input parameters 382 params = self._verify_and_init_params(params) 383 self._verify_same_dense_param_type() 384 385 # NOTE: The parent constructor uses `add_param_group()` which is 386 # partially overloaded in ZeroRedundancyOptimizer, so we use the 387 # `initialized` flag to dissociate the behaviour of `add_param_group()` 388 # between the parent and child. 389 self.initialized = False 390 391 Optimizer.__init__(self, params, defaults) 392 Joinable.__init__(self) 393 # Now, all parameters are held in both `self._all_params` and 394 # `self.param_groups` 395 396 # Internal data structures (`_cache` indicates lazily evaluated) 397 self._param_to_rank_cache: Dict[torch.Tensor, int] = {} 398 self._param_to_index_cache: Dict[torch.Tensor, int] = {} 399 self._partition_parameters_cache: List[List[Dict]] = [] 400 self._index_to_param_cache: List[torch.Tensor] = [] 401 self._device_to_params_per_rank_cache: Dict[ 402 torch.device, List[List[torch.Tensor]] 403 ] = {} 404 self._bucket_assignments_per_rank_cache: List[ 405 Dict[int, _DDPBucketAssignment] 406 ] = [] 407 self._is_trainable_mask = self._get_is_trainable_mask() 408 409 # Default device for collective communication and buckets 410 self._default_device = self._all_params[0].device 411 412 self.process_group = ( 413 process_group if process_group is not None else dist.group.WORLD 414 ) 415 self.world_size: int = dist.get_world_size(self.process_group) 416 self.rank: int = dist.get_rank(self.process_group) 417 self.global_rank: int = dist.distributed_c10d.get_global_rank( 418 self.process_group, self.rank 419 ) 420 421 self._overlap_with_ddp: bool = overlap_with_ddp 422 self._optim_defaults = defaults 423 self._optim_constructor = self._get_optimizer_constructor(optimizer_class) 424 425 # If `overlap_with_ddp=True`, local optimizer initialization is delayed 426 # to run time after the necessary information has been collected 427 if not overlap_with_ddp: 428 self._init_local_optimizer() 429 else: 430 self._overlap_info: _OverlapInfo = _OverlapInfo(self.world_size) 431 if parameters_as_bucket_view: 432 logger.warning( 433 "`parameters_as_bucket_view=True` will be ignored since " 434 "`overlap_with_ddp=True`; instead, a different bucketing " 435 "strategy will be used" 436 ) 437 438 # `self._buckets` is used if `parameters_as_bucket_view=True`, in 439 # which case parameter data is flattened into contiguous bucket tensors 440 self.parameters_as_bucket_view = parameters_as_bucket_view 441 self._buckets: List[List[torch.Tensor]] = [] 442 self._build_param_buckets() 443 444 # Optional consolidated optimizer state, only populated if this rank 445 # is the target in `consolidate_state_dict()` 446 self._all_state_dicts: List[Dict[str, Any]] = [] 447 448 self.initialized = True 449 450 def _clear_cache(self) -> None: 451 r"""Clear the cached data structures giving partition information.""" 452 self._partition_parameters_cache.clear() 453 self._param_to_rank_cache.clear() 454 self._index_to_param_cache.clear() 455 self._param_to_index_cache.clear() 456 self._device_to_params_per_rank_cache.clear() 457 self._bucket_assignments_per_rank_cache.clear() 458 459 def add_param_group(self, param_group: Dict[str, Any]) -> None: 460 r""" 461 Add a parameter group to the :class:`Optimizer` 's ``param_groups``. 462 463 This can be useful when fine tuning a pre-trained network, as frozen 464 layers can be made trainable and added to the :class:`Optimizer` as 465 training progresses. 466 467 Arguments: 468 param_group (dict): specifies the parameters to be optimized and 469 group-specific optimization options. 470 471 .. warning:: This method handles updating the shards on all partitions 472 but needs to be called on all ranks. Calling this on a subset of 473 the ranks will cause the training to hang because communication 474 primitives are called depending on the managed parameters and 475 expect all the ranks to participate on the same set of parameters. 476 """ 477 if self.initialized and self._overlap_with_ddp: 478 raise RuntimeError( 479 "ZeroRedundancyOptimizer with `overlap_with_ddp=True` only " 480 "supports a single parameter group" 481 ) 482 483 super().add_param_group(param_group) 484 # NOTE: The rest of the method assumes that the call to the parent's 485 # `add_param_group()` appends the new parameter group and preserves 486 # the previous parameter-group ordering 487 488 if self.initialized: 489 # Force a re-partitioning of the parameters 490 self._clear_cache() 491 param_groups = self._partition_parameters()[self.rank] 492 # NOTE: All parameters in the old parameter groups should be 493 # assigned to the same ranks so that the local optimizers do not 494 # need to be reinitialized 495 496 # Add the parameters assigned to this rank from the new parameter 497 # group to the local optimizer, if any 498 if len(param_groups) == len(self.optim.param_groups) + 1: 499 self.optim.add_param_group(param_groups[-1]) 500 501 # Update the bucketing strategy accordingly 502 if self.parameters_as_bucket_view: 503 self._build_param_buckets() 504 505 def consolidate_state_dict(self, to: int = 0) -> None: 506 r""" 507 Consolidate a list of ``state_dict`` s (one per rank) on the target rank. 508 509 Arguments: 510 to (int): the rank that receives the optimizer states (default: 0). 511 512 Raises: 513 RuntimeError: if ``overlap_with_ddp=True`` and this method is 514 called before this :class:`ZeroRedundancyOptimizer` instance 515 has been fully initialized, which happens once 516 :class:`DistributedDataParallel` gradient buckets have been 517 rebuilt. 518 519 .. warning:: This needs to be called on all ranks. 520 """ 521 self._check_overlap_initialized() 522 523 # Sync the exposed `param_groups` attributes to the local optimizer in 524 # case they have been updated 525 self._sync_param_groups(self.param_groups, self.optim.param_groups) 526 527 # Pull the sharded state from all ranks and store them in rank order 528 empty_messenger = torch.tensor( 529 [0], dtype=torch.uint8, device=self._default_device 530 ) 531 532 # NOTE: We wastefully use `broadcast()` (e.g. instead of `gather()`) 533 # due to compatibility issues with NCCL backend; a possible follow-up 534 # is to move all sharded state management to RPC RRef 535 self._all_state_dicts = [] 536 for rank in range(self.world_size): 537 global_rank = dist.distributed_c10d.get_global_rank( 538 self.process_group, rank 539 ) 540 if self.rank == to: 541 # Consolidate all local `state_dict`s on this rank, storing on 542 # CPU to save GPU memory 543 if rank == self.rank: 544 # Directly append own optimizer state 545 self._all_state_dicts.append( 546 _recursive_copy_to_device( 547 self.optim.state_dict(), 548 non_blocking=True, 549 device=torch.device("cpu"), 550 ) 551 ) 552 else: 553 # Receive the optimizer state from the source rank 554 local_state_dict = _broadcast_object( 555 empty_messenger, 556 src_rank=global_rank, 557 group=self.process_group, 558 device=self._default_device, 559 ) 560 self._all_state_dicts.append( 561 _recursive_copy_to_device( 562 local_state_dict, 563 non_blocking=True, 564 device=torch.device("cpu"), 565 ) 566 ) 567 else: 568 if rank == self.rank: 569 # Send the optimizer state to the target rank 570 _ = _broadcast_object( 571 self.optim.state_dict(), 572 src_rank=self.global_rank, 573 group=self.process_group, 574 device=self._default_device, 575 ) 576 elif rank != to: 577 # Discard the received object; `broadcast()` is used for 578 # compatibility reasons 579 _ = _broadcast_object( 580 empty_messenger, 581 src_rank=global_rank, 582 group=self.process_group, 583 device=self._default_device, 584 ) 585 586 def _verify_params_per_rank( 587 self, 588 params_per_rank: List[List[torch.Tensor]], 589 ) -> None: 590 r""" 591 Verify ``params_per_rank`` for :meth:`_partition_parameters`. 592 593 The verification is done by checking that ``params_per_rank`` has length equal 594 to the world size and that it does not contain any parameters not passed into the 595 :class:`ZeroRedundancyOptimizer` constructor. 596 597 The parameters in ``params_per_rank`` being a strict subset of those 598 passed into the constructor is valid since some parameters may be 599 frozen. 600 601 Raises: 602 ValueError: if ``params_per_rank`` does not have length equal to 603 the world size or if it contains a parameter that was not 604 passed into the :class:`ZeroRedundancyOptimizer` constructor. 605 """ 606 if len(params_per_rank) != self.world_size: 607 raise ValueError( 608 "`params_per_rank` must have length equal to the world size" 609 ) 610 all_params_set = set(self._all_params) 611 for params in params_per_rank: 612 for param in params: 613 if param not in all_params_set: 614 raise ValueError( 615 "Passing a new parameter in `params_per_rank` that " 616 "was not passed into the ZeroRedundancyOptimizer " 617 "constructor" 618 ) 619 620 def _partition_param_group( 621 self, param_group: Dict[str, Any], params_per_rank: List[List[torch.Tensor]] 622 ) -> None: 623 r""" 624 Partition the parameter group ``param_group`` according to ``params_per_rank``. 625 626 The partition will modify the ``self._partition_parameters_cache``. This method should 627 only be used as a subroutine for :meth:`_partition_parameters`. 628 629 Arguments: 630 param_group (dict[str, Any]): a parameter group as normally defined 631 in an optimizer state. 632 params_per_rank (list[list[torch.Tensor]]): a :class:`list` of 633 length world size containing :class:`list` s of parameters to 634 assign to each rank. 635 """ 636 for rank, params in enumerate(params_per_rank): 637 rank_param_group = copy.copy(param_group) 638 rank_param_group["params"] = params 639 self._partition_parameters_cache[rank].append(rank_param_group) 640 641 def _partition_parameters( 642 self, 643 params_per_rank: Optional[List[List[torch.Tensor]]] = None, 644 ) -> List[List[Dict]]: 645 r""" 646 Partitions parameters across distributed data parallel ranks. 647 648 Arguments: 649 params_per_rank (list[list[torch.Tensor]], optional): a 650 :class:`list` of length world size containing :class:`list` s 651 of parameters to assign to each rank; this provides a way to 652 specify a partition manually. 653 If ``None``, the parameters are partitioned according to an 654 internal algorithm. 655 (default: ``None``) 656 657 Returns: 658 A :class:`list` where each element of the list contains the 659 ``param_groups`` for a rank (which itself is a :class:`list` of 660 :class:`dict`); element 0 corresponds to rank 0, etc.; each rank 661 stores the ``param_groups`` for all ranks for the collective 662 communication in :meth:`step`. 663 664 Raises: 665 ValueError: see :meth:`_validate_params_per_rank`. 666 RuntimeError: if ``params_per_rank`` is not ``None`` and this 667 :class:`ZeroRedundancyOptimizer` instance is using more than 668 one parameter group. 669 """ 670 if params_per_rank is None: 671 # Partition the parameters optimizing for uniformity 672 if len(self._partition_parameters_cache) == 0: 673 self._partition_parameters_cache = [[] for _ in range(self.world_size)] 674 sizes = [0] * self.world_size 675 for param_group in self.param_groups: 676 param_group_params_per_rank: List[List] = [ 677 [] for _ in range(self.world_size) 678 ] 679 # Sort the parameters by size (largest first) 680 params_sorted = sorted( 681 param_group["params"], key=lambda t: t.numel(), reverse=True 682 ) 683 for param in params_sorted: 684 # Greedily add the parameter to rank with smallest size so far 685 rank = self._get_min_index(sizes) 686 param_group_params_per_rank[rank].append(param) 687 sizes[rank] += param.numel() 688 # Apply the constructed partition of the parameter group 689 self._partition_param_group( 690 param_group, param_group_params_per_rank 691 ) 692 693 return self._partition_parameters_cache 694 695 # Partition the parameters according to `params_per_rank` 696 assert len(self._partition_parameters_cache) == 0, ( 697 "Specifying `params_per_rank` should only be done when the " 698 "parameters have not been partitioned yet" 699 ) 700 if len(self.param_groups) != 1: 701 raise RuntimeError( 702 "Specifying `params_per_rank` only supports a single parameter group" 703 ) 704 self._verify_params_per_rank(params_per_rank) 705 self._partition_parameters_cache = [[] for _ in range(self.world_size)] 706 707 # Apply the passed-in partition of the parameter group 708 param_group = self.param_groups[0] 709 self._partition_param_group(param_group, params_per_rank) 710 711 return self._partition_parameters_cache 712 713 @property 714 def _param_to_rank(self) -> Dict[torch.Tensor, int]: 715 r""":class:`dict` mapping parameters to their assigned data parallel rank in the partition.""" 716 if len(self._param_to_rank_cache) == 0: 717 for rank, param_groups in enumerate(self._partition_parameters()): 718 for param_group in param_groups: 719 for param in param_group["params"]: 720 self._param_to_rank_cache[param] = rank 721 return self._param_to_rank_cache 722 723 @property 724 def _param_to_index(self) -> Dict[torch.Tensor, int]: 725 r""" 726 :class:`dict` mapping parameters to their indices in the global optimizer state. 727 728 NOTE: This assumes that the global optimizer state's indexing (in 729 ``state_dict``) follows a linear ordering over the parameter groups. 730 """ 731 if len(self._param_to_index_cache) == 0: 732 self._param_to_index_cache = { 733 p: i 734 for i, p in enumerate(chain(*(g["params"] for g in self.param_groups))) 735 } 736 return self._param_to_index_cache 737 738 @property 739 def _index_to_param(self) -> List[torch.Tensor]: 740 r"""List mapping parameter indices in the global optimizer scheme to the actual params.""" 741 if len(self._index_to_param_cache) == 0: 742 self._index_to_param_cache = list( 743 chain(*(g["params"] for g in self.param_groups)) 744 ) 745 return self._index_to_param_cache 746 747 def _broadcast_params_from_rank(self, rank: int): 748 r""" 749 Broadcast the shard of parameters from a given rank to all other ranks asynchronously. 750 751 Arguments: 752 rank (int): the source rank. 753 754 Returns: 755 A :class:`list` of async work handles for the ``broadcast()`` s 756 performed to synchronize the parameters. 757 """ 758 assert not self._overlap_with_ddp, ( 759 "`_broadcast_params_from_rank()` should not be used if " 760 "`overlap_with_ddp=True`; instead, the broadcasting should " 761 "happen in the DDP communication hook" 762 ) 763 handles = [] 764 if self.parameters_as_bucket_view: 765 for dev_i_buckets in self._buckets: 766 bucket = dev_i_buckets[rank] 767 global_rank = dist.distributed_c10d.get_global_rank( 768 self.process_group, rank 769 ) 770 handles.append( 771 dist.broadcast( 772 tensor=bucket, 773 src=global_rank, 774 group=self.process_group, 775 async_op=True, 776 ) 777 ) 778 else: 779 param_groups = self._partition_parameters()[rank] 780 global_rank = dist.distributed_c10d.get_global_rank( 781 self.process_group, rank 782 ) 783 for param_group in param_groups: 784 for param in param_group["params"]: 785 handles.append( 786 dist.broadcast( 787 tensor=param.data, 788 src=global_rank, 789 group=self.process_group, 790 async_op=True, 791 ) 792 ) 793 return handles 794 795 def _sync_params(self): 796 r""" 797 Sync all parameter shards across the ranks. 798 799 This rank sends its shard of the parameters to all other ranks and 800 receives a shard from each other rank. This is done using 801 ``broadcast()``. Parameters are sent bucket-by-bucket if 802 ``parameters_as_bucket_view=True``and sent parameter-by-parameter 803 otherwise. 804 """ 805 handles = [] 806 for rank in range(self.world_size): 807 handles.extend(self._broadcast_params_from_rank(rank)) 808 _ = [x.wait() for x in handles] 809 810 @property 811 def _device_to_params_per_rank( 812 self, 813 ) -> Dict[torch.device, List[List[torch.Tensor]]]: 814 r""" 815 Return device parameters assigned per rank. 816 817 :class:`dict` mapping each device to a :class:`list` of the per-rank parameter 818 lists filtered to only include the parameters stored on that device. 819 Each per-rank parameter list gives the parameters assigned to that rank 820 to update. 821 822 This is used for constructing the parameter buckets if 823 ``parameters_as_bucket_view=True``. 824 825 Let ``dev_i`` denote the ``i``th device for this rank. Then: 826 ``dev_0`` maps to a list containing: 827 rank 0's assigned parameters stored on ``dev_0``, 828 rank 1's assigned parameters stored on ``dev_0``, 829 ... 830 ``dev_1`` maps to a list containing: 831 rank 0's assigned parameters stored on ``dev_1``, 832 rank 1's assigned parameters stored on ``dev_1``, 833 ... 834 ... 835 """ 836 assert self.parameters_as_bucket_view, ( 837 "`_device_to_params_per_rank` should only be used if " 838 "`parameters_as_bucket_view=True`" 839 ) 840 if len(self._device_to_params_per_rank_cache) == 0: 841 for rank, param_groups in enumerate(self._partition_parameters()): 842 for param_group in param_groups: 843 for param in param_group["params"]: 844 device = param.device 845 if device not in self._device_to_params_per_rank_cache: 846 self._device_to_params_per_rank_cache[device] = [ 847 [] for _ in range(self.world_size) 848 ] 849 self._device_to_params_per_rank_cache[device][rank].append( 850 param 851 ) 852 return self._device_to_params_per_rank_cache 853 854 def _get_min_index( 855 self, 856 values: List[int], 857 disallowed_indices: Optional[Set[int]] = None, 858 ) -> int: 859 r""" 860 Return ``values.index(min(values))``, except only uses one pass. 861 862 It also excludes any indices in ``disallowed_indices`` if provided. 863 864 Arguments: 865 values: (List[int]): :class:`list` of values. 866 disallowed_indices (Optional[Set[int]]): indices that are 867 disallowed from being the returned min index. 868 """ 869 min_index = -1 870 min_value = float("inf") 871 for i, value in enumerate(values): 872 if disallowed_indices and i in disallowed_indices: 873 continue 874 if value < min_value: 875 min_value = value 876 min_index = i 877 assert min_index >= 0, "All indices are disallowed" 878 return min_index 879 880 def _assign_bucket_subset_to_rank( 881 self, 882 bucket_index: int, 883 bucket_params: List[torch.Tensor], 884 bucket_offset: int, 885 assigned_rank: int, 886 assigned_ranks_per_bucket: List[Set[int]], 887 ) -> None: 888 r""" 889 Assign ``bucket_params`` to the rank with the least size assigned so far and collects relevant information. 890 891 The model parameters given by ``bucket_params`` represents a (possibly non-strict) 892 subset of the parameters corresponding to a :class:`DistributedDataParallel` bucket. 893 894 Arguments: 895 bucket_index (int): index of the :class:`DistributedDataParallel` 896 gradient bucket. 897 bucket_params (List[torch.Tensor]): subset of the parameters 898 corresponding to the bucket to assign. 899 bucket_offset (int): offset giving the index of the first element 900 in ``bucket_params`` in the bucket's full parameter list. 901 assigned_rank (int): group rank to assign to. 902 assigned_ranks_per_bucket (List[Set[int]]): :class:`set` of group ranks 903 assigned to each bucket. 904 """ 905 overlap_info = self._overlap_info 906 if len(bucket_params) == 0: 907 raise ValueError("Empty bucket assignment") 908 params_per_rank = overlap_info.params_per_rank 909 offsets = overlap_info.offsets 910 911 self._bucket_assignments_per_rank_cache[assigned_rank][ 912 bucket_index 913 ] = _DDPBucketAssignment(bucket_index, bucket_params, bucket_offset) 914 if self.global_rank == assigned_rank: 915 offsets[bucket_index] = len(params_per_rank[assigned_rank]) 916 params_per_rank[assigned_rank].extend(bucket_params) 917 assigned_ranks_per_bucket[bucket_index].add(assigned_rank) 918 self._overlap_info.num_bucket_assignments += 1 919 920 @property 921 def _bucket_assignments_per_rank(self) -> List[Dict[int, _DDPBucketAssignment]]: 922 r""" 923 Return DDP bucket parameters assigned per rank. 924 925 :class:`list` of length world size consisting of :class:`dict` s 926 mapping bucket indices to :class:`_DDPBucketAssignment` s for each 927 rank. 928 """ 929 assert ( 930 self._overlap_with_ddp 931 ), "`_bucket_assignments_per_rank` only be used if `overlap_with_ddp=True`" 932 if len(self._bucket_assignments_per_rank_cache) > 0: 933 return self._bucket_assignments_per_rank_cache 934 935 overlap_info = self._overlap_info 936 assert overlap_info.status == _OverlapStatus.INITIALIZED 937 938 self._bucket_assignments_per_rank_cache = [{} for _ in range(self.world_size)] 939 params_per_bucket = overlap_info.params_per_bucket 940 941 if overlap_info.shard_buckets: 942 # Define the assignment threshold to approximate uniformity 943 assert overlap_info.total_size is not None, "`total_size` was not computed" 944 threshold = overlap_info.total_size / self.world_size # type: ignore[operator] 945 size_per_rank = [0 for _ in range(self.world_size)] 946 947 num_buckets = len(params_per_bucket) 948 overlap_info.assigned_ranks_per_bucket = [set() for _ in range(num_buckets)] 949 assigned_ranks_per_bucket = overlap_info.assigned_ranks_per_bucket 950 if not overlap_info.shard_buckets: 951 # Assign each DDP bucket entirely to a single rank 952 for bucket_index, bucket_params in enumerate(params_per_bucket): 953 assert len(bucket_params) > 0, "Empty bucket" 954 assigned_rank = self._get_assigned_rank(bucket_index) 955 self._assign_bucket_subset_to_rank( 956 bucket_index, 957 bucket_params, 958 0, 959 assigned_rank, 960 assigned_ranks_per_bucket, 961 ) 962 else: 963 # Assign each DDP bucket to possibly multiple ranks 964 # Specifically, sort the DDP buckets by increasing size, and for 965 # each bucket, iteratively assign the maximal unassigned subset 966 # with size less than `threshold` to the rank with the least total 967 # size so far -- each such assignment is represented by a 968 # `_DDPBucketAssignment` instance and only contains parameters from 969 # a single DDP bucket 970 params_per_bucket_enum = sorted( 971 enumerate(params_per_bucket), key=lambda x: sum(p.numel() for p in x[1]) 972 ) 973 for bucket_index, bucket_params in params_per_bucket_enum: 974 assert len(bucket_params) > 0, "Empty bucket" 975 bucket_offset = 0 976 assignment_size = 0 977 for param_index, param in enumerate(bucket_params): 978 param_numel = param.numel() 979 if ( 980 assignment_size + param_numel >= threshold 981 and param_index > bucket_offset 982 ): 983 assigned_rank = self._get_min_index( 984 size_per_rank, assigned_ranks_per_bucket[bucket_index] 985 ) 986 # Include up to but not including the parameter that 987 # exceeded the threshold 988 self._assign_bucket_subset_to_rank( 989 bucket_index, 990 bucket_params[bucket_offset:param_index], 991 bucket_offset, 992 assigned_rank, 993 assigned_ranks_per_bucket, 994 ) 995 size_per_rank[assigned_rank] += assignment_size 996 bucket_offset = param_index 997 assignment_size = 0 998 assignment_size += param_numel 999 # Assign the remainder of the bucket so that no assignment 1000 # spans across two buckets 1001 assigned_rank = self._get_min_index( 1002 size_per_rank, assigned_ranks_per_bucket[bucket_index] 1003 ) 1004 self._assign_bucket_subset_to_rank( 1005 bucket_index, 1006 bucket_params[bucket_offset:], 1007 bucket_offset, 1008 assigned_rank, 1009 assigned_ranks_per_bucket, 1010 ) 1011 size_per_rank[assigned_rank] += assignment_size 1012 1013 return self._bucket_assignments_per_rank_cache 1014 1015 def _local_step( 1016 self, 1017 gradients: Optional[List[Optional[torch.Tensor]]] = None, 1018 closure: Optional[Callable[[], float]] = None, 1019 **kwargs: Any, 1020 ) -> Optional[float]: 1021 r""" 1022 Perform a single optimizer step without syncing parameters across ranks. 1023 1024 Arguments: 1025 gradients (list[Optional[torch.Tensor]], optional): a :class:`list` 1026 of length equal to the number of parameters assigned to this 1027 rank containing gradient tensors or ``None`` as its elements; 1028 a ``None`` in the :class:`list` indicates that the 1029 corresponding parameter should not be updated. 1030 If the argument itself is ``None``, then all parameters are 1031 updated, and the gradients are assumed to be already populated. 1032 (default: ``None``) 1033 closure (Callable): a closure that re-evaluates the model and 1034 returns the loss; optional for most optimizers and should be 1035 ``None`` if ``gradients`` is not ``None``; (default: ``None``) 1036 Returns: 1037 Optional loss depending on the underlying local optimizer. 1038 1039 .. warning:: 1040 The argument ``gradients`` should only be specified (i.e. not 1041 ``None``) if ``overlap_with_ddp=True``, in which case 1042 :class:`ZeroRedundancyOptimizer` wraps a functional optimizer. 1043 """ 1044 Join.notify_join_context(self) 1045 # Check if the model trainability has changed 1046 is_trainable_mask = self._get_is_trainable_mask() 1047 if is_trainable_mask != self._is_trainable_mask: 1048 if self._overlap_with_ddp: 1049 raise RuntimeError( 1050 "ZeroRedundancyOptimizer with `overlap_with_ddp=True` " 1051 "does not support changing parameter trainability at run " 1052 "time" 1053 ) 1054 logger.warning( 1055 "ZeroRedundancyOptimizer detected that the trainable " 1056 "parameters changed; rebuilding the parameter buckets if " 1057 "enabled" 1058 ) 1059 self._build_param_buckets() 1060 self._is_trainable_mask = is_trainable_mask 1061 1062 # Sync the exposed `param_groups` attributes to the local optimizer in 1063 # case they have been updated 1064 self._sync_param_groups(self.param_groups, self.optim.param_groups) 1065 1066 # Run the optimizer step on this shard only 1067 if gradients is None: 1068 loss = ( 1069 self.optim.step(**kwargs) 1070 if closure is None 1071 else self.optim.step(closure=closure, **kwargs) 1072 ) 1073 else: 1074 assert self._overlap_with_ddp, ( 1075 "Specifying `gradients` should not " 1076 "be used when `overlap_with_ddp=False`" 1077 ) 1078 assert ( 1079 closure is None 1080 ), "`closure` is not supported when using a local functional optimizer" 1081 loss = self.optim.step(gradients=gradients) 1082 1083 # Sync any updated attributes in the local optimizer to the exposed 1084 # `param_groups` 1085 self._sync_param_groups(self.optim.param_groups, self.param_groups) 1086 1087 return loss 1088 1089 def step( 1090 self, 1091 closure: Optional[Callable[[], float]] = None, 1092 **kwargs: Any, 1093 ) -> Optional[float]: 1094 r""" 1095 Perform a single optimizer step and syncs parameters across all ranks. 1096 1097 Arguments: 1098 closure (Callable): a closure that re-evaluates the model and 1099 returns the loss; optional for most optimizers. 1100 Returns: 1101 Optional loss depending on the underlying local optimizer. 1102 1103 .. note: Any extra parameters are passed to the base optimizer as-is. 1104 """ 1105 if self._overlap_with_ddp: 1106 logger.warning( 1107 "`step()` should not be included in the training loop when " 1108 "`overlap_with_ddp=True`" 1109 ) 1110 return None 1111 1112 # Perform the local optimizer step 1113 loss = self._local_step(closure=closure, **kwargs) 1114 1115 # Sync all of the updated parameter shards across the ranks 1116 self._sync_params() 1117 1118 return loss 1119 1120 def join_hook(self, **kwargs): 1121 r""" 1122 Return the ZeRO join hook. 1123 1124 It enables training on uneven inputs by 1125 shadowing the collective communications in the optimizer step. 1126 1127 Gradients must be properly set before this hook is called. 1128 1129 Arguments: 1130 kwargs (dict): a :class:`dict` containing any keyword arguments 1131 to modify the behavior of the join hook at run time; all 1132 :class:`Joinable` instances sharing the same join context 1133 manager are forwarded the same value for ``kwargs``. 1134 1135 This hook does not support any keyword arguments; i.e. ``kwargs`` is 1136 unused. 1137 """ 1138 return _ZeROJoinHook(self) 1139 1140 @property 1141 def join_device(self) -> torch.device: 1142 r"""Return default device.""" 1143 return self._default_device 1144 1145 @property 1146 def join_process_group(self) -> Any: 1147 r"""Return process group.""" 1148 return self.process_group 1149 1150 def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 1151 r""" 1152 Load the state pertaining to the given rank from the input ``state_dict``, updating the local optimizer as needed. 1153 1154 Arguments: 1155 state_dict (dict): optimizer state; should be an object returned 1156 from a call to :meth:`state_dict`. 1157 1158 Raises: 1159 RuntimeError: if ``overlap_with_ddp=True`` and this method is 1160 called before this :class:`ZeroRedundancyOptimizer` instance 1161 has been fully initialized, which happens once 1162 :class:`DistributedDataParallel` gradient buckets have been 1163 rebuilt. 1164 """ 1165 self._check_overlap_initialized() 1166 1167 for index, value in state_dict["state"].items(): 1168 param = self._index_to_param[index] 1169 if self._param_to_rank[param] != self.rank: 1170 # Clear any state irrelevant to this rank 1171 state_dict["state"][index] = None 1172 else: 1173 # Load the parameter state to the local optimizer 1174 self.optim.state[param] = _recursive_copy_to_device( 1175 value, non_blocking=True, device=param.device 1176 ) 1177 # Force zero-dimensional tensors (like Adam "step") on CPU 1178 for state_name, state_value in self.optim.state[param].items(): 1179 if torch.is_tensor(state_value) and state_value.dim() == 0: 1180 self.optim.state[param][state_name] = state_value.cpu() 1181 1182 super().load_state_dict(state_dict) 1183 1184 # Sync the input state with the exposed and local optimizer states 1185 self._sync_param_groups(state_dict["param_groups"], self.param_groups) 1186 self._sync_param_groups(self.param_groups, self.optim.param_groups) 1187 1188 def state_dict(self) -> Dict[str, Any]: 1189 r""" 1190 Return the last global optimizer state known to this rank. 1191 1192 .. warning: 1193 If the state has not been consolidated to this rank, this raises a 1194 runtime error, and even if it has, the state may not be up-to-date, 1195 depending on when :meth:`consolidate_state_dict` was last called. 1196 1197 Raises: 1198 RuntimeError: if ``overlap_with_ddp=True`` and this method is 1199 called before this :class:`ZeroRedundancyOptimizer` instance 1200 has been fully initialized, which happens once 1201 :class:`DistributedDataParallel` gradient buckets have been 1202 rebuilt; or if this method is called without a preceding call 1203 to :meth:`consolidate_state_dict`. 1204 """ 1205 self._check_overlap_initialized() 1206 1207 if len(self._all_state_dicts) == 0: 1208 raise RuntimeError( 1209 "Optimizer state has not been consolidated on this rank. " 1210 f"Please call `consolidate_state_dict(to={self.rank})` on " 1211 "all ranks beforehand if you meant to save the global state." 1212 ) 1213 1214 # Get the possibly-stale global optimizer state that uses global 1215 # parameter indexing 1216 state_dict = super().state_dict() 1217 1218 # Update the global optimizer state with local state information, 1219 # factoring in the translation from local to global indexing 1220 for rank, local_state_dict in enumerate(self._all_state_dicts): 1221 local_param_groups = local_state_dict["param_groups"] 1222 global_param_groups = self._partition_parameters()[rank] 1223 assert len(local_param_groups) == len( 1224 global_param_groups 1225 ), "Mismatch between number of local and global parameter groups" 1226 1227 for local_param_group, global_param_group in zip( 1228 local_param_groups, global_param_groups 1229 ): 1230 # `local_param_group` stores local indices, while 1231 # `global_param_group` stores the tensors directly 1232 local_param_indices = local_param_group["params"] 1233 global_params = global_param_group["params"] 1234 1235 assert len(local_param_indices) == len( 1236 global_params 1237 ), "Mismatch between number of local and global parameters in parameter group" 1238 for local_param_index, global_param in zip( 1239 local_param_indices, global_params 1240 ): 1241 # Update the global parameter state, if any 1242 if local_param_index in local_state_dict["state"]: 1243 global_param_index = self._param_to_index[global_param] 1244 state_dict["state"][global_param_index] = local_state_dict[ 1245 "state" 1246 ][local_param_index] 1247 1248 # Sort the parameters in the state 1249 state_dict["state"] = dict(sorted(state_dict["state"].items())) 1250 return state_dict 1251 1252 @staticmethod 1253 def _sync_param_groups( 1254 src_param_groups: List[Dict[Any, Any]], 1255 dst_param_groups: List[Dict[Any, Any]], 1256 ) -> None: 1257 r""" 1258 Sync the attributes from the source parameter groups to the destination parameter groups. 1259 1260 Example attributes include learning rate or scheduler attributes. The 1261 two parameter groups should have the same length (i.e. same number of 1262 parameter groups). 1263 1264 Arguments: 1265 src_param_groups (list[dict]): parameter groups giving the 1266 attribute settings to copy. 1267 dst_param_groups (list[dict]): parameter groups giving the 1268 attribute settings to set. 1269 """ 1270 assert len(src_param_groups) == len( 1271 dst_param_groups 1272 ), "Mismatch between number of source and destination parameter groups" 1273 for src_param_group, dst_param_group in zip(src_param_groups, dst_param_groups): 1274 # Sync all attributes except the parameters 1275 for attr in filter(lambda x: x != "params", src_param_group.keys()): 1276 dst_param_group[attr] = src_param_group[attr] 1277 1278 def _build_param_buckets(self) -> None: 1279 r""" 1280 Build parameter buckets if ``parameters_as_bucket_view=True``. 1281 1282 For each device that stores this rank's parameters, there is a 1283 bucket (represented as a tensor) containing all of the parameters on 1284 that device that are assigned to a given rank in the parameter update 1285 partition. 1286 1287 This method is called in the constructor and any time parameter 1288 trainability is changed. 1289 1290 .. warning:: 1291 The current implementation assumes that all of the parameters in a 1292 bucket are of the same dense type when allocating the bucket's 1293 tensor. 1294 1295 .. warning:: 1296 If the model parameters are stored across more than one device, 1297 then the storage partitioning must be the same across all 1298 processes in order for parameter synchronization to work. 1299 """ 1300 if not self.parameters_as_bucket_view or self._overlap_with_ddp: 1301 return 1302 1303 # `self._buckets[i][j]` are the parameters stored on device i and 1304 # assigned to rank j 1305 num_devices = len(self._device_to_params_per_rank) 1306 self._buckets = [[] for _ in range(num_devices)] # type: ignore[assignment] 1307 1308 for dev_i, (device, params_per_rank) in enumerate( 1309 self._device_to_params_per_rank.items() 1310 ): 1311 for params in params_per_rank: 1312 bucket_size = 0 1313 dtype = None 1314 trainable_params = [] 1315 for param in params: 1316 if not _is_trainable(param): 1317 # Clone in case the parameter was previously part of 1318 # a bucket to avoid the data from being destroyed 1319 param.data = param.data.detach().clone() 1320 else: 1321 bucket_size += param.numel() 1322 trainable_params.append(param) 1323 dtype = param.dtype # assumes all same dtype 1324 1325 if bucket_size == 0: 1326 # Create a dummy bucket if there are no parameters 1327 bucket = torch.zeros(1, device=device) 1328 else: 1329 # Construct the bucket (assuming all dense and same dtype) 1330 bucket = torch.empty(bucket_size, dtype=dtype, device=device) 1331 offset = 0 1332 for param in trainable_params: 1333 offset_next = offset + param.numel() 1334 bucket[offset:offset_next].copy_(param.data.flatten()) 1335 param.data = bucket[offset:offset_next].view_as(param.data) 1336 offset = offset_next 1337 self._buckets[dev_i].append(bucket) # type: ignore[arg-type] 1338 1339 def _build_ddp_param_buckets(self) -> None: 1340 r""" 1341 Build the DDP bucket with parameters assigned to this rank. 1342 1343 For each DDP bucket with parameters assigned to this rank, flattens the 1344 data of those parameters into a single tensor and saves the tensor to 1345 the ``tensor`` attribute in the corresponding 1346 :class:`_DDPBucketAssignment` instance stored in 1347 ``self._bucket_assignments_per_rank``. 1348 1349 :class:`DistributedDataParallel` guarantees that the parameters 1350 corresponding to a gradient bucket have the same device and the same 1351 dtype. 1352 """ 1353 for bucket_assignments in self._bucket_assignments_per_rank: 1354 for bucket_assignment in bucket_assignments.values(): 1355 params = bucket_assignment.parameters 1356 bucket_size = 0 1357 dtype = None 1358 for param in params: 1359 assert _is_trainable(param), ( 1360 "Model parameter " 1361 "corresponding to a gradient in a DDP bucket should " 1362 "require a gradient" 1363 ) 1364 bucket_size += param.numel() 1365 dtype = param.dtype # assumes all same dtype 1366 assert bucket_size > 0, "Empty bucket" 1367 1368 # Construct the bucket tensor (assuming all dense and same dtype) 1369 tensor = torch.empty( 1370 bucket_size, dtype=dtype, device=bucket_assignment.device 1371 ) 1372 offset = 0 1373 for param in params: 1374 offset_next = offset + param.numel() 1375 tensor[offset:offset_next].copy_(param.data.flatten()) 1376 param.data = tensor[offset:offset_next].view_as(param.data) 1377 offset = offset_next 1378 bucket_assignment.tensor = tensor 1379 1380 def _verify_and_init_params( 1381 self, 1382 params: Any, 1383 ) -> Union[List[torch.Tensor], List[dict]]: 1384 r""" 1385 Verify the type of ``params`` and initializes ``self._all_params`` as a :class:`list` of all parameters. 1386 1387 The initializagtion will first make sure that provided ``params`` is valid. 1388 1389 Arguments: 1390 params (Any): Candidate parameter list or parameter groups to verify. 1391 1392 Raises: 1393 TypeError: ``params`` has an invalid type. 1394 ValueError: ``params`` is empty. 1395 1396 Returns: 1397 The persistent form of ``params`` to be passed into the parent 1398 :class:`Optimizer` constructor -- i.e. returns ``params`` as a 1399 :class:`list` to ensure that it can be iterated over again. 1400 """ 1401 if isinstance(params, torch.Tensor): 1402 raise TypeError( 1403 "`params` argument should be an iterable of " 1404 f"Tensors, but got {torch.typename(params)}" 1405 ) 1406 try: 1407 all_params = list(params) 1408 except TypeError as e: 1409 raise TypeError( 1410 "`params` argument should be an iterable of Tensors" 1411 f" or dicts, but got {torch.typename(params)}" 1412 ) from e 1413 if len(all_params) == 0: 1414 raise ValueError("ZeroRedundancyOptimizer got an empty parameter list") 1415 all_tensors = True 1416 all_dicts = True 1417 for param in all_params: 1418 all_tensors &= isinstance(param, torch.Tensor) 1419 all_dicts &= isinstance(param, dict) 1420 if not all_tensors and not all_dicts: 1421 raise TypeError( 1422 "`params` argument should be an iterable of Tensors or dicts" 1423 ) 1424 # Ensure that `self._all_params` contains a list of all parameters 1425 if all_tensors: 1426 self._all_params = all_params 1427 elif all_dicts: 1428 self._all_params = [] 1429 # `all_params` contains parameter groups (not parameters) 1430 for param_group in all_params: 1431 if "params" not in param_group: 1432 raise ValueError( 1433 "Each parameter group passed-in via `params` must " 1434 "have a 'params' key mapping to the parameters in " 1435 "the group" 1436 ) 1437 self._all_params.extend(param_group["params"]) 1438 return all_params 1439 1440 def _verify_same_dense_param_type(self) -> None: 1441 r""" 1442 Verify that all parameters are of the same dense type. 1443 1444 The method assumes that ``self._all_params`` has been initialized 1445 and is non-empty. 1446 1447 Raises: 1448 ValueError: ``params`` contains sparse parameters or parameters 1449 of varying dense types. 1450 1451 NOTE: This method can be removed once support for sparse parameters 1452 and varying parameter types is added. 1453 """ 1454 typename = torch.typename(self._all_params[0]) 1455 if self._all_params[0].is_sparse: 1456 raise ValueError( 1457 "ZeroRedundancyOptimizer only supports using " 1458 "the same dense type for all parameters but got " 1459 f"{typename}" 1460 ) 1461 for param in self._all_params[1:]: 1462 other_typename = torch.typename(param) 1463 if other_typename != typename: 1464 raise ValueError( 1465 "ZeroRedundancyOptimizer only supports " 1466 "using the same dense type for all " 1467 f"parameters but got both {typename} and " 1468 f"{other_typename}" 1469 ) 1470 1471 def _get_is_trainable_mask(self) -> List[bool]: 1472 r"""Return a boolean mask indicating if each parameter is trainable (``requires_grad``) or not.""" 1473 return list(map(_is_trainable, self._all_params)) 1474 1475 def _init_local_optimizer(self) -> None: 1476 r""" 1477 Initialize this rank's local optimizer, responsible for its subset of the parameters. 1478 1479 The local optimizer is saved in ``self.optim``. 1480 """ 1481 assert ( 1482 self._optim_constructor is not None 1483 ), "The local optimizer class has not been set" 1484 1485 param_groups = self._partition_parameters()[self.rank] 1486 # `overlap_with_ddp=True` requires a local functional optimizer 1487 if self._overlap_with_ddp: 1488 # Functional optimizers only support a single parameter group and 1489 # require passing in the parameters as a list 1490 assert len(param_groups) == 1, ( 1491 "Initializing the local " 1492 "functional optimizer with more than one parameter group" 1493 ) 1494 params = param_groups[0]["params"] 1495 # Try to pass `_allow_empty_param_list=True` to avoid erroring 1496 if ( 1497 "_allow_empty_param_list" 1498 in inspect.signature(self._optim_constructor).parameters 1499 ): 1500 self.optim: Any = self._optim_constructor( 1501 params, **self._optim_defaults, _allow_empty_param_list=True 1502 ) 1503 else: 1504 logger.warning( 1505 "%s does not support the argument " 1506 "`_allow_empty_param_list`; ZeroRedundancyOptimizer may " 1507 "error due to an empty parameter list", 1508 self._optim_constructor, 1509 ) 1510 self.optim: Any = self._optim_constructor(params, **self._optim_defaults) # type: ignore[no-redef] 1511 1512 # Log information about the DDP and ZeRO bucketing 1513 if dist.get_debug_level() != dist.DebugLevel.OFF: 1514 local_numel = sum(p.numel() for p in params) 1515 num_assigned_buckets = len( 1516 self._bucket_assignments_per_rank[self.global_rank] 1517 ) 1518 logger.info( 1519 "rank %s with %s parameters " "across %s buckets", 1520 self.global_rank, 1521 local_numel, 1522 num_assigned_buckets, 1523 ) 1524 if self.global_rank == 0: 1525 logger.info( 1526 "%s DDP " "buckets and " "%s bucket " "assignments", 1527 len(self._overlap_info.params_per_bucket), 1528 self._overlap_info.num_bucket_assignments, 1529 ) 1530 else: 1531 # NOTE: Passing `param_groups` into the local optimizer constructor 1532 # bypasses the empty parameter list check 1533 self.optim: Optimizer = self._optim_constructor(param_groups, **self._optim_defaults) # type: ignore[no-redef] 1534 1535 # TODO: Manually add `self.param_groups` if using a functional 1536 # optimizer; remove this if/when the functional optimizers support 1537 # multiple parameter groups 1538 if self._overlap_with_ddp and not hasattr(self.optim, "param_groups"): 1539 assert hasattr(self.optim, "param_group"), ( 1540 "The functional optimizer should set at least one of the " 1541 "attributes `param_group` or `param_groups`" 1542 ) 1543 self.optim.param_groups = [self.optim.param_group] # type: ignore[attr-defined] 1544 1545 self._sync_param_groups(self.optim.param_groups, self.param_groups) 1546 1547 def _init_zero_for_overlap(self) -> None: 1548 r"""Perform a delayed initialization of the local optimizer and the supporting data structures.""" 1549 assert self._overlap_with_ddp, ( 1550 "`_init_zero_for_overlap()` should only be called when " 1551 "`overlap_with_ddp=True`" 1552 ) 1553 self._overlap_info.status = _OverlapStatus.INITIALIZED 1554 self._clear_cache() 1555 self._partition_parameters(self._overlap_info.params_per_rank) 1556 self._build_ddp_param_buckets() 1557 self._init_local_optimizer() 1558 1559 def _get_assigned_rank(self, bucket_index: int) -> int: 1560 r""" 1561 Return the single rank assigned to a :class:`DistributedDataParallel` gradient bucket. 1562 1563 Arguments: 1564 bucket_index (int): index of the :class:`DistributedDataParallel` 1565 bucket for which to get the assigned rank. 1566 """ 1567 assert not self._overlap_info.shard_buckets, ( 1568 "The bucket assignment requires global bucket information and " 1569 "will be computed later; there should be no need to use this " 1570 "method" 1571 ) 1572 return bucket_index % self.world_size 1573 1574 def _check_overlap_initialized(self): 1575 r""" 1576 Check the delayed initialization depending on the value of ``overlap_with_ddp``. 1577 1578 The delayed initialization has occurred (see 1579 :meth:`_init_zero_for_overlap`) if ``overlap_with_ddp=True``, and 1580 raises a ``RuntimeError`` if not. This should preface methods that 1581 should not be run before that delayed initialization. 1582 1583 Raises: 1584 RuntimeError: if ``overlap_with_ddp=True`` and 1585 :meth:`_init_zero_for_overlap` has not been called. 1586 """ 1587 if ( 1588 self._overlap_with_ddp 1589 and self._overlap_info.status != _OverlapStatus.INITIALIZED 1590 ): 1591 raise RuntimeError( 1592 "This method should not be called until this " 1593 "ZeroRedundancyOptimizer instance has been fully " 1594 "initialized" 1595 ) 1596 1597 def _get_optimizer_constructor(self, optimizer_class: Any) -> Any: 1598 r""" 1599 Return the optimizer constructor using validation and transformation depending on ``overlap_with_ddp``. 1600 1601 Returns: 1602 - ``optimizer_class`` if ``overlap_with_ddp=False`` and 1603 ``optimizer_class`` is not a functional optimizer. 1604 - ``optimizer_class`` if ``overlap_with_ddp=True`` and 1605 ``optimizer_class`` is already a functional optimizer. 1606 - The functional equivalent of ``optimizer_class`` if 1607 ``overlap_with_ddp=True`` and ``optimizer_class`` is not 1608 already a functional optimizer (assuming the equivalent 1609 exists). 1610 1611 Raises: 1612 ValueError: 1613 1614 - if ``overlap_with_ddp=True`` but ``optimizer_class`` is 1615 neither a functional optimizer nor translatable to a 1616 functional optimizer. 1617 - if ``overlap_with_ddp=False`` and ``optimizer_class`` is a 1618 functional optimizer. 1619 """ 1620 functional_optims = functional_optim_map.values() 1621 if not self._overlap_with_ddp: 1622 if optimizer_class in functional_optims: 1623 # Using a functional optimizer is only supported when 1624 # `overlap_with_ddp=True` 1625 raise ValueError( 1626 f"Passing in a functional optimizer {optimizer_class} " 1627 "when `overlap_with_ddp=False`" 1628 ) 1629 else: 1630 return optimizer_class 1631 else: 1632 if optimizer_class in functional_optims: 1633 # Already a functional optimizer 1634 return optimizer_class 1635 elif optimizer_class in functional_optim_map: 1636 # Translate the passed-in optimizer class to its functional 1637 # equivalent if `overlap_with_ddp=True` 1638 optim_constructor = functional_optim_map[optimizer_class] 1639 logger.info( 1640 "Using the functional optimizer %s " 1641 "instead of %s since " 1642 "`overlap_with_ddp=True`", 1643 optim_constructor, 1644 optimizer_class, 1645 ) 1646 return optim_constructor 1647 else: 1648 raise ValueError( 1649 "Using `ddp_with_overlap=True` requires using a " 1650 "functional optimizer, but there is no supported functional " 1651 f"optimizer equivalent for {optimizer_class}" 1652 ) 1653