• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["module: inductor"]
2
3import logging
4
5import torch
6import torch._inductor
7from torch._dynamo.utils import counters
8from torch._inductor.test_case import run_tests, TestCase
9from torch._inductor.utils import run_and_get_code
10from torch.testing import FileCheck
11from torch.testing._internal.common_utils import (
12    instantiate_parametrized_tests,
13    parametrize,
14)
15from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA
16from torch.testing._internal.triton_utils import requires_gpu
17
18
19class MyModule(torch.nn.Module):
20    def __init__(
21        self, n_input: int, n_output: int, has_bias: bool, device=GPU_TYPE
22    ) -> None:
23        super().__init__()
24        self.linear = torch.nn.Linear(n_input, n_output, bias=has_bias)
25
26    def forward(self, x: torch.Tensor) -> torch.Tensor:
27        return self.linear(x)
28
29
30class MyModule2(torch.nn.Module):
31    def __init__(self) -> None:
32        super().__init__()
33
34    def forward(self, input1, input2):
35        output = torch.bmm(input1, input2)
36        return output
37
38
39class MyModule3(torch.nn.Module):
40    def __init__(self) -> None:
41        super().__init__()
42
43    def forward(self, input1, input2):
44        output = torch.mm(input1, input2)
45        return output
46
47
48@requires_gpu
49@torch._inductor.config.patch(
50    post_grad_fusion_options={
51        "decompose_mm_pass": {},
52    }
53)
54@instantiate_parametrized_tests
55class TestDecomposeMemMM(TestCase):
56    def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3):
57        if len(set(ref_dict.keys())) != len(set(res_dict.keys())):
58            return False
59        for key1 in ref_dict.keys():
60            key2 = "_orig_mod." + key1
61            assert key2 in res_dict, f"{key1} does not exist in traced module"
62            if not torch.allclose(ref_dict[key1], res_dict[key2], rtol=rtol, atol=atol):
63                return False
64        return True
65
66    def compare_pred(self, module, traced, input, rtol=1e-3, atol=1e-3):
67        ref = module(*input)
68        res = traced(*input)
69        self.assertEqual(ref, res, rtol=rtol, atol=atol)
70
71    def compare_parameters(self, module, traced, rtol=1e-3, atol=1e-3):
72        ref_params = dict(module.named_parameters())
73        res_params = dict(traced.named_parameters())
74        self.assertTrue(self.compare_dict_tensors(ref_params, res_params, rtol, atol))
75
76    def compare_gradients(self, module, traced, rtol=1e-3, atol=1e-3):
77        ref_grad = {key: param.grad for key, param in module.named_parameters()}
78        res_grad = {key: param.grad for key, param in traced.named_parameters()}
79        self.assertTrue(
80            self.compare_dict_tensors(ref_grad, res_grad, rtol=rtol, atol=atol)
81        )
82
83    @parametrize(
84        "b,m,k,n,should_decompose",
85        [(10240, 2, 2, 2, True), (10240, 2, 32, 32, False), (2000, 2, 2, 2, False)],
86    )
87    def test_decompose_bmm(self, b, m, n, k, should_decompose):
88        torch._logging.set_logs(inductor=logging.DEBUG)
89        mat1 = torch.randn(b, m, k, device=GPU_TYPE).requires_grad_(True)
90        mat2 = torch.randn(b, k, n, device=GPU_TYPE).requires_grad_(True)
91
92        counters.clear()
93
94        module = MyModule2().to(GPU_TYPE)
95        traced = torch.compile(module)
96        input = [mat1, mat2]
97        ref = module(*input)
98        res = traced(*input)
99
100        self.compare_pred(module, traced, input)
101
102        expected_val = 1 if should_decompose and HAS_CUDA else 0
103        self.assertEqual(
104            counters["inductor"]["decompose_bmm"],
105            expected_val,
106        )
107
108        ref.sum().backward()
109        res.sum().backward()
110        self.compare_parameters(module, traced)
111        self.compare_gradients(module, traced)
112
113        expected_val = 3 if should_decompose and HAS_CUDA else 0
114        self.assertEqual(
115            counters["inductor"]["decompose_bmm"],
116            expected_val,
117        )
118        counters.clear()
119
120    @parametrize(
121        "m,k,n, should_decompose",
122        [(20480, 5, 2, True), (20480, 32, 2, False), (2048, 2, 2, False)],
123    )
124    @parametrize("has_bias", [True, False])
125    def test_decompose_linear(self, m, n, k, has_bias, should_decompose):
126        torch._logging.set_logs(inductor=logging.DEBUG)
127        input = torch.randn(m, k, device=GPU_TYPE).requires_grad_(True)
128
129        counters.clear()
130
131        module = MyModule(k, n, has_bias).to(GPU_TYPE)
132        traced = torch.compile(module)
133        input = [input]
134        ref = module(*input)
135        res = traced(*input)
136
137        self.compare_pred(module, traced, input)
138
139        expected_val = 1 if should_decompose and HAS_CUDA else 0
140        if has_bias:
141            self.assertEqual(
142                counters["inductor"]["decompose_addmm"],
143                expected_val,
144            )
145        else:
146            self.assertEqual(
147                counters["inductor"]["decompose_mm"],
148                expected_val,
149            )
150        decompose_mm_fwd = counters["inductor"]["decompose_mm"]
151
152        ref.sum().backward()
153        res.sum().backward()
154
155        self.compare_parameters(module, traced)
156        self.compare_gradients(module, traced)
157
158        self.assertEqual(
159            counters["inductor"]["decompose_mm"] - decompose_mm_fwd,
160            expected_val,
161        )
162        counters.clear()
163
164    @parametrize(
165        "m,k,n, should_decompose",
166        [(20480, 5, 2, True), (20480, 32, 2, False), (2048, 2, 2, False)],
167    )
168    @parametrize("has_bias", [True, False])
169    def test_decompose_linear_mixed_precision(
170        self, m, n, k, has_bias, should_decompose
171    ):
172        with torch.amp.autocast(device_type=GPU_TYPE, dtype=torch.bfloat16):
173            torch._logging.set_logs(inductor=logging.DEBUG)
174            input = torch.randn(m, k, device=GPU_TYPE).requires_grad_(True)
175
176            counters.clear()
177
178            module = MyModule(k, n, has_bias).to(GPU_TYPE)
179            traced = torch.compile(module)
180            input = [input]
181            ref = module(*input)
182            res = traced(*input)
183
184            self.compare_pred(module, traced, input)
185
186            expected_val = 1 if should_decompose and HAS_CUDA else 0
187            if has_bias:
188                self.assertEqual(
189                    counters["inductor"]["decompose_addmm"],
190                    expected_val,
191                )
192            else:
193                self.assertEqual(
194                    counters["inductor"]["decompose_mm"],
195                    expected_val,
196                )
197            decompose_mm_fwd = counters["inductor"]["decompose_mm"]
198
199            ref.sum().backward()
200            res.sum().backward()
201
202            self.compare_parameters(module, traced)
203            self.compare_gradients(module, traced)
204
205            self.assertEqual(
206                counters["inductor"]["decompose_mm"] - decompose_mm_fwd,
207                expected_val,
208            )
209            counters.clear()
210
211    @parametrize(
212        "m,k,n, should_decompose",
213        [(20480, 5, 2, True), (20480, 32, 2, False), (2048, 2, 2, False)],
214    )
215    @parametrize("has_bias", [True, False])
216    def test_decompose_mm(self, m, n, k, has_bias, should_decompose):
217        torch._logging.set_logs(inductor=logging.DEBUG)
218        mat1 = torch.randn(m, k, device=GPU_TYPE).requires_grad_(True)
219        mat2 = torch.randn(k, n, device=GPU_TYPE).requires_grad_(True)
220
221        counters.clear()
222
223        module = MyModule3().to(GPU_TYPE)
224        traced = torch.compile(module)
225        input = [mat1, mat2]
226        ref = module(*input)
227        res = traced(*input)
228
229        self.compare_pred(module, traced, input)
230
231        expected_val = 1 if should_decompose and HAS_CUDA else 0
232        self.assertEqual(
233            counters["inductor"]["decompose_mm"],
234            expected_val,
235        )
236        decompose_mm_fwd = counters["inductor"]["decompose_mm"]
237
238        ref.sum().backward()
239        res.sum().backward()
240        self.compare_parameters(module, traced)
241        self.compare_gradients(module, traced)
242
243        expected_val = 1 if should_decompose and HAS_CUDA else 0
244        self.assertEqual(
245            counters["inductor"]["decompose_mm"] - decompose_mm_fwd,
246            expected_val,
247        )
248        counters.clear()
249
250    @parametrize(
251        "m,k,n, should_decompose",
252        [(20480, 5, 2, True), (20480, 32, 2, False), (2048, 2, 2, False)],
253    )
254    @parametrize("has_bias", [True, False])
255    def test_decompose_mm_mixed_precision(self, m, n, k, has_bias, should_decompose):
256        with torch.amp.autocast(device_type=GPU_TYPE, dtype=torch.bfloat16):
257            torch._logging.set_logs(inductor=logging.DEBUG)
258            mat1 = torch.randn(m, k, device=GPU_TYPE).requires_grad_(True)
259            mat2 = torch.randn(k, n, device=GPU_TYPE).requires_grad_(True)
260
261            counters.clear()
262
263            module = MyModule3().to(GPU_TYPE)
264            traced = torch.compile(module)
265            input = [mat1, mat2]
266            ref = module(*input)
267            res = traced(*input)
268
269            self.compare_pred(module, traced, input)
270
271            expected_val = 1 if should_decompose and HAS_CUDA else 0
272            self.assertEqual(
273                counters["inductor"]["decompose_mm"],
274                expected_val,
275            )
276            decompose_mm_fwd = counters["inductor"]["decompose_mm"]
277
278            ref.sum().backward()
279            res.sum().backward()
280            self.compare_parameters(module, traced)
281            self.compare_gradients(module, traced)
282
283            expected_val = 1 if should_decompose and HAS_CUDA else 0
284            self.assertEqual(
285                counters["inductor"]["decompose_mm"] - decompose_mm_fwd,
286                expected_val,
287            )
288            counters.clear()
289
290    @parametrize("m,k,n, should_decompose", [(20480, 5, 2, True)])
291    @parametrize("has_bias", [True, False])
292    def test_dynamic_shape(self, m, n, k, has_bias, should_decompose):
293        torch._logging.set_logs(inductor=logging.DEBUG)
294        input = torch.randn(m, k, device=GPU_TYPE).requires_grad_(True)
295
296        counters.clear()
297
298        module = MyModule(k, n, has_bias).to(GPU_TYPE)
299        traced = torch.compile(module, dynamic=True)
300        input = [input]
301        ref = module(*input)
302        res = traced(*input)
303
304        self.compare_pred(module, traced, input)
305
306        expected_val = 1 if should_decompose and HAS_CUDA else 0
307        if has_bias:
308            self.assertEqual(
309                counters["inductor"]["decompose_addmm"],
310                expected_val,
311            )
312
313        ref.sum().backward()
314        res.sum().backward()
315
316        self.compare_parameters(module, traced)
317        self.compare_gradients(module, traced)
318
319        expected_val = 0
320        if HAS_CUDA:
321            expected_val = 1 if has_bias else 2
322
323        self.assertEqual(
324            counters["inductor"]["decompose_mm"],
325            expected_val,
326        )
327        counters.clear()
328
329    def test_realize_input(self):
330        m = 20480
331        k = 5
332        n = 2
333        torch._logging.set_logs(inductor=logging.DEBUG)
334        input1 = torch.randn(m, k, device=GPU_TYPE).T.contiguous()
335        input2 = torch.randn(k, n, device=GPU_TYPE)
336
337        @torch.compile()
338        def foo(x, y):
339            return x.T.contiguous() @ y
340
341        out, code = run_and_get_code(foo, input1, input2)
342
343        if GPU_TYPE == "xpu":
344            # only 1 kernel generated on the XPU stack
345            FileCheck().check_count(".run(", 1, exactly=True).run(code[0])
346        else:
347            # two kernels generated
348            FileCheck().check_count(".run(", 2, exactly=True).run(code[0])
349
350
351if __name__ == "__main__":
352    run_tests()
353