• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3
4import itertools
5
6import torch
7import torch.distributed._functional_collectives as funcol
8import torch.distributed.tensor._random as random
9from torch.distributed._tensor import DeviceMesh, DTensor
10from torch.distributed._tensor._utils import compute_local_shape_and_global_offset
11from torch.distributed._tensor.api import distribute_tensor
12from torch.distributed._tensor.placement_types import Replicate, Shard
13from torch.distributed.distributed_c10d import broadcast_object_list
14from torch.distributed.tensor._random import is_rng_supported_mesh, manual_seed
15from torch.testing._internal.common_utils import run_tests
16from torch.testing._internal.distributed._tensor.common_dtensor import (
17    DTensorTestBase,
18    skip_if_lt_x_gpu,
19    skip_unless_torch_gpu,
20    with_comms,
21)
22
23
24class DistTensorRandomInitTest(DTensorTestBase):
25    def _run_init_op(self, init_op, *args, **kwargs):
26        device_mesh = self.build_device_mesh()
27        shard_spec = [Shard(0)]
28        input_size = (8, 4)
29
30        # NOTE: currently random initialization on cuda device has different
31        # behavior from other devices. Unify the test once the behavior is unified.
32        if not is_rng_supported_mesh(device_mesh):
33            input_tensor = torch.randn(*input_size, device=self.device_type)
34            dtensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
35            local_tensor_clone = torch.clone(input_tensor)
36            torch.manual_seed(self.rank)
37            local_tensor_clone = init_op(local_tensor_clone, *args, **kwargs)
38            torch.manual_seed(self.rank)
39            dtensor = init_op(dtensor, *args, **kwargs)
40            self.assertEqual(local_tensor_clone, dtensor.to_local())
41        else:
42            # create DTensor from Tensor
43            _tensor = torch.empty(*input_size, device="cuda")
44            dtensor = distribute_tensor(_tensor, device_mesh, [Shard(1)])
45
46            # DTensor random init
47            dtensor = init_op(dtensor, *args, **kwargs)
48            local_tensor = dtensor.to_local()
49
50            # compare with local tensors from other ranks
51            for other_rank in range(self.world_size):
52                if self.rank != other_rank:
53                    slice_idx = [
54                        slice(input_size[0]),
55                        slice(
56                            other_rank * input_size[1], (other_rank + 1) * input_size[1]
57                        ),
58                    ]
59                    # other rank should have a different local tensor
60                    self.assertNotEqual(dtensor.full_tensor()[slice_idx], local_tensor)
61
62    @with_comms
63    def test_init_ops(self):
64        self._run_init_op(
65            torch.nn.init.kaiming_uniform_,
66            a=0,
67            mode="fan_in",
68            nonlinearity="leaky_relu",
69        )
70        self._run_init_op(torch.nn.init.normal_, mean=1.5, std=0.8)
71        self._run_init_op(torch.nn.init.uniform_, a=0, b=1.2)
72
73        for dtype in (torch.float32, torch.float16):
74            self._run_init_op(torch.rand_like, dtype=dtype)
75            self._run_init_op(torch.randn_like, dtype=dtype)
76            self._run_init_op(torch.randint_like, low=0, high=100, dtype=dtype)
77
78
79class DistTensorRandomOpTest(DTensorTestBase):
80    @with_comms
81    @skip_unless_torch_gpu
82    def test_rng_tracker_init(self):
83        torch.cuda.manual_seed(self.rank)
84        object_list = [torch.cuda.initial_seed()]
85        broadcast_object_list(object_list)
86        seed_from_rank_0 = int(object_list[0])
87
88        device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
89        # seed synchronization happens after the first `distribute_tensor` call
90        dtensor = distribute_tensor(
91            torch.empty([self.world_size], device="cuda"), device_mesh, [Shard(0)]
92        )
93        self.assertEqual(seed_from_rank_0, random._rng_tracker.get_seed("parallel-rng"))
94
95    @with_comms
96    @skip_unless_torch_gpu
97    def test_manual_seed(self):
98        device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
99        manual_seed(1234, device_mesh)
100        self.assertEqual(1234, random._rng_tracker.get_seed("parallel-rng"))
101        with self.assertRaisesRegex(RuntimeError, "different seed values"):
102            manual_seed(self.rank, device_mesh)
103
104    @with_comms
105    @skip_unless_torch_gpu
106    def test_deterministic_dropout_1d(self):
107        # test suite sets each rank's seed to the same value but in actual
108        # execution the default random seed will be different (a random value).
109        # The DTensor random ops will use the same random seed even though the
110        # torch random generator keeps different seeds on ranks.
111        torch.cuda.manual_seed(self.rank)
112        # TODO: add test before/after enabling distribute region
113        device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
114        size = [4, 4]
115
116        dtensor = distribute_tensor(
117            torch.empty(*size, device="cuda"), device_mesh, [Shard(1)]
118        )
119
120        # a random op call shifts the offset
121        dtensor.uniform_(0, 1)
122
123        # the dtensor is now replicate on all ranks
124        dtensor = dtensor.redistribute(device_mesh, [Replicate()])
125
126        dropout = torch.nn.Dropout(p=0.2)
127        dtensor = dropout(dtensor)
128
129        # allgather the local tensors
130        local_tensor = funcol.all_gather_tensor(
131            dtensor.to_local(), gather_dim=0, group=(device_mesh, 0)
132        )
133
134        # compare with local tensors from other ranks
135        self_slice = slice(4 * self.rank, 4 * self.rank + 4)
136        for other_rank in range(self.world_size):
137            if self.rank != other_rank:
138                # other rank should have an identical local tensor
139                other_slice = slice(4 * other_rank, 4 * other_rank + 4)
140                self.assertEqual(
141                    local_tensor[self_slice, :],
142                    local_tensor[other_slice, :],
143                )
144
145    @with_comms
146    @skip_unless_torch_gpu
147    def test_deterministic_rand_1d(self):
148        device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
149        size = [4, 4 * self.world_size]
150
151        for fn in [
152            torch.distributed._tensor.rand,
153            torch.distributed._tensor.randn,
154        ]:
155            dtensor = fn(size, device_mesh=device_mesh, placements=[Shard(1)])
156            local_tensor = funcol.all_gather_tensor(
157                dtensor.to_local(), gather_dim=0, group=(device_mesh, 0)
158            )
159
160            # compare with local tensors from other ranks
161            self_slice = slice(4 * self.rank, 4 * self.rank + 4)
162            for other_rank in range(self.world_size):
163                if self.rank != other_rank:
164                    # other rank should have an identical local tensor
165                    other_slice = slice(4 * other_rank, 4 * other_rank + 4)
166                    self.assertNotEqual(
167                        local_tensor[self_slice, :],
168                        local_tensor[other_slice, :],
169                    )
170
171            torch.cuda.manual_seed(self.rank)
172            dtensor = fn(size, device_mesh=device_mesh, placements=[Replicate()])
173            local_tensor = funcol.all_gather_tensor(
174                dtensor.to_local(), gather_dim=0, group=(device_mesh, 0)
175            )
176
177            # compare with local tensors from other ranks
178            self_slice = slice(4 * self.rank, 4 * self.rank + 4)
179            for other_rank in range(self.world_size):
180                if self.rank != other_rank:
181                    # other rank should have an identical local tensor
182                    other_slice = slice(4 * other_rank, 4 * other_rank + 4)
183                    self.assertEqual(
184                        local_tensor[self_slice, :],
185                        local_tensor[other_slice, :],
186                    )
187
188    @with_comms
189    @skip_if_lt_x_gpu(4)
190    def test_deterministic_uniform_2d(self):
191        mesh = torch.arange(self.world_size).reshape(2, 2)
192        device_mesh = DeviceMesh(self.device_type, mesh)
193        dtensor = distribute_tensor(
194            torch.empty(
195                *[self.world_size for _ in mesh.size()], device=self.device_type
196            ),
197            device_mesh,
198            [Replicate(), Replicate()],
199        )
200
201        placements_list = [  # this list of placements should be enough to cover
202            [Shard(0), Shard(1)],
203            [Shard(1), Shard(0)],
204            [Shard(0), Replicate()],
205            [Replicate(), Shard(0)],
206            [Shard(1), Replicate()],
207            [Replicate(), Shard(1)],
208            [Replicate(), Replicate()],
209        ]
210
211        shard_index_list = [
212            {0: 0, 1: 1, 2: 2, 3: 3},
213            {0: 0, 1: 2, 2: 1, 3: 3},
214            {0: 0, 1: 0, 2: 1, 3: 1},
215            {0: 0, 1: 1, 2: 0, 3: 1},
216            {0: 0, 1: 0, 2: 1, 3: 1},
217            {0: 0, 1: 1, 2: 0, 3: 1},
218            {0: 0, 1: 0, 2: 0, 3: 0},
219        ]
220
221        coordinate = device_mesh.get_coordinate()
222        assert coordinate is not None
223
224        for placements, shard_index in zip(placements_list, shard_index_list):
225            dtensor = dtensor.redistribute(device_mesh, placements)
226
227            # check shard information is correct
228            shard_coord = [
229                coordinate[mesh_dim] if mesh_dim >= 0 else 0
230                for mesh_dim in dtensor._spec.dim_map
231            ]
232
233            shard_size = [
234                device_mesh.size(mesh_dim) if mesh_dim >= 0 else 1
235                for mesh_dim in dtensor._spec.dim_map
236            ]
237
238            shard_linear_idx = random._rng_tracker._calc_shard_linear_idx(
239                shard_coord, shard_size
240            )
241            self.assertEqual(shard_linear_idx, shard_index[self.rank])
242
243            # compute local size and offset
244            _, local_shard_offset = compute_local_shape_and_global_offset(
245                dtensor.shape, device_mesh, placements
246            )
247
248            # get the local shard size and local shard offset for each shard
249            # local_shard_list_on_dim[i] has the list of all shards on that dim
250            # as a tuple (local_shard_offset, local_shard_size)
251            dtensor_shape = dtensor.shape
252            local_shard_list_on_dim = [[(0, l)] for l in dtensor_shape]
253            for idx, placement in enumerate(placements):
254                if isinstance(placement, Shard):
255                    mesh_dim_size = device_mesh.size(idx)
256                    shard_dim = placement.dim
257                    local_shard_list_on_dim[shard_dim] = []
258                    for shard_idx_on_dim in range(mesh_dim_size):
259                        shard_size, shard_offset = placement._local_shard_size_on_dim(
260                            dtensor_shape[shard_dim],
261                            mesh_dim_size,
262                            shard_idx_on_dim,
263                            return_offset=True,
264                        )
265                        local_shard_list_on_dim[shard_dim].append(
266                            (shard_offset, shard_size)
267                        )
268
269            local_shard_comb = itertools.product(*local_shard_list_on_dim)
270
271            # random op call
272            dtensor.uniform_(0, 1)
273
274            # the local shard
275            local_tensor = dtensor.to_local()
276            # allgather the local tensors
277            full_tensor = dtensor.full_tensor()
278
279            # compare local tensor with each other shard
280            for other_local_shard in local_shard_comb:
281                other_local_shard_offset, _ = zip(*other_local_shard)
282                slice_idx = [
283                    slice(offset, offset + size) for offset, size in other_local_shard
284                ]
285                if local_shard_offset == other_local_shard_offset:
286                    self.assertEqual(full_tensor[slice_idx], local_tensor)
287                else:
288                    self.assertNotEqual(full_tensor[slice_idx], local_tensor)
289
290    @with_comms
291    @skip_if_lt_x_gpu(4)
292    def test_meta_tensor_init(self):
293        # test suite sets each rank's seed to the same value but in actual
294        # execution the default random seed will be different (a random value).
295        # The DTensor random ops will use the same random seed even though the
296        # torch random generator keeps different seeds on ranks. This ensures
297        # that Replicate DTensor will have the same initialized results
298        # across ranks.
299        torch.cuda.manual_seed(self.rank)
300        device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
301        size = [1024, 2048]
302        meta_dtensor = distribute_tensor(
303            torch.empty(*size, device="meta"), device_mesh, [Replicate()]
304        )
305        self.assertTrue(meta_dtensor.is_meta)
306        dtensor = torch.empty_like(meta_dtensor, device=self.device_type)
307
308        # disable the distribute region for RNG
309        random._rng_tracker.distribute_region_enabled = False
310        dtensor.uniform_()
311
312        # allgather the local tensors
313        local_tensor = funcol.all_gather_tensor(
314            dtensor.to_local(), gather_dim=0, group=(device_mesh, 0)
315        )
316
317        # compare with local tensors from other ranks
318        self_slice = slice(1024 * self.rank, 1024 * self.rank + 1024)
319        for other_rank in range(self.world_size):
320            # the RNG result on each rank differs even they're supposed
321            # to be replicated
322            if self.rank != other_rank:
323                other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024)
324                self.assertNotEqual(
325                    local_tensor[self_slice, :], local_tensor[other_slice, :]
326                )
327
328        # enable the distribute region for RNG
329        random._rng_tracker.distribute_region_enabled = True
330        self.assertTrue(meta_dtensor.is_meta)
331        dtensor = torch.empty_like(meta_dtensor, device=self.device_type)
332        dtensor.uniform_()
333
334        # allgather the local tensors
335        local_tensor = funcol.all_gather_tensor(
336            dtensor.to_local(), gather_dim=0, group=(device_mesh, 0)
337        )
338
339        # compare with local tensors from other ranks
340        for other_rank in range(self.world_size):
341            # the RNG result on each rank are the same because they're replicated
342            if self.rank != other_rank:
343                # other rank should have an identical local tensor
344                other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024)
345                self.assertEqual(
346                    local_tensor[self_slice, :], local_tensor[other_slice, :]
347                )
348
349
350if __name__ == "__main__":
351    run_tests()
352