• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8from typing import List
9
10import torch
11
12from executorch.backends.transforms import get_shape
13from executorch.backends.transforms.addmm_mm_to_linear import (
14    apply_addmm_mm_to_linear_transform,
15)
16from executorch.exir.dialects._ops import ops as exir_ops
17from executorch.exir.pass_base import ExportPass
18
19from torch.fx.passes.infra.pass_base import PassResult
20from torch.fx.passes.utils.source_matcher_utils import (
21    get_source_partitions,
22    SourcePartition,
23)
24
25logger = logging.getLogger(__name__)
26logger.setLevel(logging.WARNING)
27
28
29class ConvertToLinearPass(ExportPass):
30    linear_modules = [
31        torch.nn.Linear,
32        torch.nn.functional.linear,
33    ]
34
35    targets = [
36        exir_ops.edge.aten.mm.default,
37        exir_ops.edge.aten.addmm.default,
38        exir_ops.edge.aten.bmm.default,
39    ]
40
41    @staticmethod
42    def find(
43        node: torch.fx.Node,
44        args: List[torch.fx.Node],
45        kind: str = "args",
46        index: int = 0,
47    ):
48        # This is a hack to support lifted graphs.
49        # TODO(T171263351) - fix source partitioning for lifted graphs
50        if not node or node in args or node.op == "placeholder":
51            return node
52        if kind == "args":
53            other = node.args[index]
54        elif kind == "users":
55            other = list(node.users.keys())[index]
56        else:
57            raise AssertionError(f"Unexpected kind: {kind}")
58        return ConvertToLinearPass.find(other, args, kind)  # pyre-ignore[6]
59
60    @staticmethod
61    def get_arg(node: torch.fx.Node, arg: str):
62        if node.target == exir_ops.edge.aten.addmm.default:
63            map_ = {
64                "bias": 0,
65                "input": 1,
66                "weight": 2,
67            }
68            return node.args[map_[arg]]
69        else:
70            map_ = {"input": 0, "weight": 1}
71            return None if arg == "bias" else node.args[map_[arg]]
72
73    def find_bias_for_mm(self, src_partition: SourcePartition, mm_node: torch.fx.Node):
74        """
75        For linear decomposed with mm + add, find bias in src partition
76        """
77
78        mm_users = list(mm_node.users.keys())
79        if len(mm_users) != 1:
80            return None
81
82        add_node = mm_users[0]
83        if add_node.target != exir_ops.edge.aten.add.Tensor:
84            return None
85
86        for arg in add_node.all_input_nodes:
87            if arg != mm_node and arg in src_partition.input_nodes:
88                return arg
89
90        return None
91
92    def create_linear(
93        self,
94        graph_module: torch.fx.GraphModule,
95        node: torch.fx.Node,
96        src_partition: SourcePartition,
97    ):
98        logger.debug(f"Source Partition: {src_partition}")
99        linear_input = self.find(
100            self.get_arg(node, "input"),
101            src_partition.input_nodes,
102        )
103        logger.debug(f"Found input: {linear_input} from node {node}")
104
105        linear_weight = self.find(
106            self.get_arg(node, "weight"),
107            src_partition.input_nodes
108            + src_partition.params,  # non quant weight can be in params
109        )
110        logger.debug(f"Found weight: {linear_weight} from node {node}")
111
112        linear_bias = self.find(
113            self.get_arg(node, "bias"),
114            src_partition.input_nodes + src_partition.params,  # bias can be in params
115        )
116        if linear_bias is None and node.target == exir_ops.edge.aten.mm.default:
117            linear_bias = self.find_bias_for_mm(src_partition, node)
118
119        logger.debug(f"Found bias(?): {linear_bias} from node {node}")
120
121        # Ignore dynamic shape nodes
122        outputs = [
123            node
124            for node in src_partition.output_nodes
125            if node.target != torch.ops.aten.sym_size.int and node.op != "placeholder"
126        ]
127        assert (
128            len(outputs) == 1
129        ), f"Unexpected number of outputs for a torch.nn.Linear module, expecting 1 but got {outputs}"
130        output = outputs[0]
131
132        with graph_module.graph.inserting_before(output):
133            args = (linear_input, linear_weight)
134            if linear_bias is not None:
135                args += (linear_bias,)
136            linear_node = graph_module.graph.create_node(
137                "call_function",
138                exir_ops.edge.aten.linear.default,  # HACK not edge_op/CATen
139                args,
140            )
141        # TODO - calculate output even when dynamic_shape=True
142        linear_node.meta["val"] = torch.zeros(get_shape(output))
143        logger.debug(
144            f"Replacing {output}{get_shape(output)} node with {linear_node}{get_shape(linear_node)}"
145        )
146        output.replace_all_uses_with(linear_node)
147        graph_module.graph.eliminate_dead_code()
148
149    # override
150    def call(self, graph_module: torch.fx.GraphModule):
151        logger.debug("ConvertToLinear Begin: ")
152        logger.debug(graph_module.print_readable(print_output=False))
153
154        processed_partitions = 0
155        while True:
156            src_partition_dict = get_source_partitions(
157                graph_module.graph, self.linear_modules
158            )
159
160            src_node_dict = {
161                node: src_partition
162                for src_partitions in src_partition_dict.values()
163                for src_partition in src_partitions
164                for node in src_partition.nodes
165                if node.target in self.targets
166            }
167
168            # No more [add]mm target in source partitions
169            if len(src_node_dict) == 0:
170                if processed_partitions == 0:
171                    logger.debug(
172                        "Did not find any [add]mm target in source partitions, skipping the pass."
173                    )
174                else:
175                    logger.debug(
176                        f"Converted {processed_partitions} [add]mm target(s) into Linear."
177                    )
178                break
179
180            logger.debug("Converting [add]mm into Linear")
181            for node in src_node_dict.keys():
182                self.create_linear(graph_module, node, src_node_dict[node])
183                processed_partitions += 1
184                # Only convert the first [add]mm target
185                break
186
187        # fall back to linear transform
188        graph_module.graph = apply_addmm_mm_to_linear_transform(graph_module.graph)
189
190        graph_module.recompile()
191
192        # Propagate metadata and retrace module
193        graph_module = super().call(graph_module).graph_module
194
195        logger.debug("ConvertToLinear End: ")
196        logger.debug(graph_module.print_readable(print_output=False))
197
198        return PassResult(graph_module, True)
199