# Owner(s): ["oncall: jit"] import torch import torch._C from torch.testing import FileCheck from torch.testing._internal.jit_utils import JitTestCase class TestGraphRewritePasses(JitTestCase): def test_fuse_linear(self): class FunctionalLinear(torch.nn.Module): def __init__(self, weight, bias): super().__init__() self.weight = weight self.bias = bias def forward(self, x): res = torch.matmul(x, self.weight.t()) if self.bias is not None: res.add_(self.bias) return res x1 = torch.rand(3) w1 = torch.rand(5, 3) b1 = torch.rand(5) for has_bias in [True, False]: bias = b1 if has_bias else None model = torch.jit.trace(FunctionalLinear(w1, bias), [x1]) for node in model.graph.nodes(): if node.kind() == "aten::matmul": source_range_1 = node.sourceRange() torch._C._jit_pass_fuse_linear(model.graph) for node in model.graph.nodes(): if node.kind() == "aten::linear": source_range_2 = node.sourceRange() FileCheck().check("aten::linear").run(model.graph) check_not = ["aten::matmul", "aten::addmm", "aten::add_", "aten::t("] for cn in check_not: FileCheck().check_not(cn).run(model.graph) self.assertTrue(source_range_1 == source_range_2) # make sure it runs model(x1) # check matmuls are not fused class Matmul(torch.nn.Module): def __init__(self, weight): super().__init__() self.weight = weight def forward(self, x): return torch.matmul(x, self.weight) x = torch.rand(5, 6, 5) w = torch.rand(5, 5, 100) model = torch.jit.trace(Matmul(w), [x]) torch._C._jit_pass_fuse_linear(model.graph) # check 3d matmul is not fused FileCheck().check("aten::matmul").run(model.graph) FileCheck().check_not("aten::linear").run(model.graph) # make sure it runs model(x)