1# Owner(s): ["module: inductor"] 2 3import sys 4import unittest 5from unittest import mock 6 7import torch 8from torch._inductor.runtime.hints import TRITON_MAX_BLOCK 9from torch._inductor.test_case import run_tests, TestCase 10from torch.testing._internal.common_utils import IS_LINUX 11from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU 12 13 14try: 15 import triton 16except ImportError: 17 if __name__ == "__main__": 18 sys.exit(0) 19 raise unittest.SkipTest("requires triton") # noqa: B904 20 21from torch._inductor import config 22from torch._inductor.runtime.coordinate_descent_tuner import CoordescTuner 23 24 25config.benchmark_kernel = True 26config.coordinate_descent_tuning = True 27 28orig_compare_config = CoordescTuner.compare_config 29 30 31def mock_compare_config_prefer_larger_XBLOCK( 32 self, func, candidate_config, best_config, best_timing 33): 34 """ 35 self is the CoordescTuner object 36 """ 37 if "XBLOCK" in candidate_config.kwargs: 38 assert "XBLOCK" in best_config.kwargs 39 if candidate_config.kwargs["XBLOCK"] < best_config.kwargs["XBLOCK"]: 40 func(candidate_config) # run func so the launcher will be created 41 return False, best_timing * 1.1 42 elif candidate_config.kwargs["XBLOCK"] > best_config.kwargs["XBLOCK"]: 43 func(candidate_config) 44 return True, best_timing * 0.9 45 46 return orig_compare_config(self, func, candidate_config, best_config, best_timing) 47 48 49class TestCoordinateDescentTuner(TestCase): 50 def test_abs_function(self): 51 """ 52 The benchmark result is simply abs(XBLOCK - 15) 53 """ 54 tuner = CoordescTuner() 55 baseline_config = triton.Config({"XBLOCK": 1}, num_warps=8, num_stages=1) 56 57 def func(config): 58 return abs(config.kwargs["XBLOCK"] - 15) 59 60 best_config = tuner.autotune(func, baseline_config) 61 self.assertTrue(best_config.kwargs.get("XBLOCK") == 16, str(best_config)) 62 63 def test_no_neighbors(self): 64 """ 65 Test the case that there is no available neighbor values for a field. 66 """ 67 68 # size hint for x being 1 limits the max XBLOCK we try to be 1 69 tuner = CoordescTuner(size_hints=[1]) 70 baseline_config = triton.Config({"XBLOCK": 1}, num_warps=8, num_stages=1) 71 72 def func(config): 73 return abs(config.kwargs["XBLOCK"] - 15) 74 75 best_config = tuner.autotune(func, baseline_config) 76 self.assertTrue(best_config.kwargs.get("XBLOCK") == 1, str(best_config)) 77 78 def test_get_neighbour_values(self): 79 tuner = CoordescTuner() 80 81 neighbours = tuner.get_neighbour_values("num_stages", 2, radius=2) 82 self.assertEqual(set(neighbours), {1, 3, 4}) 83 neighbours = tuner.get_neighbour_values("num_warps", 2, radius=2) 84 self.assertEqual(set(neighbours), {1, 4, 8}) 85 86 def test_persistent_reduction(self): 87 def f(x): 88 return x / x.sum(dim=-1, keepdim=True) 89 90 with mock.patch.object( 91 CoordescTuner, "compare_config", mock_compare_config_prefer_larger_XBLOCK 92 ): 93 x = torch.ones(2, 256).to(GPU_TYPE) 94 expected = f(x) 95 # the first call get correct result when cache miss. Don't know why yet 96 _ = torch.compile(f)(x) 97 actual = torch.compile(f)(x) 98 self.assertTrue( 99 torch.allclose(expected, actual, atol=1e-4, rtol=1e-4), 100 f"Expected:\n{expected}\nActual:\n{actual}", 101 ) 102 103 def test_value_too_large(self): 104 # Simulate a reduction 105 size_hints = [2**20, 2**20] 106 107 tuner = CoordescTuner(size_hints=size_hints) 108 109 max_block = TRITON_MAX_BLOCK 110 self.assertFalse(tuner.value_too_large("XBLOCK", max_block["X"])) 111 self.assertTrue(tuner.value_too_large("XBLOCK", max_block["X"] * 2)) 112 self.assertFalse(tuner.value_too_large("RBLOCK", max_block["R"])) 113 self.assertTrue(tuner.value_too_large("RBLOCK", max_block["R"] * 2)) 114 115 116if __name__ == "__main__": 117 if IS_LINUX and HAS_GPU: 118 run_tests() 119