• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6from typing import Dict
7
8import torch
9
10from executorch.backends.qualcomm.builders.utils import is_parameter
11from executorch.backends.qualcomm.utils.constants import (
12    QCOM_ENCODING,
13    QCOM_QUANT_ATTRS,
14    QCOM_QUANTIZED_IO,
15)
16from executorch.exir.dialects._ops import ops as exir_ops
17from executorch.exir.pass_base import ExportPass, PassResult
18
19from .utils import q_ops
20
21
22class InsertIOQDQ(ExportPass):
23    """
24    For delegated QNN subgraph, no more QDQ operators will appear after
25    'fold_qdq pass'.
26    This pass will insert quantize nodes right after inputs, dequantize nodes
27    right before outputs according to stored quantization encodings.
28    """
29
30    q_dq_map = {
31        # per tensor
32        exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
33        exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor: exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
34        # per channel
35        exir_ops.edge.quantized_decomposed.quantize_per_channel.default: exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
36    }
37
38    def __init__(self, edge_program: torch.export.ExportedProgram):
39        super(InsertIOQDQ, self).__init__()
40        self.edge_program = edge_program
41
42    def _ceate_args(self, target: torch.fx.node.Target, quant_attrs: Dict):
43        ret = []
44
45        arg_schemas = list(target._schema.arguments)[1:]
46        for arg_schema in arg_schemas:
47            name = arg_schema.name
48            # TODO: Due to the new parameter "out_dtype" in the dequantize node,
49            # it could not be found in the quant_attrs of other nodes,
50            # and it will cause a key error. For now, the output type
51            # of our dequantize node is only float. (by default in pytorch)
52            if name == "out_dtype":
53                continue
54            value = quant_attrs[name]
55            if isinstance(arg_schema.type, torch.Tensor) and (
56                isinstance(value, int) or isinstance(value, float)
57            ):
58                value = torch.tensor(value)
59            ret.append(value)
60        return ret
61
62    def _create_node(
63        self,
64        graph_module: torch.fx.GraphModule,
65        node: torch.fx.node,
66        target: torch.fx.node.Target,
67        quant_attrs: Dict = None,
68    ) -> torch.fx.node:
69        # check if there has a specified quant_attrs
70        # if not, use the existent info. from current node
71        if quant_attrs is None:
72            quant_attrs = node.meta.get(QCOM_QUANT_ATTRS)
73
74        inserted_node = graph_module.graph.create_node(
75            "call_function",
76            target,
77            (node, *self._ceate_args(target, quant_attrs)),
78        )
79        meta_val = node.meta["val"]
80        if target in self.q_dq_map:
81            inserted_node.meta[QCOM_QUANT_ATTRS] = node.meta.pop(QCOM_QUANT_ATTRS)
82            meta_val = meta_val.to(quant_attrs["dtype"])
83
84        inserted_node.meta["val"] = meta_val
85        return inserted_node
86
87    def _insert_quant_node(
88        self,
89        graph_module: torch.fx.GraphModule,
90        node: torch.fx.node,
91        target: torch.fx.node.Target,
92        quant_attrs: Dict = None,
93    ) -> torch.fx.Node:
94        with graph_module.graph.inserting_after(node):
95            users = list(node.users.keys())
96            inserted_node = self._create_node(graph_module, node, target, quant_attrs)
97            for user in users:
98                # If we found mix quantization pattern and reuse the existing q_node, we skip adding a new q node.
99                if user.target not in q_ops:
100                    user.replace_input_with(node, inserted_node)
101
102        return inserted_node
103
104    def _insert_dequant_node(
105        self,
106        graph_module: torch.fx.GraphModule,
107        node: torch.fx.node,
108        target: torch.fx.node.Target,
109    ) -> None:
110        with graph_module.graph.inserting_after(node):
111            users = list(node.users.keys())
112            inserted_node = self._create_node(graph_module, node, target)
113            for user in users:
114                if user.op == "output":
115                    user.replace_input_with(node, inserted_node)
116
117    def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
118        for n in graph_module.graph.nodes:
119            # do nothing when a node is expected to output a quant tensor
120            if n.meta.get(QCOM_QUANTIZED_IO):
121                continue
122
123            # insert q after input or fold mix_quantization dq if applicable
124            if (
125                n.op == "placeholder"
126                and n.meta.get(QCOM_QUANT_ATTRS)
127                and not is_parameter(n, self.edge_program)
128            ):
129                self._insert_quant_node(
130                    graph_module, n, n.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING]
131                )
132
133            # insert dq before output or fold mix_quantization q if applicable
134            users = list(n.users.keys())
135            if n.meta.get(QCOM_QUANT_ATTRS) and any(
136                user.op == "output" for user in users
137            ):
138                self._insert_dequant_node(
139                    graph_module,
140                    n,
141                    self.q_dq_map[n.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING]],
142                )
143
144    def call(self, graph_module: torch.fx.GraphModule):
145        self._insert(graph_module)
146        graph_module.graph.eliminate_dead_code()
147        graph_module.recompile()
148        return PassResult(graph_module, True)
149