1# mypy: allow-untyped-decorators 2from typing import Callable, Iterable, Optional, Union 3from typing_extensions import deprecated 4 5import torch 6import torch.distributed as dist 7import torch.nn as nn 8from torch.distributed._composable.contract import contract 9from torch.distributed._composable_state import _get_module_state, _insert_module_state 10from torch.distributed.fsdp._common_utils import _FSDPState 11from torch.distributed.fsdp._dynamo_utils import _annotate_modules_for_dynamo 12from torch.distributed.fsdp._init_utils import ( 13 _init_buffer_state, 14 _init_core_state, 15 _init_device_handle, 16 _init_ignored_module_states, 17 _init_param_handle_from_module, 18 _init_prefetching_state, 19 _init_process_group_state, 20 _init_runtime_state, 21 _init_state_dict_state, 22 HYBRID_SHARDING_STRATEGIES, 23) 24from torch.distributed.fsdp._runtime_utils import ( 25 _register_post_forward_hook, 26 _register_pre_forward_hook, 27 _register_root_pre_forward_hook, 28) 29from torch.distributed.fsdp._state_dict_utils import _register_all_state_dict_hooks 30from torch.distributed.fsdp._wrap_utils import _auto_wrap 31from torch.distributed.fsdp.api import ( 32 BackwardPrefetch, 33 CPUOffload, 34 MixedPrecision, 35 ShardingStrategy, 36) 37from torch.distributed.fsdp.wrap import _Policy 38 39 40@contract(state_cls=_FSDPState) 41@deprecated( 42 "`torch.distributed._composable.fully_shard` is being deprecated. " 43 "You can continue to use the wrapper based FSDP. " 44 "See usage in: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/fully_sharded_data_parallel.py. " 45 "`torch.distributed._composable.fully_shard` will be removed after PyTorch 2.5.", 46 category=FutureWarning, 47) 48def fully_shard( 49 module: nn.Module, 50 *, 51 process_group: Optional[dist.ProcessGroup] = None, 52 policy: Optional[_Policy] = None, 53 strategy: Optional[ShardingStrategy] = None, 54 mixed_precision: Optional[MixedPrecision] = None, 55 cpu_offload: Optional[CPUOffload] = None, 56 ignored_modules: Optional[Iterable[torch.nn.Module]] = None, 57 device_id: Optional[Union[int, torch.device]] = None, 58 param_init_fn: Optional[Callable[[nn.Module], None]] = None, 59 sync_module_states: bool = False, 60 forward_prefetch: bool = False, 61 ignored_states: Union[ 62 Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]] 63 ] = None, 64) -> nn.Module: 65 """Applies ``FullyShardedDataParallel`` (FSDP) semantics to ``module``.""" 66 torch._C._log_api_usage_once("torch.distributed.fully_shard") 67 # Enforce the new auto wrap policy 68 if policy is not None and not isinstance(policy, _Policy): 69 raise ValueError(f"Expects a `_Policy` but got {policy}") 70 state = fully_shard.state(module) 71 state = _init_ignored_module_states(state, module, ignored_modules, ignored_states) 72 state = _init_device_handle(state, module, state._ignored_params, device_id) 73 _annotate_modules_for_dynamo(module, state._ignored_modules, True) 74 state = _init_process_group_state(state, process_group, strategy, policy) 75 if policy is not None: 76 root_kwargs = { 77 "process_group": process_group, 78 "strategy": strategy, 79 "mixed_precision": mixed_precision, 80 "cpu_offload": cpu_offload, 81 "ignored_modules": ignored_modules, 82 "device_id": device_id, 83 "param_init_fn": param_init_fn, 84 "sync_module_states": sync_module_states, 85 "forward_prefetch": forward_prefetch, 86 "ignored_states": ignored_states, 87 } 88 if strategy in HYBRID_SHARDING_STRATEGIES: 89 root_kwargs["process_group"] = (state.process_group, state._inter_node_pg) 90 _auto_wrap( 91 module, 92 policy, 93 state._ignored_modules, 94 state._ignored_params, 95 root_kwargs, 96 fully_shard, 97 ) 98 state = _init_core_state( 99 state, 100 strategy or ShardingStrategy.FULL_SHARD, 101 mixed_precision, 102 cpu_offload, 103 limit_all_gathers=True, 104 use_orig_params=True, 105 backward_prefetch_limit=1, 106 forward_prefetch_limit=1, 107 ) 108 state = _init_runtime_state(state) 109 state = _init_prefetching_state( 110 state, BackwardPrefetch.BACKWARD_PRE, forward_prefetch=forward_prefetch 111 ) 112 state = _init_buffer_state(state, module) 113 state = _init_param_handle_from_module( 114 state, module, device_id, param_init_fn, sync_module_states 115 ) 116 state = _init_state_dict_state(state) 117 _register_all_state_dict_hooks(state) 118 _register_pre_forward_hook(state, module) 119 _register_post_forward_hook(state, module) 120 _register_root_pre_forward_hook(state, module) # prepend last 121 # Always insert the state for the passed-in module even if it has no 122 # managed parameters, in which case it has no handles and does not appear 123 # in `_fully_sharded_module_to_handles` 124 _insert_module_state(module, state) 125 for submodule in module.modules(): 126 if ( 127 submodule in state._fully_sharded_module_to_handle 128 and _get_module_state(submodule) is None 129 ): 130 _insert_module_state(submodule, state) 131 return module 132