1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/infer/tile_infer.h"
18 #include <limits.h>
19 #include "nnacl/infer/infer_register.h"
20 #include "nnacl/tile_parameter.h"
21 #include "nnacl/tensor_c_utils.h"
22
TileParamCaffe2Tflite(TileParameter * param,size_t out_shape_size)23 void TileParamCaffe2Tflite(TileParameter *param, size_t out_shape_size) {
24 if (param->dims_size_ != 0) {
25 int multiples_size_tmp[5] = {0};
26 NNACL_CHECK_TRUE_RET_VOID(out_shape_size <= 5);
27 for (size_t i = 0; i < out_shape_size; i++) {
28 multiples_size_tmp[i] = 1;
29 }
30 for (size_t i = 0; i < param->dims_size_; i++) {
31 if (i >= MAX_SHAPE_SIZE) {
32 return;
33 }
34 multiples_size_tmp[param->dims_[i]] = param->multiples_[i];
35 }
36 for (size_t i = 0; i < 5; i++) {
37 param->multiples_[i] = multiples_size_tmp[i];
38 }
39 }
40 }
41
TileInferShape(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,OpParameter * parameter)42 int TileInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
43 OpParameter *parameter) {
44 int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1);
45 if (check_ret != NNACL_OK) {
46 return check_ret;
47 }
48
49 const TensorC *input = inputs[0];
50 TensorC *output = outputs[0];
51
52 SetDataTypeFormat(output, input);
53 if (!InferFlag(inputs, inputs_size)) {
54 return NNACL_INFER_INVALID;
55 }
56
57 int out_shape[MAX_SHAPE_SIZE] = {0};
58 size_t out_shape_size = 0;
59 TileParameter *param = (TileParameter *)parameter;
60
61 size_t multiples_size = 0;
62 int input1_shape_size = inputs[1]->shape_size_;
63 if (input1_shape_size > (int)(input->shape_size_) || input->shape_size_ > MAX_SHAPE_SIZE) {
64 return NNACL_INPUT_TENSOR_ERROR;
65 }
66 NNACL_CHECK_TRUE_RET(input1_shape_size <= MAX_SHAPE_SIZE, NNACL_ERR);
67 int data_num = GetElementNum(inputs[1]);
68 multiples_size = (size_t)(data_num);
69 if (inputs[1]->data_type_ != kNumberTypeInt && inputs[1]->data_type_ != kNumberTypeInt32) {
70 return NNACL_INPUT_TENSOR_ERROR;
71 }
72 int *input1_data = inputs[1]->data_;
73 if (input1_data == NULL) {
74 return NNACL_INFER_INVALID;
75 }
76 NNACL_CHECK_TRUE_RET(data_num <= MAX_SHAPE_SIZE, NNACL_ERR);
77 for (int i = 0; i < data_num; i++) {
78 param->multiples_[i] = input1_data[i];
79 }
80
81 int *dims = param->dims_;
82 size_t dims_size = param->dims_size_;
83 if (dims_size == 0) {
84 int dim_num = GetElementNum(inputs[1]);
85 NNACL_CHECK_TRUE_RET(dim_num <= MAX_SHAPE_SIZE, NNACL_ERR);
86 for (int dim = 0; dim < dim_num; ++dim) {
87 ShapePush(dims, &dims_size, dim);
88 }
89 param->dims_size_ = dims_size;
90 }
91 NNACL_CHECK_TRUE_RET(multiples_size == dims_size, NNACL_ERR);
92 for (size_t i = 0; i < input->shape_size_; ++i) {
93 ShapePush(out_shape, &out_shape_size, input->shape_[i]);
94 }
95 for (size_t i = 0; i < dims_size; ++i) {
96 if (dims[i] >= MAX_SHAPE_SIZE || input->shape_[dims[i]] == 0) {
97 return NNACL_ERR;
98 }
99 if (input->shape_[dims[i]] != 0 && param->multiples_[i] > INT_MAX / input->shape_[dims[i]]) {
100 return NNACL_ERR;
101 }
102 NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(input->shape_[dims[i]], (param->multiples_[i])), NNACL_ERR);
103 out_shape[dims[i]] = input->shape_[dims[i]] * (param->multiples_[i]);
104 }
105 // change caffe param format to tflite
106 TileParamCaffe2Tflite(param, out_shape_size);
107 SetShapeArray(output, out_shape, out_shape_size);
108 return NNACL_OK;
109 }
110
111 REG_INFER(Tile, PrimType_TileFusion, TileInferShape)
112