1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/fp32/pack_fp32.h"
18 #include "nnacl/fp32/matmul_fp32.h"
19
PackWeightKHWToHWKFp32(const void * src,void * dst,int plane,int channel)20 void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel) {
21 PackNCHWToNHWCFp32(src, dst, 1, plane, channel, 0, 0);
22 }
23
PackHWCToWHC(const float * src,float * dst,int height,int width,int channel)24 void PackHWCToWHC(const float *src, float *dst, int height, int width, int channel) {
25 for (int i = 0; i < height; ++i) {
26 for (int j = 0; j < width; ++j) {
27 memcpy(dst + (j * height + i) * channel, src + (i * width + j) * channel, channel * sizeof(float));
28 }
29 }
30 }
31
Im2ColPackUnitFp32(const float * input_data,const ConvParameter * conv_param,float * packed_input,int real_cal_num,int block_index)32 void Im2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, int real_cal_num,
33 int block_index) {
34 // input format : nhwc
35 int kernel_h = conv_param->kernel_h_;
36 int kernel_w = conv_param->kernel_w_;
37 int kernel_plane = kernel_h * kernel_w;
38 int dilation_h = conv_param->dilation_h_;
39 int dilation_w = conv_param->dilation_w_;
40 int out_w = conv_param->output_w_;
41 if (dilation_h == 0 || dilation_w == 0 || out_w == 0) {
42 return;
43 }
44 int in_channel = conv_param->input_channel_;
45 int in_w = conv_param->input_w_;
46 for (int i = 0; i < real_cal_num; i++) {
47 int block_start = block_index + i;
48 int input_h = block_start / out_w * conv_param->stride_h_ - conv_param->pad_u_;
49 int input_w = block_start % out_w * conv_param->stride_w_ - conv_param->pad_l_;
50 int input_stride = (input_h * in_w + input_w) * in_channel;
51 int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h));
52 int kh_e = MSMIN(kernel_h, UP_DIV(conv_param->input_h_ - input_h, dilation_h));
53 int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w));
54 int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w));
55 if (dilation_w == 1 && dilation_h == 1) {
56 for (int j = kh_s; j < kh_e; j++) {
57 int input_y_stride = j * in_w * in_channel + input_stride;
58 int input_x_stride = input_y_stride + kw_s * in_channel;
59 int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane;
60 memcpy(packed_input + input_plane_offset, input_data + input_x_stride,
61 (kw_e - kw_s) * in_channel * sizeof(float));
62 } // kernel_h loop
63 } else {
64 for (int j = kh_s; j < kh_e; j++) {
65 int input_y_stride = j * dilation_h * in_w * in_channel + input_stride;
66 for (int k = kw_s; k < kw_e; ++k) {
67 int input_x_stride = input_y_stride + k * dilation_w * in_channel;
68 int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane;
69 memcpy(packed_input + input_plane_offset, input_data + input_x_stride, in_channel * sizeof(float));
70 }
71 } // kernel_h loop
72 }
73 } // tile num loop
74 }
75
PackNHWCToNC4HW4Fp32(const void * src,void * dst,int batch,int plane,int channel)76 void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
77 int c4 = UP_DIV(channel, C4NUM);
78 int c4_minus = c4 - 1;
79 for (int b = 0; b < batch; b++) {
80 int src_oc_offset = b * plane * channel;
81 int dst_oc_offset = b * plane * c4 * C4NUM;
82 for (int k = 0; k < plane; k++) {
83 int src_kernel_offset = src_oc_offset + k * channel;
84 int dst_kernel_offset = dst_oc_offset + k * C4NUM;
85 for (int j = 0; j < c4_minus; ++j) {
86 int src_ic_offset = src_kernel_offset + j * C4NUM;
87 int dst_ic_offset = dst_kernel_offset + j * plane * C4NUM;
88 #ifdef ENABLE_ARM
89 vst1q_f32((float *)dst + dst_ic_offset, vld1q_f32((float *)src + src_ic_offset));
90 #else
91 for (int i = 0; i < C4NUM; ++i) {
92 ((float *)dst + dst_ic_offset)[i] = ((float *)src + src_ic_offset)[i];
93 }
94 #endif
95 }
96 int tmp_c = c4_minus * C4NUM;
97 int tmp_c_offset = tmp_c * plane;
98 int res_c = channel - tmp_c;
99 if (res_c > channel) {
100 return;
101 }
102 for (int l = 0; l < res_c; ++l) {
103 int src_ic_offset = src_kernel_offset + tmp_c + l;
104 int dst_ic_offset = dst_kernel_offset + tmp_c_offset + l;
105 ((float *)dst + dst_ic_offset)[0] = ((float *)src + src_ic_offset)[0];
106 }
107 }
108 }
109 }
110
PackNCHWToNC4HW4Fp32(const void * src,void * dst,int batch,int plane,int channel)111 void PackNCHWToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
112 int c4 = UP_DIV(channel, C4NUM);
113 for (int b = 0; b < batch; b++) {
114 int src_offset = b * plane * channel;
115 int dst_offset = b * plane * c4 * C4NUM;
116 RowMajor2Col4Major((const float *)src + src_offset, (float *)dst + dst_offset, channel, plane);
117 }
118 }
119
PackNHWCToNHWC4Fp32(const void * src,void * dst,int batch,int plane,int channel)120 void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
121 int oc_block = UP_DIV(channel, C4NUM);
122 int oc_block_channel = oc_block * C4NUM;
123 int ic_remainder_ = channel % C4NUM;
124 if (ic_remainder_ != 0) {
125 for (int b = 0; b < batch; b++) {
126 int dst_batch_offset = b * oc_block_channel * plane;
127 int batch_offset = b * channel * plane;
128 for (int i = 0; i < plane; i++) {
129 float *dst_per_plane = (float *)dst + dst_batch_offset + i * oc_block_channel;
130 memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float));
131 memset(dst_per_plane + channel, 0, (oc_block_channel - channel) * sizeof(float));
132 }
133 }
134 } else {
135 size_t ori_input_size = batch * plane * channel * sizeof(float);
136 memcpy((float *)dst, (float *)src, ori_input_size);
137 }
138 }
139
PackNHWCToNHWCXFp32(const void * src,void * dst,int batch,int plane,int channel,int oc_tile)140 void PackNHWCToNHWCXFp32(const void *src, void *dst, int batch, int plane, int channel, int oc_tile) {
141 int oc_block = UP_DIV(channel, oc_tile);
142 int oc_block_channel = oc_block * oc_tile;
143 int ic_remainder_ = channel % oc_tile;
144 if (ic_remainder_ != 0) {
145 for (int b = 0; b < batch; b++) {
146 int dst_batch_offset = b * oc_block_channel * plane;
147 int batch_offset = b * channel * plane;
148 for (int i = 0; i < plane; i++) {
149 float *dst_per_plane = (float *)dst + dst_batch_offset + i * oc_block_channel;
150 memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float));
151 memset(dst_per_plane + channel, 0, (oc_block_channel - channel) * sizeof(float));
152 }
153 }
154 } else {
155 size_t ori_input_size = batch * plane * channel * sizeof(float);
156 memcpy((float *)dst, (float *)src, ori_input_size);
157 }
158 }
159
160 #ifdef ENABLE_AVX
161 // PackNHWCToNXHWCXFp32 is SWPackNHWCToNXHWCXFp32 asm optimize
PackNHWCToNXHWCXFp32(int kernel_h,int kernel_w,int output_channel,int oc_block_num,int input_channel,float * tmp_weight,const float * src)162 void PackNHWCToNXHWCXFp32(int kernel_h, int kernel_w, int output_channel, int oc_block_num, int input_channel,
163 float *tmp_weight, const float *src) {
164 // pack weight NHWC to N32HWC32 N24HWC24 N16HWC16 N8HWC8
165 // output_channel: batch
166 int ic8 = DOWN_ROUND(input_channel, C8NUM);
167 int oc_block8 = DOWN_DIV(output_channel, C8NUM);
168 int oc_block = 0;
169 int oc = 0;
170 int oc_remainder_step = 0;
171 if (oc_block8 != oc_block_num) {
172 oc_block8 = oc_block8 / C4NUM * C4NUM;
173 oc_remainder_step = (oc_block_num - oc_block8) * C8NUM;
174 }
175 int plane = kernel_w * kernel_h;
176 if (plane == 1) { // conv 1x1 weight pack
177 for (; oc < oc_block8; oc += (oc_block / C8NUM)) {
178 oc_block = MSMIN(C4NUM, oc_block8 - oc) * C8NUM; // max_tile = 32 ==> 24 ==> 16 ==> 8
179 for (int oc_tmp = 0; oc_tmp < oc_block; oc_tmp += C8NUM) {
180 int ic = 0;
181 for (; ic < ic8; ic += C8NUM) {
182 Transpose8X8Fp32Avx(src + ic, tmp_weight + ic * oc_block + oc_tmp, input_channel, oc_block);
183 }
184 for (; ic < input_channel; ++ic) {
185 for (int j = 0; j < C8NUM; ++j) {
186 tmp_weight[ic * oc_block + oc_tmp + j] = src[ic + input_channel * j];
187 }
188 }
189 src += C8NUM * input_channel;
190 }
191 tmp_weight += oc_block * input_channel;
192 }
193 oc = output_channel - oc_block8 * C8NUM;
194 for (int oc_remainder = 0; oc_remainder < oc; ++oc_remainder) {
195 for (int ic = 0; ic < input_channel; ++ic) {
196 tmp_weight[oc_remainder + oc_remainder_step * ic] = src[ic + oc_remainder * input_channel];
197 }
198 }
199 } else {
200 for (; oc < oc_block8; oc += (oc_block / C8NUM)) {
201 oc_block = MSMIN(C4NUM, oc_block8 - oc) * C8NUM; // max_tile = 32 ==> 24 ==> 16 ==> 8
202 for (int oc_tmp = 0; oc_tmp < oc_block; oc_tmp += C8NUM) {
203 for (int hw = 0; hw < plane; ++hw) {
204 int ic = 0;
205 for (; ic < ic8; ic += C8NUM) {
206 Transpose8X8Fp32Avx(src + hw * input_channel + ic,
207 tmp_weight + hw * oc_block * input_channel + ic * oc_block + oc_tmp,
208 input_channel * plane, oc_block);
209 }
210 for (; ic < input_channel; ++ic) {
211 for (int j = 0; j < C8NUM; ++j) {
212 tmp_weight[ic * oc_block + oc_tmp + j + hw * oc_block * input_channel] =
213 src[ic + input_channel * j * plane + hw * input_channel];
214 }
215 }
216 }
217 src += C8NUM * plane * input_channel;
218 }
219 tmp_weight += oc_block * input_channel * plane;
220 }
221 oc = output_channel - oc_block8 * C8NUM;
222 for (int oc_remainder = 0; oc_remainder < oc; ++oc_remainder) {
223 for (int hw = 0; hw < plane; ++hw) {
224 for (int ic = 0; ic < input_channel; ++ic) {
225 tmp_weight[oc_remainder + oc_remainder_step * ic + hw * input_channel * oc_remainder_step] =
226 src[ic + (oc_remainder * plane + hw) * input_channel];
227 }
228 }
229 }
230 }
231 }
232
233 #ifdef ENABLE_DEBUG
SWPackNHWCToNXHWCXFp32(int kernel_h,int kernel_w,int output_channel,int oc_block_num,int input_channel,float * tmp_weight,const float * src)234 void SWPackNHWCToNXHWCXFp32(int kernel_h, int kernel_w, int output_channel, int oc_block_num, int input_channel,
235 float *tmp_weight, const float *src) {
236 // pack weight NHWC to N32HWC32 N24HWC24 N16HWC16 N8HWC8
237 int oc_block = 0;
238 for (int i = 0; i < oc_block_num; i += oc_block) {
239 oc_block = MSMIN(C4NUM, oc_block_num - i); // max_tile = 4
240 int index = i * C8NUM * kernel_h * kernel_w * input_channel;
241 int oc_remainder = MSMIN(C8NUM * oc_block, output_channel - i * C8NUM);
242 for (int h = 0; h < kernel_h; ++h) {
243 for (int w = 0; w < kernel_w; ++w) {
244 int w_index = (h * kernel_w + w) * input_channel + index;
245 for (int ic = 0; ic < input_channel; ++ic) {
246 int ic_index = ic + w_index;
247 for (int oc = 0; oc < oc_remainder; ++oc) {
248 int oc_index = oc * kernel_w * kernel_h * input_channel + ic_index;
249 tmp_weight[oc] = src[oc_index];
250 }
251 tmp_weight += oc_block * C8NUM;
252 }
253 }
254 }
255 }
256 }
257 #endif
258 #endif
259
PackNHWCToNHWC8Fp32(const void * src,void * dst,int batch,int plane,int channel)260 void PackNHWCToNHWC8Fp32(const void *src, void *dst, int batch, int plane, int channel) {
261 int c8 = UP_DIV(channel, C8NUM);
262 int c8_channel = c8 * C8NUM;
263 int nhwc8_batch_unit_offset = c8 * C8NUM * plane;
264 int ic_remainder_ = channel % C8NUM;
265 if (ic_remainder_ != 0) {
266 int nhwc8_batch_offset = 0;
267 for (int b = 0; b < batch; b++) {
268 int batch_offset = b * channel * plane;
269 for (int i = 0; i < plane; i++) {
270 float *dst_per_plane = (float *)dst + nhwc8_batch_offset + i * c8_channel;
271 memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float));
272 for (int j = channel; j < c8_channel; ++j) {
273 dst_per_plane[j] = 0;
274 }
275 }
276 nhwc8_batch_offset += nhwc8_batch_unit_offset;
277 }
278 } else {
279 size_t ori_input_size = batch * plane * channel * sizeof(float);
280 memcpy((float *)dst, (float *)src, ori_input_size);
281 }
282 }
283
PackNHWCXToNHWCFp32(const void * src,void * dst,int batch,int plane,int channel,int cx_num)284 void PackNHWCXToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel, int cx_num) {
285 int c_algin = UP_DIV(channel, cx_num);
286 int ic_remainder_ = channel % cx_num;
287 if (ic_remainder_ != 0) {
288 int nhwc_batch_unit_offset = channel * plane;
289 for (int b = 0; b < batch; b++) {
290 int batch_offset = b * c_algin * cx_num * plane;
291 for (int i = 0; i < plane; i++) {
292 memcpy((float *)dst + b * nhwc_batch_unit_offset + i * channel,
293 (float *)src + batch_offset + i * c_algin * cx_num, channel * sizeof(float));
294 }
295 }
296 } else {
297 size_t ori_input_size = batch * plane * channel * sizeof(float);
298 memcpy((float *)dst, (float *)src, ori_input_size);
299 }
300 }
301
PackNC4HW4ToNHWC4Fp32(const void * src,void * dst,int batch,int plane,int channel)302 void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
303 int c4 = UP_DIV(channel, C4NUM);
304 for (int b = 0; b < batch; b++) {
305 int src_offset = b * plane * c4 * C4NUM;
306 int dst_offset = b * plane * channel;
307 for (int c = 0; c < channel; c++) {
308 int c4_block_num = c / C4NUM;
309 int c4_block_res = c % C4NUM;
310 int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res;
311 int dst_c_offset = dst_offset + c4_block_num * C4NUM + c4_block_res;
312 for (int k = 0; k < plane; k++) {
313 int src_kernel_offset = src_c_offset + k * C4NUM;
314 int dst_kernel_offset = dst_c_offset + k * c4 * C4NUM;
315 ((float *)dst + dst_kernel_offset)[0] = ((float *)src + src_kernel_offset)[0];
316 }
317 }
318 }
319 }
320
PackNC4HW4ToNHWCFp32(const void * src,void * dst,int batch,int plane,int channel)321 void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) {
322 int c4 = UP_DIV(channel, C4NUM);
323 for (int b = 0; b < batch; b++) {
324 int src_offset = b * plane * c4 * C4NUM;
325 int dst_offset = b * plane * channel;
326 for (int k = 0; k < plane; k++) {
327 int src_kernel_offset = src_offset + k * C4NUM;
328 int dst_kernel_offset = dst_offset + k * channel;
329 for (int c = 0; c < c4 - 1; c++) {
330 int src_c_offset = src_kernel_offset + c * plane * C4NUM;
331 int dst_c_offset = dst_kernel_offset + c * C4NUM;
332 #if defined(ENABLE_NEON) || defined(ENABLE_SSE)
333 MS_STQ_F32((float *)dst + dst_c_offset, MS_LDQ_F32((float *)src + src_c_offset));
334 #else
335 ((float *)dst + dst_c_offset)[0] = ((float *)src + src_c_offset)[0];
336 ((float *)dst + dst_c_offset)[1] = ((float *)src + src_c_offset)[1];
337 ((float *)dst + dst_c_offset)[2] = ((float *)src + src_c_offset)[2];
338 ((float *)dst + dst_c_offset)[3] = ((float *)src + src_c_offset)[3];
339 #endif
340 }
341 // res part
342 int res_c = channel - (c4 - 1) * C4NUM;
343 for (int i = 0; i < res_c; i++) {
344 int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i;
345 int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i;
346 ((float *)dst + dst_res_c_offset)[0] = ((float *)src + src_res_c_offset)[0];
347 }
348 }
349 }
350 }
351
PackNC8HW8ToNHWCFp32(const void * src,void * dst,int batch,int plane,int channel)352 void PackNC8HW8ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) {
353 int c8 = UP_DIV(channel, C8NUM);
354 for (int b = 0; b < batch; b++) {
355 int src_offset = b * plane * c8 * C8NUM;
356 int dst_offset = b * plane * channel;
357 for (int k = 0; k < plane; k++) {
358 int src_kernel_offset = src_offset + k * C8NUM;
359 int dst_kernel_offset = dst_offset + k * channel;
360 for (int c = 0; c < c8 - 1; c++) {
361 int src_c_offset = src_kernel_offset + c * plane * C8NUM;
362 int dst_c_offset = dst_kernel_offset + c * C8NUM;
363 #ifdef ENABLE_AVX
364 MS_ST256_F32((float *)dst + dst_c_offset, MS_LD256_F32((float *)src + src_c_offset));
365 #else
366 ((float *)dst + dst_c_offset)[0] = ((float *)src + src_c_offset)[0];
367 ((float *)dst + dst_c_offset)[1] = ((float *)src + src_c_offset)[1];
368 ((float *)dst + dst_c_offset)[2] = ((float *)src + src_c_offset)[2];
369 ((float *)dst + dst_c_offset)[3] = ((float *)src + src_c_offset)[3];
370 ((float *)dst + dst_c_offset)[4] = ((float *)src + src_c_offset)[4];
371 ((float *)dst + dst_c_offset)[5] = ((float *)src + src_c_offset)[5];
372 ((float *)dst + dst_c_offset)[6] = ((float *)src + src_c_offset)[6];
373 ((float *)dst + dst_c_offset)[7] = ((float *)src + src_c_offset)[7];
374 #endif
375 }
376 // res part
377 int res_c = channel - (c8 - 1) * C8NUM;
378 for (int i = 0; i < res_c; i++) {
379 int src_res_c_offset = src_kernel_offset + (c8 - 1) * C8NUM * plane + i;
380 int dst_res_c_offset = dst_kernel_offset + (c8 - 1) * C8NUM + i;
381 ((float *)dst + dst_res_c_offset)[0] = ((float *)src + src_res_c_offset)[0];
382 }
383 }
384 }
385 }
386
PackNC8HW8AlignedToNC8HW8NotAlignedFp32(const void * src,void * dst,const int batch,const int plane,const int channel)387 void PackNC8HW8AlignedToNC8HW8NotAlignedFp32(const void *src, void *dst, const int batch, const int plane,
388 const int channel) {
389 int down_channel_8 = DOWN_ROUND(channel, C8NUM);
390 int up_channel_16 = UP_ROUND(channel, C16NUM);
391 size_t dst_batch_offset = (size_t)(plane * channel) * sizeof(float);
392 size_t src_batch_offset = (size_t)(plane * up_channel_16) * sizeof(float);
393 size_t unaligned_channel_size = (size_t)(channel - down_channel_8) * sizeof(float);
394 size_t aligned_channel_size = (size_t)(down_channel_8 * plane) * sizeof(float);
395 size_t src_p_offset = C8NUM * sizeof(float);
396 for (size_t b = 0; b < (size_t)(batch); ++b) {
397 const char *src_batch = (char *)(src) + b * src_batch_offset;
398 char *dst_bacth = (char *)(dst) + b * dst_batch_offset;
399 memcpy(dst_bacth, src_batch, aligned_channel_size);
400 src_batch += aligned_channel_size;
401 dst_bacth += aligned_channel_size;
402 for (int p = 0; p < plane; ++p) {
403 memcpy(dst_bacth + p * unaligned_channel_size, src_batch + p * src_p_offset, unaligned_channel_size);
404 }
405 }
406 }
407
PackNHWCToC8HWN8Fp32(const void * src,void * dst,int batch,int plane,int channel)408 void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel) {
409 int channel_up8 = UP_ROUND(channel, C8NUM);
410 for (int n = 0; n < batch; n++) {
411 for (int hw = 0; hw < plane; hw++) {
412 int c = 0;
413 for (; c < channel; c++) {
414 int c8div = c / C8NUM;
415 int c8mod = c % C8NUM;
416 int src_index = n * plane * channel + hw * channel + c;
417 int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod;
418 ((float *)dst)[dst_index] = ((float *)src)[src_index];
419 }
420 for (; c < channel_up8; c++) {
421 int c8div = c / C8NUM;
422 int c8mod = c % C8NUM;
423 int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod;
424 ((float *)dst)[dst_index] = 0;
425 }
426 }
427 }
428 }
429
PackNHWCToCXHWNXFp32(const float * src,float * dst,int batch,int plane,int channel)430 void PackNHWCToCXHWNXFp32(const float *src, float *dst, int batch, int plane, int channel) {
431 // pack weight NHWC to C24HWN24 (Priority 24)=>C16HWN16 (Not satisfied 24)=>C8HWN8 (Not satisfied 16)
432 #ifdef ENABLE_AVX
433 int oc_block_num = UP_DIV(channel, C8NUM);
434 int plane16 = plane / C16NUM * C16NUM;
435 for (int i = 0, oc_block = 0; i < oc_block_num; i += oc_block) {
436 oc_block = MSMIN(C3NUM, oc_block_num - i);
437 int oc_remainder = MSMIN(C8NUM * oc_block, channel - i * C8NUM);
438 int oc_remainder_c8 = oc_remainder / C8NUM * C8NUM;
439 int p = 0;
440 for (; p < plane16; p += C16NUM) {
441 int index_plane = i * C8NUM + p * channel;
442 for (int b = 0; b < batch; ++b) {
443 int index_batch = index_plane + b * plane * channel;
444 int oc = 0;
445 int stride = oc_block * C8NUM * batch;
446 for (; oc < oc_remainder_c8; oc += C8NUM) {
447 const float *cur_src = src + index_batch + oc;
448 float *cur_dst = dst + oc;
449 LOAD256X16_F32(r, cur_src, channel);
450 STORE256X16_F32(cur_dst, stride, r);
451 }
452 for (; oc < oc_remainder; ++oc) {
453 for (int k = 0; k < C16NUM; ++k) {
454 dst[oc + stride * k] = src[index_batch + oc + channel * k];
455 }
456 }
457 for (; oc < C8NUM; ++oc) {
458 for (int k = 0; k < C16NUM; ++k) {
459 dst[oc + stride * k] = 0;
460 }
461 }
462 dst += oc_block * C8NUM;
463 }
464 dst += (C16NUM - 1) * oc_block * C8NUM * batch;
465 }
466 for (; p < plane; ++p) {
467 int index_plane = i * C8NUM + p * channel;
468 for (int b = 0; b < batch; ++b) {
469 int index_batch = index_plane + b * plane * channel;
470 int oc = 0;
471 for (; oc < oc_remainder; ++oc) {
472 dst[oc] = src[index_batch + oc];
473 }
474 for (; oc < C8NUM; ++oc) {
475 dst[oc] = 0;
476 }
477 dst += oc_block * C8NUM;
478 }
479 }
480 }
481 #else
482 int oc_block = 0;
483 int oc_block_num = UP_DIV(channel, C8NUM);
484 for (int i = 0; i < oc_block_num; i += oc_block) {
485 oc_block = MSMIN(C3NUM, oc_block_num - i); // max_tile = 4
486 int oc_remainder = MSMIN(C8NUM * oc_block, channel - i * C8NUM);
487 for (int p = 0; p < plane; ++p) {
488 int index_plane = i * C8NUM + p * channel;
489 for (int b = 0; b < batch; ++b) {
490 int index_batch = index_plane + b * plane * channel;
491 for (int oc = 0; oc < oc_remainder; ++oc) {
492 dst[oc] = src[index_batch + oc];
493 }
494 dst += oc_block * C8NUM;
495 }
496 }
497 }
498 #endif
499 }
500
PackDepthwiseIndirectWeightC4Fp32(const void * src,void * dst,int height,int width,int channel)501 void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel) {
502 int c4 = UP_DIV(channel, C4NUM);
503 for (int c = 0; c < c4; c++) {
504 int dst_off_c = c * C4NUM * height * width;
505 for (int i = 0; i < C4NUM; i++) {
506 int src_off_c = (c * C4NUM + i) * height * width;
507 for (int kh = 0; kh < height; kh++) {
508 int src_off_kh = src_off_c + kh * width;
509 for (int kw = 0; kw < width; kw++) {
510 int dst_off = dst_off_c + kw * height * C4NUM + kh * C4NUM + i;
511 ((float *)dst)[dst_off] = ((float *)src)[src_off_kh + kw];
512 }
513 }
514 }
515 }
516 }
517
PackDepthwiseIndirectWeightC8Fp32(const void * src,void * dst,int height,int width,int channel)518 void PackDepthwiseIndirectWeightC8Fp32(const void *src, void *dst, int height, int width, int channel) {
519 int c8 = UP_DIV(channel, C8NUM);
520 for (int c = 0; c < c8; c++) {
521 int dst_off_c = c * C8NUM * height * width;
522 for (int i = 0; i < C8NUM; i++) {
523 int src_off_c = (c * C8NUM + i) * height * width;
524 for (int kh = 0; kh < height; kh++) {
525 int src_off_kh = src_off_c + kh * width;
526 for (int kw = 0; kw < width; kw++) {
527 int dst_off = dst_off_c + kw * height * C8NUM + kh * C8NUM + i;
528 ((float *)dst)[dst_off] = ((float *)src)[src_off_kh + kw];
529 }
530 }
531 }
532 }
533 }
534
PackNHWCToNCHWFp32(const void * src,void * dst,int batches,int plane,int channel,int task_id,int thread_count)535 void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel, int task_id,
536 int thread_count) {
537 #ifdef ENABLE_ARM64
538 Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Arm64;
539 #elif defined(ENABLE_ARM32)
540 Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Arm32;
541 #elif defined(ENABLE_AVX)
542 Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Avx;
543 #elif defined(ENABLE_SSE) && !defined(ENABLE_AVX)
544 Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Sse;
545 #endif
546 int hw8 = plane / C8NUM;
547 int task_start = 0;
548 int task_end = plane;
549 if (thread_count > 0) {
550 int offset_hw = UP_DIV(hw8, thread_count) * C8NUM;
551 task_start = offset_hw * task_id;
552 int count = plane - task_start;
553 if (count <= 0) {
554 return;
555 }
556 task_end = (task_id + 1) == thread_count ? plane : MSMIN(plane, task_start + offset_hw);
557 hw8 = task_start + ((task_end - task_start) >= offset_hw ? offset_hw : 0);
558 } else {
559 hw8 *= C8NUM;
560 }
561 int c8 = channel / C8NUM * C8NUM;
562 int batch = plane * channel;
563 for (int n = 0; n < batches; n++) {
564 const float *src_batch = (const float *)src + n * batch;
565 float *dst_batch = (float *)dst + n * batch;
566 int hw = task_start;
567 for (; hw < hw8; hw += C8NUM) {
568 int c = 0;
569 for (; c < c8; c += C8NUM) {
570 const float *src_ptr = src_batch + hw * channel + c;
571 float *dst_ptr = dst_batch + c * plane + hw;
572 #if defined(ENABLE_ARM64) || defined(ENABLE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM32)
573 Transpose8X8Fp32Func_(src_ptr, dst_ptr, channel, plane);
574 #else
575 for (int tr = 0; tr < C8NUM; tr++) {
576 for (int tc = 0; tc < C8NUM; tc++) {
577 dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc];
578 }
579 }
580 #endif
581 }
582 for (; c < channel; c++) {
583 const float *src_ptr = src_batch + hw * channel + c;
584 float *dst_ptr = dst_batch + c * plane + hw;
585 for (size_t i = 0; i < C8NUM; i++) {
586 dst_ptr[i] = src_ptr[i * channel];
587 }
588 }
589 }
590 for (; hw < task_end; hw++) {
591 const float *src_ptr = src_batch + hw * channel;
592 float *dst_ptr = dst_batch + hw;
593 for (size_t i = 0; i < channel; i++) {
594 dst_ptr[i * plane] = src_ptr[i];
595 }
596 }
597 }
598 }
599
PackNCHWToNHWCFp32(const void * src,void * dst,int batch,int plane,int channel,int task_id,int thread_count)600 void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel, int task_id, int thread_count) {
601 PackNHWCToNCHWFp32(src, dst, batch, channel, plane, task_id, thread_count);
602 }
603
604 #ifdef ENABLE_ARM64
Transpose8X8Fp32Arm64(const float * src_ptr,float * dst_ptr,int src_stride,int dst_stride)605 inline void Transpose8X8Fp32Arm64(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) {
606 size_t srcStride = src_stride * sizeof(float);
607 size_t dstStride = dst_stride * sizeof(float);
608 asm volatile(
609 "mov x10, %[src_ptr]\n"
610 "mov x11, %[dst_ptr]\n"
611
612 "ld1 {v0.4s, v1.4s}, [x10], %[srcStride]\n"
613 "ld1 {v2.4s, v3.4s}, [x10], %[srcStride]\n"
614
615 "zip1 v8.4s, v0.4s, v2.4s\n"
616 "zip2 v9.4s, v0.4s, v2.4s\n"
617 "zip1 v12.4s, v1.4s, v3.4s\n"
618 "zip2 v13.4s, v1.4s, v3.4s\n"
619
620 "ld1 {v4.4s, v5.4s}, [x10], %[srcStride]\n"
621 "ld1 {v6.4s, v7.4s}, [x10], %[srcStride]\n"
622
623 "zip1 v10.4s, v4.4s, v6.4s\n"
624 "zip2 v11.4s, v4.4s, v6.4s\n"
625 "zip1 v14.4s, v5.4s, v7.4s\n"
626 "zip2 v15.4s, v5.4s, v7.4s\n"
627
628 "ld1 {v0.4s, v1.4s}, [x10], %[srcStride]\n"
629 "ld1 {v2.4s, v3.4s}, [x10], %[srcStride]\n"
630
631 "trn1 v16.2d, v8.2d, v10.2d\n"
632 "trn2 v18.2d, v8.2d, v10.2d\n"
633 "trn1 v20.2d, v9.2d, v11.2d\n"
634 "trn2 v22.2d, v9.2d, v11.2d\n"
635
636 "ld1 {v4.4s, v5.4s}, [x10], %[srcStride]\n"
637 "ld1 {v6.4s, v7.4s}, [x10], %[srcStride]\n"
638
639 "trn1 v24.2d, v12.2d, v14.2d\n"
640 "trn2 v26.2d, v12.2d, v14.2d\n"
641 "trn1 v28.2d, v13.2d, v15.2d\n"
642 "trn2 v30.2d, v13.2d, v15.2d\n"
643
644 "zip1 v8.4s, v0.4s, v2.4s\n"
645 "zip2 v9.4s, v0.4s, v2.4s\n"
646 "zip1 v12.4s, v1.4s, v3.4s\n"
647 "zip2 v13.4s, v1.4s, v3.4s\n"
648
649 "zip1 v10.4s, v4.4s, v6.4s\n"
650 "zip2 v11.4s, v4.4s, v6.4s\n"
651 "zip1 v14.4s, v5.4s, v7.4s\n"
652 "zip2 v15.4s, v5.4s, v7.4s\n"
653
654 "trn1 v17.2d, v8.2d, v10.2d\n"
655 "trn2 v19.2d, v8.2d, v10.2d\n"
656 "trn1 v21.2d, v9.2d, v11.2d\n"
657 "trn2 v23.2d, v9.2d, v11.2d\n"
658
659 "trn1 v25.2d, v12.2d, v14.2d\n"
660 "trn2 v27.2d, v12.2d, v14.2d\n"
661 "trn1 v29.2d, v13.2d, v15.2d\n"
662 "trn2 v31.2d, v13.2d, v15.2d\n"
663
664 "st1 {v16.4s, v17.4s}, [x11], %[dstStride]\n"
665 "st1 {v18.4s, v19.4s}, [x11], %[dstStride]\n"
666 "st1 {v20.4s, v21.4s}, [x11], %[dstStride]\n"
667 "st1 {v22.4s, v23.4s}, [x11], %[dstStride]\n"
668 "st1 {v24.4s, v25.4s}, [x11], %[dstStride]\n"
669 "st1 {v26.4s, v27.4s}, [x11], %[dstStride]\n"
670 "st1 {v28.4s, v29.4s}, [x11], %[dstStride]\n"
671 "st1 {v30.4s, v31.4s}, [x11], %[dstStride]\n"
672
673 :
674 : [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride)
675 : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
676 "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
677 "v31");
678 }
679 #endif
680
681 #ifdef ENABLE_ARM32
Transpose8X8Fp32Arm32(const float * src_ptr,float * dst_ptr,int src_stride,int dst_stride)682 inline void Transpose8X8Fp32Arm32(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) {
683 size_t srcStride = src_stride * sizeof(float);
684 size_t dstStride = dst_stride * sizeof(float);
685 asm volatile(
686 "mov r10, %[src_ptr]\n"
687 "mov r12, %[dst_ptr]\n"
688
689 "vld1.32 {q0, q1}, [r10], %[srcStride]\n"
690 "vld1.32 {q2, q3}, [r10], %[srcStride]\n"
691
692 "vtrn.32 d0, d4\n"
693 "vtrn.32 d1, d5\n"
694 "vtrn.32 d2, d6\n"
695 "vtrn.32 d3, d7\n"
696
697 "vld1.32 {q4, q5}, [r10], %[srcStride]\n"
698 "vld1.32 {q6, q7}, [r10], %[srcStride]\n"
699
700 "vtrn.32 d8, d12\n"
701 "vtrn.32 d9, d13\n"
702 "vtrn.32 d10, d14\n"
703 "vtrn.32 d11, d15\n"
704
705 "vld1.32 {q8, q9}, [r10], %[srcStride]\n"
706 "vld1.32 {q10, q11}, [r10], %[srcStride]\n"
707
708 "vswp d1, d8\n"
709 "vswp d3, d10\n"
710 "vswp d5, d12\n"
711 "vswp d7, d14\n"
712
713 "vtrn.32 d16, d20\n"
714 "vtrn.32 d17, d21\n"
715 "vtrn.32 d18, d22\n"
716 "vtrn.32 d19, d23\n"
717
718 "vld1.32 {q12, q13}, [r10], %[srcStride]\n"
719 "vld1.32 {q14, q15}, [r10], %[srcStride]\n"
720
721 "vtrn.32 d24, d28\n"
722 "vtrn.32 d25, d29\n"
723 "vtrn.32 d26, d30\n"
724 "vtrn.32 d27, d31\n"
725
726 "vswp d17, d24\n"
727 "vswp d19, d26\n"
728 "vswp d21, d28\n"
729 "vswp d23, d30\n"
730
731 "add r10, r12, #16\n"
732 "vst1.32 {q0}, [r12], %[dstStride]\n"
733 "vst1.32 {q8}, [r10], %[dstStride]\n"
734 "vst1.32 {q2}, [r12], %[dstStride]\n"
735 "vst1.32 {q10}, [r10], %[dstStride]\n"
736 "vst1.32 {q4}, [r12], %[dstStride]\n"
737 "vst1.32 {q12}, [r10], %[dstStride]\n"
738 "vst1.32 {q6}, [r12], %[dstStride]\n"
739 "vst1.32 {q14}, [r10], %[dstStride]\n"
740 "vst1.32 {q1}, [r12], %[dstStride]\n"
741 "vst1.32 {q9}, [r10], %[dstStride]\n"
742 "vst1.32 {q3}, [r12], %[dstStride]\n"
743 "vst1.32 {q11}, [r10], %[dstStride]\n"
744 "vst1.32 {q5}, [r12], %[dstStride]\n"
745 "vst1.32 {q13}, [r10], %[dstStride]\n"
746 "vst1.32 {q7}, [r12], %[dstStride]\n"
747 "vst1.32 {q15}, [r10], %[dstStride]\n"
748
749 :
750 : [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride)
751 : "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14",
752 "q15");
753 }
754 #endif
755
756 #ifdef ENABLE_AVX
Transpose8X8Fp32Avx(const float * src_ptr,float * dst_ptr,int src_stride,int dst_stride)757 inline void Transpose8X8Fp32Avx(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) {
758 LOAD256X8_F32(src, src_ptr, src_stride)
759 __m256 r1 = _mm256_unpacklo_ps(src1, src2);
760 __m256 r2 = _mm256_unpackhi_ps(src1, src2);
761 __m256 r3 = _mm256_unpacklo_ps(src3, src4);
762 __m256 r4 = _mm256_unpackhi_ps(src3, src4);
763 __m256 r5 = _mm256_unpacklo_ps(src5, src6);
764 __m256 r6 = _mm256_unpackhi_ps(src5, src6);
765 __m256 r7 = _mm256_unpacklo_ps(src7, src8);
766 __m256 r8 = _mm256_unpackhi_ps(src7, src8);
767
768 __m256 v;
769 v = _mm256_shuffle_ps(r1, r3, 0x4E);
770 src1 = _mm256_blend_ps(r1, v, 0xCC);
771 src2 = _mm256_blend_ps(r3, v, 0x33);
772
773 v = _mm256_shuffle_ps(r2, r4, 0x4E);
774 src3 = _mm256_blend_ps(r2, v, 0xCC);
775 src4 = _mm256_blend_ps(r4, v, 0x33);
776
777 v = _mm256_shuffle_ps(r5, r7, 0x4E);
778 src5 = _mm256_blend_ps(r5, v, 0xCC);
779 src6 = _mm256_blend_ps(r7, v, 0x33);
780
781 v = _mm256_shuffle_ps(r6, r8, 0x4E);
782 src7 = _mm256_blend_ps(r6, v, 0xCC);
783 src8 = _mm256_blend_ps(r8, v, 0x33);
784
785 r1 = _mm256_permute2f128_ps(src1, src5, 0x20);
786 r2 = _mm256_permute2f128_ps(src2, src6, 0x20);
787 r3 = _mm256_permute2f128_ps(src3, src7, 0x20);
788 r4 = _mm256_permute2f128_ps(src4, src8, 0x20);
789 r5 = _mm256_permute2f128_ps(src1, src5, 0x31);
790 r6 = _mm256_permute2f128_ps(src2, src6, 0x31);
791 r7 = _mm256_permute2f128_ps(src3, src7, 0x31);
792 r8 = _mm256_permute2f128_ps(src4, src8, 0x31);
793
794 STORE256X8_F32(dst_ptr, dst_stride, r);
795 }
796 #endif
797
798 #if defined(ENABLE_SSE) && !defined(ENABLE_AVX)
Transpose8X8Fp32Sse(const float * src_ptr,float * dst_ptr,int src_stride,int dst_stride)799 inline void Transpose8X8Fp32Sse(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) {
800 __m128 v0_ma = _mm_loadu_ps(src_ptr);
801 __m128 v1_ma = _mm_loadu_ps(src_ptr + src_stride);
802 __m128 v2_ma = _mm_loadu_ps(src_ptr + 2 * src_stride);
803 __m128 v3_ma = _mm_loadu_ps(src_ptr + 3 * src_stride);
804
805 __m128 v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma);
806 __m128 v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma);
807 __m128 v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma);
808 __m128 v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma);
809
810 __m128 v8_ma = _mm_movelh_ps(v4_ma, v6_ma);
811 __m128 v9_ma = _mm_movehl_ps(v6_ma, v4_ma);
812 __m128 v10_ma = _mm_movelh_ps(v5_ma, v7_ma);
813 __m128 v11_ma = _mm_movehl_ps(v7_ma, v5_ma);
814
815 _mm_storeu_ps(dst_ptr, v8_ma);
816 _mm_storeu_ps(dst_ptr + dst_stride, v9_ma);
817 _mm_storeu_ps(dst_ptr + 2 * dst_stride, v10_ma);
818 _mm_storeu_ps(dst_ptr + 3 * dst_stride, v11_ma);
819
820 v0_ma = _mm_loadu_ps(src_ptr + C4NUM);
821 v1_ma = _mm_loadu_ps(src_ptr + src_stride + C4NUM);
822 v2_ma = _mm_loadu_ps(src_ptr + 2 * src_stride + C4NUM);
823 v3_ma = _mm_loadu_ps(src_ptr + 3 * src_stride + C4NUM);
824
825 v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma);
826 v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma);
827 v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma);
828 v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma);
829
830 v8_ma = _mm_movelh_ps(v4_ma, v6_ma);
831 v9_ma = _mm_movehl_ps(v6_ma, v4_ma);
832 v10_ma = _mm_movelh_ps(v5_ma, v7_ma);
833 v11_ma = _mm_movehl_ps(v7_ma, v5_ma);
834
835 _mm_storeu_ps(dst_ptr + C4NUM * dst_stride, v8_ma);
836 _mm_storeu_ps(dst_ptr + (C4NUM + 1) * dst_stride, v9_ma);
837 _mm_storeu_ps(dst_ptr + (C4NUM + 2) * dst_stride, v10_ma);
838 _mm_storeu_ps(dst_ptr + (C4NUM + 3) * dst_stride, v11_ma);
839
840 v0_ma = _mm_loadu_ps(src_ptr + C4NUM * src_stride);
841 v1_ma = _mm_loadu_ps(src_ptr + (C4NUM + 1) * src_stride);
842 v2_ma = _mm_loadu_ps(src_ptr + (C4NUM + 2) * src_stride);
843 v3_ma = _mm_loadu_ps(src_ptr + (C4NUM + 3) * src_stride);
844
845 v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma);
846 v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma);
847 v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma);
848 v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma);
849
850 v8_ma = _mm_movelh_ps(v4_ma, v6_ma);
851 v9_ma = _mm_movehl_ps(v6_ma, v4_ma);
852 v10_ma = _mm_movelh_ps(v5_ma, v7_ma);
853 v11_ma = _mm_movehl_ps(v7_ma, v5_ma);
854
855 _mm_storeu_ps(dst_ptr + C4NUM, v8_ma);
856 _mm_storeu_ps(dst_ptr + dst_stride + C4NUM, v9_ma);
857 _mm_storeu_ps(dst_ptr + 2 * dst_stride + C4NUM, v10_ma);
858 _mm_storeu_ps(dst_ptr + 3 * dst_stride + C4NUM, v11_ma);
859
860 v0_ma = _mm_loadu_ps(src_ptr + C4NUM * src_stride + C4NUM);
861 v1_ma = _mm_loadu_ps(src_ptr + (C4NUM + 1) * src_stride + C4NUM);
862 v2_ma = _mm_loadu_ps(src_ptr + (C4NUM + 2) * src_stride + C4NUM);
863 v3_ma = _mm_loadu_ps(src_ptr + (C4NUM + 3) * src_stride + C4NUM);
864
865 v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma);
866 v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma);
867 v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma);
868 v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma);
869
870 v8_ma = _mm_movelh_ps(v4_ma, v6_ma);
871 v9_ma = _mm_movehl_ps(v6_ma, v4_ma);
872 v10_ma = _mm_movelh_ps(v5_ma, v7_ma);
873 v11_ma = _mm_movehl_ps(v7_ma, v5_ma);
874
875 _mm_storeu_ps(dst_ptr + C4NUM * dst_stride + C4NUM, v8_ma);
876 _mm_storeu_ps(dst_ptr + (C4NUM + 1) * dst_stride + C4NUM, v9_ma);
877 _mm_storeu_ps(dst_ptr + (C4NUM + 2) * dst_stride + C4NUM, v10_ma);
878 _mm_storeu_ps(dst_ptr + (C4NUM + 3) * dst_stride + C4NUM, v11_ma);
879 }
880 #endif
881
882 #if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
PackWeightConvDw3x3Fp32(const void * src,void * dst,int channel)883 void PackWeightConvDw3x3Fp32(const void *src, void *dst, int channel) {
884 // nchw to nc4hw4 with 1D F(2,3)
885 for (int i = 0; i < channel; i++) {
886 float *src_kernel = (float *)src + i * 9;
887 float *dst_kernel = (float *)dst + (i / 4) * 48 + i % 4;
888 for (int y = 0; y < 3; y++) {
889 float g0 = src_kernel[3 * y];
890 float g1 = src_kernel[3 * y + 1];
891 float g2 = src_kernel[3 * y + 2];
892
893 dst_kernel[16 * y] = g0;
894 dst_kernel[16 * y + 4] = 0.5f * (g0 + g1 + g2);
895 dst_kernel[16 * y + 8] = 0.5f * (g0 - g1 + g2);
896 dst_kernel[16 * y + 12] = g2;
897 }
898 }
899 }
900 #endif
901