• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "nnacl/fp32/conv_depthwise_fp32.h"
17 #include "nnacl/common_func.h"
18 #include "nnacl/fp32/common_func_fp32.h"
19 #include "nnacl/intrinsics/ms_simd_instructions.h"
20 #include "nnacl/errorcode.h"
21 #include "nnacl/fp32/activation_fp32.h"
22 
23 #if !defined(ENABLE_ARM) && !defined(ENABLE_SSE)
ConvDwFp32Row(float * output_ptr,const float * input_ptr,const float * weight_ptr,int num_pixels,int output_channel,int input_step)24 void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, int num_pixels,
25                    int output_channel, int input_step) {
26   for (int i = 0; i < num_pixels; i++) {
27     for (int c = 0; c < output_channel; c++) {
28       *output_ptr++ += weight_ptr[c] * input_ptr[c];
29     }
30     input_ptr += input_step;
31   }
32 }
33 #endif
34 
ConvDw(float * output_data,const float * input_data,const float * weight_data,const float * bias_data,const ConvParameter * conv_param,int task_id)35 int ConvDw(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
36            const ConvParameter *conv_param, int task_id) {
37   if (conv_param->thread_num_ == 0 || conv_param->dilation_h_ == 0 || conv_param->stride_w_ == 0) {
38     return NNACL_ERR;
39   }
40   int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_);
41   int h_start = h_step * task_id;
42   int h_end = MSMIN(h_start + h_step, conv_param->output_h_);
43   bool relu = conv_param->act_type_ == ActType_Relu;
44   bool relu6 = conv_param->act_type_ == ActType_Relu6;
45   for (int b = 0; b < conv_param->output_batch_; b++) {
46     const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_;
47     float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
48     for (int oh = h_start; oh < h_end; oh++) {
49       float *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_;
50 
51       int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_u_;
52       int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_));
53       int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_));
54 
55       for (int ow = 0; ow < conv_param->output_w_; ow++) {
56         memcpy(dst_data + ow * conv_param->output_channel_, bias_data,
57                conv_param->output_channel_ * (int)(sizeof(float)));
58       }
59       for (int kh = start_kh; kh < end_kh; kh++) {
60         int ih = ih_origin + conv_param->dilation_h_ * kh;
61 
62         const float *src_kh = src + ih * conv_param->input_w_ * conv_param->input_channel_;
63         const float *weight_kh = weight_data + kh * conv_param->kernel_w_ * conv_param->output_channel_;
64 
65         int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_;
66         for (int kw = 0; kw < conv_param->kernel_w_; kw++) {
67           int out_w_start = MSMAX(
68             0, (conv_param->pad_l_ - conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / conv_param->stride_w_);
69           int out_w_end = MSMIN(conv_param->output_w_, (conv_param->input_w_ + conv_param->pad_l_ -
70                                                         conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) /
71                                                          conv_param->stride_w_);
72 
73           float *dst_w = dst_data + out_w_start * conv_param->output_channel_;
74           int iw_origin = (out_w_start * conv_param->stride_w_) - conv_param->pad_l_ + conv_param->dilation_w_ * kw;
75 
76           const float *src_kw = src_kh + iw_origin * conv_param->input_channel_;
77           int num_pixels = out_w_end - out_w_start;
78 
79           ConvDwFp32Row(dst_w, src_kw, weight_kh, num_pixels, conv_param->output_channel_, in_sw_step);
80           weight_kh += conv_param->output_channel_;
81         }
82       }
83       if (relu) {
84         Fp32Relu(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data);
85       }
86       if (relu6) {
87         Fp32Relu6(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data);
88       }
89     }
90   }
91   return NNACL_OK;
92 }
93 
InitSlidingParam(SlidingWindowParam * sliding,const ConvParameter * conv_param,int block)94 void InitSlidingParam(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) {
95   if (block == 0) {
96     return;
97   }
98   int left = 0;
99   int right = conv_param->output_w_;
100   int top = 0;
101   int bottom = conv_param->output_h_;
102 
103   while (left * conv_param->stride_w_ < conv_param->pad_l_) {
104     left++;
105   }
106   while ((right - 1) * conv_param->stride_w_ - conv_param->pad_l_ + conv_param->kernel_w_ * conv_param->dilation_w_ >
107            conv_param->input_w_ &&
108          right > left) {
109     right--;
110   }
111   while (top * conv_param->stride_h_ < conv_param->pad_u_) {
112     top++;
113   }
114   while ((bottom - 1) * conv_param->stride_h_ - conv_param->pad_u_ + conv_param->kernel_h_ * conv_param->dilation_h_ >
115            conv_param->input_h_ &&
116          bottom > top) {
117     bottom--;
118   }
119   sliding->left_ = left;
120   sliding->right_ = right;
121   sliding->top_ = top;
122   sliding->bottom_ = bottom;
123   sliding->c_block_ = UP_DIV(conv_param->output_channel_, block);
124   sliding->block_channel_ = UP_DIV(conv_param->output_channel_, block) * block;
125   sliding->out_step_ = conv_param->output_h_ * conv_param->output_w_ * sliding->block_channel_;
126   if (conv_param->out_format_ == NNACL_NC4HW4) {
127     // write to nc8hw8
128     sliding->out_h_step_ = conv_param->output_w_ * block;
129     sliding->out_c_step_ = block * conv_param->output_h_ * conv_param->output_w_;
130     sliding->out_w_step_ = block;
131     sliding->out_block_step_ = sliding->out_c_step_;
132   } else {
133     // write to nhwc
134     sliding->out_h_step_ = conv_param->output_w_ * sliding->block_channel_;
135     sliding->out_c_step_ = block;
136     sliding->out_w_step_ = sliding->block_channel_;
137     sliding->out_block_step_ = sliding->out_w_step_;
138   }
139 }
140 
InitSlidingParamConv(SlidingWindowParam * sliding,const ConvParameter * conv_param,int input_block,int weight_block)141 void InitSlidingParamConv(SlidingWindowParam *sliding, const ConvParameter *conv_param, int input_block,
142                           int weight_block) {
143   InitSlidingParam(sliding, conv_param, weight_block);
144   AppendSlidingParamConv(sliding, conv_param, input_block, weight_block);
145 }
146 
AppendSlidingParamConv(SlidingWindowParam * sliding,const ConvParameter * conv_param,int input_block,int weight_block)147 void AppendSlidingParamConv(SlidingWindowParam *sliding, const ConvParameter *conv_param, int input_block,
148                             int weight_block) {
149   if (input_block == 0) {  // is not aligned
150     sliding->ic_align_ = conv_param->input_channel_;
151   } else {  // 1x1 input is aligned to input_block
152     sliding->ic_align_ = UP_DIV(conv_param->input_channel_, input_block) * input_block;
153   }
154   sliding->in_step_ = conv_param->input_h_ * conv_param->input_w_ * sliding->ic_align_;  // for batch loop
155   sliding->in_h_step_ = conv_param->input_w_ * sliding->ic_align_;
156   sliding->in_sh_step_ = conv_param->input_w_ * sliding->ic_align_ * conv_param->stride_h_;    // stride H
157   sliding->in_sw_step_ = sliding->ic_align_ * conv_param->stride_w_;                           // stride W
158   sliding->in_kh_step_ = conv_param->input_w_ * sliding->ic_align_ * conv_param->dilation_h_;  // kernel H
159   sliding->in_kw_step_ = sliding->ic_align_ * conv_param->dilation_w_;                         // kernel W
160   sliding->kernel_step_ = conv_param->kernel_w_ * conv_param->kernel_h_ * sliding->ic_align_ * weight_block;
161 }
162 
InitSlidingParamConvDw(SlidingWindowParam * sliding,const ConvParameter * conv_param,int block)163 void InitSlidingParamConvDw(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) {
164   InitSlidingParam(sliding, conv_param, block);
165   AppendSlidingParamConvDw(sliding, conv_param, block);
166 }
167 
AppendSlidingParamConvDw(SlidingWindowParam * sliding,const ConvParameter * conv_param,int block)168 void AppendSlidingParamConvDw(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) {
169   sliding->in_step_ = conv_param->input_h_ * conv_param->input_w_ * sliding->block_channel_;  // for batch loop
170   sliding->in_h_step_ = conv_param->input_w_ * sliding->block_channel_;
171   sliding->in_sh_step_ = conv_param->input_w_ * sliding->block_channel_ * conv_param->stride_h_;    // stride H
172   sliding->in_sw_step_ = sliding->block_channel_ * conv_param->stride_w_;                           // stride W
173   sliding->in_kh_step_ = conv_param->input_w_ * sliding->block_channel_ * conv_param->dilation_h_;  // kernel H
174   sliding->in_kw_step_ = sliding->block_channel_ * conv_param->dilation_w_;                         // kernel W
175   sliding->kernel_step_ = conv_param->kernel_w_ * conv_param->kernel_h_ * block;
176 }
177 
178 /*conv depthwise fp32 begin*/
ConvDwBorderPixel(float * dst,const float * src,const float * weight,const float * bias,int height,int width,int in_kh_step,int in_kw_step,int kernel_w_step,bool is_relu,bool is_relu6)179 void ConvDwBorderPixel(float *dst, const float *src, const float *weight, const float *bias, int height, int width,
180                        int in_kh_step, int in_kw_step, int kernel_w_step, bool is_relu, bool is_relu6) {
181   const float *src_kh = src;
182   const float *weight_kh = weight;
183   for (int c = 0; c < C4NUM; c++) {
184     dst[c] = 0;
185   }
186   for (int kh = 0; kh < height; kh++) {
187     const float *src_kw = src_kh;
188     const float *weight_kw = weight_kh;
189     for (int kw = 0; kw < width; kw++) {
190       for (int c = 0; c < C4NUM; c++) {
191         dst[c] += src_kw[c] * weight_kw[c];
192       }
193       src_kw += in_kw_step;
194       weight_kw += C4NUM;
195     }  // kernel_w loop
196     src_kh += in_kh_step;
197     weight_kh += kernel_w_step;
198   }  // kernel_h loop
199   for (int c = 0; c < C4NUM; c++) {
200     dst[c] += bias[c];
201     dst[c] = (is_relu) ? (MSMAX(0, dst[c])) : (dst[c]);
202     dst[c] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst[c]))) : (dst[c]);
203   }
204 }
205 
ConvDwBorder(float * dst,const float * src,const float * weight,const float * bias,int top,int bottom,int left,int right,const ConvParameter * conv_param,const SlidingWindowParam * sliding)206 void ConvDwBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, int left,
207                   int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
208   if (conv_param->dilation_h_ == 0 || conv_param->dilation_w_ == 0) {
209     return;
210   }
211   bool relu = conv_param->act_type_ == ActType_Relu;
212   bool relu6 = conv_param->act_type_ == ActType_Relu6;
213   float *dst_h = dst + top * sliding->out_h_step_;
214   for (int oh = top; oh < bottom; oh++) {
215     int ih = oh * conv_param->stride_h_ - conv_param->pad_u_;
216     int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
217     int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
218     const float *src_h = src + ih * sliding->in_h_step_;
219 
220     float *dst_kernel = dst_h + left * sliding->block_channel_;
221     for (int ow = left; ow < right; ow++) {
222       int iw = ow * conv_param->stride_w_ - conv_param->pad_l_;
223       int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
224       int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
225       const float *src_w = src_h + iw * sliding->block_channel_;
226 
227       const float *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_;
228       const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM;
229 #ifdef ENABLE_AVX
230       ConvDwFp32BorderParam *param = (ConvDwFp32BorderParam *)malloc(sizeof(ConvDwFp32BorderParam));
231       if (param == NULL) {
232         return;
233       }
234       param->dst = dst_kernel;
235       param->src = src_kernel;
236       param->weight = weight_kernel;
237       param->bias = bias;
238       param->height = end_kh - start_kh;
239       param->width = end_kw - start_kw;
240       param->in_kh_step = sliding->in_kh_step_ * sizeof(float);
241       param->in_kw_step = sliding->in_kw_step_ * sizeof(float);
242       param->kernel_w = conv_param->kernel_w_ * C4NUM * sizeof(float);
243       param->relu = relu;
244       param->relu6 = relu6;
245       ConvDwFp32Border(param);
246       free(param);
247 #elif defined(ENABLE_ARM) || defined(ENABLE_SSE)
248       ConvDwFp32Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
249                        sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float),
250                        conv_param->kernel_w_ * C4NUM * sizeof(float), relu, relu6);
251 #else
252       ConvDwBorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
253                         sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C4NUM, relu, relu6);
254 #endif
255       dst_kernel += sliding->block_channel_;
256     }  // width loop
257     dst_h += sliding->out_h_step_;
258   }  // height loop
259 }
260 
261 #ifndef ENABLE_ARM64
ConvDwCenter(float * dst,const float * src,const float * weight,const float * bias,int height,int width,int kernel_h,int kernel_w,int out_h_step,int block_channel,int in_sh_step,int in_sw_step,int in_kh_step,int in_kw_step,bool is_relu,bool is_relu6)262 void ConvDwCenter(float *dst, const float *src, const float *weight, const float *bias, int height, int width,
263                   int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step,
264                   int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6) {
265   float *dst_h = dst;
266   const float *src_h = src;
267   for (int oh = 0; oh < height; oh++) {
268     float *dst_w = dst_h;
269     const float *src_w = src_h;
270     for (int ow = 0; ow < width; ow++) {
271       const float *src_kh = src_w;
272       const float *weight_kh = weight;
273       for (int c = 0; c < C4NUM; c++) {
274         dst_w[c] = 0;
275       }
276       for (int kh = 0; kh < kernel_h; kh++) {
277         const float *src_kw = src_kh;
278         const float *weight_kw = weight_kh;
279         for (int kw = 0; kw < kernel_w; kw++) {
280           for (int c = 0; c < C4NUM; c++) {
281             dst_w[c] += src_kw[c] * weight_kw[c];
282           }
283           src_kw += in_kw_step;
284           weight_kw += C4NUM;
285         }  // kernel_w loop
286         src_kh += in_kh_step;
287         weight_kh += kernel_w * C4NUM;
288       }  // kernel_h loop
289       // add biad relu
290       for (int c = 0; c < C4NUM; c++) {
291         dst_w[c] += bias[c];
292         dst_w[c] = (is_relu) ? (MSMAX(0, dst_w[c])) : (dst_w[c]);
293         dst_w[c] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst_w[c]))) : (dst_w[c]);
294       }
295       dst_w += block_channel;
296       src_w += in_sw_step;
297     }  // dst_width loop
298     dst_h += out_h_step;
299     src_h += in_sh_step;
300   }  // dst_height loop
301 }
302 #endif
303 
304 // conv depthwise fp32: sliding window
ConvDwSWFp32(float * output_data,const float * input_data,const float * weight_data,const float * bias_data,const ConvParameter * conv_param,const SlidingWindowParam * sliding,int task_id)305 void ConvDwSWFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
306                   const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) {
307   bool relu = conv_param->act_type_ == ActType_Relu;
308   bool relu6 = conv_param->act_type_ == ActType_Relu6;
309   if (conv_param->thread_num_ == 0) {
310     return;
311   }
312   const float *src = input_data;
313   float *dst = output_data;
314   for (int b = 0; b < conv_param->output_batch_; b++) {
315     for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) {
316       const float *src_data = src + oc * C4NUM;
317       float *dst_data = dst + oc * C4NUM;
318       const float *weight = weight_data + oc * sliding->kernel_step_;
319       const float *bias = bias_data + oc * C4NUM;
320       ConvDwBorder(dst_data, src_data, weight, bias, 0, sliding->top_, 0, conv_param->output_w_, conv_param, sliding);
321       ConvDwBorder(dst_data, src_data, weight, bias, sliding->bottom_, conv_param->output_h_, 0, conv_param->output_w_,
322                    conv_param, sliding);
323       ConvDwBorder(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, 0, sliding->left_, conv_param,
324                    sliding);
325       ConvDwBorder(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, sliding->right_,
326                    conv_param->output_w_, conv_param, sliding);
327 
328       if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) {
329         int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_;
330         int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
331         const float *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_;
332         float *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;
333 #if defined(ENABLE_ARM) || defined(ENABLE_SSE)
334         ConvDwFp32Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
335                          conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float),
336                          sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float),
337                          sliding->in_sw_step_ * sizeof(float), sliding->in_kh_step_ * sizeof(float),
338                          sliding->in_kw_step_ * sizeof(float), relu, relu6);
339 #else
340         ConvDwCenter(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
341                      conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_,
342                      sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_, relu,
343                      relu6);
344 #endif
345       }
346     }  // output C4 loop
347     src += sliding->in_step_;
348     dst += sliding->out_step_;
349   }  // batch loop
350   // output nhwc4
351 }
352 /*conv depthwise fp32 end*/
353 
354 /*conv depthwise 3x3 fp32 begin*/
CheckConvDwUse3X3(const ConvParameter * conv_param)355 bool CheckConvDwUse3X3(const ConvParameter *conv_param) {
356   bool use_3x3 =
357     conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 &&
358     (conv_param->stride_h_ == 1 || conv_param->stride_h_ == 2) &&
359     (conv_param->stride_w_ == 1 || conv_param->stride_w_ == 2) && conv_param->stride_h_ == conv_param->stride_w_ &&
360     (conv_param->pad_u_ == 0 || conv_param->pad_u_ == 1) && (conv_param->pad_l_ == 0 || conv_param->pad_l_ == 1) &&
361     conv_param->pad_u_ == conv_param->pad_l_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1;
362   if (!use_3x3 || conv_param->input_h_ == 1 || conv_param->input_w_ == 1) {
363     return false;
364   }
365   const int in_h = (conv_param->output_h_ - 1) * conv_param->stride_h_ + conv_param->kernel_h_;
366   const int in_w = (conv_param->output_w_ - 1) * conv_param->stride_w_ + conv_param->kernel_w_;
367   return in_h == (conv_param->input_h_ + 2 * conv_param->pad_u_) &&
368          in_w == (conv_param->input_w_ + 2 * conv_param->pad_l_);
369 }
370 
371 #if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
ConvDw3x3RowLeft(const float * src,float * line,int lw,int channel)372 static void ConvDw3x3RowLeft(const float *src, float *line, int lw, int channel) {
373   MS_FLOAT32X4 v0, v1, v2, v3;
374   v0 = MS_MOVQ_F32(0.0f);
375   int ic = 0;
376   for (; ic < channel - 3; ic += 4) {
377     v1 = MS_LDQ_F32(src + ic);
378     v2 = MS_LDQ_F32(src + channel + ic);
379     v3 = MS_LDQ_F32(src + 2 * channel + ic);
380     MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2);
381     MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2);
382     MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1);
383     MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1);
384     MS_STQ_F32(line + lw * ic, b0);
385     MS_STQ_F32(line + lw * ic + 4, b1);
386     MS_STQ_F32(line + lw * ic + 8, b2);
387     MS_STQ_F32(line + lw * ic + 12, b3);
388   }
389   if (ic < channel) {
390     float *remain_line = line + ic * lw;
391     memset(remain_line, 0, 64);
392     for (int i = 0; i < channel - ic; i++) {
393       float d1 = src[i + ic];
394       float d2 = src[i + ic + channel];
395       float d3 = src[i + ic + 2 * channel];
396       remain_line[i] = 0.0f - d2;
397       remain_line[i + 4] = d1 + d2;
398       remain_line[i + 8] = d2 - d1;
399       remain_line[i + 12] = d3 - d1;
400     }
401   }
402 }
403 
ConvDw3x3RowMiddle(const float * src,float * line,int lw,int channel)404 static void ConvDw3x3RowMiddle(const float *src, float *line, int lw, int channel) {
405   MS_FLOAT32X4 v0, v1, v2, v3;
406   int ic = 0;
407   for (; ic < channel - 3; ic += 4) {
408     v0 = MS_LDQ_F32(src + ic);
409     v1 = MS_LDQ_F32(src + channel + ic);
410     v2 = MS_LDQ_F32(src + 2 * channel + ic);
411     v3 = MS_LDQ_F32(src + 3 * channel + ic);
412     MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2);
413     MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2);
414     MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1);
415     MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1);
416     MS_STQ_F32(line + lw * ic, b0);
417     MS_STQ_F32(line + lw * ic + 4, b1);
418     MS_STQ_F32(line + lw * ic + 8, b2);
419     MS_STQ_F32(line + lw * ic + 12, b3);
420   }
421   if (ic < channel) {
422     float *remain_line = line + ic * lw;
423     memset(remain_line, 0, 64);
424     for (int i = 0; i < channel - ic; i++) {
425       float d0 = src[i + ic];
426       float d1 = src[i + ic + channel];
427       float d2 = src[i + ic + 2 * channel];
428       float d3 = src[i + ic + 3 * channel];
429       remain_line[i] = d0 - d2;
430       remain_line[i + 4] = d1 + d2;
431       remain_line[i + 8] = d2 - d1;
432       remain_line[i + 12] = d3 - d1;
433     }
434   }
435 }
436 
ConvDw3x3RowRight(const float * src,float * line,int lw,int channel)437 static void ConvDw3x3RowRight(const float *src, float *line, int lw, int channel) {
438   MS_FLOAT32X4 v0, v1, v2, v3;
439   int ic = 0;
440   v3 = MS_MOVQ_F32(0.0f);
441   for (; ic < channel - 3; ic += 4) {
442     v0 = MS_LDQ_F32(src + ic);
443     v1 = MS_LDQ_F32(src + channel + ic);
444     v2 = MS_LDQ_F32(src + 2 * channel + ic);
445     MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2);
446     MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2);
447     MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1);
448     MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1);
449     MS_STQ_F32(line + lw * ic, b0);
450     MS_STQ_F32(line + lw * ic + 4, b1);
451     MS_STQ_F32(line + lw * ic + 8, b2);
452     MS_STQ_F32(line + lw * ic + 12, b3);
453   }
454   if (ic < channel) {
455     float *remain_line = line + ic * lw;
456     memset(remain_line, 0, 64);
457     for (int i = 0; i < channel - ic; i++) {
458       float d0 = src[i + ic];
459       float d1 = src[i + ic + channel];
460       float d2 = src[i + ic + 2 * channel];
461       remain_line[i] = d0 - d2;
462       remain_line[i + 4] = d1 + d2;
463       remain_line[i + 8] = d2 - d1;
464       remain_line[i + 12] = 0.0f - d1;
465     }
466   }
467 }
468 
ConvDw3x3RowSingle(const float * src,float * line,int lw,int channel)469 static void ConvDw3x3RowSingle(const float *src, float *line, int lw, int channel) {
470   MS_FLOAT32X4 v0, v1, v2;
471   int ic = 0;
472   v2 = MS_MOVQ_F32(0.0f);
473   for (; ic < channel - 3; ic += 4) {
474     v0 = MS_LDQ_F32(src + ic);
475     v1 = MS_LDQ_F32(src + channel + ic);
476     MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1);
477     MS_STQ_F32(line + lw * ic, v0);
478     MS_STQ_F32(line + lw * ic + 4, v1);
479     MS_STQ_F32(line + lw * ic + 8, b2);
480     memset(line + lw * ic + 12, 0, 16);
481   }
482   if (ic < channel) {
483     float *remain_line = line + ic * lw;
484     memset(remain_line, 0, 64);
485     for (int i = 0; i < channel - ic; i++) {
486       float d0 = src[i + ic];
487       float d1 = src[i + ic + channel];
488       remain_line[i] = d0;
489       remain_line[i + 4] = d1;
490       remain_line[i + 8] = 0.0f - d1;
491     }
492   }
493 }
494 
ConvDw3x3InitTop(const float * src,float ** lines,int width,int channel)495 static void ConvDw3x3InitTop(const float *src, float **lines, int width, int channel) {
496   float *line0 = lines[0];
497   float *line1 = lines[1];
498   float *line2 = lines[2];
499   int c4 = UP_ROUND(channel, C4NUM);
500   int lw = UP_DIV(width, C2NUM) * C4NUM;
501   memset(line0, 0, c4 * lw * sizeof(float));
502   ConvDw3x3RowLeft(src, line1, lw, channel);
503   ConvDw3x3RowLeft(src + width * channel, line2, lw, channel);
504   int ow = 2;
505   for (; ow < width - 2; ow += 2) {
506     ConvDw3x3RowMiddle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
507     ConvDw3x3RowMiddle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
508   }
509   int remain = width - ow;
510   if (remain == 2) {
511     ConvDw3x3RowRight(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
512     ConvDw3x3RowRight(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
513   } else if (remain == 1) {
514     ConvDw3x3RowSingle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
515     ConvDw3x3RowSingle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
516   }
517 }
518 
ConvDw3x3InitRow(const float * src,float ** lines,int width,int channel)519 static void ConvDw3x3InitRow(const float *src, float **lines, int width, int channel) {
520   float *line0 = lines[0];
521   float *line1 = lines[1];
522   float *line2 = lines[2];
523   int lw = UP_DIV(width, C2NUM) * C4NUM;
524   ConvDw3x3RowLeft(src - width * channel, line0, lw, channel);
525   ConvDw3x3RowLeft(src, line1, lw, channel);
526   ConvDw3x3RowLeft(src + width * channel, line2, lw, channel);
527   int ow = 2;
528   for (; ow < width - 2; ow += 2) {
529     ConvDw3x3RowMiddle(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel);
530     ConvDw3x3RowMiddle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
531     ConvDw3x3RowMiddle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
532   }
533   int remain = width - ow;
534   if (remain == 2) {
535     ConvDw3x3RowRight(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel);
536     ConvDw3x3RowRight(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
537     ConvDw3x3RowRight(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
538   } else if (remain == 1) {
539     ConvDw3x3RowSingle(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel);
540     ConvDw3x3RowSingle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
541     ConvDw3x3RowSingle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
542   }
543 }
544 
ConvDw3x3Row(const float * src,float ** lines,int width,int channel)545 static void ConvDw3x3Row(const float *src, float **lines, int width, int channel) {
546   float *tmp = lines[0];
547   lines[0] = lines[1];
548   lines[1] = lines[2];
549   lines[2] = tmp;
550   int c4 = UP_ROUND(channel, C4NUM);
551   int lw = UP_DIV(width, C2NUM) * C4NUM;
552   memset(tmp, 0, c4 * lw * sizeof(float));
553   ConvDw3x3RowLeft(src, tmp, lw, channel);
554   int ow = 2;
555   for (; ow < width - 2; ow += 2) {
556     ConvDw3x3RowMiddle(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel);
557   }
558   int remain = width - ow;
559   if (remain == 2) {
560     ConvDw3x3RowRight(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel);
561   } else if (remain == 1) {
562     ConvDw3x3RowSingle(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel);
563   }
564 }
565 
ConvDw3x3Bottom(float ** lines,int width,int channel)566 static void ConvDw3x3Bottom(float **lines, int width, int channel) {
567   float *tmp = lines[0];
568   lines[0] = lines[1];
569   lines[1] = lines[2];
570   lines[2] = tmp;
571   int c4 = UP_ROUND(channel, C4NUM);
572   memset(tmp, 0, UP_DIV(width, C2NUM) * c4 * C4NUM * sizeof(float));
573 }
574 
575 #ifndef ENABLE_ARM64
ConvDw3x3Line(float * dst,float ** lines,const float * weight,const float * bias_data,int width,int ori_channel,bool relu,bool relu6)576 void ConvDw3x3Line(float *dst, float **lines, const float *weight, const float *bias_data, int width, int ori_channel,
577                    bool relu, bool relu6) {
578   int channel = ori_channel;
579   float *line0 = lines[0];
580   float *line1 = lines[1];
581   float *line2 = lines[2];
582   for (; channel > 0; channel -= 4) {
583     MS_FLOAT32X4 bias = MS_LDQ_F32(bias_data);
584     bias_data += 4;
585     MS_FLOAT32X4 g00 = MS_LDQ_F32(weight);
586     MS_FLOAT32X4 g01 = MS_LDQ_F32(weight + 4);
587     MS_FLOAT32X4 g02 = MS_LDQ_F32(weight + 8);
588     MS_FLOAT32X4 g03 = MS_LDQ_F32(weight + 12);
589     MS_FLOAT32X4 g10 = MS_LDQ_F32(weight + 16);
590     MS_FLOAT32X4 g11 = MS_LDQ_F32(weight + 20);
591     MS_FLOAT32X4 g12 = MS_LDQ_F32(weight + 24);
592     MS_FLOAT32X4 g13 = MS_LDQ_F32(weight + 28);
593     MS_FLOAT32X4 g20 = MS_LDQ_F32(weight + 32);
594     MS_FLOAT32X4 g21 = MS_LDQ_F32(weight + 36);
595     MS_FLOAT32X4 g22 = MS_LDQ_F32(weight + 40);
596     MS_FLOAT32X4 g23 = MS_LDQ_F32(weight + 44);
597     weight += 48;
598     float *cur_dst = dst;
599     int ow = 0;
600     for (; ow < width - 1; ow += 2) {
601       MS_FLOAT32X4 acc0 = MS_MULQ_F32(MS_LDQ_F32(line0), g00);
602       MS_FLOAT32X4 acc1 = MS_MULQ_F32(MS_LDQ_F32(line0 + 4), g01);
603       MS_FLOAT32X4 acc2 = MS_MULQ_F32(MS_LDQ_F32(line0 + 8), g02);
604       MS_FLOAT32X4 acc3 = MS_MULQ_F32(MS_LDQ_F32(line0 + 12), g03);
605       line0 += 16;
606       acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line1), g10);
607       acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line1 + 4), g11);
608       acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line1 + 8), g12);
609       acc3 = MS_MLAQ_F32(acc3, MS_LDQ_F32(line1 + 12), g13);
610       line1 += 16;
611       acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line2), g20);
612       acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line2 + 4), g21);
613       acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line2 + 8), g22);
614       acc3 = MS_MLAQ_F32(acc3, MS_LDQ_F32(line2 + 12), g23);
615       line2 += 16;
616       MS_FLOAT32X4 res0 = MS_ADDQ_F32(acc0, MS_ADDQ_F32(acc2, acc1));
617       MS_FLOAT32X4 res1 = MS_ADDQ_F32(acc1, MS_SUBQ_F32(acc3, acc2));
618       res0 = MS_ADDQ_F32(res0, bias);
619       res1 = MS_ADDQ_F32(res1, bias);
620       if (relu || relu6) {
621         res0 = MS_MAXQ_F32(res0, MS_MOVQ_F32(0.0f));
622         res1 = MS_MAXQ_F32(res1, MS_MOVQ_F32(0.0f));
623       }
624       if (relu6) {
625         res0 = MS_MINQ_F32(res0, MS_MOVQ_F32(6.0f));
626         res1 = MS_MINQ_F32(res1, MS_MOVQ_F32(6.0f));
627       }
628       if (channel >= 4) {
629         MS_STQ_F32(cur_dst, res0);
630         MS_STQ_F32(cur_dst + ori_channel, res1);
631       } else {
632         for (int i = 0; i < channel; i++) {
633           cur_dst[i] = MS_F32X4_GETI(res0, i);
634           cur_dst[ori_channel + i] = MS_F32X4_GETI(res1, i);
635         }
636       }
637       cur_dst += 2 * ori_channel;
638     }
639     if (ow < width) {
640       MS_FLOAT32X4 acc0 = MS_MULQ_F32(MS_LDQ_F32(line0), g00);
641       MS_FLOAT32X4 acc1 = MS_MULQ_F32(MS_LDQ_F32(line0 + 4), g01);
642       MS_FLOAT32X4 acc2 = MS_MULQ_F32(MS_LDQ_F32(line0 + 8), g02);
643       line0 += 16;
644       acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line1), g10);
645       acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line1 + 4), g11);
646       acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line1 + 8), g12);
647       line1 += 16;
648       acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line2), g20);
649       acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line2 + 4), g21);
650       acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line2 + 8), g22);
651       line2 += 16;
652       MS_FLOAT32X4 res0 = MS_ADDQ_F32(acc0, MS_ADDQ_F32(acc2, acc1));
653       res0 = MS_ADDQ_F32(res0, bias);
654       if (relu || relu6) {
655         res0 = MS_MAXQ_F32(res0, MS_MOVQ_F32(0.0f));
656       }
657       if (relu6) {
658         res0 = MS_MINQ_F32(res0, MS_MOVQ_F32(6.0f));
659       }
660       if (channel >= 4) {
661         MS_STQ_F32(cur_dst, res0);
662       } else {
663         for (int i = 0; i < channel; i++) {
664           cur_dst[i] = MS_F32X4_GETI(res0, i);
665         }
666       }
667     }
668     dst += 4;
669   }
670 }
671 #endif
672 
ConvDw3x3(float * output_data,float * buffer,const float * input_data,const float * weight_data,const float * bias_data,const ConvParameter * conv_param,int start_oh,int end_oh)673 void ConvDw3x3(float *output_data, float *buffer, const float *input_data, const float *weight_data,
674                const float *bias_data, const ConvParameter *conv_param, int start_oh, int end_oh) {
675   int units = UP_DIV(conv_param->output_w_, C2NUM);
676   int c4 = UP_ROUND(conv_param->input_channel_, C4NUM);
677   int line = conv_param->input_channel_ * conv_param->input_w_;
678 
679   bool relu = conv_param->act_type_ == ActType_Relu;
680   bool relu6 = conv_param->act_type_ == ActType_Relu6;
681 
682   for (int b = 0; b < conv_param->output_batch_; b++) {
683     const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_;
684     float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
685     float *line0 = buffer;
686     float *line1 = buffer + units * c4 * C4NUM;
687     float *line2 = buffer + units * c4 * C8NUM;
688     float *lines[3] = {line0, line1, line2};
689     int oh = start_oh;
690     if (oh == 0) {
691       // input trans
692       ConvDw3x3InitTop(src, lines, conv_param->output_w_, conv_param->input_channel_);
693     } else {
694       // input trans
695       ConvDw3x3InitRow(src + oh * line, lines, conv_param->output_w_, conv_param->input_channel_);
696     }
697     // dst calc and trans
698     ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_,
699                   relu, relu6);
700     for (oh = start_oh + 1; oh < end_oh - 1; oh++) {
701       // input trans
702       ConvDw3x3Row(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_);
703       // dst calc and trans
704       ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_,
705                     relu, relu6);
706     }
707     if (oh == conv_param->output_h_ - 1) {
708       // input trans
709       ConvDw3x3Bottom(lines, conv_param->output_w_, conv_param->input_channel_);
710     } else {
711       // input trans
712       ConvDw3x3Row(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_);
713     }
714     // dst calc and trans
715     ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_,
716                   relu, relu6);
717   }
718 }
719 #endif
720 
721 /*conv depthwise indirect buffer fp32 begin*/
CheckConvDwUseIndirectBuffer(const ConvParameter * conv_param)722 bool CheckConvDwUseIndirectBuffer(const ConvParameter *conv_param) {
723   bool use_indirect = (conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3) ||
724                       (conv_param->kernel_h_ == 5 && conv_param->kernel_w_ == 5);
725   return use_indirect;
726 }
727 
ConvDwInitIndirection(float ** indirect_buffer,float * src,float * zero_ptr,const ConvParameter * conv_param,int step_h,int step_w)728 void ConvDwInitIndirection(float **indirect_buffer, float *src, float *zero_ptr, const ConvParameter *conv_param,
729                            int step_h, int step_w) {
730 #ifdef ENABLE_AVX
731   int div = C8NUM;
732 #else
733   int div = C4NUM;
734 #endif
735 
736   int ic_div = UP_DIV(conv_param->input_channel_, div) * div;
737   for (int b = 0; b < conv_param->output_batch_; b++) {
738     float **indirect = indirect_buffer + b * conv_param->output_h_ * step_h;
739     float *input = src + b * conv_param->input_h_ * conv_param->input_w_ * ic_div;
740     for (int oh = 0; oh < conv_param->output_h_; oh++) {
741       for (int kh = 0; kh < conv_param->kernel_h_; kh++) {
742         int ih = oh * conv_param->stride_h_ + kh * conv_param->dilation_h_ - conv_param->pad_u_;
743         if (ih < conv_param->input_h_ && ih >= 0) {
744           for (int ow = 0; ow < conv_param->output_w_; ow++) {
745             for (int kw = 0; kw < conv_param->kernel_w_; kw++) {
746               int iw = ow * conv_param->stride_w_ + kw * conv_param->dilation_w_ - conv_param->pad_l_;
747               int index = oh * step_h + ow * step_w * conv_param->kernel_h_ + kw * conv_param->kernel_h_ + kh;
748               if (iw < conv_param->input_w_ && iw >= 0) {
749                 indirect[index] = input + (ih * conv_param->input_w_ + iw) * ic_div;
750               } else {
751                 indirect[index] = zero_ptr;
752               }
753             }
754           }
755         } else {
756           for (int ow = 0; ow < conv_param->output_w_; ow++) {
757             for (int kw = 0; kw < conv_param->kernel_w_; kw++) {
758               int index = oh * step_h + ow * step_w * conv_param->kernel_h_ + kw * conv_param->kernel_h_ + kh;
759               indirect[index] = zero_ptr;
760             }
761           }
762         }
763       }
764     }
765   }
766 }
767 
768 #if !defined(ENABLE_ARM64) && !defined(ENABLE_AVX)
ConvDwFp32IndirectRow(float * output,float ** input,const float * weights,const float * bias,int channels,int output_width,int input_stride,bool relu,bool relu6,int kernel)769 void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
770                            int output_width, int input_stride, bool relu, bool relu6, int kernel) {
771   do {
772     float **in = input;
773     size_t c = (size_t)channels;
774     const float *w = weights;
775     float *out = output;
776     memcpy(out, bias, channels * (int)sizeof(float));
777     for (; c >= C4NUM; c -= C4NUM) {
778       for (int i = 0; i < C4NUM; i++) {
779         for (int k = 0; k < kernel; k++) {
780           out[i] += in[k][i] * w[i + k * C4NUM];
781         }
782       }
783       w += kernel * C4NUM;
784       out += C4NUM;
785       for (int k = 0; k < kernel; k++) {
786         in[k] += C4NUM;
787       }
788     }
789     for (int i = 0; i < c; i++) {
790       for (int k = 0; k < kernel; k++) {
791         out[i] += in[k][i] * w[i + k * C4NUM];
792       }
793     }
794     if (relu) {
795       Fp32Relu(output, channels, output);
796     }
797     if (relu6) {
798       Fp32Relu6(output, channels, output);
799     }
800     output += channels;
801     input = input + input_stride;
802   } while (--output_width != 0);
803 }
804 #endif
805 
806 #ifdef ENABLE_ARM64
ConvDwFp32IndirectRow(float * output,float ** input,const float * weights,const float * bias,int channels,int output_width,int input_stride,bool relu,bool relu6,int kernel)807 void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
808                            int output_width, int input_stride, bool relu, bool relu6, int kernel) {
809   if (kernel == 9) {
810     ConvDwFp32Indirect3x3(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu,
811                           relu6);
812   } else if (kernel == 25) {
813     ConvDwFp32Indirect5x5(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu,
814                           relu6);
815   }
816 }
817 #endif
818 
819 #ifdef ENABLE_AVX
ConvDwFp32IndirectRow(float * output,float ** input,const float * weights,const float * bias,int channels,int output_width,int input_stride,bool relu,bool relu6,int kernel)820 void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
821                            int output_width, int input_stride, bool relu, bool relu6, int kernel) {
822   if (kernel == 9) {
823     ConvDwFp32Avx3x3(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu, relu6);
824   } else if (kernel == 25) {
825     ConvDwFp32Avx5x5(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu, relu6);
826   }
827 }
828 #endif
829 
ConvDwIndirection(float * output_data,float ** indirect_buffer,const float * weight_data,const float * bias_data,float * zero_ptr,const ConvParameter * conv_param,int task_id)830 void ConvDwIndirection(float *output_data, float **indirect_buffer, const float *weight_data, const float *bias_data,
831                        float *zero_ptr, const ConvParameter *conv_param, int task_id) {
832   if (conv_param->thread_num_ == 0) {
833     return;
834   }
835   int step_w = conv_param->dilation_w_ == 1 ? conv_param->stride_w_ : conv_param->kernel_w_;
836   int step_h =
837     (conv_param->kernel_h_ * conv_param->kernel_w_) + (conv_param->output_w_ - 1) * step_w * conv_param->kernel_h_;
838   int input_stride = conv_param->kernel_h_ * step_w;
839 
840   bool relu = conv_param->act_type_ == ActType_Relu;
841   bool relu6 = conv_param->act_type_ == ActType_Relu6;
842 
843   int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_);
844   int h_start = h_step * task_id;
845   int h_end = MSMIN(h_start + h_step, conv_param->output_h_);
846 
847   for (int b = 0; b < conv_param->output_batch_; b++) {
848     float **indirect_b = indirect_buffer + b * conv_param->output_h_ * step_h;
849     float *outout_b = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
850     for (int oh = h_start; oh < h_end; oh++) {
851       float **indirect = indirect_b + oh * step_h;
852       float *output_h = outout_b + oh * conv_param->output_w_ * conv_param->output_channel_;
853       if (conv_param->kernel_w_ == 3) {
854         ConvDwFp32IndirectRow(output_h, indirect, weight_data, bias_data, conv_param->output_channel_,
855                               conv_param->output_w_, input_stride, relu, relu6, 9);
856       } else if (conv_param->kernel_w_ == 5) {
857         ConvDwFp32IndirectRow(output_h, indirect, weight_data, bias_data, conv_param->output_channel_,
858                               conv_param->output_w_, input_stride, relu, relu6, 25);
859       }
860     }
861   }
862 }
863 /*conv depthwise indirect buffer fp32 end*/
864 
865 /*deconv depthwise fp32 begin*/
DeconvDwBorderPixel(float * dst,const float * src,const float * weight,int height,int width,int in_kh_step,int in_kw_step,int kernel_w_step)866 void DeconvDwBorderPixel(float *dst, const float *src, const float *weight, int height, int width, int in_kh_step,
867                          int in_kw_step, int kernel_w_step) {
868   float *dst_kh = dst;
869   const float *weight_kh = weight;
870   for (int kh = 0; kh < height; kh++) {
871     float *dst_kw = dst_kh;
872     const float *weight_kw = weight_kh;
873     for (int kw = 0; kw < width; kw++) {
874 #ifdef ENABLE_ARM64
875       float32x4_t src_4 = vld1q_f32(src);
876       float32x4_t weight_4 = vld1q_f32(weight_kw);
877       float32x4_t dst_4 = vld1q_f32(dst_kw);
878       dst_4 = vfmaq_f32(dst_4, src_4, weight_4);
879       vst1q_f32(dst_kw, dst_4);
880 #else
881       for (int c = 0; c < C4NUM; c++) {
882         dst_kw[c] += src[c] * weight_kw[c];
883       }
884 #endif
885       dst_kw += in_kw_step;
886       weight_kw += C4NUM;
887     }  // kernel_w loop
888     dst_kh += in_kh_step;
889     weight_kh += kernel_w_step;
890   }  // kernel_h loop
891 }
892 
DeconvDwBorder(float * dst,const float * src,const float * weight,int top,int bottom,int left,int right,const ConvParameter * conv_param,const SlidingWindowParam * sliding)893 void DeconvDwBorder(float *dst, const float *src, const float *weight, int top, int bottom, int left, int right,
894                     const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
895   if (conv_param->dilation_h_ == 0 || conv_param->dilation_w_ == 0) {
896     return;
897   }
898   const float *src_h = src + top * sliding->out_h_step_;
899   for (int ih = top; ih < bottom; ih++) {
900     int oh = ih * conv_param->stride_h_ - conv_param->pad_u_;
901     int start_kh = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
902     int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
903     float *dst_h = dst + oh * sliding->in_h_step_;
904 
905     const float *src_kernel = src_h + left * sliding->block_channel_;
906     for (int iw = left; iw < right; iw++) {
907       int ow = iw * conv_param->stride_w_ - conv_param->pad_l_;
908       int start_kw = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_));
909       int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_));
910       float *dst_w = dst_h + ow * sliding->block_channel_;
911 
912       const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM;
913       float *dst_kernel = dst_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_;
914 #ifdef ENABLE_ARM64
915       DeconvDwFp32Border(dst_kernel, src_kernel, weight_kernel, end_kh - start_kh, end_kw - start_kw,
916                          sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float),
917                          conv_param->kernel_w_ * C4NUM * sizeof(float));
918 #else
919       DeconvDwBorderPixel(dst_kernel, src_kernel, weight_kernel, end_kh - start_kh, end_kw - start_kw,
920                           sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C4NUM);
921 #endif
922       src_kernel += sliding->block_channel_;
923     }  // width loop
924     src_h += sliding->out_h_step_;
925   }  // height loop
926 }
927 
928 #ifndef ENABLE_ARM64
DeconvDwCenter(float * dst,const float * src,const float * weight,int height,int width,int kernel_h,int kernel_w,int out_h_step,int block_channel,int in_sh_step,int in_sw_step,int in_kh_step,int in_kw_step)929 void DeconvDwCenter(float *dst, const float *src, const float *weight, int height, int width, int kernel_h,
930                     int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step, int in_kh_step,
931                     int in_kw_step) {
932   float *dst_h = dst;
933   const float *src_h = src;
934   for (int oh = 0; oh < height; oh++) {
935     float *dst_w = dst_h;
936     const float *src_w = src_h;
937     for (int ow = 0; ow < width; ow++) {
938       float *dst_kh = dst_w;
939       const float *weight_kh = weight;
940       for (int kh = 0; kh < kernel_h; kh++) {
941         float *dst_kw = dst_kh;
942         const float *weight_kw = weight_kh;
943         for (int kw = 0; kw < kernel_w; kw++) {
944           for (int c = 0; c < C4NUM; c++) {
945             dst_kw[c] += src_w[c] * weight_kw[c];
946           }
947           dst_kw += in_kw_step;
948           weight_kw += C4NUM;
949         }  // kernel_w loop
950         dst_kh += in_kh_step;
951         weight_kh += kernel_w * C4NUM;
952       }  // kernel_h loop
953       dst_w += in_sw_step;
954       src_w += block_channel;
955     }  // dst_width loop
956     dst_h += in_sh_step;
957     src_h += out_h_step;
958   }  // dst_height loop
959 }
960 #endif
961 
DeconvDwPost(float * dst,const float * bias,int block_channel,const ConvParameter * conv_param)962 void DeconvDwPost(float *dst, const float *bias, int block_channel, const ConvParameter *conv_param) {
963   bool relu = conv_param->act_type_ == ActType_Relu;
964   bool relu6 = conv_param->act_type_ == ActType_Relu6;
965   float *dst_k = dst;
966   for (int k = 0; k < conv_param->output_h_ * conv_param->output_w_; k++) {
967     for (int c = 0; c < C4NUM; c++) {
968       dst_k[c] += bias[c];
969       dst_k[c] = (relu) ? (MSMAX(0, dst_k[c])) : (dst_k[c]);
970       dst_k[c] = (relu6) ? (MSMIN(6, MSMAX(0, dst_k[c]))) : (dst_k[c]);
971     }
972     dst_k += block_channel;
973   }
974 }
975 
976 // deconv depthwise fp32: sliding window
DeconvDwSWFp32(float * output_data,const float * input_data,const float * weight_data,const float * bias_data,const ConvParameter * conv_param,const SlidingWindowParam * sliding,int task_id)977 void DeconvDwSWFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
978                     const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) {
979   const float *src = input_data;
980   float *dst = output_data;
981   if (conv_param->thread_num_ == 0) {
982     return;
983   }
984   for (int b = 0; b < conv_param->output_batch_; b++) {
985     for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) {
986       const float *src_data = src + oc * C4NUM;
987       float *dst_data = dst + oc * C4NUM;
988       const float *weight = weight_data + oc * sliding->kernel_step_;
989       const float *bias = bias_data + oc * C4NUM;
990       DeconvDwBorder(dst_data, src_data, weight, 0, sliding->top_, 0, conv_param->input_w_, conv_param, sliding);
991       DeconvDwBorder(dst_data, src_data, weight, sliding->bottom_, conv_param->input_h_, 0, conv_param->input_w_,
992                      conv_param, sliding);
993       DeconvDwBorder(dst_data, src_data, weight, sliding->top_, sliding->bottom_, 0, sliding->left_, conv_param,
994                      sliding);
995       DeconvDwBorder(dst_data, src_data, weight, sliding->top_, sliding->bottom_, sliding->right_, conv_param->input_w_,
996                      conv_param, sliding);
997 
998       if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) {
999         int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_;
1000         int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
1001         float *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_;
1002         const float *in_t = src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;
1003 
1004 #if defined(ENABLE_ARM) || defined(ENABLE_SSE)
1005         DeconvDwFp32Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
1006                            conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float),
1007                            sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float),
1008                            sliding->in_sw_step_ * sizeof(float), sliding->in_kh_step_ * sizeof(float),
1009                            sliding->in_kw_step_ * sizeof(float));
1010 #else
1011         DeconvDwCenter(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
1012                        conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_,
1013                        sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_);
1014 #endif
1015       }
1016       DeconvDwPost(dst_data, bias, sliding->block_channel_, conv_param);
1017     }  // output C4 loop
1018     src += sliding->out_step_;
1019     dst += sliding->in_step_;
1020   }  // batch loop
1021   // output nhwc4
1022 }
1023 /*deconv depthwise fp32 end*/
1024 
1025 #ifdef ENABLE_AVX
DepthwiseBorderAvxFp32(float * dst,const float * src,const float * weight,const float * bias,int top,int left,int right,const ConvParameter * conv_param,const SlidingWindowParam * sw_param,const DepthwiseSWKernel kernel,int act_type,int ow_bock,int oc_block)1026 void DepthwiseBorderAvxFp32(float *dst, const float *src, const float *weight, const float *bias, int top, int left,
1027                             int right, const ConvParameter *conv_param, const SlidingWindowParam *sw_param,
1028                             const DepthwiseSWKernel kernel, int act_type, int ow_bock, int oc_block) {
1029   // dw border compate
1030   int ih = top * conv_param->stride_h_ - conv_param->pad_u_;
1031   int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
1032   int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
1033   const float *src_h = src + ih * sw_param->in_h_step_;
1034   float *dst_kernel = dst + left * sw_param->block_channel_;
1035   for (int ow = left; ow < right; ow += ow_bock) {
1036     int iw = ow * conv_param->stride_w_ - conv_param->pad_l_;
1037     int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
1038     int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
1039     const float *src_w = src_h + iw * sw_param->block_channel_;
1040     const float *src_kernel = src_w + start_kh * sw_param->in_kh_step_ + start_kw * sw_param->in_kw_step_;
1041     const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C8NUM * oc_block;
1042     kernel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, act_type, ow_bock,
1043            oc_block, sw_param->block_channel_, sw_param->in_kw_step_, sw_param->in_kh_step_, sw_param->in_sw_step_,
1044            (conv_param->kernel_w_ - end_kw + start_kw) * C8NUM * oc_block);
1045     dst_kernel += ow_bock * sw_param->block_channel_;
1046   }  // width loop
1047 }
1048 
DepthwiseSWAvxFp32(float * output_data,const float * input_data,const float * weight_data,const float * bias_data,const ConvParameter * conv_param,const SlidingWindowParam * sw_param,int task_id)1049 void DepthwiseSWAvxFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
1050                         const ConvParameter *conv_param, const SlidingWindowParam *sw_param, int task_id) {
1051   int oh_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_);
1052   int oh_start = oh_step * task_id;
1053   int oh_end = MSMIN(oh_start + oh_step, conv_param->output_h_);
1054   if (oh_start >= oh_end) {
1055     return;
1056   }
1057   // depthwise sw in x86 avx instructions
1058   int oc_tile_ = C8NUM;  // oc in algin to C8NUM in x86_64_avx
1059   int act_type = 0;
1060   if (conv_param->act_type_ == ActType_Relu6) {
1061     act_type += 1;
1062   }
1063   if (conv_param->act_type_ == ActType_Relu || conv_param->act_type_ == ActType_Relu6) {
1064     act_type += 2;
1065   }
1066   int kernel_h = conv_param->kernel_h_;
1067   int kernel_w = conv_param->kernel_w_;
1068   int output_w = conv_param->output_w_;
1069   int oc_algin = sw_param->block_channel_;
1070   int oc_num = sw_param->c_block_;
1071   int in_step = sw_param->in_step_;
1072   int out_step = sw_param->out_step_;
1073   int in_sw_step = sw_param->in_sw_step_;
1074   int in_kw_step = sw_param->in_kw_step_;
1075   int in_kh_step = sw_param->in_kh_step_;
1076   int in_sh_step = sw_param->in_sh_step_;
1077   int out_right = sw_param->right_;
1078   int out_left = sw_param->left_;
1079   int out_top = sw_param->top_;
1080   int out_bottom = sw_param->bottom_;
1081   int kernel_step = sw_param->kernel_step_;
1082   int out_h_step = sw_param->out_h_step_;
1083   int in_h_start = out_top * conv_param->stride_h_ - conv_param->pad_u_;
1084   int in_w_start = out_left * conv_param->stride_w_ - conv_param->pad_l_;
1085   int in_start = in_h_start * sw_param->in_h_step_ + in_w_start * oc_algin;
1086   const int ow_block_num[4] = {8, 4, 4, 3};
1087   const DepthwiseSWKernel kernel[4][2] = {{DepthwiseSW1x8Kernel, DepthwiseSW8x8Kernel},
1088                                           {DepthwiseSW1x16Kernel, DepthwiseSW4x16Kernel},
1089                                           {DepthwiseSW1x24Kernel, DepthwiseSW4x24Kernel},
1090                                           {DepthwiseSW1x32Kernel, DepthwiseSW3x32Kernel}};
1091   for (int b = 0; b < conv_param->output_batch_; b++) {
1092     for (int oh = oh_start; oh < oh_end; ++oh) {
1093       float *dst_oh = output_data + oh * out_h_step;
1094       const float *src_h = input_data + in_start + (oh - out_top) * in_sh_step;
1095       int oc_block = 0;
1096       const float *bias = bias_data;
1097       for (int oc = 0; oc < oc_num; oc += oc_block) {
1098         oc_block = MSMIN(C4NUM, oc_num - oc);  // 4 3 2 1
1099         int oc_step = oc * oc_tile_;
1100         const float *weight = weight_data + oc * kernel_step;
1101         if (bias != NULL) {
1102           bias = bias_data + oc_step;
1103         }
1104         float *dst_w = dst_oh + oc_step;
1105         const DepthwiseSWKernel kernel_border = kernel[oc_block - 1][0];
1106         if (oh < out_top || oh >= out_bottom) {  // oh in up or down border
1107           DepthwiseBorderAvxFp32(dst_w, input_data + oc_step, weight, bias, oh, 0, output_w, conv_param, sw_param,
1108                                  kernel_border, act_type, 1, oc_block);
1109         } else {  // oh in center
1110           // ow in right
1111           DepthwiseBorderAvxFp32(dst_w, input_data + oc_step, weight, bias, oh, 0, out_left, conv_param, sw_param,
1112                                  kernel_border, act_type, 1, oc_block);
1113           // ow in center
1114           const float *src_w = src_h + oc_step;
1115           int ow_block = ow_block_num[oc_block - 1];                 // 8 4 4 3
1116           for (int ow = out_left; ow < out_right; ow += ow_block) {  // left ~ right
1117             if (ow_block > out_right - ow) {                         // ow is not enough and process one ow
1118               ow_block = 1;
1119             }
1120             kernel[oc_block - 1][ow_block / ow_block_num[oc_block - 1]](
1121               dst_w + ow * oc_algin, src_w, weight, bias, kernel_h, kernel_w, act_type, ow_block, oc_block, oc_algin,
1122               in_kw_step, in_kh_step, in_sw_step, 0);
1123             src_w += ow_block * in_sw_step;
1124           }
1125           // ow in left
1126           DepthwiseBorderAvxFp32(dst_w, input_data + oc_step, weight, bias, oh, out_right, output_w, conv_param,
1127                                  sw_param, kernel_border, act_type, 1, oc_block);
1128         }
1129       }
1130     }  // output h loop
1131     input_data += in_step;
1132     output_data += out_step;
1133   }  // batch loop
1134 }
1135 
1136 #ifdef ENABLE_DEBUG
DepthwiseSWWxKKernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1137 void DepthwiseSWWxKKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1138                           size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1139                           size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1140   __m256 dst_data[12];
1141   __m256 src_data;
1142   const float *src_kh[12];
1143   const float *src_kw[12];
1144   __m256 weight_data[4];
1145   for (int i = 0; i < ow_block; ++i) {
1146     if (bias != NULL) {
1147       for (int j = 0; j < oc_block; ++j) {
1148         dst_data[i * oc_block + j] = _mm256_loadu_ps(bias + j * 8);
1149       }
1150     } else {
1151       for (int j = 0; j < oc_block; ++j) {
1152         dst_data[i * oc_block + j] = _mm256_set1_ps(0.0f);
1153       }
1154     }
1155     src_kh[i] = src + i * in_sw_step;
1156     src_kw[i] = NULL;
1157   }
1158   const float *weight_kernel = weight;
1159   for (int kh = 0; kh < kernel_h; kh++) {
1160     for (int i = 0; i < ow_block; ++i) {
1161       src_kw[i] = src_kh[i];
1162     }
1163     for (int kw = 0; kw < kernel_w; kw++) {
1164       for (int j = 0; j < oc_block; ++j) {
1165         weight_data[j] = _mm256_loadu_ps(weight_kernel + j * C8NUM);
1166       }
1167       for (int i = 0; i < ow_block; ++i) {  // loop ow
1168         for (int j = 0; j < oc_block; ++j) {
1169           src_data = _mm256_loadu_ps(src_kw[i] + j * C8NUM);
1170           dst_data[i * oc_block + j] += src_data * weight_data[j];
1171         }
1172       }
1173       for (int i = 0; i < ow_block; ++i) {
1174         src_kw[i] += in_kw_step;  // ic8 * dilation_w
1175       }
1176       weight_kernel += oc_block * C8NUM;
1177     }  // kernel_w loop
1178     weight_kernel += kw_remainder;
1179     for (int i = 0; i < ow_block; ++i) {
1180       src_kh[i] += in_kh_step;  //
1181     }
1182   }  // kernel_h loop
1183   // add bias and relu
1184   for (int i = 0; i < ow_block; ++i) {
1185     for (int j = 0; j < oc_block; ++j) {
1186       if (0x1 & act_flag) {  // relu6
1187         dst_data[i * oc_block + j] = _mm256_min_ps(dst_data[i * oc_block + j], _mm256_set1_ps(6.0f));
1188       }
1189       if (0x2 & act_flag) {  // relu
1190         dst_data[i * oc_block + j] = _mm256_max_ps(dst_data[i * oc_block + j], _mm256_set1_ps(0.0f));
1191       }
1192       _mm256_storeu_ps(dst + i * oc_algin + j * C8NUM, dst_data[i * oc_block + j]);
1193     }
1194   }
1195 }
1196 #endif
1197 
DepthwiseSW3x32Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1198 void DepthwiseSW3x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1199                            size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1200                            size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1201   in_kh_step *= sizeof(float);
1202   in_sw_step *= sizeof(float);
1203   in_kw_step *= sizeof(float);
1204   oc_algin *= sizeof(float);
1205   kw_remainder *= sizeof(float);
1206   asm volatile(
1207     "cmpq $0, %2\n"
1208     "je 0f\n"
1209     "vmovups (%2), %%ymm0\n"
1210     "vmovups 0x20(%2), %%ymm1\n"
1211     "vmovups 0x40(%2), %%ymm2\n"
1212     "vmovups 0x60(%2), %%ymm3\n"
1213     "vmovups (%2), %%ymm4\n"
1214     "vmovups 0x20(%2), %%ymm5\n"
1215     "vmovups 0x40(%2), %%ymm6\n"
1216     "vmovups 0x60(%2), %%ymm7\n"
1217     "vmovups (%2), %%ymm8\n"
1218     "vmovups 0x20(%2), %%ymm9\n"
1219     "vmovups 0x40(%2), %%ymm10\n"
1220     "vmovups 0x60(%2), %%ymm11\n"
1221     "jmp 1f\n"
1222     "0:\n"
1223     "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1224     "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1225     "vxorps %%ymm2, %%ymm2, %%ymm2\n"
1226     "vxorps %%ymm3, %%ymm3, %%ymm3\n"
1227     "vxorps %%ymm4, %%ymm4, %%ymm4\n"
1228     "vxorps %%ymm5, %%ymm5, %%ymm5\n"
1229     "vxorps %%ymm6, %%ymm6, %%ymm6\n"
1230     "vxorps %%ymm7, %%ymm7, %%ymm7\n"
1231     "vxorps %%ymm8, %%ymm8, %%ymm8\n"
1232     "vxorps %%ymm9, %%ymm9, %%ymm9\n"
1233     "vxorps %%ymm10, %%ymm10, %%ymm10\n"
1234     "vxorps %%ymm11, %%ymm11, %%ymm11\n"
1235     "1:\n"              // LoopH
1236     "movq %4, %%rsi\n"  // width
1237     "movq %0, %%rcx\n"  // src_h
1238     "2:\n"              // LoopW
1239 
1240     "vmovups (%1), %%ymm12\n"
1241     "vmovups (%%rcx), %%ymm13\n"
1242     "vmovups (%%rcx, %7), %%ymm14\n"
1243     "vmovups (%%rcx, %7, 2), %%ymm15\n"
1244     "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n"
1245     "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n"
1246     "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n"
1247 
1248     "vmovups 0x20(%1), %%ymm12\n"
1249     "vmovups 0x20(%%rcx), %%ymm13\n"
1250     "vmovups 0x20(%%rcx, %7), %%ymm14\n"
1251     "vmovups 0x20(%%rcx, %7, 2), %%ymm15\n"
1252     "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n"
1253     "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n"
1254     "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n"
1255 
1256     "vmovups 0x40(%1), %%ymm12\n"
1257     "vmovups 0x40(%%rcx), %%ymm13\n"
1258     "vmovups 0x40(%%rcx, %7), %%ymm14\n"
1259     "vmovups 0x40(%%rcx, %7, 2), %%ymm15\n"
1260     "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n"
1261     "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n"
1262     "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n"
1263 
1264     "vmovups 0x60(%1), %%ymm12\n"
1265     "vmovups 0x60(%%rcx), %%ymm13\n"
1266     "vmovups 0x60(%%rcx, %7), %%ymm14\n"
1267     "vmovups 0x60(%%rcx, %7, 2), %%ymm15\n"
1268     "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n"
1269     "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n"
1270     "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n"
1271     "addq $128, %1\n"
1272 
1273     "addq %5, %%rcx\n"  // in_kw_step
1274     "dec %%rsi\n"
1275     "jg 2b\n"
1276 
1277     "addq %6, %0\n"  // in_kh_step
1278     "addq %8, %1\n"
1279     "dec %3\n"
1280     "jg 1b\n"
1281     :
1282     : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step),  // 5
1283       "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder)                               // 8
1284     : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9",
1285       "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
1286 
1287   asm volatile(
1288     "and $0x3, %%eax\n"
1289     "je 0f\n"
1290     // Relu
1291     "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1292     "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1293     "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1294     "vmaxps %%ymm12, %%ymm2, %%ymm2\n"
1295     "vmaxps %%ymm12, %%ymm3, %%ymm3\n"
1296     "vmaxps %%ymm12, %%ymm4, %%ymm4\n"
1297     "vmaxps %%ymm12, %%ymm5, %%ymm5\n"
1298     "vmaxps %%ymm12, %%ymm6, %%ymm6\n"
1299     "vmaxps %%ymm12, %%ymm7, %%ymm7\n"
1300     "vmaxps %%ymm12, %%ymm8, %%ymm8\n"
1301     "vmaxps %%ymm12, %%ymm9, %%ymm9\n"
1302     "vmaxps %%ymm12, %%ymm10, %%ymm10\n"
1303     "vmaxps %%ymm12, %%ymm11, %%ymm11\n"
1304 
1305     "and $0x1, %%eax\n"
1306     "je 0f\n"
1307     // relu6
1308     "mov $0x40C00000, %%ecx\n"
1309     "vmovd %%ecx, %%xmm14\n"
1310     "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1311     "vminps %%ymm14, %%ymm0, %%ymm0\n"
1312     "vminps %%ymm14, %%ymm1, %%ymm1\n"
1313     "vminps %%ymm14, %%ymm2, %%ymm2\n"
1314     "vminps %%ymm14, %%ymm3, %%ymm3\n"
1315     "vminps %%ymm14, %%ymm4, %%ymm4\n"
1316     "vminps %%ymm14, %%ymm5, %%ymm5\n"
1317     "vminps %%ymm14, %%ymm6, %%ymm6\n"
1318     "vminps %%ymm14, %%ymm7, %%ymm7\n"
1319     "vminps %%ymm14, %%ymm8, %%ymm8\n"
1320     "vminps %%ymm14, %%ymm9, %%ymm9\n"
1321     "vminps %%ymm14, %%ymm10, %%ymm10\n"
1322     "vminps %%ymm14, %%ymm11, %%ymm11\n"
1323 
1324     "0:\n"
1325     "vmovups %%ymm0, (%2)\n"  // dst_0
1326     "vmovups %%ymm1, 0x20(%2)\n"
1327     "vmovups %%ymm2, 0x40(%2)\n"
1328     "vmovups %%ymm3, 0x60(%2)\n"
1329     "vmovups %%ymm4, (%2, %1, 1)\n"
1330     "vmovups %%ymm5, 0x20(%2, %1, 1)\n"
1331     "vmovups %%ymm6, 0x40(%2, %1, 1)\n"
1332     "vmovups %%ymm7, 0x60(%2, %1, 1)\n"
1333     "vmovups %%ymm8, (%2, %1, 2)\n"
1334     "vmovups %%ymm9, 0x20(%2, %1, 2)\n"
1335     "vmovups %%ymm10, 0x40(%2, %1, 2)\n"
1336     "vmovups %%ymm11, 0x60(%2, %1, 2)\n"
1337     :
1338     : "a"(act_flag), "r"(oc_algin), "r"(dst)
1339     : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
1340       "%ymm11", "%ymm12", "%ymm14");
1341 }
1342 
DepthwiseSW1x32Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1343 void DepthwiseSW1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1344                            size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1345                            size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1346   in_kh_step *= sizeof(float);
1347   in_kw_step *= sizeof(float);
1348   oc_algin *= sizeof(float);
1349   kw_remainder *= sizeof(float);
1350   asm volatile(
1351     "cmpq $0, %2\n"
1352     "je 0f\n"
1353     "vmovups (%2), %%ymm0\n"
1354     "vmovups 0x20(%2), %%ymm1\n"
1355     "vmovups 0x40(%2), %%ymm2\n"
1356     "vmovups 0x60(%2), %%ymm3\n"
1357     "jmp 1f\n"
1358     "0:\n"
1359     "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1360     "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1361     "vxorps %%ymm2, %%ymm2, %%ymm2\n"
1362     "vxorps %%ymm3, %%ymm3, %%ymm3\n"
1363     "1:\n"              // LoopH
1364     "movq %4, %%rsi\n"  // width
1365     "movq %0, %%rcx\n"  // src_h
1366     "2:\n"              // Loopw
1367     "vmovups (%%rcx), %%ymm4\n"
1368     "vmovups 0x20(%%rcx), %%ymm5\n"
1369     "vmovups 0x40(%%rcx), %%ymm6\n"
1370     "vmovups 0x60(%%rcx), %%ymm7\n"
1371     // Weight data is loaded directly from memory instead of into registers for calculation.
1372     "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1373     "vfmadd231ps 0x20(%1), %%ymm5, %%ymm1\n"
1374     "vfmadd231ps 0x40(%1), %%ymm6, %%ymm2\n"
1375     "vfmadd231ps 0x60(%1), %%ymm7, %%ymm3\n"
1376     "addq $128, %1\n"
1377 
1378     "addq %5, %%rcx\n"  // in_kw_step
1379     "dec %%rsi\n"
1380     "jg 2b\n"
1381 
1382     "addq %6, %0\n"  // in_kh_step
1383     "addq %7, %1\n"
1384     "dec %3\n"
1385     "jg 1b\n"
1386     :
1387     : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step),  // 5
1388       "r"(in_kh_step), "r"(kw_remainder)                                                // 7
1389     : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7");
1390 
1391   asm volatile(
1392     "and $0x3, %%eax\n"
1393     "je 0f\n"
1394     // Relu
1395     "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1396     "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1397     "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1398     "vmaxps %%ymm12, %%ymm2, %%ymm2\n"
1399     "vmaxps %%ymm12, %%ymm3, %%ymm3\n"
1400 
1401     "and $0x1, %%eax\n"
1402     "je 0f\n"
1403     // relu6
1404     "mov $0x40C00000, %%ecx\n"
1405     "vmovd %%ecx, %%xmm14\n"
1406     "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1407     "vminps %%ymm14, %%ymm0, %%ymm0\n"
1408     "vminps %%ymm14, %%ymm1, %%ymm1\n"
1409     "vminps %%ymm14, %%ymm2, %%ymm2\n"
1410     "vminps %%ymm14, %%ymm3, %%ymm3\n"
1411 
1412     "0:\n"
1413     "vmovups %%ymm0, (%2)\n"  // dst_0
1414     "vmovups %%ymm1, 0x20(%2)\n"
1415     "vmovups %%ymm2, 0x40(%2)\n"
1416     "vmovups %%ymm3, 0x60(%2)\n"
1417     :
1418     : "a"(act_flag), "r"(oc_algin), "r"(dst)
1419     : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm12", "%ymm14");
1420 }
1421 
DepthwiseSW4x24Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1422 void DepthwiseSW4x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1423                            size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1424                            size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1425   in_kh_step *= sizeof(float);
1426   in_kw_step *= sizeof(float);
1427   in_sw_step *= sizeof(float);
1428   kw_remainder *= sizeof(float);
1429   size_t src_3_step = 3 * in_sw_step;
1430   float *dst_3 = dst + 3 * oc_algin;
1431   oc_algin *= sizeof(float);
1432   asm volatile(
1433     "cmpq $0, %2\n"
1434     "je 0f\n"
1435     "vmovups (%2), %%ymm0\n"
1436     "vmovups 0x20(%2), %%ymm1\n"
1437     "vmovups 0x40(%2), %%ymm2\n"
1438     // We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction.
1439     "vmovups (%2), %%ymm3\n"
1440     "vmovups 0x20(%2), %%ymm4\n"
1441     "vmovups 0x40(%2), %%ymm5\n"
1442     "vmovups (%2), %%ymm6\n"
1443     "vmovups 0x20(%2), %%ymm7\n"
1444     "vmovups 0x40(%2), %%ymm8\n"
1445     "vmovups (%2), %%ymm9\n"
1446     "vmovups 0x20(%2), %%ymm10\n"
1447     "vmovups 0x40(%2), %%ymm11\n"
1448     "jmp 1f\n"
1449     "0:\n"
1450     "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1451     "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1452     "vxorps %%ymm2, %%ymm2, %%ymm2\n"
1453     "vxorps %%ymm3, %%ymm3, %%ymm3\n"
1454     "vxorps %%ymm4, %%ymm4, %%ymm4\n"
1455     "vxorps %%ymm5, %%ymm5, %%ymm5\n"
1456     "vxorps %%ymm6, %%ymm6, %%ymm6\n"
1457     "vxorps %%ymm7, %%ymm7, %%ymm7\n"
1458     "vxorps %%ymm8, %%ymm8, %%ymm8\n"
1459     "vxorps %%ymm9, %%ymm9, %%ymm9\n"
1460     "vxorps %%ymm10, %%ymm10, %%ymm10\n"
1461     "vxorps %%ymm11, %%ymm11, %%ymm11\n"
1462     "1:\n"              // LoopH
1463     "movq %4, %%rsi\n"  // width
1464     "movq %0, %%rcx\n"  // src_h
1465     "2:\n"              // LoopW
1466     "vmovups (%1), %%ymm12\n"
1467     "vmovups (%%rcx), %%ymm13\n"
1468     "vmovups (%%rcx, %7, 1), %%ymm14\n"
1469     "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n"
1470     "vfmadd231ps %%ymm12, %%ymm14, %%ymm3\n"
1471     "vmovups (%%rcx, %7, 2), %%ymm15\n"
1472     "vmovups (%%rcx, %9), %%ymm13\n"
1473     "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n"
1474     "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n"
1475 
1476     "vmovups 0x20(%1), %%ymm12\n"
1477     "vmovups 0x20(%%rcx), %%ymm13\n"
1478     "vmovups 0x20(%%rcx, %7, 1), %%ymm14\n"
1479     "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n"
1480     "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n"
1481     "vmovups 0x20(%%rcx, %7, 2), %%ymm15\n"
1482     "vmovups 0x20(%%rcx, %9), %%ymm13\n"
1483     "vfmadd231ps %%ymm12, %%ymm15, %%ymm7\n"
1484     "vfmadd231ps %%ymm12, %%ymm13, %%ymm10\n"
1485 
1486     "vmovups 0x40(%1), %%ymm12\n"
1487     "vmovups 0x40(%%rcx), %%ymm13\n"
1488     "vmovups 0x40(%%rcx, %7, 1), %%ymm14\n"
1489     "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n"
1490     "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n"
1491     "vmovups 0x40(%%rcx, %7, 2), %%ymm15\n"
1492     "vmovups 0x40(%%rcx, %9), %%ymm13\n"
1493     "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n"
1494     "vfmadd231ps %%ymm12, %%ymm13, %%ymm11\n"
1495 
1496     "addq $96, %1\n"
1497     "addq %5, %%rcx\n"  // in_kw_step
1498     "dec %%rsi\n"
1499     "jg 2b\n"
1500 
1501     "addq %6, %0\n"  // in_kh_step
1502     "addq %8, %1\n"  // border in sw need to add remainder data
1503     "dec %3\n"
1504     "jg 1b\n"
1505     :
1506     : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step),  // 5
1507       "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder), "r"(src_3_step)              // 9
1508     : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9",
1509       "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
1510 
1511   asm volatile(
1512     "and $0x3, %%eax\n"
1513     "je 0f\n"
1514     // Relu
1515     "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1516     "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1517     "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1518     "vmaxps %%ymm12, %%ymm2, %%ymm2\n"
1519     "vmaxps %%ymm12, %%ymm3, %%ymm3\n"
1520     "vmaxps %%ymm12, %%ymm4, %%ymm4\n"
1521     "vmaxps %%ymm12, %%ymm5, %%ymm5\n"
1522     "vmaxps %%ymm12, %%ymm6, %%ymm6\n"
1523     "vmaxps %%ymm12, %%ymm7, %%ymm7\n"
1524     "vmaxps %%ymm12, %%ymm8, %%ymm8\n"
1525     "vmaxps %%ymm12, %%ymm9, %%ymm9\n"
1526     "vmaxps %%ymm12, %%ymm10, %%ymm10\n"
1527     "vmaxps %%ymm12, %%ymm11, %%ymm11\n"
1528 
1529     "and $0x1, %%eax\n"
1530     "je 0f\n"
1531     // relu6
1532     "mov $0x40C00000, %%ecx\n"
1533     "vmovd %%ecx, %%xmm14\n"
1534     "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1535     "vminps %%ymm14, %%ymm0, %%ymm0\n"
1536     "vminps %%ymm14, %%ymm1, %%ymm1\n"
1537     "vminps %%ymm14, %%ymm2, %%ymm2\n"
1538     "vminps %%ymm14, %%ymm3, %%ymm3\n"
1539     "vminps %%ymm14, %%ymm4, %%ymm4\n"
1540     "vminps %%ymm14, %%ymm5, %%ymm5\n"
1541     "vminps %%ymm14, %%ymm6, %%ymm6\n"
1542     "vminps %%ymm14, %%ymm7, %%ymm7\n"
1543     "vminps %%ymm14, %%ymm8, %%ymm8\n"
1544     "vminps %%ymm14, %%ymm9, %%ymm9\n"
1545     "vminps %%ymm14, %%ymm10, %%ymm10\n"
1546     "vminps %%ymm14, %%ymm11, %%ymm11\n"
1547 
1548     "0:\n"
1549     "vmovups %%ymm0, (%2)\n"  // dst_0
1550     "vmovups %%ymm1, 0x20(%2)\n"
1551     "vmovups %%ymm2, 0x40(%2)\n"
1552     "vmovups %%ymm3, (%2, %1, 1)\n"
1553     "vmovups %%ymm4, 0x20(%2, %1, 1)\n"
1554     "vmovups %%ymm5, 0x40(%2, %1, 1)\n"
1555     "vmovups %%ymm6, (%2, %1, 2)\n"
1556     "vmovups %%ymm7, 0x20(%2, %1, 2)\n"
1557     "vmovups %%ymm8, 0x40(%2, %1, 2)\n"
1558     "vmovups %%ymm9, (%3)\n"  // dst+3
1559     "vmovups %%ymm10, 0x20(%3)\n"
1560     "vmovups %%ymm11, 0x40(%3)\n"
1561     :
1562     : "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3)
1563     : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
1564       "%ymm11", "%ymm12", "%ymm14");
1565 }
1566 
DepthwiseSW1x24Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1567 void DepthwiseSW1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1568                            size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1569                            size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1570   in_kh_step *= sizeof(float);
1571   in_kw_step *= sizeof(float);
1572   oc_algin *= sizeof(float);
1573   kw_remainder *= sizeof(float);
1574   asm volatile(
1575     "cmpq $0, %2\n"
1576     "je 0f\n"
1577     "vmovups (%2), %%ymm0\n"
1578     "vmovups 0x20(%2), %%ymm1\n"
1579     "vmovups 0x40(%2), %%ymm2\n"
1580     "jmp 1f\n"
1581     "0:\n"
1582     "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1583     "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1584     "vxorps %%ymm2, %%ymm2, %%ymm2\n"
1585     "1:\n"              // LoopH
1586     "movq %4, %%rsi\n"  // width
1587     "movq %0, %%rcx\n"  // src_h
1588     "2:\n"              // Loopw
1589     "vmovups (%%rcx), %%ymm4\n"
1590     "vmovups 0x20(%%rcx), %%ymm5\n"
1591     "vmovups 0x40(%%rcx), %%ymm6\n"
1592     // Weight data is loaded directly from memory instead of into registers for calculation.
1593     "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1594     "vfmadd231ps 0x20(%1), %%ymm5, %%ymm1\n"
1595     "vfmadd231ps 0x40(%1), %%ymm6, %%ymm2\n"
1596     "addq $96, %1\n"
1597 
1598     "addq %5, %%rcx\n"  // in_kw_step
1599     "dec %%rsi\n"
1600     "jg 2b\n"
1601 
1602     "addq %6, %0\n"  // in_kh_step
1603     "addq %7, %1\n"  // kw_remainder
1604     "dec %3\n"
1605     "jg 1b\n"
1606     :
1607     : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step),  // 5
1608       "r"(in_kh_step), "r"(kw_remainder)                                                // 7
1609     : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm4", "%ymm5", "%ymm6");
1610 
1611   asm volatile(
1612     "and $0x3, %%eax\n"
1613     "je 0f\n"
1614     // Relu
1615     "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1616     "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1617     "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1618     "vmaxps %%ymm12, %%ymm2, %%ymm2\n"
1619 
1620     "and $0x1, %%eax\n"
1621     "je 0f\n"
1622     // relu6
1623     "mov $0x40C00000, %%ecx\n"
1624     "vmovd %%ecx, %%xmm14\n"
1625     "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1626     "vminps %%ymm14, %%ymm0, %%ymm0\n"
1627     "vminps %%ymm14, %%ymm1, %%ymm1\n"
1628     "vminps %%ymm14, %%ymm2, %%ymm2\n"
1629 
1630     "0:\n"
1631     "vmovups %%ymm0, (%2)\n"  // dst_0
1632     "vmovups %%ymm1, 0x20(%2)\n"
1633     "vmovups %%ymm2, 0x40(%2)\n"
1634     :
1635     : "a"(act_flag), "r"(oc_algin), "r"(dst)
1636     : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm12", "%ymm14");
1637 }
1638 
DepthwiseSW4x16Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1639 void DepthwiseSW4x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1640                            size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1641                            size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1642   in_kh_step *= sizeof(float);
1643   in_kw_step *= sizeof(float);
1644   in_sw_step *= sizeof(float);
1645   kw_remainder *= sizeof(float);
1646   size_t src_3_step = 3 * in_sw_step;
1647   float *dst_3 = dst + 3 * oc_algin;
1648   oc_algin *= sizeof(float);
1649   asm volatile(
1650     "cmpq $0, %2\n"
1651     "je 0f\n"
1652     "vmovups (%2), %%ymm0\n"
1653     "vmovups 0x20(%2), %%ymm1\n"
1654     // We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction.
1655     "vmovups (%2), %%ymm3\n"
1656     "vmovups 0x20(%2), %%ymm4\n"
1657     "vmovups (%2), %%ymm6\n"
1658     "vmovups 0x20(%2), %%ymm7\n"
1659     "vmovups (%2), %%ymm9\n"
1660     "vmovups 0x20(%2), %%ymm10\n"
1661     "jmp 1f\n"
1662     "0:\n"
1663     "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1664     "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1665     "vxorps %%ymm3, %%ymm3, %%ymm3\n"
1666     "vxorps %%ymm4, %%ymm4, %%ymm4\n"
1667     "vxorps %%ymm6, %%ymm6, %%ymm6\n"
1668     "vxorps %%ymm7, %%ymm7, %%ymm7\n"
1669     "vxorps %%ymm9, %%ymm9, %%ymm9\n"
1670     "vxorps %%ymm10, %%ymm10, %%ymm10\n"
1671     "1:\n"              // LoopH
1672     "movq %4, %%rsi\n"  // width
1673     "movq %0, %%rcx\n"  // src_h
1674     "2:\n"              // LoopW
1675     "vmovups (%1), %%ymm12\n"
1676     "vmovups (%%rcx), %%ymm13\n"
1677     "vmovups (%%rcx, %7, 1), %%ymm14\n"
1678     "vmovups (%%rcx, %7, 2), %%ymm15\n"
1679     "vmovups (%%rcx, %9), %%ymm2\n"
1680     "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n"
1681     "vfmadd231ps %%ymm12, %%ymm14, %%ymm3\n"
1682     "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n"
1683     "vfmadd231ps %%ymm12, %%ymm2, %%ymm9\n"
1684 
1685     "vmovups 0x20(%1), %%ymm12\n"
1686     "vmovups 0x20(%%rcx), %%ymm13\n"
1687     "vmovups 0x20(%%rcx, %7, 1), %%ymm14\n"
1688     "vmovups 0x20(%%rcx, %7, 2), %%ymm15\n"
1689     "vmovups 0x20(%%rcx, %9), %%ymm2\n"
1690     "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n"
1691     "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n"
1692     "vfmadd231ps %%ymm12, %%ymm15, %%ymm7\n"
1693     "vfmadd231ps %%ymm12, %%ymm2, %%ymm10\n"
1694 
1695     "addq $64, %1\n"
1696     "addq %5, %%rcx\n"  // in_kw_step
1697     "dec %%rsi\n"
1698     "jg 2b\n"
1699 
1700     "addq %6, %0\n"  // in_kh_step
1701     "addq %8, %1\n"  // border in sw need to add remainder data
1702     "dec %3\n"
1703     "jg 1b\n"
1704     :
1705     : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step),  // 5
1706       "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder), "r"(src_3_step)              // 9
1707     : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm3", "%ymm4", "%ymm6", "%ymm7", "%ymm9", "%ymm10", "%ymm12", "%ymm13",
1708       "%ymm14", "%ymm15");
1709 
1710   asm volatile(
1711     "and $0x3, %%eax\n"
1712     "je 0f\n"
1713     // Relu
1714     "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1715     "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1716     "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1717     "vmaxps %%ymm12, %%ymm3, %%ymm3\n"
1718     "vmaxps %%ymm12, %%ymm4, %%ymm4\n"
1719     "vmaxps %%ymm12, %%ymm6, %%ymm6\n"
1720     "vmaxps %%ymm12, %%ymm7, %%ymm7\n"
1721     "vmaxps %%ymm12, %%ymm9, %%ymm9\n"
1722     "vmaxps %%ymm12, %%ymm10, %%ymm10\n"
1723 
1724     "and $0x1, %%eax\n"
1725     "je 0f\n"
1726     // relu6
1727     "mov $0x40C00000, %%ecx\n"
1728     "vmovd %%ecx, %%xmm14\n"
1729     "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1730     "vminps %%ymm14, %%ymm0, %%ymm0\n"
1731     "vminps %%ymm14, %%ymm1, %%ymm1\n"
1732     "vminps %%ymm14, %%ymm3, %%ymm3\n"
1733     "vminps %%ymm14, %%ymm4, %%ymm4\n"
1734     "vminps %%ymm14, %%ymm6, %%ymm6\n"
1735     "vminps %%ymm14, %%ymm7, %%ymm7\n"
1736     "vminps %%ymm14, %%ymm9, %%ymm9\n"
1737     "vminps %%ymm14, %%ymm10, %%ymm10\n"
1738 
1739     "0:\n"
1740     "vmovups %%ymm0, (%2)\n"  // dst_0
1741     "vmovups %%ymm1, 0x20(%2)\n"
1742     "vmovups %%ymm3, (%2, %1, 1)\n"
1743     "vmovups %%ymm4, 0x20(%2, %1, 1)\n"
1744     "vmovups %%ymm6, (%2, %1, 2)\n"
1745     "vmovups %%ymm7, 0x20(%2, %1, 2)\n"
1746     "vmovups %%ymm9, (%3)\n"  // dst+3
1747     "vmovups %%ymm10, 0x20(%3)\n"
1748     :
1749     : "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3)
1750     : "%ecx", "%ymm0", "%ymm1", "%ymm3", "%ymm4", "%ymm6", "%ymm7", "%ymm9", "%ymm10", "%ymm12", "%ymm14");
1751 }
1752 
DepthwiseSW1x16Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1753 void DepthwiseSW1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1754                            size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1755                            size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1756   in_kh_step *= sizeof(float);
1757   in_kw_step *= sizeof(float);
1758   oc_algin *= sizeof(float);
1759   kw_remainder *= sizeof(float);
1760   asm volatile(
1761     "cmpq $0, %2\n"
1762     "je 0f\n"
1763     "vmovups (%2), %%ymm0\n"
1764     "vmovups 0x20(%2), %%ymm1\n"
1765     "jmp 1f\n"
1766     "0:\n"
1767     "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1768     "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1769     "1:\n"              // LoopH
1770     "movq %4, %%rsi\n"  // width
1771     "movq %0, %%rcx\n"  // src_h
1772     "2:\n"              // Loopw
1773     "vmovups (%%rcx), %%ymm4\n"
1774     "vmovups 0x20(%%rcx), %%ymm5\n"
1775     // Weight data is loaded directly from memory instead of into registers for calculation.
1776     "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1777     "vfmadd231ps 0x20(%1), %%ymm5, %%ymm1\n"
1778     "addq $64, %1\n"
1779 
1780     "addq %5, %%rcx\n"  // in_kw_step
1781     "dec %%rsi\n"
1782     "jg 2b\n"
1783 
1784     "addq %6, %0\n"  // in_kh_step
1785     "addq %7, %1\n"  // kw_remainder
1786     "dec %3\n"
1787     "jg 1b\n"
1788     :
1789     : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step),  // 5
1790       "r"(in_kh_step), "r"(kw_remainder)                                                // 7
1791     : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm4", "%ymm5");
1792 
1793   asm volatile(
1794     "and $0x3, %%eax\n"
1795     "je 0f\n"
1796     // Relu
1797     "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1798     "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1799     "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1800 
1801     "and $0x1, %%eax\n"
1802     "je 0f\n"
1803     // relu6
1804     "mov $0x40C00000, %%ecx\n"
1805     "vmovd %%ecx, %%xmm14\n"
1806     "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1807     "vminps %%ymm14, %%ymm0, %%ymm0\n"
1808     "vminps %%ymm14, %%ymm1, %%ymm1\n"
1809 
1810     "0:\n"
1811     "vmovups %%ymm0, (%2)\n"  // dst_0
1812     "vmovups %%ymm1, 0x20(%2)\n"
1813     :
1814     : "a"(act_flag), "r"(oc_algin), "r"(dst)
1815     : "%ecx", "%ymm0", "%ymm1", "%ymm12", "%ymm14");
1816 }
1817 
DepthwiseSW8x8Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1818 void DepthwiseSW8x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1819                           size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1820                           size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1821   in_kh_step *= sizeof(float);
1822   in_sw_step *= sizeof(float);
1823   in_kw_step *= sizeof(float);
1824   kw_remainder *= sizeof(float);
1825   size_t src_3_step = 3 * in_sw_step;
1826   float *dst_3 = dst + 3 * oc_algin;
1827   float *dst_5 = dst + 5 * oc_algin;
1828   oc_algin *= sizeof(float);
1829   asm volatile(
1830     "cmpq $0, %0\n"
1831     "je 0f\n"
1832     "vmovups (%0), %%ymm0\n"
1833     "vmovups (%0), %%ymm1\n"
1834     "vmovups (%0), %%ymm2\n"
1835     "vmovups (%0), %%ymm3\n"
1836     "vmovups (%0), %%ymm4\n"
1837     "vmovups (%0), %%ymm5\n"
1838     "vmovups (%0), %%ymm6\n"
1839     "vmovups (%0), %%ymm7\n"
1840     "jmp 1f\n"
1841     "0:\n"
1842     "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1843     "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1844     "vxorps %%ymm2, %%ymm2, %%ymm2\n"
1845     "vxorps %%ymm3, %%ymm3, %%ymm3\n"
1846     "vxorps %%ymm4, %%ymm4, %%ymm4\n"
1847     "vxorps %%ymm5, %%ymm5, %%ymm5\n"
1848     "vxorps %%ymm6, %%ymm6, %%ymm6\n"
1849     "vxorps %%ymm7, %%ymm7, %%ymm7\n"
1850     "1:\n"
1851     :
1852     : "r"(bias)
1853     : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7");
1854 
1855   asm volatile(
1856     "LoopH:\n"
1857     "movq %3, %%rsi\n"  // width
1858     "movq %0, %%rcx\n"  // src_h
1859     "LoopW:\n"
1860     "movq %%rcx, %%rax\n"
1861     "vmovups (%1), %%ymm12\n"
1862     "vmovups (%%rax), %%ymm13\n"
1863     "vmovups (%%rax, %6), %%ymm14\n"
1864     "vmovups (%%rax, %6, 2), %%ymm15\n"
1865     "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n"
1866     "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n"
1867     "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n"
1868     "addq %7, %%rax\n"
1869     "vmovups (%%rax), %%ymm13\n"
1870     "vmovups (%%rax, %6), %%ymm14\n"
1871     "vmovups (%%rax, %6, 2), %%ymm15\n"
1872     "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n"
1873     "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n"
1874     "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n"
1875     "addq %7, %%rax\n"
1876     "vmovups (%%rax), %%ymm13\n"
1877     "vmovups (%%rax, %6), %%ymm14\n"
1878     "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n"
1879     "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n"
1880 
1881     "addq $32, %1\n"
1882     "addq %4, %%rcx\n"  // in_kw_step
1883     "dec %%rsi\n"
1884     "jg LoopW\n"
1885 
1886     "addq %5, %0\n"  // in_kh_step
1887     "addq %8, %1\n"  // border in sw need to add remainder data
1888     "dec %2\n"
1889     "jg LoopH\n"
1890     :
1891     : "r"(src), "r"(weight), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), "r"(in_kh_step),  // 5
1892       "r"(in_sw_step), "r"(src_3_step), "r"(kw_remainder)                                     // 8
1893     : "%rcx", "%rsi", "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm12",
1894       "%ymm13", "%ymm14", "%ymm15");
1895 
1896   asm volatile(
1897     "and $0x3, %%eax\n"
1898     "je Write\n"
1899     // Relu
1900     "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1901     "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1902     "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1903     "vmaxps %%ymm12, %%ymm2, %%ymm2\n"
1904     "vmaxps %%ymm12, %%ymm3, %%ymm3\n"
1905     "vmaxps %%ymm12, %%ymm4, %%ymm4\n"
1906     "vmaxps %%ymm12, %%ymm5, %%ymm5\n"
1907     "vmaxps %%ymm12, %%ymm6, %%ymm6\n"
1908     "vmaxps %%ymm12, %%ymm7, %%ymm7\n"
1909 
1910     "and $0x1, %%eax\n"
1911     "je Write\n"
1912     // relu6
1913     "mov $0x40C00000, %%ecx\n"
1914     "vmovd %%ecx, %%xmm14\n"
1915     "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1916     "vminps %%ymm14, %%ymm0, %%ymm0\n"
1917     "vminps %%ymm14, %%ymm1, %%ymm1\n"
1918     "vminps %%ymm14, %%ymm2, %%ymm2\n"
1919     "vminps %%ymm14, %%ymm3, %%ymm3\n"
1920     "vminps %%ymm14, %%ymm4, %%ymm4\n"
1921     "vminps %%ymm14, %%ymm5, %%ymm5\n"
1922     "vminps %%ymm14, %%ymm6, %%ymm6\n"
1923     "vminps %%ymm14, %%ymm7, %%ymm7\n"
1924 
1925     "Write:\n"
1926     "vmovups %%ymm0, (%2)\n"  // dst_0
1927     "vmovups %%ymm1, (%2, %1)\n"
1928     "vmovups %%ymm2, (%2, %1, 2)\n"
1929     "vmovups %%ymm3, (%3)\n"  // dst_3
1930     "vmovups %%ymm4, (%2, %1, 4)\n"
1931     "vmovups %%ymm5, (%4)\n"  // dst_5
1932     "vmovups %%ymm6, (%4, %1, 1)\n"
1933     "vmovups %%ymm7, (%4, %1, 2)\n"
1934     :
1935     : "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3), "r"(dst_5)
1936     : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm12", "%ymm14");
1937 }
1938 
DepthwiseSW1x8Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1939 void DepthwiseSW1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1940                           size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1941                           size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1942   in_kh_step *= sizeof(float);
1943   in_kw_step *= sizeof(float);
1944   oc_algin *= sizeof(float);
1945   kw_remainder *= sizeof(float);
1946   asm volatile(
1947     "cmpq $0, %2\n"
1948     "je 0f\n"
1949     "vmovups (%2), %%ymm0\n"
1950     "jmp 1f\n"
1951     "0:\n"
1952     "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1953     "1:\n"              // LoopH
1954     "movq %4, %%rsi\n"  // width
1955     "movq %0, %%rcx\n"  // src_h
1956     "2:\n"              // Loopw
1957     "vmovups (%%rcx), %%ymm4\n"
1958     // Weight data is loaded directly from memory instead of into registers for calculation.
1959     "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1960     "addq $32, %1\n"
1961 
1962     "addq %5, %%rcx\n"  // in_kw_step
1963     "dec %%rsi\n"
1964     "jg 2b\n"
1965 
1966     "addq %6, %0\n"  // in_kh_step
1967     "addq %7, %1\n"  // kw_remainder
1968     "dec %3\n"
1969     "jg 1b\n"
1970     :
1971     : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step),  // 5
1972       "r"(in_kh_step), "r"(kw_remainder)                                                // 7
1973     : "%rcx", "%rsi", "%ymm0", "%ymm4");
1974 
1975   asm volatile(
1976     "and $0x3, %%eax\n"
1977     "je 0f\n"
1978     // Relu
1979     "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1980     "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1981 
1982     "and $0x1, %%eax\n"
1983     "je 0f\n"
1984     // relu6
1985     "mov $0x40C00000, %%ecx\n"
1986     "vmovd %%ecx, %%xmm14\n"
1987     "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1988     "vminps %%ymm14, %%ymm0, %%ymm0\n"
1989 
1990     "0:\n"
1991     "vmovups %%ymm0, (%2)\n"  // dst_0
1992     :
1993     : "a"(act_flag), "r"(oc_algin), "r"(dst)
1994     : "%ecx", "%ymm0", "%ymm12", "%ymm14");
1995 }
1996 #endif
1997