1# mypy: allow-untyped-defs 2import collections 3import logging 4import operator 5from collections import OrderedDict 6from typing import ( 7 Any, 8 DefaultDict, 9 Deque, 10 Dict, 11 Iterable, 12 Iterator, 13 List, 14 Optional, 15 Set, 16 Tuple, 17) 18 19import torch 20from torch._dynamo.utils import counters, optimus_scuba_log 21from torch._utils_internal import upload_graph 22from torch.fx.passes.graph_transform_observer import GraphTransformObserver 23 24from .. import config 25from ..pattern_matcher import ( 26 CallFunctionVarArgs, 27 get_arg_value, 28 stable_topological_sort, 29) 30 31 32try: 33 # importing this will register fbgemm lowerings for inductor 34 import deeplearning.fbgemm.fbgemm_gpu.fb.inductor_lowerings # noqa: F401 35 36 has_fbgemm = True 37except Exception: 38 has_fbgemm = False 39 40aten = torch.ops.aten 41 42log = logging.getLogger(__name__) 43 44MIN_FUSE_SET_SIZE = 5 45MAX_FUSE_SET_SIZE = 300 46MAX_FUSE_SEARCH_DEPTH = 5 47# The maximum tensor size that can go into the fusion group 48MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR = 4096 49# Whether we only fuse nodes with same parent node 50FUSE_NODES_WITH_SAME_PARENT = False 51# Whether we enable the add broadcast in batch linear 52SHAPE_BROADCAST_BATCH_LINEAR = False 53# Whether we enable the fuse nodes with same users 54Fuse_NODES_WITH_SAME_USERS = False 55 56# exclude these nodes from BFS 57# excluding get item improves optimizer compilation time by 60s 58SEARCH_EXCLUSIONS = {operator.getitem} 59 60 61default_graph_search_options = { 62 "min_fuse_set_size": MIN_FUSE_SET_SIZE, 63 "max_fuse_set_size": MAX_FUSE_SET_SIZE, 64 "max_fuse_search_depth": MAX_FUSE_SEARCH_DEPTH, 65 "max_fuse_tensor_size_group_linear": MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR, 66 "fuse_nodes_with_same_parent": FUSE_NODES_WITH_SAME_PARENT, 67 "shape_broadcast_batch_linear": SHAPE_BROADCAST_BATCH_LINEAR, 68 "fuse_nodes_with_same_users": Fuse_NODES_WITH_SAME_USERS, 69} 70 71graph_search_options = default_graph_search_options 72 73 74def update_stack_example_value(node, metadata, dim=0, op=torch.stack): 75 """ 76 Update the example value of the node in the graph to enable followup split cat opt. 77 """ 78 if node is not None and hasattr(node, "meta"): 79 if op == torch.stack: 80 example_value = torch.stack(metadata, dim=dim) 81 elif op == torch.unbind: 82 example_value = torch.unbind(metadata, dim=dim) # type: ignore[assignment] 83 else: 84 return 85 node.meta["example_value"] = example_value 86 87 88def update_pointwise_example_value(pointwise_node, input, other, op): 89 """ 90 Update the example value of the add node in the graph to enable followup split cat opt. 91 """ 92 if pointwise_node is not None and hasattr(pointwise_node, "meta"): 93 if op == torch.add: 94 example_value = torch.add(input, other) 95 elif op == torch.mul: 96 example_value = torch.mul(input, other) 97 else: 98 return 99 pointwise_node.meta["example_value"] = example_value 100 101 102class GroupBatchFusionBase: 103 def __init__(self, **kwargs) -> None: 104 self.graph_search_options = kwargs.pop( 105 "graph_search_options", default_graph_search_options 106 ) 107 108 def match(self, node): 109 raise NotImplementedError("match called on base") 110 111 def fuse(self, graph, subset): 112 raise NotImplementedError("fuse called on base") 113 114 115PRE_GRAD_FUSIONS: Dict[str, GroupBatchFusionBase] = {} 116POST_GRAD_FUSIONS: Dict[str, GroupBatchFusionBase] = {} 117 118 119def register_fusion(name: str, pre_grad=True): 120 def decorator(fusion_cls: GroupBatchFusionBase): 121 if pre_grad: 122 PRE_GRAD_FUSIONS[name] = fusion_cls 123 else: 124 POST_GRAD_FUSIONS[name] = fusion_cls 125 return fusion_cls 126 127 return decorator 128 129 130def list_group_batch_fusions(pre_grad=True) -> List[str]: 131 if pre_grad: 132 return list(PRE_GRAD_FUSIONS.keys()) 133 else: 134 return list(POST_GRAD_FUSIONS.keys()) 135 136 137def decompose_stack(graph: torch.fx.GraphModule, input_tensors: List[Any]) -> Any: 138 unsqueezed_inputs = [] 139 unsqueezed_inputs_meta = [] 140 for input_tensor in input_tensors: 141 unsqueezed_input = graph.call_function( 142 aten.unsqueeze, args=(input_tensor,), kwargs={"dim": 0} 143 ) 144 unsqueezed_inputs.append(unsqueezed_input) 145 unsqueezed_input.meta["val"] = aten.unsqueeze(input_tensor.meta["val"], dim=0) # type: ignore[assignment] 146 unsqueezed_inputs_meta.append(unsqueezed_input.meta["val"]) 147 stacked_inputs = graph.call_function( 148 aten.cat, args=(unsqueezed_inputs,), kwargs={"dim": 0} 149 ) 150 stacked_inputs.meta["val"] = aten.cat(unsqueezed_inputs_meta, dim=0) # type: ignore[assignment] 151 return stacked_inputs 152 153 154class GroupFusion(GroupBatchFusionBase): 155 """ 156 Fuse ops in a group way, e.g, fuse mm/addmm of arbitrary input shapes with fbgemm.gmm. 157 """ 158 159 160class BatchFusion(GroupBatchFusionBase): 161 """ 162 Fuse ops in a batch way, e.g, fuse mm/addmm of same input shapes with bmm. 163 """ 164 165 166class BatchPointwiseOpsFusionFactory(BatchFusion): 167 def __init__(self, op, **kwargs) -> None: 168 super().__init__(**kwargs) 169 self.op = op 170 171 172@register_fusion("batch_linear_post_grad", pre_grad=False) 173class PostGradBatchLinearFusion(BatchFusion): 174 """ 175 Fuse ops in a batch way in post grad (aten level). 176 """ 177 178 def _addmm_node_can_be_fused(self, node: torch.fx.Node) -> bool: 179 # pyre-fixme[7]: Incompatible return type 180 return ( 181 node.kwargs.get("beta", 1.0) == 1.0 and node.kwargs.get("alpha", 1.0) == 1.0 # type: ignore[return-value] 182 ) 183 184 def _is_input_2d(self, input: torch.fx.Node) -> bool: 185 input_shapes = input.meta["val"].shape 186 return ( 187 len(input_shapes) == 2 188 and isinstance(input_shapes[0], int) 189 and isinstance(input_shapes[1], int) 190 ) 191 192 def match( 193 self, node: torch.fx.Node 194 ) -> Optional[Tuple[str, int, int, int, bool, str]]: 195 if CallFunctionVarArgs(aten.mm).match(node): 196 input_m, weight_m = node.args 197 bias_m = None 198 199 elif CallFunctionVarArgs(aten.addmm.default).match( 200 node 201 ) and self._addmm_node_can_be_fused(node): 202 bias_m, input_m, weight_m = node.args 203 else: 204 return None 205 # get the user of the node 206 if self.graph_search_options.get("fuse_nodes_with_same_users", False): 207 users = [user.target for user in node.users.keys()] 208 else: 209 users = "" # type: ignore[assignment] 210 # only handle the cases where inputs are 2D tensors 211 if not self._is_input_2d(input_m) or not self._is_input_2d(weight_m): # type: ignore[arg-type] 212 return None 213 m, k = input_m.meta["val"].shape # type: ignore[union-attr] 214 n = weight_m.meta["val"].shape[1] # type: ignore[union-attr] 215 batch_key = ("batch_linear_post_grad", m, k, n, bias_m is not None, str(users)) 216 return batch_key 217 218 def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): 219 batch_inputs = [] 220 batch_weights = [] 221 batch_biases = [] 222 batch_nodes = [] 223 batch_inputs_meta = [] 224 batch_weights_meta = [] 225 batch_biases_meta = [] 226 227 for node in subset: 228 if CallFunctionVarArgs(aten.addmm.default).match(node): 229 bias, input, weight = node.args 230 elif CallFunctionVarArgs(aten.mm.default).match(node): 231 input, weight = node.args 232 bias = None 233 batch_nodes.append(node) 234 batch_inputs.append(input) # type: ignore[possibly-undefined] 235 batch_weights.append(weight) # type: ignore[possibly-undefined] 236 batch_biases.append(bias) # type: ignore[possibly-undefined] 237 batch_inputs_meta.append(input.meta) # type: ignore[possibly-undefined, union-attr] 238 batch_weights_meta.append(weight.meta) # type: ignore[possibly-undefined, union-attr] 239 if bias is not None: # type: ignore[possibly-undefined] 240 batch_biases_meta.append(bias.meta) # type: ignore[possibly-undefined, union-attr] 241 else: 242 batch_biases_meta.append(None) 243 244 with graph.inserting_before(subset[-1]): 245 fused_inputs = decompose_stack(graph, batch_inputs) 246 fused_weights = decompose_stack(graph, batch_weights) 247 fused_inputs_meta_val = torch.stack( 248 [input["val"] for input in batch_inputs_meta] 249 ) 250 fused_weights_meta_val = torch.stack( 251 [weight["val"] for weight in batch_weights_meta] 252 ) 253 fused_bmm = graph.call_function( 254 aten.bmm, 255 args=(fused_inputs, fused_weights), 256 ) 257 fused_bmm.meta["val"] = aten.bmm( 258 fused_inputs_meta_val, fused_weights_meta_val 259 ) 260 for i, original_mm in enumerate(batch_nodes): 261 has_bias = False 262 with graph.inserting_after(fused_bmm): 263 new_mm = graph.call_function(aten.select, args=((fused_bmm, 0, i))) 264 new_mm.meta["val"] = aten.select(fused_bmm.meta["val"], 0, i) 265 if batch_biases[i]: 266 has_bias = True 267 # broadcast the bias to the same shape as the mm output 268 if self.graph_search_options.get( 269 "shape_broadcast_batch_linear", False 270 ): 271 broadcast_shape = torch.broadcast_shapes( 272 batch_biases_meta[i]["val"].shape, new_mm.meta["val"].shape 273 ) 274 broadcast_bias = graph.call_function( 275 aten.broadcast_to.default, 276 args=(batch_biases[i],), 277 kwargs={"size": broadcast_shape}, 278 ) 279 broadcast_bias.meta["val"] = aten.broadcast_to(batch_biases_meta[i]["val"], broadcast_shape) # type: ignore[assignment] 280 new_bias_add = graph.call_function( 281 aten.add.Tensor, args=((broadcast_bias, new_mm)) 282 ) 283 new_bias_add.meta["val"] = aten.add.Tensor( 284 broadcast_bias.meta["val"], new_mm.meta["val"] 285 ) 286 else: 287 new_bias_add = graph.call_function( 288 aten.add, args=((batch_biases[i], new_mm)) 289 ) 290 new_bias_add.meta["val"] = aten.add.Tensor( 291 batch_biases_meta[i]["val"], new_mm.meta["val"] 292 ) 293 new_mm_cont = new_bias_add if has_bias else new_mm # type: ignore[possibly-undefined] 294 original_mm.replace_all_uses_with(new_mm_cont) 295 new_mm_cont.meta.update(original_mm.meta) 296 graph.erase_node(original_mm) 297 counters["inductor"]["batch_linear_post_grad"] += 1 298 299 300@register_fusion("group_linear", pre_grad=False) 301class GroupLinearFusion(GroupFusion): 302 def _addmm_node_can_be_fused(self, node: torch.fx.Node): 303 input_shape = node.args[1].meta["val"].shape # type: ignore[union-attr] 304 weight_shape = node.args[2].meta["val"].shape # type: ignore[union-attr] 305 return ( 306 node.kwargs.get("beta", 1.0) == 1.0 307 and node.kwargs.get("alpha", 1.0) == 1.0 308 and len(input_shape) == 2 309 and len(weight_shape) == 2 310 and all(x % 2 == 0 for x in input_shape + weight_shape) 311 and all( 312 shape <= self.graph_search_options["max_fuse_tensor_size_group_linear"] 313 for shape in input_shape + weight_shape 314 ) 315 ) 316 317 def _mm_node_can_be_fused(self, node: torch.fx.Node): 318 input_shape = node.args[0].meta["val"].shape # type: ignore[union-attr] 319 weight_shape = node.args[1].meta["val"].shape # type: ignore[union-attr] 320 return ( 321 len(input_shape) == 2 322 and len(weight_shape) == 2 323 and all(x % 2 == 0 for x in input_shape + weight_shape) 324 and all( 325 shape <= self.graph_search_options["max_fuse_tensor_size_group_linear"] 326 for shape in input_shape + weight_shape 327 ) 328 ) 329 330 def match(self, node: torch.fx.Node) -> Optional[Tuple[str, bool]]: 331 if CallFunctionVarArgs(aten.mm.default).match( 332 node 333 ) and self._mm_node_can_be_fused(node): 334 group_key = ("group_linear", True) 335 elif CallFunctionVarArgs(aten.addmm.default).match( 336 node 337 ) and self._addmm_node_can_be_fused(node): 338 bias = node.args[0] 339 group_key = ("group_linear", bias is None) 340 else: 341 group_key = None 342 return group_key 343 344 def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): 345 group_inputs = [] 346 group_weights = [] 347 group_biases = [] 348 group_nodes = [] 349 for node in subset: 350 if CallFunctionVarArgs(aten.addmm.default).match(node): 351 bias, input, weight = node.args 352 else: 353 assert CallFunctionVarArgs(aten.mm.default).match(node) 354 input, weight = node.args 355 bias = None 356 357 group_nodes.append(node) 358 group_inputs.append(input) 359 group_weights.append(weight) 360 group_biases.append(bias) 361 362 if all(bias is None for bias in group_biases): 363 group_biases = None # type: ignore[assignment] 364 365 with graph.inserting_before(subset[0]): 366 fused_mm = graph.call_function( 367 torch.ops.fbgemm.gmm.default, 368 args=(group_inputs, group_weights, group_biases), 369 kwargs={"smart_fused": True}, 370 ) 371 372 for i, original_mm in enumerate(group_nodes): 373 with graph.inserting_after(fused_mm): 374 new_mm = graph.call_function(operator.getitem, args=(fused_mm, i)) 375 original_mm.replace_all_uses_with(new_mm) 376 new_mm.meta.update(original_mm.meta) 377 graph.erase_node(original_mm) 378 counters["inductor"]["group_linear"] += 1 379 380 381class BatchPointwiseMathOpsPostGradFusion(BatchPointwiseOpsFusionFactory): 382 """ 383 Batch pointwise math operator (e.g., add, mul) in post grad pass. 384 """ 385 386 def __init__(self, op, **kwargs) -> None: 387 super().__init__(op, **kwargs) 388 self.op = op 389 390 def _pointwise_node_can_be_fused(self, node: torch.fx.Node): 391 # note: we only consider the case where the inputs are tensors 392 # for mixed precision training, we need to make sure the inputs 393 # of the aten.cat when do the stack should be the same dtype 394 # otherwise, the output of the aten.cat may be not the same as 395 # its inputs, and cause dtype not same error in mm or addmm 396 input, other = node.args 397 return ( 398 input.meta["val"].shape == other.meta["val"].shape # type: ignore[union-attr] 399 if hasattr(input, "meta") 400 and hasattr(other, "meta") 401 and "val" in input.meta # type: ignore[union-attr] 402 and "val" in other.meta # type: ignore[union-attr] 403 else False 404 ) 405 406 def match(self, node: torch.fx.Node): 407 if CallFunctionVarArgs(self.op).match( 408 node 409 ) and self._pointwise_node_can_be_fused(node): 410 alpha = node.kwargs.get("alpha", 1.0) 411 rounding_mode = node.kwargs.get("rounding_mode", None) 412 input, other = node.args 413 shape = list(input.meta["val"].shape) # type: ignore[union-attr] 414 if self.graph_search_options.get("fuse_nodes_with_same_parent", False): 415 # only consider the linear case so far 416 # pyre-fixme[16] 417 if input.target == aten.select or other.target == aten.select: # type: ignore[union-attr] 418 parent = ( 419 # pyre-fixme[16] 420 input.args[0] # type: ignore[union-attr] 421 # pyre-fixme[16] 422 if input.target == aten.select # type: ignore[union-attr] 423 else other.args[0] # type: ignore[union-attr] 424 ) 425 else: 426 parent = "" 427 else: 428 parent = "" 429 group_key = ( 430 "batch_aten_" + self.op.__name__.lower().split(".")[0], 431 str(shape), 432 str(input.meta["val"].dtype), # type: ignore[union-attr] 433 str(other.meta["val"].dtype), # type: ignore[union-attr] 434 str(alpha), 435 str(rounding_mode), 436 str(parent), 437 ) 438 else: 439 group_key = None 440 return group_key 441 442 def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): 443 batch_inputs, batch_others = [], [] 444 alpha = subset[0].kwargs.get("alpha", 1.0) 445 batch_inputs_meta, batch_others_meta = [], [] 446 447 for node in subset: 448 input, other = node.args 449 batch_inputs.append(input) 450 batch_others.append(other) 451 batch_inputs_meta.append(input.meta) # type: ignore[possibly-undefined, union-attr] 452 batch_others_meta.append(other.meta) # type: ignore[possibly-undefined, union-attr] 453 454 with graph.inserting_before(subset[0]): 455 stack_inputs = decompose_stack(graph, batch_inputs) 456 stack_others = decompose_stack(graph, batch_others) 457 stack_inputs_meta = torch.stack( 458 [input["val"] for input in batch_inputs_meta] 459 ) 460 stack_others_meta = torch.stack( 461 [other["val"] for other in batch_others_meta] 462 ) 463 464 batch_op = graph.call_function( 465 self.op, 466 args=(stack_inputs, stack_others), 467 kwargs={"alpha": alpha} if self.op == aten.add.Tensor else {}, 468 ) 469 batch_op.meta["val"] = self.op(stack_inputs_meta, stack_others_meta) 470 for i, original_add in enumerate(subset): 471 with graph.inserting_after(batch_op): 472 new_add = graph.call_function( 473 torch.ops.aten.select, args=((batch_op, 0, i)) 474 ) 475 original_add.replace_all_uses_with(new_add) 476 new_add.meta.update(original_add.meta) 477 graph.erase_node(original_add) 478 counters["inductor"][ 479 "batch_aten_" + self.op.__name__.lower().split(".")[0] 480 ] += 1 481 482 483@register_fusion("batch_linear_lhs") 484class BatchLinearLHSFusion(BatchFusion): 485 """ 486 Batch linear left-hand side fusion. This pass tries to fuse the following patterns: 487 488 torch.nn.functional.linear(x, w1), linear(x, w2),... * linear(x, wn) 489 -> torch.mm(x, torch.cat([w1, w2,... * wn]).transpose(0, 1)) 490 491 We have a separate pass to eliminate contiguous transpose in a generic way. 492 """ 493 494 def match(self, node: torch.fx.Node) -> Optional[Tuple[str, bool, Any]]: 495 if CallFunctionVarArgs(torch.nn.functional.linear).match( 496 node 497 ) and is_linear_node_can_be_fused(node): 498 input = get_arg_value(node, 0, "input") 499 bias = get_arg_value(node, 2, "bias") 500 group_key = ("batch_linear_lhs", bias is None, input) 501 else: 502 group_key = None 503 return group_key 504 505 def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): 506 batch_nodes = [] 507 batch_input = None 508 batch_weights, batch_weights_meta = [], [] 509 batch_biases, batch_biases_meta = [], [] 510 split_sections = [] 511 for node in subset: 512 input = get_arg_value(node, 0, "input") 513 weight = get_arg_value(node, 1, "weight") 514 bias = get_arg_value(node, 2, "bias") 515 batch_nodes.append(node) 516 if batch_input is None: 517 batch_input = input 518 else: 519 assert batch_input is input 520 batch_weights.append(weight) 521 batch_weights_meta.append(weight.meta["example_value"]) 522 if bias: 523 batch_biases.append(bias) 524 batch_biases_meta.append(bias.meta["example_value"]) 525 split_sections.append(weight.meta["example_value"].shape[0]) 526 527 with graph.inserting_before(subset[0]): 528 cat_weights = graph.call_function( 529 torch.cat, args=(batch_weights,), kwargs={"dim": 0} 530 ) 531 cat_weights.meta["example_value"] = torch.cat(batch_weights_meta, dim=0) 532 transposed_weights = graph.call_function( 533 torch.transpose, args=(cat_weights, 0, 1) 534 ) 535 transposed_weights.meta["example_value"] = torch.transpose( 536 cat_weights.meta["example_value"], 0, 1 537 ) 538 if len(batch_biases) > 0: 539 cat_biases = graph.call_function( 540 torch.cat, args=(batch_biases,), kwargs={"dim": 0} 541 ) 542 cat_biases.meta["example_value"] = torch.cat(batch_biases_meta, dim=0) 543 fused_lhs = graph.call_function( 544 torch.addmm, 545 args=(cat_biases, batch_input, transposed_weights), 546 ) 547 fused_lhs.meta["example_value"] = torch.addmm( 548 cat_biases.meta["example_value"], 549 batch_input.meta["example_value"], # type: ignore[union-attr] 550 transposed_weights.meta["example_value"], 551 ) 552 else: 553 fused_lhs = graph.call_function( 554 torch.mm, 555 args=(batch_input, transposed_weights), 556 ) 557 fused_lhs.meta["example_value"] = torch.mm( 558 batch_input.meta["example_value"], # type: ignore[union-attr] 559 transposed_weights.meta["example_value"], 560 ) 561 fused_lhs_list = graph.call_function( 562 torch.split, args=(fused_lhs, split_sections), kwargs={"dim": 1} 563 ) 564 565 for i, node in enumerate(batch_nodes): 566 with graph.inserting_after(fused_lhs_list): 567 new_node = graph.call_function( 568 operator.getitem, args=(fused_lhs_list, i) 569 ) 570 node.replace_all_uses_with(new_node) 571 new_node.meta.update(node.meta) 572 graph.erase_node(node) 573 counters["inductor"]["batch_linear_lhs"] += 1 574 575 576def is_node_meta_valid(node: Optional[torch.fx.Node]): 577 return node is None or "example_value" in node.meta or "val" in node.meta 578 579 580# Poor person's check for if a node in the graph mutates its input. 581# (the graph is torch IR, so we will see torch fns and python operators) 582def _is_mutable_node(tgt): 583 if str(tgt).endswith("_"): 584 # e.g. torch.mul_, torch.Tensor.mul_ 585 return True 586 if ( 587 hasattr(tgt, "__module__") 588 and tgt.__module__ == "_operator" 589 and tgt.__name__.startswith("i") 590 ): 591 # e.g. operator.iand, operator.imul 592 return True 593 return False 594 595 596def is_linear_node_can_be_fused(node: torch.fx.Node): 597 input = get_arg_value(node, 0, "input") 598 weight = get_arg_value(node, 1, "weight") 599 return ( 600 is_node_meta_valid(node) 601 and is_node_meta_valid(input) 602 and is_node_meta_valid(weight) 603 and len(input.meta["example_value"].shape) == 2 604 and len(weight.meta["example_value"].shape) == 2 605 # the mm -> bmm transform adds an unbind() op, 606 # which is not safe for autograd when the output of the mm is mutated. 607 # don't pattern match if any users of the mm mutate the input. 608 and not any(_is_mutable_node(user.target) for user in node.users) 609 ) 610 611 612@register_fusion("batch_linear") 613class PreGradBatchLinearFusion(BatchFusion): 614 """ 615 Batch linear fusion in pre grad pass. 616 Fuse linear with same size with torch.baddmm 617 """ 618 619 def _getitem_args(self, getitem_node: torch.fx.Node): 620 if getitem_node.target != operator.__getitem__ or ( 621 getitem_node.op != "call_function" 622 ): 623 return None 624 return getitem_node.args[0] 625 626 def match(self, node: torch.fx.Node): 627 if CallFunctionVarArgs(torch.nn.functional.linear).match( 628 node 629 ) and is_linear_node_can_be_fused(node): 630 input = get_arg_value(node, 0, "input") 631 weight = get_arg_value(node, 1, "weight") 632 bias = get_arg_value(node, 2, "bias") 633 if self.graph_search_options.get("fuse_nodes_with_same_users", False): 634 users = [user.target for user in node.users.keys()] 635 else: 636 users = "" # type: ignore[assignment] 637 group_key = ( 638 "batch_linear", 639 self._getitem_args(input), 640 str(input.meta["example_value"].shape), 641 str(weight.meta["example_value"].shape), 642 bias is None, 643 str(users), 644 ) 645 else: 646 group_key = None 647 return group_key 648 649 def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): 650 batch_nodes = [] 651 batch_inputs = [] 652 batch_weights = [] 653 batch_biases = [] 654 batch_inputs_metadata = [] 655 batch_weights_metadata = [] 656 batch_biases_metadata = [] 657 for node in subset: 658 batch_nodes.append(node) 659 input = get_arg_value(node, 0, "input") 660 batch_inputs.append(input) 661 batch_inputs_metadata.append(input.meta["example_value"]) 662 weight = get_arg_value(node, 1, "weight") 663 batch_weights.append(weight) 664 batch_weights_metadata.append(weight.meta["example_value"]) 665 bias = get_arg_value(node, 2, "bias") 666 batch_biases.append(bias) 667 if bias is not None and hasattr(bias, "meta"): 668 batch_biases_metadata.append(bias.meta["example_value"]) 669 670 with graph.inserting_before(subset[0]): 671 stack_inputs = graph.call_function( 672 torch.stack, args=(batch_inputs,), kwargs={"dim": 0} 673 ) 674 update_stack_example_value(stack_inputs, batch_inputs_metadata) 675 stack_weights = graph.call_function( 676 torch.stack, args=(batch_weights,), kwargs={"dim": 0} 677 ) 678 update_stack_example_value(stack_weights, batch_weights_metadata) 679 transpose_weight = graph.call_function( 680 torch.transpose, args=(stack_weights, 1, 2) 681 ) 682 transpose_weight.meta["example_value"] = torch.transpose( 683 stack_weights.meta["example_value"], 1, 2 684 ) 685 if all(bias is None for bias in batch_biases): 686 bmm = graph.call_function( 687 torch.bmm, 688 args=(stack_inputs, transpose_weight), 689 ) 690 bmm.meta["example_value"] = torch.bmm( 691 stack_inputs.meta["example_value"], 692 transpose_weight.meta["example_value"], 693 ) 694 bmm_meta = bmm.meta["example_value"] 695 else: 696 stack_biases = graph.call_function( 697 torch.stack, args=(batch_biases,), kwargs={"dim": 0} 698 ) 699 update_stack_example_value(stack_biases, batch_biases_metadata) 700 unsqueeze_biases = graph.call_function( 701 torch.unsqueeze, args=(stack_biases, 1) 702 ) 703 unsqueeze_biases.meta["example_value"] = torch.unsqueeze( 704 stack_biases.meta["example_value"], 1 705 ) 706 bmm = graph.call_function( 707 torch.baddbmm, 708 args=(unsqueeze_biases, stack_inputs, transpose_weight), 709 ) 710 try: 711 # it will have runtime error to broadcast when it has dynamic shape included 712 # in the meta data, so we need to skip the update meta data 713 bmm.meta["example_value"] = torch.baddbmm( 714 unsqueeze_biases.meta["example_value"], 715 stack_inputs.meta["example_value"], 716 transpose_weight.meta["example_value"], 717 ) 718 bmm_meta = bmm.meta["example_value"] 719 except Exception as e: 720 log.debug( 721 f" exception when update bmm meta data with stack error tracekey {e}" # noqa: G004 722 ) 723 bmm_meta = None 724 725 bmm = graph.call_function(torch.unbind, args=(bmm,), kwargs={"dim": 0}) 726 if bmm_meta is not None: 727 bmm.meta["example_value"] = torch.unbind(bmm_meta, dim=0) 728 for i, linear in enumerate(batch_nodes): 729 with graph.inserting_after(bmm): 730 getitem = graph.call_function(operator.getitem, args=(bmm, i)) 731 linear.replace_all_uses_with(getitem) 732 getitem.meta.update(linear.meta) 733 graph.erase_node(linear) 734 counters["inductor"]["batch_linear"] += 1 735 736 737@register_fusion("batch_layernorm") 738class BatchLayernormFusion(BatchFusion): 739 """ 740 Batch layer norm fusion in pre grad pass 741 """ 742 743 def match(self, node: torch.fx.Node): 744 if CallFunctionVarArgs(torch.nn.functional.layer_norm).match(node): 745 input = get_arg_value(node, 0, "input") 746 weight = get_arg_value(node, 2, "weight") 747 bias = get_arg_value(node, 3, "bias") 748 if self.graph_search_options.get("fuse_nodes_with_same_users", False): 749 users = [user.target for user in node.users.keys()] 750 else: 751 users = "" # type: ignore[assignment] 752 group_key = ( 753 ( 754 "batch_layernorm", 755 str(input.meta["example_value"].shape), 756 str(weight.meta["example_value"].shape) 757 if weight is not None 758 else "", 759 str(bias.meta["example_value"].shape) if bias is not None else "", 760 str(get_arg_value(node, 1, "normalized_shape")), 761 str(get_arg_value(node, 4, "eps")), 762 str(users), 763 ) 764 if "example_value" in input.meta 765 and is_node_meta_valid(weight) 766 and is_node_meta_valid(bias) 767 else None 768 ) 769 else: 770 group_key = None 771 return group_key 772 773 def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): 774 group_inputs = [] 775 group_shapes = [] 776 group_weights = [] 777 group_biases = [] 778 group_epss = [] 779 group_nodes = [] 780 group_inputs_metadata = [] 781 group_biases_metadata = [] 782 group_weights_metadata = [] 783 for node in subset: 784 group_nodes.append(node) 785 input = get_arg_value(node, 0, "input") 786 group_inputs.append(input) 787 group_inputs_metadata.append(input.meta["example_value"]) 788 group_shapes.append(get_arg_value(node, 1, "normalized_shape")) 789 weight = get_arg_value(node, 2, "weight") 790 group_weights.append(weight) 791 if weight is not None and hasattr(weight, "meta"): 792 group_weights_metadata.append(weight.meta["example_value"]) 793 bias = get_arg_value(node, 3, "bias") 794 group_biases.append(bias) 795 if bias is not None and hasattr(bias, "meta"): 796 group_biases_metadata.append(bias.meta["example_value"]) 797 eps = get_arg_value(node, 4, "eps") 798 if eps is None: 799 eps = 1e-5 800 group_epss.append(eps) 801 stack_dim = -1 - len(group_shapes[-1]) 802 803 if all(bias is None for bias in group_biases): 804 group_biases = None # type: ignore[assignment] 805 if all(weight is None for weight in group_weights): 806 group_weights = None # type: ignore[assignment] 807 assert all( 808 eps == group_epss[0] for eps in group_epss 809 ), "all epsilon values must be equal" 810 811 with graph.inserting_before(subset[0]): 812 stack_input = graph.call_function( 813 torch.stack, args=(group_inputs,), kwargs={"dim": stack_dim} 814 ) 815 update_stack_example_value(stack_input, group_inputs_metadata, stack_dim) 816 if group_weights is not None: 817 stack_weight = graph.call_function( 818 torch.stack, args=(group_weights,), kwargs={"dim": 0} 819 ) 820 update_stack_example_value(stack_weight, group_weights_metadata) 821 else: 822 stack_weight = None 823 if group_biases is not None: 824 stack_bias = graph.call_function( 825 torch.stack, args=(group_biases,), kwargs={"dim": 0} 826 ) 827 update_stack_example_value(stack_bias, group_biases_metadata) 828 else: 829 stack_bias = None 830 831 batch_layer_norm = graph.call_function( 832 torch.nn.functional.layer_norm, 833 args=(stack_input, group_shapes[-1]), 834 kwargs={"eps": group_epss[-1]}, 835 ) 836 batch_layer_norm.meta["example_value"] = stack_input.meta["example_value"] 837 838 if group_weights is not None and group_biases is not None: 839 previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"] 840 batch_layer_norm = graph.call_function( 841 torch.mul, args=(stack_weight, batch_layer_norm) 842 ) 843 update_pointwise_example_value( 844 batch_layer_norm, 845 stack_weight.meta["example_value"], 846 previous_batch_layer_norm_meta, 847 torch.mul, 848 ) 849 previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"] 850 batch_layer_norm = graph.call_function( 851 torch.add, args=(stack_bias, batch_layer_norm) 852 ) 853 update_pointwise_example_value( 854 batch_layer_norm, 855 stack_bias.meta["example_value"], 856 previous_batch_layer_norm_meta, 857 torch.add, 858 ) 859 elif group_weights is not None and group_biases is None: 860 previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"] 861 batch_layer_norm = graph.call_function( 862 torch.mul, args=(stack_weight, batch_layer_norm) 863 ) 864 update_pointwise_example_value( 865 batch_layer_norm, 866 stack_weight.meta["example_value"], 867 previous_batch_layer_norm_meta, 868 torch.mul, 869 ) 870 elif group_weights is None and group_biases is not None: 871 previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"] 872 batch_layer_norm = graph.call_function( 873 torch.add, args=(stack_bias, batch_layer_norm) 874 ) 875 update_pointwise_example_value( 876 batch_layer_norm, 877 stack_bias.meta["example_value"], 878 previous_batch_layer_norm_meta, 879 torch.add, 880 ) 881 882 batch_layer_norm_unbind = graph.call_function( 883 torch.unbind, 884 args=(batch_layer_norm,), 885 kwargs={"dim": stack_dim}, 886 ) 887 update_stack_example_value( 888 batch_layer_norm_unbind, 889 batch_layer_norm.meta["example_value"], 890 op=torch.unbind, 891 dim=stack_dim, 892 ) 893 894 for i, node in enumerate(group_nodes): 895 with graph.inserting_after(batch_layer_norm_unbind): 896 new_node = graph.call_function( 897 operator.getitem, args=(batch_layer_norm_unbind, i) 898 ) 899 node.replace_all_uses_with(new_node) 900 new_node.meta.update(node.meta) 901 graph.erase_node(node) 902 counters["inductor"]["batch_layernorm"] += 1 903 904 905class BatchPointwiseOpsPreGradFusion(BatchPointwiseOpsFusionFactory): 906 """ 907 Batch pointwise ops (e.g., sigmoid, relu, tanh) fusion in pre grad pass. 908 We fuse it in random place, and the introduced stack node may be merged in split cat. 909 """ 910 911 def __init__(self, op, **kwargs) -> None: 912 super().__init__(op, **kwargs) 913 self.op = op 914 915 def match(self, node: torch.fx.Node): 916 input = get_arg_value(node, 0, "input") 917 if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node): 918 if self.graph_search_options.get("fuse_nodes_with_same_parent", False): 919 # pyre-fixme[16] 920 parent = node.args[0] 921 parent = parent.target if parent is not None else "" # type: ignore[union-attr] 922 else: 923 parent = "" 924 # for relu op, we also use the inplace to construct the key 925 group_key = ( 926 "batch_" + self.op.__name__.lower().split(".")[0], 927 str(input.meta["example_value"].shape), 928 str(node.kwargs.get("inplace", False)), 929 str(parent), 930 ) 931 else: 932 group_key = None 933 return group_key 934 935 def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): 936 batch_nodes = [] 937 batch_inputs = [] 938 batch_inputs_metadata = [] 939 940 for node in subset: 941 batch_nodes.append(node) 942 input = get_arg_value(node, 0, "input") 943 batch_inputs.append(input) 944 batch_inputs_metadata.append(input.meta["example_value"]) 945 946 with graph.inserting_before(subset[0]): 947 stack_inputs = graph.call_function( 948 torch.stack, args=(batch_inputs,), kwargs={"dim": 0} 949 ) 950 update_stack_example_value(stack_inputs, batch_inputs_metadata) 951 if self.op == torch.nn.functional.relu: 952 batch_op = graph.call_function( 953 self.op, 954 args=(stack_inputs,), 955 kwargs={"inplace": subset[0].kwargs.get("inplace", False)}, 956 ) 957 batch_op.meta["example_value"] = self.op( 958 stack_inputs.meta["example_value"], 959 inplace=subset[0].kwargs.get("inplace", False), 960 ) 961 else: 962 batch_op = graph.call_function( 963 self.op, 964 args=(stack_inputs,), 965 ) 966 batch_op.meta["example_value"] = self.op( 967 stack_inputs.meta["example_value"] 968 ) 969 unbind_op = graph.call_function( 970 torch.unbind, args=(batch_op,), kwargs={"dim": 0} 971 ) 972 unbind_op.meta["example_value"] = torch.unbind( 973 batch_op.meta["example_value"], dim=0 974 ) 975 for i, node in enumerate(batch_nodes): 976 with graph.inserting_after(unbind_op): 977 getitem = graph.call_function(operator.getitem, args=(unbind_op, i)) 978 node.replace_all_uses_with(getitem) 979 getitem.meta.update(node.meta) 980 graph.erase_node(node) 981 counters["inductor"]["batch_" + self.op.__name__.lower().split(".")[0]] += 1 982 983 984class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory): 985 """ 986 Batch pointwise ops (e.g., sigmoid, relu, tanh) fusion in post grad pass. 987 The introduced stack node may be merged in split cat. 988 """ 989 990 def __init__(self, op, **kwargs) -> None: 991 super().__init__(op, **kwargs) 992 self.op = op 993 994 def match(self, node: torch.fx.Node): 995 input = get_arg_value(node, 0, "input") 996 if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node): 997 # for relu op, we also use the inplace to construct the key 998 # we batch the ops with same parent to enable followup split cat 999 parent = node.args[0] 1000 parent = parent.target if self.graph_search_options.get("fuse_nodes_with_same_parent", False) else "" # type: ignore[union-attr] 1001 group_key = ( 1002 "batch_aten_" + self.op.__name__.lower().split(".")[0], 1003 str(input.meta["val"].shape), 1004 str(node.kwargs.get("inplace", False)), 1005 # pyre-fixme[16] 1006 str(parent), 1007 ) 1008 else: 1009 group_key = None 1010 return group_key 1011 1012 def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): 1013 batch_nodes = [] 1014 batch_inputs = [] 1015 batch_inputs_metadata = [] 1016 1017 for node in subset: 1018 batch_nodes.append(node) 1019 input = get_arg_value(node, 0, "input") 1020 batch_inputs.append(input) 1021 batch_inputs_metadata.append(input.meta["val"]) 1022 1023 with graph.inserting_before(subset[0]): 1024 stack_inputs = decompose_stack(graph, batch_inputs) 1025 update_stack_example_value(stack_inputs, batch_inputs_metadata) 1026 batch_op = graph.call_function( 1027 self.op, 1028 args=(stack_inputs,), 1029 ) 1030 for i, node in enumerate(batch_nodes): 1031 with graph.inserting_after(batch_op): 1032 getitem = graph.call_function(aten.select, args=(batch_op, 0, i)) 1033 node.replace_all_uses_with(getitem) 1034 getitem.meta.update(node.meta) 1035 graph.erase_node(node) 1036 counters["inductor"][ 1037 "batch_aten_" + self.op.__name__.lower().split(".")[0] 1038 ] += 1 1039 1040 1041@register_fusion("batch_tanh") 1042class BatchTanhPreGradFusion(BatchPointwiseOpsPreGradFusion): 1043 def __init__(self, **kwargs) -> None: 1044 super().__init__(torch.tanh, **kwargs) 1045 1046 1047@register_fusion("batch_sigmoid") 1048class BatchSigmoidPreGradFusion(BatchPointwiseOpsPreGradFusion): 1049 def __init__(self, **kwargs) -> None: 1050 super().__init__(torch.sigmoid, **kwargs) 1051 1052 1053@register_fusion("batch_relu") 1054class BatchReLuPreGradFusion(BatchPointwiseOpsPreGradFusion): 1055 def __init__(self, **kwargs) -> None: 1056 super().__init__(torch.nn.functional.relu, **kwargs) 1057 1058 1059@register_fusion("batch_aten_tanh", pre_grad=False) 1060class BatchTanhPostGradFusion(BatchPointwiseOpsPostGradFusion): 1061 def __init__(self, **kwargs) -> None: 1062 super().__init__(aten.tanh.default, **kwargs) 1063 1064 1065@register_fusion("batch_aten_sigmoid", pre_grad=False) 1066class BatchSigmoidPostGradFusion(BatchPointwiseOpsPostGradFusion): 1067 def __init__(self, **kwargs) -> None: 1068 super().__init__(aten.sigmoid.default, **kwargs) 1069 1070 1071@register_fusion("batch_aten_relu", pre_grad=False) 1072class BatchReLuPostGradFusion(BatchPointwiseOpsPostGradFusion): 1073 def __init__(self, **kwargs) -> None: 1074 super().__init__(aten.relu.default, **kwargs) 1075 1076 1077@register_fusion("batch_aten_add", pre_grad=False) 1078class BatchAddPostGradFusion(BatchPointwiseMathOpsPostGradFusion): 1079 def __init__(self, **kwargs) -> None: 1080 super().__init__(aten.add.Tensor, **kwargs) 1081 1082 1083@register_fusion("batch_aten_sub", pre_grad=False) 1084class BatchSubPostGradFusion(BatchPointwiseMathOpsPostGradFusion): 1085 def __init__(self, **kwargs) -> None: 1086 super().__init__(aten.sub.Tensor, **kwargs) 1087 1088 1089@register_fusion("batch_aten_div", pre_grad=False) 1090class BatchDivPostGradFusion(BatchPointwiseMathOpsPostGradFusion): 1091 def __init__(self, **kwargs) -> None: 1092 super().__init__(aten.div.Tensor, **kwargs) 1093 1094 1095@register_fusion("batch_aten_mul", pre_grad=False) 1096class BatchMulPostGradFusion(BatchPointwiseMathOpsPostGradFusion): 1097 def __init__(self, **kwargs) -> None: 1098 super().__init__(aten.mul.Tensor, **kwargs) 1099 1100 1101class _OrderedSet: 1102 def __init__(self, param=None) -> None: 1103 if param: 1104 self.rep = OrderedDict(dict.fromkeys(param)) 1105 else: 1106 self.rep = OrderedDict() 1107 1108 def __contains__(self, o) -> bool: 1109 return o in self.rep 1110 1111 def __len__(self) -> int: 1112 return self.rep.__len__() 1113 1114 def append(self, o): 1115 self.rep[o] = None 1116 1117 def __iter__(self): 1118 return self.rep.keys().__iter__() 1119 1120 1121def find_independent_subset_greedy( 1122 node_list: Iterable[torch.fx.Node], 1123 graph_search_options: Dict[str, Any], 1124) -> Iterator[Iterable[torch.fx.Node]]: 1125 """ 1126 Yields a list of subsets of `node_list` where no element in the subset 1127 depends on any other element in the subset. This results in a set of 1128 independent nodes which can be fused together. 1129 1130 The order of `node_list` is preserved within each subset so we can benefit 1131 from split-cat elimination in later passes. 1132 1133 During iteration it is only safe to mutate the graph by changing the nodes 1134 that have been returned. 1135 1136 graph_search_options: 1137 - min_fuse_set_size: Minimum size of the subset to consider. Subsets below 1138 this size will be ignored. 1139 - max_fuse_set_size: Maximum size of the subset to consider. Subsets will 1140 be broken to be at most this size. 1141 """ 1142 1143 # Compute all the children of `node` which are members of 1144 # `interesting_nodes`. 1145 def find_dependent_nodes(node, interesting_nodes): 1146 visited_node_set: Set[torch.fx.Node] = {node} 1147 dep_set: Set[torch.fx.Node] = set() 1148 1149 work = [node] 1150 while work: 1151 node = work.pop() 1152 for input_node in node.all_input_nodes: 1153 if input_node in interesting_nodes: 1154 dep_set.add(input_node) 1155 1156 if input_node not in visited_node_set: 1157 visited_node_set.add(input_node) 1158 work.append(input_node) 1159 1160 return dep_set 1161 1162 min_fuse_set_size = graph_search_options["min_fuse_set_size"] 1163 max_fuse_set_size = graph_search_options["max_fuse_set_size"] 1164 1165 # node_list needs to be a set because we only track the nodes that are left 1166 # in it (and we want to do the `in` on a set, not a list). But we want to 1167 # keep the correct order. 1168 node_list = _OrderedSet(node_list) 1169 1170 cache: Dict[torch.fx.Node, Set[torch.fx.Node]] = {} 1171 while node_list: 1172 subset: List[torch.fx.Node] = [] 1173 subset_deps: Set[torch.fx.Node] = set() 1174 1175 next_round_node_list = _OrderedSet() 1176 for node in node_list: 1177 if len(subset) >= max_fuse_set_size or node in subset_deps: 1178 next_round_node_list.append(node) 1179 continue 1180 1181 dep_set = cache.pop(node, None) 1182 if dep_set is None: 1183 dep_set = find_dependent_nodes(node, node_list) 1184 1185 if not dep_set.intersection(subset): 1186 subset.append(node) 1187 subset_deps.update(dep_set) 1188 else: 1189 next_round_node_list.append(node) 1190 cache[node] = dep_set 1191 1192 if len(subset) >= min_fuse_set_size: 1193 # Careful here - the caller uses the subsets to fuse nodes together 1194 # so we need to clear any cache entry that contains one of the 1195 # returned nodes because the dependency list could be different 1196 # (larger) after the merge. 1197 cache = {k: v for k, v in cache.items() if v.isdisjoint(subset)} 1198 yield subset 1199 1200 node_list = next_round_node_list 1201 1202 1203def get_fusion_candidates( 1204 rule: GroupBatchFusionBase, root_node: torch.fx.Node, fused_set: Set[torch.fx.Node] 1205) -> DefaultDict[Any, List[torch.fx.Node]]: 1206 """ 1207 Search fusion candidates for a specific rule using BFS starting from the root node. 1208 We only search the subgraph within graph_search_options["max_fuse_search_depth"]. 1209 """ 1210 q: Deque[Tuple[int, torch.fx.Node]] = collections.deque() 1211 1212 candidate_dict: DefaultDict[Any, List[torch.fx.Node]] = collections.defaultdict( 1213 list 1214 ) 1215 1216 if root_node.target in SEARCH_EXCLUSIONS: 1217 return candidate_dict 1218 1219 visited_set: Set[torch.fx.Node] = set() 1220 1221 for next_node in root_node.all_input_nodes: 1222 q.append((1, next_node)) 1223 visited_set.add(next_node) 1224 1225 while len(q) > 0: 1226 depth, node = q.popleft() 1227 1228 if node in fused_set: 1229 continue 1230 1231 key = rule.match(node) 1232 if key is not None: 1233 candidate_nodes = candidate_dict[key] 1234 if node not in candidate_nodes: 1235 candidate_nodes.append(node) 1236 else: 1237 if depth < rule.graph_search_options["max_fuse_search_depth"]: 1238 for next_node in node.all_input_nodes: 1239 if next_node not in visited_set: 1240 visited_set.add(next_node) 1241 q.append((depth + 1, next_node)) 1242 1243 return candidate_dict 1244 1245 1246def apply_group_batch_fusion(graph: torch.fx.GraphModule, rule: GroupBatchFusionBase): 1247 stable_topological_sort(graph) # type: ignore[arg-type] 1248 fused_set: Set[torch.fx.Node] = set() 1249 log_to_scuba = False 1250 1251 for node in reversed(graph.nodes): 1252 candidates = get_fusion_candidates(rule, node, fused_set) 1253 1254 for key, candidate_nodes in candidates.items(): 1255 if len(candidate_nodes) < rule.graph_search_options["min_fuse_set_size"]: 1256 continue 1257 1258 for subset in find_independent_subset_greedy( 1259 candidate_nodes, rule.graph_search_options 1260 ): 1261 rule.fuse(graph, subset) 1262 fused_set.update(subset) 1263 log.debug( 1264 f"{rule.__class__.__name__}: key = {key}; subset size = {len(list(subset))}" # noqa: G004 1265 ) 1266 log_to_scuba = True 1267 if log_to_scuba: 1268 optimus_scuba_log[rule.__class__.__name__] = upload_graph(graph) 1269 1270 1271def generate_fusion_from_config(config_options: Dict[str, Any], pre_grad=True): 1272 fusions: List[GroupBatchFusionBase] = [] 1273 for name, options in config_options.items(): 1274 # we skip all patterns from pattern_matcher passes (e.g., split_cat) 1275 if name not in PRE_GRAD_FUSIONS and name not in POST_GRAD_FUSIONS: 1276 continue 1277 fusion_cls = PRE_GRAD_FUSIONS[name] if pre_grad else POST_GRAD_FUSIONS[name] 1278 _options = graph_search_options.copy() 1279 _options.update(options) 1280 fusions.append(fusion_cls(graph_search_options=_options)) # type: ignore[operator] 1281 return fusions 1282 1283 1284def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True): 1285 fusions: List[GroupBatchFusionBase] = [] 1286 # we keep all current pre grad fusions to keep 1287 # current implementation, will remove this later 1288 if pre_grad: 1289 fusions += generate_fusion_from_config( 1290 config.pre_grad_fusion_options, pre_grad=True 1291 ) 1292 else: 1293 fbgemm_fusion_keys = [ 1294 x 1295 for x in config.post_grad_fusion_options 1296 if config.post_grad_fusion_options[x].get("require_fbgemm", False) 1297 ] 1298 fbgemm_fusions = { 1299 fusion: config.post_grad_fusion_options[fusion] 1300 for fusion in fbgemm_fusion_keys 1301 } 1302 non_fbgemm_fusions = { 1303 fusion: config.post_grad_fusion_options[fusion] 1304 for fusion in config.post_grad_fusion_options.keys() 1305 if fusion not in fbgemm_fusion_keys 1306 } 1307 fusions += generate_fusion_from_config(non_fbgemm_fusions, pre_grad=False) 1308 if has_fbgemm: 1309 fusions += generate_fusion_from_config(fbgemm_fusions, pre_grad=False) 1310 1311 for i, rule in enumerate(fusions): 1312 with GraphTransformObserver( 1313 graph.owning_module, 1314 f"group_batch_fusion_{i}", 1315 config.trace.log_url_for_graph_xform, 1316 ): 1317 apply_group_batch_fusion(graph, rule) # type: ignore[arg-type] 1318