1# mypy: allow-untyped-defs 2""" 3The following example demonstrates how to represent torchrec's embedding 4sharding with the DTensor API. 5""" 6import argparse 7import os 8from functools import cached_property 9from typing import List, TYPE_CHECKING 10 11import torch 12from torch.distributed.checkpoint.metadata import ( 13 ChunkStorageMetadata, 14 TensorProperties, 15 TensorStorageMetadata, 16) 17from torch.distributed.tensor import ( 18 DeviceMesh, 19 DTensor, 20 init_device_mesh, 21 Replicate, 22 Shard, 23) 24from torch.distributed.tensor.debug import visualize_sharding 25 26 27if TYPE_CHECKING: 28 from torch.distributed.tensor.placement_types import Placement 29 30 31def get_device_type(): 32 return ( 33 "cuda" 34 if torch.cuda.is_available() and torch.cuda.device_count() >= 4 35 else "cpu" 36 ) 37 38 39aten = torch.ops.aten 40supported_ops = [aten.view.default, aten._to_copy.default] 41 42 43# this torch.Tensor subclass is a wrapper around all local shards associated 44# with a single sharded embedding table. 45class LocalShardsWrapper(torch.Tensor): 46 local_shards: List[torch.Tensor] 47 storage_meta: TensorStorageMetadata 48 49 @staticmethod 50 def __new__( 51 cls, local_shards: List[torch.Tensor], offsets: List[torch.Size] 52 ) -> "LocalShardsWrapper": 53 assert len(local_shards) > 0 54 assert len(local_shards) == len(offsets) 55 assert local_shards[0].ndim == 2 56 # we calculate the total tensor size by "concat" on second tensor dimension 57 cat_tensor_shape = list(local_shards[0].shape) 58 if len(local_shards) > 1: # column-wise sharding 59 for shard_size in [s.shape for s in local_shards[1:]]: 60 cat_tensor_shape[1] += shard_size[1] 61 62 # according to DCP, each chunk is expected to have the same properties of the 63 # TensorStorageMetadata that includes it. Vice versa, the wrapper's properties 64 # should also be the same with that of its first chunk. 65 wrapper_properties = TensorProperties.create_from_tensor(local_shards[0]) 66 wrapper_shape = torch.Size(cat_tensor_shape) 67 chunks_meta = [ 68 ChunkStorageMetadata(o, s.shape) for s, o in zip(local_shards, offsets) 69 ] 70 71 r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] 72 cls, 73 wrapper_shape, 74 ) 75 r.shards = local_shards 76 r.storage_meta = TensorStorageMetadata( 77 properties=wrapper_properties, 78 size=wrapper_shape, 79 chunks=chunks_meta, 80 ) 81 82 return r 83 84 # necessary for ops dispatching from this subclass to its local shards 85 @classmethod 86 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 87 kwargs = kwargs or {} 88 89 # TODO: we shall continually extend this function to support more ops if needed 90 if func in supported_ops: 91 res_shards_list = [ 92 func(shard, *args[1:], **kwargs) for shard in args[0].shards 93 ] 94 return LocalShardsWrapper(res_shards_list, args[0].shard_offsets) 95 else: 96 raise NotImplementedError( 97 f"{func} is not supported for LocalShardsWrapper!" 98 ) 99 100 @property 101 def shards(self) -> List[torch.Tensor]: 102 return self.local_shards 103 104 @shards.setter 105 def shards(self, local_shards: List[torch.Tensor]): 106 self.local_shards = local_shards 107 108 @cached_property 109 def shard_sizes(self) -> List[torch.Size]: 110 return [chunk.sizes for chunk in self.storage_meta.chunks] 111 112 @cached_property 113 def shard_offsets(self) -> List[torch.Size]: 114 return [chunk.offsets for chunk in self.storage_meta.chunks] 115 116 117def run_torchrec_row_wise_even_sharding_example(rank, world_size): 118 # row-wise even sharding example: 119 # One table is evenly sharded by rows within the global ProcessGroup. 120 # In our example, the table's num_embedding is 8, and the embedding dim is 16 121 # The global ProcessGroup has 4 ranks, so each rank will have one 2 by 16 local 122 # shard. 123 124 # device mesh is a representation of the worker ranks 125 # create a 1-D device mesh that includes every rank 126 device_type = get_device_type() 127 device = torch.device(device_type) 128 device_mesh = init_device_mesh(device_type=device_type, mesh_shape=(world_size,)) 129 130 # manually create the embedding table's local shards 131 num_embeddings = 8 132 embedding_dim = 16 133 emb_table_shape = torch.Size([num_embeddings, embedding_dim]) 134 # tensor shape 135 local_shard_shape = torch.Size( 136 [num_embeddings // world_size, embedding_dim] # (local_rows, local_cols) 137 ) 138 # tensor offset 139 local_shard_offset = torch.Size((rank * 2, embedding_dim)) 140 # tensor 141 local_tensor = torch.randn(local_shard_shape, device=device) 142 # row-wise sharding: one shard per rank 143 # create the local shards wrapper 144 local_shards_wrapper = LocalShardsWrapper( 145 local_shards=[local_tensor], 146 offsets=[local_shard_offset], 147 ) 148 149 ########################################################################### 150 # example 1: transform local_shards into DTensor 151 # usage in TorchRec: 152 # ShardedEmbeddingCollection stores model parallel params in 153 # _model_parallel_name_to_sharded_tensor which is initialized in 154 # _initialize_torch_state() and torch.Tensor params are transformed 155 # into ShardedTensor by ShardedTensor._init_from_local_shards(). 156 # 157 # This allows state_dict() to always return ShardedTensor objects. 158 159 # this is the sharding placement we use in DTensor to represent row-wise sharding 160 # row_wise_sharding_placements means that the global tensor is sharded by first dim 161 # over the 1-d mesh. 162 row_wise_sharding_placements: List[Placement] = [Shard(0)] 163 164 # create a DTensor from the local shard 165 dtensor = DTensor.from_local( 166 local_shards_wrapper, device_mesh, row_wise_sharding_placements, run_check=False 167 ) 168 169 # display the DTensor's sharding 170 visualize_sharding(dtensor, header="Row-wise even sharding example in DTensor") 171 172 ########################################################################### 173 # example 2: transform DTensor into local_shards 174 # usage in TorchRec: 175 # In ShardedEmbeddingCollection's load_state_dict pre hook 176 # _pre_load_state_dict_hook, if the source param is a ShardedTensor 177 # then we need to transform it into its local_shards. 178 179 # transform DTensor into LocalShardsWrapper 180 dtensor_local_shards = dtensor.to_local() 181 assert isinstance(dtensor_local_shards, LocalShardsWrapper) 182 shard_tensor = dtensor_local_shards.shards[0] 183 assert torch.equal(shard_tensor, local_tensor) 184 assert dtensor_local_shards.shard_sizes[0] == local_shard_shape # unwrap shape 185 assert dtensor_local_shards.shard_offsets[0] == local_shard_offset # unwrap offset 186 187 188def run_torchrec_row_wise_uneven_sharding_example(rank, world_size): 189 # row-wise uneven sharding example: 190 # One table is unevenly sharded by rows within the global ProcessGroup. 191 # In our example, the table's num_embedding is 8, and the embedding dim is 16 192 # The global ProcessGroup has 4 ranks, and each rank will have the local shard 193 # of shape: 194 # rank 0: [1, 16] 195 # rank 1: [3, 16] 196 # rank 2: [1, 16] 197 # rank 3: [3, 16] 198 199 # device mesh is a representation of the worker ranks 200 # create a 1-D device mesh that includes every rank 201 device_type = get_device_type() 202 device = torch.device(device_type) 203 device_mesh = init_device_mesh(device_type=device_type, mesh_shape=(world_size,)) 204 205 # manually create the embedding table's local shards 206 num_embeddings = 8 207 embedding_dim = 16 208 emb_table_shape = torch.Size([num_embeddings, embedding_dim]) 209 # tensor shape 210 local_shard_shape = ( 211 torch.Size([1, embedding_dim]) 212 if rank % 2 == 0 213 else torch.Size([3, embedding_dim]) 214 ) 215 # tensor offset 216 local_shard_offset = torch.Size((rank // 2 * 4 + rank % 2 * 1, embedding_dim)) 217 # tensor 218 local_tensor = torch.randn(local_shard_shape, device=device) 219 # local shards 220 # row-wise sharding: one shard per rank 221 # create the local shards wrapper 222 local_shards_wrapper = LocalShardsWrapper( 223 local_shards=[local_tensor], 224 offsets=[local_shard_offset], 225 ) 226 227 ########################################################################### 228 # example 1: transform local_shards into DTensor 229 # create the DTensorMetadata which torchrec should provide 230 row_wise_sharding_placements: List[Placement] = [Shard(0)] 231 232 # note: for uneven sharding, we need to specify the shape and stride because 233 # DTensor would assume even sharding and compute shape/stride based on the 234 # assumption. Torchrec needs to pass in this information explicitely. 235 # shape/stride are global tensor's shape and stride 236 dtensor = DTensor.from_local( 237 local_shards_wrapper, # a torch.Tensor subclass 238 device_mesh, # DeviceMesh 239 row_wise_sharding_placements, # List[Placement] 240 run_check=False, 241 shape=emb_table_shape, # this is required for uneven sharding 242 stride=(embedding_dim, 1), 243 ) 244 # so far visualize_sharding() cannot print correctly for unevenly sharded DTensor 245 # because it relies on offset computation which assumes even sharding. 246 visualize_sharding(dtensor, header="Row-wise uneven sharding example in DTensor") 247 # check the dtensor has the correct shape and stride on all ranks 248 assert dtensor.shape == emb_table_shape 249 assert dtensor.stride() == (embedding_dim, 1) 250 251 ########################################################################### 252 # example 2: transform DTensor into local_shards 253 # note: DTensor.to_local() always returns a LocalShardsWrapper 254 dtensor_local_shards = dtensor.to_local() 255 assert isinstance(dtensor_local_shards, LocalShardsWrapper) 256 shard_tensor = dtensor_local_shards.shards[0] 257 assert torch.equal(shard_tensor, local_tensor) 258 assert dtensor_local_shards.shard_sizes[0] == local_shard_shape # unwrap shape 259 assert dtensor_local_shards.shard_offsets[0] == local_shard_offset # unwrap offset 260 261 262def run_torchrec_table_wise_sharding_example(rank, world_size): 263 # table-wise example: 264 # each rank in the global ProcessGroup holds one different table. 265 # In our example, the table's num_embedding is 8, and the embedding dim is 16 266 # The global ProcessGroup has 4 ranks, so each rank will have one 8 by 16 complete 267 # table as its local shard. 268 269 device_type = get_device_type() 270 device = torch.device(device_type) 271 # note: without initializing this mesh, the following local_tensor will be put on 272 # device cuda:0. 273 device_mesh = init_device_mesh(device_type=device_type, mesh_shape=(world_size,)) 274 275 # manually create the embedding table's local shards 276 num_embeddings = 8 277 embedding_dim = 16 278 emb_table_shape = torch.Size([num_embeddings, embedding_dim]) 279 280 # for table i, if the current rank holds the table, then the local shard is 281 # a LocalShardsWrapper containing the tensor; otherwise the local shard is 282 # an empty torch.Tensor 283 table_to_shards = {} # map {table_id: local shard of table_id} 284 table_to_local_tensor = {} # map {table_id: local tensor of table_id} 285 # create 4 embedding tables and place them on different ranks 286 # each rank will hold one complete table, and the dict will store 287 # the corresponding local shard. 288 for i in range(world_size): 289 # tensor 290 local_tensor = ( 291 torch.randn(*emb_table_shape, device=device) 292 if rank == i 293 else torch.empty(0, device=device) 294 ) 295 table_to_local_tensor[i] = local_tensor 296 # tensor shape 297 local_shard_shape = local_tensor.shape 298 # tensor offset 299 local_shard_offset = torch.Size((0, 0)) 300 # wrap local shards into a wrapper 301 local_shards_wrapper = ( 302 LocalShardsWrapper( 303 local_shards=[local_tensor], 304 offsets=[local_shard_offset], 305 ) 306 if rank == i 307 else local_tensor 308 ) 309 table_to_shards[i] = local_shards_wrapper 310 311 ########################################################################### 312 # example 1: transform local_shards into DTensor 313 table_to_dtensor = {} # same purpose as _model_parallel_name_to_sharded_tensor 314 table_wise_sharding_placements = [Replicate()] # table-wise sharding 315 316 for table_id, local_shards in table_to_shards.items(): 317 # create a submesh that only contains the rank we place the table 318 # note that we cannot use ``init_device_mesh'' to create a submesh 319 # so we choose to use the `DeviceMesh` api to directly create a DeviceMesh 320 device_submesh = DeviceMesh( 321 device_type=device_type, 322 mesh=torch.tensor( 323 [table_id], dtype=torch.int64 324 ), # table ``table_id`` is placed on rank ``table_id`` 325 ) 326 # create a DTensor from the local shard for the current table 327 # note: for uneven sharding, we need to specify the shape and stride because 328 # DTensor would assume even sharding and compute shape/stride based on the 329 # assumption. Torchrec needs to pass in this information explicitely. 330 dtensor = DTensor.from_local( 331 local_shards, 332 device_submesh, 333 table_wise_sharding_placements, 334 run_check=False, 335 shape=emb_table_shape, # this is required for uneven sharding 336 stride=(embedding_dim, 1), 337 ) 338 table_to_dtensor[table_id] = dtensor 339 340 # print each table's sharding 341 for table_id, dtensor in table_to_dtensor.items(): 342 visualize_sharding( 343 dtensor, 344 header=f"Table-wise sharding example in DTensor for Table {table_id}", 345 ) 346 # check the dtensor has the correct shape and stride on all ranks 347 assert dtensor.shape == emb_table_shape 348 assert dtensor.stride() == (embedding_dim, 1) 349 350 ########################################################################### 351 # example 2: transform DTensor into torch.Tensor 352 for table_id, local_tensor in table_to_local_tensor.items(): 353 # important: note that DTensor.to_local() always returns an empty torch.Tensor 354 # no matter what was passed to DTensor._local_tensor. 355 dtensor_local_shards = table_to_dtensor[table_id].to_local() 356 if rank == table_id: 357 assert isinstance(dtensor_local_shards, LocalShardsWrapper) 358 shard_tensor = dtensor_local_shards.shards[0] 359 assert torch.equal(shard_tensor, local_tensor) # unwrap tensor 360 assert ( 361 dtensor_local_shards.shard_sizes[0] == emb_table_shape 362 ) # unwrap shape 363 assert dtensor_local_shards.shard_offsets[0] == torch.Size( 364 (0, 0) 365 ) # unwrap offset 366 else: 367 assert dtensor_local_shards.numel() == 0 368 369 370def run_example(rank, world_size, example_name): 371 # the dict that stores example code 372 name_to_example_code = { 373 "row-wise-even": run_torchrec_row_wise_even_sharding_example, 374 "row-wise-uneven": run_torchrec_row_wise_uneven_sharding_example, 375 "table-wise": run_torchrec_table_wise_sharding_example, 376 } 377 if example_name not in name_to_example_code: 378 print(f"example for {example_name} does not exist!") 379 return 380 381 # the example to run 382 example_func = name_to_example_code[example_name] 383 384 # set manual seed 385 torch.manual_seed(0) 386 387 # run the example 388 example_func(rank, world_size) 389 390 391if __name__ == "__main__": 392 # this script is launched via torchrun which automatically manages ProcessGroup 393 rank = int(os.environ["RANK"]) 394 world_size = int(os.environ["WORLD_SIZE"]) 395 assert world_size == 4 # our example uses 4 worker ranks 396 # parse the arguments 397 parser = argparse.ArgumentParser( 398 description="torchrec sharding examples", 399 formatter_class=argparse.RawTextHelpFormatter, 400 ) 401 example_prompt = ( 402 "choose one sharding example from below:\n" 403 "\t1. row-wise-even;\n" 404 "\t2. row-wise-uneven\n" 405 "\t3. table-wise\n" 406 "e.g. you want to try the row-wise even sharding example, please input 'row-wise-even'\n" 407 ) 408 parser.add_argument("-e", "--example", help=example_prompt, required=True) 409 args = parser.parse_args() 410 run_example(rank, world_size, args.example) 411