• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_NNACL_INTRINSICS_SSE_SSE_COMMON_H_
18 #define MINDSPORE_LITE_NNACL_INTRINSICS_SSE_SSE_COMMON_H_
19 
ActBlock1(__m128 * v1,size_t relu,size_t relu6)20 static inline void ActBlock1(__m128 *v1, size_t relu, size_t relu6) {
21   __m128 zero_ma = _mm_setzero_ps();
22   if (relu || relu6) {
23     *v1 = _mm_max_ps(zero_ma, *v1);
24   }
25   if (relu6) {
26     __m128 relu6_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f);
27     *v1 = _mm_min_ps(relu6_ma, *v1);
28   }
29 }
30 
ActBlock2(__m128 * v1,__m128 * v2,size_t relu,size_t relu6)31 static inline void ActBlock2(__m128 *v1, __m128 *v2, size_t relu, size_t relu6) {
32   __m128 zero_ma = _mm_setzero_ps();
33   if (relu || relu6) {
34     *v1 = _mm_max_ps(zero_ma, *v1);
35     *v2 = _mm_max_ps(zero_ma, *v2);
36   }
37   if (relu6) {
38     __m128 relu6_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f);
39     *v1 = _mm_min_ps(relu6_ma, *v1);
40     *v2 = _mm_min_ps(relu6_ma, *v2);
41   }
42 }
43 
ActBlock4(__m128 * v1,__m128 * v2,__m128 * v3,__m128 * v4,size_t relu,size_t relu6)44 static inline void ActBlock4(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, size_t relu, size_t relu6) {
45   __m128 zero_ma = _mm_setzero_ps();
46   if (relu || relu6) {
47     *v1 = _mm_max_ps(zero_ma, *v1);
48     *v2 = _mm_max_ps(zero_ma, *v2);
49     *v3 = _mm_max_ps(zero_ma, *v3);
50     *v4 = _mm_max_ps(zero_ma, *v4);
51   }
52   if (relu6) {
53     __m128 relu6_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f);
54     *v1 = _mm_min_ps(relu6_ma, *v1);
55     *v2 = _mm_min_ps(relu6_ma, *v2);
56     *v3 = _mm_min_ps(relu6_ma, *v3);
57     *v4 = _mm_min_ps(relu6_ma, *v4);
58   }
59 }
60 
ActBlock8(__m128 * v1,__m128 * v2,__m128 * v3,__m128 * v4,__m128 * v5,__m128 * v6,__m128 * v7,__m128 * v8,size_t relu_type)61 static inline void ActBlock8(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, __m128 *v5, __m128 *v6, __m128 *v7,
62                              __m128 *v8, size_t relu_type) {
63   __m128 relu6 = _mm_set_ps1(6.0);
64   __m128 zero = _mm_setzero_ps();
65   switch (relu_type) {
66     case 3:
67       *v1 = _mm_min_ps(*v1, relu6);
68       *v2 = _mm_min_ps(*v2, relu6);
69       *v3 = _mm_min_ps(*v3, relu6);
70       *v4 = _mm_min_ps(*v4, relu6);
71       *v5 = _mm_min_ps(*v5, relu6);
72       *v6 = _mm_min_ps(*v6, relu6);
73       *v7 = _mm_min_ps(*v7, relu6);
74       *v8 = _mm_min_ps(*v8, relu6);
75     case 1:
76       *v1 = _mm_max_ps(*v1, zero);
77       *v2 = _mm_max_ps(*v2, zero);
78       *v3 = _mm_max_ps(*v3, zero);
79       *v4 = _mm_max_ps(*v4, zero);
80       *v5 = _mm_max_ps(*v5, zero);
81       *v6 = _mm_max_ps(*v6, zero);
82       *v7 = _mm_max_ps(*v7, zero);
83       *v8 = _mm_max_ps(*v8, zero);
84       break;
85   }
86 }
87 
WriteCol1(float ** dst,__m128 * dst1,__m128 * dst2,__m128 * dst3,__m128 * dst4,__m128 * dst5,__m128 * dst6,__m128 * dst7,__m128 * dst8,int stride,int extra_stride,int r)88 static inline void WriteCol1(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5,
89                              __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) {
90   _mm_store_ss(*dst, *dst1);
91   if (r > 1) {
92     *dst += stride;
93     _mm_store_ss(*dst, *dst3);
94   }
95   if (r > 2) {
96     *dst += stride;
97     _mm_store_ss(*dst, *dst5);
98   }
99   if (r > 3) {
100     *dst += stride;
101     _mm_store_ss(*dst, *dst7);
102     *dst += stride;
103     *dst += extra_stride;
104   }
105 }
106 
WriteCol2(float ** dst,__m128 * dst1,__m128 * dst2,__m128 * dst3,__m128 * dst4,__m128 * dst5,__m128 * dst6,__m128 * dst7,__m128 * dst8,int stride,int r)107 static inline void WriteCol2(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5,
108                              __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int r) {
109   _mm_store_ss(*dst, *dst1);
110   *dst1 = _mm_shuffle_ps(*dst1, *dst1, _MM_SHUFFLE(0, 3, 2, 1));
111   _mm_store_ss(*dst, *dst1);
112   if (r > 1) {
113     *dst += stride;
114     _mm_store_ss(*dst, *dst3);
115     *dst3 = _mm_shuffle_ps(*dst3, *dst3, _MM_SHUFFLE(0, 3, 2, 1));
116     _mm_store_ss(*dst, *dst3);
117   }
118   if (r > 2) {
119     *dst += stride;
120     _mm_store_ss(*dst, *dst5);
121     *dst5 = _mm_shuffle_ps(*dst5, *dst5, _MM_SHUFFLE(0, 3, 2, 1));
122     _mm_store_ss(*dst, *dst5);
123   }
124   if (r > 3) {
125     *dst += stride;
126     _mm_store_ss(*dst, *dst7);
127     *dst7 = _mm_shuffle_ps(*dst7, *dst7, _MM_SHUFFLE(0, 3, 2, 1));
128     _mm_store_ss(*dst, *dst7);
129   }
130 }
131 
WriteCol2Opt(float ** dst,__m128 * dst1,__m128 * dst2,__m128 * dst3,__m128 * dst4,__m128 * dst5,__m128 * dst6,__m128 * dst7,__m128 * dst8,int stride,int r)132 static inline void WriteCol2Opt(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5,
133                                 __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int r) {
134   _mm_store_ss(*dst, *dst1);
135   *dst1 = _mm_shuffle_ps(*dst1, *dst1, _MM_SHUFFLE(0, 3, 2, 1));
136   _mm_store_ss(*dst + 1, *dst1);
137   if (r > 1) {
138     *dst += stride;
139     _mm_store_ss(*dst, *dst3);
140     *dst3 = _mm_shuffle_ps(*dst3, *dst3, _MM_SHUFFLE(0, 3, 2, 1));
141     _mm_store_ss(*dst + 1, *dst3);
142   }
143   if (r > 2) {
144     *dst += stride;
145     _mm_store_ss(*dst, *dst5);
146     *dst5 = _mm_shuffle_ps(*dst5, *dst5, _MM_SHUFFLE(0, 3, 2, 1));
147     _mm_store_ss(*dst + 1, *dst5);
148   }
149   if (r > 3) {
150     *dst += stride;
151     _mm_store_ss(*dst, *dst7);
152     *dst7 = _mm_shuffle_ps(*dst7, *dst7, _MM_SHUFFLE(0, 3, 2, 1));
153     _mm_store_ss(*dst + 1, *dst7);
154     *dst += stride;
155     *dst += 2;
156   }
157 }
158 
WriteCol3(float ** dst,__m128 * dst1,__m128 * dst2,__m128 * dst3,__m128 * dst4,__m128 * dst5,__m128 * dst6,__m128 * dst7,__m128 * dst8,int stride,int extra_stride,int r)159 static inline void WriteCol3(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5,
160                              __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) {
161   if (r > 1) {
162     *dst += stride;
163     _mm_store_ss(*dst, *dst3);
164     *dst3 = _mm_shuffle_ps(*dst3, *dst3, _MM_SHUFFLE(0, 3, 2, 1));
165     _mm_store_ss(*dst + 1, *dst3);
166     *dst3 = _mm_shuffle_ps(*dst3, *dst3, _MM_SHUFFLE(0, 3, 2, 1));
167     _mm_store_ss(*dst + 2, *dst3);
168   }
169   if (r > 2) {
170     *dst += stride;
171     _mm_store_ss(*dst, *dst5);
172     *dst5 = _mm_shuffle_ps(*dst5, *dst5, _MM_SHUFFLE(0, 3, 2, 1));
173     _mm_store_ss(*dst + 1, *dst5);
174     *dst5 = _mm_shuffle_ps(*dst5, *dst5, _MM_SHUFFLE(0, 3, 2, 1));
175     _mm_store_ss(*dst + 2, *dst5);
176   }
177   if (r > 3) {
178     *dst += stride;
179     _mm_store_ss(*dst, *dst7);
180     *dst7 = _mm_shuffle_ps(*dst7, *dst7, _MM_SHUFFLE(0, 3, 2, 1));
181     _mm_store_ss(*dst + 1, *dst7);
182     *dst7 = _mm_shuffle_ps(*dst7, *dst7, _MM_SHUFFLE(0, 3, 2, 1));
183     _mm_store_ss(*dst + 2, *dst7);
184     *dst += stride;
185     *dst += extra_stride;
186   }
187 }
188 
WriteCol4(float ** dst,__m128 * dst1,__m128 * dst2,__m128 * dst3,__m128 * dst4,__m128 * dst5,__m128 * dst6,__m128 * dst7,__m128 * dst8,int stride,int extra_stride,int r)189 static inline void WriteCol4(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5,
190                              __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) {
191   _mm_storeu_ps(*dst, *dst1);
192   if (r > 1) {
193     *dst += stride;
194     _mm_storeu_ps(*dst, *dst3);
195   }
196   if (r > 2) {
197     *dst += stride;
198     _mm_storeu_ps(*dst, *dst5);
199   }
200   if (r > 3) {
201     *dst += stride;
202     _mm_storeu_ps(*dst, *dst7);
203     *dst += stride;
204     *dst += extra_stride;
205   }
206 }
207 
WriteCol5(float ** dst,__m128 * dst1,__m128 * dst2,__m128 * dst3,__m128 * dst4,__m128 * dst5,__m128 * dst6,__m128 * dst7,__m128 * dst8,int stride,int extra_stride,int r)208 static inline void WriteCol5(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5,
209                              __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) {
210   _mm_storeu_ps(*dst, *dst1);
211   _mm_store_ss(*dst + 4, *dst2);
212   if (r > 1) {
213     *dst += stride;
214     _mm_storeu_ps(*dst, *dst3);
215     _mm_store_ss(*dst + 4, *dst4);
216   }
217   if (r > 2) {
218     *dst += stride;
219     _mm_storeu_ps(*dst, *dst5);
220     _mm_store_ss(*dst + 4, *dst6);
221   }
222   if (r > 3) {
223     *dst += stride;
224     _mm_storeu_ps(*dst, *dst7);
225     _mm_store_ss(*dst + 4, *dst8);
226     *dst += stride;
227     *dst += extra_stride;
228   }
229 }
230 
WriteCol6(float ** dst,__m128 * dst1,__m128 * dst2,__m128 * dst3,__m128 * dst4,__m128 * dst5,__m128 * dst6,__m128 * dst7,__m128 * dst8,int stride,int extra_stride,int r)231 static inline void WriteCol6(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5,
232                              __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) {
233   _mm_storeu_ps(*dst, *dst1);
234   _mm_store_ss(*dst + 4, *dst2);
235   *dst2 = _mm_shuffle_ps(*dst2, *dst2, _MM_SHUFFLE(0, 3, 2, 1));
236   _mm_store_ss(*dst + 5, *dst2);
237   if (r > 1) {
238     *dst += stride;
239     _mm_storeu_ps(*dst, *dst3);
240     _mm_store_ss(*dst + 4, *dst4);
241     *dst4 = _mm_shuffle_ps(*dst4, *dst4, _MM_SHUFFLE(0, 3, 2, 1));
242     _mm_store_ss(*dst + 5, *dst4);
243   }
244   if (r > 2) {
245     *dst += stride;
246     _mm_storeu_ps(*dst, *dst5);
247     _mm_store_ss(*dst + 4, *dst6);
248     *dst6 = _mm_shuffle_ps(*dst6, *dst6, _MM_SHUFFLE(0, 3, 2, 1));
249     _mm_store_ss(*dst + 5, *dst6);
250   }
251   if (r > 3) {
252     *dst += stride;
253     _mm_storeu_ps(*dst, *dst7);
254     _mm_store_ss(*dst + 4, *dst8);
255     *dst8 = _mm_shuffle_ps(*dst8, *dst8, _MM_SHUFFLE(0, 3, 2, 1));
256     _mm_store_ss(*dst + 5, *dst8);
257     *dst += stride;
258     *dst += extra_stride;
259   }
260 }
261 
WriteCol7(float ** dst,__m128 * dst1,__m128 * dst2,__m128 * dst3,__m128 * dst4,__m128 * dst5,__m128 * dst6,__m128 * dst7,__m128 * dst8,int stride,int extra_stride,int r)262 static inline void WriteCol7(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5,
263                              __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) {
264   _mm_storeu_ps(*dst, *dst1);
265   _mm_store_ss(*dst + 4, *dst2);
266   *dst2 = _mm_shuffle_ps(*dst2, *dst2, _MM_SHUFFLE(0, 3, 2, 1));
267   _mm_store_ss(*dst + 5, *dst2);
268   *dst2 = _mm_shuffle_ps(*dst2, *dst2, _MM_SHUFFLE(0, 3, 2, 1));
269   _mm_store_ss(*dst + 6, *dst2);
270   if (r > 1) {
271     *dst += stride;
272     _mm_storeu_ps(*dst, *dst3);
273     _mm_store_ss(*dst + 4, *dst4);
274     *dst4 = _mm_shuffle_ps(*dst4, *dst4, _MM_SHUFFLE(0, 3, 2, 1));
275     _mm_store_ss(*dst + 5, *dst4);
276     *dst4 = _mm_shuffle_ps(*dst4, *dst4, _MM_SHUFFLE(0, 3, 2, 1));
277     _mm_store_ss(*dst + 6, *dst4);
278   }
279   if (r > 2) {
280     *dst += stride;
281     _mm_storeu_ps(*dst, *dst5);
282     _mm_store_ss(*dst + 4, *dst6);
283     *dst6 = _mm_shuffle_ps(*dst6, *dst6, _MM_SHUFFLE(0, 3, 2, 1));
284     _mm_store_ss(*dst + 5, *dst6);
285     *dst6 = _mm_shuffle_ps(*dst6, *dst6, _MM_SHUFFLE(0, 3, 2, 1));
286     _mm_store_ss(*dst + 6, *dst6);
287   }
288   if (r > 3) {
289     *dst += stride;
290     _mm_storeu_ps(*dst, *dst7);
291     _mm_store_ss(*dst + 4, *dst8);
292     *dst8 = _mm_shuffle_ps(*dst8, *dst8, _MM_SHUFFLE(0, 3, 2, 1));
293     _mm_store_ss(*dst + 5, *dst8);
294     *dst8 = _mm_shuffle_ps(*dst8, *dst8, _MM_SHUFFLE(0, 3, 2, 1));
295     _mm_store_ss(*dst + 6, *dst8);
296     *dst += stride;
297     *dst += extra_stride;
298   }
299 }
300 
WriteCol8(float ** dst,__m128 * dst1,__m128 * dst2,__m128 * dst3,__m128 * dst4,__m128 * dst5,__m128 * dst6,__m128 * dst7,__m128 * dst8,int stride,int extra_stride,int r)301 static inline void WriteCol8(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5,
302                              __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) {
303   _mm_storeu_ps(*dst, *dst1);
304   _mm_storeu_ps(*dst + 4, *dst2);
305   if (r > 1) {
306     *dst += stride;
307     _mm_storeu_ps(*dst, *dst3);
308     _mm_storeu_ps(*dst + 4, *dst4);
309   }
310   if (r > 2) {
311     *dst += stride;
312     _mm_storeu_ps(*dst, *dst5);
313     _mm_storeu_ps(*dst + 4, *dst6);
314   }
315   if (r > 3) {
316     *dst += stride;
317     _mm_storeu_ps(*dst, *dst7);
318     _mm_storeu_ps(*dst + 4, *dst8);
319     *dst += stride;
320     *dst += extra_stride;
321   }
322 }
323 
DoBiasBlock8(const float * bias_ptr,__m128 * dst1,__m128 * dst2,__m128 * dst3,__m128 * dst4,__m128 * dst5,__m128 * dst6,__m128 * dst7,__m128 * dst8)324 static inline void DoBiasBlock8(const float *bias_ptr, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4,
325                                 __m128 *dst5, __m128 *dst6, __m128 *dst7, __m128 *dst8) {
326   __m128 bias1 = _mm_loadu_ps(bias_ptr);
327   __m128 bias2 = _mm_loadu_ps(bias_ptr + C4NUM);
328   *dst1 = _mm_add_ps(*dst1, bias1);
329   *dst2 = _mm_add_ps(*dst2, bias2);
330   *dst3 = _mm_add_ps(*dst3, bias1);
331   *dst4 = _mm_add_ps(*dst4, bias2);
332   *dst5 = _mm_add_ps(*dst5, bias1);
333   *dst6 = _mm_add_ps(*dst6, bias2);
334   *dst7 = _mm_add_ps(*dst7, bias1);
335   *dst8 = _mm_add_ps(*dst8, bias2);
336 }
337 
338 #endif  // MINDSPORE_LITE_NNACL_INTRINSICS_SSE_SSE_COMMON_H_
339