1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use tensor file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/infer/common_infer.h"
18 #include <stdlib.h>
19 #include <string.h>
20 #include "nnacl/infer/infer_register.h"
21 #include "nnacl/op_base.h"
22 #include "nnacl/tensor_c_utils.h"
23 #include "nnacl/tensorlist_c_utils.h"
24
CheckShaleValid(TensorC ** tensors,int tensors_size)25 bool CheckShaleValid(TensorC **tensors, int tensors_size) {
26 for (int i = 0; i < tensors_size; i++) {
27 TensorC *t = tensors[i];
28 for (size_t j = 0; j < t->shape_size_; j++) {
29 if (t->shape_[j] == -1) {
30 return false;
31 }
32 }
33 }
34 return true;
35 }
36
CheckInferShapeDone(TensorC ** in,int in_size,TensorC ** out,int out_size)37 bool CheckInferShapeDone(TensorC **in, int in_size, TensorC **out, int out_size) {
38 for (int i = 0; i < in_size; i++) {
39 TensorC *t = in[i];
40 for (size_t j = 0; j < t->shape_size_; j++) {
41 if (t->shape_[j] == -1) {
42 return false;
43 }
44 }
45 }
46 for (int i = 0; i < out_size; i++) {
47 TensorC *t = out[i];
48 for (size_t j = 0; j < t->shape_size_; j++) {
49 if (t->shape_[j] == -1) {
50 return false;
51 }
52 }
53 }
54 return true;
55 }
56
ShapeSet(int * dst_shape,size_t * dst_shape_size,const int * src_shape,size_t src_shape_size)57 void ShapeSet(int *dst_shape, size_t *dst_shape_size, const int *src_shape, size_t src_shape_size) {
58 size_t i = 0;
59 for (; i < src_shape_size && i < MAX_SHAPE_SIZE; i++) {
60 dst_shape[i] = src_shape[i];
61 }
62 *dst_shape_size = i;
63 }
64
Int64ShapeSet(int * dst_shape,size_t * dst_shape_size,const int64_t * src_shape,size_t src_shape_size)65 bool Int64ShapeSet(int *dst_shape, size_t *dst_shape_size, const int64_t *src_shape, size_t src_shape_size) {
66 if (dst_shape_size == NULL || dst_shape == NULL || src_shape == NULL) {
67 return false;
68 }
69 size_t i = 0;
70 for (; i < src_shape_size && i < MAX_SHAPE_SIZE; i++) {
71 if (MS_UNLIKELY(src_shape[i] > (int64_t)INT32_MAX || src_shape[i] < (int64_t)INT32_MIN)) {
72 return false;
73 }
74 dst_shape[i] = (int32_t)(src_shape[i]);
75 }
76 *dst_shape_size = i;
77 return true;
78 }
79
ShapePush(int * shape,size_t * shape_size,int value)80 void ShapePush(int *shape, size_t *shape_size, int value) {
81 if (*shape_size >= MAX_SHAPE_SIZE) {
82 return;
83 }
84 shape[*shape_size] = value;
85 *shape_size = *shape_size + 1;
86 }
87
GetInt32DataFromTensor(const TensorC * tensor,int * result,size_t * result_size)88 int GetInt32DataFromTensor(const TensorC *tensor, int *result, size_t *result_size) {
89 if (tensor->data_ == NULL || result == NULL || result_size == NULL) {
90 return NNACL_ERR;
91 }
92 if (tensor->shape_size_ > MAX_SHAPE_SIZE) {
93 return NNACL_ERR;
94 }
95 int ele_num = GetElementNum(tensor);
96 if (ele_num <= 0) {
97 return NNACL_ERR;
98 }
99 *result_size = (size_t)ele_num;
100 if (tensor->data_type_ == kNumberTypeInt || tensor->data_type_ == kNumberTypeInt32) {
101 int *data = (int *)(tensor->data_);
102 for (int i = 0; i < ele_num; i++) {
103 result[i] = data[i];
104 }
105 } else if (tensor->data_type_ == kNumberTypeInt64) {
106 int64_t *data = (int64_t *)(tensor->data_);
107 for (int i = 0; i < ele_num; i++) {
108 if (data[i] >= INT32_MAX) {
109 return NNACL_ERR;
110 }
111 result[i] = (int32_t)data[i];
112 }
113 } else {
114 return NNACL_UNSUPPORTED_DATA_TYPE;
115 }
116 return NNACL_OK;
117 }
118
ShapeInsert(int * shape,size_t * shape_size,int index,int value)119 int ShapeInsert(int *shape, size_t *shape_size, int index, int value) {
120 if (index < 0 || index > *shape_size) {
121 return NNACL_ERR;
122 }
123 if (*shape_size >= MAX_SHAPE_SIZE) {
124 return NNACL_ERR;
125 }
126 for (int i = *shape_size; i > index; i--) {
127 shape[i] = shape[i - 1];
128 }
129 shape[index] = value;
130 *shape_size = *shape_size + 1;
131 return NNACL_OK;
132 }
133
ShapeErase(int * shape,size_t * shape_size,int index)134 int ShapeErase(int *shape, size_t *shape_size, int index) {
135 if (index < 0 || index >= *shape_size) {
136 return NNACL_ERR;
137 }
138
139 for (int i = index; i < *shape_size - 1; i++) {
140 shape[i] = shape[i + 1];
141 }
142 *shape_size = *shape_size - 1;
143 return NNACL_OK;
144 }
145
ShapeEqual(const int * shape0,size_t shape0_size,const int * shape1,size_t shape1_size)146 bool ShapeEqual(const int *shape0, size_t shape0_size, const int *shape1, size_t shape1_size) {
147 if (shape0_size != shape1_size) {
148 return false;
149 }
150 for (size_t i = 0; i < shape0_size; i++) {
151 if (shape0[i] != shape1[i]) {
152 return false;
153 }
154 }
155 return true;
156 }
157
iswap(int * a,int * b)158 void iswap(int *a, int *b) {
159 int tmp = *a;
160 *a = *b;
161 *b = tmp;
162 }
163
imin(int a,int b)164 int imin(int a, int b) { return a > b ? b : a; }
165
imax(int a,int b)166 int imax(int a, int b) { return a < b ? b : a; }
167
168 // input == output completely refer to
169 // 1. zeros_like
CommonInferShape(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,OpParameter * parameter)170 int CommonInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
171 OpParameter *parameter) {
172 if (parameter == NULL || inputs[0] == NULL || outputs[0] == NULL) {
173 return NNACL_NULL_PTR;
174 }
175 SetDataTypeFormat(outputs[0], inputs[0]);
176 if (!InferFlag(inputs, inputs_size)) {
177 return NNACL_INFER_INVALID;
178 }
179 SetShapeTensor(outputs[0], inputs[0]);
180 return NNACL_OK;
181 }
182
CommonGradInferShape(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,OpParameter * parameter)183 int CommonGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
184 OpParameter *parameter) {
185 int ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 2);
186 if (ret != NNACL_OK) {
187 return ret;
188 }
189 SetDataTypeFormat(outputs[0], inputs[0]);
190 if (!InferFlag(inputs, inputs_size)) {
191 return NNACL_INFER_INVALID;
192 }
193 NNACL_CHECK_TRUE_RET(inputs[0]->shape_size_ == inputs[1]->shape_size_, NNACL_ERR);
194 for (int i = 0; i < inputs[0]->shape_size_; i++) {
195 if (inputs[0]->shape_[i] != inputs[1]->shape_[i]) {
196 return NNACL_ERR;
197 }
198 }
199 SetShapeTensor(outputs[0], inputs[0]);
200 return NNACL_OK;
201 }
202
CommonInferShapeWithOneInput(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,OpParameter * parameter)203 int CommonInferShapeWithOneInput(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs,
204 size_t outputs_size, OpParameter *parameter) {
205 int ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 1);
206 if (ret != NNACL_OK) {
207 return ret;
208 }
209 SetDataTypeFormat(outputs[0], inputs[0]);
210 if (!InferFlag(inputs, inputs_size)) {
211 return NNACL_INFER_INVALID;
212 }
213 SetShapeTensor(outputs[0], inputs[0]);
214 return NNACL_OK;
215 }
216
CommonInferShapeWithTwoInput(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,OpParameter * parameter)217 int CommonInferShapeWithTwoInput(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs,
218 size_t outputs_size, OpParameter *parameter) {
219 int ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 2);
220 if (ret != NNACL_OK) {
221 return ret;
222 }
223 SetDataTypeFormat(outputs[0], inputs[0]);
224 if (!InferFlag(inputs, inputs_size)) {
225 return NNACL_INFER_INVALID;
226 }
227 SetShapeTensor(outputs[0], inputs[0]);
228 return NNACL_OK;
229 }
230
CommonInferShapeWithNHWC(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,OpParameter * parameter)231 int CommonInferShapeWithNHWC(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
232 OpParameter *parameter) {
233 if (parameter == NULL || inputs[0] == NULL || outputs[0] == NULL) {
234 return NNACL_NULL_PTR;
235 }
236 if (inputs[0]->format_ != Format_NHWC) {
237 return NNACL_FORMAT_ERROR;
238 }
239 SetDataTypeFormat(outputs[0], inputs[0]);
240 if (!InferFlag(inputs, inputs_size)) {
241 return NNACL_INFER_INVALID;
242 }
243 SetShapeTensor(outputs[0], inputs[0]);
244 return NNACL_OK;
245 }
246
FftInferShape(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,const OpParameter * parameter)247 int FftInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
248 const OpParameter *parameter) {
249 int ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1);
250 if (ret != NNACL_OK) {
251 return ret;
252 }
253 const TensorC *input = inputs[0];
254 TensorC *output = outputs[0];
255 output->data_type_ = kNumberTypeFloat32;
256 output->format_ = input->format_;
257 if (!InferFlag(inputs, inputs_size)) {
258 return NNACL_INFER_INVALID;
259 }
260 if (input->shape_size_ > MAX_SHAPE_SIZE) {
261 return NNACL_INPUT_TENSOR_ERROR;
262 }
263 int input_shape[MAX_SHAPE_SIZE] = {0};
264 size_t input_shape_size = 0;
265 ShapeSet(input_shape, &input_shape_size, input->shape_, input->shape_size_);
266 if (input_shape_size == 0) {
267 return NNACL_ERR;
268 }
269 input_shape_size--;
270 SetShapeArray(output, input_shape, input_shape_size);
271 return NNACL_OK;
272 }
273
InferFlag(const TensorC * const * inputs,size_t inputs_size)274 bool InferFlag(const TensorC *const *inputs, size_t inputs_size) {
275 if (inputs == NULL) {
276 return false;
277 }
278 for (size_t i = 0; i < inputs_size; i++) {
279 if (inputs[i] == NULL) {
280 return false;
281 }
282 if (inputs[i]->data_type_ == kObjectTypeTensorType) {
283 if (InferFlagTensorList((TensorC *)inputs[i]) == false) {
284 return false;
285 }
286 } else {
287 for (size_t j = 0; j < inputs[i]->shape_size_; ++j) {
288 if (inputs[i]->shape_[j] < 0) {
289 return false;
290 }
291 }
292 }
293 }
294 return true;
295 }
296
297 REG_INFER(Abs, PrimType_Abs, CommonInferShape)
298 REG_INFER(AbsGrad, PrimType_AbsGrad, CommonGradInferShape)
299 REG_INFER(Activation, PrimType_Activation, CommonInferShape)
300 REG_INFER(BatchNorm, PrimType_BatchNorm, CommonInferShape)
301 REG_INFER(BinaryCrossEntropyGrad, PrimType_BinaryCrossEntropyGrad, CommonInferShape)
302 REG_INFER(Ceil, PrimType_Ceil, CommonInferShape)
303 REG_INFER(Clip, PrimType_Clip, CommonInferShape)
304 REG_INFER(Cos, PrimType_Cos, CommonInferShape)
305 REG_INFER(Depend, PrimType_Depend, CommonInferShape)
306 REG_INFER(Elu, PrimType_Elu, CommonInferShape)
307 REG_INFER(Erf, PrimType_Erf, CommonInferShape)
308 REG_INFER(Exp, PrimType_ExpFusion, CommonInferShape)
309 REG_INFER(FakeQuantWithMinMaxVars, PrimType_FakeQuantWithMinMaxVars, CommonInferShape)
310 REG_INFER(Floor, PrimType_Floor, CommonInferShapeWithOneInput)
311 REG_INFER(LeakyRelu, PrimType_LeakyRelu, CommonInferShape)
312 REG_INFER(Log, PrimType_Log, CommonInferShape)
313 REG_INFER(Log1p, PrimType_Log1p, CommonInferShape)
314 REG_INFER(LogGrad, PrimType_LogGrad, CommonGradInferShape)
315 REG_INFER(LogicalNot, PrimType_LogicalNot, CommonInferShape)
316 REG_INFER(LRN, PrimType_LRN, CommonInferShapeWithNHWC)
317 REG_INFER(L2Normalize, PrimType_L2NormalizeFusion, CommonInferShape)
318 REG_INFER(Neg, PrimType_Neg, CommonInferShape)
319 REG_INFER(NegGrad, PrimType_NegGrad, CommonGradInferShape)
320 REG_INFER(OnesLike, PrimType_OnesLike, CommonInferShape)
321 REG_INFER(PowerGrad, PrimType_PowerGrad, CommonGradInferShape)
322 REG_INFER(PReLU, PrimType_PReLUFusion, CommonInferShape)
323 REG_INFER(Reciprocal, PrimType_Reciprocal, CommonInferShape)
324 REG_INFER(ReverseSequence, PrimType_ReverseSequence, CommonInferShape)
325 REG_INFER(Reverse, PrimType_ReverseV2, CommonInferShape)
326 REG_INFER(Round, PrimType_Round, CommonInferShape)
327 REG_INFER(Rsqrt, PrimType_Rsqrt, CommonInferShape)
328 REG_INFER(Scale, PrimType_ScaleFusion, CommonInferShape)
329 REG_INFER(SigmoidCrossEntropyWithLogits, PrimType_SigmoidCrossEntropyWithLogits, CommonInferShape)
330 REG_INFER(SigmoidCrossEntropyWithLogitsGrad, PrimType_SigmoidCrossEntropyWithLogitsGrad, CommonInferShape)
331 REG_INFER(Sin, PrimType_Sin, CommonInferShape)
332 REG_INFER(SmoothL1Loss, PrimType_SmoothL1Loss, CommonInferShape)
333 REG_INFER(SmoothL1LossGrad, PrimType_SmoothL1LossGrad, CommonInferShape)
334 REG_INFER(Sqrt, PrimType_Sqrt, CommonInferShape)
335 REG_INFER(SqrtGrad, PrimType_SqrtGrad, CommonInferShape)
336 REG_INFER(Square, PrimType_Square, CommonInferShape)
337 REG_INFER(ZerosLike, PrimType_ZerosLike, CommonInferShape)
338 REG_INFER(ScatterElements, PrimType_ScatterElements, CommonInferShape)
339