1# Mypy will not try inferring the types of any 3rd party libraries installed. 2# mypy: ignore-errors 3 4import io 5import os 6from contextlib import contextmanager 7from pathlib import Path 8from typing import Generator, Optional, Union 9 10import fsspec 11from fsspec import AbstractFileSystem 12from fsspec.core import url_to_fs 13 14from torch.distributed.checkpoint.filesystem import ( 15 FileSystemBase, 16 FileSystemReader, 17 FileSystemWriter, 18) 19 20 21__all__ = [ 22 "FsspecWriter", 23 "FsspecReader", 24] 25 26 27class FileSystem(FileSystemBase): 28 def __init__(self) -> None: 29 self.fs: Optional[AbstractFileSystem] = None 30 31 @contextmanager 32 def create_stream( 33 self, path: Union[str, os.PathLike], mode: str 34 ) -> Generator[io.IOBase, None, None]: 35 assert self.fs is not None 36 with self.fs.transaction: 37 with fsspec.open(str(path), mode) as stream: 38 yield stream 39 40 def concat_path( 41 self, path: Union[str, os.PathLike], suffix: str 42 ) -> Union[str, os.PathLike]: 43 return os.path.join(path, suffix) 44 45 def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]: 46 self.fs, _ = url_to_fs(path) 47 return path 48 49 def rename( 50 self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike] 51 ) -> None: 52 self.fs.rename(path, new_path) 53 54 def mkdir(self, path: [str, os.PathLike]) -> None: 55 self.fs.makedirs(path, exist_ok=True) 56 57 @classmethod 58 def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: 59 if isinstance(checkpoint_id, Path): 60 return False 61 62 try: 63 url_to_fs(checkpoint_id) 64 except ValueError: 65 return False 66 67 return True 68 69 def exists(self, path: Union[str, os.PathLike]) -> bool: 70 return self.fs.exists(path) 71 72 def rm_file(self, path: Union[str, os.PathLike]) -> None: 73 self.fs.rm(path) 74 75 76# TODO: add the dcp.async_save mixin 77class FsspecWriter(FileSystemWriter): 78 """ 79 Basic implementation of StorageWriter using FFspec. 80 81 This implementation makes the following assumptions and simplifications: 82 83 * The checkpoint path is an empty or non-existing directory. 84 * File creation is atomic 85 86 The checkpoint consist of one file per write request plus 87 a `.metadata` file with the serialized metadata. 88 89 """ 90 91 def __init__( 92 self, 93 path: Union[str, os.PathLike], 94 single_file_per_rank: bool = True, 95 sync_files: bool = True, 96 thread_count: int = 1, 97 per_thread_copy_ahead: int = 10_000_000, 98 overwrite: bool = True, 99 ) -> None: 100 """ 101 Initialize the writer pointing to `path`. 102 103 Args: 104 path: directory where the checkpoint will be written to. 105 single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True. 106 sync_files : force files to be synced to permanent storage. Default to True. 107 thread_count: Number of IO threads to use to write. Default to 1. 108 per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb. 109 overwrite: Whether to allow overwriting existing checkpoints. Defaults to True. 110 111 N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure. 112 """ 113 super().__init__( 114 path, 115 single_file_per_rank, 116 sync_files, 117 thread_count, 118 per_thread_copy_ahead, 119 overwrite=overwrite, 120 ) 121 self.fs = FileSystem() 122 self.path = self.fs.init_path(path) 123 124 @classmethod 125 def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: 126 return FileSystem.validate_checkpoint_id(checkpoint_id) 127 128 129class FsspecReader(FileSystemReader): 130 def __init__(self, path: Union[str, os.PathLike]) -> None: 131 super().__init__(path) 132 self.fs = FileSystem() 133 self.path = self.fs.init_path(path) 134 135 @classmethod 136 def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: 137 return FileSystem.validate_checkpoint_id(checkpoint_id) 138