• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3
4
5import torch
6import torch.distributed as dist
7from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate
8from torch.testing._internal.common_utils import run_tests
9from torch.testing._internal.distributed._tensor.common_dtensor import (
10    DTensorTestBase,
11    with_comms,
12)
13
14
15ITER_TIME = 10
16LR = 0.001
17
18
19class DistOtherOpsTest(DTensorTestBase):
20    @property
21    def world_size(self) -> int:
22        # hard code world size to 2
23        return 2
24
25    @with_comms
26    def test_slice(self):
27        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
28        shard_spec = [Replicate()]
29
30        input_list = torch.rand(ITER_TIME, 1024, 10)
31        grad_output_list = torch.rand(ITER_TIME, 1024, 5) * 1e-3
32
33        for i in range(ITER_TIME):
34            inp = input_list[i].to(self.device_type).requires_grad_()
35            grad_output = grad_output_list[i].to(self.device_type)
36
37            # droppath  with dtensor
38            inp_dtensor = distribute_tensor(inp, device_mesh, shard_spec)
39            grad_output_dtensor = distribute_tensor(
40                grad_output, device_mesh, shard_spec
41            )
42            output = inp_dtensor[:, :5]
43            output.backward(grad_output_dtensor)
44
45            # nll with plain tensor
46            output_gt = inp[:, :5]
47            output_gt.backward(grad_output)
48
49            output_diff_abs = output.to_local() - output_gt
50            output_diff_rel = output_diff_abs / (torch.abs(output_gt) + 1e-8)
51            output_mse_abs = torch.mean(output_diff_abs * output_diff_abs).item()
52            output_mse_rel = torch.mean(output_diff_rel * output_diff_rel).item()
53
54            grad_diff_abs = inp_dtensor.grad.to_local() - inp.grad
55            grad_diff_rel = grad_diff_abs / (torch.abs(inp.grad) + 1e-8)
56            grad_mse_abs = torch.mean(grad_diff_abs * grad_diff_abs).item()
57            grad_mse_rel = torch.mean(grad_diff_rel * grad_diff_rel).item()
58
59            self.assertTrue(
60                output_mse_abs <= 1e-6,
61                f"Too large absolute mse for output, expected less equal 1e-6, got {output_mse_abs}",
62            )
63            self.assertTrue(
64                output_mse_rel <= 1e-6,
65                f"Too large relative mse for output, expected less equal 1e-6, got {output_mse_rel}",
66            )
67            self.assertTrue(
68                grad_mse_abs <= 1e-6,
69                f"Too large absolute mse for gradient, expected less equal 1e-6, got {grad_mse_abs}",
70            )
71            self.assertTrue(
72                grad_mse_rel <= 1e-6,
73                f"Too large relative mse for gradient, expected less equal 1e-6, got {grad_mse_rel}",
74            )
75
76    @with_comms
77    def test_bernoulli(self):
78        rank = dist.get_rank()
79        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
80        shard_spec = [Replicate()]
81
82        input_list = torch.rand(ITER_TIME, 1024, 10)
83        grad_output_list = torch.rand(ITER_TIME, 1024, 10) * 1e-3
84
85        for i in range(ITER_TIME):
86            inp = input_list[i].to(self.device_type).requires_grad_()
87            grad_output = grad_output_list[i].to(self.device_type)
88
89            # bernoulli  with dtensor
90            inp_dtensor = distribute_tensor(inp, device_mesh, shard_spec)
91            grad_output_dtensor = distribute_tensor(
92                grad_output, device_mesh, shard_spec
93            )
94            output = torch.bernoulli(inp_dtensor)
95            output.backward(grad_output_dtensor)
96
97            send_output_tensor = output.to_local()
98            recv_output_tensor = torch.zeros_like(send_output_tensor)
99
100            send_grad_tensor = inp_dtensor.grad.to_local()
101            recv_grad_tensor = torch.zeros_like(send_grad_tensor)
102
103            send_op_1 = dist.P2POp(dist.isend, send_output_tensor, 1 ^ rank)
104            send_op_2 = dist.P2POp(dist.isend, send_grad_tensor, 1 ^ rank)
105            recv_op_1 = dist.P2POp(dist.irecv, recv_output_tensor, 1 ^ rank)
106            recv_op_2 = dist.P2POp(dist.irecv, recv_grad_tensor, 1 ^ rank)
107
108            reqs = dist.batch_isend_irecv([send_op_1, send_op_2, recv_op_1, recv_op_2])
109            for req in reqs:
110                req.wait()
111
112            output_diff_abs = send_output_tensor - recv_output_tensor
113            output_diff_rel = output_diff_abs / (torch.abs(recv_output_tensor) + 1e-8)
114            output_mse_abs = torch.mean(output_diff_abs * output_diff_abs).item()
115            output_mse_rel = torch.mean(output_diff_rel * output_diff_rel).item()
116
117            grad_diff_abs = send_grad_tensor - recv_grad_tensor
118            grad_diff_rel = grad_diff_abs / (torch.abs(recv_grad_tensor) + 1e-8)
119            grad_mse_abs = torch.mean(grad_diff_abs * grad_diff_abs).item()
120            grad_mse_rel = torch.mean(grad_diff_rel * grad_diff_rel).item()
121
122            self.assertTrue(
123                output_mse_abs <= 1e-6,
124                f"Too large absolute mse for output, expected less equal 1e-6, got {output_mse_abs}",
125            )
126            self.assertTrue(
127                output_mse_rel <= 1e-6,
128                f"Too large relative mse for output, expected less equal 1e-6, got {output_mse_rel}",
129            )
130            self.assertTrue(
131                grad_mse_abs <= 1e-6,
132                f"Too large absolute mse for gradient, expected less equal 1e-6, got {grad_mse_abs}",
133            )
134            self.assertTrue(
135                grad_mse_rel <= 1e-6,
136                f"Too large relative mse for gradient, expected less equal 1e-6, got {grad_mse_rel}",
137            )
138
139    @with_comms
140    def test_nll(self):
141        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
142        shard_spec = [Replicate()]
143
144        pred_list = torch.rand(ITER_TIME, 1024, 10)
145        target_list = torch.randint(0, 10, (ITER_TIME, 1024), dtype=torch.long)
146
147        criterion = torch.nn.CrossEntropyLoss()
148
149        for i in range(ITER_TIME):
150            pred = pred_list[i].to(self.device_type).requires_grad_()
151            target = target_list[i].to(self.device_type)
152
153            # nll with dtensor
154            pred_dtensor = distribute_tensor(pred, device_mesh, shard_spec)
155            target_dtensor = distribute_tensor(target, device_mesh, shard_spec)
156            loss = criterion(pred_dtensor, target_dtensor)
157            loss.backward()
158
159            # nll with plain tensor
160            loss_gt = criterion(pred, target)
161            loss_gt.backward()
162
163            loss_diff_abs = loss.to_local() - loss_gt
164            loss_diff_rel = loss_diff_abs / (torch.abs(loss_gt) + 1e-8)
165            loss_mse_abs = torch.mean(loss_diff_abs * loss_diff_abs).item()
166            loss_mse_rel = torch.mean(loss_diff_rel * loss_diff_rel).item()
167
168            grad_diff_abs = pred_dtensor.grad.to_local() - pred.grad
169            grad_diff_rel = grad_diff_abs / (torch.abs(pred.grad) + 1e-8)
170            grad_mse_abs = torch.mean(grad_diff_abs * grad_diff_abs).item()
171            grad_mse_rel = torch.mean(grad_diff_rel * grad_diff_rel).item()
172
173            self.assertTrue(
174                loss_mse_abs <= 1e-6,
175                f"Too large absolute mse for loss, expected less equal 1e-6, got {loss_mse_abs}",
176            )
177            self.assertTrue(
178                loss_mse_rel <= 1e-6,
179                f"Too large relative mse for loss, expected less equal 1e-6, got {loss_mse_rel}",
180            )
181            self.assertTrue(
182                grad_mse_abs <= 1e-6,
183                f"Too large absolute mse for gradient, expected less equal 1e-6, got {grad_mse_abs}",
184            )
185            self.assertTrue(
186                grad_mse_rel <= 1e-6,
187                f"Too large relative mse for gradient, expected less equal 1e-6, got {grad_mse_rel}",
188            )
189
190
191if __name__ == "__main__":
192    run_tests()
193