1# Owner(s): ["oncall: distributed"] 2 3import sys 4from itertools import product 5 6import torch 7from torch.distributed._shard import _shard_tensor, sharded_tensor 8from torch.distributed._shard.sharding_spec import EnumerableShardingSpec, ShardMetadata 9from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu 10from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN 11from torch.testing._internal.distributed._shard.sharded_tensor import ( 12 ShardedTensorTestBase, 13 with_comms, 14) 15from torch.testing._internal.distributed._shard.sharded_tensor._test_st_common import ( 16 _chunk_sharding_specs_list_for_test, 17) 18 19 20if TEST_WITH_DEV_DBG_ASAN: 21 print( 22 "Skip dev-asan as torch + multiprocessing spawn have known issues", 23 file=sys.stderr, 24 ) 25 sys.exit(0) 26 27 28class TestReshard(ShardedTensorTestBase): 29 def _run_sharded_tensor_reshard(self, sharding_spec, reshard_spec, input_size): 30 torch.manual_seed(0) 31 local_tensor = torch.rand(*input_size).cuda(self.rank) 32 st = _shard_tensor(local_tensor, sharding_spec) 33 st_compare = _shard_tensor(local_tensor, reshard_spec) 34 st.reshard(reshard_spec) 35 self.assertEqual(1, len(st.local_shards())) 36 self.assertEqual(1, len(st_compare.local_shards())) 37 st_compare._metadata.shards_metadata.sort( 38 key=lambda metadata: metadata.placement.rank() 39 ) 40 self.assertEqual(st._metadata, st_compare._metadata) 41 self.assertEqual(st.local_tensor(), st_compare.local_tensor()) 42 self.assertEqual( 43 st.local_shards()[0].metadata, st_compare.local_shards()[0].metadata 44 ) 45 46 @with_comms(init_rpc=False) 47 @skip_if_lt_x_gpu(4) 48 @requires_nccl() 49 def test_sharded_tensor_reshard(self): 50 dims = [0, 1] 51 for sharding_dim, reshard_dim in product(dims, dims): 52 specs = _chunk_sharding_specs_list_for_test( 53 [sharding_dim, reshard_dim], seed=5 54 ) 55 spec, reshard_spec = specs[0], specs[1] 56 self._run_sharded_tensor_reshard(spec, reshard_spec, [13, 21]) 57 self._run_sharded_tensor_reshard(spec, reshard_spec, [14, 23]) 58 self._run_sharded_tensor_reshard(spec, reshard_spec, [15, 26]) 59 self._run_sharded_tensor_reshard(spec, reshard_spec, [12, 24]) 60 61 @with_comms(init_rpc=False) 62 @skip_if_lt_x_gpu(4) 63 @requires_nccl() 64 def test_sharded_tensor_reshard_errors(self): 65 specs = _chunk_sharding_specs_list_for_test([0, 1], seed=6) 66 spec, reshard_spec = specs[0], specs[1] 67 enumerable_sharding_spec = EnumerableShardingSpec( 68 [ 69 ShardMetadata( 70 shard_offsets=[0, 0], 71 shard_sizes=[5, 5], 72 placement="rank:0/cuda:0", 73 ), 74 ShardMetadata( 75 shard_offsets=[5, 0], 76 shard_sizes=[5, 5], 77 placement="rank:1/cuda:1", 78 ), 79 ] 80 ) 81 st = sharded_tensor.rand(spec, 24, 12) 82 with self.assertRaisesRegex( 83 NotImplementedError, "Only ChunkShardingSpec supported for reshard." 84 ): 85 st.reshard(enumerable_sharding_spec) 86 st._local_shards = [st.local_shards()[0], st.local_shards()[0]] 87 with self.assertRaisesRegex( 88 NotImplementedError, "Only single local shard supported for reshard." 89 ): 90 st.reshard(reshard_spec) 91 92 93if __name__ == "__main__": 94 run_tests() 95