1# mypy: allow-untyped-defs 2# Copyright (c) Meta Platforms, Inc. and affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8from typing import Any, List, Tuple 9 10import torch 11from torch.distributed.checkpoint.metadata import ( 12 ChunkStorageMetadata, 13 MetadataIndex, 14 TensorProperties, 15 TensorStorageMetadata, 16) 17from torch.distributed.checkpoint.planner import ( 18 TensorWriteData, 19 WriteItem, 20 WriteItemType, 21) 22 23 24aten = ( 25 torch.ops.aten 26) # pyre-ignore[5]: Globally accessible variable `aten` has no type specified. 27 28 29class LocalShardsWrapper(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__ 30 """ 31 A wrapper class to hold local shards of a DTensor. 32 This class is used largely for checkpointing purposes and implicity subtypes 33 the _Checkpointable protocol. 34 """ 35 36 __slots__ = ["_local_shards", "_storage_meta"] 37 _local_shards: List[torch.Tensor] 38 _storage_meta: TensorStorageMetadata 39 40 @staticmethod 41 def __new__( 42 cls, local_shards: List[torch.Tensor], local_offsets: List[Tuple[int, ...]] 43 ) -> "LocalShardsWrapper": 44 assert len(local_shards) > 0 45 assert len(local_shards) == len(local_offsets) 46 assert all( 47 tensor.device == local_shards[0].device for tensor in local_shards[1:] 48 ) 49 50 # we calculate the total tensor size by "concat" on second tensor dimension 51 cat_tensor_shape = list(local_shards[0].size()) 52 if len(local_shards) > 1: # column-wise sharding 53 for shard in local_shards[1:]: 54 cat_tensor_shape[1] += shard.size()[1] 55 56 wrapper_properties = TensorProperties.create_from_tensor(local_shards[0]) 57 wrapper_shape = torch.Size(cat_tensor_shape) 58 chunks_meta = [ 59 ChunkStorageMetadata( 60 offsets=torch.Size(offset), 61 sizes=shard.size(), 62 ) 63 for shard, offset in zip(local_shards, local_offsets) 64 ] 65 66 r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] 67 cls, 68 torch.Size(cat_tensor_shape), 69 ) 70 r._local_shards = local_shards 71 r._storage_meta = TensorStorageMetadata( 72 properties=wrapper_properties, 73 size=wrapper_shape, 74 chunks=chunks_meta, 75 ) 76 77 return r 78 79 # necessary for ops dispatching from this subclass to its local shards 80 @classmethod 81 # pyre-fixme[3]: Return type must be annotated. 82 # pyre-fixme[2]: Parameter must be annotated. 83 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 84 kwargs = kwargs or {} 85 86 dispatcher = { 87 torch.ops._c10d_functional.all_gather_into_tensor.default: cls.handle_all_gather_into_tensor, 88 torch.ops._c10d_functional.wait_tensor.default: cls.handle_wait_tensor, 89 aten._to_copy.default: cls.handle_to_copy, 90 aten.view.default: cls.handle_view, 91 aten.equal.default: cls.handle_equal, 92 aten.detach.default: cls.handle_detach, 93 aten.clone.default: cls.handle_clone, 94 } 95 96 if func in dispatcher: 97 return dispatcher[func]( 98 args, kwargs 99 ) # pyre-ignore [29] - `Variable[_VT]` is not a function. 100 else: 101 raise NotImplementedError( 102 f"{func} is not supported for LocalShardsWrapper!" 103 ) 104 105 @staticmethod 106 # pyre-fixme[3]: Return type must be annotated. 107 # pyre-fixme[2]: Parameter must be annotated. 108 def handle_all_gather_into_tensor(args, kwargs): 109 dim = args[0].local_sizes()[0][1] 110 cat_tensor = torch.cat( 111 [t.view(-1) for t in args[0].local_shards()], dim=0 112 ).view(-1, dim) 113 return torch.ops._c10d_functional.all_gather_into_tensor.default( 114 cat_tensor, *args[1:], **kwargs 115 ) 116 117 @staticmethod 118 # pyre-fixme[3]: Return type must be annotated. 119 # pyre-fixme[2]: Parameter must be annotated. 120 def handle_wait_tensor(args, kwargs): 121 return torch.ops._c10d_functional.wait_tensor(args[0]) 122 123 @staticmethod 124 # pyre-fixme[3]: Return type must be annotated. 125 # pyre-fixme[2]: Parameter must be annotated. 126 def handle_to_copy(args, kwargs): 127 res_shards_list = [ 128 aten._to_copy.default(shard, *args[1:], **kwargs) 129 for shard in args[0].local_shards() 130 ] 131 return LocalShardsWrapper(res_shards_list, args[0].local_offsets()) 132 133 @staticmethod 134 # pyre-fixme[3]: Return type must be annotated. 135 # pyre-fixme[2]: Parameter must be annotated. 136 def handle_view(args, kwargs): 137 # TODO, do we need to change the shape of associated offsets? 138 res_shards_list = [ 139 aten.view.default(shard, args[1], **kwargs) 140 for shard in args[0].local_shards() 141 ] 142 return LocalShardsWrapper(res_shards_list, args[0].local_offsets()) 143 144 @staticmethod 145 # pyre-fixme[3]: Return type must be annotated. 146 # pyre-fixme[2]: Parameter must be annotated. 147 def handle_equal(args, kwargs): 148 """ 149 LocalShardsWrapper equal impl also checks for equality of storage metadata 150 and the order of shards 151 """ 152 a, b = args[0], args[1] 153 if len(a.local_shards()) != len(b.local_shards()): 154 return False 155 if not all( 156 aten.equal.default(x, y) for x, y in zip(a.local_shards(), b.local_shards()) 157 ): 158 return False 159 if not a.storage_metadata() == b.storage_metadata(): 160 return False 161 return True 162 163 @staticmethod 164 # pyre-fixme[3]: Return type must be annotated. 165 # pyre-fixme[2]: Parameter must be annotated. 166 def handle_detach(args, kwargs): 167 self_ls = args[0] 168 deatched_local_shards = [ 169 aten.detach.default(shard) for shard in self_ls.local_shards() 170 ] 171 self_ls._local_shards = deatched_local_shards 172 self_ls._storage_meta.properties.requires_grad = False 173 return self_ls 174 175 @staticmethod 176 # pyre-fixme[3]: Return type must be annotated. 177 # pyre-fixme[2]: Parameter must be annotated. 178 def handle_clone(args, kwargs): 179 self_ls = args[0] 180 desired_memory_format = kwargs.get("memory_format", None) 181 if desired_memory_format and desired_memory_format != torch.preserve_format: 182 raise NotImplementedError( 183 f"{desired_memory_format} is not supported for LocalShardsWrapper!" 184 ) 185 cloned_local_shards = [ 186 shard.clone(memory_format=desired_memory_format) 187 for shard in self_ls._local_shards 188 ] 189 return LocalShardsWrapper(cloned_local_shards, self_ls.local_offsets()) 190 191 @property 192 def device(self) -> torch._C.device: # type: ignore[override] 193 return self._local_shards[0].device 194 195 @property 196 def is_meta(self) -> bool: # type: ignore[override] 197 return self._local_shards[0].is_meta 198 199 # pyre-ignore[14] 200 def is_pinned(self) -> bool: # type: ignore[override] 201 return self._storage_meta.properties.pin_memory 202 203 # pyre-ignore[14] 204 def requires_grad_(self, requires_grad: bool = True) -> "LocalShardsWrapper": 205 self._storage_meta.properties.requires_grad = requires_grad 206 [shard.requires_grad_(requires_grad) for shard in self._local_shards] 207 return self 208 209 def local_shards(self) -> List[torch.Tensor]: 210 """ 211 Returns a list of :class:`torch.Tensor' corresponding to the 212 local shards for this rank. Returns an empty list if the current rank 213 does not host any shards for this Tensor. 214 """ 215 return self._local_shards 216 217 def local_sizes(self) -> List[torch.Size]: 218 """ 219 Returns a list of :class:`torch.Size' corresponding to the 220 local sizes for the shards on this rank. Returns an empty list if the current rank 221 does not host any shards for this Tensor. 222 """ 223 return [chunk.sizes for chunk in self._storage_meta.chunks] 224 225 def local_offsets(self) -> List[torch.Size]: 226 """ 227 Returns a list of :class:`torch.Size' corresponding to the 228 local offsets for the shards on this rank. Returns an empty list if the current rank 229 does not host any shards for this Tensor. 230 """ 231 return [chunk.offsets for chunk in self._storage_meta.chunks] 232 233 @property 234 def local_chunks(self) -> List[ChunkStorageMetadata]: 235 """ 236 Returns a :class:`List[ChunkStorageMetadata]` object corresponding to the 237 metadata for each tensor shard 238 """ 239 return self._storage_meta.chunks 240 241 def storage_metadata(self) -> TensorStorageMetadata: 242 """ 243 Returns a :class:`TensorStorageMetadata` object corresponding to the 244 metadata for the local tensor on current rank 245 """ 246 return self._storage_meta 247 248 def __create_write_items__( 249 self, fqn: str, object: Any 250 ) -> List[WriteItem]: # pyre-ignore[2] 251 """ 252 For compatibility with DCP, we support creation of WriteItems 253 such that they can be saved properly. 254 """ 255 return [ 256 WriteItem( 257 index=MetadataIndex(fqn, chunks.offsets), 258 type=WriteItemType.SHARD, 259 tensor_data=TensorWriteData( 260 chunk=ChunkStorageMetadata( 261 offsets=chunks.offsets, 262 sizes=chunks.sizes, 263 ), 264 properties=self._storage_meta.properties, 265 size=object.size(), 266 ), 267 ) 268 for tensor, chunks in zip(self.local_shards(), self.local_chunks) 269 ] 270 271 def __create_chunk_list__(self) -> List[ChunkStorageMetadata]: 272 """ 273 For compatibility with DCP, we support creation of chunk lists 274 such that they can be saved properly. 275 """ 276 return self._storage_meta.chunks 277 278 def __get_tensor_shard__(self, index: MetadataIndex) -> torch.Tensor: 279 """ 280 For compatibility with DCP, we support finding shard based on index 281 Return a 'torch.Tensor' shard based on 'MetadataIndex'. 282 """ 283 # Fast lookup path 284 if index.index is not None: 285 if ( 286 len(self._local_shards) > index.index 287 and self._storage_meta.chunks[index.index].offsets == index.offset 288 ): 289 return self._local_shards[index.index] 290 291 if index.offset is not None: 292 for shard, chunk in zip(self._local_shards, self._storage_meta.chunks): 293 if chunk.offsets == index.offset: 294 return shard 295 296 raise ValueError( 297 f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'" 298 ) 299 300 def _get_tensor_size_bytes(self) -> int: 301 object_size = 0 302 for shard in self.local_shards(): 303 object_size += shard.nelement() * shard.element_size() 304 return object_size 305 306 # pyre-fixme[3]: Return type must be annotated. 307 def __hash__(self): 308 return id(self) 309 310 # pyre-fixme[14]: `__repr__` overrides method defined in `torch._tensor.Tensor` inconsistently. 311 # pyre-fixme[3]: Return type must be annotated. 312 def __repr__(self): 313 return f"LocalShardsWrapper:{self._local_shards} {self._storage_meta}" 314 315 def __str__(self) -> str: 316 return f"LocalShardsWrapper:{self._local_shards} {self._storage_meta}" 317