• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["oncall: distributed"]
2
3import copy
4import io
5import itertools
6import math
7import pickle
8import sys
9from typing import List
10
11import torch
12import torch.distributed as dist
13from torch.distributed import distributed_c10d, rpc
14from torch.distributed._shard import sharded_tensor
15from torch.distributed._shard.api import (
16    _collect_local_shard,
17    _reshard_output,
18    _shard_tensor,
19    load_with_process_group,
20    shard_parameter,
21)
22from torch.distributed._shard.sharded_tensor import (
23    custom_sharded_op_impl,
24    pre_load_state_dict_hook,
25    Shard,
26    ShardedTensor,
27    ShardedTensorBase,
28    ShardedTensorMetadata,
29    state_dict_hook,
30)
31from torch.distributed._shard.sharded_tensor.api import (
32    _create_tensor_from_params,
33    TensorProperties,
34)
35from torch.distributed._shard.sharded_tensor.utils import (
36    _parse_and_validate_remote_device,
37)
38from torch.distributed._shard.sharding_spec import (
39    ChunkShardingSpec,
40    EnumerableShardingSpec,
41    ShardMetadata,
42)
43from torch.distributed.remote_device import _remote_device
44from torch.testing._internal.common_distributed import (
45    requires_nccl,
46    skip_if_lt_x_gpu,
47    spawn_threads_and_init_comms,
48    tp_transports,
49)
50from torch.testing._internal.common_utils import (
51    run_tests,
52    skip_but_pass_in_sandcastle_if,
53    TEST_CUDA,
54    TEST_WITH_DEV_DBG_ASAN,
55    TestCase,
56)
57from torch.testing._internal.distributed._shard.sharded_tensor import (
58    ShardedTensorTestBase,
59    with_comms,
60)
61from torch.testing._internal.distributed._shard.sharded_tensor._test_st_common import (
62    _chunk_sharding_specs_list_for_test,
63    MyShardedModel1,
64)
65
66
67if TEST_WITH_DEV_DBG_ASAN:
68    print(
69        "Skip dev-asan as torch + multiprocessing spawn have known issues",
70        file=sys.stderr,
71    )
72    sys.exit(0)
73
74
75class TestShardedTensorMetadata(TestCase):
76    def test_serialize_and_deserialize(self):
77        shard_metadatas = [
78            ShardMetadata(
79                shard_offsets=[0, 0],
80                shard_sizes=[5, 5],
81                placement="rank:0/cuda:0",
82            ),
83            ShardMetadata(
84                shard_offsets=[0, 5],
85                shard_sizes=[5, 5],
86                placement="rank:1/cuda:1",
87            ),
88            ShardMetadata(
89                shard_offsets=[5, 0],
90                shard_sizes=[5, 5],
91                placement="rank:2/cuda:2",
92            ),
93            ShardMetadata(
94                shard_offsets=[5, 5],
95                shard_sizes=[5, 5],
96                placement="rank:3/cuda:3",
97            ),
98        ]
99
100        dtypes = [
101            torch.float,
102            torch.double,
103            torch.cfloat,
104            torch.cdouble,
105            torch.half,
106            torch.bfloat16,
107            torch.uint8,
108            torch.int8,
109            torch.short,
110            torch.int,
111            torch.long,
112            torch.bool,
113        ]
114
115        layouts = [torch.strided, torch.sparse_coo]
116        requires_grads = [True, False]
117        memory_formats = [
118            torch.contiguous_format,
119            torch.channels_last,
120            torch.preserve_format,
121        ]
122        pin_memories = [True, False]
123
124        for tensor_properties_input in itertools.product(
125            dtypes, layouts, requires_grads, memory_formats, pin_memories
126        ):
127            (
128                dtype,
129                layout,
130                requires_grad,
131                memory_format,
132                pin_memory,
133            ) = tensor_properties_input
134
135            expected_st_metadata = sharded_tensor.ShardedTensorMetadata(
136                shard_metadatas,
137                (10, 10),
138                TensorProperties(
139                    dtype, layout, requires_grad, memory_format, pin_memory
140                ),
141            )
142
143            pickled_obj = pickle.dumps(expected_st_metadata)
144            st_metadata = pickle.loads(pickled_obj)
145            self.assertEqual(expected_st_metadata, st_metadata)
146
147
148class TestCreateTensorFromParams(TestCase):
149    @skip_but_pass_in_sandcastle_if(not TEST_CUDA, "CUDA GPU is needed")
150    def test_empty(self):
151        expected_dtype = torch.double
152        tensor_properties = TensorProperties(
153            dtype=expected_dtype,
154            layout=torch.strided,
155            requires_grad=False,
156            pin_memory=False,
157            memory_format=torch.contiguous_format,
158        )
159        local_device = torch.device("cuda:0")
160        local_tensor = _create_tensor_from_params(
161            5, 10, local_device=local_device, tensor_properties=tensor_properties
162        )
163        self.assertEqual(local_device, local_tensor.device)
164        self.assertEqual(expected_dtype, local_tensor.dtype)
165        self.assertEqual(torch.strided, local_tensor.layout)
166        self.assertEqual(False, local_tensor.requires_grad)
167
168
169class TestShardParameter(ShardedTensorTestBase):
170    @with_comms(init_rpc=False)
171    @skip_if_lt_x_gpu(4)
172    @requires_nccl()
173    def test_shard_parameter(self):
174        spec = ChunkShardingSpec(
175            dim=0,
176            placements=[
177                "rank:0/cuda:0",
178                "rank:1/cuda:1",
179                "rank:2/cuda:2",
180                "rank:3/cuda:3",
181            ],
182        )
183
184        fc = torch.nn.Linear(12, 12).cuda(self.rank)
185        weight_og = fc.weight.clone()
186        shard_parameter(fc, "weight", spec)
187
188        # Verify.
189        self.assertTrue(isinstance(fc.weight, ShardedTensor))
190        local_shards = fc.weight.local_shards()
191        self.assertEqual(1, len(local_shards))
192        self.assertEqual(torch.Size([3, 12]), local_shards[0].tensor.size())
193        self.assertEqual(3, local_shards[0].tensor.size(0))
194        self.assertEqual(12, local_shards[0].tensor.size(1))
195        self.assertEqual(
196            torch.narrow(weight_og, 0, 3 * self.rank, 3), local_shards[0].tensor
197        )
198
199    @with_comms(init_rpc=False)
200    @skip_if_lt_x_gpu(4)
201    @requires_nccl()
202    def test_shard_parameter_errors(self):
203        spec = ChunkShardingSpec(
204            dim=0,
205            placements=[
206                "rank:0/cuda:0",
207                "rank:1/cuda:1",
208                "rank:2/cuda:2",
209                "rank:3/cuda:3",
210            ],
211        )
212
213        fc = torch.nn.Linear(12, 12).cuda(self.rank)
214        with self.assertRaisesRegex(ValueError, "does not match with src_rank"):
215            shard_parameter(fc, "weight", spec, src_rank=self.rank)
216
217        with self.assertRaisesRegex(AttributeError, "has no attribute"):
218            shard_parameter(fc, "foo", spec)
219
220        with self.assertRaisesRegex(
221            ValueError, "Expected Linear.bias to be a Tensor, but found str"
222        ):
223            del fc.bias
224            fc.bias = "foo"
225            shard_parameter(fc, "bias", spec)
226
227        with self.assertRaisesRegex(ValueError, "not a contiguous Tensor"):
228            fc.bias = torch.rand(10, 10).cuda(self.rank).t()
229            shard_parameter(fc, "bias", spec)
230
231        spec = ChunkShardingSpec(
232            dim=0,
233            placements=[
234                f"rank:{self.rank}/cuda:0",
235                "rank:1/cuda:1",
236                "rank:2/cuda:2",
237                "rank:3/cuda:3",
238            ],
239        )
240        with self.assertRaisesRegex(ValueError, "does not match with sharding_spec"):
241            shard_parameter(fc, "weight", spec)
242
243        spec = EnumerableShardingSpec(
244            [
245                ShardMetadata(
246                    shard_offsets=[0, 0],
247                    shard_sizes=[5, 5],
248                    placement="rank:0/cuda:0",
249                ),
250                ShardMetadata(
251                    shard_offsets=[5, 0],
252                    shard_sizes=[5, 5],
253                    placement="rank:1/cuda:1",
254                ),
255            ]
256        )
257        with self.assertRaisesRegex(NotImplementedError, "not implemented yet!"):
258            shard_parameter(fc, "weight", spec)
259
260
261class TestShardTensor(ShardedTensorTestBase):
262    @with_comms(init_rpc=False)
263    @skip_if_lt_x_gpu(4)
264    @requires_nccl()
265    def test_shard_tensor(self):
266        spec = ChunkShardingSpec(
267            dim=0,
268            placements=[
269                "rank:0/cuda:0",
270                "rank:1/cuda:1",
271                "rank:2/cuda:2",
272                "rank:3/cuda:3",
273            ],
274        )
275        tensor = torch.rand(12, 12).cuda(self.rank)
276        st = _shard_tensor(tensor, spec)
277
278        # Verify.
279        self.assertTrue(isinstance(st, sharded_tensor.ShardedTensor))
280        local_shard = st.local_tensor()
281        self.assertEqual(1, len(st.local_shards()))
282        self.assertEqual(torch.Size([3, 12]), local_shard.size())
283        self.assertEqual(torch.narrow(tensor, 0, 3 * self.rank, 3), local_shard)
284
285    @with_comms(init_rpc=False)
286    @skip_if_lt_x_gpu(4)
287    @requires_nccl()
288    def test_shard_tensor_with_empty_shard(self):
289        spec = ChunkShardingSpec(
290            dim=0,
291            placements=[
292                "rank:0/cuda:0",
293                "rank:1/cuda:1",
294                "rank:2/cuda:2",
295                "rank:3/cuda:3",
296            ],
297        )
298        tensor = torch.rand(9, 12).cuda(self.rank)
299        st = _shard_tensor(tensor, spec)
300
301        # Verify.
302        self.assertTrue(isinstance(st, sharded_tensor.ShardedTensor))
303        sms = st.metadata().shards_metadata
304        self.assertEqual(len(sms), 4)
305        for sm in sms:
306            self.assertTrue(sm.shard_offsets[0] + sm.shard_sizes[0] <= tensor.size(0))
307
308        local_shard = st.local_tensor()
309        self.assertEqual(1, len(st.local_shards()))
310        if dist.get_rank() < 3:
311            self.assertEqual(torch.Size([3, 12]), local_shard.size())
312            self.assertEqual(torch.narrow(tensor, 0, 3 * self.rank, 3), local_shard)
313        else:
314            self.assertEqual(torch.Size([0, 12]), local_shard.size())
315
316    @with_comms(init_rpc=False)
317    @skip_if_lt_x_gpu(4)
318    @requires_nccl()
319    def test_shard_tensor_errors(self):
320        spec = ChunkShardingSpec(
321            dim=0,
322            placements=[
323                "rank:0/cuda:0",
324                "rank:1/cuda:1",
325                "rank:2/cuda:2",
326                "rank:3/cuda:3",
327            ],
328        )
329        tensor = torch.rand(12, 12).cuda(self.rank)
330
331        with self.assertRaisesRegex(ValueError, "does not match with src_rank"):
332            _shard_tensor(tensor, spec, src_rank=self.rank)
333
334        with self.assertRaisesRegex(ValueError, "not a contiguous Tensor"):
335            tensor_t = torch.rand(12, 12).cuda(self.rank).t()
336            _shard_tensor(tensor_t, spec)
337
338        spec = ChunkShardingSpec(
339            dim=0,
340            placements=[
341                f"rank:{self.rank}/cuda:0",
342                "rank:1/cuda:1",
343                "rank:2/cuda:2",
344                "rank:3/cuda:3",
345            ],
346        )
347        with self.assertRaisesRegex(ValueError, "does not match with sharding_spec"):
348            _shard_tensor(tensor, spec)
349
350        spec = EnumerableShardingSpec(
351            [
352                ShardMetadata(
353                    shard_offsets=[0, 0],
354                    shard_sizes=[5, 5],
355                    placement="rank:0/cuda:0",
356                ),
357                ShardMetadata(
358                    shard_offsets=[5, 0],
359                    shard_sizes=[5, 5],
360                    placement="rank:1/cuda:1",
361                ),
362            ]
363        )
364        with self.assertRaisesRegex(NotImplementedError, "not implemented yet!"):
365            _shard_tensor(tensor, spec)
366
367
368class TestModuleHookApi(ShardedTensorTestBase):
369    class DummyNNModule(torch.nn.Module):
370        def __init__(self, spec, tensor_size):
371            super().__init__()
372            self.st = sharded_tensor.rand(spec, *tensor_size)
373
374        def forward(self):
375            return self.st
376
377    @with_comms(init_rpc=False)
378    @skip_if_lt_x_gpu(4)
379    @requires_nccl()
380    def test_reshard_output(self):
381        specs = _chunk_sharding_specs_list_for_test([0, 1], seed=5)
382        spec, reshard_spec = specs[0], specs[1]
383        test_module = self.DummyNNModule(spec, [24, 12])
384        st = test_module()
385        local_shard = st.local_tensor()
386        pg = dist.distributed_c10d._get_default_group()
387        st_compare = ShardedTensor._init_from_local_shards(
388            copy.deepcopy(st.local_shards()),
389            st.size(),
390            process_group=pg,
391        )
392        st_compare._sharding_spec = copy.deepcopy(spec)
393        st_compare.reshard(reshard_spec)
394        test_module = _reshard_output(test_module, reshard_spec)
395        st = test_module()
396        local_shard = st.local_tensor()
397        local_shard_compare = st_compare.local_tensor()
398        self.assertEqual(local_shard, local_shard_compare)
399        self.assertEqual(local_shard.size(0), 24)
400        self.assertEqual(local_shard.size(1), 3)
401
402    @with_comms(init_rpc=False)
403    @skip_if_lt_x_gpu(4)
404    @requires_nccl()
405    def test_collect_local_shard(self):
406        specs = _chunk_sharding_specs_list_for_test([0], seed=5)
407        spec = specs[0]
408        test_module = self.DummyNNModule(spec, [23, 15])
409        st = test_module()
410        local_shard = st.local_tensor()
411        test_module = _collect_local_shard(test_module)
412        output = test_module()
413        self.assertTrue(isinstance(output, torch.Tensor))
414        self.assertEqual(local_shard, output)
415
416
417class TestLocalTensor(ShardedTensorTestBase):
418    @with_comms(init_rpc=False)
419    @skip_if_lt_x_gpu(4)
420    @requires_nccl()
421    def test_local_tensor(self):
422        spec = ChunkShardingSpec(
423            dim=0,
424            placements=[
425                "rank:0/cuda:0",
426                "rank:1/cuda:1",
427                "rank:2/cuda:2",
428                "rank:3/cuda:3",
429            ],
430        )
431        st = sharded_tensor.rand(spec, 24, 12)
432        local_shard = st.local_tensor()
433        self.assertEqual(torch.Size([6, 12]), local_shard.size())
434        self.assertEqual(st.local_tensor(), local_shard)
435
436    @with_comms(init_rpc=False)
437    @skip_if_lt_x_gpu(4)
438    @requires_nccl()
439    def test_local_tensor_error(self):
440        spec = ChunkShardingSpec(
441            dim=0,
442            placements=[
443                "rank:0/cuda:0",
444                "rank:0/cuda:0",
445                "rank:1/cuda:1",
446                "rank:1/cuda:1",
447                "rank:1/cuda:1",
448                "rank:2/cuda:2",
449                "rank:2/cuda:2",
450                "rank:2/cuda:2",
451                "rank:3/cuda:3",
452                "rank:3/cuda:3",
453            ],
454        )
455        st = sharded_tensor.rand(spec, 24, 12)
456        with self.assertRaisesRegex(
457            NotImplementedError, "Only single local shard is supported."
458        ):
459            local_shard = st.local_tensor()
460
461
462class TestShardedTensorChunked(ShardedTensorTestBase):
463    @with_comms
464    @skip_if_lt_x_gpu(4)
465    @requires_nccl()
466    def test_sharded_tensor_metadata(self):
467        spec = ChunkShardingSpec(
468            dim=0,
469            placements=[
470                "rank:0/cuda:0",
471                "rank:1/cuda:1",
472                "rank:2/cuda:2",
473                "rank:3/cuda:3",
474            ],
475        )
476
477        st = sharded_tensor.empty(spec, 10, 20, init_rrefs=True)
478        st_metadata = st.metadata()
479        self.assertEqual(torch.Size([10, 20]), st_metadata.size)
480        self.assertEqual(torch.Size([10, 20]), st.size())
481        self.assertEqual(torch.float, st.dtype)
482        self.assertEqual(torch.strided, st.layout)
483        self.assertEqual(False, st.requires_grad)
484        self.assertTrue(st.is_contiguous())
485        self.assertFalse(st.is_pinned())
486
487        st = sharded_tensor.empty(spec, 10, 20, requires_grad=True, init_rrefs=True)
488        self.assertEqual(True, st.requires_grad)
489
490        st = sharded_tensor.empty(spec, 10, 20, dtype=torch.double, init_rrefs=True)
491        self.assertEqual(torch.double, st.dtype)
492
493        # Need CPU for pin_memory
494        spec = ChunkShardingSpec(
495            dim=0,
496            placements=[
497                "rank:0/cpu",
498                "rank:1/cpu",
499                "rank:2/cpu",
500                "rank:3/cpu",
501            ],
502        )
503
504        st = sharded_tensor.empty(spec, 10, 20, pin_memory=True, init_rrefs=True)
505        self.assertEqual(True, st.is_pinned())
506
507        # test read only properties, they're read only as we can't simply change
508        # the global metadata without changing the underlying shard's properties
509        with self.assertRaisesRegex(RuntimeError, "torch function '__set__'"):
510            st.requires_grad = True
511
512    @with_comms
513    @skip_if_lt_x_gpu(4)
514    @requires_nccl()
515    def test_complete_world_size(self):
516        for dim in [0, -2]:
517            spec = ChunkShardingSpec(
518                dim=dim,
519                placements=[
520                    "rank:0/cuda:0",
521                    "rank:1/cuda:1",
522                    "rank:2/cuda:2",
523                    "rank:3/cuda:3",
524                ],
525            )
526            st = sharded_tensor.empty(spec, 10, 20, init_rrefs=True)
527
528            # Validate local shard.
529            local_shards = st.local_shards()
530            self.assertEqual(1, len(local_shards))
531            local_shard = local_shards[0].tensor
532            self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device)
533            if self.rank == 3:
534                self.assertEqual((1, 20), local_shard.size())
535            else:
536                self.assertEqual((3, 20), local_shard.size())
537
538            # Validate global metadata.
539            st_metadata = st.metadata()
540            shards_metadata = st_metadata.shards_metadata
541            self.assertEqual(4, len(shards_metadata))
542
543            for rank, shard_metadata in enumerate(shards_metadata):
544                self.assertEqual([rank * 3, 0], shard_metadata.shard_offsets)
545                if rank == 3:
546                    self.assertEqual([1, 20], shard_metadata.shard_sizes)
547                else:
548                    self.assertEqual([3, 20], shard_metadata.shard_sizes)
549                self.assertEqual(
550                    f"rank:{rank}/cuda:{rank}", str(shard_metadata.placement)
551                )
552
553            # Validate remote shards.
554            remote_shards = st.remote_shards()
555            self.assertEqual(3, len(remote_shards))
556
557            for rpc_rank, shards in remote_shards.items():
558                self.assertEqual(1, len(shards))
559                for remote_shard in shards:
560                    self.assertEqual(rpc_rank, remote_shard.owner().id)
561                    shard = remote_shard.to_here()
562                    self.assertEqual(
563                        f"rank:{rpc_rank}/cuda:{rpc_rank}",
564                        str(shard.metadata.placement),
565                    )
566                    if rpc_rank == 3:
567                        self.assertEqual((1, 20), shard.tensor.size())
568                    else:
569                        self.assertEqual((3, 20), shard.tensor.size())
570
571    @with_comms
572    @skip_if_lt_x_gpu(4)
573    @requires_nccl()
574    def test_create_sharded_tensor_with_ones(self):
575        """Test sharded_tensor.ones(...)"""
576
577        spec = ChunkShardingSpec(
578            dim=0,
579            placements=[
580                "rank:0/cuda:0",
581                "rank:1/cuda:1",
582                "rank:2/cuda:2",
583                "rank:3/cuda:3",
584            ],
585        )
586        h, w = 10, 20
587        st = sharded_tensor.ones(spec, h, w)
588
589        # Validate local shard is initialized with torch.ones
590        local_shards = st.local_shards()
591        self.assertEqual(1, len(local_shards))
592        local_shard = local_shards[0].tensor
593        self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device)
594        # The split: for rank!=3 ceil(h/4)=3  for rank=3 1
595        expected_h = 1 if self.rank == 3 else math.ceil(h / 4)
596        self.assertEqual((expected_h, w), local_shard.size())
597        self.assertEqual(local_shard, torch.ones(expected_h, w))
598
599    @with_comms
600    @skip_if_lt_x_gpu(4)
601    @requires_nccl()
602    def test_gather_even(self) -> None:
603        """Test _sharded_tensor.gather(...) with evenly distributed._shards"""
604
605        spec = ChunkShardingSpec(
606            dim=0,
607            placements=[
608                "rank:0/cuda:0",
609                "rank:1/cuda:1",
610                "rank:2/cuda:2",
611                "rank:3/cuda:3",
612            ],
613        )
614        h, w = 10, 20
615        st = sharded_tensor.ones(spec, h, w)
616
617        full_tensor = None
618        dst = 1
619        if self.rank == dst:
620            full_tensor = torch.zeros(
621                h,
622                w,
623                device=torch.device(f"cuda:{dst}"),
624            )
625        st.gather(dst, full_tensor)
626
627        if self.rank == dst:
628            self.assertEqual(full_tensor, torch.ones(h, w))
629        else:
630            self.assertIsNone(full_tensor)
631
632    @with_comms
633    @skip_if_lt_x_gpu(4)
634    @requires_nccl()
635    def test_gather_uneven(self) -> None:
636        """Test _sharded_tensor.gather(...) with unevenly distributed._shards"""
637
638        spec = ChunkShardingSpec(
639            dim=0,
640            placements=[
641                "rank:0/cuda:0",
642                "rank:0/cuda:0",
643                "rank:1/cuda:1",
644                "rank:1/cuda:1",
645                "rank:2/cuda:2",
646            ],
647        )
648        h, w = 10, 20
649        st = sharded_tensor.ones(spec, h, w)
650
651        full_tensor = None
652        dst = 1
653        if self.rank == dst:
654            full_tensor = torch.zeros(
655                h,
656                w,
657                device=torch.device(f"cuda:{dst}"),
658            )
659        st.gather(dst, full_tensor)
660
661        if self.rank == dst:
662            self.assertEqual(full_tensor, torch.ones(h, w))
663        else:
664            self.assertIsNone(full_tensor)
665
666    @with_comms
667    @skip_if_lt_x_gpu(4)
668    @requires_nccl()
669    def test_create_sharded_tensor_with_zeros(self):
670        """Test sharded_tensor.zeros(...)"""
671
672        spec = ChunkShardingSpec(
673            dim=0,
674            placements=[
675                "rank:0/cuda:0",
676                "rank:1/cuda:1",
677                "rank:2/cuda:2",
678                "rank:3/cuda:3",
679            ],
680        )
681        h, w = 10, 20
682        st = sharded_tensor.zeros(spec, h, w)
683
684        # Validate local shard is initialized with torch.zeros
685        local_shards = st.local_shards()
686        self.assertEqual(1, len(local_shards))
687        local_shard = local_shards[0].tensor
688        self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device)
689        # The split: for rank!=3 ceil(h/4)=3  for rank=3 1
690        expected_h = 1 if self.rank == 3 else math.ceil(h / 4)
691        self.assertEqual((expected_h, w), local_shard.size())
692        self.assertEqual(local_shard, torch.zeros(expected_h, w))
693
694    @with_comms
695    @skip_if_lt_x_gpu(4)
696    @requires_nccl()
697    def test_create_sharded_tensor_with_rand(self):
698        """Test sharded_tensor.rand(...)/randn(...)"""
699
700        spec = ChunkShardingSpec(
701            dim=0,
702            placements=[
703                "rank:0/cuda:0",
704                "rank:1/cuda:1",
705                "rank:2/cuda:2",
706                "rank:3/cuda:3",
707            ],
708        )
709        h, w = 8, 2
710        seed = 1234
711
712        expected_h = 2
713        expected_device = torch.device(f"cuda:{self.rank}")
714        dtype = torch.double
715        torch.manual_seed(seed)
716        # Test sharded_tensor.rand creation
717        expected = torch.rand(expected_h, w, device=expected_device, dtype=dtype)
718        # reset seed to ensure the same random numbers are generated
719        torch.manual_seed(seed)
720        st = sharded_tensor.rand(spec, h, w, dtype=dtype)
721
722        # Validate local shard is initialized with torch.rand
723        local_shards = st.local_shards()
724        self.assertEqual(1, len(local_shards))
725        local_shard = local_shards[0].tensor
726        self.assertEqual(expected_device, local_shard.device)
727        self.assertEqual((expected_h, w), local_shard.size())
728        self.assertEqual(expected, local_shard)
729
730        # Test sharded_tensor.randn creation
731        torch.manual_seed(seed)
732        expected_randn = torch.randn(expected_h, w, device=expected_device, dtype=dtype)
733        # reset seed to ensure the same random numbers are generated
734        torch.manual_seed(seed)
735        st_randn = sharded_tensor.randn(spec, h, w, dtype=dtype)
736
737        # Validate local shard is initialized with torch.randn
738        local_shards = st_randn.local_shards()
739        self.assertEqual(1, len(local_shards))
740        local_shard = local_shards[0].tensor
741        self.assertEqual(expected_device, local_shard.device)
742        self.assertEqual((expected_h, w), local_shard.size())
743        self.assertEqual(expected_randn, local_shard)
744
745    @with_comms
746    @skip_if_lt_x_gpu(4)
747    @requires_nccl()
748    def test_create_sharded_tensor_with_full(self):
749        """Test sharded_tensor.full(...)"""
750
751        spec = ChunkShardingSpec(
752            dim=0,
753            placements=[
754                "rank:0/cuda:0",
755                "rank:1/cuda:1",
756                "rank:2/cuda:2",
757                "rank:3/cuda:3",
758            ],
759        )
760        h, w = 10, 20
761        fill_value = 1234
762        st = sharded_tensor.full(
763            spec, size=(h, w), fill_value=fill_value, dtype=torch.int32
764        )
765
766        # Validate local shard is initialized with torch.full
767        local_shards = st.local_shards()
768        self.assertEqual(1, len(local_shards))
769        local_shard = local_shards[0].tensor
770        self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device)
771        # The split: for rank!=3 ceil(h/4)=3  for rank=3 1
772        expected_h = 1 if self.rank == 3 else math.ceil(h / 4)
773        self.assertEqual((expected_h, w), local_shard.size())
774        self.assertEqual(
775            local_shard,
776            torch.full(size=(expected_h, w), fill_value=fill_value, dtype=torch.int32),
777        )
778
779    @with_comms
780    @skip_if_lt_x_gpu(4)
781    @requires_nccl()
782    def test_create_sharded_tensor_like(self):
783        """Test tensor like methods, i.e. torch.zeros_like(...), torch.full_like, etc."""
784
785        spec = ChunkShardingSpec(
786            dim=0,
787            placements=[
788                "rank:0/cuda:0",
789                "rank:1/cuda:1",
790                "rank:2/cuda:2",
791                "rank:3/cuda:3",
792            ],
793        )
794        h, w = 8, 8
795        expected_h = 2
796        seed = 1234
797        dtype = torch.double
798        expected_device = torch.device(f"cuda:{self.rank}")
799        st = sharded_tensor.rand(spec, (h, w), dtype=dtype)
800        tensor_like_ops = {
801            torch.zeros_like: torch.zeros,
802            torch.ones_like: torch.ones,
803            torch.rand_like: torch.rand,
804            torch.randn_like: torch.randn,
805            torch.empty_like: torch.empty,
806            torch.full_like: torch.full,
807        }
808        for op, expect_local_op in tensor_like_ops.items():
809            if op == torch.full_like:
810                # special handle full/full_like as it needs to have additional fill_value arg
811                expect_tensor = expect_local_op(
812                    (expected_h, w), 8.8, device=expected_device, dtype=dtype
813                )
814                new_op_st = op(st, 8.8, dtype=dtype)
815                self.assertEqual(new_op_st.local_tensor(), expect_tensor)
816            elif op == torch.empty_like:
817                # empty/empty_like we only compare the shape
818                expect_tensor = expect_local_op(
819                    expected_h, w, device=expected_device, dtype=dtype
820                )
821                new_op_st = op(st, dtype=dtype)
822                self.assertEqual(new_op_st.local_tensor().shape, expect_tensor.shape)
823            else:
824                torch.manual_seed(seed)
825                expect_tensor = expect_local_op(
826                    expected_h, w, device=expected_device, dtype=dtype
827                )
828                torch.manual_seed(seed)
829                new_op_st = op(st, dtype=dtype)
830                self.assertEqual(new_op_st.local_tensor(), expect_tensor)
831
832    @with_comms
833    @skip_if_lt_x_gpu(4)
834    @requires_nccl()
835    def test_partial_world_size(self):
836        spec = ChunkShardingSpec(
837            dim=0,
838            placements=[
839                "rank:2/cuda:2",
840                "rank:3/cuda:3",
841            ],
842        )
843        st = sharded_tensor.empty(spec, 10, 20, init_rrefs=True)
844
845        # Validate local shard.
846        local_shards = st.local_shards()
847        if self.rank >= 2:
848            self.assertEqual(1, len(local_shards))
849            local_shard = local_shards[0].tensor
850            self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device)
851            self.assertEqual((5, 20), local_shard.size())
852        else:
853            self.assertEqual(0, len(local_shards))
854
855        # Validate global metadata.
856        st_metadata = st.metadata()
857        shards_metadata = st_metadata.shards_metadata
858        self.assertEqual(2, len(shards_metadata))
859
860        for shard_rank, shard_metadata in enumerate(shards_metadata):
861            self.assertEqual([shard_rank * 5, 0], shard_metadata.shard_offsets)
862            self.assertEqual([5, 20], shard_metadata.shard_sizes)
863            self.assertEqual(
864                f"rank:{shard_rank + 2}/cuda:{shard_rank + 2}",
865                str(shard_metadata.placement),
866            )
867
868        # Validate remote shards.
869        remote_shards = st.remote_shards()
870        if self.rank >= 2:
871            self.assertEqual(1, len(remote_shards))
872        else:
873            self.assertEqual(2, len(remote_shards))
874
875        for rpc_rank, shards in remote_shards.items():
876            self.assertEqual(1, len(shards))
877            for remote_shard in shards:
878                self.assertEqual(rpc_rank, remote_shard.owner().id)
879                shard = remote_shard.to_here()
880                self.assertEqual(
881                    f"rank:{rpc_rank}/cuda:{rpc_rank}", str(shard.metadata.placement)
882                )
883                self.assertEqual((5, 20), shard.tensor.size())
884
885    @with_comms
886    @skip_if_lt_x_gpu(4)
887    @requires_nccl()
888    def test_new_group(self):
889        spec = ChunkShardingSpec(
890            dim=0,
891            placements=[
892                "rank:2/cuda:2",
893                "rank:3/cuda:3",
894            ],
895        )
896
897        pg = dist.new_group(ranks=[1, 2, 3])
898        st = sharded_tensor.empty(spec, 10, 20, process_group=pg, init_rrefs=True)
899
900        # Validate local shard.
901        local_shards = st.local_shards()
902        if self.rank >= 2:
903            self.assertEqual(1, len(local_shards))
904            local_shard = local_shards[0].tensor
905            self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device)
906            self.assertEqual((5, 20), local_shard.size())
907        else:
908            self.assertEqual(0, len(local_shards))
909
910        # Validate global metadata.
911        st_metadata = st.metadata()
912        shards_metadata = st_metadata.shards_metadata
913        self.assertEqual(2, len(shards_metadata))
914
915        for shard_rank, shard_metadata in enumerate(shards_metadata):
916            self.assertEqual([shard_rank * 5, 0], shard_metadata.shard_offsets)
917            self.assertEqual([5, 20], shard_metadata.shard_sizes)
918            self.assertEqual(
919                f"rank:{shard_rank + 2}/cuda:{shard_rank + 2}",
920                str(shard_metadata.placement),
921            )
922
923        # Validate remote shards.
924        remote_shards = st.remote_shards()
925        if self.rank >= 2:
926            self.assertEqual(1, len(remote_shards))
927        else:
928            self.assertEqual(2, len(remote_shards))
929
930        for rpc_rank, shards in remote_shards.items():
931            self.assertEqual(1, len(shards))
932            for remote_shard in shards:
933                shard = remote_shard.to_here()
934                self.assertEqual(rpc_rank, remote_shard.owner().id)
935                self.assertEqual(
936                    f"rank:{rpc_rank}/cuda:{rpc_rank}", str(shard.metadata.placement)
937                )
938                self.assertEqual((5, 20), shard.tensor.size())
939
940    @with_comms
941    @skip_if_lt_x_gpu(4)
942    @requires_nccl()
943    def test_multiple_local_shards(self):
944        spec = ChunkShardingSpec(
945            dim=0,
946            placements=[
947                "rank:0/cuda:0",
948                "rank:1/cuda:1",
949                "rank:2/cuda:2",
950                "rank:3/cuda:3",
951                "rank:0/cuda:0",
952                "rank:1/cuda:1",
953                "rank:2/cuda:2",
954                "rank:3/cuda:3",
955            ],
956        )
957        st = sharded_tensor.empty(spec, 16, 20, init_rrefs=True)
958
959        # Validate local shards.
960        local_shards = st.local_shards()
961        self.assertEqual(2, len(local_shards))
962        for local_shard in local_shards:
963            self.assertEqual(
964                torch.device(f"cuda:{self.rank}"), local_shard.tensor.device
965            )
966            self.assertEqual((2, 20), local_shard.tensor.size())
967
968        # Validate global metadata.
969        st_metadata = st.metadata()
970        shards_metadata = st_metadata.shards_metadata
971        self.assertEqual(8, len(shards_metadata))
972
973        for shard_idx, shard_metadata in enumerate(shards_metadata):
974            self.assertEqual([shard_idx * 2, 0], shard_metadata.shard_offsets)
975            self.assertEqual([2, 20], shard_metadata.shard_sizes)
976            self.assertEqual(
977                f"rank:{shard_idx % 4}/cuda:{shard_idx % 4}",
978                str(shard_metadata.placement),
979            )
980
981        # Validate remote shards.
982        remote_shards = st.remote_shards()
983        self.assertEqual(3, len(remote_shards))
984        owners = {}
985        for rpc_rank, shards in remote_shards.items():
986            self.assertEqual(2, len(shards))
987            for remote_shard in shards:
988                shard = remote_shard.to_here()
989                self.assertEqual((2, 20), shard.tensor.size())
990                self.assertEqual(rpc_rank, remote_shard.owner().id)
991
992    @skip_if_lt_x_gpu(4)
993    @requires_nccl()
994    def test_sharding_columns(self):
995        self.init_pg()
996
997        for dim in [1, -1]:
998            spec = ChunkShardingSpec(
999                dim=dim,
1000                placements=[
1001                    "rank:0/cuda:0",
1002                    "rank:1/cuda:1",
1003                    "rank:2/cuda:2",
1004                    "rank:3/cuda:3",
1005                ],
1006            )
1007
1008            st = sharded_tensor.empty(spec, 10, 32)
1009
1010            # Validate local shard.
1011            local_shards = st.local_shards()
1012            self.assertEqual(1, len(local_shards))
1013            local_shard = local_shards[0].tensor
1014            self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device)
1015            self.assertEqual((10, 8), local_shard.size())
1016
1017            # Validate global metadata.
1018            st_metadata = st.metadata()
1019            shards_metadata = st_metadata.shards_metadata
1020            self.assertEqual(4, len(shards_metadata))
1021
1022            for rank, shard_metadata in enumerate(shards_metadata):
1023                self.assertEqual([0, rank * 8], shard_metadata.shard_offsets)
1024                self.assertEqual([10, 8], shard_metadata.shard_sizes)
1025                self.assertEqual(
1026                    f"rank:{rank}/cuda:{rank}", str(shard_metadata.placement)
1027                )
1028
1029    @skip_if_lt_x_gpu(4)
1030    @requires_nccl()
1031    def test_invalid_sharding(self):
1032        self.init_pg()
1033
1034        with self.assertRaisesRegex(
1035            NotImplementedError, "does not support named dimension"
1036        ):
1037            spec = ChunkShardingSpec(dim="H", placements=["rank:1/cuda:1"])
1038            sharded_tensor.empty(spec, 10, 20)
1039
1040        for dim in [2, 3, 4, -3, -4, -5]:
1041            spec = ChunkShardingSpec(dim=dim, placements=["rank:1/cuda:1"])
1042            with self.assertRaisesRegex(ValueError, "Invalid sharding dim"):
1043                sharded_tensor.empty(spec, 10, 20)
1044
1045        spec = ChunkShardingSpec(dim=0, placements=["rank:5/cuda:1"])
1046        with self.assertRaisesRegex(
1047            ValueError, "Global rank 5 does not exist in input process group"
1048        ):
1049            sharded_tensor.empty(spec, 10, 20)
1050
1051        spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"])
1052        st = sharded_tensor.empty(spec, 10, 20)
1053        tensor = torch.empty(10, 20)
1054        with self.assertRaisesRegex(
1055            RuntimeError, r".*not supported for ShardedTensor!$"
1056        ):
1057            torch.add(st, tensor)
1058
1059        spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"])
1060        with self.assertRaisesRegex(
1061            ValueError, "Only torch.strided layout is currently supported"
1062        ):
1063            sharded_tensor.empty(spec, 10, 20, layout=torch.sparse_coo)
1064
1065        spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"])
1066        with self.assertRaisesRegex(
1067            ValueError,
1068            "Only torch.contiguous_format memory_format is currently supported",
1069        ):
1070            sharded_tensor.empty(spec, 10, 20, memory_format=torch.channels_last)
1071
1072        spec = ChunkShardingSpec(dim=0, placements=["worker0/cuda:1"])
1073        with self.assertRaisesRegex(
1074            RuntimeError, "RPC framework needs to be initialized"
1075        ):
1076            sharded_tensor.empty(spec, 10, 20)
1077
1078        spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"])
1079        with self.assertRaisesRegex(
1080            RuntimeError, "RPC Framework needs to be initialized"
1081        ):
1082            st = sharded_tensor.empty(spec, 10, 20, init_rrefs=True)
1083
1084        with self.assertRaisesRegex(
1085            RuntimeError, "ShardedTensor created with init_rrefs=False"
1086        ):
1087            st = sharded_tensor.empty(spec, 10, 20)
1088            st.remote_shards()
1089
1090        self.init_rpc()
1091        spec = ChunkShardingSpec(dim=0, placements=["workerfoo/cuda:1"])
1092        with self.assertRaisesRegex(ValueError, "Invalid worker name"):
1093            sharded_tensor.empty(spec, 10, 20, init_rrefs=True)
1094
1095    @skip_if_lt_x_gpu(4)
1096    @requires_nccl()
1097    def test_invalid_pg_rpc_ranks(self):
1098        self.init_pg()
1099
1100        # Init RPC with different ranks.
1101        rpc_backend_options = rpc.TensorPipeRpcBackendOptions(
1102            _transports=tp_transports()
1103        )
1104        rpc_backend_options.init_method = f"file://{self.file_name}"
1105        rank = (self.rank + 1) % self.world_size
1106        rpc.init_rpc(
1107            name=f"worker{rank}",
1108            rank=rank,
1109            world_size=self.world_size,
1110            rpc_backend_options=rpc_backend_options,
1111        )
1112
1113        spec = ChunkShardingSpec(dim=0, placements=["rank:1/cuda:1"])
1114        with self.assertRaisesRegex(
1115            ValueError, "Default ProcessGroup and RPC ranks must be the same"
1116        ):
1117            sharded_tensor.empty(spec, 10, 20, init_rrefs=True)
1118
1119    @skip_if_lt_x_gpu(4)
1120    @requires_nccl()
1121    def test_insufficient_sharding_dims(self):
1122        self.init_pg()
1123
1124        spec = ChunkShardingSpec(
1125            dim=0,
1126            placements=[
1127                "rank:0/cuda:0",
1128                "rank:1/cuda:1",
1129                "rank:2/cuda:2",
1130                "rank:3/cuda:3",
1131            ],
1132        )
1133        st = sharded_tensor.empty(spec, 2, 20)
1134
1135        # Validate local shard.
1136        local_shards = st.local_shards()
1137        if self.rank <= 1:
1138            self.assertEqual(1, len(local_shards))
1139            local_shard = local_shards[0].tensor
1140            self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device)
1141            self.assertEqual((1, 20), local_shard.size())
1142        else:
1143            self.assertEqual(1, len(local_shards))
1144            local_shard = local_shards[0].tensor
1145            self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device)
1146            self.assertEqual(local_shard.numel(), 0)
1147
1148        # Validate global metadata.
1149        st_metadata = st.metadata()
1150        shards_metadata = st_metadata.shards_metadata
1151        self.assertEqual(4, len(shards_metadata))
1152
1153        for shard_rank, shard_metadata in enumerate(shards_metadata):
1154            self.assertEqual([shard_rank, 0], shard_metadata.shard_offsets)
1155            self.assertEqual(
1156                f"rank:{shard_rank}/cuda:{shard_rank}", str(shard_metadata.placement)
1157            )
1158            if shard_rank <= 1:
1159                self.assertEqual([1, 20], shard_metadata.shard_sizes)
1160            else:
1161                self.assertEqual([0, 20], shard_metadata.shard_sizes)
1162
1163    @with_comms
1164    @skip_if_lt_x_gpu(4)
1165    @requires_nccl()
1166    def test_sharded_tensor_sizes(self):
1167        spec = ChunkShardingSpec(
1168            dim=0,
1169            placements=[
1170                "rank:0/cuda:0",
1171                "rank:1/cuda:1",
1172                "rank:2/cuda:2",
1173                "rank:3/cuda:3",
1174            ],
1175        )
1176
1177        # Test with *args
1178        st = sharded_tensor.empty(spec, 10, 20, init_rrefs=True)
1179        self.assertEqual(torch.Size([10, 20]), st.size())
1180
1181        # Test with single *args
1182        st = sharded_tensor.empty(spec, 10, init_rrefs=True)
1183        self.assertEqual(torch.Size([10]), st.size())
1184
1185        # Test with list
1186        st = sharded_tensor.empty(spec, [10, 20], init_rrefs=True)
1187        self.assertEqual(torch.Size([10, 20]), st.size())
1188
1189        # Test with tuple
1190        st = sharded_tensor.empty(spec, (10, 20), init_rrefs=True)
1191        self.assertEqual(torch.Size([10, 20]), st.size())
1192
1193        # Test with row size
1194        st = sharded_tensor.empty(spec, (10, 20), init_rrefs=True)
1195        self.assertEqual(st.size(0), 10)
1196
1197        # Test with col size
1198        st = sharded_tensor.empty(spec, (10, 20), init_rrefs=True)
1199        self.assertEqual(st.size(1), 20)
1200
1201        # Test with negative indexed size
1202        st = sharded_tensor.empty(spec, (10, 20), init_rrefs=True)
1203        self.assertEqual(st.size(-1), 20)
1204
1205        # Test with dim/ndim
1206        self.assertEqual(st.dim(), 2)
1207        self.assertEqual(st.ndim, 2)
1208        # Test with invalid input
1209        st = sharded_tensor.empty(spec, (10, 20), init_rrefs=True)
1210        with self.assertRaisesRegex(IndexError, "Dimension out of range"):
1211            st.size(-3)
1212        with self.assertRaisesRegex(IndexError, "Dimension out of range"):
1213            st.size(2)
1214
1215        with self.assertRaises(TypeError):
1216            st = sharded_tensor.empty(spec, "foo")
1217
1218    @with_comms
1219    @skip_if_lt_x_gpu(4)
1220    @requires_nccl()
1221    def test_state_dict(self):
1222        spec = ChunkShardingSpec(
1223            dim=0,
1224            placements=[
1225                "rank:0/cuda:0",
1226                "rank:1/cuda:1",
1227                "rank:2/cuda:2",
1228                "rank:3/cuda:3",
1229            ],
1230        )
1231
1232        m = MyShardedModel1(spec)
1233
1234        # Test save
1235        m._register_state_dict_hook(state_dict_hook)
1236        buffer = io.BytesIO()
1237        mod_state_dict = m.state_dict()
1238        mod_state_keys = mod_state_dict.keys()
1239        self.assertTrue("sharded_tensor1" in mod_state_keys)
1240        self.assertTrue("submodule.sharded_tensor2" in mod_state_keys)
1241        torch.save(mod_state_dict, buffer)
1242
1243        # Test load.
1244        module_load = MyShardedModel1()
1245        module_load._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True)
1246
1247        buffer.seek(0)
1248        state_dict_deser = torch.load(buffer)
1249        module_load.load_state_dict(state_dict_deser, strict=False)
1250
1251        module_load._register_state_dict_hook(state_dict_hook)
1252        loaded_dict_keys = module_load.state_dict().keys()
1253        self.assertTrue("sharded_tensor1" in loaded_dict_keys)
1254        self.assertTrue("submodule.sharded_tensor2" in loaded_dict_keys)
1255        # Verify after load.
1256        self.assertTrue(torch.equal(m.sharded_tensor1, module_load.sharded_tensor1))
1257        self.assertTrue(
1258            torch.equal(
1259                m.submodule.sharded_tensor2, module_load.submodule.sharded_tensor2
1260            )
1261        )
1262
1263    @with_comms
1264    @skip_if_lt_x_gpu(4)
1265    @requires_nccl()
1266    def test_state_dict_new_group(self):
1267        spec = ChunkShardingSpec(
1268            dim=0,
1269            placements=[
1270                "rank:2/cuda:0",
1271                "rank:3/cuda:1",
1272                "rank:2/cuda:2",
1273                "rank:3/cuda:3",
1274            ],
1275        )
1276
1277        pg = dist.new_group([2, 3])
1278
1279        m = MyShardedModel1(spec, pg)
1280
1281        # Test save
1282        m._register_state_dict_hook(state_dict_hook)
1283        buffer = io.BytesIO()
1284        torch.save(m.state_dict(), buffer)
1285
1286        # Test load.
1287        module_load = MyShardedModel1(spec=None, group=pg)
1288        module_load._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True)
1289
1290        buffer.seek(0)
1291        with load_with_process_group(pg):
1292            state_dict_deser = torch.load(buffer)
1293            module_load.load_state_dict(state_dict_deser, strict=False)
1294
1295        # Verify after load.
1296        self.assertTrue(torch.equal(m.sharded_tensor1, module_load.sharded_tensor1))
1297        self.assertTrue(
1298            torch.equal(
1299                m.submodule.sharded_tensor2, module_load.submodule.sharded_tensor2
1300            )
1301        )
1302
1303    @with_comms
1304    @skip_if_lt_x_gpu(4)
1305    @requires_nccl()
1306    def test_state_dict_no_sharded_tensors(self):
1307        # Verify hooks don't affect modules with no ShardedTensors.
1308        m = torch.nn.Linear(10, 10)
1309
1310        # Test save
1311        state_dict_before = m.state_dict()
1312        m._register_state_dict_hook(state_dict_hook)
1313        buffer = io.BytesIO()
1314        torch.save(m.state_dict(), buffer)
1315        self.assertEqual(state_dict_before, m.state_dict())
1316
1317        # Test load.
1318        module_load = torch.nn.Linear(10, 10)
1319        module_load._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True)
1320
1321        buffer.seek(0)
1322        state_dict_deser = torch.load(buffer)
1323        module_load.load_state_dict(state_dict_deser, strict=False)
1324
1325        # Verify after load.
1326        self.assertEqual(m.weight, module_load.weight)
1327        self.assertEqual(m.bias, module_load.bias)
1328
1329    @skip_if_lt_x_gpu(4)
1330    @requires_nccl()
1331    def test_load_state_dict_errors(self):
1332        self.init_rpc()
1333
1334        dist.init_process_group(
1335            backend="nccl",
1336            world_size=self.world_size,
1337            rank=self.rank,
1338            init_method=f"file://{self.file_name}",
1339        )
1340
1341        spec = ChunkShardingSpec(
1342            dim=0,
1343            placements=[
1344                "rank:0/cuda:0",
1345                "rank:1/cuda:1",
1346                "rank:2/cuda:2",
1347                "rank:3/cuda:3",
1348            ],
1349        )
1350
1351        m = MyShardedModel1(spec)
1352
1353        # Test save
1354        m._register_state_dict_hook(state_dict_hook)
1355        buffer = io.BytesIO()
1356        torch.save(m.state_dict(), buffer)
1357
1358        pg = dist.new_group(ranks=[0, 2, 3])
1359
1360        buffer.seek(0)
1361        if self.rank != 0:
1362            with self.assertRaisesRegex(RuntimeError, "Local rank at save time was"):
1363                with load_with_process_group(pg):
1364                    state_dict_deser = torch.load(buffer)
1365        else:
1366            with self.assertRaisesRegex(
1367                RuntimeError, "Local world size at save time was"
1368            ):
1369                with load_with_process_group(pg):
1370                    state_dict_deser = torch.load(buffer)
1371
1372        dist.destroy_process_group()
1373        buffer.seek(0)
1374        with self.assertRaisesRegex(
1375            RuntimeError, "Need to initialize default process group"
1376        ):
1377            state_dict_deser = torch.load(buffer)
1378        rpc.shutdown()
1379
1380    @with_comms
1381    @skip_if_lt_x_gpu(4)
1382    @requires_nccl()
1383    def test_cleanup(self):
1384        def create_tensors():
1385            spec = ChunkShardingSpec(
1386                dim=0,
1387                placements=[
1388                    "rank:0/cuda:0",
1389                    "rank:1/cuda:1",
1390                    "rank:2/cuda:2",
1391                    "rank:3/cuda:3",
1392                ],
1393            )
1394            st1 = sharded_tensor.empty(spec, 10, 20, init_rrefs=True)
1395            st2 = sharded_tensor.empty(spec, 10, 20)
1396
1397        create_tensors()
1398        self.assertEqual(0, len(sharded_tensor.api._sharded_tensor_map))
1399
1400
1401class TestShardedTensorEnumerable(ShardedTensorTestBase):
1402    @with_comms
1403    @skip_if_lt_x_gpu(4)
1404    @requires_nccl()
1405    def test_sharded_tensor_metadata(self):
1406        spec = EnumerableShardingSpec(
1407            [
1408                ShardMetadata(
1409                    shard_offsets=[0, 0],
1410                    shard_sizes=[5, 5],
1411                    placement="rank:0/cuda:0",
1412                ),
1413                ShardMetadata(
1414                    shard_offsets=[0, 5],
1415                    shard_sizes=[5, 5],
1416                    placement="rank:1/cuda:1",
1417                ),
1418                ShardMetadata(
1419                    shard_offsets=[5, 0],
1420                    shard_sizes=[5, 5],
1421                    placement="rank:2/cuda:2",
1422                ),
1423                ShardMetadata(
1424                    shard_offsets=[5, 5],
1425                    shard_sizes=[5, 5],
1426                    placement="rank:3/cuda:3",
1427                ),
1428            ]
1429        )
1430
1431        st = sharded_tensor.empty(spec, 10, 10, init_rrefs=True)
1432        st_metadata = st.metadata()
1433        self.assertEqual(torch.Size([10, 10]), st_metadata.size)
1434        self.assertEqual(torch.float, st.dtype)
1435        self.assertEqual(torch.strided, st.layout)
1436        self.assertEqual(False, st.requires_grad)
1437        self.assertTrue(st.is_contiguous())
1438        self.assertFalse(st.is_pinned())
1439
1440        st = sharded_tensor.empty(spec, 10, 10, requires_grad=True, init_rrefs=True)
1441        self.assertEqual(True, st.requires_grad)
1442
1443        st = sharded_tensor.empty(spec, 10, 10, dtype=torch.double, init_rrefs=True)
1444        self.assertEqual(torch.double, st.dtype)
1445
1446        # Need CPU for pin_memory
1447        spec = EnumerableShardingSpec(
1448            [
1449                ShardMetadata(
1450                    shard_offsets=[0, 0],
1451                    shard_sizes=[5, 5],
1452                    placement="rank:0/cpu",
1453                ),
1454                ShardMetadata(
1455                    shard_offsets=[0, 5],
1456                    shard_sizes=[5, 5],
1457                    placement="rank:1/cpu",
1458                ),
1459                ShardMetadata(
1460                    shard_offsets=[5, 0],
1461                    shard_sizes=[5, 5],
1462                    placement="rank:2/cpu",
1463                ),
1464                ShardMetadata(
1465                    shard_offsets=[5, 5],
1466                    shard_sizes=[5, 5],
1467                    placement="rank:3/cpu",
1468                ),
1469            ]
1470        )
1471
1472        st = sharded_tensor.empty(spec, 10, 10, pin_memory=True, init_rrefs=True)
1473        self.assertTrue(st.is_pinned())
1474
1475    @with_comms
1476    @skip_if_lt_x_gpu(4)
1477    @requires_nccl()
1478    def test_grid_sharding(self):
1479        spec = EnumerableShardingSpec(
1480            [
1481                ShardMetadata(
1482                    shard_offsets=[0, 0],
1483                    shard_sizes=[5, 5],
1484                    placement="rank:0/cuda:0",
1485                ),
1486                ShardMetadata(
1487                    shard_offsets=[0, 5],
1488                    shard_sizes=[5, 5],
1489                    placement="rank:1/cuda:1",
1490                ),
1491                ShardMetadata(
1492                    shard_offsets=[5, 0],
1493                    shard_sizes=[5, 5],
1494                    placement="rank:2/cuda:2",
1495                ),
1496                ShardMetadata(
1497                    shard_offsets=[5, 5],
1498                    shard_sizes=[5, 5],
1499                    placement="rank:3/cuda:3",
1500                ),
1501            ]
1502        )
1503
1504        st = sharded_tensor.empty(spec, 10, 10, init_rrefs=True)
1505        self.assertEqual((10, 10), st.size())
1506        self.assertEqual(1, len(st.local_shards()))
1507
1508        # Verify local shard.
1509        local_shard = st.local_shards()[0]
1510        self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.tensor.device)
1511        self.assertEqual((5, 5), local_shard.tensor.size())
1512
1513        # Verify local shard metadata.
1514        self.assertEqual(
1515            (self.rank // 2 * 5, (self.rank % 2) * 5),
1516            local_shard.metadata.shard_offsets,
1517        )
1518        self.assertEqual((5, 5), local_shard.metadata.shard_sizes)
1519        self.assertEqual(
1520            f"rank:{self.rank}/cuda:{self.rank}", str(local_shard.metadata.placement)
1521        )
1522
1523        # Verify global metadata.
1524        st_metadata = st.metadata()
1525        shards_metadata = st_metadata.shards_metadata
1526        self.assertEqual(4, len(shards_metadata))
1527        for rank, shard_metadata in enumerate(shards_metadata):
1528            self.assertEqual(
1529                (rank // 2 * 5, (rank % 2) * 5), shard_metadata.shard_offsets
1530            )
1531            self.assertEqual((5, 5), shard_metadata.shard_sizes)
1532            self.assertEqual(f"rank:{rank}/cuda:{rank}", str(shard_metadata.placement))
1533
1534        # Validate remote shards.
1535        remote_shards = st.remote_shards()
1536        self.assertEqual(3, len(remote_shards))
1537
1538        for rpc_rank, shards in remote_shards.items():
1539            self.assertEqual(1, len(shards))
1540            for remote_shard in shards:
1541                self.assertEqual(rpc_rank, remote_shard.owner().id)
1542                shard = remote_shard.to_here()
1543                self.assertEqual((5, 5), shard.tensor.size())
1544
1545    @with_comms
1546    @skip_if_lt_x_gpu(4)
1547    @requires_nccl()
1548    def test_create_sharded_tensor_with_ones(self):
1549        """Test sharded_tensor.ones(...)"""
1550
1551        spec = EnumerableShardingSpec(
1552            [
1553                ShardMetadata(
1554                    shard_offsets=[0, 0],
1555                    shard_sizes=[5, 5],
1556                    placement="rank:0/cuda:0",
1557                ),
1558                ShardMetadata(
1559                    shard_offsets=[0, 5],
1560                    shard_sizes=[5, 5],
1561                    placement="rank:1/cuda:1",
1562                ),
1563                ShardMetadata(
1564                    shard_offsets=[5, 0],
1565                    shard_sizes=[5, 5],
1566                    placement="rank:2/cuda:2",
1567                ),
1568                ShardMetadata(
1569                    shard_offsets=[5, 5],
1570                    shard_sizes=[5, 5],
1571                    placement="rank:3/cuda:3",
1572                ),
1573            ]
1574        )
1575
1576        st = sharded_tensor.ones(spec, 10, 10, init_rrefs=True)
1577        self.assertEqual((10, 10), st.size())
1578        self.assertEqual(1, len(st.local_shards()))
1579
1580        # Verify local shard is initialized with torch.ones
1581        local_shard = st.local_shards()[0]
1582        self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.tensor.device)
1583        self.assertEqual((5, 5), local_shard.tensor.size())
1584        self.assertEqual(local_shard.tensor, torch.ones(5, 5))
1585
1586    @with_comms
1587    @skip_if_lt_x_gpu(4)
1588    @requires_nccl()
1589    def test_gather_even(self) -> None:
1590        """Test _sharded_tensor.gather(...) with evenly distributed._shards"""
1591
1592        spec = EnumerableShardingSpec(
1593            [
1594                ShardMetadata(
1595                    shard_offsets=[0, 0],
1596                    shard_sizes=[5, 5],
1597                    placement="rank:0/cuda:0",
1598                ),
1599                ShardMetadata(
1600                    shard_offsets=[0, 5],
1601                    shard_sizes=[5, 5],
1602                    placement="rank:1/cuda:1",
1603                ),
1604                ShardMetadata(
1605                    shard_offsets=[5, 0],
1606                    shard_sizes=[5, 5],
1607                    placement="rank:2/cuda:2",
1608                ),
1609                ShardMetadata(
1610                    shard_offsets=[5, 5],
1611                    shard_sizes=[5, 5],
1612                    placement="rank:3/cuda:3",
1613                ),
1614            ]
1615        )
1616
1617        h, w = 10, 10
1618        st = sharded_tensor.ones(spec, h, w, init_rrefs=True)
1619
1620        full_tensor = None
1621        dst = 0
1622        if self.rank == dst:
1623            full_tensor = torch.zeros(h, w, device=torch.device(f"cuda:{dst}"))
1624        st.gather(dst, full_tensor)
1625
1626        if self.rank == dst:
1627            self.assertEqual(full_tensor, torch.ones(h, w))
1628        else:
1629            self.assertIsNone(full_tensor)
1630
1631    @with_comms
1632    @skip_if_lt_x_gpu(4)
1633    @requires_nccl()
1634    def test_gather_uneven(self) -> None:
1635        """Test _sharded_tensor.gather(...) with unevenly distributed._shards"""
1636
1637        spec = EnumerableShardingSpec(
1638            [
1639                ShardMetadata(
1640                    shard_offsets=[0, 0],
1641                    shard_sizes=[5, 5],
1642                    placement="rank:0/cuda:0",
1643                ),
1644                ShardMetadata(
1645                    shard_offsets=[0, 5],
1646                    shard_sizes=[5, 5],
1647                    placement="rank:1/cuda:1",
1648                ),
1649                ShardMetadata(
1650                    shard_offsets=[5, 0],
1651                    shard_sizes=[5, 5],
1652                    placement="rank:0/cuda:0",
1653                ),
1654                ShardMetadata(
1655                    shard_offsets=[5, 5],
1656                    shard_sizes=[5, 5],
1657                    placement="rank:3/cuda:3",
1658                ),
1659            ]
1660        )
1661
1662        h, w = 10, 10
1663        st = sharded_tensor.ones(spec, h, w, init_rrefs=True)
1664
1665        full_tensor = None
1666        dst = 0
1667        if self.rank == dst:
1668            full_tensor = torch.zeros(h, w, device=torch.device(f"cuda:{dst}"))
1669        st.gather(dst, full_tensor)
1670
1671        if self.rank == dst:
1672            self.assertEqual(full_tensor, torch.ones(h, w))
1673        else:
1674            self.assertIsNone(full_tensor)
1675
1676    @with_comms
1677    @skip_if_lt_x_gpu(4)
1678    @requires_nccl()
1679    def test_sharded_tensor_to_cpu(self):
1680        cpu_spec = ChunkShardingSpec(
1681            dim=0,
1682            placements=[
1683                "rank:0/cpu",
1684                "rank:1/cpu",
1685                "rank:2/cpu",
1686                "rank:3/cpu",
1687            ],
1688        )
1689        spec = ChunkShardingSpec(
1690            dim=0,
1691            placements=[
1692                "rank:0/cuda:0",
1693                "rank:1/cuda:1",
1694                "rank:2/cuda:2",
1695                "rank:3/cuda:3",
1696            ],
1697        )
1698        h, w = 10, 20
1699        gloo_pg = dist.new_group(backend="gloo")
1700
1701        # CPU sharded tensor should return the same instance (no copy)
1702        st_cpu = sharded_tensor.zeros(cpu_spec, h, w, process_group=gloo_pg)
1703        new_st_cpu = st_cpu.cpu()
1704        self.assertTrue(st_cpu is new_st_cpu)
1705
1706        # GPU sharded tensor to cpu
1707        st = sharded_tensor.zeros(spec, h, w)
1708        # test ability to move st to CPU
1709        spec_before_move = st.sharding_spec()
1710        new_st = st.cpu(process_group=gloo_pg)
1711        # return a copy of original st
1712        self.assertFalse(st is new_st)
1713        # check the spec is still ChunkShardingSpec
1714        spec_after_move = new_st.sharding_spec()
1715        self.assertIsInstance(spec_after_move, ChunkShardingSpec)
1716        self.assertIsInstance(new_st._process_group, distributed_c10d.ProcessGroup)
1717        # test specs before and after the move almost the same except placement device
1718        self.assertEqual(spec_before_move.dim, spec_after_move.dim)
1719        self.assertEqual(
1720            len(spec_before_move.placements), len(spec_after_move.placements)
1721        )
1722        for i, remote_device_after in enumerate(spec_after_move.placements):
1723            remote_device_before = spec_before_move.placements[i]
1724            self.assertEqual(remote_device_before.rank(), remote_device_after.rank())
1725            self.assertEqual(str(remote_device_after.device()), "cpu")
1726
1727        # ensure metdata also get changed to CPU
1728        metas = new_st.metadata().shards_metadata
1729        for meta in metas:
1730            self.assertEqual(str(meta.placement.device()), "cpu")
1731
1732        # Test if a mixed sharded tensor (ShardedTensor with different devices) to cpu
1733        mixed_spec = ChunkShardingSpec(
1734            dim=0,
1735            placements=[
1736                "rank:0/cpu",
1737                "rank:1/cpu",
1738                "rank:2/cuda:2",
1739                "rank:3/cuda:3",
1740            ],
1741        )
1742
1743        st = sharded_tensor.zeros(mixed_spec, h, w, process_group=gloo_pg)
1744        new_st = st.cpu()
1745        # return a copy of original st
1746        self.assertFalse(st is new_st)
1747        # check the spec is still ChunkShardingSpec
1748        spec_after_move = new_st.sharding_spec()
1749        self.assertIsInstance(spec_after_move, ChunkShardingSpec)
1750        # test specs before and after the move almost the same except placement device
1751        self.assertEqual(mixed_spec.dim, spec_after_move.dim)
1752        self.assertEqual(len(mixed_spec.placements), len(spec_after_move.placements))
1753        for i, remote_device_after in enumerate(spec_after_move.placements):
1754            remote_device_before = mixed_spec.placements[i]
1755            self.assertEqual(remote_device_before.rank(), remote_device_after.rank())
1756            self.assertEqual(str(remote_device_after.device()), "cpu")
1757
1758        # ensure metdata also get changed to CPU
1759        metas = new_st.metadata().shards_metadata
1760        for meta in metas:
1761            self.assertEqual(str(meta.placement.device()), "cpu")
1762
1763    @with_comms
1764    @skip_if_lt_x_gpu(4)
1765    @requires_nccl()
1766    def test_sharded_tensor_to_cuda(self):
1767        cpu_spec = ChunkShardingSpec(
1768            dim=0,
1769            placements=[
1770                "rank:0/cpu",
1771                "rank:1/cpu",
1772                "rank:2/cpu",
1773                "rank:3/cpu",
1774            ],
1775        )
1776        spec = ChunkShardingSpec(
1777            dim=0,
1778            placements=[
1779                "rank:0/cuda:0",
1780                "rank:1/cuda:1",
1781                "rank:2/cuda:2",
1782                "rank:3/cuda:3",
1783            ],
1784        )
1785        h, w = 10, 20
1786        # CUDA sharded tensor should return a new ShardedTensor, but same
1787        # local shards(no movements)
1788        st_cuda = sharded_tensor.zeros(spec, h, w)
1789        new_st_cuda = st_cuda.cuda()
1790        self.assertTrue(st_cuda is not new_st_cuda)
1791        self.assertTrue(st_cuda.local_tensor() is new_st_cuda.local_tensor())
1792
1793        gloo_pg = dist.new_group(backend="gloo")
1794
1795        # CPU sharded tensor to GPU
1796        st_cpu = sharded_tensor.zeros(cpu_spec, h, w, process_group=gloo_pg)
1797        # test ability to move st to GPU
1798        spec_before_move = st_cpu.sharding_spec()
1799        new_st_gpu = st_cpu.cuda()
1800        # check the spec is still ChunkShardingSpec
1801        spec_after_move = new_st_gpu.sharding_spec()
1802        self.assertIsInstance(spec_after_move, ChunkShardingSpec)
1803        # test specs before and after the move almost the same except placement device
1804        self.assertEqual(spec_before_move.dim, spec_after_move.dim)
1805        self.assertEqual(
1806            len(spec_before_move.placements), len(spec_after_move.placements)
1807        )
1808        for i, remote_device_after in enumerate(spec_after_move.placements):
1809            remote_device_before = spec_before_move.placements[i]
1810            self.assertEqual(remote_device_before.rank(), remote_device_after.rank())
1811            self.assertEqual(str(remote_device_before.device().type), "cpu")
1812            self.assertEqual(str(remote_device_after.device().type), "cuda")
1813
1814        # ensure metdata also get changed to GPU
1815        metas = new_st_gpu.metadata().shards_metadata
1816        for meta in metas:
1817            self.assertEqual(str(meta.placement.device().type), "cuda")
1818
1819    @with_comms
1820    @skip_if_lt_x_gpu(4)
1821    @requires_nccl()
1822    def test_sharded_tensor_to_test(self):
1823        spec = ChunkShardingSpec(
1824            dim=0,
1825            placements=[
1826                "rank:0/cuda:0",
1827                "rank:1/cuda:1",
1828                "rank:2/cuda:2",
1829                "rank:3/cuda:3",
1830            ],
1831        )
1832        h, w = 10, 20
1833        # CUDA sharded tensor should return a new ShardedTensor, but same
1834        # local shards(no movements)
1835        st = sharded_tensor.zeros(spec, h, w)
1836        # test same dtype, device return itself
1837        st_self = st.to(dtype=st.dtype, device="cuda")
1838        self.assertTrue(st_self is st)
1839
1840        # test dtype to
1841        st_16 = st.to(torch.float16)
1842        self.assertFalse(st_16 is st)
1843        self.assertEqual(st_16.dtype, torch.float16)
1844        # test device to
1845        st_cpu = st.to(device=torch.device("cpu"))
1846        self.assertFalse(st_cpu is st)
1847        self.assertEqual(st_cpu.local_tensor().device.type, "cpu")
1848        st_cuda = st_cpu.to(device=torch.device("cuda"))
1849        self.assertEqual(st_cuda.local_tensor().device.type, "cuda")
1850        # non-kwarg device to
1851        st_cuda = st_cpu.to(torch.device("cuda"))
1852        self.assertEqual(st_cuda.local_tensor().device.type, "cuda")
1853        st_cpu = st_cuda.to(torch.device("cpu"))
1854        self.assertEqual(st_cpu.local_tensor().device.type, "cpu")
1855        # with string like device conversion
1856        st_cpu = st_cuda.to("cpu")
1857        self.assertEqual(st_cpu.local_tensor().device.type, "cpu")
1858        st_cuda = st_cpu.to("cuda")
1859        self.assertEqual(st_cuda.local_tensor().device.type, "cuda")
1860        # with int like device conversion
1861        st_cpu = st_cuda.to("cpu")
1862        self.assertEqual(st_cpu.local_tensor().device.type, "cpu")
1863        st_cuda = st_cpu.to(self.rank)
1864        self.assertEqual(st_cuda.local_tensor().device.type, "cuda")
1865
1866        # test tensor to
1867        cuda_tensor = torch.randn(3, 4, dtype=torch.float16, device="cuda")
1868        st_cuda = st.to(cuda_tensor)
1869        self.assertFalse(st_cuda is st)
1870        self.assertEqual(st_cuda.dtype, torch.float16)
1871
1872        cuda_tensor = torch.randn(3, 4, dtype=torch.float16, device="cuda:2")
1873        st_cuda = st.to(cuda_tensor)
1874        self.assertEqual(st_cuda.dtype, torch.float16)
1875
1876        # test dtype and device together
1877        st_cpu_16 = st.to("cpu", torch.float16)
1878        self.assertEqual(st_cpu_16.dtype, torch.float16)
1879        self.assertEqual(st_cpu_16.local_tensor().device.type, "cpu")
1880
1881        st_cuda_32 = st_cpu_16.to("cuda", torch.float32)
1882        self.assertEqual(st_cuda_32.dtype, torch.float32)
1883        self.assertEqual(st_cuda_32.local_tensor().device.type, "cuda")
1884
1885        # test pass additional process group
1886        gloo_pg = dist.new_group(backend="gloo")
1887        st_gloo = st.to(device="cpu", process_group=gloo_pg)
1888        self.assertFalse(st_gloo is st)
1889        self.assertEqual(st_gloo.local_tensor().device.type, "cpu")
1890        self.assertEqual(st_gloo._process_group, gloo_pg)
1891
1892    @with_comms
1893    @skip_if_lt_x_gpu(4)
1894    @requires_nccl()
1895    def test_sharded_tensor_device(self):
1896        spec = ChunkShardingSpec(
1897            dim=0,
1898            placements=[
1899                "rank:0/cuda:0",
1900                "rank:1/cuda:1",
1901                "rank:2/cuda:2",
1902                "rank:3/cuda:3",
1903            ],
1904        )
1905        h, w = 10, 20
1906        # CUDA sharded tensor should return a new ShardedTensor, but same
1907        # local shards(no movements)
1908        st = sharded_tensor.zeros(spec, h, w)
1909        current_device = torch.device(torch.cuda.current_device())
1910        self.assertEqual(current_device, st.device)
1911
1912        # test after to cpu, device get changed
1913        cpu_device = torch.device("cpu")
1914        st_cpu = st.to(device=cpu_device)
1915        self.assertEqual(st_cpu.device, cpu_device)
1916
1917    @skip_if_lt_x_gpu(4)
1918    @requires_nccl()
1919    def test_uneven_shards(self):
1920        self.init_pg()
1921
1922        spec = EnumerableShardingSpec(
1923            [
1924                ShardMetadata(
1925                    shard_offsets=[0, 0],
1926                    shard_sizes=[2, 4],
1927                    placement="rank:0/cuda:0",
1928                ),
1929                ShardMetadata(
1930                    shard_offsets=[0, 4],
1931                    shard_sizes=[4, 2],
1932                    placement="rank:1/cuda:1",
1933                ),
1934                ShardMetadata(
1935                    shard_offsets=[2, 0],
1936                    shard_sizes=[4, 4],
1937                    placement="rank:2/cuda:2",
1938                ),
1939                ShardMetadata(
1940                    shard_offsets=[4, 4],
1941                    shard_sizes=[2, 2],
1942                    placement="rank:3/cuda:3",
1943                ),
1944            ]
1945        )
1946
1947        st = sharded_tensor.empty(spec, 6, 6)
1948        self.assertEqual((6, 6), st.size())
1949        self.assertEqual(1, len(st.local_shards()))
1950
1951        def verify_size(rank, tensor_dims):
1952            if rank == 0:
1953                self.assertEqual((2, 4), tensor_dims)
1954            elif rank == 1:
1955                self.assertEqual((4, 2), tensor_dims)
1956            elif rank == 2:
1957                self.assertEqual((4, 4), tensor_dims)
1958            elif rank == 3:
1959                self.assertEqual((2, 2), tensor_dims)
1960
1961        def verify_offsets(rank, offsets):
1962            if rank == 0:
1963                self.assertEqual((0, 0), offsets)
1964            elif rank == 1:
1965                self.assertEqual((0, 4), offsets)
1966            elif rank == 2:
1967                self.assertEqual((2, 0), offsets)
1968            elif rank == 3:
1969                self.assertEqual((4, 4), offsets)
1970
1971        # Verify local shard.
1972        local_shard = st.local_shards()[0]
1973        self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.tensor.device)
1974        verify_size(self.rank, local_shard.tensor.size())
1975
1976        # Verify local shard metadata.
1977        verify_offsets(self.rank, local_shard.metadata.shard_offsets)
1978        verify_size(self.rank, local_shard.metadata.shard_sizes)
1979        self.assertEqual(
1980            f"rank:{self.rank}/cuda:{self.rank}", str(local_shard.metadata.placement)
1981        )
1982
1983        # Verify global metadata.
1984        st_metadata = st.metadata()
1985        shards_metadata = st_metadata.shards_metadata
1986        self.assertEqual(4, len(shards_metadata))
1987        for rank, shard_metadata in enumerate(shards_metadata):
1988            verify_offsets(rank, shard_metadata.shard_offsets)
1989            verify_size(rank, shard_metadata.shard_sizes)
1990            self.assertEqual(f"rank:{rank}/cuda:{rank}", str(shard_metadata.placement))
1991
1992    @with_comms
1993    @skip_if_lt_x_gpu(4)
1994    @requires_nccl()
1995    def test_partial_world_size(self):
1996        spec = EnumerableShardingSpec(
1997            [
1998                ShardMetadata(
1999                    shard_offsets=[0, 0],
2000                    shard_sizes=[5, 5],
2001                    placement="rank:0/cuda:0",
2002                ),
2003                ShardMetadata(
2004                    shard_offsets=[5, 0],
2005                    shard_sizes=[5, 5],
2006                    placement="rank:1/cuda:1",
2007                ),
2008            ]
2009        )
2010
2011        st = sharded_tensor.empty(spec, 10, 5, init_rrefs=True)
2012        self.assertEqual((10, 5), st.size())
2013        if self.rank <= 1:
2014            self.assertEqual(1, len(st.local_shards()))
2015        else:
2016            self.assertEqual(0, len(st.local_shards()))
2017
2018        if self.rank <= 1:
2019            # Verify local shard.
2020            local_shard = st.local_shards()[0]
2021            self.assertEqual(
2022                torch.device(f"cuda:{self.rank}"), local_shard.tensor.device
2023            )
2024            self.assertEqual((5, 5), local_shard.tensor.size())
2025
2026            # Verify local shard metadata.
2027            self.assertEqual((self.rank * 5, 0), local_shard.metadata.shard_offsets)
2028            self.assertEqual((5, 5), local_shard.metadata.shard_sizes)
2029            self.assertEqual(
2030                f"rank:{self.rank}/cuda:{self.rank}",
2031                str(local_shard.metadata.placement),
2032            )
2033
2034        # Verify global metadata.
2035        st_metadata = st.metadata()
2036        shards_metadata = st_metadata.shards_metadata
2037        self.assertEqual(2, len(shards_metadata))
2038        for rank, shard_metadata in enumerate(shards_metadata):
2039            self.assertEqual((rank * 5, 0), shard_metadata.shard_offsets)
2040            self.assertEqual((5, 5), shard_metadata.shard_sizes)
2041            self.assertEqual(f"rank:{rank}/cuda:{rank}", str(shard_metadata.placement))
2042
2043        # Validate remote shards.
2044        remote_shards = st.remote_shards()
2045        if self.rank <= 1:
2046            self.assertEqual(1, len(remote_shards))
2047        else:
2048            self.assertEqual(2, len(remote_shards))
2049
2050        for rpc_rank, shards in remote_shards.items():
2051            self.assertEqual(1, len(shards))
2052
2053            for remote_shard in shards:
2054                self.assertEqual(rpc_rank, remote_shard.owner().id)
2055                shard = remote_shard.to_here()
2056                self.assertEqual((5, 5), shard.tensor.size())
2057
2058    @with_comms
2059    @skip_if_lt_x_gpu(4)
2060    @requires_nccl()
2061    def test_new_group(self):
2062        spec = EnumerableShardingSpec(
2063            [
2064                ShardMetadata(
2065                    shard_offsets=[0, 0],
2066                    shard_sizes=[5, 5],
2067                    placement="rank:1/cuda:1",
2068                ),
2069                ShardMetadata(
2070                    shard_offsets=[5, 0],
2071                    shard_sizes=[5, 5],
2072                    placement="rank:3/cuda:3",
2073                ),
2074            ]
2075        )
2076
2077        pg = dist.new_group(ranks=[1, 2, 3])
2078
2079        st = sharded_tensor.empty(spec, 10, 5, process_group=pg, init_rrefs=True)
2080        self.assertEqual((10, 5), st.size())
2081        if self.rank == 1 or self.rank == 3:
2082            # Verify local shard.
2083            local_shard = st.local_shards()[0]
2084            self.assertEqual(
2085                torch.device(f"cuda:{self.rank}"), local_shard.tensor.device
2086            )
2087            self.assertEqual((5, 5), local_shard.tensor.size())
2088
2089            # Verify local shard metadata.
2090            self.assertEqual(
2091                (self.rank // 2 * 5, 0), local_shard.metadata.shard_offsets
2092            )
2093            self.assertEqual((5, 5), local_shard.metadata.shard_sizes)
2094            self.assertEqual(
2095                f"rank:{self.rank}/cuda:{self.rank}",
2096                str(local_shard.metadata.placement),
2097            )
2098
2099        # Verify global metadata.
2100        st_metadata = st.metadata()
2101        shards_metadata = st_metadata.shards_metadata
2102        self.assertEqual(2, len(shards_metadata))
2103        for rank, shard_metadata in enumerate(shards_metadata):
2104            self.assertEqual((rank * 5, 0), shard_metadata.shard_offsets)
2105            self.assertEqual((5, 5), shard_metadata.shard_sizes)
2106            self.assertEqual(
2107                f"rank:{rank * 2 + 1}/cuda:{rank * 2 + 1}",
2108                str(shard_metadata.placement),
2109            )
2110
2111        # Validate remote shards.
2112        remote_shards = st.remote_shards()
2113        if self.rank == 1 or self.rank == 3:
2114            self.assertEqual(1, len(remote_shards))
2115        else:
2116            self.assertEqual(2, len(remote_shards))
2117
2118        for rpc_rank, shards in remote_shards.items():
2119            self.assertEqual(1, len(shards))
2120
2121            for remote_shard in shards:
2122                self.assertEqual(rpc_rank, remote_shard.owner().id)
2123                shard = remote_shard.to_here()
2124                self.assertEqual((5, 5), shard.tensor.size())
2125
2126    @with_comms
2127    @skip_if_lt_x_gpu(4)
2128    @requires_nccl()
2129    def test_multiple_local_shards(self):
2130        spec = EnumerableShardingSpec(
2131            [
2132                ShardMetadata(
2133                    shard_offsets=[0, 0],
2134                    shard_sizes=[5, 5],
2135                    placement="rank:0/cuda:0",
2136                ),
2137                ShardMetadata(
2138                    shard_offsets=[0, 5],
2139                    shard_sizes=[5, 5],
2140                    placement="rank:1/cuda:1",
2141                ),
2142                ShardMetadata(
2143                    shard_offsets=[5, 0],
2144                    shard_sizes=[5, 5],
2145                    placement="rank:0/cuda:0",
2146                ),
2147                ShardMetadata(
2148                    shard_offsets=[5, 5],
2149                    shard_sizes=[5, 5],
2150                    placement="rank:1/cuda:1",
2151                ),
2152            ]
2153        )
2154
2155        st = sharded_tensor.empty(spec, 10, 10, init_rrefs=True)
2156        self.assertEqual((10, 10), st.size())
2157
2158        if self.rank <= 1:
2159            self.assertEqual(2, len(st.local_shards()))
2160
2161            # Verify local shards.
2162            for idx, local_shard in enumerate(st.local_shards()):
2163                self.assertEqual(
2164                    torch.device(f"cuda:{self.rank}"), local_shard.tensor.device
2165                )
2166                self.assertEqual((5, 5), local_shard.tensor.size())
2167
2168                # Verify local shard metadata.
2169                self.assertEqual(
2170                    (idx * 5, self.rank * 5), local_shard.metadata.shard_offsets
2171                )
2172                self.assertEqual((5, 5), local_shard.metadata.shard_sizes)
2173                self.assertEqual(
2174                    f"rank:{self.rank}/cuda:{self.rank}",
2175                    str(local_shard.metadata.placement),
2176                )
2177        else:
2178            self.assertEqual(0, len(st.local_shards()))
2179
2180        # Verify global metadata.
2181        st_metadata = st.metadata()
2182        shards_metadata = st_metadata.shards_metadata
2183        self.assertEqual(4, len(shards_metadata))
2184        for shard_rank, shard_metadata in enumerate(shards_metadata):
2185            self.assertEqual(
2186                (shard_rank // 2 * 5, (shard_rank % 2) * 5),
2187                shard_metadata.shard_offsets,
2188            )
2189            self.assertEqual((5, 5), shard_metadata.shard_sizes)
2190            self.assertEqual(
2191                f"rank:{shard_rank % 2}/cuda:{shard_rank % 2}",
2192                str(shard_metadata.placement),
2193            )
2194
2195        # Validate remote shards.
2196        remote_shards = st.remote_shards()
2197        if self.rank <= 1:
2198            self.assertEqual(1, len(remote_shards))
2199        else:
2200            self.assertEqual(2, len(remote_shards))
2201
2202        owners = {}
2203        for rpc_rank, shards in remote_shards.items():
2204            self.assertEqual(2, len(shards))
2205            for remote_shard in shards:
2206                self.assertEqual(rpc_rank, remote_shard.owner().id)
2207                shard = remote_shard.to_here()
2208                self.assertEqual((5, 5), shard.tensor.size())
2209
2210    @with_comms
2211    @skip_if_lt_x_gpu(4)
2212    @requires_nccl()
2213    def test_with_rpc_names(self):
2214        spec = EnumerableShardingSpec(
2215            [
2216                ShardMetadata(
2217                    shard_offsets=[0, 0],
2218                    shard_sizes=[5, 5],
2219                    placement="worker0/cuda:0",
2220                ),
2221                ShardMetadata(
2222                    shard_offsets=[0, 5],
2223                    shard_sizes=[5, 5],
2224                    placement="worker1/cuda:1",
2225                ),
2226                ShardMetadata(
2227                    shard_offsets=[5, 0],
2228                    shard_sizes=[5, 5],
2229                    placement="worker2/cuda:2",
2230                ),
2231                ShardMetadata(
2232                    shard_offsets=[5, 5],
2233                    shard_sizes=[5, 5],
2234                    placement="worker3/cuda:3",
2235                ),
2236            ]
2237        )
2238
2239        st = sharded_tensor.empty(spec, 10, 10, init_rrefs=True)
2240        self.assertEqual((10, 10), st.size())
2241        self.assertEqual(1, len(st.local_shards()))
2242
2243        # Verify local shard.
2244        local_shard = st.local_shards()[0]
2245        self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.tensor.device)
2246        self.assertEqual((5, 5), local_shard.tensor.size())
2247
2248        # Verify local shard metadata.
2249        self.assertEqual(
2250            (self.rank // 2 * 5, (self.rank % 2) * 5),
2251            local_shard.metadata.shard_offsets,
2252        )
2253        self.assertEqual((5, 5), local_shard.metadata.shard_sizes)
2254        self.assertEqual(
2255            f"worker{self.rank}/cuda:{self.rank}", str(local_shard.metadata.placement)
2256        )
2257
2258        # Verify global metadata.
2259        st_metadata = st.metadata()
2260        shards_metadata = st_metadata.shards_metadata
2261        self.assertEqual(4, len(shards_metadata))
2262        for rank, shard_metadata in enumerate(shards_metadata):
2263            self.assertEqual(
2264                (rank // 2 * 5, (rank % 2) * 5), shard_metadata.shard_offsets
2265            )
2266            self.assertEqual((5, 5), shard_metadata.shard_sizes)
2267            self.assertEqual(f"worker{rank}/cuda:{rank}", str(shard_metadata.placement))
2268
2269        # Validate remote shards.
2270        remote_shards = st.remote_shards()
2271        self.assertEqual(3, len(remote_shards))
2272
2273        for rpc_rank, shards in remote_shards.items():
2274            self.assertEqual(1, len(shards))
2275            for remote_shard in shards:
2276                self.assertEqual(rpc_rank, remote_shard.owner().id)
2277                shard = remote_shard.to_here()
2278                self.assertEqual((5, 5), shard.tensor.size())
2279
2280
2281class TestShardedTensorFromLocalTensor(ShardedTensorTestBase):
2282    def _generate_st_from_chunk_local_tensor(self, st_size, sharding_spec):
2283        tensor_meta = sharding_spec.build_metadata(st_size, TensorProperties())
2284        pg = dist.distributed_c10d._get_default_group()
2285
2286        local_tensor = None
2287        local_shard_metadata = None
2288        rank_to_metadata = {}
2289        for shard_metadata in tensor_meta.shards_metadata:
2290            rank, device = _parse_and_validate_remote_device(
2291                pg, shard_metadata.placement
2292            )
2293            rank_to_metadata[rank] = shard_metadata
2294            if rank == self.rank:
2295                local_tensor = torch.rand(shard_metadata.shard_sizes).cuda(device)
2296                local_shard_metadata = shard_metadata
2297
2298        # TODO: figure out what the API should behave when some rank have no shard
2299        # see https://github.com/pytorch/pytorch/issues/73133
2300        assert local_tensor is not None
2301        st = ShardedTensor._init_from_local_tensor(
2302            local_tensor,
2303            sharding_spec,
2304            st_size,
2305            init_rrefs=True,
2306        )
2307        self.assertEqual(tuple(st_size), st.size())
2308        self.assertEqual(1, len(st.local_shards()))
2309
2310        # Verify local shard.
2311        local_shard = st.local_shards()[0]
2312        self.assertEqual(st.local_tensor(), local_tensor)
2313        self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.tensor.device)
2314
2315        # Verify local shard metadata.
2316        self.assertEqual(
2317            local_shard_metadata.shard_offsets, local_shard.metadata.shard_offsets
2318        )
2319        self.assertEqual(
2320            local_shard_metadata.shard_sizes, local_shard.metadata.shard_sizes
2321        )
2322        self.assertEqual(local_shard_metadata.placement, local_shard.metadata.placement)
2323
2324        # Verify global metadata.
2325        st_shards_metadata = st.metadata().shards_metadata
2326        self.assertEqual(self.world_size, len(st_shards_metadata))
2327        self.assertEqual(tensor_meta.shards_metadata, st_shards_metadata)
2328
2329        # Validate remote shards.
2330        remote_shards = st.remote_shards()
2331        self.assertEqual(self.world_size - 1, len(remote_shards))
2332        for rpc_rank, shards in remote_shards.items():
2333            self.assertEqual(1, len(shards))
2334            for remote_shard in shards:
2335                self.assertEqual(rpc_rank, remote_shard.owner().id)
2336                # If remote shard does not exist, to_here() will throw exception.
2337                if tensor_meta.shards_metadata[rpc_rank]:
2338                    shard = remote_shard.to_here()
2339                    self.assertEqual(
2340                        rank_to_metadata[rpc_rank].shard_sizes, shard.tensor.size()
2341                    )
2342
2343    @with_comms
2344    @skip_if_lt_x_gpu(4)
2345    @requires_nccl()
2346    def test_init_from_local_tensor(self):
2347        chunk_specs = _chunk_sharding_specs_list_for_test([0, 1, 1, 0], seed=31)
2348        for spec in chunk_specs:
2349            self._generate_st_from_chunk_local_tensor([20, 10], spec)
2350            self._generate_st_from_chunk_local_tensor([21, 11], spec)
2351            self._generate_st_from_chunk_local_tensor([23, 16], spec)
2352            self._generate_st_from_chunk_local_tensor([44, 16, 8], spec)
2353
2354    @with_comms
2355    @skip_if_lt_x_gpu(4)
2356    @requires_nccl()
2357    def test_init_from_local_tensor_errors(self):
2358        enumerable_sharding_spec = EnumerableShardingSpec(
2359            [
2360                ShardMetadata(
2361                    shard_offsets=[0, 0],
2362                    shard_sizes=[5, 5],
2363                    placement="rank:0/cuda:0",
2364                ),
2365                ShardMetadata(
2366                    shard_offsets=[5, 0],
2367                    shard_sizes=[5, 5],
2368                    placement="rank:1/cuda:1",
2369                ),
2370            ]
2371        )
2372        st_size = [24, 12]
2373        local_tensor = torch.rand(*st_size).cuda(self.rank)
2374        with self.assertRaisesRegex(ValueError, "do not cover the entire tensor"):
2375            ShardedTensor._init_from_local_tensor(
2376                local_tensor,
2377                enumerable_sharding_spec,
2378                st_size,
2379            )
2380        chunk_specs = _chunk_sharding_specs_list_for_test([0], seed=31)
2381        with self.assertRaisesRegex(
2382            ValueError, "local_tensor is not a contiguous Tensor."
2383        ):
2384            ShardedTensor._init_from_local_tensor(
2385                local_tensor.t(),
2386                chunk_specs[0],
2387                st_size,
2388            )
2389
2390
2391class TestShardedTensorFromLocalShards(ShardedTensorTestBase):
2392    @with_comms(init_rpc=False)
2393    @skip_if_lt_x_gpu(4)
2394    @requires_nccl()
2395    def test_local_shards(self):
2396        shard_offsets = [(self.rank // 2) * 5, (self.rank % 2) * 5]
2397        local_shard_metadata = ShardMetadata(
2398            shard_offsets=shard_offsets,
2399            shard_sizes=[5, 5],
2400            placement=f"rank:{self.rank}/cuda:{self.rank}",
2401        )
2402
2403        local_tensor = torch.randn(5, 5, device=f"cuda:{self.rank}")
2404        local_shard = sharded_tensor.Shard(local_tensor, local_shard_metadata)
2405        local_shard_from_offsets = sharded_tensor.Shard.from_tensor_and_offsets(
2406            local_tensor, shard_offsets=shard_offsets, rank=self.rank
2407        )
2408        self.assertEqual(local_shard.metadata, local_shard_from_offsets.metadata)
2409
2410        wrong_local_shard_metadata = ShardMetadata(
2411            shard_offsets=shard_offsets,
2412            shard_sizes=[6, 5],
2413            placement=f"rank:{self.rank}/cuda:{self.rank}",
2414        )
2415        with self.assertRaisesRegex(ValueError, "Shard tensor size does not match"):
2416            local_shard_from_wrong_meta = sharded_tensor.Shard(
2417                local_tensor,
2418                metadata=wrong_local_shard_metadata,
2419            )
2420
2421    @with_comms
2422    @skip_if_lt_x_gpu(4)
2423    @requires_nccl()
2424    def test_init_from_local_shards(self):
2425        local_shard_metadata = ShardMetadata(
2426            shard_offsets=[(self.rank // 2) * 5, (self.rank % 2) * 5],
2427            shard_sizes=[5, 5],
2428            placement=f"rank:{self.rank}/cuda:{self.rank}",
2429        )
2430
2431        local_shards = [
2432            sharded_tensor.Shard(
2433                torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata
2434            )
2435        ]
2436
2437        st = sharded_tensor.init_from_local_shards(
2438            local_shards, [10, 10], init_rrefs=True
2439        )
2440        self.assertEqual((10, 10), st.size())
2441        self.assertEqual(1, len(st.local_shards()))
2442
2443        # Verify local shard.
2444        local_shard = st.local_shards()[0]
2445        self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.tensor.device)
2446        self.assertEqual((5, 5), local_shard.tensor.size())
2447
2448        # Verify local shard metadata.
2449        self.assertEqual(
2450            (self.rank // 2 * 5, (self.rank % 2) * 5),
2451            local_shard.metadata.shard_offsets,
2452        )
2453        self.assertEqual((5, 5), local_shard.metadata.shard_sizes)
2454        self.assertEqual(
2455            f"rank:{self.rank}/cuda:{self.rank}", str(local_shard.metadata.placement)
2456        )
2457
2458        # Verify global metadata.
2459        shards_metadata = st.metadata().shards_metadata
2460        self.assertEqual(4, len(shards_metadata))
2461        for rank, shard_metadata in enumerate(shards_metadata):
2462            self.assertEqual(
2463                (rank // 2 * 5, (rank % 2) * 5), shard_metadata.shard_offsets
2464            )
2465            self.assertEqual((5, 5), shard_metadata.shard_sizes)
2466            self.assertEqual(f"rank:{rank}/cuda:{rank}", str(shard_metadata.placement))
2467
2468        # Validate remote shards.
2469        remote_shards = st.remote_shards()
2470        self.assertEqual(3, len(remote_shards))
2471
2472        for rpc_rank, shards in remote_shards.items():
2473            self.assertEqual(1, len(shards))
2474            for remote_shard in shards:
2475                self.assertEqual(rpc_rank, remote_shard.owner().id)
2476                shard = remote_shard.to_here()
2477                self.assertEqual((5, 5), shard.tensor.size())
2478
2479    @skip_if_lt_x_gpu(4)
2480    def test_st_base_init_from_local_shards_and_global_metadata(self):
2481        world_size = 4
2482        shards_metadata = []
2483        shards = []
2484        for rank in range(world_size):
2485            local_shard_metadata = ShardMetadata(
2486                shard_offsets=[(rank // 2) * 5, (rank % 2) * 5],
2487                shard_sizes=[5, 5],
2488                placement=f"rank:{rank}/cuda:{rank}",
2489            )
2490            shards_metadata.append(local_shard_metadata)
2491            shards.append(
2492                sharded_tensor.Shard(
2493                    torch.randn(5, 5, device=f"cuda:{rank}"), local_shard_metadata
2494                )
2495            )
2496
2497        tensor_properties = TensorProperties(
2498            dtype=torch.get_default_dtype(),
2499            layout=torch.strided,
2500            requires_grad=False,
2501            memory_format=torch.contiguous_format,
2502            pin_memory=False,
2503        )
2504
2505        sharded_tensor_metadata = sharded_tensor.ShardedTensorMetadata(
2506            shards_metadata=shards_metadata,
2507            size=torch.Size([10, 10]),
2508            tensor_properties=tensor_properties,
2509        )
2510
2511        st_base = sharded_tensor.ShardedTensorBase._init_from_local_shards_and_global_metadata(
2512            shards, sharded_tensor_metadata=sharded_tensor_metadata
2513        )
2514        self.assertEqual(4, len(st_base.local_shards()))
2515
2516        # Verify local shard of st_base
2517        local_shard = st_base.local_shards()[0]
2518        self.assertEqual(torch.device("cuda:0"), local_shard.tensor.device)
2519        self.assertEqual((5, 5), local_shard.tensor.size())
2520
2521        # Verify local shard metadata.
2522        self.assertEqual(
2523            (0, 0),
2524            local_shard.metadata.shard_offsets,
2525        )
2526        self.assertEqual((5, 5), local_shard.metadata.shard_sizes)
2527        self.assertEqual("rank:0/cuda:0", str(local_shard.metadata.placement))
2528
2529        # Verify global metadata.
2530        shards_metadata = st_base.metadata().shards_metadata
2531        self.assertEqual(4, len(shards_metadata))
2532        for rank, shard_metadata in enumerate(shards_metadata):
2533            self.assertEqual(
2534                (rank // 2 * 5, (rank % 2) * 5), shard_metadata.shard_offsets
2535            )
2536            self.assertEqual((5, 5), shard_metadata.shard_sizes)
2537            self.assertEqual(f"rank:{rank}/cuda:{rank}", str(shard_metadata.placement))
2538
2539    @with_comms
2540    @skip_if_lt_x_gpu(4)
2541    @requires_nccl()
2542    def test_init_from_local_shards_and_global_metadata(self):
2543        local_shard_metadata = ShardMetadata(
2544            shard_offsets=[(self.rank // 2) * 5, (self.rank % 2) * 5],
2545            shard_sizes=[5, 5],
2546            placement=f"rank:{self.rank}/cuda:{self.rank}",
2547        )
2548
2549        shards_metadata = []
2550        for r in range(self.world_size):
2551            if r == self.rank:
2552                shards_metadata.append(local_shard_metadata)
2553            else:
2554                shards_metadata.append(
2555                    ShardMetadata(
2556                        shard_offsets=[(r // 2) * 5, (r % 2) * 5],
2557                        shard_sizes=[5, 5],
2558                        placement=f"rank:{r}/cuda:{r}",
2559                    )
2560                )
2561
2562        local_shards = [
2563            sharded_tensor.Shard(
2564                torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata
2565            )
2566        ]
2567
2568        tensor_properties = TensorProperties(
2569            dtype=torch.get_default_dtype(),
2570            layout=torch.strided,
2571            requires_grad=False,
2572            memory_format=torch.contiguous_format,
2573            pin_memory=False,
2574        )
2575
2576        sharded_tensor_metadata = sharded_tensor.ShardedTensorMetadata(
2577            shards_metadata=shards_metadata,
2578            size=torch.Size([10, 10]),
2579            tensor_properties=tensor_properties,
2580        )
2581
2582        st = ShardedTensor._init_from_local_shards_and_global_metadata(
2583            local_shards,
2584            sharded_tensor_metadata,
2585            init_rrefs=True,
2586        )
2587        self.assertEqual((10, 10), st.size())
2588        self.assertEqual(1, len(st.local_shards()))
2589
2590        # Verify local shard.
2591        local_shard = st.local_shards()[0]
2592        self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.tensor.device)
2593        self.assertEqual((5, 5), local_shard.tensor.size())
2594
2595        # Verify local shard metadata.
2596        self.assertEqual(
2597            (self.rank // 2 * 5, (self.rank % 2) * 5),
2598            local_shard.metadata.shard_offsets,
2599        )
2600        self.assertEqual((5, 5), local_shard.metadata.shard_sizes)
2601        self.assertEqual(
2602            f"rank:{self.rank}/cuda:{self.rank}", str(local_shard.metadata.placement)
2603        )
2604
2605        # Verify global metadata.
2606        shards_metadata = st.metadata().shards_metadata
2607        self.assertEqual(4, len(shards_metadata))
2608        for rank, shard_metadata in enumerate(shards_metadata):
2609            self.assertEqual(
2610                (rank // 2 * 5, (rank % 2) * 5), shard_metadata.shard_offsets
2611            )
2612            self.assertEqual((5, 5), shard_metadata.shard_sizes)
2613            self.assertEqual(f"rank:{rank}/cuda:{rank}", str(shard_metadata.placement))
2614
2615        # Validate remote shards.
2616        remote_shards = st.remote_shards()
2617        self.assertEqual(3, len(remote_shards))
2618
2619        for rpc_rank, shards in remote_shards.items():
2620            self.assertEqual(1, len(shards))
2621            for remote_shard in shards:
2622                self.assertEqual(rpc_rank, remote_shard.owner().id)
2623                shard = remote_shard.to_here()
2624                self.assertEqual((5, 5), shard.tensor.size())
2625
2626    @with_comms
2627    @skip_if_lt_x_gpu(4)
2628    @requires_nccl()
2629    def test_init_from_local_shards_new_group(self):
2630        new_pg = dist.new_group(ranks=[1, 2, 3])
2631
2632        if self.rank != 0:
2633            local_shard_metadata = ShardMetadata(
2634                shard_offsets=[5 * (self.rank - 1), 0],
2635                shard_sizes=[5, 5],
2636                placement=f"rank:{self.rank}/cuda:{self.rank}",
2637            )
2638            local_shards = [
2639                sharded_tensor.Shard(
2640                    torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata
2641                )
2642            ]
2643
2644            st = sharded_tensor.init_from_local_shards(
2645                local_shards, [15, 5], process_group=new_pg
2646            )
2647
2648            # Verify local shard.
2649            local_shard = st.local_shards()[0]
2650            self.assertEqual(
2651                torch.device(f"cuda:{self.rank}"), local_shard.tensor.device
2652            )
2653            self.assertEqual((5, 5), local_shard.tensor.size())
2654
2655            # Verify local shard metadata.
2656            self.assertEqual(
2657                ((self.rank - 1) * 5, 0), local_shard.metadata.shard_offsets
2658            )
2659            self.assertEqual((5, 5), local_shard.metadata.shard_sizes)
2660            self.assertEqual(
2661                f"rank:{self.rank}/cuda:{self.rank}",
2662                str(local_shard.metadata.placement),
2663            )
2664
2665            # Verify global metadata.
2666            st_metadata = st.metadata()
2667            shards_metadata = st_metadata.shards_metadata
2668            self.assertEqual(3, len(shards_metadata))
2669            for rank, shard_metadata in enumerate(shards_metadata):
2670                self.assertEqual((rank * 5, 0), shard_metadata.shard_offsets)
2671                self.assertEqual((5, 5), shard_metadata.shard_sizes)
2672                self.assertEqual(
2673                    f"rank:{rank + 1}/cuda:{rank + 1}", str(shard_metadata.placement)
2674                )
2675
2676    @with_comms
2677    @skip_if_lt_x_gpu(4)
2678    @requires_nccl()
2679    def test_init_from_local_shards_invalid_local_shards(self):
2680        local_shard_metadata = ShardMetadata(
2681            shard_offsets=[(self.rank // 2) * 5, (self.rank % 2) * 5],
2682            shard_sizes=[5, 5],
2683            placement=f"rank:{self.rank}/cuda:{self.rank}",
2684        )
2685
2686        indices = [[0, 1, 1], [2, 0, 2]]
2687        values = [3.2, 4.5, 5.8]
2688        sparse_tensor = torch.sparse_coo_tensor(
2689            indices, values, (5, 5), device=f"cuda:{self.rank}"
2690        )
2691
2692        empty_local_shards = []
2693        with self.assertRaisesRegex(ValueError, "have no local shards on all ranks"):
2694            st = sharded_tensor.init_from_local_shards(
2695                empty_local_shards, [10, 10], init_rrefs=True
2696            )
2697
2698        wrong_layout_shards = [
2699            sharded_tensor.Shard(sparse_tensor, local_shard_metadata)
2700        ]
2701        with self.assertRaisesRegex(
2702            ValueError, "Only torch.strided layout is currently supported"
2703        ):
2704            st = sharded_tensor.init_from_local_shards(
2705                wrong_layout_shards, [10, 10], init_rrefs=True
2706            )
2707
2708        wrong_memory_format_shards = [
2709            sharded_tensor.Shard(
2710                torch.randn(5, 5, device=f"cuda:{self.rank}").t(), local_shard_metadata
2711            )
2712        ]
2713        with self.assertRaisesRegex(
2714            ValueError,
2715            "Only torch.contiguous_format memory_format is currently supported",
2716        ):
2717            st = sharded_tensor.init_from_local_shards(
2718                wrong_memory_format_shards, [10, 10], init_rrefs=True
2719            )
2720
2721        with self.assertRaisesRegex(ValueError, "Shard tensor size does not match"):
2722            wrong_size_shards = [
2723                sharded_tensor.Shard(
2724                    torch.randn(2, 3, device=f"cuda:{self.rank}"), local_shard_metadata
2725                )
2726            ]
2727
2728        with self.assertRaisesRegex(
2729            ValueError, "Local shard tensor device does not match"
2730        ):
2731            wrong_device_shards = [
2732                sharded_tensor.Shard(torch.randn(5, 5), local_shard_metadata)
2733            ]
2734
2735    @with_comms
2736    @skip_if_lt_x_gpu(4)
2737    @requires_nccl()
2738    def test_init_from_local_shards_invalid_property_cross_ranks(self):
2739        local_shard_metadata = ShardMetadata(
2740            shard_offsets=[(self.rank // 2) * 5, (self.rank % 2) * 5],
2741            shard_sizes=[5, 5],
2742            placement=f"rank:{self.rank}/cuda:{self.rank}",
2743        )
2744        tensor_overall_size = [10, 10] if self.rank == 0 else [10, 5]
2745        wrong_dtype_shards = [
2746            sharded_tensor.Shard(
2747                torch.ones(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata
2748            )
2749        ]
2750        with self.assertRaisesRegex(
2751            ValueError,
2752            "ShardedTensor global_size property does not match from different ranks!",
2753        ):
2754            st = sharded_tensor.init_from_local_shards(
2755                wrong_dtype_shards, tensor_overall_size, init_rrefs=True
2756            )
2757
2758        tensor_dtype = torch.int if self.rank == 0 else torch.float32
2759        wrong_dtype_shards = [
2760            sharded_tensor.Shard(
2761                torch.ones(5, 5, device=f"cuda:{self.rank}", dtype=tensor_dtype),
2762                local_shard_metadata,
2763            )
2764        ]
2765        with self.assertRaisesRegex(
2766            ValueError,
2767            "ShardedTensor dtype property does not match from different ranks!",
2768        ):
2769            st = sharded_tensor.init_from_local_shards(
2770                wrong_dtype_shards, [10, 10], init_rrefs=True
2771            )
2772
2773        tensor_requires_grad = True if self.rank == 0 else False
2774        wrong_requires_grad_shards = [
2775            sharded_tensor.Shard(
2776                torch.randn(
2777                    5, 5, device=f"cuda:{self.rank}", requires_grad=tensor_requires_grad
2778                ),
2779                local_shard_metadata,
2780            )
2781        ]
2782        with self.assertRaisesRegex(
2783            ValueError,
2784            "ShardedTensor requires_grad property does not match from different ranks!",
2785        ):
2786            st = sharded_tensor.init_from_local_shards(
2787                wrong_requires_grad_shards, [10, 10], init_rrefs=True
2788            )
2789
2790        local_shard_metadata = ShardMetadata(
2791            shard_offsets=[(self.rank // 2) * 5, (self.rank % 2) * 5],
2792            shard_sizes=[5, 5],
2793            placement=f"rank:{self.rank}/cpu",
2794        )
2795
2796    @with_comms(init_rpc=False, backend="gloo")
2797    @skip_if_lt_x_gpu(4)
2798    def test_init_from_local_shards_invalid_pin_memory(self):
2799        # pin memory can only be on dense cpu
2800        local_shard_metadata = ShardMetadata(
2801            shard_offsets=[(self.rank // 2) * 5, (self.rank % 2) * 5],
2802            shard_sizes=[5, 5],
2803            placement=f"rank:{self.rank}/cpu",
2804        )
2805        wrong_pin_memory_local_shards = [
2806            sharded_tensor.Shard(
2807                torch.randn(5, 5, pin_memory=True), local_shard_metadata
2808            ),
2809            sharded_tensor.Shard(
2810                torch.randn(5, 5, pin_memory=False), local_shard_metadata
2811            ),
2812        ]
2813        with self.assertRaisesRegex(
2814            ValueError, "Local shards' tensor pin_memory property need to be the same"
2815        ):
2816            st = sharded_tensor.init_from_local_shards(
2817                wrong_pin_memory_local_shards, [10, 10], init_rrefs=True
2818            )
2819
2820        tensor_pin_memory = True if self.rank == 0 else False
2821        wrong_pin_memory_shards_cross_ranks = [
2822            sharded_tensor.Shard(
2823                torch.randn(5, 5, pin_memory=tensor_pin_memory), local_shard_metadata
2824            )
2825        ]
2826        with self.assertRaisesRegex(
2827            ValueError,
2828            "ShardedTensor pin_memory property does not match from different ranks!",
2829        ):
2830            st = sharded_tensor.init_from_local_shards(
2831                wrong_pin_memory_shards_cross_ranks, [10, 10], init_rrefs=True
2832            )
2833
2834    @with_comms
2835    @skip_if_lt_x_gpu(4)
2836    @requires_nccl()
2837    def test_init_from_local_shards_invalid_shards_overlap(self):
2838        local_shard_size = [5, 5] if self.rank != 0 else [6, 6]
2839        local_shard_metadata = ShardMetadata(
2840            shard_offsets=[(self.rank // 2) * 5, (self.rank % 2) * 5],
2841            shard_sizes=local_shard_size,
2842            placement=f"rank:{self.rank}/cuda:{self.rank}",
2843        )
2844
2845        local_shards = [
2846            sharded_tensor.Shard(
2847                torch.randn(local_shard_size, device=f"cuda:{self.rank}"),
2848                local_shard_metadata,
2849            )
2850        ]
2851
2852        with self.assertRaisesRegex(ValueError, "overlap"):
2853            sharded_tensor.init_from_local_shards(
2854                local_shards, [10, 10], init_rrefs=True
2855            )
2856
2857    @with_comms
2858    @skip_if_lt_x_gpu(4)
2859    @requires_nccl()
2860    def test_init_from_local_shards_invalid_shards_gaps(self):
2861        local_shard_size = [5, 5] if self.rank != 0 else [4, 4]
2862        local_shard_metadata = ShardMetadata(
2863            shard_offsets=[(self.rank // 2) * 5, (self.rank % 2) * 5],
2864            shard_sizes=local_shard_size,
2865            placement=f"rank:{self.rank}/cuda:{self.rank}",
2866        )
2867
2868        local_shards = [
2869            sharded_tensor.Shard(
2870                torch.randn(local_shard_size, device=f"cuda:{self.rank}"),
2871                local_shard_metadata,
2872            )
2873        ]
2874
2875        with self.assertRaisesRegex(ValueError, "does not match tensor volume"):
2876            sharded_tensor.init_from_local_shards(
2877                local_shards, [10, 10], init_rrefs=True
2878            )
2879
2880    @with_comms
2881    @skip_if_lt_x_gpu(4)
2882    @requires_nccl()
2883    def test_init_from_local_shards_and_global_metadata_invalid_shards(self):
2884        local_shard_metadata = ShardMetadata(
2885            shard_offsets=[(self.rank // 2) * 5, (self.rank % 2) * 5],
2886            shard_sizes=[5, 5],
2887            placement=f"rank:{self.rank}/cuda:{self.rank}",
2888        )
2889
2890        shards_metadata = []
2891        for r in range(self.world_size):
2892            if r == self.rank:
2893                shards_metadata.append(local_shard_metadata)
2894            else:
2895                shards_metadata.append(
2896                    ShardMetadata(
2897                        shard_offsets=[(r // 2) * 5, (r % 2) * 5],
2898                        shard_sizes=[5, 5],
2899                        placement=f"rank:{r}/cuda:{r}",
2900                    )
2901                )
2902
2903        tensor_properties = TensorProperties(
2904            dtype=torch.get_default_dtype(),
2905            layout=torch.strided,
2906            requires_grad=False,
2907            memory_format=torch.contiguous_format,
2908            pin_memory=False,
2909        )
2910
2911        sharded_tensor_metadata = sharded_tensor.ShardedTensorMetadata(
2912            shards_metadata=shards_metadata,
2913            size=torch.Size([10, 10]),
2914            tensor_properties=tensor_properties,
2915        )
2916
2917        empty_local_shards = []
2918        with self.assertRaisesRegex(
2919            RuntimeError, "does not match number of local shards metadata"
2920        ):
2921            ShardedTensor._init_from_local_shards_and_global_metadata(
2922                empty_local_shards, sharded_tensor_metadata
2923            )
2924
2925        wrong_num_shards = [
2926            sharded_tensor.Shard(
2927                torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata
2928            ),
2929            sharded_tensor.Shard(
2930                torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata
2931            ),
2932        ]
2933        with self.assertRaisesRegex(
2934            RuntimeError, "does not match number of local shards metadata"
2935        ):
2936            ShardedTensor._init_from_local_shards_and_global_metadata(
2937                wrong_num_shards, sharded_tensor_metadata
2938            )
2939
2940        with self.assertRaisesRegex(
2941            ValueError, "Shard tensor size does not match with metadata.shard_lengths"
2942        ):
2943            wrong_size_shards = [
2944                sharded_tensor.Shard(
2945                    torch.randn(2, 3, device=f"cuda:{self.rank}"), local_shard_metadata
2946                )
2947            ]
2948
2949        with self.assertRaisesRegex(
2950            ValueError,
2951            "Local shard tensor device does not match with local Shard's placement",
2952        ):
2953            wrong_device_shards = [
2954                sharded_tensor.Shard(torch.randn(5, 5), local_shard_metadata)
2955            ]
2956
2957        wrong_dtype_shards = [
2958            sharded_tensor.Shard(
2959                torch.ones(5, 5, device=f"cuda:{self.rank}", dtype=torch.int),
2960                local_shard_metadata,
2961            )
2962        ]
2963        with self.assertRaisesRegex(
2964            ValueError, "Local shards' tensor dtype property is incompatible with"
2965        ):
2966            ShardedTensor._init_from_local_shards_and_global_metadata(
2967                wrong_dtype_shards, sharded_tensor_metadata
2968            )
2969
2970        indices = [[0, 1, 1], [2, 0, 2]]
2971        values = [3.2, 4.5, 5.8]
2972        sparse_tensor = torch.sparse_coo_tensor(
2973            indices, values, (5, 5), device=f"cuda:{self.rank}"
2974        )
2975
2976        wrong_layout_shards = [
2977            sharded_tensor.Shard(sparse_tensor, local_shard_metadata)
2978        ]
2979        with self.assertRaisesRegex(
2980            ValueError, "Local shards' tensor layout property is incompatible with"
2981        ):
2982            ShardedTensor._init_from_local_shards_and_global_metadata(
2983                wrong_layout_shards, sharded_tensor_metadata
2984            )
2985
2986        wrong_requires_grad_shards = [
2987            sharded_tensor.Shard(
2988                torch.randn(5, 5, device=f"cuda:{self.rank}", requires_grad=True),
2989                local_shard_metadata,
2990            )
2991        ]
2992        with self.assertRaisesRegex(
2993            ValueError,
2994            "Local shards' tensor requires_grad property is incompatible with",
2995        ):
2996            ShardedTensor._init_from_local_shards_and_global_metadata(
2997                wrong_requires_grad_shards, sharded_tensor_metadata
2998            )
2999
3000        wrong_memory_format_shards = [
3001            sharded_tensor.Shard(
3002                torch.randn(5, 5, device=f"cuda:{self.rank}").t(), local_shard_metadata
3003            )
3004        ]
3005        with self.assertRaisesRegex(
3006            ValueError,
3007            "Only torch.contiguous_format memory_format is currently supported",
3008        ):
3009            ShardedTensor._init_from_local_shards_and_global_metadata(
3010                wrong_memory_format_shards, sharded_tensor_metadata
3011            )
3012        # pin_memory can only be on CPU
3013        local_shard_metadata.placement = _remote_device(f"rank:{self.rank}/cpu")
3014        wrong_pin_memory_shards = [
3015            sharded_tensor.Shard(
3016                torch.randn(5, 5, pin_memory=True), local_shard_metadata
3017            )
3018        ]
3019        with self.assertRaisesRegex(
3020            ValueError, "Local shards' tensor pin_memory property is incompatible with"
3021        ):
3022            ShardedTensor._init_from_local_shards_and_global_metadata(
3023                wrong_pin_memory_shards, sharded_tensor_metadata
3024            )
3025
3026
3027class TestShardedTensorCustomOps(ShardedTensorTestBase):
3028    @with_comms
3029    @skip_if_lt_x_gpu(4)
3030    @requires_nccl()
3031    def test_custom_op(self):
3032        @custom_sharded_op_impl(torch.asin)
3033        def my_sharded_asin(types, args, kwargs, process_group):
3034            return torch.asin(args[0].local_shards()[0].tensor)
3035
3036        spec = ChunkShardingSpec(
3037            dim=0,
3038            placements=[
3039                "rank:0/cuda:0",
3040                "rank:1/cuda:1",
3041                "rank:2/cuda:2",
3042                "rank:3/cuda:3",
3043            ],
3044        )
3045
3046        st = sharded_tensor.rand(spec, 10, 10)
3047        res = torch.asin(st)
3048        self.assertEqual(res, torch.asin(st.local_shards()[0].tensor))
3049
3050    @with_comms
3051    @skip_if_lt_x_gpu(4)
3052    @requires_nccl()
3053    def test_custom_op_override(self):
3054        t = torch.rand(10, 10).cuda(self.rank)
3055
3056        from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op
3057
3058        @custom_sharding_spec_op(ChunkShardingSpec, torch.nn.functional.linear)
3059        def my_sharded_linear(types, args, kwargs, process_group):
3060            return t
3061
3062        spec = ChunkShardingSpec(
3063            dim=0,
3064            placements=[
3065                "rank:0/cuda:0",
3066                "rank:1/cuda:1",
3067                "rank:2/cuda:2",
3068                "rank:3/cuda:3",
3069            ],
3070        )
3071        m = torch.nn.Linear(32, 16).cuda(self.rank)
3072        shard_parameter(m, "weight", spec)
3073
3074        result = m(torch.rand(15, 32).cuda(self.rank))
3075        self.assertEqual(t, result)
3076
3077    @with_comms
3078    @skip_if_lt_x_gpu(4)
3079    @requires_nccl()
3080    def test_custom_op_errors(self):
3081        with self.assertRaisesRegex(TypeError, "expects signature"):
3082
3083            @custom_sharded_op_impl(torch.nn.functional.linear)
3084            def my_op1(types, args, kwargs, process_group, random_param):
3085                pass
3086
3087        with self.assertRaisesRegex(TypeError, "expects signature"):
3088
3089            @custom_sharded_op_impl(torch.nn.functional.linear)
3090            def my_op2(types):
3091                pass
3092
3093
3094class TestShardMetadata(ShardedTensorTestBase):
3095    @with_comms
3096    @requires_nccl()
3097    def test_shard_metadata_init(self):
3098        pg = dist.distributed_c10d._get_default_group()
3099
3100        md = ShardMetadata([10], [0])
3101        self.assertIsNone(md.placement)
3102        with self.assertRaisesRegex(ValueError, "remote device is None"):
3103            _parse_and_validate_remote_device(pg, md.placement)
3104
3105        # String placement gets converted by ctor
3106        md = ShardMetadata([10], [0], "rank:0/cpu")
3107        self.assertEqual(md.placement, _remote_device("rank:0/cpu"))
3108        rank, device = _parse_and_validate_remote_device(pg, md.placement)
3109        self.assertEqual(0, rank)
3110        self.assertEqual(device, torch.device("cpu"))
3111
3112    @with_comms
3113    @requires_nccl()
3114    def test_create_shard_with_no_placement(self):
3115        md = ShardMetadata([0], [10])
3116        shard = Shard(torch.zeros(10), md)
3117        self.assertIsNone(shard.metadata.placement)
3118
3119
3120class TestShardedTensorSubGroupInit(TestCase):
3121    @spawn_threads_and_init_comms(world_size=4)
3122    def test_sub_process_group_sharded_tensor_init(self):
3123        world_pg = dist.GroupMember.WORLD
3124        rank = dist.get_rank()
3125
3126        sub_group_sz = 2
3127        sub_pg_ranks = [r for r in range(4) if r % sub_group_sz == rank % sub_group_sz]
3128        sub_pg = dist.new_group(
3129            sub_pg_ranks,
3130            backend=dist.get_backend(world_pg),
3131            use_local_synchronization=True,
3132        )
3133        dist.barrier(sub_pg)
3134
3135        ShardedTensor._init_from_local_shards(
3136            [
3137                Shard(
3138                    tensor=torch.tensor([1, 2, 3], device="meta"),
3139                    metadata=ShardMetadata(
3140                        shard_offsets=[3 * (rank // sub_group_sz)],
3141                        shard_sizes=[3],
3142                        placement=f"rank:{rank}/meta",
3143                    ),
3144                )
3145            ],
3146            6,
3147            process_group=sub_pg,
3148        )
3149
3150    @spawn_threads_and_init_comms(world_size=4)
3151    def test_sub_process_group_placement_validation(self):
3152        world_pg = dist.GroupMember.WORLD
3153        self.assertIsNotNone(world_pg)
3154        rank = dist.get_rank()
3155
3156        sub_group_sz = 2
3157        sub_pg_ranks = [r for r in range(4) if r % sub_group_sz == rank % sub_group_sz]
3158        sub_pg = dist.new_group(
3159            sub_pg_ranks,
3160            backend=dist.get_backend(world_pg),
3161            use_local_synchronization=True,
3162        )
3163        dist.barrier(sub_pg)
3164
3165        for r in sub_pg_ranks:
3166            _parse_and_validate_remote_device(
3167                sub_pg, _remote_device(f"rank:{r}/cuda:{r % sub_group_sz}")
3168            )
3169
3170
3171class TestCreateTensorNoProcessGroupMode(TestCase):
3172    def test_init_from_local_shards_and_global_metadata(self):
3173        st_metadata: ShardedTensorMetadata = ShardedTensorMetadata(
3174            shards_metadata=[
3175                ShardMetadata(
3176                    shard_offsets=[0, 0], shard_sizes=[2, 2], placement="rank:0/cpu"
3177                ),
3178                ShardMetadata(
3179                    shard_offsets=[2, 0], shard_sizes=[2, 2], placement="rank:1/cpu"
3180                ),
3181            ],
3182            size=torch.Size([4, 2]),
3183        )
3184        st_local_shards: List[Shard] = []
3185        for shard_metadata in st_metadata.shards_metadata:
3186            st_local_shards.append(
3187                Shard(
3188                    tensor=torch.zeros(
3189                        shard_metadata.shard_sizes,
3190                        device=shard_metadata.placement.device(),
3191                    ),
3192                    metadata=shard_metadata,
3193                )
3194            )
3195
3196        ShardedTensorBase._init_from_local_shards_and_global_metadata(
3197            local_shards=st_local_shards,
3198            sharded_tensor_metadata=st_metadata,
3199        )
3200
3201    def test_non_contiguous_local_shards(self):
3202        st_metadata: ShardedTensorMetadata = ShardedTensorMetadata(
3203            shards_metadata=[
3204                ShardMetadata(
3205                    shard_offsets=[0, 0], shard_sizes=[2, 2], placement="rank:0/cpu"
3206                ),
3207                ShardMetadata(
3208                    shard_offsets=[2, 0], shard_sizes=[2, 2], placement="rank:1/cpu"
3209                ),
3210            ],
3211            size=torch.Size([4, 2]),
3212        )
3213        st_local_shards: List[Shard] = []
3214        src = torch.randn(4, 2)
3215        for shard_metadata in st_metadata.shards_metadata:
3216            offsets = shard_metadata.shard_offsets
3217            sizes = shard_metadata.shard_sizes
3218            st_local_shards.append(
3219                Shard(
3220                    tensor=src[
3221                        offsets[0] : offsets[0] + sizes[0],
3222                        offsets[1] : offsets[1] + sizes[1],
3223                    ],
3224                    metadata=shard_metadata,
3225                )
3226            )
3227
3228        ShardedTensorBase._init_from_local_shards_and_global_metadata(
3229            local_shards=st_local_shards,
3230            sharded_tensor_metadata=st_metadata,
3231        )
3232
3233
3234if __name__ == "__main__":
3235    run_tests()
3236