• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/fp32/pack_fp32.h"
18 #include "nnacl/intrinsics/ms_simd_instructions.h"
19 
PackWeightKHWToHWKFp32(const void * src,void * dst,int plane,int channel)20 void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel) {
21   PackNCHWToNHWCFp32(src, dst, 1, plane, channel, 0, 0);
22 }
23 
PackHWCToWHC(const float * src,float * dst,int height,int width,int channel)24 void PackHWCToWHC(const float *src, float *dst, int height, int width, int channel) {
25   for (int i = 0; i < height; ++i) {
26     for (int j = 0; j < width; ++j) {
27       memcpy(dst + (j * height + i) * channel, src + (i * width + j) * channel, channel * sizeof(float));
28     }
29   }
30 }
31 
PackNHWCToNC4HW4NotAlignedFp32(const float * src,float * dst,const int batch,const int plane,const int channel)32 void PackNHWCToNC4HW4NotAlignedFp32(const float *src, float *dst, const int batch, const int plane, const int channel) {
33   if (channel <= C4NUM) {
34     memcpy(dst, src, batch * plane * channel * sizeof(float));
35     return;
36   }
37   int tmp = DOWN_DIV(channel, C4NUM);
38   int c_res = channel - tmp * C4NUM;
39   int c4_block = tmp * plane * C4NUM;
40   for (int b = 0; b < batch; b++) {
41     int batch_oc_offset = b * plane * channel;
42     for (int k = 0; k < plane; k++) {
43       int src_kernel_offset = batch_oc_offset + k * channel;
44       int dst_kernel_offset = batch_oc_offset + k * C4NUM;
45       int c = 0;
46       for (; c <= channel - C4NUM; c += C4NUM) {
47 #if defined(ENABLE_SSE) || defined(ENABLE_ARM)
48         MS_FLOAT32X4 src_data = MS_LDQ_F32(src + src_kernel_offset + c);
49         MS_STQ_F32(dst + dst_kernel_offset + c * plane, src_data);
50 #else
51         for (int k1 = 0; k1 < C4NUM; ++k1) {
52           (dst + dst_kernel_offset + c * plane)[k1] = (src + src_kernel_offset + c)[k1];
53         }
54 #endif
55       }
56       for (; c < channel; ++c) {
57         dst[batch_oc_offset + c4_block + k * c_res + c - tmp * C4NUM] = src[src_kernel_offset + c];
58       }
59     }
60   }
61 }
62 
PackNHWCToNC8HW8NotAlignedFp32(const float * src,float * dst,const int batch,const int plane,const int channel)63 void PackNHWCToNC8HW8NotAlignedFp32(const float *src, float *dst, const int batch, const int plane, const int channel) {
64   if (channel <= C8NUM) {
65     memcpy(dst, src, batch * plane * channel * sizeof(float));
66     return;
67   }
68   int tmp = DOWN_DIV(channel, C8NUM);
69   int c_res = channel - tmp * C8NUM;
70   int c8_block = tmp * plane * C8NUM;
71   for (int b = 0; b < batch; b++) {
72     int batch_oc_offset = b * plane * channel;
73     for (int k = 0; k < plane; k++) {
74       int src_kernel_offset = batch_oc_offset + k * channel;
75       int dst_kernel_offset = batch_oc_offset + k * C8NUM;
76       int c = 0;
77       for (; c <= channel - C8NUM; c += C8NUM) {
78 #ifdef ENABLE_AVX
79         MS_FLOAT32X8 src_data = MS_LD256_F32(src + src_kernel_offset + c);
80         MS_ST256_F32(dst + dst_kernel_offset + c * plane, src_data);
81 #else
82         for (int k1 = 0; k1 < C8NUM; ++k1) {
83           (dst + dst_kernel_offset + c * plane)[k1] = (src + src_kernel_offset + c)[k1];
84         }
85 #endif
86       }
87       for (; c < channel; ++c) {
88         dst[batch_oc_offset + c8_block + k * c_res + c - tmp * C8NUM] = src[src_kernel_offset + c];
89       }
90     }
91   }
92 }
93 
RowMajor2ColMajorParallel(const float * src_ptr,float * dst_ptr,int row,int col,int row_start,int row_end)94 void RowMajor2ColMajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) {
95   for (int r = row_start; r < row_end; ++r) {
96     for (int c = 0; c < col; ++c) {
97       dst_ptr[c * row + r] = src_ptr[r * col + c];
98     }
99   }
100 }
101 
RowMajor2RowMajorParallel(const float * src_ptr,float * dst_ptr,int row,int col,int row_start,int row_end)102 void RowMajor2RowMajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) {
103   if (row_end > row_start) {
104     src_ptr += row_start * col;
105     dst_ptr += row_start * col;
106     memcpy(dst_ptr, src_ptr, (row_end - row_start) * col * (int)(sizeof(float)));
107   }
108 }
109 
RowMajor2Row4MajorParallel(const float * src_ptr,float * dst_ptr,int row,int col,int row_start,int row_end)110 void RowMajor2Row4MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) {
111   for (int r = row_start; r < row_end; r++) {
112     const float *src = src_ptr + r * col;
113     int c = 0;
114     for (; c < col; c++) {
115       int cd4 = c / C4NUM;
116       int cm4 = c % C4NUM;
117       dst_ptr[cd4 * C4NUM * row + r * C4NUM + cm4] = src[c];
118     }
119     for (; c < UP_ROUND(col, C4NUM); c++) {
120       int cd4 = c / C4NUM;
121       int cm4 = c % C4NUM;
122       dst_ptr[cd4 * C4NUM * row + r * C4NUM + cm4] = 0;
123     }
124   }
125   return;
126 }
127 
RowMajor2Row6MajorParallel(const float * src_ptr,float * dst_ptr,int row,int col,int row_start,int row_end)128 void RowMajor2Row6MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) {
129   for (int r = row_start; r < row_end; r++) {
130     const float *src = src_ptr + r * col;
131     int c = 0;
132     for (; c < col; c++) {
133       int cd6 = c / C6NUM;
134       int cm6 = c % C6NUM;
135       dst_ptr[cd6 * C6NUM * row + r * C6NUM + cm6] = src[c];
136     }
137     for (; c < UP_ROUND(col, C6NUM); c++) {
138       int cd6 = c / C6NUM;
139       int cm6 = c % C6NUM;
140       dst_ptr[cd6 * C6NUM * row + r * C6NUM + cm6] = 0;
141     }
142   }
143   return;
144 }
145 
RowMajor2Row8MajorParallel(const float * src_ptr,float * dst_ptr,int row,int col,int row_start,int row_end)146 void RowMajor2Row8MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) {
147   for (int r = row_start; r < row_end; r++) {
148     const float *src = src_ptr + r * col;
149     int c = 0;
150     for (; c < col; c++) {
151       int cd8 = c / C8NUM;
152       int cm8 = c % C8NUM;
153       dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = src[c];
154     }
155     for (; c < UP_ROUND(col, C8NUM); c++) {
156       int cd8 = c / C8NUM;
157       int cm8 = c % C8NUM;
158       dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = 0;
159     }
160   }
161   return;
162 }
163 
RowMajor2Row12MajorParallel(const float * src_ptr,float * dst_ptr,int row,int col,int row_start,int row_end)164 void RowMajor2Row12MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) {
165   for (int r = row_start; r < row_end; r++) {
166     const float *src = src_ptr + r * col;
167     int c = 0;
168     for (; c < col; c++) {
169       int cd12 = c / C12NUM;
170       int cm12 = c % C12NUM;
171       dst_ptr[cd12 * C12NUM * row + r * C12NUM + cm12] = src[c];
172     }
173     for (; c < UP_ROUND(col, C12NUM); c++) {
174       int cd12 = c / C12NUM;
175       int cm12 = c % C12NUM;
176       dst_ptr[cd12 * C12NUM * row + r * C12NUM + cm12] = 0;
177     }
178   }
179   return;
180 }
181 
RowMajor2Row16MajorParallel(const float * src_ptr,float * dst_ptr,int row,int col,int row_start,int row_end)182 void RowMajor2Row16MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) {
183   for (int r = row_start; r < row_end; r++) {
184     const float *src = src_ptr + r * col;
185     int c = 0;
186     for (; c < col; c++) {
187       int cd16 = c / C16NUM;
188       int cm16 = c % C16NUM;
189       dst_ptr[cd16 * C16NUM * row + r * C16NUM + cm16] = src[c];
190     }
191     for (; c < UP_ROUND(col, C16NUM); c++) {
192       int cd16 = c / C16NUM;
193       int cm16 = c % C16NUM;
194       dst_ptr[cd16 * C16NUM * row + r * C16NUM + cm16] = 0;
195     }
196   }
197   return;
198 }
199 
RowMajor2Row32MajorParallel(const float * src_ptr,float * dst_ptr,int col,int row,int col_start,int col_end)200 void RowMajor2Row32MajorParallel(const float *src_ptr, float *dst_ptr, int col, int row, int col_start, int col_end) {
201   // Not exactly aligned to 32, but aligned to 24 or 16 or 8 If 32 is not met.
202   int row_block_num = UP_DIV(row, C8NUM);
203   int row_block = C4NUM;
204   for (int i = 0; i < row_block_num; i += row_block) {
205     row_block = MSMIN(C4NUM, row_block_num - i);  // max_tile = 4
206     int row_remainder = MSMIN(row_block * C8NUM, row - i * C8NUM);
207     dst_ptr += col_start * row_block * C8NUM;
208     for (int oc = col_start; oc < col_end; ++oc) {
209       memcpy(dst_ptr, src_ptr + oc * row + i * C8NUM, row_remainder * sizeof(float));
210       dst_ptr += row_block * C8NUM;
211     }
212     dst_ptr += (col - col_end) * row_block * C8NUM;
213   }
214 }
215 
RowMajor2Row64MajorParallel(const float * src_ptr,float * dst_ptr,int col,int row,int col_start,int col_end)216 void RowMajor2Row64MajorParallel(const float *src_ptr, float *dst_ptr, int col, int row, int col_start, int col_end) {
217   // Not exactly aligned to 32, but aligned to 24 or 16 or 8 If 32 is not met.
218   int row_block_num = UP_DIV(row, C16NUM);
219   int row_block = C4NUM;
220   for (int i = 0; i < row_block_num; i += row_block) {
221     row_block = MSMIN(C4NUM, row_block_num - i);  // max_tile = 4
222     int row_remainder = MSMIN(row_block * C16NUM, row - i * C16NUM);
223     dst_ptr += col_start * row_block * C16NUM;
224     for (int oc = col_start; oc < col_end; ++oc) {
225       memcpy(dst_ptr, src_ptr + oc * row + i * C16NUM, row_remainder * sizeof(float));
226       dst_ptr += row_block * C16NUM;
227     }
228     dst_ptr += (col - col_end) * row_block * C16NUM;
229   }
230 }
231 
232 #ifdef ENABLE_ARM64
RowMajor2Col12Major_arm64(const float * src_c,float * dst_c,size_t col)233 void RowMajor2Col12Major_arm64(const float *src_c, float *dst_c, size_t col) {
234   size_t stride = col * sizeof(float);
235   asm volatile(
236     "mov x10, %[src_c]\n"
237     "mov x11, %[dst_c]\n"
238 
239     "ld1 {v0.4s}, [x10], %[stride]\n"
240     "ld1 {v1.4s}, [x10], %[stride]\n"
241     "ld1 {v2.4s}, [x10], %[stride]\n"
242     "ld1 {v3.4s}, [x10], %[stride]\n"
243 
244     "ld1 {v4.4s}, [x10], %[stride]\n"
245     "ld1 {v5.4s}, [x10], %[stride]\n"
246     "ld1 {v6.4s}, [x10], %[stride]\n"
247     "ld1 {v7.4s}, [x10], %[stride]\n"
248 
249     "zip1 v12.4s, v0.4s, v1.4s\n"
250     "zip2 v13.4s, v0.4s, v1.4s\n"
251     "zip1 v14.4s, v2.4s, v3.4s\n"
252     "zip2 v15.4s, v2.4s, v3.4s\n"
253 
254     "ld1 {v8.4s}, [x10], %[stride]\n"
255     "ld1 {v9.4s}, [x10], %[stride]\n"
256     "ld1 {v10.4s}, [x10], %[stride]\n"
257     "ld1 {v11.4s}, [x10], %[stride]\n"
258 
259     "zip1 v16.4s, v4.4s, v5.4s\n"
260     "zip2 v17.4s, v4.4s, v5.4s\n"
261     "zip1 v18.4s, v6.4s, v7.4s\n"
262     "zip2 v19.4s, v6.4s, v7.4s\n"
263 
264     "trn1 v20.2d, v12.2d, v14.2d\n"
265     "trn2 v23.2d, v12.2d, v14.2d\n"
266     "trn1 v26.2d, v13.2d, v15.2d\n"
267     "trn2 v29.2d, v13.2d, v15.2d\n"
268 
269     "trn1 v21.2d, v16.2d, v18.2d\n"
270     "trn2 v24.2d, v16.2d, v18.2d\n"
271     "trn1 v27.2d, v17.2d, v19.2d\n"
272     "trn2 v30.2d, v17.2d, v19.2d\n"
273 
274     "zip1 v12.4s, v8.4s, v9.4s\n"
275     "zip2 v13.4s, v8.4s, v9.4s\n"
276     "zip1 v14.4s, v10.4s, v11.4s\n"
277     "zip2 v15.4s, v10.4s, v11.4s\n"
278 
279     "trn1 v22.2d, v12.2d, v14.2d\n"
280     "trn2 v25.2d, v12.2d, v14.2d\n"
281     "trn1 v28.2d, v13.2d, v15.2d\n"
282     "trn2 v31.2d, v13.2d, v15.2d\n"
283 
284     "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x11], #64\n"
285     "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x11], #64\n"
286     "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x11], #64\n"
287 
288     :
289     : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
290     : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
291       "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
292       "v31");
293   return;
294 }
295 #endif
296 #ifdef ENABLE_ARM32
RowMajor2Col12Major_arm32(const float * src_c,float * dst_c,size_t col)297 void RowMajor2Col12Major_arm32(const float *src_c, float *dst_c, size_t col) {
298   size_t stride = col * sizeof(float);
299   asm volatile(
300     "mov r10, %[src_c]\n"
301     "mov r12, %[dst_c]\n"
302 
303     "vld1.32 {q0}, [r10], %[stride]\n"
304     "vld1.32 {q3}, [r10], %[stride]\n"
305     "vld1.32 {q10}, [r10], %[stride]\n"
306     "vld1.32 {q13}, [r10], %[stride]\n"
307 
308     "vtrn.32 d0, d6\n"
309     "vtrn.32 d1, d7\n"
310     "vtrn.32 d20, d26\n"
311     "vtrn.32 d21, d27\n"
312 
313     "vld1.32 {q1}, [r10], %[stride]\n"
314     "vld1.32 {q8}, [r10], %[stride]\n"
315     "vld1.32 {q11}, [r10], %[stride]\n"
316     "vld1.32 {q14}, [r10], %[stride]\n"
317 
318     "vswp d1, d20\n"
319     "vswp d7, d26\n"
320 
321     "vld1.32 {q2}, [r10], %[stride]\n"
322     "vld1.32 {q9}, [r10], %[stride]\n"
323     "vld1.32 {q12}, [r10], %[stride]\n"
324     "vld1.32 {q15}, [r10], %[stride]\n"
325 
326     "vtrn.32 d2, d16\n"
327     "vtrn.32 d3, d17\n"
328     "vtrn.32 d22, d28\n"
329     "vtrn.32 d23, d29\n"
330 
331     "vswp d3, d22\n"
332     "vswp d17, d28\n"
333 
334     "vtrn.32 d4, d18\n"
335     "vtrn.32 d5, d19\n"
336     "vtrn.32 d24, d30\n"
337     "vtrn.32 d25, d31\n"
338 
339     "vswp d5, d24\n"
340     "vswp d19, d30\n"
341 
342     "vst1.32 {q0, q1}, [r12]!\n"
343     "vst1.32 {q2, q3}, [r12]!\n"
344     "vst1.32 {q8, q9}, [r12]!\n"
345     "vst1.32 {q10, q11}, [r12]!\n"
346     "vst1.32 {q12, q13}, [r12]!\n"
347     "vst1.32 {q14, q15}, [r12]!\n"
348 
349     :
350     : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
351     : "r10", "r12", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
352   return;
353 }
354 #endif
RowMajor2Col12MajorParallel(const float * src_ptr,float * dst_ptr,int row,int col,int row_start,int row_end)355 void RowMajor2Col12MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) {
356   int ri = (row_start / C12NUM * C12NUM);
357   float *dst_r = dst_ptr + ri * col;
358   const float *src_r = src_ptr + ri * col;
359   for (; ri < (row_end / C12NUM * C12NUM); ri += C12NUM) {
360     int ci = 0;
361     for (; ci < (col / C4NUM * C4NUM); ci += C4NUM) {
362       const float *src_c = src_r + ci;
363       float *dst_c = dst_r + ci * C12NUM;
364 #ifdef ENABLE_ARM64
365       RowMajor2Col12Major_arm64(src_c, dst_c, col);
366 #elif ENABLE_ARM32
367       RowMajor2Col12Major_arm32(src_c, dst_c, col);
368 #elif ENABLE_SSE
369       __m128 src1 = _mm_loadu_ps(src_c);
370       __m128 src2 = _mm_loadu_ps(src_c + col);
371       __m128 src3 = _mm_loadu_ps(src_c + 2 * col);
372       __m128 src4 = _mm_loadu_ps(src_c + 3 * col);
373       src_c += C4NUM * col;
374       __m128 src12L = _mm_unpacklo_ps(src1, src2);
375       __m128 src12H = _mm_unpackhi_ps(src1, src2);
376       __m128 src34L = _mm_unpacklo_ps(src3, src4);
377       __m128 src34H = _mm_unpackhi_ps(src3, src4);
378 
379       __m128 dst0 = _mm_movelh_ps(src12L, src34L);
380       __m128 dst3 = _mm_movehl_ps(src34L, src12L);
381       __m128 dst6 = _mm_movelh_ps(src12H, src34H);
382       __m128 dst9 = _mm_movehl_ps(src34H, src12H);
383 
384       __m128 src5 = _mm_loadu_ps(src_c);
385       __m128 src6 = _mm_loadu_ps(src_c + col);
386       __m128 src7 = _mm_loadu_ps(src_c + 2 * col);
387       __m128 src8 = _mm_loadu_ps(src_c + 3 * col);
388       src_c += C4NUM * col;
389       __m128 src56L = _mm_unpacklo_ps(src5, src6);
390       __m128 src56H = _mm_unpackhi_ps(src5, src6);
391       __m128 src78L = _mm_unpacklo_ps(src7, src8);
392       __m128 src78H = _mm_unpackhi_ps(src7, src8);
393       __m128 dst1 = _mm_movelh_ps(src56L, src78L);
394       __m128 dst4 = _mm_movehl_ps(src78L, src56L);
395       __m128 dst7 = _mm_movelh_ps(src56H, src78H);
396       __m128 dst10 = _mm_movehl_ps(src78H, src56H);
397 
398       __m128 src9 = _mm_loadu_ps(src_c);
399       __m128 src10 = _mm_loadu_ps(src_c + col);
400       __m128 src11 = _mm_loadu_ps(src_c + 2 * col);
401       __m128 src12 = _mm_loadu_ps(src_c + 3 * col);
402       src_c += C4NUM * col;
403       __m128 src910L = _mm_unpacklo_ps(src9, src10);
404       __m128 src910H = _mm_unpackhi_ps(src9, src10);
405       __m128 src1112L = _mm_unpacklo_ps(src11, src12);
406       __m128 src1112H = _mm_unpackhi_ps(src11, src12);
407       __m128 dst2 = _mm_movelh_ps(src910L, src1112L);
408       __m128 dst5 = _mm_movehl_ps(src1112L, src910L);
409       __m128 dst8 = _mm_movelh_ps(src910H, src1112H);
410       __m128 dst11 = _mm_movehl_ps(src1112H, src910H);
411 
412       _mm_storeu_ps(dst_c, dst0);
413       _mm_storeu_ps(dst_c + 4, dst1);
414       _mm_storeu_ps(dst_c + 8, dst2);
415       _mm_storeu_ps(dst_c + 12, dst3);
416       _mm_storeu_ps(dst_c + 16, dst4);
417       _mm_storeu_ps(dst_c + 20, dst5);
418       _mm_storeu_ps(dst_c + 24, dst6);
419       _mm_storeu_ps(dst_c + 28, dst7);
420       _mm_storeu_ps(dst_c + 32, dst8);
421       _mm_storeu_ps(dst_c + 36, dst9);
422       _mm_storeu_ps(dst_c + 40, dst10);
423       _mm_storeu_ps(dst_c + 44, dst11);
424 #else
425       for (int tr = 0; tr < C12NUM; tr++) {
426         for (int tc = 0; tc < C4NUM; tc++) {
427           dst_c[tc * C12NUM + tr] = src_c[tr * col + tc];
428         }
429       }
430 #endif
431     }
432     for (; ci < col; ci++) {
433       const float *src_c = src_r + ci;
434       float *dst_c = dst_r + ci * C12NUM;
435       for (int i = 0; i < C12NUM; i++) {
436         dst_c[i] = src_c[i * col];
437       }
438     }
439     src_r += C12NUM * col;
440     dst_r += C12NUM * col;
441   }
442   if (row_end == row) {
443     for (; ri < row_end; ri++, dst_r++, src_r += col) {
444       for (int i = 0; i < col; i++) {
445         dst_r[i * C12NUM] = src_r[i];
446       }
447     }
448     for (; ri < UP_ROUND(row, C12NUM); ri++, dst_r++) {
449       for (int i = 0; i < col; i++) {
450         dst_r[i * C12NUM] = 0;
451       }
452     }
453   }
454 }
455 
456 #ifdef ENABLE_ARM64
RowMajor2Col8Major_arm64(const float * src_c,float * dst_c,size_t col)457 void RowMajor2Col8Major_arm64(const float *src_c, float *dst_c, size_t col) {
458   size_t stride = col * sizeof(float);
459   asm volatile(
460     "mov x10, %[src_c]\n"
461     "mov x11, %[dst_c]\n"
462 
463     "ld1 {v0.4s, v1.4s}, [x10], %[stride]\n"
464     "ld1 {v2.4s, v3.4s}, [x10], %[stride]\n"
465     "ld1 {v4.4s, v5.4s}, [x10], %[stride]\n"
466     "ld1 {v6.4s, v7.4s}, [x10], %[stride]\n"
467 
468     "zip1 v8.4s, v0.4s, v2.4s\n"
469     "zip2 v9.4s, v0.4s, v2.4s\n"
470     "zip1 v10.4s, v4.4s, v6.4s\n"
471     "zip2 v11.4s, v4.4s, v6.4s\n"
472 
473     "ld1 {v16.4s, v17.4s}, [x10], %[stride]\n"
474     "ld1 {v18.4s, v19.4s}, [x10], %[stride]\n"
475     "ld1 {v20.4s, v21.4s}, [x10], %[stride]\n"
476     "ld1 {v22.4s, v23.4s}, [x10], %[stride]\n"
477 
478     "zip1 v12.4s, v1.4s, v3.4s\n"
479     "zip2 v13.4s, v1.4s, v3.4s\n"
480     "zip1 v14.4s, v5.4s, v7.4s\n"
481     "zip2 v15.4s, v5.4s, v7.4s\n"
482 
483     "trn1 v0.2d, v8.2d, v10.2d\n"
484     "trn2 v1.2d, v8.2d, v10.2d\n"
485     "trn1 v2.2d, v9.2d, v11.2d\n"
486     "trn2 v3.2d, v9.2d, v11.2d\n"
487 
488     "zip1 v24.4s, v16.4s, v18.4s\n"
489     "zip2 v25.4s, v16.4s, v18.4s\n"
490     "zip1 v26.4s, v20.4s, v22.4s\n"
491     "zip2 v27.4s, v20.4s, v22.4s\n"
492 
493     "trn1 v4.2d, v12.2d, v14.2d\n"
494     "trn2 v5.2d, v12.2d, v14.2d\n"
495     "trn1 v6.2d, v13.2d, v15.2d\n"
496     "trn2 v7.2d, v13.2d, v15.2d\n"
497 
498     "zip1 v28.4s, v17.4s, v19.4s\n"
499     "zip2 v29.4s, v17.4s, v19.4s\n"
500     "zip1 v30.4s, v21.4s, v23.4s\n"
501     "zip2 v31.4s, v21.4s, v23.4s\n"
502 
503     "trn1 v16.2d, v24.2d, v26.2d\n"
504     "trn2 v17.2d, v24.2d, v26.2d\n"
505     "trn1 v18.2d, v25.2d, v27.2d\n"
506     "trn2 v19.2d, v25.2d, v27.2d\n"
507 
508     "trn1 v20.2d, v28.2d, v30.2d\n"
509     "trn2 v21.2d, v28.2d, v30.2d\n"
510     "trn1 v22.2d, v29.2d, v31.2d\n"
511     "trn2 v23.2d, v29.2d, v31.2d\n"
512 
513     "st1 {v0.4s}, [x11], #16\n"
514     "st1 {v16.4s}, [x11], #16\n"
515     "st1 {v1.4s}, [x11], #16\n"
516     "st1 {v17.4s}, [x11], #16\n"
517     "st1 {v2.4s}, [x11], #16\n"
518     "st1 {v18.4s}, [x11], #16\n"
519     "st1 {v3.4s}, [x11], #16\n"
520     "st1 {v19.4s}, [x11], #16\n"
521     "st1 {v4.4s}, [x11], #16\n"
522     "st1 {v20.4s}, [x11], #16\n"
523     "st1 {v5.4s}, [x11], #16\n"
524     "st1 {v21.4s}, [x11], #16\n"
525     "st1 {v6.4s}, [x11], #16\n"
526     "st1 {v22.4s}, [x11], #16\n"
527     "st1 {v7.4s}, [x11], #16\n"
528     "st1 {v23.4s}, [x11], #16\n"
529 
530     :
531     : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
532     : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
533       "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
534       "v31");
535   return;
536 }
537 #endif
538 #ifdef ENABLE_ARM32
539 #ifndef SUPPORT_NNIE
RowMajor2Col8Major_arm32(const float * src_c,float * dst_c,size_t col)540 void RowMajor2Col8Major_arm32(const float *src_c, float *dst_c, size_t col) {
541   size_t stride = col * sizeof(float);
542   asm volatile(
543     "mov r10, %[src_c]\n"
544     "mov r11, %[dst_c]\n"
545 
546     "vld1.32 {q0}, [r10], %[stride]\n"
547     "vld1.32 {q2}, [r10], %[stride]\n"
548     "vld1.32 {q4}, [r10], %[stride]\n"
549     "vld1.32 {q6}, [r10], %[stride]\n"
550 
551     "vtrn.32 d0, d4\n"
552     "vtrn.32 d1, d5\n"
553     "vtrn.32 d8, d12\n"
554     "vtrn.32 d9, d13\n"
555 
556     "vld1.32 {q1}, [r10], %[stride]\n"
557     "vld1.32 {q3}, [r10], %[stride]\n"
558     "vld1.32 {q5}, [r10], %[stride]\n"
559     "vld1.32 {q7}, [r10], %[stride]\n"
560 
561     "vswp d1, d8\n"
562     "vswp d5, d12\n"
563 
564     "vtrn.32 d2, d6\n"
565     "vtrn.32 d3, d7\n"
566     "vtrn.32 d10, d14\n"
567     "vtrn.32 d11, d15\n"
568 
569     "vswp d3, d10\n"
570     "vswp d7, d14\n"
571 
572     "vst1.32 {q0, q1}, [r11]!\n"
573     "vst1.32 {q2, q3}, [r11]!\n"
574     "vst1.32 {q4, q5}, [r11]!\n"
575     "vst1.32 {q6, q7}, [r11]!\n"
576 
577     :
578     : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
579     : "r10", "r11", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7");
580   return;
581 }
582 #else
RowMajor2Col8Major_arm32(const float * src_c,float * dst_c,size_t col)583 void RowMajor2Col8Major_arm32(const float *src_c, float *dst_c, size_t col) {
584   size_t stride = col * sizeof(float);
585   asm volatile(
586     "mov r10, %[src_c]\n"
587     "mov r7, %[dst_c]\n"
588 
589     "vld1.32 {q0}, [r10], %[stride]\n"
590     "vld1.32 {q2}, [r10], %[stride]\n"
591     "vld1.32 {q4}, [r10], %[stride]\n"
592     "vld1.32 {q6}, [r10], %[stride]\n"
593 
594     "vtrn.32 d0, d4\n"
595     "vtrn.32 d1, d5\n"
596     "vtrn.32 d8, d12\n"
597     "vtrn.32 d9, d13\n"
598 
599     "vld1.32 {q1}, [r10], %[stride]\n"
600     "vld1.32 {q3}, [r10], %[stride]\n"
601     "vld1.32 {q5}, [r10], %[stride]\n"
602     "vld1.32 {q7}, [r10], %[stride]\n"
603 
604     "vswp d1, d8\n"
605     "vswp d5, d12\n"
606 
607     "vtrn.32 d2, d6\n"
608     "vtrn.32 d3, d7\n"
609     "vtrn.32 d10, d14\n"
610     "vtrn.32 d11, d15\n"
611 
612     "vswp d3, d10\n"
613     "vswp d7, d14\n"
614 
615     "vst1.32 {q0, q1}, [r7]!\n"
616     "vst1.32 {q2, q3}, [r7]!\n"
617     "vst1.32 {q4, q5}, [r7]!\n"
618     "vst1.32 {q6, q7}, [r7]!\n"
619 
620     :
621     : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
622     : "r10", "r7", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7");
623   return;
624 }
625 #endif
626 #endif
RowMajor2Col8MajorParallel(const float * src_ptr,float * dst_ptr,int row,int col,int row_start,int row_end)627 void RowMajor2Col8MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) {
628   int row8 = row_end / C8NUM * C8NUM;
629 #ifdef ENABLE_ARM64
630   int col_skip = col / C8NUM * C8NUM;
631   int skip_size = C8NUM;
632 #else
633   int col_skip = col / C4NUM * C4NUM;
634   int skip_size = C4NUM;
635 #endif
636   int ri = (row_start / C8NUM * C8NUM);
637   const float *src_r = src_ptr + ri * col;
638   float *dst_r = dst_ptr + ri * col;
639 
640   for (; ri < row8; ri += C8NUM) {
641     int ci = 0;
642     for (; ci < col_skip; ci += skip_size) {
643       const float *src_c = src_r + ci;
644       float *dst_c = dst_r + ci * C8NUM;
645 
646 #ifdef ENABLE_ARM64
647       RowMajor2Col8Major_arm64(src_c, dst_c, col);
648 #elif ENABLE_ARM32
649       RowMajor2Col8Major_arm32(src_c, dst_c, col);
650 #elif ENABLE_SSE
651       __m128 src1 = _mm_loadu_ps(src_c);
652       __m128 src2 = _mm_loadu_ps(src_c + col);
653       __m128 src3 = _mm_loadu_ps(src_c + 2 * col);
654       __m128 src4 = _mm_loadu_ps(src_c + 3 * col);
655       src_c += C4NUM * col;
656       __m128 src12L = _mm_unpacklo_ps(src1, src2);  // x5
657       __m128 src12H = _mm_unpackhi_ps(src1, src2);  // x1
658       __m128 src34L = _mm_unpacklo_ps(src3, src4);  // x
659       __m128 src34H = _mm_unpackhi_ps(src3, src4);
660       _mm_storeu_ps(dst_c, _mm_movelh_ps(src12L, src34L));
661       _mm_storeu_ps(dst_c + C8NUM, _mm_movehl_ps(src34L, src12L));
662       _mm_storeu_ps(dst_c + C16NUM, _mm_movelh_ps(src12H, src34H));
663       _mm_storeu_ps(dst_c + C24NUM, _mm_movehl_ps(src34H, src12H));
664 
665       __m128 src5 = _mm_loadu_ps(src_c);
666       __m128 src6 = _mm_loadu_ps(src_c + col);
667       __m128 src7 = _mm_loadu_ps(src_c + 2 * col);
668       __m128 src8 = _mm_loadu_ps(src_c + 3 * col);
669       src_c += C4NUM * col;
670       __m128 src56L = _mm_unpacklo_ps(src5, src6);
671       __m128 src56H = _mm_unpackhi_ps(src5, src6);
672       __m128 src78L = _mm_unpacklo_ps(src7, src8);
673       __m128 src78H = _mm_unpackhi_ps(src7, src8);
674       _mm_storeu_ps(dst_c + C4NUM, _mm_movelh_ps(src56L, src78L));
675       _mm_storeu_ps(dst_c + C12NUM, _mm_movehl_ps(src78L, src56L));
676       _mm_storeu_ps(dst_c + 20, _mm_movelh_ps(src56H, src78H));
677       _mm_storeu_ps(dst_c + 28, _mm_movehl_ps(src78H, src56H));
678 #else
679       for (int tr = 0; tr < 8; tr++) {
680         for (int tc = 0; tc < 4; tc++) {
681           dst_c[tc * 8 + tr] = src_c[tr * col + tc];
682         }
683       }
684 #endif
685     }
686     for (; ci < col; ci++) {
687       const float *src_c = src_r + ci;
688       float *dst_c = dst_r + ci * C8NUM;
689       for (int i = 0; i < C8NUM; i++) {
690         dst_c[i] = src_c[i * col];
691       }
692     }
693     src_r += C8NUM * col;
694     dst_r += C8NUM * col;
695   }
696   if (row_end == row) {
697     for (; ri < row; ri++, src_r += col, dst_r++) {
698       for (int i = 0; i < col; i++) {
699         dst_r[i * C8NUM] = src_r[i];
700       }
701     }
702 
703     for (; ri < UP_ROUND(row, C8NUM); ri++, dst_r++) {
704       for (int i = 0; i < col; i++) {
705         dst_r[i * C8NUM] = 0;
706       }
707     }
708   }
709 }
710 
RowMajor2Col16MajorParallel(const float * src_ptr,float * dst_ptr,int row,int col,int row_start,int row_end)711 void RowMajor2Col16MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) {
712   int row16 = row_end / C16NUM * C16NUM;
713   int ri = row_start / C16NUM * C16NUM;
714   int col8 = col / C8NUM * C8NUM;
715   const float *src_r = src_ptr + ri * col;
716   float *dst_r = dst_ptr + ri * col;
717 
718   for (; ri < row16; ri += C16NUM) {
719     int ci = 0;
720     for (; ci < col8; ci += C8NUM) {
721       const float *src_c = src_r + ci;
722       float *dst_c = dst_r + ci * C16NUM;
723 #ifdef ENABLE_AVX
724       Transpose8X8Fp32Avx(src_c, dst_c, col, C16NUM);
725       Transpose8X8Fp32Avx(src_c + C8NUM * col, dst_c + C8NUM, col, C16NUM);
726 #else
727       for (int tr = 0; tr < C16NUM; tr++) {
728         for (int tc = 0; tc < C8NUM; tc++) {
729           dst_c[tc * C16NUM + tr] = src_c[tr * col + tc];
730         }
731       }
732 #endif
733     }
734     for (; ci < col; ci++) {
735       const float *src_c = src_r + ci;
736       float *dst_c = dst_r + ci * C16NUM;
737       for (int i = 0; i < C16NUM; i++) {
738         dst_c[i] = src_c[i * col];
739       }
740     }
741     src_r += C16NUM * col;
742     dst_r += C16NUM * col;
743   }
744   if (row_end == row) {
745     for (; ri < row; ri++) {
746       for (int i = 0; i < col; i++) {
747         dst_r[i * C16NUM] = src_r[i];
748       }
749       src_r += col;
750       dst_r += 1;
751     }
752     int total_row = UP_ROUND(row, C16NUM);
753     for (; ri < total_row; ri++) {
754       for (int i = 0; i < col; i++) {
755         dst_r[i * C16NUM] = 0;
756       }
757       dst_r += 1;
758     }
759   }
760 }
761 
RowMajor2Col32MajorParallel(const float * src_ptr,float * dst_ptr,int row,int col,int row_start,int row_end)762 void RowMajor2Col32MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) {
763   // Not exactly aligned to 32, but aligned to 24 or 16 or 8 If 32 is not met.
764 #ifdef ENABLE_AVX
765   int col8 = col / C8NUM * C8NUM;
766 #endif
767   int all_block_num = UP_DIV(row, C8NUM);
768   int cur_block = C4NUM;
769   row_start = UP_DIV(row_start, C8NUM);
770   row_end = UP_DIV(row_end, C8NUM);
771   for (int i = UP_ROUND(row_start, C4NUM); i < row_end; i += cur_block) {
772     cur_block = MSMIN(C4NUM, all_block_num - i);  // max_tile = 4
773     int dst_stride = cur_block * C8NUM;
774     int row_num = MSMIN(dst_stride, row - i * C8NUM);
775 #ifdef ENABLE_AVX
776     int row8_num = row_num / C8NUM * C8NUM;
777 #endif
778     const float *src = src_ptr + i * C8NUM * col;
779     float *dst = dst_ptr + i * C8NUM * col;
780     int r = 0;
781 #ifdef ENABLE_AVX
782     for (; r < row8_num; r += C8NUM) {
783       int c = 0;
784       for (; c < col8; c += C8NUM) {
785         Transpose8X8Fp32Avx(src + r * col + c, dst + c * dst_stride + r, col, dst_stride);
786       }
787       for (; c < col; ++c) {
788         for (int k = 0; k < C8NUM; ++k) {
789           dst[c * dst_stride + r + k] = src[r * col + c + k * col];
790         }
791       }
792     }
793 #endif
794     for (; r < row_num; r++) {
795       for (int c = 0; c < col; ++c) {
796         dst[c * dst_stride + r] = src[r * col + c];
797       }
798     }
799   }
800 }
801 
RowMajor2Col64MajorParallel(const float * src_ptr,float * dst_ptr,int row,int col,int row_start,int row_end)802 void RowMajor2Col64MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) {
803   // Not exactly aligned to 64, but aligned to 48 or 32 or 16 If 64 is not met.
804   int all_block_num = UP_DIV(row, C16NUM);
805   int cur_block = C4NUM;
806   row_start = UP_DIV(row_start, C16NUM);
807   row_end = UP_DIV(row_end, C16NUM);
808   for (int i = UP_ROUND(row_start, C4NUM); i < row_end; i += cur_block) {
809     cur_block = MSMIN(C4NUM, all_block_num - i);  // max_tile = 4
810     int dst_stride = cur_block * C16NUM;
811     int row_num = MSMIN(dst_stride, row - i * C16NUM);
812     const float *src = src_ptr + i * C16NUM * col;
813     float *dst = dst_ptr + i * C16NUM * col;
814     int r = 0;
815     for (; r < row_num; r++) {
816       for (int c = 0; c < col; ++c) {
817         dst[c * dst_stride + r] = src[r * col + c];
818       }
819     }
820   }
821 }
822 
RowMajor2Col6MajorParallel(const float * src_ptr,float * dst_ptr,int row,int col,int row_start,int row_end)823 void RowMajor2Col6MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) {
824   int row6 = row_end / C6NUM * C6NUM;
825   int ri = row_start / C6NUM * C6NUM;
826   int col8 = col / C8NUM * C8NUM;
827   const float *src_r = src_ptr + ri * col;
828   float *dst_r = dst_ptr + ri * col;
829 
830   for (; ri < row6; ri += C6NUM) {
831     int ci = 0;
832     for (; ci < col8; ci += C8NUM) {
833       const float *src_c = src_r + ci;
834       float *dst_c = dst_r + ci * C6NUM;
835 
836 #ifdef ENABLE_AVX
837       __m256 src0 = _mm256_loadu_ps(src_c);
838       __m256 src1 = _mm256_loadu_ps(src_c + col);
839       __m256 src2 = _mm256_loadu_ps(src_c + 2 * col);
840       __m256 src3 = _mm256_loadu_ps(src_c + 3 * col);
841       __m256 src4 = _mm256_loadu_ps(src_c + 4 * col);
842       __m256 src5 = _mm256_loadu_ps(src_c + 5 * col);
843       __m256 trans0 = _mm256_unpacklo_ps(src0, src1);
844       __m256 trans1 = _mm256_unpacklo_ps(src2, src3);
845       __m256 trans2 = _mm256_unpacklo_ps(src4, src5);
846       __m256 trans3 = _mm256_unpackhi_ps(src0, src1);
847       __m256 trans4 = _mm256_unpackhi_ps(src2, src3);
848       __m256 trans5 = _mm256_unpackhi_ps(src4, src5);
849       __m128 lo0 = _mm256_castps256_ps128(trans0);
850       __m128 lo1 = _mm256_castps256_ps128(trans1);
851       __m128 lo2 = _mm256_castps256_ps128(trans2);
852       __m128 lo3 = _mm256_castps256_ps128(trans3);
853       __m128 lo4 = _mm256_castps256_ps128(trans4);
854       __m128 lo5 = _mm256_castps256_ps128(trans5);
855       __m128 hi0 = _mm256_extractf128_ps(trans0, 1);
856       __m128 hi1 = _mm256_extractf128_ps(trans1, 1);
857       __m128 hi2 = _mm256_extractf128_ps(trans2, 1);
858       __m128 hi3 = _mm256_extractf128_ps(trans3, 1);
859       __m128 hi4 = _mm256_extractf128_ps(trans4, 1);
860       __m128 hi5 = _mm256_extractf128_ps(trans5, 1);
861       __m128 res0 = _mm_shuffle_ps(lo0, lo1, _MM_SHUFFLE(1, 0, 1, 0));
862       __m128 res1 = _mm_shuffle_ps(lo2, lo0, _MM_SHUFFLE(3, 2, 1, 0));
863       __m128 res2 = _mm_shuffle_ps(lo1, lo2, _MM_SHUFFLE(3, 2, 3, 2));
864       __m128 res3 = _mm_shuffle_ps(lo3, lo4, _MM_SHUFFLE(1, 0, 1, 0));
865       __m128 res4 = _mm_shuffle_ps(lo5, lo3, _MM_SHUFFLE(3, 2, 1, 0));
866       __m128 res5 = _mm_shuffle_ps(lo4, lo5, _MM_SHUFFLE(3, 2, 3, 2));
867       __m128 res6 = _mm_shuffle_ps(hi0, hi1, _MM_SHUFFLE(1, 0, 1, 0));
868       __m128 res7 = _mm_shuffle_ps(hi2, hi0, _MM_SHUFFLE(3, 2, 1, 0));
869       __m128 res8 = _mm_shuffle_ps(hi1, hi2, _MM_SHUFFLE(3, 2, 3, 2));
870       __m128 res9 = _mm_shuffle_ps(hi3, hi4, _MM_SHUFFLE(1, 0, 1, 0));
871       __m128 res10 = _mm_shuffle_ps(hi5, hi3, _MM_SHUFFLE(3, 2, 1, 0));
872       __m128 res11 = _mm_shuffle_ps(hi4, hi5, _MM_SHUFFLE(3, 2, 3, 2));
873       _mm_storeu_ps(dst_c, res0);
874       _mm_storeu_ps(dst_c + 4, res1);
875       _mm_storeu_ps(dst_c + 8, res2);
876       _mm_storeu_ps(dst_c + 12, res3);
877       _mm_storeu_ps(dst_c + 16, res4);
878       _mm_storeu_ps(dst_c + 20, res5);
879       _mm_storeu_ps(dst_c + 24, res6);
880       _mm_storeu_ps(dst_c + 28, res7);
881       _mm_storeu_ps(dst_c + 32, res8);
882       _mm_storeu_ps(dst_c + 36, res9);
883       _mm_storeu_ps(dst_c + 40, res10);
884       _mm_storeu_ps(dst_c + 44, res11);
885 #else
886       for (int tr = 0; tr < C6NUM; tr++) {
887         for (int tc = 0; tc < C8NUM; tc++) {
888           dst_c[tc * C6NUM + tr] = src_c[tr * col + tc];
889         }
890       }
891 #endif
892     }
893     for (; ci < col; ci++) {
894       const float *src_c = src_r + ci;
895       float *dst_c = dst_r + ci * C6NUM;
896       for (int i = 0; i < C6NUM; i++) {
897         dst_c[i] = src_c[i * col];
898       }
899     }
900     src_r += C6NUM * col;
901     dst_r += C6NUM * col;
902   }
903 
904   if (row_end == row) {
905     for (; ri < row_end; ri++) {
906       for (int i = 0; i < col; i++) {
907         dst_r[i * C6NUM] = src_r[i];
908       }
909       src_r += col;
910       dst_r += 1;
911     }
912 
913     int totalRow = UP_ROUND(row, C6NUM);
914     for (; ri < totalRow; ri++) {
915       for (int i = 0; i < col; i++) {
916         dst_r[i * C6NUM] = 0;
917       }
918       dst_r += 1;
919     }
920   }
921 }
922 
RowMajor2Col4MajorParallel(const float * src_ptr,float * dst_ptr,int row,int col,int row_start,int row_end)923 void RowMajor2Col4MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) {
924   int row4 = row_end / C4NUM * C4NUM;
925   int ri = row_start / C4NUM * C4NUM;
926   int col4 = col / C4NUM * C4NUM;
927   const float *src_r = src_ptr + ri * col;
928   float *dst_r = dst_ptr + ri * col;
929 
930   for (; ri < row4; ri += C4NUM) {
931     int ci = 0;
932     for (; ci < col4; ci += C4NUM) {
933       const float *src_c = src_r + ci;
934       float *dst_c = dst_r + ci * C4NUM;
935 
936 #ifdef ENABLE_ARM32
937       int stride = col * 4;
938       asm volatile(
939         "mov r10, %[src_c]\n"
940         "mov r12, %[dst_c]\n"
941 
942         "vld1.32 {q0}, [r10], %[stride]\n"
943         "vld1.32 {q1}, [r10], %[stride]\n"
944         "vld1.32 {q2}, [r10], %[stride]\n"
945         "vld1.32 {q3}, [r10], %[stride]\n"
946 
947         "vtrn.32 d0, d2\n"
948         "vtrn.32 d1, d3\n"
949         "vtrn.32 d4, d6\n"
950         "vtrn.32 d5, d7\n"
951 
952         "vswp d1, d4\n"
953         "vswp d3, d6\n"
954 
955         "vst1.32 {q0}, [r12]!\n"
956         "vst1.32 {q1}, [r12]!\n"
957         "vst1.32 {q2}, [r12]!\n"
958         "vst1.32 {q3}, [r12]!\n"
959 
960         :
961         : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
962         : "r10", "r12", "q0", "q1", "q2", "q3");
963 #elif ENABLE_SSE
964       __m128 src1 = _mm_loadu_ps(src_c);
965       __m128 src2 = _mm_loadu_ps(src_c + col);
966       __m128 src3 = _mm_loadu_ps(src_c + 2 * col);
967       __m128 src4 = _mm_loadu_ps(src_c + 3 * col);
968       src_c += C4NUM * col;
969       __m128 src12L = _mm_unpacklo_ps(src1, src2);
970       __m128 src12H = _mm_unpackhi_ps(src1, src2);
971       __m128 src34L = _mm_unpacklo_ps(src3, src4);
972       __m128 src34H = _mm_unpackhi_ps(src3, src4);
973 
974       __m128 dst0 = _mm_movelh_ps(src12L, src34L);
975       __m128 dst1 = _mm_movehl_ps(src34L, src12L);
976       __m128 dst2 = _mm_movelh_ps(src12H, src34H);
977       __m128 dst3 = _mm_movehl_ps(src34H, src12H);
978 
979       _mm_storeu_ps(dst_c, dst0);
980       _mm_storeu_ps(dst_c + 4, dst1);
981       _mm_storeu_ps(dst_c + 8, dst2);
982       _mm_storeu_ps(dst_c + 12, dst3);
983 #else
984       for (size_t tr = 0; tr < C4NUM; tr++) {
985         for (size_t tc = 0; tc < C4NUM; tc++) {
986           dst_c[tc * C4NUM + tr] = src_c[tr * col + tc];
987         }
988       }
989 #endif
990     }
991     for (; ci < col; ci++) {
992       const float *src_c = src_r + ci;
993       float *dst_c = dst_r + ci * C4NUM;
994       for (int i = 0; i < C4NUM; i++) {
995         dst_c[i] = src_c[i * col];
996       }
997     }
998     src_r += C4NUM * col;
999     dst_r += C4NUM * col;
1000   }
1001   if (row_end == row) {
1002     for (; ri < row; ri++) {
1003       for (int i = 0; i < col; i++) {
1004         dst_r[i * C4NUM] = src_r[i];
1005       }
1006       src_r += col;
1007       dst_r += 1;
1008     }
1009 
1010     int total_row = UP_ROUND(row, C4NUM);
1011     for (; ri < total_row; ri++) {
1012       for (int i = 0; i < col; i++) {
1013         dst_r[i * C4NUM] = 0;
1014       }
1015       dst_r += 1;
1016     }
1017   }
1018 }
1019 
RowMajor2ColMajor(const float * src_ptr,float * dst_ptr,int row,int col)1020 void RowMajor2ColMajor(const float *src_ptr, float *dst_ptr, int row, int col) {
1021   RowMajor2ColMajorParallel(src_ptr, dst_ptr, row, col, 0, row);
1022 }
RowMajor2RowMajor(const float * src_ptr,float * dst_ptr,int row,int col)1023 void RowMajor2RowMajor(const float *src_ptr, float *dst_ptr, int row, int col) {
1024   RowMajor2RowMajorParallel(src_ptr, dst_ptr, row, col, 0, row);
1025 }
RowMajor2Row4Major(const float * src_ptr,float * dst_ptr,int row,int col)1026 void RowMajor2Row4Major(const float *src_ptr, float *dst_ptr, int row, int col) {
1027   RowMajor2Row4MajorParallel(src_ptr, dst_ptr, row, col, 0, row);
1028 }
RowMajor2Row6Major(const float * src_ptr,float * dst_ptr,int row,int col)1029 void RowMajor2Row6Major(const float *src_ptr, float *dst_ptr, int row, int col) {
1030   RowMajor2Row6MajorParallel(src_ptr, dst_ptr, row, col, 0, row);
1031 }
RowMajor2Row8Major(const float * src_ptr,float * dst_ptr,int row,int col)1032 void RowMajor2Row8Major(const float *src_ptr, float *dst_ptr, int row, int col) {
1033   RowMajor2Row8MajorParallel(src_ptr, dst_ptr, row, col, 0, row);
1034 }
RowMajor2Row12Major(const float * src_ptr,float * dst_ptr,int row,int col)1035 void RowMajor2Row12Major(const float *src_ptr, float *dst_ptr, int row, int col) {
1036   RowMajor2Row12MajorParallel(src_ptr, dst_ptr, row, col, 0, row);
1037 }
RowMajor2Row16Major(const float * src_ptr,float * dst_ptr,int row,int col)1038 void RowMajor2Row16Major(const float *src_ptr, float *dst_ptr, int row, int col) {
1039   RowMajor2Row16MajorParallel(src_ptr, dst_ptr, row, col, 0, row);
1040 }
RowMajor2Row32Major(const float * src_ptr,float * dst_ptr,int col,int row)1041 void RowMajor2Row32Major(const float *src_ptr, float *dst_ptr, int col, int row) {
1042   RowMajor2Row32MajorParallel(src_ptr, dst_ptr, col, row, 0, col);
1043 }
RowMajor2Row64Major(const float * src_ptr,float * dst_ptr,int col,int row)1044 void RowMajor2Row64Major(const float *src_ptr, float *dst_ptr, int col, int row) {
1045   RowMajor2Row64MajorParallel(src_ptr, dst_ptr, col, row, 0, col);
1046 }
RowMajor2Col12Major(const float * src_ptr,float * dst_ptr,int row,int col)1047 void RowMajor2Col12Major(const float *src_ptr, float *dst_ptr, int row, int col) {
1048   RowMajor2Col12MajorParallel(src_ptr, dst_ptr, row, col, 0, row);
1049 }
RowMajor2Col8Major(const float * src_ptr,float * dst_ptr,int row,int col)1050 void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, int row, int col) {
1051   RowMajor2Col8MajorParallel(src_ptr, dst_ptr, row, col, 0, row);
1052 }
RowMajor2Col16Major(const float * src_ptr,float * dst_ptr,int row,int col)1053 void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, int row, int col) {
1054   RowMajor2Col16MajorParallel(src_ptr, dst_ptr, row, col, 0, row);
1055 }
RowMajor2Col32Major(const float * src_ptr,float * dst_ptr,int row,int col)1056 void RowMajor2Col32Major(const float *src_ptr, float *dst_ptr, int row, int col) {
1057   RowMajor2Col32MajorParallel(src_ptr, dst_ptr, row, col, 0, row);
1058 }
RowMajor2Col64Major(const float * src_ptr,float * dst_ptr,int row,int col)1059 void RowMajor2Col64Major(const float *src_ptr, float *dst_ptr, int row, int col) {
1060   RowMajor2Col64MajorParallel(src_ptr, dst_ptr, row, col, 0, row);
1061 }
RowMajor2Col6Major(const float * src_ptr,float * dst_ptr,int row,int col)1062 void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, int row, int col) {
1063   RowMajor2Col6MajorParallel(src_ptr, dst_ptr, row, col, 0, row);
1064 }
RowMajor2Col4Major(const float * src_ptr,float * dst_ptr,int row,int col)1065 void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, int row, int col) {
1066   RowMajor2Col4MajorParallel(src_ptr, dst_ptr, row, col, 0, row);
1067 }
1068 
PackNHWCToNC4HW4Fp32(const void * src,void * dst,int batch,int plane,int channel)1069 void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
1070   int c4 = UP_DIV(channel, C4NUM);
1071   int c4_minus = c4 - 1;
1072   for (int b = 0; b < batch; b++) {
1073     int src_oc_offset = b * plane * channel;
1074     int dst_oc_offset = b * plane * c4 * C4NUM;
1075     for (int k = 0; k < plane; k++) {
1076       int src_kernel_offset = src_oc_offset + k * channel;
1077       int dst_kernel_offset = dst_oc_offset + k * C4NUM;
1078       for (int j = 0; j < c4_minus; ++j) {
1079         int src_ic_offset = src_kernel_offset + j * C4NUM;
1080         int dst_ic_offset = dst_kernel_offset + j * plane * C4NUM;
1081 #ifdef ENABLE_ARM
1082         vst1q_f32((float *)dst + dst_ic_offset, vld1q_f32((float *)src + src_ic_offset));
1083 #else
1084         for (int i = 0; i < C4NUM; ++i) {
1085           ((float *)dst + dst_ic_offset)[i] = ((float *)src + src_ic_offset)[i];
1086         }
1087 #endif
1088       }
1089       int tmp_c = c4_minus * C4NUM;
1090       int tmp_c_offset = tmp_c * plane;
1091       int res_c = channel - tmp_c;
1092       if (res_c > channel) {
1093         return;
1094       }
1095       for (int l = 0; l < res_c; ++l) {
1096         int src_ic_offset = src_kernel_offset + tmp_c + l;
1097         int dst_ic_offset = dst_kernel_offset + tmp_c_offset + l;
1098         ((float *)dst + dst_ic_offset)[0] = ((float *)src + src_ic_offset)[0];
1099       }
1100     }
1101   }
1102 }
1103 
PackNCHWToNC4HW4Fp32(const void * src,void * dst,int batch,int plane,int channel)1104 void PackNCHWToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
1105   int c4 = UP_DIV(channel, C4NUM);
1106   for (int b = 0; b < batch; b++) {
1107     int src_offset = b * plane * channel;
1108     int dst_offset = b * plane * c4 * C4NUM;
1109     RowMajor2Col4Major((const float *)src + src_offset, (float *)dst + dst_offset, channel, plane);
1110   }
1111 }
1112 
PackNHWCToNHWC4Fp32(const void * src,void * dst,int batch,int plane,int channel)1113 void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
1114   int oc_block = UP_DIV(channel, C4NUM);
1115   int oc_block_channel = oc_block * C4NUM;
1116   int ic_remainder_ = channel % C4NUM;
1117   if (ic_remainder_ != 0) {
1118     for (int b = 0; b < batch; b++) {
1119       int dst_batch_offset = b * oc_block_channel * plane;
1120       int batch_offset = b * channel * plane;
1121       for (int i = 0; i < plane; i++) {
1122         float *dst_per_plane = (float *)dst + dst_batch_offset + i * oc_block_channel;
1123         memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float));
1124         memset(dst_per_plane + channel, 0, (oc_block_channel - channel) * sizeof(float));
1125       }
1126     }
1127   } else {
1128     size_t ori_input_size = batch * plane * channel * sizeof(float);
1129     memcpy((float *)dst, (float *)src, ori_input_size);
1130   }
1131 }
1132 
PackNHWCToNHWCXFp32(const void * src,void * dst,int batch,int plane,int channel,int oc_tile)1133 void PackNHWCToNHWCXFp32(const void *src, void *dst, int batch, int plane, int channel, int oc_tile) {
1134   int oc_block = UP_DIV(channel, oc_tile);
1135   int oc_block_channel = oc_block * oc_tile;
1136   int ic_remainder_ = channel % oc_tile;
1137   if (ic_remainder_ != 0) {
1138     for (int b = 0; b < batch; b++) {
1139       int dst_batch_offset = b * oc_block_channel * plane;
1140       int batch_offset = b * channel * plane;
1141       for (int i = 0; i < plane; i++) {
1142         float *dst_per_plane = (float *)dst + dst_batch_offset + i * oc_block_channel;
1143         memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float));
1144         memset(dst_per_plane + channel, 0, (oc_block_channel - channel) * sizeof(float));
1145       }
1146     }
1147   } else {
1148     size_t ori_input_size = batch * plane * channel * sizeof(float);
1149     memcpy((float *)dst, (float *)src, ori_input_size);
1150   }
1151 }
1152 
1153 #if defined(ENABLE_AVX) || defined(ENABLE_ARM64)
PackNHWCToNXHWCXFp32H1W1(int output_channel,int oc_block_num,int input_channel,float * tmp_weight,const float * src,int oc_block_unit,Transpose8X8Fp32Func transpose_func)1154 void PackNHWCToNXHWCXFp32H1W1(int output_channel, int oc_block_num, int input_channel, float *tmp_weight,
1155                               const float *src, int oc_block_unit, Transpose8X8Fp32Func transpose_func) {
1156   int oc_block8 = DOWN_DIV(output_channel, C8NUM);
1157   int oc = 0;
1158   int oc_block = 0;
1159   int ic8 = DOWN_ROUND(input_channel, C8NUM);
1160   int oc_remainder_step = 0;
1161   if (oc_block8 != oc_block_num) {
1162     oc_block8 = oc_block8 / oc_block_unit * oc_block_unit;
1163     oc_remainder_step = (oc_block_num - oc_block8) * C8NUM;
1164   }
1165   for (; oc < oc_block8; oc += (oc_block / C8NUM)) {
1166     oc_block = MSMIN(oc_block_unit, oc_block8 - oc) * C8NUM;  // max_tile = 32 ==> 24 ==> 16 ==> 8
1167     for (int oc_tmp = 0; oc_tmp < oc_block; oc_tmp += C8NUM) {
1168       int ic = 0;
1169       for (; ic < ic8; ic += C8NUM) {
1170         transpose_func(src + ic, tmp_weight + ic * oc_block + oc_tmp, input_channel, oc_block);
1171       }
1172       for (; ic < input_channel; ++ic) {
1173         for (int j = 0; j < C8NUM; ++j) {
1174           tmp_weight[ic * oc_block + oc_tmp + j] = src[ic + input_channel * j];
1175         }
1176       }
1177       src += C8NUM * input_channel;
1178     }
1179     tmp_weight += oc_block * input_channel;
1180   }
1181   oc = output_channel - oc_block8 * C8NUM;
1182   for (int oc_remainder = 0; oc_remainder < oc; ++oc_remainder) {
1183     for (int ic = 0; ic < input_channel; ++ic) {
1184       tmp_weight[oc_remainder + oc_remainder_step * ic] = src[ic + oc_remainder * input_channel];
1185     }
1186   }
1187 }
1188 
1189 // PackNHWCToNXHWCXFp32 is SWPackNHWCToNXHWCXFp32 asm optimize
PackNHWCToNXHWCXFp32(int kernel_h,int kernel_w,int output_channel,int oc_block_num,int input_channel,float * tmp_weight,const float * src)1190 void PackNHWCToNXHWCXFp32(int kernel_h, int kernel_w, int output_channel, int oc_block_num, int input_channel,
1191                           float *tmp_weight, const float *src) {
1192 #ifdef ENABLE_ARM64
1193   Transpose8X8Fp32Func transpose_func = Transpose8X8Fp32Arm64;
1194   int oc_block_unit = C2NUM;
1195 #elif defined(ENABLE_AVX)
1196   Transpose8X8Fp32Func transpose_func = Transpose8X8Fp32Avx;
1197   int oc_block_unit = C4NUM;
1198 #endif
1199   // pack weight NHWC to N32HWC32 N24HWC24 N16HWC16 N8HWC8
1200   // output_channel: batch
1201   int plane = kernel_w * kernel_h;
1202   if (plane == 1) {  // conv 1x1 weight pack
1203     PackNHWCToNXHWCXFp32H1W1(output_channel, oc_block_num, input_channel, tmp_weight, src, oc_block_unit,
1204                              transpose_func);
1205     return;
1206   }
1207 
1208   int ic8 = DOWN_ROUND(input_channel, C8NUM);
1209   int oc_block8 = DOWN_DIV(output_channel, C8NUM);
1210   int oc_block = 0;
1211   int oc = 0;
1212   int oc_remainder_step = 0;
1213   if (oc_block8 != oc_block_num) {
1214     oc_block8 = oc_block8 / oc_block_unit * oc_block_unit;
1215     oc_remainder_step = (oc_block_num - oc_block8) * C8NUM;
1216   }
1217   for (; oc < oc_block8; oc += (oc_block / C8NUM)) {
1218     oc_block = MSMIN(oc_block_unit, oc_block8 - oc) * C8NUM;  // max_tile = 32 ==> 24 ==> 16 ==> 8
1219     for (int oc_tmp = 0; oc_tmp < oc_block; oc_tmp += C8NUM) {
1220       for (int hw = 0; hw < plane; ++hw) {
1221         int ic = 0;
1222         for (; ic < ic8; ic += C8NUM) {
1223           transpose_func(src + hw * input_channel + ic,
1224                          tmp_weight + hw * oc_block * input_channel + ic * oc_block + oc_tmp, input_channel * plane,
1225                          oc_block);
1226         }
1227         for (; ic < input_channel; ++ic) {
1228           for (int j = 0; j < C8NUM; ++j) {
1229             tmp_weight[ic * oc_block + oc_tmp + j + hw * oc_block * input_channel] =
1230               src[ic + input_channel * j * plane + hw * input_channel];
1231           }
1232         }
1233       }
1234       src += C8NUM * plane * input_channel;
1235     }
1236     tmp_weight += oc_block * input_channel * plane;
1237   }
1238   oc = output_channel - oc_block8 * C8NUM;
1239   for (int oc_remainder = 0; oc_remainder < oc; ++oc_remainder) {
1240     for (int hw = 0; hw < plane; ++hw) {
1241       for (int ic = 0; ic < input_channel; ++ic) {
1242         tmp_weight[oc_remainder + oc_remainder_step * ic + hw * input_channel * oc_remainder_step] =
1243           src[ic + (oc_remainder * plane + hw) * input_channel];
1244       }
1245     }
1246   }
1247 }
1248 
1249 #ifdef ENABLE_DEBUG
SWPackNHWCToNXHWCXFp32(int kernel_h,int kernel_w,int output_channel,int oc_block_num,int input_channel,float * tmp_weight,const float * src)1250 void SWPackNHWCToNXHWCXFp32(int kernel_h, int kernel_w, int output_channel, int oc_block_num, int input_channel,
1251                             float *tmp_weight, const float *src) {
1252   // pack weight NHWC to N32HWC32 N24HWC24 N16HWC16 N8HWC8
1253   int oc_block = 0;
1254   for (int i = 0; i < oc_block_num; i += oc_block) {
1255     oc_block = MSMIN(C4NUM, oc_block_num - i);  // max_tile = 4
1256     int index = i * C8NUM * kernel_h * kernel_w * input_channel;
1257     int oc_remainder = MSMIN(C8NUM * oc_block, output_channel - i * C8NUM);
1258     for (int h = 0; h < kernel_h; ++h) {
1259       for (int w = 0; w < kernel_w; ++w) {
1260         int w_index = (h * kernel_w + w) * input_channel + index;
1261         for (int ic = 0; ic < input_channel; ++ic) {
1262           int ic_index = ic + w_index;
1263           for (int oc = 0; oc < oc_remainder; ++oc) {
1264             int oc_index = oc * kernel_w * kernel_h * input_channel + ic_index;
1265             tmp_weight[oc] = src[oc_index];
1266           }
1267           tmp_weight += oc_block * C8NUM;
1268         }
1269       }
1270     }
1271   }
1272 }
1273 #endif
1274 #endif
1275 
PackNHWCToNHWC8Fp32(const void * src,void * dst,int batch,int plane,int channel)1276 void PackNHWCToNHWC8Fp32(const void *src, void *dst, int batch, int plane, int channel) {
1277   int c8 = UP_DIV(channel, C8NUM);
1278   int c8_channel = c8 * C8NUM;
1279   int nhwc8_batch_unit_offset = c8 * C8NUM * plane;
1280   int ic_remainder_ = channel % C8NUM;
1281   if (ic_remainder_ != 0) {
1282     int nhwc8_batch_offset = 0;
1283     for (int b = 0; b < batch; b++) {
1284       int batch_offset = b * channel * plane;
1285       for (int i = 0; i < plane; i++) {
1286         float *dst_per_plane = (float *)dst + nhwc8_batch_offset + i * c8_channel;
1287         memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float));
1288         for (int j = channel; j < c8_channel; ++j) {
1289           dst_per_plane[j] = 0;
1290         }
1291       }
1292       nhwc8_batch_offset += nhwc8_batch_unit_offset;
1293     }
1294   } else {
1295     size_t ori_input_size = batch * plane * channel * sizeof(float);
1296     memcpy((float *)dst, (float *)src, ori_input_size);
1297   }
1298 }
1299 
PackNHWCXToNHWCFp32(const void * src,void * dst,int batch,int plane,int channel,int cx_num)1300 void PackNHWCXToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel, int cx_num) {
1301   int c_algin = UP_DIV(channel, cx_num);
1302   int ic_remainder_ = channel % cx_num;
1303   if (ic_remainder_ != 0) {
1304     int nhwc_batch_unit_offset = channel * plane;
1305     for (int b = 0; b < batch; b++) {
1306       int batch_offset = b * c_algin * cx_num * plane;
1307       for (int i = 0; i < plane; i++) {
1308         memcpy((float *)dst + b * nhwc_batch_unit_offset + i * channel,
1309                (float *)src + batch_offset + i * c_algin * cx_num, channel * sizeof(float));
1310       }
1311     }
1312   } else {
1313     size_t ori_input_size = batch * plane * channel * sizeof(float);
1314     memcpy((float *)dst, (float *)src, ori_input_size);
1315   }
1316 }
1317 
PackNC4HW4ToNHWC4Fp32(const void * src,void * dst,int batch,int plane,int channel)1318 void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
1319   int c4 = UP_DIV(channel, C4NUM);
1320   for (int b = 0; b < batch; b++) {
1321     int src_offset = b * plane * c4 * C4NUM;
1322     int dst_offset = b * plane * channel;
1323     for (int c = 0; c < channel; c++) {
1324       int c4_block_num = c / C4NUM;
1325       int c4_block_res = c % C4NUM;
1326       int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res;
1327       int dst_c_offset = dst_offset + c4_block_num * C4NUM + c4_block_res;
1328       for (int k = 0; k < plane; k++) {
1329         int src_kernel_offset = src_c_offset + k * C4NUM;
1330         int dst_kernel_offset = dst_c_offset + k * c4 * C4NUM;
1331         ((float *)dst + dst_kernel_offset)[0] = ((float *)src + src_kernel_offset)[0];
1332       }
1333     }
1334   }
1335 }
1336 
UnPackC4Uint(const void * src,void * dst,size_t plane,size_t channel)1337 void UnPackC4Uint(const void *src, void *dst, size_t plane, size_t channel) {
1338   const float *fp32_src = (const float *)src;
1339   float *fp32_dst = (float *)dst;
1340   for (size_t c = 0; c < channel; c++) {
1341     size_t c_div = c / C4NUM;
1342     size_t c_mod = c % C4NUM;
1343     for (size_t p = 0; p < plane; p++) {
1344       int src_offset = c_div * plane * C4NUM + p * C4NUM + c_mod;
1345       int dst_offset = c * plane + p;
1346       fp32_dst[dst_offset] = fp32_src[src_offset];
1347     }
1348   }
1349 }
1350 
PackNC4HW4ToNCHWFp32(const void * src,void * dst,int batch,int plane,int channel)1351 void PackNC4HW4ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel) {
1352   int c4 = UP_ROUND(channel, C4NUM);
1353   for (int b = 0; b < batch; b++) {
1354     int src_offset = b * plane * c4;
1355     int dst_offset = b * plane * channel;
1356     UnPackC4Uint((const float *)src + src_offset, (float *)dst + dst_offset, plane, channel);
1357   }
1358 }
1359 
PackNC4HW4ToNHWCFp32(const void * src,void * dst,int batch,int plane,int channel)1360 void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) {
1361   int c4 = UP_DIV(channel, C4NUM);
1362   for (int b = 0; b < batch; b++) {
1363     int src_offset = b * plane * c4 * C4NUM;
1364     int dst_offset = b * plane * channel;
1365     for (int k = 0; k < plane; k++) {
1366       int src_kernel_offset = src_offset + k * C4NUM;
1367       int dst_kernel_offset = dst_offset + k * channel;
1368       for (int c = 0; c < c4 - 1; c++) {
1369         int src_c_offset = src_kernel_offset + c * plane * C4NUM;
1370         int dst_c_offset = dst_kernel_offset + c * C4NUM;
1371 #if defined(ENABLE_NEON) || defined(ENABLE_SSE)
1372         MS_STQ_F32((float *)dst + dst_c_offset, MS_LDQ_F32((float *)src + src_c_offset));
1373 #else
1374         ((float *)dst + dst_c_offset)[0] = ((float *)src + src_c_offset)[0];
1375         ((float *)dst + dst_c_offset)[1] = ((float *)src + src_c_offset)[1];
1376         ((float *)dst + dst_c_offset)[2] = ((float *)src + src_c_offset)[2];
1377         ((float *)dst + dst_c_offset)[3] = ((float *)src + src_c_offset)[3];
1378 #endif
1379       }
1380       // res part
1381       int res_c = channel - (c4 - 1) * C4NUM;
1382       for (int i = 0; i < res_c; i++) {
1383         int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i;
1384         int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i;
1385         ((float *)dst + dst_res_c_offset)[0] = ((float *)src + src_res_c_offset)[0];
1386       }
1387     }
1388   }
1389 }
1390 
PackNC8HW8ToNCHWFp32(const void * src,void * dst,int batch,int plane,int channel)1391 void PackNC8HW8ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel) {
1392   int c8 = UP_ROUND(channel, C8NUM);
1393   for (int b = 0; b < batch; b++) {
1394     int src_offset = b * plane * c8;
1395     int dst_offset = b * plane * channel;
1396 
1397     const float *fp32_src = (const float *)src + src_offset;
1398     float *fp32_dst = (float *)dst + dst_offset;
1399     for (size_t c = 0; c < channel; c++) {
1400       size_t c_div = c / C8NUM;
1401       size_t c_mod = c % C8NUM;
1402       for (size_t p = 0; p < plane; p++) {
1403         int src_offset_c = c_div * plane * C8NUM + p * C8NUM + c_mod;
1404         int dst_offset_c = c * plane + p;
1405         fp32_dst[dst_offset_c] = fp32_src[src_offset_c];
1406       }
1407     }
1408   }
1409 }
1410 
PackNHWCToNC8HW8Fp32(const void * src,void * dst,int batch,int plane,int channel)1411 void PackNHWCToNC8HW8Fp32(const void *src, void *dst, int batch, int plane, int channel) {
1412   int c8 = UP_DIV(channel, C8NUM);
1413   int c8_minus = c8 - 1;
1414   for (int b = 0; b < batch; b++) {
1415     int src_oc_offset = b * plane * channel;
1416     int dst_oc_offset = b * plane * c8 * C8NUM;
1417     for (int k = 0; k < plane; k++) {
1418       int src_kernel_offset = src_oc_offset + k * channel;
1419       int dst_kernel_offset = dst_oc_offset + k * C8NUM;
1420       for (int j = 0; j < c8_minus; ++j) {
1421         int src_ic_offset = src_kernel_offset + j * C8NUM;
1422         int dst_ic_offset = dst_kernel_offset + j * plane * C8NUM;
1423         for (int i = 0; i < C8NUM; ++i) {
1424           ((float *)dst + dst_ic_offset)[i] = ((float *)src + src_ic_offset)[i];
1425         }
1426       }
1427       int tmp_c = c8_minus * C8NUM;
1428       int tmp_c_offset = tmp_c * plane;
1429       int res_c = channel - tmp_c;
1430       if (res_c > channel) {
1431         return;
1432       }
1433       for (int l = 0; l < res_c; ++l) {
1434         int src_ic_offset = src_kernel_offset + tmp_c + l;
1435         int dst_ic_offset = dst_kernel_offset + tmp_c_offset + l;
1436         ((float *)dst + dst_ic_offset)[0] = ((float *)src + src_ic_offset)[0];
1437       }
1438     }
1439   }
1440 }
1441 
PackNC8HW8ToNHWCFp32(const void * src,void * dst,int batch,int plane,int channel)1442 void PackNC8HW8ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) {
1443   int c8 = UP_DIV(channel, C8NUM);
1444   for (int b = 0; b < batch; b++) {
1445     int src_offset = b * plane * c8 * C8NUM;
1446     int dst_offset = b * plane * channel;
1447     for (int k = 0; k < plane; k++) {
1448       int src_kernel_offset = src_offset + k * C8NUM;
1449       int dst_kernel_offset = dst_offset + k * channel;
1450       for (int c = 0; c < c8 - 1; c++) {
1451         int src_c_offset = src_kernel_offset + c * plane * C8NUM;
1452         int dst_c_offset = dst_kernel_offset + c * C8NUM;
1453 
1454         ((float *)dst + dst_c_offset)[Index0] = ((float *)src + src_c_offset)[Index0];
1455         ((float *)dst + dst_c_offset)[Index1] = ((float *)src + src_c_offset)[Index1];
1456         ((float *)dst + dst_c_offset)[Index2] = ((float *)src + src_c_offset)[Index2];
1457         ((float *)dst + dst_c_offset)[Index3] = ((float *)src + src_c_offset)[Index3];
1458       }
1459       // res part
1460       int res_c = channel - (c8 - 1) * C8NUM;
1461       for (int i = 0; i < res_c; i++) {
1462         int src_res_c_offset = src_kernel_offset + (c8 - 1) * C8NUM * plane + i;
1463         int dst_res_c_offset = dst_kernel_offset + (c8 - 1) * C8NUM + i;
1464         ((float *)dst + dst_res_c_offset)[0] = ((float *)src + src_res_c_offset)[0];
1465       }
1466     }
1467   }
1468 }
1469 
PackNC8HW8AlignedToNC8HW8NotAlignedFp32(const void * src,void * dst,const int batch,const int plane,const int channel)1470 void PackNC8HW8AlignedToNC8HW8NotAlignedFp32(const void *src, void *dst, const int batch, const int plane,
1471                                              const int channel) {
1472   int down_channel_8 = DOWN_ROUND(channel, C8NUM);
1473   int up_channel_16 = UP_ROUND(channel, C16NUM);
1474   size_t dst_batch_offset = (size_t)(plane * channel) * sizeof(float);
1475   size_t src_batch_offset = (size_t)(plane * up_channel_16) * sizeof(float);
1476   size_t unaligned_channel_size = (size_t)(channel - down_channel_8) * sizeof(float);
1477   size_t aligned_channel_size = (size_t)(down_channel_8 * plane) * sizeof(float);
1478   size_t src_p_offset = C8NUM * sizeof(float);
1479   for (size_t b = 0; b < (size_t)(batch); ++b) {
1480     const char *src_batch = (char *)(src) + b * src_batch_offset;
1481     char *dst_bacth = (char *)(dst) + b * dst_batch_offset;
1482     memcpy(dst_bacth, src_batch, aligned_channel_size);
1483     src_batch += aligned_channel_size;
1484     dst_bacth += aligned_channel_size;
1485     for (int p = 0; p < plane; ++p) {
1486       memcpy(dst_bacth + p * unaligned_channel_size, src_batch + p * src_p_offset, unaligned_channel_size);
1487     }
1488   }
1489 }
1490 
PackNHWCToC8HWN8Fp32(const void * src,void * dst,int batch,int plane,int channel)1491 void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel) {
1492   int channel_up8 = UP_ROUND(channel, C8NUM);
1493   for (int n = 0; n < batch; n++) {
1494     for (int hw = 0; hw < plane; hw++) {
1495       int c = 0;
1496       for (; c < channel; c++) {
1497         int c8div = c / C8NUM;
1498         int c8mod = c % C8NUM;
1499         int src_index = n * plane * channel + hw * channel + c;
1500         int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod;
1501         ((float *)dst)[dst_index] = ((float *)src)[src_index];
1502       }
1503       for (; c < channel_up8; c++) {
1504         int c8div = c / C8NUM;
1505         int c8mod = c % C8NUM;
1506         int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod;
1507         ((float *)dst)[dst_index] = 0;
1508       }
1509     }
1510   }
1511 }
1512 
PackNHWCToCXHWNXFp32(const float * src,float * dst,int batch,int plane,int channel)1513 void PackNHWCToCXHWNXFp32(const float *src, float *dst, int batch, int plane, int channel) {
1514   // pack weight NHWC to C24HWN24 (Priority 24)=>C16HWN16 (Not satisfied 24)=>C8HWN8 (Not satisfied 16)
1515 #ifdef ENABLE_AVX
1516   int oc_block_num = UP_DIV(channel, C8NUM);
1517   int plane16 = plane / C16NUM * C16NUM;
1518   for (int i = 0, oc_block = 0; i < oc_block_num; i += oc_block) {
1519     oc_block = MSMIN(C3NUM, oc_block_num - i);
1520     int oc_remainder = MSMIN(C8NUM * oc_block, channel - i * C8NUM);
1521     int oc_remainder_c8 = oc_remainder / C8NUM * C8NUM;
1522     int p = 0;
1523     for (; p < plane16; p += C16NUM) {
1524       int index_plane = i * C8NUM + p * channel;
1525       for (int b = 0; b < batch; ++b) {
1526         int index_batch = index_plane + b * plane * channel;
1527         int oc = 0;
1528         int stride = oc_block * C8NUM * batch;
1529         for (; oc < oc_remainder_c8; oc += C8NUM) {
1530           const float *cur_src = src + index_batch + oc;
1531           float *cur_dst = dst + oc;
1532           MS_LOAD256X16_F32(r, cur_src, channel);
1533           STORE256X16_F32(cur_dst, stride, r);
1534         }
1535         for (; oc < oc_remainder; ++oc) {
1536           for (int k = 0; k < C16NUM; ++k) {
1537             dst[oc + stride * k] = src[index_batch + oc + channel * k];
1538           }
1539         }
1540         for (; oc < C8NUM; ++oc) {
1541           for (int k = 0; k < C16NUM; ++k) {
1542             dst[oc + stride * k] = 0;
1543           }
1544         }
1545         dst += oc_block * C8NUM;
1546       }
1547       dst += (C16NUM - 1) * oc_block * C8NUM * batch;
1548     }
1549     for (; p < plane; ++p) {
1550       int index_plane = i * C8NUM + p * channel;
1551       for (int b = 0; b < batch; ++b) {
1552         int index_batch = index_plane + b * plane * channel;
1553         int oc = 0;
1554         for (; oc < oc_remainder; ++oc) {
1555           dst[oc] = src[index_batch + oc];
1556         }
1557         for (; oc < C8NUM; ++oc) {
1558           dst[oc] = 0;
1559         }
1560         dst += oc_block * C8NUM;
1561       }
1562     }
1563   }
1564 #else
1565   int oc_block = 0;
1566   int oc_block_num = UP_DIV(channel, C8NUM);
1567   for (int i = 0; i < oc_block_num; i += oc_block) {
1568     oc_block = MSMIN(C3NUM, oc_block_num - i);  // max_tile = 4
1569     int oc_remainder = MSMIN(C8NUM * oc_block, channel - i * C8NUM);
1570     for (int p = 0; p < plane; ++p) {
1571       int index_plane = i * C8NUM + p * channel;
1572       for (int b = 0; b < batch; ++b) {
1573         int index_batch = index_plane + b * plane * channel;
1574         for (int oc = 0; oc < oc_remainder; ++oc) {
1575           dst[oc] = src[index_batch + oc];
1576         }
1577         dst += oc_block * C8NUM;
1578       }
1579     }
1580   }
1581 #endif
1582 }
1583 
PackDepthwiseIndirectWeightC4Fp32(const void * src,void * dst,int height,int width,int channel)1584 void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel) {
1585   int c4 = UP_DIV(channel, C4NUM);
1586   for (int c = 0; c < c4; c++) {
1587     int dst_off_c = c * C4NUM * height * width;
1588     for (int i = 0; i < C4NUM; i++) {
1589       int src_off_c = (c * C4NUM + i) * height * width;
1590       for (int kh = 0; kh < height; kh++) {
1591         int src_off_kh = src_off_c + kh * width;
1592         for (int kw = 0; kw < width; kw++) {
1593           int dst_off = dst_off_c + kw * height * C4NUM + kh * C4NUM + i;
1594           ((float *)dst)[dst_off] = ((float *)src)[src_off_kh + kw];
1595         }
1596       }
1597     }
1598   }
1599 }
1600 
PackDepthwiseIndirectWeightC8Fp32(const void * src,void * dst,int height,int width,int channel)1601 void PackDepthwiseIndirectWeightC8Fp32(const void *src, void *dst, int height, int width, int channel) {
1602   int c8 = UP_DIV(channel, C8NUM);
1603   for (int c = 0; c < c8; c++) {
1604     int dst_off_c = c * C8NUM * height * width;
1605     for (int i = 0; i < C8NUM; i++) {
1606       int src_off_c = (c * C8NUM + i) * height * width;
1607       for (int kh = 0; kh < height; kh++) {
1608         int src_off_kh = src_off_c + kh * width;
1609         for (int kw = 0; kw < width; kw++) {
1610           int dst_off = dst_off_c + kw * height * C8NUM + kh * C8NUM + i;
1611           ((float *)dst)[dst_off] = ((float *)src)[src_off_kh + kw];
1612         }
1613       }
1614     }
1615   }
1616 }
1617 
PackNHWCToNCHWFp32(const void * src,void * dst,int batches,int plane,int channel,int task_id,int thread_count)1618 void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel, int task_id,
1619                         int thread_count) {
1620 #ifdef ENABLE_ARM64
1621   Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Arm64;
1622 #elif defined(ENABLE_ARM32)
1623   Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Arm32;
1624 #elif defined(ENABLE_AVX)
1625   Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Avx;
1626 #elif defined(ENABLE_SSE) && !defined(ENABLE_AVX)
1627   Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Sse;
1628 #endif
1629   int hw8 = plane / C8NUM;
1630   int task_start = 0;
1631   int task_end = plane;
1632   if (thread_count > 0) {
1633     int offset_hw = UP_DIV(hw8, thread_count) * C8NUM;
1634     task_start = offset_hw * task_id;
1635     int count = plane - task_start;
1636     if (count <= 0) {
1637       return;
1638     }
1639     task_end = (task_id + 1) == thread_count ? plane : MSMIN(plane, task_start + offset_hw);
1640     hw8 = task_start + ((task_end - task_start) >= offset_hw ? offset_hw : 0);
1641   } else {
1642     hw8 *= C8NUM;
1643   }
1644   int c8 = channel / C8NUM * C8NUM;
1645   int batch = plane * channel;
1646   for (int n = 0; n < batches; n++) {
1647     const float *src_batch = (const float *)src + n * batch;
1648     float *dst_batch = (float *)dst + n * batch;
1649     int hw = task_start;
1650     for (; hw < hw8; hw += C8NUM) {
1651       int c = 0;
1652       for (; c < c8; c += C8NUM) {
1653         const float *src_ptr = src_batch + hw * channel + c;
1654         float *dst_ptr = dst_batch + c * plane + hw;
1655 #if defined(ENABLE_ARM64) || defined(ENABLE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM32)
1656         Transpose8X8Fp32Func_(src_ptr, dst_ptr, channel, plane);
1657 #else
1658         for (int tr = 0; tr < C8NUM; tr++) {
1659           for (int tc = 0; tc < C8NUM; tc++) {
1660             dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc];
1661           }
1662         }
1663 #endif
1664       }
1665       for (; c < channel; c++) {
1666         const float *src_ptr = src_batch + hw * channel + c;
1667         float *dst_ptr = dst_batch + c * plane + hw;
1668         for (size_t i = 0; i < C8NUM; i++) {
1669           dst_ptr[i] = src_ptr[i * channel];
1670         }
1671       }
1672     }
1673     for (; hw < task_end; hw++) {
1674       const float *src_ptr = src_batch + hw * channel;
1675       float *dst_ptr = dst_batch + hw;
1676       for (size_t i = 0; i < channel; i++) {
1677         dst_ptr[i * plane] = src_ptr[i];
1678       }
1679     }
1680   }
1681 }
1682 
1683 /*
1684 |<---------------- plane --------------->|
1685 +---------------------------+------------+  ---
1686 |          |     |          |            |   ↑
1687 |8x8-blocks| ... |8x8-blocks|   right    |   |
1688 |          |     |          |            |   |
1689 + - - - - -+     + - - - - -+            |   |
1690 |   ...      ...     ...    |    top     | channel
1691 + - - - - -+     + - - - - -|            |   |
1692 |          |     |          |   tails    |   |
1693 |8x8-blocks| ... |8x8-blocks|            |   |
1694 +---------------------------+------------+   |
1695 |                           |right bottom|   |
1696 |     left bottom tails     |   tails    |   ↓
1697 +---------------------------+------------+  ---
1698 */
TransposeFp32(const void * src,void * dst,int batches,int channel,int plane,int start,int end)1699 void TransposeFp32(const void *src, void *dst, int batches, int channel, int plane, int start, int end) {
1700 #ifdef ENABLE_ARM64
1701   Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Arm64;
1702 #elif defined(ENABLE_ARM32)
1703   Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Arm32;
1704 #elif defined(ENABLE_AVX)
1705   Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Avx;
1706 #elif defined(ENABLE_SSE) && !defined(ENABLE_AVX)
1707   Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Sse;
1708 #endif
1709   int m_pad = UP_DIV(channel, C8NUM);
1710   int n_pad = UP_DIV(plane, C8NUM);
1711   int m_blk = channel / C8NUM;
1712   int n_blk = plane / C8NUM;
1713   int b_stride = plane * channel;
1714   //  printf("channel, plane: %d, %d\n", channel, plane);
1715   int b = 0, m = 0, n = 0;
1716   // To make write dst consecutively, (m,n):(0,0)->(1,0)->...->(0,1)->(1,1)->...
1717   offset_to_index_init(start, 6, &m, m_pad, &n, n_pad, &b, batches);
1718   for (int task = start; task < end; task++) {
1719     const float *src_batch = (const float *)src + b * b_stride;
1720     float *dst_batch = (float *)dst + b * b_stride;
1721     int m_start = m * C8NUM;
1722     int n_start = n * C8NUM;
1723     if (m < m_blk && n < n_blk) {
1724       // process 8x8-blocks
1725       const float *from = src_batch + m_start * plane + n_start;
1726       float *to = dst_batch + n_start * channel + m_start;
1727 #if defined(ENABLE_ARM64) || defined(ENABLE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM32)
1728       Transpose8X8Fp32Func_(from, to, plane, channel);
1729 #else
1730       for (int tr = 0; tr < C8NUM; tr++) {
1731         for (int tc = 0; tc < C8NUM; tc++) {
1732           to[tc * channel + tr] = from[tr * plane + tc];
1733         }
1734       }
1735 #endif
1736     } else {
1737       // process right bottom tails
1738       const float *from = src_batch;
1739       float *to = dst_batch;
1740       int i_start = m_start;
1741       int i_end = channel;
1742       int j_start = n_start;
1743       int j_end = plane;
1744       if (m >= m_blk && n < n_blk) {
1745         // process left bottom tails
1746         from = src_batch + n_start;
1747         to = dst_batch + n_start * channel;
1748         j_start = 0;
1749         j_end = C8NUM;
1750       } else if (m < m_blk && n >= n_blk) {
1751         // process right top tails
1752         from = src_batch + m_start * plane;
1753         to = dst_batch + m_start;
1754         i_start = 0;
1755         i_end = C8NUM;
1756       }
1757       transpose_tail(from, to, j_start, j_end, i_start, i_end, channel, plane);
1758     }
1759     offset_to_index_step(6, &m, m_pad, &n, n_pad, &b, batches);
1760   }
1761 }
1762 
PackNCHWToNHWCFp32(const void * src,void * dst,int batch,int plane,int channel,int task_id,int thread_count)1763 void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel, int task_id, int thread_count) {
1764   PackNHWCToNCHWFp32(src, dst, batch, channel, plane, task_id, thread_count);
1765 }
1766 
1767 #ifdef ENABLE_ARM64
Transpose8X8Fp32Arm64(const float * src_ptr,float * dst_ptr,int src_stride,int dst_stride)1768 inline void Transpose8X8Fp32Arm64(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) {
1769   size_t srcStride = src_stride * sizeof(float);
1770   size_t dstStride = dst_stride * sizeof(float);
1771   asm volatile(
1772     "mov x10, %[src_ptr]\n"
1773     "mov x11, %[dst_ptr]\n"
1774 
1775     "ld1 {v0.4s, v1.4s}, [x10], %[srcStride]\n"
1776     "ld1 {v2.4s, v3.4s}, [x10], %[srcStride]\n"
1777 
1778     "zip1 v8.4s, v0.4s, v2.4s\n"
1779     "zip2 v9.4s, v0.4s, v2.4s\n"
1780     "zip1 v12.4s, v1.4s, v3.4s\n"
1781     "zip2 v13.4s, v1.4s, v3.4s\n"
1782 
1783     "ld1 {v4.4s, v5.4s}, [x10], %[srcStride]\n"
1784     "ld1 {v6.4s, v7.4s}, [x10], %[srcStride]\n"
1785 
1786     "zip1 v10.4s, v4.4s, v6.4s\n"
1787     "zip2 v11.4s, v4.4s, v6.4s\n"
1788     "zip1 v14.4s, v5.4s, v7.4s\n"
1789     "zip2 v15.4s, v5.4s, v7.4s\n"
1790 
1791     "ld1 {v0.4s, v1.4s}, [x10], %[srcStride]\n"
1792     "ld1 {v2.4s, v3.4s}, [x10], %[srcStride]\n"
1793 
1794     "trn1 v16.2d, v8.2d, v10.2d\n"
1795     "trn2 v18.2d, v8.2d, v10.2d\n"
1796     "trn1 v20.2d, v9.2d, v11.2d\n"
1797     "trn2 v22.2d, v9.2d, v11.2d\n"
1798 
1799     "ld1 {v4.4s, v5.4s}, [x10], %[srcStride]\n"
1800     "ld1 {v6.4s, v7.4s}, [x10], %[srcStride]\n"
1801 
1802     "trn1 v24.2d, v12.2d, v14.2d\n"
1803     "trn2 v26.2d, v12.2d, v14.2d\n"
1804     "trn1 v28.2d, v13.2d, v15.2d\n"
1805     "trn2 v30.2d, v13.2d, v15.2d\n"
1806 
1807     "zip1 v8.4s, v0.4s, v2.4s\n"
1808     "zip2 v9.4s, v0.4s, v2.4s\n"
1809     "zip1 v12.4s, v1.4s, v3.4s\n"
1810     "zip2 v13.4s, v1.4s, v3.4s\n"
1811 
1812     "zip1 v10.4s, v4.4s, v6.4s\n"
1813     "zip2 v11.4s, v4.4s, v6.4s\n"
1814     "zip1 v14.4s, v5.4s, v7.4s\n"
1815     "zip2 v15.4s, v5.4s, v7.4s\n"
1816 
1817     "trn1 v17.2d, v8.2d, v10.2d\n"
1818     "trn2 v19.2d, v8.2d, v10.2d\n"
1819     "trn1 v21.2d, v9.2d, v11.2d\n"
1820     "trn2 v23.2d, v9.2d, v11.2d\n"
1821 
1822     "trn1 v25.2d, v12.2d, v14.2d\n"
1823     "trn2 v27.2d, v12.2d, v14.2d\n"
1824     "trn1 v29.2d, v13.2d, v15.2d\n"
1825     "trn2 v31.2d, v13.2d, v15.2d\n"
1826 
1827     "st1 {v16.4s, v17.4s}, [x11], %[dstStride]\n"
1828     "st1 {v18.4s, v19.4s}, [x11], %[dstStride]\n"
1829     "st1 {v20.4s, v21.4s}, [x11], %[dstStride]\n"
1830     "st1 {v22.4s, v23.4s}, [x11], %[dstStride]\n"
1831     "st1 {v24.4s, v25.4s}, [x11], %[dstStride]\n"
1832     "st1 {v26.4s, v27.4s}, [x11], %[dstStride]\n"
1833     "st1 {v28.4s, v29.4s}, [x11], %[dstStride]\n"
1834     "st1 {v30.4s, v31.4s}, [x11], %[dstStride]\n"
1835 
1836     :
1837     : [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride)
1838     : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
1839       "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
1840       "v31");
1841 }
1842 #endif
1843 
1844 #ifdef ENABLE_ARM32
Transpose8X8Fp32Arm32(const float * src_ptr,float * dst_ptr,int src_stride,int dst_stride)1845 inline void Transpose8X8Fp32Arm32(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) {
1846   size_t srcStride = src_stride * sizeof(float);
1847   size_t dstStride = dst_stride * sizeof(float);
1848   asm volatile(
1849     "mov r10, %[src_ptr]\n"
1850     "mov r12, %[dst_ptr]\n"
1851 
1852     "vld1.32 {q0, q1}, [r10], %[srcStride]\n"
1853     "vld1.32 {q2, q3}, [r10], %[srcStride]\n"
1854 
1855     "vtrn.32 d0, d4\n"
1856     "vtrn.32 d1, d5\n"
1857     "vtrn.32 d2, d6\n"
1858     "vtrn.32 d3, d7\n"
1859 
1860     "vld1.32 {q4, q5}, [r10], %[srcStride]\n"
1861     "vld1.32 {q6, q7}, [r10], %[srcStride]\n"
1862 
1863     "vtrn.32 d8, d12\n"
1864     "vtrn.32 d9, d13\n"
1865     "vtrn.32 d10, d14\n"
1866     "vtrn.32 d11, d15\n"
1867 
1868     "vld1.32 {q8, q9}, [r10], %[srcStride]\n"
1869     "vld1.32 {q10, q11}, [r10], %[srcStride]\n"
1870 
1871     "vswp d1, d8\n"
1872     "vswp d3, d10\n"
1873     "vswp d5, d12\n"
1874     "vswp d7, d14\n"
1875 
1876     "vtrn.32 d16, d20\n"
1877     "vtrn.32 d17, d21\n"
1878     "vtrn.32 d18, d22\n"
1879     "vtrn.32 d19, d23\n"
1880 
1881     "vld1.32 {q12, q13}, [r10], %[srcStride]\n"
1882     "vld1.32 {q14, q15}, [r10], %[srcStride]\n"
1883 
1884     "vtrn.32 d24, d28\n"
1885     "vtrn.32 d25, d29\n"
1886     "vtrn.32 d26, d30\n"
1887     "vtrn.32 d27, d31\n"
1888 
1889     "vswp d17, d24\n"
1890     "vswp d19, d26\n"
1891     "vswp d21, d28\n"
1892     "vswp d23, d30\n"
1893 
1894     "add r10, r12, #16\n"
1895     "vst1.32 {q0}, [r12], %[dstStride]\n"
1896     "vst1.32 {q8}, [r10], %[dstStride]\n"
1897     "vst1.32 {q2}, [r12], %[dstStride]\n"
1898     "vst1.32 {q10}, [r10], %[dstStride]\n"
1899     "vst1.32 {q4}, [r12], %[dstStride]\n"
1900     "vst1.32 {q12}, [r10], %[dstStride]\n"
1901     "vst1.32 {q6}, [r12], %[dstStride]\n"
1902     "vst1.32 {q14}, [r10], %[dstStride]\n"
1903     "vst1.32 {q1}, [r12], %[dstStride]\n"
1904     "vst1.32 {q9}, [r10], %[dstStride]\n"
1905     "vst1.32 {q3}, [r12], %[dstStride]\n"
1906     "vst1.32 {q11}, [r10], %[dstStride]\n"
1907     "vst1.32 {q5}, [r12], %[dstStride]\n"
1908     "vst1.32 {q13}, [r10], %[dstStride]\n"
1909     "vst1.32 {q7}, [r12], %[dstStride]\n"
1910     "vst1.32 {q15}, [r10], %[dstStride]\n"
1911 
1912     :
1913     : [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride)
1914     : "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14",
1915       "q15");
1916 }
1917 #endif
1918 
1919 #ifdef ENABLE_AVX
1920 /*
1921   Using _mm256_insertf128_ps at the beginning, instead of using  _mm256_permute2f128_ps at the end.
1922   On the whole, 4 vinsertf128 and 4 vperm2f128 are used less than before.
1923 */
Transpose8X8Fp32Avx(const float * src_ptr,float * dst_ptr,int src_stride,int dst_stride)1924 inline void Transpose8X8Fp32Avx(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) {
1925   const float *src1 = src_ptr + 0 * src_stride;
1926   const float *src2 = src_ptr + 1 * src_stride;
1927   const float *src3 = src_ptr + 2 * src_stride;
1928   const float *src4 = src_ptr + 3 * src_stride;
1929   const float *src5 = src_ptr + 4 * src_stride;
1930   const float *src6 = src_ptr + 5 * src_stride;
1931   const float *src7 = src_ptr + 6 * src_stride;
1932   const float *src8 = src_ptr + 7 * src_stride;
1933 
1934   __m256 r1, r2, r3, r4, r5, r6, r7, r8;
1935   __m256 t1, t2, t3, t4, t5, t6, t7, t8;
1936   // _mm256_castps128_ps256 is only for compilation and generates no instructions, thus it has zero latency.
1937   r1 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src1 + 0)), _mm_loadu_ps(src5 + 0), 1);
1938   r2 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src2 + 0)), _mm_loadu_ps(src6 + 0), 1);
1939   r3 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src3 + 0)), _mm_loadu_ps(src7 + 0), 1);
1940   r4 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src4 + 0)), _mm_loadu_ps(src8 + 0), 1);
1941   r5 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src1 + 4)), _mm_loadu_ps(src5 + 4), 1);
1942   r6 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src2 + 4)), _mm_loadu_ps(src6 + 4), 1);
1943   r7 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src3 + 4)), _mm_loadu_ps(src7 + 4), 1);
1944   r8 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src4 + 4)), _mm_loadu_ps(src8 + 4), 1);
1945 
1946   t1 = _mm256_unpacklo_ps(r1, r2);
1947   t2 = _mm256_unpackhi_ps(r1, r2);
1948   t3 = _mm256_unpacklo_ps(r3, r4);
1949   t4 = _mm256_unpackhi_ps(r3, r4);
1950   t5 = _mm256_unpacklo_ps(r5, r6);
1951   t6 = _mm256_unpackhi_ps(r5, r6);
1952   t7 = _mm256_unpacklo_ps(r7, r8);
1953   t8 = _mm256_unpackhi_ps(r7, r8);
1954 
1955   __m256 v;
1956   v = _mm256_shuffle_ps(t1, t3, 0x4E);
1957   r1 = _mm256_blend_ps(t1, v, 0xCC);
1958   r2 = _mm256_blend_ps(t3, v, 0x33);
1959 
1960   v = _mm256_shuffle_ps(t2, t4, 0x4E);
1961   r3 = _mm256_blend_ps(t2, v, 0xCC);
1962   r4 = _mm256_blend_ps(t4, v, 0x33);
1963 
1964   v = _mm256_shuffle_ps(t5, t7, 0x4E);
1965   r5 = _mm256_blend_ps(t5, v, 0xCC);
1966   r6 = _mm256_blend_ps(t7, v, 0x33);
1967 
1968   v = _mm256_shuffle_ps(t6, t8, 0x4E);
1969   r7 = _mm256_blend_ps(t6, v, 0xCC);
1970   r8 = _mm256_blend_ps(t8, v, 0x33);
1971 
1972   STORE256X8_F32(dst_ptr, dst_stride, r);
1973 }
1974 #endif
1975 
1976 #if defined(ENABLE_SSE) && !defined(ENABLE_AVX)
Transpose8X8Fp32Sse(const float * src_ptr,float * dst_ptr,int src_stride,int dst_stride)1977 inline void Transpose8X8Fp32Sse(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) {
1978   __m128 v0_ma = _mm_loadu_ps(src_ptr);
1979   __m128 v1_ma = _mm_loadu_ps(src_ptr + src_stride);
1980   __m128 v2_ma = _mm_loadu_ps(src_ptr + 2 * src_stride);
1981   __m128 v3_ma = _mm_loadu_ps(src_ptr + 3 * src_stride);
1982 
1983   __m128 v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma);
1984   __m128 v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma);
1985   __m128 v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma);
1986   __m128 v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma);
1987 
1988   __m128 v8_ma = _mm_movelh_ps(v4_ma, v6_ma);
1989   __m128 v9_ma = _mm_movehl_ps(v6_ma, v4_ma);
1990   __m128 v10_ma = _mm_movelh_ps(v5_ma, v7_ma);
1991   __m128 v11_ma = _mm_movehl_ps(v7_ma, v5_ma);
1992 
1993   _mm_storeu_ps(dst_ptr, v8_ma);
1994   _mm_storeu_ps(dst_ptr + dst_stride, v9_ma);
1995   _mm_storeu_ps(dst_ptr + 2 * dst_stride, v10_ma);
1996   _mm_storeu_ps(dst_ptr + 3 * dst_stride, v11_ma);
1997 
1998   v0_ma = _mm_loadu_ps(src_ptr + C4NUM);
1999   v1_ma = _mm_loadu_ps(src_ptr + src_stride + C4NUM);
2000   v2_ma = _mm_loadu_ps(src_ptr + 2 * src_stride + C4NUM);
2001   v3_ma = _mm_loadu_ps(src_ptr + 3 * src_stride + C4NUM);
2002 
2003   v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma);
2004   v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma);
2005   v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma);
2006   v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma);
2007 
2008   v8_ma = _mm_movelh_ps(v4_ma, v6_ma);
2009   v9_ma = _mm_movehl_ps(v6_ma, v4_ma);
2010   v10_ma = _mm_movelh_ps(v5_ma, v7_ma);
2011   v11_ma = _mm_movehl_ps(v7_ma, v5_ma);
2012 
2013   _mm_storeu_ps(dst_ptr + C4NUM * dst_stride, v8_ma);
2014   _mm_storeu_ps(dst_ptr + (C4NUM + 1) * dst_stride, v9_ma);
2015   _mm_storeu_ps(dst_ptr + (C4NUM + 2) * dst_stride, v10_ma);
2016   _mm_storeu_ps(dst_ptr + (C4NUM + 3) * dst_stride, v11_ma);
2017 
2018   v0_ma = _mm_loadu_ps(src_ptr + C4NUM * src_stride);
2019   v1_ma = _mm_loadu_ps(src_ptr + (C4NUM + 1) * src_stride);
2020   v2_ma = _mm_loadu_ps(src_ptr + (C4NUM + 2) * src_stride);
2021   v3_ma = _mm_loadu_ps(src_ptr + (C4NUM + 3) * src_stride);
2022 
2023   v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma);
2024   v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma);
2025   v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma);
2026   v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma);
2027 
2028   v8_ma = _mm_movelh_ps(v4_ma, v6_ma);
2029   v9_ma = _mm_movehl_ps(v6_ma, v4_ma);
2030   v10_ma = _mm_movelh_ps(v5_ma, v7_ma);
2031   v11_ma = _mm_movehl_ps(v7_ma, v5_ma);
2032 
2033   _mm_storeu_ps(dst_ptr + C4NUM, v8_ma);
2034   _mm_storeu_ps(dst_ptr + dst_stride + C4NUM, v9_ma);
2035   _mm_storeu_ps(dst_ptr + 2 * dst_stride + C4NUM, v10_ma);
2036   _mm_storeu_ps(dst_ptr + 3 * dst_stride + C4NUM, v11_ma);
2037 
2038   v0_ma = _mm_loadu_ps(src_ptr + C4NUM * src_stride + C4NUM);
2039   v1_ma = _mm_loadu_ps(src_ptr + (C4NUM + 1) * src_stride + C4NUM);
2040   v2_ma = _mm_loadu_ps(src_ptr + (C4NUM + 2) * src_stride + C4NUM);
2041   v3_ma = _mm_loadu_ps(src_ptr + (C4NUM + 3) * src_stride + C4NUM);
2042 
2043   v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma);
2044   v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma);
2045   v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma);
2046   v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma);
2047 
2048   v8_ma = _mm_movelh_ps(v4_ma, v6_ma);
2049   v9_ma = _mm_movehl_ps(v6_ma, v4_ma);
2050   v10_ma = _mm_movelh_ps(v5_ma, v7_ma);
2051   v11_ma = _mm_movehl_ps(v7_ma, v5_ma);
2052 
2053   _mm_storeu_ps(dst_ptr + C4NUM * dst_stride + C4NUM, v8_ma);
2054   _mm_storeu_ps(dst_ptr + (C4NUM + 1) * dst_stride + C4NUM, v9_ma);
2055   _mm_storeu_ps(dst_ptr + (C4NUM + 2) * dst_stride + C4NUM, v10_ma);
2056   _mm_storeu_ps(dst_ptr + (C4NUM + 3) * dst_stride + C4NUM, v11_ma);
2057 }
2058 #endif
2059 
2060 #if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
PackWeightConvDw3x3Fp32(const void * src,void * dst,int channel)2061 void PackWeightConvDw3x3Fp32(const void *src, void *dst, int channel) {
2062   // nchw to nc4hw4 with 1D F(2,3)
2063   for (int i = 0; i < channel; i++) {
2064     float *src_kernel = (float *)src + i * 9;
2065     float *dst_kernel = (float *)dst + (i / 4) * 48 + i % 4;
2066     for (int y = 0; y < 3; y++) {
2067       float g0 = src_kernel[3 * y];
2068       float g1 = src_kernel[3 * y + 1];
2069       float g2 = src_kernel[3 * y + 2];
2070 
2071       dst_kernel[16 * y] = g0;
2072       dst_kernel[16 * y + 4] = 0.5f * (g0 + g1 + g2);
2073       dst_kernel[16 * y + 8] = 0.5f * (g0 - g1 + g2);
2074       dst_kernel[16 * y + 12] = g2;
2075     }
2076   }
2077 }
2078 #endif
2079