• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# Copyright 2024 Arm Limited and/or its affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8# pyre-unsafe
9
10from typing import cast
11
12from executorch.backends.arm._passes.arm_pass_utils import (
13    create_node,
14    get_first_fake_tensor,
15)
16
17from executorch.exir.dialects._ops import ops as exir_ops
18
19from executorch.exir.pass_base import ExportPass, PassResult
20from torch.fx import GraphModule, Node
21
22
23class MatchArgRanksPass(ExportPass):
24    """
25    For ops in 'targeted_ops', make sure that the inputs share the same rank.
26    New dimensions are inserted at from the beginning of the
27    """
28
29    def __init__(self, exported_program):
30        super().__init__()
31        self.exported_program = exported_program
32
33    targeted_ops = [
34        exir_ops.edge.aten.add.Tensor,
35        exir_ops.edge.aten.sub.Tensor,
36        exir_ops.edge.aten.mul.Tensor,
37        exir_ops.edge.aten.div.Tensor,
38    ]
39
40    def _match_op_rank(self, graph_module, node, arg, max_rank):
41        """
42        In graph_module, insert a view between arg and node to make the
43        rank of arg match the other args to node.
44        """
45        shape = get_first_fake_tensor(arg).shape
46        rank = len(shape)
47        new_shape = list([1] * (max_rank - rank) + list(shape))
48        with graph_module.graph.inserting_before(node):
49            view = create_node(
50                graph_module.graph,
51                exir_ops.edge.aten.view_copy.default,
52                args=(arg, new_shape),
53                kwargs={},
54            )
55            node.replace_input_with(arg, view)
56
57    def _match_buffer_rank(self, arg, max_rank):
58        """
59        Change arg's fake tensor meta to match max_rank if:
60            - arg is found in inputs_to_buffers or inputs_to_parameters.
61        """
62        fake_tensor = get_first_fake_tensor(arg)
63        shape = fake_tensor.shape
64        rank = len(shape)
65        new_shape = list([1] * (max_rank - rank) + list(shape))
66
67        buffer_name = None
68        if arg.name in self.exported_program.graph_signature.inputs_to_buffers:
69            buffer_name = self.exported_program.graph_signature.inputs_to_buffers[
70                arg.name
71            ]
72        elif arg.name in self.exported_program.graph_signature.inputs_to_parameters:
73            buffer_name = self.exported_program.graph_signature.inputs_to_parameters[
74                arg.name
75            ]
76        if buffer_name:
77            new_tensor = self.exported_program.state_dict[buffer_name].reshape(
78                new_shape
79            )
80            self.exported_program.state_dict[buffer_name] = new_tensor
81            arg.meta["val"] = fake_tensor.fake_mode.from_tensor(
82                new_tensor, static_shapes=True
83            )
84
85    def call(self, graph_module: GraphModule) -> PassResult:
86        for node in graph_module.graph.nodes:
87            node = cast(Node, node)
88
89            if node.op != "call_function" or node.target not in self.targeted_ops:
90                continue
91
92            # Calculate max rank of all inputs to node
93            max_rank = 1
94            for arg in node.args:
95                if isinstance(arg, Node):
96                    shape = get_first_fake_tensor(arg).shape
97                    max_rank = max(max_rank, len(shape))
98
99            # Adjust output shape of args if needed.
100            for arg in node.args:
101                if not isinstance(arg, Node):
102                    continue
103                shape = get_first_fake_tensor(arg).shape
104                rank = len(shape)
105                if rank == max_rank:
106                    continue
107
108                # If the argument is call_function, match shape by inserting view node.
109                if arg.op == "call_function":
110                    self._match_op_rank(graph_module, node, arg, max_rank)
111                else:
112                    # If the argument is a buffer or parameter, adjust shape by changing the fake tensor meta.
113                    self._match_buffer_rank(arg, max_rank)
114
115        graph_module.recompile()
116        graph_module = super().call(graph_module).graph_module
117        return PassResult(graph_module, True)
118