• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6from typing import Dict
7
8import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9
10import numpy as np
11import torch
12
13from .node_visitor import NodeVisitor, register_node_visitor
14from .qnn_constants import OpGather, QNN_OP_PACKAGE_NAME_QTI_AISW
15
16
17@register_node_visitor
18class Index(NodeVisitor):
19    # schema = aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
20    target = ["aten.index.Tensor"]
21
22    def __init__(self, *args) -> None:
23        super().__init__(*args)
24
25    def define_node(
26        self,
27        node: torch.fx.Node,
28        nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
29    ) -> PyQnnWrapper.PyQnnOpWrapper:
30        input_node = node.args[0]
31        input_tensor = self.get_tensor(input_node, node)
32        input_tensor_wrapper = self.define_tensor(
33            input_node,
34            input_tensor,
35            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
36            nodes_to_wrappers,
37            is_input_tensor=True,
38        )
39
40        if len(node.args[1]) > 1:
41            # TODO consider to implement it in a recursive way.
42            raise NotImplementedError("Not support tuple of tensor.")
43
44        indices_node = node.args[1][0]
45        indices_tensor = self.get_tensor(indices_node, node).to(torch.int32)
46        assert indices_tensor.size(0) != 0, "Not support empty indices list"
47
48        indices_tensor_wrapper = self.define_tensor(
49            indices_node,
50            indices_tensor,
51            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
52            nodes_to_wrappers,
53            is_input_tensor=True,
54        )
55
56        gather_input_tensors = [input_tensor_wrapper, indices_tensor_wrapper]
57
58        output_tensor = self.get_tensor(node, node)
59        output_tensor_wrapper = self.define_tensor(
60            node,
61            output_tensor,
62            PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
63            nodes_to_wrappers,
64            is_input_tensor=False,
65        )
66        gather_output_tensors = [output_tensor_wrapper]
67
68        gather_op = PyQnnWrapper.PyQnnOpWrapper(
69            node.name,
70            QNN_OP_PACKAGE_NAME_QTI_AISW,
71            OpGather.op_name,
72        )
73        gather_op.AddInputTensors(gather_input_tensors)
74        gather_op.AddOutputTensors(gather_output_tensors)
75
76        # If support tuple of tensor, need to refine it based on len
77        gather_op.AddScalarParam(
78            OpGather.param_axis,
79            PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32,
80            {"data": np.int32(0)},
81        )
82
83        return gather_op
84