• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3
4import copy
5
6import torch
7import torch.nn as nn
8from torch.distributed._tensor import (
9    DeviceMesh,
10    distribute_module,
11    distribute_tensor,
12    Replicate,
13    Shard,
14)
15from torch.testing._internal.common_utils import run_tests
16from torch.testing._internal.distributed._tensor.common_dtensor import (
17    DTensorTestBase,
18    skip_if_lt_x_gpu,
19    with_comms,
20)
21
22
23ITER_TIME = 10
24LR = 0.001
25
26
27def _conv_fn(
28    name: str,
29    module: nn.Module,
30    device_mesh: DeviceMesh,
31) -> None:
32    for name, param in module.named_parameters():
33        dist_spec = [Replicate()]
34        dist_param = torch.nn.Parameter(
35            distribute_tensor(param, device_mesh, dist_spec)
36        )
37        name = "_".join(name.split("."))
38        module.register_parameter(name, dist_param)
39
40
41class DistConvolutionOpsTest(DTensorTestBase):
42    @property
43    def world_size(self) -> int:
44        # hard code world size to 2
45        return 2
46
47    @with_comms
48    def test_downsampling_convolution(self):
49        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
50        shard_spec = [Shard(3)]
51
52        input_list = torch.rand(ITER_TIME, 7, 3, 512, 1024)
53        grad_output_list = torch.rand(ITER_TIME, 7, 256, 128, 256) * 1e-3
54
55        model = nn.Conv2d(3, 256, kernel_size=4, stride=4, padding=0).to(
56            self.device_type
57        )
58        nn.init.ones_(model.weight)
59        nn.init.zeros_(model.bias)
60        model_gt = copy.deepcopy(model).to(self.device_type)
61
62        # training with dtensor
63        model = distribute_module(
64            model, device_mesh, _conv_fn, input_fn=None, output_fn=None
65        )
66        optimizer = torch.optim.SGD(model.parameters(), lr=LR)
67        for i in range(ITER_TIME):
68            optimizer.zero_grad()
69            inp = input_list[i].to(self.device_type).requires_grad_()
70            inp_dtensor = distribute_tensor(inp, device_mesh, shard_spec)
71            output = model(inp_dtensor)
72            grad_output = grad_output_list[i].to(self.device_type)
73            grad_output_dtensor = distribute_tensor(
74                grad_output, device_mesh, shard_spec
75            )
76            output.backward(grad_output_dtensor)
77            optimizer.step()
78
79        # training with plain tensor
80        optimizer_gt = torch.optim.SGD(model_gt.parameters(), lr=LR)
81        for i in range(ITER_TIME):
82            optimizer_gt.zero_grad()
83            inp = input_list[i].to(self.device_type).requires_grad_()
84            output = model_gt(inp)
85            grad_output = grad_output_list[i].to(self.device_type)
86            output.backward(grad_output)
87            optimizer_gt.step()
88
89        weight_diff_abs = model.weight.to_local() - model_gt.weight
90        bias_diff_abs = model.bias.to_local() - model_gt.bias
91        weight_diff_rel = weight_diff_abs / (torch.abs(model_gt.weight) + 1e-8)
92        bias_diff_rel = bias_diff_abs / (torch.abs(model_gt.bias) + 1e-8)
93        weight_mse_abs = torch.mean(weight_diff_abs * weight_diff_abs).item()
94        bias_mse_abs = torch.mean(bias_diff_abs * bias_diff_abs).item()
95        weight_mse_rel = torch.mean(weight_diff_rel * weight_diff_rel).item()
96        bias_mse_rel = torch.mean(bias_diff_rel * bias_diff_rel).item()
97        self.assertTrue(
98            weight_mse_abs <= 1e-6,
99            f"Too large absolute mse for weight tensor, expected less equal 1e-6, got {weight_mse_abs}",
100        )
101        self.assertTrue(
102            bias_mse_abs <= 1e-6,
103            f"Too large absolute mse for bias tensor, expected less equal 1e-6, got {bias_mse_abs}",
104        )
105        self.assertTrue(
106            weight_mse_rel <= 1e-6,
107            f"Too large relative mse for weight tensor, expected less equal 1e-6, got {weight_mse_rel}",
108        )
109        self.assertTrue(
110            bias_mse_rel <= 1e-6,
111            f"Too large relative mse for bias tensor, expected less equal 1e-6, got {bias_mse_rel}",
112        )
113
114    # TODO: test_depthwise_convolution is broken in CI with gloo backend.
115    # Temporarily disable it to unblock CI.
116    @with_comms
117    @skip_if_lt_x_gpu(2)
118    def test_depthwise_convolution(self):
119        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
120        shard_spec = [Shard(3)]
121
122        input_list = torch.rand(ITER_TIME, 7, 256, 128, 256)
123        grad_output_list = torch.rand(ITER_TIME, 7, 256, 128, 256) * 1e-3
124
125        model = nn.Conv2d(256, 256, kernel_size=7, padding=3, groups=256).to(
126            self.device_type
127        )
128        nn.init.ones_(model.weight)
129        nn.init.zeros_(model.bias)
130        model_gt = copy.deepcopy(model).to(self.device_type)
131
132        # training with dtensor
133        model = distribute_module(
134            model, device_mesh, _conv_fn, input_fn=None, output_fn=None
135        )
136        optimizer = torch.optim.SGD(model.parameters(), lr=LR)
137        for i in range(ITER_TIME):
138            optimizer.zero_grad()
139            inp = input_list[i].to(self.device_type).requires_grad_()
140            inp_dtensor = distribute_tensor(inp, device_mesh, shard_spec)
141            output = model(inp_dtensor)
142            grad_output = grad_output_list[i].to(self.device_type)
143            grad_output_dtensor = distribute_tensor(
144                grad_output, device_mesh, shard_spec
145            )
146            output.backward(grad_output_dtensor)
147            optimizer.step()
148
149        # training with plain tensor
150        optimizer_gt = torch.optim.SGD(model_gt.parameters(), lr=LR)
151        for i in range(ITER_TIME):
152            optimizer_gt.zero_grad()
153            inp = input_list[i].to(self.device_type).requires_grad_()
154            output = model_gt(inp)
155            grad_output = grad_output_list[i].to(self.device_type)
156            output.backward(grad_output)
157            optimizer_gt.step()
158
159        weight_diff_abs = model.weight.to_local() - model_gt.weight
160        bias_diff_abs = model.bias.to_local() - model_gt.bias
161        weight_diff_rel = weight_diff_abs / (torch.abs(model_gt.weight) + 1e-8)
162        bias_diff_rel = bias_diff_abs / (torch.abs(model_gt.bias) + 1e-8)
163        weight_mse_abs = torch.mean(weight_diff_abs * weight_diff_abs).item()
164        bias_mse_abs = torch.mean(bias_diff_abs * bias_diff_abs).item()
165        weight_mse_rel = torch.mean(weight_diff_rel * weight_diff_rel).item()
166        bias_mse_rel = torch.mean(bias_diff_rel * bias_diff_rel).item()
167        self.assertTrue(
168            weight_mse_abs <= 1e-6,
169            f"Too large absolute mse for weight tensor, expected less equal 1e-6, got {weight_mse_abs}",
170        )
171        self.assertTrue(
172            bias_mse_abs <= 1e-6,
173            f"Too large absolute mse for bias tensor, expected less equal 1e-6, got {bias_mse_abs}",
174        )
175        self.assertTrue(
176            weight_mse_rel <= 1e-6,
177            f"Too large relative mse for weight tensor, expected less equal 1e-6, got {weight_mse_rel}",
178        )
179        self.assertTrue(
180            bias_mse_rel <= 1e-6,
181            f"Too large relative mse for bias tensor, expected less equal 1e-6, got {bias_mse_rel}",
182        )
183
184
185if __name__ == "__main__":
186    run_tests()
187