• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import logging
3from collections import abc, defaultdict
4from typing import Any, Dict, Iterable, List, Optional, overload, Sequence, Tuple, Union
5
6import torch
7import torch.distributed as dist
8from torch.amp.grad_scaler import _MultiDeviceReplicator, GradScaler, OptState
9from torch.distributed.distributed_c10d import ProcessGroup
10
11
12logger = logging.getLogger(__name__)
13
14
15def _refresh_per_optimizer_state() -> Dict[str, Any]:
16    return {"stage": OptState.READY, "found_inf_per_device": {}}
17
18
19def _is_supported_device(tensor: torch.Tensor) -> bool:
20    return tensor.is_cuda or tensor.device.type in (
21        "xla",
22        "cpu",
23        "hpu",
24        "mtia",
25        torch._C._get_privateuse1_backend_name(),
26    )
27
28
29class _GeneralMultiDeviceReplicator(_MultiDeviceReplicator):
30    """
31    Lazily serves tensor to request device. This class extends
32    _MultiDeviceReplicator to allow support for "cpu" as a device.
33    """
34
35    def __init__(self, master_tensor: torch.Tensor) -> None:
36        assert _is_supported_device(master_tensor)
37        self.master = master_tensor
38        self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
39
40
41class ShardedGradScaler(GradScaler):
42    """
43    ShardedGradScaler helps perform gradient scaling in a shard aware manner. It extends
44    functionality from GradScaler:
45    * Supports Pytorch DDP and FSDP implementations
46    * Support CPU offloaded tensors (as used in fully sharded data parallel[FSDP])
47    * Supports the custom Mixed Precision loss dtype (fp16, bf16) that FSDP returns
48    * Sync inf/nan for scaled gradient tensors on any torch.device (where tensors are placed) across
49    nodes
50
51    Example::
52
53        # Creates a ShardedGradScaler once at the beginning of training.
54        scaler = ShardedGradScaler()
55
56        for epoch in epochs:
57            for input, target in data:
58                optimizer.zero_grad()
59                output = model(input)
60                loss = loss_fn(output, target)
61
62                # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
63                scaler.scale(loss).backward()
64
65                # scaler.step() first unscales gradients of the optimizer's params.
66                # If gradients don't contain infs/NaNs, optimizer.step() is then called,
67                # otherwise, optimizer.step() is skipped.
68                scaler.step(optimizer)
69
70                # Updates the scale for next iteration.
71                scaler.update()
72
73    See :class:`GradScaler` for explanation of scaling/unscaling and more use cases.
74
75    Args:
76        init_scale (float, optional, default=2.**16):  Initial scale factor.
77        growth_factor (float, optional, default=2.0):  Factor by which the scale is multiplied during
78            :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
79        backoff_factor (float, optional, default=0.5):  Factor by which the scale is multiplied during
80            :meth:`update` if inf/NaN gradients occur in an iteration.
81        growth_interval (int, optional, default=2000):  Number of consecutive iterations without inf/NaN gradients
82            that must occur for the scale to be multiplied by ``growth_factor``.
83        enabled (bool, optional):  If ``False``, disables gradient scaling. :meth:`step` simply
84            invokes the underlying ``optimizer.step()``, and other methods become no-ops.
85            Default: ``True``
86        process_group (ProcessGroup, optional, default=torch.distributed.group.WORLD):
87            process group for sharding
88    """
89
90    def __init__(
91        self,
92        device: str = "cuda",
93        init_scale: float = 2.0**16,
94        backoff_factor: float = 0.5,
95        growth_factor: float = 2.0,
96        growth_interval: int = 2000,
97        enabled: bool = True,
98        process_group: Optional[ProcessGroup] = dist.group.WORLD,
99    ) -> None:
100        super().__init__(
101            device,
102            init_scale=init_scale,
103            backoff_factor=backoff_factor,
104            growth_factor=growth_factor,
105            growth_interval=growth_interval,
106            enabled=enabled,
107        )
108        if self._enabled:
109            self.process_group = process_group
110            self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
111
112    @overload
113    def scale(self, outputs: torch.Tensor) -> torch.Tensor:
114        ...
115
116    @overload
117    def scale(self, outputs: List[torch.Tensor]) -> List[torch.Tensor]:
118        ...
119
120    @overload
121    def scale(self, outputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
122        ...
123
124    @overload
125    def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]:
126        ...
127
128    def scale(
129        self, outputs: Union[torch.Tensor, Iterable[torch.Tensor]]
130    ) -> Union[torch.Tensor, Iterable[torch.Tensor]]:
131        if not self._enabled:
132            return outputs
133
134        if isinstance(outputs, torch.Tensor):
135            assert _is_supported_device(outputs)
136            if self._scale is None:
137                self._lazy_init_scale_growth_tracker(outputs.device)
138            assert self._scale is not None
139            scaled_output = outputs * self._scale.to(
140                device=outputs.device, non_blocking=True
141            )
142            # Here we ensure the return dtype is the same as the outputs dtype.
143            # For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision
144            # format (fp16, bf16) and so the scaled loss should be of the same dtype.
145            return scaled_output.type(outputs.dtype)
146
147        stash: List[_GeneralMultiDeviceReplicator] = []
148
149        def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):
150            if isinstance(val, torch.Tensor):
151                assert _is_supported_device(val)
152                if len(stash) == 0:
153                    if self._scale is None:
154                        self._lazy_init_scale_growth_tracker(val.device)
155                    assert self._scale is not None
156                    stash.append(_GeneralMultiDeviceReplicator(self._scale))
157                scaled_val = val * stash[0].get(val.device)
158                # Here we ensure the return dtype is the same as the outputs dtype.
159                # For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision
160                # format (fp16, bf16) and so the scaled loss should be of the same dtype.
161                return scaled_val.type(val.dtype)
162            if isinstance(val, abc.Iterable):
163                iterator = map(apply_scale, val)
164                if isinstance(val, (list, tuple)):
165                    return type(val)(iterator)
166                return iterator
167            raise ValueError("outputs must be a Tensor or an iterable of Tensors")
168
169        return apply_scale(outputs)
170
171    def _foreach_non_finite_check_and_unscale_cpu_(
172        self,
173        grads: Sequence[torch.Tensor],
174        found_inf: torch.Tensor,
175        inv_scale: torch.Tensor,
176    ) -> None:
177        if len(grads) == 0:
178            return
179        assert inv_scale.numel() == 1, "inv_scale must be a 1-element tensor."
180        assert found_inf.numel() == 1, "found_inf must be a 1-element tensor."
181
182        for grad in grads:
183            if grad.device.type != "cpu":
184                logger.error(
185                    "tensor device is %s but was expected to be ``cpu``",
186                    grad.device,
187                )
188                raise ValueError(
189                    "Gradients were found on a non-CPU device when"
190                    " expected to be on CPU."
191                )
192            if (
193                torch.isinf(grad).any().item() is True
194                or torch.isnan(grad).any().item() is True
195            ):
196                found_inf.data = torch.tensor([1.0])
197                break
198            else:
199                grad.data *= inv_scale.item()
200
201    def _unscale_grads_(
202        self,
203        optimizer: torch.optim.Optimizer,
204        inv_scale: torch.Tensor,
205        found_inf: torch.Tensor,
206        allow_fp16: bool = True,
207    ) -> Dict[torch.device, torch.Tensor]:
208        per_device_inv_scale = _GeneralMultiDeviceReplicator(inv_scale)
209        per_device_found_inf = _GeneralMultiDeviceReplicator(found_inf)
210
211        # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
212        # There could be thousands of grads, so we'd like to iterate through them just once.
213        # However, we don't know their devices or dtypes in advance.
214
215        # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
216        # Google says mypy struggles with defaultdicts type annotations.
217        per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list))  # type: ignore[var-annotated]
218        with torch.no_grad():
219            for group in optimizer.param_groups:
220                for param in group["params"]:
221                    if param.grad is None:
222                        continue
223                    if (not allow_fp16) and param.grad.dtype == torch.float16:
224                        raise ValueError("Attempting to unscale FP16 gradients.")
225                    if param.grad.is_sparse:
226                        # is_coalesced() == False means the sparse grad has values with duplicate indices.
227                        # coalesce() deduplicates indices and adds all values that have the same index.
228                        # For scaled fp16 values, there's a good chance coalescing will cause overflow,
229                        # so we should check the coalesced _values().
230                        if param.grad.dtype is torch.float16:
231                            # coalesce is not supported in torch.float16
232                            param_grad_fp32 = param.grad.type(torch.float32).coalesce()
233                            param.grad = param_grad_fp32.type(torch.float16)
234                        to_unscale = param.grad._values()
235                    else:
236                        to_unscale = param.grad
237
238                    per_device_and_dtype_grads[to_unscale.device][
239                        to_unscale.dtype
240                    ].append(to_unscale)
241
242            for device, per_dtype_grads in per_device_and_dtype_grads.items():
243                for grads in per_dtype_grads.values():
244                    if grads[0].device.type == "cpu":
245                        self._foreach_non_finite_check_and_unscale_cpu_(
246                            grads,
247                            per_device_found_inf.get(device),
248                            per_device_inv_scale.get(device),
249                        )
250                    else:
251                        torch._amp_foreach_non_finite_check_and_unscale_(
252                            grads,
253                            per_device_found_inf.get(device),
254                            per_device_inv_scale.get(device),
255                        )
256        # There exist contexts (e.g. w/ `use_orig_params=True`) wherein some
257        # ranks may have no (non-zero sized) parameter shards, necessitating the
258        # initialization of `per_device_found_inf._per_device_tensors` here
259        if not per_device_found_inf._per_device_tensors:
260            assert self._scale is not None
261            per_device_found_inf.get(self._scale.device)
262        return per_device_found_inf._per_device_tensors
263
264    def unscale_(self, optimizer: torch.optim.Optimizer) -> None:
265        if not self._enabled:
266            return
267
268        self._check_scale_growth_tracker("unscale_")
269
270        optimizer_state = self._per_optimizer_states[id(optimizer)]
271
272        if optimizer_state["stage"] is OptState.UNSCALED:
273            raise RuntimeError(
274                "unscale_() has already been called on this optimizer since the last update()."
275            )
276        elif optimizer_state["stage"] is OptState.STEPPED:
277            raise RuntimeError("unscale_() is being called after step().")
278
279        # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
280        assert self._scale is not None
281        inv_scale = self._scale.double().reciprocal().float()
282        found_inf = torch.full(
283            (1,), 0.0, dtype=torch.float32, device=self._scale.device
284        )
285
286        optimizer_state["found_inf_per_device"] = self._unscale_grads_(
287            optimizer, inv_scale, found_inf, True
288        )
289        optimizer_state["stage"] = OptState.UNSCALED
290
291        # Synchronize the detected inf across the ranks
292        optimizer_state = self._per_optimizer_states[id(optimizer)]
293        works = []
294        found_inf_on_cpus = []
295        found_inf_on_devices = []
296
297        for found_inf in optimizer_state["found_inf_per_device"].values():
298            if self._device != "cpu" and found_inf.device.type == "cpu":
299                found_inf_on_cpus.append(found_inf)
300                found_inf_on_device = found_inf.to(self._device)
301                found_inf_on_devices.append(found_inf_on_device)
302                works.append(
303                    dist.all_reduce(
304                        found_inf_on_device, async_op=True, group=self.process_group
305                    )
306                )
307            else:
308                works.append(
309                    dist.all_reduce(found_inf, async_op=True, group=self.process_group)
310                )
311        for work in works:
312            work.wait()
313        if found_inf_on_cpus:
314            torch._foreach_copy_(found_inf_on_cpus, found_inf_on_devices)
315
316    def _amp_update_scale_cpu_(self, found_inf: torch.Tensor) -> None:
317        """
318        If found_inf is 1.0 (True), then scale is multiplied by backoff_factor and growth_tracker is set to zero.
319        Otherwise, scale is multiplied by the growth factor when the growth interval is reached.
320        """
321        assert self._scale is not None and self._growth_tracker is not None
322
323        if found_inf.item() >= 1.0:
324            self._scale *= self._backoff_factor
325            self._growth_tracker.fill_(0)
326        else:
327            successful = self._growth_tracker + 1
328            if successful == self._growth_interval:
329                self._scale *= self._growth_factor
330                self._growth_tracker.fill_(0)
331            else:
332                self._growth_tracker = successful
333
334    def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None:
335        """
336        Updates the scale factor.
337        If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
338        to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
339        the scale is multiplied by ``growth_factor`` to increase it.
340        Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
341        used directly, it's used to fill GradScaler's internal scale tensor. So if
342        ``new_scale`` was a tensor, later in-place changes to that tensor will not further
343        affect the scale GradScaler uses internally.)
344        Args:
345            new_scale (float or :class:`torch.Tensor`, optional, default=None):  New scale factor.
346        .. warning::
347            :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
348            been invoked for all optimizers used this iteration.
349        """
350
351        if not self._enabled:
352            return
353
354        _scale, _growth_tracker = self._check_scale_growth_tracker("update")  # type: ignore[var-annotated]
355
356        if new_scale is not None:
357            # Accept a new user-defined scale.
358            if isinstance(new_scale, float):
359                self._scale.fill_(new_scale)  # type: ignore[union-attr]
360            else:
361                reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor or \
362                    torch.FloatTensor with requires_grad=False."
363                assert new_scale.device.type == self._device, reason
364                assert new_scale.numel() == 1, reason
365                assert new_scale.requires_grad is False, reason
366                self._scale.copy_(new_scale)  # type: ignore[union-attr]
367        else:
368            # Consume shared inf/nan data collected from optimizers to update the scale.
369            # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
370            found_infs = [
371                found_inf.to(device=_scale.device, non_blocking=True)
372                for state in self._per_optimizer_states.values()
373                for found_inf in state["found_inf_per_device"].values()
374            ]
375
376            assert len(found_infs) > 0, "No inf checks were recorded prior to update."
377
378            found_inf_combined = found_infs[0]
379            if len(found_infs) > 1:
380                for i in range(1, len(found_infs)):
381                    found_inf_combined += found_infs[i]
382
383            if _scale.device.type == "cpu":
384                self._amp_update_scale_cpu_(found_inf_combined)
385            else:
386                torch._amp_update_scale_(
387                    self._scale,  # type: ignore[arg-type]
388                    self._growth_tracker,  # type: ignore[arg-type]
389                    found_inf_combined,
390                    self._growth_factor,  # type: ignore[arg-type]
391                    self._backoff_factor,  # type: ignore[arg-type]
392                    self._growth_interval,  # type: ignore[arg-type]
393                )
394
395        # To prepare for next iteration, clear the data collected from optimizers this iteration.
396        self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
397