1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "nnacl/fp32/conv_depthwise_fp32.h"
17 #include "nnacl/common_func.h"
18 #include "nnacl/fp32/common_func_fp32.h"
19 #include "nnacl/intrinsics/ms_simd_instructions.h"
20 #include "nnacl/errorcode.h"
21 #include "nnacl/fp32/activation_fp32.h"
22
23 #if !defined(ENABLE_ARM) && !defined(ENABLE_SSE)
ConvDwFp32Row(float * output_ptr,const float * input_ptr,const float * weight_ptr,int num_pixels,int output_channel,int input_step)24 void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, int num_pixels,
25 int output_channel, int input_step) {
26 for (int i = 0; i < num_pixels; i++) {
27 for (int c = 0; c < output_channel; c++) {
28 *output_ptr++ += weight_ptr[c] * input_ptr[c];
29 }
30 input_ptr += input_step;
31 }
32 }
33 #endif
34
ConvDw(float * output_data,const float * input_data,const float * weight_data,const float * bias_data,const ConvParameter * conv_param,int task_id)35 int ConvDw(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
36 const ConvParameter *conv_param, int task_id) {
37 if (conv_param->thread_num_ == 0 || conv_param->dilation_h_ == 0 || conv_param->stride_w_ == 0) {
38 return NNACL_ERR;
39 }
40 int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_);
41 int h_start = h_step * task_id;
42 int h_end = MSMIN(h_start + h_step, conv_param->output_h_);
43 bool relu = conv_param->act_type_ == ActType_Relu;
44 bool relu6 = conv_param->act_type_ == ActType_Relu6;
45 for (int b = 0; b < conv_param->output_batch_; b++) {
46 const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_;
47 float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
48 for (int oh = h_start; oh < h_end; oh++) {
49 float *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_;
50
51 int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_u_;
52 int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_));
53 int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_));
54
55 for (int ow = 0; ow < conv_param->output_w_; ow++) {
56 memcpy(dst_data + ow * conv_param->output_channel_, bias_data,
57 conv_param->output_channel_ * (int)(sizeof(float)));
58 }
59 for (int kh = start_kh; kh < end_kh; kh++) {
60 int ih = ih_origin + conv_param->dilation_h_ * kh;
61
62 const float *src_kh = src + ih * conv_param->input_w_ * conv_param->input_channel_;
63 const float *weight_kh = weight_data + kh * conv_param->kernel_w_ * conv_param->output_channel_;
64
65 int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_;
66 for (int kw = 0; kw < conv_param->kernel_w_; kw++) {
67 int out_w_start = MSMAX(
68 0, (conv_param->pad_l_ - conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / conv_param->stride_w_);
69 int out_w_end = MSMIN(conv_param->output_w_, (conv_param->input_w_ + conv_param->pad_l_ -
70 conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) /
71 conv_param->stride_w_);
72
73 float *dst_w = dst_data + out_w_start * conv_param->output_channel_;
74 int iw_origin = (out_w_start * conv_param->stride_w_) - conv_param->pad_l_ + conv_param->dilation_w_ * kw;
75
76 const float *src_kw = src_kh + iw_origin * conv_param->input_channel_;
77 int num_pixels = out_w_end - out_w_start;
78
79 ConvDwFp32Row(dst_w, src_kw, weight_kh, num_pixels, conv_param->output_channel_, in_sw_step);
80 weight_kh += conv_param->output_channel_;
81 }
82 }
83 if (relu) {
84 Fp32Relu(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data);
85 }
86 if (relu6) {
87 Fp32Relu6(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data);
88 }
89 }
90 }
91 return NNACL_OK;
92 }
93
InitSlidingParam(SlidingWindowParam * sliding,const ConvParameter * conv_param,int block)94 void InitSlidingParam(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) {
95 if (block == 0) {
96 return;
97 }
98 int left = 0;
99 int right = conv_param->output_w_;
100 int top = 0;
101 int bottom = conv_param->output_h_;
102
103 while (left * conv_param->stride_w_ < conv_param->pad_l_) {
104 left++;
105 }
106 while ((right - 1) * conv_param->stride_w_ - conv_param->pad_l_ + conv_param->kernel_w_ * conv_param->dilation_w_ >
107 conv_param->input_w_ &&
108 right > left) {
109 right--;
110 }
111 while (top * conv_param->stride_h_ < conv_param->pad_u_) {
112 top++;
113 }
114 while ((bottom - 1) * conv_param->stride_h_ - conv_param->pad_u_ + conv_param->kernel_h_ * conv_param->dilation_h_ >
115 conv_param->input_h_ &&
116 bottom > top) {
117 bottom--;
118 }
119 sliding->left_ = left;
120 sliding->right_ = right;
121 sliding->top_ = top;
122 sliding->bottom_ = bottom;
123 sliding->c_block_ = UP_DIV(conv_param->output_channel_, block);
124 sliding->block_channel_ = UP_DIV(conv_param->output_channel_, block) * block;
125 sliding->out_step_ = conv_param->output_h_ * conv_param->output_w_ * sliding->block_channel_;
126 if (conv_param->out_format_ == NNACL_NC4HW4) {
127 // write to nc8hw8
128 sliding->out_h_step_ = conv_param->output_w_ * block;
129 sliding->out_c_step_ = block * conv_param->output_h_ * conv_param->output_w_;
130 sliding->out_w_step_ = block;
131 sliding->out_block_step_ = sliding->out_c_step_;
132 } else {
133 // write to nhwc
134 sliding->out_h_step_ = conv_param->output_w_ * sliding->block_channel_;
135 sliding->out_c_step_ = block;
136 sliding->out_w_step_ = sliding->block_channel_;
137 sliding->out_block_step_ = sliding->out_w_step_;
138 }
139 }
140
InitSlidingParamConv(SlidingWindowParam * sliding,const ConvParameter * conv_param,int input_block,int weight_block)141 void InitSlidingParamConv(SlidingWindowParam *sliding, const ConvParameter *conv_param, int input_block,
142 int weight_block) {
143 InitSlidingParam(sliding, conv_param, weight_block);
144 AppendSlidingParamConv(sliding, conv_param, input_block, weight_block);
145 }
146
AppendSlidingParamConv(SlidingWindowParam * sliding,const ConvParameter * conv_param,int input_block,int weight_block)147 void AppendSlidingParamConv(SlidingWindowParam *sliding, const ConvParameter *conv_param, int input_block,
148 int weight_block) {
149 if (input_block == 0) { // is not aligned
150 sliding->ic_align_ = conv_param->input_channel_;
151 } else { // 1x1 input is aligned to input_block
152 sliding->ic_align_ = UP_DIV(conv_param->input_channel_, input_block) * input_block;
153 }
154 sliding->in_step_ = conv_param->input_h_ * conv_param->input_w_ * sliding->ic_align_; // for batch loop
155 sliding->in_h_step_ = conv_param->input_w_ * sliding->ic_align_;
156 sliding->in_sh_step_ = conv_param->input_w_ * sliding->ic_align_ * conv_param->stride_h_; // stride H
157 sliding->in_sw_step_ = sliding->ic_align_ * conv_param->stride_w_; // stride W
158 sliding->in_kh_step_ = conv_param->input_w_ * sliding->ic_align_ * conv_param->dilation_h_; // kernel H
159 sliding->in_kw_step_ = sliding->ic_align_ * conv_param->dilation_w_; // kernel W
160 sliding->kernel_step_ = conv_param->kernel_w_ * conv_param->kernel_h_ * sliding->ic_align_ * weight_block;
161 }
162
InitSlidingParamConvDw(SlidingWindowParam * sliding,const ConvParameter * conv_param,int block)163 void InitSlidingParamConvDw(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) {
164 InitSlidingParam(sliding, conv_param, block);
165 AppendSlidingParamConvDw(sliding, conv_param, block);
166 }
167
AppendSlidingParamConvDw(SlidingWindowParam * sliding,const ConvParameter * conv_param,int block)168 void AppendSlidingParamConvDw(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) {
169 sliding->in_step_ = conv_param->input_h_ * conv_param->input_w_ * sliding->block_channel_; // for batch loop
170 sliding->in_h_step_ = conv_param->input_w_ * sliding->block_channel_;
171 sliding->in_sh_step_ = conv_param->input_w_ * sliding->block_channel_ * conv_param->stride_h_; // stride H
172 sliding->in_sw_step_ = sliding->block_channel_ * conv_param->stride_w_; // stride W
173 sliding->in_kh_step_ = conv_param->input_w_ * sliding->block_channel_ * conv_param->dilation_h_; // kernel H
174 sliding->in_kw_step_ = sliding->block_channel_ * conv_param->dilation_w_; // kernel W
175 sliding->kernel_step_ = conv_param->kernel_w_ * conv_param->kernel_h_ * block;
176 }
177
178 /*conv depthwise fp32 begin*/
ConvDwBorderPixel(float * dst,const float * src,const float * weight,const float * bias,int height,int width,int in_kh_step,int in_kw_step,int kernel_w_step,bool is_relu,bool is_relu6)179 void ConvDwBorderPixel(float *dst, const float *src, const float *weight, const float *bias, int height, int width,
180 int in_kh_step, int in_kw_step, int kernel_w_step, bool is_relu, bool is_relu6) {
181 const float *src_kh = src;
182 const float *weight_kh = weight;
183 for (int c = 0; c < C4NUM; c++) {
184 dst[c] = 0;
185 }
186 for (int kh = 0; kh < height; kh++) {
187 const float *src_kw = src_kh;
188 const float *weight_kw = weight_kh;
189 for (int kw = 0; kw < width; kw++) {
190 for (int c = 0; c < C4NUM; c++) {
191 dst[c] += src_kw[c] * weight_kw[c];
192 }
193 src_kw += in_kw_step;
194 weight_kw += C4NUM;
195 } // kernel_w loop
196 src_kh += in_kh_step;
197 weight_kh += kernel_w_step;
198 } // kernel_h loop
199 for (int c = 0; c < C4NUM; c++) {
200 dst[c] += bias[c];
201 dst[c] = (is_relu) ? (MSMAX(0, dst[c])) : (dst[c]);
202 dst[c] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst[c]))) : (dst[c]);
203 }
204 }
205
ConvDwBorder(float * dst,const float * src,const float * weight,const float * bias,int top,int bottom,int left,int right,const ConvParameter * conv_param,const SlidingWindowParam * sliding)206 void ConvDwBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, int left,
207 int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
208 if (conv_param->dilation_h_ == 0 || conv_param->dilation_w_ == 0) {
209 return;
210 }
211 bool relu = conv_param->act_type_ == ActType_Relu;
212 bool relu6 = conv_param->act_type_ == ActType_Relu6;
213 float *dst_h = dst + top * sliding->out_h_step_;
214 for (int oh = top; oh < bottom; oh++) {
215 int ih = oh * conv_param->stride_h_ - conv_param->pad_u_;
216 int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
217 int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
218 const float *src_h = src + ih * sliding->in_h_step_;
219
220 float *dst_kernel = dst_h + left * sliding->block_channel_;
221 for (int ow = left; ow < right; ow++) {
222 int iw = ow * conv_param->stride_w_ - conv_param->pad_l_;
223 int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
224 int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
225 const float *src_w = src_h + iw * sliding->block_channel_;
226
227 const float *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_;
228 const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM;
229 #ifdef ENABLE_AVX
230 ConvDwFp32BorderParam *param = (ConvDwFp32BorderParam *)malloc(sizeof(ConvDwFp32BorderParam));
231 if (param == NULL) {
232 return;
233 }
234 param->dst = dst_kernel;
235 param->src = src_kernel;
236 param->weight = weight_kernel;
237 param->bias = bias;
238 param->height = end_kh - start_kh;
239 param->width = end_kw - start_kw;
240 param->in_kh_step = sliding->in_kh_step_ * sizeof(float);
241 param->in_kw_step = sliding->in_kw_step_ * sizeof(float);
242 param->kernel_w = conv_param->kernel_w_ * C4NUM * sizeof(float);
243 param->relu = relu;
244 param->relu6 = relu6;
245 ConvDwFp32Border(param);
246 free(param);
247 #elif defined(ENABLE_ARM) || defined(ENABLE_SSE)
248 ConvDwFp32Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
249 sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float),
250 conv_param->kernel_w_ * C4NUM * sizeof(float), relu, relu6);
251 #else
252 ConvDwBorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
253 sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C4NUM, relu, relu6);
254 #endif
255 dst_kernel += sliding->block_channel_;
256 } // width loop
257 dst_h += sliding->out_h_step_;
258 } // height loop
259 }
260
261 #ifndef ENABLE_ARM64
ConvDwCenter(float * dst,const float * src,const float * weight,const float * bias,int height,int width,int kernel_h,int kernel_w,int out_h_step,int block_channel,int in_sh_step,int in_sw_step,int in_kh_step,int in_kw_step,bool is_relu,bool is_relu6)262 void ConvDwCenter(float *dst, const float *src, const float *weight, const float *bias, int height, int width,
263 int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step,
264 int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6) {
265 float *dst_h = dst;
266 const float *src_h = src;
267 for (int oh = 0; oh < height; oh++) {
268 float *dst_w = dst_h;
269 const float *src_w = src_h;
270 for (int ow = 0; ow < width; ow++) {
271 const float *src_kh = src_w;
272 const float *weight_kh = weight;
273 for (int c = 0; c < C4NUM; c++) {
274 dst_w[c] = 0;
275 }
276 for (int kh = 0; kh < kernel_h; kh++) {
277 const float *src_kw = src_kh;
278 const float *weight_kw = weight_kh;
279 for (int kw = 0; kw < kernel_w; kw++) {
280 for (int c = 0; c < C4NUM; c++) {
281 dst_w[c] += src_kw[c] * weight_kw[c];
282 }
283 src_kw += in_kw_step;
284 weight_kw += C4NUM;
285 } // kernel_w loop
286 src_kh += in_kh_step;
287 weight_kh += kernel_w * C4NUM;
288 } // kernel_h loop
289 // add biad relu
290 for (int c = 0; c < C4NUM; c++) {
291 dst_w[c] += bias[c];
292 dst_w[c] = (is_relu) ? (MSMAX(0, dst_w[c])) : (dst_w[c]);
293 dst_w[c] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst_w[c]))) : (dst_w[c]);
294 }
295 dst_w += block_channel;
296 src_w += in_sw_step;
297 } // dst_width loop
298 dst_h += out_h_step;
299 src_h += in_sh_step;
300 } // dst_height loop
301 }
302 #endif
303
304 // conv depthwise fp32: sliding window
ConvDwSWFp32(float * output_data,const float * input_data,const float * weight_data,const float * bias_data,const ConvParameter * conv_param,const SlidingWindowParam * sliding,int task_id)305 void ConvDwSWFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
306 const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) {
307 bool relu = conv_param->act_type_ == ActType_Relu;
308 bool relu6 = conv_param->act_type_ == ActType_Relu6;
309 if (conv_param->thread_num_ == 0) {
310 return;
311 }
312 const float *src = input_data;
313 float *dst = output_data;
314 for (int b = 0; b < conv_param->output_batch_; b++) {
315 for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) {
316 const float *src_data = src + oc * C4NUM;
317 float *dst_data = dst + oc * C4NUM;
318 const float *weight = weight_data + oc * sliding->kernel_step_;
319 const float *bias = bias_data + oc * C4NUM;
320 ConvDwBorder(dst_data, src_data, weight, bias, 0, sliding->top_, 0, conv_param->output_w_, conv_param, sliding);
321 ConvDwBorder(dst_data, src_data, weight, bias, sliding->bottom_, conv_param->output_h_, 0, conv_param->output_w_,
322 conv_param, sliding);
323 ConvDwBorder(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, 0, sliding->left_, conv_param,
324 sliding);
325 ConvDwBorder(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, sliding->right_,
326 conv_param->output_w_, conv_param, sliding);
327
328 if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) {
329 int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_;
330 int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
331 const float *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_;
332 float *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;
333 #if defined(ENABLE_ARM) || defined(ENABLE_SSE)
334 ConvDwFp32Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
335 conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float),
336 sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float),
337 sliding->in_sw_step_ * sizeof(float), sliding->in_kh_step_ * sizeof(float),
338 sliding->in_kw_step_ * sizeof(float), relu, relu6);
339 #else
340 ConvDwCenter(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
341 conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_,
342 sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_, relu,
343 relu6);
344 #endif
345 }
346 } // output C4 loop
347 src += sliding->in_step_;
348 dst += sliding->out_step_;
349 } // batch loop
350 // output nhwc4
351 }
352 /*conv depthwise fp32 end*/
353
354 /*conv depthwise 3x3 fp32 begin*/
CheckConvDwUse3X3(const ConvParameter * conv_param)355 bool CheckConvDwUse3X3(const ConvParameter *conv_param) {
356 bool use_3x3 =
357 conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 &&
358 (conv_param->stride_h_ == 1 || conv_param->stride_h_ == 2) &&
359 (conv_param->stride_w_ == 1 || conv_param->stride_w_ == 2) && conv_param->stride_h_ == conv_param->stride_w_ &&
360 (conv_param->pad_u_ == 0 || conv_param->pad_u_ == 1) && (conv_param->pad_l_ == 0 || conv_param->pad_l_ == 1) &&
361 conv_param->pad_u_ == conv_param->pad_l_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1;
362 if (!use_3x3 || conv_param->input_h_ == 1 || conv_param->input_w_ == 1) {
363 return false;
364 }
365 const int in_h = (conv_param->output_h_ - 1) * conv_param->stride_h_ + conv_param->kernel_h_;
366 const int in_w = (conv_param->output_w_ - 1) * conv_param->stride_w_ + conv_param->kernel_w_;
367 return in_h == (conv_param->input_h_ + 2 * conv_param->pad_u_) &&
368 in_w == (conv_param->input_w_ + 2 * conv_param->pad_l_);
369 }
370
371 #if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
ConvDw3x3RowLeft(const float * src,float * line,int lw,int channel)372 static void ConvDw3x3RowLeft(const float *src, float *line, int lw, int channel) {
373 MS_FLOAT32X4 v0, v1, v2, v3;
374 v0 = MS_MOVQ_F32(0.0f);
375 int ic = 0;
376 for (; ic < channel - 3; ic += 4) {
377 v1 = MS_LDQ_F32(src + ic);
378 v2 = MS_LDQ_F32(src + channel + ic);
379 v3 = MS_LDQ_F32(src + 2 * channel + ic);
380 MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2);
381 MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2);
382 MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1);
383 MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1);
384 MS_STQ_F32(line + lw * ic, b0);
385 MS_STQ_F32(line + lw * ic + 4, b1);
386 MS_STQ_F32(line + lw * ic + 8, b2);
387 MS_STQ_F32(line + lw * ic + 12, b3);
388 }
389 if (ic < channel) {
390 float *remain_line = line + ic * lw;
391 memset(remain_line, 0, 64);
392 for (int i = 0; i < channel - ic; i++) {
393 float d1 = src[i + ic];
394 float d2 = src[i + ic + channel];
395 float d3 = src[i + ic + 2 * channel];
396 remain_line[i] = 0.0f - d2;
397 remain_line[i + 4] = d1 + d2;
398 remain_line[i + 8] = d2 - d1;
399 remain_line[i + 12] = d3 - d1;
400 }
401 }
402 }
403
ConvDw3x3RowMiddle(const float * src,float * line,int lw,int channel)404 static void ConvDw3x3RowMiddle(const float *src, float *line, int lw, int channel) {
405 MS_FLOAT32X4 v0, v1, v2, v3;
406 int ic = 0;
407 for (; ic < channel - 3; ic += 4) {
408 v0 = MS_LDQ_F32(src + ic);
409 v1 = MS_LDQ_F32(src + channel + ic);
410 v2 = MS_LDQ_F32(src + 2 * channel + ic);
411 v3 = MS_LDQ_F32(src + 3 * channel + ic);
412 MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2);
413 MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2);
414 MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1);
415 MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1);
416 MS_STQ_F32(line + lw * ic, b0);
417 MS_STQ_F32(line + lw * ic + 4, b1);
418 MS_STQ_F32(line + lw * ic + 8, b2);
419 MS_STQ_F32(line + lw * ic + 12, b3);
420 }
421 if (ic < channel) {
422 float *remain_line = line + ic * lw;
423 memset(remain_line, 0, 64);
424 for (int i = 0; i < channel - ic; i++) {
425 float d0 = src[i + ic];
426 float d1 = src[i + ic + channel];
427 float d2 = src[i + ic + 2 * channel];
428 float d3 = src[i + ic + 3 * channel];
429 remain_line[i] = d0 - d2;
430 remain_line[i + 4] = d1 + d2;
431 remain_line[i + 8] = d2 - d1;
432 remain_line[i + 12] = d3 - d1;
433 }
434 }
435 }
436
ConvDw3x3RowRight(const float * src,float * line,int lw,int channel)437 static void ConvDw3x3RowRight(const float *src, float *line, int lw, int channel) {
438 MS_FLOAT32X4 v0, v1, v2, v3;
439 int ic = 0;
440 v3 = MS_MOVQ_F32(0.0f);
441 for (; ic < channel - 3; ic += 4) {
442 v0 = MS_LDQ_F32(src + ic);
443 v1 = MS_LDQ_F32(src + channel + ic);
444 v2 = MS_LDQ_F32(src + 2 * channel + ic);
445 MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2);
446 MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2);
447 MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1);
448 MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1);
449 MS_STQ_F32(line + lw * ic, b0);
450 MS_STQ_F32(line + lw * ic + 4, b1);
451 MS_STQ_F32(line + lw * ic + 8, b2);
452 MS_STQ_F32(line + lw * ic + 12, b3);
453 }
454 if (ic < channel) {
455 float *remain_line = line + ic * lw;
456 memset(remain_line, 0, 64);
457 for (int i = 0; i < channel - ic; i++) {
458 float d0 = src[i + ic];
459 float d1 = src[i + ic + channel];
460 float d2 = src[i + ic + 2 * channel];
461 remain_line[i] = d0 - d2;
462 remain_line[i + 4] = d1 + d2;
463 remain_line[i + 8] = d2 - d1;
464 remain_line[i + 12] = 0.0f - d1;
465 }
466 }
467 }
468
ConvDw3x3RowSingle(const float * src,float * line,int lw,int channel)469 static void ConvDw3x3RowSingle(const float *src, float *line, int lw, int channel) {
470 MS_FLOAT32X4 v0, v1, v2;
471 int ic = 0;
472 v2 = MS_MOVQ_F32(0.0f);
473 for (; ic < channel - 3; ic += 4) {
474 v0 = MS_LDQ_F32(src + ic);
475 v1 = MS_LDQ_F32(src + channel + ic);
476 MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1);
477 MS_STQ_F32(line + lw * ic, v0);
478 MS_STQ_F32(line + lw * ic + 4, v1);
479 MS_STQ_F32(line + lw * ic + 8, b2);
480 memset(line + lw * ic + 12, 0, 16);
481 }
482 if (ic < channel) {
483 float *remain_line = line + ic * lw;
484 memset(remain_line, 0, 64);
485 for (int i = 0; i < channel - ic; i++) {
486 float d0 = src[i + ic];
487 float d1 = src[i + ic + channel];
488 remain_line[i] = d0;
489 remain_line[i + 4] = d1;
490 remain_line[i + 8] = 0.0f - d1;
491 }
492 }
493 }
494
ConvDw3x3InitTop(const float * src,float ** lines,int width,int channel)495 static void ConvDw3x3InitTop(const float *src, float **lines, int width, int channel) {
496 float *line0 = lines[0];
497 float *line1 = lines[1];
498 float *line2 = lines[2];
499 int c4 = UP_ROUND(channel, C4NUM);
500 int lw = UP_DIV(width, C2NUM) * C4NUM;
501 memset(line0, 0, c4 * lw * sizeof(float));
502 ConvDw3x3RowLeft(src, line1, lw, channel);
503 ConvDw3x3RowLeft(src + width * channel, line2, lw, channel);
504 int ow = 2;
505 for (; ow < width - 2; ow += 2) {
506 ConvDw3x3RowMiddle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
507 ConvDw3x3RowMiddle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
508 }
509 int remain = width - ow;
510 if (remain == 2) {
511 ConvDw3x3RowRight(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
512 ConvDw3x3RowRight(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
513 } else if (remain == 1) {
514 ConvDw3x3RowSingle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
515 ConvDw3x3RowSingle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
516 }
517 }
518
ConvDw3x3InitRow(const float * src,float ** lines,int width,int channel)519 static void ConvDw3x3InitRow(const float *src, float **lines, int width, int channel) {
520 float *line0 = lines[0];
521 float *line1 = lines[1];
522 float *line2 = lines[2];
523 int lw = UP_DIV(width, C2NUM) * C4NUM;
524 ConvDw3x3RowLeft(src - width * channel, line0, lw, channel);
525 ConvDw3x3RowLeft(src, line1, lw, channel);
526 ConvDw3x3RowLeft(src + width * channel, line2, lw, channel);
527 int ow = 2;
528 for (; ow < width - 2; ow += 2) {
529 ConvDw3x3RowMiddle(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel);
530 ConvDw3x3RowMiddle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
531 ConvDw3x3RowMiddle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
532 }
533 int remain = width - ow;
534 if (remain == 2) {
535 ConvDw3x3RowRight(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel);
536 ConvDw3x3RowRight(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
537 ConvDw3x3RowRight(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
538 } else if (remain == 1) {
539 ConvDw3x3RowSingle(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel);
540 ConvDw3x3RowSingle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
541 ConvDw3x3RowSingle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
542 }
543 }
544
ConvDw3x3Row(const float * src,float ** lines,int width,int channel)545 static void ConvDw3x3Row(const float *src, float **lines, int width, int channel) {
546 float *tmp = lines[0];
547 lines[0] = lines[1];
548 lines[1] = lines[2];
549 lines[2] = tmp;
550 int c4 = UP_ROUND(channel, C4NUM);
551 int lw = UP_DIV(width, C2NUM) * C4NUM;
552 memset(tmp, 0, c4 * lw * sizeof(float));
553 ConvDw3x3RowLeft(src, tmp, lw, channel);
554 int ow = 2;
555 for (; ow < width - 2; ow += 2) {
556 ConvDw3x3RowMiddle(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel);
557 }
558 int remain = width - ow;
559 if (remain == 2) {
560 ConvDw3x3RowRight(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel);
561 } else if (remain == 1) {
562 ConvDw3x3RowSingle(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel);
563 }
564 }
565
ConvDw3x3Bottom(float ** lines,int width,int channel)566 static void ConvDw3x3Bottom(float **lines, int width, int channel) {
567 float *tmp = lines[0];
568 lines[0] = lines[1];
569 lines[1] = lines[2];
570 lines[2] = tmp;
571 int c4 = UP_ROUND(channel, C4NUM);
572 memset(tmp, 0, UP_DIV(width, C2NUM) * c4 * C4NUM * sizeof(float));
573 }
574
575 #ifndef ENABLE_ARM64
ConvDw3x3Line(float * dst,float ** lines,const float * weight,const float * bias_data,int width,int ori_channel,bool relu,bool relu6)576 void ConvDw3x3Line(float *dst, float **lines, const float *weight, const float *bias_data, int width, int ori_channel,
577 bool relu, bool relu6) {
578 int channel = ori_channel;
579 float *line0 = lines[0];
580 float *line1 = lines[1];
581 float *line2 = lines[2];
582 for (; channel > 0; channel -= 4) {
583 MS_FLOAT32X4 bias = MS_LDQ_F32(bias_data);
584 bias_data += 4;
585 MS_FLOAT32X4 g00 = MS_LDQ_F32(weight);
586 MS_FLOAT32X4 g01 = MS_LDQ_F32(weight + 4);
587 MS_FLOAT32X4 g02 = MS_LDQ_F32(weight + 8);
588 MS_FLOAT32X4 g03 = MS_LDQ_F32(weight + 12);
589 MS_FLOAT32X4 g10 = MS_LDQ_F32(weight + 16);
590 MS_FLOAT32X4 g11 = MS_LDQ_F32(weight + 20);
591 MS_FLOAT32X4 g12 = MS_LDQ_F32(weight + 24);
592 MS_FLOAT32X4 g13 = MS_LDQ_F32(weight + 28);
593 MS_FLOAT32X4 g20 = MS_LDQ_F32(weight + 32);
594 MS_FLOAT32X4 g21 = MS_LDQ_F32(weight + 36);
595 MS_FLOAT32X4 g22 = MS_LDQ_F32(weight + 40);
596 MS_FLOAT32X4 g23 = MS_LDQ_F32(weight + 44);
597 weight += 48;
598 float *cur_dst = dst;
599 int ow = 0;
600 for (; ow < width - 1; ow += 2) {
601 MS_FLOAT32X4 acc0 = MS_MULQ_F32(MS_LDQ_F32(line0), g00);
602 MS_FLOAT32X4 acc1 = MS_MULQ_F32(MS_LDQ_F32(line0 + 4), g01);
603 MS_FLOAT32X4 acc2 = MS_MULQ_F32(MS_LDQ_F32(line0 + 8), g02);
604 MS_FLOAT32X4 acc3 = MS_MULQ_F32(MS_LDQ_F32(line0 + 12), g03);
605 line0 += 16;
606 acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line1), g10);
607 acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line1 + 4), g11);
608 acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line1 + 8), g12);
609 acc3 = MS_MLAQ_F32(acc3, MS_LDQ_F32(line1 + 12), g13);
610 line1 += 16;
611 acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line2), g20);
612 acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line2 + 4), g21);
613 acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line2 + 8), g22);
614 acc3 = MS_MLAQ_F32(acc3, MS_LDQ_F32(line2 + 12), g23);
615 line2 += 16;
616 MS_FLOAT32X4 res0 = MS_ADDQ_F32(acc0, MS_ADDQ_F32(acc2, acc1));
617 MS_FLOAT32X4 res1 = MS_ADDQ_F32(acc1, MS_SUBQ_F32(acc3, acc2));
618 res0 = MS_ADDQ_F32(res0, bias);
619 res1 = MS_ADDQ_F32(res1, bias);
620 if (relu || relu6) {
621 res0 = MS_MAXQ_F32(res0, MS_MOVQ_F32(0.0f));
622 res1 = MS_MAXQ_F32(res1, MS_MOVQ_F32(0.0f));
623 }
624 if (relu6) {
625 res0 = MS_MINQ_F32(res0, MS_MOVQ_F32(6.0f));
626 res1 = MS_MINQ_F32(res1, MS_MOVQ_F32(6.0f));
627 }
628 if (channel >= 4) {
629 MS_STQ_F32(cur_dst, res0);
630 MS_STQ_F32(cur_dst + ori_channel, res1);
631 } else {
632 for (int i = 0; i < channel; i++) {
633 cur_dst[i] = MS_F32X4_GETI(res0, i);
634 cur_dst[ori_channel + i] = MS_F32X4_GETI(res1, i);
635 }
636 }
637 cur_dst += 2 * ori_channel;
638 }
639 if (ow < width) {
640 MS_FLOAT32X4 acc0 = MS_MULQ_F32(MS_LDQ_F32(line0), g00);
641 MS_FLOAT32X4 acc1 = MS_MULQ_F32(MS_LDQ_F32(line0 + 4), g01);
642 MS_FLOAT32X4 acc2 = MS_MULQ_F32(MS_LDQ_F32(line0 + 8), g02);
643 line0 += 16;
644 acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line1), g10);
645 acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line1 + 4), g11);
646 acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line1 + 8), g12);
647 line1 += 16;
648 acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line2), g20);
649 acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line2 + 4), g21);
650 acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line2 + 8), g22);
651 line2 += 16;
652 MS_FLOAT32X4 res0 = MS_ADDQ_F32(acc0, MS_ADDQ_F32(acc2, acc1));
653 res0 = MS_ADDQ_F32(res0, bias);
654 if (relu || relu6) {
655 res0 = MS_MAXQ_F32(res0, MS_MOVQ_F32(0.0f));
656 }
657 if (relu6) {
658 res0 = MS_MINQ_F32(res0, MS_MOVQ_F32(6.0f));
659 }
660 if (channel >= 4) {
661 MS_STQ_F32(cur_dst, res0);
662 } else {
663 for (int i = 0; i < channel; i++) {
664 cur_dst[i] = MS_F32X4_GETI(res0, i);
665 }
666 }
667 }
668 dst += 4;
669 }
670 }
671 #endif
672
ConvDw3x3(float * output_data,float * buffer,const float * input_data,const float * weight_data,const float * bias_data,const ConvParameter * conv_param,int start_oh,int end_oh)673 void ConvDw3x3(float *output_data, float *buffer, const float *input_data, const float *weight_data,
674 const float *bias_data, const ConvParameter *conv_param, int start_oh, int end_oh) {
675 int units = UP_DIV(conv_param->output_w_, C2NUM);
676 int c4 = UP_ROUND(conv_param->input_channel_, C4NUM);
677 int line = conv_param->input_channel_ * conv_param->input_w_;
678
679 bool relu = conv_param->act_type_ == ActType_Relu;
680 bool relu6 = conv_param->act_type_ == ActType_Relu6;
681
682 for (int b = 0; b < conv_param->output_batch_; b++) {
683 const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_;
684 float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
685 float *line0 = buffer;
686 float *line1 = buffer + units * c4 * C4NUM;
687 float *line2 = buffer + units * c4 * C8NUM;
688 float *lines[3] = {line0, line1, line2};
689 int oh = start_oh;
690 if (oh == 0) {
691 // input trans
692 ConvDw3x3InitTop(src, lines, conv_param->output_w_, conv_param->input_channel_);
693 } else {
694 // input trans
695 ConvDw3x3InitRow(src + oh * line, lines, conv_param->output_w_, conv_param->input_channel_);
696 }
697 // dst calc and trans
698 ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_,
699 relu, relu6);
700 for (oh = start_oh + 1; oh < end_oh - 1; oh++) {
701 // input trans
702 ConvDw3x3Row(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_);
703 // dst calc and trans
704 ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_,
705 relu, relu6);
706 }
707 if (oh == conv_param->output_h_ - 1) {
708 // input trans
709 ConvDw3x3Bottom(lines, conv_param->output_w_, conv_param->input_channel_);
710 } else {
711 // input trans
712 ConvDw3x3Row(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_);
713 }
714 // dst calc and trans
715 ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_,
716 relu, relu6);
717 }
718 }
719 #endif
720
721 /*conv depthwise indirect buffer fp32 begin*/
CheckConvDwUseIndirectBuffer(const ConvParameter * conv_param)722 bool CheckConvDwUseIndirectBuffer(const ConvParameter *conv_param) {
723 bool use_indirect = (conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3) ||
724 (conv_param->kernel_h_ == 5 && conv_param->kernel_w_ == 5);
725 return use_indirect;
726 }
727
ConvDwInitIndirection(float ** indirect_buffer,float * src,float * zero_ptr,const ConvParameter * conv_param,int step_h,int step_w)728 void ConvDwInitIndirection(float **indirect_buffer, float *src, float *zero_ptr, const ConvParameter *conv_param,
729 int step_h, int step_w) {
730 #ifdef ENABLE_AVX
731 int div = C8NUM;
732 #else
733 int div = C4NUM;
734 #endif
735
736 int ic_div = UP_DIV(conv_param->input_channel_, div) * div;
737 for (int b = 0; b < conv_param->output_batch_; b++) {
738 float **indirect = indirect_buffer + b * conv_param->output_h_ * step_h;
739 float *input = src + b * conv_param->input_h_ * conv_param->input_w_ * ic_div;
740 for (int oh = 0; oh < conv_param->output_h_; oh++) {
741 for (int kh = 0; kh < conv_param->kernel_h_; kh++) {
742 int ih = oh * conv_param->stride_h_ + kh * conv_param->dilation_h_ - conv_param->pad_u_;
743 if (ih < conv_param->input_h_ && ih >= 0) {
744 for (int ow = 0; ow < conv_param->output_w_; ow++) {
745 for (int kw = 0; kw < conv_param->kernel_w_; kw++) {
746 int iw = ow * conv_param->stride_w_ + kw * conv_param->dilation_w_ - conv_param->pad_l_;
747 int index = oh * step_h + ow * step_w * conv_param->kernel_h_ + kw * conv_param->kernel_h_ + kh;
748 if (iw < conv_param->input_w_ && iw >= 0) {
749 indirect[index] = input + (ih * conv_param->input_w_ + iw) * ic_div;
750 } else {
751 indirect[index] = zero_ptr;
752 }
753 }
754 }
755 } else {
756 for (int ow = 0; ow < conv_param->output_w_; ow++) {
757 for (int kw = 0; kw < conv_param->kernel_w_; kw++) {
758 int index = oh * step_h + ow * step_w * conv_param->kernel_h_ + kw * conv_param->kernel_h_ + kh;
759 indirect[index] = zero_ptr;
760 }
761 }
762 }
763 }
764 }
765 }
766 }
767
768 #if !defined(ENABLE_ARM64) && !defined(ENABLE_AVX)
ConvDwFp32IndirectRow(float * output,float ** input,const float * weights,const float * bias,int channels,int output_width,int input_stride,bool relu,bool relu6,int kernel)769 void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
770 int output_width, int input_stride, bool relu, bool relu6, int kernel) {
771 do {
772 float **in = input;
773 size_t c = (size_t)channels;
774 const float *w = weights;
775 float *out = output;
776 memcpy(out, bias, channels * (int)sizeof(float));
777 for (; c >= C4NUM; c -= C4NUM) {
778 for (int i = 0; i < C4NUM; i++) {
779 for (int k = 0; k < kernel; k++) {
780 out[i] += in[k][i] * w[i + k * C4NUM];
781 }
782 }
783 w += kernel * C4NUM;
784 out += C4NUM;
785 for (int k = 0; k < kernel; k++) {
786 in[k] += C4NUM;
787 }
788 }
789 for (int i = 0; i < c; i++) {
790 for (int k = 0; k < kernel; k++) {
791 out[i] += in[k][i] * w[i + k * C4NUM];
792 }
793 }
794 if (relu) {
795 Fp32Relu(output, channels, output);
796 }
797 if (relu6) {
798 Fp32Relu6(output, channels, output);
799 }
800 output += channels;
801 input = input + input_stride;
802 } while (--output_width != 0);
803 }
804 #endif
805
806 #ifdef ENABLE_ARM64
ConvDwFp32IndirectRow(float * output,float ** input,const float * weights,const float * bias,int channels,int output_width,int input_stride,bool relu,bool relu6,int kernel)807 void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
808 int output_width, int input_stride, bool relu, bool relu6, int kernel) {
809 if (kernel == 9) {
810 ConvDwFp32Indirect3x3(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu,
811 relu6);
812 } else if (kernel == 25) {
813 ConvDwFp32Indirect5x5(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu,
814 relu6);
815 }
816 }
817 #endif
818
819 #ifdef ENABLE_AVX
ConvDwFp32IndirectRow(float * output,float ** input,const float * weights,const float * bias,int channels,int output_width,int input_stride,bool relu,bool relu6,int kernel)820 void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
821 int output_width, int input_stride, bool relu, bool relu6, int kernel) {
822 if (kernel == 9) {
823 ConvDwFp32Avx3x3(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu, relu6);
824 } else if (kernel == 25) {
825 ConvDwFp32Avx5x5(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu, relu6);
826 }
827 }
828 #endif
829
ConvDwIndirection(float * output_data,float ** indirect_buffer,const float * weight_data,const float * bias_data,float * zero_ptr,const ConvParameter * conv_param,int task_id)830 void ConvDwIndirection(float *output_data, float **indirect_buffer, const float *weight_data, const float *bias_data,
831 float *zero_ptr, const ConvParameter *conv_param, int task_id) {
832 if (conv_param->thread_num_ == 0) {
833 return;
834 }
835 int step_w = conv_param->dilation_w_ == 1 ? conv_param->stride_w_ : conv_param->kernel_w_;
836 int step_h =
837 (conv_param->kernel_h_ * conv_param->kernel_w_) + (conv_param->output_w_ - 1) * step_w * conv_param->kernel_h_;
838 int input_stride = conv_param->kernel_h_ * step_w;
839
840 bool relu = conv_param->act_type_ == ActType_Relu;
841 bool relu6 = conv_param->act_type_ == ActType_Relu6;
842
843 int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_);
844 int h_start = h_step * task_id;
845 int h_end = MSMIN(h_start + h_step, conv_param->output_h_);
846
847 for (int b = 0; b < conv_param->output_batch_; b++) {
848 float **indirect_b = indirect_buffer + b * conv_param->output_h_ * step_h;
849 float *outout_b = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
850 for (int oh = h_start; oh < h_end; oh++) {
851 float **indirect = indirect_b + oh * step_h;
852 float *output_h = outout_b + oh * conv_param->output_w_ * conv_param->output_channel_;
853 if (conv_param->kernel_w_ == 3) {
854 ConvDwFp32IndirectRow(output_h, indirect, weight_data, bias_data, conv_param->output_channel_,
855 conv_param->output_w_, input_stride, relu, relu6, 9);
856 } else if (conv_param->kernel_w_ == 5) {
857 ConvDwFp32IndirectRow(output_h, indirect, weight_data, bias_data, conv_param->output_channel_,
858 conv_param->output_w_, input_stride, relu, relu6, 25);
859 }
860 }
861 }
862 }
863 /*conv depthwise indirect buffer fp32 end*/
864
865 /*deconv depthwise fp32 begin*/
DeconvDwBorderPixel(float * dst,const float * src,const float * weight,int height,int width,int in_kh_step,int in_kw_step,int kernel_w_step)866 void DeconvDwBorderPixel(float *dst, const float *src, const float *weight, int height, int width, int in_kh_step,
867 int in_kw_step, int kernel_w_step) {
868 float *dst_kh = dst;
869 const float *weight_kh = weight;
870 for (int kh = 0; kh < height; kh++) {
871 float *dst_kw = dst_kh;
872 const float *weight_kw = weight_kh;
873 for (int kw = 0; kw < width; kw++) {
874 #ifdef ENABLE_ARM64
875 float32x4_t src_4 = vld1q_f32(src);
876 float32x4_t weight_4 = vld1q_f32(weight_kw);
877 float32x4_t dst_4 = vld1q_f32(dst_kw);
878 dst_4 = vfmaq_f32(dst_4, src_4, weight_4);
879 vst1q_f32(dst_kw, dst_4);
880 #else
881 for (int c = 0; c < C4NUM; c++) {
882 dst_kw[c] += src[c] * weight_kw[c];
883 }
884 #endif
885 dst_kw += in_kw_step;
886 weight_kw += C4NUM;
887 } // kernel_w loop
888 dst_kh += in_kh_step;
889 weight_kh += kernel_w_step;
890 } // kernel_h loop
891 }
892
DeconvDwBorder(float * dst,const float * src,const float * weight,int top,int bottom,int left,int right,const ConvParameter * conv_param,const SlidingWindowParam * sliding)893 void DeconvDwBorder(float *dst, const float *src, const float *weight, int top, int bottom, int left, int right,
894 const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
895 if (conv_param->dilation_h_ == 0 || conv_param->dilation_w_ == 0) {
896 return;
897 }
898 const float *src_h = src + top * sliding->out_h_step_;
899 for (int ih = top; ih < bottom; ih++) {
900 int oh = ih * conv_param->stride_h_ - conv_param->pad_u_;
901 int start_kh = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
902 int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
903 float *dst_h = dst + oh * sliding->in_h_step_;
904
905 const float *src_kernel = src_h + left * sliding->block_channel_;
906 for (int iw = left; iw < right; iw++) {
907 int ow = iw * conv_param->stride_w_ - conv_param->pad_l_;
908 int start_kw = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_));
909 int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_));
910 float *dst_w = dst_h + ow * sliding->block_channel_;
911
912 const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM;
913 float *dst_kernel = dst_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_;
914 #ifdef ENABLE_ARM64
915 DeconvDwFp32Border(dst_kernel, src_kernel, weight_kernel, end_kh - start_kh, end_kw - start_kw,
916 sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float),
917 conv_param->kernel_w_ * C4NUM * sizeof(float));
918 #else
919 DeconvDwBorderPixel(dst_kernel, src_kernel, weight_kernel, end_kh - start_kh, end_kw - start_kw,
920 sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C4NUM);
921 #endif
922 src_kernel += sliding->block_channel_;
923 } // width loop
924 src_h += sliding->out_h_step_;
925 } // height loop
926 }
927
928 #ifndef ENABLE_ARM64
DeconvDwCenter(float * dst,const float * src,const float * weight,int height,int width,int kernel_h,int kernel_w,int out_h_step,int block_channel,int in_sh_step,int in_sw_step,int in_kh_step,int in_kw_step)929 void DeconvDwCenter(float *dst, const float *src, const float *weight, int height, int width, int kernel_h,
930 int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step, int in_kh_step,
931 int in_kw_step) {
932 float *dst_h = dst;
933 const float *src_h = src;
934 for (int oh = 0; oh < height; oh++) {
935 float *dst_w = dst_h;
936 const float *src_w = src_h;
937 for (int ow = 0; ow < width; ow++) {
938 float *dst_kh = dst_w;
939 const float *weight_kh = weight;
940 for (int kh = 0; kh < kernel_h; kh++) {
941 float *dst_kw = dst_kh;
942 const float *weight_kw = weight_kh;
943 for (int kw = 0; kw < kernel_w; kw++) {
944 for (int c = 0; c < C4NUM; c++) {
945 dst_kw[c] += src_w[c] * weight_kw[c];
946 }
947 dst_kw += in_kw_step;
948 weight_kw += C4NUM;
949 } // kernel_w loop
950 dst_kh += in_kh_step;
951 weight_kh += kernel_w * C4NUM;
952 } // kernel_h loop
953 dst_w += in_sw_step;
954 src_w += block_channel;
955 } // dst_width loop
956 dst_h += in_sh_step;
957 src_h += out_h_step;
958 } // dst_height loop
959 }
960 #endif
961
DeconvDwPost(float * dst,const float * bias,int block_channel,const ConvParameter * conv_param)962 void DeconvDwPost(float *dst, const float *bias, int block_channel, const ConvParameter *conv_param) {
963 bool relu = conv_param->act_type_ == ActType_Relu;
964 bool relu6 = conv_param->act_type_ == ActType_Relu6;
965 float *dst_k = dst;
966 for (int k = 0; k < conv_param->output_h_ * conv_param->output_w_; k++) {
967 for (int c = 0; c < C4NUM; c++) {
968 dst_k[c] += bias[c];
969 dst_k[c] = (relu) ? (MSMAX(0, dst_k[c])) : (dst_k[c]);
970 dst_k[c] = (relu6) ? (MSMIN(6, MSMAX(0, dst_k[c]))) : (dst_k[c]);
971 }
972 dst_k += block_channel;
973 }
974 }
975
976 // deconv depthwise fp32: sliding window
DeconvDwSWFp32(float * output_data,const float * input_data,const float * weight_data,const float * bias_data,const ConvParameter * conv_param,const SlidingWindowParam * sliding,int task_id)977 void DeconvDwSWFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
978 const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) {
979 const float *src = input_data;
980 float *dst = output_data;
981 if (conv_param->thread_num_ == 0) {
982 return;
983 }
984 for (int b = 0; b < conv_param->output_batch_; b++) {
985 for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) {
986 const float *src_data = src + oc * C4NUM;
987 float *dst_data = dst + oc * C4NUM;
988 const float *weight = weight_data + oc * sliding->kernel_step_;
989 const float *bias = bias_data + oc * C4NUM;
990 DeconvDwBorder(dst_data, src_data, weight, 0, sliding->top_, 0, conv_param->input_w_, conv_param, sliding);
991 DeconvDwBorder(dst_data, src_data, weight, sliding->bottom_, conv_param->input_h_, 0, conv_param->input_w_,
992 conv_param, sliding);
993 DeconvDwBorder(dst_data, src_data, weight, sliding->top_, sliding->bottom_, 0, sliding->left_, conv_param,
994 sliding);
995 DeconvDwBorder(dst_data, src_data, weight, sliding->top_, sliding->bottom_, sliding->right_, conv_param->input_w_,
996 conv_param, sliding);
997
998 if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) {
999 int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_;
1000 int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
1001 float *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_;
1002 const float *in_t = src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;
1003
1004 #if defined(ENABLE_ARM) || defined(ENABLE_SSE)
1005 DeconvDwFp32Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
1006 conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float),
1007 sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float),
1008 sliding->in_sw_step_ * sizeof(float), sliding->in_kh_step_ * sizeof(float),
1009 sliding->in_kw_step_ * sizeof(float));
1010 #else
1011 DeconvDwCenter(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
1012 conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_,
1013 sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_);
1014 #endif
1015 }
1016 DeconvDwPost(dst_data, bias, sliding->block_channel_, conv_param);
1017 } // output C4 loop
1018 src += sliding->out_step_;
1019 dst += sliding->in_step_;
1020 } // batch loop
1021 // output nhwc4
1022 }
1023 /*deconv depthwise fp32 end*/
1024
1025 #ifdef ENABLE_AVX
DepthwiseBorderAvxFp32(float * dst,const float * src,const float * weight,const float * bias,int top,int left,int right,const ConvParameter * conv_param,const SlidingWindowParam * sw_param,const DepthwiseSWKernel kernel,int act_type,int ow_bock,int oc_block)1026 void DepthwiseBorderAvxFp32(float *dst, const float *src, const float *weight, const float *bias, int top, int left,
1027 int right, const ConvParameter *conv_param, const SlidingWindowParam *sw_param,
1028 const DepthwiseSWKernel kernel, int act_type, int ow_bock, int oc_block) {
1029 // dw border compate
1030 int ih = top * conv_param->stride_h_ - conv_param->pad_u_;
1031 int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
1032 int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
1033 const float *src_h = src + ih * sw_param->in_h_step_;
1034 float *dst_kernel = dst + left * sw_param->block_channel_;
1035 for (int ow = left; ow < right; ow += ow_bock) {
1036 int iw = ow * conv_param->stride_w_ - conv_param->pad_l_;
1037 int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
1038 int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
1039 const float *src_w = src_h + iw * sw_param->block_channel_;
1040 const float *src_kernel = src_w + start_kh * sw_param->in_kh_step_ + start_kw * sw_param->in_kw_step_;
1041 const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C8NUM * oc_block;
1042 kernel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, act_type, ow_bock,
1043 oc_block, sw_param->block_channel_, sw_param->in_kw_step_, sw_param->in_kh_step_, sw_param->in_sw_step_,
1044 (conv_param->kernel_w_ - end_kw + start_kw) * C8NUM * oc_block);
1045 dst_kernel += ow_bock * sw_param->block_channel_;
1046 } // width loop
1047 }
1048
DepthwiseSWAvxFp32(float * output_data,const float * input_data,const float * weight_data,const float * bias_data,const ConvParameter * conv_param,const SlidingWindowParam * sw_param,int task_id)1049 void DepthwiseSWAvxFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
1050 const ConvParameter *conv_param, const SlidingWindowParam *sw_param, int task_id) {
1051 int oh_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_);
1052 int oh_start = oh_step * task_id;
1053 int oh_end = MSMIN(oh_start + oh_step, conv_param->output_h_);
1054 if (oh_start >= oh_end) {
1055 return;
1056 }
1057 // depthwise sw in x86 avx instructions
1058 int oc_tile_ = C8NUM; // oc in algin to C8NUM in x86_64_avx
1059 int act_type = 0;
1060 if (conv_param->act_type_ == ActType_Relu6) {
1061 act_type += 1;
1062 }
1063 if (conv_param->act_type_ == ActType_Relu || conv_param->act_type_ == ActType_Relu6) {
1064 act_type += 2;
1065 }
1066 int kernel_h = conv_param->kernel_h_;
1067 int kernel_w = conv_param->kernel_w_;
1068 int output_w = conv_param->output_w_;
1069 int oc_algin = sw_param->block_channel_;
1070 int oc_num = sw_param->c_block_;
1071 int in_step = sw_param->in_step_;
1072 int out_step = sw_param->out_step_;
1073 int in_sw_step = sw_param->in_sw_step_;
1074 int in_kw_step = sw_param->in_kw_step_;
1075 int in_kh_step = sw_param->in_kh_step_;
1076 int in_sh_step = sw_param->in_sh_step_;
1077 int out_right = sw_param->right_;
1078 int out_left = sw_param->left_;
1079 int out_top = sw_param->top_;
1080 int out_bottom = sw_param->bottom_;
1081 int kernel_step = sw_param->kernel_step_;
1082 int out_h_step = sw_param->out_h_step_;
1083 int in_h_start = out_top * conv_param->stride_h_ - conv_param->pad_u_;
1084 int in_w_start = out_left * conv_param->stride_w_ - conv_param->pad_l_;
1085 int in_start = in_h_start * sw_param->in_h_step_ + in_w_start * oc_algin;
1086 const int ow_block_num[4] = {8, 4, 4, 3};
1087 const DepthwiseSWKernel kernel[4][2] = {{DepthwiseSW1x8Kernel, DepthwiseSW8x8Kernel},
1088 {DepthwiseSW1x16Kernel, DepthwiseSW4x16Kernel},
1089 {DepthwiseSW1x24Kernel, DepthwiseSW4x24Kernel},
1090 {DepthwiseSW1x32Kernel, DepthwiseSW3x32Kernel}};
1091 for (int b = 0; b < conv_param->output_batch_; b++) {
1092 for (int oh = oh_start; oh < oh_end; ++oh) {
1093 float *dst_oh = output_data + oh * out_h_step;
1094 const float *src_h = input_data + in_start + (oh - out_top) * in_sh_step;
1095 int oc_block = 0;
1096 const float *bias = bias_data;
1097 for (int oc = 0; oc < oc_num; oc += oc_block) {
1098 oc_block = MSMIN(C4NUM, oc_num - oc); // 4 3 2 1
1099 int oc_step = oc * oc_tile_;
1100 const float *weight = weight_data + oc * kernel_step;
1101 if (bias != NULL) {
1102 bias = bias_data + oc_step;
1103 }
1104 float *dst_w = dst_oh + oc_step;
1105 const DepthwiseSWKernel kernel_border = kernel[oc_block - 1][0];
1106 if (oh < out_top || oh >= out_bottom) { // oh in up or down border
1107 DepthwiseBorderAvxFp32(dst_w, input_data + oc_step, weight, bias, oh, 0, output_w, conv_param, sw_param,
1108 kernel_border, act_type, 1, oc_block);
1109 } else { // oh in center
1110 // ow in right
1111 DepthwiseBorderAvxFp32(dst_w, input_data + oc_step, weight, bias, oh, 0, out_left, conv_param, sw_param,
1112 kernel_border, act_type, 1, oc_block);
1113 // ow in center
1114 const float *src_w = src_h + oc_step;
1115 int ow_block = ow_block_num[oc_block - 1]; // 8 4 4 3
1116 for (int ow = out_left; ow < out_right; ow += ow_block) { // left ~ right
1117 if (ow_block > out_right - ow) { // ow is not enough and process one ow
1118 ow_block = 1;
1119 }
1120 kernel[oc_block - 1][ow_block / ow_block_num[oc_block - 1]](
1121 dst_w + ow * oc_algin, src_w, weight, bias, kernel_h, kernel_w, act_type, ow_block, oc_block, oc_algin,
1122 in_kw_step, in_kh_step, in_sw_step, 0);
1123 src_w += ow_block * in_sw_step;
1124 }
1125 // ow in left
1126 DepthwiseBorderAvxFp32(dst_w, input_data + oc_step, weight, bias, oh, out_right, output_w, conv_param,
1127 sw_param, kernel_border, act_type, 1, oc_block);
1128 }
1129 }
1130 } // output h loop
1131 input_data += in_step;
1132 output_data += out_step;
1133 } // batch loop
1134 }
1135
1136 #ifdef ENABLE_DEBUG
DepthwiseSWWxKKernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1137 void DepthwiseSWWxKKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1138 size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1139 size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1140 __m256 dst_data[12];
1141 __m256 src_data;
1142 const float *src_kh[12];
1143 const float *src_kw[12];
1144 __m256 weight_data[4];
1145 for (int i = 0; i < ow_block; ++i) {
1146 if (bias != NULL) {
1147 for (int j = 0; j < oc_block; ++j) {
1148 dst_data[i * oc_block + j] = _mm256_loadu_ps(bias + j * 8);
1149 }
1150 } else {
1151 for (int j = 0; j < oc_block; ++j) {
1152 dst_data[i * oc_block + j] = _mm256_set1_ps(0.0f);
1153 }
1154 }
1155 src_kh[i] = src + i * in_sw_step;
1156 src_kw[i] = NULL;
1157 }
1158 const float *weight_kernel = weight;
1159 for (int kh = 0; kh < kernel_h; kh++) {
1160 for (int i = 0; i < ow_block; ++i) {
1161 src_kw[i] = src_kh[i];
1162 }
1163 for (int kw = 0; kw < kernel_w; kw++) {
1164 for (int j = 0; j < oc_block; ++j) {
1165 weight_data[j] = _mm256_loadu_ps(weight_kernel + j * C8NUM);
1166 }
1167 for (int i = 0; i < ow_block; ++i) { // loop ow
1168 for (int j = 0; j < oc_block; ++j) {
1169 src_data = _mm256_loadu_ps(src_kw[i] + j * C8NUM);
1170 dst_data[i * oc_block + j] += src_data * weight_data[j];
1171 }
1172 }
1173 for (int i = 0; i < ow_block; ++i) {
1174 src_kw[i] += in_kw_step; // ic8 * dilation_w
1175 }
1176 weight_kernel += oc_block * C8NUM;
1177 } // kernel_w loop
1178 weight_kernel += kw_remainder;
1179 for (int i = 0; i < ow_block; ++i) {
1180 src_kh[i] += in_kh_step; //
1181 }
1182 } // kernel_h loop
1183 // add bias and relu
1184 for (int i = 0; i < ow_block; ++i) {
1185 for (int j = 0; j < oc_block; ++j) {
1186 if (0x1 & act_flag) { // relu6
1187 dst_data[i * oc_block + j] = _mm256_min_ps(dst_data[i * oc_block + j], _mm256_set1_ps(6.0f));
1188 }
1189 if (0x2 & act_flag) { // relu
1190 dst_data[i * oc_block + j] = _mm256_max_ps(dst_data[i * oc_block + j], _mm256_set1_ps(0.0f));
1191 }
1192 _mm256_storeu_ps(dst + i * oc_algin + j * C8NUM, dst_data[i * oc_block + j]);
1193 }
1194 }
1195 }
1196 #endif
1197
DepthwiseSW3x32Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1198 void DepthwiseSW3x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1199 size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1200 size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1201 in_kh_step *= sizeof(float);
1202 in_sw_step *= sizeof(float);
1203 in_kw_step *= sizeof(float);
1204 oc_algin *= sizeof(float);
1205 kw_remainder *= sizeof(float);
1206 asm volatile(
1207 "cmpq $0, %2\n"
1208 "je 0f\n"
1209 "vmovups (%2), %%ymm0\n"
1210 "vmovups 0x20(%2), %%ymm1\n"
1211 "vmovups 0x40(%2), %%ymm2\n"
1212 "vmovups 0x60(%2), %%ymm3\n"
1213 "vmovups (%2), %%ymm4\n"
1214 "vmovups 0x20(%2), %%ymm5\n"
1215 "vmovups 0x40(%2), %%ymm6\n"
1216 "vmovups 0x60(%2), %%ymm7\n"
1217 "vmovups (%2), %%ymm8\n"
1218 "vmovups 0x20(%2), %%ymm9\n"
1219 "vmovups 0x40(%2), %%ymm10\n"
1220 "vmovups 0x60(%2), %%ymm11\n"
1221 "jmp 1f\n"
1222 "0:\n"
1223 "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1224 "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1225 "vxorps %%ymm2, %%ymm2, %%ymm2\n"
1226 "vxorps %%ymm3, %%ymm3, %%ymm3\n"
1227 "vxorps %%ymm4, %%ymm4, %%ymm4\n"
1228 "vxorps %%ymm5, %%ymm5, %%ymm5\n"
1229 "vxorps %%ymm6, %%ymm6, %%ymm6\n"
1230 "vxorps %%ymm7, %%ymm7, %%ymm7\n"
1231 "vxorps %%ymm8, %%ymm8, %%ymm8\n"
1232 "vxorps %%ymm9, %%ymm9, %%ymm9\n"
1233 "vxorps %%ymm10, %%ymm10, %%ymm10\n"
1234 "vxorps %%ymm11, %%ymm11, %%ymm11\n"
1235 "1:\n" // LoopH
1236 "movq %4, %%rsi\n" // width
1237 "movq %0, %%rcx\n" // src_h
1238 "2:\n" // LoopW
1239
1240 "vmovups (%1), %%ymm12\n"
1241 "vmovups (%%rcx), %%ymm13\n"
1242 "vmovups (%%rcx, %7), %%ymm14\n"
1243 "vmovups (%%rcx, %7, 2), %%ymm15\n"
1244 "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n"
1245 "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n"
1246 "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n"
1247
1248 "vmovups 0x20(%1), %%ymm12\n"
1249 "vmovups 0x20(%%rcx), %%ymm13\n"
1250 "vmovups 0x20(%%rcx, %7), %%ymm14\n"
1251 "vmovups 0x20(%%rcx, %7, 2), %%ymm15\n"
1252 "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n"
1253 "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n"
1254 "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n"
1255
1256 "vmovups 0x40(%1), %%ymm12\n"
1257 "vmovups 0x40(%%rcx), %%ymm13\n"
1258 "vmovups 0x40(%%rcx, %7), %%ymm14\n"
1259 "vmovups 0x40(%%rcx, %7, 2), %%ymm15\n"
1260 "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n"
1261 "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n"
1262 "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n"
1263
1264 "vmovups 0x60(%1), %%ymm12\n"
1265 "vmovups 0x60(%%rcx), %%ymm13\n"
1266 "vmovups 0x60(%%rcx, %7), %%ymm14\n"
1267 "vmovups 0x60(%%rcx, %7, 2), %%ymm15\n"
1268 "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n"
1269 "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n"
1270 "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n"
1271 "addq $128, %1\n"
1272
1273 "addq %5, %%rcx\n" // in_kw_step
1274 "dec %%rsi\n"
1275 "jg 2b\n"
1276
1277 "addq %6, %0\n" // in_kh_step
1278 "addq %8, %1\n"
1279 "dec %3\n"
1280 "jg 1b\n"
1281 :
1282 : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5
1283 "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder) // 8
1284 : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9",
1285 "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
1286
1287 asm volatile(
1288 "and $0x3, %%eax\n"
1289 "je 0f\n"
1290 // Relu
1291 "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1292 "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1293 "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1294 "vmaxps %%ymm12, %%ymm2, %%ymm2\n"
1295 "vmaxps %%ymm12, %%ymm3, %%ymm3\n"
1296 "vmaxps %%ymm12, %%ymm4, %%ymm4\n"
1297 "vmaxps %%ymm12, %%ymm5, %%ymm5\n"
1298 "vmaxps %%ymm12, %%ymm6, %%ymm6\n"
1299 "vmaxps %%ymm12, %%ymm7, %%ymm7\n"
1300 "vmaxps %%ymm12, %%ymm8, %%ymm8\n"
1301 "vmaxps %%ymm12, %%ymm9, %%ymm9\n"
1302 "vmaxps %%ymm12, %%ymm10, %%ymm10\n"
1303 "vmaxps %%ymm12, %%ymm11, %%ymm11\n"
1304
1305 "and $0x1, %%eax\n"
1306 "je 0f\n"
1307 // relu6
1308 "mov $0x40C00000, %%ecx\n"
1309 "vmovd %%ecx, %%xmm14\n"
1310 "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1311 "vminps %%ymm14, %%ymm0, %%ymm0\n"
1312 "vminps %%ymm14, %%ymm1, %%ymm1\n"
1313 "vminps %%ymm14, %%ymm2, %%ymm2\n"
1314 "vminps %%ymm14, %%ymm3, %%ymm3\n"
1315 "vminps %%ymm14, %%ymm4, %%ymm4\n"
1316 "vminps %%ymm14, %%ymm5, %%ymm5\n"
1317 "vminps %%ymm14, %%ymm6, %%ymm6\n"
1318 "vminps %%ymm14, %%ymm7, %%ymm7\n"
1319 "vminps %%ymm14, %%ymm8, %%ymm8\n"
1320 "vminps %%ymm14, %%ymm9, %%ymm9\n"
1321 "vminps %%ymm14, %%ymm10, %%ymm10\n"
1322 "vminps %%ymm14, %%ymm11, %%ymm11\n"
1323
1324 "0:\n"
1325 "vmovups %%ymm0, (%2)\n" // dst_0
1326 "vmovups %%ymm1, 0x20(%2)\n"
1327 "vmovups %%ymm2, 0x40(%2)\n"
1328 "vmovups %%ymm3, 0x60(%2)\n"
1329 "vmovups %%ymm4, (%2, %1, 1)\n"
1330 "vmovups %%ymm5, 0x20(%2, %1, 1)\n"
1331 "vmovups %%ymm6, 0x40(%2, %1, 1)\n"
1332 "vmovups %%ymm7, 0x60(%2, %1, 1)\n"
1333 "vmovups %%ymm8, (%2, %1, 2)\n"
1334 "vmovups %%ymm9, 0x20(%2, %1, 2)\n"
1335 "vmovups %%ymm10, 0x40(%2, %1, 2)\n"
1336 "vmovups %%ymm11, 0x60(%2, %1, 2)\n"
1337 :
1338 : "a"(act_flag), "r"(oc_algin), "r"(dst)
1339 : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
1340 "%ymm11", "%ymm12", "%ymm14");
1341 }
1342
DepthwiseSW1x32Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1343 void DepthwiseSW1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1344 size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1345 size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1346 in_kh_step *= sizeof(float);
1347 in_kw_step *= sizeof(float);
1348 oc_algin *= sizeof(float);
1349 kw_remainder *= sizeof(float);
1350 asm volatile(
1351 "cmpq $0, %2\n"
1352 "je 0f\n"
1353 "vmovups (%2), %%ymm0\n"
1354 "vmovups 0x20(%2), %%ymm1\n"
1355 "vmovups 0x40(%2), %%ymm2\n"
1356 "vmovups 0x60(%2), %%ymm3\n"
1357 "jmp 1f\n"
1358 "0:\n"
1359 "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1360 "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1361 "vxorps %%ymm2, %%ymm2, %%ymm2\n"
1362 "vxorps %%ymm3, %%ymm3, %%ymm3\n"
1363 "1:\n" // LoopH
1364 "movq %4, %%rsi\n" // width
1365 "movq %0, %%rcx\n" // src_h
1366 "2:\n" // Loopw
1367 "vmovups (%%rcx), %%ymm4\n"
1368 "vmovups 0x20(%%rcx), %%ymm5\n"
1369 "vmovups 0x40(%%rcx), %%ymm6\n"
1370 "vmovups 0x60(%%rcx), %%ymm7\n"
1371 // Weight data is loaded directly from memory instead of into registers for calculation.
1372 "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1373 "vfmadd231ps 0x20(%1), %%ymm5, %%ymm1\n"
1374 "vfmadd231ps 0x40(%1), %%ymm6, %%ymm2\n"
1375 "vfmadd231ps 0x60(%1), %%ymm7, %%ymm3\n"
1376 "addq $128, %1\n"
1377
1378 "addq %5, %%rcx\n" // in_kw_step
1379 "dec %%rsi\n"
1380 "jg 2b\n"
1381
1382 "addq %6, %0\n" // in_kh_step
1383 "addq %7, %1\n"
1384 "dec %3\n"
1385 "jg 1b\n"
1386 :
1387 : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5
1388 "r"(in_kh_step), "r"(kw_remainder) // 7
1389 : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7");
1390
1391 asm volatile(
1392 "and $0x3, %%eax\n"
1393 "je 0f\n"
1394 // Relu
1395 "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1396 "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1397 "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1398 "vmaxps %%ymm12, %%ymm2, %%ymm2\n"
1399 "vmaxps %%ymm12, %%ymm3, %%ymm3\n"
1400
1401 "and $0x1, %%eax\n"
1402 "je 0f\n"
1403 // relu6
1404 "mov $0x40C00000, %%ecx\n"
1405 "vmovd %%ecx, %%xmm14\n"
1406 "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1407 "vminps %%ymm14, %%ymm0, %%ymm0\n"
1408 "vminps %%ymm14, %%ymm1, %%ymm1\n"
1409 "vminps %%ymm14, %%ymm2, %%ymm2\n"
1410 "vminps %%ymm14, %%ymm3, %%ymm3\n"
1411
1412 "0:\n"
1413 "vmovups %%ymm0, (%2)\n" // dst_0
1414 "vmovups %%ymm1, 0x20(%2)\n"
1415 "vmovups %%ymm2, 0x40(%2)\n"
1416 "vmovups %%ymm3, 0x60(%2)\n"
1417 :
1418 : "a"(act_flag), "r"(oc_algin), "r"(dst)
1419 : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm12", "%ymm14");
1420 }
1421
DepthwiseSW4x24Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1422 void DepthwiseSW4x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1423 size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1424 size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1425 in_kh_step *= sizeof(float);
1426 in_kw_step *= sizeof(float);
1427 in_sw_step *= sizeof(float);
1428 kw_remainder *= sizeof(float);
1429 size_t src_3_step = 3 * in_sw_step;
1430 float *dst_3 = dst + 3 * oc_algin;
1431 oc_algin *= sizeof(float);
1432 asm volatile(
1433 "cmpq $0, %2\n"
1434 "je 0f\n"
1435 "vmovups (%2), %%ymm0\n"
1436 "vmovups 0x20(%2), %%ymm1\n"
1437 "vmovups 0x40(%2), %%ymm2\n"
1438 // We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction.
1439 "vmovups (%2), %%ymm3\n"
1440 "vmovups 0x20(%2), %%ymm4\n"
1441 "vmovups 0x40(%2), %%ymm5\n"
1442 "vmovups (%2), %%ymm6\n"
1443 "vmovups 0x20(%2), %%ymm7\n"
1444 "vmovups 0x40(%2), %%ymm8\n"
1445 "vmovups (%2), %%ymm9\n"
1446 "vmovups 0x20(%2), %%ymm10\n"
1447 "vmovups 0x40(%2), %%ymm11\n"
1448 "jmp 1f\n"
1449 "0:\n"
1450 "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1451 "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1452 "vxorps %%ymm2, %%ymm2, %%ymm2\n"
1453 "vxorps %%ymm3, %%ymm3, %%ymm3\n"
1454 "vxorps %%ymm4, %%ymm4, %%ymm4\n"
1455 "vxorps %%ymm5, %%ymm5, %%ymm5\n"
1456 "vxorps %%ymm6, %%ymm6, %%ymm6\n"
1457 "vxorps %%ymm7, %%ymm7, %%ymm7\n"
1458 "vxorps %%ymm8, %%ymm8, %%ymm8\n"
1459 "vxorps %%ymm9, %%ymm9, %%ymm9\n"
1460 "vxorps %%ymm10, %%ymm10, %%ymm10\n"
1461 "vxorps %%ymm11, %%ymm11, %%ymm11\n"
1462 "1:\n" // LoopH
1463 "movq %4, %%rsi\n" // width
1464 "movq %0, %%rcx\n" // src_h
1465 "2:\n" // LoopW
1466 "vmovups (%1), %%ymm12\n"
1467 "vmovups (%%rcx), %%ymm13\n"
1468 "vmovups (%%rcx, %7, 1), %%ymm14\n"
1469 "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n"
1470 "vfmadd231ps %%ymm12, %%ymm14, %%ymm3\n"
1471 "vmovups (%%rcx, %7, 2), %%ymm15\n"
1472 "vmovups (%%rcx, %9), %%ymm13\n"
1473 "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n"
1474 "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n"
1475
1476 "vmovups 0x20(%1), %%ymm12\n"
1477 "vmovups 0x20(%%rcx), %%ymm13\n"
1478 "vmovups 0x20(%%rcx, %7, 1), %%ymm14\n"
1479 "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n"
1480 "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n"
1481 "vmovups 0x20(%%rcx, %7, 2), %%ymm15\n"
1482 "vmovups 0x20(%%rcx, %9), %%ymm13\n"
1483 "vfmadd231ps %%ymm12, %%ymm15, %%ymm7\n"
1484 "vfmadd231ps %%ymm12, %%ymm13, %%ymm10\n"
1485
1486 "vmovups 0x40(%1), %%ymm12\n"
1487 "vmovups 0x40(%%rcx), %%ymm13\n"
1488 "vmovups 0x40(%%rcx, %7, 1), %%ymm14\n"
1489 "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n"
1490 "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n"
1491 "vmovups 0x40(%%rcx, %7, 2), %%ymm15\n"
1492 "vmovups 0x40(%%rcx, %9), %%ymm13\n"
1493 "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n"
1494 "vfmadd231ps %%ymm12, %%ymm13, %%ymm11\n"
1495
1496 "addq $96, %1\n"
1497 "addq %5, %%rcx\n" // in_kw_step
1498 "dec %%rsi\n"
1499 "jg 2b\n"
1500
1501 "addq %6, %0\n" // in_kh_step
1502 "addq %8, %1\n" // border in sw need to add remainder data
1503 "dec %3\n"
1504 "jg 1b\n"
1505 :
1506 : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5
1507 "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder), "r"(src_3_step) // 9
1508 : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9",
1509 "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
1510
1511 asm volatile(
1512 "and $0x3, %%eax\n"
1513 "je 0f\n"
1514 // Relu
1515 "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1516 "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1517 "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1518 "vmaxps %%ymm12, %%ymm2, %%ymm2\n"
1519 "vmaxps %%ymm12, %%ymm3, %%ymm3\n"
1520 "vmaxps %%ymm12, %%ymm4, %%ymm4\n"
1521 "vmaxps %%ymm12, %%ymm5, %%ymm5\n"
1522 "vmaxps %%ymm12, %%ymm6, %%ymm6\n"
1523 "vmaxps %%ymm12, %%ymm7, %%ymm7\n"
1524 "vmaxps %%ymm12, %%ymm8, %%ymm8\n"
1525 "vmaxps %%ymm12, %%ymm9, %%ymm9\n"
1526 "vmaxps %%ymm12, %%ymm10, %%ymm10\n"
1527 "vmaxps %%ymm12, %%ymm11, %%ymm11\n"
1528
1529 "and $0x1, %%eax\n"
1530 "je 0f\n"
1531 // relu6
1532 "mov $0x40C00000, %%ecx\n"
1533 "vmovd %%ecx, %%xmm14\n"
1534 "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1535 "vminps %%ymm14, %%ymm0, %%ymm0\n"
1536 "vminps %%ymm14, %%ymm1, %%ymm1\n"
1537 "vminps %%ymm14, %%ymm2, %%ymm2\n"
1538 "vminps %%ymm14, %%ymm3, %%ymm3\n"
1539 "vminps %%ymm14, %%ymm4, %%ymm4\n"
1540 "vminps %%ymm14, %%ymm5, %%ymm5\n"
1541 "vminps %%ymm14, %%ymm6, %%ymm6\n"
1542 "vminps %%ymm14, %%ymm7, %%ymm7\n"
1543 "vminps %%ymm14, %%ymm8, %%ymm8\n"
1544 "vminps %%ymm14, %%ymm9, %%ymm9\n"
1545 "vminps %%ymm14, %%ymm10, %%ymm10\n"
1546 "vminps %%ymm14, %%ymm11, %%ymm11\n"
1547
1548 "0:\n"
1549 "vmovups %%ymm0, (%2)\n" // dst_0
1550 "vmovups %%ymm1, 0x20(%2)\n"
1551 "vmovups %%ymm2, 0x40(%2)\n"
1552 "vmovups %%ymm3, (%2, %1, 1)\n"
1553 "vmovups %%ymm4, 0x20(%2, %1, 1)\n"
1554 "vmovups %%ymm5, 0x40(%2, %1, 1)\n"
1555 "vmovups %%ymm6, (%2, %1, 2)\n"
1556 "vmovups %%ymm7, 0x20(%2, %1, 2)\n"
1557 "vmovups %%ymm8, 0x40(%2, %1, 2)\n"
1558 "vmovups %%ymm9, (%3)\n" // dst+3
1559 "vmovups %%ymm10, 0x20(%3)\n"
1560 "vmovups %%ymm11, 0x40(%3)\n"
1561 :
1562 : "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3)
1563 : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
1564 "%ymm11", "%ymm12", "%ymm14");
1565 }
1566
DepthwiseSW1x24Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1567 void DepthwiseSW1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1568 size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1569 size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1570 in_kh_step *= sizeof(float);
1571 in_kw_step *= sizeof(float);
1572 oc_algin *= sizeof(float);
1573 kw_remainder *= sizeof(float);
1574 asm volatile(
1575 "cmpq $0, %2\n"
1576 "je 0f\n"
1577 "vmovups (%2), %%ymm0\n"
1578 "vmovups 0x20(%2), %%ymm1\n"
1579 "vmovups 0x40(%2), %%ymm2\n"
1580 "jmp 1f\n"
1581 "0:\n"
1582 "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1583 "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1584 "vxorps %%ymm2, %%ymm2, %%ymm2\n"
1585 "1:\n" // LoopH
1586 "movq %4, %%rsi\n" // width
1587 "movq %0, %%rcx\n" // src_h
1588 "2:\n" // Loopw
1589 "vmovups (%%rcx), %%ymm4\n"
1590 "vmovups 0x20(%%rcx), %%ymm5\n"
1591 "vmovups 0x40(%%rcx), %%ymm6\n"
1592 // Weight data is loaded directly from memory instead of into registers for calculation.
1593 "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1594 "vfmadd231ps 0x20(%1), %%ymm5, %%ymm1\n"
1595 "vfmadd231ps 0x40(%1), %%ymm6, %%ymm2\n"
1596 "addq $96, %1\n"
1597
1598 "addq %5, %%rcx\n" // in_kw_step
1599 "dec %%rsi\n"
1600 "jg 2b\n"
1601
1602 "addq %6, %0\n" // in_kh_step
1603 "addq %7, %1\n" // kw_remainder
1604 "dec %3\n"
1605 "jg 1b\n"
1606 :
1607 : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5
1608 "r"(in_kh_step), "r"(kw_remainder) // 7
1609 : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm4", "%ymm5", "%ymm6");
1610
1611 asm volatile(
1612 "and $0x3, %%eax\n"
1613 "je 0f\n"
1614 // Relu
1615 "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1616 "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1617 "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1618 "vmaxps %%ymm12, %%ymm2, %%ymm2\n"
1619
1620 "and $0x1, %%eax\n"
1621 "je 0f\n"
1622 // relu6
1623 "mov $0x40C00000, %%ecx\n"
1624 "vmovd %%ecx, %%xmm14\n"
1625 "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1626 "vminps %%ymm14, %%ymm0, %%ymm0\n"
1627 "vminps %%ymm14, %%ymm1, %%ymm1\n"
1628 "vminps %%ymm14, %%ymm2, %%ymm2\n"
1629
1630 "0:\n"
1631 "vmovups %%ymm0, (%2)\n" // dst_0
1632 "vmovups %%ymm1, 0x20(%2)\n"
1633 "vmovups %%ymm2, 0x40(%2)\n"
1634 :
1635 : "a"(act_flag), "r"(oc_algin), "r"(dst)
1636 : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm12", "%ymm14");
1637 }
1638
DepthwiseSW4x16Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1639 void DepthwiseSW4x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1640 size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1641 size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1642 in_kh_step *= sizeof(float);
1643 in_kw_step *= sizeof(float);
1644 in_sw_step *= sizeof(float);
1645 kw_remainder *= sizeof(float);
1646 size_t src_3_step = 3 * in_sw_step;
1647 float *dst_3 = dst + 3 * oc_algin;
1648 oc_algin *= sizeof(float);
1649 asm volatile(
1650 "cmpq $0, %2\n"
1651 "je 0f\n"
1652 "vmovups (%2), %%ymm0\n"
1653 "vmovups 0x20(%2), %%ymm1\n"
1654 // We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction.
1655 "vmovups (%2), %%ymm3\n"
1656 "vmovups 0x20(%2), %%ymm4\n"
1657 "vmovups (%2), %%ymm6\n"
1658 "vmovups 0x20(%2), %%ymm7\n"
1659 "vmovups (%2), %%ymm9\n"
1660 "vmovups 0x20(%2), %%ymm10\n"
1661 "jmp 1f\n"
1662 "0:\n"
1663 "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1664 "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1665 "vxorps %%ymm3, %%ymm3, %%ymm3\n"
1666 "vxorps %%ymm4, %%ymm4, %%ymm4\n"
1667 "vxorps %%ymm6, %%ymm6, %%ymm6\n"
1668 "vxorps %%ymm7, %%ymm7, %%ymm7\n"
1669 "vxorps %%ymm9, %%ymm9, %%ymm9\n"
1670 "vxorps %%ymm10, %%ymm10, %%ymm10\n"
1671 "1:\n" // LoopH
1672 "movq %4, %%rsi\n" // width
1673 "movq %0, %%rcx\n" // src_h
1674 "2:\n" // LoopW
1675 "vmovups (%1), %%ymm12\n"
1676 "vmovups (%%rcx), %%ymm13\n"
1677 "vmovups (%%rcx, %7, 1), %%ymm14\n"
1678 "vmovups (%%rcx, %7, 2), %%ymm15\n"
1679 "vmovups (%%rcx, %9), %%ymm2\n"
1680 "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n"
1681 "vfmadd231ps %%ymm12, %%ymm14, %%ymm3\n"
1682 "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n"
1683 "vfmadd231ps %%ymm12, %%ymm2, %%ymm9\n"
1684
1685 "vmovups 0x20(%1), %%ymm12\n"
1686 "vmovups 0x20(%%rcx), %%ymm13\n"
1687 "vmovups 0x20(%%rcx, %7, 1), %%ymm14\n"
1688 "vmovups 0x20(%%rcx, %7, 2), %%ymm15\n"
1689 "vmovups 0x20(%%rcx, %9), %%ymm2\n"
1690 "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n"
1691 "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n"
1692 "vfmadd231ps %%ymm12, %%ymm15, %%ymm7\n"
1693 "vfmadd231ps %%ymm12, %%ymm2, %%ymm10\n"
1694
1695 "addq $64, %1\n"
1696 "addq %5, %%rcx\n" // in_kw_step
1697 "dec %%rsi\n"
1698 "jg 2b\n"
1699
1700 "addq %6, %0\n" // in_kh_step
1701 "addq %8, %1\n" // border in sw need to add remainder data
1702 "dec %3\n"
1703 "jg 1b\n"
1704 :
1705 : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5
1706 "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder), "r"(src_3_step) // 9
1707 : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm3", "%ymm4", "%ymm6", "%ymm7", "%ymm9", "%ymm10", "%ymm12", "%ymm13",
1708 "%ymm14", "%ymm15");
1709
1710 asm volatile(
1711 "and $0x3, %%eax\n"
1712 "je 0f\n"
1713 // Relu
1714 "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1715 "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1716 "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1717 "vmaxps %%ymm12, %%ymm3, %%ymm3\n"
1718 "vmaxps %%ymm12, %%ymm4, %%ymm4\n"
1719 "vmaxps %%ymm12, %%ymm6, %%ymm6\n"
1720 "vmaxps %%ymm12, %%ymm7, %%ymm7\n"
1721 "vmaxps %%ymm12, %%ymm9, %%ymm9\n"
1722 "vmaxps %%ymm12, %%ymm10, %%ymm10\n"
1723
1724 "and $0x1, %%eax\n"
1725 "je 0f\n"
1726 // relu6
1727 "mov $0x40C00000, %%ecx\n"
1728 "vmovd %%ecx, %%xmm14\n"
1729 "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1730 "vminps %%ymm14, %%ymm0, %%ymm0\n"
1731 "vminps %%ymm14, %%ymm1, %%ymm1\n"
1732 "vminps %%ymm14, %%ymm3, %%ymm3\n"
1733 "vminps %%ymm14, %%ymm4, %%ymm4\n"
1734 "vminps %%ymm14, %%ymm6, %%ymm6\n"
1735 "vminps %%ymm14, %%ymm7, %%ymm7\n"
1736 "vminps %%ymm14, %%ymm9, %%ymm9\n"
1737 "vminps %%ymm14, %%ymm10, %%ymm10\n"
1738
1739 "0:\n"
1740 "vmovups %%ymm0, (%2)\n" // dst_0
1741 "vmovups %%ymm1, 0x20(%2)\n"
1742 "vmovups %%ymm3, (%2, %1, 1)\n"
1743 "vmovups %%ymm4, 0x20(%2, %1, 1)\n"
1744 "vmovups %%ymm6, (%2, %1, 2)\n"
1745 "vmovups %%ymm7, 0x20(%2, %1, 2)\n"
1746 "vmovups %%ymm9, (%3)\n" // dst+3
1747 "vmovups %%ymm10, 0x20(%3)\n"
1748 :
1749 : "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3)
1750 : "%ecx", "%ymm0", "%ymm1", "%ymm3", "%ymm4", "%ymm6", "%ymm7", "%ymm9", "%ymm10", "%ymm12", "%ymm14");
1751 }
1752
DepthwiseSW1x16Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1753 void DepthwiseSW1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1754 size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1755 size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1756 in_kh_step *= sizeof(float);
1757 in_kw_step *= sizeof(float);
1758 oc_algin *= sizeof(float);
1759 kw_remainder *= sizeof(float);
1760 asm volatile(
1761 "cmpq $0, %2\n"
1762 "je 0f\n"
1763 "vmovups (%2), %%ymm0\n"
1764 "vmovups 0x20(%2), %%ymm1\n"
1765 "jmp 1f\n"
1766 "0:\n"
1767 "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1768 "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1769 "1:\n" // LoopH
1770 "movq %4, %%rsi\n" // width
1771 "movq %0, %%rcx\n" // src_h
1772 "2:\n" // Loopw
1773 "vmovups (%%rcx), %%ymm4\n"
1774 "vmovups 0x20(%%rcx), %%ymm5\n"
1775 // Weight data is loaded directly from memory instead of into registers for calculation.
1776 "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1777 "vfmadd231ps 0x20(%1), %%ymm5, %%ymm1\n"
1778 "addq $64, %1\n"
1779
1780 "addq %5, %%rcx\n" // in_kw_step
1781 "dec %%rsi\n"
1782 "jg 2b\n"
1783
1784 "addq %6, %0\n" // in_kh_step
1785 "addq %7, %1\n" // kw_remainder
1786 "dec %3\n"
1787 "jg 1b\n"
1788 :
1789 : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5
1790 "r"(in_kh_step), "r"(kw_remainder) // 7
1791 : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm4", "%ymm5");
1792
1793 asm volatile(
1794 "and $0x3, %%eax\n"
1795 "je 0f\n"
1796 // Relu
1797 "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1798 "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1799 "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1800
1801 "and $0x1, %%eax\n"
1802 "je 0f\n"
1803 // relu6
1804 "mov $0x40C00000, %%ecx\n"
1805 "vmovd %%ecx, %%xmm14\n"
1806 "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1807 "vminps %%ymm14, %%ymm0, %%ymm0\n"
1808 "vminps %%ymm14, %%ymm1, %%ymm1\n"
1809
1810 "0:\n"
1811 "vmovups %%ymm0, (%2)\n" // dst_0
1812 "vmovups %%ymm1, 0x20(%2)\n"
1813 :
1814 : "a"(act_flag), "r"(oc_algin), "r"(dst)
1815 : "%ecx", "%ymm0", "%ymm1", "%ymm12", "%ymm14");
1816 }
1817
DepthwiseSW8x8Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1818 void DepthwiseSW8x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1819 size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1820 size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1821 in_kh_step *= sizeof(float);
1822 in_sw_step *= sizeof(float);
1823 in_kw_step *= sizeof(float);
1824 kw_remainder *= sizeof(float);
1825 size_t src_3_step = 3 * in_sw_step;
1826 float *dst_3 = dst + 3 * oc_algin;
1827 float *dst_5 = dst + 5 * oc_algin;
1828 oc_algin *= sizeof(float);
1829 asm volatile(
1830 "cmpq $0, %0\n"
1831 "je 0f\n"
1832 "vmovups (%0), %%ymm0\n"
1833 "vmovups (%0), %%ymm1\n"
1834 "vmovups (%0), %%ymm2\n"
1835 "vmovups (%0), %%ymm3\n"
1836 "vmovups (%0), %%ymm4\n"
1837 "vmovups (%0), %%ymm5\n"
1838 "vmovups (%0), %%ymm6\n"
1839 "vmovups (%0), %%ymm7\n"
1840 "jmp 1f\n"
1841 "0:\n"
1842 "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1843 "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1844 "vxorps %%ymm2, %%ymm2, %%ymm2\n"
1845 "vxorps %%ymm3, %%ymm3, %%ymm3\n"
1846 "vxorps %%ymm4, %%ymm4, %%ymm4\n"
1847 "vxorps %%ymm5, %%ymm5, %%ymm5\n"
1848 "vxorps %%ymm6, %%ymm6, %%ymm6\n"
1849 "vxorps %%ymm7, %%ymm7, %%ymm7\n"
1850 "1:\n"
1851 :
1852 : "r"(bias)
1853 : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7");
1854
1855 asm volatile(
1856 "LoopH:\n"
1857 "movq %3, %%rsi\n" // width
1858 "movq %0, %%rcx\n" // src_h
1859 "LoopW:\n"
1860 "movq %%rcx, %%rax\n"
1861 "vmovups (%1), %%ymm12\n"
1862 "vmovups (%%rax), %%ymm13\n"
1863 "vmovups (%%rax, %6), %%ymm14\n"
1864 "vmovups (%%rax, %6, 2), %%ymm15\n"
1865 "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n"
1866 "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n"
1867 "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n"
1868 "addq %7, %%rax\n"
1869 "vmovups (%%rax), %%ymm13\n"
1870 "vmovups (%%rax, %6), %%ymm14\n"
1871 "vmovups (%%rax, %6, 2), %%ymm15\n"
1872 "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n"
1873 "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n"
1874 "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n"
1875 "addq %7, %%rax\n"
1876 "vmovups (%%rax), %%ymm13\n"
1877 "vmovups (%%rax, %6), %%ymm14\n"
1878 "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n"
1879 "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n"
1880
1881 "addq $32, %1\n"
1882 "addq %4, %%rcx\n" // in_kw_step
1883 "dec %%rsi\n"
1884 "jg LoopW\n"
1885
1886 "addq %5, %0\n" // in_kh_step
1887 "addq %8, %1\n" // border in sw need to add remainder data
1888 "dec %2\n"
1889 "jg LoopH\n"
1890 :
1891 : "r"(src), "r"(weight), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), "r"(in_kh_step), // 5
1892 "r"(in_sw_step), "r"(src_3_step), "r"(kw_remainder) // 8
1893 : "%rcx", "%rsi", "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm12",
1894 "%ymm13", "%ymm14", "%ymm15");
1895
1896 asm volatile(
1897 "and $0x3, %%eax\n"
1898 "je Write\n"
1899 // Relu
1900 "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1901 "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1902 "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1903 "vmaxps %%ymm12, %%ymm2, %%ymm2\n"
1904 "vmaxps %%ymm12, %%ymm3, %%ymm3\n"
1905 "vmaxps %%ymm12, %%ymm4, %%ymm4\n"
1906 "vmaxps %%ymm12, %%ymm5, %%ymm5\n"
1907 "vmaxps %%ymm12, %%ymm6, %%ymm6\n"
1908 "vmaxps %%ymm12, %%ymm7, %%ymm7\n"
1909
1910 "and $0x1, %%eax\n"
1911 "je Write\n"
1912 // relu6
1913 "mov $0x40C00000, %%ecx\n"
1914 "vmovd %%ecx, %%xmm14\n"
1915 "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1916 "vminps %%ymm14, %%ymm0, %%ymm0\n"
1917 "vminps %%ymm14, %%ymm1, %%ymm1\n"
1918 "vminps %%ymm14, %%ymm2, %%ymm2\n"
1919 "vminps %%ymm14, %%ymm3, %%ymm3\n"
1920 "vminps %%ymm14, %%ymm4, %%ymm4\n"
1921 "vminps %%ymm14, %%ymm5, %%ymm5\n"
1922 "vminps %%ymm14, %%ymm6, %%ymm6\n"
1923 "vminps %%ymm14, %%ymm7, %%ymm7\n"
1924
1925 "Write:\n"
1926 "vmovups %%ymm0, (%2)\n" // dst_0
1927 "vmovups %%ymm1, (%2, %1)\n"
1928 "vmovups %%ymm2, (%2, %1, 2)\n"
1929 "vmovups %%ymm3, (%3)\n" // dst_3
1930 "vmovups %%ymm4, (%2, %1, 4)\n"
1931 "vmovups %%ymm5, (%4)\n" // dst_5
1932 "vmovups %%ymm6, (%4, %1, 1)\n"
1933 "vmovups %%ymm7, (%4, %1, 2)\n"
1934 :
1935 : "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3), "r"(dst_5)
1936 : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm12", "%ymm14");
1937 }
1938
DepthwiseSW1x8Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1939 void DepthwiseSW1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1940 size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1941 size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1942 in_kh_step *= sizeof(float);
1943 in_kw_step *= sizeof(float);
1944 oc_algin *= sizeof(float);
1945 kw_remainder *= sizeof(float);
1946 asm volatile(
1947 "cmpq $0, %2\n"
1948 "je 0f\n"
1949 "vmovups (%2), %%ymm0\n"
1950 "jmp 1f\n"
1951 "0:\n"
1952 "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1953 "1:\n" // LoopH
1954 "movq %4, %%rsi\n" // width
1955 "movq %0, %%rcx\n" // src_h
1956 "2:\n" // Loopw
1957 "vmovups (%%rcx), %%ymm4\n"
1958 // Weight data is loaded directly from memory instead of into registers for calculation.
1959 "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1960 "addq $32, %1\n"
1961
1962 "addq %5, %%rcx\n" // in_kw_step
1963 "dec %%rsi\n"
1964 "jg 2b\n"
1965
1966 "addq %6, %0\n" // in_kh_step
1967 "addq %7, %1\n" // kw_remainder
1968 "dec %3\n"
1969 "jg 1b\n"
1970 :
1971 : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5
1972 "r"(in_kh_step), "r"(kw_remainder) // 7
1973 : "%rcx", "%rsi", "%ymm0", "%ymm4");
1974
1975 asm volatile(
1976 "and $0x3, %%eax\n"
1977 "je 0f\n"
1978 // Relu
1979 "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1980 "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1981
1982 "and $0x1, %%eax\n"
1983 "je 0f\n"
1984 // relu6
1985 "mov $0x40C00000, %%ecx\n"
1986 "vmovd %%ecx, %%xmm14\n"
1987 "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1988 "vminps %%ymm14, %%ymm0, %%ymm0\n"
1989
1990 "0:\n"
1991 "vmovups %%ymm0, (%2)\n" // dst_0
1992 :
1993 : "a"(act_flag), "r"(oc_algin), "r"(dst)
1994 : "%ecx", "%ymm0", "%ymm12", "%ymm14");
1995 }
1996 #endif
1997