• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from copy import deepcopy
2from datetime import timedelta
3from functools import partial, wraps
4from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type, Union
5
6import torch
7import torch.distributed as dist
8from torch import nn, optim
9from torch._guards import active_fake_mode
10from torch.distributed._composable.fsdp import FSDPModule
11from torch.distributed._composable.fsdp._fsdp_param_group import FSDPParamGroup
12from torch.distributed._tools.mem_tracker import _RefType, _State, MemTracker
13from torch.distributed.distributed_c10d import (
14    _IllegalWork,
15    ProcessGroup,
16    ReduceOp,
17    Work,
18)
19from torch.futures import Future
20from torch.utils._python_dispatch import TorchDispatchMode
21from torch.utils._pytree import tree_map_only
22from torch.utils.weak import WeakIdKeyDictionary, weakref
23
24
25_TOTAL_KEY = "Total"
26
27__all__ = ["FSDPMemTracker"]
28
29
30class _FSDPRefType(_RefType):
31    """
32    Enumerates categories of memory usage in FSDP modules, including parameters, gradients, activations,
33    and optimizer states.
34
35    Attributes:
36        SHARDED_PARAM (str): Memory usage of sharded parameters.
37        UNSHARDED_PARAM (str): Memory usage of unsharded parameters.
38        SHARDED_GRAD (str): Memory usage of sharded gradients corresponding to the sharded parameters.
39        UNSHARDED_GRAD (str): Memory usage of unsharded gradients corresponding to the unsharded parameters.
40        ACT (str): Memory usage of activations and tensors from forward and AC recomputation.
41        TEMP (str): Memory usage of temporary tensors during the backward pass including gradients of activations.
42        ALL_GATHER (str): Memory usage of all_gather output tensor.
43        REDUCE_SCATTER (str): Memory usage of reduce_scatter input tensor.
44        OPT (str): Memory usage of tensors storing optimizer states.
45        INP (str): Memory usage of input tensors.
46    """
47
48    SHARDED_PARAM = "Sharded Param"
49    UNSHARDED_PARAM = "Unsharded Param"
50    BUFFER = "Buffer"
51    SHARDED_GRAD = "Sharded Grad"
52    UNSHARDED_GRAD = "Unsharded Grad"
53    ACT = "Activation"
54    TEMP = "Temp"
55    ALL_GATHER = "All Gather"
56    REDUCE_SCATTER = "Reduce Scatter"
57    OPT = "OptState"
58    INP = "Inputs"
59
60
61class _SavedFSDPMethods(NamedTuple):
62    pre_backward: Callable
63    post_backward: Callable
64
65
66class _SavedCollectives(NamedTuple):
67    all_gather_into_tensor: Callable
68    reduce_scatter_tensor: Callable
69    all_reduce: Callable
70    barrier: Callable
71
72
73class _FSDPModState(_State):
74    """
75    Enumerates the states of FSDP modules during the forward and backward passes.
76    """
77
78    BEF_PRE_FW = "Before Pre-Forward"
79    AFT_PRE_FW = "After Pre-Forward"
80    BEF_POST_FW = "Before Post-Forward"
81    AFT_POST_FW = "After Post-Forward"
82    BEF_PRE_BW = "Before Pre-Backward"
83    AFT_PRE_BW = "After Pre-Backward"
84    BEF_POST_BW = "Before Post-Backward"
85    AFT_POST_BW = "After Post-Backward"
86    PRE_FW_AC = "Pre-Forward AC"
87    POST_FW_AC = "Post-Forward AC"
88    PEAK_FW = "Peak Forward"
89    PEAK_BW = "Peak Backward"
90
91
92class _FSDPModMemStats:
93    """
94    A class to store the memory statistics of an FSDP module.
95
96    Args:
97        mod_fqn (str): The fully qualified name of the FSDP module.
98
99    Attributes:
100        snapshots (Dict[_FSDPModState, Dict[torch.device, Dict[str, int]]]): A dictionary of memory snapshots
101        of the module at different states as defined by ``_FSDPModState``. Each key is a device, and
102        each value is another dictionary with keys as memory reference types defined by ``_FSDPRefType`` and
103        values as the memory consumed in bytes.
104
105    """
106
107    def __init__(self, mod_fqn: str) -> None:
108        self.mod_fqn = mod_fqn
109        self.local_peak: Dict[torch.device, int] = {}
110        self.snapshots: Dict[
111            _FSDPModState, List[Dict[torch.device, Dict[str, int]]]
112        ] = {}
113
114
115class FSDPMemTracker(MemTracker):
116    """
117    A ``TorchDispatchMode`` based context manager that extends ``torch.distributed._tools.mem_tracker.MemTracker`` to track
118    and categorize the peak memory and module-wise memory usage of FSDP modules.
119
120    It tracks the peak memory usage across all the devices of all the FSDP modules in the module tree and categorizes
121    the tensor memory usage as defined by ``_FSDPRefType``. Further, it captures memory `snapshots` at different stages of
122    the module execution defined by ``_FSDPModState``.
123
124    Attributes:
125        memory_tracking: A weakref key dictionary to store the memory statistics of each module. Each key is a reference
126        to a module, and each value is a ``_FSDPModMemStats`` object that stores the memory statistics of the module.
127
128    Args:
129        mod (torch.nn.Module): The root FSDP module to be tracked.
130        optm (torch.optim.Optimizer, optional): The optimizer to be tracked.
131
132    Note: Please refer to ``torch.distributed._tools.mem_tracker.MemTracker`` to learn about the limitations.
133
134    Example usage
135
136    .. code-block:: python
137
138        module = ...
139        optimizer = ...
140        inp = ...
141        fmt = FSDPMemTracker(module, optimizer)
142        fmt.track_inputs((inp,))
143        with fmt:
144            optimizer.zero_grad()
145            loss = module(inp)
146            print("After Forward:")
147            fmt.display_snapshot("current")
148            loss.backward()
149            optimizer.step()
150        fmt.display_snapshot("peak")
151        fmt.display_modulewise_snapshots(depth = 3, units = "MB")
152
153    """
154
155    def __init__(
156        self,
157        mod: torch.nn.Module,
158        optm: Optional[torch.optim.Optimizer] = None,
159    ) -> None:
160        super().__init__()
161        assert isinstance(mod, FSDPModule), "FSDPMemTracker only supports FSDP modules"
162        self._root_mod = mod
163        self._optm = optm
164        self._in_fake_mode: bool = False
165        self._fsdp_mod_to_saved_methods: WeakIdKeyDictionary = WeakIdKeyDictionary()
166        self._saved_collectives: _SavedCollectives
167        self._ref_class: Type[_RefType] = _FSDPRefType
168
169    def _instrument_fsdp_sharded_params_grads(
170        self, fsdp_param_group: FSDPParamGroup
171    ) -> None:
172        # Track sharded params and grads after initilization
173        for fsdp_param in fsdp_param_group.fsdp_params:
174            self._update_and_maybe_create_winfos(
175                fsdp_param.sharded_param,
176                _FSDPRefType.SHARDED_PARAM,
177            )
178            sharded_grad = fsdp_param.sharded_param.grad
179            if sharded_grad is not None:
180                self._update_and_maybe_create_winfos(
181                    sharded_grad,
182                    _FSDPRefType.SHARDED_GRAD,
183                )
184
185    def _fsdp_state_pre_forward(
186        self,
187        fsdp_mod: FSDPModule,
188        orig_fsdp_state_pre_fw: Callable,
189    ) -> Callable:
190        # We capture memory snapshots before and after ``FSDPState._pre_forward`` to attribute the `unsharded` params
191        # and `all_gather` buffers.  There are three cases:
192        # Case 1: If the module is not in the ``memory_tracking`` dictionary, create a new ``_FSDPModMemStats``
193        #         instance for the module and add it to the ``memory_tracking`` dictionary.
194        # Case 2: If the module is already in the ``memory_tracking`` dictionary and we are in backward, this means
195        #         we are in the AC region. We check if this is the top most module in the AC region. If it is,
196        #         we store a weak reference and set the flag ``_in_ac`` to True.
197        # Case 3: If the module is already in the ``memory_tracking`` dictionary and we are in forward, this means
198        #         this module is called for the second time. If it is a root module, that means we are in the next
199        #         iteration and we error out. If it is not a root module, that means it's a submodule that is being
200        #         used multiple times in the same iteration, which we allow and track.
201        # For Case 1 and 3, we also initialiaze the ``local_peak`` and ``PEAK_FW`` snapshot for the module.
202        # For Case 2 we only capture 1 snapshot after ``FSDPState._pre_forward`` runs because it is a no-op.
203        @wraps(orig_fsdp_state_pre_fw)
204        def inner(*args: Any, **kwargs: Any) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
205            mod_fqn = self._mod_tracker.get_known_fqn(fsdp_mod)
206            assert mod_fqn is not None
207            if fsdp_mod not in self.memory_tracking:
208                mod_stat = _FSDPModMemStats(mod_fqn)
209                self.memory_tracking[fsdp_mod] = mod_stat
210                snapshot = self.get_tracker_snapshot()
211                mod_stat.local_peak = {
212                    dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in snapshot.items()
213                }
214                mod_stat.snapshots.setdefault(_FSDPModState.PEAK_FW, []).append(
215                    snapshot
216                )
217                mod_stat.snapshots.setdefault(_FSDPModState.BEF_PRE_FW, []).append(
218                    deepcopy(snapshot)
219                )
220            elif not self._mod_tracker.is_bw:
221                parents = self._mod_tracker.parents - {mod_fqn}
222                if len(parents) == 1 and "Global" in parents:
223                    raise NotImplementedError(
224                        "FSDPMemTracker does not support memory tracking for multiple iterative calls."
225                        " Either use ``reset_mod_stats`` to clear module memory stats for the previous iteration"
226                        " or file a github issue if you need this feature."
227                    )
228
229            args, kwargs = orig_fsdp_state_pre_fw(*args, **kwargs)
230
231            fsdp_state = fsdp_mod._get_fsdp_state()
232            if fsdp_param_group := fsdp_state._fsdp_param_group:
233                for fsdp_param in fsdp_param_group.fsdp_params:
234                    self._update_and_maybe_create_winfos(
235                        fsdp_param.unsharded_param,
236                        _FSDPRefType.UNSHARDED_PARAM,
237                    )
238            mod_stat = self.memory_tracking[fsdp_mod]
239            if self._mod_tracker.is_bw:
240                state = _FSDPModState.PRE_FW_AC
241                if self._ac_mod is None:
242                    self._ac_mod = weakref.ref(fsdp_mod)
243                    self._in_ac = True
244            else:
245                state = _FSDPModState.AFT_PRE_FW
246            mod_stat.snapshots.setdefault(state, []).append(self.get_tracker_snapshot())
247            return args, kwargs
248
249        return inner
250
251    def _fsdp_state_post_forward(
252        self,
253        fsdp_mod: FSDPModule,
254        orig_fsdp_state_post_fw: Callable,
255    ) -> Callable:
256        # We capture memory snapshots before and after ``FSDPState._post_forward`` to capture the resharded state
257        # if ``reshard_after_forward`` is not ``False``. There are two cases:
258        # Case 1: This is called in backward, which means we are in the AC region. If this is the top most module
259        #         in the AC region, we set the flag ``_in_ac`` to False.
260        # Case 2: This is called in forward.
261        @wraps(orig_fsdp_state_post_fw)
262        def inner(*args: Any, **kwargs: Any) -> Any:
263            mod_stat = self.memory_tracking[fsdp_mod]
264            if self._mod_tracker.is_bw:
265                state = _FSDPModState.POST_FW_AC
266                if self._ac_mod is not None and self._ac_mod() is fsdp_mod:
267                    self._ac_mod = None
268                    self._in_ac = False
269            else:
270                state = _FSDPModState.BEF_POST_FW
271            mod_stat.snapshots.setdefault(state, []).append(self.get_tracker_snapshot())
272
273            output = orig_fsdp_state_post_fw(*args, **kwargs)
274
275            if not self._mod_tracker.is_bw:
276                mod_stat.snapshots.setdefault(_FSDPModState.AFT_POST_FW, []).append(
277                    self.get_tracker_snapshot()
278                )
279            return output
280
281        return inner
282
283    def _fsdp_param_group_pre_backward(
284        self,
285        fsdp_mod: FSDPModule,
286        orig_fsdp_param_group_pre_backward: Callable,
287    ) -> Callable:
288        # We capture memory snapshots before and after ``FSDPParamGroup.pre_backward`` to capture the pre-fetching
289        # and unsharding of params. We also initialize ``local_peak`` and ``PEAK_BW`` snapshot for the module.
290        @wraps(orig_fsdp_param_group_pre_backward)
291        def inner(*args: Any, **kwargs: Any) -> None:
292            mod_stat = self.memory_tracking[fsdp_mod]
293            snapshot = self.get_tracker_snapshot()
294            mod_stat.local_peak = {
295                dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in snapshot.items()
296            }
297            mod_stat.snapshots.setdefault(_FSDPModState.PEAK_BW, []).append(snapshot)
298            mod_stat.snapshots.setdefault(_FSDPModState.BEF_PRE_BW, []).append(
299                deepcopy(snapshot)
300            )
301            orig_fsdp_param_group_pre_backward(*args, **kwargs)
302
303            mod_stat.snapshots.setdefault(_FSDPModState.AFT_PRE_BW, []).append(
304                self.get_tracker_snapshot()
305            )
306
307        return inner
308
309    def _fsdp_param_group_post_backward(
310        self,
311        fsdp_mod: FSDPModule,
312        orig_fsdp_param_group_post_backward: Callable,
313    ) -> Callable:
314        # We capture the memory snapshots before and after ``FSDPParamGroup.post_backward`` to track and attribute
315        # the `unsharded` grads before the post backward and then `sharded` grads and `reduce_scatter`  buffers
316        # after the post backward.
317        @wraps(orig_fsdp_param_group_post_backward)
318        def inner(*args: Any, **kwargs: Any) -> None:
319            fsdp_state = fsdp_mod._get_fsdp_state()
320            if fsdp_param_group := fsdp_state._fsdp_param_group:
321                for fsdp_param in fsdp_param_group.fsdp_params:
322                    unsharded_grad = fsdp_param._unsharded_param.grad
323                    if unsharded_grad is not None:
324                        self._update_and_maybe_create_winfos(
325                            unsharded_grad,
326                            _FSDPRefType.UNSHARDED_GRAD,
327                            update_existing=True,
328                        )
329
330            mod_stat = self.memory_tracking[fsdp_mod]
331            mod_stat.snapshots.setdefault(_FSDPModState.BEF_POST_BW, []).append(
332                self.get_tracker_snapshot()
333            )
334
335            orig_fsdp_param_group_post_backward(*args, **kwargs)
336
337            if fsdp_param_group := fsdp_state._fsdp_param_group:
338                for fsdp_param in fsdp_param_group.fsdp_params:
339                    sharded_grad = fsdp_param.sharded_param.grad
340                    if sharded_grad is not None:
341                        self._update_and_maybe_create_winfos(
342                            sharded_grad,
343                            _FSDPRefType.SHARDED_GRAD,
344                        )
345
346            mod_stat.snapshots.setdefault(_FSDPModState.AFT_POST_BW, []).append(
347                self.get_tracker_snapshot()
348            )
349
350        return inner
351
352    def _instrument_fsdp_module(self) -> None:
353        # We uninstall the existing `FSDPState._pre_forward` and `FSDPState._post_forward` hooks and install
354        # our own hooks that wrap them. We choose this over monkey-patching `FSDPParamGroup.pre_forward` and
355        # `FSDPParamGroup.post_forward` because during AC these won't be called.
356        # TODO(@sanketpurandare): This will need to be modified after this PR (https://github.com/pytorch/pytorch/pull/127786)
357        # lands. For backward we monkey-patch the `FSDPParamGroup.pre_backward` and `FSDPParamGroup.post_backward`.
358        for module in self._root_mod.modules():
359            if isinstance(module, FSDPModule):
360                fsdp_state = module._get_fsdp_state()
361                if fsdp_param_group := fsdp_state._fsdp_param_group:
362                    self._instrument_fsdp_sharded_params_grads(fsdp_param_group)
363                    fsdp_state._pre_forward_hook_handle.remove()
364                    fsdp_state._post_forward_hook_handle.remove()
365                    fsdp_state._pre_forward_hook_handle = (
366                        module.register_forward_pre_hook(
367                            self._fsdp_state_pre_forward(
368                                module, fsdp_state._pre_forward
369                            ),
370                            prepend=True,
371                            with_kwargs=True,
372                        )
373                    )
374                    fsdp_state._post_forward_hook_handle = module.register_forward_hook(
375                        self._fsdp_state_post_forward(module, fsdp_state._post_forward),
376                        prepend=False,
377                        always_call=True,
378                    )
379                    self._fsdp_mod_to_saved_methods[module] = _SavedFSDPMethods(
380                        fsdp_param_group.pre_backward,
381                        fsdp_param_group.post_backward,
382                    )
383                    fsdp_param_group.pre_backward = self._fsdp_param_group_pre_backward(  # type: ignore[assignment]
384                        module, fsdp_param_group.pre_backward
385                    )
386                    fsdp_param_group.post_backward = (  # type: ignore[assignment]
387                        self._fsdp_param_group_post_backward(
388                            module, fsdp_param_group.post_backward
389                        )
390                    )
391
392        for buffer in self._root_mod.buffers():
393            self._update_and_maybe_create_winfos(
394                buffer,
395                _FSDPRefType.BUFFER,
396            )
397
398    def _instrument_optimizer(self) -> None:
399        # Register a hook on the optimizer step to track the optimizer states.
400        # The pre-hook is to set the flag ``_in_opt`` to True. The post-hook unsets the flag,
401        # and also tracks any optimizer states that are created during the optimizer step.
402        if self._optm is not None:
403            self._track_optimizer_states(_FSDPRefType.OPT, self._optm)
404
405            def _opt_step_pre_hook(
406                optimizer: optim.Optimizer, args: Any, kwargs: Any
407            ) -> None:
408                self._in_opt = True
409
410            def _opt_step_post_hook(
411                optimizer: optim.Optimizer, args: Any, kwargs: Any
412            ) -> None:
413                self._track_optimizer_states(_FSDPRefType.OPT, optimizer)
414                self._in_opt = False
415
416            self._optimizer_hook_handles = (
417                self._optm.register_step_pre_hook(_opt_step_pre_hook),
418                self._optm.register_step_post_hook(_opt_step_post_hook),
419            )
420
421    def _register_module_and_optimizer_hooks(self) -> None:
422        self._instrument_fsdp_module()
423        self._instrument_optimizer()
424
425    def _deregister_module_and_optimizer_hooks(self) -> None:
426        for (
427            fsdp_mod,
428            saved_methods,
429        ) in self._fsdp_mod_to_saved_methods.items():
430            fsdp_state = fsdp_mod._get_fsdp_state()
431            fsdp_state._pre_forward_hook_handle.remove()
432            fsdp_state._post_forward_hook_handle.remove()
433            fsdp_state._pre_forward_hook_handle = fsdp_mod.register_forward_pre_hook(
434                fsdp_state._pre_forward, prepend=True, with_kwargs=True
435            )
436            fsdp_state._post_forward_hook_handle = fsdp_mod.register_forward_hook(
437                fsdp_state._post_forward, prepend=False
438            )
439            if fsdp_param_group := fsdp_state._fsdp_param_group:
440                fsdp_param_group.pre_backward = saved_methods.pre_backward
441                fsdp_param_group.post_backward = saved_methods.post_backward
442        self._fsdp_mod_to_saved_methods.clear()
443
444        if self._optimizer_hook_handles is not None:
445            for handle in self._optimizer_hook_handles:
446                handle.remove()
447            self._optimizer_hook_handles = None
448
449    def _instrument_and_maybe_bypass_collectives(self) -> None:
450        # Monkey-patching collectives is required because they do not work with `FakeTensorMode`
451        # It's also easier to track `all_gather` and `reduce_scatter` buffers faithfully.
452        self._saved_collectives = _SavedCollectives(
453            dist.all_gather_into_tensor,
454            dist.reduce_scatter_tensor,
455            dist.all_reduce,
456            dist.barrier,
457        )
458
459        class FakeWork(Work):
460            def __init__(self) -> None:
461                super().__init__()
462
463            def get_future(self) -> Future:
464                future: Future = Future()
465                future.set_result(None)
466                return future
467
468            def wait(self, timeout: Optional[timedelta] = None) -> bool:
469                return True
470
471        @wraps(dist.all_gather_into_tensor)
472        def all_gather_into_tensor(
473            output_tensor: torch.Tensor,
474            input_tensor: torch.Tensor,
475            group: Union[ProcessGroup, None] = None,
476            async_op: bool = False,
477        ) -> Union[Work, _IllegalWork, None]:
478            self._update_and_maybe_create_winfos(
479                output_tensor,
480                _FSDPRefType.ALL_GATHER,
481                update_existing=True,
482            )
483
484            if self._in_fake_mode:
485                if async_op:
486                    return FakeWork()
487                return None
488            else:
489                return self._saved_collectives.all_gather_into_tensor(
490                    output_tensor, input_tensor, group, async_op
491                )
492
493        @wraps(dist.reduce_scatter_tensor)
494        def reduce_scatter_tensor(
495            output: torch.Tensor,
496            input: torch.Tensor,
497            op: ReduceOp.RedOpType = dist.ReduceOp.SUM,
498            group: Union[ProcessGroup, None] = None,
499            async_op: bool = False,
500        ) -> Union[Work, _IllegalWork, None]:
501            self._update_and_maybe_create_winfos(
502                input,
503                _FSDPRefType.REDUCE_SCATTER,
504                update_existing=True,
505            )
506
507            if self._in_fake_mode:
508                if async_op:
509                    return FakeWork()
510                return None
511            else:
512                return self._saved_collectives.reduce_scatter_tensor(
513                    output, input, op, group, async_op
514                )
515
516        @wraps(dist.all_reduce)
517        def all_reduce(
518            tensor: torch.Tensor,
519            op: ReduceOp.RedOpType = dist.ReduceOp.SUM,
520            group: Union[ProcessGroup, None] = None,
521            async_op: bool = False,
522        ) -> Union[Work, _IllegalWork, None]:
523            if self._in_fake_mode:
524                if async_op:
525                    return FakeWork()
526                return None
527            else:
528                return self._saved_collectives.all_reduce(tensor, op, group, async_op)
529
530        @wraps(dist.barrier)
531        def barrier(
532            group: Union[ProcessGroup, None] = dist.GroupMember.WORLD,
533            async_op: bool = False,
534            device_ids: Union[List[int], None] = None,
535        ) -> Union[Work, None]:
536            if self._in_fake_mode:
537                return None
538            else:
539                return self._saved_collectives.barrier(group, async_op, device_ids)
540
541        dist.all_gather_into_tensor = all_gather_into_tensor
542        dist.reduce_scatter_tensor = reduce_scatter_tensor
543        dist.all_reduce = all_reduce
544        dist.barrier = barrier
545
546    def _restore_collectives(self) -> None:
547        dist.all_gather_into_tensor = self._saved_collectives.all_gather_into_tensor
548        dist.reduce_scatter_tensor = self._saved_collectives.reduce_scatter_tensor
549        dist.all_reduce = self._saved_collectives.all_reduce
550        dist.barrier = self._saved_collectives.barrier
551        del self._saved_collectives
552
553    def track_inputs(self, inputs: Tuple[Any, ...]) -> None:
554        """
555        This is used to track the input tensors to the model and annotate them as ``Inputs``.
556        Args:
557            inputs (Tuple[Any]): A tuple containing the input data. This can include tensors
558                        as well as other data types. Only tensors will be tracked.
559        """
560
561        def _track_inputs(t: torch.Tensor) -> None:
562            self._update_and_maybe_create_winfos(
563                t,
564                _FSDPRefType.INP,
565            )
566
567        tree_map_only(torch.Tensor, _track_inputs, inputs)
568
569    def track_external(
570        self, *external: Union[nn.Module, optim.Optimizer, torch.Tensor]
571    ) -> None:
572        """This is no-op for ``FSDPMemTracker``"""
573
574    def __enter__(self) -> "FSDPMemTracker":
575        self._in_fake_mode = True if active_fake_mode() else False
576        self._register_module_and_optimizer_hooks()
577        self._instrument_and_maybe_bypass_collectives()
578        self._track_resize()
579        self._peak_mem_snap = self.get_tracker_snapshot()
580        self._peak_mem = {
581            dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in self._peak_mem_snap.items()
582        }
583        self._mod_tracker.__enter__()
584        TorchDispatchMode.__enter__(self)
585        return self
586
587    def __exit__(self, *args: Any) -> None:
588        self._deregister_module_and_optimizer_hooks()
589        self._restore_collectives()
590        self._restore_resize()
591        TorchDispatchMode.__exit__(self, *args)
592        self._mod_tracker.__exit__(*args)
593
594    def __torch_dispatch__(self, func, types, args=..., kwargs=None):  # type: ignore[no-untyped-def]
595        res = func(*args, **kwargs or {})
596        # If we are tracking an optimizer state, we use the optimizer reference type.
597        # If we are in backward region and not in AC region, we use the backward reference type.
598        # Else we use the forward reference type.
599        if self._in_opt:
600            reftype = _FSDPRefType.OPT
601        elif self._mod_tracker.is_bw and not self._in_ac:
602            reftype = _FSDPRefType.TEMP
603        else:
604            reftype = _FSDPRefType.ACT
605        tree_map_only(torch.Tensor, partial(self._track, reftype), res)
606        peak_state = (
607            _FSDPModState.PEAK_BW if self._mod_tracker.is_bw else _FSDPModState.PEAK_FW
608        )
609        self._update_peak_stats(peak_state)
610        return res
611