• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# Copyright 2024 Arm Limited and/or its affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8# pyre-unsafe
9
10
11import torch
12from executorch.backends.arm._passes.arm_pass_utils import (
13    create_node,
14    get_param_tensor,
15    insert_q_dq_pair,
16    is_param_node,
17)
18from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
19from executorch.exir import ExportedProgram
20from executorch.exir.dialects._ops import ops as exir_ops
21from executorch.exir.pass_base import ExportPass, PassResult
22
23
24class Conv1dUnsqueezePass(ExportPass):
25    """
26    This pass is used to change conv1d ops into conv2d since TOSA only
27    supports 2d and 3d convolution. This is done by modifying the graph to do the
28    following:
29    1) unsqueeze the convolution's input from 3d to 4d
30    2) if the input to unsqueeze is quantized, insert q/dq-pair after unsqueeze
31    3) perform a conv2d (with a modified version of the original conv1d args)
32    4) squeeze the output back down to 3d.
33    5) if all users of squeeze are quantized, insert q/dq-pair before squeeze
34    """
35
36    def __init__(self, exported_program: ExportedProgram) -> None:
37        super().__init__()
38        self.exported_program = exported_program
39
40    def unsqueeze_kernel_weights(self, kernel_node):
41        """
42        Unsqueezes the weights of a conv1d to make it 4 dimensional.
43
44        Args:
45            kernel_node: the weights of conv1d node to be unsqueezed
46        """
47        kernel_param_3d = get_param_tensor(self.exported_program, kernel_node)
48        if kernel_param_3d is None:
49            raise AssertionError("Expected param tensor for the kernel node")
50
51        kernel_param_4d = torch.nn.Parameter(
52            data=kernel_param_3d.data.contiguous().unsqueeze(dim=-1),
53            requires_grad=False,
54        )
55
56        if torch._export.utils.is_param(self.exported_program, kernel_node):
57            parameter_name = self.exported_program.graph_signature.inputs_to_parameters[
58                kernel_node.name
59            ]
60            self.exported_program.state_dict[parameter_name] = kernel_param_4d
61            kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1)
62        elif torch._export.utils.is_buffer(self.exported_program, kernel_node):
63            buffer_name = self.exported_program.graph_signature.inputs_to_buffers[
64                kernel_node.name
65            ]
66            self.exported_program.state_dict[buffer_name] = kernel_param_4d
67            kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1)
68        elif torch._export.utils.is_lifted_tensor_constant(
69            self.exported_program, kernel_node
70        ):
71            buffer_name = (
72                self.exported_program.graph_signature.inputs_to_lifted_tensor_constants[
73                    kernel_node.name
74                ]
75            )
76            self.exported_program.constants[buffer_name] = kernel_param_4d
77            kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1)
78        else:
79            setattr(
80                kernel_node.graph.owning_module,
81                kernel_node.target,
82                kernel_param_4d,
83            )
84
85    def call(self, graph_module: torch.fx.GraphModule):
86        graph = graph_module.graph
87        node_list = list(graph.nodes)
88        for node in node_list:
89            if node.op == "call_function":
90                if node.target == exir_ops.edge.aten.convolution.default:
91                    stride = list(node.args[3])
92                    if len(stride) != 1:
93                        # skip conv if it is not 1d
94                        continue
95
96                    kernel_node = node.args[1]
97                    if kernel_node.target == dq_op:
98                        kernel_node = kernel_node.args[0]
99
100                    if not is_param_node(self.exported_program, kernel_node):
101                        raise AssertionError(
102                            "Expected op for convolution weight node to be a get_attr node or a parameter"
103                        )
104
105                    # Modify graph such that the conv changes from 1d to 2d
106                    self.unsqueeze_kernel_weights(kernel_node)
107
108                    # (b) Extend stride, padding, and dilation for extra dim
109                    node.args = (
110                        node.args[0],
111                        node.args[1],
112                        node.args[2],
113                        node.args[3] + [1],  # stride
114                        node.args[4] + [0],  # padding
115                        node.args[5] + [1],  # dilation
116                        node.args[6],
117                        node.args[7] + [0],
118                        node.args[8],
119                    )
120
121                    # c. Add unsqueeze to input (3d -> 4d) and squeeze to output (4d -> 3d)
122                    # unsqueeze -> conv2d -> squeeze
123                    with graph.inserting_before(node):
124                        input_node = node.args[0]
125                        unsqueeze_before = create_node(
126                            graph, exir_ops.edge.aten.unsqueeze_copy.default
127                        )
128                        unsqueeze_before.args = (
129                            input_node,  # Input is node's original input
130                            -1,  # Last Dimension
131                        )
132                        node.replace_input_with(input_node, unsqueeze_before)
133
134                    # If Quantized we must insert unsqueeze --> q --> dq --> node
135                    if input_node.target == dq_op:
136                        q_params = input_node.args[1:]
137                        insert_q_dq_pair(graph, unsqueeze_before, q_params)
138
139                    with graph.inserting_after(node):
140                        squeeze_after = create_node(
141                            graph,
142                            exir_ops.edge.aten.squeeze_copy.dims,
143                        )
144                        squeeze_after.args = (
145                            node,  # Input is the conv node
146                            [-1],  # Last dimension
147                        )
148                        original_users = [
149                            user for user in node.users if user != squeeze_after
150                        ]
151                        for user in original_users:
152                            user.replace_input_with(node, squeeze_after)
153
154                        # If quantized, insert conv2d --> q --> dq --> squeeze
155                    if all(
156                        original_user.target == q_op for original_user in original_users
157                    ):
158                        q_params = original_users[0].args[1:]
159                        insert_q_dq_pair(graph, node, q_params)
160
161        graph_module.recompile()
162        # Since we are overriding "call", we need to call the parent's "call"
163        # to retrace the graph and regenerate metadata
164        graph_module = super().call(graph_module).graph_module
165
166        return PassResult(graph_module, True)
167