1# mypy: allow-untyped-defs 2import contextlib 3import warnings 4from typing import cast, Generator 5 6import torch 7import torch.distributed.fsdp._traversal_utils as traversal_utils 8import torch.nn as nn 9from torch.distributed.fsdp._common_utils import ( 10 _FSDPState, 11 _get_module_fsdp_state, 12 _has_fsdp_params, 13 _module_handle, 14 HandleTrainingState, 15 TrainingState, 16) 17from torch.distributed.fsdp._runtime_utils import ( 18 _lazy_init, 19 _reset_flat_param_grad_info_if_needed, 20 _reshard, 21 _reshard_grads, 22 _unshard, 23 _unshard_grads, 24) 25from torch.distributed.utils import _p_assert 26 27from ._flat_param import FlatParamHandle 28 29 30FLAT_PARAM = "_flat_param" 31 32 33@torch.no_grad() 34def _writeback_to_local_shard( 35 handle: FlatParamHandle, 36 writeback_grad: bool, 37): 38 """ 39 For the handle, writes back the this rank's shard of the unsharded 40 flattened parameter to the sharded flattened parameter. If 41 ``writeback_grad=True``, then writes back to the sharded gradient as 42 well. 43 44 Precondition: The handle's ``FlatParameter`` 's data points to the 45 padded unsharded flattened parameter. 46 """ 47 48 def _get_shard(flat_param_or_grad: torch.Tensor) -> torch.Tensor: 49 if handle.uses_sharded_strategy: 50 # For sharded strategies, get the *unpadded* shard instead of 51 # the *padded* shard to persist user changes to the padding 52 # (though FSDP does not explicitly support this) 53 shard, _ = FlatParamHandle._get_unpadded_shard( 54 flat_param_or_grad, 55 handle.rank, 56 handle.world_size, 57 ) 58 return shard 59 # For `NO_SHARD`, the `flat_param` or its gradient may be modified, 60 # so we write it back directly 61 return flat_param_or_grad 62 63 param_shard = _get_shard(handle.flat_param) 64 handle.flat_param._local_shard[: param_shard.numel()].copy_(param_shard) # type: ignore[attr-defined] 65 if writeback_grad: 66 existing_grad = handle.sharded_grad 67 if existing_grad is not None: 68 assert handle.flat_param.grad is not None 69 grad_shard = _get_shard(handle.flat_param.grad) 70 existing_grad[: grad_shard.numel()].copy_(grad_shard) 71 72 73def _deregister_flat_param(state: _FSDPState, module: nn.Module) -> None: 74 """ 75 De-registers the flattened parameter from the wrapped module, hiding it 76 from ``nn.Module`` methods. 77 78 We do not use ``del`` because we want ``FLAT_PARAM`` to always be an 79 attribute but dynamically change whether it is visible to ``nn.Module`` 80 methods. 81 """ 82 if _has_fsdp_params(state, module): 83 # TODO: figure out the case for the composable APIs. 84 cast(nn.Module, module.module)._parameters.pop(FLAT_PARAM, None) 85 86 87def _register_flat_param(state: _FSDPState, module: nn.Module) -> None: 88 """ 89 Registers the flattened parameter to the wrapped module, making it 90 visible to ``nn.Module`` methods. 91 92 We do not use :meth:`nn.Module.register_parameter` because we want 93 ``FLAT_PARAM`` to always be an attribute but dynamically change whether 94 it is visible to ``nn.Module`` methods. 95 """ 96 handle = _module_handle(state, module) 97 if _has_fsdp_params(state, module): 98 # TODO: figure out the case for the composable APIs. 99 cast(nn.Module, module.module)._parameters[FLAT_PARAM] = handle.flat_param 100 101 102@contextlib.contextmanager 103def _unflatten_as_params(state: _FSDPState, module: nn.Module) -> Generator: 104 """ 105 Assumes that the flattened parameter is unsharded. When in the context, 106 de-registers the flattened parameter and unflattens the original 107 parameters as ``nn.Parameter`` views into the flattened parameter. 108 After the context, re-registers the flattened parameter and restores 109 the original parameters as ``Tensor`` views into the flattened 110 parameter. 111 """ 112 handle = _module_handle(state, module) 113 if not handle: 114 yield 115 else: 116 _deregister_flat_param(state, module) 117 try: 118 with handle.unflatten_as_params(): 119 yield 120 finally: 121 if not handle._use_orig_params: 122 _register_flat_param(state, module) 123 124 125def _validate_unshard_params_args( 126 state: _FSDPState, 127 writeback: bool, 128 rank0_only: bool, 129 offload_to_cpu: bool, 130 with_grads: bool, 131) -> None: 132 if with_grads and (offload_to_cpu or not state._use_orig_params): 133 raise NotImplementedError( 134 f"with_grads={with_grads}, " 135 f"use_orig_params={state._use_orig_params}, " 136 f"offload_to_cpu={offload_to_cpu} " 137 f"is not supported yet" 138 ) 139 if offload_to_cpu and state._handle and (not state._handle.uses_sharded_strategy): 140 raise NotImplementedError( 141 "offload_to_cpu=True and NO_SHARD is not supported yet" 142 ) 143 if writeback and rank0_only: 144 # TODO: Rank 0 can broadcast the `FlatParameter` to allow all ranks to 145 # persist the changes. 146 raise NotImplementedError( 147 "writeback=True and rank0_only=True is not supported yet" 148 ) 149 if offload_to_cpu and not rank0_only: 150 warnings.warn( 151 "offload_to_cpu=True and rank0_only=False may result in the" 152 "unsharded parameters being redundantly copied to CPU memory for " 153 "GPUs sharing the same CPU memory, which risks CPU OOM. We " 154 "recommend using offload_to_cpu=True with rank0_only=True." 155 ) 156 157 158@contextlib.contextmanager 159def _unshard_fsdp_state_params( 160 module: nn.Module, 161 state: _FSDPState, 162 writeback: bool, 163 rank0_only: bool, 164 offload_to_cpu: bool, 165 with_grads: bool, 166): 167 """ 168 This unshards the parameters for a single FSDP state ``state`` that 169 corresponds to ``module``. 170 """ 171 _validate_unshard_params_args( 172 state, writeback, rank0_only, offload_to_cpu, with_grads 173 ) 174 state._device_handle.synchronize() 175 # If handles are shared by other module(s), the handle may be already unsharded. 176 maybe_handle = _module_handle(state, module) 177 handle = None 178 if ( 179 maybe_handle 180 and maybe_handle._training_state != HandleTrainingState.SUMMON_FULL_PARAMS 181 ): 182 handle = maybe_handle 183 if not handle: 184 yield 185 return 186 187 assert ( 188 handle._training_state == HandleTrainingState.IDLE 189 ), f"Expects the handle training to be IDLE but got {handle._training_state}" 190 191 handle._training_state = HandleTrainingState.SUMMON_FULL_PARAMS 192 193 _reset_flat_param_grad_info_if_needed(handle) 194 free_unsharded_flat_param = handle.needs_unshard() 195 # No need to call `wait_stream()` since we unshard in the computation 196 # stream directly 197 computation_stream = state._device_handle.current_stream() 198 _unshard(state, handle, computation_stream, computation_stream) 199 if with_grads: 200 _unshard_grads(handle) 201 202 if rank0_only and state.rank != 0: 203 # Free the unsharded flattened parameter early 204 _reshard(state, handle, free_unsharded_flat_param) 205 if with_grads: 206 _reshard_grads(handle) 207 try: 208 yield 209 finally: 210 handle._training_state = HandleTrainingState.IDLE 211 else: 212 # Unflatten the unsharded flattened parameters 213 with contextlib.ExitStack() as stack: 214 # Invariant: rank == 0 or !rank0_only 215 if offload_to_cpu and handle.uses_sharded_strategy: 216 stack.enter_context(handle.to_cpu()) 217 # NOTE: Since PyTorch enforces that a parameter and its 218 # gradients need to match metadata (e.g. device), we must 219 # move gradients to CPU *after* we move parameters. 220 # NOTE: This assumes 1 `FlatParameter` 221 if not state._use_orig_params: 222 stack.enter_context(_unflatten_as_params(state, module)) 223 try: 224 yield 225 finally: 226 stack.close() 227 if writeback: 228 _writeback_to_local_shard(handle, with_grads) 229 _reshard(state, handle, free_unsharded_flat_param) 230 if with_grads: 231 _reshard_grads(handle) 232 handle._training_state = HandleTrainingState.IDLE 233 234 235@contextlib.contextmanager 236def _unshard_params_for_summon( 237 module: nn.Module, 238 state: _FSDPState, 239 writeback: bool, 240 rank0_only: bool, 241 offload_to_cpu: bool, 242 with_grads: bool, 243): 244 _validate_unshard_params_args( 245 state, writeback, rank0_only, offload_to_cpu, with_grads 246 ) 247 _lazy_init(state, module) 248 if state.training_state == TrainingState.FORWARD_BACKWARD: 249 raise AssertionError( 250 "Cannot manually unshard parameters during forward/backward" 251 ) 252 elif state.training_state == TrainingState.SUMMON_FULL_PARAMS: 253 raise AssertionError( 254 "Cannot manually unshard parameters when already unsharding parameters" 255 ) 256 with _unshard_fsdp_state_params( 257 module=module, 258 state=state, 259 writeback=writeback, 260 rank0_only=rank0_only, 261 offload_to_cpu=offload_to_cpu, 262 with_grads=with_grads, 263 ): 264 try: 265 state.training_state = TrainingState.SUMMON_FULL_PARAMS 266 yield 267 finally: 268 state.training_state = TrainingState.IDLE 269 270 271@contextlib.contextmanager 272def _unshard_params( 273 module: nn.Module, 274 recurse: bool, 275 writeback: bool, 276 rank0_only: bool, 277 offload_to_cpu: bool, 278 with_grads: bool, 279): 280 """ 281 This unshards FSDP-managed parameters for all modules with FSDP applied in 282 the module tree rooted at ``module``. 283 """ 284 if not recurse: 285 optional_state = _get_module_fsdp_state(module) 286 if optional_state is None: 287 with contextlib.nullcontext(): 288 yield 289 return 290 states_and_modules = ([optional_state], [module]) 291 else: 292 states_and_modules = traversal_utils._get_fsdp_states_with_modules(module) 293 with contextlib.ExitStack() as stack: 294 for state, module in zip(*states_and_modules): 295 stack.enter_context( 296 _unshard_params_for_summon( 297 module=module, 298 state=state, 299 writeback=writeback, 300 rank0_only=rank0_only, 301 offload_to_cpu=offload_to_cpu, 302 with_grads=with_grads, 303 ) 304 ) 305 yield 306 307 308def _deregister_orig_params(state: _FSDPState, module: nn.Module) -> None: 309 """ 310 Deregisters the original parameters; registers the ``FlatParameter``. 311 """ 312 handle = _module_handle(state, module) 313 if not handle: 314 return 315 _p_assert( 316 handle._use_orig_params, 317 f"Inconsistent `_use_orig_params` -- FSDP: {state._use_orig_params} " 318 f"handle: {handle._use_orig_params}", 319 ) 320 handle._deregister_orig_params() 321 _register_flat_param(state, module) 322 323 324def _register_orig_params(state: _FSDPState, module: nn.Module) -> None: 325 """ 326 Deregisters the ``FlatParameter``; registers the original parameters. 327 """ 328 handle = _module_handle(state, module) 329 if not handle: 330 return 331 _deregister_flat_param(state, module) 332 if handle.is_sharded(handle.flat_param): 333 handle._use_sharded_views() 334 handle._use_sharded_grad_views() 335 else: 336 handle._use_unsharded_views(as_params=True) 337