1.. role:: hidden 2 :class: hidden-section 3 4Distributed Checkpoint - torch.distributed.checkpoint 5===================================================== 6 7 8Distributed Checkpoint (DCP) support loading and saving models from multiple ranks in parallel. 9It handles load-time resharding which enables saving in one cluster topology and loading into another. 10 11DCP is different than `torch.save` and `torch.load` in a few significant ways: 12 13* It produces multiple files per checkpoint, with at least one per rank. 14* It operates in place, meaning that the model should allocate its data first and DCP uses that storage instead. 15 16The entrypoints to load and save a checkpoint are the following: 17 18 19.. automodule:: torch.distributed.checkpoint 20 21.. currentmodule:: torch.distributed.checkpoint.state_dict_saver 22 23.. autofunction:: save 24.. autofunction:: async_save 25.. autofunction:: save_state_dict 26 27.. currentmodule:: torch.distributed.checkpoint.state_dict_loader 28 29.. autofunction:: load 30.. autofunction:: load_state_dict 31 32The following module is also useful for additional customization of the staging mechanisms used for asynchronous checkpointing (`torch.distributed.checkpoint.async_save`): 33 34.. automodule:: torch.distributed.checkpoint.staging 35 36.. autoclass:: torch.distributed.checkpoint.staging.AsyncStager 37 :members: 38 39.. autoclass:: torch.distributed.checkpoint.staging.BlockingAsyncStager 40 :members: 41 42In addition to the above entrypoints, `Stateful` objects, as described below, provide additional customization during saving/loading 43.. automodule:: torch.distributed.checkpoint.stateful 44 45.. autoclass:: torch.distributed.checkpoint.stateful.Stateful 46 :members: 47 48This `example <https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py>`_ shows how to use Pytorch Distributed Checkpoint to save a FSDP model. 49 50The following types define the IO interface used during checkpoint: 51 52.. autoclass:: torch.distributed.checkpoint.StorageReader 53 :members: 54 55.. autoclass:: torch.distributed.checkpoint.StorageWriter 56 :members: 57 58The following types define the planner interface used during checkpoint: 59 60.. autoclass:: torch.distributed.checkpoint.LoadPlanner 61 :members: 62 63.. autoclass:: torch.distributed.checkpoint.LoadPlan 64 :members: 65 66.. autoclass:: torch.distributed.checkpoint.ReadItem 67 :members: 68 69.. autoclass:: torch.distributed.checkpoint.SavePlanner 70 :members: 71 72.. autoclass:: torch.distributed.checkpoint.SavePlan 73 :members: 74 75.. autoclass:: torch.distributed.checkpoint.planner.WriteItem 76 :members: 77 78We provide a filesystem based storage layer: 79 80.. autoclass:: torch.distributed.checkpoint.FileSystemReader 81 :members: 82 83.. autoclass:: torch.distributed.checkpoint.FileSystemWriter 84 :members: 85 86We provide default implementations of `LoadPlanner` and `SavePlanner` that 87can handle all of torch.distributed constructs such as FSDP, DDP, ShardedTensor and DistributedTensor. 88 89.. autoclass:: torch.distributed.checkpoint.DefaultSavePlanner 90 :members: 91 92.. autoclass:: torch.distributed.checkpoint.DefaultLoadPlanner 93 :members: 94 95 96Due to legacy design decisions, the state dictionaries of `FSDP` and `DDP` may have different keys or fully qualified names (e.g., layer1.weight) even when the original unparallelized model is identical. Moreover, `FSDP` offers various types of model state dictionaries, such as full and sharded state dictionaries. Additionally, optimizer state dictionaries employ parameter IDs instead of fully qualified names to identify parameters, potentially causing issues when parallelisms are used (e.g., pipeline parallelism). 97 98To tackle these challenges, we offer a collection of APIs for users to easily manage state_dicts. `get_model_state_dict` returns a model state dictionary with keys consistent with those returned by the unparallelized model state dictionary. Similarly, `get_optimizer_state_dict` provides the optimizer state dictionary with keys uniform across all parallelisms applied. To achieve this consistency, `get_optimizer_state_dict` converts parameter IDs to fully qualified names identical to those found in the unparallelized model state dictionary. 99 100Note that results returned by these APIs can be used directly with the `torch.distributed.checkpoint.save()` and `torch.distributed.checkpoint.load()` methods without requiring any additional conversions. 101 102Note that this feature is experimental, and API signatures might change in the future. 103 104 105.. autofunction:: torch.distributed.checkpoint.state_dict.get_state_dict 106 107.. autofunction:: torch.distributed.checkpoint.state_dict.get_model_state_dict 108 109.. autofunction:: torch.distributed.checkpoint.state_dict.get_optimizer_state_dict 110 111.. autofunction:: torch.distributed.checkpoint.state_dict.set_state_dict 112 113.. autofunction:: torch.distributed.checkpoint.state_dict.set_model_state_dict 114 115.. autofunction:: torch.distributed.checkpoint.state_dict.set_optimizer_state_dict 116 117.. autoclass:: torch.distributed.checkpoint.state_dict.StateDictOptions 118 :members: 119 120For users which are used to using and sharing models in the `torch.save` format, the following methods are provided which provide offline utilities for converting betweeing formats. 121 122.. automodule:: torch.distributed.checkpoint.format_utils 123 124.. currentmodule:: torch.distributed.checkpoint.format_utils 125 126.. autofunction:: dcp_to_torch_save 127.. autofunction:: torch_save_to_dcp 128 129The following classes can also be utilized for online loading and resharding of models from the torch.save format. 130 131.. autoclass:: torch.distributed.checkpoint.format_utils.BroadcastingTorchSaveReader 132 :members: 133 134.. autoclass:: torch.distributed.checkpoint.format_utils.DynamicMetaLoadPlanner 135 :members: 136 137The following experimental interfaces are provided for improved observability in production environments: 138 139.. py:module:: torch.distributed.checkpoint.logger 140.. py:module:: torch.distributed.checkpoint.logging_handlers 141