• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3
4import itertools
5
6import torch
7from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor
8from torch.distributed._tensor.placement_types import Partial, Replicate, Shard
9from torch.distributed.device_mesh import init_device_mesh
10from torch.distributed.tensor._collective_utils import shard_dim_alltoall
11from torch.distributed.tensor.debug import CommDebugMode
12from torch.testing._internal.common_utils import run_tests
13from torch.testing._internal.distributed._tensor.common_dtensor import (
14    DTensorTestBase,
15    with_comms,
16)
17
18
19funcol = torch.ops.c10d_functional
20
21
22class RedistributeTest(DTensorTestBase):
23    @property
24    def world_size(self):
25        return 4
26
27    @with_comms
28    def test_shard_to_replicate_forward_backward(self):
29        # 1) test shard -> replicate forward
30        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
31        replica_spec = [Replicate()]
32
33        input_sizes_and_shard_dim = [
34            ((self.world_size * 3, 3), 0),
35            ((self.world_size * 3 + 1, 3), 0),
36            ((self.world_size * 3 + 2, 3), 0),
37            ((3, self.world_size * 3), 1),
38            ((3, self.world_size * 3 + 1), 1),
39            ((3, self.world_size * 3 + 2), 1),
40        ]
41
42        comm_mode = CommDebugMode()
43        for input_size, shard_dim in input_sizes_and_shard_dim:
44            shard_spec = [Shard(shard_dim)]
45            expected_tensor = torch.randn(
46                input_size, device=self.device_type, requires_grad=True
47            )
48            dtensor = distribute_tensor(expected_tensor, device_mesh, shard_spec)
49            with comm_mode:
50                reshard_dtensor = dtensor.redistribute(device_mesh, replica_spec)
51            self.assertEqual(reshard_dtensor.size(), torch.Size(input_size))
52            self.assertEqual(expected_tensor, reshard_dtensor.to_local())
53            self.assertEqual(
54                comm_mode.get_comm_counts()[funcol.all_gather_into_tensor], 1
55            )
56
57            # 2) test shard -> replicate backward:
58            # should give gradient as shard
59            grad_output = torch.ones_like(reshard_dtensor)
60            with comm_mode:
61                reshard_dtensor.backward(grad_output)
62            grad_input = dtensor.grad
63            self.assertEqual(grad_input.placements, shard_spec)
64            self.assertEqual(
65                grad_input.to_local(), torch.ones(dtensor.to_local().size())
66            )
67            self.assertEqual(comm_mode.get_total_counts(), 0)
68
69    @with_comms
70    def test_replicate_to_replicate_forward_backward(self):
71        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
72        replica_spec = [Replicate()]
73        local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True)
74
75        comm_mode = CommDebugMode()
76
77        # 1) test replicate -> replicate forward
78        replica_tensor = distribute_tensor(local_tensor, device_mesh, replica_spec)
79        with comm_mode:
80            reshard_replica_tensor = replica_tensor.redistribute(
81                device_mesh, replica_spec
82            )
83        self.assertEqual(replica_tensor.size(), local_tensor.size())
84        self.assertEqual(replica_tensor, reshard_replica_tensor)
85        self.assertEqual(comm_mode.get_total_counts(), 0)
86
87        # 2) test replicate -> replicate backward:
88        # should give gradient as replicate
89        grad_output = torch.ones_like(reshard_replica_tensor)
90        with comm_mode:
91            reshard_replica_tensor.backward(grad_output)
92        grad_input = replica_tensor.grad
93        self.assertEqual(grad_input.placements, replica_spec)
94        self.assertEqual(grad_input.to_local(), torch.ones(12, 3))
95        self.assertEqual(comm_mode.get_total_counts(), 0)
96
97    @with_comms
98    def test_replicate_to_local_partial_grad(self):
99        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
100        replica_spec = [Replicate()]
101        local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True)
102
103        replica_tensor = distribute_tensor(local_tensor, device_mesh, replica_spec)
104
105        comm_mode = CommDebugMode()
106
107        with comm_mode:
108            out = replica_tensor.redistribute(placements=[Replicate()]).to_local(
109                grad_placements=[Partial()]
110            )
111            out.backward(torch.ones_like(out))
112
113        self.assertEqual(comm_mode.get_total_counts(), 1)
114        self.assertEqual(comm_mode.get_comm_counts()[funcol.all_reduce], 1)
115
116    @with_comms
117    def test_replicate_to_shard_forward_backward(self):
118        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
119        replica_spec = [Replicate()]
120
121        input_sizes_and_shard_dim = [
122            ((self.world_size * 3, 3), 0),
123            ((self.world_size * 3 + 1, 3), 0),
124            ((self.world_size * 3 + 2, 3), 0),
125            ((3, self.world_size * 3), 1),
126            ((3, self.world_size * 3 + 1), 1),
127            ((3, self.world_size * 3 + 2), 1),
128        ]
129
130        comm_mode = CommDebugMode()
131        for input_size, shard_dim in input_sizes_and_shard_dim:
132            shard_spec = [Shard(shard_dim)]
133            # 1) test replicate -> shard forward
134            local_replica = torch.randn(
135                input_size, device=self.device_type, requires_grad=True
136            )
137            splitted_list = list(
138                torch.chunk(local_replica, self.world_size, dim=shard_dim)
139            )
140
141            # make local tensor as the element of the corresponding chunked list
142            local_tensor = splitted_list[self.rank]
143            replica_tensor = distribute_tensor(local_replica, device_mesh, replica_spec)
144            with comm_mode:
145                reshard_tensor = replica_tensor.redistribute(device_mesh, shard_spec)
146            self.assertEqual(reshard_tensor.size(), replica_tensor.size())
147            self.assertEqual(reshard_tensor.placements, shard_spec)
148            self.assertEqual(reshard_tensor.to_local(), local_tensor)
149            self.assertEqual(comm_mode.get_total_counts(), 0)
150
151            # 2) test replicate -> shard backward:
152            # should give gradient as replicate
153            grad_output = torch.ones_like(reshard_tensor)
154            with comm_mode:
155                reshard_tensor.backward(grad_output)
156            grad_input = replica_tensor.grad
157            self.assertEqual(grad_input.placements, replica_spec)
158            self.assertEqual(grad_input.to_local(), torch.ones(input_size))
159            self.assertEqual(comm_mode.get_total_counts(), 1)
160            self.assertEqual(
161                comm_mode.get_comm_counts()[funcol.all_gather_into_tensor], 1
162            )
163
164    @with_comms
165    def test_partial_to_replicate_forward_backward(self):
166        # Although we don't allow user to reshard to produce a partial
167        # placement (i.e. user can't reshard to partial), we do allow
168        # replicate to partial internally, and also partial to replicate
169        # backward should work as expected
170        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
171        partial_local = torch.ones(12, 3, device=self.device_type, requires_grad=True)
172        partial_spec = [Partial()]
173        replica_spec = [Replicate()]
174
175        comm_mode = CommDebugMode()
176        # test partial -> replicate, which trigger all_reduce
177        partial_tensor = DTensor.from_local(partial_local, device_mesh, partial_spec)
178        with comm_mode:
179            global_partial_tensor = partial_tensor.redistribute(
180                device_mesh, replica_spec
181            )
182
183        self.assertEqual(partial_tensor.size(), partial_local.size())
184        self.assertEqual(
185            partial_local * self.world_size, global_partial_tensor.to_local()
186        )
187        self.assertEqual(comm_mode.get_comm_counts()[funcol.all_reduce], 1)
188
189        # test backward to have replicate grad on partial
190        # for from_local backward, we want the replicate() -> partial() to be
191        # pass through.
192        with comm_mode:
193            global_partial_tensor.backward(torch.ones_like(global_partial_tensor))
194        self.assertIsNotNone(partial_local.grad)
195        self.assertEqual(partial_local.grad.size(), partial_local.size())
196        self.assertEqual(partial_local.grad, torch.ones_like(partial_local))
197        self.assertEqual(comm_mode.get_total_counts(), 0)
198
199    @with_comms
200    def test_replicate_to_partial(self):
201        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
202        local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True)
203        partial_spec = Partial()
204        replica_spec = Replicate()
205        # 1) test replicate -> partial forward
206        replica_tensor = distribute_tensor(local_tensor, device_mesh, [replica_spec])
207        with self.assertRaisesRegex(RuntimeError, "Can not redistribute to Partial"):
208            partial_tensor = replica_tensor.redistribute(device_mesh, [partial_spec])
209
210        from torch.distributed.tensor._redistribute import Redistribute
211
212        comm_mode = CommDebugMode()
213
214        with comm_mode:
215            partial_tensor = Redistribute.apply(
216                replica_tensor, device_mesh, [partial_spec]
217            )
218        self.assertEqual(partial_tensor.size(), local_tensor.size())
219        # test it successfully zero out the contents on other ranks
220        self.assertEqual(
221            replica_tensor.to_local() / self.world_size, partial_tensor.to_local()
222        )
223        self.assertEqual(comm_mode.get_total_counts(), 0)
224
225        # replicate to partial on sub groups
226        local_tensor = torch.randn(12, 3, device=self.device_type)
227        device_mesh = DeviceMesh(
228            self.device_type,
229            torch.arange(self.world_size).reshape(self.world_size // 2, 2),
230        )
231        # 1) test replicate -> partial on 2d-mesh subgroups
232        replica_tensor = distribute_tensor(
233            local_tensor, device_mesh, [replica_spec, replica_spec]
234        )
235        with comm_mode:
236            partial_tensor = Redistribute.apply(
237                replica_tensor, device_mesh, [partial_spec, partial_spec]
238            )
239        self.assertEqual(partial_tensor.size(), local_tensor.size())
240
241        self.assertEqual(
242            replica_tensor.to_local() / self.world_size,
243            partial_tensor.to_local(),
244        )
245        self.assertEqual(comm_mode.get_total_counts(), 0)
246
247    @with_comms
248    def test_partial_to_shard(self):
249        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
250        partial_spec = [Partial()]
251        my_rank = device_mesh.get_rank()
252
253        input_sizes_and_shard_dim = [
254            ((self.world_size * 3, 3), 0),
255            ((self.world_size * 3 + 1, 3), 0),
256            ((self.world_size * 3 + 2, 3), 0),
257            ((3, self.world_size * 3), 1),
258            ((3, self.world_size * 3 + 1), 1),
259            ((3, self.world_size * 3 + 2), 1),
260        ]
261
262        comm_mode = CommDebugMode()
263
264        for input_size, shard_dim in input_sizes_and_shard_dim:
265            shard_spec = [Shard(shard_dim)]
266
267            partial_local = torch.ones(input_size, device=self.device_type)
268            partial_tensor = DTensor.from_local(
269                partial_local, device_mesh, partial_spec, run_check=False
270            )
271
272            full_chunk_size = (
273                input_size[shard_dim] + self.world_size - 1
274            ) // self.world_size
275            chunk_sizes = [
276                max(
277                    min(input_size[shard_dim], full_chunk_size * (idx + 1))
278                    - full_chunk_size * idx,
279                    0,
280                )
281                for idx in range(self.world_size)
282            ]
283
284            local_shape = list(input_size)
285            local_shape[shard_dim] = chunk_sizes[my_rank]
286
287            # test partial to shard, trigger reduce_scatter
288            with comm_mode:
289                scatter_shard_tensor = partial_tensor.redistribute(
290                    device_mesh, shard_spec
291                )
292            self.assertEqual(scatter_shard_tensor.size(), partial_tensor.size())
293            self.assertEqual(scatter_shard_tensor.placements, shard_spec)
294            self.assertEqual(
295                scatter_shard_tensor.to_local(),
296                torch.ones(local_shape) * self.world_size,
297            )
298            self.assertEqual(
299                comm_mode.get_comm_counts()[funcol.reduce_scatter_tensor], 1
300            )
301
302    @with_comms
303    def test_redistribute_negative_shard_dim(self):
304        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
305        local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True)
306        shard_spec = [Shard(1)]
307        shard_minus_spec = [Shard(-1)]
308
309        shard_tensor = distribute_tensor(local_tensor, device_mesh, shard_spec)
310        self.assertEqual(shard_tensor.placements[0].dim, 1)
311        reshard_tensor = shard_tensor.redistribute(device_mesh, shard_minus_spec)
312        self.assertEqual(shard_tensor.placements[0].dim, 1)
313
314    @with_comms
315    def test_redistribute_uneven_sharding(self):
316        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2))
317        data_to_test = [
318            # uneven on last mesh dim
319            torch.randn((10, 5), device=self.device_type),
320            # uneven on both mesh dims
321            torch.randn((9, 5), device=self.device_type),
322            # smaller than mesh dim shape
323            torch.randn((3, 5), device=self.device_type),
324            torch.randn((1, 3), device=self.device_type),
325        ]
326
327        sharding_to_tests = [
328            [Shard(0), Shard(0)],
329            [Shard(0), Shard(1)],
330        ]
331
332        for input_tensor in data_to_test:
333            for placements in sharding_to_tests:
334                dt = distribute_tensor(input_tensor, mesh, placements)
335                dt_full_tensor = dt.full_tensor()
336                self.assertEqual(dt_full_tensor, input_tensor)
337
338    @with_comms
339    def test_redistribute_shard_dim_change(self):
340        # test 1d device mesh
341        mesh_1d = DeviceMesh(self.device_type, torch.arange(self.world_size))
342        data_to_test = [
343            # evenly sharded case
344            torch.randn((8, 8), device=self.device_type),
345            # 3d or more dims
346            torch.randn((8, 8, 8), device=self.device_type),
347            # uneven case 1
348            torch.randn((8, 5), device=self.device_type),
349            # uneven case 2
350            torch.randn((5, 8), device=self.device_type),
351            # uneven case 3
352            torch.randn((5, 5), device=self.device_type),
353        ]
354
355        sharding_src_dst_pairs = [([Shard(0)], [Shard(1)]), ([Shard(1)], [Shard(0)])]
356
357        comm_mode = CommDebugMode()
358
359        for input_data in data_to_test:
360            for src, dst in sharding_src_dst_pairs:
361                expected_dt = distribute_tensor(input_data.clone(), mesh_1d, dst)
362                sharded_dt = distribute_tensor(input_data, mesh_1d, src)
363                with comm_mode:
364                    out_dt = sharded_dt.redistribute(mesh_1d, dst)
365                self.assertEqual(out_dt.placements, expected_dt.placements)
366                local_out_dt = out_dt.to_local()
367                local_expected_dt = expected_dt.to_local()
368                self.assertEqual(out_dt.to_local(), expected_dt.to_local())
369                if self.device_type == "cuda":
370                    self.assertEqual(
371                        comm_mode.get_comm_counts()[
372                            torch.ops._dtensor.shard_dim_alltoall
373                        ],
374                        1,
375                    )
376                else:
377                    self.assertEqual(
378                        comm_mode.get_comm_counts()[funcol.all_gather_into_tensor],
379                        1,
380                    )
381
382        # test 2d device mesh
383        mesh_2d = DeviceMesh(
384            self.device_type, torch.arange(self.world_size).reshape(2, 2)
385        )
386        data_to_test_2d = [
387            # evenly sharded case
388            torch.randn((8, 8), device=self.device_type),
389            # 3d or more dims
390            torch.randn((8, 8, 8), device=self.device_type),
391            # uneven case 1
392            torch.randn((8, 5), device=self.device_type),
393            # uneven case 2
394            torch.randn((5, 8), device=self.device_type),
395            # uneven case 3
396            torch.randn((5, 5), device=self.device_type),
397        ]
398        sharding_src_dst_pairs_2d = [
399            ([Shard(0), Shard(1)], [Shard(0), Shard(0)]),
400            ([Shard(0), Shard(1)], [Shard(1), Shard(0)]),
401            ([Shard(0), Shard(0)], [Shard(1), Shard(1)]),
402        ]
403        comm_counts_2d = [
404            1,  # 1: S1 -> S0
405            2,  # 1: S1 -> R, 0: S0 -> S1, 1: R -> S0
406            2,  # 1: S0 -> R, 0: S0 -> S1, 1: R -> S1
407        ]
408
409        for input_data in data_to_test_2d:
410            if input_data.ndim > 2:
411                sharding_spec_combs = sharding_src_dst_pairs_2d + [
412                    ([Shard(0), Shard(2)], [Shard(1), Shard(0)]),
413                    ([Shard(1), Shard(1)], [Shard(1), Shard(2)]),
414                ]
415                comm_counts_2d = comm_counts_2d + [
416                    2,  # 1. S2 -> R, 0: S0 -> S1, 1: R -> S0
417                    1,  # 1: S1 -> S2
418                ]
419            else:
420                sharding_spec_combs = sharding_src_dst_pairs_2d
421
422            for idx, (src, dst) in enumerate(sharding_spec_combs):
423                expected_dt = distribute_tensor(input_data.clone(), mesh_2d, dst)
424                sharded_dt = distribute_tensor(input_data, mesh_2d, src)
425                with comm_mode:
426                    out_dt = sharded_dt.redistribute(mesh_2d, dst)
427
428                self.assertEqual(out_dt.placements, expected_dt.placements)
429                self.assertEqual(comm_mode.get_total_counts(), comm_counts_2d[idx])
430
431                local_out_dt = out_dt.to_local()
432                local_expected_dt = expected_dt.to_local()
433                self.assertEqual(local_out_dt, local_expected_dt)
434
435    @with_comms
436    def test_shard_dim_alltoall(self):
437        # init 2d mesh here so we can test when group_rank != global_rank
438        mesh = init_device_mesh(self.device_type, (2, 2))
439        tensor = torch.randn(12, self.world_size, device=self.device_type)
440        new_tensor = shard_dim_alltoall(tensor, 0, 1, mesh, 0)
441
442        meta_tensor = torch.randn(12, self.world_size, device="meta")
443        new_meta_tensor = shard_dim_alltoall(meta_tensor, 0, 1, mesh, 0)
444
445        self.assertEqual(new_tensor.shape, new_meta_tensor.shape)
446
447
448class MultiDimRedistributeTest(DTensorTestBase):
449    @property
450    def world_size(self) -> int:
451        return 8
452
453    @with_comms
454    def test_multi_dim_mesh(self):
455        devices = torch.arange(self.world_size)
456        for mesh_shape in [devices, devices.view(4, 2), devices.view(2, 2, 2)]:
457            mesh_shape = torch.arange(self.world_size).view(-1, 2)
458            device_mesh = DeviceMesh(self.device_type, mesh_shape)
459            tensor_shape = (16, 24)
460
461            if torch.distributed.get_rank() == 0:
462                full_tensor = torch.randn(*tensor_shape)
463            else:
464                # these should be entirely ignored
465                # because distribute_tensor is expected to override shards in ranks != 0
466                full_tensor = torch.ones(*tensor_shape)
467
468            possibilities = [Replicate()] + [Shard(i) for i in range(full_tensor.ndim)]
469            all_outputs = list(itertools.product(*(mesh_shape.ndim * [possibilities])))
470            all_inputs = list(
471                itertools.product(*(mesh_shape.ndim * [possibilities + [Partial()]]))
472            )
473
474            for inputs in all_inputs:
475                # if partial, temporarily make it Replicated, then replace replicated with partial afterwards
476                repl_inputs = [Replicate() if s.is_partial() else s for s in inputs]
477                dt = distribute_tensor(full_tensor, device_mesh, repl_inputs)
478
479                if repl_inputs != inputs:
480                    # create a new DTensor reinterpreting some of the replicated entires as "Partial"
481                    dt = DTensor.from_local(
482                        dt.to_local(), device_mesh, inputs, run_check=False
483                    )
484
485                for outputs in all_outputs:
486                    # redistribute on target outputs
487                    dt2 = dt.redistribute(device_mesh, outputs)
488
489                    # replicate and then get first shard
490                    local_full = dt2.full_tensor()
491
492                    if torch.distributed.get_rank() == 0:
493                        self.assertEqual(local_full.shape, full_tensor.shape)
494
495                        num_sums = 1
496                        for idx, input in enumerate(inputs):
497                            if input.is_partial():
498                                num_sums *= mesh_shape.size(idx)
499                        expected = num_sums * full_tensor
500                        self.assertEqual(local_full, expected)
501
502    @with_comms
503    def test_redistribute_shard_dim_multi_dim_mesh(self):
504        mesh = init_device_mesh(self.device_type, (2, 2, 2))
505        input_data = torch.randn((8, 8, 8), device=self.device_type)
506
507        sharding_src_dst_pairs_3d = [
508            ([Shard(0), Shard(0), Shard(0)], [Shard(1), Shard(1), Shard(1)]),
509            ([Shard(0), Shard(1), Shard(0)], [Shard(1), Shard(0), Shard(0)]),
510            ([Shard(0), Shard(1), Shard(2)], [Shard(2), Shard(1), Shard(0)]),
511            ([Shard(1), Shard(0), Shard(0)], [Replicate(), Shard(0), Shard(0)]),
512            ([Shard(1), Replicate(), Shard(0)], [Replicate(), Shard(0), Shard(0)]),
513            ([Shard(0), Shard(0), Shard(1)], [Shard(0), Shard(1), Shard(2)]),
514        ]
515        comm_counts_3d = [
516            3,  # 2: S0 - R, 1: S1 -> R, 0: S0 -> S1
517            3,  # 2: S0 -> R, 1: S1 -> R, 0: S0 -> S1, 1: R -> S0, 2: R -> S0
518            2,  # 2: S2 -> R, 0: S1 -> S2
519            1,  # 0: S1 -> R
520            2,  # 2: S0 -> R, 1: R -> S0, 2: R -> S0, 0: S1 -> R
521            2,  # 2: S1 -> S2, 1: S0 -> S1
522        ]
523
524        comm_mode = CommDebugMode()
525        for idx, (src_placement, dst_placement) in enumerate(sharding_src_dst_pairs_3d):
526            expected_dt = distribute_tensor(input_data.clone(), mesh, dst_placement)
527            sharded_dt = distribute_tensor(input_data, mesh, src_placement)
528
529            with comm_mode:
530                out_dt = sharded_dt.redistribute(mesh, dst_placement)
531
532            self.assertEqual(out_dt.placements, expected_dt.placements)
533            self.assertEqual(comm_mode.get_total_counts(), comm_counts_3d[idx])
534
535            local_out_dt = out_dt.to_local()
536            local_expected_dt = expected_dt.to_local()
537            self.assertEqual(local_out_dt, local_expected_dt)
538
539
540if __name__ == "__main__":
541    run_tests()
542