• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import collections
3import logging
4import operator
5from collections import OrderedDict
6from typing import (
7    Any,
8    DefaultDict,
9    Deque,
10    Dict,
11    Iterable,
12    Iterator,
13    List,
14    Optional,
15    Set,
16    Tuple,
17)
18
19import torch
20from torch._dynamo.utils import counters, optimus_scuba_log
21from torch._utils_internal import upload_graph
22from torch.fx.passes.graph_transform_observer import GraphTransformObserver
23
24from .. import config
25from ..pattern_matcher import (
26    CallFunctionVarArgs,
27    get_arg_value,
28    stable_topological_sort,
29)
30
31
32try:
33    # importing this will register fbgemm lowerings for inductor
34    import deeplearning.fbgemm.fbgemm_gpu.fb.inductor_lowerings  # noqa: F401
35
36    has_fbgemm = True
37except Exception:
38    has_fbgemm = False
39
40aten = torch.ops.aten
41
42log = logging.getLogger(__name__)
43
44MIN_FUSE_SET_SIZE = 5
45MAX_FUSE_SET_SIZE = 300
46MAX_FUSE_SEARCH_DEPTH = 5
47# The maximum tensor size that can go into the fusion group
48MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR = 4096
49# Whether we only fuse nodes with same parent node
50FUSE_NODES_WITH_SAME_PARENT = False
51# Whether we enable the add broadcast in batch linear
52SHAPE_BROADCAST_BATCH_LINEAR = False
53# Whether we enable the fuse nodes with same users
54Fuse_NODES_WITH_SAME_USERS = False
55
56# exclude these nodes from BFS
57# excluding get item improves optimizer compilation time by 60s
58SEARCH_EXCLUSIONS = {operator.getitem}
59
60
61default_graph_search_options = {
62    "min_fuse_set_size": MIN_FUSE_SET_SIZE,
63    "max_fuse_set_size": MAX_FUSE_SET_SIZE,
64    "max_fuse_search_depth": MAX_FUSE_SEARCH_DEPTH,
65    "max_fuse_tensor_size_group_linear": MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR,
66    "fuse_nodes_with_same_parent": FUSE_NODES_WITH_SAME_PARENT,
67    "shape_broadcast_batch_linear": SHAPE_BROADCAST_BATCH_LINEAR,
68    "fuse_nodes_with_same_users": Fuse_NODES_WITH_SAME_USERS,
69}
70
71graph_search_options = default_graph_search_options
72
73
74def update_stack_example_value(node, metadata, dim=0, op=torch.stack):
75    """
76    Update the example value of the node in the graph to enable followup split cat opt.
77    """
78    if node is not None and hasattr(node, "meta"):
79        if op == torch.stack:
80            example_value = torch.stack(metadata, dim=dim)
81        elif op == torch.unbind:
82            example_value = torch.unbind(metadata, dim=dim)  # type: ignore[assignment]
83        else:
84            return
85        node.meta["example_value"] = example_value
86
87
88def update_pointwise_example_value(pointwise_node, input, other, op):
89    """
90    Update the example value of the add node in the graph to enable followup split cat opt.
91    """
92    if pointwise_node is not None and hasattr(pointwise_node, "meta"):
93        if op == torch.add:
94            example_value = torch.add(input, other)
95        elif op == torch.mul:
96            example_value = torch.mul(input, other)
97        else:
98            return
99        pointwise_node.meta["example_value"] = example_value
100
101
102class GroupBatchFusionBase:
103    def __init__(self, **kwargs) -> None:
104        self.graph_search_options = kwargs.pop(
105            "graph_search_options", default_graph_search_options
106        )
107
108    def match(self, node):
109        raise NotImplementedError("match called on base")
110
111    def fuse(self, graph, subset):
112        raise NotImplementedError("fuse called on base")
113
114
115PRE_GRAD_FUSIONS: Dict[str, GroupBatchFusionBase] = {}
116POST_GRAD_FUSIONS: Dict[str, GroupBatchFusionBase] = {}
117
118
119def register_fusion(name: str, pre_grad=True):
120    def decorator(fusion_cls: GroupBatchFusionBase):
121        if pre_grad:
122            PRE_GRAD_FUSIONS[name] = fusion_cls
123        else:
124            POST_GRAD_FUSIONS[name] = fusion_cls
125        return fusion_cls
126
127    return decorator
128
129
130def list_group_batch_fusions(pre_grad=True) -> List[str]:
131    if pre_grad:
132        return list(PRE_GRAD_FUSIONS.keys())
133    else:
134        return list(POST_GRAD_FUSIONS.keys())
135
136
137def decompose_stack(graph: torch.fx.GraphModule, input_tensors: List[Any]) -> Any:
138    unsqueezed_inputs = []
139    unsqueezed_inputs_meta = []
140    for input_tensor in input_tensors:
141        unsqueezed_input = graph.call_function(
142            aten.unsqueeze, args=(input_tensor,), kwargs={"dim": 0}
143        )
144        unsqueezed_inputs.append(unsqueezed_input)
145        unsqueezed_input.meta["val"] = aten.unsqueeze(input_tensor.meta["val"], dim=0)  # type: ignore[assignment]
146        unsqueezed_inputs_meta.append(unsqueezed_input.meta["val"])
147    stacked_inputs = graph.call_function(
148        aten.cat, args=(unsqueezed_inputs,), kwargs={"dim": 0}
149    )
150    stacked_inputs.meta["val"] = aten.cat(unsqueezed_inputs_meta, dim=0)  # type: ignore[assignment]
151    return stacked_inputs
152
153
154class GroupFusion(GroupBatchFusionBase):
155    """
156    Fuse ops in a group way, e.g, fuse mm/addmm of arbitrary input shapes with fbgemm.gmm.
157    """
158
159
160class BatchFusion(GroupBatchFusionBase):
161    """
162    Fuse ops in a batch way, e.g, fuse mm/addmm of same input shapes with bmm.
163    """
164
165
166class BatchPointwiseOpsFusionFactory(BatchFusion):
167    def __init__(self, op, **kwargs) -> None:
168        super().__init__(**kwargs)
169        self.op = op
170
171
172@register_fusion("batch_linear_post_grad", pre_grad=False)
173class PostGradBatchLinearFusion(BatchFusion):
174    """
175    Fuse ops in a batch way in post grad (aten level).
176    """
177
178    def _addmm_node_can_be_fused(self, node: torch.fx.Node) -> bool:
179        # pyre-fixme[7]: Incompatible return type
180        return (
181            node.kwargs.get("beta", 1.0) == 1.0 and node.kwargs.get("alpha", 1.0) == 1.0  # type: ignore[return-value]
182        )
183
184    def _is_input_2d(self, input: torch.fx.Node) -> bool:
185        input_shapes = input.meta["val"].shape
186        return (
187            len(input_shapes) == 2
188            and isinstance(input_shapes[0], int)
189            and isinstance(input_shapes[1], int)
190        )
191
192    def match(
193        self, node: torch.fx.Node
194    ) -> Optional[Tuple[str, int, int, int, bool, str]]:
195        if CallFunctionVarArgs(aten.mm).match(node):
196            input_m, weight_m = node.args
197            bias_m = None
198
199        elif CallFunctionVarArgs(aten.addmm.default).match(
200            node
201        ) and self._addmm_node_can_be_fused(node):
202            bias_m, input_m, weight_m = node.args
203        else:
204            return None
205        # get the user of the node
206        if self.graph_search_options.get("fuse_nodes_with_same_users", False):
207            users = [user.target for user in node.users.keys()]
208        else:
209            users = ""  # type: ignore[assignment]
210        # only handle the cases where inputs are 2D tensors
211        if not self._is_input_2d(input_m) or not self._is_input_2d(weight_m):  # type: ignore[arg-type]
212            return None
213        m, k = input_m.meta["val"].shape  # type: ignore[union-attr]
214        n = weight_m.meta["val"].shape[1]  # type: ignore[union-attr]
215        batch_key = ("batch_linear_post_grad", m, k, n, bias_m is not None, str(users))
216        return batch_key
217
218    def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
219        batch_inputs = []
220        batch_weights = []
221        batch_biases = []
222        batch_nodes = []
223        batch_inputs_meta = []
224        batch_weights_meta = []
225        batch_biases_meta = []
226
227        for node in subset:
228            if CallFunctionVarArgs(aten.addmm.default).match(node):
229                bias, input, weight = node.args
230            elif CallFunctionVarArgs(aten.mm.default).match(node):
231                input, weight = node.args
232                bias = None
233            batch_nodes.append(node)
234            batch_inputs.append(input)  # type: ignore[possibly-undefined]
235            batch_weights.append(weight)  # type: ignore[possibly-undefined]
236            batch_biases.append(bias)  # type: ignore[possibly-undefined]
237            batch_inputs_meta.append(input.meta)  # type: ignore[possibly-undefined, union-attr]
238            batch_weights_meta.append(weight.meta)  # type: ignore[possibly-undefined, union-attr]
239            if bias is not None:  # type: ignore[possibly-undefined]
240                batch_biases_meta.append(bias.meta)  # type: ignore[possibly-undefined, union-attr]
241            else:
242                batch_biases_meta.append(None)
243
244        with graph.inserting_before(subset[-1]):
245            fused_inputs = decompose_stack(graph, batch_inputs)
246            fused_weights = decompose_stack(graph, batch_weights)
247            fused_inputs_meta_val = torch.stack(
248                [input["val"] for input in batch_inputs_meta]
249            )
250            fused_weights_meta_val = torch.stack(
251                [weight["val"] for weight in batch_weights_meta]
252            )
253            fused_bmm = graph.call_function(
254                aten.bmm,
255                args=(fused_inputs, fused_weights),
256            )
257            fused_bmm.meta["val"] = aten.bmm(
258                fused_inputs_meta_val, fused_weights_meta_val
259            )
260        for i, original_mm in enumerate(batch_nodes):
261            has_bias = False
262            with graph.inserting_after(fused_bmm):
263                new_mm = graph.call_function(aten.select, args=((fused_bmm, 0, i)))
264                new_mm.meta["val"] = aten.select(fused_bmm.meta["val"], 0, i)
265                if batch_biases[i]:
266                    has_bias = True
267                    # broadcast the bias to the same shape as the mm output
268                    if self.graph_search_options.get(
269                        "shape_broadcast_batch_linear", False
270                    ):
271                        broadcast_shape = torch.broadcast_shapes(
272                            batch_biases_meta[i]["val"].shape, new_mm.meta["val"].shape
273                        )
274                        broadcast_bias = graph.call_function(
275                            aten.broadcast_to.default,
276                            args=(batch_biases[i],),
277                            kwargs={"size": broadcast_shape},
278                        )
279                        broadcast_bias.meta["val"] = aten.broadcast_to(batch_biases_meta[i]["val"], broadcast_shape)  # type: ignore[assignment]
280                        new_bias_add = graph.call_function(
281                            aten.add.Tensor, args=((broadcast_bias, new_mm))
282                        )
283                        new_bias_add.meta["val"] = aten.add.Tensor(
284                            broadcast_bias.meta["val"], new_mm.meta["val"]
285                        )
286                    else:
287                        new_bias_add = graph.call_function(
288                            aten.add, args=((batch_biases[i], new_mm))
289                        )
290                        new_bias_add.meta["val"] = aten.add.Tensor(
291                            batch_biases_meta[i]["val"], new_mm.meta["val"]
292                        )
293            new_mm_cont = new_bias_add if has_bias else new_mm  # type: ignore[possibly-undefined]
294            original_mm.replace_all_uses_with(new_mm_cont)
295            new_mm_cont.meta.update(original_mm.meta)
296            graph.erase_node(original_mm)
297        counters["inductor"]["batch_linear_post_grad"] += 1
298
299
300@register_fusion("group_linear", pre_grad=False)
301class GroupLinearFusion(GroupFusion):
302    def _addmm_node_can_be_fused(self, node: torch.fx.Node):
303        input_shape = node.args[1].meta["val"].shape  # type: ignore[union-attr]
304        weight_shape = node.args[2].meta["val"].shape  # type: ignore[union-attr]
305        return (
306            node.kwargs.get("beta", 1.0) == 1.0
307            and node.kwargs.get("alpha", 1.0) == 1.0
308            and len(input_shape) == 2
309            and len(weight_shape) == 2
310            and all(x % 2 == 0 for x in input_shape + weight_shape)
311            and all(
312                shape <= self.graph_search_options["max_fuse_tensor_size_group_linear"]
313                for shape in input_shape + weight_shape
314            )
315        )
316
317    def _mm_node_can_be_fused(self, node: torch.fx.Node):
318        input_shape = node.args[0].meta["val"].shape  # type: ignore[union-attr]
319        weight_shape = node.args[1].meta["val"].shape  # type: ignore[union-attr]
320        return (
321            len(input_shape) == 2
322            and len(weight_shape) == 2
323            and all(x % 2 == 0 for x in input_shape + weight_shape)
324            and all(
325                shape <= self.graph_search_options["max_fuse_tensor_size_group_linear"]
326                for shape in input_shape + weight_shape
327            )
328        )
329
330    def match(self, node: torch.fx.Node) -> Optional[Tuple[str, bool]]:
331        if CallFunctionVarArgs(aten.mm.default).match(
332            node
333        ) and self._mm_node_can_be_fused(node):
334            group_key = ("group_linear", True)
335        elif CallFunctionVarArgs(aten.addmm.default).match(
336            node
337        ) and self._addmm_node_can_be_fused(node):
338            bias = node.args[0]
339            group_key = ("group_linear", bias is None)
340        else:
341            group_key = None
342        return group_key
343
344    def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
345        group_inputs = []
346        group_weights = []
347        group_biases = []
348        group_nodes = []
349        for node in subset:
350            if CallFunctionVarArgs(aten.addmm.default).match(node):
351                bias, input, weight = node.args
352            else:
353                assert CallFunctionVarArgs(aten.mm.default).match(node)
354                input, weight = node.args
355                bias = None
356
357            group_nodes.append(node)
358            group_inputs.append(input)
359            group_weights.append(weight)
360            group_biases.append(bias)
361
362        if all(bias is None for bias in group_biases):
363            group_biases = None  # type: ignore[assignment]
364
365        with graph.inserting_before(subset[0]):
366            fused_mm = graph.call_function(
367                torch.ops.fbgemm.gmm.default,
368                args=(group_inputs, group_weights, group_biases),
369                kwargs={"smart_fused": True},
370            )
371
372        for i, original_mm in enumerate(group_nodes):
373            with graph.inserting_after(fused_mm):
374                new_mm = graph.call_function(operator.getitem, args=(fused_mm, i))
375            original_mm.replace_all_uses_with(new_mm)
376            new_mm.meta.update(original_mm.meta)
377            graph.erase_node(original_mm)
378        counters["inductor"]["group_linear"] += 1
379
380
381class BatchPointwiseMathOpsPostGradFusion(BatchPointwiseOpsFusionFactory):
382    """
383    Batch pointwise math operator (e.g., add, mul) in post grad pass.
384    """
385
386    def __init__(self, op, **kwargs) -> None:
387        super().__init__(op, **kwargs)
388        self.op = op
389
390    def _pointwise_node_can_be_fused(self, node: torch.fx.Node):
391        # note: we only consider the case where the inputs are tensors
392        # for mixed precision training, we need to make sure the inputs
393        # of the aten.cat when do the stack should be the same dtype
394        # otherwise, the output of the aten.cat may be not the same as
395        # its inputs, and cause dtype not same error in mm or addmm
396        input, other = node.args
397        return (
398            input.meta["val"].shape == other.meta["val"].shape  # type: ignore[union-attr]
399            if hasattr(input, "meta")
400            and hasattr(other, "meta")
401            and "val" in input.meta  # type: ignore[union-attr]
402            and "val" in other.meta  # type: ignore[union-attr]
403            else False
404        )
405
406    def match(self, node: torch.fx.Node):
407        if CallFunctionVarArgs(self.op).match(
408            node
409        ) and self._pointwise_node_can_be_fused(node):
410            alpha = node.kwargs.get("alpha", 1.0)
411            rounding_mode = node.kwargs.get("rounding_mode", None)
412            input, other = node.args
413            shape = list(input.meta["val"].shape)  # type: ignore[union-attr]
414            if self.graph_search_options.get("fuse_nodes_with_same_parent", False):
415                # only consider the linear case so far
416                # pyre-fixme[16]
417                if input.target == aten.select or other.target == aten.select:  # type: ignore[union-attr]
418                    parent = (
419                        # pyre-fixme[16]
420                        input.args[0]  # type: ignore[union-attr]
421                        # pyre-fixme[16]
422                        if input.target == aten.select  # type: ignore[union-attr]
423                        else other.args[0]  # type: ignore[union-attr]
424                    )
425                else:
426                    parent = ""
427            else:
428                parent = ""
429            group_key = (
430                "batch_aten_" + self.op.__name__.lower().split(".")[0],
431                str(shape),
432                str(input.meta["val"].dtype),  # type: ignore[union-attr]
433                str(other.meta["val"].dtype),  # type: ignore[union-attr]
434                str(alpha),
435                str(rounding_mode),
436                str(parent),
437            )
438        else:
439            group_key = None
440        return group_key
441
442    def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
443        batch_inputs, batch_others = [], []
444        alpha = subset[0].kwargs.get("alpha", 1.0)
445        batch_inputs_meta, batch_others_meta = [], []
446
447        for node in subset:
448            input, other = node.args
449            batch_inputs.append(input)
450            batch_others.append(other)
451            batch_inputs_meta.append(input.meta)  # type: ignore[possibly-undefined, union-attr]
452            batch_others_meta.append(other.meta)  # type: ignore[possibly-undefined, union-attr]
453
454        with graph.inserting_before(subset[0]):
455            stack_inputs = decompose_stack(graph, batch_inputs)
456            stack_others = decompose_stack(graph, batch_others)
457            stack_inputs_meta = torch.stack(
458                [input["val"] for input in batch_inputs_meta]
459            )
460            stack_others_meta = torch.stack(
461                [other["val"] for other in batch_others_meta]
462            )
463
464            batch_op = graph.call_function(
465                self.op,
466                args=(stack_inputs, stack_others),
467                kwargs={"alpha": alpha} if self.op == aten.add.Tensor else {},
468            )
469            batch_op.meta["val"] = self.op(stack_inputs_meta, stack_others_meta)
470            for i, original_add in enumerate(subset):
471                with graph.inserting_after(batch_op):
472                    new_add = graph.call_function(
473                        torch.ops.aten.select, args=((batch_op, 0, i))
474                    )
475                original_add.replace_all_uses_with(new_add)
476                new_add.meta.update(original_add.meta)
477                graph.erase_node(original_add)
478        counters["inductor"][
479            "batch_aten_" + self.op.__name__.lower().split(".")[0]
480        ] += 1
481
482
483@register_fusion("batch_linear_lhs")
484class BatchLinearLHSFusion(BatchFusion):
485    """
486    Batch linear left-hand side fusion. This pass tries to fuse the following patterns:
487
488        torch.nn.functional.linear(x, w1), linear(x, w2),... * linear(x, wn)
489        -> torch.mm(x, torch.cat([w1, w2,... * wn]).transpose(0, 1))
490
491    We have a separate pass to eliminate contiguous transpose in a generic way.
492    """
493
494    def match(self, node: torch.fx.Node) -> Optional[Tuple[str, bool, Any]]:
495        if CallFunctionVarArgs(torch.nn.functional.linear).match(
496            node
497        ) and is_linear_node_can_be_fused(node):
498            input = get_arg_value(node, 0, "input")
499            bias = get_arg_value(node, 2, "bias")
500            group_key = ("batch_linear_lhs", bias is None, input)
501        else:
502            group_key = None
503        return group_key
504
505    def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
506        batch_nodes = []
507        batch_input = None
508        batch_weights, batch_weights_meta = [], []
509        batch_biases, batch_biases_meta = [], []
510        split_sections = []
511        for node in subset:
512            input = get_arg_value(node, 0, "input")
513            weight = get_arg_value(node, 1, "weight")
514            bias = get_arg_value(node, 2, "bias")
515            batch_nodes.append(node)
516            if batch_input is None:
517                batch_input = input
518            else:
519                assert batch_input is input
520            batch_weights.append(weight)
521            batch_weights_meta.append(weight.meta["example_value"])
522            if bias:
523                batch_biases.append(bias)
524                batch_biases_meta.append(bias.meta["example_value"])
525            split_sections.append(weight.meta["example_value"].shape[0])
526
527        with graph.inserting_before(subset[0]):
528            cat_weights = graph.call_function(
529                torch.cat, args=(batch_weights,), kwargs={"dim": 0}
530            )
531            cat_weights.meta["example_value"] = torch.cat(batch_weights_meta, dim=0)
532            transposed_weights = graph.call_function(
533                torch.transpose, args=(cat_weights, 0, 1)
534            )
535            transposed_weights.meta["example_value"] = torch.transpose(
536                cat_weights.meta["example_value"], 0, 1
537            )
538            if len(batch_biases) > 0:
539                cat_biases = graph.call_function(
540                    torch.cat, args=(batch_biases,), kwargs={"dim": 0}
541                )
542                cat_biases.meta["example_value"] = torch.cat(batch_biases_meta, dim=0)
543                fused_lhs = graph.call_function(
544                    torch.addmm,
545                    args=(cat_biases, batch_input, transposed_weights),
546                )
547                fused_lhs.meta["example_value"] = torch.addmm(
548                    cat_biases.meta["example_value"],
549                    batch_input.meta["example_value"],  # type: ignore[union-attr]
550                    transposed_weights.meta["example_value"],
551                )
552            else:
553                fused_lhs = graph.call_function(
554                    torch.mm,
555                    args=(batch_input, transposed_weights),
556                )
557                fused_lhs.meta["example_value"] = torch.mm(
558                    batch_input.meta["example_value"],  # type: ignore[union-attr]
559                    transposed_weights.meta["example_value"],
560                )
561            fused_lhs_list = graph.call_function(
562                torch.split, args=(fused_lhs, split_sections), kwargs={"dim": 1}
563            )
564
565        for i, node in enumerate(batch_nodes):
566            with graph.inserting_after(fused_lhs_list):
567                new_node = graph.call_function(
568                    operator.getitem, args=(fused_lhs_list, i)
569                )
570            node.replace_all_uses_with(new_node)
571            new_node.meta.update(node.meta)
572            graph.erase_node(node)
573        counters["inductor"]["batch_linear_lhs"] += 1
574
575
576def is_node_meta_valid(node: Optional[torch.fx.Node]):
577    return node is None or "example_value" in node.meta or "val" in node.meta
578
579
580# Poor person's check for if a node in the graph mutates its input.
581# (the graph is torch IR, so we will see torch fns and python operators)
582def _is_mutable_node(tgt):
583    if str(tgt).endswith("_"):
584        # e.g. torch.mul_, torch.Tensor.mul_
585        return True
586    if (
587        hasattr(tgt, "__module__")
588        and tgt.__module__ == "_operator"
589        and tgt.__name__.startswith("i")
590    ):
591        # e.g. operator.iand, operator.imul
592        return True
593    return False
594
595
596def is_linear_node_can_be_fused(node: torch.fx.Node):
597    input = get_arg_value(node, 0, "input")
598    weight = get_arg_value(node, 1, "weight")
599    return (
600        is_node_meta_valid(node)
601        and is_node_meta_valid(input)
602        and is_node_meta_valid(weight)
603        and len(input.meta["example_value"].shape) == 2
604        and len(weight.meta["example_value"].shape) == 2
605        # the mm -> bmm transform adds an unbind() op,
606        # which is not safe for autograd when the output of the mm is mutated.
607        # don't pattern match if any users of the mm mutate the input.
608        and not any(_is_mutable_node(user.target) for user in node.users)
609    )
610
611
612@register_fusion("batch_linear")
613class PreGradBatchLinearFusion(BatchFusion):
614    """
615    Batch linear fusion in pre grad pass.
616    Fuse linear with same size with torch.baddmm
617    """
618
619    def _getitem_args(self, getitem_node: torch.fx.Node):
620        if getitem_node.target != operator.__getitem__ or (
621            getitem_node.op != "call_function"
622        ):
623            return None
624        return getitem_node.args[0]
625
626    def match(self, node: torch.fx.Node):
627        if CallFunctionVarArgs(torch.nn.functional.linear).match(
628            node
629        ) and is_linear_node_can_be_fused(node):
630            input = get_arg_value(node, 0, "input")
631            weight = get_arg_value(node, 1, "weight")
632            bias = get_arg_value(node, 2, "bias")
633            if self.graph_search_options.get("fuse_nodes_with_same_users", False):
634                users = [user.target for user in node.users.keys()]
635            else:
636                users = ""  # type: ignore[assignment]
637            group_key = (
638                "batch_linear",
639                self._getitem_args(input),
640                str(input.meta["example_value"].shape),
641                str(weight.meta["example_value"].shape),
642                bias is None,
643                str(users),
644            )
645        else:
646            group_key = None
647        return group_key
648
649    def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
650        batch_nodes = []
651        batch_inputs = []
652        batch_weights = []
653        batch_biases = []
654        batch_inputs_metadata = []
655        batch_weights_metadata = []
656        batch_biases_metadata = []
657        for node in subset:
658            batch_nodes.append(node)
659            input = get_arg_value(node, 0, "input")
660            batch_inputs.append(input)
661            batch_inputs_metadata.append(input.meta["example_value"])
662            weight = get_arg_value(node, 1, "weight")
663            batch_weights.append(weight)
664            batch_weights_metadata.append(weight.meta["example_value"])
665            bias = get_arg_value(node, 2, "bias")
666            batch_biases.append(bias)
667            if bias is not None and hasattr(bias, "meta"):
668                batch_biases_metadata.append(bias.meta["example_value"])
669
670        with graph.inserting_before(subset[0]):
671            stack_inputs = graph.call_function(
672                torch.stack, args=(batch_inputs,), kwargs={"dim": 0}
673            )
674            update_stack_example_value(stack_inputs, batch_inputs_metadata)
675            stack_weights = graph.call_function(
676                torch.stack, args=(batch_weights,), kwargs={"dim": 0}
677            )
678            update_stack_example_value(stack_weights, batch_weights_metadata)
679            transpose_weight = graph.call_function(
680                torch.transpose, args=(stack_weights, 1, 2)
681            )
682            transpose_weight.meta["example_value"] = torch.transpose(
683                stack_weights.meta["example_value"], 1, 2
684            )
685            if all(bias is None for bias in batch_biases):
686                bmm = graph.call_function(
687                    torch.bmm,
688                    args=(stack_inputs, transpose_weight),
689                )
690                bmm.meta["example_value"] = torch.bmm(
691                    stack_inputs.meta["example_value"],
692                    transpose_weight.meta["example_value"],
693                )
694                bmm_meta = bmm.meta["example_value"]
695            else:
696                stack_biases = graph.call_function(
697                    torch.stack, args=(batch_biases,), kwargs={"dim": 0}
698                )
699                update_stack_example_value(stack_biases, batch_biases_metadata)
700                unsqueeze_biases = graph.call_function(
701                    torch.unsqueeze, args=(stack_biases, 1)
702                )
703                unsqueeze_biases.meta["example_value"] = torch.unsqueeze(
704                    stack_biases.meta["example_value"], 1
705                )
706                bmm = graph.call_function(
707                    torch.baddbmm,
708                    args=(unsqueeze_biases, stack_inputs, transpose_weight),
709                )
710                try:
711                    # it will have runtime error to broadcast when it has dynamic shape included
712                    # in the meta data, so we need to skip the update meta data
713                    bmm.meta["example_value"] = torch.baddbmm(
714                        unsqueeze_biases.meta["example_value"],
715                        stack_inputs.meta["example_value"],
716                        transpose_weight.meta["example_value"],
717                    )
718                    bmm_meta = bmm.meta["example_value"]
719                except Exception as e:
720                    log.debug(
721                        f" exception when update bmm meta data with stack error tracekey {e}"  # noqa: G004
722                    )
723                    bmm_meta = None
724
725            bmm = graph.call_function(torch.unbind, args=(bmm,), kwargs={"dim": 0})
726            if bmm_meta is not None:
727                bmm.meta["example_value"] = torch.unbind(bmm_meta, dim=0)
728            for i, linear in enumerate(batch_nodes):
729                with graph.inserting_after(bmm):
730                    getitem = graph.call_function(operator.getitem, args=(bmm, i))
731                linear.replace_all_uses_with(getitem)
732                getitem.meta.update(linear.meta)
733                graph.erase_node(linear)
734        counters["inductor"]["batch_linear"] += 1
735
736
737@register_fusion("batch_layernorm")
738class BatchLayernormFusion(BatchFusion):
739    """
740    Batch layer norm fusion in pre grad pass
741    """
742
743    def match(self, node: torch.fx.Node):
744        if CallFunctionVarArgs(torch.nn.functional.layer_norm).match(node):
745            input = get_arg_value(node, 0, "input")
746            weight = get_arg_value(node, 2, "weight")
747            bias = get_arg_value(node, 3, "bias")
748            if self.graph_search_options.get("fuse_nodes_with_same_users", False):
749                users = [user.target for user in node.users.keys()]
750            else:
751                users = ""  # type: ignore[assignment]
752            group_key = (
753                (
754                    "batch_layernorm",
755                    str(input.meta["example_value"].shape),
756                    str(weight.meta["example_value"].shape)
757                    if weight is not None
758                    else "",
759                    str(bias.meta["example_value"].shape) if bias is not None else "",
760                    str(get_arg_value(node, 1, "normalized_shape")),
761                    str(get_arg_value(node, 4, "eps")),
762                    str(users),
763                )
764                if "example_value" in input.meta
765                and is_node_meta_valid(weight)
766                and is_node_meta_valid(bias)
767                else None
768            )
769        else:
770            group_key = None
771        return group_key
772
773    def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
774        group_inputs = []
775        group_shapes = []
776        group_weights = []
777        group_biases = []
778        group_epss = []
779        group_nodes = []
780        group_inputs_metadata = []
781        group_biases_metadata = []
782        group_weights_metadata = []
783        for node in subset:
784            group_nodes.append(node)
785            input = get_arg_value(node, 0, "input")
786            group_inputs.append(input)
787            group_inputs_metadata.append(input.meta["example_value"])
788            group_shapes.append(get_arg_value(node, 1, "normalized_shape"))
789            weight = get_arg_value(node, 2, "weight")
790            group_weights.append(weight)
791            if weight is not None and hasattr(weight, "meta"):
792                group_weights_metadata.append(weight.meta["example_value"])
793            bias = get_arg_value(node, 3, "bias")
794            group_biases.append(bias)
795            if bias is not None and hasattr(bias, "meta"):
796                group_biases_metadata.append(bias.meta["example_value"])
797            eps = get_arg_value(node, 4, "eps")
798            if eps is None:
799                eps = 1e-5
800            group_epss.append(eps)
801        stack_dim = -1 - len(group_shapes[-1])
802
803        if all(bias is None for bias in group_biases):
804            group_biases = None  # type: ignore[assignment]
805        if all(weight is None for weight in group_weights):
806            group_weights = None  # type: ignore[assignment]
807        assert all(
808            eps == group_epss[0] for eps in group_epss
809        ), "all epsilon values must be equal"
810
811        with graph.inserting_before(subset[0]):
812            stack_input = graph.call_function(
813                torch.stack, args=(group_inputs,), kwargs={"dim": stack_dim}
814            )
815            update_stack_example_value(stack_input, group_inputs_metadata, stack_dim)
816            if group_weights is not None:
817                stack_weight = graph.call_function(
818                    torch.stack, args=(group_weights,), kwargs={"dim": 0}
819                )
820                update_stack_example_value(stack_weight, group_weights_metadata)
821            else:
822                stack_weight = None
823            if group_biases is not None:
824                stack_bias = graph.call_function(
825                    torch.stack, args=(group_biases,), kwargs={"dim": 0}
826                )
827                update_stack_example_value(stack_bias, group_biases_metadata)
828            else:
829                stack_bias = None
830
831            batch_layer_norm = graph.call_function(
832                torch.nn.functional.layer_norm,
833                args=(stack_input, group_shapes[-1]),
834                kwargs={"eps": group_epss[-1]},
835            )
836            batch_layer_norm.meta["example_value"] = stack_input.meta["example_value"]
837
838            if group_weights is not None and group_biases is not None:
839                previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"]
840                batch_layer_norm = graph.call_function(
841                    torch.mul, args=(stack_weight, batch_layer_norm)
842                )
843                update_pointwise_example_value(
844                    batch_layer_norm,
845                    stack_weight.meta["example_value"],
846                    previous_batch_layer_norm_meta,
847                    torch.mul,
848                )
849                previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"]
850                batch_layer_norm = graph.call_function(
851                    torch.add, args=(stack_bias, batch_layer_norm)
852                )
853                update_pointwise_example_value(
854                    batch_layer_norm,
855                    stack_bias.meta["example_value"],
856                    previous_batch_layer_norm_meta,
857                    torch.add,
858                )
859            elif group_weights is not None and group_biases is None:
860                previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"]
861                batch_layer_norm = graph.call_function(
862                    torch.mul, args=(stack_weight, batch_layer_norm)
863                )
864                update_pointwise_example_value(
865                    batch_layer_norm,
866                    stack_weight.meta["example_value"],
867                    previous_batch_layer_norm_meta,
868                    torch.mul,
869                )
870            elif group_weights is None and group_biases is not None:
871                previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"]
872                batch_layer_norm = graph.call_function(
873                    torch.add, args=(stack_bias, batch_layer_norm)
874                )
875                update_pointwise_example_value(
876                    batch_layer_norm,
877                    stack_bias.meta["example_value"],
878                    previous_batch_layer_norm_meta,
879                    torch.add,
880                )
881
882            batch_layer_norm_unbind = graph.call_function(
883                torch.unbind,
884                args=(batch_layer_norm,),
885                kwargs={"dim": stack_dim},
886            )
887            update_stack_example_value(
888                batch_layer_norm_unbind,
889                batch_layer_norm.meta["example_value"],
890                op=torch.unbind,
891                dim=stack_dim,
892            )
893
894        for i, node in enumerate(group_nodes):
895            with graph.inserting_after(batch_layer_norm_unbind):
896                new_node = graph.call_function(
897                    operator.getitem, args=(batch_layer_norm_unbind, i)
898                )
899            node.replace_all_uses_with(new_node)
900            new_node.meta.update(node.meta)
901            graph.erase_node(node)
902        counters["inductor"]["batch_layernorm"] += 1
903
904
905class BatchPointwiseOpsPreGradFusion(BatchPointwiseOpsFusionFactory):
906    """
907    Batch pointwise ops (e.g., sigmoid, relu, tanh) fusion in pre grad pass.
908    We fuse it in random place, and the introduced stack node may be merged in split cat.
909    """
910
911    def __init__(self, op, **kwargs) -> None:
912        super().__init__(op, **kwargs)
913        self.op = op
914
915    def match(self, node: torch.fx.Node):
916        input = get_arg_value(node, 0, "input")
917        if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node):
918            if self.graph_search_options.get("fuse_nodes_with_same_parent", False):
919                # pyre-fixme[16]
920                parent = node.args[0]
921                parent = parent.target if parent is not None else ""  # type: ignore[union-attr]
922            else:
923                parent = ""
924            # for relu op, we also use the inplace to construct the key
925            group_key = (
926                "batch_" + self.op.__name__.lower().split(".")[0],
927                str(input.meta["example_value"].shape),
928                str(node.kwargs.get("inplace", False)),
929                str(parent),
930            )
931        else:
932            group_key = None
933        return group_key
934
935    def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
936        batch_nodes = []
937        batch_inputs = []
938        batch_inputs_metadata = []
939
940        for node in subset:
941            batch_nodes.append(node)
942            input = get_arg_value(node, 0, "input")
943            batch_inputs.append(input)
944            batch_inputs_metadata.append(input.meta["example_value"])
945
946        with graph.inserting_before(subset[0]):
947            stack_inputs = graph.call_function(
948                torch.stack, args=(batch_inputs,), kwargs={"dim": 0}
949            )
950            update_stack_example_value(stack_inputs, batch_inputs_metadata)
951            if self.op == torch.nn.functional.relu:
952                batch_op = graph.call_function(
953                    self.op,
954                    args=(stack_inputs,),
955                    kwargs={"inplace": subset[0].kwargs.get("inplace", False)},
956                )
957                batch_op.meta["example_value"] = self.op(
958                    stack_inputs.meta["example_value"],
959                    inplace=subset[0].kwargs.get("inplace", False),
960                )
961            else:
962                batch_op = graph.call_function(
963                    self.op,
964                    args=(stack_inputs,),
965                )
966                batch_op.meta["example_value"] = self.op(
967                    stack_inputs.meta["example_value"]
968                )
969            unbind_op = graph.call_function(
970                torch.unbind, args=(batch_op,), kwargs={"dim": 0}
971            )
972            unbind_op.meta["example_value"] = torch.unbind(
973                batch_op.meta["example_value"], dim=0
974            )
975            for i, node in enumerate(batch_nodes):
976                with graph.inserting_after(unbind_op):
977                    getitem = graph.call_function(operator.getitem, args=(unbind_op, i))
978                node.replace_all_uses_with(getitem)
979                getitem.meta.update(node.meta)
980                graph.erase_node(node)
981        counters["inductor"]["batch_" + self.op.__name__.lower().split(".")[0]] += 1
982
983
984class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory):
985    """
986    Batch pointwise ops (e.g., sigmoid, relu, tanh) fusion in post grad pass.
987    The introduced stack node may be merged in split cat.
988    """
989
990    def __init__(self, op, **kwargs) -> None:
991        super().__init__(op, **kwargs)
992        self.op = op
993
994    def match(self, node: torch.fx.Node):
995        input = get_arg_value(node, 0, "input")
996        if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node):
997            # for relu op, we also use the inplace to construct the key
998            # we batch the ops with same parent to enable followup split cat
999            parent = node.args[0]
1000            parent = parent.target if self.graph_search_options.get("fuse_nodes_with_same_parent", False) else ""  # type: ignore[union-attr]
1001            group_key = (
1002                "batch_aten_" + self.op.__name__.lower().split(".")[0],
1003                str(input.meta["val"].shape),
1004                str(node.kwargs.get("inplace", False)),
1005                # pyre-fixme[16]
1006                str(parent),
1007            )
1008        else:
1009            group_key = None
1010        return group_key
1011
1012    def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
1013        batch_nodes = []
1014        batch_inputs = []
1015        batch_inputs_metadata = []
1016
1017        for node in subset:
1018            batch_nodes.append(node)
1019            input = get_arg_value(node, 0, "input")
1020            batch_inputs.append(input)
1021            batch_inputs_metadata.append(input.meta["val"])
1022
1023        with graph.inserting_before(subset[0]):
1024            stack_inputs = decompose_stack(graph, batch_inputs)
1025            update_stack_example_value(stack_inputs, batch_inputs_metadata)
1026            batch_op = graph.call_function(
1027                self.op,
1028                args=(stack_inputs,),
1029            )
1030            for i, node in enumerate(batch_nodes):
1031                with graph.inserting_after(batch_op):
1032                    getitem = graph.call_function(aten.select, args=(batch_op, 0, i))
1033                node.replace_all_uses_with(getitem)
1034                getitem.meta.update(node.meta)
1035                graph.erase_node(node)
1036        counters["inductor"][
1037            "batch_aten_" + self.op.__name__.lower().split(".")[0]
1038        ] += 1
1039
1040
1041@register_fusion("batch_tanh")
1042class BatchTanhPreGradFusion(BatchPointwiseOpsPreGradFusion):
1043    def __init__(self, **kwargs) -> None:
1044        super().__init__(torch.tanh, **kwargs)
1045
1046
1047@register_fusion("batch_sigmoid")
1048class BatchSigmoidPreGradFusion(BatchPointwiseOpsPreGradFusion):
1049    def __init__(self, **kwargs) -> None:
1050        super().__init__(torch.sigmoid, **kwargs)
1051
1052
1053@register_fusion("batch_relu")
1054class BatchReLuPreGradFusion(BatchPointwiseOpsPreGradFusion):
1055    def __init__(self, **kwargs) -> None:
1056        super().__init__(torch.nn.functional.relu, **kwargs)
1057
1058
1059@register_fusion("batch_aten_tanh", pre_grad=False)
1060class BatchTanhPostGradFusion(BatchPointwiseOpsPostGradFusion):
1061    def __init__(self, **kwargs) -> None:
1062        super().__init__(aten.tanh.default, **kwargs)
1063
1064
1065@register_fusion("batch_aten_sigmoid", pre_grad=False)
1066class BatchSigmoidPostGradFusion(BatchPointwiseOpsPostGradFusion):
1067    def __init__(self, **kwargs) -> None:
1068        super().__init__(aten.sigmoid.default, **kwargs)
1069
1070
1071@register_fusion("batch_aten_relu", pre_grad=False)
1072class BatchReLuPostGradFusion(BatchPointwiseOpsPostGradFusion):
1073    def __init__(self, **kwargs) -> None:
1074        super().__init__(aten.relu.default, **kwargs)
1075
1076
1077@register_fusion("batch_aten_add", pre_grad=False)
1078class BatchAddPostGradFusion(BatchPointwiseMathOpsPostGradFusion):
1079    def __init__(self, **kwargs) -> None:
1080        super().__init__(aten.add.Tensor, **kwargs)
1081
1082
1083@register_fusion("batch_aten_sub", pre_grad=False)
1084class BatchSubPostGradFusion(BatchPointwiseMathOpsPostGradFusion):
1085    def __init__(self, **kwargs) -> None:
1086        super().__init__(aten.sub.Tensor, **kwargs)
1087
1088
1089@register_fusion("batch_aten_div", pre_grad=False)
1090class BatchDivPostGradFusion(BatchPointwiseMathOpsPostGradFusion):
1091    def __init__(self, **kwargs) -> None:
1092        super().__init__(aten.div.Tensor, **kwargs)
1093
1094
1095@register_fusion("batch_aten_mul", pre_grad=False)
1096class BatchMulPostGradFusion(BatchPointwiseMathOpsPostGradFusion):
1097    def __init__(self, **kwargs) -> None:
1098        super().__init__(aten.mul.Tensor, **kwargs)
1099
1100
1101class _OrderedSet:
1102    def __init__(self, param=None) -> None:
1103        if param:
1104            self.rep = OrderedDict(dict.fromkeys(param))
1105        else:
1106            self.rep = OrderedDict()
1107
1108    def __contains__(self, o) -> bool:
1109        return o in self.rep
1110
1111    def __len__(self) -> int:
1112        return self.rep.__len__()
1113
1114    def append(self, o):
1115        self.rep[o] = None
1116
1117    def __iter__(self):
1118        return self.rep.keys().__iter__()
1119
1120
1121def find_independent_subset_greedy(
1122    node_list: Iterable[torch.fx.Node],
1123    graph_search_options: Dict[str, Any],
1124) -> Iterator[Iterable[torch.fx.Node]]:
1125    """
1126    Yields a list of subsets of `node_list` where no element in the subset
1127    depends on any other element in the subset. This results in a set of
1128    independent nodes which can be fused together.
1129
1130    The order of `node_list` is preserved within each subset so we can benefit
1131    from split-cat elimination in later passes.
1132
1133    During iteration it is only safe to mutate the graph by changing the nodes
1134    that have been returned.
1135
1136    graph_search_options:
1137      - min_fuse_set_size: Minimum size of the subset to consider. Subsets below
1138        this size will be ignored.
1139      - max_fuse_set_size: Maximum size of the subset to consider. Subsets will
1140        be broken to be at most this size.
1141    """
1142
1143    # Compute all the children of `node` which are members of
1144    # `interesting_nodes`.
1145    def find_dependent_nodes(node, interesting_nodes):
1146        visited_node_set: Set[torch.fx.Node] = {node}
1147        dep_set: Set[torch.fx.Node] = set()
1148
1149        work = [node]
1150        while work:
1151            node = work.pop()
1152            for input_node in node.all_input_nodes:
1153                if input_node in interesting_nodes:
1154                    dep_set.add(input_node)
1155
1156                if input_node not in visited_node_set:
1157                    visited_node_set.add(input_node)
1158                    work.append(input_node)
1159
1160        return dep_set
1161
1162    min_fuse_set_size = graph_search_options["min_fuse_set_size"]
1163    max_fuse_set_size = graph_search_options["max_fuse_set_size"]
1164
1165    # node_list needs to be a set because we only track the nodes that are left
1166    # in it (and we want to do the `in` on a set, not a list). But we want to
1167    # keep the correct order.
1168    node_list = _OrderedSet(node_list)
1169
1170    cache: Dict[torch.fx.Node, Set[torch.fx.Node]] = {}
1171    while node_list:
1172        subset: List[torch.fx.Node] = []
1173        subset_deps: Set[torch.fx.Node] = set()
1174
1175        next_round_node_list = _OrderedSet()
1176        for node in node_list:
1177            if len(subset) >= max_fuse_set_size or node in subset_deps:
1178                next_round_node_list.append(node)
1179                continue
1180
1181            dep_set = cache.pop(node, None)
1182            if dep_set is None:
1183                dep_set = find_dependent_nodes(node, node_list)
1184
1185            if not dep_set.intersection(subset):
1186                subset.append(node)
1187                subset_deps.update(dep_set)
1188            else:
1189                next_round_node_list.append(node)
1190                cache[node] = dep_set
1191
1192        if len(subset) >= min_fuse_set_size:
1193            # Careful here - the caller uses the subsets to fuse nodes together
1194            # so we need to clear any cache entry that contains one of the
1195            # returned nodes because the dependency list could be different
1196            # (larger) after the merge.
1197            cache = {k: v for k, v in cache.items() if v.isdisjoint(subset)}
1198            yield subset
1199
1200        node_list = next_round_node_list
1201
1202
1203def get_fusion_candidates(
1204    rule: GroupBatchFusionBase, root_node: torch.fx.Node, fused_set: Set[torch.fx.Node]
1205) -> DefaultDict[Any, List[torch.fx.Node]]:
1206    """
1207    Search fusion candidates for a specific rule using BFS starting from the root node.
1208    We only search the subgraph within graph_search_options["max_fuse_search_depth"].
1209    """
1210    q: Deque[Tuple[int, torch.fx.Node]] = collections.deque()
1211
1212    candidate_dict: DefaultDict[Any, List[torch.fx.Node]] = collections.defaultdict(
1213        list
1214    )
1215
1216    if root_node.target in SEARCH_EXCLUSIONS:
1217        return candidate_dict
1218
1219    visited_set: Set[torch.fx.Node] = set()
1220
1221    for next_node in root_node.all_input_nodes:
1222        q.append((1, next_node))
1223        visited_set.add(next_node)
1224
1225    while len(q) > 0:
1226        depth, node = q.popleft()
1227
1228        if node in fused_set:
1229            continue
1230
1231        key = rule.match(node)
1232        if key is not None:
1233            candidate_nodes = candidate_dict[key]
1234            if node not in candidate_nodes:
1235                candidate_nodes.append(node)
1236        else:
1237            if depth < rule.graph_search_options["max_fuse_search_depth"]:
1238                for next_node in node.all_input_nodes:
1239                    if next_node not in visited_set:
1240                        visited_set.add(next_node)
1241                        q.append((depth + 1, next_node))
1242
1243    return candidate_dict
1244
1245
1246def apply_group_batch_fusion(graph: torch.fx.GraphModule, rule: GroupBatchFusionBase):
1247    stable_topological_sort(graph)  # type: ignore[arg-type]
1248    fused_set: Set[torch.fx.Node] = set()
1249    log_to_scuba = False
1250
1251    for node in reversed(graph.nodes):
1252        candidates = get_fusion_candidates(rule, node, fused_set)
1253
1254        for key, candidate_nodes in candidates.items():
1255            if len(candidate_nodes) < rule.graph_search_options["min_fuse_set_size"]:
1256                continue
1257
1258            for subset in find_independent_subset_greedy(
1259                candidate_nodes, rule.graph_search_options
1260            ):
1261                rule.fuse(graph, subset)
1262                fused_set.update(subset)
1263                log.debug(
1264                    f"{rule.__class__.__name__}: key = {key}; subset size = {len(list(subset))}"  # noqa: G004
1265                )
1266                log_to_scuba = True
1267    if log_to_scuba:
1268        optimus_scuba_log[rule.__class__.__name__] = upload_graph(graph)
1269
1270
1271def generate_fusion_from_config(config_options: Dict[str, Any], pre_grad=True):
1272    fusions: List[GroupBatchFusionBase] = []
1273    for name, options in config_options.items():
1274        # we skip all patterns from pattern_matcher passes (e.g., split_cat)
1275        if name not in PRE_GRAD_FUSIONS and name not in POST_GRAD_FUSIONS:
1276            continue
1277        fusion_cls = PRE_GRAD_FUSIONS[name] if pre_grad else POST_GRAD_FUSIONS[name]
1278        _options = graph_search_options.copy()
1279        _options.update(options)
1280        fusions.append(fusion_cls(graph_search_options=_options))  # type: ignore[operator]
1281    return fusions
1282
1283
1284def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True):
1285    fusions: List[GroupBatchFusionBase] = []
1286    # we keep all current pre grad fusions to keep
1287    # current implementation, will remove this later
1288    if pre_grad:
1289        fusions += generate_fusion_from_config(
1290            config.pre_grad_fusion_options, pre_grad=True
1291        )
1292    else:
1293        fbgemm_fusion_keys = [
1294            x
1295            for x in config.post_grad_fusion_options
1296            if config.post_grad_fusion_options[x].get("require_fbgemm", False)
1297        ]
1298        fbgemm_fusions = {
1299            fusion: config.post_grad_fusion_options[fusion]
1300            for fusion in fbgemm_fusion_keys
1301        }
1302        non_fbgemm_fusions = {
1303            fusion: config.post_grad_fusion_options[fusion]
1304            for fusion in config.post_grad_fusion_options.keys()
1305            if fusion not in fbgemm_fusion_keys
1306        }
1307        fusions += generate_fusion_from_config(non_fbgemm_fusions, pre_grad=False)
1308        if has_fbgemm:
1309            fusions += generate_fusion_from_config(fbgemm_fusions, pre_grad=False)
1310
1311    for i, rule in enumerate(fusions):
1312        with GraphTransformObserver(
1313            graph.owning_module,
1314            f"group_batch_fusion_{i}",
1315            config.trace.log_url_for_graph_xform,
1316        ):
1317            apply_group_batch_fusion(graph, rule)  # type: ignore[arg-type]
1318