• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import contextlib
3import warnings
4from typing import cast, Generator
5
6import torch
7import torch.distributed.fsdp._traversal_utils as traversal_utils
8import torch.nn as nn
9from torch.distributed.fsdp._common_utils import (
10    _FSDPState,
11    _get_module_fsdp_state,
12    _has_fsdp_params,
13    _module_handle,
14    HandleTrainingState,
15    TrainingState,
16)
17from torch.distributed.fsdp._runtime_utils import (
18    _lazy_init,
19    _reset_flat_param_grad_info_if_needed,
20    _reshard,
21    _reshard_grads,
22    _unshard,
23    _unshard_grads,
24)
25from torch.distributed.utils import _p_assert
26
27from ._flat_param import FlatParamHandle
28
29
30FLAT_PARAM = "_flat_param"
31
32
33@torch.no_grad()
34def _writeback_to_local_shard(
35    handle: FlatParamHandle,
36    writeback_grad: bool,
37):
38    """
39    For the handle, writes back the this rank's shard of the unsharded
40    flattened parameter to the sharded flattened parameter. If
41    ``writeback_grad=True``, then writes back to the sharded gradient as
42    well.
43
44    Precondition: The handle's ``FlatParameter`` 's data points to the
45    padded unsharded flattened parameter.
46    """
47
48    def _get_shard(flat_param_or_grad: torch.Tensor) -> torch.Tensor:
49        if handle.uses_sharded_strategy:
50            # For sharded strategies, get the *unpadded* shard instead of
51            # the *padded* shard to persist user changes to the padding
52            # (though FSDP does not explicitly support this)
53            shard, _ = FlatParamHandle._get_unpadded_shard(
54                flat_param_or_grad,
55                handle.rank,
56                handle.world_size,
57            )
58            return shard
59        # For `NO_SHARD`, the `flat_param` or its gradient may be modified,
60        # so we write it back directly
61        return flat_param_or_grad
62
63    param_shard = _get_shard(handle.flat_param)
64    handle.flat_param._local_shard[: param_shard.numel()].copy_(param_shard)  # type: ignore[attr-defined]
65    if writeback_grad:
66        existing_grad = handle.sharded_grad
67        if existing_grad is not None:
68            assert handle.flat_param.grad is not None
69            grad_shard = _get_shard(handle.flat_param.grad)
70            existing_grad[: grad_shard.numel()].copy_(grad_shard)
71
72
73def _deregister_flat_param(state: _FSDPState, module: nn.Module) -> None:
74    """
75    De-registers the flattened parameter from the wrapped module, hiding it
76    from ``nn.Module`` methods.
77
78    We do not use ``del`` because we want ``FLAT_PARAM`` to always be an
79    attribute but dynamically change whether it is visible to ``nn.Module``
80    methods.
81    """
82    if _has_fsdp_params(state, module):
83        # TODO: figure out the case for the composable APIs.
84        cast(nn.Module, module.module)._parameters.pop(FLAT_PARAM, None)
85
86
87def _register_flat_param(state: _FSDPState, module: nn.Module) -> None:
88    """
89    Registers the flattened parameter to the wrapped module, making it
90    visible to ``nn.Module`` methods.
91
92    We do not use :meth:`nn.Module.register_parameter` because we want
93    ``FLAT_PARAM`` to always be an attribute but dynamically change whether
94    it is visible to ``nn.Module`` methods.
95    """
96    handle = _module_handle(state, module)
97    if _has_fsdp_params(state, module):
98        # TODO: figure out the case for the composable APIs.
99        cast(nn.Module, module.module)._parameters[FLAT_PARAM] = handle.flat_param
100
101
102@contextlib.contextmanager
103def _unflatten_as_params(state: _FSDPState, module: nn.Module) -> Generator:
104    """
105    Assumes that the flattened parameter is unsharded. When in the context,
106    de-registers the flattened parameter and unflattens the original
107    parameters as ``nn.Parameter`` views into the flattened parameter.
108    After the context, re-registers the flattened parameter and restores
109    the original parameters as ``Tensor`` views into the flattened
110    parameter.
111    """
112    handle = _module_handle(state, module)
113    if not handle:
114        yield
115    else:
116        _deregister_flat_param(state, module)
117        try:
118            with handle.unflatten_as_params():
119                yield
120        finally:
121            if not handle._use_orig_params:
122                _register_flat_param(state, module)
123
124
125def _validate_unshard_params_args(
126    state: _FSDPState,
127    writeback: bool,
128    rank0_only: bool,
129    offload_to_cpu: bool,
130    with_grads: bool,
131) -> None:
132    if with_grads and (offload_to_cpu or not state._use_orig_params):
133        raise NotImplementedError(
134            f"with_grads={with_grads}, "
135            f"use_orig_params={state._use_orig_params}, "
136            f"offload_to_cpu={offload_to_cpu} "
137            f"is not supported yet"
138        )
139    if offload_to_cpu and state._handle and (not state._handle.uses_sharded_strategy):
140        raise NotImplementedError(
141            "offload_to_cpu=True and NO_SHARD is not supported yet"
142        )
143    if writeback and rank0_only:
144        # TODO: Rank 0 can broadcast the `FlatParameter` to allow all ranks to
145        # persist the changes.
146        raise NotImplementedError(
147            "writeback=True and rank0_only=True is not supported yet"
148        )
149    if offload_to_cpu and not rank0_only:
150        warnings.warn(
151            "offload_to_cpu=True and rank0_only=False may result in the"
152            "unsharded parameters being redundantly copied to CPU memory for "
153            "GPUs sharing the same CPU memory, which risks CPU OOM. We "
154            "recommend using offload_to_cpu=True with rank0_only=True."
155        )
156
157
158@contextlib.contextmanager
159def _unshard_fsdp_state_params(
160    module: nn.Module,
161    state: _FSDPState,
162    writeback: bool,
163    rank0_only: bool,
164    offload_to_cpu: bool,
165    with_grads: bool,
166):
167    """
168    This unshards the parameters for a single FSDP state ``state`` that
169    corresponds to ``module``.
170    """
171    _validate_unshard_params_args(
172        state, writeback, rank0_only, offload_to_cpu, with_grads
173    )
174    state._device_handle.synchronize()
175    # If handles are shared by other module(s), the handle may be already unsharded.
176    maybe_handle = _module_handle(state, module)
177    handle = None
178    if (
179        maybe_handle
180        and maybe_handle._training_state != HandleTrainingState.SUMMON_FULL_PARAMS
181    ):
182        handle = maybe_handle
183    if not handle:
184        yield
185        return
186
187    assert (
188        handle._training_state == HandleTrainingState.IDLE
189    ), f"Expects the handle training to be IDLE but got {handle._training_state}"
190
191    handle._training_state = HandleTrainingState.SUMMON_FULL_PARAMS
192
193    _reset_flat_param_grad_info_if_needed(handle)
194    free_unsharded_flat_param = handle.needs_unshard()
195    # No need to call `wait_stream()` since we unshard in the computation
196    # stream directly
197    computation_stream = state._device_handle.current_stream()
198    _unshard(state, handle, computation_stream, computation_stream)
199    if with_grads:
200        _unshard_grads(handle)
201
202    if rank0_only and state.rank != 0:
203        # Free the unsharded flattened parameter early
204        _reshard(state, handle, free_unsharded_flat_param)
205        if with_grads:
206            _reshard_grads(handle)
207        try:
208            yield
209        finally:
210            handle._training_state = HandleTrainingState.IDLE
211    else:
212        # Unflatten the unsharded flattened parameters
213        with contextlib.ExitStack() as stack:
214            # Invariant: rank == 0 or !rank0_only
215            if offload_to_cpu and handle.uses_sharded_strategy:
216                stack.enter_context(handle.to_cpu())
217                # NOTE: Since PyTorch enforces that a parameter and its
218                # gradients need to match metadata (e.g. device), we must
219                # move gradients to CPU *after* we move parameters.
220            # NOTE: This assumes 1 `FlatParameter`
221            if not state._use_orig_params:
222                stack.enter_context(_unflatten_as_params(state, module))
223            try:
224                yield
225            finally:
226                stack.close()
227                if writeback:
228                    _writeback_to_local_shard(handle, with_grads)
229                _reshard(state, handle, free_unsharded_flat_param)
230                if with_grads:
231                    _reshard_grads(handle)
232                handle._training_state = HandleTrainingState.IDLE
233
234
235@contextlib.contextmanager
236def _unshard_params_for_summon(
237    module: nn.Module,
238    state: _FSDPState,
239    writeback: bool,
240    rank0_only: bool,
241    offload_to_cpu: bool,
242    with_grads: bool,
243):
244    _validate_unshard_params_args(
245        state, writeback, rank0_only, offload_to_cpu, with_grads
246    )
247    _lazy_init(state, module)
248    if state.training_state == TrainingState.FORWARD_BACKWARD:
249        raise AssertionError(
250            "Cannot manually unshard parameters during forward/backward"
251        )
252    elif state.training_state == TrainingState.SUMMON_FULL_PARAMS:
253        raise AssertionError(
254            "Cannot manually unshard parameters when already unsharding parameters"
255        )
256    with _unshard_fsdp_state_params(
257        module=module,
258        state=state,
259        writeback=writeback,
260        rank0_only=rank0_only,
261        offload_to_cpu=offload_to_cpu,
262        with_grads=with_grads,
263    ):
264        try:
265            state.training_state = TrainingState.SUMMON_FULL_PARAMS
266            yield
267        finally:
268            state.training_state = TrainingState.IDLE
269
270
271@contextlib.contextmanager
272def _unshard_params(
273    module: nn.Module,
274    recurse: bool,
275    writeback: bool,
276    rank0_only: bool,
277    offload_to_cpu: bool,
278    with_grads: bool,
279):
280    """
281    This unshards FSDP-managed parameters for all modules with FSDP applied in
282    the module tree rooted at ``module``.
283    """
284    if not recurse:
285        optional_state = _get_module_fsdp_state(module)
286        if optional_state is None:
287            with contextlib.nullcontext():
288                yield
289            return
290        states_and_modules = ([optional_state], [module])
291    else:
292        states_and_modules = traversal_utils._get_fsdp_states_with_modules(module)
293    with contextlib.ExitStack() as stack:
294        for state, module in zip(*states_and_modules):
295            stack.enter_context(
296                _unshard_params_for_summon(
297                    module=module,
298                    state=state,
299                    writeback=writeback,
300                    rank0_only=rank0_only,
301                    offload_to_cpu=offload_to_cpu,
302                    with_grads=with_grads,
303                )
304            )
305        yield
306
307
308def _deregister_orig_params(state: _FSDPState, module: nn.Module) -> None:
309    """
310    Deregisters the original parameters; registers the ``FlatParameter``.
311    """
312    handle = _module_handle(state, module)
313    if not handle:
314        return
315    _p_assert(
316        handle._use_orig_params,
317        f"Inconsistent `_use_orig_params` -- FSDP: {state._use_orig_params} "
318        f"handle: {handle._use_orig_params}",
319    )
320    handle._deregister_orig_params()
321    _register_flat_param(state, module)
322
323
324def _register_orig_params(state: _FSDPState, module: nn.Module) -> None:
325    """
326    Deregisters the ``FlatParameter``; registers the original parameters.
327    """
328    handle = _module_handle(state, module)
329    if not handle:
330        return
331    _deregister_flat_param(state, module)
332    if handle.is_sharded(handle.flat_param):
333        handle._use_sharded_views()
334        handle._use_sharded_grad_views()
335    else:
336        handle._use_unsharded_views(as_params=True)
337