1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/infer/control/tensorlist_getitem_infer.h"
18 #include "nnacl/infer/infer_register.h"
19 #include "nnacl/tensorlist_c_utils.h"
20 #include "nnacl/tensor_c_utils.h"
21
TensorListGetItemInferShape(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,OpParameter * parameter)22 int TensorListGetItemInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs,
23 size_t outputs_size, OpParameter *parameter) {
24 int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1);
25 if (check_ret != NNACL_OK) {
26 return check_ret;
27 }
28
29 if (inputs[0]->data_type_ != kObjectTypeTensorType) {
30 return NNACL_ERR;
31 }
32 TensorListC *input0 = (TensorListC *)(inputs[0]);
33 const TensorC *get_index = inputs[1];
34 if (get_index->data_ == NULL) {
35 return NNACL_INFER_INVALID;
36 }
37 if (GetElementNum(get_index) != 1) {
38 return NNACL_ERR;
39 }
40 TensorC *output = outputs[0];
41 if (!InferFlag(inputs, inputs_size) || input0->element_num_ == 0) {
42 return NNACL_INFER_INVALID;
43 }
44 int index = ((int *)(get_index->data_))[0];
45 if (index < 0 || index > ((int)(input0->element_num_ - 1))) {
46 return NNACL_ERR;
47 }
48 TensorC *tensor_index = input0->tensors_[index];
49 NNACL_CHECK_NULL_RETURN_ERR(tensor_index);
50
51 if (tensor_index->data_type_ != kTypeUnknown) {
52 output->data_type_ = tensor_index->data_type_;
53 } else {
54 output->data_type_ = input0->tensors_data_type_;
55 }
56 output->format_ = input0->tensors_[index]->format_;
57
58 if (!InferFlag(inputs, inputs_size)) {
59 return NNACL_INFER_INVALID;
60 }
61
62 if (tensor_index->data_type_ != kTypeUnknown) {
63 ShapeSet(output->shape_, &(output->shape_size_), tensor_index->shape_, tensor_index->shape_size_);
64 } else {
65 const TensorC *input2 = inputs[2];
66 NNACL_CHECK_NULL_RETURN_ERR(input2);
67 NNACL_CHECK_NULL_RETURN_ERR(input2->data_);
68 int *ele_shape_data = (int *)(input2->data_);
69 NNACL_CHECK_NULL_RETURN_ERR(ele_shape_data);
70 int element_shape[MAX_SHAPE_SIZE] = {0};
71 size_t element_shape_size = 0;
72 for (int i = 0; i < GetElementNum(input2); ++i) {
73 ShapePush(element_shape, &element_shape_size, ele_shape_data[i]);
74 }
75 int status =
76 TensorListMergeShape(element_shape, &element_shape_size, input0->element_shape_, input0->element_shape_size_);
77 if (status != NNACL_OK) {
78 return NNACL_ERR;
79 }
80 if (!TensorListIsFullyDefined(element_shape, element_shape_size)) {
81 for (size_t i = 0; i < input0->element_num_; ++i) {
82 TensorC *input = input0->tensors_[i];
83 NNACL_CHECK_NULL_RETURN_ERR(input);
84 if (input->data_type_ != kTypeUnknown) {
85 status = TensorListMergeShape(element_shape, &element_shape_size, input->shape_, input->shape_size_);
86 if (status != NNACL_OK) {
87 return NNACL_ERR;
88 }
89 }
90 }
91 }
92 if (!TensorListIsFullyDefined(element_shape, element_shape_size)) { // the pre is the same judge condition
93 return NNACL_ERR;
94 }
95
96 SetShapeArray(output, element_shape, element_shape_size);
97 }
98
99 return NNACL_OK;
100 }
101
102 REG_INFER(TensorListGetItem, PrimType_TensorListGetItem, TensorListGetItemInferShape)
103