1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifdef ENABLE_SSE
18 #include "nnacl/intrinsics/ms_simd_instructions.h"
19 #include "nnacl/fp32/conv_depthwise_fp32.h"
20 #include "nnacl/intrinsics/sse/sse_common.h"
21
22 #ifndef ENABLE_AVX
ConvDwFp32Border(float * dst,const float * src,const float * weight,const float * bias,size_t height,size_t width,size_t in_kh_step,size_t in_kw_step,size_t kernel_w_step,size_t relu,size_t relu6)23 void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width,
24 size_t in_kh_step, size_t in_kw_step, size_t kernel_w_step, size_t relu, size_t relu6) {
25 in_kh_step /= sizeof(float);
26 in_kw_step /= sizeof(float);
27 kernel_w_step /= sizeof(float);
28
29 const float *src_kh = src;
30 const float *weight_kh = weight;
31 __m128 dst_ma = _mm_setzero_ps();
32
33 for (int kh = 0; kh < height; kh++) {
34 const float *src_kw = src_kh;
35 const float *weight_kw = weight_kh;
36
37 int c1 = 0;
38 int c4 = DOWN_DIV(width, C4NUM) * C4NUM;
39 int c2 = DOWN_DIV(width, C2NUM) * C2NUM;
40 // c4 loop
41 for (; c1 < c4; c1 += C4NUM) {
42 __m128 src_ma1 = _mm_loadu_ps(src_kw);
43 __m128 src_ma2 = _mm_loadu_ps(src_kw + in_kw_step);
44 __m128 src_ma3 = _mm_loadu_ps(src_kw + 2 * in_kw_step);
45 __m128 src_ma4 = _mm_loadu_ps(src_kw + 3 * in_kw_step);
46
47 __m128 weight_ma1 = _mm_loadu_ps(weight_kw);
48 __m128 weight_ma2 = _mm_loadu_ps(weight_kw + C4NUM);
49 __m128 weight_ma3 = _mm_loadu_ps(weight_kw + 2 * C4NUM);
50 __m128 weight_ma4 = _mm_loadu_ps(weight_kw + 3 * C4NUM);
51
52 __m128 mul_ma1 = _mm_mul_ps(src_ma1, weight_ma1);
53 __m128 mul_ma2 = _mm_mul_ps(src_ma2, weight_ma2);
54 __m128 mul_ma3 = _mm_mul_ps(src_ma3, weight_ma3);
55 __m128 mul_ma4 = _mm_mul_ps(src_ma4, weight_ma4);
56
57 dst_ma = _mm_add_ps(dst_ma, mul_ma1);
58 dst_ma = _mm_add_ps(dst_ma, mul_ma2);
59 dst_ma = _mm_add_ps(dst_ma, mul_ma3);
60 dst_ma = _mm_add_ps(dst_ma, mul_ma4);
61
62 src_kw += in_kw_step * 4;
63 weight_kw += C4NUM * 4;
64 }
65
66 // c2 loop
67 for (; c1 < c2; c1 += C2NUM) {
68 __m128 src_ma1 = _mm_loadu_ps(src_kw);
69 __m128 src_ma2 = _mm_loadu_ps(src_kw + in_kw_step);
70 __m128 weight_ma1 = _mm_loadu_ps(weight_kw);
71 __m128 weight_ma2 = _mm_loadu_ps(weight_kw + C4NUM);
72 __m128 mul_ma1 = _mm_mul_ps(src_ma1, weight_ma1);
73 __m128 mul_ma2 = _mm_mul_ps(src_ma2, weight_ma2);
74 dst_ma = _mm_add_ps(dst_ma, mul_ma1);
75 dst_ma = _mm_add_ps(dst_ma, mul_ma2);
76
77 src_kw += in_kw_step * 2;
78 weight_kw += C4NUM * 2;
79 }
80
81 // remaining
82 for (; c1 < width; ++c1) {
83 __m128 src_ma1 = _mm_loadu_ps(src_kw);
84 __m128 weight_ma1 = _mm_loadu_ps(weight_kw);
85 __m128 mul_ma1 = _mm_mul_ps(src_ma1, weight_ma1);
86 dst_ma = _mm_add_ps(dst_ma, mul_ma1);
87
88 src_kw += in_kw_step;
89 weight_kw += C4NUM;
90 }
91
92 src_kh += in_kh_step;
93 weight_kh += kernel_w_step;
94 }
95
96 __m128 bias_ma = _mm_loadu_ps(bias);
97 dst_ma = _mm_add_ps(dst_ma, bias_ma);
98 __m128 zero_ma = _mm_setzero_ps();
99 if (relu || relu6) {
100 dst_ma = _mm_max_ps(zero_ma, dst_ma);
101 if (relu6) {
102 __m128 const_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f);
103 dst_ma = _mm_min_ps(const_ma, dst_ma);
104 }
105 }
106 _mm_storeu_ps(dst, dst_ma);
107 }
108 #endif
109
ConvDwFp32Center(float * dst,const float * src,const float * weight,const float * bias,size_t height,size_t width,size_t kernel_h,size_t kernel_w,size_t out_h_step,size_t block_channel,size_t in_sh_step,size_t in_sw_step,size_t in_kh_step,size_t in_kw_step,size_t relu,size_t relu6)110 void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width,
111 size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step,
112 size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6) {
113 out_h_step /= sizeof(float);
114 block_channel /= sizeof(float);
115 in_sh_step /= sizeof(float);
116 in_sw_step /= sizeof(float);
117 in_kh_step /= sizeof(float);
118 in_kw_step /= sizeof(float);
119
120 float *dst_h = dst;
121 const float *src_h = src;
122 for (int oh = 0; oh < height; oh++) {
123 float *dst_w = dst_h;
124 const float *src_w = src_h;
125 int c4 = DOWN_DIV(width, C4NUM) * C4NUM;
126 int c2 = DOWN_DIV(width, C2NUM) * C2NUM;
127 int c1 = 0;
128 // c4 loop
129 for (; c1 < c4; c1 += C4NUM, dst_w += C4NUM * block_channel, src_w += C4NUM * in_sw_step) {
130 const float *src_kh = src_w, *weight_kh = weight;
131 __m128 dst_w_ma1 = _mm_setzero_ps();
132 __m128 dst_w_ma2 = _mm_setzero_ps();
133 __m128 dst_w_ma3 = _mm_setzero_ps();
134 __m128 dst_w_ma4 = _mm_setzero_ps();
135
136 for (int kh = 0; kh < kernel_h; kh++, src_kh += in_kh_step, weight_kh += kernel_w * C4NUM) {
137 const float *src_kw = src_kh, *weight_kw = weight_kh;
138 for (int kw = 0; kw < kernel_w; kw++, src_kw += in_kw_step, weight_kw += C4NUM) {
139 __m128 src_kw_ma1 = _mm_loadu_ps(src_kw);
140 __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw);
141 __m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1);
142 dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1);
143
144 __m128 src_kw_ma2 = _mm_loadu_ps(src_kw + in_sw_step);
145 __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw);
146 __m128 tmp_ma2 = _mm_mul_ps(src_kw_ma2, weight_kw_ma2);
147 dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2);
148
149 __m128 src_kw_ma3 = _mm_loadu_ps(src_kw + 2 * in_sw_step);
150 __m128 weight_kw_ma3 = _mm_loadu_ps(weight_kw);
151 __m128 tmp_ma3 = _mm_mul_ps(src_kw_ma3, weight_kw_ma3);
152 dst_w_ma3 = _mm_add_ps(dst_w_ma3, tmp_ma3);
153
154 __m128 src_kw_ma4 = _mm_loadu_ps(src_kw + 3 * in_sw_step);
155 __m128 weight_kw_ma4 = _mm_loadu_ps(weight_kw);
156 __m128 tmp_ma4 = _mm_mul_ps(src_kw_ma4, weight_kw_ma4);
157 dst_w_ma4 = _mm_add_ps(dst_w_ma4, tmp_ma4);
158 } // kernel_w loop
159 } // kernel_h loop
160
161 // add bias relu
162 __m128 bias_ma = _mm_loadu_ps(bias);
163 dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma);
164 dst_w_ma2 = _mm_add_ps(dst_w_ma2, bias_ma);
165 dst_w_ma3 = _mm_add_ps(dst_w_ma3, bias_ma);
166 dst_w_ma4 = _mm_add_ps(dst_w_ma4, bias_ma);
167
168 ActBlock4(&dst_w_ma1, &dst_w_ma2, &dst_w_ma3, &dst_w_ma4, relu, relu6);
169
170 _mm_storeu_ps(dst_w, dst_w_ma1);
171 _mm_storeu_ps(dst_w + block_channel, dst_w_ma2);
172 _mm_storeu_ps(dst_w + 2 * block_channel, dst_w_ma3);
173 _mm_storeu_ps(dst_w + 3 * block_channel, dst_w_ma4);
174 } // dst_width loop
175
176 // c2 loop
177 for (; c1 < c2; c1 += C2NUM, dst_w += C2NUM * block_channel, src_w += C2NUM * in_sw_step) {
178 const float *src_kh = src_w, *weight_kh = weight;
179 __m128 dst_w_ma1 = _mm_setzero_ps();
180 __m128 dst_w_ma2 = _mm_setzero_ps();
181
182 for (int kh = 0; kh < kernel_h; kh++, src_kh += in_kh_step, weight_kh += kernel_w * C4NUM) {
183 const float *src_kw = src_kh, *weight_kw = weight_kh;
184 for (int kw = 0; kw < kernel_w; kw++, src_kw += in_kw_step, weight_kw += C4NUM) {
185 __m128 src_kw_ma1 = _mm_loadu_ps(src_kw);
186 __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw);
187 __m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1);
188 dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1);
189
190 __m128 src_kw_ma2 = _mm_loadu_ps(src_kw + in_sw_step);
191 __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw);
192 __m128 tmp_ma2 = _mm_mul_ps(src_kw_ma2, weight_kw_ma2);
193 dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2);
194 } // kernel_w loop
195 } // kernel_h loop
196 // add bias relu
197 __m128 bias_ma = _mm_loadu_ps(bias);
198 dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma);
199 dst_w_ma2 = _mm_add_ps(dst_w_ma2, bias_ma);
200
201 ActBlock2(&dst_w_ma1, &dst_w_ma2, relu, relu6);
202
203 _mm_storeu_ps(dst_w, dst_w_ma1);
204 _mm_storeu_ps(dst_w + block_channel, dst_w_ma2);
205 }
206
207 // remaining
208 for (; c1 < width; c1++, dst_w += block_channel, src_w += in_sw_step) {
209 const float *src_kh = src_w, *weight_kh = weight;
210 __m128 dst_w_ma1 = _mm_setzero_ps();
211 for (int kh = 0; kh < kernel_h; kh++, src_kh += in_kh_step, weight_kh += kernel_w * C4NUM) {
212 const float *src_kw = src_kh, *weight_kw = weight_kh;
213 for (int kw = 0; kw < kernel_w; kw++, src_kw += in_kw_step, weight_kw += C4NUM) {
214 __m128 src_kw_ma1 = _mm_loadu_ps(src_kw);
215 __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw);
216 __m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1);
217 dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1);
218 } // kernel_w loop
219 } // kernel_h loop
220
221 // add bias relu
222 __m128 bias_ma = _mm_loadu_ps(bias);
223 dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma);
224 ActBlock1(&dst_w_ma1, relu, relu6);
225 _mm_storeu_ps(dst_w, dst_w_ma1);
226 }
227 dst_h += out_h_step;
228 src_h += in_sh_step;
229 } // dst_height loop
230 }
231
DeconvDwFp32Center(float * dst,const float * src,const float * weight,size_t height,size_t width,size_t kernel_h,size_t kernel_w,size_t out_h_step,size_t block_channel,size_t in_sh_step,size_t in_sw_step,size_t in_kh_step,size_t in_kw_step)232 void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_t height, size_t width, size_t kernel_h,
233 size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step,
234 size_t in_kh_step, size_t in_kw_step) {
235 out_h_step /= sizeof(float);
236 block_channel /= sizeof(float);
237 in_sh_step /= sizeof(float);
238 in_sw_step /= sizeof(float);
239 in_kh_step /= sizeof(float);
240 in_kw_step /= sizeof(float);
241
242 float *dst_h = dst;
243 const float *src_h = src;
244 for (int oh = 0; oh < height; oh++) {
245 float *dst_w = dst_h;
246 const float *src_w = src_h;
247 for (int ow = 0; ow < width; ow++) {
248 float *dst_kh = dst_w;
249 const float *weight_kh = weight;
250 __m128 src_w_ma = _mm_loadu_ps(src_w);
251 for (int kh = 0; kh < kernel_h; kh++) {
252 float *dst_kw = dst_kh;
253 const float *weight_kw = weight_kh;
254
255 int c4 = DOWN_DIV(kernel_w, C4NUM) * C4NUM;
256 int c2 = DOWN_DIV(kernel_w, C2NUM) * C2NUM;
257 int c1 = 0;
258 // c4 loop
259 for (; c1 < c4; c1 += C4NUM) {
260 __m128 dst_w_ma1 = _mm_loadu_ps(dst_kw);
261 __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw);
262 __m128 tmp_ma1 = _mm_mul_ps(src_w_ma, weight_kw_ma1);
263 dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1);
264 _mm_storeu_ps(dst_kw, dst_w_ma1);
265
266 __m128 dst_w_ma2 = _mm_loadu_ps(dst_kw + in_kw_step);
267 __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw + C4NUM);
268 __m128 tmp_ma2 = _mm_mul_ps(src_w_ma, weight_kw_ma2);
269 dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2);
270 _mm_storeu_ps(dst_kw + in_kw_step, dst_w_ma2);
271
272 __m128 dst_w_ma3 = _mm_loadu_ps(dst_kw + 2 * in_kw_step);
273 __m128 weight_kw_ma3 = _mm_loadu_ps(weight_kw + 2 * C4NUM);
274 __m128 tmp_ma3 = _mm_mul_ps(src_w_ma, weight_kw_ma3);
275 dst_w_ma3 = _mm_add_ps(dst_w_ma3, tmp_ma3);
276 _mm_storeu_ps(dst_kw + 2 * in_kw_step, dst_w_ma3);
277
278 __m128 dst_w_ma4 = _mm_loadu_ps(dst_kw + 3 * in_kw_step);
279 __m128 weight_kw_ma4 = _mm_loadu_ps(weight_kw + 3 * C4NUM);
280 __m128 tmp_ma4 = _mm_mul_ps(src_w_ma, weight_kw_ma4);
281 dst_w_ma4 = _mm_add_ps(dst_w_ma4, tmp_ma4);
282 _mm_storeu_ps(dst_kw + 3 * in_kw_step, dst_w_ma4);
283
284 dst_kw += 4 * in_kw_step;
285 weight_kw += 4 * C4NUM;
286 }
287 // c2 loop
288 for (; c1 < c2; c1 += C2NUM) {
289 __m128 dst_w_ma1 = _mm_loadu_ps(dst_kw);
290 __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw);
291 __m128 tmp_ma1 = _mm_mul_ps(src_w_ma, weight_kw_ma1);
292 dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1);
293 _mm_storeu_ps(dst_kw, dst_w_ma1);
294
295 __m128 dst_w_ma2 = _mm_loadu_ps(dst_kw + in_kw_step);
296 __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw + C4NUM);
297 __m128 tmp_ma2 = _mm_mul_ps(src_w_ma, weight_kw_ma2);
298 dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2);
299 _mm_storeu_ps(dst_kw + in_kw_step, dst_w_ma2);
300
301 dst_kw += 2 * in_kw_step;
302 weight_kw += 2 * C4NUM;
303 }
304 // remaining
305 for (; c1 < kernel_w; ++c1) {
306 __m128 dst_w_ma1 = _mm_loadu_ps(dst_kw);
307 __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw);
308 __m128 tmp_ma1 = _mm_mul_ps(src_w_ma, weight_kw_ma1);
309 dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1);
310 _mm_storeu_ps(dst_kw, dst_w_ma1);
311
312 dst_kw += in_kw_step;
313 weight_kw += C4NUM;
314 } // kernel_w loop
315
316 dst_kh += in_kh_step;
317 weight_kh += kernel_w * C4NUM;
318 } // kernel_h loop
319 dst_w += in_sw_step;
320 src_w += block_channel;
321 } // dst_width loop
322 dst_h += in_sh_step;
323 src_h += out_h_step;
324 } // dst_height loop
325 }
326
327 #endif
328