1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9from typing import cast 10 11import torch 12from executorch.backends.arm._passes.arm_pass_utils import ( 13 create_node, 14 get_first_fake_tensor, 15 insert_q_dq_pair, 16) 17from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, register_passable_op 18from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d 19from executorch.exir.dialects._ops import ops as exir_ops 20from executorch.exir.pass_base import ExportPass, PassResult 21from torch.library import impl, Library 22 23# Define lib with passthrough operators. The operators have no real meaning in edge IR 24# except for argument validaiton and a passthrough output. The operators will be used 25# when lowering to TOSA, e.g. a passthrough_to_tosa._transpose will not affect 26# the edge IR graph but will be lowered to a TOSA-TRANSPOSE. 27lib = Library("passthrough_to_tosa", "DEF") 28# For operators that change the rank of the input, such as unsqueeze and squeeze, we may need 29# to switch dim_order before the opertation. Changing tosa_dim_order is not sufficient 30# as we also need transpose the data into the correct data format. 31# By utilizing an edge IR passthrough operator we can keep the edge program in 32# channels-first/contiguous and get the desired behavior in the TOSA lowering. 33lib.define("_transpose(Tensor self, int[] dim_order) -> Tensor") 34 35 36@impl(lib, "_transpose") 37def _transpose_impl(*args, **kwargs): 38 # Validate length of dim_order array 39 dim = args[1] 40 assert len(dim) <= 4 41 # Pass-through in edge-IR 42 return args[0] 43 44 45register_passable_op(torch.ops.passthrough_to_tosa._transpose) 46 47 48class AnnotateChannelsLastDimOrder(ExportPass): 49 """ 50 Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order 51 that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts passthrough_to_tosa._transpose 52 when a transition between 3D and 4D tensors happen. 53 The annotated tosa_dim_order is used to permute the node's shape such that it gives a TOSA-compliant shape. 54 """ 55 56 NHWC_order = (0, 2, 3, 1) 57 NHWC_inverse_order = (0, 3, 1, 2) 58 HWCM_order = (2, 3, 0, 1) 59 60 def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node): 61 """ 62 returns True for dq and w in the following sequences; 63 w -> depthwise_conv2d -> ... 64 w -> dq -> depthwise_conv2d -> ... 65 """ 66 if node.op == "call_function": 67 if node.target != dq_op: 68 return False 69 prev_node = node.args[0] 70 if cast(torch.fx.Node, prev_node).op != "placeholder": 71 return False 72 if is_consumer_node_depthwise_conv2d(node): 73 consumer_node = list(node.users)[0] 74 return consumer_node.args[1] == node 75 elif node.op == "placeholder": 76 # node is an input, weight or bias node 77 consumer_node = list(node.users)[0] 78 if self.is_weight_node_for_depthwise_conv2d(consumer_node): 79 return True 80 if is_consumer_node_depthwise_conv2d(node): 81 # Check that node is the weight-argument and not input or bias 82 return consumer_node.args[1] == node 83 84 return False 85 86 def insert_input_transpose(self, node, input_node, graph_module): 87 quantize = input_node.target == dq_op 88 q_params = input_node.args[1:] if quantize else None 89 with graph_module.graph.inserting_before(node): 90 permute_node = create_node( 91 graph_module.graph, 92 torch.ops.passthrough_to_tosa._transpose, 93 args=(input_node, list(self.NHWC_inverse_order)), 94 quantize=quantize, 95 q_params=q_params, 96 ) 97 node.replace_input_with(input_node, permute_node) 98 99 permute_node.meta["tosa_dim_order"] = tuple( 100 range(len(input_node.meta["val"].size())) 101 ) 102 103 def insert_output_transpose(self, node, graph_module): 104 with graph_module.graph.inserting_after(node): 105 permute_node = create_node( 106 graph_module.graph, 107 torch.ops.passthrough_to_tosa._transpose, 108 args=(node, list(self.NHWC_order)), 109 ) 110 permute_node.meta["tosa_dim_order"] = self.NHWC_order 111 node.meta["tosa_dim_order"] = (0, 1, 2, 3) 112 users = [user for user in node.users if user != permute_node] 113 for user in users: 114 user.replace_input_with(node, permute_node) 115 116 quantize = node.args[0] == q_op 117 if quantize: 118 q_params = node.args[0].args[1:] 119 insert_q_dq_pair(graph_module.graph, node, q_params) 120 121 def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): 122 """ 123 Reshape operations are not equivalent in NCHW and NHWC. 124 To get around this, transposes need to be added if the previous or new shape 125 fulfil the following condition: 126 C > 1 and (H or W > 1) 127 128 This is relevant for the following operations; 129 squeeze: 4D -> 3D 130 unsqueeze: <4D -> 4D 131 view: <4D -> 4D 132 view: 4D -> <4D 133 view: 4D -> 4D 134 """ 135 136 def transpose_condition(shape): 137 if len(shape) != 4: 138 return False 139 C = shape[1] 140 H = shape[2] 141 W = shape[3] 142 return C > 1 and (H > 1 or W > 1) 143 144 for node in graph_module.graph.nodes: 145 if node.op != "call_function": 146 continue 147 if node.target == exir_ops.edge.aten.squeeze_copy.dims: 148 input_node = node.args[0] 149 input_shape = input_node.meta["val"].shape 150 if transpose_condition(input_shape): 151 self.insert_input_transpose(node, input_node, graph_module) 152 153 elif node.target == exir_ops.edge.aten.unsqueeze_copy.default: 154 output_shape = node.meta["val"].shape 155 if transpose_condition(output_shape): 156 self.insert_output_transpose(node, graph_module) 157 158 elif node.target == exir_ops.edge.aten.view_copy.default: 159 input_node = node.args[0] 160 161 old_shape = input_node.meta["val"].shape 162 new_shape = node.meta["val"].shape 163 164 if transpose_condition(old_shape): 165 self.insert_input_transpose(node, input_node, graph_module) 166 167 if transpose_condition(new_shape): 168 self.insert_output_transpose(node, graph_module) 169 170 def call(self, graph_module: torch.fx.GraphModule): 171 for node in graph_module.graph.nodes: 172 node_data = get_first_fake_tensor(node).data 173 174 if node_data.dim() == 4: 175 dim_order = self.NHWC_order 176 if self.is_weight_node_for_depthwise_conv2d(node): 177 # The weights of TOSA DEPTHWISE_CONV2D have shape (H, W, C, M) which corresponds to 178 # dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d). 179 dim_order = self.HWCM_order 180 else: 181 dim_order = tuple(range(node_data.dim())) 182 node.meta["tosa_dim_order"] = dim_order 183 # Take care of cases when: 184 # 4D (NHWC) -> >4D (NCH) 185 # 3D (NCH) -> 4D (NHWC) 186 self.insert_tosa_transposes(graph_module) 187 graph_module.recompile() 188 graph_module = super().call(graph_module).graph_module 189 190 return PassResult(graph_module, True) 191