• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifdef ENABLE_SSE
18 #include "nnacl/intrinsics/ms_simd_instructions.h"
19 #include "nnacl/fp32/conv_depthwise_fp32.h"
20 #include "nnacl/intrinsics/sse/sse_common.h"
21 
22 #ifndef ENABLE_AVX
ConvDwFp32Border(float * dst,const float * src,const float * weight,const float * bias,size_t height,size_t width,size_t in_kh_step,size_t in_kw_step,size_t kernel_w_step,size_t relu,size_t relu6)23 void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width,
24                       size_t in_kh_step, size_t in_kw_step, size_t kernel_w_step, size_t relu, size_t relu6) {
25   in_kh_step /= sizeof(float);
26   in_kw_step /= sizeof(float);
27   kernel_w_step /= sizeof(float);
28 
29   const float *src_kh = src;
30   const float *weight_kh = weight;
31   __m128 dst_ma = _mm_setzero_ps();
32 
33   for (int kh = 0; kh < height; kh++) {
34     const float *src_kw = src_kh;
35     const float *weight_kw = weight_kh;
36 
37     int c1 = 0;
38     int c4 = DOWN_DIV(width, C4NUM) * C4NUM;
39     int c2 = DOWN_DIV(width, C2NUM) * C2NUM;
40     // c4 loop
41     for (; c1 < c4; c1 += C4NUM) {
42       __m128 src_ma1 = _mm_loadu_ps(src_kw);
43       __m128 src_ma2 = _mm_loadu_ps(src_kw + in_kw_step);
44       __m128 src_ma3 = _mm_loadu_ps(src_kw + 2 * in_kw_step);
45       __m128 src_ma4 = _mm_loadu_ps(src_kw + 3 * in_kw_step);
46 
47       __m128 weight_ma1 = _mm_loadu_ps(weight_kw);
48       __m128 weight_ma2 = _mm_loadu_ps(weight_kw + C4NUM);
49       __m128 weight_ma3 = _mm_loadu_ps(weight_kw + 2 * C4NUM);
50       __m128 weight_ma4 = _mm_loadu_ps(weight_kw + 3 * C4NUM);
51 
52       __m128 mul_ma1 = _mm_mul_ps(src_ma1, weight_ma1);
53       __m128 mul_ma2 = _mm_mul_ps(src_ma2, weight_ma2);
54       __m128 mul_ma3 = _mm_mul_ps(src_ma3, weight_ma3);
55       __m128 mul_ma4 = _mm_mul_ps(src_ma4, weight_ma4);
56 
57       dst_ma = _mm_add_ps(dst_ma, mul_ma1);
58       dst_ma = _mm_add_ps(dst_ma, mul_ma2);
59       dst_ma = _mm_add_ps(dst_ma, mul_ma3);
60       dst_ma = _mm_add_ps(dst_ma, mul_ma4);
61 
62       src_kw += in_kw_step * 4;
63       weight_kw += C4NUM * 4;
64     }
65 
66     // c2 loop
67     for (; c1 < c2; c1 += C2NUM) {
68       __m128 src_ma1 = _mm_loadu_ps(src_kw);
69       __m128 src_ma2 = _mm_loadu_ps(src_kw + in_kw_step);
70       __m128 weight_ma1 = _mm_loadu_ps(weight_kw);
71       __m128 weight_ma2 = _mm_loadu_ps(weight_kw + C4NUM);
72       __m128 mul_ma1 = _mm_mul_ps(src_ma1, weight_ma1);
73       __m128 mul_ma2 = _mm_mul_ps(src_ma2, weight_ma2);
74       dst_ma = _mm_add_ps(dst_ma, mul_ma1);
75       dst_ma = _mm_add_ps(dst_ma, mul_ma2);
76 
77       src_kw += in_kw_step * 2;
78       weight_kw += C4NUM * 2;
79     }
80 
81     // remaining
82     for (; c1 < width; ++c1) {
83       __m128 src_ma1 = _mm_loadu_ps(src_kw);
84       __m128 weight_ma1 = _mm_loadu_ps(weight_kw);
85       __m128 mul_ma1 = _mm_mul_ps(src_ma1, weight_ma1);
86       dst_ma = _mm_add_ps(dst_ma, mul_ma1);
87 
88       src_kw += in_kw_step;
89       weight_kw += C4NUM;
90     }
91 
92     src_kh += in_kh_step;
93     weight_kh += kernel_w_step;
94   }
95 
96   __m128 bias_ma = _mm_loadu_ps(bias);
97   dst_ma = _mm_add_ps(dst_ma, bias_ma);
98   __m128 zero_ma = _mm_setzero_ps();
99   if (relu || relu6) {
100     dst_ma = _mm_max_ps(zero_ma, dst_ma);
101     if (relu6) {
102       __m128 const_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f);
103       dst_ma = _mm_min_ps(const_ma, dst_ma);
104     }
105   }
106   _mm_storeu_ps(dst, dst_ma);
107 }
108 #endif
109 
ConvDwFp32Center(float * dst,const float * src,const float * weight,const float * bias,size_t height,size_t width,size_t kernel_h,size_t kernel_w,size_t out_h_step,size_t block_channel,size_t in_sh_step,size_t in_sw_step,size_t in_kh_step,size_t in_kw_step,size_t relu,size_t relu6)110 void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width,
111                       size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step,
112                       size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6) {
113   out_h_step /= sizeof(float);
114   block_channel /= sizeof(float);
115   in_sh_step /= sizeof(float);
116   in_sw_step /= sizeof(float);
117   in_kh_step /= sizeof(float);
118   in_kw_step /= sizeof(float);
119 
120   float *dst_h = dst;
121   const float *src_h = src;
122   for (int oh = 0; oh < height; oh++) {
123     float *dst_w = dst_h;
124     const float *src_w = src_h;
125     int c4 = DOWN_DIV(width, C4NUM) * C4NUM;
126     int c2 = DOWN_DIV(width, C2NUM) * C2NUM;
127     int c1 = 0;
128     // c4 loop
129     for (; c1 < c4; c1 += C4NUM, dst_w += C4NUM * block_channel, src_w += C4NUM * in_sw_step) {
130       const float *src_kh = src_w, *weight_kh = weight;
131       __m128 dst_w_ma1 = _mm_setzero_ps();
132       __m128 dst_w_ma2 = _mm_setzero_ps();
133       __m128 dst_w_ma3 = _mm_setzero_ps();
134       __m128 dst_w_ma4 = _mm_setzero_ps();
135 
136       for (int kh = 0; kh < kernel_h; kh++, src_kh += in_kh_step, weight_kh += kernel_w * C4NUM) {
137         const float *src_kw = src_kh, *weight_kw = weight_kh;
138         for (int kw = 0; kw < kernel_w; kw++, src_kw += in_kw_step, weight_kw += C4NUM) {
139           __m128 src_kw_ma1 = _mm_loadu_ps(src_kw);
140           __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw);
141           __m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1);
142           dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1);
143 
144           __m128 src_kw_ma2 = _mm_loadu_ps(src_kw + in_sw_step);
145           __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw);
146           __m128 tmp_ma2 = _mm_mul_ps(src_kw_ma2, weight_kw_ma2);
147           dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2);
148 
149           __m128 src_kw_ma3 = _mm_loadu_ps(src_kw + 2 * in_sw_step);
150           __m128 weight_kw_ma3 = _mm_loadu_ps(weight_kw);
151           __m128 tmp_ma3 = _mm_mul_ps(src_kw_ma3, weight_kw_ma3);
152           dst_w_ma3 = _mm_add_ps(dst_w_ma3, tmp_ma3);
153 
154           __m128 src_kw_ma4 = _mm_loadu_ps(src_kw + 3 * in_sw_step);
155           __m128 weight_kw_ma4 = _mm_loadu_ps(weight_kw);
156           __m128 tmp_ma4 = _mm_mul_ps(src_kw_ma4, weight_kw_ma4);
157           dst_w_ma4 = _mm_add_ps(dst_w_ma4, tmp_ma4);
158         }  // kernel_w loop
159       }    // kernel_h loop
160 
161       // add bias relu
162       __m128 bias_ma = _mm_loadu_ps(bias);
163       dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma);
164       dst_w_ma2 = _mm_add_ps(dst_w_ma2, bias_ma);
165       dst_w_ma3 = _mm_add_ps(dst_w_ma3, bias_ma);
166       dst_w_ma4 = _mm_add_ps(dst_w_ma4, bias_ma);
167 
168       ActBlock4(&dst_w_ma1, &dst_w_ma2, &dst_w_ma3, &dst_w_ma4, relu, relu6);
169 
170       _mm_storeu_ps(dst_w, dst_w_ma1);
171       _mm_storeu_ps(dst_w + block_channel, dst_w_ma2);
172       _mm_storeu_ps(dst_w + 2 * block_channel, dst_w_ma3);
173       _mm_storeu_ps(dst_w + 3 * block_channel, dst_w_ma4);
174     }  // dst_width loop
175 
176     // c2 loop
177     for (; c1 < c2; c1 += C2NUM, dst_w += C2NUM * block_channel, src_w += C2NUM * in_sw_step) {
178       const float *src_kh = src_w, *weight_kh = weight;
179       __m128 dst_w_ma1 = _mm_setzero_ps();
180       __m128 dst_w_ma2 = _mm_setzero_ps();
181 
182       for (int kh = 0; kh < kernel_h; kh++, src_kh += in_kh_step, weight_kh += kernel_w * C4NUM) {
183         const float *src_kw = src_kh, *weight_kw = weight_kh;
184         for (int kw = 0; kw < kernel_w; kw++, src_kw += in_kw_step, weight_kw += C4NUM) {
185           __m128 src_kw_ma1 = _mm_loadu_ps(src_kw);
186           __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw);
187           __m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1);
188           dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1);
189 
190           __m128 src_kw_ma2 = _mm_loadu_ps(src_kw + in_sw_step);
191           __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw);
192           __m128 tmp_ma2 = _mm_mul_ps(src_kw_ma2, weight_kw_ma2);
193           dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2);
194         }  // kernel_w loop
195       }    // kernel_h loop
196       // add bias relu
197       __m128 bias_ma = _mm_loadu_ps(bias);
198       dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma);
199       dst_w_ma2 = _mm_add_ps(dst_w_ma2, bias_ma);
200 
201       ActBlock2(&dst_w_ma1, &dst_w_ma2, relu, relu6);
202 
203       _mm_storeu_ps(dst_w, dst_w_ma1);
204       _mm_storeu_ps(dst_w + block_channel, dst_w_ma2);
205     }
206 
207     // remaining
208     for (; c1 < width; c1++, dst_w += block_channel, src_w += in_sw_step) {
209       const float *src_kh = src_w, *weight_kh = weight;
210       __m128 dst_w_ma1 = _mm_setzero_ps();
211       for (int kh = 0; kh < kernel_h; kh++, src_kh += in_kh_step, weight_kh += kernel_w * C4NUM) {
212         const float *src_kw = src_kh, *weight_kw = weight_kh;
213         for (int kw = 0; kw < kernel_w; kw++, src_kw += in_kw_step, weight_kw += C4NUM) {
214           __m128 src_kw_ma1 = _mm_loadu_ps(src_kw);
215           __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw);
216           __m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1);
217           dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1);
218         }  // kernel_w loop
219       }    // kernel_h loop
220 
221       // add bias relu
222       __m128 bias_ma = _mm_loadu_ps(bias);
223       dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma);
224       ActBlock1(&dst_w_ma1, relu, relu6);
225       _mm_storeu_ps(dst_w, dst_w_ma1);
226     }
227     dst_h += out_h_step;
228     src_h += in_sh_step;
229   }  // dst_height loop
230 }
231 
DeconvDwFp32Center(float * dst,const float * src,const float * weight,size_t height,size_t width,size_t kernel_h,size_t kernel_w,size_t out_h_step,size_t block_channel,size_t in_sh_step,size_t in_sw_step,size_t in_kh_step,size_t in_kw_step)232 void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_t height, size_t width, size_t kernel_h,
233                         size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step,
234                         size_t in_kh_step, size_t in_kw_step) {
235   out_h_step /= sizeof(float);
236   block_channel /= sizeof(float);
237   in_sh_step /= sizeof(float);
238   in_sw_step /= sizeof(float);
239   in_kh_step /= sizeof(float);
240   in_kw_step /= sizeof(float);
241 
242   float *dst_h = dst;
243   const float *src_h = src;
244   for (int oh = 0; oh < height; oh++) {
245     float *dst_w = dst_h;
246     const float *src_w = src_h;
247     for (int ow = 0; ow < width; ow++) {
248       float *dst_kh = dst_w;
249       const float *weight_kh = weight;
250       __m128 src_w_ma = _mm_loadu_ps(src_w);
251       for (int kh = 0; kh < kernel_h; kh++) {
252         float *dst_kw = dst_kh;
253         const float *weight_kw = weight_kh;
254 
255         int c4 = DOWN_DIV(kernel_w, C4NUM) * C4NUM;
256         int c2 = DOWN_DIV(kernel_w, C2NUM) * C2NUM;
257         int c1 = 0;
258         // c4 loop
259         for (; c1 < c4; c1 += C4NUM) {
260           __m128 dst_w_ma1 = _mm_loadu_ps(dst_kw);
261           __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw);
262           __m128 tmp_ma1 = _mm_mul_ps(src_w_ma, weight_kw_ma1);
263           dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1);
264           _mm_storeu_ps(dst_kw, dst_w_ma1);
265 
266           __m128 dst_w_ma2 = _mm_loadu_ps(dst_kw + in_kw_step);
267           __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw + C4NUM);
268           __m128 tmp_ma2 = _mm_mul_ps(src_w_ma, weight_kw_ma2);
269           dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2);
270           _mm_storeu_ps(dst_kw + in_kw_step, dst_w_ma2);
271 
272           __m128 dst_w_ma3 = _mm_loadu_ps(dst_kw + 2 * in_kw_step);
273           __m128 weight_kw_ma3 = _mm_loadu_ps(weight_kw + 2 * C4NUM);
274           __m128 tmp_ma3 = _mm_mul_ps(src_w_ma, weight_kw_ma3);
275           dst_w_ma3 = _mm_add_ps(dst_w_ma3, tmp_ma3);
276           _mm_storeu_ps(dst_kw + 2 * in_kw_step, dst_w_ma3);
277 
278           __m128 dst_w_ma4 = _mm_loadu_ps(dst_kw + 3 * in_kw_step);
279           __m128 weight_kw_ma4 = _mm_loadu_ps(weight_kw + 3 * C4NUM);
280           __m128 tmp_ma4 = _mm_mul_ps(src_w_ma, weight_kw_ma4);
281           dst_w_ma4 = _mm_add_ps(dst_w_ma4, tmp_ma4);
282           _mm_storeu_ps(dst_kw + 3 * in_kw_step, dst_w_ma4);
283 
284           dst_kw += 4 * in_kw_step;
285           weight_kw += 4 * C4NUM;
286         }
287         // c2 loop
288         for (; c1 < c2; c1 += C2NUM) {
289           __m128 dst_w_ma1 = _mm_loadu_ps(dst_kw);
290           __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw);
291           __m128 tmp_ma1 = _mm_mul_ps(src_w_ma, weight_kw_ma1);
292           dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1);
293           _mm_storeu_ps(dst_kw, dst_w_ma1);
294 
295           __m128 dst_w_ma2 = _mm_loadu_ps(dst_kw + in_kw_step);
296           __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw + C4NUM);
297           __m128 tmp_ma2 = _mm_mul_ps(src_w_ma, weight_kw_ma2);
298           dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2);
299           _mm_storeu_ps(dst_kw + in_kw_step, dst_w_ma2);
300 
301           dst_kw += 2 * in_kw_step;
302           weight_kw += 2 * C4NUM;
303         }
304         // remaining
305         for (; c1 < kernel_w; ++c1) {
306           __m128 dst_w_ma1 = _mm_loadu_ps(dst_kw);
307           __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw);
308           __m128 tmp_ma1 = _mm_mul_ps(src_w_ma, weight_kw_ma1);
309           dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1);
310           _mm_storeu_ps(dst_kw, dst_w_ma1);
311 
312           dst_kw += in_kw_step;
313           weight_kw += C4NUM;
314         }  // kernel_w loop
315 
316         dst_kh += in_kh_step;
317         weight_kh += kernel_w * C4NUM;
318       }  // kernel_h loop
319       dst_w += in_sw_step;
320       src_w += block_channel;
321     }  // dst_width loop
322     dst_h += in_sh_step;
323     src_h += out_h_step;
324   }  // dst_height loop
325 }
326 
327 #endif
328