# Owner(s): ["module: inductor"] import os import re import unittest import torch from torch import nn from torch._dynamo.testing import reset_rng_state from torch._inductor import config, test_operators from torch._inductor.codegen.multi_kernel import MultiKernelCall from torch._inductor.test_case import TestCase from torch._inductor.utils import run_and_get_code from torch.nn import functional as F from torch.testing import make_tensor from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, skipIfXpu, ) from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU class TransformerSnippet(nn.Module): def __init__(self) -> None: super().__init__() self.ln1 = nn.LayerNorm(64) self.ln2 = nn.LayerNorm(64) def forward(self, x1, x2): x1 = F.dropout(x1, 0.1) x2 = F.dropout(self.ln1(x2), 0.1) return self.ln2(x1 + x2) def example_inputs(self): return (torch.randn(2, 64).to(GPU_TYPE), torch.randn(2, 64).to(GPU_TYPE)) def _contains_multi_kernel_code(wrapper_code: str): return ( re.search(r"multi_kernel_[^ ]* = async_compile.multi_kernel[(]", wrapper_code) is not None ) def make_cpp_wrapper_test(orig_test, **extra_args): """ Wrap an existing test into a new test with cpp-wrapper enabled. Make this as a free function rather than staticmethod in MultiKernelTest. Otherwise we get 'TypeError: 'staticmethod' object is not callable' error in py3.8. (py3.10 works) """ @config.patch("cpp_wrapper", True) @skipIfXpu(msg="cpp wrapper doesn't currently work on the XPU stack") def fn(self): # The same kernel may have been compiled by previous tests with # cpp_wrapper disabled. Clear the cache so we go ahead to re-compile # the kernel with cpp_wrapper enabled. from torch._inductor import codecache codecache.PyCodeCache.cache_clear() return orig_test(self, **extra_args) return fn @config.patch( { "triton.multi_kernel": int(os.environ.get("TORCHINDUCTOR_MULTI_KERNEL", "1")), "benchmark_kernel": True, } ) @instantiate_parametrized_tests class MultiKernelTest(TestCase): def test_softmax(self, expect_multi_kernel=True): x = torch.rand(2, 1024).to(GPU_TYPE) ref = torch.softmax(x, -1) compiled_fn = torch.compile(torch.softmax) act, wrapper_code = run_and_get_code(compiled_fn, x, -1) # wrapper_code will contains 2 entries if cpp_wrapper=True. # One for the first pass and one for the second pass. # We mainly care about the wrapper for the final pass here. wrapper_code = wrapper_code[-1] self.assertEqual(ref, act) if expect_multi_kernel: self.assertTrue(_contains_multi_kernel_code(wrapper_code)) else: # Skip verifying the wrapper_code in fbcode since we may fail # compiling the cpp wrapper cuda code due to lacking proper setup of # cuda compiler in fbcode environment. In that case, the last # collected wrapper_code will corresponds to the first pass # cpp-wrapper codegen which contains the multi-kernel. if not config.is_fbcode(): self.assertFalse(_contains_multi_kernel_code(wrapper_code)) @parametrize("force_kernel", (0, 1)) @unittest.mock.patch.dict( os.environ, {"TORCHINDUCTOR_DISABLE_MULTI_KERNEL_CACHE": "1"} ) def test_softmax_force_non_persistent_reduction(self, force_kernel): """ Force a specific sub-kernel being picked by mocking the benchmark result. """ x = torch.rand(2, 1024).to(GPU_TYPE) mock_latency = [0.2, 0.2] mock_latency[force_kernel] = 0.1 # this make sure force_kernel will be picked def f(x): return torch.softmax(x, -1) + force_kernel orig_run = MultiKernelCall.run picked_kernel = None def mock_run(self, *args, **kwargs): out = orig_run(self, *args, **kwargs) nonlocal picked_kernel picked_kernel = self.picked_kernel return out with unittest.mock.patch.object( MultiKernelCall, "run", mock_run ), unittest.mock.patch.object( MultiKernelCall, "benchmark_sub_kernels", lambda *args, **kwargs: mock_latency, ): torch.compile(f)(x) self.assertEqual(picked_kernel, force_kernel) @config.patch("warn_mix_layout", True) def test_softmax_warn_mixed_layout(self): self.test_softmax() test_softmax_cpp_wrapper = make_cpp_wrapper_test( test_softmax, expect_multi_kernel=False ) def test_layernorm(self): ln = nn.LayerNorm(1024).to(GPU_TYPE) x = torch.rand(2, 1024).to(GPU_TYPE) ref = ln(x) act = torch.compile(ln)(x) self.assertEqual(ref, act, atol=1e-4, rtol=1e-4) def test_inplace_update(self): """ Inductor generate inplace kernel for mul. """ def f(x, y): return x.sum(dim=-1, keepdims=True) * (y @ y) x = torch.rand(1024, 1024).to(GPU_TYPE) y = torch.rand(1024, 1024).to(GPU_TYPE) ref = f(x, y) act = torch.compile(f)(x, y) self.assertEqual(ref, act) def test_transformer_snippet(self): model = TransformerSnippet().to(GPU_TYPE) x = model.example_inputs() def f(*x): y = model(*x) return y reset_rng_state() ref = f(*x) opt_f = torch.compile(f) reset_rng_state() act = opt_f(*x) # don't compare tensor if using inductor random number generator. # inductor random number implementation is different to eager. # We should fallback to eager if we want to test accuracy. if config.fallback_random: self.assertEqual(ref, act, atol=1e-4, rtol=1e-4) def test_transformer_snippet_with_fallback_random(self): """ Same as test_transformer_snippet but fallback the random number generator to eager so we can check accuracy. """ with config.patch("fallback_random", True): self.test_transformer_snippet() def test_batchnorm_training(self): """ For training, batchnorm will tracking running mean/variance during forward pass. The kernel generated by inductor currently will pass in those tensors twice as arguments: once for input and once for output. They are ruled out as in-out argument because they are considered as graph inputs. Multi-kernel previously assumes that we never pass the same argument mutli times for a kernel. No mater if we change inductor behavior to assure that, it's better to make multi-kernel being able to handle those cases. """ bn = nn.BatchNorm2d(3).to(GPU_TYPE) @torch.compile def f(x): bn(x).sum().backward() _, (wrapper_code, _) = run_and_get_code( f, torch.randn(2, 3, 8, 8, device=GPU_TYPE) ) self.assertTrue(_contains_multi_kernel_code(wrapper_code)) def test_pass_same_arg_multi_times(self): """ A super simple example that simulate how BatchNorm update the running stats. Inductor currently pass the same tensor multiple times for the generated kernel: once for input and once for output. Here is a paster for the generated kernel (without multi-kernel enabled): https://gist.github.com/shunting314/f0b446b4b9a28f4940e31dcd3e809cf9 """ def f(x, y): x = x.sum(dim=1, keepdim=False) y.copy_(y * 0.9 + x * 0.1) x = torch.randn(8, 16, device=GPU_TYPE) y = torch.randn(8, device=GPU_TYPE) y_ref = y.clone() ref = f(x, y_ref) act = torch.compile(f)(x, y) self.assertEqual(y_ref, y) def test_reduction_scratch_buffer(self, force_multi_kernel=1): """ The explicited realized buffer in the test function will be passed in as a scratch buffer for the non-persistent reduction kernel but can be skipped for the persistent reduction kernel. This causes different argument lists for non-persistent reduction kernel and persistent reduction kernel. Check documentation around torch._inductor.config.triton.multi_kernel about how to interpret the force_multi_kernel argument. """ def f(x): x = x.sum(dim=-1, keepdim=True) + x x = test_operators.realize(x) x = x.sum(dim=-1, keepdim=True) + x return x x = torch.rand(16, 16, device=GPU_TYPE) ref = f(x) with config.patch("triton.multi_kernel", force_multi_kernel): act = torch.compile(f)(x) self.assertEqual(ref, act) def test_split_scan(self, force_multi_kernel=1): def f(x): x = x.view(-1) return torch.cumsum(x, 0) x = make_tensor(10, 3, 352, 352, low=0, dtype=torch.float32, device=GPU_TYPE) expect = f(x) with config.patch("triton.multi_kernel", force_multi_kernel): actual = torch.compile(f)(x) self.assertEqual(expect, actual) def test_sort_disables_multi_kernel(self, force_multi_kernel=1): """ Sort currently requires a persistent kernel, so multi-kernel is not possible. Make sure this falls back gracefully. """ def f(x): return x.sort(-1).values x = torch.rand(32, 32, device=GPU_TYPE) expect = f(x) with config.patch("triton.multi_kernel", force_multi_kernel): actual = torch.compile(f)(x) self.assertEqual(expect, actual) # Use benchmarking to pick the faster kernel test_reduction_scratch_buffer_cpp_wrapper = make_cpp_wrapper_test( test_reduction_scratch_buffer, force_multi_kernel=1 ) # force pick persistent reduction. This can be a good test since this persistent # reduction uses less call arguments than the corresponding non-persistent # reduction. test_reduction_scratch_buffer_cpp_wrapper_persistent_reduction = ( make_cpp_wrapper_test(test_reduction_scratch_buffer, force_multi_kernel=2) ) # force pick non-persistent reduction test_reduction_scratch_buffer_cpp_wrapper_non_persistent_reduction = ( make_cpp_wrapper_test(test_reduction_scratch_buffer, force_multi_kernel=3) ) if __name__ == "__main__": from torch._inductor.test_case import run_tests if HAS_GPU: run_tests()