1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6from typing import Dict 7 8import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 9 10import torch 11from executorch.backends.qualcomm.utils.constants import ( 12 QCOM_QUANT_ATTRS, 13 QCOM_QUANT_MAX, 14 QCOM_SCALE, 15) 16 17from .node_visitor import NodeVisitor, register_node_visitor 18from .qnn_constants import OpBatchnorm, QNN_OP_PACKAGE_NAME_QTI_AISW 19from .utils import get_parameter 20 21 22@register_node_visitor 23class BatchNorm(NodeVisitor): 24 target = ["aten._native_batch_norm_legit_no_training.default"] 25 26 def __init__(self, *args) -> None: 27 super().__init__(*args) 28 29 def update_encoding(self, node: torch.fx.Node, tensor: torch.Tensor, eps): 30 if isinstance(tensor, torch._subclasses.FakeTensor): 31 return 32 33 if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS): 34 # scale value equals to zero will cause failure in HTP 35 diff = max(abs(tensor.max()), abs(tensor.min())) + eps 36 quant_attrs[QCOM_SCALE] = diff / quant_attrs[QCOM_QUANT_MAX] 37 38 def define_node( 39 self, 40 node: torch.fx.Node, 41 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 42 ) -> PyQnnWrapper.PyQnnOpWrapper: 43 input_node = node.args[0] 44 input_tensor = self.get_tensor(input_node, node) 45 46 mean_node, var_node, eps = node.args[3], node.args[4], 1e-9 47 mean_tensor = get_parameter(mean_node, self.edge_program) 48 var_tensor = get_parameter(var_node, self.edge_program) 49 50 input_tensor_wrapper = self.define_tensor( 51 input_node, 52 input_tensor, 53 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 54 nodes_to_wrappers, 55 is_input_tensor=True, 56 ) 57 58 bias_node = node.args[2] 59 bias_tensor = get_parameter(bias_node, self.edge_program) 60 filter_node = node.args[1] 61 filter_tensor = get_parameter(filter_node, self.edge_program) 62 63 amount = (filter_tensor * mean_tensor) / torch.sqrt(var_tensor + eps) 64 bias_tensor = bias_tensor - amount 65 self.update_encoding(bias_node, bias_tensor, eps) 66 bias_tensor_wrapper = self.define_tensor( 67 bias_node, 68 bias_tensor, 69 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, 70 nodes_to_wrappers, 71 is_input_tensor=False, 72 ) 73 74 filter_tensor = filter_tensor / torch.sqrt(var_tensor + eps) 75 self.update_encoding(filter_node, filter_tensor, eps) 76 filter_tensor_wrapper = self.define_tensor( 77 filter_node, 78 filter_tensor, 79 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, 80 nodes_to_wrappers, 81 is_input_tensor=False, 82 ) 83 84 batch_norm_input_tensors = [ 85 input_tensor_wrapper, 86 filter_tensor_wrapper, 87 bias_tensor_wrapper, 88 ] 89 90 output_tensor = self.get_tensor(node, node, 0) 91 output_tensor_wrapper = self.define_tensor( 92 node, 93 output_tensor, 94 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 95 nodes_to_wrappers, 96 is_input_tensor=False, 97 ) 98 batch_norm_output_tensors = [output_tensor_wrapper] 99 100 batch_norm_op = PyQnnWrapper.PyQnnOpWrapper( 101 node.name, 102 QNN_OP_PACKAGE_NAME_QTI_AISW, 103 OpBatchnorm.op_name, 104 ) 105 batch_norm_op.AddInputTensors(batch_norm_input_tensors) 106 batch_norm_op.AddOutputTensors(batch_norm_output_tensors) 107 108 return batch_norm_op 109