• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-decorators
2from typing import Callable, Iterable, Optional, Union
3from typing_extensions import deprecated
4
5import torch
6import torch.distributed as dist
7import torch.nn as nn
8from torch.distributed._composable.contract import contract
9from torch.distributed._composable_state import _get_module_state, _insert_module_state
10from torch.distributed.fsdp._common_utils import _FSDPState
11from torch.distributed.fsdp._dynamo_utils import _annotate_modules_for_dynamo
12from torch.distributed.fsdp._init_utils import (
13    _init_buffer_state,
14    _init_core_state,
15    _init_device_handle,
16    _init_ignored_module_states,
17    _init_param_handle_from_module,
18    _init_prefetching_state,
19    _init_process_group_state,
20    _init_runtime_state,
21    _init_state_dict_state,
22    HYBRID_SHARDING_STRATEGIES,
23)
24from torch.distributed.fsdp._runtime_utils import (
25    _register_post_forward_hook,
26    _register_pre_forward_hook,
27    _register_root_pre_forward_hook,
28)
29from torch.distributed.fsdp._state_dict_utils import _register_all_state_dict_hooks
30from torch.distributed.fsdp._wrap_utils import _auto_wrap
31from torch.distributed.fsdp.api import (
32    BackwardPrefetch,
33    CPUOffload,
34    MixedPrecision,
35    ShardingStrategy,
36)
37from torch.distributed.fsdp.wrap import _Policy
38
39
40@contract(state_cls=_FSDPState)
41@deprecated(
42    "`torch.distributed._composable.fully_shard` is being deprecated. "
43    "You can continue to use the wrapper based FSDP. "
44    "See usage in: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/fully_sharded_data_parallel.py. "
45    "`torch.distributed._composable.fully_shard` will be removed after PyTorch 2.5.",
46    category=FutureWarning,
47)
48def fully_shard(
49    module: nn.Module,
50    *,
51    process_group: Optional[dist.ProcessGroup] = None,
52    policy: Optional[_Policy] = None,
53    strategy: Optional[ShardingStrategy] = None,
54    mixed_precision: Optional[MixedPrecision] = None,
55    cpu_offload: Optional[CPUOffload] = None,
56    ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
57    device_id: Optional[Union[int, torch.device]] = None,
58    param_init_fn: Optional[Callable[[nn.Module], None]] = None,
59    sync_module_states: bool = False,
60    forward_prefetch: bool = False,
61    ignored_states: Union[
62        Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]
63    ] = None,
64) -> nn.Module:
65    """Applies ``FullyShardedDataParallel`` (FSDP) semantics to ``module``."""
66    torch._C._log_api_usage_once("torch.distributed.fully_shard")
67    # Enforce the new auto wrap policy
68    if policy is not None and not isinstance(policy, _Policy):
69        raise ValueError(f"Expects a `_Policy` but got {policy}")
70    state = fully_shard.state(module)
71    state = _init_ignored_module_states(state, module, ignored_modules, ignored_states)
72    state = _init_device_handle(state, module, state._ignored_params, device_id)
73    _annotate_modules_for_dynamo(module, state._ignored_modules, True)
74    state = _init_process_group_state(state, process_group, strategy, policy)
75    if policy is not None:
76        root_kwargs = {
77            "process_group": process_group,
78            "strategy": strategy,
79            "mixed_precision": mixed_precision,
80            "cpu_offload": cpu_offload,
81            "ignored_modules": ignored_modules,
82            "device_id": device_id,
83            "param_init_fn": param_init_fn,
84            "sync_module_states": sync_module_states,
85            "forward_prefetch": forward_prefetch,
86            "ignored_states": ignored_states,
87        }
88        if strategy in HYBRID_SHARDING_STRATEGIES:
89            root_kwargs["process_group"] = (state.process_group, state._inter_node_pg)
90        _auto_wrap(
91            module,
92            policy,
93            state._ignored_modules,
94            state._ignored_params,
95            root_kwargs,
96            fully_shard,
97        )
98    state = _init_core_state(
99        state,
100        strategy or ShardingStrategy.FULL_SHARD,
101        mixed_precision,
102        cpu_offload,
103        limit_all_gathers=True,
104        use_orig_params=True,
105        backward_prefetch_limit=1,
106        forward_prefetch_limit=1,
107    )
108    state = _init_runtime_state(state)
109    state = _init_prefetching_state(
110        state, BackwardPrefetch.BACKWARD_PRE, forward_prefetch=forward_prefetch
111    )
112    state = _init_buffer_state(state, module)
113    state = _init_param_handle_from_module(
114        state, module, device_id, param_init_fn, sync_module_states
115    )
116    state = _init_state_dict_state(state)
117    _register_all_state_dict_hooks(state)
118    _register_pre_forward_hook(state, module)
119    _register_post_forward_hook(state, module)
120    _register_root_pre_forward_hook(state, module)  # prepend last
121    # Always insert the state for the passed-in module even if it has no
122    # managed parameters, in which case it has no handles and does not appear
123    # in `_fully_sharded_module_to_handles`
124    _insert_module_state(module, state)
125    for submodule in module.modules():
126        if (
127            submodule in state._fully_sharded_module_to_handle
128            and _get_module_state(submodule) is None
129        ):
130            _insert_module_state(submodule, state)
131    return module
132