1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/infer/control/tensorlist_setitem_infer.h"
18 #include "nnacl/infer/infer_register.h"
19 #include "nnacl/tensorlist_c_utils.h"
20 #include "nnacl/tensor_c_utils.h"
21
PreJudge(const TensorC * get_index,TensorListC * input0,const TensorC * value_tensor)22 int PreJudge(const TensorC *get_index, TensorListC *input0, const TensorC *value_tensor) {
23 if (get_index->data_ == NULL) {
24 return NNACL_INFER_INVALID;
25 }
26
27 if (get_index->data_type_ != kNumberTypeInt && get_index->data_type_ != kNumberTypeInt32) {
28 return NNACL_ERR;
29 }
30 if (GetElementNum(get_index) != 1) {
31 return NNACL_ERR;
32 }
33 if (get_index->data_ == NULL) {
34 return NNACL_NULL_PTR;
35 }
36 return NNACL_OK;
37 }
38
TensorListSetItemInferShape(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,OpParameter * parameter)39 int TensorListSetItemInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs,
40 size_t outputs_size, OpParameter *parameter) {
41 int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1);
42 if (check_ret != NNACL_OK) {
43 return check_ret;
44 }
45
46 TensorListC *input0 = (TensorListC *)(inputs[0]);
47 const TensorC *get_index = inputs[1];
48 const TensorC *value_tensor = inputs[2];
49 TensorListC *output0 = (TensorListC *)(outputs[0]);
50 output0->data_type_ = input0->data_type_;
51 output0->format_ = input0->format_;
52 output0->tensors_data_type_ = value_tensor->data_type_;
53
54 if (!InferFlag(inputs, inputs_size)) {
55 return NNACL_INFER_INVALID;
56 }
57
58 int judge_ret = PreJudge(get_index, input0, value_tensor);
59 if (judge_ret != NNACL_OK) {
60 return judge_ret;
61 }
62
63 int index = ((int *)(get_index->data_))[0];
64 output0->max_elements_num_ = input0->max_elements_num_;
65
66 if (input0->element_num_ == 0 && input0->element_shape_size_ == 0 && index == 0) {
67 ShapeSet(input0->element_shape_, &(input0->element_shape_size_), value_tensor->shape_, value_tensor->shape_size_);
68 ShapeSet(output0->element_shape_, &(output0->element_shape_size_), value_tensor->shape_, value_tensor->shape_size_);
69 } else {
70 ShapeSet(output0->element_shape_, &(output0->element_shape_size_), input0->element_shape_,
71 input0->element_shape_size_);
72 }
73
74 vvector out_shape;
75 out_shape.size_ = 0;
76 out_shape.shape_ = (int **)malloc((input0->element_num_ + 1) * sizeof(int *));
77 if (out_shape.shape_ == NULL) {
78 return NNACL_NULL_PTR;
79 }
80 out_shape.shape_size_ = (int *)malloc((input0->element_num_ + 1) * sizeof(int));
81 if (out_shape.shape_size_ == NULL) {
82 free(out_shape.shape_);
83 return NNACL_NULL_PTR;
84 }
85
86 if (index == 0 && input0->element_num_ == 0) { // uninitialized tensorlist
87 out_shape.shape_[out_shape.size_] = (int *)(value_tensor->shape_);
88 out_shape.shape_size_[out_shape.size_] = value_tensor->shape_size_;
89 out_shape.size_++;
90 output0->element_num_ = 1;
91 } else {
92 output0->element_num_ = input0->element_num_;
93 for (size_t i = 0; i < input0->element_num_; ++i) {
94 TensorC *src_ptr = input0->tensors_[i];
95 if (src_ptr == NULL) {
96 free(out_shape.shape_);
97 free(out_shape.shape_size_);
98 return NNACL_NULL_PTR;
99 }
100 if (src_ptr->data_type_ != kTypeUnknown) {
101 out_shape.shape_[out_shape.size_] = src_ptr->shape_;
102 out_shape.shape_size_[out_shape.size_] = (int)(src_ptr->shape_size_);
103 out_shape.size_++;
104 } else {
105 out_shape.shape_[out_shape.size_] = NULL;
106 out_shape.shape_size_[out_shape.size_] = 0;
107 out_shape.size_++;
108 }
109 }
110 }
111
112 if (input0->tensors_data_type_ == kTypeUnknown) {
113 input0->tensors_data_type_ = value_tensor->data_type_;
114 }
115
116 out_shape.shape_[index] = (int *)(value_tensor->shape_);
117 out_shape.shape_size_[index] = (int)value_tensor->shape_size_;
118 int ret = MallocTensorListData(output0, input0->tensors_data_type_, &out_shape);
119 if (ret != NNACL_OK) {
120 free(out_shape.shape_);
121 free(out_shape.shape_size_);
122 return NNACL_ERR;
123 }
124 free(out_shape.shape_);
125 free(out_shape.shape_size_);
126 return NNACL_OK;
127 }
128
129 REG_INFER(TensorListSetItem, PrimType_TensorListSetItem, TensorListSetItemInferShape)
130