• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6from typing import Dict
7
8import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9
10import torch
11from executorch.backends.qualcomm.utils.constants import (
12    QCOM_QUANT_ATTRS,
13    QCOM_QUANT_MAX,
14    QCOM_SCALE,
15)
16
17from .node_visitor import NodeVisitor, register_node_visitor
18from .qnn_constants import OpBatchnorm, QNN_OP_PACKAGE_NAME_QTI_AISW
19from .utils import get_parameter
20
21
22@register_node_visitor
23class BatchNorm(NodeVisitor):
24    target = ["aten._native_batch_norm_legit_no_training.default"]
25
26    def __init__(self, *args) -> None:
27        super().__init__(*args)
28
29    def update_encoding(self, node: torch.fx.Node, tensor: torch.Tensor, eps):
30        if isinstance(tensor, torch._subclasses.FakeTensor):
31            return
32
33        if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
34            # scale value equals to zero will cause failure in HTP
35            diff = max(abs(tensor.max()), abs(tensor.min())) + eps
36            quant_attrs[QCOM_SCALE] = diff / quant_attrs[QCOM_QUANT_MAX]
37
38    def define_node(
39        self,
40        node: torch.fx.Node,
41        nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
42    ) -> PyQnnWrapper.PyQnnOpWrapper:
43        input_node = node.args[0]
44        input_tensor = self.get_tensor(input_node, node)
45
46        mean_node, var_node, eps = node.args[3], node.args[4], 1e-9
47        mean_tensor = get_parameter(mean_node, self.edge_program)
48        var_tensor = get_parameter(var_node, self.edge_program)
49
50        input_tensor_wrapper = self.define_tensor(
51            input_node,
52            input_tensor,
53            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
54            nodes_to_wrappers,
55            is_input_tensor=True,
56        )
57
58        bias_node = node.args[2]
59        bias_tensor = get_parameter(bias_node, self.edge_program)
60        filter_node = node.args[1]
61        filter_tensor = get_parameter(filter_node, self.edge_program)
62
63        amount = (filter_tensor * mean_tensor) / torch.sqrt(var_tensor + eps)
64        bias_tensor = bias_tensor - amount
65        self.update_encoding(bias_node, bias_tensor, eps)
66        bias_tensor_wrapper = self.define_tensor(
67            bias_node,
68            bias_tensor,
69            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
70            nodes_to_wrappers,
71            is_input_tensor=False,
72        )
73
74        filter_tensor = filter_tensor / torch.sqrt(var_tensor + eps)
75        self.update_encoding(filter_node, filter_tensor, eps)
76        filter_tensor_wrapper = self.define_tensor(
77            filter_node,
78            filter_tensor,
79            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
80            nodes_to_wrappers,
81            is_input_tensor=False,
82        )
83
84        batch_norm_input_tensors = [
85            input_tensor_wrapper,
86            filter_tensor_wrapper,
87            bias_tensor_wrapper,
88        ]
89
90        output_tensor = self.get_tensor(node, node, 0)
91        output_tensor_wrapper = self.define_tensor(
92            node,
93            output_tensor,
94            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
95            nodes_to_wrappers,
96            is_input_tensor=False,
97        )
98        batch_norm_output_tensors = [output_tensor_wrapper]
99
100        batch_norm_op = PyQnnWrapper.PyQnnOpWrapper(
101            node.name,
102            QNN_OP_PACKAGE_NAME_QTI_AISW,
103            OpBatchnorm.op_name,
104        )
105        batch_norm_op.AddInputTensors(batch_norm_input_tensors)
106        batch_norm_op.AddOutputTensors(batch_norm_output_tensors)
107
108        return batch_norm_op
109