# Owner(s): ["module: inductor"] from typing import Any, Callable import torch from torch._inductor.fx_passes.pre_grad import ( linear_permute_fusion, linear_transpose, permute_linear_fusion, permute_matmul_fusion, sink_cat_after_pointwise, transpose_linear, transpose_matmul, ) from torch._inductor.test_case import run_tests, TestCase from torch.fx.passes.shape_prop import ShapeProp PassFunc = Callable[[torch.fx.GraphModule, Any], torch.fx.GraphModule] def chain_passes(*passes: PassFunc) -> PassFunc: def parent_pass(module: torch.fx.GraphModule, input: Any) -> torch.fx.GraphModule: for pass_ in passes: if isinstance(module, torch.fx.GraphModule): ShapeProp(module).propagate(*input) module = pass_(module) return module return parent_pass def count_call(module: torch.fx.GraphModule, op: str, target_op: Any) -> int: return sum( 1 if (n.op == op and n.target == target_op) else 0 for n in module.graph.nodes ) def count_call_function(module: torch.fx.GraphModule, target_op: Any) -> int: return count_call(module, "call_function", target_op) def count_call_method(module: torch.fx.GraphModule, target_op: Any) -> int: return count_call(module, "call_method", target_op) class TestFxFusion(TestCase): def test_sink_cat_after_pointwise(self): def test_kwarg(x, y): return torch.cat([x, y], dim=-1).view(-1).view(128).tanh() def test_arg(x, y): return torch.cat([x, y], -1).view(-1).view(128).tanh() def test_arg2(x, y): return torch.cat([x, y]).view(-1).view(128).tanh() def test_kwarg2(x, y): return torch.cat(tensors=[x, y], dim=0).tanh() def test_kwarg3(x, y): return torch.cat(tensors=[x, y], dim=0).view(128).tanh() trace_func = chain_passes(torch.fx.symbolic_trace, sink_cat_after_pointwise) inputs = [ torch.randn(8, 8), torch.randn(8, 8), ] for f in [test_kwarg, test_arg, test_arg2, test_kwarg2, test_kwarg3]: traced = trace_func(f, inputs) torch.testing.assert_close(f(*inputs), traced(*inputs)) self.assertEqual(count_call_method(traced, "tanh"), 2) def test_linear_permute_fusion(self): class TestModule(torch.nn.Module): def __init__(self, k: int, n: int, has_bias: bool): super().__init__() self.weight = torch.nn.Parameter(torch.randn(n, k)) self.has_bias = has_bias if has_bias: self.bias = torch.nn.Parameter(torch.randn(n)) def forward(self, input: torch.Tensor): if self.has_bias: a0 = torch.nn.functional.linear(input, self.weight, self.bias) else: a0 = torch.nn.functional.linear(input, self.weight) b0 = a0.permute(0, 2, 1) return b0 m, k, n = 16, 8, 4 trace_func = chain_passes(torch.fx.symbolic_trace, linear_permute_fusion) for has_bias in [True, False]: module = TestModule(k, n, has_bias).eval() input = torch.randn(6, m, k) traced = trace_func(module, [input]) num_linear = count_call_function(traced, torch.nn.functional.linear) num_linear_transpose = count_call_function(traced, linear_transpose) self.assertEqual(num_linear, 0) self.assertEqual(num_linear_transpose, 1) torch.testing.assert_close(module(input), traced(input)) def test_permute_linear_fusion(self): class TestModule(torch.nn.Module): def __init__(self, k: int, n: int, has_bias: bool): super().__init__() self.weight = torch.nn.Parameter(torch.randn(n, k)) self.has_bias = has_bias if has_bias: self.bias = torch.nn.Parameter(torch.randn(n)) def forward(self, input: torch.Tensor): input1 = input.permute(0, 2, 1) if self.has_bias: return torch.nn.functional.linear(input1, self.weight, self.bias) return torch.nn.functional.linear(input1, self.weight) m, k, n = 16, 8, 4 trace_func = chain_passes(torch.fx.symbolic_trace, permute_linear_fusion) for has_bias in [True, False]: module = TestModule(k, n, has_bias).eval() input = torch.randn(6, k, m) traced = trace_func(module, [input]) num_linear = count_call_function(traced, torch.nn.functional.linear) num_transpose_linear = count_call_function(traced, transpose_linear) self.assertEqual(num_linear, 0) self.assertEqual(num_transpose_linear, 1) torch.testing.assert_close(module(input), traced(input)) def test_permute_bmm_fusion(self): class TestModule(torch.nn.Module): def __init__(self, batch: int, k: int, n: int): super().__init__() self.other = torch.randn(batch, k, n) def forward(self, input: torch.Tensor): input1 = input.permute(0, 2, 1) output = torch.bmm(input1, self.other) return output batch, m, k, n = 6, 16, 8, 4 trace_func = chain_passes(torch.fx.symbolic_trace, permute_matmul_fusion) module = TestModule(batch, k, n).eval() input = torch.randn(batch, k, m) traced = trace_func(module, [input]) num_bmm = count_call_function(traced, torch.bmm) num_transpose_matmul = count_call_function(traced, transpose_matmul) self.assertEqual(num_bmm, 0) self.assertEqual(num_transpose_matmul, 1) torch.testing.assert_close(module(input), traced(input)) if __name__ == "__main__": run_tests()