1# mypy: allow-untyped-defs 2import collections 3import functools 4import inspect 5import warnings 6from functools import partial 7from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union 8 9import torch.nn as nn 10from torch.distributed.fsdp._common_utils import ( 11 _get_module_fsdp_state, 12 _override_module_mixed_precision, 13) 14from torch.distributed.fsdp.wrap import ( 15 _construct_wrap_fn, 16 _or_policy, 17 _Policy, 18 _post_order_apply, 19 _recursive_wrap, 20 _run_mixed_precision_override_policy, 21 _wrap_module_cls_individually, 22) 23 24 25def _auto_wrap( 26 root_module: nn.Module, 27 policy: Union[Callable, _Policy], 28 ignored_modules: Set[nn.Module], 29 ignored_params: Set[nn.Parameter], 30 root_kwargs: Dict[str, Any], 31 fsdp_fn: Callable, # e.g. `FullyShardedDataParallel` or `fully_shard` 32): 33 """ 34 Auto wraps modules in ``root_module`` 's tree according to ``policy`` 35 following a post-order traversal. 36 37 Precondition: ``root_kwargs`` should contain all arguments except 38 ``module``. This function accepts the kwargs dict directly since it gets 39 forwarded into the post-order traversal function. 40 """ 41 mixed_precision = root_kwargs["mixed_precision"] 42 is_wrapper = inspect.isclass(fsdp_fn) 43 # TODO: We may relax this no-nested-wrapping constraint to support manual 44 # wrapping followed by auto wrapping. 45 _check_nested_wrapping(root_module) 46 47 if isinstance(policy, _Policy): 48 root_kwargs["auto_wrap_policy" if is_wrapper else "policy"] = None 49 target_module_to_kwargs = policy._run_policy( 50 root_module, ignored_modules, root_kwargs 51 ) 52 if mixed_precision is not None: 53 target_module_to_kwargs = _run_mixed_precision_override_policy( 54 root_module, 55 mixed_precision._module_classes_to_ignore, 56 ignored_modules, 57 root_kwargs, 58 target_module_to_kwargs, 59 ) 60 overridden_module_classes = _override_module_mixed_precision( 61 root_module, mixed_precision._module_classes_to_ignore 62 ) 63 _warn_on_overridden_mixed_precision(overridden_module_classes) 64 use_orig_params = root_kwargs.get("use_orig_params", False) 65 _validate_frozen_params( 66 root_module, 67 set(target_module_to_kwargs.keys()), 68 ignored_params, 69 use_orig_params, 70 ) 71 wrap_fn = _construct_wrap_fn(root_module, target_module_to_kwargs, fsdp_fn) 72 _post_order_apply(root_module, wrap_fn) 73 return 74 75 recursive_wrap_kwargs = { 76 "module": root_module, 77 "auto_wrap_policy": policy, 78 "wrapper_cls": fsdp_fn, 79 "ignored_modules": ignored_modules, 80 "ignored_params": ignored_params, 81 "only_wrap_children": True, 82 } 83 if mixed_precision is not None: 84 # Wrap modules of the ignored types separately and register forward 85 # hooks to cast to fp32 and back to the original dtype, respectively 86 overridden_module_classes = _override_module_mixed_precision( 87 root_module, mixed_precision._module_classes_to_ignore 88 ) 89 policy = functools.partial( 90 _or_policy, 91 policies=[ 92 policy, 93 partial( 94 _wrap_module_cls_individually, 95 module_classes=mixed_precision._module_classes_to_ignore, 96 ), 97 ], 98 ) 99 recursive_wrap_kwargs["auto_wrap_policy"] = policy 100 _warn_on_overridden_mixed_precision(overridden_module_classes) 101 _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type] 102 103 104def _check_nested_wrapping(root_module: nn.Module): 105 for module_name, module in root_module.named_modules(): 106 if _get_module_fsdp_state(module) is not None: 107 raise ValueError( 108 "FSDP auto wrapping requires modules to not already have " 109 f"FSDP applied but found {module_name} in\n{root_module}" 110 ) 111 112 113def _warn_on_overridden_mixed_precision( 114 overridden_module_classes: Set[Type[nn.Module]], 115): 116 if len(overridden_module_classes) == 0: 117 return 118 warnings.warn( 119 "Both mixed precision and an auto_wrap_policy were specified to FSDP, " 120 f"where the wrapped module has submodules of type:\n{overridden_module_classes}\n" 121 "These modules will be wrapped as separate FSDP instacnes with mixed " 122 "precision disabled." 123 ) 124 125 126def _validate_frozen_params( 127 root_module: nn.Module, 128 modules_to_wrap: Set[nn.Module], 129 ignored_params: Set[nn.Parameter], 130 use_orig_params: bool, 131): 132 """ 133 This checks that, given ``modules_to_wrap``, each module would manage 134 parameters that are uniformly frozen or non-frozen. This uniformity 135 requirement is strict for ``use_orig_params=False`` (hard error) and highly 136 recommended for ``use_orig_params=True`` (user warning). 137 """ 138 post_order_named_modules = _get_post_order_named_modules(root_module) 139 visited_modules: Set[nn.Module] = set() 140 for module_name, module in post_order_named_modules: 141 if module in modules_to_wrap: 142 param_to_fqn = _get_managed_param_to_fqn( 143 module, ignored_params, visited_modules, module_name 144 ) 145 frozen_param_fqns: List[str] = [] 146 frozen_param_numel = 0 147 nonfrozen_param_fqns: List[str] = [] 148 nonfrozen_param_numel = 0 149 for param, fqn in param_to_fqn.items(): 150 if param.requires_grad: 151 nonfrozen_param_fqns.append(fqn) 152 nonfrozen_param_numel += param.numel() 153 else: 154 frozen_param_fqns.append(fqn) 155 frozen_param_numel += param.numel() 156 if len(frozen_param_fqns) > 0 and len(nonfrozen_param_fqns) > 0: 157 msg = f"{module_name} has both parameters with requires_grad=True and False." 158 if use_orig_params: 159 total_param_numel = frozen_param_numel + nonfrozen_param_numel 160 msg += ( 161 " We do not recommend wrapping such modules since " 162 "the gradient memory usage will be higher than expected " 163 f"({total_param_numel} numel instead of {nonfrozen_param_numel} numel " 164 "before sharding via reduce-scatter). " 165 ) 166 else: 167 msg += " FSDP does not support wrapping such modules when use_orig_params=False. " 168 msg += "If possible, wrap the frozen parameters with FSDP separately.\n" 169 msg += ( 170 f"The following parameters have requires_grad=True:\n{nonfrozen_param_fqns}\n" 171 f"The following parameters have requires_grad=False:\n{frozen_param_fqns}" 172 ) 173 if use_orig_params: 174 warnings.warn(msg) 175 else: 176 raise ValueError(msg) 177 178 179def _get_post_order_named_modules( 180 root_module: nn.Module, 181) -> List[Tuple[str, nn.Module]]: 182 """ 183 This returns the named modules following a post-order traversal, which is a 184 valid reverse topological sort. We achieve this using the reverse of a 185 stack-based DFS order instead of reversing ``root_module.named_modules()`` 186 since the former gives the modules in registration order at each level in 187 the module tree (as opposed to the reverse), which allows us to error/warn 188 on the first registered module that violates the condition. 189 190 For example, consider the following module structure: 191 M( 192 S1(), 193 S2( 194 SS1(), 195 SS2(), 196 ), 197 S3(), 198 ) 199 The reverse DFS order is [S1, SS1, SS2, S2, S3, M], while the reverse 200 ``named_modules()`` order is [S3, SS2, SS1, S2, S1, M]. 201 """ 202 visited_modules = {root_module} 203 stack = [("", root_module)] 204 # Append and reverse at the end for linear-time algorithm 205 reverse_post_order_named_modules: List[Tuple[str, nn.Module]] = [] 206 while stack: 207 module_name, module = stack.pop() 208 reverse_post_order_named_modules.append((module_name, module)) 209 for child_module_name, child_module in module.named_children(): 210 if child_module is None: # only for overrides of `named_children()` 211 continue 212 if child_module not in visited_modules: 213 visited_modules.add(child_module) 214 if module_name != "": 215 child_module_name = module_name + "." + child_module_name 216 stack.append((child_module_name, child_module)) 217 post_order_named_modules = list(reversed(reverse_post_order_named_modules)) 218 return post_order_named_modules 219 220 221def _get_managed_param_to_fqn( 222 module_to_wrap: nn.Module, 223 ignored_params: Set[nn.Parameter], 224 visited_modules: Set[nn.Module], 225 root_prefix: str, 226) -> Dict[nn.Parameter, str]: 227 """ 228 This returns a dict that maps managed parameter to its FQN for the given 229 ``module_to_wrap``. The dict's keys are exactly the parameters that would 230 be managed by the module, where this is achieved by calling this function 231 on the modules to wrap in reverse topological order, destructively updating 232 ``visited_modules``, and not traversing into those modules. The FQNs are 233 prefixed from the root (via ``root_prefix``) to be more informative. 234 235 NOTE: This function is meant to be called pre-wrapping and iteratively in 236 reverse topological order to cover the full module tree. This differs from 237 the ``_get_param_to_fqn()`` function meant to be called post-wrapping and 238 on the full module tree in one shot. Given those differences, we do not try 239 to unify the two. 240 """ 241 param_to_fqn: Dict[nn.Parameter, str] = {} 242 # Run BFS (or any tree traversal works) 243 queue = collections.deque([(module_to_wrap, root_prefix)]) 244 visited_modules.add(module_to_wrap) 245 while queue: 246 module, prefix = queue.popleft() 247 for param_name, param in module.named_parameters(recurse=False): 248 if param not in ignored_params: 249 fqn = param_name if prefix == "" else prefix + "." + param_name 250 param_to_fqn[param] = fqn 251 for child_module_name, child_module in module.named_children(): 252 if child_module is None: # only for overrides of `named_children()` 253 continue 254 if child_module not in visited_modules: 255 visited_modules.add(child_module) 256 child_prefix = ( 257 child_module_name 258 if prefix == "" 259 else prefix + "." + child_module_name 260 ) 261 queue.append((child_module, child_prefix)) 262 return param_to_fqn 263