1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
18 #define NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
19 #include <math.h>
20 #include "nnacl/intrinsics/ms_simd_cpu_info.h"
21
22 #ifdef ENABLE_AVX512
23 #include "nnacl/intrinsics/ms_simd_avx512_instructions.h"
24 #endif
25
26 #ifdef ENABLE_AVX
27 #include "nnacl/intrinsics/ms_simd_avx_instructions.h"
28 #endif
29
30 #ifdef ENABLE_SSE
31 #include "nnacl/intrinsics/ms_simd_sse_instructions.h"
32 #endif
33
34 #ifdef ENABLE_ARM
35 #include "nnacl/intrinsics/ms_simd_neon_instructions.h"
36 #endif
37
38 #define MS_SIMD_AVX512_INSTRUCTION(instruction, suffix) instruction##512##suffix
39 #define MS_SIMD_AVX_INSTRUCTION(instruction, suffix) instruction##256##suffix
40 #define MS_SIMD_SSE_INSTRUCTION(instruction, suffix) instruction##128##suffix
41 #define MS_SIMD_NEON_INSTRUCTION(instruction, suffix) instruction##128##suffix
42
43 #define MS_SIMD_INSTRUCTION_F32(instruction) MS_SIMD_INSTRUCTION(instruction, _F32)
44 #define MS_SIMD_INSTRUCTION_EPI32(instruction) MS_SIMD_INSTRUCTION(instruction, _EPI32)
45 #define MS_SIMD_INSTRUCTION_MASK(instruction) MS_SIMD_INSTRUCTION(instruction, _MASK)
46
47 // define (float/int) data
48 #define SIMD_F32 MS_SIMD_INSTRUCTION_F32(MS_FLOAT)
49 #define SIMD_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_INT)
50 #define SIMD_MASK MS_SIMD_INSTRUCTION(MS_MASK, _TYPE)
51
52 // read scaler data
53 #define SIMD_F32_GETI MS_SIMD_INSTRUCTION(MS, _F32_GETI)
54
55 // move (float/int) data
56 #define SIMD_MOV_F32 MS_SIMD_INSTRUCTION_F32(MS_MOV)
57 #define SIMD_MOV_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_MOV)
58 #define SIMD_SET0_F32 MS_SIMD_INSTRUCTION(MS_MOV, _VAL0_F32)
59
60 // load (float/int) data
61 #define SIMD_LD_F32 MS_SIMD_INSTRUCTION_F32(MS_LD)
62 #define SIMD_LD_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_LD)
63 #define SIMD_LD_HALF_EPI32 MS_SIMD_INSTRUCTION(MS_LD, _HALF_EPI32)
64
65 // load 4 (float/int) data
66 #define SIMD_LDX4_F32 MS_SIMD_INSTRUCTION(MS_LOAD, X4_F32)
67 #define SIMD_LDX4_EPI32 MS_SIMD_INSTRUCTION(MS_LOAD, X4_EPI32)
68
69 // stored (float/int) data
70 #define SIMD_ST_F32 MS_SIMD_INSTRUCTION_F32(MS_ST)
71 #define SIMD_ST_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_ST)
72 #define SIMD_ST_HALF_EPI32 MS_SIMD_INSTRUCTION(MS_ST, _HALF_EPI32)
73
74 // sign
75 #define SIMD_SIGN_F32 MS_SIMD_INSTRUCTION_F32(SIMD_SIGN)
76 #define SIMD_SIGNABS_F32 MS_SIMD_INSTRUCTION_F32(SIMD_SIGNABS)
77
78 // add (float/int) op
79 #define SIMD_ADD_F32 MS_SIMD_INSTRUCTION_F32(MS_ADD)
80 #define SIMD_ADD_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_ADD)
81 #define SIMD_ADD_N_F32(val1, val2) MS_EXPAND(SIMD_ADD_F32(val1, SIMD_MOV_F32(val2)))
82 #define SIMD_ADD_N_EPI32(val1, val2) MS_EXPAND(SIMD_ADD_EPI32(val1, SIMD_MOV_EPI32(val2)))
83
84 // sub (float/int) op
85 #define SIMD_SUB_F32 MS_SIMD_INSTRUCTION_F32(MS_SUB)
86 #define SIMD_SUB_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_SUB)
87 #define SIMD_SUB_N_F32(val1, val2) MS_EXPAND(SIMD_SUB_F32(val1, SIMD_MOV_F32(val2)))
88 #define SIMD_SUB_N_EPI32(val1, val2) MS_EXPAND(SIMD_SUB_EPI32(val1, SIMD_MOV_EPI32(val2)))
89
90 // div (float/int) op
91 #define SIMD_DIV_F32 MS_SIMD_INSTRUCTION_F32(MS_DIV)
92 #define SIMD_DIV_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_DIV)
93 #define SIMD_DIV_N_F32(val1, val2) MS_EXPAND(SIMD_DIV_F32(val1, SIMD_MOV_F32(val2)))
94 #define SIMD_DIV_N_EPI32(val1, val2) MS_EXPAND(SIMD_DIV_EPI32(val1, SIMD_MOV_EPI32(val2)))
95
96 // sqrt (float) op
97 #define SIMD_SQRT_F32 MS_SIMD_INSTRUCTION_F32(MS_SQRT)
98
99 // rsqrt (float) op
100 #define SIMD_RSQRT_F32 MS_SIMD_INSTRUCTION_F32(MS_RSQRT)
101
102 // log (float) op
103 #define SIMD_LOG_F32 MS_SIMD_INSTRUCTION(MS, _LOG_F32)
104
105 // cos (float) op
106 #define SIMD_COS_F32 MS_SIMD_INSTRUCTION_F32(MS_COS)
107
108 // sin (float) op
109 #define SIMD_SIN_F32 MS_SIMD_INSTRUCTION_F32(MS_SIN)
110
111 // erf (float) op
112 #define SIMD_ERF_F32 MS_SIMD_INSTRUCTION(MS, _ERF_F32)
113
114 // abs (float) op
115 #define SIMD_ABS_F32 MS_SIMD_INSTRUCTION_F32(MS_ABS)
116 #define SIMD_ABS_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_ABS)
117
118 // round (float) op
119 #define SIMD_ROUND_F32 MS_SIMD_INSTRUCTION_F32(MS_ROUND)
120
121 // ceil (float) op
122 #define SIMD_CEIL_F32 MS_SIMD_INSTRUCTION_F32(MS_CEIL)
123
124 // floor (float) op
125 #define SIMD_FLOOR_F32 MS_SIMD_INSTRUCTION_F32(MS_FLOOR)
126
127 // tanh (float) op
128 #define SIMD_TANH_F32 MS_SIMD_INSTRUCTION_F32(MS_TANH)
129
130 // min (float/int) op
131 #define SIMD_MIN_F32 MS_SIMD_INSTRUCTION_F32(MS_MIN)
132 #define SIMD_MIN_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_MIN)
133 #define SIMD_MIN_N_F32(val1, val2) MS_EXPAND(SIMD_MIN_F32(val1, SIMD_MOV_F32(val2)))
134 #define SIMD_MIN_N_EPI32(val1, val2) MS_EXPAND(SIMD_MIN_EPI32(val1, SIMD_MOV_EPI32(val2)))
135
136 // max (float/int) op
137 #define SIMD_MAX_F32 MS_SIMD_INSTRUCTION_F32(MS_MAX)
138 #define SIMD_MAX_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_MAX)
139 #define SIMD_MAX_N_F32(val1, val2) MS_EXPAND(SIMD_MAX_F32(val1, SIMD_MOV_F32(val2)))
140 #define SIMD_MAX_N_EPI32(val1, val2) MS_EXPAND(SIMD_MAX_EPI32(val1, SIMD_MOV_EPI32(val2)))
141
142 // get max (float/int) op
143 #define SIMD_GET_MAX_F32 MS_SIMD_INSTRUCTION_F32(MS_GET_MAX)
144 #define SIMD_GET_MAX_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_GET_MAX)
145
146 // get max (float/int) op
147 #define SIMD_GET_SUM_F32 MS_SIMD_INSTRUCTION_F32(MS_GET_SUM)
148 #define SIMD_REDUCE_ADD_F32 MS_SIMD_INSTRUCTION(MS_REDUCE_ADD, _F32)
149
150 // clamp (float/int) op
151 #define SIMD_CLAMP_F32(val, min_val, max_val) SIMD_MIN_F32(SIMD_MAX_F32(val, min_val), max_val)
152 #define SIMD_CLAMP_EPI32(val, min_val, max_val) SIMD_MIN_EPI32(SIMD_MAX_EPI32(val, min_val), max_val)
153 #define SIMD_CLAMP_N_F32(val, min_val, max_val) \
154 SIMD_MIN_F32(SIMD_MAX_F32(val, SIMD_MOV_F32(min_val)), SIMD_MOV_F32(max_val))
155 #define SIMD_CLAMP_N_EPI32(val, min_val, max_val) \
156 SIMD_MIN_EPI32(SIMD_MAX_EPI32(val, SIMD_MOV_EPI32(min_val)), SIMD_MOV_EPI32(max_val))
157
158 // mul (float/int) op
159 #define SIMD_MUL_F32 MS_SIMD_INSTRUCTION_F32(MS_MUL)
160 #define SIMD_MUL_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_MUL)
161 #define SIMD_MUL_N_F32(val1, val2) MS_EXPAND(SIMD_MUL_F32(val1, SIMD_MOV_F32(val2)))
162 #define SIMD_MUL_N_EPI32(val1, val2) MS_EXPAND(SIMD_MUL_EPI32(val1, SIMD_MOV_EPI32(val2)))
163
164 // pow (float) op
165 #define SIMD_POW_F32 MS_SIMD_INSTRUCTION_F32(MS_POW)
166
167 // fma (float/int) op
168 #define SIMD_FMADD_F32 MS_SIMD_INSTRUCTION_F32(MS_FMADD)
169
170 // fms (float/int) op
171 #define SIMD_FMSUB_F32 MS_SIMD_INSTRUCTION_F32(MS_FMSUB)
172
173 // fsm (float) op
174 #define MS_FSMUL_F32 MS_SIMD_INSTRUCTION_F32(MS_FSMUL)
175
176 // square (float/int) op
177 #define SIMD_MUL_SQUARE_F32(val1) SIMD_MUL_F32(val1, val1)
178 #define SIMD_MUL_SQUARE_EPI32(val1) SIMD_MUL_EPI32(val1, val1)
179
180 // exp (float) op
181 #define SIMD_EXP_ST_F32 MS_SIMD_INSTRUCTION(simd_exp, )
182 #define SIMD_EXP_F32 MS_SIMD_INSTRUCTION(simd_exp, _f32)
183 // exp (float) high precision but a little slow op.
184 #define SIMD_HEXP_F32 MS_SIMD_INSTRUCTION(simd_hexp, _f32)
185
186 // cmp (float/int) op
187 #define SIMD_CMPLT_F32 MS_SIMD_INSTRUCTION_F32(MS_CMPLT)
188 #define SIMD_CMPLE_F32 MS_SIMD_INSTRUCTION_F32(MS_CMPLE)
189 #define SIMD_CMPGT_F32 MS_SIMD_INSTRUCTION_F32(MS_CMPGT)
190 #define SIMD_BLEND_F32 MS_SIMD_INSTRUCTION_F32(MS_BLEND)
191
192 // cast data
193 #define MS_CAST_F32_S32 MS_SIMD_INSTRUCTION(MS_CAST, _F32_S32)
194
195 // logical op
196 #define SIMD_AND_MASK MS_SIMD_INSTRUCTION_MASK(MS_AND)
197 #define SIMD_OR_F32 MS_SIMD_INSTRUCTION_F32(MS_OR)
198 #define SIMD_AND_MASK_F32 MS_SIMD_INSTRUCTION(MS_AND, _MASK_F32)
199 #define SIMD_AND_F32 MS_SIMD_INSTRUCTION_F32(MS_AND)
200
201 #define SIMD_GETSIGN_F32(src) \
202 SIMD_OR_F32(SIMD_AND_F32(src, MS_CAST_F32_S32(SIMD_MOV_EPI32(0x80000000))), \
203 MS_CAST_F32_S32(SIMD_MOV_EPI32(0x3F800000)))
204
205 // int32/float mutual conversion
206 #define SIMD_EPI32_TO_F32 MS_SIMD_INSTRUCTION(MS, _INT32_TO_FLOAT32)
207 #define SIMD_F32_TO_EPI32 MS_SIMD_INSTRUCTION(MS, _FLOAT32_TO_INT32)
208 #define SIMD_F16_TO_F32 MS_SIMD_INSTRUCTION(MS, _FLOAT16_TO_FLOAT32)
209 #define SIMD_F32_TO_F16 MS_SIMD_INSTRUCTION(MS, _FLOAT32_TO_FLOAT16)
210
211 // enable avx512
212 #if defined(ENABLE_AVX512)
213 #define SIMD_RUN_AVX512(function, index, ...) \
214 do { \
215 AVX512_HARDWARE_SELF_AWARENESS_BEGIN \
216 index = function##AVX512(index, __VA_ARGS__); \
217 AVX512_HARDWARE_SELF_AWARENESS_END \
218 } while (0)
219 #else
220 #define SIMD_RUN_AVX512(function, index, ...)
221 #endif
222
223 // enable avx256
224 #if defined(ENABLE_AVX)
225 #define SIMD_RUN_AVX(function, index, ...) index = function##AVX(index, __VA_ARGS__)
226 #else
227 #define SIMD_RUN_AVX(function, index, ...)
228 #endif
229
230 // enable sse
231 #if defined(ENABLE_SSE)
232 #define SIMD_RUN_SSE(function, index, ...) index = function##SSE(index, __VA_ARGS__)
233 #else
234 #define SIMD_RUN_SSE(function, index, ...)
235 #endif
236
237 // enable neon
238 #if defined(ENABLE_NEON)
239 #define SIMD_RUN_NEON(function, index, ...) index = function##NEON(index, __VA_ARGS__)
240 #else
241 #define SIMD_RUN_NEON(function, index, ...)
242 #endif
243
244 #define SIMD_RUN_NO_SCALAR(function, index, ...) \
245 do { \
246 SIMD_RUN_AVX512(function, index, __VA_ARGS__); \
247 SIMD_RUN_AVX(function, index, __VA_ARGS__); \
248 SIMD_RUN_SSE(function, index, __VA_ARGS__); \
249 SIMD_RUN_NEON(function, index, __VA_ARGS__); \
250 } while (0)
251
252 #define SIMD_RUN_X86_NO_SCALAR(function, index, ...) \
253 do { \
254 SIMD_RUN_AVX512(function, index, __VA_ARGS__); \
255 SIMD_RUN_AVX(function, index, __VA_ARGS__); \
256 SIMD_RUN_SSE(function, index, __VA_ARGS__); \
257 } while (0)
258
259 #define SIMD512_BLOCK16 32 // SIMD : 512 = 16 x 32
260 #define SIMD256_BLOCK16 16 // SIMD : 256 = 16 x 16
261 #define SIMD128_BLOCK16 8 // SIMD : 128 = 16 x 8
262
263 #define SIMD512_BLOCK32 16 // SIMD : 512 = 32 x 16
264 #define SIMD256_BLOCK32 8 // SIMD : 256 = 32 x 8
265 #define SIMD128_BLOCK32 4 // SIMD : 128 = 32 x 4
266
267 #define SIMD512_BLOCK64 8 // SIMD : 512 = 64 x 8
268 #define SIMD256_BLOCK64 4 // SIMD : 256 = 64 x 4
269 #define SIMD128_BLOCK64 2 // SIMD : 128 = 64 x 2
270
271 #define MS_EXPAND(...) __VA_ARGS__
272
273 // Scaler
274 #define MS_FLOAT32X1 float
275 #define MS_INT32X1 int
276 #define MS_MOV32_F32(value) (value)
277 #define MS_MOV32_EPI32(value) (value)
278 #define MS_LD32_F32(address) (*(address))
279 #define MS_LD32_EPI32(address) (*(address))
280 #define MS_ST32_F32(address, value) (*(address) = (value))
281 #define MS_ST32_EPI32(address, value) (*(address) = (value))
282 #define MS_ADD32_F32(value1, value2) ((value1) + (value2))
283 #define MS_ADD32_EPI32(value1, value2) ((value1) + (value2))
284 #define MS_SUB32_F32(value1, value2) ((value1) - (value2))
285 #define MS_SUB32_EPI32(value1, value2) ((value1) - (value2))
286 #define MS_MUL32_F32(value1, value2) ((value1) * (value2))
287 #define MS_MUL32_EPI32(value1, value2) ((value1) * (value2))
288 #define MS_DIV32_F32(value1, value2) ((value1) / (value2))
289 #define MS_DIV32_EPI32(value1, value2) ((value1) / (value2))
290 #define MS_MIN32_F32(value1, value2) (fmin((value1), (value2)))
291 #define MS_MIN32_EPI32(value1, value2) ((value1) < (value2) ? (value1) : (value2))
292 #define MS_MAX32_F32(value1, value2) (fmax((value1), (value2)))
293 #define MS_MAX32_EPI32(value1, value2) ((value1) > (value2) ? (value1) : (value2))
294 #define MS_SQRT32_F32(value) (sqrt(value))
295
simd_exp32_f32(float data)296 static inline float simd_exp32_f32(float data) {
297 typedef union {
298 float f;
299 int i;
300 } fi;
301 static float param[] = {0.693147f, 1.0f / 120, 1.0f / 24, 1.0f / 6, 1.0f / 2, 1.0f}; // Approximate calculation param
302 #ifdef _WIN32
303 if (data < -88.0f) {
304 return 0.0f;
305 } else if (data > 88.0f) {
306 return 1.6516363e+38; // e^88 = 1.6516363e+38
307 }
308 #else
309 data =
310 MS_MAX32_F32(-87.3365478515625f, MS_MIN32_F32(88.72283935546875f, data)); // clamp(logf(FLT_MIN), logf(FLT_MAX))
311 #endif
312 int integer = data / param[0];
313 float decimal = data - integer * param[0];
314 fi int_exp;
315 const int shift = 23;
316 const int bias = 126;
317 const float factor = 2;
318 // 2^n * exp(r) should be counted 2 * 2^(n - 1) * exp(r),
319 // because n may be 128, and it is not representable by fp32.
320 int_exp.i = (integer + bias) << shift; // integer num 2^(n - 1) approximate calculation : ((x - 1) + 127) << 23
321 // Approximate calculation
322 const float decimal_exp =
323 1.0f + decimal * (1.0f + decimal * (0.5f + decimal * (param[3] + decimal * (param[2] + decimal * param[1]))));
324 return factor * int_exp.f * decimal_exp;
325 }
326
327 // exp(x) = exp(n * ln(2) + r) = 2^n * exp(r) = 2 * 2^(n - 1) * exp(r)
simd_exp32(float src,float * dst)328 static inline void simd_exp32(float src, float *dst) {
329 typedef union {
330 float f;
331 int i;
332 } fi;
333 static float param[] = {0.693147f, 1.0f / 120, 1.0f / 24, 1.0f / 6, 1.0f / 2, 1.0f}; // log(2.0f)
334 src = MS_MAX32_F32(-87.3365478515625f, MS_MIN32_F32(88.72283935546875f, src)); // clamp(logf(FLT_MIN), logf(FLT_MAX))
335 int integer = floor(src * 1.44269504088896341f + 0.5f);
336 float decimal = src - integer * param[0];
337 fi int_exp;
338 const int shift = 23;
339 const int bias = 126;
340 const float factor = 2;
341 // 2^n * exp(r) should be counted 2 * 2^(n - 1) * exp(r),
342 // because n may be 128, and it is not representable by fp32.
343 int_exp.i = (integer + bias) << shift; // integer num 2^(n - 1) approximate calculation : ((x - 1) + 127) << 23
344 const float decimal_exp =
345 1.0f + decimal * (1.0f + decimal * (0.5f + decimal * (param[3] + decimal * (param[2] + decimal * param[1]))));
346 *dst = factor * int_exp.f * decimal_exp;
347 }
348
349 // define (float/int) data
350 #define MS_FLOAT_32xN(byte_num) MS_FLOAT32##X##byte_num
351 #define MS_INT_32xN(byte_num) MS_INT32##X##byte_num
352
353 // move (float/int) data
354 #define MS_MOVN_F32(byte_num, ...) MS_EXPAND(MS_MOV##byte_num##_F32(__VA_ARGS__))
355 #define MS_MOVN_EPI32(byte_num, ...) MS_EXPAND(MS_MOV##byte_num##_EPI32(__VA_ARGS__))
356
357 // load (float/int) data
358 #define MS_LD_F32(bit_num, ...) MS_EXPAND(MS_LD##bit_num##_F32(__VA_ARGS__))
359 #define MS_LD_EPI32(bit_num, ...) MS_EXPAND(MS_LD##bit_num##_EPI32(__VA_ARGS__))
360
361 // load 4 (float/int) data
362 #define MS_LDX4_F32(bit_num, ...) MS_EXPAND(MS_LOAD##bit_num##X4_F32(__VA_ARGS__))
363 #define MS_LDX4_EPI32(bit_num, ...) MS_EXPAND(MS_LOAD##bit_num##X4_EPI32(__VA_ARGS__))
364
365 // stored (float/int) data
366 #define MS_ST_F32(bit_num, ...) MS_EXPAND(MS_ST##bit_num##_F32(__VA_ARGS__))
367 #define MS_ST_EPI32(bit_num, ...) MS_EXPAND(MS_ST##bit_num##_EPI32(__VA_ARGS__))
368
369 // add (float/int) op
370 #define MS_ADD_F32(bit_num, ...) MS_EXPAND(MS_ADD##bit_num##_F32(__VA_ARGS__))
371 #define MS_ADD_EPI32(bit_num, ...) MS_EXPAND(MS_ADD##bit_num##_EPI32(__VA_ARGS__))
372 #define MS_ADD_N_F32(bit_num, val1, val2) MS_EXPAND(MS_ADD##bit_num##_F32(val1, MS_MOV##bit_num##_F32(val2)))
373 #define MS_ADD_N_EPI32(bit_num, val1, val2) MS_EXPAND(MS_ADD##bit_num##_EPI32(val1, MS_MOV##bit_num##_F32(val2)))
374
375 // sub (float/int) op
376 #define MS_SUB_F32(bit_num, ...) MS_EXPAND(MS_SUB##bit_num##_F32(__VA_ARGS__))
377 #define MS_SUB_EPI32(bit_num, ...) MS_EXPAND(MS_SUB##bit_num##_EPI32(__VA_ARGS__))
378 #define MS_SUB_N_F32(bit_num, val1, val2) MS_EXPAND(MS_SUB##bit_num##_F32(val1, MS_MOV##bit_num##_F32(val2)))
379 #define MS_SUB_N_EPI32(bit_num, val1, val2) MS_EXPAND(MS_SUB##bit_num##_EPI32(val1, MS_MOV##bit_num##_F32(val2)))
380
381 // div (float/int) op
382 #define MS_DIV_F32(bit_num, ...) MS_EXPAND(MS_DIV##bit_num##_F32(__VA_ARGS__))
383 #define MS_DIV_EPI32(bit_num, ...) MS_EXPAND(MS_DIV##bit_num##_EPI32(__VA_ARGS__))
384 #define MS_DIV_N_F32(bit_num, val1, val2) MS_EXPAND(MS_DIV##bit_num##_F32(val1, MS_MOV##bit_num##_F32(val2)))
385 #define MS_DIV_N_EPI32(bit_num, val1, val2) MS_EXPAND(MS_DIV##bit_num##_EPI32(val1, MS_MOV##bit_num##_EPI32(val2)))
386
387 // sqrt (float) op
388 #define MS_SQRT_F32(bit_num, ...) MS_EXPAND(MS_SQRT##bit_num##_F32(__VA_ARGS__))
389
390 // rsqrt (float) op
391 #define MS_RSQRT_F32(bit_num, ...) MS_EXPAND(MS_RSQRT##bit_num##_F32(__VA_ARGS__))
392
393 // log (float) op
394 #define MS_LOG_F32(bit_num, ...) MS_EXPAND(MS_LOG##bit_num##_F32(__VA_ARGS__))
395
396 // cos (float) op
397 #define MS_COS_F32(bit_num, ...) MS_EXPAND(MS_COS##bit_num##_F32(__VA_ARGS__))
398
399 // sin (float) op
400 #define MS_SIN_F32(bit_num, ...) MS_EXPAND(MS_SIN##bit_num##_F32(__VA_ARGS__))
401
402 // erf (float) op
403 #define MS_ERF_F32(bit_num, ...) MS_EXPAND(MS_ERF##bit_num##_F32(__VA_ARGS__))
404
405 // log (float) op
406 #define MS_ABS_F32(bit_num, ...) MS_EXPAND(MS_ABS##bit_num##_F32(__VA_ARGS__))
407
408 // round (float) op
409 #define MS_ROUND_F32(bit_num, ...) MS_EXPAND(MS_ROUND##bit_num##_F32(__VA_ARGS__))
410
411 // ceil (float) op
412 #define MS_CEIL_F32(bit_num, ...) MS_EXPAND(MS_CEIL##bit_num##_F32(__VA_ARGS__))
413
414 // floor (float) op
415 #define MS_FLOOR_F32(bit_num, ...) MS_EXPAND(MS_FLOOR##bit_num##_F32(__VA_ARGS__))
416
417 // min (float/int) op
418 #define MS_MIN_F32(bit_num, ...) MS_EXPAND(MS_MIN##bit_num##_F32(__VA_ARGS__))
419 #define MS_MIN_EPI32(bit_num, ...) MS_EXPAND(MS_MIN##bit_num##_EPI32(__VA_ARGS__))
420 #define MS_MIN_N_F32(bit_num, val, n) MS_MIN_F32(bit_num, val, MS_MOVN_F32(bit_num, n))
421 #define MS_MIN_N_EPI32(bit_num, val, n) MS_MIN_EPI32(bit_num, val, MS_MOVN_EPI32(bit_num, n))
422
423 // max (float/int) op
424 #define MS_MAX_F32(bit_num, ...) MS_EXPAND(MS_MAX##bit_num##_F32(__VA_ARGS__))
425 #define MS_MAX_EPI32(bit_num, ...) MS_EXPAND(MS_MAX##bit_num##_EPI32(__VA_ARGS__))
426
427 // get max (float/int) op
428 #define MS_GET_MAX_F32(bit_num, ...) MS_EXPAND(MS_GET_MAX##bit_num##_F32(__VA_ARGS__))
429 #define MS_GET_MAX_EPI32(bit_num, ...) MS_EXPAND(MS_GET_MAX##bit_num##_EPI32(__VA_ARGS__))
430
431 // get max (float/int) op
432 #define MS_GET_SUM_F32(bit_num, ...) MS_EXPAND(MS_GET_SUM##bit_num##_F32(__VA_ARGS__))
433
434 // max n (float/int) op
435 #define MS_MAX_N_F32(bit_num, val, n) MS_MAX_F32(bit_num, val, MS_MOVN_F32(bit_num, n))
436 #define MS_MAX_N_EPI32(bit_num, val, n) MS_MAX_EPI32(bit_num, val, MS_MOVN_EPI32(bit_num, n))
437 #define MS_CLAMP_F32(bit_num, val, min_val, max_val) MS_MIN_F32(bit_num, MS_MAX_F32(bit_num, val, min_val), max_val)
438 #define MS_CLAMP_EPI32(bit_num, val, min_val, max_val) \
439 MS_MIN_EPI32(bit_num, MS_MAX_EPI32(bit_num, val, min_val), max_val)
440
441 // clamp n (float/int) op
442 #define MS_CLAMP_N_F32(bit_num, val, min_val, max_val) \
443 MS_MIN_F32(bit_num, MS_MAX_F32(bit_num, val, MS_MOV##bit_num##_F32(min_val)), MS_MOV##bit_num##_F32(max_val))
444 #define MS_CLAMP_N_EPI32(bit_num, val, min_val, max_val) \
445 MS_MIN_EPI32(bit_num, MS_MAX_EPI32(bit_num, val, MS_MOV##bit_num##_EPI32(min_val)), MS_MOV##bit_num##_EPI32(max_val))
446
447 // mul (float/int) op
448 #define MS_MUL_F32(bit_num, ...) MS_EXPAND(MS_MUL##bit_num##_F32(__VA_ARGS__))
449 #define MS_MUL_EPI32(bit_num, ...) MS_EXPAND(MS_MUL##bit_num##_EPI32(__VA_ARGS__))
450 #define MS_MUL_N_F32(bit_num, val1, val2) MS_EXPAND(MS_MUL##bit_num##_F32(val1, MS_MOV##bit_num##_F32(val2)))
451 #define MS_MUL_N_EPI32(bit_num, val1, val2) MS_EXPAND(MS_MUL##bit_num##_EPI32(val1, MS_MOV##bit_num##_EPI32(val2)))
452
453 // fma (float/int) op
454 #define MS_FMADD_F32(bit_num, ...) MS_EXPAND(MS_FMADD##bit_num##_F32(__VA_ARGS__))
455 #define MS_FMADD_N_F32(bit_num, val1, val2) MS_EXPAND(MS_FMADD##bit_num##_F32(val1, MS_MOV##bit_num##_F32(val2)))
456
457 // fms (float/int) op
458 #define MS_FMSUB_F32(bit_num, ...) MS_EXPAND(MS_FMSUB##bit_num##_F32(__VA_ARGS__))
459 #define MS_FMSUB_N_F32(bit_num, val1, val2) MS_EXPAND(MS_FMSUB##bit_num##_F32(val1, MS_MOV##bit_num##_F32(val2)))
460
461 // square (float/int) op
462 #define MS_MUL_SQUARE_F32(bit_num, val) MS_EXPAND((MS_MUL##bit_num##_F32(val, val)))
463 #define MS_MUL_SQUARE_EPI32(bit_num, val) MS_EXPAND((MS_MUL##bit_num##_EPI32(val, val)))
464
465 // exp (float) op
466 #define MS_EXP_ST_F32(bit_num, ...) MS_EXPAND((simd_exp##bit_num(__VA_ARGS__)))
467 #define MS_EXP_F32(bit_num, ...) MS_EXPAND((simd_exp##bit_num##_f32(__VA_ARGS__)))
468
469 #define MS_CMPLT_F32(bit_num, ...) MS_EXPAND((MS_CMPLT##bit_num##_F32(__VA_ARGS__)))
470 #define MS_CMPLE_F32(bit_num, ...) MS_EXPAND((MS_CMPLE##bit_num##_F32(__VA_ARGS__)))
471 #define MS_CMPGT_F32(bit_num, ...) MS_EXPAND((MS_CMPGT##bit_num##_F32(__VA_ARGS__)))
472 #define MS_BLEND_F32(bit_num, ...) MS_EXPAND((MS_BLEND##bit_num##_F32(__VA_ARGS__)))
473
474 #define MS_INT16_TO_FLOAT16(bit_num, ...) MS_EXPAND((MS##bit_num##_INT16_TO_FLOAT16(__VA_ARGS__)))
475 #define MS_FLOAT16_TO_INT16(bit_num, ...) MS_EXPAND((MS##bit_num##_FLOAT16_TO_INT16(__VA_ARGS__)))
476
477 #define MS_INT32_TO_FLOAT16(bit_num, ...) MS_EXPAND((MS##bit_num##_INT32_TO_FLOAT16(__VA_ARGS__)))
478 #define MS_FLOAT16_TO_INT32(bit_num, ...) MS_EXPAND((MS##bit_num##_FLOAT16_TO_INT32(__VA_ARGS__)))
479
480 #define MS_INT32_TO_FLOAT32(bit_num, ...) MS_EXPAND((MS##bit_num##_INT32_TO_FLOAT32(__VA_ARGS__)))
481 #define MS_FLOAT32_TO_INT32(bit_num, ...) MS_EXPAND((MS##bit_num##_FLOAT32_TO_INT32(__VA_ARGS__)))
482
483 #define MS_INT64_TO_FLOAT32(bit_num, ...) MS_EXPAND((MS##bit_num##_INT64_TO_FLOAT32(__VA_ARGS__)))
484 #define MS_FLOAT32_TO_INT64(bit_num, ...) MS_EXPAND((MS##bit_num##_FLOAT32_TO_INT64(__VA_ARGS__)))
485
486 #define MS_INT64_TO_FLOAT16(bit_num, ...) MS_EXPAND((MS##bit_num##_INT64_TO_FLOAT16(__VA_ARGS__)))
487 #define MS_FLOAT16_TO_INT64(bit_num, ...) MS_EXPAND((MS##bit_num##_FLOAT16_TO_INT64(__VA_ARGS__)))
488
489 #define MS_INT32_TO_FLOAT64(bit_num, ...) MS_EXPAND((MS##bit_num##_INT32_TO_FLOAT64(__VA_ARGS__)))
490 #define MS_FLOAT64_TO_INT32(bit_num, ...) MS_EXPAND((MS##bit_num##_FLOAT64_TO_INT32(__VA_ARGS__)))
491
492 #define MS_INT64_TO_FLOAT64(bit_num, ...) MS_EXPAND((MS##bit_num##_INT64_TO_FLOAT64(__VA_ARGS__)))
493 #define MS_FLOAT64_TO_INT64(bit_num, ...) MS_EXPAND((MS##bit_num##_FLOAT64_TO_INT64(__VA_ARGS__)))
494
495 // enable avx512
496 #if defined(ENABLE_AVX512)
497 #define MS_SIMD_RUN_AVX512(function, ...) MS_EXPAND(function(512, 16, __VA_ARGS__))
498 #else
499 #define MS_SIMD_RUN_AVX512(function, ...)
500 #endif
501
502 // enable avx256
503 #if defined(ENABLE_AVX)
504 #define MS_SIMD_RUN_AVX(function, ...) MS_EXPAND(function(256, 8, __VA_ARGS__))
505 #else
506 #define MS_SIMD_RUN_AVX(function, ...)
507 #endif
508
509 // enable sse
510 #if defined(ENABLE_SSE)
511 #define MS_SIMD_RUN_SSE(function, ...) MS_EXPAND(function(128, 4, __VA_ARGS__))
512 #else
513 #define MS_SIMD_RUN_SSE(function, ...)
514 #endif
515
516 // enable neon
517 #if defined(ENABLE_NEON)
518 #define MS_SIMD_RUN_NEON(function, ...) MS_EXPAND(function(128, 4, __VA_ARGS__))
519 #else
520 #define MS_SIMD_RUN_NEON(function, ...)
521 #endif
522
523 // enable neon/sse
524 #if defined(ENABLE_NEON) || defined(ENABLE_SSE)
525 #define MS_SIMD_RUN_SSEORNEON128(function, ...) MS_EXPAND(function(128, 4, __VA_ARGS__))
526 #else
527 #define MS_SIMD_RUN_SSEORNEON128(function, ...)
528 #endif
529
530 // scalar (c style data)
531 #define MS_SIMD_RUN_SCALAR(function, ...) MS_EXPAND(function(32, 1, __VA_ARGS__))
532
533 #define MS_SIMD_RUN(function, ...) \
534 do { \
535 MS_SIMD_RUN_AVX512(function, __VA_ARGS__); \
536 MS_SIMD_RUN_AVX(function, __VA_ARGS__); \
537 MS_SIMD_RUN_SSEORNEON128(function, __VA_ARGS__); \
538 MS_SIMD_RUN_SCALAR(function, __VA_ARGS__); \
539 } while (0)
540
541 #define MS_SIMD_RUN_NO_SCALAR(function, ...) \
542 do { \
543 MS_SIMD_RUN_AVX512(function, __VA_ARGS__); \
544 MS_SIMD_RUN_AVX(function, __VA_ARGS__); \
545 MS_SIMD_RUN_SSEORNEON128(function, __VA_ARGS__); \
546 } while (0)
547
548 #define MS_SIMD_RUN_X86(function, ...) \
549 do { \
550 MS_SIMD_RUN_AVX512(function, __VA_ARGS__); \
551 MS_SIMD_RUN_AVX(function, __VA_ARGS__); \
552 MS_SIMD_RUN_SSE(function, __VA_ARGS__); \
553 MS_SIMD_RUN_SCALAR(function, __VA_ARGS__); \
554 } while (0)
555
556 #define MS_SIMD_RUN_X86_NO_SCALAR(function, ...) \
557 do { \
558 MS_SIMD_RUN_AVX512(function, __VA_ARGS__); \
559 MS_SIMD_RUN_AVX(function, __VA_ARGS__); \
560 MS_SIMD_RUN_SSE(function, __VA_ARGS__); \
561 } while (0)
562
563 #endif // NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
564