• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
18 #define MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
19 #include <math.h>
20 
21 #ifdef ENABLE_ARM
22 #include <arm_neon.h>
23 #define MS_F32X4_GETI(src, i) src[i]
24 #endif
25 
26 #if defined(ENABLE_SSE)
27 #ifdef _MSC_VER
28 #include <immintrin.h>
29 #define MS_F32X4_GETI(src, i) src.m128_f32[i]
30 #else
31 #include <x86intrin.h>
32 #define MS_F32X4_GETI(src, i) src[i]
33 #endif
34 #endif
35 
36 #ifdef ENABLE_AVX
37 #ifdef _MSC_VER
38 #include <immintrin.h>
39 #define MS_F32X8_GETI(src, i) src.m256_f32[i]
40 #else
41 #define MS_F32X8_GETI(src, i) src[i]
42 #endif
43 #endif
44 
45 #ifdef ENABLE_ARM
46 #define MS_FLOAT32X4 float32x4_t
47 #define MS_INT32X4 int32x4_t
48 #define MS_UINT32X4 uint32x4_t
49 #define MS_LDQ_F32 vld1q_f32
50 #define MS_LDQ_EPI32 vld1q_s32
51 #define MS_ADDQ_F32 vaddq_f32
52 #define MS_ADDQ_EPI32 vaddq_s32
53 #define MS_MOVQ_F32 vmovq_n_f32
54 #define MS_MOVQ_EPI32 vmovq_n_s32
55 #define MS_SUBQ_F32 vsubq_f32
56 #define MS_MLAQ_F32(src1, src2, src3) vmlaq_f32(src1, src2, src3)
57 #define MS_STQ_F32 vst1q_f32
58 #define MS_STQ_EPI32 vst1q_s32
59 #define MS_MAXQ_F32 vmaxq_f32
60 #define MS_MAXQ_EPI32 vmaxq_s32
61 #define MS_MINQ_F32 vminq_f32
62 #define MS_MINQ_EPI32 vminq_s32
63 #define MS_MULQ_F32(src1, src2) vmulq_f32(src1, src2)
64 #define MS_MULQ_EPI32(src1, src2) vmulq_s32(src1, src2)
65 #ifdef ENABLE_ARM64
66 #define MS_DIVQ_F32(src1, src2) vdivq_f32(src1, src2)
67 #else
vrecp(float32x4_t v)68 static inline float32x4_t vrecp(float32x4_t v) {
69   float32x4_t r = vrecpeq_f32(v);
70   r = vmulq_f32(vrecpsq_f32(v, r), r);
71   r = vmulq_f32(vrecpsq_f32(v, r), r);
72   return r;
73 }
74 #define MS_DIVQ_F32(src1, src2) vmulq_f32(src1, vrecp(src2))
75 #endif
76 #define MS_MULQ_N_F32(src1, src2) vmulq_n_f32(src1, src2)
77 #define MS_MULQ_N_EPI32(src1, src2) vmulq_n_s32(src1, src2)
78 #define MS_DIVQ_N_F32(src1, src2) vdivq_n_f32(src1, src2)
79 #define MS_SLLIQ_EPI32(src1, src2) vshlq_s32(src1, vmovq_n_s32(src2))
80 #define MS_CVTQPS_EPI32(src) vcvtq_s32_f32(src)
81 #define MS_CVTQEPI32_PS(src) vcvtq_f32_s32(src)
82 #define MS_CMPGTQ_F32(src1, src2) vcgtq_f32(src1, src2)
83 #define MS_CMPGTQ_EPI32(src1, src2) vcgtq_s32(src1, src2)
84 // Note: Compared with X86, the vbslq_f32 parameters are the opposite with _mm_blendv_f32
85 #define MS_BLENDQ_F32(src1, src2, src3) vbslq_f32(src3, src2, src1)
86 #define MS_BLENDQ_EPI32(src1, src2, src3) vbslq_s32(src3, src2, src1)
87 #define MS_CAST_F32_S32(src) vreinterpretq_f32_s32(src)
88 #endif
89 
90 #if defined(ENABLE_AVX)
91 #define MS_FLOAT32X8 __m256
92 #define MS_INT32X8 __m256i
93 #define MS_LD256_F32 _mm256_loadu_ps
94 #define MS_LD256_EPI32(src) _mm256_loadu_si256((__m256i const *)(src))
95 #define MS_ADD256_F32 _mm256_add_ps
96 #define MS_ADD256_EPI32 _mm256_add_epi32
97 #define MS_MOV256_F32 _mm256_set1_ps
98 #define MS_MOV256_EPI32 _mm256_set1_epi32
99 #define MS_MLA256_F32(src1, src2, src3) _mm256_add_ps(src1, _mm256_mul_ps(src2, src3))
100 #define MS_ST256_F32 _mm256_storeu_ps
101 #define MS_ST256_EPI32(src1, src2) _mm256_storeu_si256((__m256i *)(src1), src2)
102 #define MS_SUB256_F32 _mm256_sub_ps
103 #define MS_MAX256_F32 _mm256_max_ps
104 #define MS_MAX256_EPI32 _mm256_max_epi32
105 #define MS_MIN256_F32 _mm256_min_ps
106 #define MS_MIN256_EPI32 _mm256_min_epi32
107 #define MS_MUL256_F32(src1, src2) _mm256_mul_ps(src1, src2)
108 #define MS_MUL256_EPI32(src1, src2) _mm256_mul_epi32(src1, src2)
109 #define MS_DIV256_F32(src1, src2) _mm256_div_ps(src1, src2)
110 #define MS_MUL256_N_F32(src1, src2) _mm256_mul_ps(src1, _mm256_set1_ps(src2))
111 #define MS_MUL256_N_EPI32(src1, src2) _mm256_mul_epi32(src1, _mm256_set1_epi32(src2))
112 #define MS_DIV256_N_F32(src1, src2) _mm256_div_ps(src1, _mm256_set1_ps(src2))
113 #define MS_SLLI256_EPI32(src1, src2) _mm256_slli_epi32(src1, src2)
114 #define MS_CVT256PS_EPI32(src) _mm256_cvttps_epi32(src)
115 #define MS_CVT256EPI32_PS(src) _mm256_cvtepi32_ps(src)  // truncate float to int
116 #define MS_CMP256_F32(src1, src2, src3) _mm256_cmp_ps(src1, src2, src3)
117 #define MS_CMPGT256_EPI32(src1, src2) _mm256_cmpgt_epi32(src1, src2)
118 #define MS_BLEND256_F32(src1, src2, src3) _mm256_blendv_ps(src1, src2, src3)
119 #define MS_BLEND256_EPI32(src1, src2, src3) _mm256_blendv_epi8(src1, src2, src3)
120 #define MS_CAST256_F32_S32(src) _mm256_castsi256_ps(src)
121 #endif
122 
123 #if defined(ENABLE_SSE)
124 #define MS_FLOAT32X4 __m128
125 #define MS_INT32X4 __m128i
126 #define MS_LDQ_F32 _mm_loadu_ps
127 #define MS_LDQ_EPI32(src) _mm_loadu_si128((__m128i const *)(src))
128 #define MS_ADDQ_F32 _mm_add_ps
129 #define MS_ADDQ_EPI32 _mm_add_epi32
130 #define MS_MOVQ_F32 _mm_set1_ps
131 #define MS_MOVQ_EPI32 _mm_set1_epi32
132 #define MS_MLAQ_F32(src1, src2, src3) _mm_add_ps(src1, _mm_mul_ps(src2, src3))
133 #define MS_STQ_F32 _mm_storeu_ps
134 #define MS_STQ_EPI32(src1, src2) _mm_storeu_si128((__m128i *)(src1), src2)
135 #define MS_SUBQ_F32 _mm_sub_ps
136 #define MS_MAXQ_F32 _mm_max_ps
137 #define MS_MAXQ_EPI32 _mm_max_epi32
138 #define MS_MINQ_F32 _mm_min_ps
139 #define MS_MINQ_EPI32 _mm_min_epi32
140 #define MS_MULQ_F32(src1, src2) _mm_mul_ps(src1, src2)
141 #define MS_MULQ_EPI32(src1, src2) _mm_mul_epi32(src1, src2)
142 #define MS_DIVQ_F32(src1, src2) _mm_div_ps(src1, src2)
143 #define MS_MULQ_N_F32(src1, src2) _mm_mul_ps(src1, _mm_set1_ps(src2))
144 #define MS_MULQ_N_EPI32(src1, src2) _mm_mul_epi32(src1, _mm_set1_epi32(src2))
145 #define MS_DIVQ_N_F32(src1, src2) _mm_div_ps(src1, _mm_set1_ps(src2))
146 #define MS_SLLIQ_EPI32(src1, src2) _mm_slli_epi32(src1, src2)
147 #define MS_CVTQPS_EPI32(src) _mm_cvttps_epi32(src)  // truncate float to int
148 #define MS_CVTQEPI32_PS(src) _mm_cvtepi32_ps(src)
149 #define MS_CMPGTQ_F32(src1, src2) _mm_cmpgt_ps(src1, src2)
150 #define MS_CMPGTQ_EPI32(src1, src2) _mm_cmpgt_epi32(src1, src2)
151 #define MS_BLENDQ_F32(src1, src2, src3) _mm_blendv_ps(src1, src2, src3)
152 #define MS_BLENDQ_EPI32(src1, src2, src3) _mm_blendv_epi8(src1, src2, src3)
153 #define MS_CAST_F32_S32(src) _mm_castsi128_ps(src)
154 #endif
155 
156 #if defined(ENABLE_ARM) || defined(ENABLE_SSE)
MS_SQRTFX4_F32(MS_FLOAT32X4 src)157 static inline MS_FLOAT32X4 MS_SQRTFX4_F32(MS_FLOAT32X4 src) {
158   MS_FLOAT32X4 dst;
159   MS_F32X4_GETI(dst, 0) = sqrtf(MS_F32X4_GETI(src, 0));
160   MS_F32X4_GETI(dst, 1) = sqrtf(MS_F32X4_GETI(src, 1));
161   MS_F32X4_GETI(dst, 2) = sqrtf(MS_F32X4_GETI(src, 2));
162   MS_F32X4_GETI(dst, 3) = sqrtf(MS_F32X4_GETI(src, 3));
163   return dst;
164 }
165 
166 #define LOAD128X8_F32(src, input_ptr, num)               \
167   MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \
168   MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \
169   MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \
170   MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); \
171   MS_FLOAT32X4 src##5 = MS_LDQ_F32(input_ptr + 4 * num); \
172   MS_FLOAT32X4 src##6 = MS_LDQ_F32(input_ptr + 5 * num); \
173   MS_FLOAT32X4 src##7 = MS_LDQ_F32(input_ptr + 6 * num); \
174   MS_FLOAT32X4 src##8 = MS_LDQ_F32(input_ptr + 7 * num);
175 
176 #define STORE128X8_F32(output_ptr, num, dst) \
177   MS_STQ_F32(output_ptr + 0 * num, dst##1);  \
178   MS_STQ_F32(output_ptr + 1 * num, dst##2);  \
179   MS_STQ_F32(output_ptr + 2 * num, dst##3);  \
180   MS_STQ_F32(output_ptr + 3 * num, dst##4);  \
181   MS_STQ_F32(output_ptr + 4 * num, dst##5);  \
182   MS_STQ_F32(output_ptr + 5 * num, dst##6);  \
183   MS_STQ_F32(output_ptr + 6 * num, dst##7);  \
184   MS_STQ_F32(output_ptr + 7 * num, dst##8);
185 
MS_TANHX4_F32(MS_FLOAT32X4 src)186 static inline MS_FLOAT32X4 MS_TANHX4_F32(MS_FLOAT32X4 src) {
187   static const MS_FLOAT32X4 data0 = {378.0f, 378.0f, 378.0f, 378.0f};
188   static const MS_FLOAT32X4 data1 = {17325.0f, 17325.0f, 17325.0f, 17325.0f};
189   static const MS_FLOAT32X4 data2 = {135135.0f, 135135.0f, 135135.0f, 135135.0f};
190   static const MS_FLOAT32X4 data3 = {28.0f, 28.0f, 28.0f, 28.0f};
191   static const MS_FLOAT32X4 data4 = {3150.0f, 3150.0f, 3150.0f, 3150.0f};
192   static const MS_FLOAT32X4 data5 = {62370.0f, 62370.0f, 62370.0f, 62370.0f};
193   static const MS_FLOAT32X4 neg = {-1.0f, -1.0f, -1.0f, -1.0f};
194   static const MS_FLOAT32X4 pos = {1.0f, 1.0f, 1.0f, 1.0f};
195   MS_FLOAT32X4 square = MS_MULQ_F32(src, src);
196   MS_FLOAT32X4 a = MS_MULQ_F32(
197     MS_ADDQ_F32(MS_MULQ_F32(MS_ADDQ_F32(MS_MULQ_F32(MS_ADDQ_F32(square, data0), square), data1), square), data2), src);
198   MS_FLOAT32X4 b = MS_ADDQ_F32(
199     MS_MULQ_F32(MS_ADDQ_F32(MS_MULQ_F32(MS_ADDQ_F32(MS_MULQ_F32(data3, square), data4), square), data5), square),
200     data2);
201   return MS_MINQ_F32(MS_MAXQ_F32(MS_DIVQ_F32(a, b), neg), pos);
202 }
203 
MS_ERFX4_F32(MS_FLOAT32X4 src)204 static inline MS_FLOAT32X4 MS_ERFX4_F32(MS_FLOAT32X4 src) {
205   MS_FLOAT32X4 dst;
206   MS_F32X4_GETI(dst, 0) = erff(MS_F32X4_GETI(src, 0));
207   MS_F32X4_GETI(dst, 1) = erff(MS_F32X4_GETI(src, 1));
208   MS_F32X4_GETI(dst, 2) = erff(MS_F32X4_GETI(src, 2));
209   MS_F32X4_GETI(dst, 3) = erff(MS_F32X4_GETI(src, 3));
210   return dst;
211 }
212 #endif
213 
214 #ifdef ENABLE_AVX
MS_SQRTFX8_F32(MS_FLOAT32X8 src)215 static inline MS_FLOAT32X8 MS_SQRTFX8_F32(MS_FLOAT32X8 src) {
216   MS_FLOAT32X8 dst;
217   MS_F32X8_GETI(dst, 0) = sqrtf(MS_F32X8_GETI(src, 0));
218   MS_F32X8_GETI(dst, 1) = sqrtf(MS_F32X8_GETI(src, 1));
219   MS_F32X8_GETI(dst, 2) = sqrtf(MS_F32X8_GETI(src, 2));
220   MS_F32X8_GETI(dst, 3) = sqrtf(MS_F32X8_GETI(src, 3));
221   MS_F32X8_GETI(dst, 4) = sqrtf(MS_F32X8_GETI(src, 4));
222   MS_F32X8_GETI(dst, 5) = sqrtf(MS_F32X8_GETI(src, 5));
223   MS_F32X8_GETI(dst, 6) = sqrtf(MS_F32X8_GETI(src, 6));
224   MS_F32X8_GETI(dst, 7) = sqrtf(MS_F32X8_GETI(src, 7));
225   return dst;
226 }
227 
228 #define LOAD256X8_F32(src, input_ptr, num)                 \
229   MS_FLOAT32X8 src##1 = MS_LD256_F32(input_ptr + 0 * num); \
230   MS_FLOAT32X8 src##2 = MS_LD256_F32(input_ptr + 1 * num); \
231   MS_FLOAT32X8 src##3 = MS_LD256_F32(input_ptr + 2 * num); \
232   MS_FLOAT32X8 src##4 = MS_LD256_F32(input_ptr + 3 * num); \
233   MS_FLOAT32X8 src##5 = MS_LD256_F32(input_ptr + 4 * num); \
234   MS_FLOAT32X8 src##6 = MS_LD256_F32(input_ptr + 5 * num); \
235   MS_FLOAT32X8 src##7 = MS_LD256_F32(input_ptr + 6 * num); \
236   MS_FLOAT32X8 src##8 = MS_LD256_F32(input_ptr + 7 * num);
237 
238 #define LOAD256X16_F32(src, input_ptr, num)                  \
239   MS_FLOAT32X8 src##1 = MS_LD256_F32(input_ptr + 0 * num);   \
240   MS_FLOAT32X8 src##2 = MS_LD256_F32(input_ptr + 1 * num);   \
241   MS_FLOAT32X8 src##3 = MS_LD256_F32(input_ptr + 2 * num);   \
242   MS_FLOAT32X8 src##4 = MS_LD256_F32(input_ptr + 3 * num);   \
243   MS_FLOAT32X8 src##5 = MS_LD256_F32(input_ptr + 4 * num);   \
244   MS_FLOAT32X8 src##6 = MS_LD256_F32(input_ptr + 5 * num);   \
245   MS_FLOAT32X8 src##7 = MS_LD256_F32(input_ptr + 6 * num);   \
246   MS_FLOAT32X8 src##8 = MS_LD256_F32(input_ptr + 7 * num);   \
247   MS_FLOAT32X8 src##9 = MS_LD256_F32(input_ptr + 8 * num);   \
248   MS_FLOAT32X8 src##10 = MS_LD256_F32(input_ptr + 9 * num);  \
249   MS_FLOAT32X8 src##11 = MS_LD256_F32(input_ptr + 10 * num); \
250   MS_FLOAT32X8 src##12 = MS_LD256_F32(input_ptr + 11 * num); \
251   MS_FLOAT32X8 src##13 = MS_LD256_F32(input_ptr + 12 * num); \
252   MS_FLOAT32X8 src##14 = MS_LD256_F32(input_ptr + 13 * num); \
253   MS_FLOAT32X8 src##15 = MS_LD256_F32(input_ptr + 14 * num); \
254   MS_FLOAT32X8 src##16 = MS_LD256_F32(input_ptr + 15 * num);
255 
256 #define STORE256X8_F32(output_ptr, num, dst)  \
257   MS_ST256_F32(output_ptr + 0 * num, dst##1); \
258   MS_ST256_F32(output_ptr + 1 * num, dst##2); \
259   MS_ST256_F32(output_ptr + 2 * num, dst##3); \
260   MS_ST256_F32(output_ptr + 3 * num, dst##4); \
261   MS_ST256_F32(output_ptr + 4 * num, dst##5); \
262   MS_ST256_F32(output_ptr + 5 * num, dst##6); \
263   MS_ST256_F32(output_ptr + 6 * num, dst##7); \
264   MS_ST256_F32(output_ptr + 7 * num, dst##8);
265 
266 #define STORE256X16_F32(output_ptr, num, dst)   \
267   MS_ST256_F32(output_ptr + 0 * num, dst##1);   \
268   MS_ST256_F32(output_ptr + 1 * num, dst##2);   \
269   MS_ST256_F32(output_ptr + 2 * num, dst##3);   \
270   MS_ST256_F32(output_ptr + 3 * num, dst##4);   \
271   MS_ST256_F32(output_ptr + 4 * num, dst##5);   \
272   MS_ST256_F32(output_ptr + 5 * num, dst##6);   \
273   MS_ST256_F32(output_ptr + 6 * num, dst##7);   \
274   MS_ST256_F32(output_ptr + 7 * num, dst##8);   \
275   MS_ST256_F32(output_ptr + 8 * num, dst##9);   \
276   MS_ST256_F32(output_ptr + 9 * num, dst##10);  \
277   MS_ST256_F32(output_ptr + 10 * num, dst##11); \
278   MS_ST256_F32(output_ptr + 11 * num, dst##12); \
279   MS_ST256_F32(output_ptr + 12 * num, dst##13); \
280   MS_ST256_F32(output_ptr + 13 * num, dst##14); \
281   MS_ST256_F32(output_ptr + 14 * num, dst##15); \
282   MS_ST256_F32(output_ptr + 15 * num, dst##16);
283 
MS_TANHX8_F32(MS_FLOAT32X8 src)284 static inline MS_FLOAT32X8 MS_TANHX8_F32(MS_FLOAT32X8 src) {
285   static const MS_FLOAT32X8 data0 = {378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f};
286   static const MS_FLOAT32X8 data1 = {17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f};
287   static const MS_FLOAT32X8 data2 = {135135.0f, 135135.0f, 135135.0f, 135135.0f,
288                                      135135.0f, 135135.0f, 135135.0f, 135135.0f};
289   static const MS_FLOAT32X8 data3 = {28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f};
290   static const MS_FLOAT32X8 data4 = {3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f};
291   static const MS_FLOAT32X8 data5 = {62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f};
292   static const MS_FLOAT32X8 neg = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f};
293   static const MS_FLOAT32X8 pos = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
294   MS_FLOAT32X8 square = MS_MUL256_F32(src, src);
295   MS_FLOAT32X8 a = MS_MUL256_F32(
296     MS_ADD256_F32(MS_MUL256_F32(MS_ADD256_F32(MS_MUL256_F32(MS_ADD256_F32(square, data0), square), data1), square),
297                   data2),
298     src);
299   MS_FLOAT32X8 b = MS_ADD256_F32(
300     MS_MUL256_F32(MS_ADD256_F32(MS_MUL256_F32(MS_ADD256_F32(MS_MUL256_F32(data3, square), data4), square), data5),
301                   square),
302     data2);
303   return MS_MIN256_F32(MS_MAX256_F32(MS_DIV256_F32(a, b), neg), pos);
304 }
305 #endif
306 
307 #endif  // MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
308