# Owner(s): ["module: inductor"] import sys import unittest import torch from torch.testing._internal.common_utils import IS_LINUX, skipIfXpu from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU try: import triton # noqa: F401 import triton.language as tl except ImportError: if __name__ == "__main__": sys.exit(0) raise unittest.SkipTest("requires triton") # noqa: B904 from torch._inductor import config from torch._inductor.runtime.hints import ( DeviceProperties, HeuristicType, TRITON_MAX_BLOCK, ) from torch._inductor.runtime.triton_helpers import math as tl_math from torch._inductor.runtime.triton_heuristics import CachingAutotuner, triton_config from torch._inductor.test_case import run_tests, TestCase class TestTritonHeuristics(TestCase): device_type = GPU_TYPE def test_triton_config(self): """ Make sure block size does not exceed the maximum defined in inductor config. """ cfg = triton_config([2048, 2], 64, 64) for label in "XYZ": key = f"{label}BLOCK" if key not in cfg.kwargs: continue self.assertTrue(cfg.kwargs[key] <= TRITON_MAX_BLOCK[label]) def _test_artificial_zgrid(self): def forward(primals_1, primals_2, primals_5): view = torch.ops.aten.reshape.default(primals_5, [-1, 2, 4]) primals_5 = None permute = torch.ops.aten.permute.default(view, [0, 2, 1]) clone = torch.ops.aten.clone.default( permute, memory_format=torch.contiguous_format ) permute = None view_1 = torch.ops.aten.reshape.default(clone, [-1, 4]) clone = None permute_1 = torch.ops.aten.permute.default(primals_1, [1, 0]) primals_1 = None addmm = torch.ops.aten.addmm.default(primals_2, view_1, permute_1) primals_2 = None return addmm s0 = 16777472 s1 = 8 args = [ torch.rand([2, 4], device=GPU_TYPE), torch.rand([2], device=GPU_TYPE), torch.rand([s0, s1], device=GPU_TYPE), ] torch._dynamo.mark_dynamic(args[-1], 0) foo_c = torch.compile(forward) self.assertEqual(forward(*args), foo_c(*args)) args = [ torch.rand([2, 4], device=GPU_TYPE), torch.rand([2], device=GPU_TYPE), torch.rand([s0, s1], device=GPU_TYPE), ] self.assertEqual(forward(*args), foo_c(*args)) @skipIfXpu def test_artificial_zgrid(self): self._test_artificial_zgrid() @skipIfXpu @config.patch("cpp_wrapper", True) def test_artificial_grid_cpp_wrapper(self): self._test_artificial_zgrid() def _get_cos_kernel_caching_autotuner_args(self): from triton.compiler.compiler import AttrsDescriptor @triton.jit def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): xnumel = 16 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex tmp0 = tl.load(in_ptr0 + (x0), xmask) tmp1 = tl_math.cos(tmp0) tl.store(out_ptr0 + (x0), tmp1, xmask) triton_meta = { "signature": {0: "*fp32", 1: "*fp32", 2: "i32"}, "device": DeviceProperties.create(torch.device("cuda")), "constants": {}, "configs": [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())], } configs = [ triton_config([16], 64), triton_config([256], 64), ] inductor_meta = {} return { "fn": triton_, "triton_meta": triton_meta, "configs": configs, "save_cache_hook": False, "mutated_arg_names": [], "heuristic_type": HeuristicType.POINTWISE, "inductor_meta": inductor_meta, } @skipIfXpu def test_pre_hook_assert(self): # assert if any of the configs passed to the CachingAutotuner have pre-hooks args = self._get_cos_kernel_caching_autotuner_args() def pre_hook(kwargs): if "in_ptr0" in kwargs: kwargs["in_ptr0"].zero_() for cfg in args["configs"]: cfg.pre_hook = pre_hook with self.assertRaisesRegex(AssertionError, "pre_hook"): autotuner = CachingAutotuner(**args) if __name__ == "__main__": if IS_LINUX and HAS_GPU: run_tests()