# Copyright (c) Meta Platforms, Inc. and affiliates # Owner(s): ["oncall: distributed"] import torch import torch.distributed as dist from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, with_comms, ) ITER_TIME = 10 LR = 0.001 class DistOtherOpsTest(DTensorTestBase): @property def world_size(self) -> int: # hard code world size to 2 return 2 @with_comms def test_slice(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) shard_spec = [Replicate()] input_list = torch.rand(ITER_TIME, 1024, 10) grad_output_list = torch.rand(ITER_TIME, 1024, 5) * 1e-3 for i in range(ITER_TIME): inp = input_list[i].to(self.device_type).requires_grad_() grad_output = grad_output_list[i].to(self.device_type) # droppath with dtensor inp_dtensor = distribute_tensor(inp, device_mesh, shard_spec) grad_output_dtensor = distribute_tensor( grad_output, device_mesh, shard_spec ) output = inp_dtensor[:, :5] output.backward(grad_output_dtensor) # nll with plain tensor output_gt = inp[:, :5] output_gt.backward(grad_output) output_diff_abs = output.to_local() - output_gt output_diff_rel = output_diff_abs / (torch.abs(output_gt) + 1e-8) output_mse_abs = torch.mean(output_diff_abs * output_diff_abs).item() output_mse_rel = torch.mean(output_diff_rel * output_diff_rel).item() grad_diff_abs = inp_dtensor.grad.to_local() - inp.grad grad_diff_rel = grad_diff_abs / (torch.abs(inp.grad) + 1e-8) grad_mse_abs = torch.mean(grad_diff_abs * grad_diff_abs).item() grad_mse_rel = torch.mean(grad_diff_rel * grad_diff_rel).item() self.assertTrue( output_mse_abs <= 1e-6, f"Too large absolute mse for output, expected less equal 1e-6, got {output_mse_abs}", ) self.assertTrue( output_mse_rel <= 1e-6, f"Too large relative mse for output, expected less equal 1e-6, got {output_mse_rel}", ) self.assertTrue( grad_mse_abs <= 1e-6, f"Too large absolute mse for gradient, expected less equal 1e-6, got {grad_mse_abs}", ) self.assertTrue( grad_mse_rel <= 1e-6, f"Too large relative mse for gradient, expected less equal 1e-6, got {grad_mse_rel}", ) @with_comms def test_bernoulli(self): rank = dist.get_rank() device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) shard_spec = [Replicate()] input_list = torch.rand(ITER_TIME, 1024, 10) grad_output_list = torch.rand(ITER_TIME, 1024, 10) * 1e-3 for i in range(ITER_TIME): inp = input_list[i].to(self.device_type).requires_grad_() grad_output = grad_output_list[i].to(self.device_type) # bernoulli with dtensor inp_dtensor = distribute_tensor(inp, device_mesh, shard_spec) grad_output_dtensor = distribute_tensor( grad_output, device_mesh, shard_spec ) output = torch.bernoulli(inp_dtensor) output.backward(grad_output_dtensor) send_output_tensor = output.to_local() recv_output_tensor = torch.zeros_like(send_output_tensor) send_grad_tensor = inp_dtensor.grad.to_local() recv_grad_tensor = torch.zeros_like(send_grad_tensor) send_op_1 = dist.P2POp(dist.isend, send_output_tensor, 1 ^ rank) send_op_2 = dist.P2POp(dist.isend, send_grad_tensor, 1 ^ rank) recv_op_1 = dist.P2POp(dist.irecv, recv_output_tensor, 1 ^ rank) recv_op_2 = dist.P2POp(dist.irecv, recv_grad_tensor, 1 ^ rank) reqs = dist.batch_isend_irecv([send_op_1, send_op_2, recv_op_1, recv_op_2]) for req in reqs: req.wait() output_diff_abs = send_output_tensor - recv_output_tensor output_diff_rel = output_diff_abs / (torch.abs(recv_output_tensor) + 1e-8) output_mse_abs = torch.mean(output_diff_abs * output_diff_abs).item() output_mse_rel = torch.mean(output_diff_rel * output_diff_rel).item() grad_diff_abs = send_grad_tensor - recv_grad_tensor grad_diff_rel = grad_diff_abs / (torch.abs(recv_grad_tensor) + 1e-8) grad_mse_abs = torch.mean(grad_diff_abs * grad_diff_abs).item() grad_mse_rel = torch.mean(grad_diff_rel * grad_diff_rel).item() self.assertTrue( output_mse_abs <= 1e-6, f"Too large absolute mse for output, expected less equal 1e-6, got {output_mse_abs}", ) self.assertTrue( output_mse_rel <= 1e-6, f"Too large relative mse for output, expected less equal 1e-6, got {output_mse_rel}", ) self.assertTrue( grad_mse_abs <= 1e-6, f"Too large absolute mse for gradient, expected less equal 1e-6, got {grad_mse_abs}", ) self.assertTrue( grad_mse_rel <= 1e-6, f"Too large relative mse for gradient, expected less equal 1e-6, got {grad_mse_rel}", ) @with_comms def test_nll(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) shard_spec = [Replicate()] pred_list = torch.rand(ITER_TIME, 1024, 10) target_list = torch.randint(0, 10, (ITER_TIME, 1024), dtype=torch.long) criterion = torch.nn.CrossEntropyLoss() for i in range(ITER_TIME): pred = pred_list[i].to(self.device_type).requires_grad_() target = target_list[i].to(self.device_type) # nll with dtensor pred_dtensor = distribute_tensor(pred, device_mesh, shard_spec) target_dtensor = distribute_tensor(target, device_mesh, shard_spec) loss = criterion(pred_dtensor, target_dtensor) loss.backward() # nll with plain tensor loss_gt = criterion(pred, target) loss_gt.backward() loss_diff_abs = loss.to_local() - loss_gt loss_diff_rel = loss_diff_abs / (torch.abs(loss_gt) + 1e-8) loss_mse_abs = torch.mean(loss_diff_abs * loss_diff_abs).item() loss_mse_rel = torch.mean(loss_diff_rel * loss_diff_rel).item() grad_diff_abs = pred_dtensor.grad.to_local() - pred.grad grad_diff_rel = grad_diff_abs / (torch.abs(pred.grad) + 1e-8) grad_mse_abs = torch.mean(grad_diff_abs * grad_diff_abs).item() grad_mse_rel = torch.mean(grad_diff_rel * grad_diff_rel).item() self.assertTrue( loss_mse_abs <= 1e-6, f"Too large absolute mse for loss, expected less equal 1e-6, got {loss_mse_abs}", ) self.assertTrue( loss_mse_rel <= 1e-6, f"Too large relative mse for loss, expected less equal 1e-6, got {loss_mse_rel}", ) self.assertTrue( grad_mse_abs <= 1e-6, f"Too large absolute mse for gradient, expected less equal 1e-6, got {grad_mse_abs}", ) self.assertTrue( grad_mse_rel <= 1e-6, f"Too large relative mse for gradient, expected less equal 1e-6, got {grad_mse_rel}", ) if __name__ == "__main__": run_tests()