1# mypy: allow-untyped-defs 2import argparse 3import os 4from enum import Enum 5from typing import cast, Dict, List, Optional, Union 6 7import torch 8import torch.distributed as dist 9from torch.distributed._shard._utils import narrow_tensor_by_index 10from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter 11from torch.distributed.checkpoint._nested_dict import flatten_state_dict 12from torch.distributed.checkpoint.default_planner import ( 13 _EmptyStateDictLoadPlanner, 14 DefaultLoadPlanner, 15) 16from torch.distributed.checkpoint.metadata import ( 17 Metadata, 18 STATE_DICT_TYPE, 19 STORAGE_TYPES, 20 TensorProperties, 21 TensorStorageMetadata, 22) 23from torch.distributed.checkpoint.planner import LoadItemType, LoadPlan, LoadPlanner 24from torch.distributed.checkpoint.planner_helpers import _create_chunk_list 25from torch.distributed.checkpoint.state_dict_loader import _load_state_dict 26from torch.distributed.checkpoint.state_dict_saver import _save_state_dict 27from torch.distributed.checkpoint.storage import StorageReader 28from torch.futures import Future 29 30 31__all__ = [ 32 "dcp_to_torch_save", 33 "torch_save_to_dcp", 34 "BroadcastingTorchSaveReader", 35 "DynamicMetaLoadPlanner", 36] 37 38 39class BroadcastingTorchSaveReader(StorageReader): 40 """ 41 StorageReader for reading a Torch Save file. This reader will read the entire checkpoint 42 on the coordinator rank, and then broadcast and shard each tensor to all ranks. 43 44 . N.B. Intended to be used with DynamicMetaLoadPlanner 45 46 .. warning:: 47 Current implementation only supports loading Tensors. 48 49 >>> # xdoctest: +SKIP("undefined vars") 50 >>> sd = {"mode": model} 51 >>> dcp.load( 52 >>> sd, 53 >>> storage_reader=BroadcastingTorchSaveReader(), 54 >>> planner=DynamicMetaLoadPlanner(), 55 >>> checkpoint_id="path_to_model.pt" 56 >>> ) 57 """ 58 59 def __init__( 60 self, 61 checkpoint_id: Optional[Union[str, os.PathLike]] = None, 62 coordinator_rank: int = 0, 63 ) -> None: 64 self.checkpoint_id = checkpoint_id 65 self.coordinator_rank = coordinator_rank 66 67 def read_metadata(self) -> Metadata: 68 """Extends the default StorageReader to support building the metadata file""" 69 # Metadata is built in planner.set_up_planner, since we are not actually reading metadata from 70 # the disk 71 return Metadata(state_dict_metadata={}) 72 73 def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: 74 """ 75 Reads torch save data on the coordinator rank, and broadcast afterwards 76 this incurrs a communication cost, but avoids having to load 77 the entire checkpoint on each rank, hopefully preventing OOM issues 78 """ 79 planner = cast(DefaultLoadPlanner, planner) 80 81 # data is read in on the coordinator rank, and broadcast afterwards 82 # this incurrs a communication cost, but it avoids having to load 83 # the entire checkpoint on each rank, hopefully preventing OOM issues 84 # TODO: read on each host, instead of only the coordinator 85 if self.is_coordinator: 86 assert self.checkpoint_id is not None 87 torch_state_dict = torch.load( 88 self.checkpoint_id, map_location="cpu", weights_only=False 89 ) 90 if planner.flatten_state_dict: 91 torch_state_dict, _ = flatten_state_dict(torch_state_dict) 92 else: 93 torch_state_dict = None 94 95 for req in plan.items: 96 if req.type == LoadItemType.BYTE_IO: 97 raise RuntimeError( 98 f"Non-tensor value identified at {req.storage_index.fqn}. " 99 f"At this time {type(self).__name__} only supports loading Tensors." 100 ) 101 102 # Broadcast the tensor from the coordinator rank 103 if self.is_coordinator: 104 pg_device = dist.distributed_c10d._get_pg_default_device() 105 tensor = torch_state_dict[req.storage_index.fqn].to(pg_device) 106 else: 107 tensor = torch.empty_like(planner.state_dict[req.storage_index.fqn]) 108 109 dist.broadcast(tensor, src=self.coordinator_rank, async_op=False) 110 111 tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths) 112 target_tensor = planner.resolve_tensor(req).detach() 113 assert target_tensor.size() == tensor.size(), ( 114 f"req {req.storage_index} mismatch sizes, " 115 f"{target_tensor.size()} vs {tensor.size()}" 116 ) 117 target_tensor.copy_(tensor) 118 planner.commit_tensor(req, target_tensor) 119 120 fut: Future = Future() 121 fut.set_result(None) 122 return fut 123 124 def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None: 125 """Implementation of the StorageReader method""" 126 self.is_coordinator = is_coordinator 127 if self.is_coordinator: 128 assert dist.get_rank() == self.coordinator_rank 129 130 assert self.checkpoint_id is not None 131 132 def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: 133 """Implementation of the StorageReader method""" 134 return plan 135 136 def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]: 137 """Implementation of the StorageReader method""" 138 return global_plan 139 140 def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: 141 """Implementation of the StorageReader method""" 142 self.checkpoint_id = checkpoint_id 143 144 @classmethod 145 def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: 146 """Implementation of the StorageReader method""" 147 return os.path.isfile(checkpoint_id) 148 149 150class DynamicMetaLoadPlanner(DefaultLoadPlanner): 151 """ 152 Extension of DefaultLoadPlanner, which creates a new Metadata object based on the passed in state dict, 153 avoiding the need to read metadata from disk. This is useful when reading formats which don't have a 154 metadata file, like Torch Save files. 155 156 . N.B. Intended to be used with BroadcastingTorchSaveReader 157 158 .. warning:: 159 Current implementation only supports loading Tensors. 160 161 >>> # xdoctest: +SKIP("undefined vars") 162 >>> sd = {"mode": model} 163 >>> dcp.load( 164 >>> sd, 165 >>> storage_reader=BroadcastingTorchSaveReader(), 166 >>> planner=DynamicMetaLoadPlanner(), 167 >>> checkpoint_id="path_to_model.pt" 168 >>> ) 169 """ 170 171 def set_up_planner( 172 self, 173 state_dict: STATE_DICT_TYPE, 174 metadata: Optional[Metadata] = None, 175 is_coordinator: bool = False, 176 ) -> None: 177 """Setups of the planner, extnding default behavior by creating the Metadata object from the state dict""" 178 super().set_up_planner(state_dict, metadata, is_coordinator) 179 180 state_dict_metadata: Dict[str, STORAGE_TYPES] = {} 181 for key, tensor in self.state_dict.items(): 182 if not torch.is_tensor(tensor): 183 raise RuntimeError( 184 f"Non-tensor value identified at {key}. " 185 f"At this time {type(self).__name__} only supports loading Tensors." 186 ) 187 188 state_dict_metadata[key] = TensorStorageMetadata( 189 TensorProperties(dtype=tensor.dtype), 190 tensor.size(), 191 _create_chunk_list(tensor), 192 ) 193 self.metadata = Metadata(state_dict_metadata=state_dict_metadata) 194 195 196def dcp_to_torch_save( 197 dcp_checkpoint_dir: Union[str, os.PathLike], 198 torch_save_path: Union[str, os.PathLike], 199): 200 """ 201 Given a directory containing a DCP checkpoint, this function will convert it into a 202 Torch save file. 203 204 Args: 205 dcp_checkpoint_dir: Directory containing the DCP checkpoint. 206 torch_save_path: Filename to store the converted Torch save file. 207 208 .. warning:: 209 To avoid OOM, it's recommended to only run this function on a single rank. 210 """ 211 sd: STATE_DICT_TYPE = {} 212 _load_state_dict( 213 sd, 214 storage_reader=FileSystemReader(dcp_checkpoint_dir), 215 planner=_EmptyStateDictLoadPlanner(), 216 no_dist=True, 217 ) 218 torch.save(sd, torch_save_path) 219 220 221def torch_save_to_dcp( 222 torch_save_path: Union[str, os.PathLike], 223 dcp_checkpoint_dir: Union[str, os.PathLike], 224): 225 """ 226 Given the location of a torch save file, converts it into a DCP checkpoint. 227 228 Args: 229 torch_save_path: Filename of the Torch save file. 230 dcp_checkpoint_dir: Directory to store the DCP checkpoint. 231 232 .. warning:: 233 To avoid OOM, it's recommended to only run this function on a single rank. 234 """ 235 236 state_dict = torch.load(torch_save_path, weights_only=False) 237 # we don't need stateful behavior here because the expectation is anything loaded by 238 # torch.load would not contain stateful objects. 239 _save_state_dict( 240 state_dict, storage_writer=FileSystemWriter(dcp_checkpoint_dir), no_dist=True 241 ) 242 243 244if __name__ == "__main__": 245 246 class FormatMode(Enum): 247 TORCH_TO_DCP = "torch_to_dcp" 248 DCP_TO_TORCH = "dcp_to_torch" 249 250 # Parse command-line arguments 251 parser = argparse.ArgumentParser() 252 parser.add_argument( 253 "mode", 254 type=str, 255 help="Conversion mode", 256 choices=[m.value for m in FormatMode], 257 default=FormatMode.TORCH_TO_DCP, 258 ) 259 parser.add_argument("src", type=str, help="Path to the source model") 260 parser.add_argument("dst", type=str, help="Path to the destination model") 261 args = parser.parse_args() 262 263 print( 264 f"Converting checkpoint from {args.src} to {args.dst} using method: '{args.mode}'" 265 ) 266 checkpoint_missing_warning = ( 267 f"No checkpoint found at {args.src}. Skipping conversion." 268 ) 269 if args.mode == FormatMode.TORCH_TO_DCP.value: 270 if os.path.isfile(args.src): 271 torch_save_to_dcp(args.src, args.dst) 272 else: 273 print(checkpoint_missing_warning) 274 elif args.mode == FormatMode.DCP_TO_TORCH.value: 275 if os.path.isdir(args.src): 276 dcp_to_torch_save(args.src, args.dst) 277 else: 278 print(checkpoint_missing_warning) 279 else: 280 raise ValueError(f"Unknown conversion mode: {args.mode}") 281