# mypy: allow-untyped-defs import contextlib import logging import math import warnings from typing import ( Any, Callable, cast, Dict, Generator, Iterator, List, no_type_check, Tuple, ) import torch import torch.distributed as dist import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as checkpoint_wrapper import torch.nn as nn import torch.nn.functional as F from torch.distributed._shard.sharded_tensor import ( init_from_local_shards, Shard, ShardedTensor, ) from torch.distributed.device_mesh import _mesh_resources from torch.distributed.fsdp._common_utils import ( _FSDPState, _get_module_fsdp_state_if_fully_sharded_module, _has_fsdp_params, _is_composable, _module_handle, clean_tensor_name, FSDP_PREFIX, FSDP_WRAPPED_MODULE, ) from torch.distributed.fsdp._debug_utils import SimpleProfiler from torch.distributed.fsdp._runtime_utils import ( _cast_buffers_to_dtype_and_device, _get_orig_buffer_dtypes, _lazy_init, _reset_flat_param_grad_info_if_needed, ) from torch.distributed.fsdp.api import ( FullStateDictConfig, ShardingStrategy, StateDictType, ) from torch.distributed.tensor import DTensor from torch.distributed.utils import _replace_by_prefix from ._fsdp_extensions import ( _ext_all_gather_dtensor, _ext_chunk_dtensor, _ext_chunk_tensor, _ext_post_unflatten_transform, _ext_pre_load_state_dict_transform, ) from ._unshard_param_utils import _unshard_fsdp_state_params, FLAT_PARAM logger = logging.getLogger(__name__) def _should_unshard_params(fsdp_state: _FSDPState) -> bool: return not ( fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD and (_is_composable(fsdp_state) or fsdp_state._use_orig_params) ) def _convert_to_wrapped_module_name(module_name: str) -> str: module_name = module_name.replace(f"{FSDP_PREFIX}", "") module_name = module_name.replace(f"{FSDP_WRAPPED_MODULE}", "") if module_name: module_name = f"{module_name}." # `CheckpointWrapper` adds a prefix that has to be removed as well. module_name = module_name.replace(checkpoint_wrapper._CHECKPOINT_PREFIX, "") return module_name def _param_name_infos( module: nn.Module, fsdp_state: _FSDPState ) -> Iterator[Tuple[str, str, str]]: if not _has_fsdp_params(fsdp_state, module): return for param_name, module_name in _module_handle( fsdp_state, module ).param_module_names(): module_name = _convert_to_wrapped_module_name(module_name) fqn = f"{module_name}{param_name}" yield fqn, param_name, module_name def _shared_param_name_infos( module: nn.Module, fsdp_state ) -> Iterator[Tuple[str, str, str]]: for param_name, module_name in _module_handle( fsdp_state, module ).shared_param_module_names(): module_name = _convert_to_wrapped_module_name(module_name) fqn = f"{module_name}{param_name}" yield fqn, param_name, module_name @no_type_check def _enter_unshard_params_ctx( module: nn.Module, fsdp_state: _FSDPState, writeback: bool = False, rank0_only: bool = False, offload_to_cpu: bool = False, with_grads: bool = False, ) -> None: """ state_dict hooks cannot use the pure context call as the checkpoint flow requires to enter the context in the pre-hook but leave the context in the post-hook. This API enters the context of ``_unshard_fsdp_state_params``. """ assert module not in fsdp_state._unshard_params_ctx, ( "Entering the ``_unshard_fsdp_state_params`` context but _unshard_params_ctx[module] " "is not None." ) fsdp_state._unshard_params_ctx[module] = _unshard_fsdp_state_params( module, fsdp_state, writeback=writeback, rank0_only=rank0_only, offload_to_cpu=offload_to_cpu, with_grads=with_grads, ) fsdp_state._unshard_params_ctx[module].__enter__() @no_type_check def _exit_unshard_params_ctx(module: nn.Module, fsdp_state: _FSDPState) -> None: """A helper function to exit ``_unshard_fsdp_state_params`` context.""" fsdp_state._unshard_params_ctx[module].__exit__(None, None, None) fsdp_state._unshard_params_ctx.pop(module) def _common_pre_state_dict_hook( module: nn.Module, fsdp_state: _FSDPState, ) -> None: """Performs the pre-state_dict tasks shared by all state_dict types.""" if fsdp_state._device_handle.is_available(): fsdp_state._device_handle.synchronize() # TODO: need to check if this is always correct for composable FSDP. _lazy_init(fsdp_state, module) if fsdp_state._is_root: _reset_flat_param_grad_info_if_needed(fsdp_state._all_handles) def _common_unshard_pre_state_dict_hook( module: nn.Module, fsdp_state: _FSDPState, offload_to_cpu: bool, rank0_only: bool, ) -> None: """ Performs the pre-state_dict tasks shared by all state_dict types that require ``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this hook. """ # For composable `fully_shard`, it does not need to unshard parameters for `NO_SHARD` cases. if not _should_unshard_params(fsdp_state): return _enter_unshard_params_ctx( module, fsdp_state, writeback=False, offload_to_cpu=offload_to_cpu, rank0_only=rank0_only, ) @no_type_check def _common_unshard_post_state_dict_hook( module: nn.Module, fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, param_hook: Callable, ) -> Dict[str, Any]: """ The post-state_dict flow that shared by all state_dict types that require ``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this hook. """ _replace_by_prefix(state_dict, prefix + f"{FSDP_PREFIX}", prefix) # Return early for trivial cases if not state_dict or not _has_fsdp_params(fsdp_state, module): if _should_unshard_params(fsdp_state): _exit_unshard_params_ctx(module, fsdp_state) return state_dict # If a rank does not have unsharded parameters(when `rank0_only=True` # and `rank != 0`), then the rank only needed to participate in the # all-gather and does not need to save the # state dict. We simply check # rank0_only to ensure this issue. rank0_only = ( fsdp_state._state_dict_type == StateDictType.FULL_STATE_DICT and cast(FullStateDictConfig, fsdp_state._state_dict_config).rank0_only ) # no_fsdp_return means the state_dict returned by this rank should contain # only non-FSDP controlled parameters and buffers. no_fsdp_return = rank0_only and fsdp_state.rank != 0 if no_fsdp_return and not fsdp_state._use_orig_params: for clean_key in fsdp_state._buffer_names: # This is a hack to support activation checkpoint. clean_key = clean_key.replace( f"{checkpoint_wrapper._CHECKPOINT_PREFIX}.", "" ) state_dict.pop(f"{prefix}{clean_key}", None) # Non-zero ranks have flat_param key when rank0_only=True, because rank0_only=True is # passed in to unshard context, but nonzero ranks reshard early, causing this flat_param # to appear in state_dict. state_dict.pop(f"{prefix}{FLAT_PARAM}") _exit_unshard_params_ctx(module, fsdp_state) return state_dict # Loop only the parameters saved in this instance's wrapped module to # avoid processing buffers. for fqn, param_name, module_name in _param_name_infos(module, fsdp_state): fqn = f"{prefix}{fqn}" if no_fsdp_return: state_dict.pop(fqn) continue assert fqn in state_dict, ( f"FSDP assumes {fqn} is in the state_dict but the state_dict only " f"has {state_dict.keys()}. " f"prefix={prefix}, module_name={module_name}, " f"param_name={param_name} rank={fsdp_state.rank}." ) param_hook(state_dict, prefix, fqn) if _should_unshard_params(fsdp_state): _exit_unshard_params_ctx(module, fsdp_state) cpu_device = torch.device("cpu") buffer_clean_fqns = [] buffers = [] for clean_key in fsdp_state._buffer_names: # This is a hack to support activation checkpoint. clean_key = clean_tensor_name(clean_key) fqn = f"{prefix}{clean_key}" if fqn not in state_dict: # A buffer can be registered as non-persistent. continue if no_fsdp_return: state_dict.pop(fqn) else: buffer = state_dict[fqn] if ( fsdp_state._state_dict_config.offload_to_cpu and buffer.device != cpu_device ): state_dict[fqn] = buffer.to(cpu_device) # skip upcasting for ignored buffers if clean_key not in fsdp_state._ignored_buffer_names: buffer_clean_fqns.append(clean_key) buffers.append(state_dict[fqn]) if buffers: mixed_precision_enabled_for_buffers = ( fsdp_state._mixed_precision_enabled_for_buffers() if not _is_composable(fsdp_state) else (fsdp_state.mixed_precision.buffer_dtype is not None) ) if mixed_precision_enabled_for_buffers: buffer_dtypes = _get_orig_buffer_dtypes(fsdp_state, buffer_clean_fqns) _cast_buffers_to_dtype_and_device( buffers, buffer_dtypes, fsdp_state.compute_device ) for buffer, clean_fqn in zip(buffers, buffer_clean_fqns): fqn = f"{prefix}{clean_fqn}" logger.info("FSDP is casting the dtype of %s to %s", fqn, buffer.dtype) state_dict[fqn] = buffer.clone() return state_dict @no_type_check def _full_pre_state_dict_hook( fsdp_state: _FSDPState, module: nn.Module, *args, **kwargs, ) -> None: """ Hook that runs before model.state_dict() is called. pre-state_dict hook is not actually supported by ``nn.Module``. As a result, this API is called from ``_full_post_state_dict_hook()`` to simulate the case. Once pre-state_dict is supported in ``nn.Module``, this hook will be registered as a hook in ``nn.Module``. """ if getattr(fsdp_state, "_device_mesh", False): root_mesh = _mesh_resources.get_root_mesh(fsdp_state._device_mesh) _common_pre_state_dict_hook(module, fsdp_state) _common_unshard_pre_state_dict_hook( module, fsdp_state, offload_to_cpu=fsdp_state._state_dict_config.offload_to_cpu, rank0_only=cast(FullStateDictConfig, fsdp_state._state_dict_config).rank0_only, ) @no_type_check def _full_post_state_dict_hook( module: nn.Module, fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, ) -> Dict[str, Any]: """ Hook that runs after model.state_dict() is called before returning result to user. For FSDP, we may have to clone the tensors in state_dict as params go back to sharded version after _unshard_fsdp_state_params ends, and also remove the ``FSDP_WRAPPED_MODULE`` prefix. """ def param_hook( state_dict: Dict[str, Any], prefix: str, fqn: str, ) -> None: clean_key = fqn clean_prefix = clean_tensor_name(prefix) # Strip prefix out of key if needed as buffer names and param names # do not have prefix considered as they are not computed in `state_dict` # call. if clean_key.startswith(clean_prefix): clean_key = clean_key[len(clean_prefix) :] # Clone parameters before exiting the `_unshard_fsdp_state_params()` context. if not getattr(state_dict[fqn], "_has_been_cloned", False): try: state_dict[fqn] = state_dict[fqn].clone().detach() state_dict[fqn]._has_been_cloned = True # type: ignore[attr-defined] except BaseException as e: warnings.warn( f"Failed to clone() tensor with name {fqn} on rank {fsdp_state.rank}. " "This may mean that this state_dict entry could point to invalid " "memory regions after returning from state_dict() call if this " "parameter is managed by FSDP. Please check clone " f"implementation of {fqn}. Error: {str(e)}" ) return _common_unshard_post_state_dict_hook( module, fsdp_state, state_dict, prefix, param_hook ) def _full_pre_load_state_dict_hook( module: nn.Module, fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, ) -> None: _lazy_init(fsdp_state, module) if _should_unshard_params(fsdp_state): with SimpleProfiler.profile("_enter_unshard_params_ctx"): _enter_unshard_params_ctx(module, fsdp_state, writeback=True) # Add FSDP_PREFIX only for wrapper-based FSDP. if not _is_composable(fsdp_state): _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}") def _full_post_load_state_dict_hook( module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs ) -> None: if _should_unshard_params(fsdp_state): with SimpleProfiler.profile("_exit_unshard_params_ctx"): _exit_unshard_params_ctx(module, fsdp_state) def _local_pre_state_dict_hook( fsdp_state: _FSDPState, module: nn.Module, *args, **kwargs, ) -> None: """ Hook that runs before model.state_dict() is called. Right now, pre-state_dict hook is not supported by the PyTorch core. So this API is called from `_local_post_state_dict_hook()` to simulate the case. """ if ( _has_fsdp_params(fsdp_state, module) and not _module_handle(fsdp_state, module).uses_sharded_strategy ): raise RuntimeError( "``local_state_dict`` can only be used when parameters are flatten " "and sharded." ) _common_pre_state_dict_hook(module, fsdp_state) @no_type_check def _local_post_state_dict_hook( module: nn.Module, fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, ) -> Dict[str, Any]: """ This hook create a ShardedTensor from the local flat_param and replace the state_dict[f"{prefix}{FLAT_PARAM}] with the ShardedTensor. No copy will happen. The underlying storage is the same. """ _replace_by_prefix(state_dict, f"{prefix}{FSDP_PREFIX}", prefix) if not _has_fsdp_params(fsdp_state, module): return state_dict # state_dict[f"{prefix}{FLAT_PARAM}"] exists and has the same tensor # value as the flat_param but it is a pure Tensor because # nn.Module.state_dict() will detach the parameter. Therefore, we need # to get flat_param to get the metadata. assert _module_handle(fsdp_state, module), "Should have returned early" flat_param = _module_handle(fsdp_state, module).flat_param # Constructs a ShardedTensor from the flat_param "without" padding. # Removing the padding allows users to change the number of ranks # when loading the local_state_dict. full_numel = flat_param._unpadded_unsharded_size.numel() # type: ignore[attr-defined] shard_offset = flat_param.numel() * fsdp_state.rank valid_data_size = flat_param.numel() - flat_param._shard_numel_padded if valid_data_size > 0: # If FlatParameter is returned, FlatParameter._local_shard cause a # pickling issue (can be torch.save but not torch.load). Since there # is no benefit for state_dict to return the actual FlatParameter class, # a view (which is a tensor) of the FlatParameter will be returned. flat_param = flat_param[:valid_data_size].view(valid_data_size) local_shards = [ Shard.from_tensor_and_offsets(flat_param, [shard_offset], fsdp_state.rank) ] else: local_shards = [] sharded_tensor = init_from_local_shards( local_shards, full_numel, process_group=fsdp_state.process_group ) # type: ignore[assignment] # TODO: Add DTensor state_dict support for LOCAL_STATE_DICT. if fsdp_state._state_dict_config.offload_to_cpu: sharded_tensor = sharded_tensor.cpu() state_dict[f"{prefix}{FLAT_PARAM}"] = sharded_tensor return state_dict def _local_post_load_state_dict_hook( module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs ) -> None: pass def _local_pre_load_state_dict_hook( module: nn.Module, fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, ) -> None: """ This hook finds the local flat_param for this FSDP module from the state_dict. The flat_param should be a ShardedTensor. This hook converts the ShardedTensor to a tensor. No copy happen unless padding is required. """ _lazy_init(fsdp_state, module) _replace_by_prefix(state_dict, prefix, f"{prefix}{FSDP_PREFIX}") fqn = f"{prefix}{FSDP_PREFIX}{FLAT_PARAM}" if fqn not in state_dict: assert not _has_fsdp_params(fsdp_state, module), ( "No `FlatParameter` in `state_dict` for this FSDP instance " "but it has parameters" ) return load_tensor = state_dict[fqn] assert isinstance( load_tensor, ShardedTensor ), "Tensors in local_state_dict should be ShardedTensor." # Convert the ShardedTensor to a Tensor. flat_param = _module_handle(fsdp_state, module).flat_param assert flat_param is not None valid_data_size = flat_param.numel() - flat_param._shard_numel_padded shards = load_tensor.local_shards() if valid_data_size > 0: assert len(shards), "load_local_state_dict assume one shard per ShardedTensor." load_tensor = shards[0].tensor # Get the metadata of the flat_param to decide whether to pad the loaded # tensor. if flat_param._shard_numel_padded > 0: assert load_tensor.numel() < flat_param.numel(), ( f"Local shard size = {flat_param.numel()} and the tensor in " f"the state_dict is {load_tensor.numel()}." ) load_tensor = F.pad(load_tensor, [0, flat_param._shard_numel_padded]) else: load_tensor = flat_param # TODO: Add DTensor state_dict support for LOCAL_STATE_DICT. state_dict[fqn] = load_tensor def _sharded_pre_state_dict_hook( fsdp_state: _FSDPState, module: nn.Module, *args, **kwargs, ) -> None: """ Hook that runs before model.state_dict() is called. Check ``_full_pre_load_state_dict_hook`` for the detail. """ if ( _has_fsdp_params(fsdp_state, module) and not _module_handle(fsdp_state, module).uses_sharded_strategy ): raise RuntimeError( "``sharded_state_dict`` can only be used when parameters are flatten " "and sharded." ) _common_pre_state_dict_hook(module, fsdp_state) # Setting offload_to_cpu here does not work even if offload_to_cpu is True. # We have to create ShardedTensor first then move it to CPU. _common_unshard_pre_state_dict_hook( module, fsdp_state, offload_to_cpu=False, rank0_only=False, ) @no_type_check def _sharded_post_state_dict_hook( module: nn.Module, fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, ) -> Dict[str, Any]: """ The hook replaces the unflattened, unsharded parameter in the state_dict with a unflattened, sharded parameter (a ShardedTensor). """ def param_hook(state_dict: Dict[str, Any], prefix: str, fqn: str): param = state_dict[fqn] if not fsdp_state._state_dict_config._use_dtensor: sharded_tensor = _ext_chunk_tensor( tensor=param, rank=fsdp_state.rank, world_size=fsdp_state.world_size, num_devices_per_node=fsdp_state._device_handle.device_count(), pg=fsdp_state.process_group, fsdp_extension=fsdp_state._fsdp_extension, ) else: sharded_tensor = _ext_chunk_dtensor( tensor=param, rank=fsdp_state.rank, device_mesh=fsdp_state._device_mesh, fsdp_extension=fsdp_state._fsdp_extension, ) if fsdp_state._state_dict_config.offload_to_cpu: sharded_tensor = sharded_tensor.cpu() state_dict[fqn] = sharded_tensor return _common_unshard_post_state_dict_hook( module, fsdp_state, state_dict, prefix, param_hook ) @no_type_check def _sharded_post_load_state_dict_hook( module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs ) -> None: if _has_fsdp_params(fsdp_state, module): with SimpleProfiler.profile("_exit_unshard_params_ctx"): _exit_unshard_params_ctx(module, fsdp_state) @no_type_check def _sharded_pre_load_state_dict_hook( module: nn.Module, fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, ) -> None: """ The hook combines the unflattened, sharded parameters (ShardedTensor) to a new FlatParameter and shards the new FlatParameter to the local chunk. """ _lazy_init(fsdp_state, module) if not _is_composable(fsdp_state): _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}") if not _has_fsdp_params(fsdp_state, module): return handle = _module_handle(fsdp_state, module) if not handle.uses_sharded_strategy: raise RuntimeError( "load_sharded_state_dict can only be called when parameters " "are flattened and sharded." ) fqn_to_param_ext = dict( zip(handle.flat_param._fqns, handle.flat_param._param_extensions) ) for fqn, _, _ in _param_name_infos(module, fsdp_state): if not _is_composable(fsdp_state): fqn_from_global_root = f"{prefix}{FSDP_PREFIX}{fqn}" else: fqn_from_global_root = f"{prefix}{fqn}" try: param = state_dict.pop(fqn_from_global_root) except KeyError: logger.warning( f"Did not find param with FQN {fqn_from_global_root}, skipping it. " # noqa: G004 "The weight will not be filled if you expect it to be." ) continue # TODO: Improve unittesting for state_dict finetuning # cases: https://github.com/pytorch/pytorch/issues/109134 if not fsdp_state._state_dict_config._use_dtensor: # All-gather the param (ShardedTensor) param, shards = _ext_pre_load_state_dict_transform( param, fsdp_state._fsdp_extension ) assert len(shards) < 2, ( "Expects 0 or 1 shard per rank " f"but got {len(shards)} shards on rank {fsdp_state.rank}." ) param_numel = param.size().numel() dim_0_size = param.size()[0] chunk_size = ( math.ceil(dim_0_size / fsdp_state.world_size) * param_numel // dim_0_size ) if len(shards) == 1: local_tensor = shards[0].tensor.flatten() with SimpleProfiler.profile(SimpleProfiler.Type.H2D): local_tensor = local_tensor.to(fsdp_state.compute_device) num_padding = chunk_size - local_tensor.numel() if num_padding > 0: local_tensor = F.pad(local_tensor, [0, num_padding]) else: local_tensor = torch.zeros( chunk_size, dtype=param.dtype, device=fsdp_state.compute_device ) tensor = torch.empty( chunk_size * fsdp_state.world_size, dtype=local_tensor.dtype, device=fsdp_state.compute_device, ) with SimpleProfiler.profile(SimpleProfiler.Type.ALLGATHER): dist.all_gather_into_tensor( tensor, local_tensor, group=fsdp_state.process_group ) tensor = tensor.narrow(0, 0, param_numel).reshape(param.size()) state_dict[fqn_from_global_root] = tensor else: if param.device != fsdp_state._device_mesh.device_type: param = param.to(fsdp_state._device_mesh.device_type) root_mesh = _mesh_resources.get_root_mesh(fsdp_state._device_mesh) local_tensor = _ext_all_gather_dtensor( param, root_mesh, fsdp_state._fsdp_extension ) if fqn_to_param_ext.get(fqn) is not None: ext = fqn_to_param_ext[fqn] local_tensor = _ext_post_unflatten_transform( local_tensor, ext, fsdp_state._fsdp_extension ) state_dict[fqn_from_global_root] = local_tensor with SimpleProfiler.profile("_enter_unshard_params_ctx"): _enter_unshard_params_ctx(module, fsdp_state, writeback=True) @contextlib.contextmanager def _replace_with_full_state_dict_type(fsdp_state: _FSDPState) -> Generator: old_state_dict_config = fsdp_state._state_dict_config old_state_dict_type = fsdp_state._state_dict_type fsdp_state._state_dict_config = FullStateDictConfig() fsdp_state._state_dict_type = StateDictType.FULL_STATE_DICT yield fsdp_state._state_dict_config = old_state_dict_config fsdp_state._state_dict_type = old_state_dict_type @no_type_check @torch.no_grad() def _post_state_dict_hook( module: nn.Module, state_dict: Dict[str, Any], prefix: str, *args: Any, ) -> Dict[str, Any]: """ _post_state_dict_hook() is called after the state_dict() of this FSDP module is executed. ``fsdp_state._state_dict_type`` is used to decide what postprocessing will be done. """ fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD: context = _replace_with_full_state_dict_type(fsdp_state) warnings.warn( "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will" "be returned." ) else: context = contextlib.nullcontext() with context: _post_state_dict_hook_fn = { StateDictType.FULL_STATE_DICT: _full_post_state_dict_hook, StateDictType.LOCAL_STATE_DICT: _local_post_state_dict_hook, StateDictType.SHARDED_STATE_DICT: _sharded_post_state_dict_hook, } processed_state_dict = _post_state_dict_hook_fn[fsdp_state._state_dict_type]( module, fsdp_state, state_dict, prefix ) if fsdp_state._is_root: logger.info("FSDP finished processing state_dict(), prefix=%s", prefix) for key, tensor in sorted(processed_state_dict.items()): if key.startswith(prefix) and isinstance(tensor, torch.Tensor): local_shape = tensor.shape if isinstance(tensor, ShardedTensor): local_shape = None shards = tensor.local_shards() if shards: local_shape = shards[0].tensor.shape elif isinstance(tensor, DTensor): local_shape = tensor.to_local().shape logger.info( "FQN=%s: type=%s, shape=%s, local_shape=%s, dtype=%s, device=%s", key, type(tensor), tensor.shape, local_shape, tensor.dtype, tensor.device, ) return processed_state_dict @no_type_check @torch.no_grad() def _pre_state_dict_hook( module: nn.Module, *args, **kwargs, ) -> None: """ This is called before the core state dict saving logic of ``module``. ``fsdp_state._state_dict_type`` is used to decide what postprocessing will be done. """ fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD: context = _replace_with_full_state_dict_type(fsdp_state) warnings.warn( "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will" "be returned." ) else: _set_use_dtensor(fsdp_state) context = contextlib.nullcontext() with context: _pre_state_dict_hook_fn = { StateDictType.FULL_STATE_DICT: _full_pre_state_dict_hook, StateDictType.LOCAL_STATE_DICT: _local_pre_state_dict_hook, StateDictType.SHARDED_STATE_DICT: _sharded_pre_state_dict_hook, } _pre_state_dict_hook_fn[fsdp_state._state_dict_type]( fsdp_state, module, *args, **kwargs, ) @no_type_check def _set_use_dtensor(fsdp_state: _FSDPState) -> None: # If device_mesh is passed in when initalizing FSDP, we automatically turn the # _use_dtensor flag to be true for ShardedStateDictConfig(). if getattr(fsdp_state, "_device_mesh", None): state_dict_type = fsdp_state._state_dict_type if state_dict_type == StateDictType.LOCAL_STATE_DICT: raise RuntimeError( "Found state_dict_type LOCAL_STATE_DICT", "DeviceMesh is not compatible with LOCAL_STATE_DICT.", "Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict.", ) else: fsdp_state._state_dict_config._use_dtensor = True @no_type_check @torch.no_grad() def _pre_load_state_dict_hook( module: nn.Module, state_dict: Dict[str, Any], prefix: str, *args: Any, ) -> None: """ This is called before ``module._load_from_state_dict()``. ``fsdp_state._state_dict_type`` is used to decide what preprocessing will be done. """ fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD: context = _replace_with_full_state_dict_type(fsdp_state) warnings.warn( "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will" "be returned." ) else: _set_use_dtensor(fsdp_state) context = contextlib.nullcontext() _lazy_init(fsdp_state, module) if fsdp_state._is_root: SimpleProfiler.reset() with context: _pre_load_state_dict_hook_fn = { StateDictType.FULL_STATE_DICT: _full_pre_load_state_dict_hook, StateDictType.LOCAL_STATE_DICT: _local_pre_load_state_dict_hook, StateDictType.SHARDED_STATE_DICT: _sharded_pre_load_state_dict_hook, } # Code that is common for all state_dict impls if fsdp_state._device_handle.is_available(): fsdp_state._device_handle.synchronize() # Dispatch into state_dict specific implementation of pre-hook. _pre_load_state_dict_hook_fn[fsdp_state._state_dict_type]( module, fsdp_state, state_dict, prefix ) @no_type_check @torch.no_grad() def _post_load_state_dict_hook( module: nn.Module, incompatible_keys: Tuple[List[str], List[str]], *args: Any, ) -> None: fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD: context = _replace_with_full_state_dict_type(fsdp_state) warnings.warn( "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will" "be returned." ) else: context = contextlib.nullcontext() with context: _post_load_state_dict_hook_fn = { StateDictType.FULL_STATE_DICT: _full_post_load_state_dict_hook, StateDictType.LOCAL_STATE_DICT: _local_post_load_state_dict_hook, StateDictType.SHARDED_STATE_DICT: _sharded_post_load_state_dict_hook, } # Code that is common for all state_dict impls # Dispatch into state_dict type specific implementation of post-hook for # loading state_dict. _post_load_state_dict_hook_fn[fsdp_state._state_dict_type](module, fsdp_state) # When reporting incompatible keys, trim FSDP prefixes. missing_keys = incompatible_keys[0] unexpected_keys = incompatible_keys[1] for i in range(len(missing_keys)): missing_keys[i] = clean_tensor_name(missing_keys[i]) for i in range(len(unexpected_keys)): unexpected_keys[i] = clean_tensor_name(unexpected_keys[i]) if fsdp_state._is_root: SimpleProfiler.dump_and_reset("FSDP model load_state_dict profiling: ") def _register_all_state_dict_hooks(state: _FSDPState): """ Registers pre-save, post-save, pre-load, and post-load state dict hooks. """ for hook_registration_fn_str, hook, hook_registration_fn_kwargs in ( ("register_state_dict_pre_hook", _pre_state_dict_hook, {}), ("_register_state_dict_hook", _post_state_dict_hook, {}), ( "_register_load_state_dict_pre_hook", _pre_load_state_dict_hook, {"with_module": True}, ), ("register_load_state_dict_post_hook", _post_load_state_dict_hook, {}), ): _register_state_dict_hooks_base( state, hook_registration_fn_str, hook, hook_registration_fn_kwargs ) @no_type_check def _register_state_dict_hooks_base( state: _FSDPState, hook_registration_fn_name: str, hook: Callable, hook_registration_fn_kwargs: Dict[str, Any], ) -> None: """Registers ``hook`` using ``hook_registration_fn``.""" if not _is_composable(state): getattr(state, hook_registration_fn_name)(hook, **hook_registration_fn_kwargs) else: handle = state._handle if handle: getattr(handle._fully_sharded_module, hook_registration_fn_name)( hook, **hook_registration_fn_kwargs )