1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# Copyright 2024 Arm Limited and/or its affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8# pyre-unsafe 9 10 11import torch 12from executorch.backends.arm._passes.arm_pass_utils import ( 13 create_node, 14 get_param_tensor, 15 insert_q_dq_pair, 16 is_param_node, 17) 18from executorch.backends.arm.tosa_quant_utils import dq_op, q_op 19from executorch.exir import ExportedProgram 20from executorch.exir.dialects._ops import ops as exir_ops 21from executorch.exir.pass_base import ExportPass, PassResult 22 23 24class Conv1dUnsqueezePass(ExportPass): 25 """ 26 This pass is used to change conv1d ops into conv2d since TOSA only 27 supports 2d and 3d convolution. This is done by modifying the graph to do the 28 following: 29 1) unsqueeze the convolution's input from 3d to 4d 30 2) if the input to unsqueeze is quantized, insert q/dq-pair after unsqueeze 31 3) perform a conv2d (with a modified version of the original conv1d args) 32 4) squeeze the output back down to 3d. 33 5) if all users of squeeze are quantized, insert q/dq-pair before squeeze 34 """ 35 36 def __init__(self, exported_program: ExportedProgram) -> None: 37 super().__init__() 38 self.exported_program = exported_program 39 40 def unsqueeze_kernel_weights(self, kernel_node): 41 """ 42 Unsqueezes the weights of a conv1d to make it 4 dimensional. 43 44 Args: 45 kernel_node: the weights of conv1d node to be unsqueezed 46 """ 47 kernel_param_3d = get_param_tensor(self.exported_program, kernel_node) 48 if kernel_param_3d is None: 49 raise AssertionError("Expected param tensor for the kernel node") 50 51 kernel_param_4d = torch.nn.Parameter( 52 data=kernel_param_3d.data.contiguous().unsqueeze(dim=-1), 53 requires_grad=False, 54 ) 55 56 if torch._export.utils.is_param(self.exported_program, kernel_node): 57 parameter_name = self.exported_program.graph_signature.inputs_to_parameters[ 58 kernel_node.name 59 ] 60 self.exported_program.state_dict[parameter_name] = kernel_param_4d 61 kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1) 62 elif torch._export.utils.is_buffer(self.exported_program, kernel_node): 63 buffer_name = self.exported_program.graph_signature.inputs_to_buffers[ 64 kernel_node.name 65 ] 66 self.exported_program.state_dict[buffer_name] = kernel_param_4d 67 kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1) 68 elif torch._export.utils.is_lifted_tensor_constant( 69 self.exported_program, kernel_node 70 ): 71 buffer_name = ( 72 self.exported_program.graph_signature.inputs_to_lifted_tensor_constants[ 73 kernel_node.name 74 ] 75 ) 76 self.exported_program.constants[buffer_name] = kernel_param_4d 77 kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1) 78 else: 79 setattr( 80 kernel_node.graph.owning_module, 81 kernel_node.target, 82 kernel_param_4d, 83 ) 84 85 def call(self, graph_module: torch.fx.GraphModule): 86 graph = graph_module.graph 87 node_list = list(graph.nodes) 88 for node in node_list: 89 if node.op == "call_function": 90 if node.target == exir_ops.edge.aten.convolution.default: 91 stride = list(node.args[3]) 92 if len(stride) != 1: 93 # skip conv if it is not 1d 94 continue 95 96 kernel_node = node.args[1] 97 if kernel_node.target == dq_op: 98 kernel_node = kernel_node.args[0] 99 100 if not is_param_node(self.exported_program, kernel_node): 101 raise AssertionError( 102 "Expected op for convolution weight node to be a get_attr node or a parameter" 103 ) 104 105 # Modify graph such that the conv changes from 1d to 2d 106 self.unsqueeze_kernel_weights(kernel_node) 107 108 # (b) Extend stride, padding, and dilation for extra dim 109 node.args = ( 110 node.args[0], 111 node.args[1], 112 node.args[2], 113 node.args[3] + [1], # stride 114 node.args[4] + [0], # padding 115 node.args[5] + [1], # dilation 116 node.args[6], 117 node.args[7] + [0], 118 node.args[8], 119 ) 120 121 # c. Add unsqueeze to input (3d -> 4d) and squeeze to output (4d -> 3d) 122 # unsqueeze -> conv2d -> squeeze 123 with graph.inserting_before(node): 124 input_node = node.args[0] 125 unsqueeze_before = create_node( 126 graph, exir_ops.edge.aten.unsqueeze_copy.default 127 ) 128 unsqueeze_before.args = ( 129 input_node, # Input is node's original input 130 -1, # Last Dimension 131 ) 132 node.replace_input_with(input_node, unsqueeze_before) 133 134 # If Quantized we must insert unsqueeze --> q --> dq --> node 135 if input_node.target == dq_op: 136 q_params = input_node.args[1:] 137 insert_q_dq_pair(graph, unsqueeze_before, q_params) 138 139 with graph.inserting_after(node): 140 squeeze_after = create_node( 141 graph, 142 exir_ops.edge.aten.squeeze_copy.dims, 143 ) 144 squeeze_after.args = ( 145 node, # Input is the conv node 146 [-1], # Last dimension 147 ) 148 original_users = [ 149 user for user in node.users if user != squeeze_after 150 ] 151 for user in original_users: 152 user.replace_input_with(node, squeeze_after) 153 154 # If quantized, insert conv2d --> q --> dq --> squeeze 155 if all( 156 original_user.target == q_op for original_user in original_users 157 ): 158 q_params = original_users[0].args[1:] 159 insert_q_dq_pair(graph, node, q_params) 160 161 graph_module.recompile() 162 # Since we are overriding "call", we need to call the parent's "call" 163 # to retrace the graph and regenerate metadata 164 graph_module = super().call(graph_module).graph_module 165 166 return PassResult(graph_module, True) 167