1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
18 #define MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
19 #include <math.h>
20
21 #ifdef ENABLE_ARM
22 #include <arm_neon.h>
23 #define MS_F32X4_GETI(src, i) src[i]
24 #endif
25
26 #if defined(ENABLE_SSE)
27 #ifdef _MSC_VER
28 #include <immintrin.h>
29 #define MS_F32X4_GETI(src, i) src.m128_f32[i]
30 #else
31 #include <x86intrin.h>
32 #define MS_F32X4_GETI(src, i) src[i]
33 #endif
34 #endif
35
36 #ifdef ENABLE_AVX
37 #ifdef _MSC_VER
38 #include <immintrin.h>
39 #define MS_F32X8_GETI(src, i) src.m256_f32[i]
40 #else
41 #define MS_F32X8_GETI(src, i) src[i]
42 #endif
43 #endif
44
45 #ifdef ENABLE_ARM
46 #define MS_FLOAT32X4 float32x4_t
47 #define MS_INT32X4 int32x4_t
48 #define MS_UINT32X4 uint32x4_t
49 #define MS_LDQ_F32 vld1q_f32
50 #define MS_LDQ_EPI32 vld1q_s32
51 #define MS_ADDQ_F32 vaddq_f32
52 #define MS_ADDQ_EPI32 vaddq_s32
53 #define MS_MOVQ_F32 vmovq_n_f32
54 #define MS_MOVQ_EPI32 vmovq_n_s32
55 #define MS_SUBQ_F32 vsubq_f32
56 #define MS_MLAQ_F32(src1, src2, src3) vmlaq_f32(src1, src2, src3)
57 #define MS_STQ_F32 vst1q_f32
58 #define MS_STQ_EPI32 vst1q_s32
59 #define MS_MAXQ_F32 vmaxq_f32
60 #define MS_MAXQ_EPI32 vmaxq_s32
61 #define MS_MINQ_F32 vminq_f32
62 #define MS_MINQ_EPI32 vminq_s32
63 #define MS_MULQ_F32(src1, src2) vmulq_f32(src1, src2)
64 #define MS_MULQ_EPI32(src1, src2) vmulq_s32(src1, src2)
65 #ifdef ENABLE_ARM64
66 #define MS_DIVQ_F32(src1, src2) vdivq_f32(src1, src2)
67 #else
vrecp(float32x4_t v)68 static inline float32x4_t vrecp(float32x4_t v) {
69 float32x4_t r = vrecpeq_f32(v);
70 r = vmulq_f32(vrecpsq_f32(v, r), r);
71 r = vmulq_f32(vrecpsq_f32(v, r), r);
72 return r;
73 }
74 #define MS_DIVQ_F32(src1, src2) vmulq_f32(src1, vrecp(src2))
75 #endif
76 #define MS_MULQ_N_F32(src1, src2) vmulq_n_f32(src1, src2)
77 #define MS_MULQ_N_EPI32(src1, src2) vmulq_n_s32(src1, src2)
78 #define MS_DIVQ_N_F32(src1, src2) vdivq_n_f32(src1, src2)
79 #define MS_SLLIQ_EPI32(src1, src2) vshlq_s32(src1, vmovq_n_s32(src2))
80 #define MS_CVTQPS_EPI32(src) vcvtq_s32_f32(src)
81 #define MS_CVTQEPI32_PS(src) vcvtq_f32_s32(src)
82 #define MS_CMPGTQ_F32(src1, src2) vcgtq_f32(src1, src2)
83 #define MS_CMPGTQ_EPI32(src1, src2) vcgtq_s32(src1, src2)
84 // Note: Compared with X86, the vbslq_f32 parameters are the opposite with _mm_blendv_f32
85 #define MS_BLENDQ_F32(src1, src2, src3) vbslq_f32(src3, src2, src1)
86 #define MS_BLENDQ_EPI32(src1, src2, src3) vbslq_s32(src3, src2, src1)
87 #define MS_CAST_F32_S32(src) vreinterpretq_f32_s32(src)
88 #endif
89
90 #if defined(ENABLE_AVX)
91 #define MS_FLOAT32X8 __m256
92 #define MS_INT32X8 __m256i
93 #define MS_LD256_F32 _mm256_loadu_ps
94 #define MS_LD256_EPI32(src) _mm256_loadu_si256((__m256i const *)(src))
95 #define MS_ADD256_F32 _mm256_add_ps
96 #define MS_ADD256_EPI32 _mm256_add_epi32
97 #define MS_MOV256_F32 _mm256_set1_ps
98 #define MS_MOV256_EPI32 _mm256_set1_epi32
99 #define MS_MLA256_F32(src1, src2, src3) _mm256_add_ps(src1, _mm256_mul_ps(src2, src3))
100 #define MS_ST256_F32 _mm256_storeu_ps
101 #define MS_ST256_EPI32(src1, src2) _mm256_storeu_si256((__m256i *)(src1), src2)
102 #define MS_SUB256_F32 _mm256_sub_ps
103 #define MS_MAX256_F32 _mm256_max_ps
104 #define MS_MAX256_EPI32 _mm256_max_epi32
105 #define MS_MIN256_F32 _mm256_min_ps
106 #define MS_MIN256_EPI32 _mm256_min_epi32
107 #define MS_MUL256_F32(src1, src2) _mm256_mul_ps(src1, src2)
108 #define MS_MUL256_EPI32(src1, src2) _mm256_mul_epi32(src1, src2)
109 #define MS_DIV256_F32(src1, src2) _mm256_div_ps(src1, src2)
110 #define MS_MUL256_N_F32(src1, src2) _mm256_mul_ps(src1, _mm256_set1_ps(src2))
111 #define MS_MUL256_N_EPI32(src1, src2) _mm256_mul_epi32(src1, _mm256_set1_epi32(src2))
112 #define MS_DIV256_N_F32(src1, src2) _mm256_div_ps(src1, _mm256_set1_ps(src2))
113 #define MS_SLLI256_EPI32(src1, src2) _mm256_slli_epi32(src1, src2)
114 #define MS_CVT256PS_EPI32(src) _mm256_cvttps_epi32(src)
115 #define MS_CVT256EPI32_PS(src) _mm256_cvtepi32_ps(src) // truncate float to int
116 #define MS_CMP256_F32(src1, src2, src3) _mm256_cmp_ps(src1, src2, src3)
117 #define MS_CMPGT256_EPI32(src1, src2) _mm256_cmpgt_epi32(src1, src2)
118 #define MS_BLEND256_F32(src1, src2, src3) _mm256_blendv_ps(src1, src2, src3)
119 #define MS_BLEND256_EPI32(src1, src2, src3) _mm256_blendv_epi8(src1, src2, src3)
120 #define MS_CAST256_F32_S32(src) _mm256_castsi256_ps(src)
121 #endif
122
123 #if defined(ENABLE_SSE)
124 #define MS_FLOAT32X4 __m128
125 #define MS_INT32X4 __m128i
126 #define MS_LDQ_F32 _mm_loadu_ps
127 #define MS_LDQ_EPI32(src) _mm_loadu_si128((__m128i const *)(src))
128 #define MS_ADDQ_F32 _mm_add_ps
129 #define MS_ADDQ_EPI32 _mm_add_epi32
130 #define MS_MOVQ_F32 _mm_set1_ps
131 #define MS_MOVQ_EPI32 _mm_set1_epi32
132 #define MS_MLAQ_F32(src1, src2, src3) _mm_add_ps(src1, _mm_mul_ps(src2, src3))
133 #define MS_STQ_F32 _mm_storeu_ps
134 #define MS_STQ_EPI32(src1, src2) _mm_storeu_si128((__m128i *)(src1), src2)
135 #define MS_SUBQ_F32 _mm_sub_ps
136 #define MS_MAXQ_F32 _mm_max_ps
137 #define MS_MAXQ_EPI32 _mm_max_epi32
138 #define MS_MINQ_F32 _mm_min_ps
139 #define MS_MINQ_EPI32 _mm_min_epi32
140 #define MS_MULQ_F32(src1, src2) _mm_mul_ps(src1, src2)
141 #define MS_MULQ_EPI32(src1, src2) _mm_mul_epi32(src1, src2)
142 #define MS_DIVQ_F32(src1, src2) _mm_div_ps(src1, src2)
143 #define MS_MULQ_N_F32(src1, src2) _mm_mul_ps(src1, _mm_set1_ps(src2))
144 #define MS_MULQ_N_EPI32(src1, src2) _mm_mul_epi32(src1, _mm_set1_epi32(src2))
145 #define MS_DIVQ_N_F32(src1, src2) _mm_div_ps(src1, _mm_set1_ps(src2))
146 #define MS_SLLIQ_EPI32(src1, src2) _mm_slli_epi32(src1, src2)
147 #define MS_CVTQPS_EPI32(src) _mm_cvttps_epi32(src) // truncate float to int
148 #define MS_CVTQEPI32_PS(src) _mm_cvtepi32_ps(src)
149 #define MS_CMPGTQ_F32(src1, src2) _mm_cmpgt_ps(src1, src2)
150 #define MS_CMPGTQ_EPI32(src1, src2) _mm_cmpgt_epi32(src1, src2)
151 #define MS_BLENDQ_F32(src1, src2, src3) _mm_blendv_ps(src1, src2, src3)
152 #define MS_BLENDQ_EPI32(src1, src2, src3) _mm_blendv_epi8(src1, src2, src3)
153 #define MS_CAST_F32_S32(src) _mm_castsi128_ps(src)
154 #endif
155
156 #if defined(ENABLE_ARM) || defined(ENABLE_SSE)
MS_SQRTFX4_F32(MS_FLOAT32X4 src)157 static inline MS_FLOAT32X4 MS_SQRTFX4_F32(MS_FLOAT32X4 src) {
158 MS_FLOAT32X4 dst;
159 MS_F32X4_GETI(dst, 0) = sqrtf(MS_F32X4_GETI(src, 0));
160 MS_F32X4_GETI(dst, 1) = sqrtf(MS_F32X4_GETI(src, 1));
161 MS_F32X4_GETI(dst, 2) = sqrtf(MS_F32X4_GETI(src, 2));
162 MS_F32X4_GETI(dst, 3) = sqrtf(MS_F32X4_GETI(src, 3));
163 return dst;
164 }
165
166 #define LOAD128X8_F32(src, input_ptr, num) \
167 MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \
168 MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \
169 MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \
170 MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); \
171 MS_FLOAT32X4 src##5 = MS_LDQ_F32(input_ptr + 4 * num); \
172 MS_FLOAT32X4 src##6 = MS_LDQ_F32(input_ptr + 5 * num); \
173 MS_FLOAT32X4 src##7 = MS_LDQ_F32(input_ptr + 6 * num); \
174 MS_FLOAT32X4 src##8 = MS_LDQ_F32(input_ptr + 7 * num);
175
176 #define STORE128X8_F32(output_ptr, num, dst) \
177 MS_STQ_F32(output_ptr + 0 * num, dst##1); \
178 MS_STQ_F32(output_ptr + 1 * num, dst##2); \
179 MS_STQ_F32(output_ptr + 2 * num, dst##3); \
180 MS_STQ_F32(output_ptr + 3 * num, dst##4); \
181 MS_STQ_F32(output_ptr + 4 * num, dst##5); \
182 MS_STQ_F32(output_ptr + 5 * num, dst##6); \
183 MS_STQ_F32(output_ptr + 6 * num, dst##7); \
184 MS_STQ_F32(output_ptr + 7 * num, dst##8);
185
MS_TANHX4_F32(MS_FLOAT32X4 src)186 static inline MS_FLOAT32X4 MS_TANHX4_F32(MS_FLOAT32X4 src) {
187 static const MS_FLOAT32X4 data0 = {378.0f, 378.0f, 378.0f, 378.0f};
188 static const MS_FLOAT32X4 data1 = {17325.0f, 17325.0f, 17325.0f, 17325.0f};
189 static const MS_FLOAT32X4 data2 = {135135.0f, 135135.0f, 135135.0f, 135135.0f};
190 static const MS_FLOAT32X4 data3 = {28.0f, 28.0f, 28.0f, 28.0f};
191 static const MS_FLOAT32X4 data4 = {3150.0f, 3150.0f, 3150.0f, 3150.0f};
192 static const MS_FLOAT32X4 data5 = {62370.0f, 62370.0f, 62370.0f, 62370.0f};
193 static const MS_FLOAT32X4 neg = {-1.0f, -1.0f, -1.0f, -1.0f};
194 static const MS_FLOAT32X4 pos = {1.0f, 1.0f, 1.0f, 1.0f};
195 MS_FLOAT32X4 square = MS_MULQ_F32(src, src);
196 MS_FLOAT32X4 a = MS_MULQ_F32(
197 MS_ADDQ_F32(MS_MULQ_F32(MS_ADDQ_F32(MS_MULQ_F32(MS_ADDQ_F32(square, data0), square), data1), square), data2), src);
198 MS_FLOAT32X4 b = MS_ADDQ_F32(
199 MS_MULQ_F32(MS_ADDQ_F32(MS_MULQ_F32(MS_ADDQ_F32(MS_MULQ_F32(data3, square), data4), square), data5), square),
200 data2);
201 return MS_MINQ_F32(MS_MAXQ_F32(MS_DIVQ_F32(a, b), neg), pos);
202 }
203
MS_ERFX4_F32(MS_FLOAT32X4 src)204 static inline MS_FLOAT32X4 MS_ERFX4_F32(MS_FLOAT32X4 src) {
205 MS_FLOAT32X4 dst;
206 MS_F32X4_GETI(dst, 0) = erff(MS_F32X4_GETI(src, 0));
207 MS_F32X4_GETI(dst, 1) = erff(MS_F32X4_GETI(src, 1));
208 MS_F32X4_GETI(dst, 2) = erff(MS_F32X4_GETI(src, 2));
209 MS_F32X4_GETI(dst, 3) = erff(MS_F32X4_GETI(src, 3));
210 return dst;
211 }
212 #endif
213
214 #ifdef ENABLE_AVX
MS_SQRTFX8_F32(MS_FLOAT32X8 src)215 static inline MS_FLOAT32X8 MS_SQRTFX8_F32(MS_FLOAT32X8 src) {
216 MS_FLOAT32X8 dst;
217 MS_F32X8_GETI(dst, 0) = sqrtf(MS_F32X8_GETI(src, 0));
218 MS_F32X8_GETI(dst, 1) = sqrtf(MS_F32X8_GETI(src, 1));
219 MS_F32X8_GETI(dst, 2) = sqrtf(MS_F32X8_GETI(src, 2));
220 MS_F32X8_GETI(dst, 3) = sqrtf(MS_F32X8_GETI(src, 3));
221 MS_F32X8_GETI(dst, 4) = sqrtf(MS_F32X8_GETI(src, 4));
222 MS_F32X8_GETI(dst, 5) = sqrtf(MS_F32X8_GETI(src, 5));
223 MS_F32X8_GETI(dst, 6) = sqrtf(MS_F32X8_GETI(src, 6));
224 MS_F32X8_GETI(dst, 7) = sqrtf(MS_F32X8_GETI(src, 7));
225 return dst;
226 }
227
228 #define LOAD256X8_F32(src, input_ptr, num) \
229 MS_FLOAT32X8 src##1 = MS_LD256_F32(input_ptr + 0 * num); \
230 MS_FLOAT32X8 src##2 = MS_LD256_F32(input_ptr + 1 * num); \
231 MS_FLOAT32X8 src##3 = MS_LD256_F32(input_ptr + 2 * num); \
232 MS_FLOAT32X8 src##4 = MS_LD256_F32(input_ptr + 3 * num); \
233 MS_FLOAT32X8 src##5 = MS_LD256_F32(input_ptr + 4 * num); \
234 MS_FLOAT32X8 src##6 = MS_LD256_F32(input_ptr + 5 * num); \
235 MS_FLOAT32X8 src##7 = MS_LD256_F32(input_ptr + 6 * num); \
236 MS_FLOAT32X8 src##8 = MS_LD256_F32(input_ptr + 7 * num);
237
238 #define LOAD256X16_F32(src, input_ptr, num) \
239 MS_FLOAT32X8 src##1 = MS_LD256_F32(input_ptr + 0 * num); \
240 MS_FLOAT32X8 src##2 = MS_LD256_F32(input_ptr + 1 * num); \
241 MS_FLOAT32X8 src##3 = MS_LD256_F32(input_ptr + 2 * num); \
242 MS_FLOAT32X8 src##4 = MS_LD256_F32(input_ptr + 3 * num); \
243 MS_FLOAT32X8 src##5 = MS_LD256_F32(input_ptr + 4 * num); \
244 MS_FLOAT32X8 src##6 = MS_LD256_F32(input_ptr + 5 * num); \
245 MS_FLOAT32X8 src##7 = MS_LD256_F32(input_ptr + 6 * num); \
246 MS_FLOAT32X8 src##8 = MS_LD256_F32(input_ptr + 7 * num); \
247 MS_FLOAT32X8 src##9 = MS_LD256_F32(input_ptr + 8 * num); \
248 MS_FLOAT32X8 src##10 = MS_LD256_F32(input_ptr + 9 * num); \
249 MS_FLOAT32X8 src##11 = MS_LD256_F32(input_ptr + 10 * num); \
250 MS_FLOAT32X8 src##12 = MS_LD256_F32(input_ptr + 11 * num); \
251 MS_FLOAT32X8 src##13 = MS_LD256_F32(input_ptr + 12 * num); \
252 MS_FLOAT32X8 src##14 = MS_LD256_F32(input_ptr + 13 * num); \
253 MS_FLOAT32X8 src##15 = MS_LD256_F32(input_ptr + 14 * num); \
254 MS_FLOAT32X8 src##16 = MS_LD256_F32(input_ptr + 15 * num);
255
256 #define STORE256X8_F32(output_ptr, num, dst) \
257 MS_ST256_F32(output_ptr + 0 * num, dst##1); \
258 MS_ST256_F32(output_ptr + 1 * num, dst##2); \
259 MS_ST256_F32(output_ptr + 2 * num, dst##3); \
260 MS_ST256_F32(output_ptr + 3 * num, dst##4); \
261 MS_ST256_F32(output_ptr + 4 * num, dst##5); \
262 MS_ST256_F32(output_ptr + 5 * num, dst##6); \
263 MS_ST256_F32(output_ptr + 6 * num, dst##7); \
264 MS_ST256_F32(output_ptr + 7 * num, dst##8);
265
266 #define STORE256X16_F32(output_ptr, num, dst) \
267 MS_ST256_F32(output_ptr + 0 * num, dst##1); \
268 MS_ST256_F32(output_ptr + 1 * num, dst##2); \
269 MS_ST256_F32(output_ptr + 2 * num, dst##3); \
270 MS_ST256_F32(output_ptr + 3 * num, dst##4); \
271 MS_ST256_F32(output_ptr + 4 * num, dst##5); \
272 MS_ST256_F32(output_ptr + 5 * num, dst##6); \
273 MS_ST256_F32(output_ptr + 6 * num, dst##7); \
274 MS_ST256_F32(output_ptr + 7 * num, dst##8); \
275 MS_ST256_F32(output_ptr + 8 * num, dst##9); \
276 MS_ST256_F32(output_ptr + 9 * num, dst##10); \
277 MS_ST256_F32(output_ptr + 10 * num, dst##11); \
278 MS_ST256_F32(output_ptr + 11 * num, dst##12); \
279 MS_ST256_F32(output_ptr + 12 * num, dst##13); \
280 MS_ST256_F32(output_ptr + 13 * num, dst##14); \
281 MS_ST256_F32(output_ptr + 14 * num, dst##15); \
282 MS_ST256_F32(output_ptr + 15 * num, dst##16);
283
MS_TANHX8_F32(MS_FLOAT32X8 src)284 static inline MS_FLOAT32X8 MS_TANHX8_F32(MS_FLOAT32X8 src) {
285 static const MS_FLOAT32X8 data0 = {378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f};
286 static const MS_FLOAT32X8 data1 = {17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f};
287 static const MS_FLOAT32X8 data2 = {135135.0f, 135135.0f, 135135.0f, 135135.0f,
288 135135.0f, 135135.0f, 135135.0f, 135135.0f};
289 static const MS_FLOAT32X8 data3 = {28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f};
290 static const MS_FLOAT32X8 data4 = {3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f};
291 static const MS_FLOAT32X8 data5 = {62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f};
292 static const MS_FLOAT32X8 neg = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f};
293 static const MS_FLOAT32X8 pos = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
294 MS_FLOAT32X8 square = MS_MUL256_F32(src, src);
295 MS_FLOAT32X8 a = MS_MUL256_F32(
296 MS_ADD256_F32(MS_MUL256_F32(MS_ADD256_F32(MS_MUL256_F32(MS_ADD256_F32(square, data0), square), data1), square),
297 data2),
298 src);
299 MS_FLOAT32X8 b = MS_ADD256_F32(
300 MS_MUL256_F32(MS_ADD256_F32(MS_MUL256_F32(MS_ADD256_F32(MS_MUL256_F32(data3, square), data4), square), data5),
301 square),
302 data2);
303 return MS_MIN256_F32(MS_MAX256_F32(MS_DIV256_F32(a, b), neg), pos);
304 }
305 #endif
306
307 #endif // MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
308