1# mypy: allow-untyped-defs 2import collections 3from collections import defaultdict 4from typing import Any, Callable, Dict, Optional 5 6import torch 7import torch.utils._pytree as pytree 8 9 10aten = torch.ops.aten 11 12# We would like to split modules into two subgraphs for runtime weight updates to work correctly. 13# The use case and more information could be found at: 14# https://docs.google.com/document/d/1inZC-8KarJ6gKB7G9egmYLx1V_dKX_apxon0w4zPC0Q/edit?usp=sharing 15META_TAG = "MODULE_TYPE" 16MODULE_TAG = "_MAIN_MODULE" 17CONST_MODULE_TAG = "_CONST_MODULE" 18 19 20def replace_node_with_constant(gm, node, constant, name=None): 21 g = gm.graph 22 23 if name: 24 qualname = name 25 else: 26 if not hasattr(gm, "_frozen_param_count"): 27 gm._frozen_param_count = 0 28 i = gm._frozen_param_count 29 30 while True: 31 qualname = f"_frozen_param{i}" 32 if not hasattr(gm, qualname): 33 break 34 i += 1 35 36 gm._frozen_param_count = i + 1 37 38 with g.inserting_before(node): 39 new_input_node = g.create_node("get_attr", qualname, (), {}) 40 node.replace_all_uses_with(new_input_node) 41 new_input_node.meta.update(node.meta) 42 g.erase_node(node) 43 44 # needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning 45 gm.register_buffer(qualname, constant) 46 setattr(gm, qualname, constant) 47 48 49class ConstantFolder(torch.fx.Interpreter): 50 def __init__( 51 self, 52 gm, 53 skip_constructors=False, 54 ): 55 super().__init__(gm) 56 self.node_replacements: Dict[torch.fx.Node, Any] = {} 57 self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter() 58 self.unknown_value = object() 59 self.skip_constructors: bool = skip_constructors 60 61 # overwrite this to deallocate env values if their only remaining use 62 # is the output 63 self.user_to_last_uses = self.node_to_last_non_output_use() 64 65 def is_impure(self, node: torch.fx.node.Node): 66 if ( 67 node.target == torch.ops.prims.convert_element_type.default 68 and node.args[0].op == "get_attr" # type: ignore[union-attr] 69 and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr] 70 and node.args[1] == torch.bfloat16 71 ): 72 # For int8_weight -> dq -> bf16_weight 73 return True 74 if node.target in [ 75 torch.ops.quantized_decomposed.dequantize_per_channel.default, 76 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 77 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, 78 ]: 79 # For the pattern fp32_weight -> q -> dq 80 # We only folding fp32_weight -> q 81 # int8_weight and leave dq in graph to be fused 82 return True 83 return False 84 85 def node_to_last_non_output_use(self): 86 last_non_output_use = collections.defaultdict(list) 87 seen_uses = set() 88 output_node = next(iter(reversed(self.module.graph.nodes))) 89 90 for node in reversed(self.module.graph.nodes): 91 if node.target == "output": 92 continue 93 94 def add_use(inp): 95 if inp in seen_uses: 96 return 97 98 seen_uses.add(inp) 99 last_non_output_use[node].append(inp) 100 101 # In-place is fine since we don't mutate 102 pytree.tree_map_only_(torch.fx.Node, add_use, (node.args, node.kwargs)) 103 104 # if this node is only used in output, we want to gc it right away 105 if len(node.users) == 1 and output_node in node.users: 106 last_non_output_use[node].append(node) 107 108 return last_non_output_use 109 110 def run_node(self, node): 111 if node.target == "output": 112 # because we remove nodes from env on last non output use, 113 # re-define them now or we'll get error in interpreter 114 def set_env(arg): 115 self.env[arg] = self.unknown_value 116 117 # In-place is fine since we don't mutate 118 pytree.tree_map_only_(torch.fx.Node, set_env, node.args) 119 return super().run_node(node) 120 121 args, kwargs = self.fetch_args_kwargs_from_env(node) 122 flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs) 123 124 # We need to do this weird thing because in cases where flattened_inputs 125 # contains a ScriptObject, equality checking results in a type error if 126 # the types are different. 127 if any( 128 type(self.unknown_value) == type(input_) and self.unknown_value == input_ 129 for input_ in flattened_inputs 130 ): 131 return self.unknown_value 132 133 # TODO - fix errors with this 134 if ( 135 node.op == "call_function" 136 and node.target == aten._efficientzerotensor.default 137 ): 138 return self.unknown_value 139 140 # TODO - constant folding triton kernel returns the inputs -- fix this 141 if ( 142 node.op == "call_function" 143 and node.name == "triton_kernel_wrapper_functional_proxy" 144 ): 145 return self.unknown_value 146 147 # skip constructors, since inductor generates optimal code for them already 148 # and turning into tensor would result in an additional global memory read 149 # TODO - more complicated strategy 150 if ( 151 self.skip_constructors 152 and node.op != "get_attr" 153 and not any(isinstance(e, torch.Tensor) for e in flattened_inputs) 154 ): 155 return self.unknown_value 156 157 # All mutations should either be removed or on inputs which we did not make constant 158 if ( 159 isinstance(node.target, torch._ops.OpOverload) 160 and torch.Tag.nondeterministic_seeded in node.target.tags 161 ): 162 return self.unknown_value 163 164 out = super().run_node(node) 165 166 if node.op != "get_attr" and isinstance(out, torch.Tensor): 167 if out.device.type == "meta": 168 return out 169 170 if not self.insertable_tensor_check(out): 171 return out 172 173 if self.is_impure(node): 174 return self.unknown_value 175 176 self.add_node_replacement(node, out) 177 178 flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs) 179 180 for n in flattened_node_inps: 181 if not isinstance(n, torch.fx.Node): 182 continue 183 184 self.replaced_uses[n] += 1 185 186 for to_delete in self.user_to_last_uses.get(node, []): 187 if self.replaced_uses[to_delete] == len(to_delete.users): 188 self.node_replacements.pop(to_delete, None) 189 190 return out 191 192 def insertable_tensor_check(self, tensor: torch.Tensor) -> bool: 193 return True 194 195 def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None: 196 self.node_replacements[node] = tensor 197 198 def run(self): 199 env = {} 200 for n in self.module.graph.find_nodes(op="placeholder"): 201 env[n] = self.unknown_value 202 return super().run(initial_env=env) 203 204 205def constant_fold(gm, constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None): 206 with torch.utils._python_dispatch._disable_current_modes(): 207 cf = ConstantFolder(gm, skip_constructors=True) 208 cf.run() 209 210 for node, constant in cf.node_replacements.items(): 211 if constraint_fn is not None and not constraint_fn(node): 212 continue 213 replace_node_with_constant(gm, node, constant) 214 215 erased_params = [] 216 # Get all attr users by looking up the graph instead from node.users, because in this case 217 # _tensor_constant0 and _tensor_constant0_1 are actually refereing to the same tensor. 218 219 # opcode name target args kwargs 220 # ------------- ------------------- ---------------- --------------------------- -------- 221 # placeholder arg0_1 arg0 () {} 222 # get_attr _tensor_constant0 state () {} 223 # call_function add aten.add.Tensor (arg0_1, _tensor_constant0) {} 224 # get_attr _tensor_constant0_1 state () {} 225 # call_function add_ aten.add_.Tensor (_tensor_constant0_1, 1) {} 226 # output output output ([add],) {} 227 228 get_attr_node_users = defaultdict(list) 229 for node in gm.graph.nodes: 230 if node.op == "get_attr": 231 get_attr_node_users[node.target].extend(node.users.keys()) 232 for node in gm.graph.find_nodes(op="get_attr"): 233 if node.op == "get_attr" and len(get_attr_node_users[node.target]) == 0: 234 if hasattr(gm, node.target): 235 delattr(gm, node.target) 236 erased_params.append(node) 237 for node in erased_params: 238 gm.graph.erase_node(node) 239 240 gm.graph.eliminate_dead_code() 241 gm.graph.lint() 242 gm.recompile() 243 244 245def constant_graph_tag(gm: torch.fx.GraphModule): 246 with torch.utils._python_dispatch._disable_current_modes(): 247 cf = ConstantFolder(gm, skip_constructors=True) 248 cf.run() 249 250 for node in gm.graph.nodes: 251 if ( 252 node.op == "get_attr" 253 or node in cf.node_replacements 254 or node in cf.replaced_uses 255 ): 256 node.meta[META_TAG] = CONST_MODULE_TAG 257 else: 258 node.meta[META_TAG] = MODULE_TAG 259 260 261def run_and_get_constant_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: 262 """ 263 Construct a GraphModule which corresponds to the part which could be 264 constant folded in provided gm. 265 """ 266 267 constant_graph_tag(gm) 268 # We rewrite the tags, if it's a constant being directly consumed, without 269 # any folding opportunity, we keep it in main gm. 270 for node in gm.graph.find_nodes(op="get_attr"): 271 used_to_fold = False 272 for u in node.users: 273 if u.meta[META_TAG] == CONST_MODULE_TAG: 274 used_to_fold = True 275 break 276 if not used_to_fold: 277 node.meta[META_TAG] = MODULE_TAG 278 279 new_graph = torch.fx.Graph() 280 281 node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {} 282 output_nodes = [] 283 for node in gm.graph.nodes: 284 if node.meta[META_TAG] == MODULE_TAG: 285 continue 286 287 new_node = new_graph.node_copy(node, lambda x: node_remapping[x]) 288 node_remapping[node] = new_node 289 290 for user in node.users: 291 if user.meta[META_TAG] == MODULE_TAG: 292 output_nodes.append(new_node) 293 break 294 295 new_graph.output(tuple(output_nodes)) 296 new_graph.lint() 297 new_gm = torch.fx.GraphModule(gm, new_graph) 298 299 return new_gm 300