1from copy import deepcopy 2from datetime import timedelta 3from functools import partial, wraps 4from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type, Union 5 6import torch 7import torch.distributed as dist 8from torch import nn, optim 9from torch._guards import active_fake_mode 10from torch.distributed._composable.fsdp import FSDPModule 11from torch.distributed._composable.fsdp._fsdp_param_group import FSDPParamGroup 12from torch.distributed._tools.mem_tracker import _RefType, _State, MemTracker 13from torch.distributed.distributed_c10d import ( 14 _IllegalWork, 15 ProcessGroup, 16 ReduceOp, 17 Work, 18) 19from torch.futures import Future 20from torch.utils._python_dispatch import TorchDispatchMode 21from torch.utils._pytree import tree_map_only 22from torch.utils.weak import WeakIdKeyDictionary, weakref 23 24 25_TOTAL_KEY = "Total" 26 27__all__ = ["FSDPMemTracker"] 28 29 30class _FSDPRefType(_RefType): 31 """ 32 Enumerates categories of memory usage in FSDP modules, including parameters, gradients, activations, 33 and optimizer states. 34 35 Attributes: 36 SHARDED_PARAM (str): Memory usage of sharded parameters. 37 UNSHARDED_PARAM (str): Memory usage of unsharded parameters. 38 SHARDED_GRAD (str): Memory usage of sharded gradients corresponding to the sharded parameters. 39 UNSHARDED_GRAD (str): Memory usage of unsharded gradients corresponding to the unsharded parameters. 40 ACT (str): Memory usage of activations and tensors from forward and AC recomputation. 41 TEMP (str): Memory usage of temporary tensors during the backward pass including gradients of activations. 42 ALL_GATHER (str): Memory usage of all_gather output tensor. 43 REDUCE_SCATTER (str): Memory usage of reduce_scatter input tensor. 44 OPT (str): Memory usage of tensors storing optimizer states. 45 INP (str): Memory usage of input tensors. 46 """ 47 48 SHARDED_PARAM = "Sharded Param" 49 UNSHARDED_PARAM = "Unsharded Param" 50 BUFFER = "Buffer" 51 SHARDED_GRAD = "Sharded Grad" 52 UNSHARDED_GRAD = "Unsharded Grad" 53 ACT = "Activation" 54 TEMP = "Temp" 55 ALL_GATHER = "All Gather" 56 REDUCE_SCATTER = "Reduce Scatter" 57 OPT = "OptState" 58 INP = "Inputs" 59 60 61class _SavedFSDPMethods(NamedTuple): 62 pre_backward: Callable 63 post_backward: Callable 64 65 66class _SavedCollectives(NamedTuple): 67 all_gather_into_tensor: Callable 68 reduce_scatter_tensor: Callable 69 all_reduce: Callable 70 barrier: Callable 71 72 73class _FSDPModState(_State): 74 """ 75 Enumerates the states of FSDP modules during the forward and backward passes. 76 """ 77 78 BEF_PRE_FW = "Before Pre-Forward" 79 AFT_PRE_FW = "After Pre-Forward" 80 BEF_POST_FW = "Before Post-Forward" 81 AFT_POST_FW = "After Post-Forward" 82 BEF_PRE_BW = "Before Pre-Backward" 83 AFT_PRE_BW = "After Pre-Backward" 84 BEF_POST_BW = "Before Post-Backward" 85 AFT_POST_BW = "After Post-Backward" 86 PRE_FW_AC = "Pre-Forward AC" 87 POST_FW_AC = "Post-Forward AC" 88 PEAK_FW = "Peak Forward" 89 PEAK_BW = "Peak Backward" 90 91 92class _FSDPModMemStats: 93 """ 94 A class to store the memory statistics of an FSDP module. 95 96 Args: 97 mod_fqn (str): The fully qualified name of the FSDP module. 98 99 Attributes: 100 snapshots (Dict[_FSDPModState, Dict[torch.device, Dict[str, int]]]): A dictionary of memory snapshots 101 of the module at different states as defined by ``_FSDPModState``. Each key is a device, and 102 each value is another dictionary with keys as memory reference types defined by ``_FSDPRefType`` and 103 values as the memory consumed in bytes. 104 105 """ 106 107 def __init__(self, mod_fqn: str) -> None: 108 self.mod_fqn = mod_fqn 109 self.local_peak: Dict[torch.device, int] = {} 110 self.snapshots: Dict[ 111 _FSDPModState, List[Dict[torch.device, Dict[str, int]]] 112 ] = {} 113 114 115class FSDPMemTracker(MemTracker): 116 """ 117 A ``TorchDispatchMode`` based context manager that extends ``torch.distributed._tools.mem_tracker.MemTracker`` to track 118 and categorize the peak memory and module-wise memory usage of FSDP modules. 119 120 It tracks the peak memory usage across all the devices of all the FSDP modules in the module tree and categorizes 121 the tensor memory usage as defined by ``_FSDPRefType``. Further, it captures memory `snapshots` at different stages of 122 the module execution defined by ``_FSDPModState``. 123 124 Attributes: 125 memory_tracking: A weakref key dictionary to store the memory statistics of each module. Each key is a reference 126 to a module, and each value is a ``_FSDPModMemStats`` object that stores the memory statistics of the module. 127 128 Args: 129 mod (torch.nn.Module): The root FSDP module to be tracked. 130 optm (torch.optim.Optimizer, optional): The optimizer to be tracked. 131 132 Note: Please refer to ``torch.distributed._tools.mem_tracker.MemTracker`` to learn about the limitations. 133 134 Example usage 135 136 .. code-block:: python 137 138 module = ... 139 optimizer = ... 140 inp = ... 141 fmt = FSDPMemTracker(module, optimizer) 142 fmt.track_inputs((inp,)) 143 with fmt: 144 optimizer.zero_grad() 145 loss = module(inp) 146 print("After Forward:") 147 fmt.display_snapshot("current") 148 loss.backward() 149 optimizer.step() 150 fmt.display_snapshot("peak") 151 fmt.display_modulewise_snapshots(depth = 3, units = "MB") 152 153 """ 154 155 def __init__( 156 self, 157 mod: torch.nn.Module, 158 optm: Optional[torch.optim.Optimizer] = None, 159 ) -> None: 160 super().__init__() 161 assert isinstance(mod, FSDPModule), "FSDPMemTracker only supports FSDP modules" 162 self._root_mod = mod 163 self._optm = optm 164 self._in_fake_mode: bool = False 165 self._fsdp_mod_to_saved_methods: WeakIdKeyDictionary = WeakIdKeyDictionary() 166 self._saved_collectives: _SavedCollectives 167 self._ref_class: Type[_RefType] = _FSDPRefType 168 169 def _instrument_fsdp_sharded_params_grads( 170 self, fsdp_param_group: FSDPParamGroup 171 ) -> None: 172 # Track sharded params and grads after initilization 173 for fsdp_param in fsdp_param_group.fsdp_params: 174 self._update_and_maybe_create_winfos( 175 fsdp_param.sharded_param, 176 _FSDPRefType.SHARDED_PARAM, 177 ) 178 sharded_grad = fsdp_param.sharded_param.grad 179 if sharded_grad is not None: 180 self._update_and_maybe_create_winfos( 181 sharded_grad, 182 _FSDPRefType.SHARDED_GRAD, 183 ) 184 185 def _fsdp_state_pre_forward( 186 self, 187 fsdp_mod: FSDPModule, 188 orig_fsdp_state_pre_fw: Callable, 189 ) -> Callable: 190 # We capture memory snapshots before and after ``FSDPState._pre_forward`` to attribute the `unsharded` params 191 # and `all_gather` buffers. There are three cases: 192 # Case 1: If the module is not in the ``memory_tracking`` dictionary, create a new ``_FSDPModMemStats`` 193 # instance for the module and add it to the ``memory_tracking`` dictionary. 194 # Case 2: If the module is already in the ``memory_tracking`` dictionary and we are in backward, this means 195 # we are in the AC region. We check if this is the top most module in the AC region. If it is, 196 # we store a weak reference and set the flag ``_in_ac`` to True. 197 # Case 3: If the module is already in the ``memory_tracking`` dictionary and we are in forward, this means 198 # this module is called for the second time. If it is a root module, that means we are in the next 199 # iteration and we error out. If it is not a root module, that means it's a submodule that is being 200 # used multiple times in the same iteration, which we allow and track. 201 # For Case 1 and 3, we also initialiaze the ``local_peak`` and ``PEAK_FW`` snapshot for the module. 202 # For Case 2 we only capture 1 snapshot after ``FSDPState._pre_forward`` runs because it is a no-op. 203 @wraps(orig_fsdp_state_pre_fw) 204 def inner(*args: Any, **kwargs: Any) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: 205 mod_fqn = self._mod_tracker.get_known_fqn(fsdp_mod) 206 assert mod_fqn is not None 207 if fsdp_mod not in self.memory_tracking: 208 mod_stat = _FSDPModMemStats(mod_fqn) 209 self.memory_tracking[fsdp_mod] = mod_stat 210 snapshot = self.get_tracker_snapshot() 211 mod_stat.local_peak = { 212 dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in snapshot.items() 213 } 214 mod_stat.snapshots.setdefault(_FSDPModState.PEAK_FW, []).append( 215 snapshot 216 ) 217 mod_stat.snapshots.setdefault(_FSDPModState.BEF_PRE_FW, []).append( 218 deepcopy(snapshot) 219 ) 220 elif not self._mod_tracker.is_bw: 221 parents = self._mod_tracker.parents - {mod_fqn} 222 if len(parents) == 1 and "Global" in parents: 223 raise NotImplementedError( 224 "FSDPMemTracker does not support memory tracking for multiple iterative calls." 225 " Either use ``reset_mod_stats`` to clear module memory stats for the previous iteration" 226 " or file a github issue if you need this feature." 227 ) 228 229 args, kwargs = orig_fsdp_state_pre_fw(*args, **kwargs) 230 231 fsdp_state = fsdp_mod._get_fsdp_state() 232 if fsdp_param_group := fsdp_state._fsdp_param_group: 233 for fsdp_param in fsdp_param_group.fsdp_params: 234 self._update_and_maybe_create_winfos( 235 fsdp_param.unsharded_param, 236 _FSDPRefType.UNSHARDED_PARAM, 237 ) 238 mod_stat = self.memory_tracking[fsdp_mod] 239 if self._mod_tracker.is_bw: 240 state = _FSDPModState.PRE_FW_AC 241 if self._ac_mod is None: 242 self._ac_mod = weakref.ref(fsdp_mod) 243 self._in_ac = True 244 else: 245 state = _FSDPModState.AFT_PRE_FW 246 mod_stat.snapshots.setdefault(state, []).append(self.get_tracker_snapshot()) 247 return args, kwargs 248 249 return inner 250 251 def _fsdp_state_post_forward( 252 self, 253 fsdp_mod: FSDPModule, 254 orig_fsdp_state_post_fw: Callable, 255 ) -> Callable: 256 # We capture memory snapshots before and after ``FSDPState._post_forward`` to capture the resharded state 257 # if ``reshard_after_forward`` is not ``False``. There are two cases: 258 # Case 1: This is called in backward, which means we are in the AC region. If this is the top most module 259 # in the AC region, we set the flag ``_in_ac`` to False. 260 # Case 2: This is called in forward. 261 @wraps(orig_fsdp_state_post_fw) 262 def inner(*args: Any, **kwargs: Any) -> Any: 263 mod_stat = self.memory_tracking[fsdp_mod] 264 if self._mod_tracker.is_bw: 265 state = _FSDPModState.POST_FW_AC 266 if self._ac_mod is not None and self._ac_mod() is fsdp_mod: 267 self._ac_mod = None 268 self._in_ac = False 269 else: 270 state = _FSDPModState.BEF_POST_FW 271 mod_stat.snapshots.setdefault(state, []).append(self.get_tracker_snapshot()) 272 273 output = orig_fsdp_state_post_fw(*args, **kwargs) 274 275 if not self._mod_tracker.is_bw: 276 mod_stat.snapshots.setdefault(_FSDPModState.AFT_POST_FW, []).append( 277 self.get_tracker_snapshot() 278 ) 279 return output 280 281 return inner 282 283 def _fsdp_param_group_pre_backward( 284 self, 285 fsdp_mod: FSDPModule, 286 orig_fsdp_param_group_pre_backward: Callable, 287 ) -> Callable: 288 # We capture memory snapshots before and after ``FSDPParamGroup.pre_backward`` to capture the pre-fetching 289 # and unsharding of params. We also initialize ``local_peak`` and ``PEAK_BW`` snapshot for the module. 290 @wraps(orig_fsdp_param_group_pre_backward) 291 def inner(*args: Any, **kwargs: Any) -> None: 292 mod_stat = self.memory_tracking[fsdp_mod] 293 snapshot = self.get_tracker_snapshot() 294 mod_stat.local_peak = { 295 dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in snapshot.items() 296 } 297 mod_stat.snapshots.setdefault(_FSDPModState.PEAK_BW, []).append(snapshot) 298 mod_stat.snapshots.setdefault(_FSDPModState.BEF_PRE_BW, []).append( 299 deepcopy(snapshot) 300 ) 301 orig_fsdp_param_group_pre_backward(*args, **kwargs) 302 303 mod_stat.snapshots.setdefault(_FSDPModState.AFT_PRE_BW, []).append( 304 self.get_tracker_snapshot() 305 ) 306 307 return inner 308 309 def _fsdp_param_group_post_backward( 310 self, 311 fsdp_mod: FSDPModule, 312 orig_fsdp_param_group_post_backward: Callable, 313 ) -> Callable: 314 # We capture the memory snapshots before and after ``FSDPParamGroup.post_backward`` to track and attribute 315 # the `unsharded` grads before the post backward and then `sharded` grads and `reduce_scatter` buffers 316 # after the post backward. 317 @wraps(orig_fsdp_param_group_post_backward) 318 def inner(*args: Any, **kwargs: Any) -> None: 319 fsdp_state = fsdp_mod._get_fsdp_state() 320 if fsdp_param_group := fsdp_state._fsdp_param_group: 321 for fsdp_param in fsdp_param_group.fsdp_params: 322 unsharded_grad = fsdp_param._unsharded_param.grad 323 if unsharded_grad is not None: 324 self._update_and_maybe_create_winfos( 325 unsharded_grad, 326 _FSDPRefType.UNSHARDED_GRAD, 327 update_existing=True, 328 ) 329 330 mod_stat = self.memory_tracking[fsdp_mod] 331 mod_stat.snapshots.setdefault(_FSDPModState.BEF_POST_BW, []).append( 332 self.get_tracker_snapshot() 333 ) 334 335 orig_fsdp_param_group_post_backward(*args, **kwargs) 336 337 if fsdp_param_group := fsdp_state._fsdp_param_group: 338 for fsdp_param in fsdp_param_group.fsdp_params: 339 sharded_grad = fsdp_param.sharded_param.grad 340 if sharded_grad is not None: 341 self._update_and_maybe_create_winfos( 342 sharded_grad, 343 _FSDPRefType.SHARDED_GRAD, 344 ) 345 346 mod_stat.snapshots.setdefault(_FSDPModState.AFT_POST_BW, []).append( 347 self.get_tracker_snapshot() 348 ) 349 350 return inner 351 352 def _instrument_fsdp_module(self) -> None: 353 # We uninstall the existing `FSDPState._pre_forward` and `FSDPState._post_forward` hooks and install 354 # our own hooks that wrap them. We choose this over monkey-patching `FSDPParamGroup.pre_forward` and 355 # `FSDPParamGroup.post_forward` because during AC these won't be called. 356 # TODO(@sanketpurandare): This will need to be modified after this PR (https://github.com/pytorch/pytorch/pull/127786) 357 # lands. For backward we monkey-patch the `FSDPParamGroup.pre_backward` and `FSDPParamGroup.post_backward`. 358 for module in self._root_mod.modules(): 359 if isinstance(module, FSDPModule): 360 fsdp_state = module._get_fsdp_state() 361 if fsdp_param_group := fsdp_state._fsdp_param_group: 362 self._instrument_fsdp_sharded_params_grads(fsdp_param_group) 363 fsdp_state._pre_forward_hook_handle.remove() 364 fsdp_state._post_forward_hook_handle.remove() 365 fsdp_state._pre_forward_hook_handle = ( 366 module.register_forward_pre_hook( 367 self._fsdp_state_pre_forward( 368 module, fsdp_state._pre_forward 369 ), 370 prepend=True, 371 with_kwargs=True, 372 ) 373 ) 374 fsdp_state._post_forward_hook_handle = module.register_forward_hook( 375 self._fsdp_state_post_forward(module, fsdp_state._post_forward), 376 prepend=False, 377 always_call=True, 378 ) 379 self._fsdp_mod_to_saved_methods[module] = _SavedFSDPMethods( 380 fsdp_param_group.pre_backward, 381 fsdp_param_group.post_backward, 382 ) 383 fsdp_param_group.pre_backward = self._fsdp_param_group_pre_backward( # type: ignore[assignment] 384 module, fsdp_param_group.pre_backward 385 ) 386 fsdp_param_group.post_backward = ( # type: ignore[assignment] 387 self._fsdp_param_group_post_backward( 388 module, fsdp_param_group.post_backward 389 ) 390 ) 391 392 for buffer in self._root_mod.buffers(): 393 self._update_and_maybe_create_winfos( 394 buffer, 395 _FSDPRefType.BUFFER, 396 ) 397 398 def _instrument_optimizer(self) -> None: 399 # Register a hook on the optimizer step to track the optimizer states. 400 # The pre-hook is to set the flag ``_in_opt`` to True. The post-hook unsets the flag, 401 # and also tracks any optimizer states that are created during the optimizer step. 402 if self._optm is not None: 403 self._track_optimizer_states(_FSDPRefType.OPT, self._optm) 404 405 def _opt_step_pre_hook( 406 optimizer: optim.Optimizer, args: Any, kwargs: Any 407 ) -> None: 408 self._in_opt = True 409 410 def _opt_step_post_hook( 411 optimizer: optim.Optimizer, args: Any, kwargs: Any 412 ) -> None: 413 self._track_optimizer_states(_FSDPRefType.OPT, optimizer) 414 self._in_opt = False 415 416 self._optimizer_hook_handles = ( 417 self._optm.register_step_pre_hook(_opt_step_pre_hook), 418 self._optm.register_step_post_hook(_opt_step_post_hook), 419 ) 420 421 def _register_module_and_optimizer_hooks(self) -> None: 422 self._instrument_fsdp_module() 423 self._instrument_optimizer() 424 425 def _deregister_module_and_optimizer_hooks(self) -> None: 426 for ( 427 fsdp_mod, 428 saved_methods, 429 ) in self._fsdp_mod_to_saved_methods.items(): 430 fsdp_state = fsdp_mod._get_fsdp_state() 431 fsdp_state._pre_forward_hook_handle.remove() 432 fsdp_state._post_forward_hook_handle.remove() 433 fsdp_state._pre_forward_hook_handle = fsdp_mod.register_forward_pre_hook( 434 fsdp_state._pre_forward, prepend=True, with_kwargs=True 435 ) 436 fsdp_state._post_forward_hook_handle = fsdp_mod.register_forward_hook( 437 fsdp_state._post_forward, prepend=False 438 ) 439 if fsdp_param_group := fsdp_state._fsdp_param_group: 440 fsdp_param_group.pre_backward = saved_methods.pre_backward 441 fsdp_param_group.post_backward = saved_methods.post_backward 442 self._fsdp_mod_to_saved_methods.clear() 443 444 if self._optimizer_hook_handles is not None: 445 for handle in self._optimizer_hook_handles: 446 handle.remove() 447 self._optimizer_hook_handles = None 448 449 def _instrument_and_maybe_bypass_collectives(self) -> None: 450 # Monkey-patching collectives is required because they do not work with `FakeTensorMode` 451 # It's also easier to track `all_gather` and `reduce_scatter` buffers faithfully. 452 self._saved_collectives = _SavedCollectives( 453 dist.all_gather_into_tensor, 454 dist.reduce_scatter_tensor, 455 dist.all_reduce, 456 dist.barrier, 457 ) 458 459 class FakeWork(Work): 460 def __init__(self) -> None: 461 super().__init__() 462 463 def get_future(self) -> Future: 464 future: Future = Future() 465 future.set_result(None) 466 return future 467 468 def wait(self, timeout: Optional[timedelta] = None) -> bool: 469 return True 470 471 @wraps(dist.all_gather_into_tensor) 472 def all_gather_into_tensor( 473 output_tensor: torch.Tensor, 474 input_tensor: torch.Tensor, 475 group: Union[ProcessGroup, None] = None, 476 async_op: bool = False, 477 ) -> Union[Work, _IllegalWork, None]: 478 self._update_and_maybe_create_winfos( 479 output_tensor, 480 _FSDPRefType.ALL_GATHER, 481 update_existing=True, 482 ) 483 484 if self._in_fake_mode: 485 if async_op: 486 return FakeWork() 487 return None 488 else: 489 return self._saved_collectives.all_gather_into_tensor( 490 output_tensor, input_tensor, group, async_op 491 ) 492 493 @wraps(dist.reduce_scatter_tensor) 494 def reduce_scatter_tensor( 495 output: torch.Tensor, 496 input: torch.Tensor, 497 op: ReduceOp.RedOpType = dist.ReduceOp.SUM, 498 group: Union[ProcessGroup, None] = None, 499 async_op: bool = False, 500 ) -> Union[Work, _IllegalWork, None]: 501 self._update_and_maybe_create_winfos( 502 input, 503 _FSDPRefType.REDUCE_SCATTER, 504 update_existing=True, 505 ) 506 507 if self._in_fake_mode: 508 if async_op: 509 return FakeWork() 510 return None 511 else: 512 return self._saved_collectives.reduce_scatter_tensor( 513 output, input, op, group, async_op 514 ) 515 516 @wraps(dist.all_reduce) 517 def all_reduce( 518 tensor: torch.Tensor, 519 op: ReduceOp.RedOpType = dist.ReduceOp.SUM, 520 group: Union[ProcessGroup, None] = None, 521 async_op: bool = False, 522 ) -> Union[Work, _IllegalWork, None]: 523 if self._in_fake_mode: 524 if async_op: 525 return FakeWork() 526 return None 527 else: 528 return self._saved_collectives.all_reduce(tensor, op, group, async_op) 529 530 @wraps(dist.barrier) 531 def barrier( 532 group: Union[ProcessGroup, None] = dist.GroupMember.WORLD, 533 async_op: bool = False, 534 device_ids: Union[List[int], None] = None, 535 ) -> Union[Work, None]: 536 if self._in_fake_mode: 537 return None 538 else: 539 return self._saved_collectives.barrier(group, async_op, device_ids) 540 541 dist.all_gather_into_tensor = all_gather_into_tensor 542 dist.reduce_scatter_tensor = reduce_scatter_tensor 543 dist.all_reduce = all_reduce 544 dist.barrier = barrier 545 546 def _restore_collectives(self) -> None: 547 dist.all_gather_into_tensor = self._saved_collectives.all_gather_into_tensor 548 dist.reduce_scatter_tensor = self._saved_collectives.reduce_scatter_tensor 549 dist.all_reduce = self._saved_collectives.all_reduce 550 dist.barrier = self._saved_collectives.barrier 551 del self._saved_collectives 552 553 def track_inputs(self, inputs: Tuple[Any, ...]) -> None: 554 """ 555 This is used to track the input tensors to the model and annotate them as ``Inputs``. 556 Args: 557 inputs (Tuple[Any]): A tuple containing the input data. This can include tensors 558 as well as other data types. Only tensors will be tracked. 559 """ 560 561 def _track_inputs(t: torch.Tensor) -> None: 562 self._update_and_maybe_create_winfos( 563 t, 564 _FSDPRefType.INP, 565 ) 566 567 tree_map_only(torch.Tensor, _track_inputs, inputs) 568 569 def track_external( 570 self, *external: Union[nn.Module, optim.Optimizer, torch.Tensor] 571 ) -> None: 572 """This is no-op for ``FSDPMemTracker``""" 573 574 def __enter__(self) -> "FSDPMemTracker": 575 self._in_fake_mode = True if active_fake_mode() else False 576 self._register_module_and_optimizer_hooks() 577 self._instrument_and_maybe_bypass_collectives() 578 self._track_resize() 579 self._peak_mem_snap = self.get_tracker_snapshot() 580 self._peak_mem = { 581 dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in self._peak_mem_snap.items() 582 } 583 self._mod_tracker.__enter__() 584 TorchDispatchMode.__enter__(self) 585 return self 586 587 def __exit__(self, *args: Any) -> None: 588 self._deregister_module_and_optimizer_hooks() 589 self._restore_collectives() 590 self._restore_resize() 591 TorchDispatchMode.__exit__(self, *args) 592 self._mod_tracker.__exit__(*args) 593 594 def __torch_dispatch__(self, func, types, args=..., kwargs=None): # type: ignore[no-untyped-def] 595 res = func(*args, **kwargs or {}) 596 # If we are tracking an optimizer state, we use the optimizer reference type. 597 # If we are in backward region and not in AC region, we use the backward reference type. 598 # Else we use the forward reference type. 599 if self._in_opt: 600 reftype = _FSDPRefType.OPT 601 elif self._mod_tracker.is_bw and not self._in_ac: 602 reftype = _FSDPRefType.TEMP 603 else: 604 reftype = _FSDPRefType.ACT 605 tree_map_only(torch.Tensor, partial(self._track, reftype), res) 606 peak_state = ( 607 _FSDPModState.PEAK_BW if self._mod_tracker.is_bw else _FSDPModState.PEAK_FW 608 ) 609 self._update_peak_stats(peak_state) 610 return res 611