• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import os
4import warnings
5from typing import Any, cast, Dict, Optional, Set, Union
6from typing_extensions import deprecated
7
8import torch
9import torch.distributed as dist
10from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
11from torch.distributed.checkpoint.logger import _dcp_method_logger
12from torch.distributed.checkpoint.stateful import Stateful
13
14from ._storage_utils import _storage_setup
15from .default_planner import DefaultLoadPlanner
16from .planner import LoadPlan, LoadPlanner
17from .storage import StorageReader
18from .utils import _all_gather_keys, _api_bc_check, _DistWrapper, _profile
19
20
21__all__ = ["load_state_dict", "load"]
22
23
24@deprecated(
25    "`load_state_dict` is deprecated and will be removed in future versions. "
26    "Please use `load` instead.",
27    category=FutureWarning,
28)
29def load_state_dict(
30    state_dict: Dict[str, Any],
31    storage_reader: StorageReader,
32    process_group: Optional[dist.ProcessGroup] = None,
33    coordinator_rank: int = 0,
34    no_dist: bool = False,
35    planner: Optional[LoadPlanner] = None,
36) -> None:
37    """This method is deprecated. Please switch to 'load'."""
38    storage_reader.reset()
39    with _profile():
40        # TODO: test returning `load` here instead.
41        return _load_state_dict(
42            state_dict,
43            storage_reader,
44            process_group,
45            coordinator_rank,
46            no_dist,
47            planner,
48        )
49
50
51@_dcp_method_logger(log_exceptions=True)
52@_api_bc_check
53def load(
54    state_dict: Dict[str, Any],
55    *,
56    checkpoint_id: Union[str, os.PathLike, None] = None,
57    storage_reader: Optional[StorageReader] = None,
58    planner: Optional[LoadPlanner] = None,
59    process_group: Optional[dist.ProcessGroup] = None,
60) -> None:
61    """
62    Load a distributed ``state_dict`` in SPMD style.
63
64    Each rank will try to read the least amount of data necessary
65    to fullfill the requested `state_dict`. When loading :class:`ShardedTensor`
66    or :class:`DTensor` instances, each rank only reads data for their local shards.
67
68    For each ``Stateful`` object (having both a ``state_dict`` and a ``load_state_dict``),
69    load will first call ``state_dict`` before attempting deserialization, followed by
70    ``load_state_dict`` once the deserialization is complete.
71
72    .. warning::
73        All tensors in ``state_dict`` must be allocated on their
74        destination device *prior to* calling this function.
75
76        All non-tensor data is loaded using `torch.load()` and modified in place
77        on state_dict.
78
79    .. warning::
80        Users must call `load_state_dict` on the root module to ensure load
81        pos-processing and non-tensor data properly propagates.
82
83    .. note:
84        If no process group is initialized, this function will assume the intent
85        is to load a checkpoint into the local process. This can be useful in the
86        case of local inference, and when using regular Tensors (as opposed to DTensor
87         or ShardedTensor)
88
89    .. note:
90        Rank 0 is assumed to be the coordinator rank.
91
92    Args:
93        state_dict (Dict[str, Any]): The state_dict to save.
94        checkpoint_id (Union[str, os.PathLike, None]):
95            The ID of this checkpoint instance. The meaning of the checkpoint_id
96            depends on the storage. It can be a path to a folder or to a file.
97            It can also be a key if the storage is a key-value store.
98            (Default: ``None``)
99        storage_reader (Optional[StorageReader]):
100            Instance of StorageWriter used to perform reads. If this is not
101            specified, DCP will automatically infer the reader based on the
102            checkpoint_id. If checkpoint_id is also None, an exception will
103            be raised. (Default: ``None``)
104        planner (Optional[LoadPlanner]):
105            Instance of LoadPlanner. If this is not specificed, the default
106            planner will be used. (Default: ``None``)
107        process_group (Optional[ProcessGroup]):
108            ProcessGroup to be used for cross-rank synchronization.
109            (Default: ``None``)
110
111    Returns:
112        None.
113
114    Examples
115        >>> # xdoctest: +SKIP
116        >>> my_model = MyModule()
117        >>> optimizer = Adagrad(my_model.parameters())
118        >>> model_state_dict = my_model.state_dict()
119        >>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader("/checkpoint/1")
120
121        >>> torch.distributed.checkpoint.load_state_dict(
122        >>>     state_dict=model_state_dict,
123        >>>     storage_reader=fs_storage_reader,
124        >>> )
125
126        >>> # module.load_state_dict() function might have customized steps
127        >>> # to flush the state_dict, must call it to
128        >>> # ensure correct behavior.
129        >>> my_model.load_state_dict(model_state_dict)
130
131    .. note::
132        load_state_dict uses collectives to coordinate reads across ranks.
133        For NCCL-based process groups, internal tensor representations of
134        objects must be moved to the GPU device before communication takes place.
135        In this case, the device used is given by ``torch.cuda.current_device()``
136        and it is the user's responsibility to ensure that this is set so that each
137        rank has an individual GPU, via ``torch.cuda.set_device()``.
138    """
139
140    no_dist = not (dist.is_available() and dist.is_initialized())
141    if no_dist:
142        warnings.warn(
143            "torch.distributed is unavailable or uninitialized, assuming the intent is to load in a single process."
144        )
145
146    with _profile():
147        storage_reader = cast(
148            StorageReader, _storage_setup(storage_reader, checkpoint_id, reader=True)
149        )
150
151        if no_dist:
152            keys = list(state_dict.keys())
153        else:
154            keys = _all_gather_keys(state_dict, process_group)
155            if keys != sorted(state_dict.keys()):
156                warnings.warn(
157                    "Detected mismatched keys in state dict after all gather!"
158                    " This behavior is unsupported and may cause errors may cause errors."
159                )
160
161        statetful_sd = {}
162        for key in keys:
163            if key not in state_dict:
164                continue
165            elem = state_dict[key]
166            statetful_sd[key] = (
167                elem.state_dict() if isinstance(elem, Stateful) else elem
168            )
169
170        _load_state_dict(
171            state_dict=statetful_sd,
172            storage_reader=storage_reader,
173            process_group=process_group,
174            no_dist=no_dist,
175            planner=planner,
176        )
177        for key in keys:
178            if key not in state_dict:
179                continue
180            elem = state_dict[key]
181            if isinstance(elem, Stateful):
182                elem.load_state_dict(statetful_sd[key])
183            state_dict[key] = statetful_sd[key]
184
185
186def _load_state_dict(
187    state_dict: Dict[str, Any],
188    storage_reader: StorageReader,
189    process_group: Optional[dist.ProcessGroup] = None,
190    coordinator_rank: int = 0,
191    no_dist: bool = False,
192    planner: Optional[LoadPlanner] = None,
193) -> None:
194    torch._C._log_api_usage_once("torch.distributed.checkpoint.load_state_dict")
195
196    distW = _DistWrapper(process_group, not no_dist, coordinator_rank)
197    if planner is None:
198        planner = DefaultLoadPlanner()
199
200    ckpt_kwargs = {}
201    if (ckpt_id := getattr(storage_reader, "checkpoint_id", None)) is not None:
202        ckpt_kwargs["checkpoint_id"] = ckpt_id
203
204    @_dcp_method_logger(**ckpt_kwargs)
205    def local_step():
206        assert planner is not None
207        metadata = storage_reader.read_metadata()
208        planner.set_up_planner(state_dict, metadata, distW.is_coordinator)
209        storage_reader.set_up_storage_reader(metadata, distW.is_coordinator)
210
211        local_plan = planner.create_local_plan()
212        local_plan = storage_reader.prepare_local_plan(local_plan)
213        return local_plan
214
215    @_dcp_method_logger(**ckpt_kwargs)
216    def global_step(all_local_plans):
217        assert planner is not None
218        all_local_plans = planner.create_global_plan(all_local_plans)
219        all_local_plans = storage_reader.prepare_global_plan(all_local_plans)
220        return all_local_plans
221
222    central_plan: LoadPlan = distW.reduce_scatter("plan", local_step, global_step)
223
224    @_dcp_method_logger(**ckpt_kwargs)
225    def read_data():
226        assert planner is not None
227        final_local_plan = planner.finish_plan(central_plan)
228        all_reads = storage_reader.read_data(final_local_plan, planner)
229
230        all_reads.wait()
231        return None
232
233    _ = distW.all_gather("read", read_data)
234
235
236def _load_state_dict_from_keys(
237    keys: Optional[Union[Set[str], str]] = None,
238    *,
239    checkpoint_id: Union[str, os.PathLike, None] = None,
240    storage_reader: Optional[StorageReader] = None,
241    process_group: Optional[dist.ProcessGroup] = None,
242) -> Dict[str, Any]:
243    """
244    Load only the specified keys from the checkpoint, if no keys are specified, the entire
245    checkpoint will be loaded. Note, this method completely loads the checkpoint into the
246    current process and is not distributed.
247
248    .. warning::
249
250
251    .. warning::
252
253        All non-tensor data is loaded using `torch.load()`
254
255    .. note:
256        As opposed to the usual pattern, this function does not take a state dict as input
257        and does not load inplace. Instead, a new state dict is directly initialized and read
258        from file.
259
260    .. note:
261        If no process group is initialized, this function will assume the intent
262        is to load a checkpoint into the local process. This can be useful in the
263        case of local inference, and when using regular Tensors (as opposed to DTensor
264         or ShardedTensor)
265
266    .. note:
267        Rank 0 is assumed to be the coordinator rank.
268
269    Args:
270        keys (Optional[Union[Set[str], str]]):
271            Loads any key specified in this set. If no keys are specified, the entire checkpoint
272            is loaded.
273        checkpoint_id (Union[str, os.PathLike, None]):
274            The ID of this checkpoint instance. The meaning of the checkpoint_id
275            depends on the storage. It can be a path to a folder or to a file.
276            It can also be a key if the storage is a key-value store.
277            (Default: ``None``)
278        storage_reader (Optional[StorageReader]):
279            Instance of StorageWriter used to perform reads. If this is not
280            specified, DCP will automatically infer the reader based on the
281            checkpoint_id. If checkpoint_id is also None, an exception will
282            be raised. (Default: ``None``)
283        process_group (Optional[ProcessGroup]):
284            ProcessGroup to be used for cross-rank synchronization.
285            (Default: ``None``)
286
287    Returns:
288        State dict from specified keys
289    """
290    torch._C._log_api_usage_once(
291        "torch.distributed.checkpoint._load_state_dict_from_keys"
292    )
293
294    no_dist = not (dist.is_available() and dist.is_initialized())
295    if no_dist:
296        warnings.warn(
297            "torch.distributed is unavailable or uninitialized, assuming the intent is to load in a single process."
298        )
299
300    storage_reader = cast(
301        StorageReader, _storage_setup(storage_reader, checkpoint_id, reader=True)
302    )
303
304    if isinstance(keys, str):
305        keys = {keys}
306
307    sd: Dict[str, Any] = {}
308    _load_state_dict(
309        state_dict=sd,
310        storage_reader=storage_reader,
311        process_group=process_group,
312        no_dist=no_dist,
313        planner=_EmptyStateDictLoadPlanner(keys=keys or set()),
314    )
315
316    return sd
317