1# mypy: allow-untyped-defs 2 3import copy 4import random 5import torch 6from torch.distributed._shard import sharded_tensor 7 8from torch.distributed._shard.sharding_spec import ( 9 ChunkShardingSpec, 10) 11 12PLACEMENTS = [ 13 "rank:0/cuda:0", 14 "rank:1/cuda:1", 15 "rank:2/cuda:2", 16 "rank:3/cuda:3", 17] 18 19DEFAULT_GPU_NUM = 4 20 21 22def _chunk_sharding_specs_list_for_test(sharding_dims, seed=0): 23 spec_list = [] 24 for i in range(len(sharding_dims)): 25 random.Random(seed + i).shuffle(PLACEMENTS) 26 spec_list.append( 27 ChunkShardingSpec( 28 dim=sharding_dims[i], 29 placements=copy.deepcopy(PLACEMENTS), 30 ) 31 ) 32 return spec_list 33 34class MyShardedModel2(torch.nn.Module): 35 def __init__( 36 self, 37 spec=None, 38 group=None, 39 init_rrefs=True 40 ) -> None: 41 super().__init__() 42 if spec is not None: 43 self.sharded_tensor2 = sharded_tensor.rand( 44 spec, 10, 20, process_group=group, init_rrefs=init_rrefs 45 ) 46 else: 47 self.sharded_tensor2 = None 48 self.random_tensor2 = torch.nn.Parameter(torch.rand(2, 2)) 49 50 51class MyShardedModel1(torch.nn.Module): 52 def __init__( 53 self, 54 spec=None, 55 group=None, 56 init_rrefs=True 57 ) -> None: 58 super().__init__() 59 if spec is not None: 60 self.sharded_tensor1 = sharded_tensor.rand( 61 spec, 10, 20, process_group=group, init_rrefs=init_rrefs 62 ) 63 else: 64 self.sharded_tensor1 = None 65 self.random_tensor1 = torch.nn.Parameter(torch.rand(2, 2)) 66 self.submodule = MyShardedModel2(spec, group, init_rrefs) 67