• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/infer/control/tensorlist_stack_infer.h"
18 #include "nnacl/infer/infer_register.h"
19 #include "nnacl/tensorlist_c_utils.h"
20 #include "nnacl/tensor_c_utils.h"
21 
TensorListStackInferShape(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,OpParameter * parameter)22 int TensorListStackInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
23                               OpParameter *parameter) {
24   int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1);
25   if (check_ret != NNACL_OK) {
26     return check_ret;
27   }
28 
29   TensorC *output = outputs[0];
30   if (inputs[0]->data_type_ != kObjectTypeTensorType) {
31     return NNACL_INPUT_TENSOR_ERROR;
32   }
33   TensorListC *input0 = (TensorListC *)(inputs[0]);
34   output->data_type_ = input0->tensors_data_type_;
35   output->format_ = input0->format_;
36   if (!InferFlag(inputs, inputs_size)) {
37     return NNACL_INFER_INVALID;
38   }
39   if (input0->element_num_ == 0) {
40     return NNACL_INFER_INVALID;
41   }
42   const TensorC *ele_shape = inputs[1];  // element shape
43   if (ele_shape->data_ == NULL) {
44     return NNACL_NULL_PTR;
45   }
46   int *ele_shape_ptr = (int *)(ele_shape->data_);
47   int output_shape[MAX_SHAPE_SIZE] = {0};
48   size_t output_shape_size = 0;
49   if (ele_shape_ptr[0] == -1) {
50     if (input0->element_shape_size_ > MAX_SHAPE_SIZE) {
51       return NNACL_ERR;
52     }
53     for (size_t i = 0; i < input0->element_shape_size_; i++) {
54       ShapePush(output_shape, &output_shape_size, input0->element_shape_[i]);
55     }
56   } else {
57     int ele_shape_num = GetElementNum(ele_shape);
58     if (ele_shape_num > MAX_SHAPE_SIZE) {
59       return NNACL_ERR;
60     }
61     for (int i = 0; i < ele_shape_num; ++i) {
62       ShapePush(output_shape, &output_shape_size, ele_shape_ptr[i]);
63     }
64   }
65 
66   int status =
67     TensorListMergeShape(output_shape, &output_shape_size, input0->element_shape_, input0->element_shape_size_);
68   if (status == NNACL_ERR) {
69     return NNACL_ERR;
70   }
71   if (!TensorListIsFullyDefined(output_shape, output_shape_size)) {
72     return NNACL_ERR;
73   }
74   if (!TensorListIsFullyDefined(input0->element_shape_, input0->element_shape_size_)) {
75     for (size_t i = 0; i < input0->element_num_; ++i) {
76       TensorC *tensor_ele = input0->tensors_[i];
77       if (tensor_ele->data_type_ != kTypeUnknown) {
78         status = TensorListMergeShape(output_shape, &output_shape_size, tensor_ele->shape_, tensor_ele->shape_size_);
79         if (status == NNACL_ERR) {
80           return NNACL_ERR;
81         }
82       }
83     }
84   }
85   if (output_shape_size >= MAX_SHAPE_SIZE) {
86     return NNACL_ERR;
87   }
88   int ret = ShapeInsert(output_shape, &output_shape_size, 0, input0->element_num_);
89   if (ret != NNACL_OK) {
90     return NNACL_ERR;
91   }
92   SetShapeArray(output, output_shape, output_shape_size);
93   return NNACL_OK;
94 }
95 
96 REG_INFER(TensorListStack, PrimType_TensorListStack, TensorListStackInferShape)
97