# Copyright (c) Qualcomm Innovation Center, Inc. # All rights reserved # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import torch from executorch.backends.qualcomm.utils.constants import ( QCOM_QUANT_ATTRS, QCOM_QUANTIZED_IO, QCOM_REQUANTIZE, ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult class InsertRequantize(ExportPass): """ This pass inserts convert op for operators which have different quantization specs in input and activation. Convert OP is a specific op which helps to requantize in Qnn backend """ # Storing ops that has multi output but run _single_output_annotation logic # instead of _multi_output_annotation. Ops might be added into this set because # we don't use the 2nd output, 2nd output is an integer, etc. multi_output_op_ignore_set = { exir_ops.edge.aten._native_batch_norm_legit_no_training.default, exir_ops.edge.aten.topk.default, } def __init__( self, edge_program: torch.export.ExportedProgram, ): super(InsertRequantize, self).__init__() self.edge_program = edge_program # TODO: Implement this function when we have an op with # multiple outputs that requires quant attributes. def _multi_output_annotation(self) -> None: raise NotImplementedError("requant is not implemented for multi output yet") def _single_output_annotation( self, gm: torch.fx.GraphModule, n: torch.fx.node ) -> None: with gm.graph.inserting_after(n): users = list(n.users.keys()) inserted_n = gm.graph.create_node( "call_function", exir_ops.edge.aten._to_copy.default, (n,), ) inserted_n.meta["val"] = n.meta["val"] inserted_n.meta[QCOM_QUANT_ATTRS] = n.meta.pop(QCOM_REQUANTIZE) if n.meta.get(QCOM_QUANTIZED_IO): inserted_n.meta[QCOM_QUANTIZED_IO] = n.meta[QCOM_QUANTIZED_IO] for user in users: user.replace_input_with(n, inserted_n) def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: for n in graph_module.graph.nodes: if QCOM_REQUANTIZE in n.meta: ( self._single_output_annotation(graph_module, n) if isinstance( n.meta["val"], torch._subclasses.fake_tensor.FakeTensor ) or n.target in self.multi_output_op_ignore_set else self._multi_output_annotation() ) def call(self, graph_module: torch.fx.GraphModule): self._insert(graph_module) graph_module.graph.eliminate_dead_code() graph_module.recompile() return PassResult(graph_module, True)