1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_LITE_NNACL_INTRINSICS_SSE_SSE_COMMON_H_
18 #define MINDSPORE_LITE_NNACL_INTRINSICS_SSE_SSE_COMMON_H_
19
ActBlock1(__m128 * v1,size_t relu,size_t relu6)20 static inline void ActBlock1(__m128 *v1, size_t relu, size_t relu6) {
21 __m128 zero_ma = _mm_setzero_ps();
22 if (relu || relu6) {
23 *v1 = _mm_max_ps(zero_ma, *v1);
24 }
25 if (relu6) {
26 __m128 relu6_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f);
27 *v1 = _mm_min_ps(relu6_ma, *v1);
28 }
29 }
30
ActBlock2(__m128 * v1,__m128 * v2,size_t relu,size_t relu6)31 static inline void ActBlock2(__m128 *v1, __m128 *v2, size_t relu, size_t relu6) {
32 __m128 zero_ma = _mm_setzero_ps();
33 if (relu || relu6) {
34 *v1 = _mm_max_ps(zero_ma, *v1);
35 *v2 = _mm_max_ps(zero_ma, *v2);
36 }
37 if (relu6) {
38 __m128 relu6_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f);
39 *v1 = _mm_min_ps(relu6_ma, *v1);
40 *v2 = _mm_min_ps(relu6_ma, *v2);
41 }
42 }
43
ActBlock4(__m128 * v1,__m128 * v2,__m128 * v3,__m128 * v4,size_t relu,size_t relu6)44 static inline void ActBlock4(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, size_t relu, size_t relu6) {
45 __m128 zero_ma = _mm_setzero_ps();
46 if (relu || relu6) {
47 *v1 = _mm_max_ps(zero_ma, *v1);
48 *v2 = _mm_max_ps(zero_ma, *v2);
49 *v3 = _mm_max_ps(zero_ma, *v3);
50 *v4 = _mm_max_ps(zero_ma, *v4);
51 }
52 if (relu6) {
53 __m128 relu6_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f);
54 *v1 = _mm_min_ps(relu6_ma, *v1);
55 *v2 = _mm_min_ps(relu6_ma, *v2);
56 *v3 = _mm_min_ps(relu6_ma, *v3);
57 *v4 = _mm_min_ps(relu6_ma, *v4);
58 }
59 }
60
ActBlock8(__m128 * v1,__m128 * v2,__m128 * v3,__m128 * v4,__m128 * v5,__m128 * v6,__m128 * v7,__m128 * v8,size_t relu_type)61 static inline void ActBlock8(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, __m128 *v5, __m128 *v6, __m128 *v7,
62 __m128 *v8, size_t relu_type) {
63 __m128 relu6 = _mm_set_ps1(6.0);
64 __m128 zero = _mm_setzero_ps();
65 switch (relu_type) {
66 case 3:
67 *v1 = _mm_min_ps(*v1, relu6);
68 *v2 = _mm_min_ps(*v2, relu6);
69 *v3 = _mm_min_ps(*v3, relu6);
70 *v4 = _mm_min_ps(*v4, relu6);
71 *v5 = _mm_min_ps(*v5, relu6);
72 *v6 = _mm_min_ps(*v6, relu6);
73 *v7 = _mm_min_ps(*v7, relu6);
74 *v8 = _mm_min_ps(*v8, relu6);
75 case 1:
76 *v1 = _mm_max_ps(*v1, zero);
77 *v2 = _mm_max_ps(*v2, zero);
78 *v3 = _mm_max_ps(*v3, zero);
79 *v4 = _mm_max_ps(*v4, zero);
80 *v5 = _mm_max_ps(*v5, zero);
81 *v6 = _mm_max_ps(*v6, zero);
82 *v7 = _mm_max_ps(*v7, zero);
83 *v8 = _mm_max_ps(*v8, zero);
84 break;
85 }
86 }
87
WriteCol1(float ** dst,__m128 * dst1,__m128 * dst2,__m128 * dst3,__m128 * dst4,__m128 * dst5,__m128 * dst6,__m128 * dst7,__m128 * dst8,int stride,int extra_stride,int r)88 static inline void WriteCol1(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5,
89 __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) {
90 _mm_store_ss(*dst, *dst1);
91 if (r > 1) {
92 *dst += stride;
93 _mm_store_ss(*dst, *dst3);
94 }
95 if (r > 2) {
96 *dst += stride;
97 _mm_store_ss(*dst, *dst5);
98 }
99 if (r > 3) {
100 *dst += stride;
101 _mm_store_ss(*dst, *dst7);
102 *dst += stride;
103 *dst += extra_stride;
104 }
105 }
106
WriteCol2(float ** dst,__m128 * dst1,__m128 * dst2,__m128 * dst3,__m128 * dst4,__m128 * dst5,__m128 * dst6,__m128 * dst7,__m128 * dst8,int stride,int r)107 static inline void WriteCol2(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5,
108 __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int r) {
109 _mm_store_ss(*dst, *dst1);
110 *dst1 = _mm_shuffle_ps(*dst1, *dst1, _MM_SHUFFLE(0, 3, 2, 1));
111 _mm_store_ss(*dst, *dst1);
112 if (r > 1) {
113 *dst += stride;
114 _mm_store_ss(*dst, *dst3);
115 *dst3 = _mm_shuffle_ps(*dst3, *dst3, _MM_SHUFFLE(0, 3, 2, 1));
116 _mm_store_ss(*dst, *dst3);
117 }
118 if (r > 2) {
119 *dst += stride;
120 _mm_store_ss(*dst, *dst5);
121 *dst5 = _mm_shuffle_ps(*dst5, *dst5, _MM_SHUFFLE(0, 3, 2, 1));
122 _mm_store_ss(*dst, *dst5);
123 }
124 if (r > 3) {
125 *dst += stride;
126 _mm_store_ss(*dst, *dst7);
127 *dst7 = _mm_shuffle_ps(*dst7, *dst7, _MM_SHUFFLE(0, 3, 2, 1));
128 _mm_store_ss(*dst, *dst7);
129 }
130 }
131
WriteCol2Opt(float ** dst,__m128 * dst1,__m128 * dst2,__m128 * dst3,__m128 * dst4,__m128 * dst5,__m128 * dst6,__m128 * dst7,__m128 * dst8,int stride,int r)132 static inline void WriteCol2Opt(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5,
133 __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int r) {
134 _mm_store_ss(*dst, *dst1);
135 *dst1 = _mm_shuffle_ps(*dst1, *dst1, _MM_SHUFFLE(0, 3, 2, 1));
136 _mm_store_ss(*dst + 1, *dst1);
137 if (r > 1) {
138 *dst += stride;
139 _mm_store_ss(*dst, *dst3);
140 *dst3 = _mm_shuffle_ps(*dst3, *dst3, _MM_SHUFFLE(0, 3, 2, 1));
141 _mm_store_ss(*dst + 1, *dst3);
142 }
143 if (r > 2) {
144 *dst += stride;
145 _mm_store_ss(*dst, *dst5);
146 *dst5 = _mm_shuffle_ps(*dst5, *dst5, _MM_SHUFFLE(0, 3, 2, 1));
147 _mm_store_ss(*dst + 1, *dst5);
148 }
149 if (r > 3) {
150 *dst += stride;
151 _mm_store_ss(*dst, *dst7);
152 *dst7 = _mm_shuffle_ps(*dst7, *dst7, _MM_SHUFFLE(0, 3, 2, 1));
153 _mm_store_ss(*dst + 1, *dst7);
154 *dst += stride;
155 *dst += 2;
156 }
157 }
158
WriteCol3(float ** dst,__m128 * dst1,__m128 * dst2,__m128 * dst3,__m128 * dst4,__m128 * dst5,__m128 * dst6,__m128 * dst7,__m128 * dst8,int stride,int extra_stride,int r)159 static inline void WriteCol3(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5,
160 __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) {
161 if (r > 1) {
162 *dst += stride;
163 _mm_store_ss(*dst, *dst3);
164 *dst3 = _mm_shuffle_ps(*dst3, *dst3, _MM_SHUFFLE(0, 3, 2, 1));
165 _mm_store_ss(*dst + 1, *dst3);
166 *dst3 = _mm_shuffle_ps(*dst3, *dst3, _MM_SHUFFLE(0, 3, 2, 1));
167 _mm_store_ss(*dst + 2, *dst3);
168 }
169 if (r > 2) {
170 *dst += stride;
171 _mm_store_ss(*dst, *dst5);
172 *dst5 = _mm_shuffle_ps(*dst5, *dst5, _MM_SHUFFLE(0, 3, 2, 1));
173 _mm_store_ss(*dst + 1, *dst5);
174 *dst5 = _mm_shuffle_ps(*dst5, *dst5, _MM_SHUFFLE(0, 3, 2, 1));
175 _mm_store_ss(*dst + 2, *dst5);
176 }
177 if (r > 3) {
178 *dst += stride;
179 _mm_store_ss(*dst, *dst7);
180 *dst7 = _mm_shuffle_ps(*dst7, *dst7, _MM_SHUFFLE(0, 3, 2, 1));
181 _mm_store_ss(*dst + 1, *dst7);
182 *dst7 = _mm_shuffle_ps(*dst7, *dst7, _MM_SHUFFLE(0, 3, 2, 1));
183 _mm_store_ss(*dst + 2, *dst7);
184 *dst += stride;
185 *dst += extra_stride;
186 }
187 }
188
WriteCol4(float ** dst,__m128 * dst1,__m128 * dst2,__m128 * dst3,__m128 * dst4,__m128 * dst5,__m128 * dst6,__m128 * dst7,__m128 * dst8,int stride,int extra_stride,int r)189 static inline void WriteCol4(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5,
190 __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) {
191 _mm_storeu_ps(*dst, *dst1);
192 if (r > 1) {
193 *dst += stride;
194 _mm_storeu_ps(*dst, *dst3);
195 }
196 if (r > 2) {
197 *dst += stride;
198 _mm_storeu_ps(*dst, *dst5);
199 }
200 if (r > 3) {
201 *dst += stride;
202 _mm_storeu_ps(*dst, *dst7);
203 *dst += stride;
204 *dst += extra_stride;
205 }
206 }
207
WriteCol5(float ** dst,__m128 * dst1,__m128 * dst2,__m128 * dst3,__m128 * dst4,__m128 * dst5,__m128 * dst6,__m128 * dst7,__m128 * dst8,int stride,int extra_stride,int r)208 static inline void WriteCol5(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5,
209 __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) {
210 _mm_storeu_ps(*dst, *dst1);
211 _mm_store_ss(*dst + 4, *dst2);
212 if (r > 1) {
213 *dst += stride;
214 _mm_storeu_ps(*dst, *dst3);
215 _mm_store_ss(*dst + 4, *dst4);
216 }
217 if (r > 2) {
218 *dst += stride;
219 _mm_storeu_ps(*dst, *dst5);
220 _mm_store_ss(*dst + 4, *dst6);
221 }
222 if (r > 3) {
223 *dst += stride;
224 _mm_storeu_ps(*dst, *dst7);
225 _mm_store_ss(*dst + 4, *dst8);
226 *dst += stride;
227 *dst += extra_stride;
228 }
229 }
230
WriteCol6(float ** dst,__m128 * dst1,__m128 * dst2,__m128 * dst3,__m128 * dst4,__m128 * dst5,__m128 * dst6,__m128 * dst7,__m128 * dst8,int stride,int extra_stride,int r)231 static inline void WriteCol6(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5,
232 __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) {
233 _mm_storeu_ps(*dst, *dst1);
234 _mm_store_ss(*dst + 4, *dst2);
235 *dst2 = _mm_shuffle_ps(*dst2, *dst2, _MM_SHUFFLE(0, 3, 2, 1));
236 _mm_store_ss(*dst + 5, *dst2);
237 if (r > 1) {
238 *dst += stride;
239 _mm_storeu_ps(*dst, *dst3);
240 _mm_store_ss(*dst + 4, *dst4);
241 *dst4 = _mm_shuffle_ps(*dst4, *dst4, _MM_SHUFFLE(0, 3, 2, 1));
242 _mm_store_ss(*dst + 5, *dst4);
243 }
244 if (r > 2) {
245 *dst += stride;
246 _mm_storeu_ps(*dst, *dst5);
247 _mm_store_ss(*dst + 4, *dst6);
248 *dst6 = _mm_shuffle_ps(*dst6, *dst6, _MM_SHUFFLE(0, 3, 2, 1));
249 _mm_store_ss(*dst + 5, *dst6);
250 }
251 if (r > 3) {
252 *dst += stride;
253 _mm_storeu_ps(*dst, *dst7);
254 _mm_store_ss(*dst + 4, *dst8);
255 *dst8 = _mm_shuffle_ps(*dst8, *dst8, _MM_SHUFFLE(0, 3, 2, 1));
256 _mm_store_ss(*dst + 5, *dst8);
257 *dst += stride;
258 *dst += extra_stride;
259 }
260 }
261
WriteCol7(float ** dst,__m128 * dst1,__m128 * dst2,__m128 * dst3,__m128 * dst4,__m128 * dst5,__m128 * dst6,__m128 * dst7,__m128 * dst8,int stride,int extra_stride,int r)262 static inline void WriteCol7(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5,
263 __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) {
264 _mm_storeu_ps(*dst, *dst1);
265 _mm_store_ss(*dst + 4, *dst2);
266 *dst2 = _mm_shuffle_ps(*dst2, *dst2, _MM_SHUFFLE(0, 3, 2, 1));
267 _mm_store_ss(*dst + 5, *dst2);
268 *dst2 = _mm_shuffle_ps(*dst2, *dst2, _MM_SHUFFLE(0, 3, 2, 1));
269 _mm_store_ss(*dst + 6, *dst2);
270 if (r > 1) {
271 *dst += stride;
272 _mm_storeu_ps(*dst, *dst3);
273 _mm_store_ss(*dst + 4, *dst4);
274 *dst4 = _mm_shuffle_ps(*dst4, *dst4, _MM_SHUFFLE(0, 3, 2, 1));
275 _mm_store_ss(*dst + 5, *dst4);
276 *dst4 = _mm_shuffle_ps(*dst4, *dst4, _MM_SHUFFLE(0, 3, 2, 1));
277 _mm_store_ss(*dst + 6, *dst4);
278 }
279 if (r > 2) {
280 *dst += stride;
281 _mm_storeu_ps(*dst, *dst5);
282 _mm_store_ss(*dst + 4, *dst6);
283 *dst6 = _mm_shuffle_ps(*dst6, *dst6, _MM_SHUFFLE(0, 3, 2, 1));
284 _mm_store_ss(*dst + 5, *dst6);
285 *dst6 = _mm_shuffle_ps(*dst6, *dst6, _MM_SHUFFLE(0, 3, 2, 1));
286 _mm_store_ss(*dst + 6, *dst6);
287 }
288 if (r > 3) {
289 *dst += stride;
290 _mm_storeu_ps(*dst, *dst7);
291 _mm_store_ss(*dst + 4, *dst8);
292 *dst8 = _mm_shuffle_ps(*dst8, *dst8, _MM_SHUFFLE(0, 3, 2, 1));
293 _mm_store_ss(*dst + 5, *dst8);
294 *dst8 = _mm_shuffle_ps(*dst8, *dst8, _MM_SHUFFLE(0, 3, 2, 1));
295 _mm_store_ss(*dst + 6, *dst8);
296 *dst += stride;
297 *dst += extra_stride;
298 }
299 }
300
WriteCol8(float ** dst,__m128 * dst1,__m128 * dst2,__m128 * dst3,__m128 * dst4,__m128 * dst5,__m128 * dst6,__m128 * dst7,__m128 * dst8,int stride,int extra_stride,int r)301 static inline void WriteCol8(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5,
302 __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) {
303 _mm_storeu_ps(*dst, *dst1);
304 _mm_storeu_ps(*dst + 4, *dst2);
305 if (r > 1) {
306 *dst += stride;
307 _mm_storeu_ps(*dst, *dst3);
308 _mm_storeu_ps(*dst + 4, *dst4);
309 }
310 if (r > 2) {
311 *dst += stride;
312 _mm_storeu_ps(*dst, *dst5);
313 _mm_storeu_ps(*dst + 4, *dst6);
314 }
315 if (r > 3) {
316 *dst += stride;
317 _mm_storeu_ps(*dst, *dst7);
318 _mm_storeu_ps(*dst + 4, *dst8);
319 *dst += stride;
320 *dst += extra_stride;
321 }
322 }
323
DoBiasBlock8(const float * bias_ptr,__m128 * dst1,__m128 * dst2,__m128 * dst3,__m128 * dst4,__m128 * dst5,__m128 * dst6,__m128 * dst7,__m128 * dst8)324 static inline void DoBiasBlock8(const float *bias_ptr, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4,
325 __m128 *dst5, __m128 *dst6, __m128 *dst7, __m128 *dst8) {
326 __m128 bias1 = _mm_loadu_ps(bias_ptr);
327 __m128 bias2 = _mm_loadu_ps(bias_ptr + C4NUM);
328 *dst1 = _mm_add_ps(*dst1, bias1);
329 *dst2 = _mm_add_ps(*dst2, bias2);
330 *dst3 = _mm_add_ps(*dst3, bias1);
331 *dst4 = _mm_add_ps(*dst4, bias2);
332 *dst5 = _mm_add_ps(*dst5, bias1);
333 *dst6 = _mm_add_ps(*dst6, bias2);
334 *dst7 = _mm_add_ps(*dst7, bias1);
335 *dst8 = _mm_add_ps(*dst8, bias2);
336 }
337
338 #endif // MINDSPORE_LITE_NNACL_INTRINSICS_SSE_SSE_COMMON_H_
339