1# mypy: allow-untyped-defs 2import contextlib 3import logging 4import math 5import warnings 6from typing import ( 7 Any, 8 Callable, 9 cast, 10 Dict, 11 Generator, 12 Iterator, 13 List, 14 no_type_check, 15 Tuple, 16) 17 18import torch 19import torch.distributed as dist 20import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as checkpoint_wrapper 21import torch.nn as nn 22import torch.nn.functional as F 23from torch.distributed._shard.sharded_tensor import ( 24 init_from_local_shards, 25 Shard, 26 ShardedTensor, 27) 28from torch.distributed.device_mesh import _mesh_resources 29from torch.distributed.fsdp._common_utils import ( 30 _FSDPState, 31 _get_module_fsdp_state_if_fully_sharded_module, 32 _has_fsdp_params, 33 _is_composable, 34 _module_handle, 35 clean_tensor_name, 36 FSDP_PREFIX, 37 FSDP_WRAPPED_MODULE, 38) 39from torch.distributed.fsdp._debug_utils import SimpleProfiler 40from torch.distributed.fsdp._runtime_utils import ( 41 _cast_buffers_to_dtype_and_device, 42 _get_orig_buffer_dtypes, 43 _lazy_init, 44 _reset_flat_param_grad_info_if_needed, 45) 46from torch.distributed.fsdp.api import ( 47 FullStateDictConfig, 48 ShardingStrategy, 49 StateDictType, 50) 51from torch.distributed.tensor import DTensor 52from torch.distributed.utils import _replace_by_prefix 53 54from ._fsdp_extensions import ( 55 _ext_all_gather_dtensor, 56 _ext_chunk_dtensor, 57 _ext_chunk_tensor, 58 _ext_post_unflatten_transform, 59 _ext_pre_load_state_dict_transform, 60) 61from ._unshard_param_utils import _unshard_fsdp_state_params, FLAT_PARAM 62 63 64logger = logging.getLogger(__name__) 65 66 67def _should_unshard_params(fsdp_state: _FSDPState) -> bool: 68 return not ( 69 fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD 70 and (_is_composable(fsdp_state) or fsdp_state._use_orig_params) 71 ) 72 73 74def _convert_to_wrapped_module_name(module_name: str) -> str: 75 module_name = module_name.replace(f"{FSDP_PREFIX}", "") 76 module_name = module_name.replace(f"{FSDP_WRAPPED_MODULE}", "") 77 if module_name: 78 module_name = f"{module_name}." 79 # `CheckpointWrapper` adds a prefix that has to be removed as well. 80 module_name = module_name.replace(checkpoint_wrapper._CHECKPOINT_PREFIX, "") 81 return module_name 82 83 84def _param_name_infos( 85 module: nn.Module, fsdp_state: _FSDPState 86) -> Iterator[Tuple[str, str, str]]: 87 if not _has_fsdp_params(fsdp_state, module): 88 return 89 for param_name, module_name in _module_handle( 90 fsdp_state, module 91 ).param_module_names(): 92 module_name = _convert_to_wrapped_module_name(module_name) 93 fqn = f"{module_name}{param_name}" 94 yield fqn, param_name, module_name 95 96 97def _shared_param_name_infos( 98 module: nn.Module, fsdp_state 99) -> Iterator[Tuple[str, str, str]]: 100 for param_name, module_name in _module_handle( 101 fsdp_state, module 102 ).shared_param_module_names(): 103 module_name = _convert_to_wrapped_module_name(module_name) 104 fqn = f"{module_name}{param_name}" 105 yield fqn, param_name, module_name 106 107 108@no_type_check 109def _enter_unshard_params_ctx( 110 module: nn.Module, 111 fsdp_state: _FSDPState, 112 writeback: bool = False, 113 rank0_only: bool = False, 114 offload_to_cpu: bool = False, 115 with_grads: bool = False, 116) -> None: 117 """ 118 state_dict hooks cannot use the pure context call as the checkpoint flow 119 requires to enter the context in the pre-hook but leave the context in the 120 post-hook. This API enters the context of ``_unshard_fsdp_state_params``. 121 """ 122 assert module not in fsdp_state._unshard_params_ctx, ( 123 "Entering the ``_unshard_fsdp_state_params`` context but _unshard_params_ctx[module] " 124 "is not None." 125 ) 126 fsdp_state._unshard_params_ctx[module] = _unshard_fsdp_state_params( 127 module, 128 fsdp_state, 129 writeback=writeback, 130 rank0_only=rank0_only, 131 offload_to_cpu=offload_to_cpu, 132 with_grads=with_grads, 133 ) 134 fsdp_state._unshard_params_ctx[module].__enter__() 135 136 137@no_type_check 138def _exit_unshard_params_ctx(module: nn.Module, fsdp_state: _FSDPState) -> None: 139 """A helper function to exit ``_unshard_fsdp_state_params`` context.""" 140 fsdp_state._unshard_params_ctx[module].__exit__(None, None, None) 141 fsdp_state._unshard_params_ctx.pop(module) 142 143 144def _common_pre_state_dict_hook( 145 module: nn.Module, 146 fsdp_state: _FSDPState, 147) -> None: 148 """Performs the pre-state_dict tasks shared by all state_dict types.""" 149 if fsdp_state._device_handle.is_available(): 150 fsdp_state._device_handle.synchronize() 151 # TODO: need to check if this is always correct for composable FSDP. 152 _lazy_init(fsdp_state, module) 153 if fsdp_state._is_root: 154 _reset_flat_param_grad_info_if_needed(fsdp_state._all_handles) 155 156 157def _common_unshard_pre_state_dict_hook( 158 module: nn.Module, 159 fsdp_state: _FSDPState, 160 offload_to_cpu: bool, 161 rank0_only: bool, 162) -> None: 163 """ 164 Performs the pre-state_dict tasks shared by all state_dict types that require 165 ``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this hook. 166 """ 167 # For composable `fully_shard`, it does not need to unshard parameters for `NO_SHARD` cases. 168 if not _should_unshard_params(fsdp_state): 169 return 170 _enter_unshard_params_ctx( 171 module, 172 fsdp_state, 173 writeback=False, 174 offload_to_cpu=offload_to_cpu, 175 rank0_only=rank0_only, 176 ) 177 178 179@no_type_check 180def _common_unshard_post_state_dict_hook( 181 module: nn.Module, 182 fsdp_state: _FSDPState, 183 state_dict: Dict[str, Any], 184 prefix: str, 185 param_hook: Callable, 186) -> Dict[str, Any]: 187 """ 188 The post-state_dict flow that shared by all state_dict types that require 189 ``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this 190 hook. 191 """ 192 _replace_by_prefix(state_dict, prefix + f"{FSDP_PREFIX}", prefix) 193 # Return early for trivial cases 194 if not state_dict or not _has_fsdp_params(fsdp_state, module): 195 if _should_unshard_params(fsdp_state): 196 _exit_unshard_params_ctx(module, fsdp_state) 197 return state_dict 198 199 # If a rank does not have unsharded parameters(when `rank0_only=True` 200 # and `rank != 0`), then the rank only needed to participate in the 201 # all-gather and does not need to save the # state dict. We simply check 202 # rank0_only to ensure this issue. 203 rank0_only = ( 204 fsdp_state._state_dict_type == StateDictType.FULL_STATE_DICT 205 and cast(FullStateDictConfig, fsdp_state._state_dict_config).rank0_only 206 ) 207 # no_fsdp_return means the state_dict returned by this rank should contain 208 # only non-FSDP controlled parameters and buffers. 209 no_fsdp_return = rank0_only and fsdp_state.rank != 0 210 if no_fsdp_return and not fsdp_state._use_orig_params: 211 for clean_key in fsdp_state._buffer_names: 212 # This is a hack to support activation checkpoint. 213 clean_key = clean_key.replace( 214 f"{checkpoint_wrapper._CHECKPOINT_PREFIX}.", "" 215 ) 216 state_dict.pop(f"{prefix}{clean_key}", None) 217 # Non-zero ranks have flat_param key when rank0_only=True, because rank0_only=True is 218 # passed in to unshard context, but nonzero ranks reshard early, causing this flat_param 219 # to appear in state_dict. 220 state_dict.pop(f"{prefix}{FLAT_PARAM}") 221 _exit_unshard_params_ctx(module, fsdp_state) 222 return state_dict 223 224 # Loop only the parameters saved in this instance's wrapped module to 225 # avoid processing buffers. 226 for fqn, param_name, module_name in _param_name_infos(module, fsdp_state): 227 fqn = f"{prefix}{fqn}" 228 if no_fsdp_return: 229 state_dict.pop(fqn) 230 continue 231 assert fqn in state_dict, ( 232 f"FSDP assumes {fqn} is in the state_dict but the state_dict only " 233 f"has {state_dict.keys()}. " 234 f"prefix={prefix}, module_name={module_name}, " 235 f"param_name={param_name} rank={fsdp_state.rank}." 236 ) 237 238 param_hook(state_dict, prefix, fqn) 239 240 if _should_unshard_params(fsdp_state): 241 _exit_unshard_params_ctx(module, fsdp_state) 242 243 cpu_device = torch.device("cpu") 244 buffer_clean_fqns = [] 245 buffers = [] 246 for clean_key in fsdp_state._buffer_names: 247 # This is a hack to support activation checkpoint. 248 clean_key = clean_tensor_name(clean_key) 249 fqn = f"{prefix}{clean_key}" 250 if fqn not in state_dict: 251 # A buffer can be registered as non-persistent. 252 continue 253 if no_fsdp_return: 254 state_dict.pop(fqn) 255 else: 256 buffer = state_dict[fqn] 257 if ( 258 fsdp_state._state_dict_config.offload_to_cpu 259 and buffer.device != cpu_device 260 ): 261 state_dict[fqn] = buffer.to(cpu_device) 262 # skip upcasting for ignored buffers 263 if clean_key not in fsdp_state._ignored_buffer_names: 264 buffer_clean_fqns.append(clean_key) 265 buffers.append(state_dict[fqn]) 266 267 if buffers: 268 mixed_precision_enabled_for_buffers = ( 269 fsdp_state._mixed_precision_enabled_for_buffers() 270 if not _is_composable(fsdp_state) 271 else (fsdp_state.mixed_precision.buffer_dtype is not None) 272 ) 273 if mixed_precision_enabled_for_buffers: 274 buffer_dtypes = _get_orig_buffer_dtypes(fsdp_state, buffer_clean_fqns) 275 _cast_buffers_to_dtype_and_device( 276 buffers, buffer_dtypes, fsdp_state.compute_device 277 ) 278 for buffer, clean_fqn in zip(buffers, buffer_clean_fqns): 279 fqn = f"{prefix}{clean_fqn}" 280 logger.info("FSDP is casting the dtype of %s to %s", fqn, buffer.dtype) 281 state_dict[fqn] = buffer.clone() 282 return state_dict 283 284 285@no_type_check 286def _full_pre_state_dict_hook( 287 fsdp_state: _FSDPState, 288 module: nn.Module, 289 *args, 290 **kwargs, 291) -> None: 292 """ 293 Hook that runs before model.state_dict() is called. pre-state_dict hook is 294 not actually supported by ``nn.Module``. As a result, this API is called 295 from ``_full_post_state_dict_hook()`` to simulate the case. Once pre-state_dict 296 is supported in ``nn.Module``, this hook will be registered as a hook in 297 ``nn.Module``. 298 """ 299 if getattr(fsdp_state, "_device_mesh", False): 300 root_mesh = _mesh_resources.get_root_mesh(fsdp_state._device_mesh) 301 302 _common_pre_state_dict_hook(module, fsdp_state) 303 _common_unshard_pre_state_dict_hook( 304 module, 305 fsdp_state, 306 offload_to_cpu=fsdp_state._state_dict_config.offload_to_cpu, 307 rank0_only=cast(FullStateDictConfig, fsdp_state._state_dict_config).rank0_only, 308 ) 309 310 311@no_type_check 312def _full_post_state_dict_hook( 313 module: nn.Module, 314 fsdp_state: _FSDPState, 315 state_dict: Dict[str, Any], 316 prefix: str, 317) -> Dict[str, Any]: 318 """ 319 Hook that runs after model.state_dict() is called before returning result to 320 user. For FSDP, we may have to clone the tensors in state_dict as params go 321 back to sharded version after _unshard_fsdp_state_params ends, and also remove 322 the ``FSDP_WRAPPED_MODULE`` prefix. 323 """ 324 325 def param_hook( 326 state_dict: Dict[str, Any], 327 prefix: str, 328 fqn: str, 329 ) -> None: 330 clean_key = fqn 331 clean_prefix = clean_tensor_name(prefix) 332 # Strip prefix out of key if needed as buffer names and param names 333 # do not have prefix considered as they are not computed in `state_dict` 334 # call. 335 if clean_key.startswith(clean_prefix): 336 clean_key = clean_key[len(clean_prefix) :] 337 338 # Clone parameters before exiting the `_unshard_fsdp_state_params()` context. 339 if not getattr(state_dict[fqn], "_has_been_cloned", False): 340 try: 341 state_dict[fqn] = state_dict[fqn].clone().detach() 342 state_dict[fqn]._has_been_cloned = True # type: ignore[attr-defined] 343 except BaseException as e: 344 warnings.warn( 345 f"Failed to clone() tensor with name {fqn} on rank {fsdp_state.rank}. " 346 "This may mean that this state_dict entry could point to invalid " 347 "memory regions after returning from state_dict() call if this " 348 "parameter is managed by FSDP. Please check clone " 349 f"implementation of {fqn}. Error: {str(e)}" 350 ) 351 352 return _common_unshard_post_state_dict_hook( 353 module, fsdp_state, state_dict, prefix, param_hook 354 ) 355 356 357def _full_pre_load_state_dict_hook( 358 module: nn.Module, 359 fsdp_state: _FSDPState, 360 state_dict: Dict[str, Any], 361 prefix: str, 362) -> None: 363 _lazy_init(fsdp_state, module) 364 if _should_unshard_params(fsdp_state): 365 with SimpleProfiler.profile("_enter_unshard_params_ctx"): 366 _enter_unshard_params_ctx(module, fsdp_state, writeback=True) 367 # Add FSDP_PREFIX only for wrapper-based FSDP. 368 if not _is_composable(fsdp_state): 369 _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}") 370 371 372def _full_post_load_state_dict_hook( 373 module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs 374) -> None: 375 if _should_unshard_params(fsdp_state): 376 with SimpleProfiler.profile("_exit_unshard_params_ctx"): 377 _exit_unshard_params_ctx(module, fsdp_state) 378 379 380def _local_pre_state_dict_hook( 381 fsdp_state: _FSDPState, 382 module: nn.Module, 383 *args, 384 **kwargs, 385) -> None: 386 """ 387 Hook that runs before model.state_dict() is called. Right now, pre-state_dict 388 hook is not supported by the PyTorch core. So this API is called from 389 `_local_post_state_dict_hook()` to simulate the case. 390 """ 391 if ( 392 _has_fsdp_params(fsdp_state, module) 393 and not _module_handle(fsdp_state, module).uses_sharded_strategy 394 ): 395 raise RuntimeError( 396 "``local_state_dict`` can only be used when parameters are flatten " 397 "and sharded." 398 ) 399 _common_pre_state_dict_hook(module, fsdp_state) 400 401 402@no_type_check 403def _local_post_state_dict_hook( 404 module: nn.Module, 405 fsdp_state: _FSDPState, 406 state_dict: Dict[str, Any], 407 prefix: str, 408) -> Dict[str, Any]: 409 """ 410 This hook create a ShardedTensor from the local flat_param and replace 411 the state_dict[f"{prefix}{FLAT_PARAM}] with the ShardedTensor. No copy 412 will happen. The underlying storage is the same. 413 """ 414 415 _replace_by_prefix(state_dict, f"{prefix}{FSDP_PREFIX}", prefix) 416 if not _has_fsdp_params(fsdp_state, module): 417 return state_dict 418 419 # state_dict[f"{prefix}{FLAT_PARAM}"] exists and has the same tensor 420 # value as the flat_param but it is a pure Tensor because 421 # nn.Module.state_dict() will detach the parameter. Therefore, we need 422 # to get flat_param to get the metadata. 423 assert _module_handle(fsdp_state, module), "Should have returned early" 424 flat_param = _module_handle(fsdp_state, module).flat_param 425 # Constructs a ShardedTensor from the flat_param "without" padding. 426 # Removing the padding allows users to change the number of ranks 427 # when loading the local_state_dict. 428 full_numel = flat_param._unpadded_unsharded_size.numel() # type: ignore[attr-defined] 429 shard_offset = flat_param.numel() * fsdp_state.rank 430 valid_data_size = flat_param.numel() - flat_param._shard_numel_padded 431 if valid_data_size > 0: 432 # If FlatParameter is returned, FlatParameter._local_shard cause a 433 # pickling issue (can be torch.save but not torch.load). Since there 434 # is no benefit for state_dict to return the actual FlatParameter class, 435 # a view (which is a tensor) of the FlatParameter will be returned. 436 flat_param = flat_param[:valid_data_size].view(valid_data_size) 437 local_shards = [ 438 Shard.from_tensor_and_offsets(flat_param, [shard_offset], fsdp_state.rank) 439 ] 440 else: 441 local_shards = [] 442 sharded_tensor = init_from_local_shards( 443 local_shards, full_numel, process_group=fsdp_state.process_group 444 ) # type: ignore[assignment] 445 # TODO: Add DTensor state_dict support for LOCAL_STATE_DICT. 446 if fsdp_state._state_dict_config.offload_to_cpu: 447 sharded_tensor = sharded_tensor.cpu() 448 state_dict[f"{prefix}{FLAT_PARAM}"] = sharded_tensor 449 return state_dict 450 451 452def _local_post_load_state_dict_hook( 453 module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs 454) -> None: 455 pass 456 457 458def _local_pre_load_state_dict_hook( 459 module: nn.Module, 460 fsdp_state: _FSDPState, 461 state_dict: Dict[str, Any], 462 prefix: str, 463) -> None: 464 """ 465 This hook finds the local flat_param for this FSDP module from the 466 state_dict. The flat_param should be a ShardedTensor. This hook converts 467 the ShardedTensor to a tensor. No copy happen unless padding is required. 468 """ 469 _lazy_init(fsdp_state, module) 470 _replace_by_prefix(state_dict, prefix, f"{prefix}{FSDP_PREFIX}") 471 fqn = f"{prefix}{FSDP_PREFIX}{FLAT_PARAM}" 472 if fqn not in state_dict: 473 assert not _has_fsdp_params(fsdp_state, module), ( 474 "No `FlatParameter` in `state_dict` for this FSDP instance " 475 "but it has parameters" 476 ) 477 return 478 load_tensor = state_dict[fqn] 479 assert isinstance( 480 load_tensor, ShardedTensor 481 ), "Tensors in local_state_dict should be ShardedTensor." 482 483 # Convert the ShardedTensor to a Tensor. 484 flat_param = _module_handle(fsdp_state, module).flat_param 485 assert flat_param is not None 486 valid_data_size = flat_param.numel() - flat_param._shard_numel_padded 487 shards = load_tensor.local_shards() 488 if valid_data_size > 0: 489 assert len(shards), "load_local_state_dict assume one shard per ShardedTensor." 490 load_tensor = shards[0].tensor 491 492 # Get the metadata of the flat_param to decide whether to pad the loaded 493 # tensor. 494 if flat_param._shard_numel_padded > 0: 495 assert load_tensor.numel() < flat_param.numel(), ( 496 f"Local shard size = {flat_param.numel()} and the tensor in " 497 f"the state_dict is {load_tensor.numel()}." 498 ) 499 load_tensor = F.pad(load_tensor, [0, flat_param._shard_numel_padded]) 500 else: 501 load_tensor = flat_param 502 # TODO: Add DTensor state_dict support for LOCAL_STATE_DICT. 503 state_dict[fqn] = load_tensor 504 505 506def _sharded_pre_state_dict_hook( 507 fsdp_state: _FSDPState, 508 module: nn.Module, 509 *args, 510 **kwargs, 511) -> None: 512 """ 513 Hook that runs before model.state_dict() is called. Check 514 ``_full_pre_load_state_dict_hook`` for the detail. 515 """ 516 if ( 517 _has_fsdp_params(fsdp_state, module) 518 and not _module_handle(fsdp_state, module).uses_sharded_strategy 519 ): 520 raise RuntimeError( 521 "``sharded_state_dict`` can only be used when parameters are flatten " 522 "and sharded." 523 ) 524 _common_pre_state_dict_hook(module, fsdp_state) 525 # Setting offload_to_cpu here does not work even if offload_to_cpu is True. 526 # We have to create ShardedTensor first then move it to CPU. 527 _common_unshard_pre_state_dict_hook( 528 module, 529 fsdp_state, 530 offload_to_cpu=False, 531 rank0_only=False, 532 ) 533 534 535@no_type_check 536def _sharded_post_state_dict_hook( 537 module: nn.Module, 538 fsdp_state: _FSDPState, 539 state_dict: Dict[str, Any], 540 prefix: str, 541) -> Dict[str, Any]: 542 """ 543 The hook replaces the unflattened, unsharded parameter in the state_dict 544 with a unflattened, sharded parameter (a ShardedTensor). 545 """ 546 547 def param_hook(state_dict: Dict[str, Any], prefix: str, fqn: str): 548 param = state_dict[fqn] 549 if not fsdp_state._state_dict_config._use_dtensor: 550 sharded_tensor = _ext_chunk_tensor( 551 tensor=param, 552 rank=fsdp_state.rank, 553 world_size=fsdp_state.world_size, 554 num_devices_per_node=fsdp_state._device_handle.device_count(), 555 pg=fsdp_state.process_group, 556 fsdp_extension=fsdp_state._fsdp_extension, 557 ) 558 else: 559 sharded_tensor = _ext_chunk_dtensor( 560 tensor=param, 561 rank=fsdp_state.rank, 562 device_mesh=fsdp_state._device_mesh, 563 fsdp_extension=fsdp_state._fsdp_extension, 564 ) 565 if fsdp_state._state_dict_config.offload_to_cpu: 566 sharded_tensor = sharded_tensor.cpu() 567 state_dict[fqn] = sharded_tensor 568 569 return _common_unshard_post_state_dict_hook( 570 module, fsdp_state, state_dict, prefix, param_hook 571 ) 572 573 574@no_type_check 575def _sharded_post_load_state_dict_hook( 576 module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs 577) -> None: 578 if _has_fsdp_params(fsdp_state, module): 579 with SimpleProfiler.profile("_exit_unshard_params_ctx"): 580 _exit_unshard_params_ctx(module, fsdp_state) 581 582 583@no_type_check 584def _sharded_pre_load_state_dict_hook( 585 module: nn.Module, 586 fsdp_state: _FSDPState, 587 state_dict: Dict[str, Any], 588 prefix: str, 589) -> None: 590 """ 591 The hook combines the unflattened, sharded parameters (ShardedTensor) to 592 a new FlatParameter and shards the new FlatParameter to the local chunk. 593 """ 594 _lazy_init(fsdp_state, module) 595 if not _is_composable(fsdp_state): 596 _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}") 597 if not _has_fsdp_params(fsdp_state, module): 598 return 599 600 handle = _module_handle(fsdp_state, module) 601 if not handle.uses_sharded_strategy: 602 raise RuntimeError( 603 "load_sharded_state_dict can only be called when parameters " 604 "are flattened and sharded." 605 ) 606 fqn_to_param_ext = dict( 607 zip(handle.flat_param._fqns, handle.flat_param._param_extensions) 608 ) 609 610 for fqn, _, _ in _param_name_infos(module, fsdp_state): 611 if not _is_composable(fsdp_state): 612 fqn_from_global_root = f"{prefix}{FSDP_PREFIX}{fqn}" 613 else: 614 fqn_from_global_root = f"{prefix}{fqn}" 615 try: 616 param = state_dict.pop(fqn_from_global_root) 617 except KeyError: 618 logger.warning( 619 f"Did not find param with FQN {fqn_from_global_root}, skipping it. " # noqa: G004 620 "The weight will not be filled if you expect it to be." 621 ) 622 continue # TODO: Improve unittesting for state_dict finetuning 623 # cases: https://github.com/pytorch/pytorch/issues/109134 624 625 if not fsdp_state._state_dict_config._use_dtensor: 626 # All-gather the param (ShardedTensor) 627 param, shards = _ext_pre_load_state_dict_transform( 628 param, fsdp_state._fsdp_extension 629 ) 630 631 assert len(shards) < 2, ( 632 "Expects 0 or 1 shard per rank " 633 f"but got {len(shards)} shards on rank {fsdp_state.rank}." 634 ) 635 param_numel = param.size().numel() 636 dim_0_size = param.size()[0] 637 chunk_size = ( 638 math.ceil(dim_0_size / fsdp_state.world_size) 639 * param_numel 640 // dim_0_size 641 ) 642 if len(shards) == 1: 643 local_tensor = shards[0].tensor.flatten() 644 with SimpleProfiler.profile(SimpleProfiler.Type.H2D): 645 local_tensor = local_tensor.to(fsdp_state.compute_device) 646 num_padding = chunk_size - local_tensor.numel() 647 if num_padding > 0: 648 local_tensor = F.pad(local_tensor, [0, num_padding]) 649 else: 650 local_tensor = torch.zeros( 651 chunk_size, dtype=param.dtype, device=fsdp_state.compute_device 652 ) 653 tensor = torch.empty( 654 chunk_size * fsdp_state.world_size, 655 dtype=local_tensor.dtype, 656 device=fsdp_state.compute_device, 657 ) 658 with SimpleProfiler.profile(SimpleProfiler.Type.ALLGATHER): 659 dist.all_gather_into_tensor( 660 tensor, local_tensor, group=fsdp_state.process_group 661 ) 662 tensor = tensor.narrow(0, 0, param_numel).reshape(param.size()) 663 state_dict[fqn_from_global_root] = tensor 664 else: 665 if param.device != fsdp_state._device_mesh.device_type: 666 param = param.to(fsdp_state._device_mesh.device_type) 667 668 root_mesh = _mesh_resources.get_root_mesh(fsdp_state._device_mesh) 669 local_tensor = _ext_all_gather_dtensor( 670 param, root_mesh, fsdp_state._fsdp_extension 671 ) 672 673 if fqn_to_param_ext.get(fqn) is not None: 674 ext = fqn_to_param_ext[fqn] 675 local_tensor = _ext_post_unflatten_transform( 676 local_tensor, ext, fsdp_state._fsdp_extension 677 ) 678 state_dict[fqn_from_global_root] = local_tensor 679 680 with SimpleProfiler.profile("_enter_unshard_params_ctx"): 681 _enter_unshard_params_ctx(module, fsdp_state, writeback=True) 682 683 684@contextlib.contextmanager 685def _replace_with_full_state_dict_type(fsdp_state: _FSDPState) -> Generator: 686 old_state_dict_config = fsdp_state._state_dict_config 687 old_state_dict_type = fsdp_state._state_dict_type 688 fsdp_state._state_dict_config = FullStateDictConfig() 689 fsdp_state._state_dict_type = StateDictType.FULL_STATE_DICT 690 yield 691 fsdp_state._state_dict_config = old_state_dict_config 692 fsdp_state._state_dict_type = old_state_dict_type 693 694 695@no_type_check 696@torch.no_grad() 697def _post_state_dict_hook( 698 module: nn.Module, 699 state_dict: Dict[str, Any], 700 prefix: str, 701 *args: Any, 702) -> Dict[str, Any]: 703 """ 704 _post_state_dict_hook() is called after the state_dict() of this 705 FSDP module is executed. ``fsdp_state._state_dict_type`` is used to decide 706 what postprocessing will be done. 707 """ 708 fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) 709 if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD: 710 context = _replace_with_full_state_dict_type(fsdp_state) 711 warnings.warn( 712 "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will" 713 "be returned." 714 ) 715 else: 716 context = contextlib.nullcontext() 717 718 with context: 719 _post_state_dict_hook_fn = { 720 StateDictType.FULL_STATE_DICT: _full_post_state_dict_hook, 721 StateDictType.LOCAL_STATE_DICT: _local_post_state_dict_hook, 722 StateDictType.SHARDED_STATE_DICT: _sharded_post_state_dict_hook, 723 } 724 processed_state_dict = _post_state_dict_hook_fn[fsdp_state._state_dict_type]( 725 module, fsdp_state, state_dict, prefix 726 ) 727 728 if fsdp_state._is_root: 729 logger.info("FSDP finished processing state_dict(), prefix=%s", prefix) 730 for key, tensor in sorted(processed_state_dict.items()): 731 if key.startswith(prefix) and isinstance(tensor, torch.Tensor): 732 local_shape = tensor.shape 733 if isinstance(tensor, ShardedTensor): 734 local_shape = None 735 shards = tensor.local_shards() 736 if shards: 737 local_shape = shards[0].tensor.shape 738 elif isinstance(tensor, DTensor): 739 local_shape = tensor.to_local().shape 740 logger.info( 741 "FQN=%s: type=%s, shape=%s, local_shape=%s, dtype=%s, device=%s", 742 key, 743 type(tensor), 744 tensor.shape, 745 local_shape, 746 tensor.dtype, 747 tensor.device, 748 ) 749 750 return processed_state_dict 751 752 753@no_type_check 754@torch.no_grad() 755def _pre_state_dict_hook( 756 module: nn.Module, 757 *args, 758 **kwargs, 759) -> None: 760 """ 761 This is called before the core state dict saving logic of ``module``. 762 ``fsdp_state._state_dict_type`` is used to decide what postprocessing will 763 be done. 764 """ 765 fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) 766 if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD: 767 context = _replace_with_full_state_dict_type(fsdp_state) 768 warnings.warn( 769 "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will" 770 "be returned." 771 ) 772 else: 773 _set_use_dtensor(fsdp_state) 774 context = contextlib.nullcontext() 775 776 with context: 777 _pre_state_dict_hook_fn = { 778 StateDictType.FULL_STATE_DICT: _full_pre_state_dict_hook, 779 StateDictType.LOCAL_STATE_DICT: _local_pre_state_dict_hook, 780 StateDictType.SHARDED_STATE_DICT: _sharded_pre_state_dict_hook, 781 } 782 _pre_state_dict_hook_fn[fsdp_state._state_dict_type]( 783 fsdp_state, 784 module, 785 *args, 786 **kwargs, 787 ) 788 789 790@no_type_check 791def _set_use_dtensor(fsdp_state: _FSDPState) -> None: 792 # If device_mesh is passed in when initalizing FSDP, we automatically turn the 793 # _use_dtensor flag to be true for ShardedStateDictConfig(). 794 if getattr(fsdp_state, "_device_mesh", None): 795 state_dict_type = fsdp_state._state_dict_type 796 if state_dict_type == StateDictType.LOCAL_STATE_DICT: 797 raise RuntimeError( 798 "Found state_dict_type LOCAL_STATE_DICT", 799 "DeviceMesh is not compatible with LOCAL_STATE_DICT.", 800 "Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict.", 801 ) 802 else: 803 fsdp_state._state_dict_config._use_dtensor = True 804 805 806@no_type_check 807@torch.no_grad() 808def _pre_load_state_dict_hook( 809 module: nn.Module, 810 state_dict: Dict[str, Any], 811 prefix: str, 812 *args: Any, 813) -> None: 814 """ 815 This is called before ``module._load_from_state_dict()``. 816 ``fsdp_state._state_dict_type`` is used to decide what preprocessing will 817 be done. 818 """ 819 fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) 820 if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD: 821 context = _replace_with_full_state_dict_type(fsdp_state) 822 warnings.warn( 823 "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will" 824 "be returned." 825 ) 826 else: 827 _set_use_dtensor(fsdp_state) 828 context = contextlib.nullcontext() 829 830 _lazy_init(fsdp_state, module) 831 if fsdp_state._is_root: 832 SimpleProfiler.reset() 833 834 with context: 835 _pre_load_state_dict_hook_fn = { 836 StateDictType.FULL_STATE_DICT: _full_pre_load_state_dict_hook, 837 StateDictType.LOCAL_STATE_DICT: _local_pre_load_state_dict_hook, 838 StateDictType.SHARDED_STATE_DICT: _sharded_pre_load_state_dict_hook, 839 } 840 # Code that is common for all state_dict impls 841 if fsdp_state._device_handle.is_available(): 842 fsdp_state._device_handle.synchronize() 843 # Dispatch into state_dict specific implementation of pre-hook. 844 _pre_load_state_dict_hook_fn[fsdp_state._state_dict_type]( 845 module, fsdp_state, state_dict, prefix 846 ) 847 848 849@no_type_check 850@torch.no_grad() 851def _post_load_state_dict_hook( 852 module: nn.Module, 853 incompatible_keys: Tuple[List[str], List[str]], 854 *args: Any, 855) -> None: 856 fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) 857 if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD: 858 context = _replace_with_full_state_dict_type(fsdp_state) 859 warnings.warn( 860 "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will" 861 "be returned." 862 ) 863 else: 864 context = contextlib.nullcontext() 865 866 with context: 867 _post_load_state_dict_hook_fn = { 868 StateDictType.FULL_STATE_DICT: _full_post_load_state_dict_hook, 869 StateDictType.LOCAL_STATE_DICT: _local_post_load_state_dict_hook, 870 StateDictType.SHARDED_STATE_DICT: _sharded_post_load_state_dict_hook, 871 } 872 # Code that is common for all state_dict impls 873 # Dispatch into state_dict type specific implementation of post-hook for 874 # loading state_dict. 875 _post_load_state_dict_hook_fn[fsdp_state._state_dict_type](module, fsdp_state) 876 877 # When reporting incompatible keys, trim FSDP prefixes. 878 missing_keys = incompatible_keys[0] 879 unexpected_keys = incompatible_keys[1] 880 for i in range(len(missing_keys)): 881 missing_keys[i] = clean_tensor_name(missing_keys[i]) 882 883 for i in range(len(unexpected_keys)): 884 unexpected_keys[i] = clean_tensor_name(unexpected_keys[i]) 885 886 if fsdp_state._is_root: 887 SimpleProfiler.dump_and_reset("FSDP model load_state_dict profiling: ") 888 889 890def _register_all_state_dict_hooks(state: _FSDPState): 891 """ 892 Registers pre-save, post-save, pre-load, and post-load state dict hooks. 893 """ 894 for hook_registration_fn_str, hook, hook_registration_fn_kwargs in ( 895 ("register_state_dict_pre_hook", _pre_state_dict_hook, {}), 896 ("_register_state_dict_hook", _post_state_dict_hook, {}), 897 ( 898 "_register_load_state_dict_pre_hook", 899 _pre_load_state_dict_hook, 900 {"with_module": True}, 901 ), 902 ("register_load_state_dict_post_hook", _post_load_state_dict_hook, {}), 903 ): 904 _register_state_dict_hooks_base( 905 state, hook_registration_fn_str, hook, hook_registration_fn_kwargs 906 ) 907 908 909@no_type_check 910def _register_state_dict_hooks_base( 911 state: _FSDPState, 912 hook_registration_fn_name: str, 913 hook: Callable, 914 hook_registration_fn_kwargs: Dict[str, Any], 915) -> None: 916 """Registers ``hook`` using ``hook_registration_fn``.""" 917 if not _is_composable(state): 918 getattr(state, hook_registration_fn_name)(hook, **hook_registration_fn_kwargs) 919 else: 920 handle = state._handle 921 if handle: 922 getattr(handle._fully_sharded_module, hook_registration_fn_name)( 923 hook, **hook_registration_fn_kwargs 924 ) 925