• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["oncall: distributed"]
2
3import sys
4from itertools import product
5
6import torch
7from torch.distributed._shard import _shard_tensor, sharded_tensor
8from torch.distributed._shard.sharding_spec import EnumerableShardingSpec, ShardMetadata
9from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
10from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
11from torch.testing._internal.distributed._shard.sharded_tensor import (
12    ShardedTensorTestBase,
13    with_comms,
14)
15from torch.testing._internal.distributed._shard.sharded_tensor._test_st_common import (
16    _chunk_sharding_specs_list_for_test,
17)
18
19
20if TEST_WITH_DEV_DBG_ASAN:
21    print(
22        "Skip dev-asan as torch + multiprocessing spawn have known issues",
23        file=sys.stderr,
24    )
25    sys.exit(0)
26
27
28class TestReshard(ShardedTensorTestBase):
29    def _run_sharded_tensor_reshard(self, sharding_spec, reshard_spec, input_size):
30        torch.manual_seed(0)
31        local_tensor = torch.rand(*input_size).cuda(self.rank)
32        st = _shard_tensor(local_tensor, sharding_spec)
33        st_compare = _shard_tensor(local_tensor, reshard_spec)
34        st.reshard(reshard_spec)
35        self.assertEqual(1, len(st.local_shards()))
36        self.assertEqual(1, len(st_compare.local_shards()))
37        st_compare._metadata.shards_metadata.sort(
38            key=lambda metadata: metadata.placement.rank()
39        )
40        self.assertEqual(st._metadata, st_compare._metadata)
41        self.assertEqual(st.local_tensor(), st_compare.local_tensor())
42        self.assertEqual(
43            st.local_shards()[0].metadata, st_compare.local_shards()[0].metadata
44        )
45
46    @with_comms(init_rpc=False)
47    @skip_if_lt_x_gpu(4)
48    @requires_nccl()
49    def test_sharded_tensor_reshard(self):
50        dims = [0, 1]
51        for sharding_dim, reshard_dim in product(dims, dims):
52            specs = _chunk_sharding_specs_list_for_test(
53                [sharding_dim, reshard_dim], seed=5
54            )
55            spec, reshard_spec = specs[0], specs[1]
56            self._run_sharded_tensor_reshard(spec, reshard_spec, [13, 21])
57            self._run_sharded_tensor_reshard(spec, reshard_spec, [14, 23])
58            self._run_sharded_tensor_reshard(spec, reshard_spec, [15, 26])
59            self._run_sharded_tensor_reshard(spec, reshard_spec, [12, 24])
60
61    @with_comms(init_rpc=False)
62    @skip_if_lt_x_gpu(4)
63    @requires_nccl()
64    def test_sharded_tensor_reshard_errors(self):
65        specs = _chunk_sharding_specs_list_for_test([0, 1], seed=6)
66        spec, reshard_spec = specs[0], specs[1]
67        enumerable_sharding_spec = EnumerableShardingSpec(
68            [
69                ShardMetadata(
70                    shard_offsets=[0, 0],
71                    shard_sizes=[5, 5],
72                    placement="rank:0/cuda:0",
73                ),
74                ShardMetadata(
75                    shard_offsets=[5, 0],
76                    shard_sizes=[5, 5],
77                    placement="rank:1/cuda:1",
78                ),
79            ]
80        )
81        st = sharded_tensor.rand(spec, 24, 12)
82        with self.assertRaisesRegex(
83            NotImplementedError, "Only ChunkShardingSpec supported for reshard."
84        ):
85            st.reshard(enumerable_sharding_spec)
86        st._local_shards = [st.local_shards()[0], st.local_shards()[0]]
87        with self.assertRaisesRegex(
88            NotImplementedError, "Only single local shard supported for reshard."
89        ):
90            st.reshard(reshard_spec)
91
92
93if __name__ == "__main__":
94    run_tests()
95