• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import collections
3import functools
4import inspect
5import warnings
6from functools import partial
7from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union
8
9import torch.nn as nn
10from torch.distributed.fsdp._common_utils import (
11    _get_module_fsdp_state,
12    _override_module_mixed_precision,
13)
14from torch.distributed.fsdp.wrap import (
15    _construct_wrap_fn,
16    _or_policy,
17    _Policy,
18    _post_order_apply,
19    _recursive_wrap,
20    _run_mixed_precision_override_policy,
21    _wrap_module_cls_individually,
22)
23
24
25def _auto_wrap(
26    root_module: nn.Module,
27    policy: Union[Callable, _Policy],
28    ignored_modules: Set[nn.Module],
29    ignored_params: Set[nn.Parameter],
30    root_kwargs: Dict[str, Any],
31    fsdp_fn: Callable,  # e.g. `FullyShardedDataParallel` or `fully_shard`
32):
33    """
34    Auto wraps modules in ``root_module`` 's tree according to ``policy``
35    following a post-order traversal.
36
37    Precondition: ``root_kwargs`` should contain all arguments except
38    ``module``. This function accepts the kwargs dict directly since it gets
39    forwarded into the post-order traversal function.
40    """
41    mixed_precision = root_kwargs["mixed_precision"]
42    is_wrapper = inspect.isclass(fsdp_fn)
43    # TODO: We may relax this no-nested-wrapping constraint to support manual
44    # wrapping followed by auto wrapping.
45    _check_nested_wrapping(root_module)
46
47    if isinstance(policy, _Policy):
48        root_kwargs["auto_wrap_policy" if is_wrapper else "policy"] = None
49        target_module_to_kwargs = policy._run_policy(
50            root_module, ignored_modules, root_kwargs
51        )
52        if mixed_precision is not None:
53            target_module_to_kwargs = _run_mixed_precision_override_policy(
54                root_module,
55                mixed_precision._module_classes_to_ignore,
56                ignored_modules,
57                root_kwargs,
58                target_module_to_kwargs,
59            )
60            overridden_module_classes = _override_module_mixed_precision(
61                root_module, mixed_precision._module_classes_to_ignore
62            )
63            _warn_on_overridden_mixed_precision(overridden_module_classes)
64        use_orig_params = root_kwargs.get("use_orig_params", False)
65        _validate_frozen_params(
66            root_module,
67            set(target_module_to_kwargs.keys()),
68            ignored_params,
69            use_orig_params,
70        )
71        wrap_fn = _construct_wrap_fn(root_module, target_module_to_kwargs, fsdp_fn)
72        _post_order_apply(root_module, wrap_fn)
73        return
74
75    recursive_wrap_kwargs = {
76        "module": root_module,
77        "auto_wrap_policy": policy,
78        "wrapper_cls": fsdp_fn,
79        "ignored_modules": ignored_modules,
80        "ignored_params": ignored_params,
81        "only_wrap_children": True,
82    }
83    if mixed_precision is not None:
84        # Wrap modules of the ignored types separately and register forward
85        # hooks to cast to fp32 and back to the original dtype, respectively
86        overridden_module_classes = _override_module_mixed_precision(
87            root_module, mixed_precision._module_classes_to_ignore
88        )
89        policy = functools.partial(
90            _or_policy,
91            policies=[
92                policy,
93                partial(
94                    _wrap_module_cls_individually,
95                    module_classes=mixed_precision._module_classes_to_ignore,
96                ),
97            ],
98        )
99        recursive_wrap_kwargs["auto_wrap_policy"] = policy
100        _warn_on_overridden_mixed_precision(overridden_module_classes)
101    _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]
102
103
104def _check_nested_wrapping(root_module: nn.Module):
105    for module_name, module in root_module.named_modules():
106        if _get_module_fsdp_state(module) is not None:
107            raise ValueError(
108                "FSDP auto wrapping requires modules to not already have "
109                f"FSDP applied but found {module_name} in\n{root_module}"
110            )
111
112
113def _warn_on_overridden_mixed_precision(
114    overridden_module_classes: Set[Type[nn.Module]],
115):
116    if len(overridden_module_classes) == 0:
117        return
118    warnings.warn(
119        "Both mixed precision and an auto_wrap_policy were specified to FSDP, "
120        f"where the wrapped module has submodules of type:\n{overridden_module_classes}\n"
121        "These modules will be wrapped as separate FSDP instacnes with mixed "
122        "precision disabled."
123    )
124
125
126def _validate_frozen_params(
127    root_module: nn.Module,
128    modules_to_wrap: Set[nn.Module],
129    ignored_params: Set[nn.Parameter],
130    use_orig_params: bool,
131):
132    """
133    This checks that, given ``modules_to_wrap``, each module would manage
134    parameters that are uniformly frozen or non-frozen. This uniformity
135    requirement is strict for ``use_orig_params=False`` (hard error) and highly
136    recommended for ``use_orig_params=True`` (user warning).
137    """
138    post_order_named_modules = _get_post_order_named_modules(root_module)
139    visited_modules: Set[nn.Module] = set()
140    for module_name, module in post_order_named_modules:
141        if module in modules_to_wrap:
142            param_to_fqn = _get_managed_param_to_fqn(
143                module, ignored_params, visited_modules, module_name
144            )
145            frozen_param_fqns: List[str] = []
146            frozen_param_numel = 0
147            nonfrozen_param_fqns: List[str] = []
148            nonfrozen_param_numel = 0
149            for param, fqn in param_to_fqn.items():
150                if param.requires_grad:
151                    nonfrozen_param_fqns.append(fqn)
152                    nonfrozen_param_numel += param.numel()
153                else:
154                    frozen_param_fqns.append(fqn)
155                    frozen_param_numel += param.numel()
156            if len(frozen_param_fqns) > 0 and len(nonfrozen_param_fqns) > 0:
157                msg = f"{module_name} has both parameters with requires_grad=True and False."
158                if use_orig_params:
159                    total_param_numel = frozen_param_numel + nonfrozen_param_numel
160                    msg += (
161                        " We do not recommend wrapping such modules since "
162                        "the gradient memory usage will be higher than expected "
163                        f"({total_param_numel} numel instead of {nonfrozen_param_numel} numel "
164                        "before sharding via reduce-scatter). "
165                    )
166                else:
167                    msg += " FSDP does not support wrapping such modules when use_orig_params=False. "
168                msg += "If possible, wrap the frozen parameters with FSDP separately.\n"
169                msg += (
170                    f"The following parameters have requires_grad=True:\n{nonfrozen_param_fqns}\n"
171                    f"The following parameters have requires_grad=False:\n{frozen_param_fqns}"
172                )
173                if use_orig_params:
174                    warnings.warn(msg)
175                else:
176                    raise ValueError(msg)
177
178
179def _get_post_order_named_modules(
180    root_module: nn.Module,
181) -> List[Tuple[str, nn.Module]]:
182    """
183    This returns the named modules following a post-order traversal, which is a
184    valid reverse topological sort. We achieve this using the reverse of a
185    stack-based DFS order instead of reversing ``root_module.named_modules()``
186    since the former gives the modules in registration order at each level in
187    the module tree (as opposed to the reverse), which allows us to error/warn
188    on the first registered module that violates the condition.
189
190    For example, consider the following module structure:
191        M(
192          S1(),
193          S2(
194            SS1(),
195            SS2(),
196          ),
197          S3(),
198        )
199    The reverse DFS order is [S1, SS1, SS2, S2, S3, M], while the reverse
200    ``named_modules()`` order is [S3, SS2, SS1, S2, S1, M].
201    """
202    visited_modules = {root_module}
203    stack = [("", root_module)]
204    # Append and reverse at the end for linear-time algorithm
205    reverse_post_order_named_modules: List[Tuple[str, nn.Module]] = []
206    while stack:
207        module_name, module = stack.pop()
208        reverse_post_order_named_modules.append((module_name, module))
209        for child_module_name, child_module in module.named_children():
210            if child_module is None:  # only for overrides of `named_children()`
211                continue
212            if child_module not in visited_modules:
213                visited_modules.add(child_module)
214                if module_name != "":
215                    child_module_name = module_name + "." + child_module_name
216                stack.append((child_module_name, child_module))
217    post_order_named_modules = list(reversed(reverse_post_order_named_modules))
218    return post_order_named_modules
219
220
221def _get_managed_param_to_fqn(
222    module_to_wrap: nn.Module,
223    ignored_params: Set[nn.Parameter],
224    visited_modules: Set[nn.Module],
225    root_prefix: str,
226) -> Dict[nn.Parameter, str]:
227    """
228    This returns a dict that maps managed parameter to its FQN for the given
229    ``module_to_wrap``. The dict's keys are exactly the parameters that would
230    be managed by the module, where this is achieved by calling this function
231    on the modules to wrap in reverse topological order, destructively updating
232    ``visited_modules``, and not traversing into those modules. The FQNs are
233    prefixed from the root (via ``root_prefix``) to be more informative.
234
235    NOTE: This function is meant to be called pre-wrapping and iteratively in
236    reverse topological order to cover the full module tree. This differs from
237    the ``_get_param_to_fqn()`` function meant to be called post-wrapping and
238    on the full module tree in one shot. Given those differences, we do not try
239    to unify the two.
240    """
241    param_to_fqn: Dict[nn.Parameter, str] = {}
242    # Run BFS (or any tree traversal works)
243    queue = collections.deque([(module_to_wrap, root_prefix)])
244    visited_modules.add(module_to_wrap)
245    while queue:
246        module, prefix = queue.popleft()
247        for param_name, param in module.named_parameters(recurse=False):
248            if param not in ignored_params:
249                fqn = param_name if prefix == "" else prefix + "." + param_name
250                param_to_fqn[param] = fqn
251        for child_module_name, child_module in module.named_children():
252            if child_module is None:  # only for overrides of `named_children()`
253                continue
254            if child_module not in visited_modules:
255                visited_modules.add(child_module)
256                child_prefix = (
257                    child_module_name
258                    if prefix == ""
259                    else prefix + "." + child_module_name
260                )
261                queue.append((child_module, child_prefix))
262    return param_to_fqn
263