1# mypy: allow-untyped-defs 2import weakref 3from typing import Any, Callable, List, Optional 4 5import torch 6import torch.distributed as dist 7from torch.distributed.optim import ZeroRedundancyOptimizer 8from torch.distributed.optim.zero_redundancy_optimizer import _OverlapStatus 9from torch.nn.parallel.distributed import DistributedDataParallel 10 11 12__all__ = ["hook_with_zero_step", "hook_with_zero_step_interleaved"] 13 14# Functional optimizers require passing a list of gradients to their `step()` 15# method, and ZeRO requires a functional optimizer to overlap with DDP 16# Passing a `None` instead of an actual gradient indicates to the optimizer 17# to not update the corresponding parameter 18_NO_PARAM_UPDATE: None = None 19 20 21def _perform_local_step( 22 bucket: dist.GradBucket, 23 zero: ZeroRedundancyOptimizer, 24 rank: int, 25): 26 r""" 27 Perform a local optimizer step using the gradients provided by ``bucket``. 28 29 Arguments: 30 bucket (dist.GradBucket): the bucket providing the gradients. 31 zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer` 32 instance to perform the :meth:`_local_step`. 33 rank (int): the calling process's rank. 34 35 .. warning:: 36 This function assumes that appropriate synchronization has taken place 37 so that the bucket's gradients can be used. 38 """ 39 overlap_info = zero._overlap_info 40 bucket_index = bucket.index() 41 assert ( 42 len(zero.optim.param_groups) == 1 43 ), "Overlapping DDP with ZeRO only supports a single parameter group" 44 45 # Construct the `gradients` input for the local optimizer step, which 46 # expects `None` in a list position to indicate that the corresponding 47 # parameter should not be updated 48 num_local_optim_params = len(zero.optim.param_groups[0]["params"]) 49 gradients: List[Optional[torch.Tensor]] = [ 50 _NO_PARAM_UPDATE for _ in range(num_local_optim_params) 51 ] 52 assert ( 53 bucket_index in overlap_info.offsets 54 ), f"Bucket index {bucket_index} was not assigned to rank {rank}" 55 gradients_offset = overlap_info.offsets[bucket_index] 56 bucket_assignment = zero._bucket_assignments_per_rank[rank][bucket_index] 57 bucket_offset = bucket_assignment.offset 58 length = len(bucket_assignment.parameters) 59 bucket_gradients = bucket.gradients()[bucket_offset : bucket_offset + length] 60 for i, grad in enumerate(bucket_gradients): 61 gradients[gradients_offset + i] = grad 62 63 zero._local_step(gradients) 64 65 66def _broadcast_bucket( 67 bucket_index: int, 68 zero: ZeroRedundancyOptimizer, 69): 70 r""" 71 Broadcasts a bucket's parameters. 72 73 Arguments: 74 bucket_index (int): the index of the bucket corresponding to the 75 parameters to broadcast. 76 zero (ZeroRedundancyOptimizer): the calling process's 77 :class:`ZeroRedundancyOptimizer` instance. 78 """ 79 overlap_info = zero._overlap_info 80 assert ( 81 len(overlap_info.assigned_ranks_per_bucket) > bucket_index 82 ), "`assigned_ranks_per_bucket` is not fully constructed" 83 # Sort to ensure the same ordering across ranks 84 assigned_ranks = sorted(overlap_info.assigned_ranks_per_bucket[bucket_index]) 85 assert len(assigned_ranks) > 0, ( 86 f"Bucket {bucket_index} should be " "assigned to at least one rank" 87 ) 88 for assigned_rank in assigned_ranks: 89 bucket_assignments = zero._bucket_assignments_per_rank[assigned_rank] 90 if bucket_index in bucket_assignments: 91 overlap_info.broadcast_handles.append( 92 dist.broadcast( 93 bucket_assignments[bucket_index].tensor, 94 src=dist.get_global_rank(zero.process_group, assigned_rank), 95 group=zero.process_group, 96 async_op=True, 97 ) 98 ) 99 100 101def _save_ddp_bucket_info( 102 bucket: dist.GradBucket, 103 zero: ZeroRedundancyOptimizer, 104): 105 r""" 106 Save :class:`DistributedDataParallel` gradient bucket information for :class:`ZeroRedundancyOptimizer` instance ``zero``. 107 108 In particular, this function is meant to be called upon seeing each 109 gradient bucket to use when overlapping, meaning it does not save or compute any global 110 information. 111 112 Arguments: 113 bucket (dist.GradBucket): the current gradient bucket. 114 zero (ZeroRedundancyOptimizer): the calling process's 115 :class:`ZeroRedundancyOptimizer` instance. 116 """ 117 overlap_info = zero._overlap_info 118 bucket_params = bucket.parameters() 119 assert len(bucket_params) > 0, "Empty bucket" 120 121 # Save the parameters in the bucket 122 overlap_info.params_per_bucket.append(bucket_params) 123 if overlap_info.shard_buckets: 124 # Additionally save the bucket size for the assignment heuristic to use 125 bucket_size = 0 126 for param in bucket_params: 127 bucket_size += param.numel() 128 assert overlap_info.total_size is not None 129 overlap_info.total_size += bucket_size 130 131 132def _hook_with_zero_step_setup( 133 ddp_ref: weakref.ReferenceType, 134 zero: ZeroRedundancyOptimizer, 135 bucket: dist.GradBucket, 136): 137 r""" 138 Encapsulate the setup logic for :func:`hook_with_zero_step` and :func:`hook_with_zero_step_interleaved`. 139 140 This means the logic to run in the 141 hook before the backward pass and optimizer step can actually be 142 overlapped. This is factored out since it is common to both 143 :func:`hook_with_zero_step` and :func:`hook_with_zero_step_interleaved`. 144 145 Arguments: 146 ddp_ref (weakref.ReferenceType): weak reference to the process's 147 :class:`DistributedDataParallel` instance. 148 zero (ZeroRedundancyOptimizer): the calling process's 149 :class:`ZeroRedundancyOptimizer` instance. 150 bucket (dist.GradBucket): the current gradient bucket. 151 """ 152 # Proceed as normal until the DDP buckets have been rebuilt 153 if not ddp_ref()._has_rebuilt_buckets: # type: ignore[union-attr] 154 assert zero._overlap_info.status == _OverlapStatus.UNINITIALIZED 155 return 156 157 bucket_index = bucket.index() 158 overlap_info = zero._overlap_info 159 if overlap_info.status == _OverlapStatus.UNINITIALIZED: 160 overlap_info.status = _OverlapStatus.DDP_HAS_REBUILT_BUCKETS 161 162 if overlap_info.status == _OverlapStatus.DDP_HAS_REBUILT_BUCKETS: 163 if bucket_index == 0 and len(overlap_info.params_per_bucket) > 0: 164 # This corresponds to the first bucket of the backward pass 165 # immediately after all information has been saved, so we 166 # can perform the delayed ZeRO initialization 167 zero._init_zero_for_overlap() 168 else: 169 # Once DDP buckets have been rebuilt but ZeRO has not been 170 # properly initialized yet, save the information needed 171 _save_ddp_bucket_info(bucket, zero) 172 173 174def hook_with_zero_step( 175 hook: Callable[[Any, dist.GradBucket], torch.futures.Future], 176 ddp: DistributedDataParallel, 177 zero: ZeroRedundancyOptimizer, 178 shard_buckets: bool = False, 179) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: 180 r""" 181 Modify ``hook`` to overlap :class:`ZeroRedundancyOptimizer` optimizer step with :class:`DistributedDataParallel` backward pass. 182 183 This approach overlaps the optimizer computation and communication with the 184 backward communication. In particular, the backward computation proceeds 185 contiguously, and the optimizer computation follows, overlapping with 186 outstanding backward communication (i.e. all-reduces) and possibly other 187 optimizer communication (i.e. broadcasts). 188 The optimizer step computation begins after the last gradient bucket computation has finished. 189 190 This approach may be preferred over :meth:`hook_with_zero_step_interleaved` 191 if communication is relatively slow compared to computation. 192 193 Arguments: 194 hook (Callable[[Any, dist.GradBucket], torch.futures.Future]): the hook 195 to modify. 196 ddp (DistributedDataParallel): the :class:`DistributedDataParallel` 197 instance to use. 198 zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer` 199 instance to use. 200 shard_buckets (bool): if ``True``, then the assignment of each 201 :class:`DistributedDataParallel` bucket is partitioned across 202 possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e. 203 across possibly multiple ranks) to approximate uniformity; if 204 ``False``, then each bucket is wholly assigned to a single 205 :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank). 206 207 Returns: 208 The modified hook. 209 210 Raises: 211 ValueError: if ``zero`` was constructed with ``overlap_with_ddp=False``. 212 RuntimeError: if using any backend other than NCCL/HCCL since currently 213 Gloo may hang. 214 215 .. warning:: 216 Given the way that overlapping :class:`DistributedDataParallel` with 217 :class:`ZeroRedundancyOptimizer` is currently implemented, the first 218 two or three training iterations do not perform parameter updates in 219 the optimizer step, depending on if ``static_graph=False`` or 220 ``static_graph=True``, respectively. This is because it needs 221 information about the gradient bucketing strategy used by 222 :class:`DistributedDataParallel`, which is not finalized until the 223 second forward pass if ``static_graph=False`` or until the third 224 forward pass if ``static_graph=True``. 225 """ 226 if not zero._overlap_with_ddp: 227 raise ValueError( 228 "ZeroRedundancyOptimizer must be constructed with " 229 "`overlap_with_ddp=True` to use this hook properly" 230 ) 231 ddp_ref = weakref.ref(ddp) 232 233 # NOTE: Gloo may hang with this overlapping approach, so we require 234 # NCCL/HCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300 235 pg = dist.get_backend(ddp_ref().process_group) # type: ignore[union-attr] 236 if (pg != dist.Backend.NCCL) and (pg != "hccl"): 237 raise RuntimeError( 238 "Overlapping DDP with ZeRO using this approach currently requires " 239 "NCCL/HCCL backend to avoid hangs" 240 ) 241 242 if shard_buckets: 243 zero._overlap_info.shard_buckets = True 244 zero._overlap_info.total_size = 0 245 246 def hook_with_zero_fn( 247 state: Any, 248 bucket: dist.GradBucket, 249 ) -> torch.futures.Future[torch.Tensor]: 250 r""" 251 Return :class:`Future` that runs the optimizer step if this corresponds to the last gradient bucket. 252 253 Perform equivalent of :class:`ZeroRedundancyOptimizer` :meth:`step` if ``bucket`` is last gradient bucket. 254 The function gives a gradient bucket tensor and 255 performs additional computation on the iteration that 256 the :class:`DistributedDataParallel` buckets are rebuilt to collect 257 information used to implement the modified hook. 258 259 Arguments: 260 state (Any): any state for the hook. 261 bucket (dist.GradBucket): the :class:`DistributedDataParallel` 262 gradient bucket. 263 """ 264 fut = hook(state, bucket) 265 _hook_with_zero_step_setup(ddp_ref, zero, bucket) 266 if zero._overlap_info.status != _OverlapStatus.INITIALIZED: 267 return fut 268 269 overlap_info = zero._overlap_info 270 bucket_index = bucket.index() 271 rank = zero.global_rank 272 273 assert overlap_info.status == _OverlapStatus.INITIALIZED 274 assert ( 275 len(overlap_info.assigned_ranks_per_bucket) > bucket_index 276 ), "`assigned_ranks_per_bucket` is not fully constructed" 277 assigned_to_bucket = ( 278 rank in overlap_info.assigned_ranks_per_bucket[bucket_index] 279 ) 280 281 # Save the bucket reference and all-reduce future for the final bucket 282 if assigned_to_bucket: 283 overlap_info.bucket_index_to_bucket[bucket_index] = bucket 284 overlap_info.bucket_index_to_future[bucket_index] = fut 285 286 # Check that buckets are indexed incrementally starting from 0 in the 287 # order of their autograd hooks firing 288 if len(overlap_info.bucket_indices_seen) > 0: 289 assert ( 290 overlap_info.bucket_indices_seen[-1] == bucket_index - 1 291 ), "Bucket indices are not in incremental order" 292 else: 293 assert bucket_index == 0, "Bucket indices do not start from 0" 294 overlap_info.bucket_indices_seen.append(bucket_index) 295 296 # Directly return the future without any optimizer computation if this 297 # is not the last bucket 298 num_buckets = len(overlap_info.params_per_bucket) 299 is_last_bucket = bucket_index == num_buckets - 1 300 if not is_last_bucket: 301 return fut 302 303 # Perform partial optimizer step on all buckets after the final 304 # bucket has been computed 305 # NOTE: This should not be chained as a callback to the last bucket's 306 # all-reduce future since that would add synchronization that delays 307 # all optimizer computation to wait for that last all-reduce 308 for bucket_index in range(num_buckets): 309 assigned_ranks = overlap_info.assigned_ranks_per_bucket[bucket_index] 310 if rank in assigned_ranks: 311 # Wait on the bucket's all-reduce future to ensure correct 312 # gradients 313 assert bucket_index in overlap_info.bucket_index_to_future, ( 314 f"All-reduce future for bucket {bucket_index} not saved " 315 f"on rank {rank}" 316 ) 317 allreduce_future = overlap_info.bucket_index_to_future[bucket_index] 318 allreduce_future.wait() 319 320 # Perform the partial optimizer step 321 curr_bucket = overlap_info.bucket_index_to_bucket[bucket_index] 322 _perform_local_step(curr_bucket, zero, rank) 323 324 _broadcast_bucket(bucket_index, zero) 325 326 # Ensure that all parameter updates are finished before the 327 # next forward pass 328 overlap_info.wait_for_broadcasts() 329 overlap_info.clear_per_iter_info() 330 331 return fut 332 333 return hook_with_zero_fn 334 335 336def hook_with_zero_step_interleaved( 337 hook: Callable[[Any, dist.GradBucket], torch.futures.Future], 338 ddp: DistributedDataParallel, 339 zero: ZeroRedundancyOptimizer, 340 shard_buckets: bool = False, 341) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: 342 r""" 343 Modify ``hook`` to overlap :class:`ZeroRedundancyOptimizer` optimizer step with :class:`DistributedDataParallel` backward pass 344 345 This approach overlaps the optimizer computation and communication with the 346 backward computation and communication. In particular, once a bucket's 347 gradients have been computed, the optimizer computation using those 348 gradients is launched (though the actual computation must wait for the 349 bucket's all-reduce to complete). This yields an interleaving of all- 350 reduces and broadcasts in the communication stream. 351 352 This approach may be preferred over :meth:`hook_with_zero_step` if 353 communication is relatively fast compared to computation. 354 355 Arguments: 356 hook (Any * dist.GradBucket -> torch.futures.Future): the hook to 357 modify. 358 ddp (DistributedDataParallel): the :class:`DistributedDataParallel` 359 instance to use. 360 zero (ZeroRedundancyOptimizer): the :class:`ZeroRedundancyOptimizer` 361 instance to use. 362 shard_buckets (bool): if ``True``, then the assignment of each 363 :class:`DistributedDataParallel` bucket is partitioned across 364 possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e. 365 across possibly multiple ranks) to approximate uniformity; if 366 ``False``, then each bucket is wholly assigned to a single 367 :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank). 368 369 Returns: 370 The modified hook. 371 372 Raises: 373 ValueError: if ``zero`` was constructed with ``overlap_with_ddp=False``. 374 RuntimeError: if using any backend other than NCCL since currently 375 Gloo may hang. 376 377 .. warning:: 378 Given the way that overlapping :class:`DistributedDataParallel` with 379 :class:`ZeroRedundancyOptimizer` is currently implemented, the first 380 two or three training iterations do not perform parameter updates in 381 the optimizer step, depending on if ``static_graph=False`` or 382 ``static_graph=True``, respectively. This is because it needs 383 information about the gradient bucketing strategy used by 384 :class:`DistributedDataParallel`, which is not finalized until the 385 second forward pass if ``static_graph=False`` or until the third 386 forward pass if ``static_graph=True``. 387 """ 388 if not zero._overlap_with_ddp: 389 raise ValueError( 390 "ZeroRedundancyOptimizer must be constructed with " 391 "`overlap_with_ddp=True` to use this hook properly" 392 ) 393 ddp_ref = weakref.ref(ddp) 394 395 # NOTE: Gloo may hang with this overlapping approach, so we require 396 # NCCL/HCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300 397 pg = dist.get_backend(ddp_ref().process_group) # type: ignore[union-attr] 398 if (pg != dist.Backend.NCCL) and (pg != "hccl"): 399 raise RuntimeError( 400 "Overlapping DDP with ZeRO using this approach currently requires " 401 "NCCL/HCCL backend to avoid hangs" 402 ) 403 404 if shard_buckets: 405 zero._overlap_info.shard_buckets = True 406 zero._overlap_info.total_size = 0 407 408 def hook_with_zero_interleaved_fn( 409 state, 410 bucket: dist.GradBucket, 411 ) -> torch.futures.Future[torch.Tensor]: 412 r""" 413 Return :class:`Future` that gives gradient bucket tensor and performs partial :class:`ZeroRedundancyOptimizer` :meth:`step`. 414 415 This function uses the gradients in gradient in given bucket to perform a partial 416 :class:`ZeroRedundancyOptimizer` :meth:`step` 417 418 Arguments: 419 state: any state for the hook. 420 bucket (dist.GradBucket): the :class:`DistributedDataParallel` 421 gradient bucket. 422 """ 423 fut = hook(state, bucket) 424 _hook_with_zero_step_setup(ddp_ref, zero, bucket) 425 if zero._overlap_info.status != _OverlapStatus.INITIALIZED: 426 return fut 427 428 def zero_step(fut: torch.futures.Future) -> torch.Tensor: 429 r""" 430 Perform partial :class:`ZeroRedundancyOptimizer` :meth:`step` using gradients in the :class:`DistributedDataParallel`. 431 432 Returns: 433 A :class:`torch.Tensor` representing the contents of the 434 gradient bucket. 435 """ 436 overlap_info = zero._overlap_info 437 bucket_index = bucket.index() 438 rank = zero.global_rank 439 440 assigned_ranks = overlap_info.assigned_ranks_per_bucket[bucket_index] 441 overlap_info.bucket_indices_seen.append(bucket_index) 442 if rank in assigned_ranks: 443 _perform_local_step(bucket, zero, rank) 444 445 _broadcast_bucket(bucket_index, zero) 446 447 num_buckets = len(overlap_info.params_per_bucket) 448 if len(overlap_info.bucket_indices_seen) == num_buckets: 449 # Ensure that all parameter updates are finished before the 450 # next forward pass 451 overlap_info.wait_for_broadcasts() 452 overlap_info.clear_per_iter_info() 453 454 return bucket.buffer() 455 456 return fut.then(zero_step) 457 458 return hook_with_zero_interleaved_fn 459