• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3
4import itertools
5from typing import cast, List
6
7import torch
8import torch.distributed as dist
9from torch import rand, randn, Tensor
10from torch.distributed._tensor import (
11    DeviceMesh,
12    distribute_tensor,
13    init_device_mesh,
14    Replicate,
15    Shard,
16)
17from torch.distributed._tensor.placement_types import Placement
18from torch.distributed.tensor._ops._view_ops import (
19    Broadcast,
20    dim_maps,
21    Flatten,
22    InputDim,
23    Repeat,
24    Singleton,
25    Split,
26    view_groups,
27)
28from torch.distributed.tensor.debug import CommDebugMode
29from torch.testing._internal.common_utils import run_tests
30from torch.testing._internal.distributed._tensor.common_dtensor import (
31    DTensorTestBase,
32    with_comms,
33)
34from torch.utils import _pytree as pytree
35
36
37class TestViewOps(DTensorTestBase):
38    @property
39    def world_size(self) -> int:
40        return 6
41
42    def test_view_groups(self):
43        self.assertEqual(
44            view_groups([2, 3], [3, 2]),
45            (
46                Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 0),
47                Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 1),
48            ),
49        )
50        self.assertEqual(
51            view_groups([3, 4, 5], [12, 5]),
52            (Flatten((InputDim(0), InputDim(1))), InputDim(2)),
53        )
54        self.assertEqual(
55            view_groups([2, 3, 4, 5, 7], [12, 70]),
56            (
57                Split(
58                    Flatten(
59                        (
60                            InputDim(0),
61                            InputDim(1),
62                            InputDim(2),
63                            InputDim(3),
64                            InputDim(4),
65                        )
66                    ),
67                    (12, 70),
68                    0,
69                ),
70                Split(
71                    Flatten(
72                        (
73                            InputDim(0),
74                            InputDim(1),
75                            InputDim(2),
76                            InputDim(3),
77                            InputDim(4),
78                        )
79                    ),
80                    (12, 70),
81                    1,
82                ),
83            ),
84        )
85        self.assertEqual(
86            view_groups([2, 3, 4, 5, 7], [3, 8, 7, 5]),
87            (
88                Split(Flatten((InputDim(0), InputDim(1), InputDim(2))), (3, 8), 0),
89                Split(Flatten((InputDim(0), InputDim(1), InputDim(2))), (3, 8), 1),
90                Split(Flatten((InputDim(3), InputDim(4))), (7, 5), 0),
91                Split(Flatten((InputDim(3), InputDim(4))), (7, 5), 1),
92            ),
93        )
94        self.assertEqual(
95            view_groups([3, 4, 8, 3], [12, 4, 2, 3]),
96            (
97                Flatten((InputDim(0), InputDim(1))),
98                Split(InputDim(2), (4, 2), 0),
99                Split(InputDim(2), (4, 2), 1),
100                InputDim(3),
101            ),
102        )
103        self.assertEqual(
104            view_groups([3, 24], [1, 3, 2, 4, 1, 3, 1]),
105            (
106                Singleton(),
107                InputDim(0),
108                Split(InputDim(1), (2, 4, 3), 0),
109                Split(InputDim(1), (2, 4, 3), 1),
110                Singleton(),
111                Split(InputDim(1), (2, 4, 3), 2),
112                Singleton(),
113            ),
114        )
115        self.assertEqual(
116            view_groups([1, 1, 3, 2, 1, 1], [6, 1, 1, 1]),
117            (
118                Flatten((InputDim(2), InputDim(3))),
119                InputDim(4),
120                InputDim(5),
121                Singleton(),
122            ),
123        )
124        self.assertEqual(
125            view_groups([1, 1, 12, 1, 1, 1, 2, 5, 1], [3, 4, 1, 10]),
126            (
127                Split(InputDim(2), (3, 4), 0),
128                Split(InputDim(2), (3, 4), 1),
129                InputDim(3),
130                Flatten((InputDim(6), InputDim(7))),
131            ),
132        )
133        self.assertEqual(
134            view_groups([2, 3, 4], [2, -1, 4]),
135            (InputDim(0), InputDim(1), InputDim(2)),
136        )
137
138    def call_dt_test(self, op, args, kwargs, device_mesh: DeviceMesh):
139        dim_map = dim_maps[op]
140        rules = dim_map(*args, **kwargs)
141        outputs = op(*args, **kwargs)
142        flat_args = pytree.arg_tree_leaves(*args)
143        in_shape = flat_args[0].shape
144
145        no_shard_dims = set()
146        for rule in rules:
147            if isinstance(rule, Repeat):
148                if isinstance(rule.input_dim, InputDim):
149                    no_shard_dims.add(rule.input_dim.input_dim)
150            elif isinstance(rule, Flatten):
151                for dim in rule.input_dims[1:]:
152                    if isinstance(dim, InputDim):
153                        no_shard_dims.add(dim.input_dim)
154            elif isinstance(rule, Split):
155                if isinstance(rule.input_dim, Flatten):
156                    for dim in rule.input_dim.input_dims[1:]:
157                        if isinstance(dim, InputDim):
158                            no_shard_dims.add(dim.input_dim)
159
160        if op == torch.unbind:
161            no_shard_dims.add(kwargs.get("dim", 0))
162
163        sharding_choices = cast(List[Placement], [Replicate()]) + [
164            Shard(i) for i, s in enumerate(in_shape) if s > 1 and i not in no_shard_dims
165        ]
166
167        all_sharding_choices = itertools.product(
168            *(device_mesh.ndim * [sharding_choices])
169        )
170
171        for in_shard in all_sharding_choices:
172            in_dt = distribute_tensor(args[0], device_mesh, in_shard)
173
174            comm_mode = CommDebugMode()
175            with comm_mode:
176                out_dt = op(in_dt, *args[1:], **kwargs)
177
178            self.assertEqual(
179                comm_mode.get_total_counts(), 0, "Expected no redistribution."
180            )
181
182            full_out = out_dt.full_tensor()
183
184            if dist.get_rank() == 0:
185                self.assertEqual(outputs, full_out)
186
187    def dimmap_test(self, op, args, expected_rule_output):
188        rules = dim_maps[op](*args)
189        self.assertEqual(rules, expected_rule_output)
190        self.call_dt_test(op, args, {}, self.device_mesh)
191
192    @with_comms
193    def test_view_ops(self):
194        self.device_mesh = DeviceMesh(
195            self.device_type, torch.arange(dist.get_world_size()).view(-1, 2)
196        )
197        self.dimmap_test(torch.atleast_1d, (randn(()),), (Singleton(),))
198        self.dimmap_test(torch.atleast_1d, (randn(24),), (InputDim(0),))
199        self.dimmap_test(torch.atleast_1d, (randn(24, 36),), (InputDim(0), InputDim(1)))
200
201        self.dimmap_test(torch.atleast_2d, (randn(()),), (Singleton(), Singleton()))
202        self.dimmap_test(torch.atleast_2d, (randn(24),), (Singleton(), InputDim(0)))
203        self.dimmap_test(torch.atleast_2d, (randn(24, 36),), (InputDim(0), InputDim(1)))
204        self.dimmap_test(
205            torch.atleast_2d,
206            (randn(24, 36, 48),),
207            (InputDim(0), InputDim(1), InputDim(2)),
208        )
209
210        self.dimmap_test(
211            torch.atleast_3d,
212            (randn(()),),
213            (Singleton(), Singleton(), Singleton()),
214        )
215        self.dimmap_test(
216            torch.atleast_3d,
217            (randn(24),),
218            (Singleton(), InputDim(0), Singleton()),
219        )
220        self.dimmap_test(
221            torch.atleast_3d,
222            (randn(24, 36),),
223            (InputDim(0), InputDim(1), Singleton()),
224        )
225        self.dimmap_test(
226            torch.atleast_3d,
227            (randn(24, 36, 42),),
228            (InputDim(0), InputDim(1), InputDim(2)),
229        )
230        self.dimmap_test(
231            torch.atleast_3d,
232            (randn(24, 36, 42, 24),),
233            (InputDim(0), InputDim(1), InputDim(2), InputDim(3)),
234        )
235
236        with self.assertRaises(AssertionError):
237            dim_maps[torch.broadcast_to](randn(24, 36), (1, 2, 4))
238
239        self.dimmap_test(
240            torch.broadcast_to,
241            (rand(24, 36), (1, 24, 36)),
242            (Singleton(), InputDim(0), InputDim(1)),
243        )
244        self.dimmap_test(
245            torch.broadcast_to,
246            (rand(24, 36), (42, 24, 36)),
247            (Broadcast(Singleton(), 42), InputDim(0), InputDim(1)),
248        )
249        self.dimmap_test(
250            torch.broadcast_to,
251            (rand(24, 1, 36), (12, 24, 24, 36)),
252            (
253                Broadcast(Singleton(), 12),
254                InputDim(0),
255                Broadcast(InputDim(1), 24),
256                InputDim(2),
257            ),
258        )
259        self.dimmap_test(
260            torch.broadcast_to,
261            (rand(24, 36), (-1, 36)),
262            (InputDim(0), InputDim(1)),
263        )
264        self.dimmap_test(
265            torch.broadcast_to,
266            (rand(24, 1, 36), (-1, 1, 36)),
267            (InputDim(0), InputDim(1), InputDim(2)),
268        )
269
270        self.dimmap_test(
271            torch.broadcast_to,
272            (randn(36, 1, 24), (12, 36, 42, 24)),
273            (
274                Broadcast(Singleton(), 12),
275                InputDim(0),
276                Broadcast(InputDim(1), 42),
277                InputDim(2),
278            ),
279        )
280
281        self.dimmap_test(
282            Tensor.expand,
283            (randn(24, 1, 36, 1), 36, 24, 42, -1, 24),
284            (
285                Broadcast(Singleton(), 36),
286                InputDim(0),
287                Broadcast(InputDim(1), 42),
288                InputDim(2),
289                Broadcast(InputDim(3), 24),
290            ),
291        )
292
293        self.dimmap_test(
294            Tensor.expand,
295            (randn(24, 1, 36, 1), (36, 24, 42, -1, 24)),
296            (
297                Broadcast(Singleton(), 36),
298                InputDim(0),
299                Broadcast(InputDim(1), 42),
300                InputDim(2),
301                Broadcast(InputDim(3), 24),
302            ),
303        )
304
305        self.dimmap_test(
306            torch.flatten,
307            (randn(24, 36),),
308            (Flatten((InputDim(0), InputDim(1))),),
309        )
310        self.dimmap_test(torch.flatten, (randn(42),), (InputDim(0),))
311        self.dimmap_test(torch.flatten, (randn(()),), (Singleton(),))
312
313        self.dimmap_test(
314            torch.movedim,
315            (randn(12, 24, 48, 96), 1, 2),
316            (InputDim(0), InputDim(2), InputDim(1), InputDim(3)),
317        )
318        self.dimmap_test(
319            torch.movedim,
320            (randn(6, 12, 24), 1, 0),
321            (InputDim(1), InputDim(0), InputDim(2)),
322        )
323        self.dimmap_test(
324            torch.movedim,
325            (randn(24, 12, 6), (1, 2), (0, 1)),
326            (InputDim(1), InputDim(2), InputDim(0)),
327        )
328        self.dimmap_test(
329            torch.movedim,
330            (randn(24, 6, 12), (0, 2, 1), (2, 1, 0)),
331            (InputDim(1), InputDim(2), InputDim(0)),
332        )
333        self.dimmap_test(
334            torch.movedim,
335            (randn(24, 12), (1, 0), (0, 1)),
336            (InputDim(1), InputDim(0)),
337        )
338
339        self.dimmap_test(
340            torch.movedim,
341            (randn(36, 24, 12), (1, 2), (0, 1)),
342            (InputDim(1), InputDim(2), InputDim(0)),
343        )
344        self.dimmap_test(
345            torch.movedim,
346            (randn(36, 24, 12), (1, 2), (-3, -2)),
347            (InputDim(1), InputDim(2), InputDim(0)),
348        )
349
350        self.dimmap_test(
351            torch.permute,
352            (randn(24, 36, 42), (2, 0, 1)),
353            (InputDim(2), InputDim(0), InputDim(1)),
354        )
355        self.dimmap_test(
356            torch.permute,
357            (randn(24, 36, 42), (-1, -3, -2)),
358            (InputDim(2), InputDim(0), InputDim(1)),
359        )
360
361        self.dimmap_test(
362            torch.ravel,
363            (randn(24, 36),),
364            (Flatten((InputDim(0), InputDim(1))),),
365        )
366        self.dimmap_test(torch.ravel, (randn(42),), (InputDim(0),))
367        self.dimmap_test(torch.ravel, (randn(()),), (Singleton(),))
368
369        self.dimmap_test(
370            Tensor.repeat,
371            (randn(24, 36), 1, 2, 1, 1, 2),
372            (
373                Singleton(),
374                Broadcast(Singleton(), 2),
375                Singleton(),
376                InputDim(0),
377                Repeat(InputDim(1), 2),
378            ),
379        )
380
381        self.dimmap_test(
382            torch.reshape,
383            (randn(6, 12, 24), (72, 24)),
384            (Flatten((InputDim(0), InputDim(1))), InputDim(2)),
385        )
386
387        self.dimmap_test(
388            torch.tile,
389            (randn(24, 36), (1, 2, 1, 1, 2)),
390            (
391                Singleton(),
392                Broadcast(Singleton(), 2),
393                Singleton(),
394                InputDim(0),
395                Repeat(InputDim(1), 2),
396            ),
397        )
398        self.dimmap_test(
399            torch.tile,
400            (randn(42, 24, 36), (1, 3)),
401            (InputDim(0), InputDim(1), Repeat(InputDim(2), 3)),
402        )
403
404        self.dimmap_test(
405            torch.transpose,
406            (randn(24, 60, 42, 60), 2, 0),
407            (InputDim(2), InputDim(1), InputDim(0), InputDim(3)),
408        )
409        self.dimmap_test(
410            torch.transpose,
411            (randn(24, 60, 42, 60), -1, 0),
412            (InputDim(3), InputDim(1), InputDim(2), InputDim(0)),
413        )
414
415        self.dimmap_test(
416            torch.unsqueeze,
417            (randn(42, 24, 36), 1),
418            (InputDim(0), Singleton(), InputDim(1), InputDim(2)),
419        )
420
421        self.dimmap_test(
422            Tensor.view,
423            (randn(6, 12, 24), 72, 24),
424            (Flatten((InputDim(0), InputDim(1))), InputDim(2)),
425        )
426
427        self.dimmap_test(Tensor.view, (randn(1, 1, 12), -1), (InputDim(2),))
428
429        self.dimmap_test(
430            Tensor.view,
431            (randn(1, 1, 42, 24), -1),
432            (Flatten((InputDim(2), InputDim(3))),),
433        )
434
435        self.dimmap_test(
436            Tensor.view,
437            (randn(1, 1, 42, 1, 24, 1), -1),
438            (Flatten((InputDim(2), InputDim(input_dim=3), InputDim(4))),),
439        )
440
441        self.dimmap_test(
442            Tensor.view,
443            (randn(48, 35, 26), (24, 4, 35, 13)),
444            (
445                Split(
446                    Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))),
447                    group_shape=(24, 4, 35, 13),
448                    split_id=0,
449                ),
450                Split(
451                    Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))),
452                    group_shape=(24, 4, 35, 13),
453                    split_id=1,
454                ),
455                Split(
456                    Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))),
457                    group_shape=(24, 4, 35, 13),
458                    split_id=2,
459                ),
460                Split(
461                    Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))),
462                    group_shape=(24, 4, 35, 13),
463                    split_id=3,
464                ),
465            ),
466        )
467
468    # TODO: Currently functional collectives on complex numbers are not fully supported,
469    # so we are having a standalone test for view_as_complex and view_as_real combined.
470    # Once complex numbers are supported, we can add the following to the dim_map test.
471    #
472    # self.dimmap_test(
473    #     torch.view_as_complex,
474    #     (randn(24, 13, 2),),
475    #     (
476    #         InputDim(0),
477    #         Flatten((InputDim(1), InputDim(2))),
478    #     ),
479    # )
480    # self.dimmap_test(
481    #     torch.view_as_real,
482    #     (torch.randn(24, 13, dtype=torch.cfloat),),
483    #     (
484    #         InputDim(0),
485    #         Split(InputDim(1), (13, 2), 0),
486    #         Split(InputDim(1), (13, 2), 1),
487    #     ),
488    # )
489    @with_comms
490    def test_complex_view_ops(self):
491        self.device_mesh = DeviceMesh(
492            self.device_type, torch.arange(dist.get_world_size()).view(-1, 2)
493        )
494        inp = randn(24, 13, 2)
495        intermediate = torch.view_as_complex(inp)
496        out = torch.view_as_real(intermediate)
497
498        # test dim_map correctness
499        expected_view_as_complex_rule = (
500            InputDim(0),
501            Flatten((InputDim(1), InputDim(2))),
502        )
503        view_as_complex_rule = dim_maps[torch.view_as_complex](inp)
504        self.assertEqual(view_as_complex_rule, expected_view_as_complex_rule)
505        expected_view_as_real_rule = (
506            InputDim(0),
507            Split(InputDim(1), (13, 2), 0),
508            Split(InputDim(1), (13, 2), 1),
509        )
510        view_as_real_rule = dim_maps[torch.view_as_real](intermediate)
511        self.assertEqual(view_as_real_rule, expected_view_as_real_rule)
512
513        # test sharded computation correctness
514        # NOTE: For the input to torch.view_as_complex, sharding
515        #       on the last two dimensions is not supported.
516        sharding_choices: List[Placement] = [Replicate(), Shard(0)]
517        all_sharding_choices = itertools.product(
518            *(self.device_mesh.ndim * [sharding_choices])
519        )
520
521        for inp_shard in all_sharding_choices:
522            inp_dt = distribute_tensor(inp, self.device_mesh, inp_shard)
523
524            comm_mode = CommDebugMode()
525            with comm_mode:
526                intermediate_dt = torch.view_as_complex(inp_dt)
527                out_dt = torch.view_as_real(intermediate_dt)
528
529            self.assertEqual(
530                comm_mode.get_total_counts(), 0, "Expected no redistribution."
531            )
532            self.assertEqual(out, out_dt.full_tensor())
533
534    @with_comms
535    def test_dtensor_view_op_uneven(self):
536        """
537        Test two uneven cases for view op:
538            1) the sharded tensor dim is 1 so that only the first rank has an non-empty shard.
539            2) the sharded tensor dim is uneven such that some ranks have full shards,
540                smaller non-empty shards, and empty shards.
541        """
542        dim0_sizes = [1, self.world_size + 1]
543        for dim0_size in dim0_sizes:
544            p = torch.randn(dim0_size, 2, 2, 2)
545            mesh = init_device_mesh(self.device_type, (self.world_size,))
546            dtensor = distribute_tensor(p, mesh, [Shard(0)])
547
548            with CommDebugMode() as comm_mode:
549                view = dtensor.view(dim0_size, 2, 4)
550                self.assertEqual(len(comm_mode.get_comm_counts()), 0)
551                # when no communication happens, the data pointer should be the same.
552                self.assertEqual(
553                    view.to_local().data_ptr(), dtensor.to_local().data_ptr()
554                )
555
556                view = dtensor.view(dim0_size, 4, 2)
557                self.assertEqual(
558                    view.to_local().data_ptr(), dtensor.to_local().data_ptr()
559                )
560                self.assertEqual(len(comm_mode.get_comm_counts()), 0)
561
562                view = dtensor.view(dim0_size, 8)
563                self.assertEqual(
564                    view.to_local().data_ptr(), dtensor.to_local().data_ptr()
565                )
566                self.assertEqual(len(comm_mode.get_comm_counts()), 0)
567
568                view = dtensor.view(dtensor.shape)
569                self.assertEqual(
570                    view.to_local().data_ptr(), dtensor.to_local().data_ptr()
571                )
572                self.assertEqual(len(comm_mode.get_comm_counts()), 0)
573
574
575if __name__ == "__main__":
576    run_tests()
577