• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2"""
3The following example demonstrates how to represent torchrec's embedding
4sharding with the DTensor API.
5"""
6import argparse
7import os
8from functools import cached_property
9from typing import List, TYPE_CHECKING
10
11import torch
12from torch.distributed.checkpoint.metadata import (
13    ChunkStorageMetadata,
14    TensorProperties,
15    TensorStorageMetadata,
16)
17from torch.distributed.tensor import (
18    DeviceMesh,
19    DTensor,
20    init_device_mesh,
21    Replicate,
22    Shard,
23)
24from torch.distributed.tensor.debug import visualize_sharding
25
26
27if TYPE_CHECKING:
28    from torch.distributed.tensor.placement_types import Placement
29
30
31def get_device_type():
32    return (
33        "cuda"
34        if torch.cuda.is_available() and torch.cuda.device_count() >= 4
35        else "cpu"
36    )
37
38
39aten = torch.ops.aten
40supported_ops = [aten.view.default, aten._to_copy.default]
41
42
43# this torch.Tensor subclass is a wrapper around all local shards associated
44# with a single sharded embedding table.
45class LocalShardsWrapper(torch.Tensor):
46    local_shards: List[torch.Tensor]
47    storage_meta: TensorStorageMetadata
48
49    @staticmethod
50    def __new__(
51        cls, local_shards: List[torch.Tensor], offsets: List[torch.Size]
52    ) -> "LocalShardsWrapper":
53        assert len(local_shards) > 0
54        assert len(local_shards) == len(offsets)
55        assert local_shards[0].ndim == 2
56        # we calculate the total tensor size by "concat" on second tensor dimension
57        cat_tensor_shape = list(local_shards[0].shape)
58        if len(local_shards) > 1:  # column-wise sharding
59            for shard_size in [s.shape for s in local_shards[1:]]:
60                cat_tensor_shape[1] += shard_size[1]
61
62        # according to DCP, each chunk is expected to have the same properties of the
63        # TensorStorageMetadata that includes it. Vice versa, the wrapper's properties
64        # should also be the same with that of its first chunk.
65        wrapper_properties = TensorProperties.create_from_tensor(local_shards[0])
66        wrapper_shape = torch.Size(cat_tensor_shape)
67        chunks_meta = [
68            ChunkStorageMetadata(o, s.shape) for s, o in zip(local_shards, offsets)
69        ]
70
71        r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
72            cls,
73            wrapper_shape,
74        )
75        r.shards = local_shards
76        r.storage_meta = TensorStorageMetadata(
77            properties=wrapper_properties,
78            size=wrapper_shape,
79            chunks=chunks_meta,
80        )
81
82        return r
83
84    # necessary for ops dispatching from this subclass to its local shards
85    @classmethod
86    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
87        kwargs = kwargs or {}
88
89        # TODO: we shall continually extend this function to support more ops if needed
90        if func in supported_ops:
91            res_shards_list = [
92                func(shard, *args[1:], **kwargs) for shard in args[0].shards
93            ]
94            return LocalShardsWrapper(res_shards_list, args[0].shard_offsets)
95        else:
96            raise NotImplementedError(
97                f"{func} is not supported for LocalShardsWrapper!"
98            )
99
100    @property
101    def shards(self) -> List[torch.Tensor]:
102        return self.local_shards
103
104    @shards.setter
105    def shards(self, local_shards: List[torch.Tensor]):
106        self.local_shards = local_shards
107
108    @cached_property
109    def shard_sizes(self) -> List[torch.Size]:
110        return [chunk.sizes for chunk in self.storage_meta.chunks]
111
112    @cached_property
113    def shard_offsets(self) -> List[torch.Size]:
114        return [chunk.offsets for chunk in self.storage_meta.chunks]
115
116
117def run_torchrec_row_wise_even_sharding_example(rank, world_size):
118    # row-wise even sharding example:
119    #   One table is evenly sharded by rows within the global ProcessGroup.
120    #   In our example, the table's num_embedding is 8, and the embedding dim is 16
121    #   The global ProcessGroup has 4 ranks, so each rank will have one 2 by 16 local
122    #   shard.
123
124    # device mesh is a representation of the worker ranks
125    # create a 1-D device mesh that includes every rank
126    device_type = get_device_type()
127    device = torch.device(device_type)
128    device_mesh = init_device_mesh(device_type=device_type, mesh_shape=(world_size,))
129
130    # manually create the embedding table's local shards
131    num_embeddings = 8
132    embedding_dim = 16
133    emb_table_shape = torch.Size([num_embeddings, embedding_dim])
134    # tensor shape
135    local_shard_shape = torch.Size(
136        [num_embeddings // world_size, embedding_dim]  # (local_rows, local_cols)
137    )
138    # tensor offset
139    local_shard_offset = torch.Size((rank * 2, embedding_dim))
140    # tensor
141    local_tensor = torch.randn(local_shard_shape, device=device)
142    # row-wise sharding: one shard per rank
143    # create the local shards wrapper
144    local_shards_wrapper = LocalShardsWrapper(
145        local_shards=[local_tensor],
146        offsets=[local_shard_offset],
147    )
148
149    ###########################################################################
150    # example 1: transform local_shards into DTensor
151    # usage in TorchRec:
152    #   ShardedEmbeddingCollection stores model parallel params in
153    #   _model_parallel_name_to_sharded_tensor which is initialized in
154    #   _initialize_torch_state() and torch.Tensor params are transformed
155    #   into ShardedTensor by ShardedTensor._init_from_local_shards().
156    #
157    #   This allows state_dict() to always return ShardedTensor objects.
158
159    # this is the sharding placement we use in DTensor to represent row-wise sharding
160    # row_wise_sharding_placements means that the global tensor is sharded by first dim
161    # over the 1-d mesh.
162    row_wise_sharding_placements: List[Placement] = [Shard(0)]
163
164    # create a DTensor from the local shard
165    dtensor = DTensor.from_local(
166        local_shards_wrapper, device_mesh, row_wise_sharding_placements, run_check=False
167    )
168
169    # display the DTensor's sharding
170    visualize_sharding(dtensor, header="Row-wise even sharding example in DTensor")
171
172    ###########################################################################
173    # example 2: transform DTensor into local_shards
174    # usage in TorchRec:
175    #   In ShardedEmbeddingCollection's load_state_dict pre hook
176    #   _pre_load_state_dict_hook, if the source param is a ShardedTensor
177    #   then we need to transform it into its local_shards.
178
179    # transform DTensor into LocalShardsWrapper
180    dtensor_local_shards = dtensor.to_local()
181    assert isinstance(dtensor_local_shards, LocalShardsWrapper)
182    shard_tensor = dtensor_local_shards.shards[0]
183    assert torch.equal(shard_tensor, local_tensor)
184    assert dtensor_local_shards.shard_sizes[0] == local_shard_shape  # unwrap shape
185    assert dtensor_local_shards.shard_offsets[0] == local_shard_offset  # unwrap offset
186
187
188def run_torchrec_row_wise_uneven_sharding_example(rank, world_size):
189    # row-wise uneven sharding example:
190    #   One table is unevenly sharded by rows within the global ProcessGroup.
191    #   In our example, the table's num_embedding is 8, and the embedding dim is 16
192    #   The global ProcessGroup has 4 ranks, and each rank will have the local shard
193    #   of shape:
194    #       rank 0: [1, 16]
195    #       rank 1: [3, 16]
196    #       rank 2: [1, 16]
197    #       rank 3: [3, 16]
198
199    # device mesh is a representation of the worker ranks
200    # create a 1-D device mesh that includes every rank
201    device_type = get_device_type()
202    device = torch.device(device_type)
203    device_mesh = init_device_mesh(device_type=device_type, mesh_shape=(world_size,))
204
205    # manually create the embedding table's local shards
206    num_embeddings = 8
207    embedding_dim = 16
208    emb_table_shape = torch.Size([num_embeddings, embedding_dim])
209    # tensor shape
210    local_shard_shape = (
211        torch.Size([1, embedding_dim])
212        if rank % 2 == 0
213        else torch.Size([3, embedding_dim])
214    )
215    # tensor offset
216    local_shard_offset = torch.Size((rank // 2 * 4 + rank % 2 * 1, embedding_dim))
217    # tensor
218    local_tensor = torch.randn(local_shard_shape, device=device)
219    # local shards
220    # row-wise sharding: one shard per rank
221    # create the local shards wrapper
222    local_shards_wrapper = LocalShardsWrapper(
223        local_shards=[local_tensor],
224        offsets=[local_shard_offset],
225    )
226
227    ###########################################################################
228    # example 1: transform local_shards into DTensor
229    # create the DTensorMetadata which torchrec should provide
230    row_wise_sharding_placements: List[Placement] = [Shard(0)]
231
232    # note: for uneven sharding, we need to specify the shape and stride because
233    # DTensor would assume even sharding and compute shape/stride based on the
234    # assumption. Torchrec needs to pass in this information explicitely.
235    # shape/stride are global tensor's shape and stride
236    dtensor = DTensor.from_local(
237        local_shards_wrapper,  # a torch.Tensor subclass
238        device_mesh,  # DeviceMesh
239        row_wise_sharding_placements,  # List[Placement]
240        run_check=False,
241        shape=emb_table_shape,  # this is required for uneven sharding
242        stride=(embedding_dim, 1),
243    )
244    # so far visualize_sharding() cannot print correctly for unevenly sharded DTensor
245    # because it relies on offset computation which assumes even sharding.
246    visualize_sharding(dtensor, header="Row-wise uneven sharding example in DTensor")
247    # check the dtensor has the correct shape and stride on all ranks
248    assert dtensor.shape == emb_table_shape
249    assert dtensor.stride() == (embedding_dim, 1)
250
251    ###########################################################################
252    # example 2: transform DTensor into local_shards
253    # note: DTensor.to_local() always returns a LocalShardsWrapper
254    dtensor_local_shards = dtensor.to_local()
255    assert isinstance(dtensor_local_shards, LocalShardsWrapper)
256    shard_tensor = dtensor_local_shards.shards[0]
257    assert torch.equal(shard_tensor, local_tensor)
258    assert dtensor_local_shards.shard_sizes[0] == local_shard_shape  # unwrap shape
259    assert dtensor_local_shards.shard_offsets[0] == local_shard_offset  # unwrap offset
260
261
262def run_torchrec_table_wise_sharding_example(rank, world_size):
263    # table-wise example:
264    #   each rank in the global ProcessGroup holds one different table.
265    #   In our example, the table's num_embedding is 8, and the embedding dim is 16
266    #   The global ProcessGroup has 4 ranks, so each rank will have one 8 by 16 complete
267    #   table as its local shard.
268
269    device_type = get_device_type()
270    device = torch.device(device_type)
271    # note: without initializing this mesh, the following local_tensor will be put on
272    # device cuda:0.
273    device_mesh = init_device_mesh(device_type=device_type, mesh_shape=(world_size,))
274
275    # manually create the embedding table's local shards
276    num_embeddings = 8
277    embedding_dim = 16
278    emb_table_shape = torch.Size([num_embeddings, embedding_dim])
279
280    # for table i, if the current rank holds the table, then the local shard is
281    # a LocalShardsWrapper containing the tensor; otherwise the local shard is
282    # an empty torch.Tensor
283    table_to_shards = {}  # map {table_id: local shard of table_id}
284    table_to_local_tensor = {}  # map {table_id: local tensor of table_id}
285    # create 4 embedding tables and place them on different ranks
286    # each rank will hold one complete table, and the dict will store
287    # the corresponding local shard.
288    for i in range(world_size):
289        # tensor
290        local_tensor = (
291            torch.randn(*emb_table_shape, device=device)
292            if rank == i
293            else torch.empty(0, device=device)
294        )
295        table_to_local_tensor[i] = local_tensor
296        # tensor shape
297        local_shard_shape = local_tensor.shape
298        # tensor offset
299        local_shard_offset = torch.Size((0, 0))
300        # wrap local shards into a wrapper
301        local_shards_wrapper = (
302            LocalShardsWrapper(
303                local_shards=[local_tensor],
304                offsets=[local_shard_offset],
305            )
306            if rank == i
307            else local_tensor
308        )
309        table_to_shards[i] = local_shards_wrapper
310
311    ###########################################################################
312    # example 1: transform local_shards into DTensor
313    table_to_dtensor = {}  # same purpose as _model_parallel_name_to_sharded_tensor
314    table_wise_sharding_placements = [Replicate()]  # table-wise sharding
315
316    for table_id, local_shards in table_to_shards.items():
317        # create a submesh that only contains the rank we place the table
318        # note that we cannot use ``init_device_mesh'' to create a submesh
319        # so we choose to use the `DeviceMesh` api to directly create a DeviceMesh
320        device_submesh = DeviceMesh(
321            device_type=device_type,
322            mesh=torch.tensor(
323                [table_id], dtype=torch.int64
324            ),  # table ``table_id`` is placed on rank ``table_id``
325        )
326        # create a DTensor from the local shard for the current table
327        # note: for uneven sharding, we need to specify the shape and stride because
328        # DTensor would assume even sharding and compute shape/stride based on the
329        # assumption. Torchrec needs to pass in this information explicitely.
330        dtensor = DTensor.from_local(
331            local_shards,
332            device_submesh,
333            table_wise_sharding_placements,
334            run_check=False,
335            shape=emb_table_shape,  # this is required for uneven sharding
336            stride=(embedding_dim, 1),
337        )
338        table_to_dtensor[table_id] = dtensor
339
340    # print each table's sharding
341    for table_id, dtensor in table_to_dtensor.items():
342        visualize_sharding(
343            dtensor,
344            header=f"Table-wise sharding example in DTensor for Table {table_id}",
345        )
346        # check the dtensor has the correct shape and stride on all ranks
347        assert dtensor.shape == emb_table_shape
348        assert dtensor.stride() == (embedding_dim, 1)
349
350    ###########################################################################
351    # example 2: transform DTensor into torch.Tensor
352    for table_id, local_tensor in table_to_local_tensor.items():
353        # important: note that DTensor.to_local() always returns an empty torch.Tensor
354        # no matter what was passed to DTensor._local_tensor.
355        dtensor_local_shards = table_to_dtensor[table_id].to_local()
356        if rank == table_id:
357            assert isinstance(dtensor_local_shards, LocalShardsWrapper)
358            shard_tensor = dtensor_local_shards.shards[0]
359            assert torch.equal(shard_tensor, local_tensor)  # unwrap tensor
360            assert (
361                dtensor_local_shards.shard_sizes[0] == emb_table_shape
362            )  # unwrap shape
363            assert dtensor_local_shards.shard_offsets[0] == torch.Size(
364                (0, 0)
365            )  # unwrap offset
366        else:
367            assert dtensor_local_shards.numel() == 0
368
369
370def run_example(rank, world_size, example_name):
371    # the dict that stores example code
372    name_to_example_code = {
373        "row-wise-even": run_torchrec_row_wise_even_sharding_example,
374        "row-wise-uneven": run_torchrec_row_wise_uneven_sharding_example,
375        "table-wise": run_torchrec_table_wise_sharding_example,
376    }
377    if example_name not in name_to_example_code:
378        print(f"example for {example_name} does not exist!")
379        return
380
381    # the example to run
382    example_func = name_to_example_code[example_name]
383
384    # set manual seed
385    torch.manual_seed(0)
386
387    # run the example
388    example_func(rank, world_size)
389
390
391if __name__ == "__main__":
392    # this script is launched via torchrun which automatically manages ProcessGroup
393    rank = int(os.environ["RANK"])
394    world_size = int(os.environ["WORLD_SIZE"])
395    assert world_size == 4  # our example uses 4 worker ranks
396    # parse the arguments
397    parser = argparse.ArgumentParser(
398        description="torchrec sharding examples",
399        formatter_class=argparse.RawTextHelpFormatter,
400    )
401    example_prompt = (
402        "choose one sharding example from below:\n"
403        "\t1. row-wise-even;\n"
404        "\t2. row-wise-uneven\n"
405        "\t3. table-wise\n"
406        "e.g. you want to try the row-wise even sharding example, please input 'row-wise-even'\n"
407    )
408    parser.add_argument("-e", "--example", help=example_prompt, required=True)
409    args = parser.parse_args()
410    run_example(rank, world_size, args.example)
411