1# mypy: allow-untyped-defs 2import logging 3from collections import abc, defaultdict 4from typing import Any, Dict, Iterable, List, Optional, overload, Sequence, Tuple, Union 5 6import torch 7import torch.distributed as dist 8from torch.amp.grad_scaler import _MultiDeviceReplicator, GradScaler, OptState 9from torch.distributed.distributed_c10d import ProcessGroup 10 11 12logger = logging.getLogger(__name__) 13 14 15def _refresh_per_optimizer_state() -> Dict[str, Any]: 16 return {"stage": OptState.READY, "found_inf_per_device": {}} 17 18 19def _is_supported_device(tensor: torch.Tensor) -> bool: 20 return tensor.is_cuda or tensor.device.type in ( 21 "xla", 22 "cpu", 23 "hpu", 24 "mtia", 25 torch._C._get_privateuse1_backend_name(), 26 ) 27 28 29class _GeneralMultiDeviceReplicator(_MultiDeviceReplicator): 30 """ 31 Lazily serves tensor to request device. This class extends 32 _MultiDeviceReplicator to allow support for "cpu" as a device. 33 """ 34 35 def __init__(self, master_tensor: torch.Tensor) -> None: 36 assert _is_supported_device(master_tensor) 37 self.master = master_tensor 38 self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} 39 40 41class ShardedGradScaler(GradScaler): 42 """ 43 ShardedGradScaler helps perform gradient scaling in a shard aware manner. It extends 44 functionality from GradScaler: 45 * Supports Pytorch DDP and FSDP implementations 46 * Support CPU offloaded tensors (as used in fully sharded data parallel[FSDP]) 47 * Supports the custom Mixed Precision loss dtype (fp16, bf16) that FSDP returns 48 * Sync inf/nan for scaled gradient tensors on any torch.device (where tensors are placed) across 49 nodes 50 51 Example:: 52 53 # Creates a ShardedGradScaler once at the beginning of training. 54 scaler = ShardedGradScaler() 55 56 for epoch in epochs: 57 for input, target in data: 58 optimizer.zero_grad() 59 output = model(input) 60 loss = loss_fn(output, target) 61 62 # Scales loss. Calls backward() on scaled loss to create scaled gradients. 63 scaler.scale(loss).backward() 64 65 # scaler.step() first unscales gradients of the optimizer's params. 66 # If gradients don't contain infs/NaNs, optimizer.step() is then called, 67 # otherwise, optimizer.step() is skipped. 68 scaler.step(optimizer) 69 70 # Updates the scale for next iteration. 71 scaler.update() 72 73 See :class:`GradScaler` for explanation of scaling/unscaling and more use cases. 74 75 Args: 76 init_scale (float, optional, default=2.**16): Initial scale factor. 77 growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during 78 :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. 79 backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during 80 :meth:`update` if inf/NaN gradients occur in an iteration. 81 growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients 82 that must occur for the scale to be multiplied by ``growth_factor``. 83 enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply 84 invokes the underlying ``optimizer.step()``, and other methods become no-ops. 85 Default: ``True`` 86 process_group (ProcessGroup, optional, default=torch.distributed.group.WORLD): 87 process group for sharding 88 """ 89 90 def __init__( 91 self, 92 device: str = "cuda", 93 init_scale: float = 2.0**16, 94 backoff_factor: float = 0.5, 95 growth_factor: float = 2.0, 96 growth_interval: int = 2000, 97 enabled: bool = True, 98 process_group: Optional[ProcessGroup] = dist.group.WORLD, 99 ) -> None: 100 super().__init__( 101 device, 102 init_scale=init_scale, 103 backoff_factor=backoff_factor, 104 growth_factor=growth_factor, 105 growth_interval=growth_interval, 106 enabled=enabled, 107 ) 108 if self._enabled: 109 self.process_group = process_group 110 self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) 111 112 @overload 113 def scale(self, outputs: torch.Tensor) -> torch.Tensor: 114 ... 115 116 @overload 117 def scale(self, outputs: List[torch.Tensor]) -> List[torch.Tensor]: 118 ... 119 120 @overload 121 def scale(self, outputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]: 122 ... 123 124 @overload 125 def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]: 126 ... 127 128 def scale( 129 self, outputs: Union[torch.Tensor, Iterable[torch.Tensor]] 130 ) -> Union[torch.Tensor, Iterable[torch.Tensor]]: 131 if not self._enabled: 132 return outputs 133 134 if isinstance(outputs, torch.Tensor): 135 assert _is_supported_device(outputs) 136 if self._scale is None: 137 self._lazy_init_scale_growth_tracker(outputs.device) 138 assert self._scale is not None 139 scaled_output = outputs * self._scale.to( 140 device=outputs.device, non_blocking=True 141 ) 142 # Here we ensure the return dtype is the same as the outputs dtype. 143 # For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision 144 # format (fp16, bf16) and so the scaled loss should be of the same dtype. 145 return scaled_output.type(outputs.dtype) 146 147 stash: List[_GeneralMultiDeviceReplicator] = [] 148 149 def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]): 150 if isinstance(val, torch.Tensor): 151 assert _is_supported_device(val) 152 if len(stash) == 0: 153 if self._scale is None: 154 self._lazy_init_scale_growth_tracker(val.device) 155 assert self._scale is not None 156 stash.append(_GeneralMultiDeviceReplicator(self._scale)) 157 scaled_val = val * stash[0].get(val.device) 158 # Here we ensure the return dtype is the same as the outputs dtype. 159 # For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision 160 # format (fp16, bf16) and so the scaled loss should be of the same dtype. 161 return scaled_val.type(val.dtype) 162 if isinstance(val, abc.Iterable): 163 iterator = map(apply_scale, val) 164 if isinstance(val, (list, tuple)): 165 return type(val)(iterator) 166 return iterator 167 raise ValueError("outputs must be a Tensor or an iterable of Tensors") 168 169 return apply_scale(outputs) 170 171 def _foreach_non_finite_check_and_unscale_cpu_( 172 self, 173 grads: Sequence[torch.Tensor], 174 found_inf: torch.Tensor, 175 inv_scale: torch.Tensor, 176 ) -> None: 177 if len(grads) == 0: 178 return 179 assert inv_scale.numel() == 1, "inv_scale must be a 1-element tensor." 180 assert found_inf.numel() == 1, "found_inf must be a 1-element tensor." 181 182 for grad in grads: 183 if grad.device.type != "cpu": 184 logger.error( 185 "tensor device is %s but was expected to be ``cpu``", 186 grad.device, 187 ) 188 raise ValueError( 189 "Gradients were found on a non-CPU device when" 190 " expected to be on CPU." 191 ) 192 if ( 193 torch.isinf(grad).any().item() is True 194 or torch.isnan(grad).any().item() is True 195 ): 196 found_inf.data = torch.tensor([1.0]) 197 break 198 else: 199 grad.data *= inv_scale.item() 200 201 def _unscale_grads_( 202 self, 203 optimizer: torch.optim.Optimizer, 204 inv_scale: torch.Tensor, 205 found_inf: torch.Tensor, 206 allow_fp16: bool = True, 207 ) -> Dict[torch.device, torch.Tensor]: 208 per_device_inv_scale = _GeneralMultiDeviceReplicator(inv_scale) 209 per_device_found_inf = _GeneralMultiDeviceReplicator(found_inf) 210 211 # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. 212 # There could be thousands of grads, so we'd like to iterate through them just once. 213 # However, we don't know their devices or dtypes in advance. 214 215 # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict 216 # Google says mypy struggles with defaultdicts type annotations. 217 per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] 218 with torch.no_grad(): 219 for group in optimizer.param_groups: 220 for param in group["params"]: 221 if param.grad is None: 222 continue 223 if (not allow_fp16) and param.grad.dtype == torch.float16: 224 raise ValueError("Attempting to unscale FP16 gradients.") 225 if param.grad.is_sparse: 226 # is_coalesced() == False means the sparse grad has values with duplicate indices. 227 # coalesce() deduplicates indices and adds all values that have the same index. 228 # For scaled fp16 values, there's a good chance coalescing will cause overflow, 229 # so we should check the coalesced _values(). 230 if param.grad.dtype is torch.float16: 231 # coalesce is not supported in torch.float16 232 param_grad_fp32 = param.grad.type(torch.float32).coalesce() 233 param.grad = param_grad_fp32.type(torch.float16) 234 to_unscale = param.grad._values() 235 else: 236 to_unscale = param.grad 237 238 per_device_and_dtype_grads[to_unscale.device][ 239 to_unscale.dtype 240 ].append(to_unscale) 241 242 for device, per_dtype_grads in per_device_and_dtype_grads.items(): 243 for grads in per_dtype_grads.values(): 244 if grads[0].device.type == "cpu": 245 self._foreach_non_finite_check_and_unscale_cpu_( 246 grads, 247 per_device_found_inf.get(device), 248 per_device_inv_scale.get(device), 249 ) 250 else: 251 torch._amp_foreach_non_finite_check_and_unscale_( 252 grads, 253 per_device_found_inf.get(device), 254 per_device_inv_scale.get(device), 255 ) 256 # There exist contexts (e.g. w/ `use_orig_params=True`) wherein some 257 # ranks may have no (non-zero sized) parameter shards, necessitating the 258 # initialization of `per_device_found_inf._per_device_tensors` here 259 if not per_device_found_inf._per_device_tensors: 260 assert self._scale is not None 261 per_device_found_inf.get(self._scale.device) 262 return per_device_found_inf._per_device_tensors 263 264 def unscale_(self, optimizer: torch.optim.Optimizer) -> None: 265 if not self._enabled: 266 return 267 268 self._check_scale_growth_tracker("unscale_") 269 270 optimizer_state = self._per_optimizer_states[id(optimizer)] 271 272 if optimizer_state["stage"] is OptState.UNSCALED: 273 raise RuntimeError( 274 "unscale_() has already been called on this optimizer since the last update()." 275 ) 276 elif optimizer_state["stage"] is OptState.STEPPED: 277 raise RuntimeError("unscale_() is being called after step().") 278 279 # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. 280 assert self._scale is not None 281 inv_scale = self._scale.double().reciprocal().float() 282 found_inf = torch.full( 283 (1,), 0.0, dtype=torch.float32, device=self._scale.device 284 ) 285 286 optimizer_state["found_inf_per_device"] = self._unscale_grads_( 287 optimizer, inv_scale, found_inf, True 288 ) 289 optimizer_state["stage"] = OptState.UNSCALED 290 291 # Synchronize the detected inf across the ranks 292 optimizer_state = self._per_optimizer_states[id(optimizer)] 293 works = [] 294 found_inf_on_cpus = [] 295 found_inf_on_devices = [] 296 297 for found_inf in optimizer_state["found_inf_per_device"].values(): 298 if self._device != "cpu" and found_inf.device.type == "cpu": 299 found_inf_on_cpus.append(found_inf) 300 found_inf_on_device = found_inf.to(self._device) 301 found_inf_on_devices.append(found_inf_on_device) 302 works.append( 303 dist.all_reduce( 304 found_inf_on_device, async_op=True, group=self.process_group 305 ) 306 ) 307 else: 308 works.append( 309 dist.all_reduce(found_inf, async_op=True, group=self.process_group) 310 ) 311 for work in works: 312 work.wait() 313 if found_inf_on_cpus: 314 torch._foreach_copy_(found_inf_on_cpus, found_inf_on_devices) 315 316 def _amp_update_scale_cpu_(self, found_inf: torch.Tensor) -> None: 317 """ 318 If found_inf is 1.0 (True), then scale is multiplied by backoff_factor and growth_tracker is set to zero. 319 Otherwise, scale is multiplied by the growth factor when the growth interval is reached. 320 """ 321 assert self._scale is not None and self._growth_tracker is not None 322 323 if found_inf.item() >= 1.0: 324 self._scale *= self._backoff_factor 325 self._growth_tracker.fill_(0) 326 else: 327 successful = self._growth_tracker + 1 328 if successful == self._growth_interval: 329 self._scale *= self._growth_factor 330 self._growth_tracker.fill_(0) 331 else: 332 self._growth_tracker = successful 333 334 def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None: 335 """ 336 Updates the scale factor. 337 If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` 338 to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, 339 the scale is multiplied by ``growth_factor`` to increase it. 340 Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not 341 used directly, it's used to fill GradScaler's internal scale tensor. So if 342 ``new_scale`` was a tensor, later in-place changes to that tensor will not further 343 affect the scale GradScaler uses internally.) 344 Args: 345 new_scale (float or :class:`torch.Tensor`, optional, default=None): New scale factor. 346 .. warning:: 347 :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has 348 been invoked for all optimizers used this iteration. 349 """ 350 351 if not self._enabled: 352 return 353 354 _scale, _growth_tracker = self._check_scale_growth_tracker("update") # type: ignore[var-annotated] 355 356 if new_scale is not None: 357 # Accept a new user-defined scale. 358 if isinstance(new_scale, float): 359 self._scale.fill_(new_scale) # type: ignore[union-attr] 360 else: 361 reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor or \ 362 torch.FloatTensor with requires_grad=False." 363 assert new_scale.device.type == self._device, reason 364 assert new_scale.numel() == 1, reason 365 assert new_scale.requires_grad is False, reason 366 self._scale.copy_(new_scale) # type: ignore[union-attr] 367 else: 368 # Consume shared inf/nan data collected from optimizers to update the scale. 369 # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. 370 found_infs = [ 371 found_inf.to(device=_scale.device, non_blocking=True) 372 for state in self._per_optimizer_states.values() 373 for found_inf in state["found_inf_per_device"].values() 374 ] 375 376 assert len(found_infs) > 0, "No inf checks were recorded prior to update." 377 378 found_inf_combined = found_infs[0] 379 if len(found_infs) > 1: 380 for i in range(1, len(found_infs)): 381 found_inf_combined += found_infs[i] 382 383 if _scale.device.type == "cpu": 384 self._amp_update_scale_cpu_(found_inf_combined) 385 else: 386 torch._amp_update_scale_( 387 self._scale, # type: ignore[arg-type] 388 self._growth_tracker, # type: ignore[arg-type] 389 found_inf_combined, 390 self._growth_factor, # type: ignore[arg-type] 391 self._backoff_factor, # type: ignore[arg-type] 392 self._growth_interval, # type: ignore[arg-type] 393 ) 394 395 # To prepare for next iteration, clear the data collected from optimizers this iteration. 396 self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) 397