1import abc 2import os 3from dataclasses import dataclass 4from typing import Any, List, Optional, Union 5 6from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex, StorageMeta 7from torch.distributed.checkpoint.planner import ( 8 LoadPlan, 9 LoadPlanner, 10 SavePlan, 11 SavePlanner, 12) 13from torch.futures import Future 14 15 16__all__ = ["WriteResult", "StorageWriter", "StorageReader"] 17 18 19@dataclass(frozen=True) 20class WriteResult: 21 index: MetadataIndex 22 23 size_in_bytes: int 24 storage_data: Any 25 26 27class StorageWriter(abc.ABC): 28 """ 29 Interface used by ``save_state_dict`` to write to storage. 30 31 One StorageWriter instance acts as both the coordinator and the follower 32 in a distributed checkpoint. As part of initialization, each instance 33 is told its role. 34 35 A subclass should expect the following sequence of calls. 36 37 0) (all ranks) set checkpoint_id if users pass a valid checkpoint_id. 38 1) (all ranks) set_up_storage_writer() 39 2) (all ranks) prepare_local_plan() 40 3) (coordinator) prepare_global_plan() 41 4) (all ranks) write_data() 42 5) (coordinator) finish() 43 """ 44 45 @abc.abstractmethod 46 def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: 47 """ 48 Calls to indicates a brand new checkpoint write is going to happen. 49 A checkpoint_id may be present if users set the checkpoint_id for 50 this checkpoint write. The meaning of the checkpiont_id is 51 storage-dependent. It can be a path to a folder/file or a key for 52 a key-value storage. 53 54 Args: 55 checkpoint_id (Union[str, os.PathLike, None]): 56 The ID of this checkpoint instance. The meaning of the checkpoint_id 57 depends on the storage. It can be a path to a folder or to a file. 58 It can also be a key if the storage is a key-value store. 59 (Default: ``None``) 60 """ 61 ... 62 63 @abc.abstractmethod 64 def set_up_storage_writer(self, is_coordinator: bool) -> None: 65 """ 66 Initialize this instance. 67 68 Args: 69 is_coordinator (bool): Whether this instance is responsible for coordinating 70 the checkpoint. 71 """ 72 73 @abc.abstractmethod 74 def prepare_local_plan(self, plan: SavePlan) -> SavePlan: 75 """ 76 Perform storage-specific local planning. 77 78 While this method can produce a completely different plan, the recommended 79 way is to store storage specific data in SavePlan::storage_data. 80 81 Args: 82 plan (SavePlan): The local plan from the ``SavePlanner`` in use. 83 84 Returns: 85 A transformed ``SavePlan`` after storage local planning 86 """ 87 88 @abc.abstractmethod 89 def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]: 90 """ 91 Perform centralized planning of storage. 92 93 This method is only called on the coordinator instance. 94 95 While this method can produce a completely different plan, the preferred 96 way is to store storage specific data in SavePlan::storage_data. 97 98 Args: 99 plans: A list of ``SavePlan`` instances, one for each rank. 100 101 Returns: 102 A list of transformed ``SavePlan`` after storage global planning 103 """ 104 105 @abc.abstractmethod 106 def write_data( 107 self, plan: SavePlan, planner: SavePlanner 108 ) -> Future[List[WriteResult]]: 109 """ 110 Write all items from ``plan`` using ``planner`` to resolve the data. 111 112 A subclass should call ``SavePlanner::resolve_data`` on each item 113 from the plan to get access to the underlying object to write. 114 115 Subclasses should lazily call `resolve_data` as it can allocate memory. 116 In case of tensors, make following assumptions: 117 118 - They might be on any device, including not matching the one on ``WriteItem::tensor_data`` 119 - They might be views or not contiguous. Only the projection needs to be saved. 120 121 Args: 122 plan (SavePlan): The save plan to execute. 123 planner (SavePlanner): Planner object to be used to resolve items to data. 124 125 Returns: 126 A future that completes to a list of WriteResult 127 """ 128 129 @abc.abstractmethod 130 def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None: 131 """ 132 Write the metadata and marks the current checkpoint as successful. 133 134 The actual format/schema used for serializing `metadata` is an 135 implementation detail. The only requirement is that it's recoverable 136 in to the same object graph. 137 138 Args: 139 metadata (Metadata): metadata for the new checkpoint 140 results: A list of WriteResults from all ranks. 141 142 Returns: 143 None 144 """ 145 146 @classmethod 147 @abc.abstractmethod 148 def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: 149 """ 150 Check if the given checkpoint_id is supported by the stroage. This allow 151 us to enable automatic storage selection. 152 """ 153 ... 154 155 def storage_meta(self) -> Optional[StorageMeta]: 156 """ 157 Return the storage-specific metadata. This is used to store additional information 158 in a checkpoint that can be useful for providing request-level observability. StorageMeta 159 is passed to the ``SavePlanner`` during save calls. Returns None by default. 160 161 TODO: provide an example 162 """ 163 return None 164 165 166class StorageReader(abc.ABC): 167 """ 168 Interface used by ``load_state_dict`` to read from storage. 169 170 One StorageReader instance acts as both the coordinator and the follower 171 in a distributed checkpoint. As part of initialization, each instance 172 is told its role. 173 174 A subclass should expected the following sequence of calls by ``load_state_dict``: 175 176 0) (all ranks) set checkpoint_id if users pass a valid checkpoint_id. 177 1) (all ranks) read_metadata() 178 2) (all ranks) set_up_storage_reader() 179 3) (all ranks) prepare_local_plan() 180 4) (coordinator) prepare_global_plan() 181 5) (all ranks) read_data() 182 """ 183 184 @abc.abstractmethod 185 def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: 186 """ 187 Calls to indicates a brand new checkpoint read is going to happen. 188 A checkpoint_id may be present if users set the checkpoint_id for 189 this checkpoint read. The meaning of the checkpiont_id is 190 storage-dependent. It can be a path to a folder/file or a key for 191 a key-value storage. 192 193 Args: 194 checkpoint_id (Union[str, os.PathLike, None]): 195 The ID of this checkpoint instance. The meaning of the checkpoint_id 196 depends on the storage. It can be a path to a folder or to a file. 197 It can also be a key if the storage is more like a key-value store. 198 (Default: ``None``) 199 """ 200 ... 201 202 @abc.abstractmethod 203 def read_metadata(self) -> Metadata: 204 """ 205 Read the checkpoint metadata. 206 207 Returns: 208 The metadata object associated with the checkpoint being loaded. 209 210 """ 211 212 @abc.abstractmethod 213 def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None: 214 """ 215 Initialize this instance. 216 217 Args: 218 metadata (Metadata): The metadata schema to use. 219 is_coordinator (bool): Whether this instance is responsible for coordinating 220 the checkpoint. 221 """ 222 223 @abc.abstractmethod 224 def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: 225 """ 226 Perform storage-specific local planning. 227 228 While this method can produce a completely different plan, the recommended 229 way is to store storage specific data in LoadPlan::storage_data. 230 231 Args: 232 plan (LoadPlan): The local plan from the ``LoadPlan`` in use. 233 234 Returns: 235 A transformed ``LoadPlan`` after storage local planning 236 """ 237 238 @abc.abstractmethod 239 def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]: 240 """ 241 Perform centralized planning of storage loading. 242 243 This method is only called on the coordinator instance. 244 245 While this method can produce a completely different plan, the preferred 246 way is to store storage specific data in LoadPlan::storage_data. 247 248 Args: 249 plans: A list of ``LoadPlan`` instances, one for each rank. 250 251 Returns: 252 A list of transformed ``LoadPlan`` after storage global planning 253 """ 254 255 @abc.abstractmethod 256 def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: 257 """ 258 Read all items from ``plan`` using ``planner`` to resolve the data. 259 260 A subclass should call ``LoadPlanner::load_bytes`` to deserialize a BytesIO 261 object into the right place. 262 263 A subclass should call ``LoadPlanner::resolve_tensor`` to get access to the 264 tensors that in should load data into. 265 266 It's the StorageLayer responsibility to properly schedule any cross device copies 267 required. 268 269 Args: 270 plan (LoadPlan): The local plan to execute on 271 planner (LoadPlanner): The planner object to use to resolve items. 272 273 Returns: 274 A future that completes once all reads are finished. 275 """ 276 277 @classmethod 278 @abc.abstractmethod 279 def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: 280 """ 281 Check if the given checkpoint_id is supported by the stroage. This allow 282 us to enable automatic storage selection. 283 """ 284 ... 285