1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6from typing import Dict 7 8import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 9 10import numpy as np 11import torch 12 13from .node_visitor import NodeVisitor, register_node_visitor 14from .qnn_constants import OpGather, QNN_OP_PACKAGE_NAME_QTI_AISW 15 16 17@register_node_visitor 18class Index(NodeVisitor): 19 # schema = aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor 20 target = ["aten.index.Tensor"] 21 22 def __init__(self, *args) -> None: 23 super().__init__(*args) 24 25 def define_node( 26 self, 27 node: torch.fx.Node, 28 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 29 ) -> PyQnnWrapper.PyQnnOpWrapper: 30 input_node = node.args[0] 31 input_tensor = self.get_tensor(input_node, node) 32 input_tensor_wrapper = self.define_tensor( 33 input_node, 34 input_tensor, 35 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 36 nodes_to_wrappers, 37 is_input_tensor=True, 38 ) 39 40 if len(node.args[1]) > 1: 41 # TODO consider to implement it in a recursive way. 42 raise NotImplementedError("Not support tuple of tensor.") 43 44 indices_node = node.args[1][0] 45 indices_tensor = self.get_tensor(indices_node, node).to(torch.int32) 46 assert indices_tensor.size(0) != 0, "Not support empty indices list" 47 48 indices_tensor_wrapper = self.define_tensor( 49 indices_node, 50 indices_tensor, 51 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 52 nodes_to_wrappers, 53 is_input_tensor=True, 54 ) 55 56 gather_input_tensors = [input_tensor_wrapper, indices_tensor_wrapper] 57 58 output_tensor = self.get_tensor(node, node) 59 output_tensor_wrapper = self.define_tensor( 60 node, 61 output_tensor, 62 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 63 nodes_to_wrappers, 64 is_input_tensor=False, 65 ) 66 gather_output_tensors = [output_tensor_wrapper] 67 68 gather_op = PyQnnWrapper.PyQnnOpWrapper( 69 node.name, 70 QNN_OP_PACKAGE_NAME_QTI_AISW, 71 OpGather.op_name, 72 ) 73 gather_op.AddInputTensors(gather_input_tensors) 74 gather_op.AddOutputTensors(gather_output_tensors) 75 76 # If support tuple of tensor, need to refine it based on len 77 gather_op.AddScalarParam( 78 OpGather.param_axis, 79 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, 80 {"data": np.int32(0)}, 81 ) 82 83 return gather_op 84