1# Copyright (c) Meta Platforms, Inc. and affiliates 2# Owner(s): ["oncall: distributed"] 3 4import copy 5import itertools 6from pprint import pformat 7from typing import NamedTuple 8 9import torch 10from torch.distributed._tensor import ( 11 DeviceMesh, 12 distribute_module, 13 distribute_tensor, 14 DTensor, 15) 16from torch.distributed._tensor.placement_types import Replicate, Shard 17from torch.distributed.tensor._ops.utils import is_tensor_partial, normalize_dim 18from torch.distributed.tensor.debug import CommDebugMode 19from torch.distributed.tensor.parallel import ( 20 ColwiseParallel, 21 parallelize_module, 22 RowwiseParallel, 23 SequenceParallel, 24) 25from torch.testing._internal.common_utils import run_tests 26from torch.testing._internal.distributed._tensor.common_dtensor import ( 27 DTensorTestBase, 28 skip_unless_torch_gpu, 29 with_comms, 30) 31 32 33funcol = torch.ops.c10d_functional 34 35 36class DistMathOpsTest(DTensorTestBase): 37 def _check_module(self, m1, m2, check_grad=False): 38 named_parameters = dict(m1.named_parameters()) 39 for name, param_m2 in m2.named_parameters(): 40 self.assertTrue(name in named_parameters) 41 param_m1 = named_parameters[name] 42 if check_grad: 43 param_m2 = param_m2.grad 44 param_m1 = param_m1.grad 45 if isinstance(param_m2, DTensor): 46 replicate = [Replicate()] 47 param_m2 = param_m2.redistribute( 48 device_mesh=param_m2.device_mesh, placements=replicate 49 ).to_local() 50 self.assertEqual(param_m2, param_m1) 51 52 def linear_op_reductions(self, op_str): 53 device_mesh = self.build_device_mesh() 54 shard_spec = [Shard(0)] 55 56 tensor = torch.randn(12, 8, 8) 57 # TODO: check `all` correctness and test `all` on a bool tensor 58 if op_str in ("any"): 59 # test out a bool tensor for any 60 tensor = tensor < 0 61 dtensor = distribute_tensor(tensor, device_mesh, shard_spec) 62 63 op = getattr(tensor, op_str) 64 op_dt = getattr(dtensor, op_str) 65 66 keep_dim_or_not = [True, False, None] 67 for dim in range(tensor.ndim): 68 for keep_dim in keep_dim_or_not: 69 args = (dim, keep_dim) if keep_dim is not None else (dim,) 70 if op_str in ("max", "min"): 71 # min and max return a tuple when dim specified 72 dim_reduced_tensor, _ = op(*args) 73 dt_reduced, _ = op_dt(*args) 74 else: 75 dim_reduced_tensor = op(*args) 76 dt_reduced = op_dt(*args) 77 dt_dim_reduced_tensor = dt_reduced.full_tensor() 78 self.assertEqual(dt_dim_reduced_tensor, dim_reduced_tensor) 79 80 full_reduced_tensor = op() 81 dt_full_reduced = op_dt().full_tensor() 82 self.assertEqual(dt_full_reduced, full_reduced_tensor) 83 84 @with_comms 85 def test_linear_op_reductions(self): 86 for op_str in ("all", "sum", "prod", "max", "min", "any"): 87 self.linear_op_reductions(op_str) 88 89 @with_comms 90 @skip_unless_torch_gpu 91 def test_mean(self): 92 self.linear_op_reductions("mean") 93 94 # TODO: forward test can be removed once test_softmax_with_bwd passes on CPU 95 @with_comms 96 def test_softmax_fwd(self): 97 device_mesh = self.build_device_mesh() 98 99 x = torch.rand(8, 12, 16, device=self.device_type) 100 dims = range(3) # used to convert -1 to the actual dim 101 softmax_dims = [-1, 0, 1, 2] 102 shard_dims = [-1, 0, 1, 2] 103 test_list = list(itertools.product(softmax_dims, shard_dims)) 104 105 for softmax_dim, shard_dim in test_list: 106 local_y = torch.nn.functional.softmax( 107 x, dim=softmax_dim, dtype=torch.float32 108 ) 109 dist_x = distribute_tensor(x, device_mesh, [Shard(shard_dim)]) 110 dist_y = torch.nn.functional.softmax( 111 dist_x, dim=softmax_dim, dtype=torch.float32 112 ) 113 shard_dim = normalize_dim(shard_dim, dist_x.ndim) 114 if dims[shard_dim] == dims[softmax_dim]: 115 self.assertTrue(dist_y.placements[0].is_replicate()) 116 self.assertEqual(dist_y.to_local(), local_y) 117 else: 118 self.assertTrue(dist_y.placements[0].is_shard(dim=shard_dim)) 119 self.assertEqual(dist_y.full_tensor(), local_y) 120 121 # TODO: get test_softmax_with_bwd pass on CPU 122 # DTensor's _softmax_backward_data produces wrong result on CPU on certain dimension. 123 # fail_on_cpu_list = [(0, -1), (1, -1)] 124 @with_comms 125 @skip_unless_torch_gpu 126 def test_softmax_with_bwd(self): 127 device_mesh = self.build_device_mesh() 128 129 dims = range(3) # used to convert -1 to the actual dim 130 softmax_dims = [-1, 0, 1, 2] 131 shard_dims = [-1, 0, 1, 2] 132 test_list = list(itertools.product(softmax_dims, shard_dims)) 133 134 for params in test_list: 135 softmax_dim, shard_dim = params 136 x = torch.rand(8, 12, 16, device=self.device_type, requires_grad=True) 137 self.assertTrue(x.requires_grad) 138 local_y = torch.nn.functional.softmax( 139 x, dim=softmax_dim, dtype=torch.float32 140 ).sum() 141 local_y.backward() 142 143 dist_x = distribute_tensor(x, device_mesh, [Shard(shard_dim)]) 144 self.assertTrue(dist_x.requires_grad) 145 dist_softmax = dist_x.softmax(dim=softmax_dim) 146 shard_dim = normalize_dim(shard_dim, dist_x.ndim) 147 if dims[softmax_dim] == dims[shard_dim]: 148 self.assertTrue(dist_softmax.placements[0].is_replicate()) 149 else: 150 self.assertTrue(dist_softmax.placements[0].is_shard(dim=shard_dim)) 151 dist_y = dist_softmax.sum() 152 if dims[softmax_dim] == dims[shard_dim]: 153 self.assertTrue(dist_y.placements[0].is_replicate()) 154 else: 155 self.assertTrue(dist_y.placements[0].is_partial()) 156 dist_y = dist_y.redistribute(device_mesh, [Replicate()]) 157 self.assertEqual(dist_y.to_local(), local_y) 158 self.assertIsNone(dist_x.grad) 159 dist_y.backward() 160 self.assertIsNotNone(dist_x.grad) 161 if dims[softmax_dim] == dims[shard_dim]: 162 self.assertTrue(dist_x.grad.placements[0].is_replicate()) 163 else: 164 self.assertTrue(dist_x.grad.placements[0].is_shard(dim=shard_dim)) 165 self.assertEqual(dist_x.grad.full_tensor(), x.grad) 166 167 @with_comms 168 @skip_unless_torch_gpu 169 def test_nll_loss_and_cross_entropy(self): 170 device_mesh = self.build_device_mesh() 171 comm_mode = CommDebugMode() 172 173 channel_size, channel_dim = 16, 1 174 test_setup = [ 175 (2, (8, channel_size), (8,)), # calling aten.nll_loss_forward 176 (3, (8, channel_size, 12), (8, 12)), # calling aten.nll_loss2d_forward 177 ] 178 for input_ndim, input_size, target_size in test_setup: 179 x = torch.rand(*input_size, device=self.device_type, requires_grad=True) 180 target = torch.randint(channel_size, target_size, device=self.device_type) 181 dist_target = distribute_tensor(target, device_mesh, [Replicate()]) 182 183 shard_dims = list(range(input_ndim)) 184 reductions = ["none", "mean", "sum"] 185 # Compared with nll_loss, cross_entropy additionally calls log_softmax first. 186 # Testing them together as code can be reused. 187 loss_functions = [ 188 torch.nn.functional.nll_loss, 189 torch.nn.functional.cross_entropy, 190 ] 191 for shard_dim, reduction, loss_fn in itertools.product( 192 shard_dims, reductions, loss_functions 193 ): 194 dist_x = distribute_tensor(x, device_mesh, [Shard(shard_dim)]) 195 y = loss_fn(x, target, reduction=reduction) 196 if reduction == "none": 197 y.sum().backward() 198 else: 199 y.backward() 200 with comm_mode: 201 dist_y = loss_fn(dist_x, dist_target, reduction=reduction) 202 if shard_dim == channel_dim: 203 self.assertEqual(comm_mode.get_total_counts(), 1) 204 self.assertEqual( 205 comm_mode.get_comm_counts()[funcol.all_gather_into_tensor], 206 1, 207 ) 208 self.assertTrue(dist_y.placements[0].is_replicate()) 209 self.assertEqual(dist_y.to_local(), y) 210 else: 211 self.assertEqual(comm_mode.get_total_counts(), 0) 212 if reduction == "none": 213 output_shard_dim = ( 214 shard_dim if shard_dim < channel_dim else shard_dim - 1 215 ) 216 self.assertTrue( 217 dist_y.placements[0].is_shard(dim=output_shard_dim) 218 ) 219 else: 220 self.assertTrue(dist_y.placements[0].is_partial()) 221 self.assertEqual(dist_y.full_tensor(), y) 222 223 if reduction == "none": 224 dist_y.sum().backward() 225 else: 226 dist_y.backward() 227 if shard_dim == channel_dim: 228 self.assertTrue(dist_x.grad.placements[0].is_replicate()) 229 self.assertEqual(dist_x.grad.to_local(), x.grad) 230 else: 231 self.assertTrue( 232 dist_x.grad.placements[0].is_shard(dim=shard_dim) 233 ) 234 self.assertEqual(dist_x.grad.full_tensor(), x.grad) 235 x.grad.zero_() 236 237 @with_comms 238 def test_shard_math_ops(self): 239 mesh_shape = (2, self.world_size // 2) 240 mesh = DeviceMesh( 241 self.device_type, 242 torch.arange(self.world_size).reshape(*mesh_shape), 243 ) 244 global_tensor = torch.ones(4, 4) 245 double_shard_tensor = distribute_tensor( 246 global_tensor, mesh, [Shard(0), Shard(0)] 247 ) 248 fully_shard_tensor = distribute_tensor( 249 global_tensor, mesh, [Shard(0), Shard(1)] 250 ) 251 252 # for op in [torch.add, torch.sub, torch.mul, torch.div]: 253 for op in [torch.add, torch.sub, torch.mul, torch.div]: 254 expect_rs = op(global_tensor, 2) 255 double_shard_full_tensor = op(double_shard_tensor, 2).full_tensor() 256 self.assertEqual(double_shard_full_tensor, expect_rs) 257 258 fully_shard_full_tensor = op(fully_shard_tensor, 2).full_tensor() 259 self.assertEqual(fully_shard_full_tensor, expect_rs) 260 261 @with_comms 262 def test_layer_norm_fwd(self): 263 device_mesh = self.build_device_mesh() 264 265 # NLP example from pytorch docs 266 # https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html 267 batch, sentence_length, embedding_dim = 20, 5, 10 268 x = torch.rand(batch, sentence_length, embedding_dim, device=self.device_type) 269 norm_shape_idx_list = list(range(x.ndim)) 270 shard_dims = [-1, 0, 1, 2] 271 elementwise_affine_list = [False, True] 272 test_config_list = list( 273 itertools.product(shard_dims, norm_shape_idx_list, elementwise_affine_list) 274 ) 275 276 # normalized shape is a torch.Size object 277 for shard_dim, norm_idx, elementwise_affine in test_config_list: 278 normalized_shape = x.shape[norm_idx:] 279 layer_norm = torch.nn.LayerNorm( 280 normalized_shape, 281 elementwise_affine=elementwise_affine, 282 device=self.device_type, 283 ) 284 layer_norm_local = copy.deepcopy(layer_norm).to(self.device_type) 285 286 def _replicate_fn(name, module, device_mesh): 287 for name, param in module.named_parameters(): 288 if name in ["weight", "bias"]: 289 param_dist = torch.nn.Parameter( 290 distribute_tensor(param, device_mesh, [Replicate()]) 291 ) 292 module.register_parameter(name, param_dist) 293 294 layer_norm_dist = distribute_module(layer_norm, device_mesh, _replicate_fn) 295 296 x_local = x 297 x_dist = distribute_tensor(x, device_mesh, [Shard(shard_dim)]) 298 299 y_local = layer_norm_local(x_local) 300 # make sure that forward layer norm does not introduce extra collectives 301 comm_mode = CommDebugMode() 302 with comm_mode: 303 y_dist = layer_norm_dist(x_dist) 304 305 self.assertLessEqual( 306 comm_mode.get_total_counts(), 307 1, # TODO: This should be 0! 308 f"comm count={comm_mode.get_total_counts()}, " 309 f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}", 310 ) 311 312 from torch.distributed._tensor.placement_types import TensorMeta 313 314 dtensor_meta = y_dist._spec.tensor_meta 315 assert isinstance(dtensor_meta, TensorMeta) 316 # make sure the right shape in sharding prop 317 self.assertEqual(y_local.shape, dtensor_meta.shape) 318 self.assertEqual(y_local, y_dist.full_tensor()) 319 320 @with_comms 321 def test_layer_norm_bwd(self): 322 device_mesh = self.build_device_mesh() 323 324 # NLP example from pytorch docs 325 # https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html 326 batch, sentence_length, embedding_dim = 20, 5, 10 327 norm_shape_idx_list = list(range(3)) 328 shard_dims = [0, 1, 2] 329 elementwise_affine_list = [False, True] 330 test_config_list = list( 331 itertools.product(shard_dims, norm_shape_idx_list, elementwise_affine_list) 332 ) 333 334 # normalized shape is a torch.Size object 335 for shard_dim, norm_idx, elementwise_affine in test_config_list: 336 x = torch.rand( 337 batch, 338 sentence_length, 339 embedding_dim, 340 device=self.device_type, 341 requires_grad=True, 342 ) 343 normalized_shape = x.shape[norm_idx:] 344 layer_norm = torch.nn.LayerNorm( 345 normalized_shape, 346 elementwise_affine=elementwise_affine, 347 device=self.device_type, 348 ) 349 layer_norm_local = copy.deepcopy(layer_norm).to(self.device_type) 350 351 def _replicate_fn(name, module, device_mesh): 352 for name, param in module.named_parameters(): 353 if name in ["weight", "bias"]: 354 param_dist = torch.nn.Parameter( 355 distribute_tensor(param, device_mesh, [Replicate()]) 356 ) 357 module.register_parameter(name, param_dist) 358 359 layer_norm_dist = distribute_module(layer_norm, device_mesh, _replicate_fn) 360 361 if elementwise_affine: 362 self.assertEqual( 363 layer_norm_local.weight, layer_norm_dist.weight.full_tensor() 364 ) 365 self.assertEqual( 366 layer_norm_local.bias, layer_norm_dist.bias.full_tensor() 367 ) 368 369 x_local = x.detach().clone().requires_grad_(True) 370 x_dist = distribute_tensor(x, device_mesh, [Shard(shard_dim)]) 371 self.assertEqual(x_local, x_dist.full_tensor()) 372 373 y_local = layer_norm_local(x_local) 374 # make sure that backward layer norm does not introduce extra collectives 375 comm_mode = CommDebugMode() 376 with comm_mode: 377 y_dist = layer_norm_dist(x_dist) 378 y_dist.sum().backward() 379 380 expected_fwd_comm = 0 if shard_dim < norm_idx else 1 381 382 self.assertEqual( 383 sum(comm_mode.comm_module_counts["Global"]["forward"].values()), 384 expected_fwd_comm, 385 f"comm count={comm_mode.get_total_counts()}, " 386 f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}", 387 ) 388 389 self.assertEqual(y_local, y_dist.full_tensor()) 390 391 # backward step 392 y_local.sum().backward() 393 394 expected_bwd_comm = 0 if shard_dim < norm_idx else 1 395 396 self.assertEqual( 397 sum(comm_mode.comm_module_counts["Global"]["backward"].values()), 398 expected_bwd_comm, 399 f"comm count={comm_mode.get_total_counts()}, " 400 f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}", 401 ) 402 403 if elementwise_affine: 404 # if input is sharded on any outer dimension, the gradient of weight 405 # and bias should be Partial 406 dim_map = x_dist._spec.dim_map 407 outer_dims = range(norm_idx) 408 needs_reduction = any(dim_map[d] >= 0 for d in outer_dims) 409 self.assertEqual( 410 is_tensor_partial(layer_norm_dist.weight.grad._spec), 411 needs_reduction, 412 ) 413 self.assertEqual( 414 is_tensor_partial(layer_norm_dist.bias.grad._spec), 415 needs_reduction, 416 ) 417 self.assertEqual( 418 layer_norm_local.weight.grad, 419 layer_norm_dist.weight.grad.full_tensor(), 420 ) 421 self.assertEqual( 422 layer_norm_local.bias.grad, 423 layer_norm_dist.bias.grad.full_tensor(), 424 ) 425 426 self.assertEqual(x_local.grad, x_dist.grad.full_tensor()) 427 428 @with_comms 429 def test_layer_norm_bwd_req_grad(self): 430 device_mesh = self.build_device_mesh() 431 batch, seq_len, embedding_dim, vocab_size = 8, 8, 10, 32 432 433 # build our subtest configurations and filter out invalid ones 434 class SubTest(NamedTuple): 435 multidim_norm: bool 436 elementwise_affine: bool 437 emb_req_grad: bool 438 ln_req_grad: bool 439 out_req_grad: bool 440 441 subtest_fails = {} 442 valid_filter = lambda cfg: not ( # noqa: E731 443 cfg.ln_req_grad and not cfg.elementwise_affine 444 ) and any(cfg[2:]) 445 subtest_cfgs = list( 446 filter( 447 valid_filter, 448 [SubTest(*cfg) for cfg in itertools.product(*(((False, True),) * 5))], 449 ) 450 ) 451 452 for subtest_cfg in subtest_cfgs: 453 try: 454 ( 455 multidim_norm, 456 elementwise_affine, 457 emb_req_grad, 458 ln_req_grad, 459 out_req_grad, 460 ) = subtest_cfg 461 normalized_shape = ( 462 (seq_len, embedding_dim) if multidim_norm else (embedding_dim,) 463 ) 464 465 # configure our local and parallelized models for this subtest 466 class LnTpBlock(torch.nn.Module): 467 def __init__(self): 468 super().__init__() 469 self.preln_embeddings = torch.nn.Embedding( 470 vocab_size, embedding_dim 471 ) 472 self.layer_norm = torch.nn.LayerNorm( 473 normalized_shape, elementwise_affine=elementwise_affine 474 ) 475 self.postln_linear = torch.nn.Linear( 476 embedding_dim, embedding_dim 477 ) 478 479 def forward(self, tokens): 480 h = self.preln_embeddings(tokens) 481 h = self.layer_norm(h) 482 output = self.postln_linear(h) 483 return output 484 485 parallel_plan = { 486 "preln_embeddings": RowwiseParallel( 487 input_layouts=Replicate(), output_layouts=Shard(1) 488 ), 489 "layer_norm": SequenceParallel(), 490 "postln_linear": ColwiseParallel( 491 input_layouts=Shard(1), 492 output_layouts=Replicate(), 493 ), 494 } 495 496 model = LnTpBlock() 497 model_local = copy.deepcopy(model).to(device=self.device_type) 498 model_dist = parallelize_module(model, device_mesh, parallel_plan) 499 req_grad_map = { 500 "preln_embeddings": emb_req_grad, 501 "postln_linear": out_req_grad, 502 "layer_norm": ln_req_grad, 503 } 504 505 # apply the relevant `requires_grad` mask for this subtest to both models 506 for target_model in [model_local, model_dist]: 507 for n, p in target_model.named_parameters(): 508 if not req_grad_map.get(n.rpartition(".")[0], False): 509 p.requires_grad_(False) 510 assert not p.requires_grad 511 else: 512 assert p.requires_grad 513 514 # forward step for both local and distributed models 515 x = torch.randint(vocab_size, (batch, seq_len), device=self.device_type) 516 x_local = x.detach().clone() 517 output_local = model_local(x_local) 518 519 with CommDebugMode() as comm_mode: 520 output_dist = model_dist(x) 521 522 self.assertEqual(output_local, output_dist) 523 524 # all requires_grad patterns should have the same forward comm counts 525 expected_fwd_comm = { 526 funcol.reduce_scatter_tensor: 1, 527 funcol.all_gather_into_tensor: 2, 528 } 529 self.assertDictEqual( 530 comm_mode.comm_module_counts["Global"]["forward"], expected_fwd_comm 531 ) 532 533 # backward step 534 output_local.sum().backward() 535 536 with CommDebugMode() as comm_mode: 537 output_dist.sum().backward() 538 539 # ensure gradients (and parameters) remain equal between local and distributed models 540 self._check_module(model_local, model_dist, check_grad=True) 541 542 # different requires_grad patterns will have different bwd comm counts 543 if out_req_grad and not any((emb_req_grad, ln_req_grad)): 544 expected_bwd_comm = {} 545 elif ln_req_grad and not any((emb_req_grad, multidim_norm)): 546 expected_bwd_comm = {funcol.reduce_scatter_tensor: 1} 547 elif multidim_norm: 548 expected_bwd_comm = {funcol.all_reduce: 1} 549 expected_bwd_comm[funcol.all_gather_into_tensor] = ( 550 2 if emb_req_grad else 1 551 ) 552 else: 553 expected_bwd_comm = { 554 funcol.reduce_scatter_tensor: 1, 555 funcol.all_gather_into_tensor: 1, 556 } 557 558 self.assertDictEqual( 559 comm_mode.comm_module_counts["Global"]["backward"], 560 expected_bwd_comm, 561 ) 562 self.assertEqual(output_local, output_dist) 563 564 except Exception as e: 565 subtest_fails[subtest_cfg] = e 566 # if any subtest fails, provide the failed subtests and report the overall failure 567 assert ( 568 not subtest_fails 569 ), f"{len(subtest_fails)}/{len(subtest_cfgs)} subtests failed: {pformat(subtest_fails)}" 570 571 @with_comms 572 def test_topk(self): 573 device_mesh = self.build_device_mesh() 574 placement_combs = [Shard(0), Shard(1), Shard(2), Replicate()] 575 576 comm_mode = CommDebugMode() 577 578 tensor = torch.randn(12, 8, 8, requires_grad=True) 579 global_topk = tensor.topk(3, dim=0) 580 581 for placement in placement_combs: 582 dtensor = distribute_tensor(tensor, device_mesh, (placement,)) 583 with comm_mode: 584 out_dt = dtensor.topk(3, dim=0) 585 if placement.is_shard(0): 586 self.assertEqual(comm_mode.get_total_counts(), 1) 587 self.assertEqual( 588 comm_mode.get_comm_counts()[funcol.all_gather_into_tensor], 589 1, 590 ) 591 out_full_values = out_dt.values.full_tensor() 592 self.assertEqual(global_topk.values, out_full_values) 593 594 # TODO: support backward scatter 595 # global_topk.values.sum().backward() 596 # out_full_values.sum().backward() 597 598 @with_comms 599 def test_shard0_svd(self): 600 device_mesh = self.build_device_mesh() 601 torch.manual_seed(42) 602 replicated_x = torch.randn((8, 8), device=self.device_type) 603 sharded_x = distribute_tensor(replicated_x, device_mesh, (Shard(0),)) 604 with CommDebugMode() as comm_mode: 605 U, S, V = torch.linalg.svd(sharded_x, full_matrices=False) 606 ref_U, ref_S, ref_V = torch.linalg.svd(replicated_x, full_matrices=False) 607 self.assertEqual(U.to_local(), ref_U) 608 self.assertEqual(S.to_local(), ref_S) 609 self.assertEqual(V.to_local(), ref_V) 610 comm_counts = comm_mode.get_comm_counts() 611 self.assertEqual(len(comm_counts), 1) 612 self.assertEqual(comm_counts[funcol.all_gather_into_tensor], 1) 613 614 @with_comms 615 def test_foreach_norm(self): 616 device_mesh = self.build_device_mesh() 617 618 grad0 = torch.randn(12, 8) 619 grad1 = torch.randn(8, 8) 620 621 sharded_grad0 = distribute_tensor(grad0, device_mesh, [Shard(0)]) 622 sharded_grad1 = distribute_tensor(grad1, device_mesh, [Shard(0)]) 623 624 # non-sharded op 625 out = torch.ops.aten._foreach_norm([grad0, grad1], 2) 626 627 # sharded op 628 sharded_out = torch.ops.aten._foreach_norm([sharded_grad0, sharded_grad1], 2) 629 630 for o, so in zip(out, sharded_out): 631 self.assertEqual(so.full_tensor(), o) 632 633 @with_comms 634 def test_linalg_eigh(self): 635 A = torch.randn(2, 2, dtype=torch.float64) 636 mesh = self.build_device_mesh() 637 dtensor_A = distribute_tensor(A, device_mesh=mesh, placements=[Replicate()]) 638 dtensor_A = dtensor_A + dtensor_A.mT 639 dtensor_L, dtensor_Q = torch.linalg.eigh(dtensor_A) 640 641 # TODO: we need to convert A, L, Q to local because we don't have a 642 # sharding strategy registered for aten.dist.default yet. 643 local_A, local_L, local_Q = ( 644 dtensor_A.to_local(), 645 dtensor_L.to_local(), 646 dtensor_Q.to_local(), 647 ) 648 distance = torch.dist(local_Q @ torch.diag(local_L) @ local_Q.mT, local_A) 649 self.assertEqual(distance.item(), 0.0) 650 651 652if __name__ == "__main__": 653 run_tests() 654