1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "nnacl/fp32/conv_depthwise_fp32.h"
17 #include "nnacl/common_func.h"
18 #include "nnacl/fp32/common_func_fp32.h"
19 #include "nnacl/intrinsics/ms_simd_instructions.h"
20 #include "nnacl/errorcode.h"
21 #include "nnacl/fp32/activation_fp32.h"
22
23 #if !defined(ENABLE_ARM) && !defined(ENABLE_SSE)
ConvDwFp32Row(float * output_ptr,const float * input_ptr,const float * weight_ptr,int num_pixels,int output_channel,int input_step)24 void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, int num_pixels,
25 int output_channel, int input_step) {
26 for (int i = 0; i < num_pixels; i++) {
27 for (int c = 0; c < output_channel; c++) {
28 *output_ptr++ += weight_ptr[c] * input_ptr[c];
29 }
30 input_ptr += input_step;
31 }
32 }
33 #endif
34
ConvDw(float * output_data,const float * input_data,const float * weight_data,const float * bias_data,const ConvParameter * conv_param,int task_id)35 int ConvDw(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
36 const ConvParameter *conv_param, int task_id) {
37 if (conv_param->thread_num_ == 0 || conv_param->dilation_h_ == 0 || conv_param->stride_w_ == 0) {
38 return NNACL_ERR;
39 }
40 int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_);
41 int h_start = h_step * task_id;
42 int h_end = MSMIN(h_start + h_step, conv_param->output_h_);
43 bool relu = conv_param->act_type_ == ActType_Relu;
44 bool relu6 = conv_param->act_type_ == ActType_Relu6;
45 for (int b = 0; b < conv_param->output_batch_; b++) {
46 const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_;
47 float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
48 for (int oh = h_start; oh < h_end; oh++) {
49 float *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_;
50
51 int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_u_;
52 int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_));
53 int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_));
54
55 for (int ow = 0; ow < conv_param->output_w_; ow++) {
56 memcpy(dst_data + ow * conv_param->output_channel_, bias_data,
57 conv_param->output_channel_ * (int)(sizeof(float)));
58 }
59 for (int kh = start_kh; kh < end_kh; kh++) {
60 int ih = ih_origin + conv_param->dilation_h_ * kh;
61
62 const float *src_kh = src + ih * conv_param->input_w_ * conv_param->input_channel_;
63 const float *dw_weight_kh = weight_data + kh * conv_param->kernel_w_ * conv_param->output_channel_;
64
65 int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_;
66 for (int kw = 0; kw < conv_param->kernel_w_; kw++) {
67 int out_w_start = MSMAX(
68 0, (conv_param->pad_l_ - conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / conv_param->stride_w_);
69 int out_w_end = MSMIN(conv_param->output_w_, (conv_param->input_w_ + conv_param->pad_l_ -
70 conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) /
71 conv_param->stride_w_);
72
73 float *dst_w = dst_data + out_w_start * conv_param->output_channel_;
74 int iw_origin = (out_w_start * conv_param->stride_w_) - conv_param->pad_l_ + conv_param->dilation_w_ * kw;
75
76 const float *src_kw = src_kh + iw_origin * conv_param->input_channel_;
77 int num_pixels = out_w_end - out_w_start;
78
79 ConvDwFp32Row(dst_w, src_kw, dw_weight_kh, num_pixels, conv_param->output_channel_, in_sw_step);
80 dw_weight_kh += conv_param->output_channel_;
81 }
82 }
83 if (relu) {
84 Fp32Relu(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data);
85 }
86 if (relu6) {
87 Fp32Relu6(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data);
88 }
89 }
90 }
91 return NNACL_OK;
92 }
93
94 #ifdef ENABLE_AVX512
ConvDwAVX512(float * output_data,const float * input_data,const float * weight_data,const float * bias_data,const ConvParameter * conv_param,int task_id,ConvDwCalcParam * conv_dw_calc_param)95 int ConvDwAVX512(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
96 const ConvParameter *conv_param, int task_id, ConvDwCalcParam *conv_dw_calc_param) {
97 if (conv_param->thread_num_ == 0 || conv_param->dilation_h_ == 0 || conv_param->stride_w_ == 0) {
98 return NNACL_ERR;
99 }
100 int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_);
101 int h_start = h_step * task_id;
102 int h_end = MSMIN(h_start + h_step, conv_param->output_h_);
103 bool relu = conv_param->act_type_ == ActType_Relu;
104 bool relu6 = conv_param->act_type_ == ActType_Relu6;
105
106 int32_t *num_pixels = conv_dw_calc_param->num_pixels_;
107 int32_t *out_w_start = conv_dw_calc_param->out_w_start_;
108 int first_calc_kw = conv_dw_calc_param->first_calc_kw_;
109
110 for (int b = 0; b < conv_param->output_batch_; b++) {
111 const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_;
112 float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
113 for (int oh = h_start; oh < h_end; oh++) {
114 float *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_;
115
116 int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_u_;
117 int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_));
118 int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_));
119
120 bool first_calc_flag = true;
121 if (first_calc_kw == -1) {
122 first_calc_flag = false;
123 for (int ow = 0; ow < conv_param->output_w_; ow++) {
124 memcpy(dst_data + ow * conv_param->output_channel_, bias_data,
125 conv_param->output_channel_ * (int)(sizeof(float)));
126 }
127 }
128 for (int kh = start_kh; kh < end_kh; kh++) {
129 int ih = ih_origin + conv_param->dilation_h_ * kh;
130
131 const float *src_kh = src + ih * conv_param->input_w_ * conv_param->input_channel_;
132 const float *weight_kh = weight_data + kh * conv_param->kernel_w_ * conv_param->output_channel_;
133
134 int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_;
135
136 if (first_calc_flag) {
137 int iw_origin = -conv_param->pad_l_ + conv_param->dilation_w_ * first_calc_kw;
138
139 const float *src_kw = src_kh + iw_origin * conv_param->input_channel_;
140
141 ConvDwAVX512Fp32Row(dst_data, src_kw, weight_kh + first_calc_kw * conv_param->output_channel_,
142 conv_param->output_w_, conv_param->output_channel_, in_sw_step, true, bias_data);
143 }
144 for (int kw = 0; kw < conv_param->kernel_w_; kw++) {
145 if (first_calc_flag && (kw == first_calc_kw)) {
146 first_calc_flag = false;
147 weight_kh += conv_param->output_channel_;
148 continue;
149 }
150
151 float *dst_w = dst_data + out_w_start[kw] * conv_param->output_channel_;
152 int iw_origin = (out_w_start[kw] * conv_param->stride_w_) - conv_param->pad_l_ + conv_param->dilation_w_ * kw;
153
154 const float *src_kw = src_kh + iw_origin * conv_param->input_channel_;
155
156 ConvDwAVX512Fp32Row(dst_w, src_kw, weight_kh, num_pixels[kw], conv_param->output_channel_, in_sw_step, false,
157 bias_data);
158 weight_kh += conv_param->output_channel_;
159 }
160 }
161 if (relu) {
162 Fp32Relu(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data);
163 } else if (relu6) {
164 Fp32Relu6(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data);
165 }
166 }
167 }
168 return NNACL_OK;
169 }
170 #endif
171
InitSlidingParam(SlidingWindowParam * sliding,const ConvParameter * conv_param,int block)172 void InitSlidingParam(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) {
173 if (block == 0) {
174 return;
175 }
176 int left = 0;
177 int right = conv_param->output_w_;
178 int top = 0;
179 int bottom = conv_param->output_h_;
180
181 while (left * conv_param->stride_w_ < conv_param->pad_l_) {
182 left++;
183 }
184 while ((right - 1) * conv_param->stride_w_ - conv_param->pad_l_ + conv_param->kernel_w_ * conv_param->dilation_w_ >
185 conv_param->input_w_ &&
186 right > left) {
187 right--;
188 }
189 while (top * conv_param->stride_h_ < conv_param->pad_u_) {
190 top++;
191 }
192 while ((bottom - 1) * conv_param->stride_h_ - conv_param->pad_u_ + conv_param->kernel_h_ * conv_param->dilation_h_ >
193 conv_param->input_h_ &&
194 bottom > top) {
195 bottom--;
196 }
197 sliding->left_ = left;
198 sliding->right_ = right;
199 sliding->top_ = top;
200 sliding->bottom_ = bottom;
201 sliding->c_block_ = UP_DIV(conv_param->output_channel_, block);
202 sliding->block_channel_ = UP_DIV(conv_param->output_channel_, block) * block;
203 sliding->out_step_ = conv_param->output_h_ * conv_param->output_w_ * sliding->block_channel_;
204 if (conv_param->out_format_ == Format_NC4HW4) {
205 // write to nc8hw8
206 sliding->out_h_step_ = conv_param->output_w_ * block;
207 sliding->out_c_step_ = block * conv_param->output_h_ * conv_param->output_w_;
208 sliding->out_w_step_ = block;
209 sliding->out_block_step_ = sliding->out_c_step_;
210 } else {
211 // write to nhwc
212 sliding->out_h_step_ = conv_param->output_w_ * sliding->block_channel_;
213 sliding->out_c_step_ = block;
214 sliding->out_w_step_ = sliding->block_channel_;
215 sliding->out_block_step_ = sliding->out_w_step_;
216 }
217 }
218
InitSlidingParamConv(SlidingWindowParam * sliding,const ConvParameter * conv_param,int input_block,int weight_block)219 void InitSlidingParamConv(SlidingWindowParam *sliding, const ConvParameter *conv_param, int input_block,
220 int weight_block) {
221 InitSlidingParam(sliding, conv_param, weight_block);
222 AppendSlidingParamConv(sliding, conv_param, input_block, weight_block);
223 }
224
AppendSlidingParamConv(SlidingWindowParam * sliding,const ConvParameter * conv_param,int input_block,int weight_block)225 void AppendSlidingParamConv(SlidingWindowParam *sliding, const ConvParameter *conv_param, int input_block,
226 int weight_block) {
227 if (input_block == 0) { // is not aligned
228 sliding->ic_align_ = conv_param->input_channel_;
229 } else { // 1x1 input is aligned to input_block
230 sliding->ic_align_ = UP_DIV(conv_param->input_channel_, input_block) * input_block;
231 }
232 sliding->in_step_ = conv_param->input_h_ * conv_param->input_w_ * sliding->ic_align_; // for batch loop
233 sliding->in_h_step_ = conv_param->input_w_ * sliding->ic_align_;
234 sliding->in_sh_step_ = conv_param->input_w_ * sliding->ic_align_ * conv_param->stride_h_; // stride H
235 sliding->in_sw_step_ = sliding->ic_align_ * conv_param->stride_w_; // stride W
236 sliding->in_kh_step_ = conv_param->input_w_ * sliding->ic_align_ * conv_param->dilation_h_; // kernel H
237 sliding->in_kw_step_ = sliding->ic_align_ * conv_param->dilation_w_; // kernel W
238 sliding->kernel_step_ = conv_param->kernel_w_ * conv_param->kernel_h_ * sliding->ic_align_ * weight_block;
239 }
240
InitSlidingParamConvDw(SlidingWindowParam * sliding,const ConvParameter * conv_param,int block)241 void InitSlidingParamConvDw(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) {
242 InitSlidingParam(sliding, conv_param, block);
243 AppendSlidingParamConvDw(sliding, conv_param, block);
244 }
245
AppendSlidingParamConvDw(SlidingWindowParam * sliding,const ConvParameter * conv_param,int block)246 void AppendSlidingParamConvDw(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) {
247 sliding->in_step_ = conv_param->input_h_ * conv_param->input_w_ * sliding->block_channel_; // for batch loop
248 sliding->in_h_step_ = conv_param->input_w_ * sliding->block_channel_;
249 sliding->in_sh_step_ = conv_param->input_w_ * sliding->block_channel_ * conv_param->stride_h_; // stride H
250 sliding->in_sw_step_ = sliding->block_channel_ * conv_param->stride_w_; // stride W
251 sliding->in_kh_step_ = conv_param->input_w_ * sliding->block_channel_ * conv_param->dilation_h_; // kernel H
252 sliding->in_kw_step_ = sliding->block_channel_ * conv_param->dilation_w_; // kernel W
253 sliding->kernel_step_ = conv_param->kernel_w_ * conv_param->kernel_h_ * block;
254 }
255
256 /*conv depthwise fp32 begin*/
ConvDwBorderPixel(float * dst,const float * src,const float * weight,const float * bias,int height,int width,int in_kh_step,int in_kw_step,int kernel_w_step,bool is_relu,bool is_relu6)257 void ConvDwBorderPixel(float *dst, const float *src, const float *weight, const float *bias, int height, int width,
258 int in_kh_step, int in_kw_step, int kernel_w_step, bool is_relu, bool is_relu6) {
259 const float *src_kh = src;
260 const float *weight_kh = weight;
261 for (int c = 0; c < C4NUM; c++) {
262 dst[c] = 0;
263 }
264 for (int kh = 0; kh < height; kh++) {
265 const float *src_kw = src_kh;
266 const float *weight_kw = weight_kh;
267 for (int kw = 0; kw < width; kw++) {
268 for (int c = 0; c < C4NUM; c++) {
269 dst[c] += src_kw[c] * weight_kw[c];
270 }
271 src_kw += in_kw_step;
272 weight_kw += C4NUM;
273 } // kernel_w loop
274 src_kh += in_kh_step;
275 weight_kh += kernel_w_step;
276 } // kernel_h loop
277 for (int c = 0; c < C4NUM; c++) {
278 dst[c] += bias[c];
279 dst[c] = (is_relu) ? (MSMAX(0, dst[c])) : (dst[c]);
280 dst[c] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst[c]))) : (dst[c]);
281 }
282 }
283
ConvDwBorder(float * dst,const float * src,const float * weight,const float * bias,int top,int bottom,int left,int right,const ConvParameter * conv_param,const SlidingWindowParam * sliding)284 void ConvDwBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, int left,
285 int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
286 if (conv_param->dilation_h_ == 0 || conv_param->dilation_w_ == 0) {
287 return;
288 }
289 bool relu = conv_param->act_type_ == ActType_Relu;
290 bool relu6 = conv_param->act_type_ == ActType_Relu6;
291 float *dst_h = dst + top * sliding->out_h_step_;
292 for (int oh = top; oh < bottom; oh++) {
293 int ih = oh * conv_param->stride_h_ - conv_param->pad_u_;
294 int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
295 int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
296 const float *src_h = src + ih * sliding->in_h_step_;
297
298 float *dst_kernel = dst_h + left * sliding->block_channel_;
299 for (int ow = left; ow < right; ow++) {
300 int iw = ow * conv_param->stride_w_ - conv_param->pad_l_;
301 int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
302 int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
303 const float *src_w = src_h + iw * sliding->block_channel_;
304
305 const float *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_;
306 const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM;
307 #ifdef ENABLE_AVX
308 ConvDwFp32BorderParam *param = (ConvDwFp32BorderParam *)malloc(sizeof(ConvDwFp32BorderParam));
309 if (param == NULL) {
310 return;
311 }
312 param->dst = dst_kernel;
313 param->src = src_kernel;
314 param->weight = weight_kernel;
315 param->bias = bias;
316 param->height = end_kh - start_kh;
317 param->width = end_kw - start_kw;
318 param->in_kh_step = sliding->in_kh_step_ * sizeof(float);
319 param->in_kw_step = sliding->in_kw_step_ * sizeof(float);
320 param->kernel_w = conv_param->kernel_w_ * C4NUM * sizeof(float);
321 param->relu = relu;
322 param->relu6 = relu6;
323 ConvDwFp32Border(param);
324 free(param);
325 #elif defined(ENABLE_ARM) || defined(ENABLE_SSE)
326 ConvDwFp32Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
327 sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float),
328 conv_param->kernel_w_ * C4NUM * sizeof(float), relu, relu6);
329 #else
330 ConvDwBorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
331 sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C4NUM, relu, relu6);
332 #endif
333 dst_kernel += sliding->block_channel_;
334 } // width loop
335 dst_h += sliding->out_h_step_;
336 } // height loop
337 }
338
339 #ifndef ENABLE_ARM64
ConvDwCenter(float * dst,const float * src,const float * weight,const float * bias,int height,int width,int kernel_h,int kernel_w,int out_h_step,int block_channel,int in_sh_step,int in_sw_step,int in_kh_step,int in_kw_step,bool is_relu,bool is_relu6)340 void ConvDwCenter(float *dst, const float *src, const float *weight, const float *bias, int height, int width,
341 int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step,
342 int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6) {
343 float *dst_h = dst;
344 const float *src_h = src;
345 for (int oh = 0; oh < height; oh++) {
346 float *dst_w = dst_h;
347 const float *src_w = src_h;
348 for (int ow = 0; ow < width; ow++) {
349 const float *src_kh = src_w;
350 const float *weight_kh = weight;
351 for (int c = 0; c < C4NUM; c++) {
352 dst_w[c] = 0;
353 }
354 for (int kh = 0; kh < kernel_h; kh++) {
355 const float *src_kw = src_kh;
356 const float *weight_kw = weight_kh;
357 for (int kw = 0; kw < kernel_w; kw++) {
358 for (int c = 0; c < C4NUM; c++) {
359 dst_w[c] += src_kw[c] * weight_kw[c];
360 }
361 src_kw += in_kw_step;
362 weight_kw += C4NUM;
363 } // kernel_w loop
364 src_kh += in_kh_step;
365 weight_kh += kernel_w * C4NUM;
366 } // kernel_h loop
367 // add biad relu
368 for (int c = 0; c < C4NUM; c++) {
369 dst_w[c] += bias[c];
370 dst_w[c] = (is_relu) ? (MSMAX(0, dst_w[c])) : (dst_w[c]);
371 dst_w[c] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst_w[c]))) : (dst_w[c]);
372 }
373 dst_w += block_channel;
374 src_w += in_sw_step;
375 } // dst_width loop
376 dst_h += out_h_step;
377 src_h += in_sh_step;
378 } // dst_height loop
379 }
380 #endif
381
382 // conv depthwise fp32: sliding window
ConvDwSWFp32(float * output_data,const float * input_data,const float * weight_data,const float * bias_data,const ConvParameter * conv_param,const SlidingWindowParam * sliding,int task_id)383 void ConvDwSWFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
384 const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) {
385 bool relu = conv_param->act_type_ == ActType_Relu;
386 bool relu6 = conv_param->act_type_ == ActType_Relu6;
387 if (conv_param->thread_num_ == 0) {
388 return;
389 }
390 const float *src = input_data;
391 float *dst = output_data;
392 for (int b = 0; b < conv_param->output_batch_; b++) {
393 for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) {
394 const float *src_data = src + oc * C4NUM;
395 float *dst_data = dst + oc * C4NUM;
396 const float *weight = weight_data + oc * sliding->kernel_step_;
397 const float *bias = bias_data + oc * C4NUM;
398 ConvDwBorder(dst_data, src_data, weight, bias, 0, sliding->top_, 0, conv_param->output_w_, conv_param, sliding);
399 ConvDwBorder(dst_data, src_data, weight, bias, sliding->bottom_, conv_param->output_h_, 0, conv_param->output_w_,
400 conv_param, sliding);
401 ConvDwBorder(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, 0, sliding->left_, conv_param,
402 sliding);
403 ConvDwBorder(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, sliding->right_,
404 conv_param->output_w_, conv_param, sliding);
405
406 if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) {
407 int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_;
408 int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
409 const float *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_;
410 float *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;
411 #if defined(ENABLE_ARM) || defined(ENABLE_SSE)
412 ConvDwFp32Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
413 conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float),
414 sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float),
415 sliding->in_sw_step_ * sizeof(float), sliding->in_kh_step_ * sizeof(float),
416 sliding->in_kw_step_ * sizeof(float), relu, relu6);
417 #else
418 ConvDwCenter(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
419 conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_,
420 sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_, relu,
421 relu6);
422 #endif
423 }
424 } // output C4 loop
425 src += sliding->in_step_;
426 dst += sliding->out_step_;
427 } // batch loop
428 // output nhwc4
429 }
430 /*conv depthwise fp32 end*/
431
432 /*conv depthwise 3x3 fp32 begin*/
CheckConvDwUse3X3(const ConvParameter * conv_param)433 bool CheckConvDwUse3X3(const ConvParameter *conv_param) {
434 bool use_3x3 =
435 conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 &&
436 (conv_param->stride_h_ == 1 || conv_param->stride_h_ == 2) &&
437 (conv_param->stride_w_ == 1 || conv_param->stride_w_ == 2) && conv_param->stride_h_ == conv_param->stride_w_ &&
438 (conv_param->pad_u_ == 0 || conv_param->pad_u_ == 1) && (conv_param->pad_l_ == 0 || conv_param->pad_l_ == 1) &&
439 conv_param->pad_u_ == conv_param->pad_l_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1;
440 if (!use_3x3 || conv_param->input_h_ == 1 || conv_param->input_w_ == 1) {
441 return false;
442 }
443 const int in_h = (conv_param->output_h_ - 1) * conv_param->stride_h_ + conv_param->kernel_h_;
444 const int in_w = (conv_param->output_w_ - 1) * conv_param->stride_w_ + conv_param->kernel_w_;
445 return in_h == (conv_param->input_h_ + 2 * conv_param->pad_u_) &&
446 in_w == (conv_param->input_w_ + 2 * conv_param->pad_l_);
447 }
448
449 #if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
ConvDw3x3RowLeft(const float * src,float * line,int lw,int channel)450 static void ConvDw3x3RowLeft(const float *src, float *line, int lw, int channel) {
451 MS_FLOAT32X4 v0, v1, v2, v3;
452 v0 = MS_MOVQ_F32(0.0f);
453 int ic = 0;
454 for (; ic < channel - 3; ic += 4) {
455 v1 = MS_LDQ_F32(src + ic);
456 v2 = MS_LDQ_F32(src + channel + ic);
457 v3 = MS_LDQ_F32(src + 2 * channel + ic);
458 MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2);
459 MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2);
460 MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1);
461 MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1);
462 MS_STQ_F32(line + lw * ic, b0);
463 MS_STQ_F32(line + lw * ic + 4, b1);
464 MS_STQ_F32(line + lw * ic + 8, b2);
465 MS_STQ_F32(line + lw * ic + 12, b3);
466 }
467 if (ic < channel) {
468 float *remain_line = line + ic * lw;
469 memset(remain_line, 0, 64);
470 for (int i = 0; i < channel - ic; i++) {
471 float d1 = src[i + ic];
472 float d2 = src[i + ic + channel];
473 float d3 = src[i + ic + 2 * channel];
474 remain_line[i] = 0.0f - d2;
475 remain_line[i + 4] = d1 + d2;
476 remain_line[i + 8] = d2 - d1;
477 remain_line[i + 12] = d3 - d1;
478 }
479 }
480 }
481
ConvDw3x3RowMiddle(const float * src,float * line,int lw,int channel)482 static void ConvDw3x3RowMiddle(const float *src, float *line, int lw, int channel) {
483 MS_FLOAT32X4 v0, v1, v2, v3;
484 int ic = 0;
485 for (; ic < channel - 3; ic += 4) {
486 v0 = MS_LDQ_F32(src + ic);
487 v1 = MS_LDQ_F32(src + channel + ic);
488 v2 = MS_LDQ_F32(src + 2 * channel + ic);
489 v3 = MS_LDQ_F32(src + 3 * channel + ic);
490 MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2);
491 MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2);
492 MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1);
493 MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1);
494 MS_STQ_F32(line + lw * ic, b0);
495 MS_STQ_F32(line + lw * ic + 4, b1);
496 MS_STQ_F32(line + lw * ic + 8, b2);
497 MS_STQ_F32(line + lw * ic + 12, b3);
498 }
499 if (ic < channel) {
500 float *remain_line = line + ic * lw;
501 memset(remain_line, 0, 64);
502 for (int i = 0; i < channel - ic; i++) {
503 float d0 = src[i + ic];
504 float d1 = src[i + ic + channel];
505 float d2 = src[i + ic + 2 * channel];
506 float d3 = src[i + ic + 3 * channel];
507 remain_line[i] = d0 - d2;
508 remain_line[i + 4] = d1 + d2;
509 remain_line[i + 8] = d2 - d1;
510 remain_line[i + 12] = d3 - d1;
511 }
512 }
513 }
514
ConvDw3x3RowRight(const float * src,float * line,int lw,int channel)515 static void ConvDw3x3RowRight(const float *src, float *line, int lw, int channel) {
516 MS_FLOAT32X4 v0, v1, v2, v3;
517 int ic = 0;
518 v3 = MS_MOVQ_F32(0.0f);
519 for (; ic < channel - 3; ic += 4) {
520 v0 = MS_LDQ_F32(src + ic);
521 v1 = MS_LDQ_F32(src + channel + ic);
522 v2 = MS_LDQ_F32(src + 2 * channel + ic);
523 MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2);
524 MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2);
525 MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1);
526 MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1);
527 MS_STQ_F32(line + lw * ic, b0);
528 MS_STQ_F32(line + lw * ic + 4, b1);
529 MS_STQ_F32(line + lw * ic + 8, b2);
530 MS_STQ_F32(line + lw * ic + 12, b3);
531 }
532 if (ic < channel) {
533 float *remain_line = line + ic * lw;
534 memset(remain_line, 0, 64);
535 for (int i = 0; i < channel - ic; i++) {
536 float d0 = src[i + ic];
537 float d1 = src[i + ic + channel];
538 float d2 = src[i + ic + 2 * channel];
539 remain_line[i] = d0 - d2;
540 remain_line[i + 4] = d1 + d2;
541 remain_line[i + 8] = d2 - d1;
542 remain_line[i + 12] = 0.0f - d1;
543 }
544 }
545 }
546
ConvDw3x3RowSingle(const float * src,float * line,int lw,int channel)547 static void ConvDw3x3RowSingle(const float *src, float *line, int lw, int channel) {
548 MS_FLOAT32X4 v0, v1, v2;
549 int ic = 0;
550 v2 = MS_MOVQ_F32(0.0f);
551 for (; ic < channel - 3; ic += 4) {
552 v0 = MS_LDQ_F32(src + ic);
553 v1 = MS_LDQ_F32(src + channel + ic);
554 MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1);
555 MS_STQ_F32(line + lw * ic, v0);
556 MS_STQ_F32(line + lw * ic + 4, v1);
557 MS_STQ_F32(line + lw * ic + 8, b2);
558 memset(line + lw * ic + 12, 0, 16);
559 }
560 if (ic < channel) {
561 float *remain_line = line + ic * lw;
562 memset(remain_line, 0, 64);
563 for (int i = 0; i < channel - ic; i++) {
564 float d0 = src[i + ic];
565 float d1 = src[i + ic + channel];
566 remain_line[i] = d0;
567 remain_line[i + 4] = d1;
568 remain_line[i + 8] = 0.0f - d1;
569 }
570 }
571 }
572
ConvDw3x3InitTop(const float * src,float ** lines,int width,int channel)573 static void ConvDw3x3InitTop(const float *src, float **lines, int width, int channel) {
574 float *line0 = lines[0];
575 float *line1 = lines[1];
576 float *line2 = lines[2];
577 int c4 = UP_ROUND(channel, C4NUM);
578 int lw = UP_DIV(width, C2NUM) * C4NUM;
579 memset(line0, 0, c4 * lw * sizeof(float));
580 ConvDw3x3RowLeft(src, line1, lw, channel);
581 ConvDw3x3RowLeft(src + width * channel, line2, lw, channel);
582 int ow = 2;
583 for (; ow < width - 2; ow += 2) {
584 ConvDw3x3RowMiddle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
585 ConvDw3x3RowMiddle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
586 }
587 int remain = width - ow;
588 if (remain == 2) {
589 ConvDw3x3RowRight(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
590 ConvDw3x3RowRight(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
591 } else if (remain == 1) {
592 ConvDw3x3RowSingle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
593 ConvDw3x3RowSingle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
594 }
595 }
596
ConvDw3x3InitRow(const float * src,float ** lines,int width,int channel)597 static void ConvDw3x3InitRow(const float *src, float **lines, int width, int channel) {
598 float *line0 = lines[0];
599 float *line1 = lines[1];
600 float *line2 = lines[2];
601 int lw = UP_DIV(width, C2NUM) * C4NUM;
602 ConvDw3x3RowLeft(src - width * channel, line0, lw, channel);
603 ConvDw3x3RowLeft(src, line1, lw, channel);
604 ConvDw3x3RowLeft(src + width * channel, line2, lw, channel);
605 int ow = 2;
606 for (; ow < width - 2; ow += 2) {
607 ConvDw3x3RowMiddle(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel);
608 ConvDw3x3RowMiddle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
609 ConvDw3x3RowMiddle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
610 }
611 int remain = width - ow;
612 if (remain == 2) {
613 ConvDw3x3RowRight(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel);
614 ConvDw3x3RowRight(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
615 ConvDw3x3RowRight(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
616 } else if (remain == 1) {
617 ConvDw3x3RowSingle(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel);
618 ConvDw3x3RowSingle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
619 ConvDw3x3RowSingle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
620 }
621 }
622
ConvDw3x3Row(const float * src,float ** lines,int width,int channel)623 static void ConvDw3x3Row(const float *src, float **lines, int width, int channel) {
624 float *tmp = lines[0];
625 lines[0] = lines[1];
626 lines[1] = lines[2];
627 lines[2] = tmp;
628 int c4 = UP_ROUND(channel, C4NUM);
629 int lw = UP_DIV(width, C2NUM) * C4NUM;
630 memset(tmp, 0, c4 * lw * sizeof(float));
631 ConvDw3x3RowLeft(src, tmp, lw, channel);
632 int ow = 2;
633 for (; ow < width - 2; ow += 2) {
634 ConvDw3x3RowMiddle(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel);
635 }
636 int remain = width - ow;
637 if (remain == 2) {
638 ConvDw3x3RowRight(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel);
639 } else if (remain == 1) {
640 ConvDw3x3RowSingle(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel);
641 }
642 }
643
ConvDw3x3Bottom(float ** lines,int width,int channel)644 static void ConvDw3x3Bottom(float **lines, int width, int channel) {
645 float *tmp = lines[0];
646 lines[0] = lines[1];
647 lines[1] = lines[2];
648 lines[2] = tmp;
649 int c4 = UP_ROUND(channel, C4NUM);
650 memset(tmp, 0, UP_DIV(width, C2NUM) * c4 * C4NUM * sizeof(float));
651 }
652
653 #ifndef ENABLE_ARM64
ConvDw3x3Line(float * dst,float ** lines,const float * weight,const float * bias_data,int width,int ori_channel,bool relu,bool relu6)654 void ConvDw3x3Line(float *dst, float **lines, const float *weight, const float *bias_data, int width, int ori_channel,
655 bool relu, bool relu6) {
656 int channel = ori_channel;
657 float *line0 = lines[0];
658 float *line1 = lines[1];
659 float *line2 = lines[2];
660 for (; channel > 0; channel -= 4) {
661 MS_FLOAT32X4 bias = MS_LDQ_F32(bias_data);
662 bias_data += 4;
663 MS_FLOAT32X4 g00 = MS_LDQ_F32(weight);
664 MS_FLOAT32X4 g01 = MS_LDQ_F32(weight + 4);
665 MS_FLOAT32X4 g02 = MS_LDQ_F32(weight + 8);
666 MS_FLOAT32X4 g03 = MS_LDQ_F32(weight + 12);
667 MS_FLOAT32X4 g10 = MS_LDQ_F32(weight + 16);
668 MS_FLOAT32X4 g11 = MS_LDQ_F32(weight + 20);
669 MS_FLOAT32X4 g12 = MS_LDQ_F32(weight + 24);
670 MS_FLOAT32X4 g13 = MS_LDQ_F32(weight + 28);
671 MS_FLOAT32X4 g20 = MS_LDQ_F32(weight + 32);
672 MS_FLOAT32X4 g21 = MS_LDQ_F32(weight + 36);
673 MS_FLOAT32X4 g22 = MS_LDQ_F32(weight + 40);
674 MS_FLOAT32X4 g23 = MS_LDQ_F32(weight + 44);
675 weight += 48;
676 float *cur_dst = dst;
677 int ow = 0;
678 for (; ow < width - 1; ow += 2) {
679 MS_FLOAT32X4 acc0 = MS_MULQ_F32(MS_LDQ_F32(line0), g00);
680 MS_FLOAT32X4 acc1 = MS_MULQ_F32(MS_LDQ_F32(line0 + 4), g01);
681 MS_FLOAT32X4 acc2 = MS_MULQ_F32(MS_LDQ_F32(line0 + 8), g02);
682 MS_FLOAT32X4 acc3 = MS_MULQ_F32(MS_LDQ_F32(line0 + 12), g03);
683 line0 += 16;
684 acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line1), g10);
685 acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line1 + 4), g11);
686 acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line1 + 8), g12);
687 acc3 = MS_MLAQ_F32(acc3, MS_LDQ_F32(line1 + 12), g13);
688 line1 += 16;
689 acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line2), g20);
690 acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line2 + 4), g21);
691 acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line2 + 8), g22);
692 acc3 = MS_MLAQ_F32(acc3, MS_LDQ_F32(line2 + 12), g23);
693 line2 += 16;
694 MS_FLOAT32X4 res0 = MS_ADDQ_F32(acc0, MS_ADDQ_F32(acc2, acc1));
695 MS_FLOAT32X4 res1 = MS_ADDQ_F32(acc1, MS_SUBQ_F32(acc3, acc2));
696 res0 = MS_ADDQ_F32(res0, bias);
697 res1 = MS_ADDQ_F32(res1, bias);
698 if (relu || relu6) {
699 res0 = MS_MAXQ_F32(res0, MS_MOVQ_F32(0.0f));
700 res1 = MS_MAXQ_F32(res1, MS_MOVQ_F32(0.0f));
701 }
702 if (relu6) {
703 res0 = MS_MINQ_F32(res0, MS_MOVQ_F32(6.0f));
704 res1 = MS_MINQ_F32(res1, MS_MOVQ_F32(6.0f));
705 }
706 if (channel >= 4) {
707 MS_STQ_F32(cur_dst, res0);
708 MS_STQ_F32(cur_dst + ori_channel, res1);
709 } else {
710 for (int i = 0; i < channel; i++) {
711 cur_dst[i] = MS_F32X4_GETI(res0, i);
712 cur_dst[ori_channel + i] = MS_F32X4_GETI(res1, i);
713 }
714 }
715 cur_dst += 2 * ori_channel;
716 }
717 if (ow < width) {
718 MS_FLOAT32X4 acc0 = MS_MULQ_F32(MS_LDQ_F32(line0), g00);
719 MS_FLOAT32X4 acc1 = MS_MULQ_F32(MS_LDQ_F32(line0 + 4), g01);
720 MS_FLOAT32X4 acc2 = MS_MULQ_F32(MS_LDQ_F32(line0 + 8), g02);
721 line0 += 16;
722 acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line1), g10);
723 acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line1 + 4), g11);
724 acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line1 + 8), g12);
725 line1 += 16;
726 acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line2), g20);
727 acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line2 + 4), g21);
728 acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line2 + 8), g22);
729 line2 += 16;
730 MS_FLOAT32X4 res0 = MS_ADDQ_F32(acc0, MS_ADDQ_F32(acc2, acc1));
731 res0 = MS_ADDQ_F32(res0, bias);
732 if (relu || relu6) {
733 res0 = MS_MAXQ_F32(res0, MS_MOVQ_F32(0.0f));
734 }
735 if (relu6) {
736 res0 = MS_MINQ_F32(res0, MS_MOVQ_F32(6.0f));
737 }
738 if (channel >= 4) {
739 MS_STQ_F32(cur_dst, res0);
740 } else {
741 for (int i = 0; i < channel; i++) {
742 cur_dst[i] = MS_F32X4_GETI(res0, i);
743 }
744 }
745 }
746 dst += 4;
747 }
748 }
749 #endif
750
ConvDw3x3(float * output_data,float * buffer,const float * input_data,const float * weight_data,const float * bias_data,const ConvParameter * conv_param,int start_oh,int end_oh)751 void ConvDw3x3(float *output_data, float *buffer, const float *input_data, const float *weight_data,
752 const float *bias_data, const ConvParameter *conv_param, int start_oh, int end_oh) {
753 int units = UP_DIV(conv_param->output_w_, C2NUM);
754 int c4 = UP_ROUND(conv_param->input_channel_, C4NUM);
755 int line = conv_param->input_channel_ * conv_param->input_w_;
756
757 bool relu = conv_param->act_type_ == ActType_Relu;
758 bool relu6 = conv_param->act_type_ == ActType_Relu6;
759
760 for (int b = 0; b < conv_param->output_batch_; b++) {
761 const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_;
762 float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
763 float *line0 = buffer;
764 float *line1 = buffer + units * c4 * C4NUM;
765 float *line2 = buffer + units * c4 * C8NUM;
766 float *lines[3] = {line0, line1, line2};
767 int oh = start_oh;
768 if (oh == 0) {
769 // input trans
770 ConvDw3x3InitTop(src, lines, conv_param->output_w_, conv_param->input_channel_);
771 } else {
772 // input trans
773 ConvDw3x3InitRow(src + oh * line, lines, conv_param->output_w_, conv_param->input_channel_);
774 }
775 // dst calc and trans
776 ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_,
777 relu, relu6);
778 for (oh = start_oh + 1; oh < end_oh - 1; oh++) {
779 // input trans
780 ConvDw3x3Row(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_);
781 // dst calc and trans
782 ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_,
783 relu, relu6);
784 }
785 if (oh == conv_param->output_h_ - 1) {
786 // input trans
787 ConvDw3x3Bottom(lines, conv_param->output_w_, conv_param->input_channel_);
788 } else {
789 // input trans
790 ConvDw3x3Row(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_);
791 }
792 // dst calc and trans
793 ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_,
794 relu, relu6);
795 }
796 }
797 #endif
798
799 /*conv depthwise indirect buffer fp32 begin*/
CheckConvDwUseIndirectBuffer(const ConvParameter * conv_param)800 bool CheckConvDwUseIndirectBuffer(const ConvParameter *conv_param) {
801 bool use_indirect = (conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3) ||
802 (conv_param->kernel_h_ == 5 && conv_param->kernel_w_ == 5);
803 return use_indirect;
804 }
805
ConvDwInitIndirection(float ** indirect_buffer,float * src,float * zero_ptr,const ConvParameter * conv_param,int step_h,int step_w)806 void ConvDwInitIndirection(float **indirect_buffer, float *src, float *zero_ptr, const ConvParameter *conv_param,
807 int step_h, int step_w) {
808 #ifdef ENABLE_AVX
809 int div = C8NUM;
810 #else
811 int div = C4NUM;
812 #endif
813
814 int ic_div = UP_DIV(conv_param->input_channel_, div) * div;
815 for (int b = 0; b < conv_param->output_batch_; b++) {
816 float **indirect = indirect_buffer + b * conv_param->output_h_ * step_h;
817 float *input = src + b * conv_param->input_h_ * conv_param->input_w_ * ic_div;
818 for (int oh = 0; oh < conv_param->output_h_; oh++) {
819 for (int kh = 0; kh < conv_param->kernel_h_; kh++) {
820 int ih = oh * conv_param->stride_h_ + kh * conv_param->dilation_h_ - conv_param->pad_u_;
821 if (ih < conv_param->input_h_ && ih >= 0) {
822 for (int ow = 0; ow < conv_param->output_w_; ow++) {
823 for (int kw = 0; kw < conv_param->kernel_w_; kw++) {
824 int iw = ow * conv_param->stride_w_ + kw * conv_param->dilation_w_ - conv_param->pad_l_;
825 int index = oh * step_h + ow * step_w * conv_param->kernel_h_ + kw * conv_param->kernel_h_ + kh;
826 if (iw < conv_param->input_w_ && iw >= 0) {
827 indirect[index] = input + (ih * conv_param->input_w_ + iw) * ic_div;
828 } else {
829 indirect[index] = zero_ptr;
830 }
831 }
832 }
833 } else {
834 for (int ow = 0; ow < conv_param->output_w_; ow++) {
835 for (int kw = 0; kw < conv_param->kernel_w_; kw++) {
836 int index = oh * step_h + ow * step_w * conv_param->kernel_h_ + kw * conv_param->kernel_h_ + kh;
837 indirect[index] = zero_ptr;
838 }
839 }
840 }
841 }
842 }
843 }
844 }
845
846 #if !defined(ENABLE_ARM64) && !defined(ENABLE_AVX)
ConvDwFp32IndirectRow(float * output,float ** input,const float * weights,const float * bias,int channels,int output_width,int input_stride,bool relu,bool relu6,int kernel)847 void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
848 int output_width, int input_stride, bool relu, bool relu6, int kernel) {
849 do {
850 float **in = input;
851 size_t c = (size_t)channels;
852 const float *w = weights;
853 float *out = output;
854 memcpy(out, bias, channels * (int)sizeof(float));
855 for (; c >= C4NUM; c -= C4NUM) {
856 for (int i = 0; i < C4NUM; i++) {
857 for (int k = 0; k < kernel; k++) {
858 out[i] += in[k][i] * w[i + k * C4NUM];
859 }
860 }
861 w += kernel * C4NUM;
862 out += C4NUM;
863 for (int k = 0; k < kernel; k++) {
864 in[k] += C4NUM;
865 }
866 }
867 for (int i = 0; i < c; i++) {
868 for (int k = 0; k < kernel; k++) {
869 out[i] += in[k][i] * w[i + k * C4NUM];
870 }
871 }
872 if (relu) {
873 Fp32Relu(output, channels, output);
874 }
875 if (relu6) {
876 Fp32Relu6(output, channels, output);
877 }
878 output += channels;
879 input = input + input_stride;
880 } while (--output_width != 0);
881 }
882 #endif
883
884 #ifdef ENABLE_ARM64
ConvDwFp32IndirectRow(float * output,float ** input,const float * weights,const float * bias,int channels,int output_width,int input_stride,bool relu,bool relu6,int kernel)885 void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
886 int output_width, int input_stride, bool relu, bool relu6, int kernel) {
887 if (kernel == 9) {
888 ConvDwFp32Indirect3x3(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu,
889 relu6);
890 } else if (kernel == 25) {
891 ConvDwFp32Indirect5x5(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu,
892 relu6);
893 }
894 }
895 #endif
896
897 #ifdef ENABLE_AVX
ConvDwFp32IndirectRow(float * output,float ** input,const float * weights,const float * bias,int channels,int output_width,int input_stride,bool relu,bool relu6,int kernel)898 void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
899 int output_width, int input_stride, bool relu, bool relu6, int kernel) {
900 if (kernel == 9) {
901 ConvDwFp32Avx3x3(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu, relu6);
902 } else if (kernel == 25) {
903 ConvDwFp32Avx5x5(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu, relu6);
904 }
905 }
906 #endif
907
ConvDwIndirection(float * output_data,float ** indirect_buffer,const float * weight_data,const float * bias_data,float * zero_ptr,const ConvParameter * conv_param,int task_id)908 void ConvDwIndirection(float *output_data, float **indirect_buffer, const float *weight_data, const float *bias_data,
909 float *zero_ptr, const ConvParameter *conv_param, int task_id) {
910 if (conv_param->thread_num_ == 0) {
911 return;
912 }
913 int step_w = conv_param->dilation_w_ == 1 ? conv_param->stride_w_ : conv_param->kernel_w_;
914 int step_h =
915 (conv_param->kernel_h_ * conv_param->kernel_w_) + (conv_param->output_w_ - 1) * step_w * conv_param->kernel_h_;
916 int input_stride = conv_param->kernel_h_ * step_w;
917
918 bool relu = conv_param->act_type_ == ActType_Relu;
919 bool relu6 = conv_param->act_type_ == ActType_Relu6;
920
921 int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_);
922 int h_start = h_step * task_id;
923 int h_end = MSMIN(h_start + h_step, conv_param->output_h_);
924
925 for (int b = 0; b < conv_param->output_batch_; b++) {
926 float **indirect_b = indirect_buffer + b * conv_param->output_h_ * step_h;
927 float *outout_b = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
928 for (int oh = h_start; oh < h_end; oh++) {
929 float **indirect = indirect_b + oh * step_h;
930 float *output_h = outout_b + oh * conv_param->output_w_ * conv_param->output_channel_;
931 if (conv_param->kernel_w_ == 3) {
932 ConvDwFp32IndirectRow(output_h, indirect, weight_data, bias_data, conv_param->output_channel_,
933 conv_param->output_w_, input_stride, relu, relu6, 9);
934 } else if (conv_param->kernel_w_ == 5) {
935 ConvDwFp32IndirectRow(output_h, indirect, weight_data, bias_data, conv_param->output_channel_,
936 conv_param->output_w_, input_stride, relu, relu6, 25);
937 }
938 }
939 }
940 }
941 /*conv depthwise indirect buffer fp32 end*/
942
943 /*deconv depthwise fp32 begin*/
DeconvDwBorderPixel(float * dst,const float * src,const float * weight,int height,int width,int in_kh_step,int in_kw_step,int kernel_w_step)944 void DeconvDwBorderPixel(float *dst, const float *src, const float *weight, int height, int width, int in_kh_step,
945 int in_kw_step, int kernel_w_step) {
946 float *dst_kh = dst;
947 const float *weight_kh = weight;
948 for (int kh = 0; kh < height; kh++) {
949 float *dst_kw = dst_kh;
950 const float *weight_kw = weight_kh;
951 for (int kw = 0; kw < width; kw++) {
952 #ifdef ENABLE_ARM64
953 float32x4_t src_4 = vld1q_f32(src);
954 float32x4_t weight_4 = vld1q_f32(weight_kw);
955 float32x4_t dst_4 = vld1q_f32(dst_kw);
956 dst_4 = vfmaq_f32(dst_4, src_4, weight_4);
957 vst1q_f32(dst_kw, dst_4);
958 #else
959 for (int c = 0; c < C4NUM; c++) {
960 dst_kw[c] += src[c] * weight_kw[c];
961 }
962 #endif
963 dst_kw += in_kw_step;
964 weight_kw += C4NUM;
965 } // kernel_w loop
966 dst_kh += in_kh_step;
967 weight_kh += kernel_w_step;
968 } // kernel_h loop
969 }
970
DeconvDwBorder(float * dst,const float * src,const float * weight,int top,int bottom,int left,int right,const ConvParameter * conv_param,const SlidingWindowParam * sliding)971 void DeconvDwBorder(float *dst, const float *src, const float *weight, int top, int bottom, int left, int right,
972 const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
973 if (conv_param->dilation_h_ == 0 || conv_param->dilation_w_ == 0) {
974 return;
975 }
976 const float *src_h = src + top * sliding->out_h_step_;
977 for (int ih = top; ih < bottom; ih++) {
978 int oh = ih * conv_param->stride_h_ - conv_param->pad_u_;
979 int start_kh = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
980 int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
981 float *dst_h = dst + oh * sliding->in_h_step_;
982
983 const float *src_kernel = src_h + left * sliding->block_channel_;
984 for (int iw = left; iw < right; iw++) {
985 int ow = iw * conv_param->stride_w_ - conv_param->pad_l_;
986 int start_kw = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_));
987 int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_));
988 float *dst_w = dst_h + ow * sliding->block_channel_;
989
990 const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM;
991 float *dst_kernel = dst_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_;
992 #ifdef ENABLE_ARM64
993 DeconvDwFp32Border(dst_kernel, src_kernel, weight_kernel, end_kh - start_kh, end_kw - start_kw,
994 sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float),
995 conv_param->kernel_w_ * C4NUM * sizeof(float));
996 #else
997 DeconvDwBorderPixel(dst_kernel, src_kernel, weight_kernel, end_kh - start_kh, end_kw - start_kw,
998 sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C4NUM);
999 #endif
1000 src_kernel += sliding->block_channel_;
1001 } // width loop
1002 src_h += sliding->out_h_step_;
1003 } // height loop
1004 }
1005
1006 #ifndef ENABLE_ARM64
DeconvDwCenter(float * dst,const float * src,const float * weight,int height,int width,int kernel_h,int kernel_w,int out_h_step,int block_channel,int in_sh_step,int in_sw_step,int in_kh_step,int in_kw_step)1007 void DeconvDwCenter(float *dst, const float *src, const float *weight, int height, int width, int kernel_h,
1008 int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step, int in_kh_step,
1009 int in_kw_step) {
1010 float *dst_h = dst;
1011 const float *src_h = src;
1012 for (int oh = 0; oh < height; oh++) {
1013 float *dst_w = dst_h;
1014 const float *src_w = src_h;
1015 for (int ow = 0; ow < width; ow++) {
1016 float *dst_kh = dst_w;
1017 const float *weight_kh = weight;
1018 for (int kh = 0; kh < kernel_h; kh++) {
1019 float *dst_kw = dst_kh;
1020 const float *weight_kw = weight_kh;
1021 for (int kw = 0; kw < kernel_w; kw++) {
1022 for (int c = 0; c < C4NUM; c++) {
1023 dst_kw[c] += src_w[c] * weight_kw[c];
1024 }
1025 dst_kw += in_kw_step;
1026 weight_kw += C4NUM;
1027 } // kernel_w loop
1028 dst_kh += in_kh_step;
1029 weight_kh += kernel_w * C4NUM;
1030 } // kernel_h loop
1031 dst_w += in_sw_step;
1032 src_w += block_channel;
1033 } // dst_width loop
1034 dst_h += in_sh_step;
1035 src_h += out_h_step;
1036 } // dst_height loop
1037 }
1038 #endif
1039
DeconvDwPost(float * dst,const float * bias,int block_channel,const ConvParameter * conv_param)1040 void DeconvDwPost(float *dst, const float *bias, int block_channel, const ConvParameter *conv_param) {
1041 bool relu = conv_param->act_type_ == ActType_Relu;
1042 bool relu6 = conv_param->act_type_ == ActType_Relu6;
1043 float *dst_k = dst;
1044 for (int k = 0; k < conv_param->output_h_ * conv_param->output_w_; k++) {
1045 for (int c = 0; c < C4NUM; c++) {
1046 dst_k[c] += bias[c];
1047 dst_k[c] = (relu) ? (MSMAX(0, dst_k[c])) : (dst_k[c]);
1048 dst_k[c] = (relu6) ? (MSMIN(6, MSMAX(0, dst_k[c]))) : (dst_k[c]);
1049 }
1050 dst_k += block_channel;
1051 }
1052 }
1053
1054 // deconv depthwise fp32: sliding window
DeconvDwSWFp32(float * output_data,const float * input_data,const float * weight_data,const float * bias_data,const ConvParameter * conv_param,const SlidingWindowParam * sliding,int task_id)1055 void DeconvDwSWFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
1056 const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) {
1057 const float *src = input_data;
1058 float *dst = output_data;
1059 if (conv_param->thread_num_ == 0) {
1060 return;
1061 }
1062 for (int b = 0; b < conv_param->output_batch_; b++) {
1063 for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) {
1064 const float *src_data = src + oc * C4NUM;
1065 float *dst_data = dst + oc * C4NUM;
1066 const float *weight = weight_data + oc * sliding->kernel_step_;
1067 const float *bias = bias_data + oc * C4NUM;
1068 DeconvDwBorder(dst_data, src_data, weight, 0, sliding->top_, 0, conv_param->input_w_, conv_param, sliding);
1069 DeconvDwBorder(dst_data, src_data, weight, sliding->bottom_, conv_param->input_h_, 0, conv_param->input_w_,
1070 conv_param, sliding);
1071 DeconvDwBorder(dst_data, src_data, weight, sliding->top_, sliding->bottom_, 0, sliding->left_, conv_param,
1072 sliding);
1073 DeconvDwBorder(dst_data, src_data, weight, sliding->top_, sliding->bottom_, sliding->right_, conv_param->input_w_,
1074 conv_param, sliding);
1075
1076 if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) {
1077 int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_;
1078 int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
1079 float *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_;
1080 const float *in_t = src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;
1081
1082 #if defined(ENABLE_ARM) || defined(ENABLE_SSE)
1083 DeconvDwFp32Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
1084 conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float),
1085 sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float),
1086 sliding->in_sw_step_ * sizeof(float), sliding->in_kh_step_ * sizeof(float),
1087 sliding->in_kw_step_ * sizeof(float));
1088 #else
1089 DeconvDwCenter(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
1090 conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_,
1091 sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_);
1092 #endif
1093 }
1094 DeconvDwPost(dst_data, bias, sliding->block_channel_, conv_param);
1095 } // output C4 loop
1096 src += sliding->out_step_;
1097 dst += sliding->in_step_;
1098 } // batch loop
1099 // output nhwc4
1100 }
1101 /*deconv depthwise fp32 end*/
1102
1103 #ifdef ENABLE_AVX
DepthwiseBorderAvxFp32(float * dst,const float * src,const float * weight,const float * bias,int top,int left,int right,const ConvParameter * conv_param,const SlidingWindowParam * sw_param,const DepthwiseSWKernel kernel,int act_type,int ow_bock,int oc_block)1104 void DepthwiseBorderAvxFp32(float *dst, const float *src, const float *weight, const float *bias, int top, int left,
1105 int right, const ConvParameter *conv_param, const SlidingWindowParam *sw_param,
1106 const DepthwiseSWKernel kernel, int act_type, int ow_bock, int oc_block) {
1107 // dw border compate
1108 int ih = top * conv_param->stride_h_ - conv_param->pad_u_;
1109 int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
1110 int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
1111 const float *src_h = src + ih * sw_param->in_h_step_;
1112 float *dst_kernel = dst + left * sw_param->block_channel_;
1113 for (int ow = left; ow < right; ow += ow_bock) {
1114 int iw = ow * conv_param->stride_w_ - conv_param->pad_l_;
1115 int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
1116 int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
1117 const float *src_w = src_h + iw * sw_param->block_channel_;
1118 const float *src_kernel = src_w + start_kh * sw_param->in_kh_step_ + start_kw * sw_param->in_kw_step_;
1119 const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C8NUM * oc_block;
1120 kernel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, act_type, ow_bock,
1121 oc_block, sw_param->block_channel_, sw_param->in_kw_step_, sw_param->in_kh_step_, sw_param->in_sw_step_,
1122 (conv_param->kernel_w_ - end_kw + start_kw) * C8NUM * oc_block);
1123 dst_kernel += ow_bock * sw_param->block_channel_;
1124 } // width loop
1125 }
1126
DepthwiseSWAvxFp32(float * output_data,const float * input_data,const float * weight_data,const float * bias_data,const ConvParameter * conv_param,const SlidingWindowParam * sw_param,int task_id)1127 void DepthwiseSWAvxFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
1128 const ConvParameter *conv_param, const SlidingWindowParam *sw_param, int task_id) {
1129 int oh_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_);
1130 int oh_start = oh_step * task_id;
1131 int oh_end = MSMIN(oh_start + oh_step, conv_param->output_h_);
1132 if (oh_start >= oh_end) {
1133 return;
1134 }
1135 // depthwise sw in x86 avx instructions
1136 int oc_tile_ = C8NUM; // oc in algin to C8NUM in x86_64_avx
1137 int act_type = 0;
1138 if (conv_param->act_type_ == ActType_Relu6) {
1139 act_type += 1;
1140 }
1141 if (conv_param->act_type_ == ActType_Relu || conv_param->act_type_ == ActType_Relu6) {
1142 act_type += 2;
1143 }
1144 int kernel_h = conv_param->kernel_h_;
1145 int kernel_w = conv_param->kernel_w_;
1146 int output_w = conv_param->output_w_;
1147 int oc_algin = sw_param->block_channel_;
1148 int oc_num = sw_param->c_block_;
1149 int in_step = sw_param->in_step_;
1150 int out_step = sw_param->out_step_;
1151 int in_sw_step = sw_param->in_sw_step_;
1152 int in_kw_step = sw_param->in_kw_step_;
1153 int in_kh_step = sw_param->in_kh_step_;
1154 int in_sh_step = sw_param->in_sh_step_;
1155 int out_right = sw_param->right_;
1156 int out_left = sw_param->left_;
1157 int out_top = sw_param->top_;
1158 int out_bottom = sw_param->bottom_;
1159 int kernel_step = sw_param->kernel_step_;
1160 int out_h_step = sw_param->out_h_step_;
1161 int in_h_start = out_top * conv_param->stride_h_ - conv_param->pad_u_;
1162 int in_w_start = out_left * conv_param->stride_w_ - conv_param->pad_l_;
1163 int in_start = in_h_start * sw_param->in_h_step_ + in_w_start * oc_algin;
1164 const int ow_block_num[4] = {8, 4, 4, 3};
1165 const DepthwiseSWKernel kernel[4][2] = {{DepthwiseSW1x8Kernel, DepthwiseSW8x8Kernel},
1166 {DepthwiseSW1x16Kernel, DepthwiseSW4x16Kernel},
1167 {DepthwiseSW1x24Kernel, DepthwiseSW4x24Kernel},
1168 {DepthwiseSW1x32Kernel, DepthwiseSW3x32Kernel}};
1169 for (int b = 0; b < conv_param->output_batch_; b++) {
1170 for (int oh = oh_start; oh < oh_end; ++oh) {
1171 float *dst_oh = output_data + oh * out_h_step;
1172 const float *src_h = input_data + in_start + (oh - out_top) * in_sh_step;
1173 int oc_block = 0;
1174 const float *bias = bias_data;
1175 for (int oc = 0; oc < oc_num; oc += oc_block) {
1176 oc_block = MSMIN(C4NUM, oc_num - oc); // 4 3 2 1
1177 int oc_step = oc * oc_tile_;
1178 const float *weight = weight_data + oc * kernel_step;
1179 if (bias != NULL) {
1180 bias = bias_data + oc_step;
1181 }
1182 float *dst_w = dst_oh + oc_step;
1183 const DepthwiseSWKernel kernel_border = kernel[oc_block - 1][0];
1184 if (oh < out_top || oh >= out_bottom) { // oh in up or down border
1185 DepthwiseBorderAvxFp32(dst_w, input_data + oc_step, weight, bias, oh, 0, output_w, conv_param, sw_param,
1186 kernel_border, act_type, 1, oc_block);
1187 } else { // oh in center
1188 // ow in right
1189 DepthwiseBorderAvxFp32(dst_w, input_data + oc_step, weight, bias, oh, 0, out_left, conv_param, sw_param,
1190 kernel_border, act_type, 1, oc_block);
1191 // ow in center
1192 const float *src_w = src_h + oc_step;
1193 int ow_block = ow_block_num[oc_block - 1]; // 8 4 4 3
1194 for (int ow = out_left; ow < out_right; ow += ow_block) { // left ~ right
1195 if (ow_block > out_right - ow) { // ow is not enough and process one ow
1196 ow_block = 1;
1197 }
1198 kernel[oc_block - 1][ow_block / ow_block_num[oc_block - 1]](
1199 dst_w + ow * oc_algin, src_w, weight, bias, kernel_h, kernel_w, act_type, ow_block, oc_block, oc_algin,
1200 in_kw_step, in_kh_step, in_sw_step, 0);
1201 src_w += ow_block * in_sw_step;
1202 }
1203 // ow in left
1204 DepthwiseBorderAvxFp32(dst_w, input_data + oc_step, weight, bias, oh, out_right, output_w, conv_param,
1205 sw_param, kernel_border, act_type, 1, oc_block);
1206 }
1207 }
1208 } // output h loop
1209 input_data += in_step;
1210 output_data += out_step;
1211 } // batch loop
1212 }
1213
1214 #ifdef ENABLE_DEBUG
DepthwiseSWWxKKernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1215 void DepthwiseSWWxKKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1216 size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1217 size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1218 __m256 dst_data[12];
1219 __m256 src_data;
1220 const float *src_kh[12];
1221 const float *src_kw[12];
1222 __m256 weight_data[4];
1223 for (int i = 0; i < ow_block; ++i) {
1224 if (bias != NULL) {
1225 for (int j = 0; j < oc_block; ++j) {
1226 dst_data[i * oc_block + j] = _mm256_loadu_ps(bias + j * 8);
1227 }
1228 } else {
1229 for (int j = 0; j < oc_block; ++j) {
1230 dst_data[i * oc_block + j] = _mm256_set1_ps(0.0f);
1231 }
1232 }
1233 src_kh[i] = src + i * in_sw_step;
1234 src_kw[i] = NULL;
1235 }
1236 const float *weight_kernel = weight;
1237 for (int kh = 0; kh < kernel_h; kh++) {
1238 for (int i = 0; i < ow_block; ++i) {
1239 src_kw[i] = src_kh[i];
1240 }
1241 for (int kw = 0; kw < kernel_w; kw++) {
1242 for (int j = 0; j < oc_block; ++j) {
1243 weight_data[j] = _mm256_loadu_ps(weight_kernel + j * C8NUM);
1244 }
1245 for (int i = 0; i < ow_block; ++i) { // loop ow
1246 for (int j = 0; j < oc_block; ++j) {
1247 src_data = _mm256_loadu_ps(src_kw[i] + j * C8NUM);
1248 dst_data[i * oc_block + j] += src_data * weight_data[j];
1249 }
1250 }
1251 for (int i = 0; i < ow_block; ++i) {
1252 src_kw[i] += in_kw_step; // ic8 * dilation_w
1253 }
1254 weight_kernel += oc_block * C8NUM;
1255 } // kernel_w loop
1256 weight_kernel += kw_remainder;
1257 for (int i = 0; i < ow_block; ++i) {
1258 src_kh[i] += in_kh_step; //
1259 }
1260 } // kernel_h loop
1261 // add bias and relu
1262 for (int i = 0; i < ow_block; ++i) {
1263 for (int j = 0; j < oc_block; ++j) {
1264 if (0x1 & act_flag) { // relu6
1265 dst_data[i * oc_block + j] = _mm256_min_ps(dst_data[i * oc_block + j], _mm256_set1_ps(6.0f));
1266 }
1267 if (0x2 & act_flag) { // relu
1268 dst_data[i * oc_block + j] = _mm256_max_ps(dst_data[i * oc_block + j], _mm256_set1_ps(0.0f));
1269 }
1270 _mm256_storeu_ps(dst + i * oc_algin + j * C8NUM, dst_data[i * oc_block + j]);
1271 }
1272 }
1273 }
1274 #endif
1275
DepthwiseSW3x32Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1276 void DepthwiseSW3x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1277 size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1278 size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1279 in_kh_step *= sizeof(float);
1280 in_sw_step *= sizeof(float);
1281 in_kw_step *= sizeof(float);
1282 oc_algin *= sizeof(float);
1283 kw_remainder *= sizeof(float);
1284 asm volatile(
1285 "cmpq $0, %2\n"
1286 "je 0f\n"
1287 "vmovups (%2), %%ymm0\n"
1288 "vmovups 0x20(%2), %%ymm1\n"
1289 "vmovups 0x40(%2), %%ymm2\n"
1290 "vmovups 0x60(%2), %%ymm3\n"
1291 "vmovups (%2), %%ymm4\n"
1292 "vmovups 0x20(%2), %%ymm5\n"
1293 "vmovups 0x40(%2), %%ymm6\n"
1294 "vmovups 0x60(%2), %%ymm7\n"
1295 "vmovups (%2), %%ymm8\n"
1296 "vmovups 0x20(%2), %%ymm9\n"
1297 "vmovups 0x40(%2), %%ymm10\n"
1298 "vmovups 0x60(%2), %%ymm11\n"
1299 "jmp 1f\n"
1300 "0:\n"
1301 "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1302 "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1303 "vxorps %%ymm2, %%ymm2, %%ymm2\n"
1304 "vxorps %%ymm3, %%ymm3, %%ymm3\n"
1305 "vxorps %%ymm4, %%ymm4, %%ymm4\n"
1306 "vxorps %%ymm5, %%ymm5, %%ymm5\n"
1307 "vxorps %%ymm6, %%ymm6, %%ymm6\n"
1308 "vxorps %%ymm7, %%ymm7, %%ymm7\n"
1309 "vxorps %%ymm8, %%ymm8, %%ymm8\n"
1310 "vxorps %%ymm9, %%ymm9, %%ymm9\n"
1311 "vxorps %%ymm10, %%ymm10, %%ymm10\n"
1312 "vxorps %%ymm11, %%ymm11, %%ymm11\n"
1313 "1:\n" // LoopH
1314 "movq %4, %%rsi\n" // width
1315 "movq %0, %%rcx\n" // src_h
1316 "2:\n" // LoopW
1317
1318 "vmovups (%1), %%ymm12\n"
1319 "vmovups (%%rcx), %%ymm13\n"
1320 "vmovups (%%rcx, %7), %%ymm14\n"
1321 "vmovups (%%rcx, %7, 2), %%ymm15\n"
1322 "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n"
1323 "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n"
1324 "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n"
1325
1326 "vmovups 0x20(%1), %%ymm12\n"
1327 "vmovups 0x20(%%rcx), %%ymm13\n"
1328 "vmovups 0x20(%%rcx, %7), %%ymm14\n"
1329 "vmovups 0x20(%%rcx, %7, 2), %%ymm15\n"
1330 "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n"
1331 "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n"
1332 "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n"
1333
1334 "vmovups 0x40(%1), %%ymm12\n"
1335 "vmovups 0x40(%%rcx), %%ymm13\n"
1336 "vmovups 0x40(%%rcx, %7), %%ymm14\n"
1337 "vmovups 0x40(%%rcx, %7, 2), %%ymm15\n"
1338 "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n"
1339 "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n"
1340 "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n"
1341
1342 "vmovups 0x60(%1), %%ymm12\n"
1343 "vmovups 0x60(%%rcx), %%ymm13\n"
1344 "vmovups 0x60(%%rcx, %7), %%ymm14\n"
1345 "vmovups 0x60(%%rcx, %7, 2), %%ymm15\n"
1346 "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n"
1347 "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n"
1348 "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n"
1349 "addq $128, %1\n"
1350
1351 "addq %5, %%rcx\n" // in_kw_step
1352 "dec %%rsi\n"
1353 "jg 2b\n"
1354
1355 "addq %6, %0\n" // in_kh_step
1356 "addq %8, %1\n"
1357 "dec %3\n"
1358 "jg 1b\n"
1359 :
1360 : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5
1361 "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder) // 8
1362 : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9",
1363 "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
1364
1365 asm volatile(
1366 "and $0x3, %%eax\n"
1367 "je 0f\n"
1368 // Relu
1369 "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1370 "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1371 "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1372 "vmaxps %%ymm12, %%ymm2, %%ymm2\n"
1373 "vmaxps %%ymm12, %%ymm3, %%ymm3\n"
1374 "vmaxps %%ymm12, %%ymm4, %%ymm4\n"
1375 "vmaxps %%ymm12, %%ymm5, %%ymm5\n"
1376 "vmaxps %%ymm12, %%ymm6, %%ymm6\n"
1377 "vmaxps %%ymm12, %%ymm7, %%ymm7\n"
1378 "vmaxps %%ymm12, %%ymm8, %%ymm8\n"
1379 "vmaxps %%ymm12, %%ymm9, %%ymm9\n"
1380 "vmaxps %%ymm12, %%ymm10, %%ymm10\n"
1381 "vmaxps %%ymm12, %%ymm11, %%ymm11\n"
1382
1383 "and $0x1, %%eax\n"
1384 "je 0f\n"
1385 // relu6
1386 "mov $0x40C00000, %%ecx\n"
1387 "vmovd %%ecx, %%xmm14\n"
1388 "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1389 "vminps %%ymm14, %%ymm0, %%ymm0\n"
1390 "vminps %%ymm14, %%ymm1, %%ymm1\n"
1391 "vminps %%ymm14, %%ymm2, %%ymm2\n"
1392 "vminps %%ymm14, %%ymm3, %%ymm3\n"
1393 "vminps %%ymm14, %%ymm4, %%ymm4\n"
1394 "vminps %%ymm14, %%ymm5, %%ymm5\n"
1395 "vminps %%ymm14, %%ymm6, %%ymm6\n"
1396 "vminps %%ymm14, %%ymm7, %%ymm7\n"
1397 "vminps %%ymm14, %%ymm8, %%ymm8\n"
1398 "vminps %%ymm14, %%ymm9, %%ymm9\n"
1399 "vminps %%ymm14, %%ymm10, %%ymm10\n"
1400 "vminps %%ymm14, %%ymm11, %%ymm11\n"
1401
1402 "0:\n"
1403 "vmovups %%ymm0, (%2)\n" // dst_0
1404 "vmovups %%ymm1, 0x20(%2)\n"
1405 "vmovups %%ymm2, 0x40(%2)\n"
1406 "vmovups %%ymm3, 0x60(%2)\n"
1407 "vmovups %%ymm4, (%2, %1, 1)\n"
1408 "vmovups %%ymm5, 0x20(%2, %1, 1)\n"
1409 "vmovups %%ymm6, 0x40(%2, %1, 1)\n"
1410 "vmovups %%ymm7, 0x60(%2, %1, 1)\n"
1411 "vmovups %%ymm8, (%2, %1, 2)\n"
1412 "vmovups %%ymm9, 0x20(%2, %1, 2)\n"
1413 "vmovups %%ymm10, 0x40(%2, %1, 2)\n"
1414 "vmovups %%ymm11, 0x60(%2, %1, 2)\n"
1415 :
1416 : "a"(act_flag), "r"(oc_algin), "r"(dst)
1417 : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
1418 "%ymm11", "%ymm12", "%ymm14");
1419 }
1420
DepthwiseSW1x32Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1421 void DepthwiseSW1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1422 size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1423 size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1424 in_kh_step *= sizeof(float);
1425 in_kw_step *= sizeof(float);
1426 oc_algin *= sizeof(float);
1427 kw_remainder *= sizeof(float);
1428 asm volatile(
1429 "cmpq $0, %2\n"
1430 "je 0f\n"
1431 "vmovups (%2), %%ymm0\n"
1432 "vmovups 0x20(%2), %%ymm1\n"
1433 "vmovups 0x40(%2), %%ymm2\n"
1434 "vmovups 0x60(%2), %%ymm3\n"
1435 "jmp 1f\n"
1436 "0:\n"
1437 "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1438 "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1439 "vxorps %%ymm2, %%ymm2, %%ymm2\n"
1440 "vxorps %%ymm3, %%ymm3, %%ymm3\n"
1441 "1:\n" // LoopH
1442 "movq %4, %%rsi\n" // width
1443 "movq %0, %%rcx\n" // src_h
1444 "2:\n" // Loopw
1445 "vmovups (%%rcx), %%ymm4\n"
1446 "vmovups 0x20(%%rcx), %%ymm5\n"
1447 "vmovups 0x40(%%rcx), %%ymm6\n"
1448 "vmovups 0x60(%%rcx), %%ymm7\n"
1449 // Weight data is loaded directly from memory instead of into registers for calculation.
1450 "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1451 "vfmadd231ps 0x20(%1), %%ymm5, %%ymm1\n"
1452 "vfmadd231ps 0x40(%1), %%ymm6, %%ymm2\n"
1453 "vfmadd231ps 0x60(%1), %%ymm7, %%ymm3\n"
1454 "addq $128, %1\n"
1455
1456 "addq %5, %%rcx\n" // in_kw_step
1457 "dec %%rsi\n"
1458 "jg 2b\n"
1459
1460 "addq %6, %0\n" // in_kh_step
1461 "addq %7, %1\n"
1462 "dec %3\n"
1463 "jg 1b\n"
1464 :
1465 : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5
1466 "r"(in_kh_step), "r"(kw_remainder) // 7
1467 : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7");
1468
1469 asm volatile(
1470 "and $0x3, %%eax\n"
1471 "je 0f\n"
1472 // Relu
1473 "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1474 "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1475 "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1476 "vmaxps %%ymm12, %%ymm2, %%ymm2\n"
1477 "vmaxps %%ymm12, %%ymm3, %%ymm3\n"
1478
1479 "and $0x1, %%eax\n"
1480 "je 0f\n"
1481 // relu6
1482 "mov $0x40C00000, %%ecx\n"
1483 "vmovd %%ecx, %%xmm14\n"
1484 "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1485 "vminps %%ymm14, %%ymm0, %%ymm0\n"
1486 "vminps %%ymm14, %%ymm1, %%ymm1\n"
1487 "vminps %%ymm14, %%ymm2, %%ymm2\n"
1488 "vminps %%ymm14, %%ymm3, %%ymm3\n"
1489
1490 "0:\n"
1491 "vmovups %%ymm0, (%2)\n" // dst_0
1492 "vmovups %%ymm1, 0x20(%2)\n"
1493 "vmovups %%ymm2, 0x40(%2)\n"
1494 "vmovups %%ymm3, 0x60(%2)\n"
1495 :
1496 : "a"(act_flag), "r"(oc_algin), "r"(dst)
1497 : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm12", "%ymm14");
1498 }
1499
DepthwiseSW4x24Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1500 void DepthwiseSW4x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1501 size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1502 size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1503 in_kh_step *= sizeof(float);
1504 in_kw_step *= sizeof(float);
1505 in_sw_step *= sizeof(float);
1506 kw_remainder *= sizeof(float);
1507 size_t src_3_step = 3 * in_sw_step;
1508 float *dst_3 = dst + 3 * oc_algin;
1509 oc_algin *= sizeof(float);
1510 asm volatile(
1511 "cmpq $0, %2\n"
1512 "je 0f\n"
1513 "vmovups (%2), %%ymm0\n"
1514 "vmovups 0x20(%2), %%ymm1\n"
1515 "vmovups 0x40(%2), %%ymm2\n"
1516 // We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction.
1517 "vmovups (%2), %%ymm3\n"
1518 "vmovups 0x20(%2), %%ymm4\n"
1519 "vmovups 0x40(%2), %%ymm5\n"
1520 "vmovups (%2), %%ymm6\n"
1521 "vmovups 0x20(%2), %%ymm7\n"
1522 "vmovups 0x40(%2), %%ymm8\n"
1523 "vmovups (%2), %%ymm9\n"
1524 "vmovups 0x20(%2), %%ymm10\n"
1525 "vmovups 0x40(%2), %%ymm11\n"
1526 "jmp 1f\n"
1527 "0:\n"
1528 "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1529 "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1530 "vxorps %%ymm2, %%ymm2, %%ymm2\n"
1531 "vxorps %%ymm3, %%ymm3, %%ymm3\n"
1532 "vxorps %%ymm4, %%ymm4, %%ymm4\n"
1533 "vxorps %%ymm5, %%ymm5, %%ymm5\n"
1534 "vxorps %%ymm6, %%ymm6, %%ymm6\n"
1535 "vxorps %%ymm7, %%ymm7, %%ymm7\n"
1536 "vxorps %%ymm8, %%ymm8, %%ymm8\n"
1537 "vxorps %%ymm9, %%ymm9, %%ymm9\n"
1538 "vxorps %%ymm10, %%ymm10, %%ymm10\n"
1539 "vxorps %%ymm11, %%ymm11, %%ymm11\n"
1540 "1:\n" // LoopH
1541 "movq %4, %%rsi\n" // width
1542 "movq %0, %%rcx\n" // src_h
1543 "2:\n" // LoopW
1544 "vmovups (%1), %%ymm12\n"
1545 "vmovups (%%rcx), %%ymm13\n"
1546 "vmovups (%%rcx, %7, 1), %%ymm14\n"
1547 "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n"
1548 "vfmadd231ps %%ymm12, %%ymm14, %%ymm3\n"
1549 "vmovups (%%rcx, %7, 2), %%ymm15\n"
1550 "vmovups (%%rcx, %9), %%ymm13\n"
1551 "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n"
1552 "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n"
1553
1554 "vmovups 0x20(%1), %%ymm12\n"
1555 "vmovups 0x20(%%rcx), %%ymm13\n"
1556 "vmovups 0x20(%%rcx, %7, 1), %%ymm14\n"
1557 "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n"
1558 "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n"
1559 "vmovups 0x20(%%rcx, %7, 2), %%ymm15\n"
1560 "vmovups 0x20(%%rcx, %9), %%ymm13\n"
1561 "vfmadd231ps %%ymm12, %%ymm15, %%ymm7\n"
1562 "vfmadd231ps %%ymm12, %%ymm13, %%ymm10\n"
1563
1564 "vmovups 0x40(%1), %%ymm12\n"
1565 "vmovups 0x40(%%rcx), %%ymm13\n"
1566 "vmovups 0x40(%%rcx, %7, 1), %%ymm14\n"
1567 "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n"
1568 "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n"
1569 "vmovups 0x40(%%rcx, %7, 2), %%ymm15\n"
1570 "vmovups 0x40(%%rcx, %9), %%ymm13\n"
1571 "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n"
1572 "vfmadd231ps %%ymm12, %%ymm13, %%ymm11\n"
1573
1574 "addq $96, %1\n"
1575 "addq %5, %%rcx\n" // in_kw_step
1576 "dec %%rsi\n"
1577 "jg 2b\n"
1578
1579 "addq %6, %0\n" // in_kh_step
1580 "addq %8, %1\n" // border in sw need to add remainder data
1581 "dec %3\n"
1582 "jg 1b\n"
1583 :
1584 : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5
1585 "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder), "r"(src_3_step) // 9
1586 : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9",
1587 "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
1588
1589 asm volatile(
1590 "and $0x3, %%eax\n"
1591 "je 0f\n"
1592 // Relu
1593 "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1594 "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1595 "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1596 "vmaxps %%ymm12, %%ymm2, %%ymm2\n"
1597 "vmaxps %%ymm12, %%ymm3, %%ymm3\n"
1598 "vmaxps %%ymm12, %%ymm4, %%ymm4\n"
1599 "vmaxps %%ymm12, %%ymm5, %%ymm5\n"
1600 "vmaxps %%ymm12, %%ymm6, %%ymm6\n"
1601 "vmaxps %%ymm12, %%ymm7, %%ymm7\n"
1602 "vmaxps %%ymm12, %%ymm8, %%ymm8\n"
1603 "vmaxps %%ymm12, %%ymm9, %%ymm9\n"
1604 "vmaxps %%ymm12, %%ymm10, %%ymm10\n"
1605 "vmaxps %%ymm12, %%ymm11, %%ymm11\n"
1606
1607 "and $0x1, %%eax\n"
1608 "je 0f\n"
1609 // relu6
1610 "mov $0x40C00000, %%ecx\n"
1611 "vmovd %%ecx, %%xmm14\n"
1612 "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1613 "vminps %%ymm14, %%ymm0, %%ymm0\n"
1614 "vminps %%ymm14, %%ymm1, %%ymm1\n"
1615 "vminps %%ymm14, %%ymm2, %%ymm2\n"
1616 "vminps %%ymm14, %%ymm3, %%ymm3\n"
1617 "vminps %%ymm14, %%ymm4, %%ymm4\n"
1618 "vminps %%ymm14, %%ymm5, %%ymm5\n"
1619 "vminps %%ymm14, %%ymm6, %%ymm6\n"
1620 "vminps %%ymm14, %%ymm7, %%ymm7\n"
1621 "vminps %%ymm14, %%ymm8, %%ymm8\n"
1622 "vminps %%ymm14, %%ymm9, %%ymm9\n"
1623 "vminps %%ymm14, %%ymm10, %%ymm10\n"
1624 "vminps %%ymm14, %%ymm11, %%ymm11\n"
1625
1626 "0:\n"
1627 "vmovups %%ymm0, (%2)\n" // dst_0
1628 "vmovups %%ymm1, 0x20(%2)\n"
1629 "vmovups %%ymm2, 0x40(%2)\n"
1630 "vmovups %%ymm3, (%2, %1, 1)\n"
1631 "vmovups %%ymm4, 0x20(%2, %1, 1)\n"
1632 "vmovups %%ymm5, 0x40(%2, %1, 1)\n"
1633 "vmovups %%ymm6, (%2, %1, 2)\n"
1634 "vmovups %%ymm7, 0x20(%2, %1, 2)\n"
1635 "vmovups %%ymm8, 0x40(%2, %1, 2)\n"
1636 "vmovups %%ymm9, (%3)\n" // dst+3
1637 "vmovups %%ymm10, 0x20(%3)\n"
1638 "vmovups %%ymm11, 0x40(%3)\n"
1639 :
1640 : "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3)
1641 : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
1642 "%ymm11", "%ymm12", "%ymm14");
1643 }
1644
DepthwiseSW1x24Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1645 void DepthwiseSW1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1646 size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1647 size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1648 in_kh_step *= sizeof(float);
1649 in_kw_step *= sizeof(float);
1650 oc_algin *= sizeof(float);
1651 kw_remainder *= sizeof(float);
1652 asm volatile(
1653 "cmpq $0, %2\n"
1654 "je 0f\n"
1655 "vmovups (%2), %%ymm0\n"
1656 "vmovups 0x20(%2), %%ymm1\n"
1657 "vmovups 0x40(%2), %%ymm2\n"
1658 "jmp 1f\n"
1659 "0:\n"
1660 "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1661 "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1662 "vxorps %%ymm2, %%ymm2, %%ymm2\n"
1663 "1:\n" // LoopH
1664 "movq %4, %%rsi\n" // width
1665 "movq %0, %%rcx\n" // src_h
1666 "2:\n" // Loopw
1667 "vmovups (%%rcx), %%ymm4\n"
1668 "vmovups 0x20(%%rcx), %%ymm5\n"
1669 "vmovups 0x40(%%rcx), %%ymm6\n"
1670 // Weight data is loaded directly from memory instead of into registers for calculation.
1671 "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1672 "vfmadd231ps 0x20(%1), %%ymm5, %%ymm1\n"
1673 "vfmadd231ps 0x40(%1), %%ymm6, %%ymm2\n"
1674 "addq $96, %1\n"
1675
1676 "addq %5, %%rcx\n" // in_kw_step
1677 "dec %%rsi\n"
1678 "jg 2b\n"
1679
1680 "addq %6, %0\n" // in_kh_step
1681 "addq %7, %1\n" // kw_remainder
1682 "dec %3\n"
1683 "jg 1b\n"
1684 :
1685 : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5
1686 "r"(in_kh_step), "r"(kw_remainder) // 7
1687 : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm4", "%ymm5", "%ymm6");
1688
1689 asm volatile(
1690 "and $0x3, %%eax\n"
1691 "je 0f\n"
1692 // Relu
1693 "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1694 "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1695 "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1696 "vmaxps %%ymm12, %%ymm2, %%ymm2\n"
1697
1698 "and $0x1, %%eax\n"
1699 "je 0f\n"
1700 // relu6
1701 "mov $0x40C00000, %%ecx\n"
1702 "vmovd %%ecx, %%xmm14\n"
1703 "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1704 "vminps %%ymm14, %%ymm0, %%ymm0\n"
1705 "vminps %%ymm14, %%ymm1, %%ymm1\n"
1706 "vminps %%ymm14, %%ymm2, %%ymm2\n"
1707
1708 "0:\n"
1709 "vmovups %%ymm0, (%2)\n" // dst_0
1710 "vmovups %%ymm1, 0x20(%2)\n"
1711 "vmovups %%ymm2, 0x40(%2)\n"
1712 :
1713 : "a"(act_flag), "r"(oc_algin), "r"(dst)
1714 : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm12", "%ymm14");
1715 }
1716
DepthwiseSW4x16Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1717 void DepthwiseSW4x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1718 size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1719 size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1720 in_kh_step *= sizeof(float);
1721 in_kw_step *= sizeof(float);
1722 in_sw_step *= sizeof(float);
1723 kw_remainder *= sizeof(float);
1724 size_t src_3_step = 3 * in_sw_step;
1725 float *dst_3 = dst + 3 * oc_algin;
1726 oc_algin *= sizeof(float);
1727 asm volatile(
1728 "cmpq $0, %2\n"
1729 "je 0f\n"
1730 "vmovups (%2), %%ymm0\n"
1731 "vmovups 0x20(%2), %%ymm1\n"
1732 // We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction.
1733 "vmovups (%2), %%ymm3\n"
1734 "vmovups 0x20(%2), %%ymm4\n"
1735 "vmovups (%2), %%ymm6\n"
1736 "vmovups 0x20(%2), %%ymm7\n"
1737 "vmovups (%2), %%ymm9\n"
1738 "vmovups 0x20(%2), %%ymm10\n"
1739 "jmp 1f\n"
1740 "0:\n"
1741 "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1742 "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1743 "vxorps %%ymm3, %%ymm3, %%ymm3\n"
1744 "vxorps %%ymm4, %%ymm4, %%ymm4\n"
1745 "vxorps %%ymm6, %%ymm6, %%ymm6\n"
1746 "vxorps %%ymm7, %%ymm7, %%ymm7\n"
1747 "vxorps %%ymm9, %%ymm9, %%ymm9\n"
1748 "vxorps %%ymm10, %%ymm10, %%ymm10\n"
1749 "1:\n" // LoopH
1750 "movq %4, %%rsi\n" // width
1751 "movq %0, %%rcx\n" // src_h
1752 "2:\n" // LoopW
1753 "vmovups (%1), %%ymm12\n"
1754 "vmovups (%%rcx), %%ymm13\n"
1755 "vmovups (%%rcx, %7, 1), %%ymm14\n"
1756 "vmovups (%%rcx, %7, 2), %%ymm15\n"
1757 "vmovups (%%rcx, %9), %%ymm2\n"
1758 "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n"
1759 "vfmadd231ps %%ymm12, %%ymm14, %%ymm3\n"
1760 "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n"
1761 "vfmadd231ps %%ymm12, %%ymm2, %%ymm9\n"
1762
1763 "vmovups 0x20(%1), %%ymm12\n"
1764 "vmovups 0x20(%%rcx), %%ymm13\n"
1765 "vmovups 0x20(%%rcx, %7, 1), %%ymm14\n"
1766 "vmovups 0x20(%%rcx, %7, 2), %%ymm15\n"
1767 "vmovups 0x20(%%rcx, %9), %%ymm2\n"
1768 "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n"
1769 "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n"
1770 "vfmadd231ps %%ymm12, %%ymm15, %%ymm7\n"
1771 "vfmadd231ps %%ymm12, %%ymm2, %%ymm10\n"
1772
1773 "addq $64, %1\n"
1774 "addq %5, %%rcx\n" // in_kw_step
1775 "dec %%rsi\n"
1776 "jg 2b\n"
1777
1778 "addq %6, %0\n" // in_kh_step
1779 "addq %8, %1\n" // border in sw need to add remainder data
1780 "dec %3\n"
1781 "jg 1b\n"
1782 :
1783 : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5
1784 "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder), "r"(src_3_step) // 9
1785 : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm3", "%ymm4", "%ymm6", "%ymm7", "%ymm9", "%ymm10", "%ymm12", "%ymm13",
1786 "%ymm14", "%ymm15");
1787
1788 asm volatile(
1789 "and $0x3, %%eax\n"
1790 "je 0f\n"
1791 // Relu
1792 "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1793 "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1794 "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1795 "vmaxps %%ymm12, %%ymm3, %%ymm3\n"
1796 "vmaxps %%ymm12, %%ymm4, %%ymm4\n"
1797 "vmaxps %%ymm12, %%ymm6, %%ymm6\n"
1798 "vmaxps %%ymm12, %%ymm7, %%ymm7\n"
1799 "vmaxps %%ymm12, %%ymm9, %%ymm9\n"
1800 "vmaxps %%ymm12, %%ymm10, %%ymm10\n"
1801
1802 "and $0x1, %%eax\n"
1803 "je 0f\n"
1804 // relu6
1805 "mov $0x40C00000, %%ecx\n"
1806 "vmovd %%ecx, %%xmm14\n"
1807 "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1808 "vminps %%ymm14, %%ymm0, %%ymm0\n"
1809 "vminps %%ymm14, %%ymm1, %%ymm1\n"
1810 "vminps %%ymm14, %%ymm3, %%ymm3\n"
1811 "vminps %%ymm14, %%ymm4, %%ymm4\n"
1812 "vminps %%ymm14, %%ymm6, %%ymm6\n"
1813 "vminps %%ymm14, %%ymm7, %%ymm7\n"
1814 "vminps %%ymm14, %%ymm9, %%ymm9\n"
1815 "vminps %%ymm14, %%ymm10, %%ymm10\n"
1816
1817 "0:\n"
1818 "vmovups %%ymm0, (%2)\n" // dst_0
1819 "vmovups %%ymm1, 0x20(%2)\n"
1820 "vmovups %%ymm3, (%2, %1, 1)\n"
1821 "vmovups %%ymm4, 0x20(%2, %1, 1)\n"
1822 "vmovups %%ymm6, (%2, %1, 2)\n"
1823 "vmovups %%ymm7, 0x20(%2, %1, 2)\n"
1824 "vmovups %%ymm9, (%3)\n" // dst+3
1825 "vmovups %%ymm10, 0x20(%3)\n"
1826 :
1827 : "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3)
1828 : "%ecx", "%ymm0", "%ymm1", "%ymm3", "%ymm4", "%ymm6", "%ymm7", "%ymm9", "%ymm10", "%ymm12", "%ymm14");
1829 }
1830
DepthwiseSW1x16Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1831 void DepthwiseSW1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1832 size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1833 size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1834 in_kh_step *= sizeof(float);
1835 in_kw_step *= sizeof(float);
1836 oc_algin *= sizeof(float);
1837 kw_remainder *= sizeof(float);
1838 asm volatile(
1839 "cmpq $0, %2\n"
1840 "je 0f\n"
1841 "vmovups (%2), %%ymm0\n"
1842 "vmovups 0x20(%2), %%ymm1\n"
1843 "jmp 1f\n"
1844 "0:\n"
1845 "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1846 "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1847 "1:\n" // LoopH
1848 "movq %4, %%rsi\n" // width
1849 "movq %0, %%rcx\n" // src_h
1850 "2:\n" // Loopw
1851 "vmovups (%%rcx), %%ymm4\n"
1852 "vmovups 0x20(%%rcx), %%ymm5\n"
1853 // Weight data is loaded directly from memory instead of into registers for calculation.
1854 "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1855 "vfmadd231ps 0x20(%1), %%ymm5, %%ymm1\n"
1856 "addq $64, %1\n"
1857
1858 "addq %5, %%rcx\n" // in_kw_step
1859 "dec %%rsi\n"
1860 "jg 2b\n"
1861
1862 "addq %6, %0\n" // in_kh_step
1863 "addq %7, %1\n" // kw_remainder
1864 "dec %3\n"
1865 "jg 1b\n"
1866 :
1867 : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5
1868 "r"(in_kh_step), "r"(kw_remainder) // 7
1869 : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm4", "%ymm5");
1870
1871 asm volatile(
1872 "and $0x3, %%eax\n"
1873 "je 0f\n"
1874 // Relu
1875 "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1876 "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1877 "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1878
1879 "and $0x1, %%eax\n"
1880 "je 0f\n"
1881 // relu6
1882 "mov $0x40C00000, %%ecx\n"
1883 "vmovd %%ecx, %%xmm14\n"
1884 "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1885 "vminps %%ymm14, %%ymm0, %%ymm0\n"
1886 "vminps %%ymm14, %%ymm1, %%ymm1\n"
1887
1888 "0:\n"
1889 "vmovups %%ymm0, (%2)\n" // dst_0
1890 "vmovups %%ymm1, 0x20(%2)\n"
1891 :
1892 : "a"(act_flag), "r"(oc_algin), "r"(dst)
1893 : "%ecx", "%ymm0", "%ymm1", "%ymm12", "%ymm14");
1894 }
1895
DepthwiseSW8x8Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)1896 void DepthwiseSW8x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
1897 size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
1898 size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
1899 in_kh_step *= sizeof(float);
1900 in_sw_step *= sizeof(float);
1901 in_kw_step *= sizeof(float);
1902 kw_remainder *= sizeof(float);
1903 size_t src_3_step = 3 * in_sw_step;
1904 float *dst_3 = dst + 3 * oc_algin;
1905 float *dst_5 = dst + 5 * oc_algin;
1906 oc_algin *= sizeof(float);
1907 asm volatile(
1908 "cmpq $0, %0\n"
1909 "je 0f\n"
1910 "vmovups (%0), %%ymm0\n"
1911 "vmovups (%0), %%ymm1\n"
1912 "vmovups (%0), %%ymm2\n"
1913 "vmovups (%0), %%ymm3\n"
1914 "vmovups (%0), %%ymm4\n"
1915 "vmovups (%0), %%ymm5\n"
1916 "vmovups (%0), %%ymm6\n"
1917 "vmovups (%0), %%ymm7\n"
1918 "jmp 1f\n"
1919 "0:\n"
1920 "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1921 "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1922 "vxorps %%ymm2, %%ymm2, %%ymm2\n"
1923 "vxorps %%ymm3, %%ymm3, %%ymm3\n"
1924 "vxorps %%ymm4, %%ymm4, %%ymm4\n"
1925 "vxorps %%ymm5, %%ymm5, %%ymm5\n"
1926 "vxorps %%ymm6, %%ymm6, %%ymm6\n"
1927 "vxorps %%ymm7, %%ymm7, %%ymm7\n"
1928 "1:\n"
1929 :
1930 : "r"(bias)
1931 : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7");
1932
1933 asm volatile(
1934 "LoopH:\n"
1935 "movq %3, %%rsi\n" // width
1936 "movq %0, %%rcx\n" // src_h
1937 "LoopW:\n"
1938 "movq %%rcx, %%rax\n"
1939 "vmovups (%1), %%ymm12\n"
1940 "vmovups (%%rax), %%ymm13\n"
1941 "vmovups (%%rax, %6), %%ymm14\n"
1942 "vmovups (%%rax, %6, 2), %%ymm15\n"
1943 "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n"
1944 "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n"
1945 "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n"
1946 "addq %7, %%rax\n"
1947 "vmovups (%%rax), %%ymm13\n"
1948 "vmovups (%%rax, %6), %%ymm14\n"
1949 "vmovups (%%rax, %6, 2), %%ymm15\n"
1950 "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n"
1951 "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n"
1952 "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n"
1953 "addq %7, %%rax\n"
1954 "vmovups (%%rax), %%ymm13\n"
1955 "vmovups (%%rax, %6), %%ymm14\n"
1956 "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n"
1957 "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n"
1958
1959 "addq $32, %1\n"
1960 "addq %4, %%rcx\n" // in_kw_step
1961 "dec %%rsi\n"
1962 "jg LoopW\n"
1963
1964 "addq %5, %0\n" // in_kh_step
1965 "addq %8, %1\n" // border in sw need to add remainder data
1966 "dec %2\n"
1967 "jg LoopH\n"
1968 :
1969 : "r"(src), "r"(weight), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), "r"(in_kh_step), // 5
1970 "r"(in_sw_step), "r"(src_3_step), "r"(kw_remainder) // 8
1971 : "%rcx", "%rsi", "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm12",
1972 "%ymm13", "%ymm14", "%ymm15");
1973
1974 asm volatile(
1975 "and $0x3, %%eax\n"
1976 "je Write\n"
1977 // Relu
1978 "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1979 "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1980 "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1981 "vmaxps %%ymm12, %%ymm2, %%ymm2\n"
1982 "vmaxps %%ymm12, %%ymm3, %%ymm3\n"
1983 "vmaxps %%ymm12, %%ymm4, %%ymm4\n"
1984 "vmaxps %%ymm12, %%ymm5, %%ymm5\n"
1985 "vmaxps %%ymm12, %%ymm6, %%ymm6\n"
1986 "vmaxps %%ymm12, %%ymm7, %%ymm7\n"
1987
1988 "and $0x1, %%eax\n"
1989 "je Write\n"
1990 // relu6
1991 "mov $0x40C00000, %%ecx\n"
1992 "vmovd %%ecx, %%xmm14\n"
1993 "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1994 "vminps %%ymm14, %%ymm0, %%ymm0\n"
1995 "vminps %%ymm14, %%ymm1, %%ymm1\n"
1996 "vminps %%ymm14, %%ymm2, %%ymm2\n"
1997 "vminps %%ymm14, %%ymm3, %%ymm3\n"
1998 "vminps %%ymm14, %%ymm4, %%ymm4\n"
1999 "vminps %%ymm14, %%ymm5, %%ymm5\n"
2000 "vminps %%ymm14, %%ymm6, %%ymm6\n"
2001 "vminps %%ymm14, %%ymm7, %%ymm7\n"
2002
2003 "Write:\n"
2004 "vmovups %%ymm0, (%2)\n" // dst_0
2005 "vmovups %%ymm1, (%2, %1)\n"
2006 "vmovups %%ymm2, (%2, %1, 2)\n"
2007 "vmovups %%ymm3, (%3)\n" // dst_3
2008 "vmovups %%ymm4, (%2, %1, 4)\n"
2009 "vmovups %%ymm5, (%4)\n" // dst_5
2010 "vmovups %%ymm6, (%4, %1, 1)\n"
2011 "vmovups %%ymm7, (%4, %1, 2)\n"
2012 :
2013 : "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3), "r"(dst_5)
2014 : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm12", "%ymm14");
2015 }
2016
DepthwiseSW1x8Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t kernel_h,size_t kernel_w,size_t act_flag,size_t ow_block,size_t oc_block,size_t oc_algin,size_t in_kw_step,size_t in_kh_step,size_t in_sw_step,size_t kw_remainder)2017 void DepthwiseSW1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
2018 size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin,
2019 size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) {
2020 in_kh_step *= sizeof(float);
2021 in_kw_step *= sizeof(float);
2022 oc_algin *= sizeof(float);
2023 kw_remainder *= sizeof(float);
2024 asm volatile(
2025 "cmpq $0, %2\n"
2026 "je 0f\n"
2027 "vmovups (%2), %%ymm0\n"
2028 "jmp 1f\n"
2029 "0:\n"
2030 "vxorps %%ymm0, %%ymm0, %%ymm0\n"
2031 "1:\n" // LoopH
2032 "movq %4, %%rsi\n" // width
2033 "movq %0, %%rcx\n" // src_h
2034 "2:\n" // Loopw
2035 "vmovups (%%rcx), %%ymm4\n"
2036 // Weight data is loaded directly from memory instead of into registers for calculation.
2037 "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
2038 "addq $32, %1\n"
2039
2040 "addq %5, %%rcx\n" // in_kw_step
2041 "dec %%rsi\n"
2042 "jg 2b\n"
2043
2044 "addq %6, %0\n" // in_kh_step
2045 "addq %7, %1\n" // kw_remainder
2046 "dec %3\n"
2047 "jg 1b\n"
2048 :
2049 : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5
2050 "r"(in_kh_step), "r"(kw_remainder) // 7
2051 : "%rcx", "%rsi", "%ymm0", "%ymm4");
2052
2053 asm volatile(
2054 "and $0x3, %%eax\n"
2055 "je 0f\n"
2056 // Relu
2057 "vxorps %%ymm12, %%ymm12, %%ymm12\n"
2058 "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
2059
2060 "and $0x1, %%eax\n"
2061 "je 0f\n"
2062 // relu6
2063 "mov $0x40C00000, %%ecx\n"
2064 "vmovd %%ecx, %%xmm14\n"
2065 "vpermps %%ymm14, %%ymm12, %%ymm14\n"
2066 "vminps %%ymm14, %%ymm0, %%ymm0\n"
2067
2068 "0:\n"
2069 "vmovups %%ymm0, (%2)\n" // dst_0
2070 :
2071 : "a"(act_flag), "r"(oc_algin), "r"(dst)
2072 : "%ecx", "%ymm0", "%ymm12", "%ymm14");
2073 }
2074 #endif
2075