• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["module: inductor"]
2
3import sys
4import unittest
5from unittest import mock
6
7import torch
8from torch._inductor.runtime.hints import TRITON_MAX_BLOCK
9from torch._inductor.test_case import run_tests, TestCase
10from torch.testing._internal.common_utils import IS_LINUX
11from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
12
13
14try:
15    import triton
16except ImportError:
17    if __name__ == "__main__":
18        sys.exit(0)
19    raise unittest.SkipTest("requires triton")  # noqa: B904
20
21from torch._inductor import config
22from torch._inductor.runtime.coordinate_descent_tuner import CoordescTuner
23
24
25config.benchmark_kernel = True
26config.coordinate_descent_tuning = True
27
28orig_compare_config = CoordescTuner.compare_config
29
30
31def mock_compare_config_prefer_larger_XBLOCK(
32    self, func, candidate_config, best_config, best_timing
33):
34    """
35    self is the CoordescTuner object
36    """
37    if "XBLOCK" in candidate_config.kwargs:
38        assert "XBLOCK" in best_config.kwargs
39        if candidate_config.kwargs["XBLOCK"] < best_config.kwargs["XBLOCK"]:
40            func(candidate_config)  # run func so the launcher will be created
41            return False, best_timing * 1.1
42        elif candidate_config.kwargs["XBLOCK"] > best_config.kwargs["XBLOCK"]:
43            func(candidate_config)
44            return True, best_timing * 0.9
45
46    return orig_compare_config(self, func, candidate_config, best_config, best_timing)
47
48
49class TestCoordinateDescentTuner(TestCase):
50    def test_abs_function(self):
51        """
52        The benchmark result is simply abs(XBLOCK - 15)
53        """
54        tuner = CoordescTuner()
55        baseline_config = triton.Config({"XBLOCK": 1}, num_warps=8, num_stages=1)
56
57        def func(config):
58            return abs(config.kwargs["XBLOCK"] - 15)
59
60        best_config = tuner.autotune(func, baseline_config)
61        self.assertTrue(best_config.kwargs.get("XBLOCK") == 16, str(best_config))
62
63    def test_no_neighbors(self):
64        """
65        Test the case that there is no available neighbor values for a field.
66        """
67
68        # size hint for x being 1 limits the max XBLOCK we try to be 1
69        tuner = CoordescTuner(size_hints=[1])
70        baseline_config = triton.Config({"XBLOCK": 1}, num_warps=8, num_stages=1)
71
72        def func(config):
73            return abs(config.kwargs["XBLOCK"] - 15)
74
75        best_config = tuner.autotune(func, baseline_config)
76        self.assertTrue(best_config.kwargs.get("XBLOCK") == 1, str(best_config))
77
78    def test_get_neighbour_values(self):
79        tuner = CoordescTuner()
80
81        neighbours = tuner.get_neighbour_values("num_stages", 2, radius=2)
82        self.assertEqual(set(neighbours), {1, 3, 4})
83        neighbours = tuner.get_neighbour_values("num_warps", 2, radius=2)
84        self.assertEqual(set(neighbours), {1, 4, 8})
85
86    def test_persistent_reduction(self):
87        def f(x):
88            return x / x.sum(dim=-1, keepdim=True)
89
90        with mock.patch.object(
91            CoordescTuner, "compare_config", mock_compare_config_prefer_larger_XBLOCK
92        ):
93            x = torch.ones(2, 256).to(GPU_TYPE)
94            expected = f(x)
95            # the first call get correct result when cache miss. Don't know why yet
96            _ = torch.compile(f)(x)
97            actual = torch.compile(f)(x)
98            self.assertTrue(
99                torch.allclose(expected, actual, atol=1e-4, rtol=1e-4),
100                f"Expected:\n{expected}\nActual:\n{actual}",
101            )
102
103    def test_value_too_large(self):
104        # Simulate a reduction
105        size_hints = [2**20, 2**20]
106
107        tuner = CoordescTuner(size_hints=size_hints)
108
109        max_block = TRITON_MAX_BLOCK
110        self.assertFalse(tuner.value_too_large("XBLOCK", max_block["X"]))
111        self.assertTrue(tuner.value_too_large("XBLOCK", max_block["X"] * 2))
112        self.assertFalse(tuner.value_too_large("RBLOCK", max_block["R"]))
113        self.assertTrue(tuner.value_too_large("RBLOCK", max_block["R"] * 2))
114
115
116if __name__ == "__main__":
117    if IS_LINUX and HAS_GPU:
118        run_tests()
119