1# Owner(s): ["module: fx"] 2 3import torch 4import torch.fx as fx 5from torch.fx.passes.infra.pass_base import PassBase, PassResult 6from torch.fx.passes.infra.pass_manager import ( 7 _topological_sort_passes, 8 pass_result_wrapper, 9 PassManager, 10 this_before_that_pass_constraint, 11) 12from torch.testing._internal.common_utils import TestCase 13 14 15# Pass that uses PassBase and returns a PassResult (best scenario) 16class ReplaceAddWithMulPass(PassBase): 17 def call(self, gm) -> PassResult: 18 modified = False 19 for node in gm.graph.nodes: 20 if node.op == "call_function" and node.target == torch.add: 21 node.target = torch.mul 22 modified = True 23 return PassResult(gm, modified) 24 25 26# Pass that is a callable and returns a PassResult 27def replace_mul_with_div_pass(gm) -> PassResult: 28 modified = False 29 for node in gm.graph.nodes: 30 if node.op == "call_function" and node.target == torch.mul: 31 node.target = torch.div 32 modified = True 33 return PassResult(gm, modified) 34 35 36# Pass that is a PassBase and does not return a PassResult 37# Need to wrap with pass_result_wrapper or else it will fail 38class ReplaceDivWithSubPass(PassBase): 39 def call(self, gm) -> PassResult: 40 for node in gm.graph.nodes: 41 if node.op == "call_function" and node.target == torch.div: 42 node.target = torch.sub 43 44 45# Pass that is a callable and does not return a PassResult 46# Need to wrap with pass_result_wrapper or else it will fail 47def replace_sub_with_add_pass(gm) -> PassResult: 48 for node in gm.graph.nodes: 49 if node.op == "call_function" and node.target == torch.sub: 50 node.target = torch.add 51 52 53class AddModule(torch.nn.Module): 54 def forward(self, x): 55 y = torch.add(x, x) 56 z = torch.add(y, x) 57 return z 58 59 60class TestPassManager(TestCase): 61 def test_pass_manager(self): 62 """ 63 Tests that the pass manager runs the passes correctly. 64 """ 65 66 m = AddModule() 67 traced_m = torch.fx.symbolic_trace(m) 68 pm = PassManager( 69 passes=[ 70 ReplaceAddWithMulPass(), 71 replace_mul_with_div_pass, 72 pass_result_wrapper(ReplaceDivWithSubPass()), 73 pass_result_wrapper(replace_sub_with_add_pass), 74 ], 75 steps=5, 76 ) 77 78 pm.validate_constraints() 79 self.assertEqual(len(pm.passes), 4) 80 81 res = pm(traced_m) 82 modified_m = res.graph_module 83 assert isinstance(modified_m, fx.GraphModule) 84 85 # Check that all call_function nodes are divs 86 for node in modified_m.graph.nodes: 87 if node.op == "call_function": 88 self.assertEqual(node.target, torch.add) 89 90 def test_this_before_that_pass_constraint(self): 91 """ 92 Tests the construction of constraints 93 """ 94 passes = [lambda x: 2 * x for _ in range(10)] 95 pm = PassManager(passes) 96 97 # add unfulfillable constraint 98 pm.add_constraint(this_before_that_pass_constraint(passes[-1], passes[0])) 99 100 with self.assertRaises(RuntimeError): 101 pm.validate_constraints() 102 103 def test_pass_manager_checks(self): 104 """ 105 Tests that users can add in check functions correctly 106 """ 107 m = AddModule() 108 traced_m = fx.symbolic_trace(m) 109 pm = PassManager(passes=[ReplaceAddWithMulPass(), replace_mul_with_div_pass]) 110 111 def check_div_target(graph_module): 112 for node in graph_module.graph.nodes: 113 if node.op == "call_function" and node.target != torch.div: 114 raise ValueError("Target should be div!") 115 116 pm.add_checks(check_div_target) 117 118 with self.assertRaises(ValueError): 119 pm(traced_m) 120 121 def test_pass_manager_bad_checks(self): 122 """ 123 Checks that we error if we pass in a check function with the wrong parameters 124 """ 125 126 def check_bad_args(graph_module, i): 127 pass 128 129 pm = PassManager() 130 self.assertRaises(TypeError, pm.add_checks, check_bad_args) 131 132 def test_topological_sort(self): 133 """ 134 Tests that passes are correctly ordered based on contraints. 135 """ 136 137 def pass0(x): 138 return x 139 140 def pass1(x): 141 return x + 1 142 143 def pass2(x): 144 return x + 2 145 146 def pass3(x): 147 return x + 3 148 149 def pass4(x): 150 return x + 4 151 152 def pass5(x): 153 return x + 5 154 155 # Not passing any constraints should keep the original order 156 passes = [pass0, pass1, pass2, pass3, pass4, pass5] 157 sorted = _topological_sort_passes(passes, []) 158 self.assertEqual(sorted, passes) 159 160 # Graph that we are constructing: 161 # 5 ----> 0 <---- 4 162 # | | 163 # +-> 2 -> 3 -> 1 <-+ 164 # Which has a possible topological order of: [4, 5, 0, 2, 3, 1] 165 passes = [pass0, pass1, pass2, pass3, pass4, pass5] 166 constraints = [ 167 this_before_that_pass_constraint(pass5, pass0), 168 this_before_that_pass_constraint(pass5, pass2), 169 this_before_that_pass_constraint(pass4, pass0), 170 this_before_that_pass_constraint(pass4, pass1), 171 this_before_that_pass_constraint(pass2, pass3), 172 this_before_that_pass_constraint(pass3, pass1), 173 ] 174 sorted = _topological_sort_passes(passes, constraints) 175 self.assertEqual(sorted, [pass4, pass5, pass0, pass2, pass3, pass1]) 176 177 # Circular dependency should result in the circular_dep flag being set 178 passes = [pass0, pass1, pass2] 179 constraints = [ 180 this_before_that_pass_constraint(passes[0], passes[1]), 181 this_before_that_pass_constraint(passes[1], passes[2]), 182 this_before_that_pass_constraint(passes[2], passes[0]), 183 ] 184 with self.assertRaises(RuntimeError) as e: 185 _topological_sort_passes(passes, constraints) 186 expected_error_msg = ( 187 f"Circular dependency detected within the following passes: {passes}" 188 ) 189 self.assertEqual(e.exception.args[0], expected_error_msg) 190 191 def test_pass_manager_error(self): 192 """ 193 Tests error catching + debug 194 """ 195 196 def pass_fail(graph_module): 197 raise RuntimeError("bad") 198 199 m = AddModule() 200 traced_m = torch.fx.symbolic_trace(m) 201 pm = PassManager( 202 passes=[ 203 ReplaceAddWithMulPass(), 204 replace_mul_with_div_pass, 205 ReplaceDivWithSubPass(), 206 pass_result_wrapper(replace_sub_with_add_pass), 207 ], 208 ) 209 210 # Comment out this line to see the actual error message 211 error_msg = ( 212 "ReplaceDivWithSubPass.*ReplaceAddWithMulPass.*replace_mul_with_div_pass" 213 ) 214 with self.assertRaisesRegex(Exception, error_msg): 215 pm(traced_m) 216 217 pm = PassManager( 218 passes=[ 219 ReplaceAddWithMulPass(), 220 replace_mul_with_div_pass, 221 pass_result_wrapper(ReplaceDivWithSubPass()), 222 pass_result_wrapper(replace_sub_with_add_pass), 223 pass_fail, 224 ], 225 ) 226 227 # Comment out this line to see the actual error message 228 error_msg = "pass_fail.*ReplaceAddWithMulPass.*replace_mul_with_div_pass.*ReplaceDivWithSubPass.*replace_sub_with_add_pass" 229 with self.assertRaisesRegex(Exception, error_msg): 230 pm(traced_m) 231