1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import operator 7from typing import Any, Dict 8 9import torch 10from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter 11from executorch.backends.qualcomm.utils.constants import ( 12 QCOM_ENCODING, 13 QCOM_QUANT_ATTRS, 14 QCOM_REQUANTIZE, 15 QCOM_SCALES, 16 QCOM_ZERO_POINTS, 17) 18from executorch.exir.dialects._ops import ops as exir_ops 19from executorch.exir.pass_base import ExportPass, PassResult 20 21from .utils import dq_ops, get_quant_attrs, q_ops 22 23 24class AnnotateQuantAttrs(ExportPass): 25 """ 26 Add "quant_attrs" to graph nodes' meta from the QDQ information 27 generated after quatization process. 28 """ 29 30 def __init__( 31 self, edge_program: torch.export.ExportedProgram, skip_advanced_requat: bool 32 ): 33 super(AnnotateQuantAttrs, self).__init__() 34 self.edge_program = edge_program 35 self.skip_advanced_requant = skip_advanced_requat 36 37 def _annotate_source_nodes( 38 self, quant_node: torch.fx.Node, quant_attrs: Dict[str, Any] 39 ): 40 41 if quant_node.args[0].target == operator.getitem: 42 getitem_node = quant_node.args[0] 43 getitem_node.meta[QCOM_QUANT_ATTRS] = quant_attrs 44 source_n = getitem_node.args[0] 45 else: 46 source_n = quant_node.args[0] 47 source_n.meta[QCOM_QUANT_ATTRS] = quant_attrs 48 49 def _expand(self, tensor, dim, axis) -> torch.Tensor: 50 tensor = tensor[(...,) + (None,) * (dim - 1)] 51 order = torch.arange(dim).tolist() 52 order[axis], order[0] = order[0], order[axis] 53 return tensor.permute(order) 54 55 # Find the the last dq node between regular op nodes 56 # Return dq2 in example below when q1 is given as node parameter: 57 # ... -> n1 -> q1 -> dq1 -> q2 -> dq2 -> n2 -> ... 58 def _find_last_dq_node(self, node: torch.fx.node.Node) -> torch.fx.node.Node: 59 if list(node.users)[0].target in q_ops.union(dq_ops): 60 return self._find_last_dq_node(list(node.users)[0]) 61 return node 62 63 def _annotate_requant(self, n): 64 # Record requant attributes: 65 # node1 -> q_ui8 -> dq_ui8 -> q_int32 -> dq_int32 -> node2 -> .... 66 # We store quant info for dq_ui8 and q_int32 in node1.meta 67 if n.target in q_ops and n.args[0].target not in dq_ops: 68 dq_node = self._find_last_dq_node(n) 69 q_attrs = get_quant_attrs(self.edge_program, n) 70 dq_attrs = get_quant_attrs(self.edge_program, dq_node) 71 72 # TODO: Store multiple pairs of requantize attributes when we have an op builder 73 # that has multiple outputs that requires quant attributes. 74 if self.skip_advanced_requant: 75 if q_attrs["dtype"] != dq_attrs["dtype"]: 76 dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING] 77 n.args[0].meta[QCOM_REQUANTIZE] = dq_attrs 78 else: 79 # When dtype is the same but other specs such as scale and offset are different, 80 # insert requant to improve accuracy. 81 # Users can turn this feature off if any inference speed drop is observed. 82 if any( 83 q_attrs[attr] != dq_attrs[attr] 84 for attr in [ 85 "scale", 86 "zero_point", 87 "quant_min", 88 "quant_max", 89 "dtype", 90 ] 91 ): 92 dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING] 93 n.args[0].meta[QCOM_REQUANTIZE] = dq_attrs 94 95 # Dequant all the fold_quant parameters back to fp32. 96 # If an operation is not supported by QNN and got fallback, it will expect a fp32 param. 97 def _dequant_fold_params(self, n, quant_attrs, param): 98 if quant_attrs[QCOM_ENCODING] in [ 99 exir_ops.edge.quantized_decomposed.dequantize_per_channel.default 100 ]: 101 dim, axis = param.dim(), quant_attrs["axis"] 102 scales = self._expand(quant_attrs[QCOM_SCALES], dim, axis) 103 offsets = self._expand(quant_attrs[QCOM_ZERO_POINTS], dim, axis) 104 param = param.sub(offsets).mul(scales).to(torch.float32).contiguous() 105 set_parameter(param, n.args[0], self.edge_program) 106 else: 107 scale = quant_attrs["scale"] 108 offset = quant_attrs["zero_point"] 109 param = param.sub(offset).mul(scale).to(torch.float32).contiguous() 110 set_parameter(param, n.args[0], self.edge_program) 111 112 n.args[0].meta["val"] = param 113 114 def _annotate_quant_attrs( 115 self, graph_module: torch.fx.GraphModule 116 ) -> torch.fx.GraphModule: 117 # Keep track of const params that has been dequant, so it does not get 118 # dequant multiple times if the const param has more than 1 user 119 visited_const_param = set() 120 for n in graph_module.graph.nodes: 121 self._annotate_requant(n) 122 # With fold_quant enabled, check if the input of dq op is quantized param. 123 param = None 124 if n.target in dq_ops: 125 param = get_parameter(n.args[0], self.edge_program) 126 if n.target not in q_ops and param is None: 127 continue 128 quant_attrs = get_quant_attrs(self.edge_program, n) 129 self._annotate_source_nodes(n, quant_attrs) 130 131 if param is not None and n.args[0] not in visited_const_param: 132 visited_const_param.add(n.args[0]) 133 self._dequant_fold_params(n, quant_attrs, param) 134 135 return graph_module 136 137 def call(self, graph_module: torch.fx.GraphModule): 138 self._annotate_quant_attrs(graph_module) 139 graph_module.recompile() 140 return PassResult(graph_module, True) 141