• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "nnacl/fp32/conv_depthwise_fp32.h"
17 #include "nnacl/common_func.h"
18 #include "nnacl/fp32/common_func_fp32.h"
19 #include "nnacl/intrinsics/ms_simd_instructions.h"
20 #include "nnacl/errorcode.h"
21 #include "nnacl/fp32/activation_fp32.h"
22 
23 #if !defined(ENABLE_ARM) && !defined(ENABLE_SSE)
ConvDwFp32Row(float * output_ptr,const float * input_ptr,const float * weight_ptr,int num_pixels,int output_channel,int input_step)24 void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, int num_pixels,
25                    int output_channel, int input_step) {
26   for (int i = 0; i < num_pixels; i++) {
27     for (int c = 0; c < output_channel; c++) {
28       *output_ptr++ += weight_ptr[c] * input_ptr[c];
29     }
30     input_ptr += input_step;
31   }
32 }
33 #endif
34 
ConvDw(float * output_data,const float * input_data,const float * weight_data,const float * bias_data,const ConvParameter * conv_param,int task_id)35 int ConvDw(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
36            const ConvParameter *conv_param, int task_id) {
37   if (conv_param->thread_num_ == 0 || conv_param->dilation_h_ == 0 || conv_param->stride_w_ == 0) {
38     return NNACL_ERR;
39   }
40   int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_);
41   int h_start = h_step * task_id;
42   int h_end = MSMIN(h_start + h_step, conv_param->output_h_);
43   bool relu = conv_param->act_type_ == ActType_Relu;
44   bool relu6 = conv_param->act_type_ == ActType_Relu6;
45   for (int b = 0; b < conv_param->output_batch_; b++) {
46     const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_;
47     float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
48     for (int oh = h_start; oh < h_end; oh++) {
49       float *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_;
50 
51       int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_u_;
52       int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_));
53       int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_));
54 
55       for (int ow = 0; ow < conv_param->output_w_; ow++) {
56         memcpy(dst_data + ow * conv_param->output_channel_, bias_data,
57                conv_param->output_channel_ * (int)(sizeof(float)));
58       }
59       for (int kh = start_kh; kh < end_kh; kh++) {
60         int ih = ih_origin + conv_param->dilation_h_ * kh;
61 
62         const float *src_kh = src + ih * conv_param->input_w_ * conv_param->input_channel_;
63         const float *dw_weight_kh = weight_data + kh * conv_param->kernel_w_ * conv_param->output_channel_;
64 
65         int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_;
66         for (int kw = 0; kw < conv_param->kernel_w_; kw++) {
67           int out_w_start = MSMAX(
68             0, (conv_param->pad_l_ - conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / conv_param->stride_w_);
69           int out_w_end = MSMIN(conv_param->output_w_, (conv_param->input_w_ + conv_param->pad_l_ -
70                                                         conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) /
71                                                          conv_param->stride_w_);
72 
73           float *dst_w = dst_data + out_w_start * conv_param->output_channel_;
74           int iw_origin = (out_w_start * conv_param->stride_w_) - conv_param->pad_l_ + conv_param->dilation_w_ * kw;
75 
76           const float *src_kw = src_kh + iw_origin * conv_param->input_channel_;
77           int num_pixels = out_w_end - out_w_start;
78 
79           ConvDwFp32Row(dst_w, src_kw, dw_weight_kh, num_pixels, conv_param->output_channel_, in_sw_step);
80           dw_weight_kh += conv_param->output_channel_;
81         }
82       }
83       if (relu) {
84         Fp32Relu(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data);
85       }
86       if (relu6) {
87         Fp32Relu6(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data);
88       }
89     }
90   }
91   return NNACL_OK;
92 }
93 
94 #ifdef ENABLE_AVX512
ConvDwAVX512(float * output_data,const float * input_data,const float * weight_data,const float * bias_data,const ConvParameter * conv_param,int task_id,ConvDwCalcParam * conv_dw_calc_param)95 int ConvDwAVX512(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
96                  const ConvParameter *conv_param, int task_id, ConvDwCalcParam *conv_dw_calc_param) {
97   if (conv_param->thread_num_ == 0 || conv_param->dilation_h_ == 0 || conv_param->stride_w_ == 0) {
98     return NNACL_ERR;
99   }
100   int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_);
101   int h_start = h_step * task_id;
102   int h_end = MSMIN(h_start + h_step, conv_param->output_h_);
103   bool relu = conv_param->act_type_ == ActType_Relu;
104   bool relu6 = conv_param->act_type_ == ActType_Relu6;
105 
106   int32_t *num_pixels = conv_dw_calc_param->num_pixels_;
107   int32_t *out_w_start = conv_dw_calc_param->out_w_start_;
108   int first_calc_kw = conv_dw_calc_param->first_calc_kw_;
109 
110   for (int b = 0; b < conv_param->output_batch_; b++) {
111     const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_;
112     float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
113     for (int oh = h_start; oh < h_end; oh++) {
114       float *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_;
115 
116       int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_u_;
117       int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_));
118       int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_));
119 
120       bool first_calc_flag = true;
121       if (first_calc_kw == -1) {
122         first_calc_flag = false;
123         for (int ow = 0; ow < conv_param->output_w_; ow++) {
124           memcpy(dst_data + ow * conv_param->output_channel_, bias_data,
125                  conv_param->output_channel_ * (int)(sizeof(float)));
126         }
127       }
128       for (int kh = start_kh; kh < end_kh; kh++) {
129         int ih = ih_origin + conv_param->dilation_h_ * kh;
130 
131         const float *src_kh = src + ih * conv_param->input_w_ * conv_param->input_channel_;
132         const float *weight_kh = weight_data + kh * conv_param->kernel_w_ * conv_param->output_channel_;
133 
134         int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_;
135 
136         if (first_calc_flag) {
137           int iw_origin = -conv_param->pad_l_ + conv_param->dilation_w_ * first_calc_kw;
138 
139           const float *src_kw = src_kh + iw_origin * conv_param->input_channel_;
140 
141           ConvDwAVX512Fp32Row(dst_data, src_kw, weight_kh + first_calc_kw * conv_param->output_channel_,
142                               conv_param->output_w_, conv_param->output_channel_, in_sw_step, true, bias_data);
143         }
144         for (int kw = 0; kw < conv_param->kernel_w_; kw++) {
145           if (first_calc_flag && (kw == first_calc_kw)) {
146             first_calc_flag = false;
147             weight_kh += conv_param->output_channel_;
148             continue;
149           }
150 
151           float *dst_w = dst_data + out_w_start[kw] * conv_param->output_channel_;
152           int iw_origin = (out_w_start[kw] * conv_param->stride_w_) - conv_param->pad_l_ + conv_param->dilation_w_ * kw;
153 
154           const float *src_kw = src_kh + iw_origin * conv_param->input_channel_;
155 
156           ConvDwAVX512Fp32Row(dst_w, src_kw, weight_kh, num_pixels[kw], conv_param->output_channel_, in_sw_step, false,
157                               bias_data);
158           weight_kh += conv_param->output_channel_;
159         }
160       }
161       if (relu) {
162         Fp32Relu(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data);
163       } else if (relu6) {
164         Fp32Relu6(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data);
165       }
166     }
167   }
168   return NNACL_OK;
169 }
170 #endif
171 
InitSlidingParam(SlidingWindowParam * sliding,const ConvParameter * conv_param,int block)172 void InitSlidingParam(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) {
173   if (block == 0) {
174     return;
175   }
176   int left = 0;
177   int right = conv_param->output_w_;
178   int top = 0;
179   int bottom = conv_param->output_h_;
180 
181   while (left * conv_param->stride_w_ < conv_param->pad_l_) {
182     left++;
183   }
184   while ((right - 1) * conv_param->stride_w_ - conv_param->pad_l_ + conv_param->kernel_w_ * conv_param->dilation_w_ >
185            conv_param->input_w_ &&
186          right > left) {
187     right--;
188   }
189   while (top * conv_param->stride_h_ < conv_param->pad_u_) {
190     top++;
191   }
192   while ((bottom - 1) * conv_param->stride_h_ - conv_param->pad_u_ + conv_param->kernel_h_ * conv_param->dilation_h_ >
193            conv_param->input_h_ &&
194          bottom > top) {
195     bottom--;
196   }
197   sliding->left_ = left;
198   sliding->right_ = right;
199   sliding->top_ = top;
200   sliding->bottom_ = bottom;
201   sliding->c_block_ = UP_DIV(conv_param->output_channel_, block);
202   sliding->block_channel_ = UP_DIV(conv_param->output_channel_, block) * block;
203   sliding->out_step_ = conv_param->output_h_ * conv_param->output_w_ * sliding->block_channel_;
204   if (conv_param->out_format_ == Format_NC4HW4) {
205     // write to nc8hw8
206     sliding->out_h_step_ = conv_param->output_w_ * block;
207     sliding->out_c_step_ = block * conv_param->output_h_ * conv_param->output_w_;
208     sliding->out_w_step_ = block;
209     sliding->out_block_step_ = sliding->out_c_step_;
210   } else {
211     // write to nhwc
212     sliding->out_h_step_ = conv_param->output_w_ * sliding->block_channel_;
213     sliding->out_c_step_ = block;
214     sliding->out_w_step_ = sliding->block_channel_;
215     sliding->out_block_step_ = sliding->out_w_step_;
216   }
217 }
218 
InitSlidingParamConv(SlidingWindowParam * sliding,const ConvParameter * conv_param,int input_block,int weight_block)219 void InitSlidingParamConv(SlidingWindowParam *sliding, const ConvParameter *conv_param, int input_block,
220                           int weight_block) {
221   InitSlidingParam(sliding, conv_param, weight_block);
222   AppendSlidingParamConv(sliding, conv_param, input_block, weight_block);
223 }
224 
AppendSlidingParamConv(SlidingWindowParam * sliding,const ConvParameter * conv_param,int input_block,int weight_block)225 void AppendSlidingParamConv(SlidingWindowParam *sliding, const ConvParameter *conv_param, int input_block,
226                             int weight_block) {
227   if (input_block == 0) {  // is not aligned
228     sliding->ic_align_ = conv_param->input_channel_;
229   } else {  // 1x1 input is aligned to input_block
230     sliding->ic_align_ = UP_DIV(conv_param->input_channel_, input_block) * input_block;
231   }
232   sliding->in_step_ = conv_param->input_h_ * conv_param->input_w_ * sliding->ic_align_;  // for batch loop
233   sliding->in_h_step_ = conv_param->input_w_ * sliding->ic_align_;
234   sliding->in_sh_step_ = conv_param->input_w_ * sliding->ic_align_ * conv_param->stride_h_;    // stride H
235   sliding->in_sw_step_ = sliding->ic_align_ * conv_param->stride_w_;                           // stride W
236   sliding->in_kh_step_ = conv_param->input_w_ * sliding->ic_align_ * conv_param->dilation_h_;  // kernel H
237   sliding->in_kw_step_ = sliding->ic_align_ * conv_param->dilation_w_;                         // kernel W
238   sliding->kernel_step_ = conv_param->kernel_w_ * conv_param->kernel_h_ * sliding->ic_align_ * weight_block;
239 }
240 
InitSlidingParamConvDw(SlidingWindowParam * sliding,const ConvParameter * conv_param,int block)241 void InitSlidingParamConvDw(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) {
242   InitSlidingParam(sliding, conv_param, block);
243   AppendSlidingParamConvDw(sliding, conv_param, block);
244 }
245 
AppendSlidingParamConvDw(SlidingWindowParam * sliding,const ConvParameter * conv_param,int block)246 void AppendSlidingParamConvDw(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) {
247   sliding->in_step_ = conv_param->input_h_ * conv_param->input_w_ * sliding->block_channel_;  // for batch loop
248   sliding->in_h_step_ = conv_param->input_w_ * sliding->block_channel_;
249   sliding->in_sh_step_ = conv_param->input_w_ * sliding->block_channel_ * conv_param->stride_h_;    // stride H
250   sliding->in_sw_step_ = sliding->block_channel_ * conv_param->stride_w_;                           // stride W
251   sliding->in_kh_step_ = conv_param->input_w_ * sliding->block_channel_ * conv_param->dilation_h_;  // kernel H
252   sliding->in_kw_step_ = sliding->block_channel_ * conv_param->dilation_w_;                         // kernel W
253   sliding->kernel_step_ = conv_param->kernel_w_ * conv_param->kernel_h_ * block;
254 }
255 
256 /*conv depthwise fp32 begin*/
ConvDwBorderPixel(float * dst,const float * src,const float * weight,const float * bias,int height,int width,int in_kh_step,int in_kw_step,int kernel_w_step,bool is_relu,bool is_relu6)257 void ConvDwBorderPixel(float *dst, const float *src, const float *weight, const float *bias, int height, int width,
258                        int in_kh_step, int in_kw_step, int kernel_w_step, bool is_relu, bool is_relu6) {
259   const float *src_kh = src;
260   const float *weight_kh = weight;
261   for (int c = 0; c < C4NUM; c++) {
262     dst[c] = 0;
263   }
264   for (int kh = 0; kh < height; kh++) {
265     const float *src_kw = src_kh;
266     const float *weight_kw = weight_kh;
267     for (int kw = 0; kw < width; kw++) {
268       for (int c = 0; c < C4NUM; c++) {
269         dst[c] += src_kw[c] * weight_kw[c];
270       }
271       src_kw += in_kw_step;
272       weight_kw += C4NUM;
273     }  // kernel_w loop
274     src_kh += in_kh_step;
275     weight_kh += kernel_w_step;
276   }  // kernel_h loop
277   for (int c = 0; c < C4NUM; c++) {
278     dst[c] += bias[c];
279     dst[c] = (is_relu) ? (MSMAX(0, dst[c])) : (dst[c]);
280     dst[c] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst[c]))) : (dst[c]);
281   }
282 }
283 
ConvDwBorder(float * dst,const float * src,const float * weight,const float * bias,int top,int bottom,int left,int right,const ConvParameter * conv_param,const SlidingWindowParam * sliding)284 void ConvDwBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, int left,
285                   int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
286   if (conv_param->dilation_h_ == 0 || conv_param->dilation_w_ == 0) {
287     return;
288   }
289   bool relu = conv_param->act_type_ == ActType_Relu;
290   bool relu6 = conv_param->act_type_ == ActType_Relu6;
291   float *dst_h = dst + top * sliding->out_h_step_;
292   for (int oh = top; oh < bottom; oh++) {
293     int ih = oh * conv_param->stride_h_ - conv_param->pad_u_;
294     int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
295     int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
296     const float *src_h = src + ih * sliding->in_h_step_;
297 
298     float *dst_kernel = dst_h + left * sliding->block_channel_;
299     for (int ow = left; ow < right; ow++) {
300       int iw = ow * conv_param->stride_w_ - conv_param->pad_l_;
301       int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
302       int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
303       const float *src_w = src_h + iw * sliding->block_channel_;
304 
305       const float *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_;
306       const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM;
307 #ifdef ENABLE_AVX
308       ConvDwFp32BorderParam *param = (ConvDwFp32BorderParam *)malloc(sizeof(ConvDwFp32BorderParam));
309       if (param == NULL) {
310         return;
311       }
312       param->dst = dst_kernel;
313       param->src = src_kernel;
314       param->weight = weight_kernel;
315       param->bias = bias;
316       param->height = end_kh - start_kh;
317       param->width = end_kw - start_kw;
318       param->in_kh_step = sliding->in_kh_step_ * sizeof(float);
319       param->in_kw_step = sliding->in_kw_step_ * sizeof(float);
320       param->kernel_w = conv_param->kernel_w_ * C4NUM * sizeof(float);
321       param->relu = relu;
322       param->relu6 = relu6;
323       ConvDwFp32Border(param);
324       free(param);
325 #elif defined(ENABLE_ARM) || defined(ENABLE_SSE)
326       ConvDwFp32Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
327                        sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float),
328                        conv_param->kernel_w_ * C4NUM * sizeof(float), relu, relu6);
329 #else
330       ConvDwBorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
331                         sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C4NUM, relu, relu6);
332 #endif
333       dst_kernel += sliding->block_channel_;
334     }  // width loop
335     dst_h += sliding->out_h_step_;
336   }  // height loop
337 }
338 
339 #ifndef ENABLE_ARM64
ConvDwCenter(float * dst,const float * src,const float * weight,const float * bias,int height,int width,int kernel_h,int kernel_w,int out_h_step,int block_channel,int in_sh_step,int in_sw_step,int in_kh_step,int in_kw_step,bool is_relu,bool is_relu6)340 void ConvDwCenter(float *dst, const float *src, const float *weight, const float *bias, int height, int width,
341                   int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step,
342                   int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6) {
343   float *dst_h = dst;
344   const float *src_h = src;
345   for (int oh = 0; oh < height; oh++) {
346     float *dst_w = dst_h;
347     const float *src_w = src_h;
348     for (int ow = 0; ow < width; ow++) {
349       const float *src_kh = src_w;
350       const float *weight_kh = weight;
351       for (int c = 0; c < C4NUM; c++) {
352         dst_w[c] = 0;
353       }
354       for (int kh = 0; kh < kernel_h; kh++) {
355         const float *src_kw = src_kh;
356         const float *weight_kw = weight_kh;
357         for (int kw = 0; kw < kernel_w; kw++) {
358           for (int c = 0; c < C4NUM; c++) {
359             dst_w[c] += src_kw[c] * weight_kw[c];
360           }
361           src_kw += in_kw_step;
362           weight_kw += C4NUM;
363         }  // kernel_w loop
364         src_kh += in_kh_step;
365         weight_kh += kernel_w * C4NUM;
366       }  // kernel_h loop
367       // add biad relu
368       for (int c = 0; c < C4NUM; c++) {
369         dst_w[c] += bias[c];
370         dst_w[c] = (is_relu) ? (MSMAX(0, dst_w[c])) : (dst_w[c]);
371         dst_w[c] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst_w[c]))) : (dst_w[c]);
372       }
373       dst_w += block_channel;
374       src_w += in_sw_step;
375     }  // dst_width loop
376     dst_h += out_h_step;
377     src_h += in_sh_step;
378   }  // dst_height loop
379 }
380 #endif
381 
382 // conv depthwise fp32: sliding window
ConvDwSWFp32(float * output_data,const float * input_data,const float * weight_data,const float * bias_data,const ConvParameter * conv_param,const SlidingWindowParam * sliding,int task_id)383 void ConvDwSWFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
384                   const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) {
385   bool relu = conv_param->act_type_ == ActType_Relu;
386   bool relu6 = conv_param->act_type_ == ActType_Relu6;
387   if (conv_param->thread_num_ == 0) {
388     return;
389   }
390   const float *src = input_data;
391   float *dst = output_data;
392   for (int b = 0; b < conv_param->output_batch_; b++) {
393     for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) {
394       const float *src_data = src + oc * C4NUM;
395       float *dst_data = dst + oc * C4NUM;
396       const float *weight = weight_data + oc * sliding->kernel_step_;
397       const float *bias = bias_data + oc * C4NUM;
398       ConvDwBorder(dst_data, src_data, weight, bias, 0, sliding->top_, 0, conv_param->output_w_, conv_param, sliding);
399       ConvDwBorder(dst_data, src_data, weight, bias, sliding->bottom_, conv_param->output_h_, 0, conv_param->output_w_,
400                    conv_param, sliding);
401       ConvDwBorder(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, 0, sliding->left_, conv_param,
402                    sliding);
403       ConvDwBorder(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, sliding->right_,
404                    conv_param->output_w_, conv_param, sliding);
405 
406       if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) {
407         int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_;
408         int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
409         const float *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_;
410         float *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;
411 #if defined(ENABLE_ARM) || defined(ENABLE_SSE)
412         ConvDwFp32Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
413                          conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float),
414                          sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float),
415                          sliding->in_sw_step_ * sizeof(float), sliding->in_kh_step_ * sizeof(float),
416                          sliding->in_kw_step_ * sizeof(float), relu, relu6);
417 #else
418         ConvDwCenter(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
419                      conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_,
420                      sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_, relu,
421                      relu6);
422 #endif
423       }
424     }  // output C4 loop
425     src += sliding->in_step_;
426     dst += sliding->out_step_;
427   }  // batch loop
428   // output nhwc4
429 }
430 /*conv depthwise fp32 end*/
431 
432 /*conv depthwise 3x3 fp32 begin*/
CheckConvDwUse3X3(const ConvParameter * conv_param)433 bool CheckConvDwUse3X3(const ConvParameter *conv_param) {
434   bool use_3x3 =
435     conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 &&
436     (conv_param->stride_h_ == 1 || conv_param->stride_h_ == 2) &&
437     (conv_param->stride_w_ == 1 || conv_param->stride_w_ == 2) && conv_param->stride_h_ == conv_param->stride_w_ &&
438     (conv_param->pad_u_ == 0 || conv_param->pad_u_ == 1) && (conv_param->pad_l_ == 0 || conv_param->pad_l_ == 1) &&
439     conv_param->pad_u_ == conv_param->pad_l_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1;
440   if (!use_3x3 || conv_param->input_h_ == 1 || conv_param->input_w_ == 1) {
441     return false;
442   }
443   const int in_h = (conv_param->output_h_ - 1) * conv_param->stride_h_ + conv_param->kernel_h_;
444   const int in_w = (conv_param->output_w_ - 1) * conv_param->stride_w_ + conv_param->kernel_w_;
445   return in_h == (conv_param->input_h_ + 2 * conv_param->pad_u_) &&
446          in_w == (conv_param->input_w_ + 2 * conv_param->pad_l_);
447 }
448 
449 #if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
ConvDw3x3RowLeft(const float * src,float * line,int lw,int channel)450 static void ConvDw3x3RowLeft(const float *src, float *line, int lw, int channel) {
451   MS_FLOAT32X4 v0, v1, v2, v3;
452   v0 = MS_MOVQ_F32(0.0f);
453   int ic = 0;
454   for (; ic < channel - 3; ic += 4) {
455     v1 = MS_LDQ_F32(src + ic);
456     v2 = MS_LDQ_F32(src + channel + ic);
457     v3 = MS_LDQ_F32(src + 2 * channel + ic);
458     MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2);
459     MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2);
460     MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1);
461     MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1);
462     MS_STQ_F32(line + lw * ic, b0);
463     MS_STQ_F32(line + lw * ic + 4, b1);
464     MS_STQ_F32(line + lw * ic + 8, b2);
465     MS_STQ_F32(line + lw * ic + 12, b3);
466   }
467   if (ic < channel) {
468     float *remain_line = line + ic * lw;
469     memset(remain_line, 0, 64);
470     for (int i = 0; i < channel - ic; i++) {
471       float d1 = src[i + ic];
472       float d2 = src[i + ic + channel];
473       float d3 = src[i + ic + 2 * channel];
474       remain_line[i] = 0.0f - d2;
475       remain_line[i + 4] = d1 + d2;
476       remain_line[i + 8] = d2 - d1;
477       remain_line[i + 12] = d3 - d1;
478     }
479   }
480 }
481 
ConvDw3x3RowMiddle(const float * src,float * line,int lw,int channel)482 static void ConvDw3x3RowMiddle(const float *src, float *line, int lw, int channel) {
483   MS_FLOAT32X4 v0, v1, v2, v3;
484   int ic = 0;
485   for (; ic < channel - 3; ic += 4) {
486     v0 = MS_LDQ_F32(src + ic);
487     v1 = MS_LDQ_F32(src + channel + ic);
488     v2 = MS_LDQ_F32(src + 2 * channel + ic);
489     v3 = MS_LDQ_F32(src + 3 * channel + ic);
490     MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2);
491     MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2);
492     MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1);
493     MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1);
494     MS_STQ_F32(line + lw * ic, b0);
495     MS_STQ_F32(line + lw * ic + 4, b1);
496     MS_STQ_F32(line + lw * ic + 8, b2);
497     MS_STQ_F32(line + lw * ic + 12, b3);
498   }
499   if (ic < channel) {
500     float *remain_line = line + ic * lw;
501     memset(remain_line, 0, 64);
502     for (int i = 0; i < channel - ic; i++) {
503       float d0 = src[i + ic];
504       float d1 = src[i + ic + channel];
505       float d2 = src[i + ic + 2 * channel];
506       float d3 = src[i + ic + 3 * channel];
507       remain_line[i] = d0 - d2;
508       remain_line[i + 4] = d1 + d2;
509       remain_line[i + 8] = d2 - d1;
510       remain_line[i + 12] = d3 - d1;
511     }
512   }
513 }
514 
ConvDw3x3RowRight(const float * src,float * line,int lw,int channel)515 static void ConvDw3x3RowRight(const float *src, float *line, int lw, int channel) {
516   MS_FLOAT32X4 v0, v1, v2, v3;
517   int ic = 0;
518   v3 = MS_MOVQ_F32(0.0f);
519   for (; ic < channel - 3; ic += 4) {
520     v0 = MS_LDQ_F32(src + ic);
521     v1 = MS_LDQ_F32(src + channel + ic);
522     v2 = MS_LDQ_F32(src + 2 * channel + ic);
523     MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2);
524     MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2);
525     MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1);
526     MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1);
527     MS_STQ_F32(line + lw * ic, b0);
528     MS_STQ_F32(line + lw * ic + 4, b1);
529     MS_STQ_F32(line + lw * ic + 8, b2);
530     MS_STQ_F32(line + lw * ic + 12, b3);
531   }
532   if (ic < channel) {
533     float *remain_line = line + ic * lw;
534     memset(remain_line, 0, 64);
535     for (int i = 0; i < channel - ic; i++) {
536       float d0 = src[i + ic];
537       float d1 = src[i + ic + channel];
538       float d2 = src[i + ic + 2 * channel];
539       remain_line[i] = d0 - d2;
540       remain_line[i + 4] = d1 + d2;
541       remain_line[i + 8] = d2 - d1;
542       remain_line[i + 12] = 0.0f - d1;
543     }
544   }
545 }
546 
ConvDw3x3RowSingle(const float * src,float * line,int lw,int channel)547 static void ConvDw3x3RowSingle(const float *src, float *line, int lw, int channel) {
548   MS_FLOAT32X4 v0, v1, v2;
549   int ic = 0;
550   v2 = MS_MOVQ_F32(0.0f);
551   for (; ic < channel - 3; ic += 4) {
552     v0 = MS_LDQ_F32(src + ic);
553     v1 = MS_LDQ_F32(src + channel + ic);
554     MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1);
555     MS_STQ_F32(line + lw * ic, v0);
556     MS_STQ_F32(line + lw * ic + 4, v1);
557     MS_STQ_F32(line + lw * ic + 8, b2);
558     memset(line + lw * ic + 12, 0, 16);
559   }
560   if (ic < channel) {
561     float *remain_line = line + ic * lw;
562     memset(remain_line, 0, 64);
563     for (int i = 0; i < channel - ic; i++) {
564       float d0 = src[i + ic];
565       float d1 = src[i + ic + channel];
566       remain_line[i] = d0;
567       remain_line[i + 4] = d1;
568       remain_line[i + 8] = 0.0f - d1;
569     }
570   }
571 }
572 
ConvDw3x3InitTop(const float * src,float ** lines,int width,int channel)573 static void ConvDw3x3InitTop(const float *src, float **lines, int width, int channel) {
574   float *line0 = lines[0];
575   float *line1 = lines[1];
576   float *line2 = lines[2];
577   int c4 = UP_ROUND(channel, C4NUM);
578   int lw = UP_DIV(width, C2NUM) * C4NUM;
579   memset(line0, 0, c4 * lw * sizeof(float));
580   ConvDw3x3RowLeft(src, line1, lw, channel);
581   ConvDw3x3RowLeft(src + width * channel, line2, lw, channel);
582   int ow = 2;
583   for (; ow < width - 2; ow += 2) {
584     ConvDw3x3RowMiddle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
585     ConvDw3x3RowMiddle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
586   }
587   int remain = width - ow;
588   if (remain == 2) {
589     ConvDw3x3RowRight(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
590     ConvDw3x3RowRight(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
591   } else if (remain == 1) {
592     ConvDw3x3RowSingle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
593     ConvDw3x3RowSingle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
594   }
595 }
596 
ConvDw3x3InitRow(const float * src,float ** lines,int width,int channel)597 static void ConvDw3x3InitRow(const float *src, float **lines, int width, int channel) {
598   float *line0 = lines[0];
599   float *line1 = lines[1];
600   float *line2 = lines[2];
601   int lw = UP_DIV(width, C2NUM) * C4NUM;
602   ConvDw3x3RowLeft(src - width * channel, line0, lw, channel);
603   ConvDw3x3RowLeft(src, line1, lw, channel);
604   ConvDw3x3RowLeft(src + width * channel, line2, lw, channel);
605   int ow = 2;
606   for (; ow < width - 2; ow += 2) {
607     ConvDw3x3RowMiddle(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel);
608     ConvDw3x3RowMiddle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
609     ConvDw3x3RowMiddle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
610   }
611   int remain = width - ow;
612   if (remain == 2) {
613     ConvDw3x3RowRight(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel);
614     ConvDw3x3RowRight(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
615     ConvDw3x3RowRight(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
616   } else if (remain == 1) {
617     ConvDw3x3RowSingle(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel);
618     ConvDw3x3RowSingle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
619     ConvDw3x3RowSingle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
620   }
621 }
622 
ConvDw3x3Row(const float * src,float ** lines,int width,int channel)623 static void ConvDw3x3Row(const float *src, float **lines, int width, int channel) {
624   float *tmp = lines[0];
625   lines[0] = lines[1];
626   lines[1] = lines[2];
627   lines[2] = tmp;
628   int c4 = UP_ROUND(channel, C4NUM);
629   int lw = UP_DIV(width, C2NUM) * C4NUM;
630   memset(tmp, 0, c4 * lw * sizeof(float));
631   ConvDw3x3RowLeft(src, tmp, lw, channel);
632   int ow = 2;
633   for (; ow < width - 2; ow += 2) {
634     ConvDw3x3RowMiddle(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel);
635   }
636   int remain = width - ow;
637   if (remain == 2) {
638     ConvDw3x3RowRight(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel);
639   } else if (remain == 1) {
640     ConvDw3x3RowSingle(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel);
641   }
642 }
643 
ConvDw3x3Bottom(float ** lines,int width,int channel)644 static void ConvDw3x3Bottom(float **lines, int width, int channel) {
645   float *tmp = lines[0];
646   lines[0] = lines[1];
647   lines[1] = lines[2];
648   lines[2] = tmp;
649   int c4 = UP_ROUND(channel, C4NUM);
650   memset(tmp, 0, UP_DIV(width, C2NUM) * c4 * C4NUM * sizeof(float));
651 }
652 
653 #ifndef ENABLE_ARM64
ConvDw3x3Line(float * dst,float ** lines,const float * weight,const float * bias_data,int width,int ori_channel,bool relu,bool relu6)654 void ConvDw3x3Line(float *dst, float **lines, const float *weight, const float *bias_data, int width, int ori_channel,
655                    bool relu, bool relu6) {
656   int channel = ori_channel;
657   float *line0 = lines[0];
658   float *line1 = lines[1];
659   float *line2 = lines[2];
660   for (; channel > 0; channel -= 4) {
661     MS_FLOAT32X4 bias = MS_LDQ_F32(bias_data);
662     bias_data += 4;
663     MS_FLOAT32X4 g00 = MS_LDQ_F32(weight);
664     MS_FLOAT32X4 g01 = MS_LDQ_F32(weight + 4);
665     MS_FLOAT32X4 g02 = MS_LDQ_F32(weight + 8);
666     MS_FLOAT32X4 g03 = MS_LDQ_F32(weight + 12);
667     MS_FLOAT32X4 g10 = MS_LDQ_F32(weight + 16);
668     MS_FLOAT32X4 g11 = MS_LDQ_F32(weight + 20);
669     MS_FLOAT32X4 g12 = MS_LDQ_F32(weight + 24);
670     MS_FLOAT32X4 g13 = MS_LDQ_F32(weight + 28);
671     MS_FLOAT32X4 g20 = MS_LDQ_F32(weight + 32);
672     MS_FLOAT32X4 g21 = MS_LDQ_F32(weight + 36);
673     MS_FLOAT32X4 g22 = MS_LDQ_F32(weight + 40);
674     MS_FLOAT32X4 g23 = MS_LDQ_F32(weight + 44);
675     weight += 48;
676     float *cur_dst = dst;
677     int ow = 0;
678     for (; ow < width - 1; ow += 2) {
679       MS_FLOAT32X4 acc0 = MS_MULQ_F32(MS_LDQ_F32(line0), g00);
680       MS_FLOAT32X4 acc1 = MS_MULQ_F32(MS_LDQ_F32(line0 + 4), g01);
681       MS_FLOAT32X4 acc2 = MS_MULQ_F32(MS_LDQ_F32(line0 + 8), g02);
682       MS_FLOAT32X4 acc3 = MS_MULQ_F32(MS_LDQ_F32(line0 + 12), g03);
683       line0 += 16;
684       acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line1), g10);
685       acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line1 + 4), g11);
686       acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line1 + 8), g12);
687       acc3 = MS_MLAQ_F32(acc3, MS_LDQ_F32(line1 + 12), g13);
688       line1 += 16;
689       acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line2), g20);
690       acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line2 + 4), g21);
691       acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line2 + 8), g22);
692       acc3 = MS_MLAQ_F32(acc3, MS_LDQ_F32(line2 + 12), g23);
693       line2 += 16;
694       MS_FLOAT32X4 res0 = MS_ADDQ_F32(acc0, MS_ADDQ_F32(acc2, acc1));
695       MS_FLOAT32X4 res1 = MS_ADDQ_F32(acc1, MS_SUBQ_F32(acc3, acc2));
696       res0 = MS_ADDQ_F32(res0, bias);
697       res1 = MS_ADDQ_F32(res1, bias);
698       if (relu || relu6) {
699         res0 = MS_MAXQ_F32(res0, MS_MOVQ_F32(0.0f));
700         res1 = MS_MAXQ_F32(res1, MS_MOVQ_F32(0.0f));
701       }
702       if (relu6) {
703         res0 = MS_MINQ_F32(res0, MS_MOVQ_F32(6.0f));
704         res1 = MS_MINQ_F32(res1, MS_MOVQ_F32(6.0f));
705       }
706       if (channel >= 4) {
707         MS_STQ_F32(cur_dst, res0);
708         MS_STQ_F32(cur_dst + ori_channel, res1);
709       } else {
710         for (int i = 0; i < channel; i++) {
711           cur_dst[i] = MS_F32X4_GETI(res0, i);
712           cur_dst[ori_channel + i] = MS_F32X4_GETI(res1, i);
713         }
714       }
715       cur_dst += 2 * ori_channel;
716     }
717     if (ow < width) {
718       MS_FLOAT32X4 acc0 = MS_MULQ_F32(MS_LDQ_F32(line0), g00);
719       MS_FLOAT32X4 acc1 = MS_MULQ_F32(MS_LDQ_F32(line0 + 4), g01);
720       MS_FLOAT32X4 acc2 = MS_MULQ_F32(MS_LDQ_F32(line0 + 8), g02);
721       line0 += 16;
722       acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line1), g10);
723       acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line1 + 4), g11);
724       acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line1 + 8), g12);
725       line1 += 16;
726       acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line2), g20);
727       acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line2 + 4), g21);
728       acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line2 + 8), g22);
729       line2 += 16;
730       MS_FLOAT32X4 res0 = MS_ADDQ_F32(acc0, MS_ADDQ_F32(acc2, acc1));
731       res0 = MS_ADDQ_F32(res0, bias);
732       if (relu || relu6) {
733         res0 = MS_MAXQ_F32(res0, MS_MOVQ_F32(0.0f));
734       }
735       if (relu6) {
736         res0 = MS_MINQ_F32(res0, MS_MOVQ_F32(6.0f));
737       }
738       if (channel >= 4) {
739         MS_STQ_F32(cur_dst, res0);
740       } else {
741         for (int i = 0; i < channel; i++) {
742           cur_dst[i] = MS_F32X4_GETI(res0, i);
743         }
744       }
745     }
746     dst += 4;
747   }
748 }
749 #endif
750 
ConvDw3x3(float * output_data,float * buffer,const float * input_data,const float * weight_data,const float * bias_data,const ConvParameter * conv_param,int start_oh,int end_oh)751 void ConvDw3x3(float *output_data, float *buffer, const float *input_data, const float *weight_data,
752                const float *bias_data, const ConvParameter *conv_param, int start_oh, int end_oh) {
753   int units = UP_DIV(conv_param->output_w_, C2NUM);
754   int c4 = UP_ROUND(conv_param->input_channel_, C4NUM);
755   int line = conv_param->input_channel_ * conv_param->input_w_;
756 
757   bool relu = conv_param->act_type_ == ActType_Relu;
758   bool relu6 = conv_param->act_type_ == ActType_Relu6;
759 
760   for (int b = 0; b < conv_param->output_batch_; b++) {
761     const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_;
762     float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
763     float *line0 = buffer;
764     float *line1 = buffer + units * c4 * C4NUM;
765     float *line2 = buffer + units * c4 * C8NUM;
766     float *lines[3] = {line0, line1, line2};
767     int oh = start_oh;
768     if (oh == 0) {
769       // input trans
770       ConvDw3x3InitTop(src, lines, conv_param->output_w_, conv_param->input_channel_);
771     } else {
772       // input trans
773       ConvDw3x3InitRow(src + oh * line, lines, conv_param->output_w_, conv_param->input_channel_);
774     }
775     // dst calc and trans
776     ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_,
777                   relu, relu6);
778     for (oh = start_oh + 1; oh < end_oh - 1; oh++) {
779       // input trans
780       ConvDw3x3Row(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_);
781       // dst calc and trans
782       ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_,
783                     relu, relu6);
784     }
785     if (oh == conv_param->output_h_ - 1) {
786       // input trans
787       ConvDw3x3Bottom(lines, conv_param->output_w_, conv_param->input_channel_);
788     } else {
789       // input trans
790       ConvDw3x3Row(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_);
791     }
792     // dst calc and trans
793     ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_,
794                   relu, relu6);
795   }
796 }
797 #endif
798 
799 /*conv depthwise indirect buffer fp32 begin*/
CheckConvDwUseIndirectBuffer(const ConvParameter * conv_param)800 bool CheckConvDwUseIndirectBuffer(const ConvParameter *conv_param) {
801   bool use_indirect = (conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3) ||
802                       (conv_param->kernel_h_ == 5 && conv_param->kernel_w_ == 5);
803   return use_indirect;
804 }
805 
ConvDwInitIndirection(float ** indirect_buffer,float * src,float * zero_ptr,const ConvParameter * conv_param,int step_h,int step_w)806 void ConvDwInitIndirection(float **indirect_buffer, float *src, float *zero_ptr, const ConvParameter *conv_param,
807                            int step_h, int step_w) {
808 #ifdef ENABLE_AVX
809   int div = C8NUM;
810 #else
811   int div = C4NUM;
812 #endif
813 
814   int ic_div = UP_DIV(conv_param->input_channel_, div) * div;
815   for (int b = 0; b < conv_param->output_batch_; b++) {
816     float **indirect = indirect_buffer + b * conv_param->output_h_ * step_h;
817     float *input = src + b * conv_param->input_h_ * conv_param->input_w_ * ic_div;
818     for (int oh = 0; oh < conv_param->output_h_; oh++) {
819       for (int kh = 0; kh < conv_param->kernel_h_; kh++) {
820         int ih = oh * conv_param->stride_h_ + kh * conv_param->dilation_h_ - conv_param->pad_u_;
821         if (ih < conv_param->input_h_ && ih >= 0) {
822           for (int ow = 0; ow < conv_param->output_w_; ow++) {
823             for (int kw = 0; kw < conv_param->kernel_w_; kw++) {
824               int iw = ow * conv_param->stride_w_ + kw * conv_param->dilation_w_ - conv_param->pad_l_;
825               int index = oh * step_h + ow * step_w * conv_param->kernel_h_ + kw * conv_param->kernel_h_ + kh;
826               if (iw < conv_param->input_w_ && iw >= 0) {
827                 indirect[index] = input + (ih * conv_param->input_w_ + iw) * ic_div;
828               } else {
829                 indirect[index] = zero_ptr;
830               }
831             }
832           }
833         } else {
834           for (int ow = 0; ow < conv_param->output_w_; ow++) {
835             for (int kw = 0; kw < conv_param->kernel_w_; kw++) {
836               int index = oh * step_h + ow * step_w * conv_param->kernel_h_ + kw * conv_param->kernel_h_ + kh;
837               indirect[index] = zero_ptr;
838             }
839           }
840         }
841       }
842     }
843   }
844 }
845 
846 #if !defined(ENABLE_ARM64) && !defined(ENABLE_AVX)
ConvDwFp32IndirectRow(float * output,float ** input,const float * weights,const float * bias,int channels,int output_width,int input_stride,bool relu,bool relu6,int kernel)847 void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
848                            int output_width, int input_stride, bool relu, bool relu6, int kernel) {
849   do {
850     float **in = input;
851     size_t c = (size_t)channels;
852     const float *w = weights;
853     float *out = output;
854     memcpy(out, bias, channels * (int)sizeof(float));
855     for (; c >= C4NUM; c -= C4NUM) {
856       for (int i = 0; i < C4NUM; i++) {
857         for (int k = 0; k < kernel; k++) {
858           out[i] += in[k][i] * w[i + k * C4NUM];
859         }
860       }
861       w += kernel * C4NUM;
862       out += C4NUM;
863       for (int k = 0; k < kernel; k++) {
864         in[k] += C4NUM;
865       }
866     }
867     for (int i = 0; i < c; i++) {
868       for (int k = 0; k < kernel; k++) {
869         out[i] += in[k][i] * w[i + k * C4NUM];
870       }
871     }
872     if (relu) {
873       Fp32Relu(output, channels, output);
874     }
875     if (relu6) {
876       Fp32Relu6(output, channels, output);
877     }
878     output += channels;
879     input = input + input_stride;
880   } while (--output_width != 0);
881 }
882 #endif
883 
884 #ifdef ENABLE_ARM64
ConvDwFp32IndirectRow(float * output,float ** input,const float * weights,const float * bias,int channels,int output_width,int input_stride,bool relu,bool relu6,int kernel)885 void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
886                            int output_width, int input_stride, bool relu, bool relu6, int kernel) {
887   if (kernel == 9) {
888     ConvDwFp32Indirect3x3(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu,
889                           relu6);
890   } else if (kernel == 25) {
891     ConvDwFp32Indirect5x5(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu,
892                           relu6);
893   }
894 }
895 #endif
896 
897 #ifdef ENABLE_AVX
ConvDwFp32IndirectRow(float * output,float ** input,const float * weights,const float * bias,int channels,int output_width,int input_stride,bool relu,bool relu6,int kernel)898 void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
899                            int output_width, int input_stride, bool relu, bool relu6, int kernel) {
900   if (kernel == 9) {
901     ConvDwFp32Avx3x3(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu, relu6);
902   } else if (kernel == 25) {
903     ConvDwFp32Avx5x5(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu, relu6);
904   }
905 }
906 #endif
907 
ConvDwIndirection(float * output_data,float ** indirect_buffer,const float * weight_data,const float * bias_data,float * zero_ptr,const ConvParameter * conv_param,int task_id)908 void ConvDwIndirection(float *output_data, float **indirect_buffer, const float *weight_data, const float *bias_data,
909                        float *zero_ptr, const ConvParameter *conv_param, int task_id) {
910   if (conv_param->thread_num_ == 0) {
911     return;
912   }
913   int step_w = conv_param->dilation_w_ == 1 ? conv_param->stride_w_ : conv_param->kernel_w_;
914   int step_h =
915     (conv_param->kernel_h_ * conv_param->kernel_w_) + (conv_param->output_w_ - 1) * step_w * conv_param->kernel_h_;
916   int input_stride = conv_param->kernel_h_ * step_w;
917 
918   bool relu = conv_param->act_type_ == ActType_Relu;
919   bool relu6 = conv_param->act_type_ == ActType_Relu6;
920 
921   int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_);
922   int h_start = h_step * task_id;
923   int h_end = MSMIN(h_start + h_step, conv_param->output_h_);
924 
925   for (int b = 0; b < conv_param->output_batch_; b++) {
926     float **indirect_b = indirect_buffer + b * conv_param->output_h_ * step_h;
927     float *outout_b = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
928     for (int oh = h_start; oh < h_end; oh++) {
929       float **indirect = indirect_b + oh * step_h;
930       float *output_h = outout_b + oh * conv_param->output_w_ * conv_param->output_channel_;
931       if (conv_param->kernel_w_ == 3) {
932         ConvDwFp32IndirectRow(output_h, indirect, weight_data, bias_data, conv_param->output_channel_,
933                               conv_param->output_w_, input_stride, relu, relu6, 9);
934       } else if (conv_param->kernel_w_ == 5) {
935         ConvDwFp32IndirectRow(output_h, indirect, weight_data, bias_data, conv_param->output_channel_,
936                               conv_param->output_w_, input_stride, relu, relu6, 25);
937       }
938     }
939   }
940 }
941 /*conv depthwise indirect buffer fp32 end*/
942 
943 /*deconv depthwise fp32 begin*/
DeconvDwBorderPixel(float * dst,const float * src,const float * weight,int height,int width,int in_kh_step,int in_kw_step,int kernel_w_step)944 void DeconvDwBorderPixel(float *dst, const float *src, const float *weight, int height, int width, int in_kh_step,
945                          int in_kw_step, int kernel_w_step) {
946   float *dst_kh = dst;
947   const float *weight_kh = weight;
948   for (int kh = 0; kh < height; kh++) {
949     float *dst_kw = dst_kh;
950     const float *weight_kw = weight_kh;
951     for (int kw = 0; kw < width; kw++) {
952 #ifdef ENABLE_ARM64
953       float32x4_t src_4 = vld1q_f32(src);
954       float32x4_t weight_4 = vld1q_f32(weight_kw);
955       float32x4_t dst_4 = vld1q_f32(dst_kw);
956       dst_4 = vfmaq_f32(dst_4, src_4, weight_4);
957       vst1q_f32(dst_kw, dst_4);
958 #else
959       for (int c = 0; c < C4NUM; c++) {
960         dst_kw[c] += src[c] * weight_kw[c];
961       }
962 #endif
963       dst_kw += in_kw_step;
964       weight_kw += C4NUM;
965     }  // kernel_w loop
966     dst_kh += in_kh_step;
967     weight_kh += kernel_w_step;
968   }  // kernel_h loop
969 }
970 
DeconvDwBorder(float * dst,const float * src,const float * weight,int top,int bottom,int left,int right,const ConvParameter * conv_param,const SlidingWindowParam * sliding)971 void DeconvDwBorder(float *dst, const float *src, const float *weight, int top, int bottom, int left, int right,
972                     const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
973   if (conv_param->dilation_h_ == 0 || conv_param->dilation_w_ == 0) {
974     return;
975   }
976   const float *src_h = src + top * sliding->out_h_step_;
977   for (int ih = top; ih < bottom; ih++) {
978     int oh = ih * conv_param->stride_h_ - conv_param->pad_u_;
979     int start_kh = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
980     int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
981     float *dst_h = dst + oh * sliding->in_h_step_;
982 
983     const float *src_kernel = src_h + left * sliding->block_channel_;
984     for (int iw = left; iw < right; iw++) {
985       int ow = iw * conv_param->stride_w_ - conv_param->pad_l_;
986       int start_kw = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_));
987       int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_));
988       float *dst_w = dst_h + ow * sliding->block_channel_;
989 
990       const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM;
991       float *dst_kernel = dst_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_;
992 #ifdef ENABLE_ARM64
993       DeconvDwFp32Border(dst_kernel, src_kernel, weight_kernel, end_kh - start_kh, end_kw - start_kw,
994                          sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float),
995                          conv_param->kernel_w_ * C4NUM * sizeof(float));
996 #else
997       DeconvDwBorderPixel(dst_kernel, src_kernel, weight_kernel, end_kh - start_kh, end_kw - start_kw,
998                           sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C4NUM);
999 #endif
1000       src_kernel += sliding->block_channel_;
1001     }  // width loop
1002     src_h += sliding->out_h_step_;
1003   }  // height loop
1004 }
1005 
1006 #ifndef ENABLE_ARM64
DeconvDwCenter(float * dst,const float * src,const float * weight,int height,int width,int kernel_h,int kernel_w,int out_h_step,int block_channel,int in_sh_step,int in_sw_step,int in_kh_step,int in_kw_step)1007 void DeconvDwCenter(float *dst, const float *src, const float *weight, int height, int width, int kernel_h,
1008                     int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step, int in_kh_step,
1009                     int in_kw_step) {
1010   float *dst_h = dst;
1011   const float *src_h = src;
1012   for (int oh = 0; oh < height; oh++) {
1013     float *dst_w = dst_h;
1014     const float *src_w = src_h;
1015     for (int ow = 0; ow < width; ow++) {
1016       float *dst_kh = dst_w;
1017       const float *weight_kh = weight;
1018       for (int kh = 0; kh < kernel_h; kh++) {
1019         float *dst_kw = dst_kh;
1020         const float *weight_kw = weight_kh;
1021         for (int kw = 0; kw < kernel_w; kw++) {
1022           for (int c = 0; c < C4NUM; c++) {
1023             dst_kw[c] += src_w[c] * weight_kw[c];
1024           }
1025           dst_kw += in_kw_step;
1026           weight_kw += C4NUM;
1027         }  // kernel_w loop
1028         dst_kh += in_kh_step;
1029         weight_kh += kernel_w * C4NUM;
1030       }  // kernel_h loop
1031       dst_w += in_sw_step;
1032       src_w += block_channel;
1033     }  // dst_width loop
1034     dst_h += in_sh_step;
1035     src_h += out_h_step;
1036   }  // dst_height loop
1037 }
1038 #endif
1039 
DeconvDwPost(float * dst,const float * bias,int block_channel,const ConvParameter * conv_param)1040 void DeconvDwPost(float *dst, const float *bias, int block_channel, const ConvParameter *conv_param) {
1041   bool relu = conv_param->act_type_ == ActType_Relu;
1042   bool relu6 = conv_param->act_type_ == ActType_Relu6;
1043   float *dst_k = dst;
1044   for (int k = 0; k < conv_param->output_h_ * conv_param->output_w_; k++) {
1045     for (int c = 0; c < C4NUM; c++) {
1046       dst_k[c] += bias[c];
1047       dst_k[c] = (relu) ? (MSMAX(0, dst_k[c])) : (dst_k[c]);
1048       dst_k[c] = (relu6) ? (MSMIN(6, MSMAX(0, dst_k[c]))) : (dst_k[c]);
1049     }
1050     dst_k += block_channel;
1051   }
1052 }
1053 
1054 // deconv depthwise fp32: sliding window
DeconvDwSWFp32(float * output_data,const float * input_data,const float * weight_data,const float * bias_data,const ConvParameter * conv_param,const SlidingWindowParam * sliding,int task_id)1055 void DeconvDwSWFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
1056                     const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) {
1057   const float *src = input_data;
1058   float *dst = output_data;
1059   if (conv_param->thread_num_ == 0) {
1060     return;
1061   }
1062   for (int b = 0; b < conv_param->output_batch_; b++) {
1063     for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) {
1064       const float *src_data = src + oc * C4NUM;
1065       float *dst_data = dst + oc * C4NUM;
1066       const float *weight = weight_data + oc * sliding->kernel_step_;
1067       const float *bias = bias_data + oc * C4NUM;
1068       DeconvDwBorder(dst_data, src_data, weight, 0, sliding->top_, 0, conv_param->input_w_, conv_param, sliding);
1069       DeconvDwBorder(dst_data, src_data, weight, sliding->bottom_, conv_param->input_h_, 0, conv_param->input_w_,
1070                      conv_param, sliding);
1071       DeconvDwBorder(dst_data, src_data, weight, sliding->top_, sliding->bottom_, 0, sliding->left_, conv_param,
1072                      sliding);
1073       DeconvDwBorder(dst_data, src_data, weight, sliding->top_, sliding->bottom_, sliding->right_, conv_param->input_w_,
1074                      conv_param, sliding);
1075 
1076       if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) {
1077         int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_;
1078         int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
1079         float *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_;
1080         const float *in_t = src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;
1081 
1082 #if defined(ENABLE_ARM) || defined(ENABLE_SSE)
1083         DeconvDwFp32Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
1084                            conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float),
1085                            sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float),
1086                            sliding->in_sw_step_ * sizeof(float), sliding->in_kh_step_ * sizeof(float),
1087                            sliding->in_kw_step_ * sizeof(float));
1088 #else
1089         DeconvDwCenter(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
1090                        conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_,
1091                        sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_);
1092 #endif
1093       }
1094       DeconvDwPost(dst_data, bias, sliding->block_channel_, conv_param);
1095     }  // output C4 loop
1096     src += sliding->out_step_;
1097     dst += sliding->in_step_;
1098   }  // batch loop
1099   // output nhwc4
1100 }
1101 /*deconv depthwise fp32 end*/
1102 
1103 #ifdef ENABLE_AVX
DepthwiseBorderAvxFp32(float * dst,const float * src,const float * weight,const float * bias,int top,int left,int right,const ConvParameter * conv_param,const SlidingWindowParam * sw_param,const DepthwiseSWKernel kernel,int act_type,int ow_bock,int oc_block)1104 void DepthwiseBorderAvxFp32(float *dst, const float *src, const float *weight, const float *bias, int top, int left,
1105                             int right, const ConvParameter *conv_param, const SlidingWindowParam *sw_param,
1106                             const DepthwiseSWKernel kernel, int act_type, int ow_bock, int oc_block) {
1107   // dw border compate
1108   int ih = top * conv_param->stride_h_ - conv_param->pad_u_;
1109   int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
1110   int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
1111   const float *src_h = src + ih * sw_param->in_h_step_;
1112   float *dst_kernel = dst + left * sw_param->block_channel_;
1113   for (int ow = left; ow < right; ow += ow_bock) {
1114     int iw = ow * conv_param->stride_w_ - conv_param->pad_l_;
1115     int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
1116     int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
1117     const float *src_w = src_h + iw * sw_param->block_channel_;
1118     const float *src_kernel = src_w + start_kh * sw_param->in_kh_step_ + start_kw * sw_param->in_kw_step_;
1119     const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C8NUM * oc_block;
1120     kernel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, act_type, ow_bock,
1121            oc_block, sw_param->block_channel_, sw_param->in_kw_step_, sw_param->in_kh_step_, sw_param->in_sw_step_,
1122            (conv_param->kernel_w_ - end_kw + start_kw) * C8NUM * oc_block);
1123     dst_kernel += ow_bock * sw_param->block_channel_;
1124   }  // width loop
1125 }
1126 
DepthwiseSWAvxFp32(float * output_data,const float * input_data,const float * weight_data,const float * bias_data,const ConvParameter * conv_param,const SlidingWindowParam * sw_param,int task_id)1127 void DepthwiseSWAvxFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
1128                         const ConvParameter *conv_param, const SlidingWindowParam *sw_param, int task_id) {
1129   int oh_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_);
1130   int oh_start = oh_step * task_id;
1131   int oh_end = MSMIN(oh_start + oh_step, conv_param->output_h_);
1132   if (oh_start >= oh_end) {
1133     return;
1134   }
1135   // depthwise sw in x86 avx instructions
1136   int oc_tile_ = C8NUM;  // oc in algin to C8NUM in x86_64_avx
1137   int act_type = 0;
1138   if (conv_param->act_type_ == ActType_Relu6) {
1139     act_type += 1;
1140   }
1141   if (conv_param->act_type_ == ActType_Relu || conv_param->act_type_ == ActType_Relu6) {
1142     act_type += 2;
1143   }
1144   int kernel_h = conv_param->kernel_h_;
1145   int kernel_w = conv_param->kernel_w_;
1146   int output_w = conv_param->output_w_;
1147   int oc_algin = sw_param->block_channel_;
1148   int oc_num = sw_param->c_block_;
1149   int in_step = sw_param->in_step_;
1150   int out_step = sw_param->out_step_;
1151   int in_sw_step = sw_param->in_sw_step_;
1152   int in_kw_step = sw_param->in_kw_step_;
1153   int in_kh_step = sw_param->in_kh_step_;
1154   int in_sh_step = sw_param->in_sh_step_;
1155   int out_right = sw_param->right_;
1156   int out_left = sw_param->left_;
1157   int out_top = sw_param->top_;
1158   int out_bottom = sw_param->bottom_;
1159   int kernel_step = sw_param->kernel_step_;
1160   int out_h_step = sw_param->out_h_step_;
1161   int in_h_start = out_top * conv_param->stride_h_ - conv_param->pad_u_;
1162   int in_w_start = out_left * conv_param->stride_w_ - conv_param->pad_l_;
1163   int in_start = in_h_start * sw_param->in_h_step_ + in_w_start * oc_algin;
1164   const int ow_block_num[4] = {8, 4, 4, 3};
1165   const DepthwiseSWKernel kernel[4][2] = {{DepthwiseSW1x8Kernel, DepthwiseSW8x8Kernel},
1166                                           {DepthwiseSW1x16Kernel, DepthwiseSW4x16Kernel},
1167                                           {DepthwiseSW1x24Kernel, DepthwiseSW4x24Kernel},
1168                                           {DepthwiseSW1x32Kernel, DepthwiseSW3x32Kernel}};
1169   for (int b = 0; b < conv_param->output_batch_; b++) {
1170     for (int oh = oh_start; oh < oh_end; ++oh) {
1171       float *dst_oh = output_data + oh * out_h_step;
1172       const float *src_h = input_data + in_start + (oh - out_top) * in_sh_step;
1173       int oc_block = 0;
1174       const float *bias = bias_data;
1175       for (int oc = 0; oc < oc_num; oc += oc_block) {
1176         oc_block = MSMIN(C4NUM, oc_num - oc);  // 4 3 2 1
1177         int oc_step = oc * oc_tile_;
1178         const float *weight = weight_data + oc * kernel_step;
1179         if (bias != NULL) {
1180           bias = bias_data + oc_step;
1181         }
1182         float *dst_w = dst_oh + oc_step;
1183         const DepthwiseSWKernel kernel_border = kernel[oc_block - 1][0];
1184         if (oh < out_top || oh >= out_bottom) {  // oh in up or down border
1185           DepthwiseBorderAvxFp32(dst_w, input_data + oc_step, weight, bias, oh, 0, output_w, conv_param, sw_param,
1186                                  kernel_border, act_type, 1, oc_block);
1187         } else {  // oh in center
1188           // ow in right
1189           DepthwiseBorderAvxFp32(dst_w, input_data + oc_step, weight, bias, oh, 0, out_left, conv_param, sw_param,
1190                                  kernel_border, act_type, 1, oc_block);
1191           // ow in center
1192           const float *src_w = src_h + oc_step;
1193           int ow_block = ow_block_num[oc_block - 1];                 // 8 4 4 3
1194           for (int ow = out_left; ow < out_right; ow += ow_block) {  // left ~ right
1195             if (ow_block > out_right - ow) {                         // ow is not enough and process one ow
1196               ow_block = 1;
1197             }
1198             kernel[oc_block - 1][ow_block / ow_block_num[oc_block - 1]](
1199               dst_w + ow * oc_algin, src_w, weight, bias, kernel_h, kernel_w, act_type, ow_block, oc_block, oc_algin,
1200               in_kw_step, in_kh_step, in_sw_step, 0);
1201             src_w += ow_block * in_sw_step;
1202           }
1203           // ow in left
1204           DepthwiseBorderAvxFp32(dst_w, input_data + oc_step, weight, bias, oh, out_right, output_w, conv_param,
1205                                  sw_param, kernel_border, act_type, 1, oc_block);
1206         }
1207       }
1208     }  // output h loop
1209     input_data += in_step;
1210     output_data += out_step;
1211   }  // batch loop
1212 }
1213 
1214 #ifdef ENABLE_DEBUG
DepthwiseSWWxKKernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1215 void DepthwiseSWWxKKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1216                           size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1217                           size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1218   __m256 dst_data[12];
1219   __m256 src_data;
1220   const float *src_kh[12];
1221   const float *src_kw[12];
1222   __m256 weight_data[4];
1223   for (int i = 0; i < ow_block; ++i) {
1224     if (bias != NULL) {
1225       for (int j = 0; j < oc_block; ++j) {
1226         dst_data[i * oc_block + j] = _mm256_loadu_ps(bias + j * 8);
1227       }
1228     } else {
1229       for (int j = 0; j < oc_block; ++j) {
1230         dst_data[i * oc_block + j] = _mm256_set1_ps(0.0f);
1231       }
1232     }
1233     src_kh[i] = src + i * in_sw_step;
1234     src_kw[i] = NULL;
1235   }
1236   const float *weight_kernel = weight;
1237   for (int kh = 0; kh < kernel_h; kh++) {
1238     for (int i = 0; i < ow_block; ++i) {
1239       src_kw[i] = src_kh[i];
1240     }
1241     for (int kw = 0; kw < kernel_w; kw++) {
1242       for (int j = 0; j < oc_block; ++j) {
1243         weight_data[j] = _mm256_loadu_ps(weight_kernel + j * C8NUM);
1244       }
1245       for (int i = 0; i < ow_block; ++i) {  // loop ow
1246         for (int j = 0; j < oc_block; ++j) {
1247           src_data = _mm256_loadu_ps(src_kw[i] + j * C8NUM);
1248           dst_data[i * oc_block + j] += src_data * weight_data[j];
1249         }
1250       }
1251       for (int i = 0; i < ow_block; ++i) {
1252         src_kw[i] += in_kw_step;  // ic8 * dilation_w
1253       }
1254       weight_kernel += oc_block * C8NUM;
1255     }  // kernel_w loop
1256     weight_kernel += kw_remainder;
1257     for (int i = 0; i < ow_block; ++i) {
1258       src_kh[i] += in_kh_step;  //
1259     }
1260   }  // kernel_h loop
1261   // add bias and relu
1262   for (int i = 0; i < ow_block; ++i) {
1263     for (int j = 0; j < oc_block; ++j) {
1264       if (0x1 & act_flag) {  // relu6
1265         dst_data[i * oc_block + j] = _mm256_min_ps(dst_data[i * oc_block + j], _mm256_set1_ps(6.0f));
1266       }
1267       if (0x2 & act_flag) {  // relu
1268         dst_data[i * oc_block + j] = _mm256_max_ps(dst_data[i * oc_block + j], _mm256_set1_ps(0.0f));
1269       }
1270       _mm256_storeu_ps(dst + i * oc_algin + j * C8NUM, dst_data[i * oc_block + j]);
1271     }
1272   }
1273 }
1274 #endif
1275 
DepthwiseSW3x32Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1276 void DepthwiseSW3x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1277                            size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1278                            size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1279   in_kh_step *= sizeof(float);
1280   in_sw_step *= sizeof(float);
1281   in_kw_step *= sizeof(float);
1282   oc_algin *= sizeof(float);
1283   kw_remainder *= sizeof(float);
1284   asm volatile(
1285     "cmpq $0, %2\n"
1286     "je 0f\n"
1287     "vmovups (%2), %%ymm0\n"
1288     "vmovups 0x20(%2), %%ymm1\n"
1289     "vmovups 0x40(%2), %%ymm2\n"
1290     "vmovups 0x60(%2), %%ymm3\n"
1291     "vmovups (%2), %%ymm4\n"
1292     "vmovups 0x20(%2), %%ymm5\n"
1293     "vmovups 0x40(%2), %%ymm6\n"
1294     "vmovups 0x60(%2), %%ymm7\n"
1295     "vmovups (%2), %%ymm8\n"
1296     "vmovups 0x20(%2), %%ymm9\n"
1297     "vmovups 0x40(%2), %%ymm10\n"
1298     "vmovups 0x60(%2), %%ymm11\n"
1299     "jmp 1f\n"
1300     "0:\n"
1301     "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1302     "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1303     "vxorps %%ymm2, %%ymm2, %%ymm2\n"
1304     "vxorps %%ymm3, %%ymm3, %%ymm3\n"
1305     "vxorps %%ymm4, %%ymm4, %%ymm4\n"
1306     "vxorps %%ymm5, %%ymm5, %%ymm5\n"
1307     "vxorps %%ymm6, %%ymm6, %%ymm6\n"
1308     "vxorps %%ymm7, %%ymm7, %%ymm7\n"
1309     "vxorps %%ymm8, %%ymm8, %%ymm8\n"
1310     "vxorps %%ymm9, %%ymm9, %%ymm9\n"
1311     "vxorps %%ymm10, %%ymm10, %%ymm10\n"
1312     "vxorps %%ymm11, %%ymm11, %%ymm11\n"
1313     "1:\n"              // LoopH
1314     "movq %4, %%rsi\n"  // width
1315     "movq %0, %%rcx\n"  // src_h
1316     "2:\n"              // LoopW
1317 
1318     "vmovups (%1), %%ymm12\n"
1319     "vmovups (%%rcx), %%ymm13\n"
1320     "vmovups (%%rcx, %7), %%ymm14\n"
1321     "vmovups (%%rcx, %7, 2), %%ymm15\n"
1322     "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n"
1323     "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n"
1324     "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n"
1325 
1326     "vmovups 0x20(%1), %%ymm12\n"
1327     "vmovups 0x20(%%rcx), %%ymm13\n"
1328     "vmovups 0x20(%%rcx, %7), %%ymm14\n"
1329     "vmovups 0x20(%%rcx, %7, 2), %%ymm15\n"
1330     "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n"
1331     "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n"
1332     "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n"
1333 
1334     "vmovups 0x40(%1), %%ymm12\n"
1335     "vmovups 0x40(%%rcx), %%ymm13\n"
1336     "vmovups 0x40(%%rcx, %7), %%ymm14\n"
1337     "vmovups 0x40(%%rcx, %7, 2), %%ymm15\n"
1338     "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n"
1339     "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n"
1340     "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n"
1341 
1342     "vmovups 0x60(%1), %%ymm12\n"
1343     "vmovups 0x60(%%rcx), %%ymm13\n"
1344     "vmovups 0x60(%%rcx, %7), %%ymm14\n"
1345     "vmovups 0x60(%%rcx, %7, 2), %%ymm15\n"
1346     "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n"
1347     "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n"
1348     "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n"
1349     "addq $128, %1\n"
1350 
1351     "addq %5, %%rcx\n"  // in_kw_step
1352     "dec %%rsi\n"
1353     "jg 2b\n"
1354 
1355     "addq %6, %0\n"  // in_kh_step
1356     "addq %8, %1\n"
1357     "dec %3\n"
1358     "jg 1b\n"
1359     :
1360     : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step),  // 5
1361       "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder)                               // 8
1362     : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9",
1363       "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
1364 
1365   asm volatile(
1366     "and $0x3, %%eax\n"
1367     "je 0f\n"
1368     // Relu
1369     "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1370     "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1371     "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1372     "vmaxps %%ymm12, %%ymm2, %%ymm2\n"
1373     "vmaxps %%ymm12, %%ymm3, %%ymm3\n"
1374     "vmaxps %%ymm12, %%ymm4, %%ymm4\n"
1375     "vmaxps %%ymm12, %%ymm5, %%ymm5\n"
1376     "vmaxps %%ymm12, %%ymm6, %%ymm6\n"
1377     "vmaxps %%ymm12, %%ymm7, %%ymm7\n"
1378     "vmaxps %%ymm12, %%ymm8, %%ymm8\n"
1379     "vmaxps %%ymm12, %%ymm9, %%ymm9\n"
1380     "vmaxps %%ymm12, %%ymm10, %%ymm10\n"
1381     "vmaxps %%ymm12, %%ymm11, %%ymm11\n"
1382 
1383     "and $0x1, %%eax\n"
1384     "je 0f\n"
1385     // relu6
1386     "mov $0x40C00000, %%ecx\n"
1387     "vmovd %%ecx, %%xmm14\n"
1388     "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1389     "vminps %%ymm14, %%ymm0, %%ymm0\n"
1390     "vminps %%ymm14, %%ymm1, %%ymm1\n"
1391     "vminps %%ymm14, %%ymm2, %%ymm2\n"
1392     "vminps %%ymm14, %%ymm3, %%ymm3\n"
1393     "vminps %%ymm14, %%ymm4, %%ymm4\n"
1394     "vminps %%ymm14, %%ymm5, %%ymm5\n"
1395     "vminps %%ymm14, %%ymm6, %%ymm6\n"
1396     "vminps %%ymm14, %%ymm7, %%ymm7\n"
1397     "vminps %%ymm14, %%ymm8, %%ymm8\n"
1398     "vminps %%ymm14, %%ymm9, %%ymm9\n"
1399     "vminps %%ymm14, %%ymm10, %%ymm10\n"
1400     "vminps %%ymm14, %%ymm11, %%ymm11\n"
1401 
1402     "0:\n"
1403     "vmovups %%ymm0, (%2)\n"  // dst_0
1404     "vmovups %%ymm1, 0x20(%2)\n"
1405     "vmovups %%ymm2, 0x40(%2)\n"
1406     "vmovups %%ymm3, 0x60(%2)\n"
1407     "vmovups %%ymm4, (%2, %1, 1)\n"
1408     "vmovups %%ymm5, 0x20(%2, %1, 1)\n"
1409     "vmovups %%ymm6, 0x40(%2, %1, 1)\n"
1410     "vmovups %%ymm7, 0x60(%2, %1, 1)\n"
1411     "vmovups %%ymm8, (%2, %1, 2)\n"
1412     "vmovups %%ymm9, 0x20(%2, %1, 2)\n"
1413     "vmovups %%ymm10, 0x40(%2, %1, 2)\n"
1414     "vmovups %%ymm11, 0x60(%2, %1, 2)\n"
1415     :
1416     : "a"(act_flag), "r"(oc_algin), "r"(dst)
1417     : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
1418       "%ymm11", "%ymm12", "%ymm14");
1419 }
1420 
DepthwiseSW1x32Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1421 void DepthwiseSW1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1422                            size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1423                            size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1424   in_kh_step *= sizeof(float);
1425   in_kw_step *= sizeof(float);
1426   oc_algin *= sizeof(float);
1427   kw_remainder *= sizeof(float);
1428   asm volatile(
1429     "cmpq $0, %2\n"
1430     "je 0f\n"
1431     "vmovups (%2), %%ymm0\n"
1432     "vmovups 0x20(%2), %%ymm1\n"
1433     "vmovups 0x40(%2), %%ymm2\n"
1434     "vmovups 0x60(%2), %%ymm3\n"
1435     "jmp 1f\n"
1436     "0:\n"
1437     "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1438     "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1439     "vxorps %%ymm2, %%ymm2, %%ymm2\n"
1440     "vxorps %%ymm3, %%ymm3, %%ymm3\n"
1441     "1:\n"              // LoopH
1442     "movq %4, %%rsi\n"  // width
1443     "movq %0, %%rcx\n"  // src_h
1444     "2:\n"              // Loopw
1445     "vmovups (%%rcx), %%ymm4\n"
1446     "vmovups 0x20(%%rcx), %%ymm5\n"
1447     "vmovups 0x40(%%rcx), %%ymm6\n"
1448     "vmovups 0x60(%%rcx), %%ymm7\n"
1449     // Weight data is loaded directly from memory instead of into registers for calculation.
1450     "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1451     "vfmadd231ps 0x20(%1), %%ymm5, %%ymm1\n"
1452     "vfmadd231ps 0x40(%1), %%ymm6, %%ymm2\n"
1453     "vfmadd231ps 0x60(%1), %%ymm7, %%ymm3\n"
1454     "addq $128, %1\n"
1455 
1456     "addq %5, %%rcx\n"  // in_kw_step
1457     "dec %%rsi\n"
1458     "jg 2b\n"
1459 
1460     "addq %6, %0\n"  // in_kh_step
1461     "addq %7, %1\n"
1462     "dec %3\n"
1463     "jg 1b\n"
1464     :
1465     : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step),  // 5
1466       "r"(in_kh_step), "r"(kw_remainder)                                                // 7
1467     : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7");
1468 
1469   asm volatile(
1470     "and $0x3, %%eax\n"
1471     "je 0f\n"
1472     // Relu
1473     "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1474     "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1475     "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1476     "vmaxps %%ymm12, %%ymm2, %%ymm2\n"
1477     "vmaxps %%ymm12, %%ymm3, %%ymm3\n"
1478 
1479     "and $0x1, %%eax\n"
1480     "je 0f\n"
1481     // relu6
1482     "mov $0x40C00000, %%ecx\n"
1483     "vmovd %%ecx, %%xmm14\n"
1484     "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1485     "vminps %%ymm14, %%ymm0, %%ymm0\n"
1486     "vminps %%ymm14, %%ymm1, %%ymm1\n"
1487     "vminps %%ymm14, %%ymm2, %%ymm2\n"
1488     "vminps %%ymm14, %%ymm3, %%ymm3\n"
1489 
1490     "0:\n"
1491     "vmovups %%ymm0, (%2)\n"  // dst_0
1492     "vmovups %%ymm1, 0x20(%2)\n"
1493     "vmovups %%ymm2, 0x40(%2)\n"
1494     "vmovups %%ymm3, 0x60(%2)\n"
1495     :
1496     : "a"(act_flag), "r"(oc_algin), "r"(dst)
1497     : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm12", "%ymm14");
1498 }
1499 
DepthwiseSW4x24Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1500 void DepthwiseSW4x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1501                            size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1502                            size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1503   in_kh_step *= sizeof(float);
1504   in_kw_step *= sizeof(float);
1505   in_sw_step *= sizeof(float);
1506   kw_remainder *= sizeof(float);
1507   size_t src_3_step = 3 * in_sw_step;
1508   float *dst_3 = dst + 3 * oc_algin;
1509   oc_algin *= sizeof(float);
1510   asm volatile(
1511     "cmpq $0, %2\n"
1512     "je 0f\n"
1513     "vmovups (%2), %%ymm0\n"
1514     "vmovups 0x20(%2), %%ymm1\n"
1515     "vmovups 0x40(%2), %%ymm2\n"
1516     // We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction.
1517     "vmovups (%2), %%ymm3\n"
1518     "vmovups 0x20(%2), %%ymm4\n"
1519     "vmovups 0x40(%2), %%ymm5\n"
1520     "vmovups (%2), %%ymm6\n"
1521     "vmovups 0x20(%2), %%ymm7\n"
1522     "vmovups 0x40(%2), %%ymm8\n"
1523     "vmovups (%2), %%ymm9\n"
1524     "vmovups 0x20(%2), %%ymm10\n"
1525     "vmovups 0x40(%2), %%ymm11\n"
1526     "jmp 1f\n"
1527     "0:\n"
1528     "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1529     "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1530     "vxorps %%ymm2, %%ymm2, %%ymm2\n"
1531     "vxorps %%ymm3, %%ymm3, %%ymm3\n"
1532     "vxorps %%ymm4, %%ymm4, %%ymm4\n"
1533     "vxorps %%ymm5, %%ymm5, %%ymm5\n"
1534     "vxorps %%ymm6, %%ymm6, %%ymm6\n"
1535     "vxorps %%ymm7, %%ymm7, %%ymm7\n"
1536     "vxorps %%ymm8, %%ymm8, %%ymm8\n"
1537     "vxorps %%ymm9, %%ymm9, %%ymm9\n"
1538     "vxorps %%ymm10, %%ymm10, %%ymm10\n"
1539     "vxorps %%ymm11, %%ymm11, %%ymm11\n"
1540     "1:\n"              // LoopH
1541     "movq %4, %%rsi\n"  // width
1542     "movq %0, %%rcx\n"  // src_h
1543     "2:\n"              // LoopW
1544     "vmovups (%1), %%ymm12\n"
1545     "vmovups (%%rcx), %%ymm13\n"
1546     "vmovups (%%rcx, %7, 1), %%ymm14\n"
1547     "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n"
1548     "vfmadd231ps %%ymm12, %%ymm14, %%ymm3\n"
1549     "vmovups (%%rcx, %7, 2), %%ymm15\n"
1550     "vmovups (%%rcx, %9), %%ymm13\n"
1551     "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n"
1552     "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n"
1553 
1554     "vmovups 0x20(%1), %%ymm12\n"
1555     "vmovups 0x20(%%rcx), %%ymm13\n"
1556     "vmovups 0x20(%%rcx, %7, 1), %%ymm14\n"
1557     "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n"
1558     "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n"
1559     "vmovups 0x20(%%rcx, %7, 2), %%ymm15\n"
1560     "vmovups 0x20(%%rcx, %9), %%ymm13\n"
1561     "vfmadd231ps %%ymm12, %%ymm15, %%ymm7\n"
1562     "vfmadd231ps %%ymm12, %%ymm13, %%ymm10\n"
1563 
1564     "vmovups 0x40(%1), %%ymm12\n"
1565     "vmovups 0x40(%%rcx), %%ymm13\n"
1566     "vmovups 0x40(%%rcx, %7, 1), %%ymm14\n"
1567     "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n"
1568     "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n"
1569     "vmovups 0x40(%%rcx, %7, 2), %%ymm15\n"
1570     "vmovups 0x40(%%rcx, %9), %%ymm13\n"
1571     "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n"
1572     "vfmadd231ps %%ymm12, %%ymm13, %%ymm11\n"
1573 
1574     "addq $96, %1\n"
1575     "addq %5, %%rcx\n"  // in_kw_step
1576     "dec %%rsi\n"
1577     "jg 2b\n"
1578 
1579     "addq %6, %0\n"  // in_kh_step
1580     "addq %8, %1\n"  // border in sw need to add remainder data
1581     "dec %3\n"
1582     "jg 1b\n"
1583     :
1584     : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step),  // 5
1585       "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder), "r"(src_3_step)              // 9
1586     : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9",
1587       "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
1588 
1589   asm volatile(
1590     "and $0x3, %%eax\n"
1591     "je 0f\n"
1592     // Relu
1593     "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1594     "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1595     "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1596     "vmaxps %%ymm12, %%ymm2, %%ymm2\n"
1597     "vmaxps %%ymm12, %%ymm3, %%ymm3\n"
1598     "vmaxps %%ymm12, %%ymm4, %%ymm4\n"
1599     "vmaxps %%ymm12, %%ymm5, %%ymm5\n"
1600     "vmaxps %%ymm12, %%ymm6, %%ymm6\n"
1601     "vmaxps %%ymm12, %%ymm7, %%ymm7\n"
1602     "vmaxps %%ymm12, %%ymm8, %%ymm8\n"
1603     "vmaxps %%ymm12, %%ymm9, %%ymm9\n"
1604     "vmaxps %%ymm12, %%ymm10, %%ymm10\n"
1605     "vmaxps %%ymm12, %%ymm11, %%ymm11\n"
1606 
1607     "and $0x1, %%eax\n"
1608     "je 0f\n"
1609     // relu6
1610     "mov $0x40C00000, %%ecx\n"
1611     "vmovd %%ecx, %%xmm14\n"
1612     "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1613     "vminps %%ymm14, %%ymm0, %%ymm0\n"
1614     "vminps %%ymm14, %%ymm1, %%ymm1\n"
1615     "vminps %%ymm14, %%ymm2, %%ymm2\n"
1616     "vminps %%ymm14, %%ymm3, %%ymm3\n"
1617     "vminps %%ymm14, %%ymm4, %%ymm4\n"
1618     "vminps %%ymm14, %%ymm5, %%ymm5\n"
1619     "vminps %%ymm14, %%ymm6, %%ymm6\n"
1620     "vminps %%ymm14, %%ymm7, %%ymm7\n"
1621     "vminps %%ymm14, %%ymm8, %%ymm8\n"
1622     "vminps %%ymm14, %%ymm9, %%ymm9\n"
1623     "vminps %%ymm14, %%ymm10, %%ymm10\n"
1624     "vminps %%ymm14, %%ymm11, %%ymm11\n"
1625 
1626     "0:\n"
1627     "vmovups %%ymm0, (%2)\n"  // dst_0
1628     "vmovups %%ymm1, 0x20(%2)\n"
1629     "vmovups %%ymm2, 0x40(%2)\n"
1630     "vmovups %%ymm3, (%2, %1, 1)\n"
1631     "vmovups %%ymm4, 0x20(%2, %1, 1)\n"
1632     "vmovups %%ymm5, 0x40(%2, %1, 1)\n"
1633     "vmovups %%ymm6, (%2, %1, 2)\n"
1634     "vmovups %%ymm7, 0x20(%2, %1, 2)\n"
1635     "vmovups %%ymm8, 0x40(%2, %1, 2)\n"
1636     "vmovups %%ymm9, (%3)\n"  // dst+3
1637     "vmovups %%ymm10, 0x20(%3)\n"
1638     "vmovups %%ymm11, 0x40(%3)\n"
1639     :
1640     : "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3)
1641     : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
1642       "%ymm11", "%ymm12", "%ymm14");
1643 }
1644 
DepthwiseSW1x24Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1645 void DepthwiseSW1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1646                            size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1647                            size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1648   in_kh_step *= sizeof(float);
1649   in_kw_step *= sizeof(float);
1650   oc_algin *= sizeof(float);
1651   kw_remainder *= sizeof(float);
1652   asm volatile(
1653     "cmpq $0, %2\n"
1654     "je 0f\n"
1655     "vmovups (%2), %%ymm0\n"
1656     "vmovups 0x20(%2), %%ymm1\n"
1657     "vmovups 0x40(%2), %%ymm2\n"
1658     "jmp 1f\n"
1659     "0:\n"
1660     "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1661     "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1662     "vxorps %%ymm2, %%ymm2, %%ymm2\n"
1663     "1:\n"              // LoopH
1664     "movq %4, %%rsi\n"  // width
1665     "movq %0, %%rcx\n"  // src_h
1666     "2:\n"              // Loopw
1667     "vmovups (%%rcx), %%ymm4\n"
1668     "vmovups 0x20(%%rcx), %%ymm5\n"
1669     "vmovups 0x40(%%rcx), %%ymm6\n"
1670     // Weight data is loaded directly from memory instead of into registers for calculation.
1671     "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1672     "vfmadd231ps 0x20(%1), %%ymm5, %%ymm1\n"
1673     "vfmadd231ps 0x40(%1), %%ymm6, %%ymm2\n"
1674     "addq $96, %1\n"
1675 
1676     "addq %5, %%rcx\n"  // in_kw_step
1677     "dec %%rsi\n"
1678     "jg 2b\n"
1679 
1680     "addq %6, %0\n"  // in_kh_step
1681     "addq %7, %1\n"  // kw_remainder
1682     "dec %3\n"
1683     "jg 1b\n"
1684     :
1685     : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step),  // 5
1686       "r"(in_kh_step), "r"(kw_remainder)                                                // 7
1687     : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm4", "%ymm5", "%ymm6");
1688 
1689   asm volatile(
1690     "and $0x3, %%eax\n"
1691     "je 0f\n"
1692     // Relu
1693     "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1694     "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1695     "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1696     "vmaxps %%ymm12, %%ymm2, %%ymm2\n"
1697 
1698     "and $0x1, %%eax\n"
1699     "je 0f\n"
1700     // relu6
1701     "mov $0x40C00000, %%ecx\n"
1702     "vmovd %%ecx, %%xmm14\n"
1703     "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1704     "vminps %%ymm14, %%ymm0, %%ymm0\n"
1705     "vminps %%ymm14, %%ymm1, %%ymm1\n"
1706     "vminps %%ymm14, %%ymm2, %%ymm2\n"
1707 
1708     "0:\n"
1709     "vmovups %%ymm0, (%2)\n"  // dst_0
1710     "vmovups %%ymm1, 0x20(%2)\n"
1711     "vmovups %%ymm2, 0x40(%2)\n"
1712     :
1713     : "a"(act_flag), "r"(oc_algin), "r"(dst)
1714     : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm12", "%ymm14");
1715 }
1716 
DepthwiseSW4x16Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1717 void DepthwiseSW4x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1718                            size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1719                            size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1720   in_kh_step *= sizeof(float);
1721   in_kw_step *= sizeof(float);
1722   in_sw_step *= sizeof(float);
1723   kw_remainder *= sizeof(float);
1724   size_t src_3_step = 3 * in_sw_step;
1725   float *dst_3 = dst + 3 * oc_algin;
1726   oc_algin *= sizeof(float);
1727   asm volatile(
1728     "cmpq $0, %2\n"
1729     "je 0f\n"
1730     "vmovups (%2), %%ymm0\n"
1731     "vmovups 0x20(%2), %%ymm1\n"
1732     // We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction.
1733     "vmovups (%2), %%ymm3\n"
1734     "vmovups 0x20(%2), %%ymm4\n"
1735     "vmovups (%2), %%ymm6\n"
1736     "vmovups 0x20(%2), %%ymm7\n"
1737     "vmovups (%2), %%ymm9\n"
1738     "vmovups 0x20(%2), %%ymm10\n"
1739     "jmp 1f\n"
1740     "0:\n"
1741     "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1742     "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1743     "vxorps %%ymm3, %%ymm3, %%ymm3\n"
1744     "vxorps %%ymm4, %%ymm4, %%ymm4\n"
1745     "vxorps %%ymm6, %%ymm6, %%ymm6\n"
1746     "vxorps %%ymm7, %%ymm7, %%ymm7\n"
1747     "vxorps %%ymm9, %%ymm9, %%ymm9\n"
1748     "vxorps %%ymm10, %%ymm10, %%ymm10\n"
1749     "1:\n"              // LoopH
1750     "movq %4, %%rsi\n"  // width
1751     "movq %0, %%rcx\n"  // src_h
1752     "2:\n"              // LoopW
1753     "vmovups (%1), %%ymm12\n"
1754     "vmovups (%%rcx), %%ymm13\n"
1755     "vmovups (%%rcx, %7, 1), %%ymm14\n"
1756     "vmovups (%%rcx, %7, 2), %%ymm15\n"
1757     "vmovups (%%rcx, %9), %%ymm2\n"
1758     "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n"
1759     "vfmadd231ps %%ymm12, %%ymm14, %%ymm3\n"
1760     "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n"
1761     "vfmadd231ps %%ymm12, %%ymm2, %%ymm9\n"
1762 
1763     "vmovups 0x20(%1), %%ymm12\n"
1764     "vmovups 0x20(%%rcx), %%ymm13\n"
1765     "vmovups 0x20(%%rcx, %7, 1), %%ymm14\n"
1766     "vmovups 0x20(%%rcx, %7, 2), %%ymm15\n"
1767     "vmovups 0x20(%%rcx, %9), %%ymm2\n"
1768     "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n"
1769     "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n"
1770     "vfmadd231ps %%ymm12, %%ymm15, %%ymm7\n"
1771     "vfmadd231ps %%ymm12, %%ymm2, %%ymm10\n"
1772 
1773     "addq $64, %1\n"
1774     "addq %5, %%rcx\n"  // in_kw_step
1775     "dec %%rsi\n"
1776     "jg 2b\n"
1777 
1778     "addq %6, %0\n"  // in_kh_step
1779     "addq %8, %1\n"  // border in sw need to add remainder data
1780     "dec %3\n"
1781     "jg 1b\n"
1782     :
1783     : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step),  // 5
1784       "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder), "r"(src_3_step)              // 9
1785     : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm3", "%ymm4", "%ymm6", "%ymm7", "%ymm9", "%ymm10", "%ymm12", "%ymm13",
1786       "%ymm14", "%ymm15");
1787 
1788   asm volatile(
1789     "and $0x3, %%eax\n"
1790     "je 0f\n"
1791     // Relu
1792     "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1793     "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1794     "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1795     "vmaxps %%ymm12, %%ymm3, %%ymm3\n"
1796     "vmaxps %%ymm12, %%ymm4, %%ymm4\n"
1797     "vmaxps %%ymm12, %%ymm6, %%ymm6\n"
1798     "vmaxps %%ymm12, %%ymm7, %%ymm7\n"
1799     "vmaxps %%ymm12, %%ymm9, %%ymm9\n"
1800     "vmaxps %%ymm12, %%ymm10, %%ymm10\n"
1801 
1802     "and $0x1, %%eax\n"
1803     "je 0f\n"
1804     // relu6
1805     "mov $0x40C00000, %%ecx\n"
1806     "vmovd %%ecx, %%xmm14\n"
1807     "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1808     "vminps %%ymm14, %%ymm0, %%ymm0\n"
1809     "vminps %%ymm14, %%ymm1, %%ymm1\n"
1810     "vminps %%ymm14, %%ymm3, %%ymm3\n"
1811     "vminps %%ymm14, %%ymm4, %%ymm4\n"
1812     "vminps %%ymm14, %%ymm6, %%ymm6\n"
1813     "vminps %%ymm14, %%ymm7, %%ymm7\n"
1814     "vminps %%ymm14, %%ymm9, %%ymm9\n"
1815     "vminps %%ymm14, %%ymm10, %%ymm10\n"
1816 
1817     "0:\n"
1818     "vmovups %%ymm0, (%2)\n"  // dst_0
1819     "vmovups %%ymm1, 0x20(%2)\n"
1820     "vmovups %%ymm3, (%2, %1, 1)\n"
1821     "vmovups %%ymm4, 0x20(%2, %1, 1)\n"
1822     "vmovups %%ymm6, (%2, %1, 2)\n"
1823     "vmovups %%ymm7, 0x20(%2, %1, 2)\n"
1824     "vmovups %%ymm9, (%3)\n"  // dst+3
1825     "vmovups %%ymm10, 0x20(%3)\n"
1826     :
1827     : "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3)
1828     : "%ecx", "%ymm0", "%ymm1", "%ymm3", "%ymm4", "%ymm6", "%ymm7", "%ymm9", "%ymm10", "%ymm12", "%ymm14");
1829 }
1830 
DepthwiseSW1x16Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1831 void DepthwiseSW1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1832                            size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1833                            size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1834   in_kh_step *= sizeof(float);
1835   in_kw_step *= sizeof(float);
1836   oc_algin *= sizeof(float);
1837   kw_remainder *= sizeof(float);
1838   asm volatile(
1839     "cmpq $0, %2\n"
1840     "je 0f\n"
1841     "vmovups (%2), %%ymm0\n"
1842     "vmovups 0x20(%2), %%ymm1\n"
1843     "jmp 1f\n"
1844     "0:\n"
1845     "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1846     "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1847     "1:\n"              // LoopH
1848     "movq %4, %%rsi\n"  // width
1849     "movq %0, %%rcx\n"  // src_h
1850     "2:\n"              // Loopw
1851     "vmovups (%%rcx), %%ymm4\n"
1852     "vmovups 0x20(%%rcx), %%ymm5\n"
1853     // Weight data is loaded directly from memory instead of into registers for calculation.
1854     "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1855     "vfmadd231ps 0x20(%1), %%ymm5, %%ymm1\n"
1856     "addq $64, %1\n"
1857 
1858     "addq %5, %%rcx\n"  // in_kw_step
1859     "dec %%rsi\n"
1860     "jg 2b\n"
1861 
1862     "addq %6, %0\n"  // in_kh_step
1863     "addq %7, %1\n"  // kw_remainder
1864     "dec %3\n"
1865     "jg 1b\n"
1866     :
1867     : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step),  // 5
1868       "r"(in_kh_step), "r"(kw_remainder)                                                // 7
1869     : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm4", "%ymm5");
1870 
1871   asm volatile(
1872     "and $0x3, %%eax\n"
1873     "je 0f\n"
1874     // Relu
1875     "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1876     "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1877     "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1878 
1879     "and $0x1, %%eax\n"
1880     "je 0f\n"
1881     // relu6
1882     "mov $0x40C00000, %%ecx\n"
1883     "vmovd %%ecx, %%xmm14\n"
1884     "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1885     "vminps %%ymm14, %%ymm0, %%ymm0\n"
1886     "vminps %%ymm14, %%ymm1, %%ymm1\n"
1887 
1888     "0:\n"
1889     "vmovups %%ymm0, (%2)\n"  // dst_0
1890     "vmovups %%ymm1, 0x20(%2)\n"
1891     :
1892     : "a"(act_flag), "r"(oc_algin), "r"(dst)
1893     : "%ecx", "%ymm0", "%ymm1", "%ymm12", "%ymm14");
1894 }
1895 
DepthwiseSW8x8Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1896 void DepthwiseSW8x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1897                           size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1898                           size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1899   in_kh_step *= sizeof(float);
1900   in_sw_step *= sizeof(float);
1901   in_kw_step *= sizeof(float);
1902   kw_remainder *= sizeof(float);
1903   size_t src_3_step = 3 * in_sw_step;
1904   float *dst_3 = dst + 3 * oc_algin;
1905   float *dst_5 = dst + 5 * oc_algin;
1906   oc_algin *= sizeof(float);
1907   asm volatile(
1908     "cmpq $0, %0\n"
1909     "je 0f\n"
1910     "vmovups (%0), %%ymm0\n"
1911     "vmovups (%0), %%ymm1\n"
1912     "vmovups (%0), %%ymm2\n"
1913     "vmovups (%0), %%ymm3\n"
1914     "vmovups (%0), %%ymm4\n"
1915     "vmovups (%0), %%ymm5\n"
1916     "vmovups (%0), %%ymm6\n"
1917     "vmovups (%0), %%ymm7\n"
1918     "jmp 1f\n"
1919     "0:\n"
1920     "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1921     "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1922     "vxorps %%ymm2, %%ymm2, %%ymm2\n"
1923     "vxorps %%ymm3, %%ymm3, %%ymm3\n"
1924     "vxorps %%ymm4, %%ymm4, %%ymm4\n"
1925     "vxorps %%ymm5, %%ymm5, %%ymm5\n"
1926     "vxorps %%ymm6, %%ymm6, %%ymm6\n"
1927     "vxorps %%ymm7, %%ymm7, %%ymm7\n"
1928     "1:\n"
1929     :
1930     : "r"(bias)
1931     : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7");
1932 
1933   asm volatile(
1934     "LoopH:\n"
1935     "movq %3, %%rsi\n"  // width
1936     "movq %0, %%rcx\n"  // src_h
1937     "LoopW:\n"
1938     "movq %%rcx, %%rax\n"
1939     "vmovups (%1), %%ymm12\n"
1940     "vmovups (%%rax), %%ymm13\n"
1941     "vmovups (%%rax, %6), %%ymm14\n"
1942     "vmovups (%%rax, %6, 2), %%ymm15\n"
1943     "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n"
1944     "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n"
1945     "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n"
1946     "addq %7, %%rax\n"
1947     "vmovups (%%rax), %%ymm13\n"
1948     "vmovups (%%rax, %6), %%ymm14\n"
1949     "vmovups (%%rax, %6, 2), %%ymm15\n"
1950     "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n"
1951     "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n"
1952     "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n"
1953     "addq %7, %%rax\n"
1954     "vmovups (%%rax), %%ymm13\n"
1955     "vmovups (%%rax, %6), %%ymm14\n"
1956     "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n"
1957     "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n"
1958 
1959     "addq $32, %1\n"
1960     "addq %4, %%rcx\n"  // in_kw_step
1961     "dec %%rsi\n"
1962     "jg LoopW\n"
1963 
1964     "addq %5, %0\n"  // in_kh_step
1965     "addq %8, %1\n"  // border in sw need to add remainder data
1966     "dec %2\n"
1967     "jg LoopH\n"
1968     :
1969     : "r"(src), "r"(weight), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), "r"(in_kh_step),  // 5
1970       "r"(in_sw_step), "r"(src_3_step), "r"(kw_remainder)                                     // 8
1971     : "%rcx", "%rsi", "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm12",
1972       "%ymm13", "%ymm14", "%ymm15");
1973 
1974   asm volatile(
1975     "and $0x3, %%eax\n"
1976     "je Write\n"
1977     // Relu
1978     "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1979     "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1980     "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1981     "vmaxps %%ymm12, %%ymm2, %%ymm2\n"
1982     "vmaxps %%ymm12, %%ymm3, %%ymm3\n"
1983     "vmaxps %%ymm12, %%ymm4, %%ymm4\n"
1984     "vmaxps %%ymm12, %%ymm5, %%ymm5\n"
1985     "vmaxps %%ymm12, %%ymm6, %%ymm6\n"
1986     "vmaxps %%ymm12, %%ymm7, %%ymm7\n"
1987 
1988     "and $0x1, %%eax\n"
1989     "je Write\n"
1990     // relu6
1991     "mov $0x40C00000, %%ecx\n"
1992     "vmovd %%ecx, %%xmm14\n"
1993     "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1994     "vminps %%ymm14, %%ymm0, %%ymm0\n"
1995     "vminps %%ymm14, %%ymm1, %%ymm1\n"
1996     "vminps %%ymm14, %%ymm2, %%ymm2\n"
1997     "vminps %%ymm14, %%ymm3, %%ymm3\n"
1998     "vminps %%ymm14, %%ymm4, %%ymm4\n"
1999     "vminps %%ymm14, %%ymm5, %%ymm5\n"
2000     "vminps %%ymm14, %%ymm6, %%ymm6\n"
2001     "vminps %%ymm14, %%ymm7, %%ymm7\n"
2002 
2003     "Write:\n"
2004     "vmovups %%ymm0, (%2)\n"  // dst_0
2005     "vmovups %%ymm1, (%2, %1)\n"
2006     "vmovups %%ymm2, (%2, %1, 2)\n"
2007     "vmovups %%ymm3, (%3)\n"  // dst_3
2008     "vmovups %%ymm4, (%2, %1, 4)\n"
2009     "vmovups %%ymm5, (%4)\n"  // dst_5
2010     "vmovups %%ymm6, (%4, %1, 1)\n"
2011     "vmovups %%ymm7, (%4, %1, 2)\n"
2012     :
2013     : "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3), "r"(dst_5)
2014     : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm12", "%ymm14");
2015 }
2016 
DepthwiseSW1x8Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)2017 void DepthwiseSW1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
2018                           size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
2019                           size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
2020   in_kh_step *= sizeof(float);
2021   in_kw_step *= sizeof(float);
2022   oc_algin *= sizeof(float);
2023   kw_remainder *= sizeof(float);
2024   asm volatile(
2025     "cmpq $0, %2\n"
2026     "je 0f\n"
2027     "vmovups (%2), %%ymm0\n"
2028     "jmp 1f\n"
2029     "0:\n"
2030     "vxorps %%ymm0, %%ymm0, %%ymm0\n"
2031     "1:\n"              // LoopH
2032     "movq %4, %%rsi\n"  // width
2033     "movq %0, %%rcx\n"  // src_h
2034     "2:\n"              // Loopw
2035     "vmovups (%%rcx), %%ymm4\n"
2036     // Weight data is loaded directly from memory instead of into registers for calculation.
2037     "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
2038     "addq $32, %1\n"
2039 
2040     "addq %5, %%rcx\n"  // in_kw_step
2041     "dec %%rsi\n"
2042     "jg 2b\n"
2043 
2044     "addq %6, %0\n"  // in_kh_step
2045     "addq %7, %1\n"  // kw_remainder
2046     "dec %3\n"
2047     "jg 1b\n"
2048     :
2049     : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step),  // 5
2050       "r"(in_kh_step), "r"(kw_remainder)                                                // 7
2051     : "%rcx", "%rsi", "%ymm0", "%ymm4");
2052 
2053   asm volatile(
2054     "and $0x3, %%eax\n"
2055     "je 0f\n"
2056     // Relu
2057     "vxorps %%ymm12, %%ymm12, %%ymm12\n"
2058     "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
2059 
2060     "and $0x1, %%eax\n"
2061     "je 0f\n"
2062     // relu6
2063     "mov $0x40C00000, %%ecx\n"
2064     "vmovd %%ecx, %%xmm14\n"
2065     "vpermps %%ymm14, %%ymm12, %%ymm14\n"
2066     "vminps %%ymm14, %%ymm0, %%ymm0\n"
2067 
2068     "0:\n"
2069     "vmovups %%ymm0, (%2)\n"  // dst_0
2070     :
2071     : "a"(act_flag), "r"(oc_algin), "r"(dst)
2072     : "%ecx", "%ymm0", "%ymm12", "%ymm14");
2073 }
2074 #endif
2075