1# Copyright (c) Meta Platforms, Inc. and affiliates 2# Owner(s): ["oncall: distributed"] 3 4import copy 5 6import torch 7import torch.nn as nn 8from torch.distributed._tensor import ( 9 DeviceMesh, 10 distribute_module, 11 distribute_tensor, 12 Replicate, 13 Shard, 14) 15from torch.testing._internal.common_utils import run_tests 16from torch.testing._internal.distributed._tensor.common_dtensor import ( 17 DTensorTestBase, 18 skip_if_lt_x_gpu, 19 with_comms, 20) 21 22 23ITER_TIME = 10 24LR = 0.001 25 26 27def _conv_fn( 28 name: str, 29 module: nn.Module, 30 device_mesh: DeviceMesh, 31) -> None: 32 for name, param in module.named_parameters(): 33 dist_spec = [Replicate()] 34 dist_param = torch.nn.Parameter( 35 distribute_tensor(param, device_mesh, dist_spec) 36 ) 37 name = "_".join(name.split(".")) 38 module.register_parameter(name, dist_param) 39 40 41class DistConvolutionOpsTest(DTensorTestBase): 42 @property 43 def world_size(self) -> int: 44 # hard code world size to 2 45 return 2 46 47 @with_comms 48 def test_downsampling_convolution(self): 49 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 50 shard_spec = [Shard(3)] 51 52 input_list = torch.rand(ITER_TIME, 7, 3, 512, 1024) 53 grad_output_list = torch.rand(ITER_TIME, 7, 256, 128, 256) * 1e-3 54 55 model = nn.Conv2d(3, 256, kernel_size=4, stride=4, padding=0).to( 56 self.device_type 57 ) 58 nn.init.ones_(model.weight) 59 nn.init.zeros_(model.bias) 60 model_gt = copy.deepcopy(model).to(self.device_type) 61 62 # training with dtensor 63 model = distribute_module( 64 model, device_mesh, _conv_fn, input_fn=None, output_fn=None 65 ) 66 optimizer = torch.optim.SGD(model.parameters(), lr=LR) 67 for i in range(ITER_TIME): 68 optimizer.zero_grad() 69 inp = input_list[i].to(self.device_type).requires_grad_() 70 inp_dtensor = distribute_tensor(inp, device_mesh, shard_spec) 71 output = model(inp_dtensor) 72 grad_output = grad_output_list[i].to(self.device_type) 73 grad_output_dtensor = distribute_tensor( 74 grad_output, device_mesh, shard_spec 75 ) 76 output.backward(grad_output_dtensor) 77 optimizer.step() 78 79 # training with plain tensor 80 optimizer_gt = torch.optim.SGD(model_gt.parameters(), lr=LR) 81 for i in range(ITER_TIME): 82 optimizer_gt.zero_grad() 83 inp = input_list[i].to(self.device_type).requires_grad_() 84 output = model_gt(inp) 85 grad_output = grad_output_list[i].to(self.device_type) 86 output.backward(grad_output) 87 optimizer_gt.step() 88 89 weight_diff_abs = model.weight.to_local() - model_gt.weight 90 bias_diff_abs = model.bias.to_local() - model_gt.bias 91 weight_diff_rel = weight_diff_abs / (torch.abs(model_gt.weight) + 1e-8) 92 bias_diff_rel = bias_diff_abs / (torch.abs(model_gt.bias) + 1e-8) 93 weight_mse_abs = torch.mean(weight_diff_abs * weight_diff_abs).item() 94 bias_mse_abs = torch.mean(bias_diff_abs * bias_diff_abs).item() 95 weight_mse_rel = torch.mean(weight_diff_rel * weight_diff_rel).item() 96 bias_mse_rel = torch.mean(bias_diff_rel * bias_diff_rel).item() 97 self.assertTrue( 98 weight_mse_abs <= 1e-6, 99 f"Too large absolute mse for weight tensor, expected less equal 1e-6, got {weight_mse_abs}", 100 ) 101 self.assertTrue( 102 bias_mse_abs <= 1e-6, 103 f"Too large absolute mse for bias tensor, expected less equal 1e-6, got {bias_mse_abs}", 104 ) 105 self.assertTrue( 106 weight_mse_rel <= 1e-6, 107 f"Too large relative mse for weight tensor, expected less equal 1e-6, got {weight_mse_rel}", 108 ) 109 self.assertTrue( 110 bias_mse_rel <= 1e-6, 111 f"Too large relative mse for bias tensor, expected less equal 1e-6, got {bias_mse_rel}", 112 ) 113 114 # TODO: test_depthwise_convolution is broken in CI with gloo backend. 115 # Temporarily disable it to unblock CI. 116 @with_comms 117 @skip_if_lt_x_gpu(2) 118 def test_depthwise_convolution(self): 119 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 120 shard_spec = [Shard(3)] 121 122 input_list = torch.rand(ITER_TIME, 7, 256, 128, 256) 123 grad_output_list = torch.rand(ITER_TIME, 7, 256, 128, 256) * 1e-3 124 125 model = nn.Conv2d(256, 256, kernel_size=7, padding=3, groups=256).to( 126 self.device_type 127 ) 128 nn.init.ones_(model.weight) 129 nn.init.zeros_(model.bias) 130 model_gt = copy.deepcopy(model).to(self.device_type) 131 132 # training with dtensor 133 model = distribute_module( 134 model, device_mesh, _conv_fn, input_fn=None, output_fn=None 135 ) 136 optimizer = torch.optim.SGD(model.parameters(), lr=LR) 137 for i in range(ITER_TIME): 138 optimizer.zero_grad() 139 inp = input_list[i].to(self.device_type).requires_grad_() 140 inp_dtensor = distribute_tensor(inp, device_mesh, shard_spec) 141 output = model(inp_dtensor) 142 grad_output = grad_output_list[i].to(self.device_type) 143 grad_output_dtensor = distribute_tensor( 144 grad_output, device_mesh, shard_spec 145 ) 146 output.backward(grad_output_dtensor) 147 optimizer.step() 148 149 # training with plain tensor 150 optimizer_gt = torch.optim.SGD(model_gt.parameters(), lr=LR) 151 for i in range(ITER_TIME): 152 optimizer_gt.zero_grad() 153 inp = input_list[i].to(self.device_type).requires_grad_() 154 output = model_gt(inp) 155 grad_output = grad_output_list[i].to(self.device_type) 156 output.backward(grad_output) 157 optimizer_gt.step() 158 159 weight_diff_abs = model.weight.to_local() - model_gt.weight 160 bias_diff_abs = model.bias.to_local() - model_gt.bias 161 weight_diff_rel = weight_diff_abs / (torch.abs(model_gt.weight) + 1e-8) 162 bias_diff_rel = bias_diff_abs / (torch.abs(model_gt.bias) + 1e-8) 163 weight_mse_abs = torch.mean(weight_diff_abs * weight_diff_abs).item() 164 bias_mse_abs = torch.mean(bias_diff_abs * bias_diff_abs).item() 165 weight_mse_rel = torch.mean(weight_diff_rel * weight_diff_rel).item() 166 bias_mse_rel = torch.mean(bias_diff_rel * bias_diff_rel).item() 167 self.assertTrue( 168 weight_mse_abs <= 1e-6, 169 f"Too large absolute mse for weight tensor, expected less equal 1e-6, got {weight_mse_abs}", 170 ) 171 self.assertTrue( 172 bias_mse_abs <= 1e-6, 173 f"Too large absolute mse for bias tensor, expected less equal 1e-6, got {bias_mse_abs}", 174 ) 175 self.assertTrue( 176 weight_mse_rel <= 1e-6, 177 f"Too large relative mse for weight tensor, expected less equal 1e-6, got {weight_mse_rel}", 178 ) 179 self.assertTrue( 180 bias_mse_rel <= 1e-6, 181 f"Too large relative mse for bias tensor, expected less equal 1e-6, got {bias_mse_rel}", 182 ) 183 184 185if __name__ == "__main__": 186 run_tests() 187