1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import os 4import warnings 5from typing import Any, cast, Dict, Optional, Set, Union 6from typing_extensions import deprecated 7 8import torch 9import torch.distributed as dist 10from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner 11from torch.distributed.checkpoint.logger import _dcp_method_logger 12from torch.distributed.checkpoint.stateful import Stateful 13 14from ._storage_utils import _storage_setup 15from .default_planner import DefaultLoadPlanner 16from .planner import LoadPlan, LoadPlanner 17from .storage import StorageReader 18from .utils import _all_gather_keys, _api_bc_check, _DistWrapper, _profile 19 20 21__all__ = ["load_state_dict", "load"] 22 23 24@deprecated( 25 "`load_state_dict` is deprecated and will be removed in future versions. " 26 "Please use `load` instead.", 27 category=FutureWarning, 28) 29def load_state_dict( 30 state_dict: Dict[str, Any], 31 storage_reader: StorageReader, 32 process_group: Optional[dist.ProcessGroup] = None, 33 coordinator_rank: int = 0, 34 no_dist: bool = False, 35 planner: Optional[LoadPlanner] = None, 36) -> None: 37 """This method is deprecated. Please switch to 'load'.""" 38 storage_reader.reset() 39 with _profile(): 40 # TODO: test returning `load` here instead. 41 return _load_state_dict( 42 state_dict, 43 storage_reader, 44 process_group, 45 coordinator_rank, 46 no_dist, 47 planner, 48 ) 49 50 51@_dcp_method_logger(log_exceptions=True) 52@_api_bc_check 53def load( 54 state_dict: Dict[str, Any], 55 *, 56 checkpoint_id: Union[str, os.PathLike, None] = None, 57 storage_reader: Optional[StorageReader] = None, 58 planner: Optional[LoadPlanner] = None, 59 process_group: Optional[dist.ProcessGroup] = None, 60) -> None: 61 """ 62 Load a distributed ``state_dict`` in SPMD style. 63 64 Each rank will try to read the least amount of data necessary 65 to fullfill the requested `state_dict`. When loading :class:`ShardedTensor` 66 or :class:`DTensor` instances, each rank only reads data for their local shards. 67 68 For each ``Stateful`` object (having both a ``state_dict`` and a ``load_state_dict``), 69 load will first call ``state_dict`` before attempting deserialization, followed by 70 ``load_state_dict`` once the deserialization is complete. 71 72 .. warning:: 73 All tensors in ``state_dict`` must be allocated on their 74 destination device *prior to* calling this function. 75 76 All non-tensor data is loaded using `torch.load()` and modified in place 77 on state_dict. 78 79 .. warning:: 80 Users must call `load_state_dict` on the root module to ensure load 81 pos-processing and non-tensor data properly propagates. 82 83 .. note: 84 If no process group is initialized, this function will assume the intent 85 is to load a checkpoint into the local process. This can be useful in the 86 case of local inference, and when using regular Tensors (as opposed to DTensor 87 or ShardedTensor) 88 89 .. note: 90 Rank 0 is assumed to be the coordinator rank. 91 92 Args: 93 state_dict (Dict[str, Any]): The state_dict to save. 94 checkpoint_id (Union[str, os.PathLike, None]): 95 The ID of this checkpoint instance. The meaning of the checkpoint_id 96 depends on the storage. It can be a path to a folder or to a file. 97 It can also be a key if the storage is a key-value store. 98 (Default: ``None``) 99 storage_reader (Optional[StorageReader]): 100 Instance of StorageWriter used to perform reads. If this is not 101 specified, DCP will automatically infer the reader based on the 102 checkpoint_id. If checkpoint_id is also None, an exception will 103 be raised. (Default: ``None``) 104 planner (Optional[LoadPlanner]): 105 Instance of LoadPlanner. If this is not specificed, the default 106 planner will be used. (Default: ``None``) 107 process_group (Optional[ProcessGroup]): 108 ProcessGroup to be used for cross-rank synchronization. 109 (Default: ``None``) 110 111 Returns: 112 None. 113 114 Examples 115 >>> # xdoctest: +SKIP 116 >>> my_model = MyModule() 117 >>> optimizer = Adagrad(my_model.parameters()) 118 >>> model_state_dict = my_model.state_dict() 119 >>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader("/checkpoint/1") 120 121 >>> torch.distributed.checkpoint.load_state_dict( 122 >>> state_dict=model_state_dict, 123 >>> storage_reader=fs_storage_reader, 124 >>> ) 125 126 >>> # module.load_state_dict() function might have customized steps 127 >>> # to flush the state_dict, must call it to 128 >>> # ensure correct behavior. 129 >>> my_model.load_state_dict(model_state_dict) 130 131 .. note:: 132 load_state_dict uses collectives to coordinate reads across ranks. 133 For NCCL-based process groups, internal tensor representations of 134 objects must be moved to the GPU device before communication takes place. 135 In this case, the device used is given by ``torch.cuda.current_device()`` 136 and it is the user's responsibility to ensure that this is set so that each 137 rank has an individual GPU, via ``torch.cuda.set_device()``. 138 """ 139 140 no_dist = not (dist.is_available() and dist.is_initialized()) 141 if no_dist: 142 warnings.warn( 143 "torch.distributed is unavailable or uninitialized, assuming the intent is to load in a single process." 144 ) 145 146 with _profile(): 147 storage_reader = cast( 148 StorageReader, _storage_setup(storage_reader, checkpoint_id, reader=True) 149 ) 150 151 if no_dist: 152 keys = list(state_dict.keys()) 153 else: 154 keys = _all_gather_keys(state_dict, process_group) 155 if keys != sorted(state_dict.keys()): 156 warnings.warn( 157 "Detected mismatched keys in state dict after all gather!" 158 " This behavior is unsupported and may cause errors may cause errors." 159 ) 160 161 statetful_sd = {} 162 for key in keys: 163 if key not in state_dict: 164 continue 165 elem = state_dict[key] 166 statetful_sd[key] = ( 167 elem.state_dict() if isinstance(elem, Stateful) else elem 168 ) 169 170 _load_state_dict( 171 state_dict=statetful_sd, 172 storage_reader=storage_reader, 173 process_group=process_group, 174 no_dist=no_dist, 175 planner=planner, 176 ) 177 for key in keys: 178 if key not in state_dict: 179 continue 180 elem = state_dict[key] 181 if isinstance(elem, Stateful): 182 elem.load_state_dict(statetful_sd[key]) 183 state_dict[key] = statetful_sd[key] 184 185 186def _load_state_dict( 187 state_dict: Dict[str, Any], 188 storage_reader: StorageReader, 189 process_group: Optional[dist.ProcessGroup] = None, 190 coordinator_rank: int = 0, 191 no_dist: bool = False, 192 planner: Optional[LoadPlanner] = None, 193) -> None: 194 torch._C._log_api_usage_once("torch.distributed.checkpoint.load_state_dict") 195 196 distW = _DistWrapper(process_group, not no_dist, coordinator_rank) 197 if planner is None: 198 planner = DefaultLoadPlanner() 199 200 ckpt_kwargs = {} 201 if (ckpt_id := getattr(storage_reader, "checkpoint_id", None)) is not None: 202 ckpt_kwargs["checkpoint_id"] = ckpt_id 203 204 @_dcp_method_logger(**ckpt_kwargs) 205 def local_step(): 206 assert planner is not None 207 metadata = storage_reader.read_metadata() 208 planner.set_up_planner(state_dict, metadata, distW.is_coordinator) 209 storage_reader.set_up_storage_reader(metadata, distW.is_coordinator) 210 211 local_plan = planner.create_local_plan() 212 local_plan = storage_reader.prepare_local_plan(local_plan) 213 return local_plan 214 215 @_dcp_method_logger(**ckpt_kwargs) 216 def global_step(all_local_plans): 217 assert planner is not None 218 all_local_plans = planner.create_global_plan(all_local_plans) 219 all_local_plans = storage_reader.prepare_global_plan(all_local_plans) 220 return all_local_plans 221 222 central_plan: LoadPlan = distW.reduce_scatter("plan", local_step, global_step) 223 224 @_dcp_method_logger(**ckpt_kwargs) 225 def read_data(): 226 assert planner is not None 227 final_local_plan = planner.finish_plan(central_plan) 228 all_reads = storage_reader.read_data(final_local_plan, planner) 229 230 all_reads.wait() 231 return None 232 233 _ = distW.all_gather("read", read_data) 234 235 236def _load_state_dict_from_keys( 237 keys: Optional[Union[Set[str], str]] = None, 238 *, 239 checkpoint_id: Union[str, os.PathLike, None] = None, 240 storage_reader: Optional[StorageReader] = None, 241 process_group: Optional[dist.ProcessGroup] = None, 242) -> Dict[str, Any]: 243 """ 244 Load only the specified keys from the checkpoint, if no keys are specified, the entire 245 checkpoint will be loaded. Note, this method completely loads the checkpoint into the 246 current process and is not distributed. 247 248 .. warning:: 249 250 251 .. warning:: 252 253 All non-tensor data is loaded using `torch.load()` 254 255 .. note: 256 As opposed to the usual pattern, this function does not take a state dict as input 257 and does not load inplace. Instead, a new state dict is directly initialized and read 258 from file. 259 260 .. note: 261 If no process group is initialized, this function will assume the intent 262 is to load a checkpoint into the local process. This can be useful in the 263 case of local inference, and when using regular Tensors (as opposed to DTensor 264 or ShardedTensor) 265 266 .. note: 267 Rank 0 is assumed to be the coordinator rank. 268 269 Args: 270 keys (Optional[Union[Set[str], str]]): 271 Loads any key specified in this set. If no keys are specified, the entire checkpoint 272 is loaded. 273 checkpoint_id (Union[str, os.PathLike, None]): 274 The ID of this checkpoint instance. The meaning of the checkpoint_id 275 depends on the storage. It can be a path to a folder or to a file. 276 It can also be a key if the storage is a key-value store. 277 (Default: ``None``) 278 storage_reader (Optional[StorageReader]): 279 Instance of StorageWriter used to perform reads. If this is not 280 specified, DCP will automatically infer the reader based on the 281 checkpoint_id. If checkpoint_id is also None, an exception will 282 be raised. (Default: ``None``) 283 process_group (Optional[ProcessGroup]): 284 ProcessGroup to be used for cross-rank synchronization. 285 (Default: ``None``) 286 287 Returns: 288 State dict from specified keys 289 """ 290 torch._C._log_api_usage_once( 291 "torch.distributed.checkpoint._load_state_dict_from_keys" 292 ) 293 294 no_dist = not (dist.is_available() and dist.is_initialized()) 295 if no_dist: 296 warnings.warn( 297 "torch.distributed is unavailable or uninitialized, assuming the intent is to load in a single process." 298 ) 299 300 storage_reader = cast( 301 StorageReader, _storage_setup(storage_reader, checkpoint_id, reader=True) 302 ) 303 304 if isinstance(keys, str): 305 keys = {keys} 306 307 sd: Dict[str, Any] = {} 308 _load_state_dict( 309 state_dict=sd, 310 storage_reader=storage_reader, 311 process_group=process_group, 312 no_dist=no_dist, 313 planner=_EmptyStateDictLoadPlanner(keys=keys or set()), 314 ) 315 316 return sd 317