• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import contextlib
3import logging
4import math
5import warnings
6from typing import (
7    Any,
8    Callable,
9    cast,
10    Dict,
11    Generator,
12    Iterator,
13    List,
14    no_type_check,
15    Tuple,
16)
17
18import torch
19import torch.distributed as dist
20import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as checkpoint_wrapper
21import torch.nn as nn
22import torch.nn.functional as F
23from torch.distributed._shard.sharded_tensor import (
24    init_from_local_shards,
25    Shard,
26    ShardedTensor,
27)
28from torch.distributed.device_mesh import _mesh_resources
29from torch.distributed.fsdp._common_utils import (
30    _FSDPState,
31    _get_module_fsdp_state_if_fully_sharded_module,
32    _has_fsdp_params,
33    _is_composable,
34    _module_handle,
35    clean_tensor_name,
36    FSDP_PREFIX,
37    FSDP_WRAPPED_MODULE,
38)
39from torch.distributed.fsdp._debug_utils import SimpleProfiler
40from torch.distributed.fsdp._runtime_utils import (
41    _cast_buffers_to_dtype_and_device,
42    _get_orig_buffer_dtypes,
43    _lazy_init,
44    _reset_flat_param_grad_info_if_needed,
45)
46from torch.distributed.fsdp.api import (
47    FullStateDictConfig,
48    ShardingStrategy,
49    StateDictType,
50)
51from torch.distributed.tensor import DTensor
52from torch.distributed.utils import _replace_by_prefix
53
54from ._fsdp_extensions import (
55    _ext_all_gather_dtensor,
56    _ext_chunk_dtensor,
57    _ext_chunk_tensor,
58    _ext_post_unflatten_transform,
59    _ext_pre_load_state_dict_transform,
60)
61from ._unshard_param_utils import _unshard_fsdp_state_params, FLAT_PARAM
62
63
64logger = logging.getLogger(__name__)
65
66
67def _should_unshard_params(fsdp_state: _FSDPState) -> bool:
68    return not (
69        fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD
70        and (_is_composable(fsdp_state) or fsdp_state._use_orig_params)
71    )
72
73
74def _convert_to_wrapped_module_name(module_name: str) -> str:
75    module_name = module_name.replace(f"{FSDP_PREFIX}", "")
76    module_name = module_name.replace(f"{FSDP_WRAPPED_MODULE}", "")
77    if module_name:
78        module_name = f"{module_name}."
79    # `CheckpointWrapper` adds a prefix that has to be removed as well.
80    module_name = module_name.replace(checkpoint_wrapper._CHECKPOINT_PREFIX, "")
81    return module_name
82
83
84def _param_name_infos(
85    module: nn.Module, fsdp_state: _FSDPState
86) -> Iterator[Tuple[str, str, str]]:
87    if not _has_fsdp_params(fsdp_state, module):
88        return
89    for param_name, module_name in _module_handle(
90        fsdp_state, module
91    ).param_module_names():
92        module_name = _convert_to_wrapped_module_name(module_name)
93        fqn = f"{module_name}{param_name}"
94        yield fqn, param_name, module_name
95
96
97def _shared_param_name_infos(
98    module: nn.Module, fsdp_state
99) -> Iterator[Tuple[str, str, str]]:
100    for param_name, module_name in _module_handle(
101        fsdp_state, module
102    ).shared_param_module_names():
103        module_name = _convert_to_wrapped_module_name(module_name)
104        fqn = f"{module_name}{param_name}"
105        yield fqn, param_name, module_name
106
107
108@no_type_check
109def _enter_unshard_params_ctx(
110    module: nn.Module,
111    fsdp_state: _FSDPState,
112    writeback: bool = False,
113    rank0_only: bool = False,
114    offload_to_cpu: bool = False,
115    with_grads: bool = False,
116) -> None:
117    """
118    state_dict hooks cannot use the pure context call as the checkpoint flow
119    requires to enter the context in the pre-hook but leave the context in the
120    post-hook. This API enters the context of ``_unshard_fsdp_state_params``.
121    """
122    assert module not in fsdp_state._unshard_params_ctx, (
123        "Entering the ``_unshard_fsdp_state_params`` context but _unshard_params_ctx[module] "
124        "is not None."
125    )
126    fsdp_state._unshard_params_ctx[module] = _unshard_fsdp_state_params(
127        module,
128        fsdp_state,
129        writeback=writeback,
130        rank0_only=rank0_only,
131        offload_to_cpu=offload_to_cpu,
132        with_grads=with_grads,
133    )
134    fsdp_state._unshard_params_ctx[module].__enter__()
135
136
137@no_type_check
138def _exit_unshard_params_ctx(module: nn.Module, fsdp_state: _FSDPState) -> None:
139    """A helper function to exit ``_unshard_fsdp_state_params`` context."""
140    fsdp_state._unshard_params_ctx[module].__exit__(None, None, None)
141    fsdp_state._unshard_params_ctx.pop(module)
142
143
144def _common_pre_state_dict_hook(
145    module: nn.Module,
146    fsdp_state: _FSDPState,
147) -> None:
148    """Performs the pre-state_dict tasks shared by all state_dict types."""
149    if fsdp_state._device_handle.is_available():
150        fsdp_state._device_handle.synchronize()
151    # TODO: need to check if this is always correct for composable FSDP.
152    _lazy_init(fsdp_state, module)
153    if fsdp_state._is_root:
154        _reset_flat_param_grad_info_if_needed(fsdp_state._all_handles)
155
156
157def _common_unshard_pre_state_dict_hook(
158    module: nn.Module,
159    fsdp_state: _FSDPState,
160    offload_to_cpu: bool,
161    rank0_only: bool,
162) -> None:
163    """
164    Performs the pre-state_dict tasks shared by all state_dict types that require
165    ``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this hook.
166    """
167    # For composable `fully_shard`, it does not need to unshard parameters for `NO_SHARD` cases.
168    if not _should_unshard_params(fsdp_state):
169        return
170    _enter_unshard_params_ctx(
171        module,
172        fsdp_state,
173        writeback=False,
174        offload_to_cpu=offload_to_cpu,
175        rank0_only=rank0_only,
176    )
177
178
179@no_type_check
180def _common_unshard_post_state_dict_hook(
181    module: nn.Module,
182    fsdp_state: _FSDPState,
183    state_dict: Dict[str, Any],
184    prefix: str,
185    param_hook: Callable,
186) -> Dict[str, Any]:
187    """
188    The post-state_dict flow that shared by all state_dict types that require
189    ``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this
190    hook.
191    """
192    _replace_by_prefix(state_dict, prefix + f"{FSDP_PREFIX}", prefix)
193    # Return early for trivial cases
194    if not state_dict or not _has_fsdp_params(fsdp_state, module):
195        if _should_unshard_params(fsdp_state):
196            _exit_unshard_params_ctx(module, fsdp_state)
197        return state_dict
198
199    # If a rank does not have unsharded parameters(when `rank0_only=True`
200    # and `rank != 0`), then the rank only needed to participate in the
201    # all-gather and does not need to save the # state dict. We simply check
202    # rank0_only to ensure this issue.
203    rank0_only = (
204        fsdp_state._state_dict_type == StateDictType.FULL_STATE_DICT
205        and cast(FullStateDictConfig, fsdp_state._state_dict_config).rank0_only
206    )
207    # no_fsdp_return means the state_dict returned by this rank should contain
208    # only non-FSDP controlled parameters and buffers.
209    no_fsdp_return = rank0_only and fsdp_state.rank != 0
210    if no_fsdp_return and not fsdp_state._use_orig_params:
211        for clean_key in fsdp_state._buffer_names:
212            # This is a hack to support activation checkpoint.
213            clean_key = clean_key.replace(
214                f"{checkpoint_wrapper._CHECKPOINT_PREFIX}.", ""
215            )
216            state_dict.pop(f"{prefix}{clean_key}", None)
217        # Non-zero ranks have flat_param key when rank0_only=True, because rank0_only=True is
218        # passed in to unshard context, but nonzero ranks reshard early, causing this flat_param
219        # to appear in state_dict.
220        state_dict.pop(f"{prefix}{FLAT_PARAM}")
221        _exit_unshard_params_ctx(module, fsdp_state)
222        return state_dict
223
224    # Loop only the parameters saved in this instance's wrapped module to
225    # avoid processing buffers.
226    for fqn, param_name, module_name in _param_name_infos(module, fsdp_state):
227        fqn = f"{prefix}{fqn}"
228        if no_fsdp_return:
229            state_dict.pop(fqn)
230            continue
231        assert fqn in state_dict, (
232            f"FSDP assumes {fqn} is in the state_dict but the state_dict only "
233            f"has {state_dict.keys()}. "
234            f"prefix={prefix}, module_name={module_name}, "
235            f"param_name={param_name} rank={fsdp_state.rank}."
236        )
237
238        param_hook(state_dict, prefix, fqn)
239
240    if _should_unshard_params(fsdp_state):
241        _exit_unshard_params_ctx(module, fsdp_state)
242
243    cpu_device = torch.device("cpu")
244    buffer_clean_fqns = []
245    buffers = []
246    for clean_key in fsdp_state._buffer_names:
247        # This is a hack to support activation checkpoint.
248        clean_key = clean_tensor_name(clean_key)
249        fqn = f"{prefix}{clean_key}"
250        if fqn not in state_dict:
251            # A buffer can be registered as non-persistent.
252            continue
253        if no_fsdp_return:
254            state_dict.pop(fqn)
255        else:
256            buffer = state_dict[fqn]
257            if (
258                fsdp_state._state_dict_config.offload_to_cpu
259                and buffer.device != cpu_device
260            ):
261                state_dict[fqn] = buffer.to(cpu_device)
262            # skip upcasting for ignored buffers
263            if clean_key not in fsdp_state._ignored_buffer_names:
264                buffer_clean_fqns.append(clean_key)
265                buffers.append(state_dict[fqn])
266
267    if buffers:
268        mixed_precision_enabled_for_buffers = (
269            fsdp_state._mixed_precision_enabled_for_buffers()
270            if not _is_composable(fsdp_state)
271            else (fsdp_state.mixed_precision.buffer_dtype is not None)
272        )
273        if mixed_precision_enabled_for_buffers:
274            buffer_dtypes = _get_orig_buffer_dtypes(fsdp_state, buffer_clean_fqns)
275            _cast_buffers_to_dtype_and_device(
276                buffers, buffer_dtypes, fsdp_state.compute_device
277            )
278            for buffer, clean_fqn in zip(buffers, buffer_clean_fqns):
279                fqn = f"{prefix}{clean_fqn}"
280                logger.info("FSDP is casting the dtype of %s to %s", fqn, buffer.dtype)
281                state_dict[fqn] = buffer.clone()
282    return state_dict
283
284
285@no_type_check
286def _full_pre_state_dict_hook(
287    fsdp_state: _FSDPState,
288    module: nn.Module,
289    *args,
290    **kwargs,
291) -> None:
292    """
293    Hook that runs before model.state_dict() is called. pre-state_dict hook is
294    not actually supported by ``nn.Module``. As a result, this API is called
295    from ``_full_post_state_dict_hook()`` to simulate the case. Once pre-state_dict
296    is supported in ``nn.Module``, this hook will be registered as a hook in
297    ``nn.Module``.
298    """
299    if getattr(fsdp_state, "_device_mesh", False):
300        root_mesh = _mesh_resources.get_root_mesh(fsdp_state._device_mesh)
301
302    _common_pre_state_dict_hook(module, fsdp_state)
303    _common_unshard_pre_state_dict_hook(
304        module,
305        fsdp_state,
306        offload_to_cpu=fsdp_state._state_dict_config.offload_to_cpu,
307        rank0_only=cast(FullStateDictConfig, fsdp_state._state_dict_config).rank0_only,
308    )
309
310
311@no_type_check
312def _full_post_state_dict_hook(
313    module: nn.Module,
314    fsdp_state: _FSDPState,
315    state_dict: Dict[str, Any],
316    prefix: str,
317) -> Dict[str, Any]:
318    """
319    Hook that runs after model.state_dict() is called before returning result to
320    user. For FSDP, we may have to clone the tensors in state_dict as params go
321    back to sharded version after _unshard_fsdp_state_params ends, and also remove
322    the ``FSDP_WRAPPED_MODULE`` prefix.
323    """
324
325    def param_hook(
326        state_dict: Dict[str, Any],
327        prefix: str,
328        fqn: str,
329    ) -> None:
330        clean_key = fqn
331        clean_prefix = clean_tensor_name(prefix)
332        # Strip prefix out of key if needed as buffer names and param names
333        # do not have prefix considered as they are not computed in `state_dict`
334        # call.
335        if clean_key.startswith(clean_prefix):
336            clean_key = clean_key[len(clean_prefix) :]
337
338        # Clone parameters before exiting the `_unshard_fsdp_state_params()` context.
339        if not getattr(state_dict[fqn], "_has_been_cloned", False):
340            try:
341                state_dict[fqn] = state_dict[fqn].clone().detach()
342                state_dict[fqn]._has_been_cloned = True  # type: ignore[attr-defined]
343            except BaseException as e:
344                warnings.warn(
345                    f"Failed to clone() tensor with name {fqn} on rank {fsdp_state.rank}. "
346                    "This may mean that this state_dict entry could point to invalid "
347                    "memory regions after returning from state_dict() call if this "
348                    "parameter is managed by FSDP. Please check clone "
349                    f"implementation of {fqn}. Error: {str(e)}"
350                )
351
352    return _common_unshard_post_state_dict_hook(
353        module, fsdp_state, state_dict, prefix, param_hook
354    )
355
356
357def _full_pre_load_state_dict_hook(
358    module: nn.Module,
359    fsdp_state: _FSDPState,
360    state_dict: Dict[str, Any],
361    prefix: str,
362) -> None:
363    _lazy_init(fsdp_state, module)
364    if _should_unshard_params(fsdp_state):
365        with SimpleProfiler.profile("_enter_unshard_params_ctx"):
366            _enter_unshard_params_ctx(module, fsdp_state, writeback=True)
367    # Add FSDP_PREFIX only for wrapper-based FSDP.
368    if not _is_composable(fsdp_state):
369        _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}")
370
371
372def _full_post_load_state_dict_hook(
373    module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs
374) -> None:
375    if _should_unshard_params(fsdp_state):
376        with SimpleProfiler.profile("_exit_unshard_params_ctx"):
377            _exit_unshard_params_ctx(module, fsdp_state)
378
379
380def _local_pre_state_dict_hook(
381    fsdp_state: _FSDPState,
382    module: nn.Module,
383    *args,
384    **kwargs,
385) -> None:
386    """
387    Hook that runs before model.state_dict() is called. Right now, pre-state_dict
388    hook is not supported by the PyTorch core. So this API is called from
389    `_local_post_state_dict_hook()` to simulate the case.
390    """
391    if (
392        _has_fsdp_params(fsdp_state, module)
393        and not _module_handle(fsdp_state, module).uses_sharded_strategy
394    ):
395        raise RuntimeError(
396            "``local_state_dict`` can only be used when parameters are flatten "
397            "and sharded."
398        )
399    _common_pre_state_dict_hook(module, fsdp_state)
400
401
402@no_type_check
403def _local_post_state_dict_hook(
404    module: nn.Module,
405    fsdp_state: _FSDPState,
406    state_dict: Dict[str, Any],
407    prefix: str,
408) -> Dict[str, Any]:
409    """
410    This hook create a ShardedTensor from the local flat_param and replace
411    the state_dict[f"{prefix}{FLAT_PARAM}] with the ShardedTensor. No copy
412    will happen. The underlying storage is the same.
413    """
414
415    _replace_by_prefix(state_dict, f"{prefix}{FSDP_PREFIX}", prefix)
416    if not _has_fsdp_params(fsdp_state, module):
417        return state_dict
418
419    # state_dict[f"{prefix}{FLAT_PARAM}"] exists and has the same tensor
420    # value as the flat_param but it is a pure Tensor because
421    # nn.Module.state_dict() will detach the parameter. Therefore, we need
422    # to get flat_param to get the metadata.
423    assert _module_handle(fsdp_state, module), "Should have returned early"
424    flat_param = _module_handle(fsdp_state, module).flat_param
425    # Constructs a ShardedTensor from the flat_param "without" padding.
426    # Removing the padding allows users to change the number of ranks
427    # when loading the local_state_dict.
428    full_numel = flat_param._unpadded_unsharded_size.numel()  # type: ignore[attr-defined]
429    shard_offset = flat_param.numel() * fsdp_state.rank
430    valid_data_size = flat_param.numel() - flat_param._shard_numel_padded
431    if valid_data_size > 0:
432        # If FlatParameter is returned, FlatParameter._local_shard cause a
433        # pickling issue (can be torch.save but not torch.load). Since there
434        # is no benefit for state_dict to return the actual FlatParameter class,
435        # a view (which is a tensor) of the FlatParameter will be returned.
436        flat_param = flat_param[:valid_data_size].view(valid_data_size)
437        local_shards = [
438            Shard.from_tensor_and_offsets(flat_param, [shard_offset], fsdp_state.rank)
439        ]
440    else:
441        local_shards = []
442    sharded_tensor = init_from_local_shards(
443        local_shards, full_numel, process_group=fsdp_state.process_group
444    )  # type: ignore[assignment]
445    # TODO: Add DTensor state_dict support for LOCAL_STATE_DICT.
446    if fsdp_state._state_dict_config.offload_to_cpu:
447        sharded_tensor = sharded_tensor.cpu()
448    state_dict[f"{prefix}{FLAT_PARAM}"] = sharded_tensor
449    return state_dict
450
451
452def _local_post_load_state_dict_hook(
453    module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs
454) -> None:
455    pass
456
457
458def _local_pre_load_state_dict_hook(
459    module: nn.Module,
460    fsdp_state: _FSDPState,
461    state_dict: Dict[str, Any],
462    prefix: str,
463) -> None:
464    """
465    This hook finds the local flat_param for this FSDP module from the
466    state_dict. The flat_param should be a ShardedTensor. This hook converts
467    the ShardedTensor to a tensor. No copy happen unless padding is required.
468    """
469    _lazy_init(fsdp_state, module)
470    _replace_by_prefix(state_dict, prefix, f"{prefix}{FSDP_PREFIX}")
471    fqn = f"{prefix}{FSDP_PREFIX}{FLAT_PARAM}"
472    if fqn not in state_dict:
473        assert not _has_fsdp_params(fsdp_state, module), (
474            "No `FlatParameter` in `state_dict` for this FSDP instance "
475            "but it has parameters"
476        )
477        return
478    load_tensor = state_dict[fqn]
479    assert isinstance(
480        load_tensor, ShardedTensor
481    ), "Tensors in local_state_dict should be ShardedTensor."
482
483    # Convert the ShardedTensor to a Tensor.
484    flat_param = _module_handle(fsdp_state, module).flat_param
485    assert flat_param is not None
486    valid_data_size = flat_param.numel() - flat_param._shard_numel_padded
487    shards = load_tensor.local_shards()
488    if valid_data_size > 0:
489        assert len(shards), "load_local_state_dict assume one shard per ShardedTensor."
490        load_tensor = shards[0].tensor
491
492        # Get the metadata of the flat_param to decide whether to pad the loaded
493        # tensor.
494        if flat_param._shard_numel_padded > 0:
495            assert load_tensor.numel() < flat_param.numel(), (
496                f"Local shard size = {flat_param.numel()} and the tensor in "
497                f"the state_dict is {load_tensor.numel()}."
498            )
499            load_tensor = F.pad(load_tensor, [0, flat_param._shard_numel_padded])
500    else:
501        load_tensor = flat_param
502    # TODO: Add DTensor state_dict support for LOCAL_STATE_DICT.
503    state_dict[fqn] = load_tensor
504
505
506def _sharded_pre_state_dict_hook(
507    fsdp_state: _FSDPState,
508    module: nn.Module,
509    *args,
510    **kwargs,
511) -> None:
512    """
513    Hook that runs before model.state_dict() is called. Check
514    ``_full_pre_load_state_dict_hook`` for the detail.
515    """
516    if (
517        _has_fsdp_params(fsdp_state, module)
518        and not _module_handle(fsdp_state, module).uses_sharded_strategy
519    ):
520        raise RuntimeError(
521            "``sharded_state_dict`` can only be used when parameters are flatten "
522            "and sharded."
523        )
524    _common_pre_state_dict_hook(module, fsdp_state)
525    # Setting offload_to_cpu here does not work even if offload_to_cpu is True.
526    # We have to create ShardedTensor first then move it to CPU.
527    _common_unshard_pre_state_dict_hook(
528        module,
529        fsdp_state,
530        offload_to_cpu=False,
531        rank0_only=False,
532    )
533
534
535@no_type_check
536def _sharded_post_state_dict_hook(
537    module: nn.Module,
538    fsdp_state: _FSDPState,
539    state_dict: Dict[str, Any],
540    prefix: str,
541) -> Dict[str, Any]:
542    """
543    The hook replaces the unflattened, unsharded parameter in the state_dict
544    with a unflattened, sharded parameter (a ShardedTensor).
545    """
546
547    def param_hook(state_dict: Dict[str, Any], prefix: str, fqn: str):
548        param = state_dict[fqn]
549        if not fsdp_state._state_dict_config._use_dtensor:
550            sharded_tensor = _ext_chunk_tensor(
551                tensor=param,
552                rank=fsdp_state.rank,
553                world_size=fsdp_state.world_size,
554                num_devices_per_node=fsdp_state._device_handle.device_count(),
555                pg=fsdp_state.process_group,
556                fsdp_extension=fsdp_state._fsdp_extension,
557            )
558        else:
559            sharded_tensor = _ext_chunk_dtensor(
560                tensor=param,
561                rank=fsdp_state.rank,
562                device_mesh=fsdp_state._device_mesh,
563                fsdp_extension=fsdp_state._fsdp_extension,
564            )
565        if fsdp_state._state_dict_config.offload_to_cpu:
566            sharded_tensor = sharded_tensor.cpu()
567        state_dict[fqn] = sharded_tensor
568
569    return _common_unshard_post_state_dict_hook(
570        module, fsdp_state, state_dict, prefix, param_hook
571    )
572
573
574@no_type_check
575def _sharded_post_load_state_dict_hook(
576    module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs
577) -> None:
578    if _has_fsdp_params(fsdp_state, module):
579        with SimpleProfiler.profile("_exit_unshard_params_ctx"):
580            _exit_unshard_params_ctx(module, fsdp_state)
581
582
583@no_type_check
584def _sharded_pre_load_state_dict_hook(
585    module: nn.Module,
586    fsdp_state: _FSDPState,
587    state_dict: Dict[str, Any],
588    prefix: str,
589) -> None:
590    """
591    The hook combines the unflattened, sharded parameters (ShardedTensor) to
592    a new FlatParameter and shards the new FlatParameter to the local chunk.
593    """
594    _lazy_init(fsdp_state, module)
595    if not _is_composable(fsdp_state):
596        _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}")
597    if not _has_fsdp_params(fsdp_state, module):
598        return
599
600    handle = _module_handle(fsdp_state, module)
601    if not handle.uses_sharded_strategy:
602        raise RuntimeError(
603            "load_sharded_state_dict can only be called when parameters "
604            "are flattened and sharded."
605        )
606    fqn_to_param_ext = dict(
607        zip(handle.flat_param._fqns, handle.flat_param._param_extensions)
608    )
609
610    for fqn, _, _ in _param_name_infos(module, fsdp_state):
611        if not _is_composable(fsdp_state):
612            fqn_from_global_root = f"{prefix}{FSDP_PREFIX}{fqn}"
613        else:
614            fqn_from_global_root = f"{prefix}{fqn}"
615        try:
616            param = state_dict.pop(fqn_from_global_root)
617        except KeyError:
618            logger.warning(
619                f"Did not find param with FQN {fqn_from_global_root}, skipping it. "  # noqa: G004
620                "The weight will not be filled if you expect it to be."
621            )
622            continue  # TODO: Improve unittesting for state_dict finetuning
623            # cases: https://github.com/pytorch/pytorch/issues/109134
624
625        if not fsdp_state._state_dict_config._use_dtensor:
626            # All-gather the param (ShardedTensor)
627            param, shards = _ext_pre_load_state_dict_transform(
628                param, fsdp_state._fsdp_extension
629            )
630
631            assert len(shards) < 2, (
632                "Expects 0 or 1 shard per rank "
633                f"but got {len(shards)} shards on rank {fsdp_state.rank}."
634            )
635            param_numel = param.size().numel()
636            dim_0_size = param.size()[0]
637            chunk_size = (
638                math.ceil(dim_0_size / fsdp_state.world_size)
639                * param_numel
640                // dim_0_size
641            )
642            if len(shards) == 1:
643                local_tensor = shards[0].tensor.flatten()
644                with SimpleProfiler.profile(SimpleProfiler.Type.H2D):
645                    local_tensor = local_tensor.to(fsdp_state.compute_device)
646                num_padding = chunk_size - local_tensor.numel()
647                if num_padding > 0:
648                    local_tensor = F.pad(local_tensor, [0, num_padding])
649            else:
650                local_tensor = torch.zeros(
651                    chunk_size, dtype=param.dtype, device=fsdp_state.compute_device
652                )
653            tensor = torch.empty(
654                chunk_size * fsdp_state.world_size,
655                dtype=local_tensor.dtype,
656                device=fsdp_state.compute_device,
657            )
658            with SimpleProfiler.profile(SimpleProfiler.Type.ALLGATHER):
659                dist.all_gather_into_tensor(
660                    tensor, local_tensor, group=fsdp_state.process_group
661                )
662            tensor = tensor.narrow(0, 0, param_numel).reshape(param.size())
663            state_dict[fqn_from_global_root] = tensor
664        else:
665            if param.device != fsdp_state._device_mesh.device_type:
666                param = param.to(fsdp_state._device_mesh.device_type)
667
668            root_mesh = _mesh_resources.get_root_mesh(fsdp_state._device_mesh)
669            local_tensor = _ext_all_gather_dtensor(
670                param, root_mesh, fsdp_state._fsdp_extension
671            )
672
673            if fqn_to_param_ext.get(fqn) is not None:
674                ext = fqn_to_param_ext[fqn]
675                local_tensor = _ext_post_unflatten_transform(
676                    local_tensor, ext, fsdp_state._fsdp_extension
677                )
678            state_dict[fqn_from_global_root] = local_tensor
679
680    with SimpleProfiler.profile("_enter_unshard_params_ctx"):
681        _enter_unshard_params_ctx(module, fsdp_state, writeback=True)
682
683
684@contextlib.contextmanager
685def _replace_with_full_state_dict_type(fsdp_state: _FSDPState) -> Generator:
686    old_state_dict_config = fsdp_state._state_dict_config
687    old_state_dict_type = fsdp_state._state_dict_type
688    fsdp_state._state_dict_config = FullStateDictConfig()
689    fsdp_state._state_dict_type = StateDictType.FULL_STATE_DICT
690    yield
691    fsdp_state._state_dict_config = old_state_dict_config
692    fsdp_state._state_dict_type = old_state_dict_type
693
694
695@no_type_check
696@torch.no_grad()
697def _post_state_dict_hook(
698    module: nn.Module,
699    state_dict: Dict[str, Any],
700    prefix: str,
701    *args: Any,
702) -> Dict[str, Any]:
703    """
704    _post_state_dict_hook() is called after the state_dict() of this
705    FSDP module is executed. ``fsdp_state._state_dict_type`` is used to decide
706    what postprocessing will be done.
707    """
708    fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)
709    if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD:
710        context = _replace_with_full_state_dict_type(fsdp_state)
711        warnings.warn(
712            "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will"
713            "be returned."
714        )
715    else:
716        context = contextlib.nullcontext()
717
718    with context:
719        _post_state_dict_hook_fn = {
720            StateDictType.FULL_STATE_DICT: _full_post_state_dict_hook,
721            StateDictType.LOCAL_STATE_DICT: _local_post_state_dict_hook,
722            StateDictType.SHARDED_STATE_DICT: _sharded_post_state_dict_hook,
723        }
724        processed_state_dict = _post_state_dict_hook_fn[fsdp_state._state_dict_type](
725            module, fsdp_state, state_dict, prefix
726        )
727
728    if fsdp_state._is_root:
729        logger.info("FSDP finished processing state_dict(), prefix=%s", prefix)
730        for key, tensor in sorted(processed_state_dict.items()):
731            if key.startswith(prefix) and isinstance(tensor, torch.Tensor):
732                local_shape = tensor.shape
733                if isinstance(tensor, ShardedTensor):
734                    local_shape = None
735                    shards = tensor.local_shards()
736                    if shards:
737                        local_shape = shards[0].tensor.shape
738                elif isinstance(tensor, DTensor):
739                    local_shape = tensor.to_local().shape
740                logger.info(
741                    "FQN=%s: type=%s, shape=%s, local_shape=%s, dtype=%s, device=%s",
742                    key,
743                    type(tensor),
744                    tensor.shape,
745                    local_shape,
746                    tensor.dtype,
747                    tensor.device,
748                )
749
750    return processed_state_dict
751
752
753@no_type_check
754@torch.no_grad()
755def _pre_state_dict_hook(
756    module: nn.Module,
757    *args,
758    **kwargs,
759) -> None:
760    """
761    This is called before the core state dict saving logic of ``module``.
762    ``fsdp_state._state_dict_type`` is used to decide what postprocessing will
763    be done.
764    """
765    fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)
766    if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD:
767        context = _replace_with_full_state_dict_type(fsdp_state)
768        warnings.warn(
769            "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will"
770            "be returned."
771        )
772    else:
773        _set_use_dtensor(fsdp_state)
774        context = contextlib.nullcontext()
775
776    with context:
777        _pre_state_dict_hook_fn = {
778            StateDictType.FULL_STATE_DICT: _full_pre_state_dict_hook,
779            StateDictType.LOCAL_STATE_DICT: _local_pre_state_dict_hook,
780            StateDictType.SHARDED_STATE_DICT: _sharded_pre_state_dict_hook,
781        }
782        _pre_state_dict_hook_fn[fsdp_state._state_dict_type](
783            fsdp_state,
784            module,
785            *args,
786            **kwargs,
787        )
788
789
790@no_type_check
791def _set_use_dtensor(fsdp_state: _FSDPState) -> None:
792    # If device_mesh is passed in when initalizing FSDP, we automatically turn the
793    # _use_dtensor flag to be true for ShardedStateDictConfig().
794    if getattr(fsdp_state, "_device_mesh", None):
795        state_dict_type = fsdp_state._state_dict_type
796        if state_dict_type == StateDictType.LOCAL_STATE_DICT:
797            raise RuntimeError(
798                "Found state_dict_type LOCAL_STATE_DICT",
799                "DeviceMesh is not compatible with LOCAL_STATE_DICT.",
800                "Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict.",
801            )
802        else:
803            fsdp_state._state_dict_config._use_dtensor = True
804
805
806@no_type_check
807@torch.no_grad()
808def _pre_load_state_dict_hook(
809    module: nn.Module,
810    state_dict: Dict[str, Any],
811    prefix: str,
812    *args: Any,
813) -> None:
814    """
815    This is called before ``module._load_from_state_dict()``.
816    ``fsdp_state._state_dict_type`` is used to decide what preprocessing will
817    be done.
818    """
819    fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)
820    if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD:
821        context = _replace_with_full_state_dict_type(fsdp_state)
822        warnings.warn(
823            "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will"
824            "be returned."
825        )
826    else:
827        _set_use_dtensor(fsdp_state)
828        context = contextlib.nullcontext()
829
830    _lazy_init(fsdp_state, module)
831    if fsdp_state._is_root:
832        SimpleProfiler.reset()
833
834    with context:
835        _pre_load_state_dict_hook_fn = {
836            StateDictType.FULL_STATE_DICT: _full_pre_load_state_dict_hook,
837            StateDictType.LOCAL_STATE_DICT: _local_pre_load_state_dict_hook,
838            StateDictType.SHARDED_STATE_DICT: _sharded_pre_load_state_dict_hook,
839        }
840        # Code that is common for all state_dict impls
841        if fsdp_state._device_handle.is_available():
842            fsdp_state._device_handle.synchronize()
843        # Dispatch into state_dict specific implementation of pre-hook.
844        _pre_load_state_dict_hook_fn[fsdp_state._state_dict_type](
845            module, fsdp_state, state_dict, prefix
846        )
847
848
849@no_type_check
850@torch.no_grad()
851def _post_load_state_dict_hook(
852    module: nn.Module,
853    incompatible_keys: Tuple[List[str], List[str]],
854    *args: Any,
855) -> None:
856    fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)
857    if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD:
858        context = _replace_with_full_state_dict_type(fsdp_state)
859        warnings.warn(
860            "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will"
861            "be returned."
862        )
863    else:
864        context = contextlib.nullcontext()
865
866    with context:
867        _post_load_state_dict_hook_fn = {
868            StateDictType.FULL_STATE_DICT: _full_post_load_state_dict_hook,
869            StateDictType.LOCAL_STATE_DICT: _local_post_load_state_dict_hook,
870            StateDictType.SHARDED_STATE_DICT: _sharded_post_load_state_dict_hook,
871        }
872        # Code that is common for all state_dict impls
873        # Dispatch into state_dict type specific implementation of post-hook for
874        # loading state_dict.
875        _post_load_state_dict_hook_fn[fsdp_state._state_dict_type](module, fsdp_state)
876
877    # When reporting incompatible keys, trim FSDP prefixes.
878    missing_keys = incompatible_keys[0]
879    unexpected_keys = incompatible_keys[1]
880    for i in range(len(missing_keys)):
881        missing_keys[i] = clean_tensor_name(missing_keys[i])
882
883    for i in range(len(unexpected_keys)):
884        unexpected_keys[i] = clean_tensor_name(unexpected_keys[i])
885
886    if fsdp_state._is_root:
887        SimpleProfiler.dump_and_reset("FSDP model load_state_dict profiling: ")
888
889
890def _register_all_state_dict_hooks(state: _FSDPState):
891    """
892    Registers pre-save, post-save, pre-load, and post-load state dict hooks.
893    """
894    for hook_registration_fn_str, hook, hook_registration_fn_kwargs in (
895        ("register_state_dict_pre_hook", _pre_state_dict_hook, {}),
896        ("_register_state_dict_hook", _post_state_dict_hook, {}),
897        (
898            "_register_load_state_dict_pre_hook",
899            _pre_load_state_dict_hook,
900            {"with_module": True},
901        ),
902        ("register_load_state_dict_post_hook", _post_load_state_dict_hook, {}),
903    ):
904        _register_state_dict_hooks_base(
905            state, hook_registration_fn_str, hook, hook_registration_fn_kwargs
906        )
907
908
909@no_type_check
910def _register_state_dict_hooks_base(
911    state: _FSDPState,
912    hook_registration_fn_name: str,
913    hook: Callable,
914    hook_registration_fn_kwargs: Dict[str, Any],
915) -> None:
916    """Registers ``hook`` using ``hook_registration_fn``."""
917    if not _is_composable(state):
918        getattr(state, hook_registration_fn_name)(hook, **hook_registration_fn_kwargs)
919    else:
920        handle = state._handle
921        if handle:
922            getattr(handle._fully_sharded_module, hook_registration_fn_name)(
923                hook, **hook_registration_fn_kwargs
924            )
925