• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use tensor file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "nnacl/infer/common_infer.h"
17 #include <stdlib.h>
18 #include <string.h>
19 #include "nnacl/infer/infer_register.h"
20 #include "nnacl/op_base.h"
21 
22 #ifndef CONTROLFLOW_TENSORLIST_CLIP
MallocTensorListData(TensorListC * tensor_list,TypeIdC dtype,const vvector * tensor_shape)23 int MallocTensorListData(TensorListC *tensor_list, TypeIdC dtype, const vvector *tensor_shape) {
24   // This function will create a new tensors_
25   // Your must to set shape(param2: tensor_shape) and data_type_(tensors_data_type_ = param1: dtype) of each tensor in
26   // tensors_. After that, you need to call function:MallocData to malloc data buf of each tensor in tensors_.
27 
28   if (tensor_list->element_num_ == 0) {
29     return NNACL_OK;
30   }
31   if (((size_t)(tensor_list->element_num_)) != tensor_shape->size_) {
32     return NNACL_ERR;
33   }
34   tensor_list->tensors_data_type_ = dtype;
35   tensor_list->tensors_ = (TensorC *)malloc(tensor_list->element_num_ * sizeof(TensorC));  // free in infer_manager
36   if (tensor_list->tensors_ == NULL) {
37     return NNACL_NULL_PTR;
38   }
39   memset(tensor_list->tensors_, 0, tensor_list->element_num_ * sizeof(TensorC));
40   for (size_t i = 0; i < tensor_list->element_num_; ++i) {
41     tensor_list->tensors_[i].format_ = Format_NHWC;
42     tensor_list->tensors_[i].data_type_ = dtype;
43     ShapeSet(tensor_list->tensors_[i].shape_, &(tensor_list->tensors_[i].shape_size_), tensor_shape->shape_[i],
44              tensor_shape->shape_size_[i]);
45   }
46   return NNACL_OK;
47 }
48 
TensorListMergeShape(int * element_shape,size_t * element_shape_size,const int * tmp,size_t tmp_size)49 int TensorListMergeShape(int *element_shape, size_t *element_shape_size, const int *tmp, size_t tmp_size) {
50   if (*element_shape_size >= 255 || element_shape[0] == -1) {
51     ShapeSet(element_shape, element_shape_size, tmp, tmp_size);
52     return NNACL_OK;
53   }
54   if (*element_shape_size != tmp_size) {
55     return NNACL_ERR;
56   }
57   for (size_t j = 0; j < tmp_size; ++j) {
58     if (element_shape[j] >= 0 && tmp[j] >= 0 && element_shape[j] != tmp[j]) {
59       return NNACL_ERR;
60     }
61     element_shape[j] = element_shape[j] >= 0 ? element_shape[j] : tmp[j];
62   }
63   return NNACL_OK;
64 }
65 
TensorListIsFullyDefined(const int * shape,size_t shape_size)66 bool TensorListIsFullyDefined(const int *shape, size_t shape_size) {
67   for (size_t i = 0; i < shape_size; ++i) {
68     if (shape[i] < 0) {
69       return false;
70     }
71   }
72   return true;
73 }
74 #endif
75 
CheckAugmentNull(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,const OpParameter * parameter)76 int CheckAugmentNull(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
77                      const OpParameter *parameter) {
78   NNACL_CHECK_NULL_RETURN_ERR(inputs);
79   NNACL_CHECK_NULL_RETURN_ERR(outputs);
80   for (size_t i = 0; i < inputs_size; i++) {
81     if (inputs[i] == NULL) {
82       return NNACL_NULL_PTR;
83     }
84   }
85   for (size_t i = 0; i < outputs_size; i++) {
86     if (outputs[i] == NULL) {
87       return NNACL_NULL_PTR;
88     }
89   }
90   if (parameter == NULL) {
91     return NNACL_NULL_PTR;
92   }
93   return NNACL_OK;
94 }
95 
CheckAugmentNullSize(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,const OpParameter * parameter,size_t inputs_size_obj,size_t outputs_size_obj)96 int CheckAugmentNullSize(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
97                          const OpParameter *parameter, size_t inputs_size_obj, size_t outputs_size_obj) {
98   int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter);
99   if (check_ret == NNACL_NULL_PTR) {
100     return NNACL_NULL_PTR;
101   }
102   if (inputs_size != inputs_size_obj || outputs_size != outputs_size_obj) {
103     return NNACL_INPUT_TENSOR_ERROR;
104   }
105   return NNACL_OK;
106 }
107 
CheckAugmentNullSizeInputTwo(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,const OpParameter * parameter,size_t inputs_size_obj_0,size_t inputs_size_obj_1,size_t outputs_size_obj)108 int CheckAugmentNullSizeInputTwo(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs,
109                                  size_t outputs_size, const OpParameter *parameter, size_t inputs_size_obj_0,
110                                  size_t inputs_size_obj_1, size_t outputs_size_obj) {
111   int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter);
112   if (check_ret == NNACL_NULL_PTR) {
113     return NNACL_NULL_PTR;
114   }
115   if ((inputs_size != inputs_size_obj_0 && inputs_size != inputs_size_obj_1) || outputs_size != outputs_size_obj) {
116     return NNACL_INPUT_TENSOR_ERROR;
117   }
118   return NNACL_OK;
119 }
120 
CheckAugmentNullInputSize(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,const OpParameter * parameter,size_t inputs_size_obj)121 int CheckAugmentNullInputSize(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
122                               const OpParameter *parameter, size_t inputs_size_obj) {
123   int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter);
124   if (check_ret == NNACL_NULL_PTR) {
125     return NNACL_NULL_PTR;
126   }
127   if (inputs_size != inputs_size_obj) {
128     return NNACL_INPUT_TENSOR_ERROR;
129   }
130   return NNACL_OK;
131 }
132 
CheckAugmentNullOutputSize(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,const OpParameter * parameter,size_t outputs_size_obj)133 int CheckAugmentNullOutputSize(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
134                                const OpParameter *parameter, size_t outputs_size_obj) {
135   int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter);
136   if (check_ret == NNACL_NULL_PTR) {
137     return NNACL_NULL_PTR;
138   }
139   if (outputs_size != outputs_size_obj) {
140     return NNACL_INPUT_TENSOR_ERROR;
141   }
142   return NNACL_OK;
143 }
144 
CheckAugmentWithMinSize(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,const OpParameter * parameter,size_t inputs_size_obj,size_t outputs_size_obj)145 int CheckAugmentWithMinSize(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
146                             const OpParameter *parameter, size_t inputs_size_obj, size_t outputs_size_obj) {
147   int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter);
148   if (check_ret == NNACL_NULL_PTR) {
149     return NNACL_NULL_PTR;
150   }
151   if (inputs_size < inputs_size_obj || outputs_size < outputs_size_obj) {
152     return NNACL_INPUT_TENSOR_ERROR;
153   }
154   return NNACL_OK;
155 }
156 
SetShapeTensor(TensorC * dst,const TensorC * src)157 void SetShapeTensor(TensorC *dst, const TensorC *src) {
158   for (size_t i = 0; i < src->shape_size_; i++) {
159     dst->shape_[i] = src->shape_[i];
160   }
161   dst->shape_size_ = src->shape_size_;
162 }
163 
SetShapeArray(TensorC * dst,const int * src,size_t src_size)164 void SetShapeArray(TensorC *dst, const int *src, size_t src_size) {
165   for (size_t i = 0; i < src_size && i < MAX_SHAPE_SIZE; i++) {
166     dst->shape_[i] = src[i];
167   }
168   dst->shape_size_ = src_size;
169 }
170 
SetDataTypeFormat(TensorC * dst,const TensorC * src)171 void SetDataTypeFormat(TensorC *dst, const TensorC *src) {
172   dst->format_ = src->format_;
173   dst->data_type_ = src->data_type_;
174 }
175 
GetBatch(const TensorC * tensor)176 int GetBatch(const TensorC *tensor) {
177   if (tensor->shape_size_ != 4 && tensor->shape_size_ != 2) {
178     return -1;
179   }
180   switch (tensor->format_) {
181     case Format_NHWC:
182     case Format_NHWC4:
183     case Format_NCHW:
184     case Format_NC4HW4:
185     case Format_KCHW:
186     case Format_KHWC:
187     case Format_NC:
188     case Format_NC4:
189       return tensor->shape_[0];
190     case Format_HWCK:
191     case Format_CHWK:
192       return tensor->shape_[3];
193     case Format_HWKC:
194       return tensor->shape_[2];
195     case Format_CKHW:
196       return tensor->shape_[1];
197     default:
198       return -1;
199   }
200 }
GetHeight(const TensorC * tensor)201 int GetHeight(const TensorC *tensor) {
202   if (tensor->shape_size_ != 4 && tensor->shape_size_ != 2) {
203     return -1;
204   }
205   switch (tensor->format_) {
206     case Format_NCHW:
207     case Format_KCHW:
208     case Format_CKHW:
209       return tensor->shape_[2];
210     case Format_NHWC:
211     case Format_NHWC4:
212     case Format_NC4HW4:
213     case Format_KHWC:
214     case Format_CHWK:
215       return tensor->shape_[1];
216     case Format_HWCK:
217     case Format_HWKC:
218     case Format_HW:
219     case Format_HW4:
220       return tensor->shape_[0];
221     default:
222       return -1;
223   }
224 }
GetWidth(const TensorC * tensor)225 int GetWidth(const TensorC *tensor) {
226   if (tensor->shape_size_ != 4 && tensor->shape_size_ != 2) {
227     return -1;
228   }
229   switch (tensor->format_) {
230     case Format_NCHW:
231     case Format_KCHW:
232     case Format_CKHW:
233       return tensor->shape_[3];
234     case Format_KHWC:
235     case Format_NHWC:
236     case Format_NHWC4:
237     case Format_NC4HW4:
238     case Format_CHWK:
239       return tensor->shape_[2];
240     case Format_HWCK:
241     case Format_HWKC:
242     case Format_HW:
243     case Format_HW4:
244       return tensor->shape_[1];
245     default:
246       return -1;
247   }
248 }
GetChannel(const TensorC * tensor)249 int GetChannel(const TensorC *tensor) {
250   if (tensor->shape_size_ != 4 && tensor->shape_size_ != 2) {
251     return -1;
252   }
253   switch (tensor->format_) {
254     case Format_NCHW:
255     case Format_KCHW:
256     case Format_NC:
257     case Format_NC4:
258       return tensor->shape_[1];
259     case Format_HWCK:
260       return tensor->shape_[2];
261     case Format_HWKC:
262     case Format_NHWC:
263     case Format_NHWC4:
264     case Format_NC4HW4:
265     case Format_KHWC:
266       return tensor->shape_[3];
267     case Format_CKHW:
268     case Format_CHWK:
269       return tensor->shape_[0];
270     default:
271       return -1;
272   }
273 }
274 
GetElementNum(const TensorC * tensor)275 int GetElementNum(const TensorC *tensor) {
276   if (tensor->shape_size_ == 0) {
277     return 1;  // scalar mode
278   }
279   int res = 1;
280   for (size_t i = 0; i < tensor->shape_size_; i++) {
281     MS_CHECK_INT_MUL_NOT_OVERFLOW(res, tensor->shape_[i], NNACL_ERRCODE_MUL_OVERFLOW);
282     res = res * tensor->shape_[i];
283   }
284   return res;
285 }
GetDimensionSize(const TensorC * tensor,const size_t index)286 int GetDimensionSize(const TensorC *tensor, const size_t index) {
287   int dim_size = -1;
288   if (index < tensor->shape_size_) {
289     dim_size = tensor->shape_[index];
290   }
291   return dim_size;
292 }
293 
ShapeSet(int * dst_shape,size_t * dst_shape_size,const int * src_shape,size_t src_shape_size)294 void ShapeSet(int *dst_shape, size_t *dst_shape_size, const int *src_shape, size_t src_shape_size) {
295   size_t i = 0;
296   for (; i < src_shape_size && i < MAX_SHAPE_SIZE; i++) {
297     dst_shape[i] = src_shape[i];
298   }
299   *dst_shape_size = i;
300 }
301 
ShapePush(int * shape,size_t * shape_size,int value)302 void ShapePush(int *shape, size_t *shape_size, int value) {
303   if (*shape_size >= MAX_SHAPE_SIZE) {
304     return;
305   }
306   shape[*shape_size] = value;
307   *shape_size = *shape_size + 1;
308 }
309 
ShapeInsert(int * shape,size_t * shape_size,int index,int value)310 int ShapeInsert(int *shape, size_t *shape_size, int index, int value) {
311   if (index < 0 || index > *shape_size) {
312     return NNACL_ERR;
313   }
314   if (*shape_size >= MAX_SHAPE_SIZE) {
315     return NNACL_ERR;
316   }
317   for (int i = *shape_size; i > index; i--) {
318     shape[i] = shape[i - 1];
319   }
320   shape[index] = value;
321   *shape_size = *shape_size + 1;
322   return NNACL_OK;
323 }
324 
ShapeErase(int * shape,size_t * shape_size,int index)325 int ShapeErase(int *shape, size_t *shape_size, int index) {
326   if (index < 0 || index >= *shape_size) {
327     return NNACL_ERR;
328   }
329 
330   for (int i = index; i < *shape_size - 1; i++) {
331     shape[i] = shape[i + 1];
332   }
333   *shape_size = *shape_size - 1;
334   return NNACL_OK;
335 }
336 
ShapeEqual(const int * shape0,size_t shape0_size,const int * shape1,size_t shape1_size)337 bool ShapeEqual(const int *shape0, size_t shape0_size, const int *shape1, size_t shape1_size) {
338   if (shape0_size != shape1_size) {
339     return false;
340   }
341   for (size_t i = 0; i < shape0_size; i++) {
342     if (shape0[i] != shape1[i]) {
343       return false;
344     }
345   }
346   return true;
347 }
348 
iswap(int * a,int * b)349 void iswap(int *a, int *b) {
350   int tmp = *a;
351   *a = *b;
352   *b = tmp;
353 }
354 
imin(int a,int b)355 int imin(int a, int b) { return a > b ? b : a; }
356 
imax(int a,int b)357 int imax(int a, int b) { return a < b ? b : a; }
358 
359 // input == output completely refer to
360 // 1. zeros_like
CommonInferShape(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,OpParameter * parameter)361 int CommonInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
362                      OpParameter *parameter) {
363   if (parameter == NULL || inputs[0] == NULL || outputs[0] == NULL) {
364     return NNACL_NULL_PTR;
365   }
366   SetDataTypeFormat(outputs[0], inputs[0]);
367   if (!InferFlag(inputs, inputs_size)) {
368     return NNACL_INFER_INVALID;
369   }
370   SetShapeTensor(outputs[0], inputs[0]);
371   return NNACL_OK;
372 }
373 
CommonInferShapeWithNHWC(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,OpParameter * parameter)374 int CommonInferShapeWithNHWC(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
375                              OpParameter *parameter) {
376   if (parameter == NULL || inputs[0] == NULL || outputs[0] == NULL) {
377     return NNACL_NULL_PTR;
378   }
379   if (inputs[0]->format_ != Format_NHWC) {
380     return NNACL_FORMAT_ERROR;
381   }
382   SetDataTypeFormat(outputs[0], inputs[0]);
383   if (!InferFlag(inputs, inputs_size)) {
384     return NNACL_INFER_INVALID;
385   }
386   SetShapeTensor(outputs[0], inputs[0]);
387   return NNACL_OK;
388 }
389 
FftInferShape(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,const OpParameter * parameter)390 int FftInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
391                   const OpParameter *parameter) {
392   int ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1);
393   if (ret != NNACL_OK) {
394     return ret;
395   }
396   const TensorC *input = inputs[0];
397   TensorC *output = outputs[0];
398   output->data_type_ = kNumberTypeFloat32;
399   output->format_ = input->format_;
400   if (!InferFlag(inputs, inputs_size)) {
401     return NNACL_INFER_INVALID;
402   }
403   if (input->shape_size_ > MAX_SHAPE_SIZE) {
404     return NNACL_INPUT_TENSOR_ERROR;
405   }
406   int input_shape[MAX_SHAPE_SIZE] = {0};
407   size_t input_shape_size = 0;
408   ShapeSet(input_shape, &input_shape_size, input->shape_, input->shape_size_);
409   if (input_shape_size == 0) {
410     return NNACL_ERR;
411   }
412   input_shape_size--;
413   SetShapeArray(output, input_shape, input_shape_size);
414   return NNACL_OK;
415 }
416 
InferFlag(const TensorC * const * inputs,size_t inputs_size)417 bool InferFlag(const TensorC *const *inputs, size_t inputs_size) {
418   if (inputs == NULL) {
419     return false;
420   }
421   for (size_t i = 0; i < inputs_size; i++) {
422     if (inputs[i] == NULL) {
423       return false;
424     }
425 #ifndef CONTROLFLOW_TENSORLIST_CLIP
426     if (inputs[i]->data_type_ == kObjectTypeTensorType) {
427       TensorListC *input_tensor_list = (TensorListC *)inputs[i];
428       if (input_tensor_list->shape_value_ == -1) {
429         return false;
430       }
431     } else {
432 #endif
433       for (size_t j = 0; j < inputs[i]->shape_size_; ++j) {
434         if (inputs[i]->shape_[j] == -1) {
435           return false;
436         }
437       }
438 #ifndef CONTROLFLOW_TENSORLIST_CLIP
439     }
440 #endif
441   }
442   return true;
443 }
444 
445 REG_INFER(Abs, PrimType_Abs, CommonInferShape)
446 REG_INFER(AbsGrad, PrimType_AbsGrad, CommonInferShape)
447 REG_INFER(Activation, PrimType_Activation, CommonInferShape)
448 REG_INFER(ActivationGrad, PrimType_ActivationGrad, CommonInferShape)
449 REG_INFER(BatchNorm, PrimType_BatchNorm, CommonInferShape)
450 REG_INFER(BinaryCrossEntropyGrad, PrimType_BinaryCrossEntropyGrad, CommonInferShape)
451 REG_INFER(BiasAdd, PrimType_BiasAdd, CommonInferShape)
452 REG_INFER(Ceil, PrimType_Ceil, CommonInferShape)
453 REG_INFER(Clip, PrimType_Clip, CommonInferShape)
454 REG_INFER(Cos, PrimType_Cos, CommonInferShape)
455 REG_INFER(Depend, PrimType_Depend, CommonInferShape)
456 REG_INFER(Elu, PrimType_Elu, CommonInferShape)
457 REG_INFER(Erf, PrimType_Erf, CommonInferShape)
458 REG_INFER(Exp, PrimType_ExpFusion, CommonInferShape)
459 REG_INFER(FakeQuantWithMinMaxVars, PrimType_FakeQuantWithMinMaxVars, CommonInferShape)
460 REG_INFER(Floor, PrimType_Floor, CommonInferShape)
461 REG_INFER(InstanceNorm, PrimType_InstanceNorm, CommonInferShape)
462 REG_INFER(IsFinite, PrimType_IsFinite, CommonInferShape)
463 REG_INFER(LeakyRelu, PrimType_LeakyRelu, CommonInferShape)
464 REG_INFER(Log, PrimType_Log, CommonInferShape)
465 REG_INFER(LogGrad, PrimType_LogGrad, CommonInferShape)
466 REG_INFER(LogicalNot, PrimType_LogicalNot, CommonInferShape)
467 REG_INFER(LRN, PrimType_LRN, CommonInferShapeWithNHWC)
468 REG_INFER(L2Normalize, PrimType_L2NormalizeFusion, CommonInferShape)
469 REG_INFER(Neg, PrimType_Neg, CommonInferShape)
470 REG_INFER(NegGrad, PrimType_NegGrad, CommonInferShape)
471 REG_INFER(PowerGrad, PrimType_PowerGrad, CommonInferShape)
472 REG_INFER(PReLU, PrimType_PReLUFusion, CommonInferShape)
473 REG_INFER(Reciprocal, PrimType_Reciprocal, CommonInferShape)
474 REG_INFER(ReverseSequence, PrimType_ReverseSequence, CommonInferShape)
475 REG_INFER(Reverse, PrimType_ReverseV2, CommonInferShape)
476 REG_INFER(Round, PrimType_Round, CommonInferShape)
477 REG_INFER(Rsqrt, PrimType_Rsqrt, CommonInferShape)
478 REG_INFER(Scale, PrimType_ScaleFusion, CommonInferShape)
479 REG_INFER(SigmoidCrossEntropyWithLogits, PrimType_SigmoidCrossEntropyWithLogits, CommonInferShape)
480 REG_INFER(SigmoidCrossEntropyWithLogitsGrad, PrimType_SigmoidCrossEntropyWithLogitsGrad, CommonInferShape)
481 REG_INFER(Sin, PrimType_Sin, CommonInferShape)
482 REG_INFER(SmoothL1Loss, PrimType_SmoothL1Loss, CommonInferShape)
483 REG_INFER(SmoothL1LossGrad, PrimType_SmoothL1LossGrad, CommonInferShape)
484 REG_INFER(Sqrt, PrimType_Sqrt, CommonInferShape)
485 REG_INFER(Square, PrimType_Square, CommonInferShape)
486 REG_INFER(ZerosLike, PrimType_ZerosLike, CommonInferShape)
487