• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/infer/control/tensorlist_getitem_infer.h"
18 #include "nnacl/infer/infer_register.h"
19 #include "nnacl/tensorlist_c_utils.h"
20 #include "nnacl/tensor_c_utils.h"
21 
TensorListGetItemInferShape(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,OpParameter * parameter)22 int TensorListGetItemInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs,
23                                 size_t outputs_size, OpParameter *parameter) {
24   int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1);
25   if (check_ret != NNACL_OK) {
26     return check_ret;
27   }
28 
29   if (inputs[0]->data_type_ != kObjectTypeTensorType) {
30     return NNACL_ERR;
31   }
32   TensorListC *input0 = (TensorListC *)(inputs[0]);
33   const TensorC *get_index = inputs[1];
34   if (get_index->data_ == NULL) {
35     return NNACL_INFER_INVALID;
36   }
37   if (GetElementNum(get_index) != 1) {
38     return NNACL_ERR;
39   }
40   TensorC *output = outputs[0];
41   if (!InferFlag(inputs, inputs_size) || input0->element_num_ == 0) {
42     return NNACL_INFER_INVALID;
43   }
44   int index = ((int *)(get_index->data_))[0];
45   if (index < 0 || index > ((int)(input0->element_num_ - 1))) {
46     return NNACL_ERR;
47   }
48   TensorC *tensor_index = input0->tensors_[index];
49   NNACL_CHECK_NULL_RETURN_ERR(tensor_index);
50 
51   if (tensor_index->data_type_ != kTypeUnknown) {
52     output->data_type_ = tensor_index->data_type_;
53   } else {
54     output->data_type_ = input0->tensors_data_type_;
55   }
56   output->format_ = input0->tensors_[index]->format_;
57 
58   if (!InferFlag(inputs, inputs_size)) {
59     return NNACL_INFER_INVALID;
60   }
61 
62   if (tensor_index->data_type_ != kTypeUnknown) {
63     ShapeSet(output->shape_, &(output->shape_size_), tensor_index->shape_, tensor_index->shape_size_);
64   } else {
65     const TensorC *input2 = inputs[2];
66     NNACL_CHECK_NULL_RETURN_ERR(input2);
67     NNACL_CHECK_NULL_RETURN_ERR(input2->data_);
68     int *ele_shape_data = (int *)(input2->data_);
69     NNACL_CHECK_NULL_RETURN_ERR(ele_shape_data);
70     int element_shape[MAX_SHAPE_SIZE] = {0};
71     size_t element_shape_size = 0;
72     for (int i = 0; i < GetElementNum(input2); ++i) {
73       ShapePush(element_shape, &element_shape_size, ele_shape_data[i]);
74     }
75     int status =
76       TensorListMergeShape(element_shape, &element_shape_size, input0->element_shape_, input0->element_shape_size_);
77     if (status != NNACL_OK) {
78       return NNACL_ERR;
79     }
80     if (!TensorListIsFullyDefined(element_shape, element_shape_size)) {
81       for (size_t i = 0; i < input0->element_num_; ++i) {
82         TensorC *input = input0->tensors_[i];
83         NNACL_CHECK_NULL_RETURN_ERR(input);
84         if (input->data_type_ != kTypeUnknown) {
85           status = TensorListMergeShape(element_shape, &element_shape_size, input->shape_, input->shape_size_);
86           if (status != NNACL_OK) {
87             return NNACL_ERR;
88           }
89         }
90       }
91     }
92     if (!TensorListIsFullyDefined(element_shape, element_shape_size)) {  // the pre is the same judge condition
93       return NNACL_ERR;
94     }
95 
96     SetShapeArray(output, element_shape, element_shape_size);
97   }
98 
99   return NNACL_OK;
100 }
101 
102 REG_INFER(TensorListGetItem, PrimType_TensorListGetItem, TensorListGetItemInferShape)
103