1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6from typing import Dict 7 8import torch 9 10from executorch.backends.qualcomm.builders.utils import is_parameter 11from executorch.backends.qualcomm.utils.constants import ( 12 QCOM_ENCODING, 13 QCOM_QUANT_ATTRS, 14 QCOM_QUANTIZED_IO, 15) 16from executorch.exir.dialects._ops import ops as exir_ops 17from executorch.exir.pass_base import ExportPass, PassResult 18 19from .utils import q_ops 20 21 22class InsertIOQDQ(ExportPass): 23 """ 24 For delegated QNN subgraph, no more QDQ operators will appear after 25 'fold_qdq pass'. 26 This pass will insert quantize nodes right after inputs, dequantize nodes 27 right before outputs according to stored quantization encodings. 28 """ 29 30 q_dq_map = { 31 # per tensor 32 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, 33 exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor: exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, 34 # per channel 35 exir_ops.edge.quantized_decomposed.quantize_per_channel.default: exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, 36 } 37 38 def __init__(self, edge_program: torch.export.ExportedProgram): 39 super(InsertIOQDQ, self).__init__() 40 self.edge_program = edge_program 41 42 def _ceate_args(self, target: torch.fx.node.Target, quant_attrs: Dict): 43 ret = [] 44 45 arg_schemas = list(target._schema.arguments)[1:] 46 for arg_schema in arg_schemas: 47 name = arg_schema.name 48 # TODO: Due to the new parameter "out_dtype" in the dequantize node, 49 # it could not be found in the quant_attrs of other nodes, 50 # and it will cause a key error. For now, the output type 51 # of our dequantize node is only float. (by default in pytorch) 52 if name == "out_dtype": 53 continue 54 value = quant_attrs[name] 55 if isinstance(arg_schema.type, torch.Tensor) and ( 56 isinstance(value, int) or isinstance(value, float) 57 ): 58 value = torch.tensor(value) 59 ret.append(value) 60 return ret 61 62 def _create_node( 63 self, 64 graph_module: torch.fx.GraphModule, 65 node: torch.fx.node, 66 target: torch.fx.node.Target, 67 quant_attrs: Dict = None, 68 ) -> torch.fx.node: 69 # check if there has a specified quant_attrs 70 # if not, use the existent info. from current node 71 if quant_attrs is None: 72 quant_attrs = node.meta.get(QCOM_QUANT_ATTRS) 73 74 inserted_node = graph_module.graph.create_node( 75 "call_function", 76 target, 77 (node, *self._ceate_args(target, quant_attrs)), 78 ) 79 meta_val = node.meta["val"] 80 if target in self.q_dq_map: 81 inserted_node.meta[QCOM_QUANT_ATTRS] = node.meta.pop(QCOM_QUANT_ATTRS) 82 meta_val = meta_val.to(quant_attrs["dtype"]) 83 84 inserted_node.meta["val"] = meta_val 85 return inserted_node 86 87 def _insert_quant_node( 88 self, 89 graph_module: torch.fx.GraphModule, 90 node: torch.fx.node, 91 target: torch.fx.node.Target, 92 quant_attrs: Dict = None, 93 ) -> torch.fx.Node: 94 with graph_module.graph.inserting_after(node): 95 users = list(node.users.keys()) 96 inserted_node = self._create_node(graph_module, node, target, quant_attrs) 97 for user in users: 98 # If we found mix quantization pattern and reuse the existing q_node, we skip adding a new q node. 99 if user.target not in q_ops: 100 user.replace_input_with(node, inserted_node) 101 102 return inserted_node 103 104 def _insert_dequant_node( 105 self, 106 graph_module: torch.fx.GraphModule, 107 node: torch.fx.node, 108 target: torch.fx.node.Target, 109 ) -> None: 110 with graph_module.graph.inserting_after(node): 111 users = list(node.users.keys()) 112 inserted_node = self._create_node(graph_module, node, target) 113 for user in users: 114 if user.op == "output": 115 user.replace_input_with(node, inserted_node) 116 117 def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: 118 for n in graph_module.graph.nodes: 119 # do nothing when a node is expected to output a quant tensor 120 if n.meta.get(QCOM_QUANTIZED_IO): 121 continue 122 123 # insert q after input or fold mix_quantization dq if applicable 124 if ( 125 n.op == "placeholder" 126 and n.meta.get(QCOM_QUANT_ATTRS) 127 and not is_parameter(n, self.edge_program) 128 ): 129 self._insert_quant_node( 130 graph_module, n, n.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING] 131 ) 132 133 # insert dq before output or fold mix_quantization q if applicable 134 users = list(n.users.keys()) 135 if n.meta.get(QCOM_QUANT_ATTRS) and any( 136 user.op == "output" for user in users 137 ): 138 self._insert_dequant_node( 139 graph_module, 140 n, 141 self.q_dq_map[n.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING]], 142 ) 143 144 def call(self, graph_module: torch.fx.GraphModule): 145 self._insert(graph_module) 146 graph_module.graph.eliminate_dead_code() 147 graph_module.recompile() 148 return PassResult(graph_module, True) 149