# Owner(s): ["oncall: jit"] import torch from torch.testing import FileCheck from torch.testing._internal.jit_utils import JitTestCase if __name__ == "__main__": raise RuntimeError( "This test file is not meant to be run directly, use:\n\n" "\tpython test/test_jit.py TESTNAME\n\n" "instead." ) class TestBatchMM(JitTestCase): @staticmethod def _get_test_tensors(n: int): return [ torch.tensor([[1 + x, 2 + x, 3 + x], [4 + x, 5 + x, 6 + x]]) if x % 2 == 0 else torch.tensor([[1 + x, 2 + x], [3 + x, 4 + x], [5 + x, 6 + x]]) for x in range(n) ] def test_batch_mm_no_mutation(self): def test_batch_mm( T1: torch.Tensor, T2: torch.Tensor, T3: torch.Tensor, T4: torch.Tensor, T5: torch.Tensor, T6: torch.Tensor, T7: torch.Tensor, T8: torch.Tensor, ): return ( torch.mm(T1, T2) + torch.mm(T3, T4) + torch.mm(T5, T6) + torch.mm(T7, T8) ) test_batch_mm_scripted = torch.jit.script(test_batch_mm) tensors = TestBatchMM._get_test_tensors(8) expected = test_batch_mm(*tensors) FileCheck().check_count("aten::mm", 4, exactly=True).run( test_batch_mm_scripted.graph ) self.run_pass("batch_mm", test_batch_mm_scripted.graph) FileCheck().check_count("prim::MMTreeReduce", 1, exactly=True).run( test_batch_mm_scripted.graph ) actual = test_batch_mm_scripted(*tensors) self.assertEqual(expected, actual, atol=1e-9, rtol=1e-9) def test_batch_mm_permitted_mutation(self): def test_batch_mm( T1: torch.Tensor, T2: torch.Tensor, T3: torch.Tensor, T4: torch.Tensor, T5: torch.Tensor, T6: torch.Tensor, T7: torch.Tensor, T8: torch.Tensor, ): result = {} result["product"] = ( torch.mm(T1, T2) + torch.mm(T3, T4) + torch.mm(T5, T6) + torch.mm(T7, T8) ) result["constant"] = torch.tensor([42.0]) return result test_batch_mm_scripted = torch.jit.script(test_batch_mm) tensors = TestBatchMM._get_test_tensors(8) expected = test_batch_mm(*tensors) FileCheck().check_count("aten::mm", 4, exactly=True).run( test_batch_mm_scripted.graph ) self.run_pass("batch_mm", test_batch_mm_scripted.graph) FileCheck().check_count("prim::MMTreeReduce", 1, exactly=True).run( test_batch_mm_scripted.graph ) actual = test_batch_mm_scripted(*tensors) self.assertEqual(expected, actual, atol=1e-9, rtol=1e-9) def test_batch_mm_prohibited_mutation(self): @torch.jit.script def test_batch_mm(n: int): T1 = torch.zeros((n, n)) T2 = torch.zeros((n, n)) T3 = torch.zeros((n, n)) T4 = torch.zeros((n, n)) T5 = torch.zeros((n, n)) T6 = torch.zeros((n, n)) T7 = torch.zeros((n, n)) T8 = torch.zeros((n, n)) torch.relu_(T1) result = ( torch.mm(T1, T2) + torch.mm(T3, T4) + torch.mm(T5, T6) + torch.mm(T7, T8) ) return result FileCheck().check_count("aten::mm", 4, exactly=True).run(test_batch_mm.graph) self.run_pass("batch_mm", test_batch_mm.graph) FileCheck().check_count("aten::mm", 4, exactly=True).check_not( "prim::MMTreeReduce" ).run(test_batch_mm.graph) def test_batch_mm_prohibited_mutation_multiple_adds(self): @torch.jit.script def test_batch_mm(n: int): T1 = torch.zeros((n, n)) T2 = torch.zeros((n, n)) T3 = torch.zeros((n, n)) T4 = torch.zeros((n, n)) T5 = torch.zeros((n, n)) T6 = torch.zeros((n, n)) T7 = torch.zeros((n, n)) T8 = torch.zeros((n, n)) T9 = torch.zeros((n, n)) T10 = torch.zeros((n, n)) torch.relu_(T1) result = {} result["no_mutated_parameters"] = ( torch.mm(T2, T3) + torch.mm(T4, T5) + torch.mm(T6, T7) + torch.mm(T8, T9) ) result["all_parameters"] = ( torch.mm(T1, T2) + torch.mm(T3, T4) + torch.mm(T5, T6) + torch.mm(T7, T8) + torch.mm(T9, T10) ) return result self.run_pass("batch_mm", test_batch_mm.graph) FileCheck().check_count("prim::MMTreeReduce", 1, exactly=True).check_count( "aten::mm", 5, exactly=True ).run(test_batch_mm.graph) def test_batch_mm_prohibited_mutation_if_node(self): @torch.jit.script def test_batch_mm(n: int, use_t1: bool): T1 = torch.zeros((n, n)) T2 = torch.zeros((n, n)) T3 = torch.zeros((n, n)) T4 = torch.zeros((n, n)) T5 = torch.zeros((n, n)) T6 = torch.zeros((n, n)) T7 = torch.zeros((n, n)) T8 = torch.zeros((n, n)) T9 = torch.zeros((n, n)) T10 = torch.zeros((n, n)) if use_t1: torch.relu_(T1) return ( torch.mm(T1, T2) + torch.mm(T3, T4) + torch.mm(T5, T6) + torch.mm(T7, T8) + torch.mm(T9, T10) ) else: return ( torch.mm(T2, T3) + torch.mm(T4, T5) + torch.mm(T6, T7) + torch.mm(T8, T9) ) self.run_pass("batch_mm", test_batch_mm.graph) FileCheck().check_count("aten::mm", 5, exactly=True).check_count( "prim::MMTreeReduce", 1, exactly=True ).run(test_batch_mm.graph) def test_batch_mm_side_permitted_mutation(self): @torch.jit.script def test_batch_mm(n: int): result = {} A = torch.zeros((n, n)) T1 = torch.zeros((n, n)) T2 = torch.zeros((n, n)) T3 = torch.zeros((n, n)) T4 = torch.zeros((n, n)) T5 = torch.zeros((n, n)) T6 = torch.zeros((n, n)) T7 = torch.zeros((n, n)) T8 = torch.zeros((n, n)) result["T1"] = torch.mm(A, T1) result["T2"] = torch.mm(A, T2) result["T3"] = torch.mm(A, T3) result["T4"] = torch.mm(A, T4) result["T5"] = torch.mm(A, T5) result["T6"] = torch.mm(A, T6) result["T7"] = torch.mm(A, T7) result["T8"] = torch.mm(A, T8) return result FileCheck().check_count("aten::mm", 8, exactly=True).run(test_batch_mm.graph) self.run_pass("batch_mm", test_batch_mm.graph) FileCheck().check_count("prim::MMBatchSide", 1, exactly=True).check_not( "aten::mm" ).run(test_batch_mm.graph) def test_batch_mm_side_prohibited_mutation_uncommon_side(self): @torch.jit.script def test_batch_mm(n: int): A = torch.zeros((n, n)) T1 = torch.zeros((n, n)) T2 = torch.zeros((n, n)) T3 = torch.zeros((n, n)) T4 = torch.zeros((n, n)) T5 = torch.zeros((n, n)) T6 = torch.zeros((n, n)) T7 = torch.zeros((n, n)) T8 = torch.zeros((n, n)) T9 = torch.zeros((n, n)) T10 = torch.zeros((n, n)) torch.relu_(T1) result = {} result["T1"] = torch.mm(A, T1) result["T2"] = torch.mm(A, T2) result["T3"] = torch.mm(A, T3) result["T4"] = torch.mm(A, T4) result["T5"] = torch.mm(A, T5) result["T6"] = torch.mm(A, T6) result["T7"] = torch.mm(A, T7) result["T8"] = torch.mm(A, T8) result["T9"] = torch.mm(A, T9) result["T10"] = torch.mm(A, T10) return result FileCheck().check_count("aten::mm", 10, exactly=True).run(test_batch_mm.graph) self.run_pass("batch_mm", test_batch_mm.graph) FileCheck().check_count("aten::mm", 1, exactly=True).run(test_batch_mm.graph) FileCheck().check_count("prim::MMBatchSide", 1, exactly=True).run( test_batch_mm.graph ) def test_batch_mm_side_prohibited_mutation_common_side(self): @torch.jit.script def test_batch_mm(n: int): A = torch.zeros((n, n)) T1 = torch.zeros((n, n)) T2 = torch.zeros((n, n)) T3 = torch.zeros((n, n)) T4 = torch.zeros((n, n)) T5 = torch.zeros((n, n)) T6 = torch.zeros((n, n)) T7 = torch.zeros((n, n)) T8 = torch.zeros((n, n)) T9 = torch.zeros((n, n)) T10 = torch.zeros((n, n)) torch.relu_(A) result = {} result["T1"] = torch.mm(A, T1) result["T2"] = torch.mm(A, T2) result["T3"] = torch.mm(A, T3) result["T4"] = torch.mm(A, T4) result["T5"] = torch.mm(A, T5) result["T6"] = torch.mm(A, T6) result["T7"] = torch.mm(A, T7) result["T8"] = torch.mm(A, T8) result["T9"] = torch.mm(A, T9) result["T10"] = torch.mm(A, T10) return result FileCheck().check_count("aten::mm", 10, exactly=True).run(test_batch_mm.graph) self.run_pass("batch_mm", test_batch_mm.graph) FileCheck().check_count("aten::mm", 10, exactly=True).check_not( "prim::MMBatchSide" ).run(test_batch_mm.graph)