1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import warnings 8from typing import Dict 9 10import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 11import numpy as np 12 13import torch 14from executorch.backends.qualcomm.builders.utils import get_parameter 15from executorch.backends.qualcomm.utils.constants import QCOM_DATA, QCOM_QUANT_ATTRS 16from executorch.exir.dialects._ops import ops as exir_ops 17 18from .node_visitor import NodeVisitor, register_node_visitor 19from .qnn_constants import OpRmsNorm, QNN_OP_PACKAGE_NAME_QTI_AISW 20 21 22@register_node_visitor 23class RmsNormVisitor(NodeVisitor): 24 target = ["aten.rms_norm.default"] 25 26 def __init__(self, *args) -> None: 27 super().__init__(*args) 28 29 def define_node( 30 self, 31 node: torch.fx.Node, 32 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 33 ) -> PyQnnWrapper.PyQnnOpWrapper: 34 # args of node : ['input', 'normalized_shape', 'weight', 'eps'] 35 input_node = node.args[0] 36 input_tensor = self.get_tensor(input_node, node) 37 input_tensor_wrapper = self.define_tensor( 38 input_node, 39 input_tensor, 40 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 41 nodes_to_wrappers, 42 is_input_tensor=True, 43 ) 44 45 # should be a immutable list 46 normalized_shapes = node.args[1] 47 if ( 48 len(normalized_shapes) != 1 49 and normalized_shapes[0] != input_tensor.shape[-1] 50 ): 51 warnings.warn( 52 "[QNN Delegate Op Builder]: Only supports normalization with last input dimension.", 53 stacklevel=1, 54 ) 55 return 56 axes = [node.args[0].meta["val"].dim() - 1] 57 axes_shape = [len(axes)] 58 59 weight_node = node.args[2] 60 weight_tensor = get_parameter(weight_node, self.edge_program) 61 weight_tensor_wrapper = self.define_tensor( 62 weight_node, 63 weight_tensor, 64 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, 65 nodes_to_wrappers, 66 is_input_tensor=False, 67 ) 68 69 # Fake node, nn moudle seems to be inconsistant with document 70 bias_tensor = torch.zeros(weight_tensor.shape) 71 bias_node = torch.fx.Node( 72 node.graph, 73 node.name + "_runtime_bias", 74 "call_function", 75 exir_ops.edge.aten.tensor.default, 76 (), # args 77 {}, # kwargs 78 ) 79 if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS): 80 bias_node.meta[QCOM_QUANT_ATTRS] = quant_attrs 81 bias_tensor_wrapper = self.define_tensor( 82 bias_node, 83 bias_tensor, 84 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, 85 nodes_to_wrappers, 86 is_input_tensor=False, 87 ) 88 89 epsilon = node.args[3] 90 if isinstance(epsilon, torch.fx.Node): 91 epsilon = get_parameter(epsilon, self.edge_program) 92 epsilon = ( 93 epsilon 94 if isinstance(epsilon, float) 95 else torch.finfo(epsilon.dtype).eps 96 ) 97 98 output_tensor = self.get_tensor(node, node) 99 output_tensor_wrapper = self.define_tensor( 100 node, 101 output_tensor, 102 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 103 nodes_to_wrappers, 104 is_input_tensor=False, 105 ) 106 107 rms_nrom_op = PyQnnWrapper.PyQnnOpWrapper( 108 node.name, 109 QNN_OP_PACKAGE_NAME_QTI_AISW, 110 OpRmsNorm.op_name, 111 ) 112 113 rms_nrom_op.AddInputTensors( 114 [input_tensor_wrapper, weight_tensor_wrapper, bias_tensor_wrapper] 115 ) 116 rms_nrom_op.AddOutputTensors([output_tensor_wrapper]) 117 rms_nrom_op.AddScalarParam( 118 OpRmsNorm.param_epsilon, 119 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, 120 {QCOM_DATA: np.float32(epsilon)}, 121 ) 122 rms_nrom_op.AddTensorParam( 123 OpRmsNorm.param_axes, 124 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, 125 len(axes_shape), 126 axes_shape, 127 np.array(axes, dtype=np.uint32), 128 True, 129 ) 130 131 return rms_nrom_op 132