1# Owner(s): ["oncall: distributed"] 2 3import copy 4import io 5import itertools 6import math 7import pickle 8import sys 9from typing import List 10 11import torch 12import torch.distributed as dist 13from torch.distributed import distributed_c10d, rpc 14from torch.distributed._shard import sharded_tensor 15from torch.distributed._shard.api import ( 16 _collect_local_shard, 17 _reshard_output, 18 _shard_tensor, 19 load_with_process_group, 20 shard_parameter, 21) 22from torch.distributed._shard.sharded_tensor import ( 23 custom_sharded_op_impl, 24 pre_load_state_dict_hook, 25 Shard, 26 ShardedTensor, 27 ShardedTensorBase, 28 ShardedTensorMetadata, 29 state_dict_hook, 30) 31from torch.distributed._shard.sharded_tensor.api import ( 32 _create_tensor_from_params, 33 TensorProperties, 34) 35from torch.distributed._shard.sharded_tensor.utils import ( 36 _parse_and_validate_remote_device, 37) 38from torch.distributed._shard.sharding_spec import ( 39 ChunkShardingSpec, 40 EnumerableShardingSpec, 41 ShardMetadata, 42) 43from torch.distributed.remote_device import _remote_device 44from torch.testing._internal.common_distributed import ( 45 requires_nccl, 46 skip_if_lt_x_gpu, 47 spawn_threads_and_init_comms, 48 tp_transports, 49) 50from torch.testing._internal.common_utils import ( 51 run_tests, 52 skip_but_pass_in_sandcastle_if, 53 TEST_CUDA, 54 TEST_WITH_DEV_DBG_ASAN, 55 TestCase, 56) 57from torch.testing._internal.distributed._shard.sharded_tensor import ( 58 ShardedTensorTestBase, 59 with_comms, 60) 61from torch.testing._internal.distributed._shard.sharded_tensor._test_st_common import ( 62 _chunk_sharding_specs_list_for_test, 63 MyShardedModel1, 64) 65 66 67if TEST_WITH_DEV_DBG_ASAN: 68 print( 69 "Skip dev-asan as torch + multiprocessing spawn have known issues", 70 file=sys.stderr, 71 ) 72 sys.exit(0) 73 74 75class TestShardedTensorMetadata(TestCase): 76 def test_serialize_and_deserialize(self): 77 shard_metadatas = [ 78 ShardMetadata( 79 shard_offsets=[0, 0], 80 shard_sizes=[5, 5], 81 placement="rank:0/cuda:0", 82 ), 83 ShardMetadata( 84 shard_offsets=[0, 5], 85 shard_sizes=[5, 5], 86 placement="rank:1/cuda:1", 87 ), 88 ShardMetadata( 89 shard_offsets=[5, 0], 90 shard_sizes=[5, 5], 91 placement="rank:2/cuda:2", 92 ), 93 ShardMetadata( 94 shard_offsets=[5, 5], 95 shard_sizes=[5, 5], 96 placement="rank:3/cuda:3", 97 ), 98 ] 99 100 dtypes = [ 101 torch.float, 102 torch.double, 103 torch.cfloat, 104 torch.cdouble, 105 torch.half, 106 torch.bfloat16, 107 torch.uint8, 108 torch.int8, 109 torch.short, 110 torch.int, 111 torch.long, 112 torch.bool, 113 ] 114 115 layouts = [torch.strided, torch.sparse_coo] 116 requires_grads = [True, False] 117 memory_formats = [ 118 torch.contiguous_format, 119 torch.channels_last, 120 torch.preserve_format, 121 ] 122 pin_memories = [True, False] 123 124 for tensor_properties_input in itertools.product( 125 dtypes, layouts, requires_grads, memory_formats, pin_memories 126 ): 127 ( 128 dtype, 129 layout, 130 requires_grad, 131 memory_format, 132 pin_memory, 133 ) = tensor_properties_input 134 135 expected_st_metadata = sharded_tensor.ShardedTensorMetadata( 136 shard_metadatas, 137 (10, 10), 138 TensorProperties( 139 dtype, layout, requires_grad, memory_format, pin_memory 140 ), 141 ) 142 143 pickled_obj = pickle.dumps(expected_st_metadata) 144 st_metadata = pickle.loads(pickled_obj) 145 self.assertEqual(expected_st_metadata, st_metadata) 146 147 148class TestCreateTensorFromParams(TestCase): 149 @skip_but_pass_in_sandcastle_if(not TEST_CUDA, "CUDA GPU is needed") 150 def test_empty(self): 151 expected_dtype = torch.double 152 tensor_properties = TensorProperties( 153 dtype=expected_dtype, 154 layout=torch.strided, 155 requires_grad=False, 156 pin_memory=False, 157 memory_format=torch.contiguous_format, 158 ) 159 local_device = torch.device("cuda:0") 160 local_tensor = _create_tensor_from_params( 161 5, 10, local_device=local_device, tensor_properties=tensor_properties 162 ) 163 self.assertEqual(local_device, local_tensor.device) 164 self.assertEqual(expected_dtype, local_tensor.dtype) 165 self.assertEqual(torch.strided, local_tensor.layout) 166 self.assertEqual(False, local_tensor.requires_grad) 167 168 169class TestShardParameter(ShardedTensorTestBase): 170 @with_comms(init_rpc=False) 171 @skip_if_lt_x_gpu(4) 172 @requires_nccl() 173 def test_shard_parameter(self): 174 spec = ChunkShardingSpec( 175 dim=0, 176 placements=[ 177 "rank:0/cuda:0", 178 "rank:1/cuda:1", 179 "rank:2/cuda:2", 180 "rank:3/cuda:3", 181 ], 182 ) 183 184 fc = torch.nn.Linear(12, 12).cuda(self.rank) 185 weight_og = fc.weight.clone() 186 shard_parameter(fc, "weight", spec) 187 188 # Verify. 189 self.assertTrue(isinstance(fc.weight, ShardedTensor)) 190 local_shards = fc.weight.local_shards() 191 self.assertEqual(1, len(local_shards)) 192 self.assertEqual(torch.Size([3, 12]), local_shards[0].tensor.size()) 193 self.assertEqual(3, local_shards[0].tensor.size(0)) 194 self.assertEqual(12, local_shards[0].tensor.size(1)) 195 self.assertEqual( 196 torch.narrow(weight_og, 0, 3 * self.rank, 3), local_shards[0].tensor 197 ) 198 199 @with_comms(init_rpc=False) 200 @skip_if_lt_x_gpu(4) 201 @requires_nccl() 202 def test_shard_parameter_errors(self): 203 spec = ChunkShardingSpec( 204 dim=0, 205 placements=[ 206 "rank:0/cuda:0", 207 "rank:1/cuda:1", 208 "rank:2/cuda:2", 209 "rank:3/cuda:3", 210 ], 211 ) 212 213 fc = torch.nn.Linear(12, 12).cuda(self.rank) 214 with self.assertRaisesRegex(ValueError, "does not match with src_rank"): 215 shard_parameter(fc, "weight", spec, src_rank=self.rank) 216 217 with self.assertRaisesRegex(AttributeError, "has no attribute"): 218 shard_parameter(fc, "foo", spec) 219 220 with self.assertRaisesRegex( 221 ValueError, "Expected Linear.bias to be a Tensor, but found str" 222 ): 223 del fc.bias 224 fc.bias = "foo" 225 shard_parameter(fc, "bias", spec) 226 227 with self.assertRaisesRegex(ValueError, "not a contiguous Tensor"): 228 fc.bias = torch.rand(10, 10).cuda(self.rank).t() 229 shard_parameter(fc, "bias", spec) 230 231 spec = ChunkShardingSpec( 232 dim=0, 233 placements=[ 234 f"rank:{self.rank}/cuda:0", 235 "rank:1/cuda:1", 236 "rank:2/cuda:2", 237 "rank:3/cuda:3", 238 ], 239 ) 240 with self.assertRaisesRegex(ValueError, "does not match with sharding_spec"): 241 shard_parameter(fc, "weight", spec) 242 243 spec = EnumerableShardingSpec( 244 [ 245 ShardMetadata( 246 shard_offsets=[0, 0], 247 shard_sizes=[5, 5], 248 placement="rank:0/cuda:0", 249 ), 250 ShardMetadata( 251 shard_offsets=[5, 0], 252 shard_sizes=[5, 5], 253 placement="rank:1/cuda:1", 254 ), 255 ] 256 ) 257 with self.assertRaisesRegex(NotImplementedError, "not implemented yet!"): 258 shard_parameter(fc, "weight", spec) 259 260 261class TestShardTensor(ShardedTensorTestBase): 262 @with_comms(init_rpc=False) 263 @skip_if_lt_x_gpu(4) 264 @requires_nccl() 265 def test_shard_tensor(self): 266 spec = ChunkShardingSpec( 267 dim=0, 268 placements=[ 269 "rank:0/cuda:0", 270 "rank:1/cuda:1", 271 "rank:2/cuda:2", 272 "rank:3/cuda:3", 273 ], 274 ) 275 tensor = torch.rand(12, 12).cuda(self.rank) 276 st = _shard_tensor(tensor, spec) 277 278 # Verify. 279 self.assertTrue(isinstance(st, sharded_tensor.ShardedTensor)) 280 local_shard = st.local_tensor() 281 self.assertEqual(1, len(st.local_shards())) 282 self.assertEqual(torch.Size([3, 12]), local_shard.size()) 283 self.assertEqual(torch.narrow(tensor, 0, 3 * self.rank, 3), local_shard) 284 285 @with_comms(init_rpc=False) 286 @skip_if_lt_x_gpu(4) 287 @requires_nccl() 288 def test_shard_tensor_with_empty_shard(self): 289 spec = ChunkShardingSpec( 290 dim=0, 291 placements=[ 292 "rank:0/cuda:0", 293 "rank:1/cuda:1", 294 "rank:2/cuda:2", 295 "rank:3/cuda:3", 296 ], 297 ) 298 tensor = torch.rand(9, 12).cuda(self.rank) 299 st = _shard_tensor(tensor, spec) 300 301 # Verify. 302 self.assertTrue(isinstance(st, sharded_tensor.ShardedTensor)) 303 sms = st.metadata().shards_metadata 304 self.assertEqual(len(sms), 4) 305 for sm in sms: 306 self.assertTrue(sm.shard_offsets[0] + sm.shard_sizes[0] <= tensor.size(0)) 307 308 local_shard = st.local_tensor() 309 self.assertEqual(1, len(st.local_shards())) 310 if dist.get_rank() < 3: 311 self.assertEqual(torch.Size([3, 12]), local_shard.size()) 312 self.assertEqual(torch.narrow(tensor, 0, 3 * self.rank, 3), local_shard) 313 else: 314 self.assertEqual(torch.Size([0, 12]), local_shard.size()) 315 316 @with_comms(init_rpc=False) 317 @skip_if_lt_x_gpu(4) 318 @requires_nccl() 319 def test_shard_tensor_errors(self): 320 spec = ChunkShardingSpec( 321 dim=0, 322 placements=[ 323 "rank:0/cuda:0", 324 "rank:1/cuda:1", 325 "rank:2/cuda:2", 326 "rank:3/cuda:3", 327 ], 328 ) 329 tensor = torch.rand(12, 12).cuda(self.rank) 330 331 with self.assertRaisesRegex(ValueError, "does not match with src_rank"): 332 _shard_tensor(tensor, spec, src_rank=self.rank) 333 334 with self.assertRaisesRegex(ValueError, "not a contiguous Tensor"): 335 tensor_t = torch.rand(12, 12).cuda(self.rank).t() 336 _shard_tensor(tensor_t, spec) 337 338 spec = ChunkShardingSpec( 339 dim=0, 340 placements=[ 341 f"rank:{self.rank}/cuda:0", 342 "rank:1/cuda:1", 343 "rank:2/cuda:2", 344 "rank:3/cuda:3", 345 ], 346 ) 347 with self.assertRaisesRegex(ValueError, "does not match with sharding_spec"): 348 _shard_tensor(tensor, spec) 349 350 spec = EnumerableShardingSpec( 351 [ 352 ShardMetadata( 353 shard_offsets=[0, 0], 354 shard_sizes=[5, 5], 355 placement="rank:0/cuda:0", 356 ), 357 ShardMetadata( 358 shard_offsets=[5, 0], 359 shard_sizes=[5, 5], 360 placement="rank:1/cuda:1", 361 ), 362 ] 363 ) 364 with self.assertRaisesRegex(NotImplementedError, "not implemented yet!"): 365 _shard_tensor(tensor, spec) 366 367 368class TestModuleHookApi(ShardedTensorTestBase): 369 class DummyNNModule(torch.nn.Module): 370 def __init__(self, spec, tensor_size): 371 super().__init__() 372 self.st = sharded_tensor.rand(spec, *tensor_size) 373 374 def forward(self): 375 return self.st 376 377 @with_comms(init_rpc=False) 378 @skip_if_lt_x_gpu(4) 379 @requires_nccl() 380 def test_reshard_output(self): 381 specs = _chunk_sharding_specs_list_for_test([0, 1], seed=5) 382 spec, reshard_spec = specs[0], specs[1] 383 test_module = self.DummyNNModule(spec, [24, 12]) 384 st = test_module() 385 local_shard = st.local_tensor() 386 pg = dist.distributed_c10d._get_default_group() 387 st_compare = ShardedTensor._init_from_local_shards( 388 copy.deepcopy(st.local_shards()), 389 st.size(), 390 process_group=pg, 391 ) 392 st_compare._sharding_spec = copy.deepcopy(spec) 393 st_compare.reshard(reshard_spec) 394 test_module = _reshard_output(test_module, reshard_spec) 395 st = test_module() 396 local_shard = st.local_tensor() 397 local_shard_compare = st_compare.local_tensor() 398 self.assertEqual(local_shard, local_shard_compare) 399 self.assertEqual(local_shard.size(0), 24) 400 self.assertEqual(local_shard.size(1), 3) 401 402 @with_comms(init_rpc=False) 403 @skip_if_lt_x_gpu(4) 404 @requires_nccl() 405 def test_collect_local_shard(self): 406 specs = _chunk_sharding_specs_list_for_test([0], seed=5) 407 spec = specs[0] 408 test_module = self.DummyNNModule(spec, [23, 15]) 409 st = test_module() 410 local_shard = st.local_tensor() 411 test_module = _collect_local_shard(test_module) 412 output = test_module() 413 self.assertTrue(isinstance(output, torch.Tensor)) 414 self.assertEqual(local_shard, output) 415 416 417class TestLocalTensor(ShardedTensorTestBase): 418 @with_comms(init_rpc=False) 419 @skip_if_lt_x_gpu(4) 420 @requires_nccl() 421 def test_local_tensor(self): 422 spec = ChunkShardingSpec( 423 dim=0, 424 placements=[ 425 "rank:0/cuda:0", 426 "rank:1/cuda:1", 427 "rank:2/cuda:2", 428 "rank:3/cuda:3", 429 ], 430 ) 431 st = sharded_tensor.rand(spec, 24, 12) 432 local_shard = st.local_tensor() 433 self.assertEqual(torch.Size([6, 12]), local_shard.size()) 434 self.assertEqual(st.local_tensor(), local_shard) 435 436 @with_comms(init_rpc=False) 437 @skip_if_lt_x_gpu(4) 438 @requires_nccl() 439 def test_local_tensor_error(self): 440 spec = ChunkShardingSpec( 441 dim=0, 442 placements=[ 443 "rank:0/cuda:0", 444 "rank:0/cuda:0", 445 "rank:1/cuda:1", 446 "rank:1/cuda:1", 447 "rank:1/cuda:1", 448 "rank:2/cuda:2", 449 "rank:2/cuda:2", 450 "rank:2/cuda:2", 451 "rank:3/cuda:3", 452 "rank:3/cuda:3", 453 ], 454 ) 455 st = sharded_tensor.rand(spec, 24, 12) 456 with self.assertRaisesRegex( 457 NotImplementedError, "Only single local shard is supported." 458 ): 459 local_shard = st.local_tensor() 460 461 462class TestShardedTensorChunked(ShardedTensorTestBase): 463 @with_comms 464 @skip_if_lt_x_gpu(4) 465 @requires_nccl() 466 def test_sharded_tensor_metadata(self): 467 spec = ChunkShardingSpec( 468 dim=0, 469 placements=[ 470 "rank:0/cuda:0", 471 "rank:1/cuda:1", 472 "rank:2/cuda:2", 473 "rank:3/cuda:3", 474 ], 475 ) 476 477 st = sharded_tensor.empty(spec, 10, 20, init_rrefs=True) 478 st_metadata = st.metadata() 479 self.assertEqual(torch.Size([10, 20]), st_metadata.size) 480 self.assertEqual(torch.Size([10, 20]), st.size()) 481 self.assertEqual(torch.float, st.dtype) 482 self.assertEqual(torch.strided, st.layout) 483 self.assertEqual(False, st.requires_grad) 484 self.assertTrue(st.is_contiguous()) 485 self.assertFalse(st.is_pinned()) 486 487 st = sharded_tensor.empty(spec, 10, 20, requires_grad=True, init_rrefs=True) 488 self.assertEqual(True, st.requires_grad) 489 490 st = sharded_tensor.empty(spec, 10, 20, dtype=torch.double, init_rrefs=True) 491 self.assertEqual(torch.double, st.dtype) 492 493 # Need CPU for pin_memory 494 spec = ChunkShardingSpec( 495 dim=0, 496 placements=[ 497 "rank:0/cpu", 498 "rank:1/cpu", 499 "rank:2/cpu", 500 "rank:3/cpu", 501 ], 502 ) 503 504 st = sharded_tensor.empty(spec, 10, 20, pin_memory=True, init_rrefs=True) 505 self.assertEqual(True, st.is_pinned()) 506 507 # test read only properties, they're read only as we can't simply change 508 # the global metadata without changing the underlying shard's properties 509 with self.assertRaisesRegex(RuntimeError, "torch function '__set__'"): 510 st.requires_grad = True 511 512 @with_comms 513 @skip_if_lt_x_gpu(4) 514 @requires_nccl() 515 def test_complete_world_size(self): 516 for dim in [0, -2]: 517 spec = ChunkShardingSpec( 518 dim=dim, 519 placements=[ 520 "rank:0/cuda:0", 521 "rank:1/cuda:1", 522 "rank:2/cuda:2", 523 "rank:3/cuda:3", 524 ], 525 ) 526 st = sharded_tensor.empty(spec, 10, 20, init_rrefs=True) 527 528 # Validate local shard. 529 local_shards = st.local_shards() 530 self.assertEqual(1, len(local_shards)) 531 local_shard = local_shards[0].tensor 532 self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) 533 if self.rank == 3: 534 self.assertEqual((1, 20), local_shard.size()) 535 else: 536 self.assertEqual((3, 20), local_shard.size()) 537 538 # Validate global metadata. 539 st_metadata = st.metadata() 540 shards_metadata = st_metadata.shards_metadata 541 self.assertEqual(4, len(shards_metadata)) 542 543 for rank, shard_metadata in enumerate(shards_metadata): 544 self.assertEqual([rank * 3, 0], shard_metadata.shard_offsets) 545 if rank == 3: 546 self.assertEqual([1, 20], shard_metadata.shard_sizes) 547 else: 548 self.assertEqual([3, 20], shard_metadata.shard_sizes) 549 self.assertEqual( 550 f"rank:{rank}/cuda:{rank}", str(shard_metadata.placement) 551 ) 552 553 # Validate remote shards. 554 remote_shards = st.remote_shards() 555 self.assertEqual(3, len(remote_shards)) 556 557 for rpc_rank, shards in remote_shards.items(): 558 self.assertEqual(1, len(shards)) 559 for remote_shard in shards: 560 self.assertEqual(rpc_rank, remote_shard.owner().id) 561 shard = remote_shard.to_here() 562 self.assertEqual( 563 f"rank:{rpc_rank}/cuda:{rpc_rank}", 564 str(shard.metadata.placement), 565 ) 566 if rpc_rank == 3: 567 self.assertEqual((1, 20), shard.tensor.size()) 568 else: 569 self.assertEqual((3, 20), shard.tensor.size()) 570 571 @with_comms 572 @skip_if_lt_x_gpu(4) 573 @requires_nccl() 574 def test_create_sharded_tensor_with_ones(self): 575 """Test sharded_tensor.ones(...)""" 576 577 spec = ChunkShardingSpec( 578 dim=0, 579 placements=[ 580 "rank:0/cuda:0", 581 "rank:1/cuda:1", 582 "rank:2/cuda:2", 583 "rank:3/cuda:3", 584 ], 585 ) 586 h, w = 10, 20 587 st = sharded_tensor.ones(spec, h, w) 588 589 # Validate local shard is initialized with torch.ones 590 local_shards = st.local_shards() 591 self.assertEqual(1, len(local_shards)) 592 local_shard = local_shards[0].tensor 593 self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) 594 # The split: for rank!=3 ceil(h/4)=3 for rank=3 1 595 expected_h = 1 if self.rank == 3 else math.ceil(h / 4) 596 self.assertEqual((expected_h, w), local_shard.size()) 597 self.assertEqual(local_shard, torch.ones(expected_h, w)) 598 599 @with_comms 600 @skip_if_lt_x_gpu(4) 601 @requires_nccl() 602 def test_gather_even(self) -> None: 603 """Test _sharded_tensor.gather(...) with evenly distributed._shards""" 604 605 spec = ChunkShardingSpec( 606 dim=0, 607 placements=[ 608 "rank:0/cuda:0", 609 "rank:1/cuda:1", 610 "rank:2/cuda:2", 611 "rank:3/cuda:3", 612 ], 613 ) 614 h, w = 10, 20 615 st = sharded_tensor.ones(spec, h, w) 616 617 full_tensor = None 618 dst = 1 619 if self.rank == dst: 620 full_tensor = torch.zeros( 621 h, 622 w, 623 device=torch.device(f"cuda:{dst}"), 624 ) 625 st.gather(dst, full_tensor) 626 627 if self.rank == dst: 628 self.assertEqual(full_tensor, torch.ones(h, w)) 629 else: 630 self.assertIsNone(full_tensor) 631 632 @with_comms 633 @skip_if_lt_x_gpu(4) 634 @requires_nccl() 635 def test_gather_uneven(self) -> None: 636 """Test _sharded_tensor.gather(...) with unevenly distributed._shards""" 637 638 spec = ChunkShardingSpec( 639 dim=0, 640 placements=[ 641 "rank:0/cuda:0", 642 "rank:0/cuda:0", 643 "rank:1/cuda:1", 644 "rank:1/cuda:1", 645 "rank:2/cuda:2", 646 ], 647 ) 648 h, w = 10, 20 649 st = sharded_tensor.ones(spec, h, w) 650 651 full_tensor = None 652 dst = 1 653 if self.rank == dst: 654 full_tensor = torch.zeros( 655 h, 656 w, 657 device=torch.device(f"cuda:{dst}"), 658 ) 659 st.gather(dst, full_tensor) 660 661 if self.rank == dst: 662 self.assertEqual(full_tensor, torch.ones(h, w)) 663 else: 664 self.assertIsNone(full_tensor) 665 666 @with_comms 667 @skip_if_lt_x_gpu(4) 668 @requires_nccl() 669 def test_create_sharded_tensor_with_zeros(self): 670 """Test sharded_tensor.zeros(...)""" 671 672 spec = ChunkShardingSpec( 673 dim=0, 674 placements=[ 675 "rank:0/cuda:0", 676 "rank:1/cuda:1", 677 "rank:2/cuda:2", 678 "rank:3/cuda:3", 679 ], 680 ) 681 h, w = 10, 20 682 st = sharded_tensor.zeros(spec, h, w) 683 684 # Validate local shard is initialized with torch.zeros 685 local_shards = st.local_shards() 686 self.assertEqual(1, len(local_shards)) 687 local_shard = local_shards[0].tensor 688 self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) 689 # The split: for rank!=3 ceil(h/4)=3 for rank=3 1 690 expected_h = 1 if self.rank == 3 else math.ceil(h / 4) 691 self.assertEqual((expected_h, w), local_shard.size()) 692 self.assertEqual(local_shard, torch.zeros(expected_h, w)) 693 694 @with_comms 695 @skip_if_lt_x_gpu(4) 696 @requires_nccl() 697 def test_create_sharded_tensor_with_rand(self): 698 """Test sharded_tensor.rand(...)/randn(...)""" 699 700 spec = ChunkShardingSpec( 701 dim=0, 702 placements=[ 703 "rank:0/cuda:0", 704 "rank:1/cuda:1", 705 "rank:2/cuda:2", 706 "rank:3/cuda:3", 707 ], 708 ) 709 h, w = 8, 2 710 seed = 1234 711 712 expected_h = 2 713 expected_device = torch.device(f"cuda:{self.rank}") 714 dtype = torch.double 715 torch.manual_seed(seed) 716 # Test sharded_tensor.rand creation 717 expected = torch.rand(expected_h, w, device=expected_device, dtype=dtype) 718 # reset seed to ensure the same random numbers are generated 719 torch.manual_seed(seed) 720 st = sharded_tensor.rand(spec, h, w, dtype=dtype) 721 722 # Validate local shard is initialized with torch.rand 723 local_shards = st.local_shards() 724 self.assertEqual(1, len(local_shards)) 725 local_shard = local_shards[0].tensor 726 self.assertEqual(expected_device, local_shard.device) 727 self.assertEqual((expected_h, w), local_shard.size()) 728 self.assertEqual(expected, local_shard) 729 730 # Test sharded_tensor.randn creation 731 torch.manual_seed(seed) 732 expected_randn = torch.randn(expected_h, w, device=expected_device, dtype=dtype) 733 # reset seed to ensure the same random numbers are generated 734 torch.manual_seed(seed) 735 st_randn = sharded_tensor.randn(spec, h, w, dtype=dtype) 736 737 # Validate local shard is initialized with torch.randn 738 local_shards = st_randn.local_shards() 739 self.assertEqual(1, len(local_shards)) 740 local_shard = local_shards[0].tensor 741 self.assertEqual(expected_device, local_shard.device) 742 self.assertEqual((expected_h, w), local_shard.size()) 743 self.assertEqual(expected_randn, local_shard) 744 745 @with_comms 746 @skip_if_lt_x_gpu(4) 747 @requires_nccl() 748 def test_create_sharded_tensor_with_full(self): 749 """Test sharded_tensor.full(...)""" 750 751 spec = ChunkShardingSpec( 752 dim=0, 753 placements=[ 754 "rank:0/cuda:0", 755 "rank:1/cuda:1", 756 "rank:2/cuda:2", 757 "rank:3/cuda:3", 758 ], 759 ) 760 h, w = 10, 20 761 fill_value = 1234 762 st = sharded_tensor.full( 763 spec, size=(h, w), fill_value=fill_value, dtype=torch.int32 764 ) 765 766 # Validate local shard is initialized with torch.full 767 local_shards = st.local_shards() 768 self.assertEqual(1, len(local_shards)) 769 local_shard = local_shards[0].tensor 770 self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) 771 # The split: for rank!=3 ceil(h/4)=3 for rank=3 1 772 expected_h = 1 if self.rank == 3 else math.ceil(h / 4) 773 self.assertEqual((expected_h, w), local_shard.size()) 774 self.assertEqual( 775 local_shard, 776 torch.full(size=(expected_h, w), fill_value=fill_value, dtype=torch.int32), 777 ) 778 779 @with_comms 780 @skip_if_lt_x_gpu(4) 781 @requires_nccl() 782 def test_create_sharded_tensor_like(self): 783 """Test tensor like methods, i.e. torch.zeros_like(...), torch.full_like, etc.""" 784 785 spec = ChunkShardingSpec( 786 dim=0, 787 placements=[ 788 "rank:0/cuda:0", 789 "rank:1/cuda:1", 790 "rank:2/cuda:2", 791 "rank:3/cuda:3", 792 ], 793 ) 794 h, w = 8, 8 795 expected_h = 2 796 seed = 1234 797 dtype = torch.double 798 expected_device = torch.device(f"cuda:{self.rank}") 799 st = sharded_tensor.rand(spec, (h, w), dtype=dtype) 800 tensor_like_ops = { 801 torch.zeros_like: torch.zeros, 802 torch.ones_like: torch.ones, 803 torch.rand_like: torch.rand, 804 torch.randn_like: torch.randn, 805 torch.empty_like: torch.empty, 806 torch.full_like: torch.full, 807 } 808 for op, expect_local_op in tensor_like_ops.items(): 809 if op == torch.full_like: 810 # special handle full/full_like as it needs to have additional fill_value arg 811 expect_tensor = expect_local_op( 812 (expected_h, w), 8.8, device=expected_device, dtype=dtype 813 ) 814 new_op_st = op(st, 8.8, dtype=dtype) 815 self.assertEqual(new_op_st.local_tensor(), expect_tensor) 816 elif op == torch.empty_like: 817 # empty/empty_like we only compare the shape 818 expect_tensor = expect_local_op( 819 expected_h, w, device=expected_device, dtype=dtype 820 ) 821 new_op_st = op(st, dtype=dtype) 822 self.assertEqual(new_op_st.local_tensor().shape, expect_tensor.shape) 823 else: 824 torch.manual_seed(seed) 825 expect_tensor = expect_local_op( 826 expected_h, w, device=expected_device, dtype=dtype 827 ) 828 torch.manual_seed(seed) 829 new_op_st = op(st, dtype=dtype) 830 self.assertEqual(new_op_st.local_tensor(), expect_tensor) 831 832 @with_comms 833 @skip_if_lt_x_gpu(4) 834 @requires_nccl() 835 def test_partial_world_size(self): 836 spec = ChunkShardingSpec( 837 dim=0, 838 placements=[ 839 "rank:2/cuda:2", 840 "rank:3/cuda:3", 841 ], 842 ) 843 st = sharded_tensor.empty(spec, 10, 20, init_rrefs=True) 844 845 # Validate local shard. 846 local_shards = st.local_shards() 847 if self.rank >= 2: 848 self.assertEqual(1, len(local_shards)) 849 local_shard = local_shards[0].tensor 850 self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) 851 self.assertEqual((5, 20), local_shard.size()) 852 else: 853 self.assertEqual(0, len(local_shards)) 854 855 # Validate global metadata. 856 st_metadata = st.metadata() 857 shards_metadata = st_metadata.shards_metadata 858 self.assertEqual(2, len(shards_metadata)) 859 860 for shard_rank, shard_metadata in enumerate(shards_metadata): 861 self.assertEqual([shard_rank * 5, 0], shard_metadata.shard_offsets) 862 self.assertEqual([5, 20], shard_metadata.shard_sizes) 863 self.assertEqual( 864 f"rank:{shard_rank + 2}/cuda:{shard_rank + 2}", 865 str(shard_metadata.placement), 866 ) 867 868 # Validate remote shards. 869 remote_shards = st.remote_shards() 870 if self.rank >= 2: 871 self.assertEqual(1, len(remote_shards)) 872 else: 873 self.assertEqual(2, len(remote_shards)) 874 875 for rpc_rank, shards in remote_shards.items(): 876 self.assertEqual(1, len(shards)) 877 for remote_shard in shards: 878 self.assertEqual(rpc_rank, remote_shard.owner().id) 879 shard = remote_shard.to_here() 880 self.assertEqual( 881 f"rank:{rpc_rank}/cuda:{rpc_rank}", str(shard.metadata.placement) 882 ) 883 self.assertEqual((5, 20), shard.tensor.size()) 884 885 @with_comms 886 @skip_if_lt_x_gpu(4) 887 @requires_nccl() 888 def test_new_group(self): 889 spec = ChunkShardingSpec( 890 dim=0, 891 placements=[ 892 "rank:2/cuda:2", 893 "rank:3/cuda:3", 894 ], 895 ) 896 897 pg = dist.new_group(ranks=[1, 2, 3]) 898 st = sharded_tensor.empty(spec, 10, 20, process_group=pg, init_rrefs=True) 899 900 # Validate local shard. 901 local_shards = st.local_shards() 902 if self.rank >= 2: 903 self.assertEqual(1, len(local_shards)) 904 local_shard = local_shards[0].tensor 905 self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) 906 self.assertEqual((5, 20), local_shard.size()) 907 else: 908 self.assertEqual(0, len(local_shards)) 909 910 # Validate global metadata. 911 st_metadata = st.metadata() 912 shards_metadata = st_metadata.shards_metadata 913 self.assertEqual(2, len(shards_metadata)) 914 915 for shard_rank, shard_metadata in enumerate(shards_metadata): 916 self.assertEqual([shard_rank * 5, 0], shard_metadata.shard_offsets) 917 self.assertEqual([5, 20], shard_metadata.shard_sizes) 918 self.assertEqual( 919 f"rank:{shard_rank + 2}/cuda:{shard_rank + 2}", 920 str(shard_metadata.placement), 921 ) 922 923 # Validate remote shards. 924 remote_shards = st.remote_shards() 925 if self.rank >= 2: 926 self.assertEqual(1, len(remote_shards)) 927 else: 928 self.assertEqual(2, len(remote_shards)) 929 930 for rpc_rank, shards in remote_shards.items(): 931 self.assertEqual(1, len(shards)) 932 for remote_shard in shards: 933 shard = remote_shard.to_here() 934 self.assertEqual(rpc_rank, remote_shard.owner().id) 935 self.assertEqual( 936 f"rank:{rpc_rank}/cuda:{rpc_rank}", str(shard.metadata.placement) 937 ) 938 self.assertEqual((5, 20), shard.tensor.size()) 939 940 @with_comms 941 @skip_if_lt_x_gpu(4) 942 @requires_nccl() 943 def test_multiple_local_shards(self): 944 spec = ChunkShardingSpec( 945 dim=0, 946 placements=[ 947 "rank:0/cuda:0", 948 "rank:1/cuda:1", 949 "rank:2/cuda:2", 950 "rank:3/cuda:3", 951 "rank:0/cuda:0", 952 "rank:1/cuda:1", 953 "rank:2/cuda:2", 954 "rank:3/cuda:3", 955 ], 956 ) 957 st = sharded_tensor.empty(spec, 16, 20, init_rrefs=True) 958 959 # Validate local shards. 960 local_shards = st.local_shards() 961 self.assertEqual(2, len(local_shards)) 962 for local_shard in local_shards: 963 self.assertEqual( 964 torch.device(f"cuda:{self.rank}"), local_shard.tensor.device 965 ) 966 self.assertEqual((2, 20), local_shard.tensor.size()) 967 968 # Validate global metadata. 969 st_metadata = st.metadata() 970 shards_metadata = st_metadata.shards_metadata 971 self.assertEqual(8, len(shards_metadata)) 972 973 for shard_idx, shard_metadata in enumerate(shards_metadata): 974 self.assertEqual([shard_idx * 2, 0], shard_metadata.shard_offsets) 975 self.assertEqual([2, 20], shard_metadata.shard_sizes) 976 self.assertEqual( 977 f"rank:{shard_idx % 4}/cuda:{shard_idx % 4}", 978 str(shard_metadata.placement), 979 ) 980 981 # Validate remote shards. 982 remote_shards = st.remote_shards() 983 self.assertEqual(3, len(remote_shards)) 984 owners = {} 985 for rpc_rank, shards in remote_shards.items(): 986 self.assertEqual(2, len(shards)) 987 for remote_shard in shards: 988 shard = remote_shard.to_here() 989 self.assertEqual((2, 20), shard.tensor.size()) 990 self.assertEqual(rpc_rank, remote_shard.owner().id) 991 992 @skip_if_lt_x_gpu(4) 993 @requires_nccl() 994 def test_sharding_columns(self): 995 self.init_pg() 996 997 for dim in [1, -1]: 998 spec = ChunkShardingSpec( 999 dim=dim, 1000 placements=[ 1001 "rank:0/cuda:0", 1002 "rank:1/cuda:1", 1003 "rank:2/cuda:2", 1004 "rank:3/cuda:3", 1005 ], 1006 ) 1007 1008 st = sharded_tensor.empty(spec, 10, 32) 1009 1010 # Validate local shard. 1011 local_shards = st.local_shards() 1012 self.assertEqual(1, len(local_shards)) 1013 local_shard = local_shards[0].tensor 1014 self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) 1015 self.assertEqual((10, 8), local_shard.size()) 1016 1017 # Validate global metadata. 1018 st_metadata = st.metadata() 1019 shards_metadata = st_metadata.shards_metadata 1020 self.assertEqual(4, len(shards_metadata)) 1021 1022 for rank, shard_metadata in enumerate(shards_metadata): 1023 self.assertEqual([0, rank * 8], shard_metadata.shard_offsets) 1024 self.assertEqual([10, 8], shard_metadata.shard_sizes) 1025 self.assertEqual( 1026 f"rank:{rank}/cuda:{rank}", str(shard_metadata.placement) 1027 ) 1028 1029 @skip_if_lt_x_gpu(4) 1030 @requires_nccl() 1031 def test_invalid_sharding(self): 1032 self.init_pg() 1033 1034 with self.assertRaisesRegex( 1035 NotImplementedError, "does not support named dimension" 1036 ): 1037 spec = ChunkShardingSpec(dim="H", placements=["rank:1/cuda:1"]) 1038 sharded_tensor.empty(spec, 10, 20) 1039 1040 for dim in [2, 3, 4, -3, -4, -5]: 1041 spec = ChunkShardingSpec(dim=dim, placements=["rank:1/cuda:1"]) 1042 with self.assertRaisesRegex(ValueError, "Invalid sharding dim"): 1043 sharded_tensor.empty(spec, 10, 20) 1044 1045 spec = ChunkShardingSpec(dim=0, placements=["rank:5/cuda:1"]) 1046 with self.assertRaisesRegex( 1047 ValueError, "Global rank 5 does not exist in input process group" 1048 ): 1049 sharded_tensor.empty(spec, 10, 20) 1050 1051 spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"]) 1052 st = sharded_tensor.empty(spec, 10, 20) 1053 tensor = torch.empty(10, 20) 1054 with self.assertRaisesRegex( 1055 RuntimeError, r".*not supported for ShardedTensor!$" 1056 ): 1057 torch.add(st, tensor) 1058 1059 spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"]) 1060 with self.assertRaisesRegex( 1061 ValueError, "Only torch.strided layout is currently supported" 1062 ): 1063 sharded_tensor.empty(spec, 10, 20, layout=torch.sparse_coo) 1064 1065 spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"]) 1066 with self.assertRaisesRegex( 1067 ValueError, 1068 "Only torch.contiguous_format memory_format is currently supported", 1069 ): 1070 sharded_tensor.empty(spec, 10, 20, memory_format=torch.channels_last) 1071 1072 spec = ChunkShardingSpec(dim=0, placements=["worker0/cuda:1"]) 1073 with self.assertRaisesRegex( 1074 RuntimeError, "RPC framework needs to be initialized" 1075 ): 1076 sharded_tensor.empty(spec, 10, 20) 1077 1078 spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"]) 1079 with self.assertRaisesRegex( 1080 RuntimeError, "RPC Framework needs to be initialized" 1081 ): 1082 st = sharded_tensor.empty(spec, 10, 20, init_rrefs=True) 1083 1084 with self.assertRaisesRegex( 1085 RuntimeError, "ShardedTensor created with init_rrefs=False" 1086 ): 1087 st = sharded_tensor.empty(spec, 10, 20) 1088 st.remote_shards() 1089 1090 self.init_rpc() 1091 spec = ChunkShardingSpec(dim=0, placements=["workerfoo/cuda:1"]) 1092 with self.assertRaisesRegex(ValueError, "Invalid worker name"): 1093 sharded_tensor.empty(spec, 10, 20, init_rrefs=True) 1094 1095 @skip_if_lt_x_gpu(4) 1096 @requires_nccl() 1097 def test_invalid_pg_rpc_ranks(self): 1098 self.init_pg() 1099 1100 # Init RPC with different ranks. 1101 rpc_backend_options = rpc.TensorPipeRpcBackendOptions( 1102 _transports=tp_transports() 1103 ) 1104 rpc_backend_options.init_method = f"file://{self.file_name}" 1105 rank = (self.rank + 1) % self.world_size 1106 rpc.init_rpc( 1107 name=f"worker{rank}", 1108 rank=rank, 1109 world_size=self.world_size, 1110 rpc_backend_options=rpc_backend_options, 1111 ) 1112 1113 spec = ChunkShardingSpec(dim=0, placements=["rank:1/cuda:1"]) 1114 with self.assertRaisesRegex( 1115 ValueError, "Default ProcessGroup and RPC ranks must be the same" 1116 ): 1117 sharded_tensor.empty(spec, 10, 20, init_rrefs=True) 1118 1119 @skip_if_lt_x_gpu(4) 1120 @requires_nccl() 1121 def test_insufficient_sharding_dims(self): 1122 self.init_pg() 1123 1124 spec = ChunkShardingSpec( 1125 dim=0, 1126 placements=[ 1127 "rank:0/cuda:0", 1128 "rank:1/cuda:1", 1129 "rank:2/cuda:2", 1130 "rank:3/cuda:3", 1131 ], 1132 ) 1133 st = sharded_tensor.empty(spec, 2, 20) 1134 1135 # Validate local shard. 1136 local_shards = st.local_shards() 1137 if self.rank <= 1: 1138 self.assertEqual(1, len(local_shards)) 1139 local_shard = local_shards[0].tensor 1140 self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) 1141 self.assertEqual((1, 20), local_shard.size()) 1142 else: 1143 self.assertEqual(1, len(local_shards)) 1144 local_shard = local_shards[0].tensor 1145 self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) 1146 self.assertEqual(local_shard.numel(), 0) 1147 1148 # Validate global metadata. 1149 st_metadata = st.metadata() 1150 shards_metadata = st_metadata.shards_metadata 1151 self.assertEqual(4, len(shards_metadata)) 1152 1153 for shard_rank, shard_metadata in enumerate(shards_metadata): 1154 self.assertEqual([shard_rank, 0], shard_metadata.shard_offsets) 1155 self.assertEqual( 1156 f"rank:{shard_rank}/cuda:{shard_rank}", str(shard_metadata.placement) 1157 ) 1158 if shard_rank <= 1: 1159 self.assertEqual([1, 20], shard_metadata.shard_sizes) 1160 else: 1161 self.assertEqual([0, 20], shard_metadata.shard_sizes) 1162 1163 @with_comms 1164 @skip_if_lt_x_gpu(4) 1165 @requires_nccl() 1166 def test_sharded_tensor_sizes(self): 1167 spec = ChunkShardingSpec( 1168 dim=0, 1169 placements=[ 1170 "rank:0/cuda:0", 1171 "rank:1/cuda:1", 1172 "rank:2/cuda:2", 1173 "rank:3/cuda:3", 1174 ], 1175 ) 1176 1177 # Test with *args 1178 st = sharded_tensor.empty(spec, 10, 20, init_rrefs=True) 1179 self.assertEqual(torch.Size([10, 20]), st.size()) 1180 1181 # Test with single *args 1182 st = sharded_tensor.empty(spec, 10, init_rrefs=True) 1183 self.assertEqual(torch.Size([10]), st.size()) 1184 1185 # Test with list 1186 st = sharded_tensor.empty(spec, [10, 20], init_rrefs=True) 1187 self.assertEqual(torch.Size([10, 20]), st.size()) 1188 1189 # Test with tuple 1190 st = sharded_tensor.empty(spec, (10, 20), init_rrefs=True) 1191 self.assertEqual(torch.Size([10, 20]), st.size()) 1192 1193 # Test with row size 1194 st = sharded_tensor.empty(spec, (10, 20), init_rrefs=True) 1195 self.assertEqual(st.size(0), 10) 1196 1197 # Test with col size 1198 st = sharded_tensor.empty(spec, (10, 20), init_rrefs=True) 1199 self.assertEqual(st.size(1), 20) 1200 1201 # Test with negative indexed size 1202 st = sharded_tensor.empty(spec, (10, 20), init_rrefs=True) 1203 self.assertEqual(st.size(-1), 20) 1204 1205 # Test with dim/ndim 1206 self.assertEqual(st.dim(), 2) 1207 self.assertEqual(st.ndim, 2) 1208 # Test with invalid input 1209 st = sharded_tensor.empty(spec, (10, 20), init_rrefs=True) 1210 with self.assertRaisesRegex(IndexError, "Dimension out of range"): 1211 st.size(-3) 1212 with self.assertRaisesRegex(IndexError, "Dimension out of range"): 1213 st.size(2) 1214 1215 with self.assertRaises(TypeError): 1216 st = sharded_tensor.empty(spec, "foo") 1217 1218 @with_comms 1219 @skip_if_lt_x_gpu(4) 1220 @requires_nccl() 1221 def test_state_dict(self): 1222 spec = ChunkShardingSpec( 1223 dim=0, 1224 placements=[ 1225 "rank:0/cuda:0", 1226 "rank:1/cuda:1", 1227 "rank:2/cuda:2", 1228 "rank:3/cuda:3", 1229 ], 1230 ) 1231 1232 m = MyShardedModel1(spec) 1233 1234 # Test save 1235 m._register_state_dict_hook(state_dict_hook) 1236 buffer = io.BytesIO() 1237 mod_state_dict = m.state_dict() 1238 mod_state_keys = mod_state_dict.keys() 1239 self.assertTrue("sharded_tensor1" in mod_state_keys) 1240 self.assertTrue("submodule.sharded_tensor2" in mod_state_keys) 1241 torch.save(mod_state_dict, buffer) 1242 1243 # Test load. 1244 module_load = MyShardedModel1() 1245 module_load._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True) 1246 1247 buffer.seek(0) 1248 state_dict_deser = torch.load(buffer) 1249 module_load.load_state_dict(state_dict_deser, strict=False) 1250 1251 module_load._register_state_dict_hook(state_dict_hook) 1252 loaded_dict_keys = module_load.state_dict().keys() 1253 self.assertTrue("sharded_tensor1" in loaded_dict_keys) 1254 self.assertTrue("submodule.sharded_tensor2" in loaded_dict_keys) 1255 # Verify after load. 1256 self.assertTrue(torch.equal(m.sharded_tensor1, module_load.sharded_tensor1)) 1257 self.assertTrue( 1258 torch.equal( 1259 m.submodule.sharded_tensor2, module_load.submodule.sharded_tensor2 1260 ) 1261 ) 1262 1263 @with_comms 1264 @skip_if_lt_x_gpu(4) 1265 @requires_nccl() 1266 def test_state_dict_new_group(self): 1267 spec = ChunkShardingSpec( 1268 dim=0, 1269 placements=[ 1270 "rank:2/cuda:0", 1271 "rank:3/cuda:1", 1272 "rank:2/cuda:2", 1273 "rank:3/cuda:3", 1274 ], 1275 ) 1276 1277 pg = dist.new_group([2, 3]) 1278 1279 m = MyShardedModel1(spec, pg) 1280 1281 # Test save 1282 m._register_state_dict_hook(state_dict_hook) 1283 buffer = io.BytesIO() 1284 torch.save(m.state_dict(), buffer) 1285 1286 # Test load. 1287 module_load = MyShardedModel1(spec=None, group=pg) 1288 module_load._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True) 1289 1290 buffer.seek(0) 1291 with load_with_process_group(pg): 1292 state_dict_deser = torch.load(buffer) 1293 module_load.load_state_dict(state_dict_deser, strict=False) 1294 1295 # Verify after load. 1296 self.assertTrue(torch.equal(m.sharded_tensor1, module_load.sharded_tensor1)) 1297 self.assertTrue( 1298 torch.equal( 1299 m.submodule.sharded_tensor2, module_load.submodule.sharded_tensor2 1300 ) 1301 ) 1302 1303 @with_comms 1304 @skip_if_lt_x_gpu(4) 1305 @requires_nccl() 1306 def test_state_dict_no_sharded_tensors(self): 1307 # Verify hooks don't affect modules with no ShardedTensors. 1308 m = torch.nn.Linear(10, 10) 1309 1310 # Test save 1311 state_dict_before = m.state_dict() 1312 m._register_state_dict_hook(state_dict_hook) 1313 buffer = io.BytesIO() 1314 torch.save(m.state_dict(), buffer) 1315 self.assertEqual(state_dict_before, m.state_dict()) 1316 1317 # Test load. 1318 module_load = torch.nn.Linear(10, 10) 1319 module_load._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True) 1320 1321 buffer.seek(0) 1322 state_dict_deser = torch.load(buffer) 1323 module_load.load_state_dict(state_dict_deser, strict=False) 1324 1325 # Verify after load. 1326 self.assertEqual(m.weight, module_load.weight) 1327 self.assertEqual(m.bias, module_load.bias) 1328 1329 @skip_if_lt_x_gpu(4) 1330 @requires_nccl() 1331 def test_load_state_dict_errors(self): 1332 self.init_rpc() 1333 1334 dist.init_process_group( 1335 backend="nccl", 1336 world_size=self.world_size, 1337 rank=self.rank, 1338 init_method=f"file://{self.file_name}", 1339 ) 1340 1341 spec = ChunkShardingSpec( 1342 dim=0, 1343 placements=[ 1344 "rank:0/cuda:0", 1345 "rank:1/cuda:1", 1346 "rank:2/cuda:2", 1347 "rank:3/cuda:3", 1348 ], 1349 ) 1350 1351 m = MyShardedModel1(spec) 1352 1353 # Test save 1354 m._register_state_dict_hook(state_dict_hook) 1355 buffer = io.BytesIO() 1356 torch.save(m.state_dict(), buffer) 1357 1358 pg = dist.new_group(ranks=[0, 2, 3]) 1359 1360 buffer.seek(0) 1361 if self.rank != 0: 1362 with self.assertRaisesRegex(RuntimeError, "Local rank at save time was"): 1363 with load_with_process_group(pg): 1364 state_dict_deser = torch.load(buffer) 1365 else: 1366 with self.assertRaisesRegex( 1367 RuntimeError, "Local world size at save time was" 1368 ): 1369 with load_with_process_group(pg): 1370 state_dict_deser = torch.load(buffer) 1371 1372 dist.destroy_process_group() 1373 buffer.seek(0) 1374 with self.assertRaisesRegex( 1375 RuntimeError, "Need to initialize default process group" 1376 ): 1377 state_dict_deser = torch.load(buffer) 1378 rpc.shutdown() 1379 1380 @with_comms 1381 @skip_if_lt_x_gpu(4) 1382 @requires_nccl() 1383 def test_cleanup(self): 1384 def create_tensors(): 1385 spec = ChunkShardingSpec( 1386 dim=0, 1387 placements=[ 1388 "rank:0/cuda:0", 1389 "rank:1/cuda:1", 1390 "rank:2/cuda:2", 1391 "rank:3/cuda:3", 1392 ], 1393 ) 1394 st1 = sharded_tensor.empty(spec, 10, 20, init_rrefs=True) 1395 st2 = sharded_tensor.empty(spec, 10, 20) 1396 1397 create_tensors() 1398 self.assertEqual(0, len(sharded_tensor.api._sharded_tensor_map)) 1399 1400 1401class TestShardedTensorEnumerable(ShardedTensorTestBase): 1402 @with_comms 1403 @skip_if_lt_x_gpu(4) 1404 @requires_nccl() 1405 def test_sharded_tensor_metadata(self): 1406 spec = EnumerableShardingSpec( 1407 [ 1408 ShardMetadata( 1409 shard_offsets=[0, 0], 1410 shard_sizes=[5, 5], 1411 placement="rank:0/cuda:0", 1412 ), 1413 ShardMetadata( 1414 shard_offsets=[0, 5], 1415 shard_sizes=[5, 5], 1416 placement="rank:1/cuda:1", 1417 ), 1418 ShardMetadata( 1419 shard_offsets=[5, 0], 1420 shard_sizes=[5, 5], 1421 placement="rank:2/cuda:2", 1422 ), 1423 ShardMetadata( 1424 shard_offsets=[5, 5], 1425 shard_sizes=[5, 5], 1426 placement="rank:3/cuda:3", 1427 ), 1428 ] 1429 ) 1430 1431 st = sharded_tensor.empty(spec, 10, 10, init_rrefs=True) 1432 st_metadata = st.metadata() 1433 self.assertEqual(torch.Size([10, 10]), st_metadata.size) 1434 self.assertEqual(torch.float, st.dtype) 1435 self.assertEqual(torch.strided, st.layout) 1436 self.assertEqual(False, st.requires_grad) 1437 self.assertTrue(st.is_contiguous()) 1438 self.assertFalse(st.is_pinned()) 1439 1440 st = sharded_tensor.empty(spec, 10, 10, requires_grad=True, init_rrefs=True) 1441 self.assertEqual(True, st.requires_grad) 1442 1443 st = sharded_tensor.empty(spec, 10, 10, dtype=torch.double, init_rrefs=True) 1444 self.assertEqual(torch.double, st.dtype) 1445 1446 # Need CPU for pin_memory 1447 spec = EnumerableShardingSpec( 1448 [ 1449 ShardMetadata( 1450 shard_offsets=[0, 0], 1451 shard_sizes=[5, 5], 1452 placement="rank:0/cpu", 1453 ), 1454 ShardMetadata( 1455 shard_offsets=[0, 5], 1456 shard_sizes=[5, 5], 1457 placement="rank:1/cpu", 1458 ), 1459 ShardMetadata( 1460 shard_offsets=[5, 0], 1461 shard_sizes=[5, 5], 1462 placement="rank:2/cpu", 1463 ), 1464 ShardMetadata( 1465 shard_offsets=[5, 5], 1466 shard_sizes=[5, 5], 1467 placement="rank:3/cpu", 1468 ), 1469 ] 1470 ) 1471 1472 st = sharded_tensor.empty(spec, 10, 10, pin_memory=True, init_rrefs=True) 1473 self.assertTrue(st.is_pinned()) 1474 1475 @with_comms 1476 @skip_if_lt_x_gpu(4) 1477 @requires_nccl() 1478 def test_grid_sharding(self): 1479 spec = EnumerableShardingSpec( 1480 [ 1481 ShardMetadata( 1482 shard_offsets=[0, 0], 1483 shard_sizes=[5, 5], 1484 placement="rank:0/cuda:0", 1485 ), 1486 ShardMetadata( 1487 shard_offsets=[0, 5], 1488 shard_sizes=[5, 5], 1489 placement="rank:1/cuda:1", 1490 ), 1491 ShardMetadata( 1492 shard_offsets=[5, 0], 1493 shard_sizes=[5, 5], 1494 placement="rank:2/cuda:2", 1495 ), 1496 ShardMetadata( 1497 shard_offsets=[5, 5], 1498 shard_sizes=[5, 5], 1499 placement="rank:3/cuda:3", 1500 ), 1501 ] 1502 ) 1503 1504 st = sharded_tensor.empty(spec, 10, 10, init_rrefs=True) 1505 self.assertEqual((10, 10), st.size()) 1506 self.assertEqual(1, len(st.local_shards())) 1507 1508 # Verify local shard. 1509 local_shard = st.local_shards()[0] 1510 self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.tensor.device) 1511 self.assertEqual((5, 5), local_shard.tensor.size()) 1512 1513 # Verify local shard metadata. 1514 self.assertEqual( 1515 (self.rank // 2 * 5, (self.rank % 2) * 5), 1516 local_shard.metadata.shard_offsets, 1517 ) 1518 self.assertEqual((5, 5), local_shard.metadata.shard_sizes) 1519 self.assertEqual( 1520 f"rank:{self.rank}/cuda:{self.rank}", str(local_shard.metadata.placement) 1521 ) 1522 1523 # Verify global metadata. 1524 st_metadata = st.metadata() 1525 shards_metadata = st_metadata.shards_metadata 1526 self.assertEqual(4, len(shards_metadata)) 1527 for rank, shard_metadata in enumerate(shards_metadata): 1528 self.assertEqual( 1529 (rank // 2 * 5, (rank % 2) * 5), shard_metadata.shard_offsets 1530 ) 1531 self.assertEqual((5, 5), shard_metadata.shard_sizes) 1532 self.assertEqual(f"rank:{rank}/cuda:{rank}", str(shard_metadata.placement)) 1533 1534 # Validate remote shards. 1535 remote_shards = st.remote_shards() 1536 self.assertEqual(3, len(remote_shards)) 1537 1538 for rpc_rank, shards in remote_shards.items(): 1539 self.assertEqual(1, len(shards)) 1540 for remote_shard in shards: 1541 self.assertEqual(rpc_rank, remote_shard.owner().id) 1542 shard = remote_shard.to_here() 1543 self.assertEqual((5, 5), shard.tensor.size()) 1544 1545 @with_comms 1546 @skip_if_lt_x_gpu(4) 1547 @requires_nccl() 1548 def test_create_sharded_tensor_with_ones(self): 1549 """Test sharded_tensor.ones(...)""" 1550 1551 spec = EnumerableShardingSpec( 1552 [ 1553 ShardMetadata( 1554 shard_offsets=[0, 0], 1555 shard_sizes=[5, 5], 1556 placement="rank:0/cuda:0", 1557 ), 1558 ShardMetadata( 1559 shard_offsets=[0, 5], 1560 shard_sizes=[5, 5], 1561 placement="rank:1/cuda:1", 1562 ), 1563 ShardMetadata( 1564 shard_offsets=[5, 0], 1565 shard_sizes=[5, 5], 1566 placement="rank:2/cuda:2", 1567 ), 1568 ShardMetadata( 1569 shard_offsets=[5, 5], 1570 shard_sizes=[5, 5], 1571 placement="rank:3/cuda:3", 1572 ), 1573 ] 1574 ) 1575 1576 st = sharded_tensor.ones(spec, 10, 10, init_rrefs=True) 1577 self.assertEqual((10, 10), st.size()) 1578 self.assertEqual(1, len(st.local_shards())) 1579 1580 # Verify local shard is initialized with torch.ones 1581 local_shard = st.local_shards()[0] 1582 self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.tensor.device) 1583 self.assertEqual((5, 5), local_shard.tensor.size()) 1584 self.assertEqual(local_shard.tensor, torch.ones(5, 5)) 1585 1586 @with_comms 1587 @skip_if_lt_x_gpu(4) 1588 @requires_nccl() 1589 def test_gather_even(self) -> None: 1590 """Test _sharded_tensor.gather(...) with evenly distributed._shards""" 1591 1592 spec = EnumerableShardingSpec( 1593 [ 1594 ShardMetadata( 1595 shard_offsets=[0, 0], 1596 shard_sizes=[5, 5], 1597 placement="rank:0/cuda:0", 1598 ), 1599 ShardMetadata( 1600 shard_offsets=[0, 5], 1601 shard_sizes=[5, 5], 1602 placement="rank:1/cuda:1", 1603 ), 1604 ShardMetadata( 1605 shard_offsets=[5, 0], 1606 shard_sizes=[5, 5], 1607 placement="rank:2/cuda:2", 1608 ), 1609 ShardMetadata( 1610 shard_offsets=[5, 5], 1611 shard_sizes=[5, 5], 1612 placement="rank:3/cuda:3", 1613 ), 1614 ] 1615 ) 1616 1617 h, w = 10, 10 1618 st = sharded_tensor.ones(spec, h, w, init_rrefs=True) 1619 1620 full_tensor = None 1621 dst = 0 1622 if self.rank == dst: 1623 full_tensor = torch.zeros(h, w, device=torch.device(f"cuda:{dst}")) 1624 st.gather(dst, full_tensor) 1625 1626 if self.rank == dst: 1627 self.assertEqual(full_tensor, torch.ones(h, w)) 1628 else: 1629 self.assertIsNone(full_tensor) 1630 1631 @with_comms 1632 @skip_if_lt_x_gpu(4) 1633 @requires_nccl() 1634 def test_gather_uneven(self) -> None: 1635 """Test _sharded_tensor.gather(...) with unevenly distributed._shards""" 1636 1637 spec = EnumerableShardingSpec( 1638 [ 1639 ShardMetadata( 1640 shard_offsets=[0, 0], 1641 shard_sizes=[5, 5], 1642 placement="rank:0/cuda:0", 1643 ), 1644 ShardMetadata( 1645 shard_offsets=[0, 5], 1646 shard_sizes=[5, 5], 1647 placement="rank:1/cuda:1", 1648 ), 1649 ShardMetadata( 1650 shard_offsets=[5, 0], 1651 shard_sizes=[5, 5], 1652 placement="rank:0/cuda:0", 1653 ), 1654 ShardMetadata( 1655 shard_offsets=[5, 5], 1656 shard_sizes=[5, 5], 1657 placement="rank:3/cuda:3", 1658 ), 1659 ] 1660 ) 1661 1662 h, w = 10, 10 1663 st = sharded_tensor.ones(spec, h, w, init_rrefs=True) 1664 1665 full_tensor = None 1666 dst = 0 1667 if self.rank == dst: 1668 full_tensor = torch.zeros(h, w, device=torch.device(f"cuda:{dst}")) 1669 st.gather(dst, full_tensor) 1670 1671 if self.rank == dst: 1672 self.assertEqual(full_tensor, torch.ones(h, w)) 1673 else: 1674 self.assertIsNone(full_tensor) 1675 1676 @with_comms 1677 @skip_if_lt_x_gpu(4) 1678 @requires_nccl() 1679 def test_sharded_tensor_to_cpu(self): 1680 cpu_spec = ChunkShardingSpec( 1681 dim=0, 1682 placements=[ 1683 "rank:0/cpu", 1684 "rank:1/cpu", 1685 "rank:2/cpu", 1686 "rank:3/cpu", 1687 ], 1688 ) 1689 spec = ChunkShardingSpec( 1690 dim=0, 1691 placements=[ 1692 "rank:0/cuda:0", 1693 "rank:1/cuda:1", 1694 "rank:2/cuda:2", 1695 "rank:3/cuda:3", 1696 ], 1697 ) 1698 h, w = 10, 20 1699 gloo_pg = dist.new_group(backend="gloo") 1700 1701 # CPU sharded tensor should return the same instance (no copy) 1702 st_cpu = sharded_tensor.zeros(cpu_spec, h, w, process_group=gloo_pg) 1703 new_st_cpu = st_cpu.cpu() 1704 self.assertTrue(st_cpu is new_st_cpu) 1705 1706 # GPU sharded tensor to cpu 1707 st = sharded_tensor.zeros(spec, h, w) 1708 # test ability to move st to CPU 1709 spec_before_move = st.sharding_spec() 1710 new_st = st.cpu(process_group=gloo_pg) 1711 # return a copy of original st 1712 self.assertFalse(st is new_st) 1713 # check the spec is still ChunkShardingSpec 1714 spec_after_move = new_st.sharding_spec() 1715 self.assertIsInstance(spec_after_move, ChunkShardingSpec) 1716 self.assertIsInstance(new_st._process_group, distributed_c10d.ProcessGroup) 1717 # test specs before and after the move almost the same except placement device 1718 self.assertEqual(spec_before_move.dim, spec_after_move.dim) 1719 self.assertEqual( 1720 len(spec_before_move.placements), len(spec_after_move.placements) 1721 ) 1722 for i, remote_device_after in enumerate(spec_after_move.placements): 1723 remote_device_before = spec_before_move.placements[i] 1724 self.assertEqual(remote_device_before.rank(), remote_device_after.rank()) 1725 self.assertEqual(str(remote_device_after.device()), "cpu") 1726 1727 # ensure metdata also get changed to CPU 1728 metas = new_st.metadata().shards_metadata 1729 for meta in metas: 1730 self.assertEqual(str(meta.placement.device()), "cpu") 1731 1732 # Test if a mixed sharded tensor (ShardedTensor with different devices) to cpu 1733 mixed_spec = ChunkShardingSpec( 1734 dim=0, 1735 placements=[ 1736 "rank:0/cpu", 1737 "rank:1/cpu", 1738 "rank:2/cuda:2", 1739 "rank:3/cuda:3", 1740 ], 1741 ) 1742 1743 st = sharded_tensor.zeros(mixed_spec, h, w, process_group=gloo_pg) 1744 new_st = st.cpu() 1745 # return a copy of original st 1746 self.assertFalse(st is new_st) 1747 # check the spec is still ChunkShardingSpec 1748 spec_after_move = new_st.sharding_spec() 1749 self.assertIsInstance(spec_after_move, ChunkShardingSpec) 1750 # test specs before and after the move almost the same except placement device 1751 self.assertEqual(mixed_spec.dim, spec_after_move.dim) 1752 self.assertEqual(len(mixed_spec.placements), len(spec_after_move.placements)) 1753 for i, remote_device_after in enumerate(spec_after_move.placements): 1754 remote_device_before = mixed_spec.placements[i] 1755 self.assertEqual(remote_device_before.rank(), remote_device_after.rank()) 1756 self.assertEqual(str(remote_device_after.device()), "cpu") 1757 1758 # ensure metdata also get changed to CPU 1759 metas = new_st.metadata().shards_metadata 1760 for meta in metas: 1761 self.assertEqual(str(meta.placement.device()), "cpu") 1762 1763 @with_comms 1764 @skip_if_lt_x_gpu(4) 1765 @requires_nccl() 1766 def test_sharded_tensor_to_cuda(self): 1767 cpu_spec = ChunkShardingSpec( 1768 dim=0, 1769 placements=[ 1770 "rank:0/cpu", 1771 "rank:1/cpu", 1772 "rank:2/cpu", 1773 "rank:3/cpu", 1774 ], 1775 ) 1776 spec = ChunkShardingSpec( 1777 dim=0, 1778 placements=[ 1779 "rank:0/cuda:0", 1780 "rank:1/cuda:1", 1781 "rank:2/cuda:2", 1782 "rank:3/cuda:3", 1783 ], 1784 ) 1785 h, w = 10, 20 1786 # CUDA sharded tensor should return a new ShardedTensor, but same 1787 # local shards(no movements) 1788 st_cuda = sharded_tensor.zeros(spec, h, w) 1789 new_st_cuda = st_cuda.cuda() 1790 self.assertTrue(st_cuda is not new_st_cuda) 1791 self.assertTrue(st_cuda.local_tensor() is new_st_cuda.local_tensor()) 1792 1793 gloo_pg = dist.new_group(backend="gloo") 1794 1795 # CPU sharded tensor to GPU 1796 st_cpu = sharded_tensor.zeros(cpu_spec, h, w, process_group=gloo_pg) 1797 # test ability to move st to GPU 1798 spec_before_move = st_cpu.sharding_spec() 1799 new_st_gpu = st_cpu.cuda() 1800 # check the spec is still ChunkShardingSpec 1801 spec_after_move = new_st_gpu.sharding_spec() 1802 self.assertIsInstance(spec_after_move, ChunkShardingSpec) 1803 # test specs before and after the move almost the same except placement device 1804 self.assertEqual(spec_before_move.dim, spec_after_move.dim) 1805 self.assertEqual( 1806 len(spec_before_move.placements), len(spec_after_move.placements) 1807 ) 1808 for i, remote_device_after in enumerate(spec_after_move.placements): 1809 remote_device_before = spec_before_move.placements[i] 1810 self.assertEqual(remote_device_before.rank(), remote_device_after.rank()) 1811 self.assertEqual(str(remote_device_before.device().type), "cpu") 1812 self.assertEqual(str(remote_device_after.device().type), "cuda") 1813 1814 # ensure metdata also get changed to GPU 1815 metas = new_st_gpu.metadata().shards_metadata 1816 for meta in metas: 1817 self.assertEqual(str(meta.placement.device().type), "cuda") 1818 1819 @with_comms 1820 @skip_if_lt_x_gpu(4) 1821 @requires_nccl() 1822 def test_sharded_tensor_to_test(self): 1823 spec = ChunkShardingSpec( 1824 dim=0, 1825 placements=[ 1826 "rank:0/cuda:0", 1827 "rank:1/cuda:1", 1828 "rank:2/cuda:2", 1829 "rank:3/cuda:3", 1830 ], 1831 ) 1832 h, w = 10, 20 1833 # CUDA sharded tensor should return a new ShardedTensor, but same 1834 # local shards(no movements) 1835 st = sharded_tensor.zeros(spec, h, w) 1836 # test same dtype, device return itself 1837 st_self = st.to(dtype=st.dtype, device="cuda") 1838 self.assertTrue(st_self is st) 1839 1840 # test dtype to 1841 st_16 = st.to(torch.float16) 1842 self.assertFalse(st_16 is st) 1843 self.assertEqual(st_16.dtype, torch.float16) 1844 # test device to 1845 st_cpu = st.to(device=torch.device("cpu")) 1846 self.assertFalse(st_cpu is st) 1847 self.assertEqual(st_cpu.local_tensor().device.type, "cpu") 1848 st_cuda = st_cpu.to(device=torch.device("cuda")) 1849 self.assertEqual(st_cuda.local_tensor().device.type, "cuda") 1850 # non-kwarg device to 1851 st_cuda = st_cpu.to(torch.device("cuda")) 1852 self.assertEqual(st_cuda.local_tensor().device.type, "cuda") 1853 st_cpu = st_cuda.to(torch.device("cpu")) 1854 self.assertEqual(st_cpu.local_tensor().device.type, "cpu") 1855 # with string like device conversion 1856 st_cpu = st_cuda.to("cpu") 1857 self.assertEqual(st_cpu.local_tensor().device.type, "cpu") 1858 st_cuda = st_cpu.to("cuda") 1859 self.assertEqual(st_cuda.local_tensor().device.type, "cuda") 1860 # with int like device conversion 1861 st_cpu = st_cuda.to("cpu") 1862 self.assertEqual(st_cpu.local_tensor().device.type, "cpu") 1863 st_cuda = st_cpu.to(self.rank) 1864 self.assertEqual(st_cuda.local_tensor().device.type, "cuda") 1865 1866 # test tensor to 1867 cuda_tensor = torch.randn(3, 4, dtype=torch.float16, device="cuda") 1868 st_cuda = st.to(cuda_tensor) 1869 self.assertFalse(st_cuda is st) 1870 self.assertEqual(st_cuda.dtype, torch.float16) 1871 1872 cuda_tensor = torch.randn(3, 4, dtype=torch.float16, device="cuda:2") 1873 st_cuda = st.to(cuda_tensor) 1874 self.assertEqual(st_cuda.dtype, torch.float16) 1875 1876 # test dtype and device together 1877 st_cpu_16 = st.to("cpu", torch.float16) 1878 self.assertEqual(st_cpu_16.dtype, torch.float16) 1879 self.assertEqual(st_cpu_16.local_tensor().device.type, "cpu") 1880 1881 st_cuda_32 = st_cpu_16.to("cuda", torch.float32) 1882 self.assertEqual(st_cuda_32.dtype, torch.float32) 1883 self.assertEqual(st_cuda_32.local_tensor().device.type, "cuda") 1884 1885 # test pass additional process group 1886 gloo_pg = dist.new_group(backend="gloo") 1887 st_gloo = st.to(device="cpu", process_group=gloo_pg) 1888 self.assertFalse(st_gloo is st) 1889 self.assertEqual(st_gloo.local_tensor().device.type, "cpu") 1890 self.assertEqual(st_gloo._process_group, gloo_pg) 1891 1892 @with_comms 1893 @skip_if_lt_x_gpu(4) 1894 @requires_nccl() 1895 def test_sharded_tensor_device(self): 1896 spec = ChunkShardingSpec( 1897 dim=0, 1898 placements=[ 1899 "rank:0/cuda:0", 1900 "rank:1/cuda:1", 1901 "rank:2/cuda:2", 1902 "rank:3/cuda:3", 1903 ], 1904 ) 1905 h, w = 10, 20 1906 # CUDA sharded tensor should return a new ShardedTensor, but same 1907 # local shards(no movements) 1908 st = sharded_tensor.zeros(spec, h, w) 1909 current_device = torch.device(torch.cuda.current_device()) 1910 self.assertEqual(current_device, st.device) 1911 1912 # test after to cpu, device get changed 1913 cpu_device = torch.device("cpu") 1914 st_cpu = st.to(device=cpu_device) 1915 self.assertEqual(st_cpu.device, cpu_device) 1916 1917 @skip_if_lt_x_gpu(4) 1918 @requires_nccl() 1919 def test_uneven_shards(self): 1920 self.init_pg() 1921 1922 spec = EnumerableShardingSpec( 1923 [ 1924 ShardMetadata( 1925 shard_offsets=[0, 0], 1926 shard_sizes=[2, 4], 1927 placement="rank:0/cuda:0", 1928 ), 1929 ShardMetadata( 1930 shard_offsets=[0, 4], 1931 shard_sizes=[4, 2], 1932 placement="rank:1/cuda:1", 1933 ), 1934 ShardMetadata( 1935 shard_offsets=[2, 0], 1936 shard_sizes=[4, 4], 1937 placement="rank:2/cuda:2", 1938 ), 1939 ShardMetadata( 1940 shard_offsets=[4, 4], 1941 shard_sizes=[2, 2], 1942 placement="rank:3/cuda:3", 1943 ), 1944 ] 1945 ) 1946 1947 st = sharded_tensor.empty(spec, 6, 6) 1948 self.assertEqual((6, 6), st.size()) 1949 self.assertEqual(1, len(st.local_shards())) 1950 1951 def verify_size(rank, tensor_dims): 1952 if rank == 0: 1953 self.assertEqual((2, 4), tensor_dims) 1954 elif rank == 1: 1955 self.assertEqual((4, 2), tensor_dims) 1956 elif rank == 2: 1957 self.assertEqual((4, 4), tensor_dims) 1958 elif rank == 3: 1959 self.assertEqual((2, 2), tensor_dims) 1960 1961 def verify_offsets(rank, offsets): 1962 if rank == 0: 1963 self.assertEqual((0, 0), offsets) 1964 elif rank == 1: 1965 self.assertEqual((0, 4), offsets) 1966 elif rank == 2: 1967 self.assertEqual((2, 0), offsets) 1968 elif rank == 3: 1969 self.assertEqual((4, 4), offsets) 1970 1971 # Verify local shard. 1972 local_shard = st.local_shards()[0] 1973 self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.tensor.device) 1974 verify_size(self.rank, local_shard.tensor.size()) 1975 1976 # Verify local shard metadata. 1977 verify_offsets(self.rank, local_shard.metadata.shard_offsets) 1978 verify_size(self.rank, local_shard.metadata.shard_sizes) 1979 self.assertEqual( 1980 f"rank:{self.rank}/cuda:{self.rank}", str(local_shard.metadata.placement) 1981 ) 1982 1983 # Verify global metadata. 1984 st_metadata = st.metadata() 1985 shards_metadata = st_metadata.shards_metadata 1986 self.assertEqual(4, len(shards_metadata)) 1987 for rank, shard_metadata in enumerate(shards_metadata): 1988 verify_offsets(rank, shard_metadata.shard_offsets) 1989 verify_size(rank, shard_metadata.shard_sizes) 1990 self.assertEqual(f"rank:{rank}/cuda:{rank}", str(shard_metadata.placement)) 1991 1992 @with_comms 1993 @skip_if_lt_x_gpu(4) 1994 @requires_nccl() 1995 def test_partial_world_size(self): 1996 spec = EnumerableShardingSpec( 1997 [ 1998 ShardMetadata( 1999 shard_offsets=[0, 0], 2000 shard_sizes=[5, 5], 2001 placement="rank:0/cuda:0", 2002 ), 2003 ShardMetadata( 2004 shard_offsets=[5, 0], 2005 shard_sizes=[5, 5], 2006 placement="rank:1/cuda:1", 2007 ), 2008 ] 2009 ) 2010 2011 st = sharded_tensor.empty(spec, 10, 5, init_rrefs=True) 2012 self.assertEqual((10, 5), st.size()) 2013 if self.rank <= 1: 2014 self.assertEqual(1, len(st.local_shards())) 2015 else: 2016 self.assertEqual(0, len(st.local_shards())) 2017 2018 if self.rank <= 1: 2019 # Verify local shard. 2020 local_shard = st.local_shards()[0] 2021 self.assertEqual( 2022 torch.device(f"cuda:{self.rank}"), local_shard.tensor.device 2023 ) 2024 self.assertEqual((5, 5), local_shard.tensor.size()) 2025 2026 # Verify local shard metadata. 2027 self.assertEqual((self.rank * 5, 0), local_shard.metadata.shard_offsets) 2028 self.assertEqual((5, 5), local_shard.metadata.shard_sizes) 2029 self.assertEqual( 2030 f"rank:{self.rank}/cuda:{self.rank}", 2031 str(local_shard.metadata.placement), 2032 ) 2033 2034 # Verify global metadata. 2035 st_metadata = st.metadata() 2036 shards_metadata = st_metadata.shards_metadata 2037 self.assertEqual(2, len(shards_metadata)) 2038 for rank, shard_metadata in enumerate(shards_metadata): 2039 self.assertEqual((rank * 5, 0), shard_metadata.shard_offsets) 2040 self.assertEqual((5, 5), shard_metadata.shard_sizes) 2041 self.assertEqual(f"rank:{rank}/cuda:{rank}", str(shard_metadata.placement)) 2042 2043 # Validate remote shards. 2044 remote_shards = st.remote_shards() 2045 if self.rank <= 1: 2046 self.assertEqual(1, len(remote_shards)) 2047 else: 2048 self.assertEqual(2, len(remote_shards)) 2049 2050 for rpc_rank, shards in remote_shards.items(): 2051 self.assertEqual(1, len(shards)) 2052 2053 for remote_shard in shards: 2054 self.assertEqual(rpc_rank, remote_shard.owner().id) 2055 shard = remote_shard.to_here() 2056 self.assertEqual((5, 5), shard.tensor.size()) 2057 2058 @with_comms 2059 @skip_if_lt_x_gpu(4) 2060 @requires_nccl() 2061 def test_new_group(self): 2062 spec = EnumerableShardingSpec( 2063 [ 2064 ShardMetadata( 2065 shard_offsets=[0, 0], 2066 shard_sizes=[5, 5], 2067 placement="rank:1/cuda:1", 2068 ), 2069 ShardMetadata( 2070 shard_offsets=[5, 0], 2071 shard_sizes=[5, 5], 2072 placement="rank:3/cuda:3", 2073 ), 2074 ] 2075 ) 2076 2077 pg = dist.new_group(ranks=[1, 2, 3]) 2078 2079 st = sharded_tensor.empty(spec, 10, 5, process_group=pg, init_rrefs=True) 2080 self.assertEqual((10, 5), st.size()) 2081 if self.rank == 1 or self.rank == 3: 2082 # Verify local shard. 2083 local_shard = st.local_shards()[0] 2084 self.assertEqual( 2085 torch.device(f"cuda:{self.rank}"), local_shard.tensor.device 2086 ) 2087 self.assertEqual((5, 5), local_shard.tensor.size()) 2088 2089 # Verify local shard metadata. 2090 self.assertEqual( 2091 (self.rank // 2 * 5, 0), local_shard.metadata.shard_offsets 2092 ) 2093 self.assertEqual((5, 5), local_shard.metadata.shard_sizes) 2094 self.assertEqual( 2095 f"rank:{self.rank}/cuda:{self.rank}", 2096 str(local_shard.metadata.placement), 2097 ) 2098 2099 # Verify global metadata. 2100 st_metadata = st.metadata() 2101 shards_metadata = st_metadata.shards_metadata 2102 self.assertEqual(2, len(shards_metadata)) 2103 for rank, shard_metadata in enumerate(shards_metadata): 2104 self.assertEqual((rank * 5, 0), shard_metadata.shard_offsets) 2105 self.assertEqual((5, 5), shard_metadata.shard_sizes) 2106 self.assertEqual( 2107 f"rank:{rank * 2 + 1}/cuda:{rank * 2 + 1}", 2108 str(shard_metadata.placement), 2109 ) 2110 2111 # Validate remote shards. 2112 remote_shards = st.remote_shards() 2113 if self.rank == 1 or self.rank == 3: 2114 self.assertEqual(1, len(remote_shards)) 2115 else: 2116 self.assertEqual(2, len(remote_shards)) 2117 2118 for rpc_rank, shards in remote_shards.items(): 2119 self.assertEqual(1, len(shards)) 2120 2121 for remote_shard in shards: 2122 self.assertEqual(rpc_rank, remote_shard.owner().id) 2123 shard = remote_shard.to_here() 2124 self.assertEqual((5, 5), shard.tensor.size()) 2125 2126 @with_comms 2127 @skip_if_lt_x_gpu(4) 2128 @requires_nccl() 2129 def test_multiple_local_shards(self): 2130 spec = EnumerableShardingSpec( 2131 [ 2132 ShardMetadata( 2133 shard_offsets=[0, 0], 2134 shard_sizes=[5, 5], 2135 placement="rank:0/cuda:0", 2136 ), 2137 ShardMetadata( 2138 shard_offsets=[0, 5], 2139 shard_sizes=[5, 5], 2140 placement="rank:1/cuda:1", 2141 ), 2142 ShardMetadata( 2143 shard_offsets=[5, 0], 2144 shard_sizes=[5, 5], 2145 placement="rank:0/cuda:0", 2146 ), 2147 ShardMetadata( 2148 shard_offsets=[5, 5], 2149 shard_sizes=[5, 5], 2150 placement="rank:1/cuda:1", 2151 ), 2152 ] 2153 ) 2154 2155 st = sharded_tensor.empty(spec, 10, 10, init_rrefs=True) 2156 self.assertEqual((10, 10), st.size()) 2157 2158 if self.rank <= 1: 2159 self.assertEqual(2, len(st.local_shards())) 2160 2161 # Verify local shards. 2162 for idx, local_shard in enumerate(st.local_shards()): 2163 self.assertEqual( 2164 torch.device(f"cuda:{self.rank}"), local_shard.tensor.device 2165 ) 2166 self.assertEqual((5, 5), local_shard.tensor.size()) 2167 2168 # Verify local shard metadata. 2169 self.assertEqual( 2170 (idx * 5, self.rank * 5), local_shard.metadata.shard_offsets 2171 ) 2172 self.assertEqual((5, 5), local_shard.metadata.shard_sizes) 2173 self.assertEqual( 2174 f"rank:{self.rank}/cuda:{self.rank}", 2175 str(local_shard.metadata.placement), 2176 ) 2177 else: 2178 self.assertEqual(0, len(st.local_shards())) 2179 2180 # Verify global metadata. 2181 st_metadata = st.metadata() 2182 shards_metadata = st_metadata.shards_metadata 2183 self.assertEqual(4, len(shards_metadata)) 2184 for shard_rank, shard_metadata in enumerate(shards_metadata): 2185 self.assertEqual( 2186 (shard_rank // 2 * 5, (shard_rank % 2) * 5), 2187 shard_metadata.shard_offsets, 2188 ) 2189 self.assertEqual((5, 5), shard_metadata.shard_sizes) 2190 self.assertEqual( 2191 f"rank:{shard_rank % 2}/cuda:{shard_rank % 2}", 2192 str(shard_metadata.placement), 2193 ) 2194 2195 # Validate remote shards. 2196 remote_shards = st.remote_shards() 2197 if self.rank <= 1: 2198 self.assertEqual(1, len(remote_shards)) 2199 else: 2200 self.assertEqual(2, len(remote_shards)) 2201 2202 owners = {} 2203 for rpc_rank, shards in remote_shards.items(): 2204 self.assertEqual(2, len(shards)) 2205 for remote_shard in shards: 2206 self.assertEqual(rpc_rank, remote_shard.owner().id) 2207 shard = remote_shard.to_here() 2208 self.assertEqual((5, 5), shard.tensor.size()) 2209 2210 @with_comms 2211 @skip_if_lt_x_gpu(4) 2212 @requires_nccl() 2213 def test_with_rpc_names(self): 2214 spec = EnumerableShardingSpec( 2215 [ 2216 ShardMetadata( 2217 shard_offsets=[0, 0], 2218 shard_sizes=[5, 5], 2219 placement="worker0/cuda:0", 2220 ), 2221 ShardMetadata( 2222 shard_offsets=[0, 5], 2223 shard_sizes=[5, 5], 2224 placement="worker1/cuda:1", 2225 ), 2226 ShardMetadata( 2227 shard_offsets=[5, 0], 2228 shard_sizes=[5, 5], 2229 placement="worker2/cuda:2", 2230 ), 2231 ShardMetadata( 2232 shard_offsets=[5, 5], 2233 shard_sizes=[5, 5], 2234 placement="worker3/cuda:3", 2235 ), 2236 ] 2237 ) 2238 2239 st = sharded_tensor.empty(spec, 10, 10, init_rrefs=True) 2240 self.assertEqual((10, 10), st.size()) 2241 self.assertEqual(1, len(st.local_shards())) 2242 2243 # Verify local shard. 2244 local_shard = st.local_shards()[0] 2245 self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.tensor.device) 2246 self.assertEqual((5, 5), local_shard.tensor.size()) 2247 2248 # Verify local shard metadata. 2249 self.assertEqual( 2250 (self.rank // 2 * 5, (self.rank % 2) * 5), 2251 local_shard.metadata.shard_offsets, 2252 ) 2253 self.assertEqual((5, 5), local_shard.metadata.shard_sizes) 2254 self.assertEqual( 2255 f"worker{self.rank}/cuda:{self.rank}", str(local_shard.metadata.placement) 2256 ) 2257 2258 # Verify global metadata. 2259 st_metadata = st.metadata() 2260 shards_metadata = st_metadata.shards_metadata 2261 self.assertEqual(4, len(shards_metadata)) 2262 for rank, shard_metadata in enumerate(shards_metadata): 2263 self.assertEqual( 2264 (rank // 2 * 5, (rank % 2) * 5), shard_metadata.shard_offsets 2265 ) 2266 self.assertEqual((5, 5), shard_metadata.shard_sizes) 2267 self.assertEqual(f"worker{rank}/cuda:{rank}", str(shard_metadata.placement)) 2268 2269 # Validate remote shards. 2270 remote_shards = st.remote_shards() 2271 self.assertEqual(3, len(remote_shards)) 2272 2273 for rpc_rank, shards in remote_shards.items(): 2274 self.assertEqual(1, len(shards)) 2275 for remote_shard in shards: 2276 self.assertEqual(rpc_rank, remote_shard.owner().id) 2277 shard = remote_shard.to_here() 2278 self.assertEqual((5, 5), shard.tensor.size()) 2279 2280 2281class TestShardedTensorFromLocalTensor(ShardedTensorTestBase): 2282 def _generate_st_from_chunk_local_tensor(self, st_size, sharding_spec): 2283 tensor_meta = sharding_spec.build_metadata(st_size, TensorProperties()) 2284 pg = dist.distributed_c10d._get_default_group() 2285 2286 local_tensor = None 2287 local_shard_metadata = None 2288 rank_to_metadata = {} 2289 for shard_metadata in tensor_meta.shards_metadata: 2290 rank, device = _parse_and_validate_remote_device( 2291 pg, shard_metadata.placement 2292 ) 2293 rank_to_metadata[rank] = shard_metadata 2294 if rank == self.rank: 2295 local_tensor = torch.rand(shard_metadata.shard_sizes).cuda(device) 2296 local_shard_metadata = shard_metadata 2297 2298 # TODO: figure out what the API should behave when some rank have no shard 2299 # see https://github.com/pytorch/pytorch/issues/73133 2300 assert local_tensor is not None 2301 st = ShardedTensor._init_from_local_tensor( 2302 local_tensor, 2303 sharding_spec, 2304 st_size, 2305 init_rrefs=True, 2306 ) 2307 self.assertEqual(tuple(st_size), st.size()) 2308 self.assertEqual(1, len(st.local_shards())) 2309 2310 # Verify local shard. 2311 local_shard = st.local_shards()[0] 2312 self.assertEqual(st.local_tensor(), local_tensor) 2313 self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.tensor.device) 2314 2315 # Verify local shard metadata. 2316 self.assertEqual( 2317 local_shard_metadata.shard_offsets, local_shard.metadata.shard_offsets 2318 ) 2319 self.assertEqual( 2320 local_shard_metadata.shard_sizes, local_shard.metadata.shard_sizes 2321 ) 2322 self.assertEqual(local_shard_metadata.placement, local_shard.metadata.placement) 2323 2324 # Verify global metadata. 2325 st_shards_metadata = st.metadata().shards_metadata 2326 self.assertEqual(self.world_size, len(st_shards_metadata)) 2327 self.assertEqual(tensor_meta.shards_metadata, st_shards_metadata) 2328 2329 # Validate remote shards. 2330 remote_shards = st.remote_shards() 2331 self.assertEqual(self.world_size - 1, len(remote_shards)) 2332 for rpc_rank, shards in remote_shards.items(): 2333 self.assertEqual(1, len(shards)) 2334 for remote_shard in shards: 2335 self.assertEqual(rpc_rank, remote_shard.owner().id) 2336 # If remote shard does not exist, to_here() will throw exception. 2337 if tensor_meta.shards_metadata[rpc_rank]: 2338 shard = remote_shard.to_here() 2339 self.assertEqual( 2340 rank_to_metadata[rpc_rank].shard_sizes, shard.tensor.size() 2341 ) 2342 2343 @with_comms 2344 @skip_if_lt_x_gpu(4) 2345 @requires_nccl() 2346 def test_init_from_local_tensor(self): 2347 chunk_specs = _chunk_sharding_specs_list_for_test([0, 1, 1, 0], seed=31) 2348 for spec in chunk_specs: 2349 self._generate_st_from_chunk_local_tensor([20, 10], spec) 2350 self._generate_st_from_chunk_local_tensor([21, 11], spec) 2351 self._generate_st_from_chunk_local_tensor([23, 16], spec) 2352 self._generate_st_from_chunk_local_tensor([44, 16, 8], spec) 2353 2354 @with_comms 2355 @skip_if_lt_x_gpu(4) 2356 @requires_nccl() 2357 def test_init_from_local_tensor_errors(self): 2358 enumerable_sharding_spec = EnumerableShardingSpec( 2359 [ 2360 ShardMetadata( 2361 shard_offsets=[0, 0], 2362 shard_sizes=[5, 5], 2363 placement="rank:0/cuda:0", 2364 ), 2365 ShardMetadata( 2366 shard_offsets=[5, 0], 2367 shard_sizes=[5, 5], 2368 placement="rank:1/cuda:1", 2369 ), 2370 ] 2371 ) 2372 st_size = [24, 12] 2373 local_tensor = torch.rand(*st_size).cuda(self.rank) 2374 with self.assertRaisesRegex(ValueError, "do not cover the entire tensor"): 2375 ShardedTensor._init_from_local_tensor( 2376 local_tensor, 2377 enumerable_sharding_spec, 2378 st_size, 2379 ) 2380 chunk_specs = _chunk_sharding_specs_list_for_test([0], seed=31) 2381 with self.assertRaisesRegex( 2382 ValueError, "local_tensor is not a contiguous Tensor." 2383 ): 2384 ShardedTensor._init_from_local_tensor( 2385 local_tensor.t(), 2386 chunk_specs[0], 2387 st_size, 2388 ) 2389 2390 2391class TestShardedTensorFromLocalShards(ShardedTensorTestBase): 2392 @with_comms(init_rpc=False) 2393 @skip_if_lt_x_gpu(4) 2394 @requires_nccl() 2395 def test_local_shards(self): 2396 shard_offsets = [(self.rank // 2) * 5, (self.rank % 2) * 5] 2397 local_shard_metadata = ShardMetadata( 2398 shard_offsets=shard_offsets, 2399 shard_sizes=[5, 5], 2400 placement=f"rank:{self.rank}/cuda:{self.rank}", 2401 ) 2402 2403 local_tensor = torch.randn(5, 5, device=f"cuda:{self.rank}") 2404 local_shard = sharded_tensor.Shard(local_tensor, local_shard_metadata) 2405 local_shard_from_offsets = sharded_tensor.Shard.from_tensor_and_offsets( 2406 local_tensor, shard_offsets=shard_offsets, rank=self.rank 2407 ) 2408 self.assertEqual(local_shard.metadata, local_shard_from_offsets.metadata) 2409 2410 wrong_local_shard_metadata = ShardMetadata( 2411 shard_offsets=shard_offsets, 2412 shard_sizes=[6, 5], 2413 placement=f"rank:{self.rank}/cuda:{self.rank}", 2414 ) 2415 with self.assertRaisesRegex(ValueError, "Shard tensor size does not match"): 2416 local_shard_from_wrong_meta = sharded_tensor.Shard( 2417 local_tensor, 2418 metadata=wrong_local_shard_metadata, 2419 ) 2420 2421 @with_comms 2422 @skip_if_lt_x_gpu(4) 2423 @requires_nccl() 2424 def test_init_from_local_shards(self): 2425 local_shard_metadata = ShardMetadata( 2426 shard_offsets=[(self.rank // 2) * 5, (self.rank % 2) * 5], 2427 shard_sizes=[5, 5], 2428 placement=f"rank:{self.rank}/cuda:{self.rank}", 2429 ) 2430 2431 local_shards = [ 2432 sharded_tensor.Shard( 2433 torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata 2434 ) 2435 ] 2436 2437 st = sharded_tensor.init_from_local_shards( 2438 local_shards, [10, 10], init_rrefs=True 2439 ) 2440 self.assertEqual((10, 10), st.size()) 2441 self.assertEqual(1, len(st.local_shards())) 2442 2443 # Verify local shard. 2444 local_shard = st.local_shards()[0] 2445 self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.tensor.device) 2446 self.assertEqual((5, 5), local_shard.tensor.size()) 2447 2448 # Verify local shard metadata. 2449 self.assertEqual( 2450 (self.rank // 2 * 5, (self.rank % 2) * 5), 2451 local_shard.metadata.shard_offsets, 2452 ) 2453 self.assertEqual((5, 5), local_shard.metadata.shard_sizes) 2454 self.assertEqual( 2455 f"rank:{self.rank}/cuda:{self.rank}", str(local_shard.metadata.placement) 2456 ) 2457 2458 # Verify global metadata. 2459 shards_metadata = st.metadata().shards_metadata 2460 self.assertEqual(4, len(shards_metadata)) 2461 for rank, shard_metadata in enumerate(shards_metadata): 2462 self.assertEqual( 2463 (rank // 2 * 5, (rank % 2) * 5), shard_metadata.shard_offsets 2464 ) 2465 self.assertEqual((5, 5), shard_metadata.shard_sizes) 2466 self.assertEqual(f"rank:{rank}/cuda:{rank}", str(shard_metadata.placement)) 2467 2468 # Validate remote shards. 2469 remote_shards = st.remote_shards() 2470 self.assertEqual(3, len(remote_shards)) 2471 2472 for rpc_rank, shards in remote_shards.items(): 2473 self.assertEqual(1, len(shards)) 2474 for remote_shard in shards: 2475 self.assertEqual(rpc_rank, remote_shard.owner().id) 2476 shard = remote_shard.to_here() 2477 self.assertEqual((5, 5), shard.tensor.size()) 2478 2479 @skip_if_lt_x_gpu(4) 2480 def test_st_base_init_from_local_shards_and_global_metadata(self): 2481 world_size = 4 2482 shards_metadata = [] 2483 shards = [] 2484 for rank in range(world_size): 2485 local_shard_metadata = ShardMetadata( 2486 shard_offsets=[(rank // 2) * 5, (rank % 2) * 5], 2487 shard_sizes=[5, 5], 2488 placement=f"rank:{rank}/cuda:{rank}", 2489 ) 2490 shards_metadata.append(local_shard_metadata) 2491 shards.append( 2492 sharded_tensor.Shard( 2493 torch.randn(5, 5, device=f"cuda:{rank}"), local_shard_metadata 2494 ) 2495 ) 2496 2497 tensor_properties = TensorProperties( 2498 dtype=torch.get_default_dtype(), 2499 layout=torch.strided, 2500 requires_grad=False, 2501 memory_format=torch.contiguous_format, 2502 pin_memory=False, 2503 ) 2504 2505 sharded_tensor_metadata = sharded_tensor.ShardedTensorMetadata( 2506 shards_metadata=shards_metadata, 2507 size=torch.Size([10, 10]), 2508 tensor_properties=tensor_properties, 2509 ) 2510 2511 st_base = sharded_tensor.ShardedTensorBase._init_from_local_shards_and_global_metadata( 2512 shards, sharded_tensor_metadata=sharded_tensor_metadata 2513 ) 2514 self.assertEqual(4, len(st_base.local_shards())) 2515 2516 # Verify local shard of st_base 2517 local_shard = st_base.local_shards()[0] 2518 self.assertEqual(torch.device("cuda:0"), local_shard.tensor.device) 2519 self.assertEqual((5, 5), local_shard.tensor.size()) 2520 2521 # Verify local shard metadata. 2522 self.assertEqual( 2523 (0, 0), 2524 local_shard.metadata.shard_offsets, 2525 ) 2526 self.assertEqual((5, 5), local_shard.metadata.shard_sizes) 2527 self.assertEqual("rank:0/cuda:0", str(local_shard.metadata.placement)) 2528 2529 # Verify global metadata. 2530 shards_metadata = st_base.metadata().shards_metadata 2531 self.assertEqual(4, len(shards_metadata)) 2532 for rank, shard_metadata in enumerate(shards_metadata): 2533 self.assertEqual( 2534 (rank // 2 * 5, (rank % 2) * 5), shard_metadata.shard_offsets 2535 ) 2536 self.assertEqual((5, 5), shard_metadata.shard_sizes) 2537 self.assertEqual(f"rank:{rank}/cuda:{rank}", str(shard_metadata.placement)) 2538 2539 @with_comms 2540 @skip_if_lt_x_gpu(4) 2541 @requires_nccl() 2542 def test_init_from_local_shards_and_global_metadata(self): 2543 local_shard_metadata = ShardMetadata( 2544 shard_offsets=[(self.rank // 2) * 5, (self.rank % 2) * 5], 2545 shard_sizes=[5, 5], 2546 placement=f"rank:{self.rank}/cuda:{self.rank}", 2547 ) 2548 2549 shards_metadata = [] 2550 for r in range(self.world_size): 2551 if r == self.rank: 2552 shards_metadata.append(local_shard_metadata) 2553 else: 2554 shards_metadata.append( 2555 ShardMetadata( 2556 shard_offsets=[(r // 2) * 5, (r % 2) * 5], 2557 shard_sizes=[5, 5], 2558 placement=f"rank:{r}/cuda:{r}", 2559 ) 2560 ) 2561 2562 local_shards = [ 2563 sharded_tensor.Shard( 2564 torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata 2565 ) 2566 ] 2567 2568 tensor_properties = TensorProperties( 2569 dtype=torch.get_default_dtype(), 2570 layout=torch.strided, 2571 requires_grad=False, 2572 memory_format=torch.contiguous_format, 2573 pin_memory=False, 2574 ) 2575 2576 sharded_tensor_metadata = sharded_tensor.ShardedTensorMetadata( 2577 shards_metadata=shards_metadata, 2578 size=torch.Size([10, 10]), 2579 tensor_properties=tensor_properties, 2580 ) 2581 2582 st = ShardedTensor._init_from_local_shards_and_global_metadata( 2583 local_shards, 2584 sharded_tensor_metadata, 2585 init_rrefs=True, 2586 ) 2587 self.assertEqual((10, 10), st.size()) 2588 self.assertEqual(1, len(st.local_shards())) 2589 2590 # Verify local shard. 2591 local_shard = st.local_shards()[0] 2592 self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.tensor.device) 2593 self.assertEqual((5, 5), local_shard.tensor.size()) 2594 2595 # Verify local shard metadata. 2596 self.assertEqual( 2597 (self.rank // 2 * 5, (self.rank % 2) * 5), 2598 local_shard.metadata.shard_offsets, 2599 ) 2600 self.assertEqual((5, 5), local_shard.metadata.shard_sizes) 2601 self.assertEqual( 2602 f"rank:{self.rank}/cuda:{self.rank}", str(local_shard.metadata.placement) 2603 ) 2604 2605 # Verify global metadata. 2606 shards_metadata = st.metadata().shards_metadata 2607 self.assertEqual(4, len(shards_metadata)) 2608 for rank, shard_metadata in enumerate(shards_metadata): 2609 self.assertEqual( 2610 (rank // 2 * 5, (rank % 2) * 5), shard_metadata.shard_offsets 2611 ) 2612 self.assertEqual((5, 5), shard_metadata.shard_sizes) 2613 self.assertEqual(f"rank:{rank}/cuda:{rank}", str(shard_metadata.placement)) 2614 2615 # Validate remote shards. 2616 remote_shards = st.remote_shards() 2617 self.assertEqual(3, len(remote_shards)) 2618 2619 for rpc_rank, shards in remote_shards.items(): 2620 self.assertEqual(1, len(shards)) 2621 for remote_shard in shards: 2622 self.assertEqual(rpc_rank, remote_shard.owner().id) 2623 shard = remote_shard.to_here() 2624 self.assertEqual((5, 5), shard.tensor.size()) 2625 2626 @with_comms 2627 @skip_if_lt_x_gpu(4) 2628 @requires_nccl() 2629 def test_init_from_local_shards_new_group(self): 2630 new_pg = dist.new_group(ranks=[1, 2, 3]) 2631 2632 if self.rank != 0: 2633 local_shard_metadata = ShardMetadata( 2634 shard_offsets=[5 * (self.rank - 1), 0], 2635 shard_sizes=[5, 5], 2636 placement=f"rank:{self.rank}/cuda:{self.rank}", 2637 ) 2638 local_shards = [ 2639 sharded_tensor.Shard( 2640 torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata 2641 ) 2642 ] 2643 2644 st = sharded_tensor.init_from_local_shards( 2645 local_shards, [15, 5], process_group=new_pg 2646 ) 2647 2648 # Verify local shard. 2649 local_shard = st.local_shards()[0] 2650 self.assertEqual( 2651 torch.device(f"cuda:{self.rank}"), local_shard.tensor.device 2652 ) 2653 self.assertEqual((5, 5), local_shard.tensor.size()) 2654 2655 # Verify local shard metadata. 2656 self.assertEqual( 2657 ((self.rank - 1) * 5, 0), local_shard.metadata.shard_offsets 2658 ) 2659 self.assertEqual((5, 5), local_shard.metadata.shard_sizes) 2660 self.assertEqual( 2661 f"rank:{self.rank}/cuda:{self.rank}", 2662 str(local_shard.metadata.placement), 2663 ) 2664 2665 # Verify global metadata. 2666 st_metadata = st.metadata() 2667 shards_metadata = st_metadata.shards_metadata 2668 self.assertEqual(3, len(shards_metadata)) 2669 for rank, shard_metadata in enumerate(shards_metadata): 2670 self.assertEqual((rank * 5, 0), shard_metadata.shard_offsets) 2671 self.assertEqual((5, 5), shard_metadata.shard_sizes) 2672 self.assertEqual( 2673 f"rank:{rank + 1}/cuda:{rank + 1}", str(shard_metadata.placement) 2674 ) 2675 2676 @with_comms 2677 @skip_if_lt_x_gpu(4) 2678 @requires_nccl() 2679 def test_init_from_local_shards_invalid_local_shards(self): 2680 local_shard_metadata = ShardMetadata( 2681 shard_offsets=[(self.rank // 2) * 5, (self.rank % 2) * 5], 2682 shard_sizes=[5, 5], 2683 placement=f"rank:{self.rank}/cuda:{self.rank}", 2684 ) 2685 2686 indices = [[0, 1, 1], [2, 0, 2]] 2687 values = [3.2, 4.5, 5.8] 2688 sparse_tensor = torch.sparse_coo_tensor( 2689 indices, values, (5, 5), device=f"cuda:{self.rank}" 2690 ) 2691 2692 empty_local_shards = [] 2693 with self.assertRaisesRegex(ValueError, "have no local shards on all ranks"): 2694 st = sharded_tensor.init_from_local_shards( 2695 empty_local_shards, [10, 10], init_rrefs=True 2696 ) 2697 2698 wrong_layout_shards = [ 2699 sharded_tensor.Shard(sparse_tensor, local_shard_metadata) 2700 ] 2701 with self.assertRaisesRegex( 2702 ValueError, "Only torch.strided layout is currently supported" 2703 ): 2704 st = sharded_tensor.init_from_local_shards( 2705 wrong_layout_shards, [10, 10], init_rrefs=True 2706 ) 2707 2708 wrong_memory_format_shards = [ 2709 sharded_tensor.Shard( 2710 torch.randn(5, 5, device=f"cuda:{self.rank}").t(), local_shard_metadata 2711 ) 2712 ] 2713 with self.assertRaisesRegex( 2714 ValueError, 2715 "Only torch.contiguous_format memory_format is currently supported", 2716 ): 2717 st = sharded_tensor.init_from_local_shards( 2718 wrong_memory_format_shards, [10, 10], init_rrefs=True 2719 ) 2720 2721 with self.assertRaisesRegex(ValueError, "Shard tensor size does not match"): 2722 wrong_size_shards = [ 2723 sharded_tensor.Shard( 2724 torch.randn(2, 3, device=f"cuda:{self.rank}"), local_shard_metadata 2725 ) 2726 ] 2727 2728 with self.assertRaisesRegex( 2729 ValueError, "Local shard tensor device does not match" 2730 ): 2731 wrong_device_shards = [ 2732 sharded_tensor.Shard(torch.randn(5, 5), local_shard_metadata) 2733 ] 2734 2735 @with_comms 2736 @skip_if_lt_x_gpu(4) 2737 @requires_nccl() 2738 def test_init_from_local_shards_invalid_property_cross_ranks(self): 2739 local_shard_metadata = ShardMetadata( 2740 shard_offsets=[(self.rank // 2) * 5, (self.rank % 2) * 5], 2741 shard_sizes=[5, 5], 2742 placement=f"rank:{self.rank}/cuda:{self.rank}", 2743 ) 2744 tensor_overall_size = [10, 10] if self.rank == 0 else [10, 5] 2745 wrong_dtype_shards = [ 2746 sharded_tensor.Shard( 2747 torch.ones(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata 2748 ) 2749 ] 2750 with self.assertRaisesRegex( 2751 ValueError, 2752 "ShardedTensor global_size property does not match from different ranks!", 2753 ): 2754 st = sharded_tensor.init_from_local_shards( 2755 wrong_dtype_shards, tensor_overall_size, init_rrefs=True 2756 ) 2757 2758 tensor_dtype = torch.int if self.rank == 0 else torch.float32 2759 wrong_dtype_shards = [ 2760 sharded_tensor.Shard( 2761 torch.ones(5, 5, device=f"cuda:{self.rank}", dtype=tensor_dtype), 2762 local_shard_metadata, 2763 ) 2764 ] 2765 with self.assertRaisesRegex( 2766 ValueError, 2767 "ShardedTensor dtype property does not match from different ranks!", 2768 ): 2769 st = sharded_tensor.init_from_local_shards( 2770 wrong_dtype_shards, [10, 10], init_rrefs=True 2771 ) 2772 2773 tensor_requires_grad = True if self.rank == 0 else False 2774 wrong_requires_grad_shards = [ 2775 sharded_tensor.Shard( 2776 torch.randn( 2777 5, 5, device=f"cuda:{self.rank}", requires_grad=tensor_requires_grad 2778 ), 2779 local_shard_metadata, 2780 ) 2781 ] 2782 with self.assertRaisesRegex( 2783 ValueError, 2784 "ShardedTensor requires_grad property does not match from different ranks!", 2785 ): 2786 st = sharded_tensor.init_from_local_shards( 2787 wrong_requires_grad_shards, [10, 10], init_rrefs=True 2788 ) 2789 2790 local_shard_metadata = ShardMetadata( 2791 shard_offsets=[(self.rank // 2) * 5, (self.rank % 2) * 5], 2792 shard_sizes=[5, 5], 2793 placement=f"rank:{self.rank}/cpu", 2794 ) 2795 2796 @with_comms(init_rpc=False, backend="gloo") 2797 @skip_if_lt_x_gpu(4) 2798 def test_init_from_local_shards_invalid_pin_memory(self): 2799 # pin memory can only be on dense cpu 2800 local_shard_metadata = ShardMetadata( 2801 shard_offsets=[(self.rank // 2) * 5, (self.rank % 2) * 5], 2802 shard_sizes=[5, 5], 2803 placement=f"rank:{self.rank}/cpu", 2804 ) 2805 wrong_pin_memory_local_shards = [ 2806 sharded_tensor.Shard( 2807 torch.randn(5, 5, pin_memory=True), local_shard_metadata 2808 ), 2809 sharded_tensor.Shard( 2810 torch.randn(5, 5, pin_memory=False), local_shard_metadata 2811 ), 2812 ] 2813 with self.assertRaisesRegex( 2814 ValueError, "Local shards' tensor pin_memory property need to be the same" 2815 ): 2816 st = sharded_tensor.init_from_local_shards( 2817 wrong_pin_memory_local_shards, [10, 10], init_rrefs=True 2818 ) 2819 2820 tensor_pin_memory = True if self.rank == 0 else False 2821 wrong_pin_memory_shards_cross_ranks = [ 2822 sharded_tensor.Shard( 2823 torch.randn(5, 5, pin_memory=tensor_pin_memory), local_shard_metadata 2824 ) 2825 ] 2826 with self.assertRaisesRegex( 2827 ValueError, 2828 "ShardedTensor pin_memory property does not match from different ranks!", 2829 ): 2830 st = sharded_tensor.init_from_local_shards( 2831 wrong_pin_memory_shards_cross_ranks, [10, 10], init_rrefs=True 2832 ) 2833 2834 @with_comms 2835 @skip_if_lt_x_gpu(4) 2836 @requires_nccl() 2837 def test_init_from_local_shards_invalid_shards_overlap(self): 2838 local_shard_size = [5, 5] if self.rank != 0 else [6, 6] 2839 local_shard_metadata = ShardMetadata( 2840 shard_offsets=[(self.rank // 2) * 5, (self.rank % 2) * 5], 2841 shard_sizes=local_shard_size, 2842 placement=f"rank:{self.rank}/cuda:{self.rank}", 2843 ) 2844 2845 local_shards = [ 2846 sharded_tensor.Shard( 2847 torch.randn(local_shard_size, device=f"cuda:{self.rank}"), 2848 local_shard_metadata, 2849 ) 2850 ] 2851 2852 with self.assertRaisesRegex(ValueError, "overlap"): 2853 sharded_tensor.init_from_local_shards( 2854 local_shards, [10, 10], init_rrefs=True 2855 ) 2856 2857 @with_comms 2858 @skip_if_lt_x_gpu(4) 2859 @requires_nccl() 2860 def test_init_from_local_shards_invalid_shards_gaps(self): 2861 local_shard_size = [5, 5] if self.rank != 0 else [4, 4] 2862 local_shard_metadata = ShardMetadata( 2863 shard_offsets=[(self.rank // 2) * 5, (self.rank % 2) * 5], 2864 shard_sizes=local_shard_size, 2865 placement=f"rank:{self.rank}/cuda:{self.rank}", 2866 ) 2867 2868 local_shards = [ 2869 sharded_tensor.Shard( 2870 torch.randn(local_shard_size, device=f"cuda:{self.rank}"), 2871 local_shard_metadata, 2872 ) 2873 ] 2874 2875 with self.assertRaisesRegex(ValueError, "does not match tensor volume"): 2876 sharded_tensor.init_from_local_shards( 2877 local_shards, [10, 10], init_rrefs=True 2878 ) 2879 2880 @with_comms 2881 @skip_if_lt_x_gpu(4) 2882 @requires_nccl() 2883 def test_init_from_local_shards_and_global_metadata_invalid_shards(self): 2884 local_shard_metadata = ShardMetadata( 2885 shard_offsets=[(self.rank // 2) * 5, (self.rank % 2) * 5], 2886 shard_sizes=[5, 5], 2887 placement=f"rank:{self.rank}/cuda:{self.rank}", 2888 ) 2889 2890 shards_metadata = [] 2891 for r in range(self.world_size): 2892 if r == self.rank: 2893 shards_metadata.append(local_shard_metadata) 2894 else: 2895 shards_metadata.append( 2896 ShardMetadata( 2897 shard_offsets=[(r // 2) * 5, (r % 2) * 5], 2898 shard_sizes=[5, 5], 2899 placement=f"rank:{r}/cuda:{r}", 2900 ) 2901 ) 2902 2903 tensor_properties = TensorProperties( 2904 dtype=torch.get_default_dtype(), 2905 layout=torch.strided, 2906 requires_grad=False, 2907 memory_format=torch.contiguous_format, 2908 pin_memory=False, 2909 ) 2910 2911 sharded_tensor_metadata = sharded_tensor.ShardedTensorMetadata( 2912 shards_metadata=shards_metadata, 2913 size=torch.Size([10, 10]), 2914 tensor_properties=tensor_properties, 2915 ) 2916 2917 empty_local_shards = [] 2918 with self.assertRaisesRegex( 2919 RuntimeError, "does not match number of local shards metadata" 2920 ): 2921 ShardedTensor._init_from_local_shards_and_global_metadata( 2922 empty_local_shards, sharded_tensor_metadata 2923 ) 2924 2925 wrong_num_shards = [ 2926 sharded_tensor.Shard( 2927 torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata 2928 ), 2929 sharded_tensor.Shard( 2930 torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata 2931 ), 2932 ] 2933 with self.assertRaisesRegex( 2934 RuntimeError, "does not match number of local shards metadata" 2935 ): 2936 ShardedTensor._init_from_local_shards_and_global_metadata( 2937 wrong_num_shards, sharded_tensor_metadata 2938 ) 2939 2940 with self.assertRaisesRegex( 2941 ValueError, "Shard tensor size does not match with metadata.shard_lengths" 2942 ): 2943 wrong_size_shards = [ 2944 sharded_tensor.Shard( 2945 torch.randn(2, 3, device=f"cuda:{self.rank}"), local_shard_metadata 2946 ) 2947 ] 2948 2949 with self.assertRaisesRegex( 2950 ValueError, 2951 "Local shard tensor device does not match with local Shard's placement", 2952 ): 2953 wrong_device_shards = [ 2954 sharded_tensor.Shard(torch.randn(5, 5), local_shard_metadata) 2955 ] 2956 2957 wrong_dtype_shards = [ 2958 sharded_tensor.Shard( 2959 torch.ones(5, 5, device=f"cuda:{self.rank}", dtype=torch.int), 2960 local_shard_metadata, 2961 ) 2962 ] 2963 with self.assertRaisesRegex( 2964 ValueError, "Local shards' tensor dtype property is incompatible with" 2965 ): 2966 ShardedTensor._init_from_local_shards_and_global_metadata( 2967 wrong_dtype_shards, sharded_tensor_metadata 2968 ) 2969 2970 indices = [[0, 1, 1], [2, 0, 2]] 2971 values = [3.2, 4.5, 5.8] 2972 sparse_tensor = torch.sparse_coo_tensor( 2973 indices, values, (5, 5), device=f"cuda:{self.rank}" 2974 ) 2975 2976 wrong_layout_shards = [ 2977 sharded_tensor.Shard(sparse_tensor, local_shard_metadata) 2978 ] 2979 with self.assertRaisesRegex( 2980 ValueError, "Local shards' tensor layout property is incompatible with" 2981 ): 2982 ShardedTensor._init_from_local_shards_and_global_metadata( 2983 wrong_layout_shards, sharded_tensor_metadata 2984 ) 2985 2986 wrong_requires_grad_shards = [ 2987 sharded_tensor.Shard( 2988 torch.randn(5, 5, device=f"cuda:{self.rank}", requires_grad=True), 2989 local_shard_metadata, 2990 ) 2991 ] 2992 with self.assertRaisesRegex( 2993 ValueError, 2994 "Local shards' tensor requires_grad property is incompatible with", 2995 ): 2996 ShardedTensor._init_from_local_shards_and_global_metadata( 2997 wrong_requires_grad_shards, sharded_tensor_metadata 2998 ) 2999 3000 wrong_memory_format_shards = [ 3001 sharded_tensor.Shard( 3002 torch.randn(5, 5, device=f"cuda:{self.rank}").t(), local_shard_metadata 3003 ) 3004 ] 3005 with self.assertRaisesRegex( 3006 ValueError, 3007 "Only torch.contiguous_format memory_format is currently supported", 3008 ): 3009 ShardedTensor._init_from_local_shards_and_global_metadata( 3010 wrong_memory_format_shards, sharded_tensor_metadata 3011 ) 3012 # pin_memory can only be on CPU 3013 local_shard_metadata.placement = _remote_device(f"rank:{self.rank}/cpu") 3014 wrong_pin_memory_shards = [ 3015 sharded_tensor.Shard( 3016 torch.randn(5, 5, pin_memory=True), local_shard_metadata 3017 ) 3018 ] 3019 with self.assertRaisesRegex( 3020 ValueError, "Local shards' tensor pin_memory property is incompatible with" 3021 ): 3022 ShardedTensor._init_from_local_shards_and_global_metadata( 3023 wrong_pin_memory_shards, sharded_tensor_metadata 3024 ) 3025 3026 3027class TestShardedTensorCustomOps(ShardedTensorTestBase): 3028 @with_comms 3029 @skip_if_lt_x_gpu(4) 3030 @requires_nccl() 3031 def test_custom_op(self): 3032 @custom_sharded_op_impl(torch.asin) 3033 def my_sharded_asin(types, args, kwargs, process_group): 3034 return torch.asin(args[0].local_shards()[0].tensor) 3035 3036 spec = ChunkShardingSpec( 3037 dim=0, 3038 placements=[ 3039 "rank:0/cuda:0", 3040 "rank:1/cuda:1", 3041 "rank:2/cuda:2", 3042 "rank:3/cuda:3", 3043 ], 3044 ) 3045 3046 st = sharded_tensor.rand(spec, 10, 10) 3047 res = torch.asin(st) 3048 self.assertEqual(res, torch.asin(st.local_shards()[0].tensor)) 3049 3050 @with_comms 3051 @skip_if_lt_x_gpu(4) 3052 @requires_nccl() 3053 def test_custom_op_override(self): 3054 t = torch.rand(10, 10).cuda(self.rank) 3055 3056 from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op 3057 3058 @custom_sharding_spec_op(ChunkShardingSpec, torch.nn.functional.linear) 3059 def my_sharded_linear(types, args, kwargs, process_group): 3060 return t 3061 3062 spec = ChunkShardingSpec( 3063 dim=0, 3064 placements=[ 3065 "rank:0/cuda:0", 3066 "rank:1/cuda:1", 3067 "rank:2/cuda:2", 3068 "rank:3/cuda:3", 3069 ], 3070 ) 3071 m = torch.nn.Linear(32, 16).cuda(self.rank) 3072 shard_parameter(m, "weight", spec) 3073 3074 result = m(torch.rand(15, 32).cuda(self.rank)) 3075 self.assertEqual(t, result) 3076 3077 @with_comms 3078 @skip_if_lt_x_gpu(4) 3079 @requires_nccl() 3080 def test_custom_op_errors(self): 3081 with self.assertRaisesRegex(TypeError, "expects signature"): 3082 3083 @custom_sharded_op_impl(torch.nn.functional.linear) 3084 def my_op1(types, args, kwargs, process_group, random_param): 3085 pass 3086 3087 with self.assertRaisesRegex(TypeError, "expects signature"): 3088 3089 @custom_sharded_op_impl(torch.nn.functional.linear) 3090 def my_op2(types): 3091 pass 3092 3093 3094class TestShardMetadata(ShardedTensorTestBase): 3095 @with_comms 3096 @requires_nccl() 3097 def test_shard_metadata_init(self): 3098 pg = dist.distributed_c10d._get_default_group() 3099 3100 md = ShardMetadata([10], [0]) 3101 self.assertIsNone(md.placement) 3102 with self.assertRaisesRegex(ValueError, "remote device is None"): 3103 _parse_and_validate_remote_device(pg, md.placement) 3104 3105 # String placement gets converted by ctor 3106 md = ShardMetadata([10], [0], "rank:0/cpu") 3107 self.assertEqual(md.placement, _remote_device("rank:0/cpu")) 3108 rank, device = _parse_and_validate_remote_device(pg, md.placement) 3109 self.assertEqual(0, rank) 3110 self.assertEqual(device, torch.device("cpu")) 3111 3112 @with_comms 3113 @requires_nccl() 3114 def test_create_shard_with_no_placement(self): 3115 md = ShardMetadata([0], [10]) 3116 shard = Shard(torch.zeros(10), md) 3117 self.assertIsNone(shard.metadata.placement) 3118 3119 3120class TestShardedTensorSubGroupInit(TestCase): 3121 @spawn_threads_and_init_comms(world_size=4) 3122 def test_sub_process_group_sharded_tensor_init(self): 3123 world_pg = dist.GroupMember.WORLD 3124 rank = dist.get_rank() 3125 3126 sub_group_sz = 2 3127 sub_pg_ranks = [r for r in range(4) if r % sub_group_sz == rank % sub_group_sz] 3128 sub_pg = dist.new_group( 3129 sub_pg_ranks, 3130 backend=dist.get_backend(world_pg), 3131 use_local_synchronization=True, 3132 ) 3133 dist.barrier(sub_pg) 3134 3135 ShardedTensor._init_from_local_shards( 3136 [ 3137 Shard( 3138 tensor=torch.tensor([1, 2, 3], device="meta"), 3139 metadata=ShardMetadata( 3140 shard_offsets=[3 * (rank // sub_group_sz)], 3141 shard_sizes=[3], 3142 placement=f"rank:{rank}/meta", 3143 ), 3144 ) 3145 ], 3146 6, 3147 process_group=sub_pg, 3148 ) 3149 3150 @spawn_threads_and_init_comms(world_size=4) 3151 def test_sub_process_group_placement_validation(self): 3152 world_pg = dist.GroupMember.WORLD 3153 self.assertIsNotNone(world_pg) 3154 rank = dist.get_rank() 3155 3156 sub_group_sz = 2 3157 sub_pg_ranks = [r for r in range(4) if r % sub_group_sz == rank % sub_group_sz] 3158 sub_pg = dist.new_group( 3159 sub_pg_ranks, 3160 backend=dist.get_backend(world_pg), 3161 use_local_synchronization=True, 3162 ) 3163 dist.barrier(sub_pg) 3164 3165 for r in sub_pg_ranks: 3166 _parse_and_validate_remote_device( 3167 sub_pg, _remote_device(f"rank:{r}/cuda:{r % sub_group_sz}") 3168 ) 3169 3170 3171class TestCreateTensorNoProcessGroupMode(TestCase): 3172 def test_init_from_local_shards_and_global_metadata(self): 3173 st_metadata: ShardedTensorMetadata = ShardedTensorMetadata( 3174 shards_metadata=[ 3175 ShardMetadata( 3176 shard_offsets=[0, 0], shard_sizes=[2, 2], placement="rank:0/cpu" 3177 ), 3178 ShardMetadata( 3179 shard_offsets=[2, 0], shard_sizes=[2, 2], placement="rank:1/cpu" 3180 ), 3181 ], 3182 size=torch.Size([4, 2]), 3183 ) 3184 st_local_shards: List[Shard] = [] 3185 for shard_metadata in st_metadata.shards_metadata: 3186 st_local_shards.append( 3187 Shard( 3188 tensor=torch.zeros( 3189 shard_metadata.shard_sizes, 3190 device=shard_metadata.placement.device(), 3191 ), 3192 metadata=shard_metadata, 3193 ) 3194 ) 3195 3196 ShardedTensorBase._init_from_local_shards_and_global_metadata( 3197 local_shards=st_local_shards, 3198 sharded_tensor_metadata=st_metadata, 3199 ) 3200 3201 def test_non_contiguous_local_shards(self): 3202 st_metadata: ShardedTensorMetadata = ShardedTensorMetadata( 3203 shards_metadata=[ 3204 ShardMetadata( 3205 shard_offsets=[0, 0], shard_sizes=[2, 2], placement="rank:0/cpu" 3206 ), 3207 ShardMetadata( 3208 shard_offsets=[2, 0], shard_sizes=[2, 2], placement="rank:1/cpu" 3209 ), 3210 ], 3211 size=torch.Size([4, 2]), 3212 ) 3213 st_local_shards: List[Shard] = [] 3214 src = torch.randn(4, 2) 3215 for shard_metadata in st_metadata.shards_metadata: 3216 offsets = shard_metadata.shard_offsets 3217 sizes = shard_metadata.shard_sizes 3218 st_local_shards.append( 3219 Shard( 3220 tensor=src[ 3221 offsets[0] : offsets[0] + sizes[0], 3222 offsets[1] : offsets[1] + sizes[1], 3223 ], 3224 metadata=shard_metadata, 3225 ) 3226 ) 3227 3228 ShardedTensorBase._init_from_local_shards_and_global_metadata( 3229 local_shards=st_local_shards, 3230 sharded_tensor_metadata=st_metadata, 3231 ) 3232 3233 3234if __name__ == "__main__": 3235 run_tests() 3236