• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1.. role:: hidden
2    :class: hidden-section
3
4Distributed Checkpoint - torch.distributed.checkpoint
5=====================================================
6
7
8Distributed Checkpoint (DCP) support loading and saving models from multiple ranks in parallel.
9It handles load-time resharding which enables saving in one cluster topology and loading into another.
10
11DCP is different than `torch.save` and `torch.load` in a few significant ways:
12
13* It produces multiple files per checkpoint, with at least one per rank.
14* It operates in place, meaning that the model should allocate its data first and DCP uses that storage instead.
15
16The entrypoints to load and save a checkpoint are the following:
17
18
19.. automodule:: torch.distributed.checkpoint
20
21.. currentmodule:: torch.distributed.checkpoint.state_dict_saver
22
23.. autofunction::  save
24.. autofunction::  async_save
25.. autofunction::  save_state_dict
26
27.. currentmodule:: torch.distributed.checkpoint.state_dict_loader
28
29.. autofunction::  load
30.. autofunction::  load_state_dict
31
32The following module is also useful for additional customization of the staging mechanisms used for asynchronous checkpointing (`torch.distributed.checkpoint.async_save`):
33
34.. automodule:: torch.distributed.checkpoint.staging
35
36.. autoclass:: torch.distributed.checkpoint.staging.AsyncStager
37  :members:
38
39.. autoclass:: torch.distributed.checkpoint.staging.BlockingAsyncStager
40  :members:
41
42In addition to the above entrypoints, `Stateful` objects, as described below, provide additional customization during saving/loading
43.. automodule:: torch.distributed.checkpoint.stateful
44
45.. autoclass:: torch.distributed.checkpoint.stateful.Stateful
46  :members:
47
48This `example <https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py>`_ shows how to use Pytorch Distributed Checkpoint to save a FSDP model.
49
50The following types define the IO interface used during checkpoint:
51
52.. autoclass:: torch.distributed.checkpoint.StorageReader
53  :members:
54
55.. autoclass:: torch.distributed.checkpoint.StorageWriter
56  :members:
57
58The following types define the planner interface used during checkpoint:
59
60.. autoclass:: torch.distributed.checkpoint.LoadPlanner
61  :members:
62
63.. autoclass:: torch.distributed.checkpoint.LoadPlan
64  :members:
65
66.. autoclass:: torch.distributed.checkpoint.ReadItem
67  :members:
68
69.. autoclass:: torch.distributed.checkpoint.SavePlanner
70  :members:
71
72.. autoclass:: torch.distributed.checkpoint.SavePlan
73  :members:
74
75.. autoclass:: torch.distributed.checkpoint.planner.WriteItem
76  :members:
77
78We provide a filesystem based storage layer:
79
80.. autoclass:: torch.distributed.checkpoint.FileSystemReader
81  :members:
82
83.. autoclass:: torch.distributed.checkpoint.FileSystemWriter
84  :members:
85
86We provide default implementations of `LoadPlanner` and `SavePlanner` that
87can handle all of torch.distributed constructs such as FSDP, DDP, ShardedTensor and DistributedTensor.
88
89.. autoclass:: torch.distributed.checkpoint.DefaultSavePlanner
90  :members:
91
92.. autoclass:: torch.distributed.checkpoint.DefaultLoadPlanner
93  :members:
94
95
96Due to legacy design decisions, the state dictionaries of `FSDP` and `DDP` may have different keys or fully qualified names (e.g., layer1.weight) even when the original unparallelized model is identical. Moreover, `FSDP` offers various types of model state dictionaries, such as full and sharded state dictionaries. Additionally, optimizer state dictionaries employ parameter IDs instead of fully qualified names to identify parameters, potentially causing issues when parallelisms are used (e.g., pipeline parallelism).
97
98To tackle these challenges, we offer a collection of APIs for users to easily manage state_dicts. `get_model_state_dict` returns a model state dictionary with keys consistent with those returned by the unparallelized model state dictionary. Similarly, `get_optimizer_state_dict` provides the optimizer state dictionary with keys uniform across all parallelisms applied. To achieve this consistency, `get_optimizer_state_dict` converts parameter IDs to fully qualified names identical to those found in the unparallelized model state dictionary.
99
100Note that results returned by these APIs can be used directly with the `torch.distributed.checkpoint.save()` and `torch.distributed.checkpoint.load()` methods without requiring any additional conversions.
101
102Note that this feature is experimental, and API signatures might change in the future.
103
104
105.. autofunction:: torch.distributed.checkpoint.state_dict.get_state_dict
106
107.. autofunction:: torch.distributed.checkpoint.state_dict.get_model_state_dict
108
109.. autofunction:: torch.distributed.checkpoint.state_dict.get_optimizer_state_dict
110
111.. autofunction:: torch.distributed.checkpoint.state_dict.set_state_dict
112
113.. autofunction:: torch.distributed.checkpoint.state_dict.set_model_state_dict
114
115.. autofunction:: torch.distributed.checkpoint.state_dict.set_optimizer_state_dict
116
117.. autoclass:: torch.distributed.checkpoint.state_dict.StateDictOptions
118   :members:
119
120For users which are used to using and sharing models in the `torch.save` format, the following methods are provided which provide offline utilities for converting betweeing formats.
121
122.. automodule:: torch.distributed.checkpoint.format_utils
123
124.. currentmodule:: torch.distributed.checkpoint.format_utils
125
126.. autofunction:: dcp_to_torch_save
127.. autofunction:: torch_save_to_dcp
128
129The following classes can also be utilized for online loading and resharding of models from the torch.save format.
130
131.. autoclass:: torch.distributed.checkpoint.format_utils.BroadcastingTorchSaveReader
132   :members:
133
134.. autoclass:: torch.distributed.checkpoint.format_utils.DynamicMetaLoadPlanner
135   :members:
136
137The following experimental interfaces are provided for improved observability in production environments:
138
139.. py:module:: torch.distributed.checkpoint.logger
140.. py:module:: torch.distributed.checkpoint.logging_handlers
141