• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_NNACL_OP_BASE_H_
18 #define MINDSPORE_NNACL_OP_BASE_H_
19 
20 #include <stdint.h>
21 #include <stdlib.h>
22 #include <stdbool.h>
23 #include <string.h>
24 #include <limits.h>
25 #if defined(ENABLE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM)
26 #include "nnacl/intrinsics/ms_simd_instructions.h"
27 #endif
28 
29 #define C1NUM 1
30 #define C2NUM 2
31 #define C3NUM 3
32 #define C4NUM 4
33 #define C5NUM 5
34 #define C6NUM 6
35 #define C8NUM 8
36 #define C12NUM 12
37 #define C16NUM 16
38 #define C20NUM 20
39 #define C24NUM 24
40 #define C32NUM 32
41 #define C40NUM 40
42 #define C64NUM 64
43 #define TILE_NUM 8
44 
45 #define MSMIN(x, y) ((x) < (y) ? (x) : (y))
46 #define MSMAX(x, y) ((x) > (y) ? (x) : (y))
47 #define MSCEIL(x) (int)((x) + (((x) - (int)(x)) > 0 ? 1 : 0))
48 
49 #define UP_DIV(x, y) (((x) + (y) - (1)) / (y))
50 #define UP_ROUND(x, y) (((x) + (y) - (1)) / (y) * (y))
51 #define UP_ROUND_DIV(x, y) (x % y == 0 ? (x / y) : (x / y) + 1)
52 #define DOWN_DIV(x, y) ((x) / (y))
53 #define DOWN_ROUND(x, y) ((x) / (y) * (y))
54 
55 #define MSVALID(left, x, right) (MSMIN((MSMAX(left, x)), right))
56 #define SIZE_MUL_OVERFLOW(x, y) (((x) == 0) ? false : (SIZE_MAX / (x)) < (y))
57 #define INT_MUL_OVERFLOW(x, y)                                                             \
58   ((x == 0) ? false                                                                        \
59             : ((x) > 0 ? ((y >= 0) ? (INT_MAX / (x)) < (y) : (INT_MAX / (x)) < (-1 * (y))) \
60                        : ((y >= 0) ? (INT_MAX / (x)) > (-1 * (y)) : (INT_MAX / (x)) > (y))))
61 
62 #define INT_MUL_OVERFLOW_THRESHOLD(x, y, threshold)                                                \
63   ((x == 0) ? false                                                                                \
64             : ((x) > 0 ? ((y >= 0) ? ((threshold) / (x)) < (y) : ((threshold) / (x)) < (-1 * (y))) \
65                        : ((y >= 0) ? ((threshold) / (x)) > (-1 * (y)) : ((threshold) / (x)) > (y))))
66 
67 #define INT_ADD_OVERFLOW(x, y) (INT_MAX - (x)) < (y)
68 
69 #define INT_ADD_OVERFLOW_THRESHOLD(x, y, threshold) ((threshold) - (x)) < (y)
70 
71 #define COMM_SHAPE_SIZE 4
72 #define MAX_SHAPE_SIZE 8
73 
74 #define FIRST_INPUT 0
75 #define SECOND_INPUT 1
76 #define THIRD_INPUT 2
77 #define FOURTH_INPUT 3
78 #define FIFTH_INPUT 4
79 
80 #define DIMENSION_1D 1
81 #define DIMENSION_2D 2
82 #define DIMENSION_3D 3
83 #define DIMENSION_4D 4
84 #define DIMENSION_5D 5
85 #define DIMENSION_6D 6
86 #define DIMENSION_7D 7
87 #define DIMENSION_8D 8
88 #define DIMENSION_10D 10
89 #define DIMENSION_11D 11
90 #define kInputIndex 0
91 #define kWeightIndex 1
92 #define kBiasIndex 2
93 #define kOutputIndex 0
94 #define kNHWC_N 0
95 #define kNHWC_H 1
96 #define kNHWC_W 2
97 #define kNHWC_C 3
98 #define kInputSize1 2
99 #define kInputSize2 3
100 #define MAX_AXIS_SIZE 6
101 #define MAX_LEN 256
102 #define FLT16_MAX 65504
103 #define NNACL_NC4HW4 13
104 #define kDefaulLiteMaxSpinCount 300000
105 #define kDefaulLiteMinSpinCount 1
106 #define kDefaulLiteIosSpinCount 1
107 #define INPUT_MAX_NUM 10
108 
109 #if ENABLE_HIGH_PERFORMANCE
110 #define MS_CHECK_TRUE_RET(value, errcode)
111 #define MS_CHECK_TRUE_RET_VOID(value)
112 #define MS_CHECK_FALSE(value, errcode)
113 #define MS_CHECK_TRUE_MSG(value, errcode, msg)
114 #define MS_CHECK_FALSE_MSG(value, errcode, msg)
115 #define MS_CHECK_LT(value1, value2, errcode)
116 #define MS_CHECK_GT(value1, value2, errcode)
117 #define MS_CHECK_LE(value1, value2, errcode)
118 #define MS_CHECK_GE(value1, value2, errcode)
119 #define MS_CHECK_PTR_IF_NULL(ptr)
120 
121 #define MS_CHECK_INT_MUL_NOT_OVERFLOW(value1, value2, errcode)
122 #define MS_CHECK_INT_ADD_NOT_OVERFLOW(value1, value2, errcode)
123 
124 #define NNACL_CHECK_ZERO_RETURN_ERR(val)
125 #define NNACL_CHECK_ZERO_RETURN(val)
126 #define NNACL_CHECK_NULL_RETURN_ERR(ptr)
127 #define NNACL_CHECK_NULL_RETURN_VOID(ptr)
128 #else
129 // Check whether value is true, if not return 'errcode'
130 #define MS_CHECK_TRUE_RET(value, errcode) \
131   do {                                    \
132     if (!(value)) {                       \
133       return errcode;                     \
134     }                                     \
135   } while (0)
136 
137 #define MS_CHECK_TRUE_RET_VOID(value) \
138   do {                                \
139     if (!(value)) {                   \
140       return;                         \
141     }                                 \
142   } while (0)
143 
144 // Check whether value is false, if not return 'errcode'
145 #define MS_CHECK_FALSE(value, errcode) \
146   do {                                 \
147     if ((value)) {                     \
148       return errcode;                  \
149     }                                  \
150   } while (0)
151 
152 // Check whether value is true, if not return 'errcode'
153 // and print error string msg
154 #define MS_CHECK_TRUE_MSG(value, errcode, msg) \
155   do {                                         \
156     if (!(value)) {                            \
157       MS_LOG(ERROR) << #msg;                   \
158       return errcode;                          \
159     }                                          \
160   } while (0)
161 
162 #define MS_CHECK_FALSE_MSG(value, errcode, msg) \
163   do {                                          \
164     if ((value)) {                              \
165       MS_LOG(ERROR) << #msg;                    \
166       return errcode;                           \
167     }                                           \
168   } while (0)
169 
170 #define MS_CHECK_LT(value1, value2, errcode)                                         \
171   do {                                                                               \
172     if ((value1) >= (value2)) {                                                      \
173       MS_LOG(ERROR) << "check ge fail, value1: " << value1 << " value2: " << value2; \
174       return errcode;                                                                \
175     }                                                                                \
176   } while (0)
177 
178 #define MS_CHECK_GT(value1, value2, errcode)                                         \
179   do {                                                                               \
180     if ((value1) <= (value2)) {                                                      \
181       MS_LOG(ERROR) << "check gt fail, value1: " << value1 << " value2: " << value2; \
182       return errcode;                                                                \
183     }                                                                                \
184   } while (0)
185 
186 #define MS_CHECK_LE(value1, value2, errcode)                                         \
187   do {                                                                               \
188     if ((value1) > (value2)) {                                                       \
189       MS_LOG(ERROR) << "check le fail, value1: " << value1 << " value2: " << value2; \
190       return errcode;                                                                \
191     }                                                                                \
192   } while (0)
193 
194 #define MS_CHECK_GE(value1, value2, errcode)                                         \
195   do {                                                                               \
196     if ((value1) < (value2)) {                                                       \
197       MS_LOG(ERROR) << "check ge fail, value1: " << value1 << " value2: " << value2; \
198       return errcode;                                                                \
199     }                                                                                \
200   } while (0)
201 
202 #define MS_CHECK_PTR_IF_NULL(ptr)                                \
203   do {                                                           \
204     if ((ptr) == nullptr) {                                      \
205       MS_LOG(ERROR) << ": The pointer[" << #ptr << "] is null."; \
206       return;                                                    \
207     }                                                            \
208   } while (0)
209 
210 #define MS_CHECK_INT_MUL_NOT_OVERFLOW(value1, value2, errcode) \
211   MS_CHECK_TRUE_RET(!(INT_MUL_OVERFLOW(value1, value2)), errcode)
212 #define MS_CHECK_INT_ADD_NOT_OVERFLOW(value1, value2, errcode) \
213   MS_CHECK_TRUE_RET(!(INT_ADD_OVERFLOW(value1, value2)), errcode)
214 
215 #define NNACL_CHECK_ZERO_RETURN_ERR(val) \
216   do {                                   \
217     if ((val) == 0) {                    \
218       return NNACL_ERR;                  \
219     }                                    \
220   } while (0)
221 
222 #define NNACL_CHECK_ZERO_RETURN(val) \
223   do {                               \
224     if ((val) == 0) {                \
225       return;                        \
226     }                                \
227   } while (0)
228 
229 #define NNACL_CHECK_NULL_RETURN_ERR(ptr) \
230   do {                                   \
231     if ((ptr) == NULL) {                 \
232       return NNACL_NULL_PTR;             \
233     }                                    \
234   } while (0)
235 
236 #define NNACL_CHECK_NULL_RETURN_VOID(ptr) \
237   do {                                    \
238     if ((ptr) == NULL) {                  \
239       return;                             \
240     }                                     \
241   } while (0)
242 
243 #endif
244 
245 typedef enum LiteDataType {
246   kDataTypeFloat,
247   kDataTypeFloat16,
248   kDataTypeInt,
249   kDataTypeInt8,
250   kDataTypeBool,
251   kDataTypeFloat64
252 } LiteDataType;
253 
254 typedef enum DataOrder {
255   RowMajor,
256   ColMajor,
257 } DataOrder;
258 
259 typedef struct OpParameter {
260   char name_[100];
261   int type_;
262   int thread_num_;
263   int quant_type_;
264   bool is_train_session_;
265   bool is_zero_shape_;
266   void (*destroy_func_)(struct OpParameter *param);
267 } OpParameter;
268 
269 typedef struct QuantArg {
270   float scale_;
271   int32_t zp_;
272 } QuantArg;
273 
274 typedef struct QuantMulArg {
275   int32_t multiplier_;
276   int left_shift_;
277   int right_shift_;
278 } QuantMulArg;
279 
280 typedef enum ActType { ActType_No, ActType_Relu, ActType_Sigmod, ActType_Relu6, ActType_Prelu } ActType;
281 typedef enum PadMode { Pad_pad, Pad_same, Pad_valid } PadMode;
282 typedef enum RoundingMode { Rounding_No, Rounding_Away_from_zero, Rounding_Up } RoundingMode;
283 typedef enum CalFixedMultiplierMode {
284   Method_No,
285   Method_SinglePrecision,
286   Method_DoublePrecision
287 } CalFixedMultiplierMode;
288 
289 #endif  // MINDSPORE_NNACL_OP_BASE_H_
290