• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6import operator
7from typing import Any, Dict
8
9import torch
10from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter
11from executorch.backends.qualcomm.utils.constants import (
12    QCOM_ENCODING,
13    QCOM_QUANT_ATTRS,
14    QCOM_REQUANTIZE,
15    QCOM_SCALES,
16    QCOM_ZERO_POINTS,
17)
18from executorch.exir.dialects._ops import ops as exir_ops
19from executorch.exir.pass_base import ExportPass, PassResult
20
21from .utils import dq_ops, get_quant_attrs, q_ops
22
23
24class AnnotateQuantAttrs(ExportPass):
25    """
26    Add "quant_attrs" to graph nodes' meta from the QDQ information
27    generated after quatization process.
28    """
29
30    def __init__(
31        self, edge_program: torch.export.ExportedProgram, skip_advanced_requat: bool
32    ):
33        super(AnnotateQuantAttrs, self).__init__()
34        self.edge_program = edge_program
35        self.skip_advanced_requant = skip_advanced_requat
36
37    def _annotate_source_nodes(
38        self, quant_node: torch.fx.Node, quant_attrs: Dict[str, Any]
39    ):
40
41        if quant_node.args[0].target == operator.getitem:
42            getitem_node = quant_node.args[0]
43            getitem_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
44            source_n = getitem_node.args[0]
45        else:
46            source_n = quant_node.args[0]
47        source_n.meta[QCOM_QUANT_ATTRS] = quant_attrs
48
49    def _expand(self, tensor, dim, axis) -> torch.Tensor:
50        tensor = tensor[(...,) + (None,) * (dim - 1)]
51        order = torch.arange(dim).tolist()
52        order[axis], order[0] = order[0], order[axis]
53        return tensor.permute(order)
54
55    # Find the the last dq node between regular op nodes
56    # Return dq2 in example below when q1 is given as node parameter:
57    # ... -> n1 -> q1 -> dq1 -> q2 -> dq2 -> n2 -> ...
58    def _find_last_dq_node(self, node: torch.fx.node.Node) -> torch.fx.node.Node:
59        if list(node.users)[0].target in q_ops.union(dq_ops):
60            return self._find_last_dq_node(list(node.users)[0])
61        return node
62
63    def _annotate_requant(self, n):
64        # Record requant attributes:
65        # node1 -> q_ui8 -> dq_ui8 -> q_int32 -> dq_int32 -> node2 -> ....
66        # We store quant info for dq_ui8 and q_int32 in node1.meta
67        if n.target in q_ops and n.args[0].target not in dq_ops:
68            dq_node = self._find_last_dq_node(n)
69            q_attrs = get_quant_attrs(self.edge_program, n)
70            dq_attrs = get_quant_attrs(self.edge_program, dq_node)
71
72            # TODO: Store multiple pairs of requantize attributes when we have an op builder
73            # that has multiple outputs that requires quant attributes.
74            if self.skip_advanced_requant:
75                if q_attrs["dtype"] != dq_attrs["dtype"]:
76                    dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
77                    n.args[0].meta[QCOM_REQUANTIZE] = dq_attrs
78            else:
79                # When dtype is the same but other specs such as scale and offset are different,
80                # insert requant to improve accuracy.
81                # Users can turn this feature off if any inference speed drop is observed.
82                if any(
83                    q_attrs[attr] != dq_attrs[attr]
84                    for attr in [
85                        "scale",
86                        "zero_point",
87                        "quant_min",
88                        "quant_max",
89                        "dtype",
90                    ]
91                ):
92                    dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
93                    n.args[0].meta[QCOM_REQUANTIZE] = dq_attrs
94
95    # Dequant all the fold_quant parameters back to fp32.
96    # If an operation is not supported by QNN and got fallback, it will expect a fp32 param.
97    def _dequant_fold_params(self, n, quant_attrs, param):
98        if quant_attrs[QCOM_ENCODING] in [
99            exir_ops.edge.quantized_decomposed.dequantize_per_channel.default
100        ]:
101            dim, axis = param.dim(), quant_attrs["axis"]
102            scales = self._expand(quant_attrs[QCOM_SCALES], dim, axis)
103            offsets = self._expand(quant_attrs[QCOM_ZERO_POINTS], dim, axis)
104            param = param.sub(offsets).mul(scales).to(torch.float32).contiguous()
105            set_parameter(param, n.args[0], self.edge_program)
106        else:
107            scale = quant_attrs["scale"]
108            offset = quant_attrs["zero_point"]
109            param = param.sub(offset).mul(scale).to(torch.float32).contiguous()
110            set_parameter(param, n.args[0], self.edge_program)
111
112        n.args[0].meta["val"] = param
113
114    def _annotate_quant_attrs(
115        self, graph_module: torch.fx.GraphModule
116    ) -> torch.fx.GraphModule:
117        # Keep track of const params that has been dequant, so it does not get
118        # dequant multiple times if the const param has more than 1 user
119        visited_const_param = set()
120        for n in graph_module.graph.nodes:
121            self._annotate_requant(n)
122            # With fold_quant enabled, check if the input of dq op is quantized param.
123            param = None
124            if n.target in dq_ops:
125                param = get_parameter(n.args[0], self.edge_program)
126            if n.target not in q_ops and param is None:
127                continue
128            quant_attrs = get_quant_attrs(self.edge_program, n)
129            self._annotate_source_nodes(n, quant_attrs)
130
131            if param is not None and n.args[0] not in visited_const_param:
132                visited_const_param.add(n.args[0])
133                self._dequant_fold_params(n, quant_attrs, param)
134
135        return graph_module
136
137    def call(self, graph_module: torch.fx.GraphModule):
138        self._annotate_quant_attrs(graph_module)
139        graph_module.recompile()
140        return PassResult(graph_module, True)
141