• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3
4import copy
5import itertools
6from pprint import pformat
7from typing import NamedTuple
8
9import torch
10from torch.distributed._tensor import (
11    DeviceMesh,
12    distribute_module,
13    distribute_tensor,
14    DTensor,
15)
16from torch.distributed._tensor.placement_types import Replicate, Shard
17from torch.distributed.tensor._ops.utils import is_tensor_partial, normalize_dim
18from torch.distributed.tensor.debug import CommDebugMode
19from torch.distributed.tensor.parallel import (
20    ColwiseParallel,
21    parallelize_module,
22    RowwiseParallel,
23    SequenceParallel,
24)
25from torch.testing._internal.common_utils import run_tests
26from torch.testing._internal.distributed._tensor.common_dtensor import (
27    DTensorTestBase,
28    skip_unless_torch_gpu,
29    with_comms,
30)
31
32
33funcol = torch.ops.c10d_functional
34
35
36class DistMathOpsTest(DTensorTestBase):
37    def _check_module(self, m1, m2, check_grad=False):
38        named_parameters = dict(m1.named_parameters())
39        for name, param_m2 in m2.named_parameters():
40            self.assertTrue(name in named_parameters)
41            param_m1 = named_parameters[name]
42            if check_grad:
43                param_m2 = param_m2.grad
44                param_m1 = param_m1.grad
45            if isinstance(param_m2, DTensor):
46                replicate = [Replicate()]
47                param_m2 = param_m2.redistribute(
48                    device_mesh=param_m2.device_mesh, placements=replicate
49                ).to_local()
50            self.assertEqual(param_m2, param_m1)
51
52    def linear_op_reductions(self, op_str):
53        device_mesh = self.build_device_mesh()
54        shard_spec = [Shard(0)]
55
56        tensor = torch.randn(12, 8, 8)
57        # TODO: check `all` correctness and test `all` on a bool tensor
58        if op_str in ("any"):
59            # test out a bool tensor for any
60            tensor = tensor < 0
61        dtensor = distribute_tensor(tensor, device_mesh, shard_spec)
62
63        op = getattr(tensor, op_str)
64        op_dt = getattr(dtensor, op_str)
65
66        keep_dim_or_not = [True, False, None]
67        for dim in range(tensor.ndim):
68            for keep_dim in keep_dim_or_not:
69                args = (dim, keep_dim) if keep_dim is not None else (dim,)
70                if op_str in ("max", "min"):
71                    # min and max return a tuple when dim specified
72                    dim_reduced_tensor, _ = op(*args)
73                    dt_reduced, _ = op_dt(*args)
74                else:
75                    dim_reduced_tensor = op(*args)
76                    dt_reduced = op_dt(*args)
77                dt_dim_reduced_tensor = dt_reduced.full_tensor()
78                self.assertEqual(dt_dim_reduced_tensor, dim_reduced_tensor)
79
80        full_reduced_tensor = op()
81        dt_full_reduced = op_dt().full_tensor()
82        self.assertEqual(dt_full_reduced, full_reduced_tensor)
83
84    @with_comms
85    def test_linear_op_reductions(self):
86        for op_str in ("all", "sum", "prod", "max", "min", "any"):
87            self.linear_op_reductions(op_str)
88
89    @with_comms
90    @skip_unless_torch_gpu
91    def test_mean(self):
92        self.linear_op_reductions("mean")
93
94    # TODO: forward test can be removed once test_softmax_with_bwd passes on CPU
95    @with_comms
96    def test_softmax_fwd(self):
97        device_mesh = self.build_device_mesh()
98
99        x = torch.rand(8, 12, 16, device=self.device_type)
100        dims = range(3)  # used to convert -1 to the actual dim
101        softmax_dims = [-1, 0, 1, 2]
102        shard_dims = [-1, 0, 1, 2]
103        test_list = list(itertools.product(softmax_dims, shard_dims))
104
105        for softmax_dim, shard_dim in test_list:
106            local_y = torch.nn.functional.softmax(
107                x, dim=softmax_dim, dtype=torch.float32
108            )
109            dist_x = distribute_tensor(x, device_mesh, [Shard(shard_dim)])
110            dist_y = torch.nn.functional.softmax(
111                dist_x, dim=softmax_dim, dtype=torch.float32
112            )
113            shard_dim = normalize_dim(shard_dim, dist_x.ndim)
114            if dims[shard_dim] == dims[softmax_dim]:
115                self.assertTrue(dist_y.placements[0].is_replicate())
116                self.assertEqual(dist_y.to_local(), local_y)
117            else:
118                self.assertTrue(dist_y.placements[0].is_shard(dim=shard_dim))
119                self.assertEqual(dist_y.full_tensor(), local_y)
120
121    # TODO: get test_softmax_with_bwd pass on CPU
122    # DTensor's _softmax_backward_data produces wrong result on CPU on certain dimension.
123    # fail_on_cpu_list = [(0, -1), (1, -1)]
124    @with_comms
125    @skip_unless_torch_gpu
126    def test_softmax_with_bwd(self):
127        device_mesh = self.build_device_mesh()
128
129        dims = range(3)  # used to convert -1 to the actual dim
130        softmax_dims = [-1, 0, 1, 2]
131        shard_dims = [-1, 0, 1, 2]
132        test_list = list(itertools.product(softmax_dims, shard_dims))
133
134        for params in test_list:
135            softmax_dim, shard_dim = params
136            x = torch.rand(8, 12, 16, device=self.device_type, requires_grad=True)
137            self.assertTrue(x.requires_grad)
138            local_y = torch.nn.functional.softmax(
139                x, dim=softmax_dim, dtype=torch.float32
140            ).sum()
141            local_y.backward()
142
143            dist_x = distribute_tensor(x, device_mesh, [Shard(shard_dim)])
144            self.assertTrue(dist_x.requires_grad)
145            dist_softmax = dist_x.softmax(dim=softmax_dim)
146            shard_dim = normalize_dim(shard_dim, dist_x.ndim)
147            if dims[softmax_dim] == dims[shard_dim]:
148                self.assertTrue(dist_softmax.placements[0].is_replicate())
149            else:
150                self.assertTrue(dist_softmax.placements[0].is_shard(dim=shard_dim))
151            dist_y = dist_softmax.sum()
152            if dims[softmax_dim] == dims[shard_dim]:
153                self.assertTrue(dist_y.placements[0].is_replicate())
154            else:
155                self.assertTrue(dist_y.placements[0].is_partial())
156                dist_y = dist_y.redistribute(device_mesh, [Replicate()])
157            self.assertEqual(dist_y.to_local(), local_y)
158            self.assertIsNone(dist_x.grad)
159            dist_y.backward()
160            self.assertIsNotNone(dist_x.grad)
161            if dims[softmax_dim] == dims[shard_dim]:
162                self.assertTrue(dist_x.grad.placements[0].is_replicate())
163            else:
164                self.assertTrue(dist_x.grad.placements[0].is_shard(dim=shard_dim))
165            self.assertEqual(dist_x.grad.full_tensor(), x.grad)
166
167    @with_comms
168    @skip_unless_torch_gpu
169    def test_nll_loss_and_cross_entropy(self):
170        device_mesh = self.build_device_mesh()
171        comm_mode = CommDebugMode()
172
173        channel_size, channel_dim = 16, 1
174        test_setup = [
175            (2, (8, channel_size), (8,)),  # calling aten.nll_loss_forward
176            (3, (8, channel_size, 12), (8, 12)),  # calling aten.nll_loss2d_forward
177        ]
178        for input_ndim, input_size, target_size in test_setup:
179            x = torch.rand(*input_size, device=self.device_type, requires_grad=True)
180            target = torch.randint(channel_size, target_size, device=self.device_type)
181            dist_target = distribute_tensor(target, device_mesh, [Replicate()])
182
183            shard_dims = list(range(input_ndim))
184            reductions = ["none", "mean", "sum"]
185            # Compared with nll_loss, cross_entropy additionally calls log_softmax first.
186            # Testing them together as code can be reused.
187            loss_functions = [
188                torch.nn.functional.nll_loss,
189                torch.nn.functional.cross_entropy,
190            ]
191            for shard_dim, reduction, loss_fn in itertools.product(
192                shard_dims, reductions, loss_functions
193            ):
194                dist_x = distribute_tensor(x, device_mesh, [Shard(shard_dim)])
195                y = loss_fn(x, target, reduction=reduction)
196                if reduction == "none":
197                    y.sum().backward()
198                else:
199                    y.backward()
200                with comm_mode:
201                    dist_y = loss_fn(dist_x, dist_target, reduction=reduction)
202                    if shard_dim == channel_dim:
203                        self.assertEqual(comm_mode.get_total_counts(), 1)
204                        self.assertEqual(
205                            comm_mode.get_comm_counts()[funcol.all_gather_into_tensor],
206                            1,
207                        )
208                        self.assertTrue(dist_y.placements[0].is_replicate())
209                        self.assertEqual(dist_y.to_local(), y)
210                    else:
211                        self.assertEqual(comm_mode.get_total_counts(), 0)
212                        if reduction == "none":
213                            output_shard_dim = (
214                                shard_dim if shard_dim < channel_dim else shard_dim - 1
215                            )
216                            self.assertTrue(
217                                dist_y.placements[0].is_shard(dim=output_shard_dim)
218                            )
219                        else:
220                            self.assertTrue(dist_y.placements[0].is_partial())
221                        self.assertEqual(dist_y.full_tensor(), y)
222
223                    if reduction == "none":
224                        dist_y.sum().backward()
225                    else:
226                        dist_y.backward()
227                    if shard_dim == channel_dim:
228                        self.assertTrue(dist_x.grad.placements[0].is_replicate())
229                        self.assertEqual(dist_x.grad.to_local(), x.grad)
230                    else:
231                        self.assertTrue(
232                            dist_x.grad.placements[0].is_shard(dim=shard_dim)
233                        )
234                        self.assertEqual(dist_x.grad.full_tensor(), x.grad)
235                    x.grad.zero_()
236
237    @with_comms
238    def test_shard_math_ops(self):
239        mesh_shape = (2, self.world_size // 2)
240        mesh = DeviceMesh(
241            self.device_type,
242            torch.arange(self.world_size).reshape(*mesh_shape),
243        )
244        global_tensor = torch.ones(4, 4)
245        double_shard_tensor = distribute_tensor(
246            global_tensor, mesh, [Shard(0), Shard(0)]
247        )
248        fully_shard_tensor = distribute_tensor(
249            global_tensor, mesh, [Shard(0), Shard(1)]
250        )
251
252        # for op in [torch.add, torch.sub, torch.mul, torch.div]:
253        for op in [torch.add, torch.sub, torch.mul, torch.div]:
254            expect_rs = op(global_tensor, 2)
255            double_shard_full_tensor = op(double_shard_tensor, 2).full_tensor()
256            self.assertEqual(double_shard_full_tensor, expect_rs)
257
258            fully_shard_full_tensor = op(fully_shard_tensor, 2).full_tensor()
259            self.assertEqual(fully_shard_full_tensor, expect_rs)
260
261    @with_comms
262    def test_layer_norm_fwd(self):
263        device_mesh = self.build_device_mesh()
264
265        # NLP example from pytorch docs
266        # https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
267        batch, sentence_length, embedding_dim = 20, 5, 10
268        x = torch.rand(batch, sentence_length, embedding_dim, device=self.device_type)
269        norm_shape_idx_list = list(range(x.ndim))
270        shard_dims = [-1, 0, 1, 2]
271        elementwise_affine_list = [False, True]
272        test_config_list = list(
273            itertools.product(shard_dims, norm_shape_idx_list, elementwise_affine_list)
274        )
275
276        # normalized shape is a torch.Size object
277        for shard_dim, norm_idx, elementwise_affine in test_config_list:
278            normalized_shape = x.shape[norm_idx:]
279            layer_norm = torch.nn.LayerNorm(
280                normalized_shape,
281                elementwise_affine=elementwise_affine,
282                device=self.device_type,
283            )
284            layer_norm_local = copy.deepcopy(layer_norm).to(self.device_type)
285
286            def _replicate_fn(name, module, device_mesh):
287                for name, param in module.named_parameters():
288                    if name in ["weight", "bias"]:
289                        param_dist = torch.nn.Parameter(
290                            distribute_tensor(param, device_mesh, [Replicate()])
291                        )
292                        module.register_parameter(name, param_dist)
293
294            layer_norm_dist = distribute_module(layer_norm, device_mesh, _replicate_fn)
295
296            x_local = x
297            x_dist = distribute_tensor(x, device_mesh, [Shard(shard_dim)])
298
299            y_local = layer_norm_local(x_local)
300            # make sure that forward layer norm does not introduce extra collectives
301            comm_mode = CommDebugMode()
302            with comm_mode:
303                y_dist = layer_norm_dist(x_dist)
304
305            self.assertLessEqual(
306                comm_mode.get_total_counts(),
307                1,  # TODO: This should be 0!
308                f"comm count={comm_mode.get_total_counts()}, "
309                f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}",
310            )
311
312            from torch.distributed._tensor.placement_types import TensorMeta
313
314            dtensor_meta = y_dist._spec.tensor_meta
315            assert isinstance(dtensor_meta, TensorMeta)
316            # make sure the right shape in sharding prop
317            self.assertEqual(y_local.shape, dtensor_meta.shape)
318            self.assertEqual(y_local, y_dist.full_tensor())
319
320    @with_comms
321    def test_layer_norm_bwd(self):
322        device_mesh = self.build_device_mesh()
323
324        # NLP example from pytorch docs
325        # https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
326        batch, sentence_length, embedding_dim = 20, 5, 10
327        norm_shape_idx_list = list(range(3))
328        shard_dims = [0, 1, 2]
329        elementwise_affine_list = [False, True]
330        test_config_list = list(
331            itertools.product(shard_dims, norm_shape_idx_list, elementwise_affine_list)
332        )
333
334        # normalized shape is a torch.Size object
335        for shard_dim, norm_idx, elementwise_affine in test_config_list:
336            x = torch.rand(
337                batch,
338                sentence_length,
339                embedding_dim,
340                device=self.device_type,
341                requires_grad=True,
342            )
343            normalized_shape = x.shape[norm_idx:]
344            layer_norm = torch.nn.LayerNorm(
345                normalized_shape,
346                elementwise_affine=elementwise_affine,
347                device=self.device_type,
348            )
349            layer_norm_local = copy.deepcopy(layer_norm).to(self.device_type)
350
351            def _replicate_fn(name, module, device_mesh):
352                for name, param in module.named_parameters():
353                    if name in ["weight", "bias"]:
354                        param_dist = torch.nn.Parameter(
355                            distribute_tensor(param, device_mesh, [Replicate()])
356                        )
357                        module.register_parameter(name, param_dist)
358
359            layer_norm_dist = distribute_module(layer_norm, device_mesh, _replicate_fn)
360
361            if elementwise_affine:
362                self.assertEqual(
363                    layer_norm_local.weight, layer_norm_dist.weight.full_tensor()
364                )
365                self.assertEqual(
366                    layer_norm_local.bias, layer_norm_dist.bias.full_tensor()
367                )
368
369            x_local = x.detach().clone().requires_grad_(True)
370            x_dist = distribute_tensor(x, device_mesh, [Shard(shard_dim)])
371            self.assertEqual(x_local, x_dist.full_tensor())
372
373            y_local = layer_norm_local(x_local)
374            # make sure that backward layer norm does not introduce extra collectives
375            comm_mode = CommDebugMode()
376            with comm_mode:
377                y_dist = layer_norm_dist(x_dist)
378                y_dist.sum().backward()
379
380            expected_fwd_comm = 0 if shard_dim < norm_idx else 1
381
382            self.assertEqual(
383                sum(comm_mode.comm_module_counts["Global"]["forward"].values()),
384                expected_fwd_comm,
385                f"comm count={comm_mode.get_total_counts()}, "
386                f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}",
387            )
388
389            self.assertEqual(y_local, y_dist.full_tensor())
390
391            # backward step
392            y_local.sum().backward()
393
394            expected_bwd_comm = 0 if shard_dim < norm_idx else 1
395
396            self.assertEqual(
397                sum(comm_mode.comm_module_counts["Global"]["backward"].values()),
398                expected_bwd_comm,
399                f"comm count={comm_mode.get_total_counts()}, "
400                f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}",
401            )
402
403            if elementwise_affine:
404                # if input is sharded on any outer dimension, the gradient of weight
405                # and bias should be Partial
406                dim_map = x_dist._spec.dim_map
407                outer_dims = range(norm_idx)
408                needs_reduction = any(dim_map[d] >= 0 for d in outer_dims)
409                self.assertEqual(
410                    is_tensor_partial(layer_norm_dist.weight.grad._spec),
411                    needs_reduction,
412                )
413                self.assertEqual(
414                    is_tensor_partial(layer_norm_dist.bias.grad._spec),
415                    needs_reduction,
416                )
417                self.assertEqual(
418                    layer_norm_local.weight.grad,
419                    layer_norm_dist.weight.grad.full_tensor(),
420                )
421                self.assertEqual(
422                    layer_norm_local.bias.grad,
423                    layer_norm_dist.bias.grad.full_tensor(),
424                )
425
426            self.assertEqual(x_local.grad, x_dist.grad.full_tensor())
427
428    @with_comms
429    def test_layer_norm_bwd_req_grad(self):
430        device_mesh = self.build_device_mesh()
431        batch, seq_len, embedding_dim, vocab_size = 8, 8, 10, 32
432
433        # build our subtest configurations and filter out invalid ones
434        class SubTest(NamedTuple):
435            multidim_norm: bool
436            elementwise_affine: bool
437            emb_req_grad: bool
438            ln_req_grad: bool
439            out_req_grad: bool
440
441        subtest_fails = {}
442        valid_filter = lambda cfg: not (  # noqa: E731
443            cfg.ln_req_grad and not cfg.elementwise_affine
444        ) and any(cfg[2:])
445        subtest_cfgs = list(
446            filter(
447                valid_filter,
448                [SubTest(*cfg) for cfg in itertools.product(*(((False, True),) * 5))],
449            )
450        )
451
452        for subtest_cfg in subtest_cfgs:
453            try:
454                (
455                    multidim_norm,
456                    elementwise_affine,
457                    emb_req_grad,
458                    ln_req_grad,
459                    out_req_grad,
460                ) = subtest_cfg
461                normalized_shape = (
462                    (seq_len, embedding_dim) if multidim_norm else (embedding_dim,)
463                )
464
465                # configure our local and parallelized models for this subtest
466                class LnTpBlock(torch.nn.Module):
467                    def __init__(self):
468                        super().__init__()
469                        self.preln_embeddings = torch.nn.Embedding(
470                            vocab_size, embedding_dim
471                        )
472                        self.layer_norm = torch.nn.LayerNorm(
473                            normalized_shape, elementwise_affine=elementwise_affine
474                        )
475                        self.postln_linear = torch.nn.Linear(
476                            embedding_dim, embedding_dim
477                        )
478
479                    def forward(self, tokens):
480                        h = self.preln_embeddings(tokens)
481                        h = self.layer_norm(h)
482                        output = self.postln_linear(h)
483                        return output
484
485                parallel_plan = {
486                    "preln_embeddings": RowwiseParallel(
487                        input_layouts=Replicate(), output_layouts=Shard(1)
488                    ),
489                    "layer_norm": SequenceParallel(),
490                    "postln_linear": ColwiseParallel(
491                        input_layouts=Shard(1),
492                        output_layouts=Replicate(),
493                    ),
494                }
495
496                model = LnTpBlock()
497                model_local = copy.deepcopy(model).to(device=self.device_type)
498                model_dist = parallelize_module(model, device_mesh, parallel_plan)
499                req_grad_map = {
500                    "preln_embeddings": emb_req_grad,
501                    "postln_linear": out_req_grad,
502                    "layer_norm": ln_req_grad,
503                }
504
505                # apply the relevant `requires_grad` mask for this subtest to both models
506                for target_model in [model_local, model_dist]:
507                    for n, p in target_model.named_parameters():
508                        if not req_grad_map.get(n.rpartition(".")[0], False):
509                            p.requires_grad_(False)
510                            assert not p.requires_grad
511                        else:
512                            assert p.requires_grad
513
514                # forward step for both local and distributed models
515                x = torch.randint(vocab_size, (batch, seq_len), device=self.device_type)
516                x_local = x.detach().clone()
517                output_local = model_local(x_local)
518
519                with CommDebugMode() as comm_mode:
520                    output_dist = model_dist(x)
521
522                self.assertEqual(output_local, output_dist)
523
524                # all requires_grad patterns should have the same forward comm counts
525                expected_fwd_comm = {
526                    funcol.reduce_scatter_tensor: 1,
527                    funcol.all_gather_into_tensor: 2,
528                }
529                self.assertDictEqual(
530                    comm_mode.comm_module_counts["Global"]["forward"], expected_fwd_comm
531                )
532
533                # backward step
534                output_local.sum().backward()
535
536                with CommDebugMode() as comm_mode:
537                    output_dist.sum().backward()
538
539                # ensure gradients (and parameters) remain equal between local and distributed models
540                self._check_module(model_local, model_dist, check_grad=True)
541
542                # different requires_grad patterns will have different bwd comm counts
543                if out_req_grad and not any((emb_req_grad, ln_req_grad)):
544                    expected_bwd_comm = {}
545                elif ln_req_grad and not any((emb_req_grad, multidim_norm)):
546                    expected_bwd_comm = {funcol.reduce_scatter_tensor: 1}
547                elif multidim_norm:
548                    expected_bwd_comm = {funcol.all_reduce: 1}
549                    expected_bwd_comm[funcol.all_gather_into_tensor] = (
550                        2 if emb_req_grad else 1
551                    )
552                else:
553                    expected_bwd_comm = {
554                        funcol.reduce_scatter_tensor: 1,
555                        funcol.all_gather_into_tensor: 1,
556                    }
557
558                self.assertDictEqual(
559                    comm_mode.comm_module_counts["Global"]["backward"],
560                    expected_bwd_comm,
561                )
562                self.assertEqual(output_local, output_dist)
563
564            except Exception as e:
565                subtest_fails[subtest_cfg] = e
566        # if any subtest fails, provide the failed subtests and report the overall failure
567        assert (
568            not subtest_fails
569        ), f"{len(subtest_fails)}/{len(subtest_cfgs)} subtests failed: {pformat(subtest_fails)}"
570
571    @with_comms
572    def test_topk(self):
573        device_mesh = self.build_device_mesh()
574        placement_combs = [Shard(0), Shard(1), Shard(2), Replicate()]
575
576        comm_mode = CommDebugMode()
577
578        tensor = torch.randn(12, 8, 8, requires_grad=True)
579        global_topk = tensor.topk(3, dim=0)
580
581        for placement in placement_combs:
582            dtensor = distribute_tensor(tensor, device_mesh, (placement,))
583            with comm_mode:
584                out_dt = dtensor.topk(3, dim=0)
585            if placement.is_shard(0):
586                self.assertEqual(comm_mode.get_total_counts(), 1)
587                self.assertEqual(
588                    comm_mode.get_comm_counts()[funcol.all_gather_into_tensor],
589                    1,
590                )
591            out_full_values = out_dt.values.full_tensor()
592            self.assertEqual(global_topk.values, out_full_values)
593
594            # TODO: support backward scatter
595            # global_topk.values.sum().backward()
596            # out_full_values.sum().backward()
597
598    @with_comms
599    def test_shard0_svd(self):
600        device_mesh = self.build_device_mesh()
601        torch.manual_seed(42)
602        replicated_x = torch.randn((8, 8), device=self.device_type)
603        sharded_x = distribute_tensor(replicated_x, device_mesh, (Shard(0),))
604        with CommDebugMode() as comm_mode:
605            U, S, V = torch.linalg.svd(sharded_x, full_matrices=False)
606        ref_U, ref_S, ref_V = torch.linalg.svd(replicated_x, full_matrices=False)
607        self.assertEqual(U.to_local(), ref_U)
608        self.assertEqual(S.to_local(), ref_S)
609        self.assertEqual(V.to_local(), ref_V)
610        comm_counts = comm_mode.get_comm_counts()
611        self.assertEqual(len(comm_counts), 1)
612        self.assertEqual(comm_counts[funcol.all_gather_into_tensor], 1)
613
614    @with_comms
615    def test_foreach_norm(self):
616        device_mesh = self.build_device_mesh()
617
618        grad0 = torch.randn(12, 8)
619        grad1 = torch.randn(8, 8)
620
621        sharded_grad0 = distribute_tensor(grad0, device_mesh, [Shard(0)])
622        sharded_grad1 = distribute_tensor(grad1, device_mesh, [Shard(0)])
623
624        # non-sharded op
625        out = torch.ops.aten._foreach_norm([grad0, grad1], 2)
626
627        # sharded op
628        sharded_out = torch.ops.aten._foreach_norm([sharded_grad0, sharded_grad1], 2)
629
630        for o, so in zip(out, sharded_out):
631            self.assertEqual(so.full_tensor(), o)
632
633    @with_comms
634    def test_linalg_eigh(self):
635        A = torch.randn(2, 2, dtype=torch.float64)
636        mesh = self.build_device_mesh()
637        dtensor_A = distribute_tensor(A, device_mesh=mesh, placements=[Replicate()])
638        dtensor_A = dtensor_A + dtensor_A.mT
639        dtensor_L, dtensor_Q = torch.linalg.eigh(dtensor_A)
640
641        # TODO: we need to convert A, L, Q to local because we don't have a
642        # sharding strategy registered for aten.dist.default yet.
643        local_A, local_L, local_Q = (
644            dtensor_A.to_local(),
645            dtensor_L.to_local(),
646            dtensor_Q.to_local(),
647        )
648        distance = torch.dist(local_Q @ torch.diag(local_L) @ local_Q.mT, local_A)
649        self.assertEqual(distance.item(), 0.0)
650
651
652if __name__ == "__main__":
653    run_tests()
654