• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Mypy will not try inferring the types of any 3rd party libraries installed.
2# mypy: ignore-errors
3
4import io
5import os
6from contextlib import contextmanager
7from pathlib import Path
8from typing import Generator, Optional, Union
9
10import fsspec
11from fsspec import AbstractFileSystem
12from fsspec.core import url_to_fs
13
14from torch.distributed.checkpoint.filesystem import (
15    FileSystemBase,
16    FileSystemReader,
17    FileSystemWriter,
18)
19
20
21__all__ = [
22    "FsspecWriter",
23    "FsspecReader",
24]
25
26
27class FileSystem(FileSystemBase):
28    def __init__(self) -> None:
29        self.fs: Optional[AbstractFileSystem] = None
30
31    @contextmanager
32    def create_stream(
33        self, path: Union[str, os.PathLike], mode: str
34    ) -> Generator[io.IOBase, None, None]:
35        assert self.fs is not None
36        with self.fs.transaction:
37            with fsspec.open(str(path), mode) as stream:
38                yield stream
39
40    def concat_path(
41        self, path: Union[str, os.PathLike], suffix: str
42    ) -> Union[str, os.PathLike]:
43        return os.path.join(path, suffix)
44
45    def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
46        self.fs, _ = url_to_fs(path)
47        return path
48
49    def rename(
50        self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]
51    ) -> None:
52        self.fs.rename(path, new_path)
53
54    def mkdir(self, path: [str, os.PathLike]) -> None:
55        self.fs.makedirs(path, exist_ok=True)
56
57    @classmethod
58    def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
59        if isinstance(checkpoint_id, Path):
60            return False
61
62        try:
63            url_to_fs(checkpoint_id)
64        except ValueError:
65            return False
66
67        return True
68
69    def exists(self, path: Union[str, os.PathLike]) -> bool:
70        return self.fs.exists(path)
71
72    def rm_file(self, path: Union[str, os.PathLike]) -> None:
73        self.fs.rm(path)
74
75
76# TODO: add the dcp.async_save mixin
77class FsspecWriter(FileSystemWriter):
78    """
79    Basic implementation of StorageWriter using FFspec.
80
81    This implementation makes the following assumptions and simplifications:
82
83    * The checkpoint path is an empty or non-existing directory.
84    * File creation is atomic
85
86    The checkpoint consist of one file per write request plus
87    a `.metadata` file with the serialized metadata.
88
89    """
90
91    def __init__(
92        self,
93        path: Union[str, os.PathLike],
94        single_file_per_rank: bool = True,
95        sync_files: bool = True,
96        thread_count: int = 1,
97        per_thread_copy_ahead: int = 10_000_000,
98        overwrite: bool = True,
99    ) -> None:
100        """
101        Initialize the writer pointing to `path`.
102
103        Args:
104            path: directory where the checkpoint will be written to.
105            single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True.
106            sync_files : force files to be synced to permanent storage. Default to True.
107            thread_count: Number of IO threads to use to write. Default to 1.
108            per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb.
109            overwrite: Whether to allow overwriting existing checkpoints. Defaults to True.
110
111        N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure.
112        """
113        super().__init__(
114            path,
115            single_file_per_rank,
116            sync_files,
117            thread_count,
118            per_thread_copy_ahead,
119            overwrite=overwrite,
120        )
121        self.fs = FileSystem()
122        self.path = self.fs.init_path(path)
123
124    @classmethod
125    def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
126        return FileSystem.validate_checkpoint_id(checkpoint_id)
127
128
129class FsspecReader(FileSystemReader):
130    def __init__(self, path: Union[str, os.PathLike]) -> None:
131        super().__init__(path)
132        self.fs = FileSystem()
133        self.path = self.fs.init_path(path)
134
135    @classmethod
136    def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
137        return FileSystem.validate_checkpoint_id(checkpoint_id)
138