• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import warnings
8from typing import Dict
9
10import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
11import numpy as np
12
13import torch
14from executorch.backends.qualcomm.builders.utils import get_parameter
15from executorch.backends.qualcomm.utils.constants import QCOM_DATA, QCOM_QUANT_ATTRS
16from executorch.exir.dialects._ops import ops as exir_ops
17
18from .node_visitor import NodeVisitor, register_node_visitor
19from .qnn_constants import OpRmsNorm, QNN_OP_PACKAGE_NAME_QTI_AISW
20
21
22@register_node_visitor
23class RmsNormVisitor(NodeVisitor):
24    target = ["aten.rms_norm.default"]
25
26    def __init__(self, *args) -> None:
27        super().__init__(*args)
28
29    def define_node(
30        self,
31        node: torch.fx.Node,
32        nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
33    ) -> PyQnnWrapper.PyQnnOpWrapper:
34        # args of node : ['input', 'normalized_shape', 'weight', 'eps']
35        input_node = node.args[0]
36        input_tensor = self.get_tensor(input_node, node)
37        input_tensor_wrapper = self.define_tensor(
38            input_node,
39            input_tensor,
40            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
41            nodes_to_wrappers,
42            is_input_tensor=True,
43        )
44
45        # should be a immutable list
46        normalized_shapes = node.args[1]
47        if (
48            len(normalized_shapes) != 1
49            and normalized_shapes[0] != input_tensor.shape[-1]
50        ):
51            warnings.warn(
52                "[QNN Delegate Op Builder]: Only supports normalization with last input dimension.",
53                stacklevel=1,
54            )
55            return
56        axes = [node.args[0].meta["val"].dim() - 1]
57        axes_shape = [len(axes)]
58
59        weight_node = node.args[2]
60        weight_tensor = get_parameter(weight_node, self.edge_program)
61        weight_tensor_wrapper = self.define_tensor(
62            weight_node,
63            weight_tensor,
64            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
65            nodes_to_wrappers,
66            is_input_tensor=False,
67        )
68
69        # Fake node, nn moudle seems to be inconsistant with document
70        bias_tensor = torch.zeros(weight_tensor.shape)
71        bias_node = torch.fx.Node(
72            node.graph,
73            node.name + "_runtime_bias",
74            "call_function",
75            exir_ops.edge.aten.tensor.default,
76            (),  # args
77            {},  # kwargs
78        )
79        if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
80            bias_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
81        bias_tensor_wrapper = self.define_tensor(
82            bias_node,
83            bias_tensor,
84            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
85            nodes_to_wrappers,
86            is_input_tensor=False,
87        )
88
89        epsilon = node.args[3]
90        if isinstance(epsilon, torch.fx.Node):
91            epsilon = get_parameter(epsilon, self.edge_program)
92            epsilon = (
93                epsilon
94                if isinstance(epsilon, float)
95                else torch.finfo(epsilon.dtype).eps
96            )
97
98        output_tensor = self.get_tensor(node, node)
99        output_tensor_wrapper = self.define_tensor(
100            node,
101            output_tensor,
102            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
103            nodes_to_wrappers,
104            is_input_tensor=False,
105        )
106
107        rms_nrom_op = PyQnnWrapper.PyQnnOpWrapper(
108            node.name,
109            QNN_OP_PACKAGE_NAME_QTI_AISW,
110            OpRmsNorm.op_name,
111        )
112
113        rms_nrom_op.AddInputTensors(
114            [input_tensor_wrapper, weight_tensor_wrapper, bias_tensor_wrapper]
115        )
116        rms_nrom_op.AddOutputTensors([output_tensor_wrapper])
117        rms_nrom_op.AddScalarParam(
118            OpRmsNorm.param_epsilon,
119            PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
120            {QCOM_DATA: np.float32(epsilon)},
121        )
122        rms_nrom_op.AddTensorParam(
123            OpRmsNorm.param_axes,
124            PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
125            len(axes_shape),
126            axes_shape,
127            np.array(axes, dtype=np.uint32),
128            True,
129        )
130
131        return rms_nrom_op
132