# Owner(s): ["module: inductor"] import contextlib import sympy import torch import torch._inductor.config as inductor_config from torch._inductor.codegen import triton_utils from torch._inductor.codegen.common import SizeArg from torch._inductor.graph import GraphLowering from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.virtualized import V from torch.testing._internal.inductor_utils import HAS_CPU, HAS_GPU class TestCodegenTriton(InductorTestCase): def setUp(self): super().setUp() class DummyModule(torch.nn.Module): def forward(self, x): return x * 2 self._gm = torch.fx.symbolic_trace(DummyModule()) self._graph = GraphLowering(self._gm) self._stack = contextlib.ExitStack() self._stack.enter_context(V.set_graph_handler(self._graph)) def tearDown(self): self._stack.close() super().tearDown() @inductor_config.patch("triton.divisible_by_16", True) def test_config_of_sizearg(self): two = sympy.Integer(2) eight = sympy.Integer(8) sixteen = sympy.Integer(16) s0 = sympy.Symbol("s0", positive=True, integer=True) s1 = sympy.Symbol("s1", positive=True, integer=True) self.assertEqual( (2,), triton_utils.config_of( [ SizeArg("A", two), # no SizeArg("B", eight), # no SizeArg("C", sixteen), # yes SizeArg("D", s0), # no SizeArg("E", s1), # no ] ).divisible_by_16, ) self.assertEqual( (0, 2, 4, 5, 6), triton_utils.config_of( [ SizeArg("A", two * eight), # 0: yes SizeArg("B", eight * s0), # 1: no SizeArg("C", two * eight * s0), # 2: yes SizeArg("D", s0 * s1), # 3: no SizeArg("E", sixteen * s0), # 4: yes SizeArg("F", sixteen * eight * s0 * s1), # 5: yes SizeArg("G", two * eight * s0 * s1), # 6: yes ] ).divisible_by_16, ) if __name__ == "__main__": from torch._inductor.test_case import run_tests if HAS_CPU or HAS_GPU: run_tests("sympy")