• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["oncall: distributed"]
2
3import sys
4
5import torch
6from torch.distributed._tensor import (
7    distribute_module,
8    distribute_tensor,
9    DTensor,
10    Replicate,
11    Shard,
12)
13from torch.distributed.tensor.debug import CommDebugMode
14from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
15from torch.testing._internal.distributed._tensor.common_dtensor import (
16    DTensorTestBase,
17    with_comms,
18)
19
20
21if TEST_WITH_DEV_DBG_ASAN:
22    print(
23        "Skip dev-asan as torch + multiprocessing spawn have known issues",
24        file=sys.stderr,
25    )
26    sys.exit(0)
27
28
29funcol = torch.ops.c10d_functional
30
31
32class TestEmbeddingOp(DTensorTestBase):
33    def _apply_sharding(self, embedding_mod, shard_dim, device_mesh):
34        def shard_embedding_fn(name, module, device_mesh):
35            for name, param in module.named_parameters():
36                dist_param = torch.nn.Parameter(
37                    distribute_tensor(param, device_mesh, [Shard(shard_dim)])
38                )
39                module.register_parameter(name, dist_param)
40
41        sharded_embedding = distribute_module(
42            embedding_mod, device_mesh, shard_embedding_fn
43        )
44        return sharded_embedding
45
46    def _run_embedding_op_test(
47        self,
48        device_mesh,
49        shard_dim,
50        input_size,
51        num_embeddings,
52        embedding_dim,
53        **kwargs,
54    ):
55        # Use same seed.
56        torch.manual_seed(0)
57        local_embedding = torch.nn.Embedding(
58            num_embeddings,
59            embedding_dim,
60            device=self.device_type,
61            **kwargs,
62        )
63        sharded_embedding = torch.nn.Embedding(
64            num_embeddings,
65            embedding_dim,
66            device=self.device_type,
67            **kwargs,
68        )
69
70        # Shard the parameter of local embedding and set it to sharded embedding.
71        sharded_embedding.weight = torch.nn.Parameter(
72            local_embedding.weight.clone().detach()
73        )
74
75        sharded_embedding = self._apply_sharding(
76            sharded_embedding, shard_dim, device_mesh
77        )
78
79        # Run sharded computation
80        torch.manual_seed(10)
81        inp = torch.randint(
82            0, num_embeddings, tuple(input_size), device=self.device_type
83        )
84        target = torch.empty(
85            *inp.size(), embedding_dim, dtype=torch.float, device=self.device_type
86        ).random_(0, 1)
87        dist_inp = distribute_tensor(inp, device_mesh, [Replicate()])
88
89        # fwd computation, ensure no comm happened
90        with CommDebugMode() as fwd_mode:
91            dist_output = sharded_embedding(dist_inp)
92            self.assertEqual(fwd_mode.get_total_counts(), 0)
93
94        output = dist_output.full_tensor()
95        # Run local computation
96        local_output = local_embedding(inp)
97
98        # Verify
99        self.assertEqual(local_output, output)
100
101        # Use a sample cross entry loss to verify backward and grad computation.
102        loss = torch.nn.CrossEntropyLoss()
103        emb_loss = loss(
104            output,
105            target,
106        )
107        emb_dup_loss = loss(
108            local_output,
109            target,
110        )
111
112        # local embedding backward
113        emb_dup_loss.backward()
114
115        # sharded embedding bwd computation, ensure no comm happened
116        with CommDebugMode() as bwd_mode:
117            emb_loss.backward()
118            self.assertEqual(bwd_mode.get_total_counts(), 0)
119
120        gradient = sharded_embedding.weight.grad.full_tensor()
121
122        local_grad = local_embedding.weight.grad
123
124        # Verify gradient.
125        self.assertEqual(gradient, local_grad)
126
127        # Validate for torch.nn.functional.embedding version.
128        local_output = torch.nn.functional.embedding(
129            inp,
130            local_embedding.weight,
131            **kwargs,
132        )
133        sharded_output = torch.nn.functional.embedding(
134            DTensor.from_local(inp, device_mesh, [Replicate()], run_check=False),
135            sharded_embedding.weight,
136            **kwargs,
137        )
138        self.assertEqual(local_output, sharded_output.full_tensor())
139
140    @with_comms
141    def test_sharded_embedding_colwise(self):
142        mesh = self.build_device_mesh()
143        self._run_embedding_op_test(mesh, 1, [5, 4], 17, 12)
144        self._run_embedding_op_test(mesh, 1, [6, 7, 6], 21, 11)
145        self._run_embedding_op_test(mesh, 1, [8, 6, 5, 4], 23, 13)
146        self._run_embedding_op_test(mesh, 1, [8, 6, 5, 4, 7], 23, 16)
147        self._run_embedding_op_test(mesh, 1, [4], 15, 14)
148        self._run_embedding_op_test(mesh, 1, [34], 15, 14, padding_idx=10)
149        self._run_embedding_op_test(mesh, 1, [8, 6, 5, 4], 23, 13, padding_idx=12)
150
151    @with_comms
152    def test_sharded_embedding_colwise_max_norm_errors(self):
153        mesh = self.build_device_mesh()
154        with self.assertRaisesRegex(
155            NotImplementedError,
156            "aten.embedding_renorm_.default does not have a sharding strategy registered.",
157        ):
158            self._run_embedding_op_test(
159                mesh, 1, [8, 6, 5, 4], 23, 13, padding_idx=12, max_norm=2.0
160            )
161
162    @with_comms
163    def test_sharded_embedding_rowwise(self):
164        mesh = self.build_device_mesh()
165        # test correctness
166        self._run_embedding_op_test(mesh, 0, [5, 12], 16, 22)
167        self._run_embedding_op_test(mesh, 0, [6, 7, 6], 13, 22)
168        self._run_embedding_op_test(mesh, 0, [34], 15, 14, padding_idx=10)
169
170        from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
171
172        # test collectives
173        embedding_mod = torch.nn.Embedding(10, 20, device=self.device_type)
174        sharded_embedding = self._apply_sharding(embedding_mod, 0, mesh)
175        inp = torch.randint(0, 10, (8, 8), device=self.device_type)
176        replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False)
177        output = sharded_embedding(replicated_inp)
178        self.assertIsInstance(output.placements[0], _MaskPartial)
179
180        comm_mode = CommDebugMode()
181
182        with comm_mode:
183            output.full_tensor()
184            self.assertEqual(comm_mode.get_total_counts(), 1)
185            self.assertEqual(comm_mode.get_comm_counts()[funcol.all_reduce], 1)
186
187    @with_comms
188    def test_multiple_embeddings_rowwise(self):
189        mesh = self.build_device_mesh()
190
191        inp = torch.randint(0, 10, (4, 4), device=self.device_type)
192        replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False)
193
194        from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
195
196        # case 1: two embeddings with the same shape, thus sharing the underying _MaskPartial
197        # and MaskBuffer, because of cache hit from sharding propagation
198
199        emb1 = torch.nn.Embedding(10, 23, device=self.device_type)
200        sharded_emb1 = self._apply_sharding(emb1, 0, mesh)
201        output1 = sharded_emb1(replicated_inp)
202
203        emb2 = torch.nn.Embedding(10, 29, device=self.device_type)
204        sharded_emb2 = self._apply_sharding(emb2, 0, mesh)
205        output2 = sharded_emb2(replicated_inp)
206
207        partial_placement1 = output1.placements[0]
208        self.assertIsInstance(partial_placement1, _MaskPartial)
209        output1.full_tensor()
210
211        partial_placement2 = output2.placements[0]
212        self.assertIsInstance(partial_placement2, _MaskPartial)
213        output2.full_tensor()
214
215        self.assertTrue(id(partial_placement1), id(partial_placement2))
216
217        # case 2: two embeddings with the same logical_dim_size, but different logical_shape
218        # thus they will have different _MaskPartial placements (with no cache hit)
219
220        emb3 = torch.nn.Embedding(10, 29, device=self.device_type)
221        sharded_emb3 = self._apply_sharding(emb3, 0, mesh)
222        output3 = sharded_emb3(replicated_inp)
223        partial_placement3 = output3.placements[0]
224        self.assertIsInstance(partial_placement3, _MaskPartial)
225        output2.full_tensor()
226
227        # not equal because of different logical_shape, despite of same logical_dim_size
228        self.assertNotEqual(partial_placement1, partial_placement3)
229
230
231if __name__ == "__main__":
232    run_tests()
233