• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2# Copyright (c) Meta Platforms, Inc. and affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8from typing import Any, List, Tuple
9
10import torch
11from torch.distributed.checkpoint.metadata import (
12    ChunkStorageMetadata,
13    MetadataIndex,
14    TensorProperties,
15    TensorStorageMetadata,
16)
17from torch.distributed.checkpoint.planner import (
18    TensorWriteData,
19    WriteItem,
20    WriteItemType,
21)
22
23
24aten = (
25    torch.ops.aten
26)  # pyre-ignore[5]: Globally accessible variable `aten` has no type specified.
27
28
29class LocalShardsWrapper(torch.Tensor):  # pyre-ignore[13]: pyre is bad at __new__
30    """
31    A wrapper class to hold local shards of a DTensor.
32    This class is used largely for checkpointing purposes and implicity subtypes
33    the _Checkpointable protocol.
34    """
35
36    __slots__ = ["_local_shards", "_storage_meta"]
37    _local_shards: List[torch.Tensor]
38    _storage_meta: TensorStorageMetadata
39
40    @staticmethod
41    def __new__(
42        cls, local_shards: List[torch.Tensor], local_offsets: List[Tuple[int, ...]]
43    ) -> "LocalShardsWrapper":
44        assert len(local_shards) > 0
45        assert len(local_shards) == len(local_offsets)
46        assert all(
47            tensor.device == local_shards[0].device for tensor in local_shards[1:]
48        )
49
50        # we calculate the total tensor size by "concat" on second tensor dimension
51        cat_tensor_shape = list(local_shards[0].size())
52        if len(local_shards) > 1:  # column-wise sharding
53            for shard in local_shards[1:]:
54                cat_tensor_shape[1] += shard.size()[1]
55
56        wrapper_properties = TensorProperties.create_from_tensor(local_shards[0])
57        wrapper_shape = torch.Size(cat_tensor_shape)
58        chunks_meta = [
59            ChunkStorageMetadata(
60                offsets=torch.Size(offset),
61                sizes=shard.size(),
62            )
63            for shard, offset in zip(local_shards, local_offsets)
64        ]
65
66        r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
67            cls,
68            torch.Size(cat_tensor_shape),
69        )
70        r._local_shards = local_shards
71        r._storage_meta = TensorStorageMetadata(
72            properties=wrapper_properties,
73            size=wrapper_shape,
74            chunks=chunks_meta,
75        )
76
77        return r
78
79    # necessary for ops dispatching from this subclass to its local shards
80    @classmethod
81    # pyre-fixme[3]: Return type must be annotated.
82    # pyre-fixme[2]: Parameter must be annotated.
83    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
84        kwargs = kwargs or {}
85
86        dispatcher = {
87            torch.ops._c10d_functional.all_gather_into_tensor.default: cls.handle_all_gather_into_tensor,
88            torch.ops._c10d_functional.wait_tensor.default: cls.handle_wait_tensor,
89            aten._to_copy.default: cls.handle_to_copy,
90            aten.view.default: cls.handle_view,
91            aten.equal.default: cls.handle_equal,
92            aten.detach.default: cls.handle_detach,
93            aten.clone.default: cls.handle_clone,
94        }
95
96        if func in dispatcher:
97            return dispatcher[func](
98                args, kwargs
99            )  # pyre-ignore [29] - `Variable[_VT]` is not a function.
100        else:
101            raise NotImplementedError(
102                f"{func} is not supported for LocalShardsWrapper!"
103            )
104
105    @staticmethod
106    # pyre-fixme[3]: Return type must be annotated.
107    # pyre-fixme[2]: Parameter must be annotated.
108    def handle_all_gather_into_tensor(args, kwargs):
109        dim = args[0].local_sizes()[0][1]
110        cat_tensor = torch.cat(
111            [t.view(-1) for t in args[0].local_shards()], dim=0
112        ).view(-1, dim)
113        return torch.ops._c10d_functional.all_gather_into_tensor.default(
114            cat_tensor, *args[1:], **kwargs
115        )
116
117    @staticmethod
118    # pyre-fixme[3]: Return type must be annotated.
119    # pyre-fixme[2]: Parameter must be annotated.
120    def handle_wait_tensor(args, kwargs):
121        return torch.ops._c10d_functional.wait_tensor(args[0])
122
123    @staticmethod
124    # pyre-fixme[3]: Return type must be annotated.
125    # pyre-fixme[2]: Parameter must be annotated.
126    def handle_to_copy(args, kwargs):
127        res_shards_list = [
128            aten._to_copy.default(shard, *args[1:], **kwargs)
129            for shard in args[0].local_shards()
130        ]
131        return LocalShardsWrapper(res_shards_list, args[0].local_offsets())
132
133    @staticmethod
134    # pyre-fixme[3]: Return type must be annotated.
135    # pyre-fixme[2]: Parameter must be annotated.
136    def handle_view(args, kwargs):
137        # TODO, do we need to change the shape of associated offsets?
138        res_shards_list = [
139            aten.view.default(shard, args[1], **kwargs)
140            for shard in args[0].local_shards()
141        ]
142        return LocalShardsWrapper(res_shards_list, args[0].local_offsets())
143
144    @staticmethod
145    # pyre-fixme[3]: Return type must be annotated.
146    # pyre-fixme[2]: Parameter must be annotated.
147    def handle_equal(args, kwargs):
148        """
149        LocalShardsWrapper equal impl also checks for equality of storage metadata
150        and the order of shards
151        """
152        a, b = args[0], args[1]
153        if len(a.local_shards()) != len(b.local_shards()):
154            return False
155        if not all(
156            aten.equal.default(x, y) for x, y in zip(a.local_shards(), b.local_shards())
157        ):
158            return False
159        if not a.storage_metadata() == b.storage_metadata():
160            return False
161        return True
162
163    @staticmethod
164    # pyre-fixme[3]: Return type must be annotated.
165    # pyre-fixme[2]: Parameter must be annotated.
166    def handle_detach(args, kwargs):
167        self_ls = args[0]
168        deatched_local_shards = [
169            aten.detach.default(shard) for shard in self_ls.local_shards()
170        ]
171        self_ls._local_shards = deatched_local_shards
172        self_ls._storage_meta.properties.requires_grad = False
173        return self_ls
174
175    @staticmethod
176    # pyre-fixme[3]: Return type must be annotated.
177    # pyre-fixme[2]: Parameter must be annotated.
178    def handle_clone(args, kwargs):
179        self_ls = args[0]
180        desired_memory_format = kwargs.get("memory_format", None)
181        if desired_memory_format and desired_memory_format != torch.preserve_format:
182            raise NotImplementedError(
183                f"{desired_memory_format} is not supported for LocalShardsWrapper!"
184            )
185        cloned_local_shards = [
186            shard.clone(memory_format=desired_memory_format)
187            for shard in self_ls._local_shards
188        ]
189        return LocalShardsWrapper(cloned_local_shards, self_ls.local_offsets())
190
191    @property
192    def device(self) -> torch._C.device:  # type: ignore[override]
193        return self._local_shards[0].device
194
195    @property
196    def is_meta(self) -> bool:  # type: ignore[override]
197        return self._local_shards[0].is_meta
198
199    # pyre-ignore[14]
200    def is_pinned(self) -> bool:  # type: ignore[override]
201        return self._storage_meta.properties.pin_memory
202
203    # pyre-ignore[14]
204    def requires_grad_(self, requires_grad: bool = True) -> "LocalShardsWrapper":
205        self._storage_meta.properties.requires_grad = requires_grad
206        [shard.requires_grad_(requires_grad) for shard in self._local_shards]
207        return self
208
209    def local_shards(self) -> List[torch.Tensor]:
210        """
211        Returns a list of :class:`torch.Tensor' corresponding to the
212        local shards for this rank. Returns an empty list if the current rank
213        does not host any shards for this Tensor.
214        """
215        return self._local_shards
216
217    def local_sizes(self) -> List[torch.Size]:
218        """
219        Returns a list of :class:`torch.Size' corresponding to the
220        local sizes for the shards on this rank. Returns an empty list if the current rank
221        does not host any shards for this Tensor.
222        """
223        return [chunk.sizes for chunk in self._storage_meta.chunks]
224
225    def local_offsets(self) -> List[torch.Size]:
226        """
227        Returns a list of :class:`torch.Size' corresponding to the
228        local offsets for the shards on this rank. Returns an empty list if the current rank
229        does not host any shards for this Tensor.
230        """
231        return [chunk.offsets for chunk in self._storage_meta.chunks]
232
233    @property
234    def local_chunks(self) -> List[ChunkStorageMetadata]:
235        """
236        Returns a :class:`List[ChunkStorageMetadata]` object corresponding to the
237        metadata for each tensor shard
238        """
239        return self._storage_meta.chunks
240
241    def storage_metadata(self) -> TensorStorageMetadata:
242        """
243        Returns a :class:`TensorStorageMetadata` object corresponding to the
244        metadata for the local tensor on current rank
245        """
246        return self._storage_meta
247
248    def __create_write_items__(
249        self, fqn: str, object: Any
250    ) -> List[WriteItem]:  # pyre-ignore[2]
251        """
252        For compatibility with DCP, we support creation of WriteItems
253        such that they can be saved properly.
254        """
255        return [
256            WriteItem(
257                index=MetadataIndex(fqn, chunks.offsets),
258                type=WriteItemType.SHARD,
259                tensor_data=TensorWriteData(
260                    chunk=ChunkStorageMetadata(
261                        offsets=chunks.offsets,
262                        sizes=chunks.sizes,
263                    ),
264                    properties=self._storage_meta.properties,
265                    size=object.size(),
266                ),
267            )
268            for tensor, chunks in zip(self.local_shards(), self.local_chunks)
269        ]
270
271    def __create_chunk_list__(self) -> List[ChunkStorageMetadata]:
272        """
273        For compatibility with DCP, we support creation of chunk lists
274        such that they can be saved properly.
275        """
276        return self._storage_meta.chunks
277
278    def __get_tensor_shard__(self, index: MetadataIndex) -> torch.Tensor:
279        """
280        For compatibility with DCP, we support finding shard based on index
281        Return a 'torch.Tensor' shard based on 'MetadataIndex'.
282        """
283        # Fast lookup path
284        if index.index is not None:
285            if (
286                len(self._local_shards) > index.index
287                and self._storage_meta.chunks[index.index].offsets == index.offset
288            ):
289                return self._local_shards[index.index]
290
291        if index.offset is not None:
292            for shard, chunk in zip(self._local_shards, self._storage_meta.chunks):
293                if chunk.offsets == index.offset:
294                    return shard
295
296        raise ValueError(
297            f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'"
298        )
299
300    def _get_tensor_size_bytes(self) -> int:
301        object_size = 0
302        for shard in self.local_shards():
303            object_size += shard.nelement() * shard.element_size()
304        return object_size
305
306    # pyre-fixme[3]: Return type must be annotated.
307    def __hash__(self):
308        return id(self)
309
310    # pyre-fixme[14]: `__repr__` overrides method defined in `torch._tensor.Tensor` inconsistently.
311    # pyre-fixme[3]: Return type must be annotated.
312    def __repr__(self):
313        return f"LocalShardsWrapper:{self._local_shards} {self._storage_meta}"
314
315    def __str__(self) -> str:
316        return f"LocalShardsWrapper:{self._local_shards} {self._storage_meta}"
317