• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/fp32/pack_fp32.h"
18 #include "nnacl/fp32/matmul_fp32.h"
19 
PackWeightKHWToHWKFp32(const void * src,void * dst,int plane,int channel)20 void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel) {
21   PackNCHWToNHWCFp32(src, dst, 1, plane, channel, 0, 0);
22 }
23 
PackHWCToWHC(const float * src,float * dst,int height,int width,int channel)24 void PackHWCToWHC(const float *src, float *dst, int height, int width, int channel) {
25   for (int i = 0; i < height; ++i) {
26     for (int j = 0; j < width; ++j) {
27       memcpy(dst + (j * height + i) * channel, src + (i * width + j) * channel, channel * sizeof(float));
28     }
29   }
30 }
31 
Im2ColPackUnitFp32(const float * input_data,const ConvParameter * conv_param,float * packed_input,int real_cal_num,int block_index)32 void Im2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, int real_cal_num,
33                         int block_index) {
34   // input format : nhwc
35   int kernel_h = conv_param->kernel_h_;
36   int kernel_w = conv_param->kernel_w_;
37   int kernel_plane = kernel_h * kernel_w;
38   int dilation_h = conv_param->dilation_h_;
39   int dilation_w = conv_param->dilation_w_;
40   int out_w = conv_param->output_w_;
41   if (dilation_h == 0 || dilation_w == 0 || out_w == 0) {
42     return;
43   }
44   int in_channel = conv_param->input_channel_;
45   int in_w = conv_param->input_w_;
46   for (int i = 0; i < real_cal_num; i++) {
47     int block_start = block_index + i;
48     int input_h = block_start / out_w * conv_param->stride_h_ - conv_param->pad_u_;
49     int input_w = block_start % out_w * conv_param->stride_w_ - conv_param->pad_l_;
50     int input_stride = (input_h * in_w + input_w) * in_channel;
51     int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h));
52     int kh_e = MSMIN(kernel_h, UP_DIV(conv_param->input_h_ - input_h, dilation_h));
53     int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w));
54     int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w));
55     if (dilation_w == 1 && dilation_h == 1) {
56       for (int j = kh_s; j < kh_e; j++) {
57         int input_y_stride = j * in_w * in_channel + input_stride;
58         int input_x_stride = input_y_stride + kw_s * in_channel;
59         int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane;
60         memcpy(packed_input + input_plane_offset, input_data + input_x_stride,
61                (kw_e - kw_s) * in_channel * sizeof(float));
62       }  // kernel_h loop
63     } else {
64       for (int j = kh_s; j < kh_e; j++) {
65         int input_y_stride = j * dilation_h * in_w * in_channel + input_stride;
66         for (int k = kw_s; k < kw_e; ++k) {
67           int input_x_stride = input_y_stride + k * dilation_w * in_channel;
68           int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane;
69           memcpy(packed_input + input_plane_offset, input_data + input_x_stride, in_channel * sizeof(float));
70         }
71       }  // kernel_h loop
72     }
73   }  // tile num loop
74 }
75 
PackNHWCToNC4HW4Fp32(const void * src,void * dst,int batch,int plane,int channel)76 void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
77   int c4 = UP_DIV(channel, C4NUM);
78   int c4_minus = c4 - 1;
79   for (int b = 0; b < batch; b++) {
80     int src_oc_offset = b * plane * channel;
81     int dst_oc_offset = b * plane * c4 * C4NUM;
82     for (int k = 0; k < plane; k++) {
83       int src_kernel_offset = src_oc_offset + k * channel;
84       int dst_kernel_offset = dst_oc_offset + k * C4NUM;
85       for (int j = 0; j < c4_minus; ++j) {
86         int src_ic_offset = src_kernel_offset + j * C4NUM;
87         int dst_ic_offset = dst_kernel_offset + j * plane * C4NUM;
88 #ifdef ENABLE_ARM
89         vst1q_f32((float *)dst + dst_ic_offset, vld1q_f32((float *)src + src_ic_offset));
90 #else
91         for (int i = 0; i < C4NUM; ++i) {
92           ((float *)dst + dst_ic_offset)[i] = ((float *)src + src_ic_offset)[i];
93         }
94 #endif
95       }
96       int tmp_c = c4_minus * C4NUM;
97       int tmp_c_offset = tmp_c * plane;
98       int res_c = channel - tmp_c;
99       if (res_c > channel) {
100         return;
101       }
102       for (int l = 0; l < res_c; ++l) {
103         int src_ic_offset = src_kernel_offset + tmp_c + l;
104         int dst_ic_offset = dst_kernel_offset + tmp_c_offset + l;
105         ((float *)dst + dst_ic_offset)[0] = ((float *)src + src_ic_offset)[0];
106       }
107     }
108   }
109 }
110 
PackNCHWToNC4HW4Fp32(const void * src,void * dst,int batch,int plane,int channel)111 void PackNCHWToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
112   int c4 = UP_DIV(channel, C4NUM);
113   for (int b = 0; b < batch; b++) {
114     int src_offset = b * plane * channel;
115     int dst_offset = b * plane * c4 * C4NUM;
116     RowMajor2Col4Major((const float *)src + src_offset, (float *)dst + dst_offset, channel, plane);
117   }
118 }
119 
PackNHWCToNHWC4Fp32(const void * src,void * dst,int batch,int plane,int channel)120 void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
121   int oc_block = UP_DIV(channel, C4NUM);
122   int oc_block_channel = oc_block * C4NUM;
123   int ic_remainder_ = channel % C4NUM;
124   if (ic_remainder_ != 0) {
125     for (int b = 0; b < batch; b++) {
126       int dst_batch_offset = b * oc_block_channel * plane;
127       int batch_offset = b * channel * plane;
128       for (int i = 0; i < plane; i++) {
129         float *dst_per_plane = (float *)dst + dst_batch_offset + i * oc_block_channel;
130         memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float));
131         memset(dst_per_plane + channel, 0, (oc_block_channel - channel) * sizeof(float));
132       }
133     }
134   } else {
135     size_t ori_input_size = batch * plane * channel * sizeof(float);
136     memcpy((float *)dst, (float *)src, ori_input_size);
137   }
138 }
139 
PackNHWCToNHWCXFp32(const void * src,void * dst,int batch,int plane,int channel,int oc_tile)140 void PackNHWCToNHWCXFp32(const void *src, void *dst, int batch, int plane, int channel, int oc_tile) {
141   int oc_block = UP_DIV(channel, oc_tile);
142   int oc_block_channel = oc_block * oc_tile;
143   int ic_remainder_ = channel % oc_tile;
144   if (ic_remainder_ != 0) {
145     for (int b = 0; b < batch; b++) {
146       int dst_batch_offset = b * oc_block_channel * plane;
147       int batch_offset = b * channel * plane;
148       for (int i = 0; i < plane; i++) {
149         float *dst_per_plane = (float *)dst + dst_batch_offset + i * oc_block_channel;
150         memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float));
151         memset(dst_per_plane + channel, 0, (oc_block_channel - channel) * sizeof(float));
152       }
153     }
154   } else {
155     size_t ori_input_size = batch * plane * channel * sizeof(float);
156     memcpy((float *)dst, (float *)src, ori_input_size);
157   }
158 }
159 
160 #ifdef ENABLE_AVX
161 // PackNHWCToNXHWCXFp32 is SWPackNHWCToNXHWCXFp32 asm optimize
PackNHWCToNXHWCXFp32(int kernel_h,int kernel_w,int output_channel,int oc_block_num,int input_channel,float * tmp_weight,const float * src)162 void PackNHWCToNXHWCXFp32(int kernel_h, int kernel_w, int output_channel, int oc_block_num, int input_channel,
163                           float *tmp_weight, const float *src) {
164   // pack weight NHWC to N32HWC32 N24HWC24 N16HWC16 N8HWC8
165   // output_channel: batch
166   int ic8 = DOWN_ROUND(input_channel, C8NUM);
167   int oc_block8 = DOWN_DIV(output_channel, C8NUM);
168   int oc_block = 0;
169   int oc = 0;
170   int oc_remainder_step = 0;
171   if (oc_block8 != oc_block_num) {
172     oc_block8 = oc_block8 / C4NUM * C4NUM;
173     oc_remainder_step = (oc_block_num - oc_block8) * C8NUM;
174   }
175   int plane = kernel_w * kernel_h;
176   if (plane == 1) {  // conv 1x1 weight pack
177     for (; oc < oc_block8; oc += (oc_block / C8NUM)) {
178       oc_block = MSMIN(C4NUM, oc_block8 - oc) * C8NUM;  // max_tile = 32 ==> 24 ==> 16 ==> 8
179       for (int oc_tmp = 0; oc_tmp < oc_block; oc_tmp += C8NUM) {
180         int ic = 0;
181         for (; ic < ic8; ic += C8NUM) {
182           Transpose8X8Fp32Avx(src + ic, tmp_weight + ic * oc_block + oc_tmp, input_channel, oc_block);
183         }
184         for (; ic < input_channel; ++ic) {
185           for (int j = 0; j < C8NUM; ++j) {
186             tmp_weight[ic * oc_block + oc_tmp + j] = src[ic + input_channel * j];
187           }
188         }
189         src += C8NUM * input_channel;
190       }
191       tmp_weight += oc_block * input_channel;
192     }
193     oc = output_channel - oc_block8 * C8NUM;
194     for (int oc_remainder = 0; oc_remainder < oc; ++oc_remainder) {
195       for (int ic = 0; ic < input_channel; ++ic) {
196         tmp_weight[oc_remainder + oc_remainder_step * ic] = src[ic + oc_remainder * input_channel];
197       }
198     }
199   } else {
200     for (; oc < oc_block8; oc += (oc_block / C8NUM)) {
201       oc_block = MSMIN(C4NUM, oc_block8 - oc) * C8NUM;  // max_tile = 32 ==> 24 ==> 16 ==> 8
202       for (int oc_tmp = 0; oc_tmp < oc_block; oc_tmp += C8NUM) {
203         for (int hw = 0; hw < plane; ++hw) {
204           int ic = 0;
205           for (; ic < ic8; ic += C8NUM) {
206             Transpose8X8Fp32Avx(src + hw * input_channel + ic,
207                                 tmp_weight + hw * oc_block * input_channel + ic * oc_block + oc_tmp,
208                                 input_channel * plane, oc_block);
209           }
210           for (; ic < input_channel; ++ic) {
211             for (int j = 0; j < C8NUM; ++j) {
212               tmp_weight[ic * oc_block + oc_tmp + j + hw * oc_block * input_channel] =
213                 src[ic + input_channel * j * plane + hw * input_channel];
214             }
215           }
216         }
217         src += C8NUM * plane * input_channel;
218       }
219       tmp_weight += oc_block * input_channel * plane;
220     }
221     oc = output_channel - oc_block8 * C8NUM;
222     for (int oc_remainder = 0; oc_remainder < oc; ++oc_remainder) {
223       for (int hw = 0; hw < plane; ++hw) {
224         for (int ic = 0; ic < input_channel; ++ic) {
225           tmp_weight[oc_remainder + oc_remainder_step * ic + hw * input_channel * oc_remainder_step] =
226             src[ic + (oc_remainder * plane + hw) * input_channel];
227         }
228       }
229     }
230   }
231 }
232 
233 #ifdef ENABLE_DEBUG
SWPackNHWCToNXHWCXFp32(int kernel_h,int kernel_w,int output_channel,int oc_block_num,int input_channel,float * tmp_weight,const float * src)234 void SWPackNHWCToNXHWCXFp32(int kernel_h, int kernel_w, int output_channel, int oc_block_num, int input_channel,
235                             float *tmp_weight, const float *src) {
236   // pack weight NHWC to N32HWC32 N24HWC24 N16HWC16 N8HWC8
237   int oc_block = 0;
238   for (int i = 0; i < oc_block_num; i += oc_block) {
239     oc_block = MSMIN(C4NUM, oc_block_num - i);  // max_tile = 4
240     int index = i * C8NUM * kernel_h * kernel_w * input_channel;
241     int oc_remainder = MSMIN(C8NUM * oc_block, output_channel - i * C8NUM);
242     for (int h = 0; h < kernel_h; ++h) {
243       for (int w = 0; w < kernel_w; ++w) {
244         int w_index = (h * kernel_w + w) * input_channel + index;
245         for (int ic = 0; ic < input_channel; ++ic) {
246           int ic_index = ic + w_index;
247           for (int oc = 0; oc < oc_remainder; ++oc) {
248             int oc_index = oc * kernel_w * kernel_h * input_channel + ic_index;
249             tmp_weight[oc] = src[oc_index];
250           }
251           tmp_weight += oc_block * C8NUM;
252         }
253       }
254     }
255   }
256 }
257 #endif
258 #endif
259 
PackNHWCToNHWC8Fp32(const void * src,void * dst,int batch,int plane,int channel)260 void PackNHWCToNHWC8Fp32(const void *src, void *dst, int batch, int plane, int channel) {
261   int c8 = UP_DIV(channel, C8NUM);
262   int c8_channel = c8 * C8NUM;
263   int nhwc8_batch_unit_offset = c8 * C8NUM * plane;
264   int ic_remainder_ = channel % C8NUM;
265   if (ic_remainder_ != 0) {
266     int nhwc8_batch_offset = 0;
267     for (int b = 0; b < batch; b++) {
268       int batch_offset = b * channel * plane;
269       for (int i = 0; i < plane; i++) {
270         float *dst_per_plane = (float *)dst + nhwc8_batch_offset + i * c8_channel;
271         memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float));
272         for (int j = channel; j < c8_channel; ++j) {
273           dst_per_plane[j] = 0;
274         }
275       }
276       nhwc8_batch_offset += nhwc8_batch_unit_offset;
277     }
278   } else {
279     size_t ori_input_size = batch * plane * channel * sizeof(float);
280     memcpy((float *)dst, (float *)src, ori_input_size);
281   }
282 }
283 
PackNHWCXToNHWCFp32(const void * src,void * dst,int batch,int plane,int channel,int cx_num)284 void PackNHWCXToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel, int cx_num) {
285   int c_algin = UP_DIV(channel, cx_num);
286   int ic_remainder_ = channel % cx_num;
287   if (ic_remainder_ != 0) {
288     int nhwc_batch_unit_offset = channel * plane;
289     for (int b = 0; b < batch; b++) {
290       int batch_offset = b * c_algin * cx_num * plane;
291       for (int i = 0; i < plane; i++) {
292         memcpy((float *)dst + b * nhwc_batch_unit_offset + i * channel,
293                (float *)src + batch_offset + i * c_algin * cx_num, channel * sizeof(float));
294       }
295     }
296   } else {
297     size_t ori_input_size = batch * plane * channel * sizeof(float);
298     memcpy((float *)dst, (float *)src, ori_input_size);
299   }
300 }
301 
PackNC4HW4ToNHWC4Fp32(const void * src,void * dst,int batch,int plane,int channel)302 void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
303   int c4 = UP_DIV(channel, C4NUM);
304   for (int b = 0; b < batch; b++) {
305     int src_offset = b * plane * c4 * C4NUM;
306     int dst_offset = b * plane * channel;
307     for (int c = 0; c < channel; c++) {
308       int c4_block_num = c / C4NUM;
309       int c4_block_res = c % C4NUM;
310       int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res;
311       int dst_c_offset = dst_offset + c4_block_num * C4NUM + c4_block_res;
312       for (int k = 0; k < plane; k++) {
313         int src_kernel_offset = src_c_offset + k * C4NUM;
314         int dst_kernel_offset = dst_c_offset + k * c4 * C4NUM;
315         ((float *)dst + dst_kernel_offset)[0] = ((float *)src + src_kernel_offset)[0];
316       }
317     }
318   }
319 }
320 
PackNC4HW4ToNHWCFp32(const void * src,void * dst,int batch,int plane,int channel)321 void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) {
322   int c4 = UP_DIV(channel, C4NUM);
323   for (int b = 0; b < batch; b++) {
324     int src_offset = b * plane * c4 * C4NUM;
325     int dst_offset = b * plane * channel;
326     for (int k = 0; k < plane; k++) {
327       int src_kernel_offset = src_offset + k * C4NUM;
328       int dst_kernel_offset = dst_offset + k * channel;
329       for (int c = 0; c < c4 - 1; c++) {
330         int src_c_offset = src_kernel_offset + c * plane * C4NUM;
331         int dst_c_offset = dst_kernel_offset + c * C4NUM;
332 #if defined(ENABLE_NEON) || defined(ENABLE_SSE)
333         MS_STQ_F32((float *)dst + dst_c_offset, MS_LDQ_F32((float *)src + src_c_offset));
334 #else
335         ((float *)dst + dst_c_offset)[0] = ((float *)src + src_c_offset)[0];
336         ((float *)dst + dst_c_offset)[1] = ((float *)src + src_c_offset)[1];
337         ((float *)dst + dst_c_offset)[2] = ((float *)src + src_c_offset)[2];
338         ((float *)dst + dst_c_offset)[3] = ((float *)src + src_c_offset)[3];
339 #endif
340       }
341       // res part
342       int res_c = channel - (c4 - 1) * C4NUM;
343       for (int i = 0; i < res_c; i++) {
344         int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i;
345         int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i;
346         ((float *)dst + dst_res_c_offset)[0] = ((float *)src + src_res_c_offset)[0];
347       }
348     }
349   }
350 }
351 
PackNC8HW8ToNHWCFp32(const void * src,void * dst,int batch,int plane,int channel)352 void PackNC8HW8ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) {
353   int c8 = UP_DIV(channel, C8NUM);
354   for (int b = 0; b < batch; b++) {
355     int src_offset = b * plane * c8 * C8NUM;
356     int dst_offset = b * plane * channel;
357     for (int k = 0; k < plane; k++) {
358       int src_kernel_offset = src_offset + k * C8NUM;
359       int dst_kernel_offset = dst_offset + k * channel;
360       for (int c = 0; c < c8 - 1; c++) {
361         int src_c_offset = src_kernel_offset + c * plane * C8NUM;
362         int dst_c_offset = dst_kernel_offset + c * C8NUM;
363 #ifdef ENABLE_AVX
364         MS_ST256_F32((float *)dst + dst_c_offset, MS_LD256_F32((float *)src + src_c_offset));
365 #else
366         ((float *)dst + dst_c_offset)[0] = ((float *)src + src_c_offset)[0];
367         ((float *)dst + dst_c_offset)[1] = ((float *)src + src_c_offset)[1];
368         ((float *)dst + dst_c_offset)[2] = ((float *)src + src_c_offset)[2];
369         ((float *)dst + dst_c_offset)[3] = ((float *)src + src_c_offset)[3];
370         ((float *)dst + dst_c_offset)[4] = ((float *)src + src_c_offset)[4];
371         ((float *)dst + dst_c_offset)[5] = ((float *)src + src_c_offset)[5];
372         ((float *)dst + dst_c_offset)[6] = ((float *)src + src_c_offset)[6];
373         ((float *)dst + dst_c_offset)[7] = ((float *)src + src_c_offset)[7];
374 #endif
375       }
376       // res part
377       int res_c = channel - (c8 - 1) * C8NUM;
378       for (int i = 0; i < res_c; i++) {
379         int src_res_c_offset = src_kernel_offset + (c8 - 1) * C8NUM * plane + i;
380         int dst_res_c_offset = dst_kernel_offset + (c8 - 1) * C8NUM + i;
381         ((float *)dst + dst_res_c_offset)[0] = ((float *)src + src_res_c_offset)[0];
382       }
383     }
384   }
385 }
386 
PackNC8HW8AlignedToNC8HW8NotAlignedFp32(const void * src,void * dst,const int batch,const int plane,const int channel)387 void PackNC8HW8AlignedToNC8HW8NotAlignedFp32(const void *src, void *dst, const int batch, const int plane,
388                                              const int channel) {
389   int down_channel_8 = DOWN_ROUND(channel, C8NUM);
390   int up_channel_16 = UP_ROUND(channel, C16NUM);
391   size_t dst_batch_offset = (size_t)(plane * channel) * sizeof(float);
392   size_t src_batch_offset = (size_t)(plane * up_channel_16) * sizeof(float);
393   size_t unaligned_channel_size = (size_t)(channel - down_channel_8) * sizeof(float);
394   size_t aligned_channel_size = (size_t)(down_channel_8 * plane) * sizeof(float);
395   size_t src_p_offset = C8NUM * sizeof(float);
396   for (size_t b = 0; b < (size_t)(batch); ++b) {
397     const char *src_batch = (char *)(src) + b * src_batch_offset;
398     char *dst_bacth = (char *)(dst) + b * dst_batch_offset;
399     memcpy(dst_bacth, src_batch, aligned_channel_size);
400     src_batch += aligned_channel_size;
401     dst_bacth += aligned_channel_size;
402     for (int p = 0; p < plane; ++p) {
403       memcpy(dst_bacth + p * unaligned_channel_size, src_batch + p * src_p_offset, unaligned_channel_size);
404     }
405   }
406 }
407 
PackNHWCToC8HWN8Fp32(const void * src,void * dst,int batch,int plane,int channel)408 void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel) {
409   int channel_up8 = UP_ROUND(channel, C8NUM);
410   for (int n = 0; n < batch; n++) {
411     for (int hw = 0; hw < plane; hw++) {
412       int c = 0;
413       for (; c < channel; c++) {
414         int c8div = c / C8NUM;
415         int c8mod = c % C8NUM;
416         int src_index = n * plane * channel + hw * channel + c;
417         int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod;
418         ((float *)dst)[dst_index] = ((float *)src)[src_index];
419       }
420       for (; c < channel_up8; c++) {
421         int c8div = c / C8NUM;
422         int c8mod = c % C8NUM;
423         int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod;
424         ((float *)dst)[dst_index] = 0;
425       }
426     }
427   }
428 }
429 
PackNHWCToCXHWNXFp32(const float * src,float * dst,int batch,int plane,int channel)430 void PackNHWCToCXHWNXFp32(const float *src, float *dst, int batch, int plane, int channel) {
431   // pack weight NHWC to C24HWN24 (Priority 24)=>C16HWN16 (Not satisfied 24)=>C8HWN8 (Not satisfied 16)
432 #ifdef ENABLE_AVX
433   int oc_block_num = UP_DIV(channel, C8NUM);
434   int plane16 = plane / C16NUM * C16NUM;
435   for (int i = 0, oc_block = 0; i < oc_block_num; i += oc_block) {
436     oc_block = MSMIN(C3NUM, oc_block_num - i);
437     int oc_remainder = MSMIN(C8NUM * oc_block, channel - i * C8NUM);
438     int oc_remainder_c8 = oc_remainder / C8NUM * C8NUM;
439     int p = 0;
440     for (; p < plane16; p += C16NUM) {
441       int index_plane = i * C8NUM + p * channel;
442       for (int b = 0; b < batch; ++b) {
443         int index_batch = index_plane + b * plane * channel;
444         int oc = 0;
445         int stride = oc_block * C8NUM * batch;
446         for (; oc < oc_remainder_c8; oc += C8NUM) {
447           const float *cur_src = src + index_batch + oc;
448           float *cur_dst = dst + oc;
449           LOAD256X16_F32(r, cur_src, channel);
450           STORE256X16_F32(cur_dst, stride, r);
451         }
452         for (; oc < oc_remainder; ++oc) {
453           for (int k = 0; k < C16NUM; ++k) {
454             dst[oc + stride * k] = src[index_batch + oc + channel * k];
455           }
456         }
457         for (; oc < C8NUM; ++oc) {
458           for (int k = 0; k < C16NUM; ++k) {
459             dst[oc + stride * k] = 0;
460           }
461         }
462         dst += oc_block * C8NUM;
463       }
464       dst += (C16NUM - 1) * oc_block * C8NUM * batch;
465     }
466     for (; p < plane; ++p) {
467       int index_plane = i * C8NUM + p * channel;
468       for (int b = 0; b < batch; ++b) {
469         int index_batch = index_plane + b * plane * channel;
470         int oc = 0;
471         for (; oc < oc_remainder; ++oc) {
472           dst[oc] = src[index_batch + oc];
473         }
474         for (; oc < C8NUM; ++oc) {
475           dst[oc] = 0;
476         }
477         dst += oc_block * C8NUM;
478       }
479     }
480   }
481 #else
482   int oc_block = 0;
483   int oc_block_num = UP_DIV(channel, C8NUM);
484   for (int i = 0; i < oc_block_num; i += oc_block) {
485     oc_block = MSMIN(C3NUM, oc_block_num - i);  // max_tile = 4
486     int oc_remainder = MSMIN(C8NUM * oc_block, channel - i * C8NUM);
487     for (int p = 0; p < plane; ++p) {
488       int index_plane = i * C8NUM + p * channel;
489       for (int b = 0; b < batch; ++b) {
490         int index_batch = index_plane + b * plane * channel;
491         for (int oc = 0; oc < oc_remainder; ++oc) {
492           dst[oc] = src[index_batch + oc];
493         }
494         dst += oc_block * C8NUM;
495       }
496     }
497   }
498 #endif
499 }
500 
PackDepthwiseIndirectWeightC4Fp32(const void * src,void * dst,int height,int width,int channel)501 void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel) {
502   int c4 = UP_DIV(channel, C4NUM);
503   for (int c = 0; c < c4; c++) {
504     int dst_off_c = c * C4NUM * height * width;
505     for (int i = 0; i < C4NUM; i++) {
506       int src_off_c = (c * C4NUM + i) * height * width;
507       for (int kh = 0; kh < height; kh++) {
508         int src_off_kh = src_off_c + kh * width;
509         for (int kw = 0; kw < width; kw++) {
510           int dst_off = dst_off_c + kw * height * C4NUM + kh * C4NUM + i;
511           ((float *)dst)[dst_off] = ((float *)src)[src_off_kh + kw];
512         }
513       }
514     }
515   }
516 }
517 
PackDepthwiseIndirectWeightC8Fp32(const void * src,void * dst,int height,int width,int channel)518 void PackDepthwiseIndirectWeightC8Fp32(const void *src, void *dst, int height, int width, int channel) {
519   int c8 = UP_DIV(channel, C8NUM);
520   for (int c = 0; c < c8; c++) {
521     int dst_off_c = c * C8NUM * height * width;
522     for (int i = 0; i < C8NUM; i++) {
523       int src_off_c = (c * C8NUM + i) * height * width;
524       for (int kh = 0; kh < height; kh++) {
525         int src_off_kh = src_off_c + kh * width;
526         for (int kw = 0; kw < width; kw++) {
527           int dst_off = dst_off_c + kw * height * C8NUM + kh * C8NUM + i;
528           ((float *)dst)[dst_off] = ((float *)src)[src_off_kh + kw];
529         }
530       }
531     }
532   }
533 }
534 
PackNHWCToNCHWFp32(const void * src,void * dst,int batches,int plane,int channel,int task_id,int thread_count)535 void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel, int task_id,
536                         int thread_count) {
537 #ifdef ENABLE_ARM64
538   Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Arm64;
539 #elif defined(ENABLE_ARM32)
540   Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Arm32;
541 #elif defined(ENABLE_AVX)
542   Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Avx;
543 #elif defined(ENABLE_SSE) && !defined(ENABLE_AVX)
544   Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Sse;
545 #endif
546   int hw8 = plane / C8NUM;
547   int task_start = 0;
548   int task_end = plane;
549   if (thread_count > 0) {
550     int offset_hw = UP_DIV(hw8, thread_count) * C8NUM;
551     task_start = offset_hw * task_id;
552     int count = plane - task_start;
553     if (count <= 0) {
554       return;
555     }
556     task_end = (task_id + 1) == thread_count ? plane : MSMIN(plane, task_start + offset_hw);
557     hw8 = task_start + ((task_end - task_start) >= offset_hw ? offset_hw : 0);
558   } else {
559     hw8 *= C8NUM;
560   }
561   int c8 = channel / C8NUM * C8NUM;
562   int batch = plane * channel;
563   for (int n = 0; n < batches; n++) {
564     const float *src_batch = (const float *)src + n * batch;
565     float *dst_batch = (float *)dst + n * batch;
566     int hw = task_start;
567     for (; hw < hw8; hw += C8NUM) {
568       int c = 0;
569       for (; c < c8; c += C8NUM) {
570         const float *src_ptr = src_batch + hw * channel + c;
571         float *dst_ptr = dst_batch + c * plane + hw;
572 #if defined(ENABLE_ARM64) || defined(ENABLE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM32)
573         Transpose8X8Fp32Func_(src_ptr, dst_ptr, channel, plane);
574 #else
575         for (int tr = 0; tr < C8NUM; tr++) {
576           for (int tc = 0; tc < C8NUM; tc++) {
577             dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc];
578           }
579         }
580 #endif
581       }
582       for (; c < channel; c++) {
583         const float *src_ptr = src_batch + hw * channel + c;
584         float *dst_ptr = dst_batch + c * plane + hw;
585         for (size_t i = 0; i < C8NUM; i++) {
586           dst_ptr[i] = src_ptr[i * channel];
587         }
588       }
589     }
590     for (; hw < task_end; hw++) {
591       const float *src_ptr = src_batch + hw * channel;
592       float *dst_ptr = dst_batch + hw;
593       for (size_t i = 0; i < channel; i++) {
594         dst_ptr[i * plane] = src_ptr[i];
595       }
596     }
597   }
598 }
599 
PackNCHWToNHWCFp32(const void * src,void * dst,int batch,int plane,int channel,int task_id,int thread_count)600 void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel, int task_id, int thread_count) {
601   PackNHWCToNCHWFp32(src, dst, batch, channel, plane, task_id, thread_count);
602 }
603 
604 #ifdef ENABLE_ARM64
Transpose8X8Fp32Arm64(const float * src_ptr,float * dst_ptr,int src_stride,int dst_stride)605 inline void Transpose8X8Fp32Arm64(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) {
606   size_t srcStride = src_stride * sizeof(float);
607   size_t dstStride = dst_stride * sizeof(float);
608   asm volatile(
609     "mov x10, %[src_ptr]\n"
610     "mov x11, %[dst_ptr]\n"
611 
612     "ld1 {v0.4s, v1.4s}, [x10], %[srcStride]\n"
613     "ld1 {v2.4s, v3.4s}, [x10], %[srcStride]\n"
614 
615     "zip1 v8.4s, v0.4s, v2.4s\n"
616     "zip2 v9.4s, v0.4s, v2.4s\n"
617     "zip1 v12.4s, v1.4s, v3.4s\n"
618     "zip2 v13.4s, v1.4s, v3.4s\n"
619 
620     "ld1 {v4.4s, v5.4s}, [x10], %[srcStride]\n"
621     "ld1 {v6.4s, v7.4s}, [x10], %[srcStride]\n"
622 
623     "zip1 v10.4s, v4.4s, v6.4s\n"
624     "zip2 v11.4s, v4.4s, v6.4s\n"
625     "zip1 v14.4s, v5.4s, v7.4s\n"
626     "zip2 v15.4s, v5.4s, v7.4s\n"
627 
628     "ld1 {v0.4s, v1.4s}, [x10], %[srcStride]\n"
629     "ld1 {v2.4s, v3.4s}, [x10], %[srcStride]\n"
630 
631     "trn1 v16.2d, v8.2d, v10.2d\n"
632     "trn2 v18.2d, v8.2d, v10.2d\n"
633     "trn1 v20.2d, v9.2d, v11.2d\n"
634     "trn2 v22.2d, v9.2d, v11.2d\n"
635 
636     "ld1 {v4.4s, v5.4s}, [x10], %[srcStride]\n"
637     "ld1 {v6.4s, v7.4s}, [x10], %[srcStride]\n"
638 
639     "trn1 v24.2d, v12.2d, v14.2d\n"
640     "trn2 v26.2d, v12.2d, v14.2d\n"
641     "trn1 v28.2d, v13.2d, v15.2d\n"
642     "trn2 v30.2d, v13.2d, v15.2d\n"
643 
644     "zip1 v8.4s, v0.4s, v2.4s\n"
645     "zip2 v9.4s, v0.4s, v2.4s\n"
646     "zip1 v12.4s, v1.4s, v3.4s\n"
647     "zip2 v13.4s, v1.4s, v3.4s\n"
648 
649     "zip1 v10.4s, v4.4s, v6.4s\n"
650     "zip2 v11.4s, v4.4s, v6.4s\n"
651     "zip1 v14.4s, v5.4s, v7.4s\n"
652     "zip2 v15.4s, v5.4s, v7.4s\n"
653 
654     "trn1 v17.2d, v8.2d, v10.2d\n"
655     "trn2 v19.2d, v8.2d, v10.2d\n"
656     "trn1 v21.2d, v9.2d, v11.2d\n"
657     "trn2 v23.2d, v9.2d, v11.2d\n"
658 
659     "trn1 v25.2d, v12.2d, v14.2d\n"
660     "trn2 v27.2d, v12.2d, v14.2d\n"
661     "trn1 v29.2d, v13.2d, v15.2d\n"
662     "trn2 v31.2d, v13.2d, v15.2d\n"
663 
664     "st1 {v16.4s, v17.4s}, [x11], %[dstStride]\n"
665     "st1 {v18.4s, v19.4s}, [x11], %[dstStride]\n"
666     "st1 {v20.4s, v21.4s}, [x11], %[dstStride]\n"
667     "st1 {v22.4s, v23.4s}, [x11], %[dstStride]\n"
668     "st1 {v24.4s, v25.4s}, [x11], %[dstStride]\n"
669     "st1 {v26.4s, v27.4s}, [x11], %[dstStride]\n"
670     "st1 {v28.4s, v29.4s}, [x11], %[dstStride]\n"
671     "st1 {v30.4s, v31.4s}, [x11], %[dstStride]\n"
672 
673     :
674     : [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride)
675     : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
676       "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
677       "v31");
678 }
679 #endif
680 
681 #ifdef ENABLE_ARM32
Transpose8X8Fp32Arm32(const float * src_ptr,float * dst_ptr,int src_stride,int dst_stride)682 inline void Transpose8X8Fp32Arm32(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) {
683   size_t srcStride = src_stride * sizeof(float);
684   size_t dstStride = dst_stride * sizeof(float);
685   asm volatile(
686     "mov r10, %[src_ptr]\n"
687     "mov r12, %[dst_ptr]\n"
688 
689     "vld1.32 {q0, q1}, [r10], %[srcStride]\n"
690     "vld1.32 {q2, q3}, [r10], %[srcStride]\n"
691 
692     "vtrn.32 d0, d4\n"
693     "vtrn.32 d1, d5\n"
694     "vtrn.32 d2, d6\n"
695     "vtrn.32 d3, d7\n"
696 
697     "vld1.32 {q4, q5}, [r10], %[srcStride]\n"
698     "vld1.32 {q6, q7}, [r10], %[srcStride]\n"
699 
700     "vtrn.32 d8, d12\n"
701     "vtrn.32 d9, d13\n"
702     "vtrn.32 d10, d14\n"
703     "vtrn.32 d11, d15\n"
704 
705     "vld1.32 {q8, q9}, [r10], %[srcStride]\n"
706     "vld1.32 {q10, q11}, [r10], %[srcStride]\n"
707 
708     "vswp d1, d8\n"
709     "vswp d3, d10\n"
710     "vswp d5, d12\n"
711     "vswp d7, d14\n"
712 
713     "vtrn.32 d16, d20\n"
714     "vtrn.32 d17, d21\n"
715     "vtrn.32 d18, d22\n"
716     "vtrn.32 d19, d23\n"
717 
718     "vld1.32 {q12, q13}, [r10], %[srcStride]\n"
719     "vld1.32 {q14, q15}, [r10], %[srcStride]\n"
720 
721     "vtrn.32 d24, d28\n"
722     "vtrn.32 d25, d29\n"
723     "vtrn.32 d26, d30\n"
724     "vtrn.32 d27, d31\n"
725 
726     "vswp d17, d24\n"
727     "vswp d19, d26\n"
728     "vswp d21, d28\n"
729     "vswp d23, d30\n"
730 
731     "add r10, r12, #16\n"
732     "vst1.32 {q0}, [r12], %[dstStride]\n"
733     "vst1.32 {q8}, [r10], %[dstStride]\n"
734     "vst1.32 {q2}, [r12], %[dstStride]\n"
735     "vst1.32 {q10}, [r10], %[dstStride]\n"
736     "vst1.32 {q4}, [r12], %[dstStride]\n"
737     "vst1.32 {q12}, [r10], %[dstStride]\n"
738     "vst1.32 {q6}, [r12], %[dstStride]\n"
739     "vst1.32 {q14}, [r10], %[dstStride]\n"
740     "vst1.32 {q1}, [r12], %[dstStride]\n"
741     "vst1.32 {q9}, [r10], %[dstStride]\n"
742     "vst1.32 {q3}, [r12], %[dstStride]\n"
743     "vst1.32 {q11}, [r10], %[dstStride]\n"
744     "vst1.32 {q5}, [r12], %[dstStride]\n"
745     "vst1.32 {q13}, [r10], %[dstStride]\n"
746     "vst1.32 {q7}, [r12], %[dstStride]\n"
747     "vst1.32 {q15}, [r10], %[dstStride]\n"
748 
749     :
750     : [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride)
751     : "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14",
752       "q15");
753 }
754 #endif
755 
756 #ifdef ENABLE_AVX
Transpose8X8Fp32Avx(const float * src_ptr,float * dst_ptr,int src_stride,int dst_stride)757 inline void Transpose8X8Fp32Avx(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) {
758   LOAD256X8_F32(src, src_ptr, src_stride)
759   __m256 r1 = _mm256_unpacklo_ps(src1, src2);
760   __m256 r2 = _mm256_unpackhi_ps(src1, src2);
761   __m256 r3 = _mm256_unpacklo_ps(src3, src4);
762   __m256 r4 = _mm256_unpackhi_ps(src3, src4);
763   __m256 r5 = _mm256_unpacklo_ps(src5, src6);
764   __m256 r6 = _mm256_unpackhi_ps(src5, src6);
765   __m256 r7 = _mm256_unpacklo_ps(src7, src8);
766   __m256 r8 = _mm256_unpackhi_ps(src7, src8);
767 
768   __m256 v;
769   v = _mm256_shuffle_ps(r1, r3, 0x4E);
770   src1 = _mm256_blend_ps(r1, v, 0xCC);
771   src2 = _mm256_blend_ps(r3, v, 0x33);
772 
773   v = _mm256_shuffle_ps(r2, r4, 0x4E);
774   src3 = _mm256_blend_ps(r2, v, 0xCC);
775   src4 = _mm256_blend_ps(r4, v, 0x33);
776 
777   v = _mm256_shuffle_ps(r5, r7, 0x4E);
778   src5 = _mm256_blend_ps(r5, v, 0xCC);
779   src6 = _mm256_blend_ps(r7, v, 0x33);
780 
781   v = _mm256_shuffle_ps(r6, r8, 0x4E);
782   src7 = _mm256_blend_ps(r6, v, 0xCC);
783   src8 = _mm256_blend_ps(r8, v, 0x33);
784 
785   r1 = _mm256_permute2f128_ps(src1, src5, 0x20);
786   r2 = _mm256_permute2f128_ps(src2, src6, 0x20);
787   r3 = _mm256_permute2f128_ps(src3, src7, 0x20);
788   r4 = _mm256_permute2f128_ps(src4, src8, 0x20);
789   r5 = _mm256_permute2f128_ps(src1, src5, 0x31);
790   r6 = _mm256_permute2f128_ps(src2, src6, 0x31);
791   r7 = _mm256_permute2f128_ps(src3, src7, 0x31);
792   r8 = _mm256_permute2f128_ps(src4, src8, 0x31);
793 
794   STORE256X8_F32(dst_ptr, dst_stride, r);
795 }
796 #endif
797 
798 #if defined(ENABLE_SSE) && !defined(ENABLE_AVX)
Transpose8X8Fp32Sse(const float * src_ptr,float * dst_ptr,int src_stride,int dst_stride)799 inline void Transpose8X8Fp32Sse(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) {
800   __m128 v0_ma = _mm_loadu_ps(src_ptr);
801   __m128 v1_ma = _mm_loadu_ps(src_ptr + src_stride);
802   __m128 v2_ma = _mm_loadu_ps(src_ptr + 2 * src_stride);
803   __m128 v3_ma = _mm_loadu_ps(src_ptr + 3 * src_stride);
804 
805   __m128 v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma);
806   __m128 v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma);
807   __m128 v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma);
808   __m128 v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma);
809 
810   __m128 v8_ma = _mm_movelh_ps(v4_ma, v6_ma);
811   __m128 v9_ma = _mm_movehl_ps(v6_ma, v4_ma);
812   __m128 v10_ma = _mm_movelh_ps(v5_ma, v7_ma);
813   __m128 v11_ma = _mm_movehl_ps(v7_ma, v5_ma);
814 
815   _mm_storeu_ps(dst_ptr, v8_ma);
816   _mm_storeu_ps(dst_ptr + dst_stride, v9_ma);
817   _mm_storeu_ps(dst_ptr + 2 * dst_stride, v10_ma);
818   _mm_storeu_ps(dst_ptr + 3 * dst_stride, v11_ma);
819 
820   v0_ma = _mm_loadu_ps(src_ptr + C4NUM);
821   v1_ma = _mm_loadu_ps(src_ptr + src_stride + C4NUM);
822   v2_ma = _mm_loadu_ps(src_ptr + 2 * src_stride + C4NUM);
823   v3_ma = _mm_loadu_ps(src_ptr + 3 * src_stride + C4NUM);
824 
825   v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma);
826   v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma);
827   v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma);
828   v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma);
829 
830   v8_ma = _mm_movelh_ps(v4_ma, v6_ma);
831   v9_ma = _mm_movehl_ps(v6_ma, v4_ma);
832   v10_ma = _mm_movelh_ps(v5_ma, v7_ma);
833   v11_ma = _mm_movehl_ps(v7_ma, v5_ma);
834 
835   _mm_storeu_ps(dst_ptr + C4NUM * dst_stride, v8_ma);
836   _mm_storeu_ps(dst_ptr + (C4NUM + 1) * dst_stride, v9_ma);
837   _mm_storeu_ps(dst_ptr + (C4NUM + 2) * dst_stride, v10_ma);
838   _mm_storeu_ps(dst_ptr + (C4NUM + 3) * dst_stride, v11_ma);
839 
840   v0_ma = _mm_loadu_ps(src_ptr + C4NUM * src_stride);
841   v1_ma = _mm_loadu_ps(src_ptr + (C4NUM + 1) * src_stride);
842   v2_ma = _mm_loadu_ps(src_ptr + (C4NUM + 2) * src_stride);
843   v3_ma = _mm_loadu_ps(src_ptr + (C4NUM + 3) * src_stride);
844 
845   v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma);
846   v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma);
847   v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma);
848   v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma);
849 
850   v8_ma = _mm_movelh_ps(v4_ma, v6_ma);
851   v9_ma = _mm_movehl_ps(v6_ma, v4_ma);
852   v10_ma = _mm_movelh_ps(v5_ma, v7_ma);
853   v11_ma = _mm_movehl_ps(v7_ma, v5_ma);
854 
855   _mm_storeu_ps(dst_ptr + C4NUM, v8_ma);
856   _mm_storeu_ps(dst_ptr + dst_stride + C4NUM, v9_ma);
857   _mm_storeu_ps(dst_ptr + 2 * dst_stride + C4NUM, v10_ma);
858   _mm_storeu_ps(dst_ptr + 3 * dst_stride + C4NUM, v11_ma);
859 
860   v0_ma = _mm_loadu_ps(src_ptr + C4NUM * src_stride + C4NUM);
861   v1_ma = _mm_loadu_ps(src_ptr + (C4NUM + 1) * src_stride + C4NUM);
862   v2_ma = _mm_loadu_ps(src_ptr + (C4NUM + 2) * src_stride + C4NUM);
863   v3_ma = _mm_loadu_ps(src_ptr + (C4NUM + 3) * src_stride + C4NUM);
864 
865   v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma);
866   v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma);
867   v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma);
868   v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma);
869 
870   v8_ma = _mm_movelh_ps(v4_ma, v6_ma);
871   v9_ma = _mm_movehl_ps(v6_ma, v4_ma);
872   v10_ma = _mm_movelh_ps(v5_ma, v7_ma);
873   v11_ma = _mm_movehl_ps(v7_ma, v5_ma);
874 
875   _mm_storeu_ps(dst_ptr + C4NUM * dst_stride + C4NUM, v8_ma);
876   _mm_storeu_ps(dst_ptr + (C4NUM + 1) * dst_stride + C4NUM, v9_ma);
877   _mm_storeu_ps(dst_ptr + (C4NUM + 2) * dst_stride + C4NUM, v10_ma);
878   _mm_storeu_ps(dst_ptr + (C4NUM + 3) * dst_stride + C4NUM, v11_ma);
879 }
880 #endif
881 
882 #if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
PackWeightConvDw3x3Fp32(const void * src,void * dst,int channel)883 void PackWeightConvDw3x3Fp32(const void *src, void *dst, int channel) {
884   // nchw to nc4hw4 with 1D F(2,3)
885   for (int i = 0; i < channel; i++) {
886     float *src_kernel = (float *)src + i * 9;
887     float *dst_kernel = (float *)dst + (i / 4) * 48 + i % 4;
888     for (int y = 0; y < 3; y++) {
889       float g0 = src_kernel[3 * y];
890       float g1 = src_kernel[3 * y + 1];
891       float g2 = src_kernel[3 * y + 2];
892 
893       dst_kernel[16 * y] = g0;
894       dst_kernel[16 * y + 4] = 0.5f * (g0 + g1 + g2);
895       dst_kernel[16 * y + 8] = 0.5f * (g0 - g1 + g2);
896       dst_kernel[16 * y + 12] = g2;
897     }
898   }
899 }
900 #endif
901