• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import abc
2import os
3from dataclasses import dataclass
4from typing import Any, List, Optional, Union
5
6from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex, StorageMeta
7from torch.distributed.checkpoint.planner import (
8    LoadPlan,
9    LoadPlanner,
10    SavePlan,
11    SavePlanner,
12)
13from torch.futures import Future
14
15
16__all__ = ["WriteResult", "StorageWriter", "StorageReader"]
17
18
19@dataclass(frozen=True)
20class WriteResult:
21    index: MetadataIndex
22
23    size_in_bytes: int
24    storage_data: Any
25
26
27class StorageWriter(abc.ABC):
28    """
29    Interface used by ``save_state_dict`` to write to storage.
30
31    One StorageWriter instance acts as both the coordinator and the follower
32    in a distributed checkpoint. As part of initialization, each instance
33    is told its role.
34
35    A subclass should expect the following sequence of calls.
36
37    0) (all ranks) set checkpoint_id if users pass a valid checkpoint_id.
38    1) (all ranks) set_up_storage_writer()
39    2) (all ranks) prepare_local_plan()
40    3) (coordinator) prepare_global_plan()
41    4) (all ranks) write_data()
42    5) (coordinator) finish()
43    """
44
45    @abc.abstractmethod
46    def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
47        """
48        Calls to indicates a brand new checkpoint write is going to happen.
49        A checkpoint_id may be present if users set the checkpoint_id for
50        this checkpoint write. The meaning of the checkpiont_id is
51        storage-dependent. It can be a path to a folder/file or a key for
52        a key-value storage.
53
54        Args:
55            checkpoint_id (Union[str, os.PathLike, None]):
56                The ID of this checkpoint instance. The meaning of the checkpoint_id
57                depends on the storage. It can be a path to a folder or to a file.
58                It can also be a key if the storage is a key-value store.
59                (Default: ``None``)
60        """
61        ...
62
63    @abc.abstractmethod
64    def set_up_storage_writer(self, is_coordinator: bool) -> None:
65        """
66        Initialize this instance.
67
68        Args:
69            is_coordinator (bool): Whether this instance is responsible for coordinating
70              the checkpoint.
71        """
72
73    @abc.abstractmethod
74    def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
75        """
76        Perform storage-specific local planning.
77
78        While this method can produce a completely different plan, the recommended
79        way is to store storage specific data in SavePlan::storage_data.
80
81        Args:
82            plan (SavePlan): The local plan from the ``SavePlanner`` in use.
83
84        Returns:
85            A transformed ``SavePlan`` after storage local planning
86        """
87
88    @abc.abstractmethod
89    def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
90        """
91        Perform centralized planning of storage.
92
93        This method is only called on the coordinator instance.
94
95        While this method can produce a completely different plan, the preferred
96        way is to store storage specific data in SavePlan::storage_data.
97
98        Args:
99            plans: A list of ``SavePlan`` instances, one for each rank.
100
101        Returns:
102            A list of transformed ``SavePlan`` after storage global planning
103        """
104
105    @abc.abstractmethod
106    def write_data(
107        self, plan: SavePlan, planner: SavePlanner
108    ) -> Future[List[WriteResult]]:
109        """
110        Write all items from ``plan`` using ``planner`` to resolve the data.
111
112        A subclass should call ``SavePlanner::resolve_data`` on each item
113        from the plan to get access to the underlying object to write.
114
115        Subclasses should lazily call `resolve_data` as it can allocate memory.
116        In case of tensors, make following assumptions:
117
118        - They might be on any device, including not matching the one on ``WriteItem::tensor_data``
119        - They might be views or not contiguous. Only the projection needs to be saved.
120
121        Args:
122            plan (SavePlan): The save plan to execute.
123            planner (SavePlanner): Planner object to be used to resolve items to data.
124
125        Returns:
126            A future that completes to a list of WriteResult
127        """
128
129    @abc.abstractmethod
130    def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
131        """
132        Write the metadata and marks the current checkpoint as successful.
133
134        The actual format/schema used for serializing `metadata` is an
135        implementation detail. The only requirement is that it's recoverable
136        in to the same object graph.
137
138        Args:
139            metadata (Metadata): metadata for the new checkpoint
140            results: A list of WriteResults from all ranks.
141
142        Returns:
143            None
144        """
145
146    @classmethod
147    @abc.abstractmethod
148    def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
149        """
150        Check if the given checkpoint_id is supported by the stroage. This allow
151        us to enable automatic storage selection.
152        """
153        ...
154
155    def storage_meta(self) -> Optional[StorageMeta]:
156        """
157        Return the storage-specific metadata. This is used to store additional information
158        in a checkpoint that can be useful for providing request-level observability. StorageMeta
159        is passed to the ``SavePlanner`` during save calls. Returns None by default.
160
161        TODO: provide an example
162        """
163        return None
164
165
166class StorageReader(abc.ABC):
167    """
168    Interface used by ``load_state_dict`` to read from storage.
169
170    One StorageReader instance acts as both the coordinator and the follower
171    in a distributed checkpoint. As part of initialization, each instance
172    is told its role.
173
174    A subclass should expected the following sequence of calls by ``load_state_dict``:
175
176    0) (all ranks) set checkpoint_id if users pass a valid checkpoint_id.
177    1) (all ranks) read_metadata()
178    2) (all ranks) set_up_storage_reader()
179    3) (all ranks) prepare_local_plan()
180    4) (coordinator) prepare_global_plan()
181    5) (all ranks) read_data()
182    """
183
184    @abc.abstractmethod
185    def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
186        """
187        Calls to indicates a brand new checkpoint read is going to happen.
188        A checkpoint_id may be present if users set the checkpoint_id for
189        this checkpoint read. The meaning of the checkpiont_id is
190        storage-dependent. It can be a path to a folder/file or a key for
191        a key-value storage.
192
193        Args:
194            checkpoint_id (Union[str, os.PathLike, None]):
195                The ID of this checkpoint instance. The meaning of the checkpoint_id
196                depends on the storage. It can be a path to a folder or to a file.
197                It can also be a key if the storage is more like a key-value store.
198                (Default: ``None``)
199        """
200        ...
201
202    @abc.abstractmethod
203    def read_metadata(self) -> Metadata:
204        """
205        Read the checkpoint metadata.
206
207        Returns:
208            The metadata object associated with the checkpoint being loaded.
209
210        """
211
212    @abc.abstractmethod
213    def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
214        """
215        Initialize this instance.
216
217        Args:
218            metadata (Metadata): The metadata schema to use.
219            is_coordinator (bool): Whether this instance is responsible for coordinating
220              the checkpoint.
221        """
222
223    @abc.abstractmethod
224    def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
225        """
226        Perform storage-specific local planning.
227
228        While this method can produce a completely different plan, the recommended
229        way is to store storage specific data in LoadPlan::storage_data.
230
231        Args:
232            plan (LoadPlan): The local plan from the ``LoadPlan`` in use.
233
234        Returns:
235            A transformed ``LoadPlan`` after storage local planning
236        """
237
238    @abc.abstractmethod
239    def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]:
240        """
241        Perform centralized planning of storage loading.
242
243        This method is only called on the coordinator instance.
244
245        While this method can produce a completely different plan, the preferred
246        way is to store storage specific data in LoadPlan::storage_data.
247
248        Args:
249            plans: A list of ``LoadPlan`` instances, one for each rank.
250
251        Returns:
252            A list of transformed ``LoadPlan`` after storage global planning
253        """
254
255    @abc.abstractmethod
256    def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
257        """
258        Read all items from ``plan`` using ``planner`` to resolve the data.
259
260        A subclass should call ``LoadPlanner::load_bytes`` to deserialize a BytesIO
261        object into the right place.
262
263        A subclass should call ``LoadPlanner::resolve_tensor`` to get access to the
264        tensors that in should load data into.
265
266        It's the StorageLayer responsibility to properly schedule any cross device copies
267        required.
268
269        Args:
270            plan (LoadPlan): The local plan to execute on
271            planner (LoadPlanner): The planner object to use to resolve items.
272
273        Returns:
274            A future that completes once all reads are finished.
275        """
276
277    @classmethod
278    @abc.abstractmethod
279    def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
280        """
281        Check if the given checkpoint_id is supported by the stroage. This allow
282        us to enable automatic storage selection.
283        """
284        ...
285