• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2
3import copy
4import random
5import torch
6from torch.distributed._shard import sharded_tensor
7
8from torch.distributed._shard.sharding_spec import (
9    ChunkShardingSpec,
10)
11
12PLACEMENTS = [
13    "rank:0/cuda:0",
14    "rank:1/cuda:1",
15    "rank:2/cuda:2",
16    "rank:3/cuda:3",
17]
18
19DEFAULT_GPU_NUM = 4
20
21
22def _chunk_sharding_specs_list_for_test(sharding_dims, seed=0):
23    spec_list = []
24    for i in range(len(sharding_dims)):
25        random.Random(seed + i).shuffle(PLACEMENTS)
26        spec_list.append(
27            ChunkShardingSpec(
28                dim=sharding_dims[i],
29                placements=copy.deepcopy(PLACEMENTS),
30            )
31        )
32    return spec_list
33
34class MyShardedModel2(torch.nn.Module):
35    def __init__(
36        self,
37        spec=None,
38        group=None,
39        init_rrefs=True
40    ) -> None:
41        super().__init__()
42        if spec is not None:
43            self.sharded_tensor2 = sharded_tensor.rand(
44                spec, 10, 20, process_group=group, init_rrefs=init_rrefs
45            )
46        else:
47            self.sharded_tensor2 = None
48        self.random_tensor2 = torch.nn.Parameter(torch.rand(2, 2))
49
50
51class MyShardedModel1(torch.nn.Module):
52    def __init__(
53        self,
54        spec=None,
55        group=None,
56        init_rrefs=True
57    ) -> None:
58        super().__init__()
59        if spec is not None:
60            self.sharded_tensor1 = sharded_tensor.rand(
61                spec, 10, 20, process_group=group, init_rrefs=init_rrefs
62            )
63        else:
64            self.sharded_tensor1 = None
65        self.random_tensor1 = torch.nn.Parameter(torch.rand(2, 2))
66        self.submodule = MyShardedModel2(spec, group, init_rrefs)
67