1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/fp32/pack_fp32.h"
18 #include "nnacl/intrinsics/ms_simd_instructions.h"
19
PackWeightKHWToHWKFp32(const void * src,void * dst,int plane,int channel)20 void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel) {
21 PackNCHWToNHWCFp32(src, dst, 1, plane, channel, 0, 0);
22 }
23
PackHWCToWHC(const float * src,float * dst,int height,int width,int channel)24 void PackHWCToWHC(const float *src, float *dst, int height, int width, int channel) {
25 for (int i = 0; i < height; ++i) {
26 for (int j = 0; j < width; ++j) {
27 memcpy(dst + (j * height + i) * channel, src + (i * width + j) * channel, channel * sizeof(float));
28 }
29 }
30 }
31
PackNHWCToNC4HW4NotAlignedFp32(const float * src,float * dst,const int batch,const int plane,const int channel)32 void PackNHWCToNC4HW4NotAlignedFp32(const float *src, float *dst, const int batch, const int plane, const int channel) {
33 if (channel <= C4NUM) {
34 memcpy(dst, src, batch * plane * channel * sizeof(float));
35 return;
36 }
37 int tmp = DOWN_DIV(channel, C4NUM);
38 int c_res = channel - tmp * C4NUM;
39 int c4_block = tmp * plane * C4NUM;
40 for (int b = 0; b < batch; b++) {
41 int batch_oc_offset = b * plane * channel;
42 for (int k = 0; k < plane; k++) {
43 int src_kernel_offset = batch_oc_offset + k * channel;
44 int dst_kernel_offset = batch_oc_offset + k * C4NUM;
45 int c = 0;
46 for (; c <= channel - C4NUM; c += C4NUM) {
47 #if defined(ENABLE_SSE) || defined(ENABLE_ARM)
48 MS_FLOAT32X4 src_data = MS_LDQ_F32(src + src_kernel_offset + c);
49 MS_STQ_F32(dst + dst_kernel_offset + c * plane, src_data);
50 #else
51 for (int k1 = 0; k1 < C4NUM; ++k1) {
52 (dst + dst_kernel_offset + c * plane)[k1] = (src + src_kernel_offset + c)[k1];
53 }
54 #endif
55 }
56 for (; c < channel; ++c) {
57 dst[batch_oc_offset + c4_block + k * c_res + c - tmp * C4NUM] = src[src_kernel_offset + c];
58 }
59 }
60 }
61 }
62
PackNHWCToNC8HW8NotAlignedFp32(const float * src,float * dst,const int batch,const int plane,const int channel)63 void PackNHWCToNC8HW8NotAlignedFp32(const float *src, float *dst, const int batch, const int plane, const int channel) {
64 if (channel <= C8NUM) {
65 memcpy(dst, src, batch * plane * channel * sizeof(float));
66 return;
67 }
68 int tmp = DOWN_DIV(channel, C8NUM);
69 int c_res = channel - tmp * C8NUM;
70 int c8_block = tmp * plane * C8NUM;
71 for (int b = 0; b < batch; b++) {
72 int batch_oc_offset = b * plane * channel;
73 for (int k = 0; k < plane; k++) {
74 int src_kernel_offset = batch_oc_offset + k * channel;
75 int dst_kernel_offset = batch_oc_offset + k * C8NUM;
76 int c = 0;
77 for (; c <= channel - C8NUM; c += C8NUM) {
78 #ifdef ENABLE_AVX
79 MS_FLOAT32X8 src_data = MS_LD256_F32(src + src_kernel_offset + c);
80 MS_ST256_F32(dst + dst_kernel_offset + c * plane, src_data);
81 #else
82 for (int k1 = 0; k1 < C8NUM; ++k1) {
83 (dst + dst_kernel_offset + c * plane)[k1] = (src + src_kernel_offset + c)[k1];
84 }
85 #endif
86 }
87 for (; c < channel; ++c) {
88 dst[batch_oc_offset + c8_block + k * c_res + c - tmp * C8NUM] = src[src_kernel_offset + c];
89 }
90 }
91 }
92 }
93
RowMajor2ColMajorParallel(const float * src_ptr,float * dst_ptr,int row,int col,int row_start,int row_end)94 void RowMajor2ColMajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) {
95 for (int r = row_start; r < row_end; ++r) {
96 for (int c = 0; c < col; ++c) {
97 dst_ptr[c * row + r] = src_ptr[r * col + c];
98 }
99 }
100 }
101
RowMajor2RowMajorParallel(const float * src_ptr,float * dst_ptr,int row,int col,int row_start,int row_end)102 void RowMajor2RowMajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) {
103 if (row_end > row_start) {
104 src_ptr += row_start * col;
105 dst_ptr += row_start * col;
106 memcpy(dst_ptr, src_ptr, (row_end - row_start) * col * (int)(sizeof(float)));
107 }
108 }
109
RowMajor2Row4MajorParallel(const float * src_ptr,float * dst_ptr,int row,int col,int row_start,int row_end)110 void RowMajor2Row4MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) {
111 for (int r = row_start; r < row_end; r++) {
112 const float *src = src_ptr + r * col;
113 int c = 0;
114 for (; c < col; c++) {
115 int cd4 = c / C4NUM;
116 int cm4 = c % C4NUM;
117 dst_ptr[cd4 * C4NUM * row + r * C4NUM + cm4] = src[c];
118 }
119 for (; c < UP_ROUND(col, C4NUM); c++) {
120 int cd4 = c / C4NUM;
121 int cm4 = c % C4NUM;
122 dst_ptr[cd4 * C4NUM * row + r * C4NUM + cm4] = 0;
123 }
124 }
125 return;
126 }
127
RowMajor2Row6MajorParallel(const float * src_ptr,float * dst_ptr,int row,int col,int row_start,int row_end)128 void RowMajor2Row6MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) {
129 for (int r = row_start; r < row_end; r++) {
130 const float *src = src_ptr + r * col;
131 int c = 0;
132 for (; c < col; c++) {
133 int cd6 = c / C6NUM;
134 int cm6 = c % C6NUM;
135 dst_ptr[cd6 * C6NUM * row + r * C6NUM + cm6] = src[c];
136 }
137 for (; c < UP_ROUND(col, C6NUM); c++) {
138 int cd6 = c / C6NUM;
139 int cm6 = c % C6NUM;
140 dst_ptr[cd6 * C6NUM * row + r * C6NUM + cm6] = 0;
141 }
142 }
143 return;
144 }
145
RowMajor2Row8MajorParallel(const float * src_ptr,float * dst_ptr,int row,int col,int row_start,int row_end)146 void RowMajor2Row8MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) {
147 for (int r = row_start; r < row_end; r++) {
148 const float *src = src_ptr + r * col;
149 int c = 0;
150 for (; c < col; c++) {
151 int cd8 = c / C8NUM;
152 int cm8 = c % C8NUM;
153 dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = src[c];
154 }
155 for (; c < UP_ROUND(col, C8NUM); c++) {
156 int cd8 = c / C8NUM;
157 int cm8 = c % C8NUM;
158 dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = 0;
159 }
160 }
161 return;
162 }
163
RowMajor2Row12MajorParallel(const float * src_ptr,float * dst_ptr,int row,int col,int row_start,int row_end)164 void RowMajor2Row12MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) {
165 for (int r = row_start; r < row_end; r++) {
166 const float *src = src_ptr + r * col;
167 int c = 0;
168 for (; c < col; c++) {
169 int cd12 = c / C12NUM;
170 int cm12 = c % C12NUM;
171 dst_ptr[cd12 * C12NUM * row + r * C12NUM + cm12] = src[c];
172 }
173 for (; c < UP_ROUND(col, C12NUM); c++) {
174 int cd12 = c / C12NUM;
175 int cm12 = c % C12NUM;
176 dst_ptr[cd12 * C12NUM * row + r * C12NUM + cm12] = 0;
177 }
178 }
179 return;
180 }
181
RowMajor2Row16MajorParallel(const float * src_ptr,float * dst_ptr,int row,int col,int row_start,int row_end)182 void RowMajor2Row16MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) {
183 for (int r = row_start; r < row_end; r++) {
184 const float *src = src_ptr + r * col;
185 int c = 0;
186 for (; c < col; c++) {
187 int cd16 = c / C16NUM;
188 int cm16 = c % C16NUM;
189 dst_ptr[cd16 * C16NUM * row + r * C16NUM + cm16] = src[c];
190 }
191 for (; c < UP_ROUND(col, C16NUM); c++) {
192 int cd16 = c / C16NUM;
193 int cm16 = c % C16NUM;
194 dst_ptr[cd16 * C16NUM * row + r * C16NUM + cm16] = 0;
195 }
196 }
197 return;
198 }
199
RowMajor2Row32MajorParallel(const float * src_ptr,float * dst_ptr,int col,int row,int col_start,int col_end)200 void RowMajor2Row32MajorParallel(const float *src_ptr, float *dst_ptr, int col, int row, int col_start, int col_end) {
201 // Not exactly aligned to 32, but aligned to 24 or 16 or 8 If 32 is not met.
202 int row_block_num = UP_DIV(row, C8NUM);
203 int row_block = C4NUM;
204 for (int i = 0; i < row_block_num; i += row_block) {
205 row_block = MSMIN(C4NUM, row_block_num - i); // max_tile = 4
206 int row_remainder = MSMIN(row_block * C8NUM, row - i * C8NUM);
207 dst_ptr += col_start * row_block * C8NUM;
208 for (int oc = col_start; oc < col_end; ++oc) {
209 memcpy(dst_ptr, src_ptr + oc * row + i * C8NUM, row_remainder * sizeof(float));
210 dst_ptr += row_block * C8NUM;
211 }
212 dst_ptr += (col - col_end) * row_block * C8NUM;
213 }
214 }
215
RowMajor2Row64MajorParallel(const float * src_ptr,float * dst_ptr,int col,int row,int col_start,int col_end)216 void RowMajor2Row64MajorParallel(const float *src_ptr, float *dst_ptr, int col, int row, int col_start, int col_end) {
217 // Not exactly aligned to 32, but aligned to 24 or 16 or 8 If 32 is not met.
218 int row_block_num = UP_DIV(row, C16NUM);
219 int row_block = C4NUM;
220 for (int i = 0; i < row_block_num; i += row_block) {
221 row_block = MSMIN(C4NUM, row_block_num - i); // max_tile = 4
222 int row_remainder = MSMIN(row_block * C16NUM, row - i * C16NUM);
223 dst_ptr += col_start * row_block * C16NUM;
224 for (int oc = col_start; oc < col_end; ++oc) {
225 memcpy(dst_ptr, src_ptr + oc * row + i * C16NUM, row_remainder * sizeof(float));
226 dst_ptr += row_block * C16NUM;
227 }
228 dst_ptr += (col - col_end) * row_block * C16NUM;
229 }
230 }
231
232 #ifdef ENABLE_ARM64
RowMajor2Col12Major_arm64(const float * src_c,float * dst_c,size_t col)233 void RowMajor2Col12Major_arm64(const float *src_c, float *dst_c, size_t col) {
234 size_t stride = col * sizeof(float);
235 asm volatile(
236 "mov x10, %[src_c]\n"
237 "mov x11, %[dst_c]\n"
238
239 "ld1 {v0.4s}, [x10], %[stride]\n"
240 "ld1 {v1.4s}, [x10], %[stride]\n"
241 "ld1 {v2.4s}, [x10], %[stride]\n"
242 "ld1 {v3.4s}, [x10], %[stride]\n"
243
244 "ld1 {v4.4s}, [x10], %[stride]\n"
245 "ld1 {v5.4s}, [x10], %[stride]\n"
246 "ld1 {v6.4s}, [x10], %[stride]\n"
247 "ld1 {v7.4s}, [x10], %[stride]\n"
248
249 "zip1 v12.4s, v0.4s, v1.4s\n"
250 "zip2 v13.4s, v0.4s, v1.4s\n"
251 "zip1 v14.4s, v2.4s, v3.4s\n"
252 "zip2 v15.4s, v2.4s, v3.4s\n"
253
254 "ld1 {v8.4s}, [x10], %[stride]\n"
255 "ld1 {v9.4s}, [x10], %[stride]\n"
256 "ld1 {v10.4s}, [x10], %[stride]\n"
257 "ld1 {v11.4s}, [x10], %[stride]\n"
258
259 "zip1 v16.4s, v4.4s, v5.4s\n"
260 "zip2 v17.4s, v4.4s, v5.4s\n"
261 "zip1 v18.4s, v6.4s, v7.4s\n"
262 "zip2 v19.4s, v6.4s, v7.4s\n"
263
264 "trn1 v20.2d, v12.2d, v14.2d\n"
265 "trn2 v23.2d, v12.2d, v14.2d\n"
266 "trn1 v26.2d, v13.2d, v15.2d\n"
267 "trn2 v29.2d, v13.2d, v15.2d\n"
268
269 "trn1 v21.2d, v16.2d, v18.2d\n"
270 "trn2 v24.2d, v16.2d, v18.2d\n"
271 "trn1 v27.2d, v17.2d, v19.2d\n"
272 "trn2 v30.2d, v17.2d, v19.2d\n"
273
274 "zip1 v12.4s, v8.4s, v9.4s\n"
275 "zip2 v13.4s, v8.4s, v9.4s\n"
276 "zip1 v14.4s, v10.4s, v11.4s\n"
277 "zip2 v15.4s, v10.4s, v11.4s\n"
278
279 "trn1 v22.2d, v12.2d, v14.2d\n"
280 "trn2 v25.2d, v12.2d, v14.2d\n"
281 "trn1 v28.2d, v13.2d, v15.2d\n"
282 "trn2 v31.2d, v13.2d, v15.2d\n"
283
284 "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x11], #64\n"
285 "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x11], #64\n"
286 "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x11], #64\n"
287
288 :
289 : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
290 : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
291 "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
292 "v31");
293 return;
294 }
295 #endif
296 #ifdef ENABLE_ARM32
RowMajor2Col12Major_arm32(const float * src_c,float * dst_c,size_t col)297 void RowMajor2Col12Major_arm32(const float *src_c, float *dst_c, size_t col) {
298 size_t stride = col * sizeof(float);
299 asm volatile(
300 "mov r10, %[src_c]\n"
301 "mov r12, %[dst_c]\n"
302
303 "vld1.32 {q0}, [r10], %[stride]\n"
304 "vld1.32 {q3}, [r10], %[stride]\n"
305 "vld1.32 {q10}, [r10], %[stride]\n"
306 "vld1.32 {q13}, [r10], %[stride]\n"
307
308 "vtrn.32 d0, d6\n"
309 "vtrn.32 d1, d7\n"
310 "vtrn.32 d20, d26\n"
311 "vtrn.32 d21, d27\n"
312
313 "vld1.32 {q1}, [r10], %[stride]\n"
314 "vld1.32 {q8}, [r10], %[stride]\n"
315 "vld1.32 {q11}, [r10], %[stride]\n"
316 "vld1.32 {q14}, [r10], %[stride]\n"
317
318 "vswp d1, d20\n"
319 "vswp d7, d26\n"
320
321 "vld1.32 {q2}, [r10], %[stride]\n"
322 "vld1.32 {q9}, [r10], %[stride]\n"
323 "vld1.32 {q12}, [r10], %[stride]\n"
324 "vld1.32 {q15}, [r10], %[stride]\n"
325
326 "vtrn.32 d2, d16\n"
327 "vtrn.32 d3, d17\n"
328 "vtrn.32 d22, d28\n"
329 "vtrn.32 d23, d29\n"
330
331 "vswp d3, d22\n"
332 "vswp d17, d28\n"
333
334 "vtrn.32 d4, d18\n"
335 "vtrn.32 d5, d19\n"
336 "vtrn.32 d24, d30\n"
337 "vtrn.32 d25, d31\n"
338
339 "vswp d5, d24\n"
340 "vswp d19, d30\n"
341
342 "vst1.32 {q0, q1}, [r12]!\n"
343 "vst1.32 {q2, q3}, [r12]!\n"
344 "vst1.32 {q8, q9}, [r12]!\n"
345 "vst1.32 {q10, q11}, [r12]!\n"
346 "vst1.32 {q12, q13}, [r12]!\n"
347 "vst1.32 {q14, q15}, [r12]!\n"
348
349 :
350 : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
351 : "r10", "r12", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
352 return;
353 }
354 #endif
RowMajor2Col12MajorParallel(const float * src_ptr,float * dst_ptr,int row,int col,int row_start,int row_end)355 void RowMajor2Col12MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) {
356 int ri = (row_start / C12NUM * C12NUM);
357 float *dst_r = dst_ptr + ri * col;
358 const float *src_r = src_ptr + ri * col;
359 for (; ri < (row_end / C12NUM * C12NUM); ri += C12NUM) {
360 int ci = 0;
361 for (; ci < (col / C4NUM * C4NUM); ci += C4NUM) {
362 const float *src_c = src_r + ci;
363 float *dst_c = dst_r + ci * C12NUM;
364 #ifdef ENABLE_ARM64
365 RowMajor2Col12Major_arm64(src_c, dst_c, col);
366 #elif ENABLE_ARM32
367 RowMajor2Col12Major_arm32(src_c, dst_c, col);
368 #elif ENABLE_SSE
369 __m128 src1 = _mm_loadu_ps(src_c);
370 __m128 src2 = _mm_loadu_ps(src_c + col);
371 __m128 src3 = _mm_loadu_ps(src_c + 2 * col);
372 __m128 src4 = _mm_loadu_ps(src_c + 3 * col);
373 src_c += C4NUM * col;
374 __m128 src12L = _mm_unpacklo_ps(src1, src2);
375 __m128 src12H = _mm_unpackhi_ps(src1, src2);
376 __m128 src34L = _mm_unpacklo_ps(src3, src4);
377 __m128 src34H = _mm_unpackhi_ps(src3, src4);
378
379 __m128 dst0 = _mm_movelh_ps(src12L, src34L);
380 __m128 dst3 = _mm_movehl_ps(src34L, src12L);
381 __m128 dst6 = _mm_movelh_ps(src12H, src34H);
382 __m128 dst9 = _mm_movehl_ps(src34H, src12H);
383
384 __m128 src5 = _mm_loadu_ps(src_c);
385 __m128 src6 = _mm_loadu_ps(src_c + col);
386 __m128 src7 = _mm_loadu_ps(src_c + 2 * col);
387 __m128 src8 = _mm_loadu_ps(src_c + 3 * col);
388 src_c += C4NUM * col;
389 __m128 src56L = _mm_unpacklo_ps(src5, src6);
390 __m128 src56H = _mm_unpackhi_ps(src5, src6);
391 __m128 src78L = _mm_unpacklo_ps(src7, src8);
392 __m128 src78H = _mm_unpackhi_ps(src7, src8);
393 __m128 dst1 = _mm_movelh_ps(src56L, src78L);
394 __m128 dst4 = _mm_movehl_ps(src78L, src56L);
395 __m128 dst7 = _mm_movelh_ps(src56H, src78H);
396 __m128 dst10 = _mm_movehl_ps(src78H, src56H);
397
398 __m128 src9 = _mm_loadu_ps(src_c);
399 __m128 src10 = _mm_loadu_ps(src_c + col);
400 __m128 src11 = _mm_loadu_ps(src_c + 2 * col);
401 __m128 src12 = _mm_loadu_ps(src_c + 3 * col);
402 src_c += C4NUM * col;
403 __m128 src910L = _mm_unpacklo_ps(src9, src10);
404 __m128 src910H = _mm_unpackhi_ps(src9, src10);
405 __m128 src1112L = _mm_unpacklo_ps(src11, src12);
406 __m128 src1112H = _mm_unpackhi_ps(src11, src12);
407 __m128 dst2 = _mm_movelh_ps(src910L, src1112L);
408 __m128 dst5 = _mm_movehl_ps(src1112L, src910L);
409 __m128 dst8 = _mm_movelh_ps(src910H, src1112H);
410 __m128 dst11 = _mm_movehl_ps(src1112H, src910H);
411
412 _mm_storeu_ps(dst_c, dst0);
413 _mm_storeu_ps(dst_c + 4, dst1);
414 _mm_storeu_ps(dst_c + 8, dst2);
415 _mm_storeu_ps(dst_c + 12, dst3);
416 _mm_storeu_ps(dst_c + 16, dst4);
417 _mm_storeu_ps(dst_c + 20, dst5);
418 _mm_storeu_ps(dst_c + 24, dst6);
419 _mm_storeu_ps(dst_c + 28, dst7);
420 _mm_storeu_ps(dst_c + 32, dst8);
421 _mm_storeu_ps(dst_c + 36, dst9);
422 _mm_storeu_ps(dst_c + 40, dst10);
423 _mm_storeu_ps(dst_c + 44, dst11);
424 #else
425 for (int tr = 0; tr < C12NUM; tr++) {
426 for (int tc = 0; tc < C4NUM; tc++) {
427 dst_c[tc * C12NUM + tr] = src_c[tr * col + tc];
428 }
429 }
430 #endif
431 }
432 for (; ci < col; ci++) {
433 const float *src_c = src_r + ci;
434 float *dst_c = dst_r + ci * C12NUM;
435 for (int i = 0; i < C12NUM; i++) {
436 dst_c[i] = src_c[i * col];
437 }
438 }
439 src_r += C12NUM * col;
440 dst_r += C12NUM * col;
441 }
442 if (row_end == row) {
443 for (; ri < row_end; ri++, dst_r++, src_r += col) {
444 for (int i = 0; i < col; i++) {
445 dst_r[i * C12NUM] = src_r[i];
446 }
447 }
448 for (; ri < UP_ROUND(row, C12NUM); ri++, dst_r++) {
449 for (int i = 0; i < col; i++) {
450 dst_r[i * C12NUM] = 0;
451 }
452 }
453 }
454 }
455
456 #ifdef ENABLE_ARM64
RowMajor2Col8Major_arm64(const float * src_c,float * dst_c,size_t col)457 void RowMajor2Col8Major_arm64(const float *src_c, float *dst_c, size_t col) {
458 size_t stride = col * sizeof(float);
459 asm volatile(
460 "mov x10, %[src_c]\n"
461 "mov x11, %[dst_c]\n"
462
463 "ld1 {v0.4s, v1.4s}, [x10], %[stride]\n"
464 "ld1 {v2.4s, v3.4s}, [x10], %[stride]\n"
465 "ld1 {v4.4s, v5.4s}, [x10], %[stride]\n"
466 "ld1 {v6.4s, v7.4s}, [x10], %[stride]\n"
467
468 "zip1 v8.4s, v0.4s, v2.4s\n"
469 "zip2 v9.4s, v0.4s, v2.4s\n"
470 "zip1 v10.4s, v4.4s, v6.4s\n"
471 "zip2 v11.4s, v4.4s, v6.4s\n"
472
473 "ld1 {v16.4s, v17.4s}, [x10], %[stride]\n"
474 "ld1 {v18.4s, v19.4s}, [x10], %[stride]\n"
475 "ld1 {v20.4s, v21.4s}, [x10], %[stride]\n"
476 "ld1 {v22.4s, v23.4s}, [x10], %[stride]\n"
477
478 "zip1 v12.4s, v1.4s, v3.4s\n"
479 "zip2 v13.4s, v1.4s, v3.4s\n"
480 "zip1 v14.4s, v5.4s, v7.4s\n"
481 "zip2 v15.4s, v5.4s, v7.4s\n"
482
483 "trn1 v0.2d, v8.2d, v10.2d\n"
484 "trn2 v1.2d, v8.2d, v10.2d\n"
485 "trn1 v2.2d, v9.2d, v11.2d\n"
486 "trn2 v3.2d, v9.2d, v11.2d\n"
487
488 "zip1 v24.4s, v16.4s, v18.4s\n"
489 "zip2 v25.4s, v16.4s, v18.4s\n"
490 "zip1 v26.4s, v20.4s, v22.4s\n"
491 "zip2 v27.4s, v20.4s, v22.4s\n"
492
493 "trn1 v4.2d, v12.2d, v14.2d\n"
494 "trn2 v5.2d, v12.2d, v14.2d\n"
495 "trn1 v6.2d, v13.2d, v15.2d\n"
496 "trn2 v7.2d, v13.2d, v15.2d\n"
497
498 "zip1 v28.4s, v17.4s, v19.4s\n"
499 "zip2 v29.4s, v17.4s, v19.4s\n"
500 "zip1 v30.4s, v21.4s, v23.4s\n"
501 "zip2 v31.4s, v21.4s, v23.4s\n"
502
503 "trn1 v16.2d, v24.2d, v26.2d\n"
504 "trn2 v17.2d, v24.2d, v26.2d\n"
505 "trn1 v18.2d, v25.2d, v27.2d\n"
506 "trn2 v19.2d, v25.2d, v27.2d\n"
507
508 "trn1 v20.2d, v28.2d, v30.2d\n"
509 "trn2 v21.2d, v28.2d, v30.2d\n"
510 "trn1 v22.2d, v29.2d, v31.2d\n"
511 "trn2 v23.2d, v29.2d, v31.2d\n"
512
513 "st1 {v0.4s}, [x11], #16\n"
514 "st1 {v16.4s}, [x11], #16\n"
515 "st1 {v1.4s}, [x11], #16\n"
516 "st1 {v17.4s}, [x11], #16\n"
517 "st1 {v2.4s}, [x11], #16\n"
518 "st1 {v18.4s}, [x11], #16\n"
519 "st1 {v3.4s}, [x11], #16\n"
520 "st1 {v19.4s}, [x11], #16\n"
521 "st1 {v4.4s}, [x11], #16\n"
522 "st1 {v20.4s}, [x11], #16\n"
523 "st1 {v5.4s}, [x11], #16\n"
524 "st1 {v21.4s}, [x11], #16\n"
525 "st1 {v6.4s}, [x11], #16\n"
526 "st1 {v22.4s}, [x11], #16\n"
527 "st1 {v7.4s}, [x11], #16\n"
528 "st1 {v23.4s}, [x11], #16\n"
529
530 :
531 : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
532 : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
533 "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
534 "v31");
535 return;
536 }
537 #endif
538 #ifdef ENABLE_ARM32
539 #ifndef SUPPORT_NNIE
RowMajor2Col8Major_arm32(const float * src_c,float * dst_c,size_t col)540 void RowMajor2Col8Major_arm32(const float *src_c, float *dst_c, size_t col) {
541 size_t stride = col * sizeof(float);
542 asm volatile(
543 "mov r10, %[src_c]\n"
544 "mov r11, %[dst_c]\n"
545
546 "vld1.32 {q0}, [r10], %[stride]\n"
547 "vld1.32 {q2}, [r10], %[stride]\n"
548 "vld1.32 {q4}, [r10], %[stride]\n"
549 "vld1.32 {q6}, [r10], %[stride]\n"
550
551 "vtrn.32 d0, d4\n"
552 "vtrn.32 d1, d5\n"
553 "vtrn.32 d8, d12\n"
554 "vtrn.32 d9, d13\n"
555
556 "vld1.32 {q1}, [r10], %[stride]\n"
557 "vld1.32 {q3}, [r10], %[stride]\n"
558 "vld1.32 {q5}, [r10], %[stride]\n"
559 "vld1.32 {q7}, [r10], %[stride]\n"
560
561 "vswp d1, d8\n"
562 "vswp d5, d12\n"
563
564 "vtrn.32 d2, d6\n"
565 "vtrn.32 d3, d7\n"
566 "vtrn.32 d10, d14\n"
567 "vtrn.32 d11, d15\n"
568
569 "vswp d3, d10\n"
570 "vswp d7, d14\n"
571
572 "vst1.32 {q0, q1}, [r11]!\n"
573 "vst1.32 {q2, q3}, [r11]!\n"
574 "vst1.32 {q4, q5}, [r11]!\n"
575 "vst1.32 {q6, q7}, [r11]!\n"
576
577 :
578 : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
579 : "r10", "r11", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7");
580 return;
581 }
582 #else
RowMajor2Col8Major_arm32(const float * src_c,float * dst_c,size_t col)583 void RowMajor2Col8Major_arm32(const float *src_c, float *dst_c, size_t col) {
584 size_t stride = col * sizeof(float);
585 asm volatile(
586 "mov r10, %[src_c]\n"
587 "mov r7, %[dst_c]\n"
588
589 "vld1.32 {q0}, [r10], %[stride]\n"
590 "vld1.32 {q2}, [r10], %[stride]\n"
591 "vld1.32 {q4}, [r10], %[stride]\n"
592 "vld1.32 {q6}, [r10], %[stride]\n"
593
594 "vtrn.32 d0, d4\n"
595 "vtrn.32 d1, d5\n"
596 "vtrn.32 d8, d12\n"
597 "vtrn.32 d9, d13\n"
598
599 "vld1.32 {q1}, [r10], %[stride]\n"
600 "vld1.32 {q3}, [r10], %[stride]\n"
601 "vld1.32 {q5}, [r10], %[stride]\n"
602 "vld1.32 {q7}, [r10], %[stride]\n"
603
604 "vswp d1, d8\n"
605 "vswp d5, d12\n"
606
607 "vtrn.32 d2, d6\n"
608 "vtrn.32 d3, d7\n"
609 "vtrn.32 d10, d14\n"
610 "vtrn.32 d11, d15\n"
611
612 "vswp d3, d10\n"
613 "vswp d7, d14\n"
614
615 "vst1.32 {q0, q1}, [r7]!\n"
616 "vst1.32 {q2, q3}, [r7]!\n"
617 "vst1.32 {q4, q5}, [r7]!\n"
618 "vst1.32 {q6, q7}, [r7]!\n"
619
620 :
621 : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
622 : "r10", "r7", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7");
623 return;
624 }
625 #endif
626 #endif
RowMajor2Col8MajorParallel(const float * src_ptr,float * dst_ptr,int row,int col,int row_start,int row_end)627 void RowMajor2Col8MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) {
628 int row8 = row_end / C8NUM * C8NUM;
629 #ifdef ENABLE_ARM64
630 int col_skip = col / C8NUM * C8NUM;
631 int skip_size = C8NUM;
632 #else
633 int col_skip = col / C4NUM * C4NUM;
634 int skip_size = C4NUM;
635 #endif
636 int ri = (row_start / C8NUM * C8NUM);
637 const float *src_r = src_ptr + ri * col;
638 float *dst_r = dst_ptr + ri * col;
639
640 for (; ri < row8; ri += C8NUM) {
641 int ci = 0;
642 for (; ci < col_skip; ci += skip_size) {
643 const float *src_c = src_r + ci;
644 float *dst_c = dst_r + ci * C8NUM;
645
646 #ifdef ENABLE_ARM64
647 RowMajor2Col8Major_arm64(src_c, dst_c, col);
648 #elif ENABLE_ARM32
649 RowMajor2Col8Major_arm32(src_c, dst_c, col);
650 #elif ENABLE_SSE
651 __m128 src1 = _mm_loadu_ps(src_c);
652 __m128 src2 = _mm_loadu_ps(src_c + col);
653 __m128 src3 = _mm_loadu_ps(src_c + 2 * col);
654 __m128 src4 = _mm_loadu_ps(src_c + 3 * col);
655 src_c += C4NUM * col;
656 __m128 src12L = _mm_unpacklo_ps(src1, src2); // x5
657 __m128 src12H = _mm_unpackhi_ps(src1, src2); // x1
658 __m128 src34L = _mm_unpacklo_ps(src3, src4); // x
659 __m128 src34H = _mm_unpackhi_ps(src3, src4);
660 _mm_storeu_ps(dst_c, _mm_movelh_ps(src12L, src34L));
661 _mm_storeu_ps(dst_c + C8NUM, _mm_movehl_ps(src34L, src12L));
662 _mm_storeu_ps(dst_c + C16NUM, _mm_movelh_ps(src12H, src34H));
663 _mm_storeu_ps(dst_c + C24NUM, _mm_movehl_ps(src34H, src12H));
664
665 __m128 src5 = _mm_loadu_ps(src_c);
666 __m128 src6 = _mm_loadu_ps(src_c + col);
667 __m128 src7 = _mm_loadu_ps(src_c + 2 * col);
668 __m128 src8 = _mm_loadu_ps(src_c + 3 * col);
669 src_c += C4NUM * col;
670 __m128 src56L = _mm_unpacklo_ps(src5, src6);
671 __m128 src56H = _mm_unpackhi_ps(src5, src6);
672 __m128 src78L = _mm_unpacklo_ps(src7, src8);
673 __m128 src78H = _mm_unpackhi_ps(src7, src8);
674 _mm_storeu_ps(dst_c + C4NUM, _mm_movelh_ps(src56L, src78L));
675 _mm_storeu_ps(dst_c + C12NUM, _mm_movehl_ps(src78L, src56L));
676 _mm_storeu_ps(dst_c + 20, _mm_movelh_ps(src56H, src78H));
677 _mm_storeu_ps(dst_c + 28, _mm_movehl_ps(src78H, src56H));
678 #else
679 for (int tr = 0; tr < 8; tr++) {
680 for (int tc = 0; tc < 4; tc++) {
681 dst_c[tc * 8 + tr] = src_c[tr * col + tc];
682 }
683 }
684 #endif
685 }
686 for (; ci < col; ci++) {
687 const float *src_c = src_r + ci;
688 float *dst_c = dst_r + ci * C8NUM;
689 for (int i = 0; i < C8NUM; i++) {
690 dst_c[i] = src_c[i * col];
691 }
692 }
693 src_r += C8NUM * col;
694 dst_r += C8NUM * col;
695 }
696 if (row_end == row) {
697 for (; ri < row; ri++, src_r += col, dst_r++) {
698 for (int i = 0; i < col; i++) {
699 dst_r[i * C8NUM] = src_r[i];
700 }
701 }
702
703 for (; ri < UP_ROUND(row, C8NUM); ri++, dst_r++) {
704 for (int i = 0; i < col; i++) {
705 dst_r[i * C8NUM] = 0;
706 }
707 }
708 }
709 }
710
RowMajor2Col16MajorParallel(const float * src_ptr,float * dst_ptr,int row,int col,int row_start,int row_end)711 void RowMajor2Col16MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) {
712 int row16 = row_end / C16NUM * C16NUM;
713 int ri = row_start / C16NUM * C16NUM;
714 int col8 = col / C8NUM * C8NUM;
715 const float *src_r = src_ptr + ri * col;
716 float *dst_r = dst_ptr + ri * col;
717
718 for (; ri < row16; ri += C16NUM) {
719 int ci = 0;
720 for (; ci < col8; ci += C8NUM) {
721 const float *src_c = src_r + ci;
722 float *dst_c = dst_r + ci * C16NUM;
723 #ifdef ENABLE_AVX
724 Transpose8X8Fp32Avx(src_c, dst_c, col, C16NUM);
725 Transpose8X8Fp32Avx(src_c + C8NUM * col, dst_c + C8NUM, col, C16NUM);
726 #else
727 for (int tr = 0; tr < C16NUM; tr++) {
728 for (int tc = 0; tc < C8NUM; tc++) {
729 dst_c[tc * C16NUM + tr] = src_c[tr * col + tc];
730 }
731 }
732 #endif
733 }
734 for (; ci < col; ci++) {
735 const float *src_c = src_r + ci;
736 float *dst_c = dst_r + ci * C16NUM;
737 for (int i = 0; i < C16NUM; i++) {
738 dst_c[i] = src_c[i * col];
739 }
740 }
741 src_r += C16NUM * col;
742 dst_r += C16NUM * col;
743 }
744 if (row_end == row) {
745 for (; ri < row; ri++) {
746 for (int i = 0; i < col; i++) {
747 dst_r[i * C16NUM] = src_r[i];
748 }
749 src_r += col;
750 dst_r += 1;
751 }
752 int total_row = UP_ROUND(row, C16NUM);
753 for (; ri < total_row; ri++) {
754 for (int i = 0; i < col; i++) {
755 dst_r[i * C16NUM] = 0;
756 }
757 dst_r += 1;
758 }
759 }
760 }
761
RowMajor2Col32MajorParallel(const float * src_ptr,float * dst_ptr,int row,int col,int row_start,int row_end)762 void RowMajor2Col32MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) {
763 // Not exactly aligned to 32, but aligned to 24 or 16 or 8 If 32 is not met.
764 #ifdef ENABLE_AVX
765 int col8 = col / C8NUM * C8NUM;
766 #endif
767 int all_block_num = UP_DIV(row, C8NUM);
768 int cur_block = C4NUM;
769 row_start = UP_DIV(row_start, C8NUM);
770 row_end = UP_DIV(row_end, C8NUM);
771 for (int i = UP_ROUND(row_start, C4NUM); i < row_end; i += cur_block) {
772 cur_block = MSMIN(C4NUM, all_block_num - i); // max_tile = 4
773 int dst_stride = cur_block * C8NUM;
774 int row_num = MSMIN(dst_stride, row - i * C8NUM);
775 #ifdef ENABLE_AVX
776 int row8_num = row_num / C8NUM * C8NUM;
777 #endif
778 const float *src = src_ptr + i * C8NUM * col;
779 float *dst = dst_ptr + i * C8NUM * col;
780 int r = 0;
781 #ifdef ENABLE_AVX
782 for (; r < row8_num; r += C8NUM) {
783 int c = 0;
784 for (; c < col8; c += C8NUM) {
785 Transpose8X8Fp32Avx(src + r * col + c, dst + c * dst_stride + r, col, dst_stride);
786 }
787 for (; c < col; ++c) {
788 for (int k = 0; k < C8NUM; ++k) {
789 dst[c * dst_stride + r + k] = src[r * col + c + k * col];
790 }
791 }
792 }
793 #endif
794 for (; r < row_num; r++) {
795 for (int c = 0; c < col; ++c) {
796 dst[c * dst_stride + r] = src[r * col + c];
797 }
798 }
799 }
800 }
801
RowMajor2Col64MajorParallel(const float * src_ptr,float * dst_ptr,int row,int col,int row_start,int row_end)802 void RowMajor2Col64MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) {
803 // Not exactly aligned to 64, but aligned to 48 or 32 or 16 If 64 is not met.
804 int all_block_num = UP_DIV(row, C16NUM);
805 int cur_block = C4NUM;
806 row_start = UP_DIV(row_start, C16NUM);
807 row_end = UP_DIV(row_end, C16NUM);
808 for (int i = UP_ROUND(row_start, C4NUM); i < row_end; i += cur_block) {
809 cur_block = MSMIN(C4NUM, all_block_num - i); // max_tile = 4
810 int dst_stride = cur_block * C16NUM;
811 int row_num = MSMIN(dst_stride, row - i * C16NUM);
812 const float *src = src_ptr + i * C16NUM * col;
813 float *dst = dst_ptr + i * C16NUM * col;
814 int r = 0;
815 for (; r < row_num; r++) {
816 for (int c = 0; c < col; ++c) {
817 dst[c * dst_stride + r] = src[r * col + c];
818 }
819 }
820 }
821 }
822
RowMajor2Col6MajorParallel(const float * src_ptr,float * dst_ptr,int row,int col,int row_start,int row_end)823 void RowMajor2Col6MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) {
824 int row6 = row_end / C6NUM * C6NUM;
825 int ri = row_start / C6NUM * C6NUM;
826 int col8 = col / C8NUM * C8NUM;
827 const float *src_r = src_ptr + ri * col;
828 float *dst_r = dst_ptr + ri * col;
829
830 for (; ri < row6; ri += C6NUM) {
831 int ci = 0;
832 for (; ci < col8; ci += C8NUM) {
833 const float *src_c = src_r + ci;
834 float *dst_c = dst_r + ci * C6NUM;
835
836 #ifdef ENABLE_AVX
837 __m256 src0 = _mm256_loadu_ps(src_c);
838 __m256 src1 = _mm256_loadu_ps(src_c + col);
839 __m256 src2 = _mm256_loadu_ps(src_c + 2 * col);
840 __m256 src3 = _mm256_loadu_ps(src_c + 3 * col);
841 __m256 src4 = _mm256_loadu_ps(src_c + 4 * col);
842 __m256 src5 = _mm256_loadu_ps(src_c + 5 * col);
843 __m256 trans0 = _mm256_unpacklo_ps(src0, src1);
844 __m256 trans1 = _mm256_unpacklo_ps(src2, src3);
845 __m256 trans2 = _mm256_unpacklo_ps(src4, src5);
846 __m256 trans3 = _mm256_unpackhi_ps(src0, src1);
847 __m256 trans4 = _mm256_unpackhi_ps(src2, src3);
848 __m256 trans5 = _mm256_unpackhi_ps(src4, src5);
849 __m128 lo0 = _mm256_castps256_ps128(trans0);
850 __m128 lo1 = _mm256_castps256_ps128(trans1);
851 __m128 lo2 = _mm256_castps256_ps128(trans2);
852 __m128 lo3 = _mm256_castps256_ps128(trans3);
853 __m128 lo4 = _mm256_castps256_ps128(trans4);
854 __m128 lo5 = _mm256_castps256_ps128(trans5);
855 __m128 hi0 = _mm256_extractf128_ps(trans0, 1);
856 __m128 hi1 = _mm256_extractf128_ps(trans1, 1);
857 __m128 hi2 = _mm256_extractf128_ps(trans2, 1);
858 __m128 hi3 = _mm256_extractf128_ps(trans3, 1);
859 __m128 hi4 = _mm256_extractf128_ps(trans4, 1);
860 __m128 hi5 = _mm256_extractf128_ps(trans5, 1);
861 __m128 res0 = _mm_shuffle_ps(lo0, lo1, _MM_SHUFFLE(1, 0, 1, 0));
862 __m128 res1 = _mm_shuffle_ps(lo2, lo0, _MM_SHUFFLE(3, 2, 1, 0));
863 __m128 res2 = _mm_shuffle_ps(lo1, lo2, _MM_SHUFFLE(3, 2, 3, 2));
864 __m128 res3 = _mm_shuffle_ps(lo3, lo4, _MM_SHUFFLE(1, 0, 1, 0));
865 __m128 res4 = _mm_shuffle_ps(lo5, lo3, _MM_SHUFFLE(3, 2, 1, 0));
866 __m128 res5 = _mm_shuffle_ps(lo4, lo5, _MM_SHUFFLE(3, 2, 3, 2));
867 __m128 res6 = _mm_shuffle_ps(hi0, hi1, _MM_SHUFFLE(1, 0, 1, 0));
868 __m128 res7 = _mm_shuffle_ps(hi2, hi0, _MM_SHUFFLE(3, 2, 1, 0));
869 __m128 res8 = _mm_shuffle_ps(hi1, hi2, _MM_SHUFFLE(3, 2, 3, 2));
870 __m128 res9 = _mm_shuffle_ps(hi3, hi4, _MM_SHUFFLE(1, 0, 1, 0));
871 __m128 res10 = _mm_shuffle_ps(hi5, hi3, _MM_SHUFFLE(3, 2, 1, 0));
872 __m128 res11 = _mm_shuffle_ps(hi4, hi5, _MM_SHUFFLE(3, 2, 3, 2));
873 _mm_storeu_ps(dst_c, res0);
874 _mm_storeu_ps(dst_c + 4, res1);
875 _mm_storeu_ps(dst_c + 8, res2);
876 _mm_storeu_ps(dst_c + 12, res3);
877 _mm_storeu_ps(dst_c + 16, res4);
878 _mm_storeu_ps(dst_c + 20, res5);
879 _mm_storeu_ps(dst_c + 24, res6);
880 _mm_storeu_ps(dst_c + 28, res7);
881 _mm_storeu_ps(dst_c + 32, res8);
882 _mm_storeu_ps(dst_c + 36, res9);
883 _mm_storeu_ps(dst_c + 40, res10);
884 _mm_storeu_ps(dst_c + 44, res11);
885 #else
886 for (int tr = 0; tr < C6NUM; tr++) {
887 for (int tc = 0; tc < C8NUM; tc++) {
888 dst_c[tc * C6NUM + tr] = src_c[tr * col + tc];
889 }
890 }
891 #endif
892 }
893 for (; ci < col; ci++) {
894 const float *src_c = src_r + ci;
895 float *dst_c = dst_r + ci * C6NUM;
896 for (int i = 0; i < C6NUM; i++) {
897 dst_c[i] = src_c[i * col];
898 }
899 }
900 src_r += C6NUM * col;
901 dst_r += C6NUM * col;
902 }
903
904 if (row_end == row) {
905 for (; ri < row_end; ri++) {
906 for (int i = 0; i < col; i++) {
907 dst_r[i * C6NUM] = src_r[i];
908 }
909 src_r += col;
910 dst_r += 1;
911 }
912
913 int totalRow = UP_ROUND(row, C6NUM);
914 for (; ri < totalRow; ri++) {
915 for (int i = 0; i < col; i++) {
916 dst_r[i * C6NUM] = 0;
917 }
918 dst_r += 1;
919 }
920 }
921 }
922
RowMajor2Col4MajorParallel(const float * src_ptr,float * dst_ptr,int row,int col,int row_start,int row_end)923 void RowMajor2Col4MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) {
924 int row4 = row_end / C4NUM * C4NUM;
925 int ri = row_start / C4NUM * C4NUM;
926 int col4 = col / C4NUM * C4NUM;
927 const float *src_r = src_ptr + ri * col;
928 float *dst_r = dst_ptr + ri * col;
929
930 for (; ri < row4; ri += C4NUM) {
931 int ci = 0;
932 for (; ci < col4; ci += C4NUM) {
933 const float *src_c = src_r + ci;
934 float *dst_c = dst_r + ci * C4NUM;
935
936 #ifdef ENABLE_ARM32
937 int stride = col * 4;
938 asm volatile(
939 "mov r10, %[src_c]\n"
940 "mov r12, %[dst_c]\n"
941
942 "vld1.32 {q0}, [r10], %[stride]\n"
943 "vld1.32 {q1}, [r10], %[stride]\n"
944 "vld1.32 {q2}, [r10], %[stride]\n"
945 "vld1.32 {q3}, [r10], %[stride]\n"
946
947 "vtrn.32 d0, d2\n"
948 "vtrn.32 d1, d3\n"
949 "vtrn.32 d4, d6\n"
950 "vtrn.32 d5, d7\n"
951
952 "vswp d1, d4\n"
953 "vswp d3, d6\n"
954
955 "vst1.32 {q0}, [r12]!\n"
956 "vst1.32 {q1}, [r12]!\n"
957 "vst1.32 {q2}, [r12]!\n"
958 "vst1.32 {q3}, [r12]!\n"
959
960 :
961 : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
962 : "r10", "r12", "q0", "q1", "q2", "q3");
963 #elif ENABLE_SSE
964 __m128 src1 = _mm_loadu_ps(src_c);
965 __m128 src2 = _mm_loadu_ps(src_c + col);
966 __m128 src3 = _mm_loadu_ps(src_c + 2 * col);
967 __m128 src4 = _mm_loadu_ps(src_c + 3 * col);
968 src_c += C4NUM * col;
969 __m128 src12L = _mm_unpacklo_ps(src1, src2);
970 __m128 src12H = _mm_unpackhi_ps(src1, src2);
971 __m128 src34L = _mm_unpacklo_ps(src3, src4);
972 __m128 src34H = _mm_unpackhi_ps(src3, src4);
973
974 __m128 dst0 = _mm_movelh_ps(src12L, src34L);
975 __m128 dst1 = _mm_movehl_ps(src34L, src12L);
976 __m128 dst2 = _mm_movelh_ps(src12H, src34H);
977 __m128 dst3 = _mm_movehl_ps(src34H, src12H);
978
979 _mm_storeu_ps(dst_c, dst0);
980 _mm_storeu_ps(dst_c + 4, dst1);
981 _mm_storeu_ps(dst_c + 8, dst2);
982 _mm_storeu_ps(dst_c + 12, dst3);
983 #else
984 for (size_t tr = 0; tr < C4NUM; tr++) {
985 for (size_t tc = 0; tc < C4NUM; tc++) {
986 dst_c[tc * C4NUM + tr] = src_c[tr * col + tc];
987 }
988 }
989 #endif
990 }
991 for (; ci < col; ci++) {
992 const float *src_c = src_r + ci;
993 float *dst_c = dst_r + ci * C4NUM;
994 for (int i = 0; i < C4NUM; i++) {
995 dst_c[i] = src_c[i * col];
996 }
997 }
998 src_r += C4NUM * col;
999 dst_r += C4NUM * col;
1000 }
1001 if (row_end == row) {
1002 for (; ri < row; ri++) {
1003 for (int i = 0; i < col; i++) {
1004 dst_r[i * C4NUM] = src_r[i];
1005 }
1006 src_r += col;
1007 dst_r += 1;
1008 }
1009
1010 int total_row = UP_ROUND(row, C4NUM);
1011 for (; ri < total_row; ri++) {
1012 for (int i = 0; i < col; i++) {
1013 dst_r[i * C4NUM] = 0;
1014 }
1015 dst_r += 1;
1016 }
1017 }
1018 }
1019
RowMajor2ColMajor(const float * src_ptr,float * dst_ptr,int row,int col)1020 void RowMajor2ColMajor(const float *src_ptr, float *dst_ptr, int row, int col) {
1021 RowMajor2ColMajorParallel(src_ptr, dst_ptr, row, col, 0, row);
1022 }
RowMajor2RowMajor(const float * src_ptr,float * dst_ptr,int row,int col)1023 void RowMajor2RowMajor(const float *src_ptr, float *dst_ptr, int row, int col) {
1024 RowMajor2RowMajorParallel(src_ptr, dst_ptr, row, col, 0, row);
1025 }
RowMajor2Row4Major(const float * src_ptr,float * dst_ptr,int row,int col)1026 void RowMajor2Row4Major(const float *src_ptr, float *dst_ptr, int row, int col) {
1027 RowMajor2Row4MajorParallel(src_ptr, dst_ptr, row, col, 0, row);
1028 }
RowMajor2Row6Major(const float * src_ptr,float * dst_ptr,int row,int col)1029 void RowMajor2Row6Major(const float *src_ptr, float *dst_ptr, int row, int col) {
1030 RowMajor2Row6MajorParallel(src_ptr, dst_ptr, row, col, 0, row);
1031 }
RowMajor2Row8Major(const float * src_ptr,float * dst_ptr,int row,int col)1032 void RowMajor2Row8Major(const float *src_ptr, float *dst_ptr, int row, int col) {
1033 RowMajor2Row8MajorParallel(src_ptr, dst_ptr, row, col, 0, row);
1034 }
RowMajor2Row12Major(const float * src_ptr,float * dst_ptr,int row,int col)1035 void RowMajor2Row12Major(const float *src_ptr, float *dst_ptr, int row, int col) {
1036 RowMajor2Row12MajorParallel(src_ptr, dst_ptr, row, col, 0, row);
1037 }
RowMajor2Row16Major(const float * src_ptr,float * dst_ptr,int row,int col)1038 void RowMajor2Row16Major(const float *src_ptr, float *dst_ptr, int row, int col) {
1039 RowMajor2Row16MajorParallel(src_ptr, dst_ptr, row, col, 0, row);
1040 }
RowMajor2Row32Major(const float * src_ptr,float * dst_ptr,int col,int row)1041 void RowMajor2Row32Major(const float *src_ptr, float *dst_ptr, int col, int row) {
1042 RowMajor2Row32MajorParallel(src_ptr, dst_ptr, col, row, 0, col);
1043 }
RowMajor2Row64Major(const float * src_ptr,float * dst_ptr,int col,int row)1044 void RowMajor2Row64Major(const float *src_ptr, float *dst_ptr, int col, int row) {
1045 RowMajor2Row64MajorParallel(src_ptr, dst_ptr, col, row, 0, col);
1046 }
RowMajor2Col12Major(const float * src_ptr,float * dst_ptr,int row,int col)1047 void RowMajor2Col12Major(const float *src_ptr, float *dst_ptr, int row, int col) {
1048 RowMajor2Col12MajorParallel(src_ptr, dst_ptr, row, col, 0, row);
1049 }
RowMajor2Col8Major(const float * src_ptr,float * dst_ptr,int row,int col)1050 void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, int row, int col) {
1051 RowMajor2Col8MajorParallel(src_ptr, dst_ptr, row, col, 0, row);
1052 }
RowMajor2Col16Major(const float * src_ptr,float * dst_ptr,int row,int col)1053 void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, int row, int col) {
1054 RowMajor2Col16MajorParallel(src_ptr, dst_ptr, row, col, 0, row);
1055 }
RowMajor2Col32Major(const float * src_ptr,float * dst_ptr,int row,int col)1056 void RowMajor2Col32Major(const float *src_ptr, float *dst_ptr, int row, int col) {
1057 RowMajor2Col32MajorParallel(src_ptr, dst_ptr, row, col, 0, row);
1058 }
RowMajor2Col64Major(const float * src_ptr,float * dst_ptr,int row,int col)1059 void RowMajor2Col64Major(const float *src_ptr, float *dst_ptr, int row, int col) {
1060 RowMajor2Col64MajorParallel(src_ptr, dst_ptr, row, col, 0, row);
1061 }
RowMajor2Col6Major(const float * src_ptr,float * dst_ptr,int row,int col)1062 void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, int row, int col) {
1063 RowMajor2Col6MajorParallel(src_ptr, dst_ptr, row, col, 0, row);
1064 }
RowMajor2Col4Major(const float * src_ptr,float * dst_ptr,int row,int col)1065 void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, int row, int col) {
1066 RowMajor2Col4MajorParallel(src_ptr, dst_ptr, row, col, 0, row);
1067 }
1068
PackNHWCToNC4HW4Fp32(const void * src,void * dst,int batch,int plane,int channel)1069 void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
1070 int c4 = UP_DIV(channel, C4NUM);
1071 int c4_minus = c4 - 1;
1072 for (int b = 0; b < batch; b++) {
1073 int src_oc_offset = b * plane * channel;
1074 int dst_oc_offset = b * plane * c4 * C4NUM;
1075 for (int k = 0; k < plane; k++) {
1076 int src_kernel_offset = src_oc_offset + k * channel;
1077 int dst_kernel_offset = dst_oc_offset + k * C4NUM;
1078 for (int j = 0; j < c4_minus; ++j) {
1079 int src_ic_offset = src_kernel_offset + j * C4NUM;
1080 int dst_ic_offset = dst_kernel_offset + j * plane * C4NUM;
1081 #ifdef ENABLE_ARM
1082 vst1q_f32((float *)dst + dst_ic_offset, vld1q_f32((float *)src + src_ic_offset));
1083 #else
1084 for (int i = 0; i < C4NUM; ++i) {
1085 ((float *)dst + dst_ic_offset)[i] = ((float *)src + src_ic_offset)[i];
1086 }
1087 #endif
1088 }
1089 int tmp_c = c4_minus * C4NUM;
1090 int tmp_c_offset = tmp_c * plane;
1091 int res_c = channel - tmp_c;
1092 if (res_c > channel) {
1093 return;
1094 }
1095 for (int l = 0; l < res_c; ++l) {
1096 int src_ic_offset = src_kernel_offset + tmp_c + l;
1097 int dst_ic_offset = dst_kernel_offset + tmp_c_offset + l;
1098 ((float *)dst + dst_ic_offset)[0] = ((float *)src + src_ic_offset)[0];
1099 }
1100 }
1101 }
1102 }
1103
PackNCHWToNC4HW4Fp32(const void * src,void * dst,int batch,int plane,int channel)1104 void PackNCHWToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
1105 int c4 = UP_DIV(channel, C4NUM);
1106 for (int b = 0; b < batch; b++) {
1107 int src_offset = b * plane * channel;
1108 int dst_offset = b * plane * c4 * C4NUM;
1109 RowMajor2Col4Major((const float *)src + src_offset, (float *)dst + dst_offset, channel, plane);
1110 }
1111 }
1112
PackNHWCToNHWC4Fp32(const void * src,void * dst,int batch,int plane,int channel)1113 void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
1114 int oc_block = UP_DIV(channel, C4NUM);
1115 int oc_block_channel = oc_block * C4NUM;
1116 int ic_remainder_ = channel % C4NUM;
1117 if (ic_remainder_ != 0) {
1118 for (int b = 0; b < batch; b++) {
1119 int dst_batch_offset = b * oc_block_channel * plane;
1120 int batch_offset = b * channel * plane;
1121 for (int i = 0; i < plane; i++) {
1122 float *dst_per_plane = (float *)dst + dst_batch_offset + i * oc_block_channel;
1123 memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float));
1124 memset(dst_per_plane + channel, 0, (oc_block_channel - channel) * sizeof(float));
1125 }
1126 }
1127 } else {
1128 size_t ori_input_size = batch * plane * channel * sizeof(float);
1129 memcpy((float *)dst, (float *)src, ori_input_size);
1130 }
1131 }
1132
PackNHWCToNHWCXFp32(const void * src,void * dst,int batch,int plane,int channel,int oc_tile)1133 void PackNHWCToNHWCXFp32(const void *src, void *dst, int batch, int plane, int channel, int oc_tile) {
1134 int oc_block = UP_DIV(channel, oc_tile);
1135 int oc_block_channel = oc_block * oc_tile;
1136 int ic_remainder_ = channel % oc_tile;
1137 if (ic_remainder_ != 0) {
1138 for (int b = 0; b < batch; b++) {
1139 int dst_batch_offset = b * oc_block_channel * plane;
1140 int batch_offset = b * channel * plane;
1141 for (int i = 0; i < plane; i++) {
1142 float *dst_per_plane = (float *)dst + dst_batch_offset + i * oc_block_channel;
1143 memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float));
1144 memset(dst_per_plane + channel, 0, (oc_block_channel - channel) * sizeof(float));
1145 }
1146 }
1147 } else {
1148 size_t ori_input_size = batch * plane * channel * sizeof(float);
1149 memcpy((float *)dst, (float *)src, ori_input_size);
1150 }
1151 }
1152
1153 #if defined(ENABLE_AVX) || defined(ENABLE_ARM64)
PackNHWCToNXHWCXFp32H1W1(int output_channel,int oc_block_num,int input_channel,float * tmp_weight,const float * src,int oc_block_unit,Transpose8X8Fp32Func transpose_func)1154 void PackNHWCToNXHWCXFp32H1W1(int output_channel, int oc_block_num, int input_channel, float *tmp_weight,
1155 const float *src, int oc_block_unit, Transpose8X8Fp32Func transpose_func) {
1156 int oc_block8 = DOWN_DIV(output_channel, C8NUM);
1157 int oc = 0;
1158 int oc_block = 0;
1159 int ic8 = DOWN_ROUND(input_channel, C8NUM);
1160 int oc_remainder_step = 0;
1161 if (oc_block8 != oc_block_num) {
1162 oc_block8 = oc_block8 / oc_block_unit * oc_block_unit;
1163 oc_remainder_step = (oc_block_num - oc_block8) * C8NUM;
1164 }
1165 for (; oc < oc_block8; oc += (oc_block / C8NUM)) {
1166 oc_block = MSMIN(oc_block_unit, oc_block8 - oc) * C8NUM; // max_tile = 32 ==> 24 ==> 16 ==> 8
1167 for (int oc_tmp = 0; oc_tmp < oc_block; oc_tmp += C8NUM) {
1168 int ic = 0;
1169 for (; ic < ic8; ic += C8NUM) {
1170 transpose_func(src + ic, tmp_weight + ic * oc_block + oc_tmp, input_channel, oc_block);
1171 }
1172 for (; ic < input_channel; ++ic) {
1173 for (int j = 0; j < C8NUM; ++j) {
1174 tmp_weight[ic * oc_block + oc_tmp + j] = src[ic + input_channel * j];
1175 }
1176 }
1177 src += C8NUM * input_channel;
1178 }
1179 tmp_weight += oc_block * input_channel;
1180 }
1181 oc = output_channel - oc_block8 * C8NUM;
1182 for (int oc_remainder = 0; oc_remainder < oc; ++oc_remainder) {
1183 for (int ic = 0; ic < input_channel; ++ic) {
1184 tmp_weight[oc_remainder + oc_remainder_step * ic] = src[ic + oc_remainder * input_channel];
1185 }
1186 }
1187 }
1188
1189 // PackNHWCToNXHWCXFp32 is SWPackNHWCToNXHWCXFp32 asm optimize
PackNHWCToNXHWCXFp32(int kernel_h,int kernel_w,int output_channel,int oc_block_num,int input_channel,float * tmp_weight,const float * src)1190 void PackNHWCToNXHWCXFp32(int kernel_h, int kernel_w, int output_channel, int oc_block_num, int input_channel,
1191 float *tmp_weight, const float *src) {
1192 #ifdef ENABLE_ARM64
1193 Transpose8X8Fp32Func transpose_func = Transpose8X8Fp32Arm64;
1194 int oc_block_unit = C2NUM;
1195 #elif defined(ENABLE_AVX)
1196 Transpose8X8Fp32Func transpose_func = Transpose8X8Fp32Avx;
1197 int oc_block_unit = C4NUM;
1198 #endif
1199 // pack weight NHWC to N32HWC32 N24HWC24 N16HWC16 N8HWC8
1200 // output_channel: batch
1201 int plane = kernel_w * kernel_h;
1202 if (plane == 1) { // conv 1x1 weight pack
1203 PackNHWCToNXHWCXFp32H1W1(output_channel, oc_block_num, input_channel, tmp_weight, src, oc_block_unit,
1204 transpose_func);
1205 return;
1206 }
1207
1208 int ic8 = DOWN_ROUND(input_channel, C8NUM);
1209 int oc_block8 = DOWN_DIV(output_channel, C8NUM);
1210 int oc_block = 0;
1211 int oc = 0;
1212 int oc_remainder_step = 0;
1213 if (oc_block8 != oc_block_num) {
1214 oc_block8 = oc_block8 / oc_block_unit * oc_block_unit;
1215 oc_remainder_step = (oc_block_num - oc_block8) * C8NUM;
1216 }
1217 for (; oc < oc_block8; oc += (oc_block / C8NUM)) {
1218 oc_block = MSMIN(oc_block_unit, oc_block8 - oc) * C8NUM; // max_tile = 32 ==> 24 ==> 16 ==> 8
1219 for (int oc_tmp = 0; oc_tmp < oc_block; oc_tmp += C8NUM) {
1220 for (int hw = 0; hw < plane; ++hw) {
1221 int ic = 0;
1222 for (; ic < ic8; ic += C8NUM) {
1223 transpose_func(src + hw * input_channel + ic,
1224 tmp_weight + hw * oc_block * input_channel + ic * oc_block + oc_tmp, input_channel * plane,
1225 oc_block);
1226 }
1227 for (; ic < input_channel; ++ic) {
1228 for (int j = 0; j < C8NUM; ++j) {
1229 tmp_weight[ic * oc_block + oc_tmp + j + hw * oc_block * input_channel] =
1230 src[ic + input_channel * j * plane + hw * input_channel];
1231 }
1232 }
1233 }
1234 src += C8NUM * plane * input_channel;
1235 }
1236 tmp_weight += oc_block * input_channel * plane;
1237 }
1238 oc = output_channel - oc_block8 * C8NUM;
1239 for (int oc_remainder = 0; oc_remainder < oc; ++oc_remainder) {
1240 for (int hw = 0; hw < plane; ++hw) {
1241 for (int ic = 0; ic < input_channel; ++ic) {
1242 tmp_weight[oc_remainder + oc_remainder_step * ic + hw * input_channel * oc_remainder_step] =
1243 src[ic + (oc_remainder * plane + hw) * input_channel];
1244 }
1245 }
1246 }
1247 }
1248
1249 #ifdef ENABLE_DEBUG
SWPackNHWCToNXHWCXFp32(int kernel_h,int kernel_w,int output_channel,int oc_block_num,int input_channel,float * tmp_weight,const float * src)1250 void SWPackNHWCToNXHWCXFp32(int kernel_h, int kernel_w, int output_channel, int oc_block_num, int input_channel,
1251 float *tmp_weight, const float *src) {
1252 // pack weight NHWC to N32HWC32 N24HWC24 N16HWC16 N8HWC8
1253 int oc_block = 0;
1254 for (int i = 0; i < oc_block_num; i += oc_block) {
1255 oc_block = MSMIN(C4NUM, oc_block_num - i); // max_tile = 4
1256 int index = i * C8NUM * kernel_h * kernel_w * input_channel;
1257 int oc_remainder = MSMIN(C8NUM * oc_block, output_channel - i * C8NUM);
1258 for (int h = 0; h < kernel_h; ++h) {
1259 for (int w = 0; w < kernel_w; ++w) {
1260 int w_index = (h * kernel_w + w) * input_channel + index;
1261 for (int ic = 0; ic < input_channel; ++ic) {
1262 int ic_index = ic + w_index;
1263 for (int oc = 0; oc < oc_remainder; ++oc) {
1264 int oc_index = oc * kernel_w * kernel_h * input_channel + ic_index;
1265 tmp_weight[oc] = src[oc_index];
1266 }
1267 tmp_weight += oc_block * C8NUM;
1268 }
1269 }
1270 }
1271 }
1272 }
1273 #endif
1274 #endif
1275
PackNHWCToNHWC8Fp32(const void * src,void * dst,int batch,int plane,int channel)1276 void PackNHWCToNHWC8Fp32(const void *src, void *dst, int batch, int plane, int channel) {
1277 int c8 = UP_DIV(channel, C8NUM);
1278 int c8_channel = c8 * C8NUM;
1279 int nhwc8_batch_unit_offset = c8 * C8NUM * plane;
1280 int ic_remainder_ = channel % C8NUM;
1281 if (ic_remainder_ != 0) {
1282 int nhwc8_batch_offset = 0;
1283 for (int b = 0; b < batch; b++) {
1284 int batch_offset = b * channel * plane;
1285 for (int i = 0; i < plane; i++) {
1286 float *dst_per_plane = (float *)dst + nhwc8_batch_offset + i * c8_channel;
1287 memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float));
1288 for (int j = channel; j < c8_channel; ++j) {
1289 dst_per_plane[j] = 0;
1290 }
1291 }
1292 nhwc8_batch_offset += nhwc8_batch_unit_offset;
1293 }
1294 } else {
1295 size_t ori_input_size = batch * plane * channel * sizeof(float);
1296 memcpy((float *)dst, (float *)src, ori_input_size);
1297 }
1298 }
1299
PackNHWCXToNHWCFp32(const void * src,void * dst,int batch,int plane,int channel,int cx_num)1300 void PackNHWCXToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel, int cx_num) {
1301 int c_algin = UP_DIV(channel, cx_num);
1302 int ic_remainder_ = channel % cx_num;
1303 if (ic_remainder_ != 0) {
1304 int nhwc_batch_unit_offset = channel * plane;
1305 for (int b = 0; b < batch; b++) {
1306 int batch_offset = b * c_algin * cx_num * plane;
1307 for (int i = 0; i < plane; i++) {
1308 memcpy((float *)dst + b * nhwc_batch_unit_offset + i * channel,
1309 (float *)src + batch_offset + i * c_algin * cx_num, channel * sizeof(float));
1310 }
1311 }
1312 } else {
1313 size_t ori_input_size = batch * plane * channel * sizeof(float);
1314 memcpy((float *)dst, (float *)src, ori_input_size);
1315 }
1316 }
1317
PackNC4HW4ToNHWC4Fp32(const void * src,void * dst,int batch,int plane,int channel)1318 void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) {
1319 int c4 = UP_DIV(channel, C4NUM);
1320 for (int b = 0; b < batch; b++) {
1321 int src_offset = b * plane * c4 * C4NUM;
1322 int dst_offset = b * plane * channel;
1323 for (int c = 0; c < channel; c++) {
1324 int c4_block_num = c / C4NUM;
1325 int c4_block_res = c % C4NUM;
1326 int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res;
1327 int dst_c_offset = dst_offset + c4_block_num * C4NUM + c4_block_res;
1328 for (int k = 0; k < plane; k++) {
1329 int src_kernel_offset = src_c_offset + k * C4NUM;
1330 int dst_kernel_offset = dst_c_offset + k * c4 * C4NUM;
1331 ((float *)dst + dst_kernel_offset)[0] = ((float *)src + src_kernel_offset)[0];
1332 }
1333 }
1334 }
1335 }
1336
UnPackC4Uint(const void * src,void * dst,size_t plane,size_t channel)1337 void UnPackC4Uint(const void *src, void *dst, size_t plane, size_t channel) {
1338 const float *fp32_src = (const float *)src;
1339 float *fp32_dst = (float *)dst;
1340 for (size_t c = 0; c < channel; c++) {
1341 size_t c_div = c / C4NUM;
1342 size_t c_mod = c % C4NUM;
1343 for (size_t p = 0; p < plane; p++) {
1344 int src_offset = c_div * plane * C4NUM + p * C4NUM + c_mod;
1345 int dst_offset = c * plane + p;
1346 fp32_dst[dst_offset] = fp32_src[src_offset];
1347 }
1348 }
1349 }
1350
PackNC4HW4ToNCHWFp32(const void * src,void * dst,int batch,int plane,int channel)1351 void PackNC4HW4ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel) {
1352 int c4 = UP_ROUND(channel, C4NUM);
1353 for (int b = 0; b < batch; b++) {
1354 int src_offset = b * plane * c4;
1355 int dst_offset = b * plane * channel;
1356 UnPackC4Uint((const float *)src + src_offset, (float *)dst + dst_offset, plane, channel);
1357 }
1358 }
1359
PackNC4HW4ToNHWCFp32(const void * src,void * dst,int batch,int plane,int channel)1360 void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) {
1361 int c4 = UP_DIV(channel, C4NUM);
1362 for (int b = 0; b < batch; b++) {
1363 int src_offset = b * plane * c4 * C4NUM;
1364 int dst_offset = b * plane * channel;
1365 for (int k = 0; k < plane; k++) {
1366 int src_kernel_offset = src_offset + k * C4NUM;
1367 int dst_kernel_offset = dst_offset + k * channel;
1368 for (int c = 0; c < c4 - 1; c++) {
1369 int src_c_offset = src_kernel_offset + c * plane * C4NUM;
1370 int dst_c_offset = dst_kernel_offset + c * C4NUM;
1371 #if defined(ENABLE_NEON) || defined(ENABLE_SSE)
1372 MS_STQ_F32((float *)dst + dst_c_offset, MS_LDQ_F32((float *)src + src_c_offset));
1373 #else
1374 ((float *)dst + dst_c_offset)[0] = ((float *)src + src_c_offset)[0];
1375 ((float *)dst + dst_c_offset)[1] = ((float *)src + src_c_offset)[1];
1376 ((float *)dst + dst_c_offset)[2] = ((float *)src + src_c_offset)[2];
1377 ((float *)dst + dst_c_offset)[3] = ((float *)src + src_c_offset)[3];
1378 #endif
1379 }
1380 // res part
1381 int res_c = channel - (c4 - 1) * C4NUM;
1382 for (int i = 0; i < res_c; i++) {
1383 int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i;
1384 int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i;
1385 ((float *)dst + dst_res_c_offset)[0] = ((float *)src + src_res_c_offset)[0];
1386 }
1387 }
1388 }
1389 }
1390
PackNC8HW8ToNCHWFp32(const void * src,void * dst,int batch,int plane,int channel)1391 void PackNC8HW8ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel) {
1392 int c8 = UP_ROUND(channel, C8NUM);
1393 for (int b = 0; b < batch; b++) {
1394 int src_offset = b * plane * c8;
1395 int dst_offset = b * plane * channel;
1396
1397 const float *fp32_src = (const float *)src + src_offset;
1398 float *fp32_dst = (float *)dst + dst_offset;
1399 for (size_t c = 0; c < channel; c++) {
1400 size_t c_div = c / C8NUM;
1401 size_t c_mod = c % C8NUM;
1402 for (size_t p = 0; p < plane; p++) {
1403 int src_offset_c = c_div * plane * C8NUM + p * C8NUM + c_mod;
1404 int dst_offset_c = c * plane + p;
1405 fp32_dst[dst_offset_c] = fp32_src[src_offset_c];
1406 }
1407 }
1408 }
1409 }
1410
PackNHWCToNC8HW8Fp32(const void * src,void * dst,int batch,int plane,int channel)1411 void PackNHWCToNC8HW8Fp32(const void *src, void *dst, int batch, int plane, int channel) {
1412 int c8 = UP_DIV(channel, C8NUM);
1413 int c8_minus = c8 - 1;
1414 for (int b = 0; b < batch; b++) {
1415 int src_oc_offset = b * plane * channel;
1416 int dst_oc_offset = b * plane * c8 * C8NUM;
1417 for (int k = 0; k < plane; k++) {
1418 int src_kernel_offset = src_oc_offset + k * channel;
1419 int dst_kernel_offset = dst_oc_offset + k * C8NUM;
1420 for (int j = 0; j < c8_minus; ++j) {
1421 int src_ic_offset = src_kernel_offset + j * C8NUM;
1422 int dst_ic_offset = dst_kernel_offset + j * plane * C8NUM;
1423 for (int i = 0; i < C8NUM; ++i) {
1424 ((float *)dst + dst_ic_offset)[i] = ((float *)src + src_ic_offset)[i];
1425 }
1426 }
1427 int tmp_c = c8_minus * C8NUM;
1428 int tmp_c_offset = tmp_c * plane;
1429 int res_c = channel - tmp_c;
1430 if (res_c > channel) {
1431 return;
1432 }
1433 for (int l = 0; l < res_c; ++l) {
1434 int src_ic_offset = src_kernel_offset + tmp_c + l;
1435 int dst_ic_offset = dst_kernel_offset + tmp_c_offset + l;
1436 ((float *)dst + dst_ic_offset)[0] = ((float *)src + src_ic_offset)[0];
1437 }
1438 }
1439 }
1440 }
1441
PackNC8HW8ToNHWCFp32(const void * src,void * dst,int batch,int plane,int channel)1442 void PackNC8HW8ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) {
1443 int c8 = UP_DIV(channel, C8NUM);
1444 for (int b = 0; b < batch; b++) {
1445 int src_offset = b * plane * c8 * C8NUM;
1446 int dst_offset = b * plane * channel;
1447 for (int k = 0; k < plane; k++) {
1448 int src_kernel_offset = src_offset + k * C8NUM;
1449 int dst_kernel_offset = dst_offset + k * channel;
1450 for (int c = 0; c < c8 - 1; c++) {
1451 int src_c_offset = src_kernel_offset + c * plane * C8NUM;
1452 int dst_c_offset = dst_kernel_offset + c * C8NUM;
1453
1454 ((float *)dst + dst_c_offset)[Index0] = ((float *)src + src_c_offset)[Index0];
1455 ((float *)dst + dst_c_offset)[Index1] = ((float *)src + src_c_offset)[Index1];
1456 ((float *)dst + dst_c_offset)[Index2] = ((float *)src + src_c_offset)[Index2];
1457 ((float *)dst + dst_c_offset)[Index3] = ((float *)src + src_c_offset)[Index3];
1458 }
1459 // res part
1460 int res_c = channel - (c8 - 1) * C8NUM;
1461 for (int i = 0; i < res_c; i++) {
1462 int src_res_c_offset = src_kernel_offset + (c8 - 1) * C8NUM * plane + i;
1463 int dst_res_c_offset = dst_kernel_offset + (c8 - 1) * C8NUM + i;
1464 ((float *)dst + dst_res_c_offset)[0] = ((float *)src + src_res_c_offset)[0];
1465 }
1466 }
1467 }
1468 }
1469
PackNC8HW8AlignedToNC8HW8NotAlignedFp32(const void * src,void * dst,const int batch,const int plane,const int channel)1470 void PackNC8HW8AlignedToNC8HW8NotAlignedFp32(const void *src, void *dst, const int batch, const int plane,
1471 const int channel) {
1472 int down_channel_8 = DOWN_ROUND(channel, C8NUM);
1473 int up_channel_16 = UP_ROUND(channel, C16NUM);
1474 size_t dst_batch_offset = (size_t)(plane * channel) * sizeof(float);
1475 size_t src_batch_offset = (size_t)(plane * up_channel_16) * sizeof(float);
1476 size_t unaligned_channel_size = (size_t)(channel - down_channel_8) * sizeof(float);
1477 size_t aligned_channel_size = (size_t)(down_channel_8 * plane) * sizeof(float);
1478 size_t src_p_offset = C8NUM * sizeof(float);
1479 for (size_t b = 0; b < (size_t)(batch); ++b) {
1480 const char *src_batch = (char *)(src) + b * src_batch_offset;
1481 char *dst_bacth = (char *)(dst) + b * dst_batch_offset;
1482 memcpy(dst_bacth, src_batch, aligned_channel_size);
1483 src_batch += aligned_channel_size;
1484 dst_bacth += aligned_channel_size;
1485 for (int p = 0; p < plane; ++p) {
1486 memcpy(dst_bacth + p * unaligned_channel_size, src_batch + p * src_p_offset, unaligned_channel_size);
1487 }
1488 }
1489 }
1490
PackNHWCToC8HWN8Fp32(const void * src,void * dst,int batch,int plane,int channel)1491 void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel) {
1492 int channel_up8 = UP_ROUND(channel, C8NUM);
1493 for (int n = 0; n < batch; n++) {
1494 for (int hw = 0; hw < plane; hw++) {
1495 int c = 0;
1496 for (; c < channel; c++) {
1497 int c8div = c / C8NUM;
1498 int c8mod = c % C8NUM;
1499 int src_index = n * plane * channel + hw * channel + c;
1500 int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod;
1501 ((float *)dst)[dst_index] = ((float *)src)[src_index];
1502 }
1503 for (; c < channel_up8; c++) {
1504 int c8div = c / C8NUM;
1505 int c8mod = c % C8NUM;
1506 int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod;
1507 ((float *)dst)[dst_index] = 0;
1508 }
1509 }
1510 }
1511 }
1512
PackNHWCToCXHWNXFp32(const float * src,float * dst,int batch,int plane,int channel)1513 void PackNHWCToCXHWNXFp32(const float *src, float *dst, int batch, int plane, int channel) {
1514 // pack weight NHWC to C24HWN24 (Priority 24)=>C16HWN16 (Not satisfied 24)=>C8HWN8 (Not satisfied 16)
1515 #ifdef ENABLE_AVX
1516 int oc_block_num = UP_DIV(channel, C8NUM);
1517 int plane16 = plane / C16NUM * C16NUM;
1518 for (int i = 0, oc_block = 0; i < oc_block_num; i += oc_block) {
1519 oc_block = MSMIN(C3NUM, oc_block_num - i);
1520 int oc_remainder = MSMIN(C8NUM * oc_block, channel - i * C8NUM);
1521 int oc_remainder_c8 = oc_remainder / C8NUM * C8NUM;
1522 int p = 0;
1523 for (; p < plane16; p += C16NUM) {
1524 int index_plane = i * C8NUM + p * channel;
1525 for (int b = 0; b < batch; ++b) {
1526 int index_batch = index_plane + b * plane * channel;
1527 int oc = 0;
1528 int stride = oc_block * C8NUM * batch;
1529 for (; oc < oc_remainder_c8; oc += C8NUM) {
1530 const float *cur_src = src + index_batch + oc;
1531 float *cur_dst = dst + oc;
1532 MS_LOAD256X16_F32(r, cur_src, channel);
1533 STORE256X16_F32(cur_dst, stride, r);
1534 }
1535 for (; oc < oc_remainder; ++oc) {
1536 for (int k = 0; k < C16NUM; ++k) {
1537 dst[oc + stride * k] = src[index_batch + oc + channel * k];
1538 }
1539 }
1540 for (; oc < C8NUM; ++oc) {
1541 for (int k = 0; k < C16NUM; ++k) {
1542 dst[oc + stride * k] = 0;
1543 }
1544 }
1545 dst += oc_block * C8NUM;
1546 }
1547 dst += (C16NUM - 1) * oc_block * C8NUM * batch;
1548 }
1549 for (; p < plane; ++p) {
1550 int index_plane = i * C8NUM + p * channel;
1551 for (int b = 0; b < batch; ++b) {
1552 int index_batch = index_plane + b * plane * channel;
1553 int oc = 0;
1554 for (; oc < oc_remainder; ++oc) {
1555 dst[oc] = src[index_batch + oc];
1556 }
1557 for (; oc < C8NUM; ++oc) {
1558 dst[oc] = 0;
1559 }
1560 dst += oc_block * C8NUM;
1561 }
1562 }
1563 }
1564 #else
1565 int oc_block = 0;
1566 int oc_block_num = UP_DIV(channel, C8NUM);
1567 for (int i = 0; i < oc_block_num; i += oc_block) {
1568 oc_block = MSMIN(C3NUM, oc_block_num - i); // max_tile = 4
1569 int oc_remainder = MSMIN(C8NUM * oc_block, channel - i * C8NUM);
1570 for (int p = 0; p < plane; ++p) {
1571 int index_plane = i * C8NUM + p * channel;
1572 for (int b = 0; b < batch; ++b) {
1573 int index_batch = index_plane + b * plane * channel;
1574 for (int oc = 0; oc < oc_remainder; ++oc) {
1575 dst[oc] = src[index_batch + oc];
1576 }
1577 dst += oc_block * C8NUM;
1578 }
1579 }
1580 }
1581 #endif
1582 }
1583
PackDepthwiseIndirectWeightC4Fp32(const void * src,void * dst,int height,int width,int channel)1584 void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel) {
1585 int c4 = UP_DIV(channel, C4NUM);
1586 for (int c = 0; c < c4; c++) {
1587 int dst_off_c = c * C4NUM * height * width;
1588 for (int i = 0; i < C4NUM; i++) {
1589 int src_off_c = (c * C4NUM + i) * height * width;
1590 for (int kh = 0; kh < height; kh++) {
1591 int src_off_kh = src_off_c + kh * width;
1592 for (int kw = 0; kw < width; kw++) {
1593 int dst_off = dst_off_c + kw * height * C4NUM + kh * C4NUM + i;
1594 ((float *)dst)[dst_off] = ((float *)src)[src_off_kh + kw];
1595 }
1596 }
1597 }
1598 }
1599 }
1600
PackDepthwiseIndirectWeightC8Fp32(const void * src,void * dst,int height,int width,int channel)1601 void PackDepthwiseIndirectWeightC8Fp32(const void *src, void *dst, int height, int width, int channel) {
1602 int c8 = UP_DIV(channel, C8NUM);
1603 for (int c = 0; c < c8; c++) {
1604 int dst_off_c = c * C8NUM * height * width;
1605 for (int i = 0; i < C8NUM; i++) {
1606 int src_off_c = (c * C8NUM + i) * height * width;
1607 for (int kh = 0; kh < height; kh++) {
1608 int src_off_kh = src_off_c + kh * width;
1609 for (int kw = 0; kw < width; kw++) {
1610 int dst_off = dst_off_c + kw * height * C8NUM + kh * C8NUM + i;
1611 ((float *)dst)[dst_off] = ((float *)src)[src_off_kh + kw];
1612 }
1613 }
1614 }
1615 }
1616 }
1617
PackNHWCToNCHWFp32(const void * src,void * dst,int batches,int plane,int channel,int task_id,int thread_count)1618 void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel, int task_id,
1619 int thread_count) {
1620 #ifdef ENABLE_ARM64
1621 Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Arm64;
1622 #elif defined(ENABLE_ARM32)
1623 Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Arm32;
1624 #elif defined(ENABLE_AVX)
1625 Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Avx;
1626 #elif defined(ENABLE_SSE) && !defined(ENABLE_AVX)
1627 Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Sse;
1628 #endif
1629 int hw8 = plane / C8NUM;
1630 int task_start = 0;
1631 int task_end = plane;
1632 if (thread_count > 0) {
1633 int offset_hw = UP_DIV(hw8, thread_count) * C8NUM;
1634 task_start = offset_hw * task_id;
1635 int count = plane - task_start;
1636 if (count <= 0) {
1637 return;
1638 }
1639 task_end = (task_id + 1) == thread_count ? plane : MSMIN(plane, task_start + offset_hw);
1640 hw8 = task_start + ((task_end - task_start) >= offset_hw ? offset_hw : 0);
1641 } else {
1642 hw8 *= C8NUM;
1643 }
1644 int c8 = channel / C8NUM * C8NUM;
1645 int batch = plane * channel;
1646 for (int n = 0; n < batches; n++) {
1647 const float *src_batch = (const float *)src + n * batch;
1648 float *dst_batch = (float *)dst + n * batch;
1649 int hw = task_start;
1650 for (; hw < hw8; hw += C8NUM) {
1651 int c = 0;
1652 for (; c < c8; c += C8NUM) {
1653 const float *src_ptr = src_batch + hw * channel + c;
1654 float *dst_ptr = dst_batch + c * plane + hw;
1655 #if defined(ENABLE_ARM64) || defined(ENABLE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM32)
1656 Transpose8X8Fp32Func_(src_ptr, dst_ptr, channel, plane);
1657 #else
1658 for (int tr = 0; tr < C8NUM; tr++) {
1659 for (int tc = 0; tc < C8NUM; tc++) {
1660 dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc];
1661 }
1662 }
1663 #endif
1664 }
1665 for (; c < channel; c++) {
1666 const float *src_ptr = src_batch + hw * channel + c;
1667 float *dst_ptr = dst_batch + c * plane + hw;
1668 for (size_t i = 0; i < C8NUM; i++) {
1669 dst_ptr[i] = src_ptr[i * channel];
1670 }
1671 }
1672 }
1673 for (; hw < task_end; hw++) {
1674 const float *src_ptr = src_batch + hw * channel;
1675 float *dst_ptr = dst_batch + hw;
1676 for (size_t i = 0; i < channel; i++) {
1677 dst_ptr[i * plane] = src_ptr[i];
1678 }
1679 }
1680 }
1681 }
1682
1683 /*
1684 |<---------------- plane --------------->|
1685 +---------------------------+------------+ ---
1686 | | | | | ↑
1687 |8x8-blocks| ... |8x8-blocks| right | |
1688 | | | | | |
1689 + - - - - -+ + - - - - -+ | |
1690 | ... ... ... | top | channel
1691 + - - - - -+ + - - - - -| | |
1692 | | | | tails | |
1693 |8x8-blocks| ... |8x8-blocks| | |
1694 +---------------------------+------------+ |
1695 | |right bottom| |
1696 | left bottom tails | tails | ↓
1697 +---------------------------+------------+ ---
1698 */
TransposeFp32(const void * src,void * dst,int batches,int channel,int plane,int start,int end)1699 void TransposeFp32(const void *src, void *dst, int batches, int channel, int plane, int start, int end) {
1700 #ifdef ENABLE_ARM64
1701 Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Arm64;
1702 #elif defined(ENABLE_ARM32)
1703 Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Arm32;
1704 #elif defined(ENABLE_AVX)
1705 Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Avx;
1706 #elif defined(ENABLE_SSE) && !defined(ENABLE_AVX)
1707 Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Sse;
1708 #endif
1709 int m_pad = UP_DIV(channel, C8NUM);
1710 int n_pad = UP_DIV(plane, C8NUM);
1711 int m_blk = channel / C8NUM;
1712 int n_blk = plane / C8NUM;
1713 int b_stride = plane * channel;
1714 // printf("channel, plane: %d, %d\n", channel, plane);
1715 int b = 0, m = 0, n = 0;
1716 // To make write dst consecutively, (m,n):(0,0)->(1,0)->...->(0,1)->(1,1)->...
1717 offset_to_index_init(start, 6, &m, m_pad, &n, n_pad, &b, batches);
1718 for (int task = start; task < end; task++) {
1719 const float *src_batch = (const float *)src + b * b_stride;
1720 float *dst_batch = (float *)dst + b * b_stride;
1721 int m_start = m * C8NUM;
1722 int n_start = n * C8NUM;
1723 if (m < m_blk && n < n_blk) {
1724 // process 8x8-blocks
1725 const float *from = src_batch + m_start * plane + n_start;
1726 float *to = dst_batch + n_start * channel + m_start;
1727 #if defined(ENABLE_ARM64) || defined(ENABLE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM32)
1728 Transpose8X8Fp32Func_(from, to, plane, channel);
1729 #else
1730 for (int tr = 0; tr < C8NUM; tr++) {
1731 for (int tc = 0; tc < C8NUM; tc++) {
1732 to[tc * channel + tr] = from[tr * plane + tc];
1733 }
1734 }
1735 #endif
1736 } else {
1737 // process right bottom tails
1738 const float *from = src_batch;
1739 float *to = dst_batch;
1740 int i_start = m_start;
1741 int i_end = channel;
1742 int j_start = n_start;
1743 int j_end = plane;
1744 if (m >= m_blk && n < n_blk) {
1745 // process left bottom tails
1746 from = src_batch + n_start;
1747 to = dst_batch + n_start * channel;
1748 j_start = 0;
1749 j_end = C8NUM;
1750 } else if (m < m_blk && n >= n_blk) {
1751 // process right top tails
1752 from = src_batch + m_start * plane;
1753 to = dst_batch + m_start;
1754 i_start = 0;
1755 i_end = C8NUM;
1756 }
1757 transpose_tail(from, to, j_start, j_end, i_start, i_end, channel, plane);
1758 }
1759 offset_to_index_step(6, &m, m_pad, &n, n_pad, &b, batches);
1760 }
1761 }
1762
PackNCHWToNHWCFp32(const void * src,void * dst,int batch,int plane,int channel,int task_id,int thread_count)1763 void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel, int task_id, int thread_count) {
1764 PackNHWCToNCHWFp32(src, dst, batch, channel, plane, task_id, thread_count);
1765 }
1766
1767 #ifdef ENABLE_ARM64
Transpose8X8Fp32Arm64(const float * src_ptr,float * dst_ptr,int src_stride,int dst_stride)1768 inline void Transpose8X8Fp32Arm64(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) {
1769 size_t srcStride = src_stride * sizeof(float);
1770 size_t dstStride = dst_stride * sizeof(float);
1771 asm volatile(
1772 "mov x10, %[src_ptr]\n"
1773 "mov x11, %[dst_ptr]\n"
1774
1775 "ld1 {v0.4s, v1.4s}, [x10], %[srcStride]\n"
1776 "ld1 {v2.4s, v3.4s}, [x10], %[srcStride]\n"
1777
1778 "zip1 v8.4s, v0.4s, v2.4s\n"
1779 "zip2 v9.4s, v0.4s, v2.4s\n"
1780 "zip1 v12.4s, v1.4s, v3.4s\n"
1781 "zip2 v13.4s, v1.4s, v3.4s\n"
1782
1783 "ld1 {v4.4s, v5.4s}, [x10], %[srcStride]\n"
1784 "ld1 {v6.4s, v7.4s}, [x10], %[srcStride]\n"
1785
1786 "zip1 v10.4s, v4.4s, v6.4s\n"
1787 "zip2 v11.4s, v4.4s, v6.4s\n"
1788 "zip1 v14.4s, v5.4s, v7.4s\n"
1789 "zip2 v15.4s, v5.4s, v7.4s\n"
1790
1791 "ld1 {v0.4s, v1.4s}, [x10], %[srcStride]\n"
1792 "ld1 {v2.4s, v3.4s}, [x10], %[srcStride]\n"
1793
1794 "trn1 v16.2d, v8.2d, v10.2d\n"
1795 "trn2 v18.2d, v8.2d, v10.2d\n"
1796 "trn1 v20.2d, v9.2d, v11.2d\n"
1797 "trn2 v22.2d, v9.2d, v11.2d\n"
1798
1799 "ld1 {v4.4s, v5.4s}, [x10], %[srcStride]\n"
1800 "ld1 {v6.4s, v7.4s}, [x10], %[srcStride]\n"
1801
1802 "trn1 v24.2d, v12.2d, v14.2d\n"
1803 "trn2 v26.2d, v12.2d, v14.2d\n"
1804 "trn1 v28.2d, v13.2d, v15.2d\n"
1805 "trn2 v30.2d, v13.2d, v15.2d\n"
1806
1807 "zip1 v8.4s, v0.4s, v2.4s\n"
1808 "zip2 v9.4s, v0.4s, v2.4s\n"
1809 "zip1 v12.4s, v1.4s, v3.4s\n"
1810 "zip2 v13.4s, v1.4s, v3.4s\n"
1811
1812 "zip1 v10.4s, v4.4s, v6.4s\n"
1813 "zip2 v11.4s, v4.4s, v6.4s\n"
1814 "zip1 v14.4s, v5.4s, v7.4s\n"
1815 "zip2 v15.4s, v5.4s, v7.4s\n"
1816
1817 "trn1 v17.2d, v8.2d, v10.2d\n"
1818 "trn2 v19.2d, v8.2d, v10.2d\n"
1819 "trn1 v21.2d, v9.2d, v11.2d\n"
1820 "trn2 v23.2d, v9.2d, v11.2d\n"
1821
1822 "trn1 v25.2d, v12.2d, v14.2d\n"
1823 "trn2 v27.2d, v12.2d, v14.2d\n"
1824 "trn1 v29.2d, v13.2d, v15.2d\n"
1825 "trn2 v31.2d, v13.2d, v15.2d\n"
1826
1827 "st1 {v16.4s, v17.4s}, [x11], %[dstStride]\n"
1828 "st1 {v18.4s, v19.4s}, [x11], %[dstStride]\n"
1829 "st1 {v20.4s, v21.4s}, [x11], %[dstStride]\n"
1830 "st1 {v22.4s, v23.4s}, [x11], %[dstStride]\n"
1831 "st1 {v24.4s, v25.4s}, [x11], %[dstStride]\n"
1832 "st1 {v26.4s, v27.4s}, [x11], %[dstStride]\n"
1833 "st1 {v28.4s, v29.4s}, [x11], %[dstStride]\n"
1834 "st1 {v30.4s, v31.4s}, [x11], %[dstStride]\n"
1835
1836 :
1837 : [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride)
1838 : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
1839 "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
1840 "v31");
1841 }
1842 #endif
1843
1844 #ifdef ENABLE_ARM32
Transpose8X8Fp32Arm32(const float * src_ptr,float * dst_ptr,int src_stride,int dst_stride)1845 inline void Transpose8X8Fp32Arm32(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) {
1846 size_t srcStride = src_stride * sizeof(float);
1847 size_t dstStride = dst_stride * sizeof(float);
1848 asm volatile(
1849 "mov r10, %[src_ptr]\n"
1850 "mov r12, %[dst_ptr]\n"
1851
1852 "vld1.32 {q0, q1}, [r10], %[srcStride]\n"
1853 "vld1.32 {q2, q3}, [r10], %[srcStride]\n"
1854
1855 "vtrn.32 d0, d4\n"
1856 "vtrn.32 d1, d5\n"
1857 "vtrn.32 d2, d6\n"
1858 "vtrn.32 d3, d7\n"
1859
1860 "vld1.32 {q4, q5}, [r10], %[srcStride]\n"
1861 "vld1.32 {q6, q7}, [r10], %[srcStride]\n"
1862
1863 "vtrn.32 d8, d12\n"
1864 "vtrn.32 d9, d13\n"
1865 "vtrn.32 d10, d14\n"
1866 "vtrn.32 d11, d15\n"
1867
1868 "vld1.32 {q8, q9}, [r10], %[srcStride]\n"
1869 "vld1.32 {q10, q11}, [r10], %[srcStride]\n"
1870
1871 "vswp d1, d8\n"
1872 "vswp d3, d10\n"
1873 "vswp d5, d12\n"
1874 "vswp d7, d14\n"
1875
1876 "vtrn.32 d16, d20\n"
1877 "vtrn.32 d17, d21\n"
1878 "vtrn.32 d18, d22\n"
1879 "vtrn.32 d19, d23\n"
1880
1881 "vld1.32 {q12, q13}, [r10], %[srcStride]\n"
1882 "vld1.32 {q14, q15}, [r10], %[srcStride]\n"
1883
1884 "vtrn.32 d24, d28\n"
1885 "vtrn.32 d25, d29\n"
1886 "vtrn.32 d26, d30\n"
1887 "vtrn.32 d27, d31\n"
1888
1889 "vswp d17, d24\n"
1890 "vswp d19, d26\n"
1891 "vswp d21, d28\n"
1892 "vswp d23, d30\n"
1893
1894 "add r10, r12, #16\n"
1895 "vst1.32 {q0}, [r12], %[dstStride]\n"
1896 "vst1.32 {q8}, [r10], %[dstStride]\n"
1897 "vst1.32 {q2}, [r12], %[dstStride]\n"
1898 "vst1.32 {q10}, [r10], %[dstStride]\n"
1899 "vst1.32 {q4}, [r12], %[dstStride]\n"
1900 "vst1.32 {q12}, [r10], %[dstStride]\n"
1901 "vst1.32 {q6}, [r12], %[dstStride]\n"
1902 "vst1.32 {q14}, [r10], %[dstStride]\n"
1903 "vst1.32 {q1}, [r12], %[dstStride]\n"
1904 "vst1.32 {q9}, [r10], %[dstStride]\n"
1905 "vst1.32 {q3}, [r12], %[dstStride]\n"
1906 "vst1.32 {q11}, [r10], %[dstStride]\n"
1907 "vst1.32 {q5}, [r12], %[dstStride]\n"
1908 "vst1.32 {q13}, [r10], %[dstStride]\n"
1909 "vst1.32 {q7}, [r12], %[dstStride]\n"
1910 "vst1.32 {q15}, [r10], %[dstStride]\n"
1911
1912 :
1913 : [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride)
1914 : "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14",
1915 "q15");
1916 }
1917 #endif
1918
1919 #ifdef ENABLE_AVX
1920 /*
1921 Using _mm256_insertf128_ps at the beginning, instead of using _mm256_permute2f128_ps at the end.
1922 On the whole, 4 vinsertf128 and 4 vperm2f128 are used less than before.
1923 */
Transpose8X8Fp32Avx(const float * src_ptr,float * dst_ptr,int src_stride,int dst_stride)1924 inline void Transpose8X8Fp32Avx(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) {
1925 const float *src1 = src_ptr + 0 * src_stride;
1926 const float *src2 = src_ptr + 1 * src_stride;
1927 const float *src3 = src_ptr + 2 * src_stride;
1928 const float *src4 = src_ptr + 3 * src_stride;
1929 const float *src5 = src_ptr + 4 * src_stride;
1930 const float *src6 = src_ptr + 5 * src_stride;
1931 const float *src7 = src_ptr + 6 * src_stride;
1932 const float *src8 = src_ptr + 7 * src_stride;
1933
1934 __m256 r1, r2, r3, r4, r5, r6, r7, r8;
1935 __m256 t1, t2, t3, t4, t5, t6, t7, t8;
1936 // _mm256_castps128_ps256 is only for compilation and generates no instructions, thus it has zero latency.
1937 r1 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src1 + 0)), _mm_loadu_ps(src5 + 0), 1);
1938 r2 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src2 + 0)), _mm_loadu_ps(src6 + 0), 1);
1939 r3 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src3 + 0)), _mm_loadu_ps(src7 + 0), 1);
1940 r4 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src4 + 0)), _mm_loadu_ps(src8 + 0), 1);
1941 r5 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src1 + 4)), _mm_loadu_ps(src5 + 4), 1);
1942 r6 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src2 + 4)), _mm_loadu_ps(src6 + 4), 1);
1943 r7 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src3 + 4)), _mm_loadu_ps(src7 + 4), 1);
1944 r8 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src4 + 4)), _mm_loadu_ps(src8 + 4), 1);
1945
1946 t1 = _mm256_unpacklo_ps(r1, r2);
1947 t2 = _mm256_unpackhi_ps(r1, r2);
1948 t3 = _mm256_unpacklo_ps(r3, r4);
1949 t4 = _mm256_unpackhi_ps(r3, r4);
1950 t5 = _mm256_unpacklo_ps(r5, r6);
1951 t6 = _mm256_unpackhi_ps(r5, r6);
1952 t7 = _mm256_unpacklo_ps(r7, r8);
1953 t8 = _mm256_unpackhi_ps(r7, r8);
1954
1955 __m256 v;
1956 v = _mm256_shuffle_ps(t1, t3, 0x4E);
1957 r1 = _mm256_blend_ps(t1, v, 0xCC);
1958 r2 = _mm256_blend_ps(t3, v, 0x33);
1959
1960 v = _mm256_shuffle_ps(t2, t4, 0x4E);
1961 r3 = _mm256_blend_ps(t2, v, 0xCC);
1962 r4 = _mm256_blend_ps(t4, v, 0x33);
1963
1964 v = _mm256_shuffle_ps(t5, t7, 0x4E);
1965 r5 = _mm256_blend_ps(t5, v, 0xCC);
1966 r6 = _mm256_blend_ps(t7, v, 0x33);
1967
1968 v = _mm256_shuffle_ps(t6, t8, 0x4E);
1969 r7 = _mm256_blend_ps(t6, v, 0xCC);
1970 r8 = _mm256_blend_ps(t8, v, 0x33);
1971
1972 STORE256X8_F32(dst_ptr, dst_stride, r);
1973 }
1974 #endif
1975
1976 #if defined(ENABLE_SSE) && !defined(ENABLE_AVX)
Transpose8X8Fp32Sse(const float * src_ptr,float * dst_ptr,int src_stride,int dst_stride)1977 inline void Transpose8X8Fp32Sse(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) {
1978 __m128 v0_ma = _mm_loadu_ps(src_ptr);
1979 __m128 v1_ma = _mm_loadu_ps(src_ptr + src_stride);
1980 __m128 v2_ma = _mm_loadu_ps(src_ptr + 2 * src_stride);
1981 __m128 v3_ma = _mm_loadu_ps(src_ptr + 3 * src_stride);
1982
1983 __m128 v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma);
1984 __m128 v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma);
1985 __m128 v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma);
1986 __m128 v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma);
1987
1988 __m128 v8_ma = _mm_movelh_ps(v4_ma, v6_ma);
1989 __m128 v9_ma = _mm_movehl_ps(v6_ma, v4_ma);
1990 __m128 v10_ma = _mm_movelh_ps(v5_ma, v7_ma);
1991 __m128 v11_ma = _mm_movehl_ps(v7_ma, v5_ma);
1992
1993 _mm_storeu_ps(dst_ptr, v8_ma);
1994 _mm_storeu_ps(dst_ptr + dst_stride, v9_ma);
1995 _mm_storeu_ps(dst_ptr + 2 * dst_stride, v10_ma);
1996 _mm_storeu_ps(dst_ptr + 3 * dst_stride, v11_ma);
1997
1998 v0_ma = _mm_loadu_ps(src_ptr + C4NUM);
1999 v1_ma = _mm_loadu_ps(src_ptr + src_stride + C4NUM);
2000 v2_ma = _mm_loadu_ps(src_ptr + 2 * src_stride + C4NUM);
2001 v3_ma = _mm_loadu_ps(src_ptr + 3 * src_stride + C4NUM);
2002
2003 v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma);
2004 v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma);
2005 v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma);
2006 v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma);
2007
2008 v8_ma = _mm_movelh_ps(v4_ma, v6_ma);
2009 v9_ma = _mm_movehl_ps(v6_ma, v4_ma);
2010 v10_ma = _mm_movelh_ps(v5_ma, v7_ma);
2011 v11_ma = _mm_movehl_ps(v7_ma, v5_ma);
2012
2013 _mm_storeu_ps(dst_ptr + C4NUM * dst_stride, v8_ma);
2014 _mm_storeu_ps(dst_ptr + (C4NUM + 1) * dst_stride, v9_ma);
2015 _mm_storeu_ps(dst_ptr + (C4NUM + 2) * dst_stride, v10_ma);
2016 _mm_storeu_ps(dst_ptr + (C4NUM + 3) * dst_stride, v11_ma);
2017
2018 v0_ma = _mm_loadu_ps(src_ptr + C4NUM * src_stride);
2019 v1_ma = _mm_loadu_ps(src_ptr + (C4NUM + 1) * src_stride);
2020 v2_ma = _mm_loadu_ps(src_ptr + (C4NUM + 2) * src_stride);
2021 v3_ma = _mm_loadu_ps(src_ptr + (C4NUM + 3) * src_stride);
2022
2023 v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma);
2024 v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma);
2025 v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma);
2026 v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma);
2027
2028 v8_ma = _mm_movelh_ps(v4_ma, v6_ma);
2029 v9_ma = _mm_movehl_ps(v6_ma, v4_ma);
2030 v10_ma = _mm_movelh_ps(v5_ma, v7_ma);
2031 v11_ma = _mm_movehl_ps(v7_ma, v5_ma);
2032
2033 _mm_storeu_ps(dst_ptr + C4NUM, v8_ma);
2034 _mm_storeu_ps(dst_ptr + dst_stride + C4NUM, v9_ma);
2035 _mm_storeu_ps(dst_ptr + 2 * dst_stride + C4NUM, v10_ma);
2036 _mm_storeu_ps(dst_ptr + 3 * dst_stride + C4NUM, v11_ma);
2037
2038 v0_ma = _mm_loadu_ps(src_ptr + C4NUM * src_stride + C4NUM);
2039 v1_ma = _mm_loadu_ps(src_ptr + (C4NUM + 1) * src_stride + C4NUM);
2040 v2_ma = _mm_loadu_ps(src_ptr + (C4NUM + 2) * src_stride + C4NUM);
2041 v3_ma = _mm_loadu_ps(src_ptr + (C4NUM + 3) * src_stride + C4NUM);
2042
2043 v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma);
2044 v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma);
2045 v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma);
2046 v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma);
2047
2048 v8_ma = _mm_movelh_ps(v4_ma, v6_ma);
2049 v9_ma = _mm_movehl_ps(v6_ma, v4_ma);
2050 v10_ma = _mm_movelh_ps(v5_ma, v7_ma);
2051 v11_ma = _mm_movehl_ps(v7_ma, v5_ma);
2052
2053 _mm_storeu_ps(dst_ptr + C4NUM * dst_stride + C4NUM, v8_ma);
2054 _mm_storeu_ps(dst_ptr + (C4NUM + 1) * dst_stride + C4NUM, v9_ma);
2055 _mm_storeu_ps(dst_ptr + (C4NUM + 2) * dst_stride + C4NUM, v10_ma);
2056 _mm_storeu_ps(dst_ptr + (C4NUM + 3) * dst_stride + C4NUM, v11_ma);
2057 }
2058 #endif
2059
2060 #if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
PackWeightConvDw3x3Fp32(const void * src,void * dst,int channel)2061 void PackWeightConvDw3x3Fp32(const void *src, void *dst, int channel) {
2062 // nchw to nc4hw4 with 1D F(2,3)
2063 for (int i = 0; i < channel; i++) {
2064 float *src_kernel = (float *)src + i * 9;
2065 float *dst_kernel = (float *)dst + (i / 4) * 48 + i % 4;
2066 for (int y = 0; y < 3; y++) {
2067 float g0 = src_kernel[3 * y];
2068 float g1 = src_kernel[3 * y + 1];
2069 float g2 = src_kernel[3 * y + 2];
2070
2071 dst_kernel[16 * y] = g0;
2072 dst_kernel[16 * y + 4] = 0.5f * (g0 + g1 + g2);
2073 dst_kernel[16 * y + 8] = 0.5f * (g0 - g1 + g2);
2074 dst_kernel[16 * y + 12] = g2;
2075 }
2076 }
2077 }
2078 #endif
2079