• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["module: inductor"]
2
3import functools
4import unittest
5
6import torch
7from torch._inductor.test_case import run_tests, TestCase
8from torch._inductor.utils import run_and_get_code
9from torch.testing import FileCheck
10from torch.testing._internal.common_cuda import TEST_MULTIGPU
11from torch.testing._internal.common_utils import IS_LINUX
12from torch.testing._internal.inductor_utils import HAS_CUDA
13
14
15requires_multigpu = functools.partial(
16    unittest.skipIf, not TEST_MULTIGPU, "requires multiple cuda devices"
17)
18
19aten = torch.ops.aten
20
21
22class TestMoveConstructorsToCuda(TestCase):
23    def _check_fn(self, func, expect_cpu, *args):
24        out_eager = func(*args)
25
26        out_compiled, code = run_and_get_code(torch.compile(func), *args)
27        self.assertEqual(out_eager, out_compiled)
28
29        assert len(code) == 1
30        if expect_cpu:
31            FileCheck().check("cpp_fused").run(code[0])
32        else:
33            FileCheck().check_not("cpp_fused").run(code[0])
34
35    def test_simple(self):
36        def foo(x):
37            return x[torch.arange(x.shape[0])]
38
39        inp = torch.rand(32, 77, 512, device="cuda")
40
41        self._check_fn(foo, False, inp)
42
43    def test_output_failure(self):
44        def foo(x):
45            tmp1 = torch.arange(x.shape[0])
46            return tmp1, x[tmp1]
47
48        inp = torch.rand(32, 77, 512, device="cuda")
49
50        self._check_fn(foo, True, inp)
51
52    def test_non_convertable_op_failure(self):
53        def foo(x):
54            y = torch.arange(x.shape[0])
55            return x + y, torch.ones([4], device="cuda")
56
57        inp = torch.rand([100])
58
59        self._check_fn(foo, True, inp)
60
61    def test_multiple_constructors(self):
62        def foo(x):
63            tmp1 = torch.arange(x.shape[0])
64            o1 = x[tmp1]
65            tmp2 = torch.arange(x.shape[1]).view([1, x.shape[1]])
66            o2 = x[tmp2]
67            return o1, o2, o1 + o2
68
69        inp = torch.rand([200, 200])
70        self._check_fn(foo, True, inp)
71
72    def test_sets_equiv(self):
73        @torch.compile()
74        def foo(x):
75            c1 = torch.ones([4], dtype=torch.long)
76            c2 = torch.arange(-1, 3)
77            return x[c1 + c2], c2 - 4 * 2
78
79        inp = torch.rand([4]).cuda()
80        out, code = run_and_get_code(foo, inp)
81        FileCheck().check_not("triton.jit").run(code[0])
82
83        @torch.compile()
84        def foo(x):
85            c2 = torch.arange(-1, 3)
86            c1 = torch.ones([4], dtype=torch.long)
87            return x[c1 + c2], c2 - 4 * 2
88
89        out, code = run_and_get_code(foo, inp)
90        FileCheck().check_not("triton.jit").run(code[0])
91
92    @requires_multigpu()
93    def test_multi_gpu(self):
94        def foo(x):
95            return (
96                x[torch.arange(x.shape[0])],
97                torch.ones([4], device="cuda:0"),
98                torch.ones([4], device="cuda:1"),
99            )
100
101        # nyi, multi-gpu
102        inp = torch.rand([100], device="cuda")
103        self._check_fn(foo, True, inp)
104
105    def test_no_gpu(self):
106        def foo(x):
107            return x[torch.arange(x.shape[0])]
108
109        inp = torch.rand([100])
110        self._check_fn(foo, True, inp)
111
112
113if __name__ == "__main__":
114    if IS_LINUX and HAS_CUDA:
115        run_tests()
116