1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/infer/control/tensorlist_stack_infer.h"
18 #include "nnacl/infer/infer_register.h"
19 #include "nnacl/tensorlist_c_utils.h"
20 #include "nnacl/tensor_c_utils.h"
21
TensorListStackInferShape(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,OpParameter * parameter)22 int TensorListStackInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
23 OpParameter *parameter) {
24 int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1);
25 if (check_ret != NNACL_OK) {
26 return check_ret;
27 }
28
29 TensorC *output = outputs[0];
30 if (inputs[0]->data_type_ != kObjectTypeTensorType) {
31 return NNACL_INPUT_TENSOR_ERROR;
32 }
33 TensorListC *input0 = (TensorListC *)(inputs[0]);
34 output->data_type_ = input0->tensors_data_type_;
35 output->format_ = input0->format_;
36 if (!InferFlag(inputs, inputs_size)) {
37 return NNACL_INFER_INVALID;
38 }
39 if (input0->element_num_ == 0) {
40 return NNACL_INFER_INVALID;
41 }
42 const TensorC *ele_shape = inputs[1]; // element shape
43 if (ele_shape->data_ == NULL) {
44 return NNACL_NULL_PTR;
45 }
46 int *ele_shape_ptr = (int *)(ele_shape->data_);
47 int output_shape[MAX_SHAPE_SIZE] = {0};
48 size_t output_shape_size = 0;
49 if (ele_shape_ptr[0] == -1) {
50 if (input0->element_shape_size_ > MAX_SHAPE_SIZE) {
51 return NNACL_ERR;
52 }
53 for (size_t i = 0; i < input0->element_shape_size_; i++) {
54 ShapePush(output_shape, &output_shape_size, input0->element_shape_[i]);
55 }
56 } else {
57 int ele_shape_num = GetElementNum(ele_shape);
58 if (ele_shape_num > MAX_SHAPE_SIZE) {
59 return NNACL_ERR;
60 }
61 for (int i = 0; i < ele_shape_num; ++i) {
62 ShapePush(output_shape, &output_shape_size, ele_shape_ptr[i]);
63 }
64 }
65
66 int status =
67 TensorListMergeShape(output_shape, &output_shape_size, input0->element_shape_, input0->element_shape_size_);
68 if (status == NNACL_ERR) {
69 return NNACL_ERR;
70 }
71 if (!TensorListIsFullyDefined(output_shape, output_shape_size)) {
72 return NNACL_ERR;
73 }
74 if (!TensorListIsFullyDefined(input0->element_shape_, input0->element_shape_size_)) {
75 for (size_t i = 0; i < input0->element_num_; ++i) {
76 TensorC *tensor_ele = input0->tensors_[i];
77 if (tensor_ele->data_type_ != kTypeUnknown) {
78 status = TensorListMergeShape(output_shape, &output_shape_size, tensor_ele->shape_, tensor_ele->shape_size_);
79 if (status == NNACL_ERR) {
80 return NNACL_ERR;
81 }
82 }
83 }
84 }
85 if (output_shape_size >= MAX_SHAPE_SIZE) {
86 return NNACL_ERR;
87 }
88 int ret = ShapeInsert(output_shape, &output_shape_size, 0, input0->element_num_);
89 if (ret != NNACL_OK) {
90 return NNACL_ERR;
91 }
92 SetShapeArray(output, output_shape, output_shape_size);
93 return NNACL_OK;
94 }
95
96 REG_INFER(TensorListStack, PrimType_TensorListStack, TensorListStackInferShape)
97