1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import Dict 8 9import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 10 11import numpy as np 12import torch 13 14from .node_visitor import NodeVisitor, register_node_visitor 15from .qnn_constants import OpGroupNorm, QNN_OP_PACKAGE_NAME_QTI_AISW 16from .utils import get_parameter 17 18 19@register_node_visitor 20class GroupNormVisitor(NodeVisitor): 21 target = ["aten.native_group_norm.default"] 22 23 def __init__(self, *args) -> None: 24 super().__init__(*args) 25 26 def define_node( 27 self, 28 node: torch.fx.Node, 29 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 30 ) -> PyQnnWrapper.PyQnnOpWrapper: 31 input_node = node.args[0] 32 input_tensor = self.get_tensor(input_node, node) 33 input_tensor_wrapper = self.define_tensor( 34 input_node, 35 input_tensor, 36 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 37 nodes_to_wrappers, 38 is_input_tensor=True, 39 ) 40 41 weight_node = node.args[1] 42 weight_tensor = get_parameter(weight_node, self.edge_program) 43 weight_tensor_wrapper = self.define_tensor( 44 weight_node, 45 weight_tensor, 46 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, 47 nodes_to_wrappers, 48 is_input_tensor=False, 49 ) 50 51 bias_node = node.args[2] 52 bias_tensor = get_parameter(bias_node, self.edge_program) 53 bias_tensor_wrapper = self.define_tensor( 54 bias_node, 55 bias_tensor, 56 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, 57 nodes_to_wrappers, 58 is_input_tensor=False, 59 ) 60 group = node.args[6] 61 epsilon = node.args[7] 62 63 output_tensor = self.get_tensor(node, node, 0) 64 output_tensor_wrapper = self.define_tensor( 65 node, 66 output_tensor, 67 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 68 nodes_to_wrappers, 69 is_input_tensor=False, 70 ) 71 72 group_norm_op = PyQnnWrapper.PyQnnOpWrapper( 73 node.name, 74 QNN_OP_PACKAGE_NAME_QTI_AISW, 75 OpGroupNorm.op_name, 76 ) 77 group_norm_op.AddInputTensors( 78 [input_tensor_wrapper, weight_tensor_wrapper, bias_tensor_wrapper] 79 ) 80 group_norm_op.AddOutputTensors([output_tensor_wrapper]) 81 group_norm_op.AddScalarParam( 82 OpGroupNorm.param_epsilon, 83 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, 84 {"data": np.float32(epsilon)}, 85 ) 86 group_norm_op.AddScalarParam( 87 OpGroupNorm.param_group, 88 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, 89 {"data": np.uint32(group)}, 90 ) 91 92 return group_norm_op 93