• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3
4import torch
5from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard, zeros
6from torch.testing._internal.common_utils import run_tests
7from torch.testing._internal.distributed._tensor.common_dtensor import (
8    DTensorTestBase,
9    with_comms,
10)
11
12
13class DTensorInitOpsTest(DTensorTestBase):
14    def _run_init_op(self, init_op, *args, **kwargs):
15        device_mesh = self.build_device_mesh()
16        shard_spec = [Shard(0)]
17        input_size = (8, 4)
18        input_tensor = torch.randn(*input_size, device=self.device_type)
19        dtensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
20        local_tensor_clone = torch.clone(input_tensor)
21        torch.manual_seed(self.rank)
22        local_tensor_clone = init_op(local_tensor_clone, *args, **kwargs)
23        torch.manual_seed(self.rank)
24        dtensor = init_op(dtensor, *args, **kwargs)
25        self.assertEqual(local_tensor_clone, dtensor.to_local())
26
27    @with_comms
28    def test_init_ops(self):
29        # NOTE: random init tests are moved to test_random_ops.py
30        self._run_init_op(torch.nn.init.constant_, 2.4)
31
32
33class DTensorConstructorTest(DTensorTestBase):
34    @property
35    def world_size(self):
36        return 4
37
38    def _run_init_op(self, init_op, dist_init_op, eq_op, *args, **kwargs):
39        # 1d mesh test
40        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
41        placements_list = [[Shard(0)], [Shard(1)], [Shard(2)], [Replicate()]]
42
43        # even sharding
44        tensor_size = [4, 8, 12]
45        for placements in placements_list:
46            local_tensor_size = tensor_size.copy()
47            if isinstance(placements[0], Shard):
48                shard_dim = placements[0].dim
49                local_tensor_size[shard_dim] //= self.world_size
50
51            dist_tensor = dist_init_op(
52                tensor_size,
53                *args,
54                **kwargs,
55                device_mesh=device_mesh,
56                placements=placements,
57            )
58            ones_expected = init_op(local_tensor_size, *args, **kwargs)
59            eq_op(ones_expected, dist_tensor.to_local())
60
61        # uneven sharding
62        tensor_size = [5, 10, 15]
63        for placements in placements_list:
64            dist_tensor = dist_init_op(
65                tensor_size,
66                *args,
67                **kwargs,
68                device_mesh=device_mesh,
69                placements=placements,
70            )
71            if isinstance(placements[0], Shard):
72                shard_dim = placements[0].dim
73                exp_tensor_list = list(
74                    torch.chunk(
75                        init_op(tensor_size, *args, **kwargs),
76                        self.world_size,
77                        dim=shard_dim,
78                    )
79                )
80                if self.rank < len(exp_tensor_list):
81                    eq_op(exp_tensor_list[self.rank], dist_tensor.to_local())
82            else:
83                exp_tensor = init_op(tensor_size, *args, **kwargs)
84                eq_op(exp_tensor, dist_tensor.to_local())
85
86        # empty shape
87        local_tensor = dist_init_op(
88            [], *args, **kwargs, device_mesh=device_mesh, placements=[Replicate()]
89        ).to_local()
90        expected_tensor = init_op([], *args, **kwargs)
91        eq_op(expected_tensor, local_tensor)
92
93    @with_comms
94    def test_ones(self):
95        self._run_init_op(
96            torch.ones,
97            torch.distributed._tensor.ones,
98            self.assertEqual,
99            requires_grad=True,
100        )
101
102    @with_comms
103    def test_empty(self):
104        self._run_init_op(
105            torch.empty,
106            torch.distributed._tensor.empty,
107            lambda x, y: (x.shape == y.shape)
108            and (x.dtype == y.dtype)
109            and (x.layout == y.layout),
110            requires_grad=True,
111        )
112
113    @with_comms
114    def test_full(self):
115        self._run_init_op(
116            torch.full,
117            torch.distributed._tensor.full,
118            self.assertEqual,
119            123.4,
120            requires_grad=True,
121        )
122
123    @with_comms
124    def test_zeros(self):
125        self._run_init_op(
126            torch.zeros,
127            torch.distributed._tensor.zeros,
128            self.assertEqual,
129            requires_grad=True,
130        )
131
132    @with_comms
133    def test_zeros_full_mesh(self):
134        # construct a cuda device 1d mesh
135        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
136        placements = [Shard(0)]
137        size = [32, 3]
138        dist_tensor = zeros(size, device_mesh=mesh, placements=placements)
139        self.assertEqual(dist_tensor.size(), torch.Size(size))
140        local_tensor = dist_tensor.to_local()
141        self.assertEqual(local_tensor.size(), torch.Size([8, 3]))
142
143        local_tensor = torch.zeros(8, 3)
144        self.assertEqual(dist_tensor.to_local(), local_tensor)
145
146        self.assertEqual(dist_tensor.device.type, self.device_type)
147
148        # 1d sharded unevenly
149        size = [31, 3]
150        dist_tensor = zeros(size, device_mesh=mesh, placements=placements)
151        self.assertEqual(dist_tensor.size(), torch.Size(size))
152        local_tensor = dist_tensor.to_local()
153        if self.rank <= 2:
154            self.assertEqual(local_tensor.size(), torch.Size([8, 3]))
155            self.assertEqual(torch.zeros(8, 3), local_tensor)
156        else:
157            self.assertEqual(local_tensor.size(), torch.Size([7, 3]))
158            self.assertEqual(torch.zeros(7, 3), local_tensor)
159
160        # construct a cuda device mesh with 2d: shard, replicate
161        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2))
162        placements = [Shard(0), Replicate()]
163        size = [32, 4]
164        dist_tensor = zeros(size, device_mesh=mesh, placements=placements)
165
166        self.assertEqual(dist_tensor.size(), torch.Size(size))
167        local_tensor = dist_tensor.to_local()
168        self.assertEqual(local_tensor.size(), torch.Size([16, 4]))
169        self.assertEqual(local_tensor, torch.zeros([16, 4]))
170
171        # construct a cuda device mesh with 2d: shard, shard
172        placements = [Shard(0), Shard(1)]
173        size = [32, 4]
174        dist_tensor = zeros(size, device_mesh=mesh, placements=placements)
175
176        self.assertEqual(dist_tensor.size(), torch.Size(size))
177        local_tensor = dist_tensor.to_local()
178        self.assertEqual(local_tensor.size(), torch.Size([16, 2]))
179        self.assertEqual(local_tensor, torch.zeros([16, 2]))
180
181        # 2d sharded unevenly
182        placements = [Shard(0), Shard(1)]
183        size = [31, 3]
184        dist_tensor = zeros(size, device_mesh=mesh, placements=placements)
185
186        self.assertEqual(dist_tensor.size(), torch.Size(size))
187        local_tensor = dist_tensor.to_local()
188        if self.rank == 0:
189            self.assertEqual(local_tensor, torch.zeros([16, 2]))
190        elif self.rank == 1:
191            self.assertEqual(local_tensor, torch.zeros([16, 1]))
192        elif self.rank == 2:
193            self.assertEqual(local_tensor, torch.zeros([15, 2]))
194        elif self.rank == 3:
195            self.assertEqual(local_tensor, torch.zeros([15, 1]))
196
197    @with_comms
198    def test_zeros_submesh(self):
199        # default world_size is 4
200        # construct a cuda device 1d mesh, with no sub pg initialized
201        sub_mesh_list = [0, 3]
202        mesh = DeviceMesh(self.device_type, sub_mesh_list)
203        placements = [Shard(0)]
204        size = [32, 3]
205        dist_tensor = zeros(size, device_mesh=mesh, placements=placements)
206        self.assertEqual(dist_tensor.size(), torch.Size(size))
207        local_tensor = dist_tensor.to_local()
208
209        if self.rank in sub_mesh_list:
210            self.assertEqual(local_tensor.size(), torch.Size([16, 3]))
211            self.assertEqual(local_tensor, torch.zeros([16, 3]))
212        else:
213            self.assertEqual(local_tensor.size(), torch.Size([0]))
214            self.assertEqual(local_tensor, torch.zeros(0))
215
216        # construct a cuda device 1d mesh: unevenly, with subpg initialized
217        sub_mesh_list = [0, 1, 3]
218        mesh = DeviceMesh(self.device_type, sub_mesh_list)
219        placements = [Shard(0)]
220        size = [32, 3]
221        dist_tensor = zeros(size, device_mesh=mesh, placements=placements)
222        self.assertEqual(dist_tensor.size(), torch.Size(size))
223        local_tensor = dist_tensor.to_local()
224
225        if self.rank in sub_mesh_list:
226            if self.rank != 3:
227                self.assertEqual(local_tensor.size(), torch.Size([11, 3]))
228                self.assertEqual(local_tensor, torch.zeros([11, 3]))
229            else:
230                self.assertEqual(local_tensor.size(), torch.Size([10, 3]))
231                self.assertEqual(local_tensor, torch.zeros([10, 3]))
232        else:
233            self.assertEqual(local_tensor.size(), torch.Size([0]))
234            self.assertEqual(local_tensor, torch.tensor([]))
235
236        # construct a cuda device 2d mesh, with no subpg initialized
237        sub_mesh_list = [[0], [3]]
238        mesh = DeviceMesh(self.device_type, sub_mesh_list)
239        placements = [Shard(0), Shard(1)]
240        size = [32, 3]
241        dist_tensor = zeros(size, device_mesh=mesh, placements=placements)
242        self.assertEqual(dist_tensor.size(), torch.Size(size))
243        local_tensor = dist_tensor.to_local()
244
245        if self.rank in [0, 3]:
246            self.assertEqual(local_tensor.size(), torch.Size([16, 3]))
247            self.assertEqual(local_tensor, torch.zeros([16, 3]))
248        else:
249            self.assertEqual(local_tensor.size(), torch.Size([0]))
250            self.assertEqual(local_tensor, torch.tensor([]))
251
252
253if __name__ == "__main__":
254    run_tests()
255