• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9from typing import cast
10
11import torch
12from executorch.backends.arm._passes.arm_pass_utils import (
13    create_node,
14    get_first_fake_tensor,
15    insert_q_dq_pair,
16)
17from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, register_passable_op
18from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
19from executorch.exir.dialects._ops import ops as exir_ops
20from executorch.exir.pass_base import ExportPass, PassResult
21from torch.library import impl, Library
22
23# Define lib with passthrough operators. The operators have no real meaning in edge IR
24# except for argument validaiton and a passthrough output. The operators will be used
25# when lowering to TOSA, e.g. a passthrough_to_tosa._transpose will not affect
26# the edge IR graph but will be lowered to a TOSA-TRANSPOSE.
27lib = Library("passthrough_to_tosa", "DEF")
28# For operators that change the rank of the input, such as unsqueeze and squeeze, we may need
29# to switch dim_order before the opertation. Changing tosa_dim_order is not sufficient
30# as we also need transpose the data into the correct data format.
31# By utilizing an edge IR passthrough operator we can keep the edge program in
32# channels-first/contiguous and get the desired behavior in the TOSA lowering.
33lib.define("_transpose(Tensor self, int[] dim_order) -> Tensor")
34
35
36@impl(lib, "_transpose")
37def _transpose_impl(*args, **kwargs):
38    # Validate length of dim_order array
39    dim = args[1]
40    assert len(dim) <= 4
41    # Pass-through in edge-IR
42    return args[0]
43
44
45register_passable_op(torch.ops.passthrough_to_tosa._transpose)
46
47
48class AnnotateChannelsLastDimOrder(ExportPass):
49    """
50    Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order
51    that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts passthrough_to_tosa._transpose
52    when a transition between 3D and 4D tensors happen.
53    The annotated tosa_dim_order is used to permute the node's shape such that it gives a TOSA-compliant shape.
54    """
55
56    NHWC_order = (0, 2, 3, 1)
57    NHWC_inverse_order = (0, 3, 1, 2)
58    HWCM_order = (2, 3, 0, 1)
59
60    def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
61        """
62        returns True for dq and w in the following sequences;
63        w -> depthwise_conv2d -> ...
64        w -> dq -> depthwise_conv2d -> ...
65        """
66        if node.op == "call_function":
67            if node.target != dq_op:
68                return False
69            prev_node = node.args[0]
70            if cast(torch.fx.Node, prev_node).op != "placeholder":
71                return False
72            if is_consumer_node_depthwise_conv2d(node):
73                consumer_node = list(node.users)[0]
74                return consumer_node.args[1] == node
75        elif node.op == "placeholder":
76            # node is an input, weight or bias node
77            consumer_node = list(node.users)[0]
78            if self.is_weight_node_for_depthwise_conv2d(consumer_node):
79                return True
80            if is_consumer_node_depthwise_conv2d(node):
81                # Check that node is the weight-argument and not input or bias
82                return consumer_node.args[1] == node
83
84        return False
85
86    def insert_input_transpose(self, node, input_node, graph_module):
87        quantize = input_node.target == dq_op
88        q_params = input_node.args[1:] if quantize else None
89        with graph_module.graph.inserting_before(node):
90            permute_node = create_node(
91                graph_module.graph,
92                torch.ops.passthrough_to_tosa._transpose,
93                args=(input_node, list(self.NHWC_inverse_order)),
94                quantize=quantize,
95                q_params=q_params,
96            )
97            node.replace_input_with(input_node, permute_node)
98
99            permute_node.meta["tosa_dim_order"] = tuple(
100                range(len(input_node.meta["val"].size()))
101            )
102
103    def insert_output_transpose(self, node, graph_module):
104        with graph_module.graph.inserting_after(node):
105            permute_node = create_node(
106                graph_module.graph,
107                torch.ops.passthrough_to_tosa._transpose,
108                args=(node, list(self.NHWC_order)),
109            )
110            permute_node.meta["tosa_dim_order"] = self.NHWC_order
111            node.meta["tosa_dim_order"] = (0, 1, 2, 3)
112            users = [user for user in node.users if user != permute_node]
113            for user in users:
114                user.replace_input_with(node, permute_node)
115
116            quantize = node.args[0] == q_op
117            if quantize:
118                q_params = node.args[0].args[1:]
119                insert_q_dq_pair(graph_module.graph, node, q_params)
120
121    def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
122        """
123        Reshape operations are not equivalent in NCHW and NHWC.
124        To get around this, transposes need to be added if the previous or new shape
125        fulfil the following condition:
126            C > 1 and (H or W > 1)
127
128        This is relevant for the following operations;
129        squeeze:     4D ->  3D
130        unsqueeze:  <4D ->  4D
131        view:       <4D ->  4D
132        view:        4D -> <4D
133        view:        4D ->  4D
134        """
135
136        def transpose_condition(shape):
137            if len(shape) != 4:
138                return False
139            C = shape[1]
140            H = shape[2]
141            W = shape[3]
142            return C > 1 and (H > 1 or W > 1)
143
144        for node in graph_module.graph.nodes:
145            if node.op != "call_function":
146                continue
147            if node.target == exir_ops.edge.aten.squeeze_copy.dims:
148                input_node = node.args[0]
149                input_shape = input_node.meta["val"].shape
150                if transpose_condition(input_shape):
151                    self.insert_input_transpose(node, input_node, graph_module)
152
153            elif node.target == exir_ops.edge.aten.unsqueeze_copy.default:
154                output_shape = node.meta["val"].shape
155                if transpose_condition(output_shape):
156                    self.insert_output_transpose(node, graph_module)
157
158            elif node.target == exir_ops.edge.aten.view_copy.default:
159                input_node = node.args[0]
160
161                old_shape = input_node.meta["val"].shape
162                new_shape = node.meta["val"].shape
163
164                if transpose_condition(old_shape):
165                    self.insert_input_transpose(node, input_node, graph_module)
166
167                if transpose_condition(new_shape):
168                    self.insert_output_transpose(node, graph_module)
169
170    def call(self, graph_module: torch.fx.GraphModule):
171        for node in graph_module.graph.nodes:
172            node_data = get_first_fake_tensor(node).data
173
174            if node_data.dim() == 4:
175                dim_order = self.NHWC_order
176                if self.is_weight_node_for_depthwise_conv2d(node):
177                    # The weights of TOSA DEPTHWISE_CONV2D have shape (H, W, C, M) which corresponds to
178                    # dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d).
179                    dim_order = self.HWCM_order
180            else:
181                dim_order = tuple(range(node_data.dim()))
182            node.meta["tosa_dim_order"] = dim_order
183        # Take care of cases when:
184        # 4D (NHWC) -> >4D (NCH)
185        # 3D (NCH)  ->  4D (NHWC)
186        self.insert_tosa_transposes(graph_module)
187        graph_module.recompile()
188        graph_module = super().call(graph_module).graph_module
189
190        return PassResult(graph_module, True)
191