1# Copyright (c) Meta Platforms, Inc. and affiliates 2# Owner(s): ["oncall: distributed"] 3 4 5import torch 6import torch.distributed as dist 7from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate 8from torch.testing._internal.common_utils import run_tests 9from torch.testing._internal.distributed._tensor.common_dtensor import ( 10 DTensorTestBase, 11 with_comms, 12) 13 14 15ITER_TIME = 10 16LR = 0.001 17 18 19class DistOtherOpsTest(DTensorTestBase): 20 @property 21 def world_size(self) -> int: 22 # hard code world size to 2 23 return 2 24 25 @with_comms 26 def test_slice(self): 27 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 28 shard_spec = [Replicate()] 29 30 input_list = torch.rand(ITER_TIME, 1024, 10) 31 grad_output_list = torch.rand(ITER_TIME, 1024, 5) * 1e-3 32 33 for i in range(ITER_TIME): 34 inp = input_list[i].to(self.device_type).requires_grad_() 35 grad_output = grad_output_list[i].to(self.device_type) 36 37 # droppath with dtensor 38 inp_dtensor = distribute_tensor(inp, device_mesh, shard_spec) 39 grad_output_dtensor = distribute_tensor( 40 grad_output, device_mesh, shard_spec 41 ) 42 output = inp_dtensor[:, :5] 43 output.backward(grad_output_dtensor) 44 45 # nll with plain tensor 46 output_gt = inp[:, :5] 47 output_gt.backward(grad_output) 48 49 output_diff_abs = output.to_local() - output_gt 50 output_diff_rel = output_diff_abs / (torch.abs(output_gt) + 1e-8) 51 output_mse_abs = torch.mean(output_diff_abs * output_diff_abs).item() 52 output_mse_rel = torch.mean(output_diff_rel * output_diff_rel).item() 53 54 grad_diff_abs = inp_dtensor.grad.to_local() - inp.grad 55 grad_diff_rel = grad_diff_abs / (torch.abs(inp.grad) + 1e-8) 56 grad_mse_abs = torch.mean(grad_diff_abs * grad_diff_abs).item() 57 grad_mse_rel = torch.mean(grad_diff_rel * grad_diff_rel).item() 58 59 self.assertTrue( 60 output_mse_abs <= 1e-6, 61 f"Too large absolute mse for output, expected less equal 1e-6, got {output_mse_abs}", 62 ) 63 self.assertTrue( 64 output_mse_rel <= 1e-6, 65 f"Too large relative mse for output, expected less equal 1e-6, got {output_mse_rel}", 66 ) 67 self.assertTrue( 68 grad_mse_abs <= 1e-6, 69 f"Too large absolute mse for gradient, expected less equal 1e-6, got {grad_mse_abs}", 70 ) 71 self.assertTrue( 72 grad_mse_rel <= 1e-6, 73 f"Too large relative mse for gradient, expected less equal 1e-6, got {grad_mse_rel}", 74 ) 75 76 @with_comms 77 def test_bernoulli(self): 78 rank = dist.get_rank() 79 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 80 shard_spec = [Replicate()] 81 82 input_list = torch.rand(ITER_TIME, 1024, 10) 83 grad_output_list = torch.rand(ITER_TIME, 1024, 10) * 1e-3 84 85 for i in range(ITER_TIME): 86 inp = input_list[i].to(self.device_type).requires_grad_() 87 grad_output = grad_output_list[i].to(self.device_type) 88 89 # bernoulli with dtensor 90 inp_dtensor = distribute_tensor(inp, device_mesh, shard_spec) 91 grad_output_dtensor = distribute_tensor( 92 grad_output, device_mesh, shard_spec 93 ) 94 output = torch.bernoulli(inp_dtensor) 95 output.backward(grad_output_dtensor) 96 97 send_output_tensor = output.to_local() 98 recv_output_tensor = torch.zeros_like(send_output_tensor) 99 100 send_grad_tensor = inp_dtensor.grad.to_local() 101 recv_grad_tensor = torch.zeros_like(send_grad_tensor) 102 103 send_op_1 = dist.P2POp(dist.isend, send_output_tensor, 1 ^ rank) 104 send_op_2 = dist.P2POp(dist.isend, send_grad_tensor, 1 ^ rank) 105 recv_op_1 = dist.P2POp(dist.irecv, recv_output_tensor, 1 ^ rank) 106 recv_op_2 = dist.P2POp(dist.irecv, recv_grad_tensor, 1 ^ rank) 107 108 reqs = dist.batch_isend_irecv([send_op_1, send_op_2, recv_op_1, recv_op_2]) 109 for req in reqs: 110 req.wait() 111 112 output_diff_abs = send_output_tensor - recv_output_tensor 113 output_diff_rel = output_diff_abs / (torch.abs(recv_output_tensor) + 1e-8) 114 output_mse_abs = torch.mean(output_diff_abs * output_diff_abs).item() 115 output_mse_rel = torch.mean(output_diff_rel * output_diff_rel).item() 116 117 grad_diff_abs = send_grad_tensor - recv_grad_tensor 118 grad_diff_rel = grad_diff_abs / (torch.abs(recv_grad_tensor) + 1e-8) 119 grad_mse_abs = torch.mean(grad_diff_abs * grad_diff_abs).item() 120 grad_mse_rel = torch.mean(grad_diff_rel * grad_diff_rel).item() 121 122 self.assertTrue( 123 output_mse_abs <= 1e-6, 124 f"Too large absolute mse for output, expected less equal 1e-6, got {output_mse_abs}", 125 ) 126 self.assertTrue( 127 output_mse_rel <= 1e-6, 128 f"Too large relative mse for output, expected less equal 1e-6, got {output_mse_rel}", 129 ) 130 self.assertTrue( 131 grad_mse_abs <= 1e-6, 132 f"Too large absolute mse for gradient, expected less equal 1e-6, got {grad_mse_abs}", 133 ) 134 self.assertTrue( 135 grad_mse_rel <= 1e-6, 136 f"Too large relative mse for gradient, expected less equal 1e-6, got {grad_mse_rel}", 137 ) 138 139 @with_comms 140 def test_nll(self): 141 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 142 shard_spec = [Replicate()] 143 144 pred_list = torch.rand(ITER_TIME, 1024, 10) 145 target_list = torch.randint(0, 10, (ITER_TIME, 1024), dtype=torch.long) 146 147 criterion = torch.nn.CrossEntropyLoss() 148 149 for i in range(ITER_TIME): 150 pred = pred_list[i].to(self.device_type).requires_grad_() 151 target = target_list[i].to(self.device_type) 152 153 # nll with dtensor 154 pred_dtensor = distribute_tensor(pred, device_mesh, shard_spec) 155 target_dtensor = distribute_tensor(target, device_mesh, shard_spec) 156 loss = criterion(pred_dtensor, target_dtensor) 157 loss.backward() 158 159 # nll with plain tensor 160 loss_gt = criterion(pred, target) 161 loss_gt.backward() 162 163 loss_diff_abs = loss.to_local() - loss_gt 164 loss_diff_rel = loss_diff_abs / (torch.abs(loss_gt) + 1e-8) 165 loss_mse_abs = torch.mean(loss_diff_abs * loss_diff_abs).item() 166 loss_mse_rel = torch.mean(loss_diff_rel * loss_diff_rel).item() 167 168 grad_diff_abs = pred_dtensor.grad.to_local() - pred.grad 169 grad_diff_rel = grad_diff_abs / (torch.abs(pred.grad) + 1e-8) 170 grad_mse_abs = torch.mean(grad_diff_abs * grad_diff_abs).item() 171 grad_mse_rel = torch.mean(grad_diff_rel * grad_diff_rel).item() 172 173 self.assertTrue( 174 loss_mse_abs <= 1e-6, 175 f"Too large absolute mse for loss, expected less equal 1e-6, got {loss_mse_abs}", 176 ) 177 self.assertTrue( 178 loss_mse_rel <= 1e-6, 179 f"Too large relative mse for loss, expected less equal 1e-6, got {loss_mse_rel}", 180 ) 181 self.assertTrue( 182 grad_mse_abs <= 1e-6, 183 f"Too large absolute mse for gradient, expected less equal 1e-6, got {grad_mse_abs}", 184 ) 185 self.assertTrue( 186 grad_mse_rel <= 1e-6, 187 f"Too large relative mse for gradient, expected less equal 1e-6, got {grad_mse_rel}", 188 ) 189 190 191if __name__ == "__main__": 192 run_tests() 193