1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import inspect 4import os 5import warnings 6from concurrent.futures import Future, ThreadPoolExecutor 7from typing import cast, Optional, Union 8from typing_extensions import deprecated 9 10import torch 11import torch.distributed as dist 12from torch.distributed._state_dict_utils import _offload_state_dict_to_cpu 13from torch.distributed.checkpoint._storage_utils import _storage_setup 14from torch.distributed.checkpoint.default_planner import DefaultSavePlanner 15from torch.distributed.checkpoint.logger import _dcp_method_logger 16from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE 17from torch.distributed.checkpoint.planner import SavePlan, SavePlanner 18from torch.distributed.checkpoint.staging import AsyncStager 19from torch.distributed.checkpoint.stateful import Stateful 20from torch.distributed.checkpoint.storage import StorageWriter 21from torch.distributed.distributed_c10d import _get_default_group 22 23from .utils import _api_bc_check, _DistWrapper, _profile 24 25 26__all__ = ["save_state_dict", "save", "async_save"] 27 28 29@deprecated( 30 "`save_state_dict` is deprecated and will be removed in future versions." 31 "Please use `save` instead.", 32 category=FutureWarning, 33) 34def save_state_dict( 35 state_dict: STATE_DICT_TYPE, 36 storage_writer: StorageWriter, 37 process_group: Optional[dist.ProcessGroup] = None, 38 coordinator_rank: int = 0, 39 no_dist: bool = False, 40 planner: Optional[SavePlanner] = None, 41) -> Metadata: 42 """This method is deprecated. Please switch to 'save'.""" 43 storage_writer.reset() 44 45 # TODO: test returning `save` here instead. 46 with _profile(): 47 return _save_state_dict( 48 state_dict, 49 storage_writer, 50 process_group, 51 coordinator_rank, 52 no_dist, 53 planner, 54 ) 55 56 57@_dcp_method_logger(log_exceptions=True) # type: ignore[arg-type] 58@_api_bc_check 59def save( 60 state_dict: STATE_DICT_TYPE, 61 *, 62 checkpoint_id: Union[str, os.PathLike, None] = None, 63 storage_writer: Optional[StorageWriter] = None, 64 planner: Optional[SavePlanner] = None, 65 process_group: Optional[dist.ProcessGroup] = None, 66) -> Metadata: 67 """ 68 Save a distributed model in SPMD style. 69 70 This function is different from ``torch.save()`` as it handles 71 ``ShardedTensor`` , and ``DTensor`` by having each rank only save their local shards. 72 73 For each ``Stateful`` object (having both a ``state_dict`` and a ``load_state_dict``), 74 save will call ``state_dict`` before serialization. 75 76 .. warning:: 77 There is no guarantees of Backwards Compatibility across PyTorch versions 78 for saved state_dicts. 79 80 .. warning:: 81 If using the `process_group` argument, make sure that only its ranks 82 call `save_state_dict` and that all data in state_dict belong to it. 83 84 .. note:: 85 When saving checkpoint for FSDP's `ShardingStrategy.HYBRID_SHARD`, only one of 86 the shard_group should be calling `save_state_dict` and the corresponding process 87 group needs to be passed in. 88 89 .. note:: 90 If no process group is available, this function assumes the intention is to save the 91 state_dict in the local process. 92 93 .. note: 94 Rank 0 is assumed to be the coordinator rank. 95 96 97 Args: 98 state_dict (Dict[str, Any]): The state_dict to save. 99 checkpoint_id (Union[str, os.PathLike, None]): 100 The ID of this checkpoint instance. The meaning of the checkpoint_id 101 depends on the storage. It can be a path to a folder or to a file. 102 It can also be a key if the storage is a key-value store. 103 (Default: ``None``) 104 storage_writer (Optional[StorageWriter]): 105 Instance of StorageWriter used to perform writes. If this is not 106 specified, DCP will automatically infer the writer based on the 107 checkpoint_id. If checkpoint_id is also None, an exception will 108 be raised. (Default: ``None``) 109 planner (Optional[SavePlanner]): 110 Instance of SavePlanner. If this is not specificed, the default 111 planner will be used. (Default: ``None``) 112 process_group (Optional[ProcessGroup]): 113 ProcessGroup to be used for cross-rank synchronization. 114 (Default: ``None``) 115 116 Returns: 117 Metadata: Metadata object for the saved checkpoint. 118 119 Example: 120 >>> # xdoctest: +SKIP 121 >>> my_model = MyModule() 122 123 >>> state_dict = {"model": my_model} 124 125 >>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1") 126 >>> torch.distributed.checkpoint.save( 127 >>> state_dict=state_dict, 128 >>> storage_writer=fs_storage_writer, 129 >>> ) 130 131 .. note:: 132 save_state_dict uses collectives to coordinate writes across ranks. 133 For NCCL-based process groups, internal tensor representations of 134 objects must be moved to the GPU device before communication takes place. 135 In this case, the device used is given by ``torch.cuda.current_device()`` 136 and it is the user's responsibility to ensure that this is set so that 137 each rank has an individual GPU, via ``torch.cuda.set_device()``. 138 """ 139 torch._C._log_api_usage_once("torch.distributed.checkpoint.save") 140 141 no_dist = not (dist.is_available() and dist.is_initialized()) 142 if no_dist: 143 warnings.warn( 144 "torch.distributed is unavailable or uninitialized, assuming the intent is to save in a single process." 145 ) 146 147 with _profile(): 148 storage_writer = cast( 149 StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False) 150 ) 151 152 return _save_state_dict( 153 state_dict=_stateful_to_state_dict(state_dict), 154 storage_writer=storage_writer, 155 process_group=process_group, 156 no_dist=no_dist, 157 planner=planner, 158 ) 159 160 161@_dcp_method_logger(log_exceptions=True) 162def async_save( 163 state_dict: STATE_DICT_TYPE, 164 *, 165 checkpoint_id: Union[str, os.PathLike, None] = None, 166 storage_writer: Optional[StorageWriter] = None, 167 planner: Optional[SavePlanner] = None, 168 process_group: Optional[dist.ProcessGroup] = None, 169) -> Future: 170 """Asynchronous version of ``save``. This code first de-stages the state_dict on to the 171 staging storage (defaults to CPU memory), and then calls the `save` in a separate thread. 172 173 .. warning:: 174 This feature is experimental and subject to change. 175 176 Args: 177 state_dict (Dict[str, Any]): The state_dict to save. 178 checkpoint_id (Union[str, os.PathLike, None]): 179 The ID of this checkpoint instance. The meaning of the checkpoint_id 180 depends on the storage. It can be a path to a folder or to a file. 181 It can also be a key if the storage is a key-value store. 182 (Default: ``None``) 183 storage_writer (Optional[StorageWriter]): 184 Instance of StorageWriter used to perform 'stage' and 'save'. If 185 this is not specified, DCP will automatically infer the writer based on the 186 checkpoint_id. If checkpoint_id is also None, an exception will 187 be raised. (Default: ``None``) 188 planner (Optional[SavePlanner]): 189 Instance of SavePlanner. If this is not specificed, the default 190 planner will be used. (Default: ``None``) 191 process_group (Optional[ProcessGroup]): 192 ProcessGroup to be used for cross-rank synchronization. 193 (Default: ``None``) 194 195 Returns: 196 Future: A future holding the resultant Metadata object from `save`. 197 198 Example: 199 >>> # xdoctest: +SKIP 200 >>> my_model = MyModule() 201 202 >>> state_dict = {"model": my_model} 203 204 >>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1") 205 >>> checkpoint_future = torch.distributed.checkpoint.async_save( 206 >>> state_dict=state_dict, 207 >>> storage_writer=fs_storage_writer, 208 >>> ) 209 >>> 210 >>> # ... do some work ... 211 >>> 212 >>> checkpoint_future.result() 213 214 """ 215 torch._C._log_api_usage_once("torch.distributed.checkpoint.async_save") 216 217 if dist.is_available() and dist.is_initialized(): 218 pg = process_group or _get_default_group() 219 assert ( 220 torch.device("cpu") in pg._device_types # type: ignore[attr-defined] 221 ), "A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'" 222 223 storage_writer = cast( 224 StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False) 225 ) 226 227 state_dict = _stateful_to_state_dict(state_dict) 228 if isinstance(storage_writer, AsyncStager): 229 staged_state_dict = storage_writer.stage(state_dict) 230 else: # provides bwc for storage_writers not implementing AsyncStager 231 staged_state_dict = _offload_state_dict_to_cpu(state_dict, type_check=False) 232 233 executor = ThreadPoolExecutor(max_workers=1) 234 f: Future = executor.submit( 235 save, 236 staged_state_dict, 237 checkpoint_id=checkpoint_id, 238 storage_writer=storage_writer, 239 planner=planner, 240 process_group=process_group, 241 ) 242 f.add_done_callback(lambda f: executor.shutdown(wait=False)) 243 244 if ( 245 isinstance(storage_writer, AsyncStager) 246 and storage_writer.should_synchronize_after_execute 247 ): 248 storage_writer.synchronize_staging() 249 250 return f 251 252 253def _stateful_to_state_dict(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: 254 """Creates a shallow copy of `state_dict` where `state_dict` is called for each Stateful object.""" 255 stateful_state_dict = {} 256 for key, elem in state_dict.items(): 257 stateful_state_dict[key] = ( 258 elem.state_dict() if isinstance(elem, Stateful) else elem 259 ) 260 return stateful_state_dict 261 262 263def _save_state_dict( 264 state_dict: STATE_DICT_TYPE, 265 storage_writer: StorageWriter, 266 process_group: Optional[dist.ProcessGroup] = None, 267 coordinator_rank: int = 0, 268 no_dist: bool = False, 269 planner: Optional[SavePlanner] = None, 270) -> Metadata: 271 torch._C._log_api_usage_once("torch.distributed.checkpoint.save_state_dict") 272 273 distW = _DistWrapper(process_group, not no_dist, coordinator_rank) 274 if planner is None: 275 planner = DefaultSavePlanner() 276 assert planner is not None 277 278 global_metadata = None 279 280 ckpt_kwargs = {} 281 if (ckpt_id := getattr(storage_writer, "checkpoint_id", None)) is not None: 282 ckpt_kwargs["checkpoint_id"] = ckpt_id 283 284 @_dcp_method_logger(**ckpt_kwargs) 285 def local_step(): 286 assert planner is not None 287 storage_meta = storage_writer.storage_meta() 288 if "storage_meta" not in inspect.signature(planner.set_up_planner).parameters: 289 warnings.warn( 290 "The function definition for SavePlanner.set_up_planner has been updated" 291 " to include the storage_meta argument. Please update your implementation" 292 " to include this parameter." 293 ) 294 planner.set_up_planner(state_dict, distW.is_coordinator) # type: ignore[call-arg, arg-type] 295 else: 296 planner.set_up_planner( 297 state_dict=state_dict, 298 storage_meta=storage_meta, 299 is_coordinator=distW.is_coordinator, 300 ) 301 storage_writer.set_up_storage_writer(distW.is_coordinator) 302 303 local_plan = planner.create_local_plan() 304 local_plan = storage_writer.prepare_local_plan(local_plan) 305 return local_plan 306 307 @_dcp_method_logger(**ckpt_kwargs) 308 def global_step(all_local_plans): 309 nonlocal global_metadata 310 311 assert planner is not None 312 all_local_plans, global_metadata = planner.create_global_plan(all_local_plans) 313 all_local_plans = storage_writer.prepare_global_plan(all_local_plans) 314 return all_local_plans 315 316 central_plan: SavePlan = distW.reduce_scatter("plan", local_step, global_step) 317 318 @_dcp_method_logger(**ckpt_kwargs) 319 def write_data(): 320 assert planner is not None 321 final_local_plan = planner.finish_plan(central_plan) 322 all_writes = storage_writer.write_data(final_local_plan, planner) 323 324 all_writes.wait() 325 return all_writes.value() 326 327 @_dcp_method_logger(**ckpt_kwargs) 328 def finish_checkpoint(all_results): 329 assert global_metadata is not None 330 storage_writer.finish(metadata=global_metadata, results=all_results) 331 return global_metadata 332 333 return distW.all_reduce("write", write_data, finish_checkpoint) 334