• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3
4from copy import deepcopy
5
6import torch
7import torch.nn as nn
8from torch.distributed._tensor import (
9    distribute_tensor,
10    DTensor,
11    init_device_mesh,
12    Replicate,
13    Shard,
14)
15from torch.distributed.tensor.debug import CommDebugMode
16from torch.distributed.tensor.parallel import parallelize_module
17from torch.distributed.tensor.parallel.style import (
18    ColwiseParallel,
19    PrepareModuleInput,
20    PrepareModuleOutput,
21    RowwiseParallel,
22    SequenceParallel,
23)
24from torch.distributed.tensor.placement_types import _Partial
25from torch.testing._internal.common_utils import run_tests
26from torch.testing._internal.distributed._tensor.common_dtensor import (
27    DTensorTestBase,
28    NUM_DEVICES,
29    RMSNormPython,
30    with_comms,
31)
32
33
34c10d_functional = torch.ops.c10d_functional
35
36
37class TensorParallelStyleTest(DTensorTestBase):
38    @property
39    def world_size(self):
40        return NUM_DEVICES
41
42    @with_comms
43    def test_colwise_parallel_style(self):
44        mesh = init_device_mesh(self.device_type, (self.world_size,))
45
46        comm_mode = CommDebugMode()
47        tensor = torch.rand(8, 16, device=self.device_type, requires_grad=True)
48        model = nn.Linear(16, 16, device=self.device_type)
49
50        default_col_parallel = ColwiseParallel()
51        colwise_mod = parallelize_module(deepcopy(model), mesh, default_col_parallel)
52        with comm_mode:
53            out = colwise_mod(tensor)
54            # ensure output shard on the last dim
55            self.assertEqual(out.shape, (8, 16 // self.world_size))
56            # ensure no communication happened in fwd
57            self.assertEqual(comm_mode.get_total_counts(), 0)
58
59            out.sum().backward()
60            # allreduce in bwd
61            self.assertEqual(comm_mode.get_comm_counts()[c10d_functional.all_reduce], 1)
62            self.assertEqual(comm_mode.get_total_counts(), 1)
63
64        sharded_col_parallel = ColwiseParallel(input_layouts=Shard(0))
65        colwise_mod = parallelize_module(deepcopy(model), mesh, sharded_col_parallel)
66        with comm_mode:
67            out = colwise_mod(tensor)
68            # ensure output shard on the last dim
69            self.assertEqual(out.shape, (8 * self.world_size, 16 // self.world_size))
70            # allgather in fwd
71            self.assertEqual(
72                comm_mode.get_comm_counts()[c10d_functional.all_gather_into_tensor], 1
73            )
74            self.assertEqual(comm_mode.get_total_counts(), 1)
75
76            out.sum().backward()
77            # reduce_scatter in bwd
78            self.assertEqual(
79                comm_mode.get_comm_counts()[c10d_functional.reduce_scatter_tensor], 1
80            )
81            self.assertEqual(comm_mode.get_total_counts(), 2)
82
83    @with_comms
84    def test_colwise_parallel_embedding(self):
85        mesh = init_device_mesh(self.device_type, (self.world_size,))
86
87        comm_mode = CommDebugMode()
88        tensor = torch.arange(8, device=self.device_type).reshape(4, 2)
89        model = nn.Embedding(16, 16, device=self.device_type)
90
91        default_col_parallel = ColwiseParallel()
92        colwise_mod = parallelize_module(deepcopy(model), mesh, default_col_parallel)
93        with comm_mode:
94            out = colwise_mod(tensor)
95            # ensure output shard on the last dim
96            self.assertEqual(out.shape, (4, 2, 16 // self.world_size))
97            # ensure no communication happened in fwd
98            self.assertEqual(comm_mode.get_total_counts(), 0)
99
100            out.sum().backward()
101            # no comm in bwd
102            self.assertEqual(comm_mode.get_total_counts(), 0)
103
104    @with_comms
105    def test_rowwise_parallel_style(self):
106        mesh = init_device_mesh(self.device_type, (self.world_size,))
107
108        comm_mode = CommDebugMode()
109        tensor = torch.rand(
110            8, 16 // self.world_size, device=self.device_type, requires_grad=True
111        )
112        model = nn.Linear(16, 16, device=self.device_type)
113
114        default_row_parallel = RowwiseParallel()
115        rowwise_mod = parallelize_module(deepcopy(model), mesh, default_row_parallel)
116        with comm_mode:
117            out = rowwise_mod(tensor)
118            # ensure output replicated
119            self.assertEqual(out.shape, (8, 16))
120            # allreduce in fwd
121            self.assertEqual(comm_mode.get_comm_counts()[c10d_functional.all_reduce], 1)
122            self.assertEqual(comm_mode.get_total_counts(), 1)
123
124            out.sum().backward()
125            # no op in bwd
126            self.assertEqual(comm_mode.get_total_counts(), 1)
127
128        sharded_row_parallel = RowwiseParallel(output_layouts=Shard(0))
129        rowwise_mod = parallelize_module(deepcopy(model), mesh, sharded_row_parallel)
130        with comm_mode:
131            out = rowwise_mod(tensor)
132            # ensure output replicated
133            self.assertEqual(out.shape, (8 // self.world_size, 16))
134            # reduce_scatter in fwd
135            self.assertEqual(
136                comm_mode.get_comm_counts()[c10d_functional.reduce_scatter_tensor], 1
137            )
138            self.assertEqual(comm_mode.get_total_counts(), 1)
139
140            out.sum().backward()
141            # allgather in bwd
142            self.assertEqual(
143                comm_mode.get_comm_counts()[c10d_functional.all_gather_into_tensor], 1
144            )
145            self.assertEqual(comm_mode.get_total_counts(), 2)
146
147    @with_comms
148    def test_rowwise_parallel_embedding(self):
149        mesh = init_device_mesh(self.device_type, (self.world_size,))
150
151        comm_mode = CommDebugMode()
152        tensor = torch.arange(8, device=self.device_type).reshape(4, 2)
153        model = nn.Embedding(16, 16, device=self.device_type)
154
155        rowwise_mod = parallelize_module(
156            deepcopy(model), mesh, RowwiseParallel(input_layouts=Replicate())
157        )
158        with comm_mode:
159            out = rowwise_mod(tensor)
160            # ensure output shard on the last dim
161            self.assertEqual(out.shape, (4, 2, 16))
162            # ensure allreduce communication happened in fwd
163            self.assertEqual(comm_mode.get_total_counts(), 1)
164            self.assertEqual(comm_mode.get_comm_counts()[c10d_functional.all_reduce], 1)
165
166            out.sum().backward()
167            # no comm in bwd
168            self.assertEqual(comm_mode.get_total_counts(), 1)
169
170        sharded_row_parallel = RowwiseParallel(
171            input_layouts=Replicate(), output_layouts=Shard(1)
172        )
173
174        rowwise_mod = parallelize_module(deepcopy(model), mesh, sharded_row_parallel)
175
176        inp_indices = torch.arange(8, device=self.device_type)
177        with comm_mode:
178            out = rowwise_mod(inp_indices)
179            # ensure output shard on the last dim
180            self.assertEqual(out.shape, (8, 16 // self.world_size))
181            # reduce scatter in fwd
182            self.assertEqual(comm_mode.get_total_counts(), 1)
183            self.assertEqual(
184                comm_mode.get_comm_counts()[c10d_functional.reduce_scatter_tensor], 1
185            )
186            out.sum().backward()
187            # allgather comm in bwd
188            self.assertEqual(comm_mode.get_total_counts(), 2)
189            self.assertEqual(
190                comm_mode.get_comm_counts()[c10d_functional.all_gather_into_tensor], 1
191            )
192
193    @with_comms
194    def test_prepare_module_input(self):
195        mesh = init_device_mesh(self.device_type, (self.world_size,))
196
197        tensor = torch.ones(2, 16, device=self.device_type)
198        expected_tensor = torch.ones(2 * self.world_size, 16, device=self.device_type)
199        prepare_inp_style = PrepareModuleInput(
200            input_layouts=Shard(0), desired_input_layouts=Replicate()
201        )
202
203        model = nn.Identity()
204        allgather_mod = parallelize_module(model, mesh, prepare_inp_style)
205        output = allgather_mod(tensor).full_tensor()
206        self.assertEqual(output, expected_tensor)
207
208    @with_comms
209    def test_prepare_module_input_multiple_inputs(self):
210        mesh = init_device_mesh(self.device_type, (self.world_size,))
211
212        class TestModule(torch.nn.Module):
213            def __init__(self) -> None:
214                super().__init__()
215                self.linear = torch.nn.Linear(8, 8)
216
217            def forward(self, x, y):
218                return self.linear(x) + y
219
220        # Raise assertion error if input_layouts and desired_input_layouts do not have same length.
221        test_mod = TestModule().to(self.device_type)
222        with self.assertRaisesRegex(
223            AssertionError,
224            "input_layouts and desired_input_layouts should have same length!",
225        ):
226            prepare_inps_dimension_mismatch = PrepareModuleInput(
227                input_layouts=Shard(0), desired_input_layouts=(Replicate(), None)
228            )
229        # Raise assertion error if module inputs and input_layouts do not have same length.
230        prepare_inps_short_dimension = PrepareModuleInput(
231            input_layouts=Shard(0), desired_input_layouts=Replicate()
232        )
233        parallelize_module(test_mod.linear, mesh, ColwiseParallel())
234        parallelize_module(test_mod, mesh, prepare_inps_short_dimension)
235        with self.assertRaisesRegex(
236            ValueError, "module inputs and input_layouts should have same length!"
237        ):
238            output = test_mod(
239                torch.randn(2, 8, device=self.device_type),
240                torch.ones(
241                    self.world_size * 2, 8 // self.world_size, device=self.device_type
242                ),
243            )
244
245        test_mod = TestModule().to(self.device_type)
246        prepare_inps = PrepareModuleInput(
247            input_layouts=(Shard(0), None), desired_input_layouts=(Replicate(), None)
248        )
249
250        parallelize_module(test_mod.linear, mesh, ColwiseParallel())
251        parallelize_module(test_mod, mesh, prepare_inps)
252        output = test_mod(
253            torch.randn(2, 8, device=self.device_type),
254            torch.ones(
255                self.world_size * 2, 8 // self.world_size, device=self.device_type
256            ),
257        )
258        self.assertEqual(output.shape, (self.world_size * 2, 8 // self.world_size))
259
260    @with_comms
261    def test_prepare_module_kwargs_input(self):
262        mesh = init_device_mesh(self.device_type, (self.world_size,))
263
264        class TestKwargModule(torch.nn.Module):
265            def __init__(self) -> None:
266                super().__init__()
267                self.linear = torch.nn.Linear(8, 8)
268
269            def forward(self, x, *, y, z=2):
270                return self.linear(x) + y + z
271
272        test_mod = TestKwargModule().to(self.device_type)
273        prepare_inps_simple = PrepareModuleInput(
274            input_kwarg_layouts={"y": Shard(0)},
275            desired_input_kwarg_layouts={"y": Replicate()},
276        )
277        parallelize_module(
278            test_mod.linear, mesh, ColwiseParallel(use_local_output=False)
279        )
280        parallelize_module(test_mod, mesh, prepare_inps_simple)
281
282        comm_mode = CommDebugMode()
283        with comm_mode:
284            output = test_mod(
285                torch.randn(1 * self.world_size, 8, device=self.device_type),
286                y=torch.ones(1, 8, device=self.device_type),
287            )
288
289        self.assertEqual(comm_mode.get_total_counts(), 1)
290        self.assertEqual(output.shape, (1 * self.world_size, 8))
291
292        class TestKwargOnlyModule(torch.nn.Module):
293            def __init__(self) -> None:
294                super().__init__()
295                self.linear = torch.nn.Linear(8, 8)
296
297            def forward(self, *, x, y=2, z=None):
298                return self.linear(x) + y + z
299
300        test_kwonly_mod = TestKwargOnlyModule().to(self.device_type)
301        prepare_inps_simple = PrepareModuleInput(
302            input_kwarg_layouts={"x": Shard(0), "z": Shard(0)},
303            desired_input_kwarg_layouts={"x": Replicate(), "z": Replicate()},
304        )
305        parallelize_module(
306            test_kwonly_mod.linear, mesh, ColwiseParallel(use_local_output=False)
307        )
308        parallelize_module(test_kwonly_mod, mesh, prepare_inps_simple)
309
310        with comm_mode:
311            output = test_kwonly_mod(
312                x=torch.randn(1, 8, device=self.device_type),
313                z=torch.ones(1, 8, device=self.device_type),
314            )
315
316        self.assertEqual(comm_mode.get_total_counts(), 2)
317        self.assertEqual(output.shape, (1 * self.world_size, 8))
318
319        # test the case where x is a DTensor
320        x_dt = DTensor.from_local(
321            torch.randn(1, 8, device=self.device_type), mesh, [Shard(0)]
322        )
323        with comm_mode:
324            output = test_kwonly_mod(
325                x=x_dt, z=torch.ones(1, 8, device=self.device_type)
326            )
327
328        self.assertEqual(comm_mode.get_total_counts(), 2)
329        self.assertEqual(output.shape, (1 * self.world_size, 8))
330
331    @with_comms
332    def test_prepare_module_output(self):
333        mesh = init_device_mesh(self.device_type, (self.world_size,))
334
335        tensor = torch.ones(8, 16, device=self.device_type)
336        expected_tensor = torch.ones(8 // self.world_size, 16, device=self.device_type)
337        prepare_out_style = PrepareModuleOutput(
338            output_layouts=Replicate(), desired_output_layouts=Shard(0)
339        )
340
341        model = nn.Identity()
342        chunk_mod = parallelize_module(model, mesh, prepare_out_style)
343        output = chunk_mod(tensor)
344        self.assertEqual(output, expected_tensor)
345
346    @with_comms
347    def test_sequence_parallel_style(self):
348        mesh = init_device_mesh(self.device_type, (self.world_size,))
349
350        comm_mode = CommDebugMode()
351        batch, N, embedding_dim = 20, 8, 12
352
353        global_input = torch.rand(
354            batch,
355            N * self.world_size,
356            embedding_dim,
357            device=self.device_type,
358            requires_grad=True,
359        )
360        sharded_input = distribute_tensor(global_input, mesh, [Shard(1)])
361
362        # test LayerNorm
363        for elementwise_affine in [True, False]:
364            norm = nn.LayerNorm(
365                embedding_dim,
366                elementwise_affine=elementwise_affine,
367                device=self.device_type,
368            )
369            sp_norm = parallelize_module(deepcopy(norm), mesh, SequenceParallel())
370
371            output = norm(global_input)
372            output.sum().backward()
373
374            with comm_mode:
375                sharded_out = sp_norm(sharded_input)
376                grad_out = torch.ones_like(sharded_out)
377                sharded_out.backward(grad_out)
378                self.assertIsInstance(sharded_out, DTensor)
379                self.assertEqual(sharded_out.placements, (Shard(1),))
380                self.assertEqual(comm_mode.get_total_counts(), 0)
381                self.assertEqual(
382                    comm_mode.get_comm_counts()[c10d_functional.all_reduce], 0
383                )
384                if elementwise_affine:
385                    self.assertEqual(sp_norm.weight.grad.placements, (_Partial(),))
386                    self.assertEqual(sp_norm.bias.grad.placements, (_Partial(),))
387
388                self.assertEqual(sharded_out.full_tensor(), output)
389
390        # test RMSNorm
391        rmsnorm = RMSNormPython(embedding_dim).to(self.device_type)
392        sp_rmsnorm = parallelize_module(deepcopy(rmsnorm), mesh, SequenceParallel())
393
394        output = rmsnorm(global_input)
395        output.sum().backward()
396
397        with comm_mode:
398            sharded_out = sp_rmsnorm(sharded_input)
399            grad_out = torch.ones_like(sharded_out)
400            sharded_out.backward(grad_out)
401            self.assertIsInstance(sharded_out, DTensor)
402            self.assertEqual(sharded_out.placements, (Shard(1),))
403            self.assertEqual(sp_rmsnorm.weight.grad.placements, (_Partial(),))
404            self.assertEqual(comm_mode.get_total_counts(), 0)
405            self.assertEqual(comm_mode.get_comm_counts()[c10d_functional.all_reduce], 0)
406
407            self.assertEqual(sharded_out.full_tensor(), output)
408
409        # test dropout
410        dropout = nn.Dropout(0.5).to(self.device_type)
411        sp_dropout = parallelize_module(deepcopy(dropout), mesh, SequenceParallel())
412
413        output = dropout(global_input)
414        output.sum().backward()
415        with comm_mode:
416            sharded_out = sp_dropout(sharded_input)
417            grad_out = torch.ones_like(sharded_out)
418            sharded_out.backward(grad_out)
419            self.assertIsInstance(sharded_out, DTensor)
420            self.assertEqual(sharded_out.placements, (Shard(1),))
421            self.assertEqual(comm_mode.get_total_counts(), 0)
422
423        # test sharded on non-sequence dim input
424        sharded_batch_input = distribute_tensor(global_input, mesh, [Shard(0)])
425        rmsnorm = RMSNormPython(embedding_dim).to(self.device_type)
426        sp_rmsnorm = parallelize_module(deepcopy(rmsnorm), mesh, SequenceParallel())
427
428        with comm_mode:
429            sharded_out = sp_rmsnorm(sharded_batch_input)
430            grad_out = torch.ones_like(sharded_out)
431            sharded_out.backward(grad_out)
432            self.assertIsInstance(sharded_out, DTensor)
433            # output still sharded on sequence dimension
434            self.assertEqual(sharded_out.placements, (Shard(1),))
435            self.assertEqual(sp_rmsnorm.weight.grad.placements, (_Partial(),))
436            # communication happens in both fwd/bwd to redistribute input
437            self.assertEqual(comm_mode.get_total_counts(), 2)
438
439
440if __name__ == "__main__":
441    run_tests()
442