1# Owner(s): ["module: inductor"] 2 3import functools 4import unittest 5 6import torch 7from torch._inductor.test_case import run_tests, TestCase 8from torch._inductor.utils import run_and_get_code 9from torch.testing import FileCheck 10from torch.testing._internal.common_cuda import TEST_MULTIGPU 11from torch.testing._internal.common_utils import IS_LINUX 12from torch.testing._internal.inductor_utils import HAS_CUDA 13 14 15requires_multigpu = functools.partial( 16 unittest.skipIf, not TEST_MULTIGPU, "requires multiple cuda devices" 17) 18 19aten = torch.ops.aten 20 21 22class TestMoveConstructorsToCuda(TestCase): 23 def _check_fn(self, func, expect_cpu, *args): 24 out_eager = func(*args) 25 26 out_compiled, code = run_and_get_code(torch.compile(func), *args) 27 self.assertEqual(out_eager, out_compiled) 28 29 assert len(code) == 1 30 if expect_cpu: 31 FileCheck().check("cpp_fused").run(code[0]) 32 else: 33 FileCheck().check_not("cpp_fused").run(code[0]) 34 35 def test_simple(self): 36 def foo(x): 37 return x[torch.arange(x.shape[0])] 38 39 inp = torch.rand(32, 77, 512, device="cuda") 40 41 self._check_fn(foo, False, inp) 42 43 def test_output_failure(self): 44 def foo(x): 45 tmp1 = torch.arange(x.shape[0]) 46 return tmp1, x[tmp1] 47 48 inp = torch.rand(32, 77, 512, device="cuda") 49 50 self._check_fn(foo, True, inp) 51 52 def test_non_convertable_op_failure(self): 53 def foo(x): 54 y = torch.arange(x.shape[0]) 55 return x + y, torch.ones([4], device="cuda") 56 57 inp = torch.rand([100]) 58 59 self._check_fn(foo, True, inp) 60 61 def test_multiple_constructors(self): 62 def foo(x): 63 tmp1 = torch.arange(x.shape[0]) 64 o1 = x[tmp1] 65 tmp2 = torch.arange(x.shape[1]).view([1, x.shape[1]]) 66 o2 = x[tmp2] 67 return o1, o2, o1 + o2 68 69 inp = torch.rand([200, 200]) 70 self._check_fn(foo, True, inp) 71 72 def test_sets_equiv(self): 73 @torch.compile() 74 def foo(x): 75 c1 = torch.ones([4], dtype=torch.long) 76 c2 = torch.arange(-1, 3) 77 return x[c1 + c2], c2 - 4 * 2 78 79 inp = torch.rand([4]).cuda() 80 out, code = run_and_get_code(foo, inp) 81 FileCheck().check_not("triton.jit").run(code[0]) 82 83 @torch.compile() 84 def foo(x): 85 c2 = torch.arange(-1, 3) 86 c1 = torch.ones([4], dtype=torch.long) 87 return x[c1 + c2], c2 - 4 * 2 88 89 out, code = run_and_get_code(foo, inp) 90 FileCheck().check_not("triton.jit").run(code[0]) 91 92 @requires_multigpu() 93 def test_multi_gpu(self): 94 def foo(x): 95 return ( 96 x[torch.arange(x.shape[0])], 97 torch.ones([4], device="cuda:0"), 98 torch.ones([4], device="cuda:1"), 99 ) 100 101 # nyi, multi-gpu 102 inp = torch.rand([100], device="cuda") 103 self._check_fn(foo, True, inp) 104 105 def test_no_gpu(self): 106 def foo(x): 107 return x[torch.arange(x.shape[0])] 108 109 inp = torch.rand([100]) 110 self._check_fn(foo, True, inp) 111 112 113if __name__ == "__main__": 114 if IS_LINUX and HAS_CUDA: 115 run_tests() 116