1# Copyright (c) Meta Platforms, Inc. and affiliates 2# Owner(s): ["oncall: distributed"] 3 4import itertools 5from typing import cast, List 6 7import torch 8import torch.distributed as dist 9from torch import rand, randn, Tensor 10from torch.distributed._tensor import ( 11 DeviceMesh, 12 distribute_tensor, 13 init_device_mesh, 14 Replicate, 15 Shard, 16) 17from torch.distributed._tensor.placement_types import Placement 18from torch.distributed.tensor._ops._view_ops import ( 19 Broadcast, 20 dim_maps, 21 Flatten, 22 InputDim, 23 Repeat, 24 Singleton, 25 Split, 26 view_groups, 27) 28from torch.distributed.tensor.debug import CommDebugMode 29from torch.testing._internal.common_utils import run_tests 30from torch.testing._internal.distributed._tensor.common_dtensor import ( 31 DTensorTestBase, 32 with_comms, 33) 34from torch.utils import _pytree as pytree 35 36 37class TestViewOps(DTensorTestBase): 38 @property 39 def world_size(self) -> int: 40 return 6 41 42 def test_view_groups(self): 43 self.assertEqual( 44 view_groups([2, 3], [3, 2]), 45 ( 46 Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 0), 47 Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 1), 48 ), 49 ) 50 self.assertEqual( 51 view_groups([3, 4, 5], [12, 5]), 52 (Flatten((InputDim(0), InputDim(1))), InputDim(2)), 53 ) 54 self.assertEqual( 55 view_groups([2, 3, 4, 5, 7], [12, 70]), 56 ( 57 Split( 58 Flatten( 59 ( 60 InputDim(0), 61 InputDim(1), 62 InputDim(2), 63 InputDim(3), 64 InputDim(4), 65 ) 66 ), 67 (12, 70), 68 0, 69 ), 70 Split( 71 Flatten( 72 ( 73 InputDim(0), 74 InputDim(1), 75 InputDim(2), 76 InputDim(3), 77 InputDim(4), 78 ) 79 ), 80 (12, 70), 81 1, 82 ), 83 ), 84 ) 85 self.assertEqual( 86 view_groups([2, 3, 4, 5, 7], [3, 8, 7, 5]), 87 ( 88 Split(Flatten((InputDim(0), InputDim(1), InputDim(2))), (3, 8), 0), 89 Split(Flatten((InputDim(0), InputDim(1), InputDim(2))), (3, 8), 1), 90 Split(Flatten((InputDim(3), InputDim(4))), (7, 5), 0), 91 Split(Flatten((InputDim(3), InputDim(4))), (7, 5), 1), 92 ), 93 ) 94 self.assertEqual( 95 view_groups([3, 4, 8, 3], [12, 4, 2, 3]), 96 ( 97 Flatten((InputDim(0), InputDim(1))), 98 Split(InputDim(2), (4, 2), 0), 99 Split(InputDim(2), (4, 2), 1), 100 InputDim(3), 101 ), 102 ) 103 self.assertEqual( 104 view_groups([3, 24], [1, 3, 2, 4, 1, 3, 1]), 105 ( 106 Singleton(), 107 InputDim(0), 108 Split(InputDim(1), (2, 4, 3), 0), 109 Split(InputDim(1), (2, 4, 3), 1), 110 Singleton(), 111 Split(InputDim(1), (2, 4, 3), 2), 112 Singleton(), 113 ), 114 ) 115 self.assertEqual( 116 view_groups([1, 1, 3, 2, 1, 1], [6, 1, 1, 1]), 117 ( 118 Flatten((InputDim(2), InputDim(3))), 119 InputDim(4), 120 InputDim(5), 121 Singleton(), 122 ), 123 ) 124 self.assertEqual( 125 view_groups([1, 1, 12, 1, 1, 1, 2, 5, 1], [3, 4, 1, 10]), 126 ( 127 Split(InputDim(2), (3, 4), 0), 128 Split(InputDim(2), (3, 4), 1), 129 InputDim(3), 130 Flatten((InputDim(6), InputDim(7))), 131 ), 132 ) 133 self.assertEqual( 134 view_groups([2, 3, 4], [2, -1, 4]), 135 (InputDim(0), InputDim(1), InputDim(2)), 136 ) 137 138 def call_dt_test(self, op, args, kwargs, device_mesh: DeviceMesh): 139 dim_map = dim_maps[op] 140 rules = dim_map(*args, **kwargs) 141 outputs = op(*args, **kwargs) 142 flat_args = pytree.arg_tree_leaves(*args) 143 in_shape = flat_args[0].shape 144 145 no_shard_dims = set() 146 for rule in rules: 147 if isinstance(rule, Repeat): 148 if isinstance(rule.input_dim, InputDim): 149 no_shard_dims.add(rule.input_dim.input_dim) 150 elif isinstance(rule, Flatten): 151 for dim in rule.input_dims[1:]: 152 if isinstance(dim, InputDim): 153 no_shard_dims.add(dim.input_dim) 154 elif isinstance(rule, Split): 155 if isinstance(rule.input_dim, Flatten): 156 for dim in rule.input_dim.input_dims[1:]: 157 if isinstance(dim, InputDim): 158 no_shard_dims.add(dim.input_dim) 159 160 if op == torch.unbind: 161 no_shard_dims.add(kwargs.get("dim", 0)) 162 163 sharding_choices = cast(List[Placement], [Replicate()]) + [ 164 Shard(i) for i, s in enumerate(in_shape) if s > 1 and i not in no_shard_dims 165 ] 166 167 all_sharding_choices = itertools.product( 168 *(device_mesh.ndim * [sharding_choices]) 169 ) 170 171 for in_shard in all_sharding_choices: 172 in_dt = distribute_tensor(args[0], device_mesh, in_shard) 173 174 comm_mode = CommDebugMode() 175 with comm_mode: 176 out_dt = op(in_dt, *args[1:], **kwargs) 177 178 self.assertEqual( 179 comm_mode.get_total_counts(), 0, "Expected no redistribution." 180 ) 181 182 full_out = out_dt.full_tensor() 183 184 if dist.get_rank() == 0: 185 self.assertEqual(outputs, full_out) 186 187 def dimmap_test(self, op, args, expected_rule_output): 188 rules = dim_maps[op](*args) 189 self.assertEqual(rules, expected_rule_output) 190 self.call_dt_test(op, args, {}, self.device_mesh) 191 192 @with_comms 193 def test_view_ops(self): 194 self.device_mesh = DeviceMesh( 195 self.device_type, torch.arange(dist.get_world_size()).view(-1, 2) 196 ) 197 self.dimmap_test(torch.atleast_1d, (randn(()),), (Singleton(),)) 198 self.dimmap_test(torch.atleast_1d, (randn(24),), (InputDim(0),)) 199 self.dimmap_test(torch.atleast_1d, (randn(24, 36),), (InputDim(0), InputDim(1))) 200 201 self.dimmap_test(torch.atleast_2d, (randn(()),), (Singleton(), Singleton())) 202 self.dimmap_test(torch.atleast_2d, (randn(24),), (Singleton(), InputDim(0))) 203 self.dimmap_test(torch.atleast_2d, (randn(24, 36),), (InputDim(0), InputDim(1))) 204 self.dimmap_test( 205 torch.atleast_2d, 206 (randn(24, 36, 48),), 207 (InputDim(0), InputDim(1), InputDim(2)), 208 ) 209 210 self.dimmap_test( 211 torch.atleast_3d, 212 (randn(()),), 213 (Singleton(), Singleton(), Singleton()), 214 ) 215 self.dimmap_test( 216 torch.atleast_3d, 217 (randn(24),), 218 (Singleton(), InputDim(0), Singleton()), 219 ) 220 self.dimmap_test( 221 torch.atleast_3d, 222 (randn(24, 36),), 223 (InputDim(0), InputDim(1), Singleton()), 224 ) 225 self.dimmap_test( 226 torch.atleast_3d, 227 (randn(24, 36, 42),), 228 (InputDim(0), InputDim(1), InputDim(2)), 229 ) 230 self.dimmap_test( 231 torch.atleast_3d, 232 (randn(24, 36, 42, 24),), 233 (InputDim(0), InputDim(1), InputDim(2), InputDim(3)), 234 ) 235 236 with self.assertRaises(AssertionError): 237 dim_maps[torch.broadcast_to](randn(24, 36), (1, 2, 4)) 238 239 self.dimmap_test( 240 torch.broadcast_to, 241 (rand(24, 36), (1, 24, 36)), 242 (Singleton(), InputDim(0), InputDim(1)), 243 ) 244 self.dimmap_test( 245 torch.broadcast_to, 246 (rand(24, 36), (42, 24, 36)), 247 (Broadcast(Singleton(), 42), InputDim(0), InputDim(1)), 248 ) 249 self.dimmap_test( 250 torch.broadcast_to, 251 (rand(24, 1, 36), (12, 24, 24, 36)), 252 ( 253 Broadcast(Singleton(), 12), 254 InputDim(0), 255 Broadcast(InputDim(1), 24), 256 InputDim(2), 257 ), 258 ) 259 self.dimmap_test( 260 torch.broadcast_to, 261 (rand(24, 36), (-1, 36)), 262 (InputDim(0), InputDim(1)), 263 ) 264 self.dimmap_test( 265 torch.broadcast_to, 266 (rand(24, 1, 36), (-1, 1, 36)), 267 (InputDim(0), InputDim(1), InputDim(2)), 268 ) 269 270 self.dimmap_test( 271 torch.broadcast_to, 272 (randn(36, 1, 24), (12, 36, 42, 24)), 273 ( 274 Broadcast(Singleton(), 12), 275 InputDim(0), 276 Broadcast(InputDim(1), 42), 277 InputDim(2), 278 ), 279 ) 280 281 self.dimmap_test( 282 Tensor.expand, 283 (randn(24, 1, 36, 1), 36, 24, 42, -1, 24), 284 ( 285 Broadcast(Singleton(), 36), 286 InputDim(0), 287 Broadcast(InputDim(1), 42), 288 InputDim(2), 289 Broadcast(InputDim(3), 24), 290 ), 291 ) 292 293 self.dimmap_test( 294 Tensor.expand, 295 (randn(24, 1, 36, 1), (36, 24, 42, -1, 24)), 296 ( 297 Broadcast(Singleton(), 36), 298 InputDim(0), 299 Broadcast(InputDim(1), 42), 300 InputDim(2), 301 Broadcast(InputDim(3), 24), 302 ), 303 ) 304 305 self.dimmap_test( 306 torch.flatten, 307 (randn(24, 36),), 308 (Flatten((InputDim(0), InputDim(1))),), 309 ) 310 self.dimmap_test(torch.flatten, (randn(42),), (InputDim(0),)) 311 self.dimmap_test(torch.flatten, (randn(()),), (Singleton(),)) 312 313 self.dimmap_test( 314 torch.movedim, 315 (randn(12, 24, 48, 96), 1, 2), 316 (InputDim(0), InputDim(2), InputDim(1), InputDim(3)), 317 ) 318 self.dimmap_test( 319 torch.movedim, 320 (randn(6, 12, 24), 1, 0), 321 (InputDim(1), InputDim(0), InputDim(2)), 322 ) 323 self.dimmap_test( 324 torch.movedim, 325 (randn(24, 12, 6), (1, 2), (0, 1)), 326 (InputDim(1), InputDim(2), InputDim(0)), 327 ) 328 self.dimmap_test( 329 torch.movedim, 330 (randn(24, 6, 12), (0, 2, 1), (2, 1, 0)), 331 (InputDim(1), InputDim(2), InputDim(0)), 332 ) 333 self.dimmap_test( 334 torch.movedim, 335 (randn(24, 12), (1, 0), (0, 1)), 336 (InputDim(1), InputDim(0)), 337 ) 338 339 self.dimmap_test( 340 torch.movedim, 341 (randn(36, 24, 12), (1, 2), (0, 1)), 342 (InputDim(1), InputDim(2), InputDim(0)), 343 ) 344 self.dimmap_test( 345 torch.movedim, 346 (randn(36, 24, 12), (1, 2), (-3, -2)), 347 (InputDim(1), InputDim(2), InputDim(0)), 348 ) 349 350 self.dimmap_test( 351 torch.permute, 352 (randn(24, 36, 42), (2, 0, 1)), 353 (InputDim(2), InputDim(0), InputDim(1)), 354 ) 355 self.dimmap_test( 356 torch.permute, 357 (randn(24, 36, 42), (-1, -3, -2)), 358 (InputDim(2), InputDim(0), InputDim(1)), 359 ) 360 361 self.dimmap_test( 362 torch.ravel, 363 (randn(24, 36),), 364 (Flatten((InputDim(0), InputDim(1))),), 365 ) 366 self.dimmap_test(torch.ravel, (randn(42),), (InputDim(0),)) 367 self.dimmap_test(torch.ravel, (randn(()),), (Singleton(),)) 368 369 self.dimmap_test( 370 Tensor.repeat, 371 (randn(24, 36), 1, 2, 1, 1, 2), 372 ( 373 Singleton(), 374 Broadcast(Singleton(), 2), 375 Singleton(), 376 InputDim(0), 377 Repeat(InputDim(1), 2), 378 ), 379 ) 380 381 self.dimmap_test( 382 torch.reshape, 383 (randn(6, 12, 24), (72, 24)), 384 (Flatten((InputDim(0), InputDim(1))), InputDim(2)), 385 ) 386 387 self.dimmap_test( 388 torch.tile, 389 (randn(24, 36), (1, 2, 1, 1, 2)), 390 ( 391 Singleton(), 392 Broadcast(Singleton(), 2), 393 Singleton(), 394 InputDim(0), 395 Repeat(InputDim(1), 2), 396 ), 397 ) 398 self.dimmap_test( 399 torch.tile, 400 (randn(42, 24, 36), (1, 3)), 401 (InputDim(0), InputDim(1), Repeat(InputDim(2), 3)), 402 ) 403 404 self.dimmap_test( 405 torch.transpose, 406 (randn(24, 60, 42, 60), 2, 0), 407 (InputDim(2), InputDim(1), InputDim(0), InputDim(3)), 408 ) 409 self.dimmap_test( 410 torch.transpose, 411 (randn(24, 60, 42, 60), -1, 0), 412 (InputDim(3), InputDim(1), InputDim(2), InputDim(0)), 413 ) 414 415 self.dimmap_test( 416 torch.unsqueeze, 417 (randn(42, 24, 36), 1), 418 (InputDim(0), Singleton(), InputDim(1), InputDim(2)), 419 ) 420 421 self.dimmap_test( 422 Tensor.view, 423 (randn(6, 12, 24), 72, 24), 424 (Flatten((InputDim(0), InputDim(1))), InputDim(2)), 425 ) 426 427 self.dimmap_test(Tensor.view, (randn(1, 1, 12), -1), (InputDim(2),)) 428 429 self.dimmap_test( 430 Tensor.view, 431 (randn(1, 1, 42, 24), -1), 432 (Flatten((InputDim(2), InputDim(3))),), 433 ) 434 435 self.dimmap_test( 436 Tensor.view, 437 (randn(1, 1, 42, 1, 24, 1), -1), 438 (Flatten((InputDim(2), InputDim(input_dim=3), InputDim(4))),), 439 ) 440 441 self.dimmap_test( 442 Tensor.view, 443 (randn(48, 35, 26), (24, 4, 35, 13)), 444 ( 445 Split( 446 Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))), 447 group_shape=(24, 4, 35, 13), 448 split_id=0, 449 ), 450 Split( 451 Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))), 452 group_shape=(24, 4, 35, 13), 453 split_id=1, 454 ), 455 Split( 456 Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))), 457 group_shape=(24, 4, 35, 13), 458 split_id=2, 459 ), 460 Split( 461 Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))), 462 group_shape=(24, 4, 35, 13), 463 split_id=3, 464 ), 465 ), 466 ) 467 468 # TODO: Currently functional collectives on complex numbers are not fully supported, 469 # so we are having a standalone test for view_as_complex and view_as_real combined. 470 # Once complex numbers are supported, we can add the following to the dim_map test. 471 # 472 # self.dimmap_test( 473 # torch.view_as_complex, 474 # (randn(24, 13, 2),), 475 # ( 476 # InputDim(0), 477 # Flatten((InputDim(1), InputDim(2))), 478 # ), 479 # ) 480 # self.dimmap_test( 481 # torch.view_as_real, 482 # (torch.randn(24, 13, dtype=torch.cfloat),), 483 # ( 484 # InputDim(0), 485 # Split(InputDim(1), (13, 2), 0), 486 # Split(InputDim(1), (13, 2), 1), 487 # ), 488 # ) 489 @with_comms 490 def test_complex_view_ops(self): 491 self.device_mesh = DeviceMesh( 492 self.device_type, torch.arange(dist.get_world_size()).view(-1, 2) 493 ) 494 inp = randn(24, 13, 2) 495 intermediate = torch.view_as_complex(inp) 496 out = torch.view_as_real(intermediate) 497 498 # test dim_map correctness 499 expected_view_as_complex_rule = ( 500 InputDim(0), 501 Flatten((InputDim(1), InputDim(2))), 502 ) 503 view_as_complex_rule = dim_maps[torch.view_as_complex](inp) 504 self.assertEqual(view_as_complex_rule, expected_view_as_complex_rule) 505 expected_view_as_real_rule = ( 506 InputDim(0), 507 Split(InputDim(1), (13, 2), 0), 508 Split(InputDim(1), (13, 2), 1), 509 ) 510 view_as_real_rule = dim_maps[torch.view_as_real](intermediate) 511 self.assertEqual(view_as_real_rule, expected_view_as_real_rule) 512 513 # test sharded computation correctness 514 # NOTE: For the input to torch.view_as_complex, sharding 515 # on the last two dimensions is not supported. 516 sharding_choices: List[Placement] = [Replicate(), Shard(0)] 517 all_sharding_choices = itertools.product( 518 *(self.device_mesh.ndim * [sharding_choices]) 519 ) 520 521 for inp_shard in all_sharding_choices: 522 inp_dt = distribute_tensor(inp, self.device_mesh, inp_shard) 523 524 comm_mode = CommDebugMode() 525 with comm_mode: 526 intermediate_dt = torch.view_as_complex(inp_dt) 527 out_dt = torch.view_as_real(intermediate_dt) 528 529 self.assertEqual( 530 comm_mode.get_total_counts(), 0, "Expected no redistribution." 531 ) 532 self.assertEqual(out, out_dt.full_tensor()) 533 534 @with_comms 535 def test_dtensor_view_op_uneven(self): 536 """ 537 Test two uneven cases for view op: 538 1) the sharded tensor dim is 1 so that only the first rank has an non-empty shard. 539 2) the sharded tensor dim is uneven such that some ranks have full shards, 540 smaller non-empty shards, and empty shards. 541 """ 542 dim0_sizes = [1, self.world_size + 1] 543 for dim0_size in dim0_sizes: 544 p = torch.randn(dim0_size, 2, 2, 2) 545 mesh = init_device_mesh(self.device_type, (self.world_size,)) 546 dtensor = distribute_tensor(p, mesh, [Shard(0)]) 547 548 with CommDebugMode() as comm_mode: 549 view = dtensor.view(dim0_size, 2, 4) 550 self.assertEqual(len(comm_mode.get_comm_counts()), 0) 551 # when no communication happens, the data pointer should be the same. 552 self.assertEqual( 553 view.to_local().data_ptr(), dtensor.to_local().data_ptr() 554 ) 555 556 view = dtensor.view(dim0_size, 4, 2) 557 self.assertEqual( 558 view.to_local().data_ptr(), dtensor.to_local().data_ptr() 559 ) 560 self.assertEqual(len(comm_mode.get_comm_counts()), 0) 561 562 view = dtensor.view(dim0_size, 8) 563 self.assertEqual( 564 view.to_local().data_ptr(), dtensor.to_local().data_ptr() 565 ) 566 self.assertEqual(len(comm_mode.get_comm_counts()), 0) 567 568 view = dtensor.view(dtensor.shape) 569 self.assertEqual( 570 view.to_local().data_ptr(), dtensor.to_local().data_ptr() 571 ) 572 self.assertEqual(len(comm_mode.get_comm_counts()), 0) 573 574 575if __name__ == "__main__": 576 run_tests() 577